dunedn.networks package
Subpackages
Submodules
dunedn.networks.abstract_net module
- class dunedn.networks.abstract_net.AbstractNet(**kwargs)[source]
Bases:
Module,ABCAbstract network implementation.
Allows network distribution on multiple devices.
Example
Instantiate an AbstractNet daughter class such as GcnnNet:
>>> network = GcnnNet(setup) >>> print(type(network)) <class 'dunedn.networks.gcnn_net.GcnnNet'>
Distribute the network on multiple devices:
>>> device_ids = [0, 1, 2, 3] >>> device_ids [0, 1, 2, 3] >>> dp = network.to_data_parallel(device_ids=device_ids) >>> print(type(dp)) <class 'dunedn.networks.model_utils.MyDataParallel'>
- check_network_is_compiled()[source]
Checks wether the object is compiled or not.
- Raises
RuntimeError – If the network is not compiled.
- compile(loss: Loss, optimizer: Optimizer, metrics: list[str])[source]
Compiles network.
Adds loss function, optimizer and metrics functions as attributes.
- Parameters
loss (Loss) – The network loss function.
optimizer (torch.optim.Optimizer) – The optimizer used to update network’s parameters.
metrics (list[str]) – List of metrics names.
- fit(train_generator: Dataset, epochs: int, val_generator: Optional[Dataset] = None, dev: str = 'cpu', callbacks: Optional[list[dunedn.training.callbacks.Callback]] = None)[source]
Main training function.
Example
Wcample with a GCNN network.
Load a runcard.
>>> import dunedn >>> runcard_path = Path("default.yaml") >>> setup = dunedn.utils.utils.load_runcard(runcard_path)
Instantiate the network.
>>> network = dunedn.networks.gcnn.train.load_and_compile_gcnn_network( ... "collection", setup["model"]["gcnn"])
Load the training generator.
>>> train_generator = dunedn.networks.gcnn.gcnn_dataloading.GcnnDataset( ... "train", dsetup=setup["dataset"])
Train for one epoch.
>>> history = network.fit(train_generator, epochs=1)
- Parameters
train_generator (torch.utils.data.Dataset) – The train dataset generator.
epochs (int) – Number of epochs to train network on.
val_generator (torch.utils.data.Dataset) – The validation dataset generator.
dev (str) – The device hosting the computation. Defaults is “cpu”.
- Returns
history – The history callback containing training details.
- Return type
- property is_compiled
- abstract predict(generator: Dataset, device: str) Tensor[source]
Network inference.
- Parameters
generator (torch.utils.data.Dataset) – The inference generator.
device (str) – Device hosting computation.
- Returns
y_pred – Prediction tensor. Placed on “cpu” for GPU memory saving.
- Return type
torch.Tensor
- to_data_parallel(device_ids: list) MyDataParallel[source]
Returns the model wrapped by MyDataParallel class.
- Parameters
device_ids (list) – List of devices to place the model to. The first device in the list is the master device.
- Returns
The wrapped network for distributed training.
- Return type
- abstract train_epoch(train_loader: DataLoader, dev: str = 'cpu') list[float][source]
Trains the network for one epoch.
- Parameters
train_loader (torch.utils.data.DataLoader) – The training dataloader.
dev (str) – The device hosting the computation.
callback_list – callbacks implementing on_train_epoch_begin, on_train_epoch_end
- Returns
epoch_history – Dictionary containing epoch history. Computed quantities at each optimization iteration, with their uncertainties. Keys:
loss (list[Tuple(float, float)])
metrics (list[Tuple(float, float)])
- Return type
dict
- training: bool
dunedn.networks.model_utils module
This module contains utility functions for networks in general.
- class dunedn.networks.model_utils.MyDDP(module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, process_group=None, bucket_cap_mb=25, find_unused_parameters=False, check_reduction=False, gradient_as_bucket_view=False, static_graph=False)[source]
Bases:
DistributedDataParallelDistributed Data Parallel wrapper that allows calling model’s attributes.
- training: bool
dunedn.networks.utils module
This module implements utility function for all the networks.
- class dunedn.networks.utils.BatchProfiler(drop_last=False)[source]
Bases:
objectClass to profile for loops steps.
Useful to profile each batch prediction during DnModel inference.
Example
>>> from dunedn.networks.utils import BatchProfiler >>> from time import sleep >>> bp = BatchProfiler() >>> wrap = bp.set_iterable(range(10)) >>> for i in wrap: ... print(i) ... sleep(0.01) >>> msg = bp.print_stats() >>> print(msg)
- property deltas: ndarray
Computes the wall time intervals between steps.
Sets the
nb_batchesattribute.- Returns
deltas – The result time intervals, of shape=(nb intervals,).
- Return type
np.ndarray
- get_stats() Tuple[float, float][source]
Computes average and mean standard error on timings.
- Returns
mean (float) – The average batch inference time.
err (float) – The uncertainty on the batch inference step average time.
- dunedn.networks.utils.apply_median_subtraction(planes: ndarray) ndarray[source]
Computes median subtraction to input planes.
- Parameters
planes (np.ndarray) – The data to be normalized.
- Returns
output – The median subtracted data.
- Return type
np.ndarray
- dunedn.networks.utils.get_hits_from_clear_images(planes: ndarray, threshold: float) ndarray[source]
Segment input images as signal-background pixels.
- Parameters
planes (np.ndarray) – The clear planes, of shape=(N,1,H,W).
threshold (float) – Threshold above which a pixel is considered containing signal.
- Returns
hits – The signal-background segmented image, of shape=(N,1,C,W).
- Return type
np.ndarray
- dunedn.networks.utils.get_supported_models()[source]
Returns the names of the supported models.
- Returns
- Return type
list, the list of currently implemented models
- dunedn.networks.utils.print_cfnm(cfnm, channel)[source]
Prints confusion matrix.
- Parameters
cfnm (list) – Computed confusion matrix.
channel (str) – Available options readout | collection.
- Returns
msg – The confusion matrix representatiton to be printed.
- Return type
str
- dunedn.networks.utils.print_epoch_logs(logger: Logger, metrics_names: list[str], logs: dict)[source]
Prints logs dictionary on epoch end.
- Parameters
logger (Logger) – The logging object.
metrics_names (list[str]) – The list of metrics to be printed.
logs (dict) – The computed metrics values to be logged.