Source code for dunedn.networks.abstract_net

from abc import ABC, abstractmethod
import logging
from time import time as tm
import torch
from .model_utils import MyDataParallel
from .utils import print_epoch_logs
from dunedn.training.callbacks import History, Callback, CallbackList
from dunedn.training.losses import Loss
from dunedn.training.metrics import MetricsList
from dunedn import PACKAGE

logger = logging.getLogger(PACKAGE + ".train")


[docs]class AbstractNet(torch.nn.Module, ABC): """Abstract network implementation. Allows network distribution on multiple devices. Example ------- Instantiate an AbstractNet daughter class such as GcnnNet: >>> network = GcnnNet(setup) >>> print(type(network)) <class 'dunedn.networks.gcnn_net.GcnnNet'> Distribute the network on multiple devices: >>> device_ids = [0, 1, 2, 3] >>> device_ids [0, 1, 2, 3] >>> dp = network.to_data_parallel(device_ids=device_ids) >>> print(type(dp)) <class 'dunedn.networks.model_utils.MyDataParallel'> """ def __init__(self, **kwargs): super().__init__(**kwargs) self.__is_compiled = False self.history = History() @property def is_compiled(self): return self.__is_compiled @is_compiled.setter def is_compiled(self, value): self.__is_compiled = value
[docs] def to_data_parallel(self, device_ids: list) -> MyDataParallel: """Returns the model wrapped by MyDataParallel class. Parameters ---------- device_ids: list List of devices to place the model to. The first device in the list is the master device. Returns ------- MyDataParallel The wrapped network for distributed training. """ return MyDataParallel(self, device_ids=device_ids)
[docs] def check_network_is_compiled(self): """Checks wether the object is compiled or not. Raises ------ RuntimeError If the network is not compiled. """ if self.__is_compiled == False: raise RuntimeError( "Model must be compiled before training/testing. " "Please call `model.compile()` method." )
[docs] def compile(self, loss: Loss, optimizer: torch.optim.Optimizer, metrics: list[str]): """Compiles network. Adds loss function, optimizer and metrics functions as attributes. Parameters ---------- loss: Loss The network loss function. optimizer: torch.optim.Optimizer The optimizer used to update network's parameters. metrics: list[str] List of metrics names. """ self.loss_fn = loss # TODO: pass non defaults arguments to loss function self.optimizer = optimizer self.metrics_list = MetricsList(metrics) self.is_compiled = True
[docs] def fit( self, train_generator: torch.utils.data.Dataset, epochs: int, val_generator: torch.utils.data.Dataset = None, dev: str = "cpu", callbacks: list[Callback] = None, ): """Main training function. Example ------- Wcample with a GCNN network. Load a runcard. >>> import dunedn >>> runcard_path = Path("default.yaml") >>> setup = dunedn.utils.utils.load_runcard(runcard_path) Instantiate the network. >>> network = dunedn.networks.gcnn.train.load_and_compile_gcnn_network( ... "collection", setup["model"]["gcnn"]) Load the training generator. >>> train_generator = dunedn.networks.gcnn.gcnn_dataloading.GcnnDataset( ... "train", dsetup=setup["dataset"]) Train for one epoch. >>> history = network.fit(train_generator, epochs=1) Parameters ---------- train_generator: torch.utils.data.Dataset The train dataset generator. epochs: int Number of epochs to train network on. val_generator: torch.utils.data.Dataset The validation dataset generator. dev: str The device hosting the computation. Defaults is "cpu". Returns ------- history: History The history callback containing training details. """ self.check_network_is_compiled() # train_sampler = DistributedSampler(dataset=train_data, shuffle=True) train_loader = torch.utils.data.DataLoader( dataset=train_generator, shuffle=True, # sampler=train_sampler, batch_size=train_generator.batch_size, ) # create CallbackList object, add History object to it if callbacks is None: callbacks = [] self.history = History() callbacks.append(self.history) self.callback_list = CallbackList(callbacks) # model on device self.to(dev) logger.info(f"Training for {epochs} epochs") self.callback_list.on_train_begin() for epoch in range(epochs): logger.info(f"Epoch {epoch}/{epochs}") self.callback_list.on_epoch_begin() # training epoch start = tm() # epoch_logs = self.train_epoch(train_loader, dev) epoch_logs = {} epoch_time = tm() - start epoch_logs.update({"epoch": epoch, "epoch_time": epoch_time}) # validation if val_generator is not None: _, logs = self.predict(val_generator, dev, verbose=0) val_logs = {f"val_{k}": v for k, v in logs.items()} epoch_logs.update(val_logs) print_epoch_logs(logger, self.metrics_names, epoch_logs) self.callback_list.on_epoch_end(epoch_logs) training_logs = None self.callback_list.on_train_end(training_logs) # model to cpu to save memory self.to("cpu") return self.history
[docs] @abstractmethod def predict(self, generator: torch.utils.data.Dataset, device: str) -> torch.Tensor: """Network inference. Parameters ---------- generator: torch.utils.data.Dataset The inference generator. device: str Device hosting computation. Returns ------- y_pred: torch.Tensor Prediction tensor. Placed on "cpu" for GPU memory saving. """ pass
[docs] @abstractmethod def train_epoch( self, train_loader: torch.utils.data.DataLoader, dev: str = "cpu", ) -> list[float]: """Trains the network for one epoch. Parameters ---------- train_loader: torch.utils.data.DataLoader The training dataloader. dev: str The device hosting the computation. callback_list: callbacks implementing on_train_epoch_begin, on_train_epoch_end Returns ------- epoch_history: dict Dictionary containing epoch history. Computed quantities at each optimization iteration, with their uncertainties. Keys: - loss (list[Tuple(float, float)]) - metrics (list[Tuple(float, float)]) """ pass