dunedn.networks package

Subpackages

Submodules

dunedn.networks.abstract_net module

class dunedn.networks.abstract_net.AbstractNet(**kwargs)[source]

Bases: Module, ABC

Abstract network implementation.

Allows network distribution on multiple devices.

Example

Instantiate an AbstractNet daughter class such as GcnnNet:

>>> network = GcnnNet(setup)
>>> print(type(network))
<class 'dunedn.networks.gcnn_net.GcnnNet'>

Distribute the network on multiple devices:

>>> device_ids = [0, 1, 2, 3]
>>> device_ids
[0, 1, 2, 3]
>>> dp = network.to_data_parallel(device_ids=device_ids)
>>> print(type(dp))
<class 'dunedn.networks.model_utils.MyDataParallel'>
check_network_is_compiled()[source]

Checks wether the object is compiled or not.

Raises

RuntimeError – If the network is not compiled.

compile(loss: Loss, optimizer: Optimizer, metrics: list[str])[source]

Compiles network.

Adds loss function, optimizer and metrics functions as attributes.

Parameters
  • loss (Loss) – The network loss function.

  • optimizer (torch.optim.Optimizer) – The optimizer used to update network’s parameters.

  • metrics (list[str]) – List of metrics names.

fit(train_generator: Dataset, epochs: int, val_generator: Optional[Dataset] = None, dev: str = 'cpu', callbacks: Optional[list[dunedn.training.callbacks.Callback]] = None)[source]

Main training function.

Example

Wcample with a GCNN network.

Load a runcard.

>>> import dunedn
>>> runcard_path = Path("default.yaml")
>>> setup = dunedn.utils.utils.load_runcard(runcard_path)

Instantiate the network.

>>> network = dunedn.networks.gcnn.train.load_and_compile_gcnn_network(
... "collection", setup["model"]["gcnn"])

Load the training generator.

>>> train_generator = dunedn.networks.gcnn.gcnn_dataloading.GcnnDataset(
... "train", dsetup=setup["dataset"])

Train for one epoch.

>>> history = network.fit(train_generator, epochs=1)
Parameters
  • train_generator (torch.utils.data.Dataset) – The train dataset generator.

  • epochs (int) – Number of epochs to train network on.

  • val_generator (torch.utils.data.Dataset) – The validation dataset generator.

  • dev (str) – The device hosting the computation. Defaults is “cpu”.

Returns

history – The history callback containing training details.

Return type

History

property is_compiled
abstract predict(generator: Dataset, device: str) Tensor[source]

Network inference.

Parameters
  • generator (torch.utils.data.Dataset) – The inference generator.

  • device (str) – Device hosting computation.

Returns

y_pred – Prediction tensor. Placed on “cpu” for GPU memory saving.

Return type

torch.Tensor

to_data_parallel(device_ids: list) MyDataParallel[source]

Returns the model wrapped by MyDataParallel class.

Parameters

device_ids (list) – List of devices to place the model to. The first device in the list is the master device.

Returns

The wrapped network for distributed training.

Return type

MyDataParallel

abstract train_epoch(train_loader: DataLoader, dev: str = 'cpu') list[float][source]

Trains the network for one epoch.

Parameters
  • train_loader (torch.utils.data.DataLoader) – The training dataloader.

  • dev (str) – The device hosting the computation.

  • callback_list – callbacks implementing on_train_epoch_begin, on_train_epoch_end

Returns

epoch_history – Dictionary containing epoch history. Computed quantities at each optimization iteration, with their uncertainties. Keys:

  • loss (list[Tuple(float, float)])

  • metrics (list[Tuple(float, float)])

Return type

dict

training: bool

dunedn.networks.model_utils module

This module contains utility functions for networks in general.

class dunedn.networks.model_utils.MyDDP(module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, process_group=None, bucket_cap_mb=25, find_unused_parameters=False, check_reduction=False, gradient_as_bucket_view=False, static_graph=False)[source]

Bases: DistributedDataParallel

Distributed Data Parallel wrapper that allows calling model’s attributes.

training: bool
class dunedn.networks.model_utils.MyDataParallel(module, device_ids=None, output_device=None, dim=0)[source]

Bases: DataParallel

Data Parallel wrapper that allows calling model’s attributes.

training: bool

dunedn.networks.utils module

This module implements utility function for all the networks.

class dunedn.networks.utils.BatchProfiler(drop_last=False)[source]

Bases: object

Class to profile for loops steps.

Useful to profile each batch prediction during DnModel inference.

Example

>>> from dunedn.networks.utils import BatchProfiler
>>> from time import sleep
>>> bp = BatchProfiler()
>>> wrap = bp.set_iterable(range(10))
>>> for i in wrap:
...     print(i)
...     sleep(0.01)
>>> msg = bp.print_stats()
>>> print(msg)
property deltas: ndarray

Computes the wall time intervals between steps.

Sets the nb_batches attribute.

Returns

deltas – The result time intervals, of shape=(nb intervals,).

Return type

np.ndarray

get_stats() Tuple[float, float][source]

Computes average and mean standard error on timings.

Returns

  • mean (float) – The average batch inference time.

  • err (float) – The uncertainty on the batch inference step average time.

print_stats() str[source]

Human-readable message on profiled inference.

Returns

message – The message with profiling information.

Return type

str

set_iterable(iterable: Iterable)[source]

Sets the iterable to be profiled.

Parameters

iterable

dunedn.networks.utils.apply_median_subtraction(planes: ndarray) ndarray[source]

Computes median subtraction to input planes.

Parameters

planes (np.ndarray) – The data to be normalized.

Returns

output – The median subtracted data.

Return type

np.ndarray

dunedn.networks.utils.get_hits_from_clear_images(planes: ndarray, threshold: float) ndarray[source]

Segment input images as signal-background pixels.

Parameters
  • planes (np.ndarray) – The clear planes, of shape=(N,1,H,W).

  • threshold (float) – Threshold above which a pixel is considered containing signal.

Returns

hits – The signal-background segmented image, of shape=(N,1,C,W).

Return type

np.ndarray

dunedn.networks.utils.get_supported_models()[source]

Returns the names of the supported models.

Returns

Return type

  • list, the list of currently implemented models

dunedn.networks.utils.print_cfnm(cfnm, channel)[source]

Prints confusion matrix.

Parameters
  • cfnm (list) – Computed confusion matrix.

  • channel (str) – Available options readout | collection.

Returns

msg – The confusion matrix representatiton to be printed.

Return type

str

dunedn.networks.utils.print_epoch_logs(logger: Logger, metrics_names: list[str], logs: dict)[source]

Prints logs dictionary on epoch end.

Parameters
  • logger (Logger) – The logging object.

  • metrics_names (list[str]) – The list of metrics to be printed.

  • logs (dict) – The computed metrics values to be logged.

Module contents