dunedn.networks.gcnn package

Submodules

dunedn.networks.gcnn.gcnn_dataloading module

This module implements dataset loading for the CNN and GCNN networks.

class dunedn.networks.gcnn.gcnn_dataloading.BaseGcnnDataset(dataset_type: str, task: str = 'dn', channel: str = 'collection', dsetup: Optional[dict] = None, batch_size: int = 128)[source]

Bases: Dataset, ABC

Loads the dataset for CNN and GCNN networks.

class dunedn.networks.gcnn.gcnn_dataloading.GcnnDataset(dataset_type: str, task: str, channel: str, dsetup: dict, batch_size: int)[source]

Bases: BaseGcnnDataset

Loads the dataset for CNN and GCNN networks.

to_crops()[source]

Converts planes into crops.

Note

This method should not be called when training.

to_planes()[source]

Converts crops into planes.

class dunedn.networks.gcnn.gcnn_dataloading.GcnnPlanesDataset(noisy: ndarray, task: str, channel: str, dsetup: dict, batch_size: int)[source]

Bases: BaseGcnnDataset

Loads the dataset for CNN and GCNN networks.

to_crops()[source]

Converts planes into crops.

Note

This method should not be called when training.

to_planes()[source]

Converts crops into planes.

dunedn.networks.gcnn.gcnn_net module

This module contains the GcnnNet model class.

GcnnNet implements also the CNN variant.

class dunedn.networks.gcnn.gcnn_net.GcnnNet(model: str, task: str, crop_edge: int, input_channels: int, hidden_channels: int, k: Optional[int] = None)[source]

Bases: AbstractNet

Graph Convolutional Neural Network implementation.

forward(x: Tensor) Tensor[source]

Gcnn forward pass.

Parameters

x (torch.Tensor) – Input tensor of shape=(N,C,H,W).

Returns

output – Output tensor of shape=(N,C,H,W).

Return type

torch.Tensor

onnx_export(fname: Path)[source]

Export model to ONNX format.

Parameters

fname (Path) – The path to save the .onnx network.

predict(generator: BaseGcnnDataset, dev: str = 'cpu', no_metrics: bool = False, verbose: int = 1, profiler: Optional[BatchProfiler] = None) Tuple[Tensor, list[Tuple[float, float]], float][source]

Gcnn network inference.

Parameters
  • generator (GcnnDataset) – The inference dataset generator.

  • device (str) – The device hosting the computation. Defaults is “cpu”.

  • no_metrics (bool) – Wether to skip metric computation. Defaults to False, so metrics are indeed computed.

  • verbose (int) –

    Switch to log information. Defaults to 1. Available options:

    • 0: no logs.

    • 1: display progress bar.

  • profiler (BatchProfiler) – The profiler object to record batch inference time.

Returns

  • y_pred (torch.Tensor) – Denoised planes, of shape=(N,1,H,W).

  • logs (dict) – The computed metrics results in dictionary form.

train_batch(noisy: Tensor, clear: Tensor, dev: str)[source]

Makes one batch update.

Parameters
  • noisy (torch.Tensor) – Noisy inputs batch, of shape=(B,1,H,W).

  • clear (torch.Tensor) – Clear target batch, of shape=(B,1,H,W)

  • dev (str) – The device hosting the computation.

Returns

step_logs – The step logs as a dictionary.

Return type

dict

train_epoch(train_loader: DataLoader, dev: str = 'cpu') dict[source]

Trains the network for one epoch.

Parameters
  • train_loader (torch.utils.data.DataLoader) – The training dataloader.

  • validation_data (torch.utils.data.DataLoader) – The validation dataloader.

  • dev (str) – The device hosting the computation.

Returns

epoch_logs – The dictionary of epoch logs.

Return type

dict

training: bool

dunedn.networks.gcnn.gcnn_net_blocks module

This module contains the GCNN Net building blocks.

class dunedn.networks.gcnn.gcnn_net_blocks.Conv(ic, oc)[source]

Bases: Module

GConv layer.

forward(x, graph)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class dunedn.networks.gcnn.gcnn_net_blocks.GConv(ic, oc)[source]

Bases: Module

GConv layer.

forward(x, graph)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class dunedn.networks.gcnn.gcnn_net_blocks.HPF(ic, oc, getgraph_fn, model)[source]

Bases: Module

High Pass Filter

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class dunedn.networks.gcnn.gcnn_net_blocks.LPF(ic, oc, getgraph_fn, model)[source]

Bases: Module

Low Pass Filter

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class dunedn.networks.gcnn.gcnn_net_blocks.NonLocalAggregator(input_channels, out_channels)[source]

Bases: Module

NonLocalAggregator layer.

forward(x, graph)[source]
Parameters
  • x (-) –

  • graph (-) –

Returns

Return type

  • torch.Tensor of shape=(N,C,H,W)

training: bool
class dunedn.networks.gcnn.gcnn_net_blocks.NonLocalGraph(k, crop_size)[source]

Bases: object

Non-local graph layer.

class dunedn.networks.gcnn.gcnn_net_blocks.PostProcessBlock(ic, hc, getgraph_fn, model)[source]

