dunedn.networks.onnx package
Submodules
dunedn.networks.onnx.onnx_abstract_net module
- class dunedn.networks.onnx.onnx_abstract_net.OnnxNetwork(ckpt: Path, metrics: MetricsList, providers: Optional[list[str]] = None)[source]
Bases:
InferenceSessionSubclass
dunedn.networks.onnx.onnx_gcnn_net module
- class dunedn.networks.onnx.onnx_gcnn_net.OnnxGcnnNetwork(ckpt: Path, metrics: MetricsList, providers: Optional[list[str]] = None)[source]
Bases:
OnnxNetworkSubclass
- predict(generator: GcnnDataset, profiler: Optional[BatchProfiler] = None) Tensor[source]
ONNX GCNN network inference.
- Parameters
generator (GcnnDataset) – The inference generator.
profiler (BatchProfiler) – The profiler object to record batch inference time.
- Returns
Output tensor of shape=(N,C,H,W).
- Return type
torch.Tensor
dunedn.networks.onnx.uscg_onnx_net module
This module implements ONNX port of the USCG network.
dunedn.networks.onnx.utils module
- dunedn.networks.onnx.utils.gcnn_onnx_inference_pass(test_loader: DataLoader, ort_session: InferenceSession, verbose: int = 1, profiler: Optional[BatchProfiler] = None) Tensor[source]
- Parameters
generator (torch.utils.data.DataLoader) – The inference dataset generator.
ort_session (ort.InferenceSession) – The onnxruntime inference session.
verbose (int) –
Switch to log information. Defaults to 1. Available options:
0: no logs.
1: display progress bar.
profiler (BatchProfiler) – The profiler object to record batch inference time.
- Returns
Output tensor of shape=(N,C,H,W).
- Return type
torch.Tensor