dunedn.training package

Submodules

dunedn.training.callbacks module

This module implements onjects to keep track of training details.

class dunedn.training.callbacks.Callback[source]

Bases: object

on_epoch_begin(logs=None)[source]
on_epoch_end(logs=None)[source]
on_eval_begin(logs=None)[source]
on_eval_end(logs=None)[source]
on_train_batch_begin(logs=None)[source]
on_train_batch_end(logs=None)[source]
on_train_begin(logs=None)[source]
on_train_end(logs=None)[source]
class dunedn.training.callbacks.CallbackList(callbacks: list[dunedn.training.callbacks.Callback])[source]

Bases: Callback

hook(hook_name: str, logs: dict)[source]

An utility function to call each callback method.

Parameters
  • hook_name (str) – The name of the method to be called.

  • logs (dict) – The dictionary to be logged.

on_epoch_begin(logs=None)[source]
on_epoch_end(logs=None)[source]
on_eval_begin(logs=None)[source]
on_eval_end(logs=None)[source]
on_train_batch_begin(logs=None)[source]
on_train_batch_end(logs=None)[source]
on_train_begin(logs=None)[source]
on_train_end(logs=None)[source]
class dunedn.training.callbacks.History[source]

Bases: Callback

append(logs=None)[source]
on_epoch_end(logs=None)[source]
on_train_batch_end(logs=None)[source]
on_train_begin(logs=None)[source]
reset()[source]

dunedn.training.denoise_training module

This module contains the wrapper function for the dunedn train command.

Example

Train help output:

$ dunedn train --help

usage: dunedn train [-h] [--model {cnn,gcnn,uscg}] [--output OUTPUT] [--force] [--interactive]

Train model loading settings from configcard.

optional arguments:
  -h, --help            show this help message and exit
  --model {cnn,gcnn,uscg}, -m {cnn,gcnn,uscg}
                        the model to train
  --output OUTPUT, -o OUTPUT
                        output folder
  --force               overwrite existing files if present
  --interactive, -i     triggers interactive mode
dunedn.training.denoise_training.add_arguments_training(parser)[source]

Adds training subparser arguments.

Parameters

parser (ArgumentParser) – Training subparser object.

dunedn.training.denoise_training.training(args)[source]

Wrapper training function.

Parameters

args (NameSpace) – Command line parsed arguments.

Returns

  • float – Minimum loss over training.

  • float – Uncertainty over minimum loss.

  • str – Best checkpoint file name.

dunedn.training.denoise_training.training_main(modeltype: str, setup: dict)[source]

Main training function.

Parameters
  • modeltype (str) – The model to be trained. Available options: cnn | gcnn | uscg.

  • setup (dict) – Settings dictionary.

dunedn.training.losses module

This module implements several losses.

Main option is reduction, which could be either mean (default) or None.

class dunedn.training.losses.Loss(a=0.5, data_range=1.0, reduction='mean')[source]

Bases: ABC

Abstract loss function class.

class dunedn.training.losses.LossBce(ratio=0.5, reduction='mean')[source]

Bases: Loss

Binary cross entropy loss function.

Computes Xent(y_true, y_pred).

class dunedn.training.losses.LossBceDice(ratio=0.5, reduction='mean')[source]

Bases: Loss

Binary xent + soft dice loss function.

class dunedn.training.losses.LossCfnm(reduction='mean')[source]

Bases: Loss

Confusion matrix function.

class dunedn.training.losses.LossImae(a=0.84, data_range=1.0, reduction='mean')[source]

Bases: Loss

Mean absolute error on integrated charge loss function.

Computes IMAE(y_true, y_pred) = mean(|y_true.sum(-1) - y_pred.sum(-1)|).

class dunedn.training.losses.LossMse(a=0.84, data_range=1.0, reduction='mean')[source]

Bases: Loss

Mean squared error loss function.

Computes L2(y_true, y_pred) = mean((y_true - y_pred)**2).

class dunedn.training.losses.LossPsnr(reduction='mean')[source]

Bases: Loss

Peak signal to noise ration function.

class dunedn.training.losses.LossSoftDice(reduction='mean')[source]

Bases: Loss

Soft dice loss function.

dice(x, y)[source]
Parameters
  • x (torch.Tensor) – Predicted tensor, of shape=(N,C,W,H).

  • y (torch.Tensor) – Target tensor, of shape=(N,C,W,H).

class dunedn.training.losses.LossSsim(a=0.84, data_range=1.0, reduction='mean')[source]

Bases: Loss

Statistical structural similarity loss function.

Computes Lssim(y_true, y_pred) = 1 - stat-ssim(y_true, y_pred).