Bases: Module

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class dunedn.networks.gcnn.gcnn_net_blocks.PreProcessBlock(kernel_size, ic, oc, getgraph_fn, model)[source]

Bases: Module

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class dunedn.networks.gcnn.gcnn_net_blocks.ROI(kernel_size, ic, hc, getgraph_fn, model)[source]

Bases: Module

U-net style binary segmentation

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
dunedn.networks.gcnn.gcnn_net_blocks.choose_conv(model, ic, oc)[source]

Utility function to retrieve GConv or Conv layer from its name.

Parameters
  • model (-) –

  • ic (-) –

  • oc (-) –

Returns

Return type

  • torch.nn.Module, the layer instance

Raises

- NotImplementedError if op is not in ['gcnn', 'cnn']

dunedn.networks.gcnn.gcnn_net_utils module

This module contains the utility functions for CNN and GCNN networks.

class dunedn.networks.gcnn.gcnn_net_utils.Converter(crop_size: Tuple[int])[source]

Bases: object

Groups image to tiles converter functions

planes2tiles(planes: Tensor) Tensor[source]
Parameters

planes (torch.Tensor) – Planes, of shape=(N,C,W,H).

Returns

Tiles of shape=(N’,C,edge_h,edge_w). With N' = N * ceil(H/edge_h) * ceil(W/edge_w)

Return type

torch.Tensor

tiles2planes(splits: Tensor) Tensor[source]
Parameters

splits (torch.Tensor) – Tiles, of shape (N’,C,edge_h,edge_w).

Returns

Planes, of shape=(N,C,H,W).

Return type

torch.Tensor

dunedn.networks.gcnn.gcnn_net_utils.batched_index_select(t, dim, inds)[source]

Selects K nearest neighbors indices for each pixel respecting batch dimension.

Parameters
  • t (-) –

  • dim (-) –

  • inds (-) –

Returns

Return type

  • torch.Tensor, index tensor of shape=(N,H*W*K,C)

dunedn.networks.gcnn.gcnn_net_utils.calculate_pad(plane_size, crop_size)[source]

Given plane and crop shape, compute the needed padding to obtain exact tiling.

Parameters
  • plane_size (-) –

  • crop_size (-) –

Returns

- list, plane padding

Return type

[pre h, post h, pre w, post w]

dunedn.networks.gcnn.gcnn_net_utils.local_mask(crop_size)[source]

Computes mask to remove local pixels from the computation.

Parameters

crops_size (-) –

Returns

Return type

  • torch.Tensor, local mask of shape=(1,H*W,H*W)

dunedn.networks.gcnn.gcnn_net_utils.pairwise_dist(arr, k, local_mask)[source]

Computes pairwise euclidean distances between pixels.

Parameters
  • arr (-) –

  • k (-) –

  • local_mask (-) –

Returns

Return type

  • torch.Tensor, pairwise pixel distances of shape=(N,H*W,H*W)

dunedn.networks.gcnn.training module

This module provides functions for CNN and GCNN networks training and loading.

dunedn.networks.gcnn.training.gcnn_training(modeltype: str, setup: dict)[source]

GCNN network training.

Parameters
  • modeltype (str) – The model to be trained. Available options: cnn | gcnn.

  • setup (dict) – Settings dictionary.

dunedn.networks.gcnn.training.load_and_compile_gcnn_network(channel: str, msetup: dict, checkpoint_filepath: Optional[Path] = None) GcnnNet[source]

Loads a CNN or GCNN network.

Parameters
  • channel (str) – Available options induction | collection.

  • msetup (dict) – The model setup dictionary.

  • checkpoint_filepath (Path) – The .pth checkpoint containing network weights to be loaded.

Returns

network – The loaded neural network.

Return type

GcnnNet

dunedn.networks.gcnn.utils module

This module implements utility functions for the networks.gcnn subpackage.

dunedn.networks.gcnn.utils.gcnn_inference_pass(test_loader: DataLoader, network: AbstractNet, dev: str, verbose: int = 1, profiler: Optional[BatchProfiler] = None) Tensor[source]

Consumes data through CNN or GCNN network and gives outputs.

Parameters
  • test_loader (torch.utils.data.DataLoader) – The inference dataset generator.

  • network (AbstractNet) – The denoising network.

  • dev (str) – The device hosting the computation.

  • verbose (int) –

    Switch to log information. Defaults to 1. Available options:

    • 0: no logs.

    • 1: display progress bar.

  • profiler (BatchProfiler) – The profiler object to record batch inference time.

Returns

output – Denoised data, of shape=(N,1,H,W).

Return type

torch.Tensor

dunedn.networks.gcnn.utils.make_dict_compatible(state_dict: OrderedDict)[source]

Transforms state_dict keys to match new GcnnNet attributes format.

Changed in version 2.0.0:

  • Remove “.module” in front of the saved weights name.

  • Remove extra attributes due to deprecated normalization layer.

  • Transform layers names to lowercase.

Parameters

state_dict (OrderedDict) – The original dictionary containing network saved weights.

Returns

new_state_dict – The dictionary updated version.

Return type

OrderedDict

Module contents