dunedn.networks.uscg package
Submodules
dunedn.networks.uscg.training module
This module provides functions for USCG network training and loading.
- dunedn.networks.uscg.training.load_and_compile_uscg_network(channel: str, msetup: dict, checkpoint_filepath: Optional[Path] = None) UscgNet[source]
Loads a USCG network.
- Parameters
channel (str) – Available options induction | collection.
msetup (dict) – The model setup dictionary.
checkpoint_filepath (Path) – The .pth checkpoint containing network weights to be loaded.
- Returns
network – The loaded neural network.
- Return type
dunedn.networks.uscg.uscg_dataloading module
This module implements dataset loading for the USCG network.
- class dunedn.networks.uscg.uscg_dataloading.BaseUscgDataset(dataset_type: str, task: str, channel: str, dsetup: dict, batch_size: int)[source]
Bases:
DatasetLoads the dataset for CNN and GCNN networks.
- class dunedn.networks.uscg.uscg_dataloading.UscgDataset(dataset_type: str, task: str, channel: str, dsetup: dict, batch_size: int)[source]
Bases:
BaseUscgDatasetLoads the dataset for CNN and GCNN networks.
- class dunedn.networks.uscg.uscg_dataloading.UscgPlanesDataset(noisy: ndarray, task: str, channel: str, dsetup: dict, batch_size: int)[source]
Bases:
BaseUscgDatasetLoads planes in dataset form for GcnnNet network inference.
dunedn.networks.uscg.uscg_net module
This module contains the UscgNet model class.
- class dunedn.networks.uscg.uscg_net.UscgNet(channel: str = 'collection', out_channels: int = 1, h_induction: int = 800, h_collection: int = 960, w: int = 6000, stride: int = 1000, pretrained: bool = True, node_size: list[int] = [28, 28], dropout: float = 0.5, enhance_diag: bool = True, aux_pred: bool = True)[source]
Bases:
AbstractNetU-shaped Self Constructing Graph Network.
- forward(x)[source]
USCG Net Forwards pass.
- Parameters
x (-) –
- Returns
- Return type
torch.Tensor, output tensor of shape=(N,C,H,W)
- onnx_export(fname: Path)[source]
Export model to ONNX format.
- Parameters
fname (Path) – The path to save the .onnx network.
- predict(generator: UscgDataset, dev: str = 'cpu', no_metrics: bool = False, verbose: int = 1) Tuple[Tensor, dict][source]
Uscg network inference.
- Parameters
generator (UscgDataset) – The inference dataset generator.
device (str) – The device hosting the computation. Defaults is “cpu”.
no_metrics (bool) – Wether to skip metric computation. Defaults to False, so metrics are indeed computed.
verbose (int) –
Switch to log information. Defaults to 1. Available options:
0: no logs.
1: display progress bar.
- Returns
y_pred (torch.Tensor) – Denoised planes, of shape=(N,1,H,W).
logs (dict) – The computed metrics results in dictionary form.
- train_batch(noisy: Tensor, clear: Tensor, dev: str)[source]
Makes one batch update.
- Parameters
noisy (torch.Tensor) – Noisy inputs batch, of shape=(B,1,H,W).
clear (torch.Tensor) – Clear target batch, of shape=(B,1,H,W)
dev (str) – The device hosting the computation.
- Returns
step_logs – The step logs as a dictionary.
- Return type
dict
- train_epoch(train_loader: DataLoader, dev: str = 'cpu') dict[source]
Trains the network for one epoch.
- Parameters
train_loader (torch.utils.data.DataLoader) – The training dataloader.
validation_data (torch.utils.data.DataLoader) – The validation dataloader.
dev (str) – The device hosting the computation.
- Returns
epoch_logs – The dictionary of epoch logs.
- Return type
dict
- training: bool
dunedn.networks.uscg.uscg_net_blocks module
This module contains the USCG Net building blocks.
- class dunedn.networks.uscg.uscg_net_blocks.BatchNorm_GCN(num_features)[source]
Bases:
BatchNorm1dBatch normalization over GCN features
- affine: bool
- eps: float
- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- momentum: float
- num_features: int
- track_running_stats: bool
- class dunedn.networks.uscg.uscg_net_blocks.GCN_Layer(in_features, out_features, bnorm=True, activation=ReLU(), dropout=None)[source]
Bases:
Module- forward(data)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class dunedn.networks.uscg.uscg_net_blocks.Pooling_Block(c, h, w)[source]
Bases:
Module- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class dunedn.networks.uscg.uscg_net_blocks.Recombination_Layer[source]
Bases:
ModuleThis layer recombines an output with a residual connection through a (1,1) convolution. It first concatenates the two inputs along the channel dimension and then it applies the convolution.
- forward(x, y)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class dunedn.networks.uscg.uscg_net_blocks.SCG_Block(in_ch, hidden_ch=6, node_size=(32, 32), add_diag=True, dropout=0.2)[source]
Bases:
Module- forward(gx)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
dunedn.networks.uscg.uscg_net_utils module
This module contains the utility functions for USCG network.
dunedn.networks.uscg.utils module
This module implements utility functions for the networks.uscg subpackage.
- dunedn.networks.uscg.utils.make_dict_compatible(state_dict: OrderedDict)[source]
Transforms state_dict keys to match new GcnnNet attributes format.
Changed in version 2.0.0:
Remove “.module” in front of the saved weights name.
Remove extra attributes due to deprecated normalization layer.
Transform layers names to lowercase.
- Parameters
state_dict (OrderedDict) – The original dictionary containing network saved weights.
- Returns
new_state_dict – The dictionary updated version.
- Return type
OrderedDict
- dunedn.networks.uscg.utils.time_windows(planes: Tensor, w: int, stride: int) Tuple[Tensor, list[torch.Tensor], list[list[int]]][source]
Takes time windows of given width and stride from a planes.
- Parameters
planes (torch.Tensor) – Planes of shape=(N,C,H,W).
w (int) – Width of the time windows.
stride (int) – Steps between time windows.
- Returns
divisions (torch.Tensor) – Number of times each pixel should be processed by denoising network, of shape=(N,C,H,W).
windows (list[torch.Tensor]) – Time windows to be processed, each of shape=(N,C,H,w)
idxs (list[list[int]]) – Start-end time indices for the correspondent window. Each elements is [start idx, end idx].
- dunedn.networks.uscg.utils.uscg_inference_pass(test_loader: DataLoader, network: AbstractNet, dev: str, verbose: int = 1) Tensor[source]
Consumes data through USCG network and gives outputs.
- Parameters
test_loader (torch.utils.data.Dataloader) – The inference dataset generator.
network (AbstractNet) – The denoising network.
dev (str) – The device hosting the computation.
verbose (int) –
Switch to log information. Defaults to 1. Available options:
0: no logs.
1: display progress bar.
- Returns
output – Denoised data, of shape=(N,1,H,W).
- Return type
torch.Tensor