"""This module contains the UscgNet model class."""
import logging
from pathlib import Path
from time import time as tm
from typing import Tuple
import torch
from torch import nn
from math import ceil
from torchvision.models import resnext50_32x4d
from ..abstract_net import AbstractNet
from .uscg_dataloading import UscgDataset
from .uscg_net_blocks import (
SCG_Block,
GCN_Layer,
Pooling_Block,
Recombination_Layer,
)
from .utils import uscg_inference_pass, time_windows
from dunedn import PACKAGE
logger = logging.getLogger(PACKAGE + ".uscg")
[docs]class UscgNet(AbstractNet):
"""U-shaped Self Constructing Graph Network."""
def __init__(
self,
channel: str = "collection",
out_channels: int = 1,
h_induction: int = 800,
h_collection: int = 960,
w: int = 6000,
stride: int = 1000,
pretrained: bool = True,
node_size: list[int] = [28, 28],
dropout: float = 0.5,
enhance_diag: bool = True,
aux_pred: bool = True,
):
"""
Parameters
----------
channel: str
Available options induction | collection.
out_channels: int
Output image channels number.
h_induction: int
Induction input image height.
h_collection: int
Collection input image height.
w: int
Input image width.
stride: int
Steps between time windows.
pretrained: bool
Wether to download weight of pretrained resnet or not.
node_size: list
[height, width] of the image input of SCG block.
dropout: float
Percentage of neurons turned off in graph layer.
enhance_diag: bool
SCG_block flag.
aux_pred: bool
SCG_block flag.
"""
super(UscgNet, self).__init__()
self.out_channels = out_channels
self.channel = channel
self.h_collection = h_collection
self.h_induction = h_induction
self.w = w
self.stride = stride
self.pretrained = pretrained
self.node_size = node_size
self.dropout = dropout
self.enhance_diag = enhance_diag
self.aux_pred = aux_pred
self.h = self.h_induction if self.channel == "induction" else self.h_collection
self.input_shape = (1, self.h, self.w)
resnet = resnext50_32x4d(pretrained=self.pretrained, progress=True)
resnet_12 = nn.Sequential(
nn.Conv2d(1, 3, 1),
resnet.conv1,
resnet.bn1,
resnet.relu,
resnet.maxpool,
resnet.layer1,
resnet.layer2,
)
resnet_34 = nn.Sequential(
resnet.layer3, resnet.layer4, nn.Conv2d(2048, 1024, 1)
)
self.downsamples = nn.ModuleList(
[resnet_12, resnet_34, Pooling_Block(1024, 28, 28)]
)
self.upsamples = nn.ModuleList(
[
Pooling_Block(1, ceil(self.h / 32), ceil(self.w / 32)),
Pooling_Block(1, ceil(self.h / 8), ceil(self.w / 8)),
Pooling_Block(1, self.h, self.w),
]
)
self.gcns = nn.Sequential(
GCN_Layer(
1024, 128, bnorm=True, activation=nn.ReLU(True), dropout=self.dropout
),
GCN_Layer(128, self.out_channels, bnorm=False, activation=None),
)
self.scg = SCG_Block(
in_ch=1024,
hidden_ch=self.out_channels,
node_size=self.node_size,
add_diag=self.enhance_diag,
dropout=self.dropout,
)
# weight_xavier_init(*self.GCNs, self.scg)
self.adapts = nn.ModuleList(
[
nn.Conv2d(512, 1, 1, bias=False),
nn.Conv2d(1024, 1, 1, bias=False),
nn.Conv2d(1024, 1, 1, bias=False),
]
)
self.recombs = nn.ModuleList([Recombination_Layer() for i in range(3)])
self.last_recomb = Recombination_Layer()
# self.act = nn.Sigmoid() if task == "roi" else nn.Identity()
[docs] def forward(self, x):
"""USCG Net Forwards pass.
Parameters
----------
- x: torch.Tensor, input tensor of shape=(N,C,H,W)
Returns
-------
- torch.Tensor, output tensor of shape=(N,C,H,W)
"""
# if self.task == "roi":
# x /= 3197 + 524 # normalizing according to dataset
i = x
# downsampling
ys = []
for adapt, downsample in zip(self.adapts, self.downsamples):
x = downsample(x)
ys.append(adapt(x))
# Graph
batch_size, nb_channels, _, _ = x.size()
a, x, loss, z_hat = self.scg(x)
x, _ = self.gcns((x.reshape(batch_size, -1, nb_channels), a))
if self.aux_pred:
x += z_hat
x = x.reshape(batch_size, self.out_channels, *self.node_size)
# upsampling
for y, recomb, upsample in zip(
reversed(ys), reversed(self.recombs), self.upsamples
):
x = upsample(recomb(x, y))
if self.training:
return x * i, loss
# return self.act(x * i), loss
return x * i
# return self.act(x * i)
[docs] def predict(
self,
generator: UscgDataset,
dev: str = "cpu",
no_metrics: bool = False,
verbose: int = 1,
) -> Tuple[torch.Tensor, dict]:
"""Uscg network inference.
Parameters
----------
generator: UscgDataset
The inference dataset generator.
device: str
The device hosting the computation. Defaults is "cpu".
no_metrics: bool
Wether to skip metric computation. Defaults to False, so metrics are
indeed computed.
verbose: int
Switch to log information. Defaults to 1. Available options:
- 0: no logs.
- 1: display progress bar.
Returns
-------
y_pred: torch.Tensor
Denoised planes, of shape=(N,1,H,W).
logs: dict
The computed metrics results in dictionary form.
"""
self.check_network_is_compiled()
# convert planes to crops
test_loader = torch.utils.data.DataLoader(
dataset=generator,
batch_size=generator.batch_size,
)
# inference pass
start = tm()
y_pred = uscg_inference_pass(test_loader, self, dev, verbose)
inference_time = tm() - start
if no_metrics:
return y_pred
# compute metrics
y_true = generator.clear
logs = self.metrics_list.compute_metrics(y_pred, y_true)
logs.update({"time": inference_time})
return y_pred, logs
[docs] def train_epoch(
self,
train_loader: torch.utils.data.DataLoader,
dev: str = "cpu",
) -> dict:
"""Trains the network for one epoch.
Parameters
----------
train_loader: torch.utils.data.DataLoader
The training dataloader.
validation_data: torch.utils.data.DataLoader
The validation dataloader.
dev: str
The device hosting the computation.
Returns
-------
epoch_logs: dict
The dictionary of epoch logs.
"""
self.train()
for noisy, clear in train_loader:
_, cwindows, _ = time_windows(clear, self.w, self.stride)
_, nwindows, _ = time_windows(noisy, self.w, self.stride)
for nwindow, cwindow in zip(nwindows, cwindows):
self.callback_list.on_train_batch_begin()
step_logs = self.train_batch(nwindow, cwindow, dev)
self.callback_list.on_train_batch_end(step_logs)
epoch_logs = {}
return epoch_logs
[docs] def train_batch(self, noisy: torch.Tensor, clear: torch.Tensor, dev: str):
"""Makes one batch update.
Parameters
----------
noisy: torch.Tensor
Noisy inputs batch, of shape=(B,1,H,W).
clear: torch.Tensor
Clear target batch, of shape=(B,1,H,W)
dev: str
The device hosting the computation.
Returns
-------
step_logs: dict
The step logs as a dictionary.
"""
clear = clear.to(dev)
noisy = noisy.to(dev)
self.optimizer.zero_grad()
y_pred, loss0 = self.forward(noisy)
loss1 = self.loss_fn(y_pred, clear)
loss = loss1 + loss0
loss.backward()
self.optimizer.step()
step_logs = self.metrics_list.compute_metrics(y_pred, clear)
step_logs.update({"loss": loss.item()})
return step_logs
[docs] def onnx_export(self, fname: Path):
"""Export model to ONNX format.
Parameters
----------
fname: Path
The path to save the `.onnx` network.
"""
raise NotImplementedError(
"Currently, UscgNet cannot be exported to `onnx` format."
)
# produce dummy inputs
inputs = torch.randn(1, 1, self.h, self.w)
# export network
torch.onnx.export(
self,
inputs,
fname,
verbose=False,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
)