Source code for dunedn.networks.gcnn.training

""" This module provides functions for CNN and GCNN networks training and loading."""
import logging
from pathlib import Path
import torch
from .gcnn_dataloading import GcnnDataset
from .gcnn_net import GcnnNet
from .utils import make_dict_compatible
from dunedn import PACKAGE
from dunedn.training.losses import get_loss
from dunedn.training.metrics import DN_METRICS

logger = logging.getLogger(PACKAGE + ".gcnn")


[docs]def load_and_compile_gcnn_network( channel: str, msetup: dict, checkpoint_filepath: Path = None ) -> GcnnNet: """Loads a CNN or GCNN network. Parameters ---------- channel: str Available options induction | collection. msetup: dict The model setup dictionary. checkpoint_filepath: Path The `.pth` checkpoint containing network weights to be loaded. Returns ------- network: GcnnNet The loaded neural network. """ network = GcnnNet(**msetup["net_dict"]) if checkpoint_filepath: logger.info(f"Loading weights at {checkpoint_filepath}") state_dict = torch.load(checkpoint_filepath, map_location=torch.device("cpu")) new_state_dict = make_dict_compatible(state_dict) network.load_state_dict(new_state_dict) # loss loss = get_loss(msetup["loss_fn"])() # optimizer optimizer = torch.optim.Adam( list(network.parameters()), msetup["lr"], amsgrad=msetup["amsgrad"] ) network.compile(loss, optimizer, DN_METRICS) return network
[docs]def gcnn_training(modeltype: str, setup: dict): """GCNN network training. Parameters ---------- modeltype: str The model to be trained. Available options: cnn | gcnn. setup: dict Settings dictionary. """ # model loading assert modeltype in ["cnn", "gcnn"] msetup = setup["model"][modeltype] channel = "collection" network = load_and_compile_gcnn_network( channel, msetup, setup["dev"], msetup["ckpt"] ) # TODO: remove channel (collection | induction) hard coding # data loading gen_kwargs = { "task": setup["task"], "channel": channel, "dsetup": setup["dataset"], } train_generator = GcnnDataset( "train", batch_size=msetup["batch_size"], **gen_kwargs ) val_generator = GcnnDataset( "val", batch_size=msetup["test_batch_size"], **gen_kwargs ) test_generator = GcnnDataset( "test", batch_size=msetup["test_batch_size"], **gen_kwargs ) # training network.fit( train_generator, epochs=setup["model"]["epochs"], val_generator=val_generator, dev=setup["dev"], ) # testing _, logs = network.predict(test_generator) network.metrics_list.print_metrics(logger, logs)