Source code for dunedn.networks.onnx.onnx_gcnn_net

from pathlib import Path
from dunedn.networks.utils import BatchProfiler
import torch
from ..gcnn.gcnn_dataloading import GcnnDataset
from .onnx_abstract_net import OnnxNetwork
from .utils import gcnn_onnx_inference_pass
from dunedn.training.metrics import MetricsList


[docs]class OnnxGcnnNetwork(OnnxNetwork): """Subclass""" def __init__(self, ckpt: Path, metrics: MetricsList, providers: list[str] = None): """ Parameters ---------- ckpt: Path `.onnx` file path. metrics: MetricsList List of callable metrics. providers: list[str] List of providers. """ super().__init__(ckpt, metrics, providers=providers)
[docs] def predict( self, generator: GcnnDataset, profiler: BatchProfiler = None ) -> torch.Tensor: """ONNX GCNN network inference. Parameters ---------- generator: GcnnDataset The inference generator. profiler: BatchProfiler The profiler object to record batch inference time. Returns ------- torch.Tensor Output tensor of shape=(N,C,H,W). """ generator.to_crops() test_loader = torch.utils.data.DataLoader( dataset=generator, batch_size=generator.batch_size, ) output = gcnn_onnx_inference_pass(test_loader, self, profiler=profiler) y_pred = generator.converter.tiles2planes(output) generator.to_planes() return y_pred