Source code for dunedn.networks.uscg.utils

"""This module implements utility functions for the `networks.uscg` subpackage."""
from typing import Tuple
from tqdm.auto import tqdm
from math import ceil
import torch
from collections import OrderedDict
from ..abstract_net import AbstractNet


[docs]def make_dict_compatible(state_dict: OrderedDict): """Transforms `state_dict` keys to match new GcnnNet attributes format. *Changed in version 2.0.0:* - Remove ".module" in front of the saved weights name. - Remove extra attributes due to deprecated normalization layer. - Transform layers names to lowercase. Parameters ---------- state_dict: OrderedDict The original dictionary containing network saved weights. Returns ------- new_state_dict: OrderedDict The dictionary updated version. """ new_state_dict = OrderedDict() for k, v in state_dict.items(): new_k = k.lower().replace("module.", "") new_state_dict[new_k] = v return new_state_dict
[docs]def time_windows( planes: torch.Tensor, w: int, stride: int ) -> Tuple[torch.Tensor, list[torch.Tensor], list[list[int]]]: """Takes time windows of given width and stride from a planes. Parameters ---------- planes: torch.Tensor Planes of shape=(N,C,H,W). w: int Width of the time windows. stride: int Steps between time windows. Returns ------- divisions: torch.Tensor Number of times each pixel should be processed by denoising network, of shape=(N,C,H,W). windows: list[torch.Tensor] Time windows to be processed, each of shape=(N,C,H,w) idxs: list[list[int]] Start-end time indices for the correspondent window. Each elements is [start idx, end idx]. """ _, _, _, width = planes.size() n = ceil((width - w) / stride) + 1 base = torch.arange(n).unsqueeze(1) * stride idxs = torch.Tensor([[0, w]]).long() + base windows = [] divisions = torch.zeros_like(planes) for start, end in idxs: window = planes[..., start:end] windows.append(window) divisions[..., start:end] += 1 return divisions, windows, idxs
[docs]def uscg_inference_pass( test_loader: torch.utils.data.DataLoader, network: AbstractNet, dev: str, verbose: int = 1, ) -> torch.Tensor: """Consumes data through USCG network and gives outputs. Parameters ---------- test_loader: torch.utils.data.Dataloader The inference dataset generator. network: AbstractNet The denoising network. dev: str The device hosting the computation. verbose: int Switch to log information. Defaults to 1. Available options: - 0: no logs. - 1: display progress bar. Returns ------- output: torch.Tensor Denoised data, of shape=(N,1,H,W). """ w = network.w network.eval() network.to(dev) outs = [] wrap = tqdm(test_loader, desc="uscg.predict") if verbose else test_loader for noisy, _ in wrap: div, nwindows, idxs = time_windows(noisy, w, network.stride) out = torch.zeros_like(noisy) for nwindow, (start, end) in zip(nwindows, idxs): out[..., start:end] = network(nwindow.to(dev)).detach().cpu() outs.append(out / div) output = torch.cat(outs) network.to("cpu") return output