class dunedn.training.losses.LossSsimL1(a=0.84, data_range=1.0, reduction='mean')[source]

Bases: Loss

Stat ssim + mean absolute error loss function.

Computes (a * L1 + (1 - a) * 1e-3 * Lssim)(y_true, y_pred).

class dunedn.training.losses.LossSsimL2(a=0.84, data_range=1.0, reduction='mean')[source]

Bases: Loss

Stat ssim + MSE loss function.

Computes (a * L2 + (1 - a) * 1e-3 * Lssim)(y_true, y_pred).

class dunedn.training.losses.Ssim(a=0.84, data_range=1.0, reduction='mean')[source]

Bases: Loss

Statistical structural similarity function.

Computes ssim(y_true, y_pred) = stat-ssim(y_true, y_pred).

dunedn.training.losses.get_loss(loss)[source]

Utility function to retrieve loss from loss name.

Parameters

loss (str) –

Available options:

  • mse

  • imae

  • ssim

  • ssim_l2

  • ssim_l1

  • bce

  • softdice

  • cfnm

Returns

The query loss class.

Return type

Loss

Raises

NotImplementedError – If loss is not in: - mse - imae - ssim - ssim_l2 - ssim_l1 - bce - softdice - cfnm

dunedn.training.losses.get_metric(metric)[source]

Utility function to retrieve metrics from loss name.

Parameters

metric (str) –

Available options:

  • mse

  • imae

  • ssim

  • psnr

  • bce

  • softdice

  • cfnm

Returns

The query class.

Return type

Loss

Raises

NotImplementedError – If loss is not in: - mse - imae - ssim - bce - softdice - cfnm

dunedn.training.metrics module

class dunedn.training.metrics.MetricsList(metrics: list[str])[source]

Bases: object

Wrapping class for a list of metrics.

Example

>>> from dunedn.training.metrics import MetricsList
>>> metrics = ["ssim", "psnr"]
>>> MetricsList(metrics)
combine_collection_induction_results(ires: dict, cres: dict)[source]

Combine computed metrics from different planes types.

Metrics results must be averaged, while standard deviations are summed in quadrature.

Parameters
  • ires (dict) – Computed metrics on induction planes.

  • cres (dict) – Computed metrics on collection planes.

Returns

res – The combined metrices.

Return type

dict

compute_metrics(y_pred: Tensor, y_true: Tensor) dict[source]

Computes values from the stored list of metrics.

Parameters
  • y_pred (torch.tensor) – Prediction tensor, of shape=(N,C,H,W).

  • y_true (torch.tensor) – Labels tensor, of shape=(N,C,H,W).

Returns

res_metrics – The computed metrics results in dictionary form.

Return type

dict

print_metrics(logger: Logger, logs: dict)[source]

Log the computed metrics.

Parameters
  • logger (Logger) – The logging object.

  • logs (dict) – The computed metrics values to be logged.

dunedn.training.ssim module

class dunedn.training.ssim.MS_SSIM(data_range: float = 255, reduction: bool = True, win_size: int = 11, win_sigma: float = 1.5, channel: int = 3, weights: Optional[list[float]] = None, k: list[float] = [0.01, 0.03])[source]

Bases: Module

Multiscale Strctural Similarity class.

forward(x: Tensor, y: Tensor) Tensor[source]

Computes the multi scale ssim between two images.

Parameters
  • x (torch.Tensor) – Images, of shape=(N,C,H,W).

  • y (torch.Tensor) – Images, of shape=(N,C,H,W).

Returns

Multi scale Ssim results, of shape=(N,C,H,W).

Return type

torch.Tensor

training: bool
class dunedn.training.ssim.SSIM(data_range: float = 255, reduction: bool = True, win_size: int = 11, win_sigma: float = 1.5, channel: int = 3, k: list[float] = [0.01, 0.03], nonnegative_ssim: bool = False)[source]

Bases: Module

Strctural Similarity class.

forward(x: Tensor, y: Tensor) Tensor[source]

Computes the ssim between two images.

Parameters
  • x (torch.Tensor) – Images, of shape=(N,C,H,W).

  • y (torch.Tensor) – Images, of shape=(N,C,H,W).

Returns

Ssim results, of shape=(N,C,H,W).

Return type

torch.Tensor

training: bool
class dunedn.training.ssim.STAT_SSIM(data_range: float = 255, reduction: bool = True, win_size: int = 11, win_sigma: float = 1.5, channel: int = 3, k: list[float] = [0.01, 0.03], nonnegative_ssim: bool = False)[source]

Bases: Module

forward(x: Tensor, y: Tensor) Tensor[source]

