dunedn.networks.uscg package

Submodules

dunedn.networks.uscg.training module

This module provides functions for USCG network training and loading.

dunedn.networks.uscg.training.load_and_compile_uscg_network(channel: str, msetup: dict, checkpoint_filepath: Optional[Path] = None) UscgNet[source]

Loads a USCG network.

Parameters
  • channel (str) – Available options induction | collection.

  • msetup (dict) – The model setup dictionary.

  • checkpoint_filepath (Path) – The .pth checkpoint containing network weights to be loaded.

Returns

network – The loaded neural network.

Return type

UscgNet

dunedn.networks.uscg.training.uscg_training(setup: dict)[source]

GCNN network training.

Parameters

setup (dict) – Settings dictionary.

dunedn.networks.uscg.uscg_dataloading module

This module implements dataset loading for the USCG network.

class dunedn.networks.uscg.uscg_dataloading.BaseUscgDataset(dataset_type: str, task: str, channel: str, dsetup: dict, batch_size: int)[source]

Bases: Dataset

Loads the dataset for CNN and GCNN networks.

class dunedn.networks.uscg.uscg_dataloading.UscgDataset(dataset_type: str, task: str, channel: str, dsetup: dict, batch_size: int)[source]

Bases: BaseUscgDataset

Loads the dataset for CNN and GCNN networks.

get_planes_from_setup() Tuple[Tensor, Tensor][source]

Get planes from folder pointed by dsetup.

Returns

  • noisy (torch.Tensor) – The noisy planes, of shape=(N,1,H,W).

  • clear (torch.Tensor) – The clear planes, of shape=(N,1,H,W).

class dunedn.networks.uscg.uscg_dataloading.UscgPlanesDataset(noisy: ndarray, task: str, channel: str, dsetup: dict, batch_size: int)[source]

Bases: BaseUscgDataset

Loads planes in dataset form for GcnnNet network inference.

dunedn.networks.uscg.uscg_net module

This module contains the UscgNet model class.

class dunedn.networks.uscg.uscg_net.UscgNet(channel: str = 'collection', out_channels: int = 1, h_induction: int = 800, h_collection: int = 960, w: int = 6000, stride: int = 1000, pretrained: bool = True, node_size: list[int] = [28, 28], dropout: float = 0.5, enhance_diag: bool = True, aux_pred: bool = True)[source]

Bases: AbstractNet

U-shaped Self Constructing Graph Network.

forward(x)[source]

USCG Net Forwards pass.

Parameters

x (-) –

Returns

Return type

  • torch.Tensor, output tensor of shape=(N,C,H,W)

onnx_export(fname: Path)[source]

Export model to ONNX format.

Parameters

fname (Path) – The path to save the .onnx network.

predict(generator: UscgDataset, dev: str = 'cpu', no_metrics: bool = False, verbose: int = 1) Tuple[Tensor, dict][source]

Uscg network inference.

Parameters
  • generator (UscgDataset) – The inference dataset generator.

  • device (str) – The device hosting the computation. Defaults is “cpu”.

  • no_metrics (bool) – Wether to skip metric computation. Defaults to False, so metrics are indeed computed.

  • verbose (int) –

    Switch to log information. Defaults to 1. Available options:

    • 0: no logs.

    • 1: display progress bar.

Returns

  • y_pred (torch.Tensor) – Denoised planes, of shape=(N,1,H,W).

  • logs (dict) – The computed metrics results in dictionary form.

train_batch(noisy: Tensor, clear: Tensor, dev: str)[source]

Makes one batch update.

Parameters
  • noisy (torch.Tensor) – Noisy inputs batch, of shape=(B,1,H,W).

  • clear (torch.Tensor) – Clear target batch, of shape=(B,1,H,W)

  • dev (str) – The device hosting the computation.

Returns

step_logs – The step logs as a dictionary.

Return type

dict

train_epoch(train_loader: DataLoader, dev: str = 'cpu') dict[source]

Trains the network for one epoch.

Parameters
  • train_loader (torch.utils.data.DataLoader) – The training dataloader.

  • validation_data (torch.utils.data.DataLoader) – The validation dataloader.

  • dev (str) – The device hosting the computation.

Returns

epoch_logs – The dictionary of epoch logs.

Return type

dict

training: bool

dunedn.networks.uscg.uscg_net_blocks module

This module contains the USCG Net building blocks.

class dunedn.networks.uscg.uscg_net_blocks.BatchNorm_GCN(num_features)[source]

Bases: BatchNorm1d

Batch normalization over GCN features

affine: bool
eps: float
forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

momentum: float
num_features: int
track_running_stats: bool
class dunedn.networks.uscg.uscg_net_blocks.GCN_Layer(in_features, out_features, bnorm=True, activation=ReLU(), dropout=None)[source]

Bases: Module

forward(data)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class dunedn.networks.uscg.uscg_net_blocks.Pooling_Block(c, h, w)[source]

Bases: Module

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class dunedn.networks.uscg.uscg_net_blocks.Recombination_Layer[source]

Bases: Module

This layer recombines an output with a residual connection through a (1,1) convolution. It first concatenates the two inputs along the channel dimension and then it applies the convolution.

forward(x, y)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class dunedn.networks.uscg.uscg_net_blocks.SCG_Block(in_ch, hidden_ch=6, node_size=(32, 32), add_diag=True, dropout=0.2)[source]

Bases: Module

forward(gx)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

laplacian_matrix(A, self_loop=False)[source]

Computes normalized Laplacian matrix: A (B, N, N)

training: bool

dunedn.networks.uscg.uscg_net_utils module

This module contains the utility functions for USCG network.

dunedn.networks.uscg.uscg_net_utils.weight_xavier_init(*models)[source]

dunedn.networks.uscg.utils module

This module implements utility functions for the networks.uscg subpackage.

dunedn.networks.uscg.utils.make_dict_compatible(state_dict: OrderedDict)[source]

Transforms state_dict keys to match new GcnnNet attributes format.

Changed in version 2.0.0:

  • Remove “.module” in front of the saved weights name.

  • Remove extra attributes due to deprecated normalization layer.

  • Transform layers names to lowercase.

Parameters

state_dict (OrderedDict) – The original dictionary containing network saved weights.

Returns

new_state_dict – The dictionary updated version.

Return type

OrderedDict

dunedn.networks.uscg.utils.time_windows(planes: Tensor, w: int, stride: int) Tuple[Tensor, list[torch.Tensor], list[list[int]]][source]

Takes time windows of given width and stride from a planes.

Parameters
  • planes (torch.Tensor) – Planes of shape=(N,C,H,W).

  • w (int) – Width of the time windows.

  • stride (int) – Steps between time windows.

Returns

  • divisions (torch.Tensor) – Number of times each pixel should be processed by denoising network, of shape=(N,C,H,W).

  • windows (list[torch.Tensor]) – Time windows to be processed, each of shape=(N,C,H,w)

  • idxs (list[list[int]]) – Start-end time indices for the correspondent window. Each elements is [start idx, end idx].

dunedn.networks.uscg.utils.uscg_inference_pass(test_loader: DataLoader, network: AbstractNet, dev: str, verbose: int = 1) Tensor[source]

Consumes data through USCG network and gives outputs.

Parameters
  • test_loader (torch.utils.data.Dataloader) – The inference dataset generator.

  • network (AbstractNet) – The denoising network.

  • dev (str) – The device hosting the computation.

  • verbose (int) –

    Switch to log information. Defaults to 1. Available options:

    • 0: no logs.

    • 1: display progress bar.

Returns

output – Denoised data, of shape=(N,1,H,W).

Return type

torch.Tensor

Module contents