dunedn.networks.gcnn package
Submodules
dunedn.networks.gcnn.gcnn_dataloading module
This module implements dataset loading for the CNN and GCNN networks.
- class dunedn.networks.gcnn.gcnn_dataloading.BaseGcnnDataset(dataset_type: str, task: str = 'dn', channel: str = 'collection', dsetup: Optional[dict] = None, batch_size: int = 128)[source]
Bases:
Dataset,ABCLoads the dataset for CNN and GCNN networks.
- class dunedn.networks.gcnn.gcnn_dataloading.GcnnDataset(dataset_type: str, task: str, channel: str, dsetup: dict, batch_size: int)[source]
Bases:
BaseGcnnDatasetLoads the dataset for CNN and GCNN networks.
- class dunedn.networks.gcnn.gcnn_dataloading.GcnnPlanesDataset(noisy: ndarray, task: str, channel: str, dsetup: dict, batch_size: int)[source]
Bases:
BaseGcnnDatasetLoads the dataset for CNN and GCNN networks.
dunedn.networks.gcnn.gcnn_net module
This module contains the GcnnNet model class.
GcnnNet implements also the CNN variant.
- class dunedn.networks.gcnn.gcnn_net.GcnnNet(model: str, task: str, crop_edge: int, input_channels: int, hidden_channels: int, k: Optional[int] = None)[source]
Bases:
AbstractNetGraph Convolutional Neural Network implementation.
- forward(x: Tensor) Tensor[source]
Gcnn forward pass.
- Parameters
x (torch.Tensor) – Input tensor of shape=(N,C,H,W).
- Returns
output – Output tensor of shape=(N,C,H,W).
- Return type
torch.Tensor
- onnx_export(fname: Path)[source]
Export model to ONNX format.
- Parameters
fname (Path) – The path to save the .onnx network.
- predict(generator: BaseGcnnDataset, dev: str = 'cpu', no_metrics: bool = False, verbose: int = 1, profiler: Optional[BatchProfiler] = None) Tuple[Tensor, list[Tuple[float, float]], float][source]
Gcnn network inference.
- Parameters
generator (GcnnDataset) – The inference dataset generator.
device (str) – The device hosting the computation. Defaults is “cpu”.
no_metrics (bool) – Wether to skip metric computation. Defaults to False, so metrics are indeed computed.
verbose (int) –
Switch to log information. Defaults to 1. Available options:
0: no logs.
1: display progress bar.
profiler (BatchProfiler) – The profiler object to record batch inference time.
- Returns
y_pred (torch.Tensor) – Denoised planes, of shape=(N,1,H,W).
logs (dict) – The computed metrics results in dictionary form.
- train_batch(noisy: Tensor, clear: Tensor, dev: str)[source]
Makes one batch update.
- Parameters
noisy (torch.Tensor) – Noisy inputs batch, of shape=(B,1,H,W).
clear (torch.Tensor) – Clear target batch, of shape=(B,1,H,W)
dev (str) – The device hosting the computation.
- Returns
step_logs – The step logs as a dictionary.
- Return type
dict
- train_epoch(train_loader: DataLoader, dev: str = 'cpu') dict[source]
Trains the network for one epoch.
- Parameters
train_loader (torch.utils.data.DataLoader) – The training dataloader.
validation_data (torch.utils.data.DataLoader) – The validation dataloader.
dev (str) – The device hosting the computation.
- Returns
epoch_logs – The dictionary of epoch logs.
- Return type
dict
- training: bool
dunedn.networks.gcnn.gcnn_net_blocks module
This module contains the GCNN Net building blocks.
- class dunedn.networks.gcnn.gcnn_net_blocks.Conv(ic, oc)[source]
Bases:
ModuleGConv layer.
- forward(x, graph)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class dunedn.networks.gcnn.gcnn_net_blocks.GConv(ic, oc)[source]
Bases:
ModuleGConv layer.
- forward(x, graph)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class dunedn.networks.gcnn.gcnn_net_blocks.HPF(ic, oc, getgraph_fn, model)[source]
Bases:
ModuleHigh Pass Filter
- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class dunedn.networks.gcnn.gcnn_net_blocks.LPF(ic, oc, getgraph_fn, model)[source]
Bases:
ModuleLow Pass Filter
- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class dunedn.networks.gcnn.gcnn_net_blocks.NonLocalAggregator(input_channels, out_channels)[source]
Bases:
ModuleNonLocalAggregator layer.
- forward(x, graph)[source]
- Parameters
x (-) –
graph (-) –
- Returns
- Return type
torch.Tensor of shape=(N,C,H,W)
- training: bool
- class dunedn.networks.gcnn.gcnn_net_blocks.NonLocalGraph(k, crop_size)[source]
Bases:
objectNon-local graph layer.
- class dunedn.networks.gcnn.gcnn_net_blocks.PostProcessBlock(ic, hc, getgraph_fn, model)[source]
Bases:
Module- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class dunedn.networks.gcnn.gcnn_net_blocks.PreProcessBlock(kernel_size, ic, oc, getgraph_fn, model)[source]
Bases:
Module- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class dunedn.networks.gcnn.gcnn_net_blocks.ROI(kernel_size, ic, hc, getgraph_fn, model)[source]
Bases:
ModuleU-net style binary segmentation
- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
dunedn.networks.gcnn.gcnn_net_utils module
This module contains the utility functions for CNN and GCNN networks.
- class dunedn.networks.gcnn.gcnn_net_utils.Converter(crop_size: Tuple[int])[source]
Bases:
objectGroups image to tiles converter functions
- dunedn.networks.gcnn.gcnn_net_utils.batched_index_select(t, dim, inds)[source]
Selects K nearest neighbors indices for each pixel respecting batch dimension.
- Parameters
t (-) –
dim (-) –
inds (-) –
- Returns
- Return type
torch.Tensor, index tensor of shape=(N,H*W*K,C)
- dunedn.networks.gcnn.gcnn_net_utils.calculate_pad(plane_size, crop_size)[source]
Given plane and crop shape, compute the needed padding to obtain exact tiling.
- Parameters
plane_size (-) –
crop_size (-) –
- Returns
- list, plane padding
- Return type
[pre h, post h, pre w, post w]
dunedn.networks.gcnn.training module
This module provides functions for CNN and GCNN networks training and loading.
- dunedn.networks.gcnn.training.gcnn_training(modeltype: str, setup: dict)[source]
GCNN network training.
- Parameters
modeltype (str) – The model to be trained. Available options: cnn | gcnn.
setup (dict) – Settings dictionary.
- dunedn.networks.gcnn.training.load_and_compile_gcnn_network(channel: str, msetup: dict, checkpoint_filepath: Optional[Path] = None) GcnnNet[source]
Loads a CNN or GCNN network.
- Parameters
channel (str) – Available options induction | collection.
msetup (dict) – The model setup dictionary.
checkpoint_filepath (Path) – The .pth checkpoint containing network weights to be loaded.
- Returns
network – The loaded neural network.
- Return type
dunedn.networks.gcnn.utils module
This module implements utility functions for the networks.gcnn subpackage.
- dunedn.networks.gcnn.utils.gcnn_inference_pass(test_loader: DataLoader, network: AbstractNet, dev: str, verbose: int = 1, profiler: Optional[BatchProfiler] = None) Tensor[source]
Consumes data through CNN or GCNN network and gives outputs.
- Parameters
test_loader (torch.utils.data.DataLoader) – The inference dataset generator.
network (AbstractNet) – The denoising network.
dev (str) – The device hosting the computation.
verbose (int) –
Switch to log information. Defaults to 1. Available options:
0: no logs.
1: display progress bar.
profiler (BatchProfiler) – The profiler object to record batch inference time.
- Returns
output – Denoised data, of shape=(N,1,H,W).
- Return type
torch.Tensor
- dunedn.networks.gcnn.utils.make_dict_compatible(state_dict: OrderedDict)[source]
Transforms state_dict keys to match new GcnnNet attributes format.
Changed in version 2.0.0:
Remove “.module” in front of the saved weights name.
Remove extra attributes due to deprecated normalization layer.
Transform layers names to lowercase.
- Parameters
state_dict (OrderedDict) – The original dictionary containing network saved weights.
- Returns
new_state_dict – The dictionary updated version.
- Return type
OrderedDict