Computes the statistical ssim between two images.

Parameters
  • x (torch.Tensor) – Images, of shape=(N,C,H,W).

  • y (torch.Tensor) – Images, of shape=(N,C,H,W).

Returns

Ssim results, of shape=(N,C,H,W).

Return type

torch.Tensor

training: bool
dunedn.training.ssim.gaussian_filter(inputs: Tensor, win: Tensor) Tensor[source]

Blur input with 1-D kernel (valid padding)

Parameters
  • inputs (torch.Tensor) – A batch of tensors to be blured, of shape=(N,C,H,W).

  • window (torch.Tensor) – 1-D gauss kernel, of shape=(1, 1, size).

Returns

torch.Tensor

Return type

blured tensors

dunedn.training.ssim.ms_ssim(x: Tensor, y: Tensor, data_range: float = 255.0, reduction: bool = True, win_size: int = 11, win_sigma: int = 3, win: Optional[Tensor] = None, weights: Optional[list[float]] = None, k: Tuple[int, int] = (1e-13, 1e-13)) Tensor[source]

Interface for Multiscale Structural Similarity function.

Parameters
  • x (torch.Tensor) – Images, of shape=(N,C,H,W).

  • y (torch.Tensor) – Images, of shape=(N,C,H,W).

  • data_range (float) – Value range of input images. Usually 1.0 or 255.

  • reduction (bool) – If reduction=True, ssim of all images will be averaged as a scalar.

  • win_size (int) – The size of the gaussian kernel.

  • win_sigma (float) – Standard deviation in pixels units of the gaussian kernel.

  • win (torch.Tensor) – 1-D gauss kernel, of shape=(1, 1, size).

  • weights (list[float]) – Weights for different levels in the multiscale computation.

  • k (list[float]) – Cut-off values for fraction numerical stability.

Returns

Ssim results, of shape=(N,C,H,W).

Return type

torch.Tensor

dunedn.training.ssim.ssim(x: Tensor, y: Tensor, data_range: float = 255.0, reduction: bool = True, win_size: int = 11, win_sigma: int = 3, win: Optional[Tensor] = None, k: Tuple[int, int] = (1e-13, 1e-13), nonnegative_ssim: bool = False) Tensor[source]

Interface for Structural Similarity function.

Parameters
  • x (torch.Tensor) – Images, of shape=(N,C,H,W).

  • y (torch.Tensor) – Images, of shape=(N,C,H,W).

  • data_range (float) – Value range of input images. Usually 1.0 or 255.

  • reduction (bool) – If reduction=True, ssim of all images will be averaged as a scalar.

  • win_size (int) – The size of the gaussian kernel.

  • win_sigma (float) – Standard deviation in pixels units of the gaussian kernel.

  • win (torch.Tensor) – 1-D gauss kernel, of shape=(1, 1, size).

  • k (list[float]) – Cut-off values for fraction numerical stability.

  • nonnegative_ssim (bool) – Wether to force the ssim response to be nonnegative with relu function.

Returns

Ssim results, of shape=(N,C,H,W).

Return type

torch.Tensor

dunedn.training.ssim.stat_gaussian_filter(inputs: Tensor, win: Tensor)[source]

Blur input with 1-D kernel, applying same padding.

Parameters
  • inputs (torch.Tensor) – A batch of tensors to be blured, of shape=(N,C,H,W).

  • win (torch.Tensor) – 1-D gauss kernel, of shape=(1, 1, size).

Returns

Blured tensors, of shape=(N,C,H,W).

Return type

torch.Tensor

dunedn.training.ssim.stat_ssim(x, y, data_range=255, reduction=True, win_size=11, win_sigma=3, win=None, k=(1e-13, 1e-13), nonnegative_ssim=False) Tensor[source]

Interface for Statistical Structural Similarity function.

Parameters
  • x (torch.Tensor) – Images, of shape=(N,C,H,W).

  • y (torch.Tensor) – Images, of shape=(N,C,H,W).

  • data_range (float) – Value range of input images. Usually 1.0 or 255.

  • reduction (bool) – If reduction=True, ssim of all images will be averaged as a scalar.

  • win_size (int) – The size of the gaussian kernel.

  • win_sigma (float) – Standard deviation in pixels units of the gaussian kernel.

  • win (torch.Tensor) – 1-D gauss kernel, of shape=(1, 1, size).

  • k (list[float]) – Cut-off values for fraction numerical stability.

  • nonnegative_ssim (bool) – Wether to force the ssim response to be nonnegative with relu function.

Returns

Stat-Ssim results, of shape=(N,C,H,W).

Return type

torch.Tensor

Module contents