Source code for dunedn.networks.onnx.onnx_abstract_net

from typing import Callable
from pathlib import Path
import onnxruntime as ort
from dunedn.training.metrics import MetricsList


[docs]class OnnxNetwork(ort.InferenceSession): """Subclass""" def __init__(self, ckpt: Path, metrics: MetricsList, providers: list[str] = None): """ Parameters ---------- ckpt: Path `.onnx` file path. metrics: MetricsList List of callable metrics. providers: list[str] List of providers. """ super().__init__(ckpt, providers=providers) self.metrics = metrics