Source code for dunedn.networks.onnx.utils

from dunedn.networks.utils import BatchProfiler
from tqdm.auto import tqdm
import numpy as np
import torch
import onnxruntime as ort


[docs]def gcnn_onnx_inference_pass( test_loader: torch.utils.data.DataLoader, ort_session: ort.InferenceSession, verbose: int = 1, profiler: BatchProfiler = None, ) -> torch.Tensor: """ Parameters ---------- generator: torch.utils.data.DataLoader The inference dataset generator. ort_session: ort.InferenceSession The onnxruntime inference session. verbose: int Switch to log information. Defaults to 1. Available options: - 0: no logs. - 1: display progress bar. profiler: BatchProfiler The profiler object to record batch inference time. Returns ------- torch.Tensor Output tensor of shape=(N,C,H,W). """ outs = [] wrap = tqdm(test_loader, desc="onnx.predict") if verbose else test_loader if profiler is not None: wrap = profiler.set_iterable(wrap) for noisy, _ in wrap: out = ort_session.run( None, {"input": noisy.numpy().astype(np.float32)}, )[0] outs.append(torch.Tensor(out)) output = torch.cat(outs) return output