Source code for dunedn.networks.gcnn.utils

"""This module implements utility functions for the `networks.gcnn` subpackage."""
from collections import OrderedDict
from dunedn.networks.utils import BatchProfiler
from tqdm.auto import tqdm
import torch
from ..abstract_net import AbstractNet


[docs]def make_dict_compatible(state_dict: OrderedDict): """Transforms `state_dict` keys to match new GcnnNet attributes format. *Changed in version 2.0.0:* - Remove ".module" in front of the saved weights name. - Remove extra attributes due to deprecated normalization layer. - Transform layers names to lowercase. Parameters ---------- state_dict: OrderedDict The original dictionary containing network saved weights. Returns ------- new_state_dict: OrderedDict The dictionary updated version. """ extras = ["module.norm_fn.med", "module.norm_fn.Min", "module.norm_fn.Max"] for extra in extras: if state_dict.get(extra, None) is not None: state_dict.move_to_end(extra) state_dict.popitem() changes_from = [ "PreProcessBlocks", "PostProcessBlocks", "PreProcessBlock", "PostProcessBlock", ] changes_to = [ "pre_process_blocks", "post_process_blocks", "pre_process_block", "post_process_block", ] new_state_dict = OrderedDict() for k, v in state_dict.items(): for change_from, change_to in zip(changes_from, changes_to): if change_from in k: k = k.replace(change_from, change_to) new_k = k.lower().replace("module.", "") new_state_dict[new_k] = v return new_state_dict
[docs]def gcnn_inference_pass( test_loader: torch.utils.data.DataLoader, network: AbstractNet, dev: str, verbose: int = 1, profiler: BatchProfiler = None, ) -> torch.Tensor: """Consumes data through CNN or GCNN network and gives outputs. Parameters ---------- test_loader: torch.utils.data.DataLoader The inference dataset generator. network: AbstractNet The denoising network. dev: str The device hosting the computation. verbose: int Switch to log information. Defaults to 1. Available options: - 0: no logs. - 1: display progress bar. profiler: BatchProfiler The profiler object to record batch inference time. Returns ------- output: torch.Tensor Denoised data, of shape=(N,1,H,W). """ network.eval() network.to(dev) outs = [] wrap = tqdm(test_loader, desc="gcnn.predict") if verbose else test_loader if profiler is not None: wrap = profiler.set_iterable(wrap) for noisy, _ in wrap: out = network(noisy.to(dev)).detach().cpu() outs.append(out) output = torch.cat(outs) network.to("cpu") return output