Source code for dunedn.networks.gcnn.gcnn_net

"""
    This module contains the GcnnNet model class.

    GcnnNet implements also the CNN variant.
"""
import logging
from typing import Tuple
from pathlib import Path
from time import time as tm
import torch
from torch import nn
from ..abstract_net import AbstractNet
from ..utils import BatchProfiler
from .gcnn_dataloading import BaseGcnnDataset
from .gcnn_net_blocks import (
    PreProcessBlock,
    ROI,
    HPF,
    LPF,
    PostProcessBlock,
    NonLocalGraph,
)
from .utils import gcnn_inference_pass
from dunedn import PACKAGE

logger = logging.getLogger(PACKAGE + ".gcnn")


[docs]class GcnnNet(AbstractNet): """Graph Convolutional Neural Network implementation.""" def __init__( self, model: str, task: str, crop_edge: int, input_channels: int, hidden_channels: int, k: int = None, ): """ Parameters ---------- model: str Available options cnn | gcnn. task: str Available options dn | roi. crop_edge: int Crop edge size. input_channels: int Inputh channel dimension size. hidden_channels: int Convolutions hidden filters number. k: int Nearest neighbor number. None if model is cnn.. """ super(GcnnNet, self).__init__() self.crop_size = (crop_edge,) * 2 self.model = model self.task = task ic = input_channels hc = hidden_channels self.k = k self.input_shape = (1,) + self.crop_size self.getgraph_fn = ( NonLocalGraph(k, self.crop_size) if self.model == "gcnn" else lambda x: None ) # self.norm_fn = choose_norm(dataset_dir, channel, normalization) self.roi = ROI(7, ic, hc, self.getgraph_fn, self.model) self.pre_process_blocks = nn.ModuleList( [ PreProcessBlock(5, ic, hc, self.getgraph_fn, self.model), PreProcessBlock(7, ic, hc, self.getgraph_fn, self.model), PreProcessBlock(9, ic, hc, self.getgraph_fn, self.model), ] ) self.lpfs = nn.ModuleList( [ LPF(hc * 3 + 1, hc * 3 + 1, self.getgraph_fn, self.model) for _ in range(4) ] ) self.hpf = HPF(hc * 3 + 1, hc * 3 + 1, self.getgraph_fn, self.model) self.post_process_block = PostProcessBlock(ic, hc, self.getgraph_fn, self.model) self.aa = nn.Parameter(torch.Tensor([0]), requires_grad=False) self.bb = nn.Parameter(torch.Tensor([1]), requires_grad=False) self.combine = lambda x, y: x + y
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Gcnn forward pass. Parameters ---------- x: torch.Tensor Input tensor of shape=(N,C,H,W). Returns ------- output: torch.Tensor Output tensor of shape=(N,C,H,W). """ # x = self.norm_fn(x) hits = self.roi(x) if self.task == "roi": return hits y = torch.cat([block(x) for block in self.pre_process_blocks], dim=1) y = torch.cat([y, hits], 1) y_hpf = self.hpf(y) y = self.combine(y, y_hpf) for lpf in self.lpfs: y = self.combine(lpf(y), y_hpf) output = self.post_process_block(y) * x return output
[docs] def predict( self, generator: BaseGcnnDataset, dev: str = "cpu", no_metrics: bool = False, verbose: int = 1, profiler: BatchProfiler = None, ) -> Tuple[torch.Tensor, list[Tuple[float, float]], float]: """Gcnn network inference. Parameters ---------- generator: GcnnDataset The inference dataset generator. device: str The device hosting the computation. Defaults is "cpu". no_metrics: bool Wether to skip metric computation. Defaults to False, so metrics are indeed computed. verbose: int Switch to log information. Defaults to 1. Available options: - 0: no logs. - 1: display progress bar. profiler: BatchProfiler The profiler object to record batch inference time. Returns ------- y_pred: torch.Tensor Denoised planes, of shape=(N,1,H,W). logs: dict The computed metrics results in dictionary form. """ self.check_network_is_compiled() # convert planes to crops # TODO: think about a `with` statement for the test_loader object, as it # shouldn't be possible to call the inference without `to_crops()` and # `to_planes()` methods generator.to_crops() test_loader = torch.utils.data.DataLoader( dataset=generator, batch_size=generator.batch_size, ) # inference pass start = tm() output = gcnn_inference_pass(test_loader, self, dev, verbose, profiler=profiler) inference_time = tm() - start # convert back to planes y_pred = generator.converter.tiles2planes(output) generator.to_planes() if no_metrics: return y_pred # compute metrics y_true = generator.clear logs = self.metrics_list.compute_metrics(y_pred, y_true) logs.update({"time": inference_time}) return y_pred, logs
[docs] def train_epoch( self, train_loader: torch.utils.data.DataLoader, dev: str = "cpu", ) -> dict: """Trains the network for one epoch. Parameters ---------- train_loader: torch.utils.data.DataLoader The training dataloader. validation_data: torch.utils.data.DataLoader The validation dataloader. dev: str The device hosting the computation. Returns ------- epoch_logs: dict The dictionary of epoch logs. """ logger.debug("Training epoch") self.train() for clear, noisy in train_loader: self.callback_list.on_train_batch_begin() step_logs = self.train_batch(noisy, clear, dev) self.callback_list.on_train_batch_end(step_logs) epoch_logs = {} return epoch_logs
[docs] def train_batch(self, noisy: torch.Tensor, clear: torch.Tensor, dev: str): """Makes one batch update. Parameters ---------- noisy: torch.Tensor Noisy inputs batch, of shape=(B,1,H,W). clear: torch.Tensor Clear target batch, of shape=(B,1,H,W) dev: str The device hosting the computation. Returns ------- step_logs: dict The step logs as a dictionary. """ noisy = noisy.to(dev) self.optimizer.zero_grad() y_pred = self.forward(noisy) loss = self.loss_fn(y_pred, clear) loss.mean().backward() self.optimizer.step() step_logs = self.metrics_list.compute_metrics(y_pred, clear) step_logs.update({"loss": loss.item()}) return step_logs
[docs] def onnx_export(self, fname: Path): """Export model to ONNX format. Parameters ---------- fname: Path The path to save the `.onnx` network. """ # produce dummy inputs inputs = torch.randn(1, 1, *self.crop_size) # export network torch.onnx.export( self, inputs, fname, verbose=False, input_names=["input"], output_names=["output"], dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, )