Source code for dunedn.tests.test_networks

"""
    Ensures DUNEdn networks objects run the forwrd pass without errors.
"""
import logging
from pathlib import Path
import torch
from dunedn.configdn import PACKAGE
from dunedn.networks.gcnn.training import load_and_compile_gcnn_network
from dunedn.networks.uscg.training import load_and_compile_uscg_network
from dunedn.networks.utils import get_supported_models
from dunedn.utils.utils import load_runcard

# instantiate logger
logger = logging.getLogger(PACKAGE + ".test")


[docs]def run_test_uscg(setup: dict): """Run USCG network test. Parameters ---------- setup: dict Runcard settings. Raises ------ AssertionError If model input and output shapes do not match. """ # tuple containing induction and collection inference arguments msetup = setup["model"]["uscg"] batch_size = 1 msetup["net_dict"]["h_collection"] = 64 msetup["net_dict"]["w"] = 128 # load dummy dataset dummy_dataset = torch.rand( batch_size, 1, msetup["net_dict"]["h_collection"], msetup["net_dict"]["w"] ) # load cnn model model = load_and_compile_uscg_network("collection", msetup) model.eval() # forward pass output = model(dummy_dataset) # check that input and output have the same shape try: assert dummy_dataset.shape == output.shape except AssertionError as err: logger.critical( "Assertion fail: uscg model input and output shapes do not match" ) raise err
[docs]def run_test_cnn(setup: dict): """Run CNN network test. Parameters ---------- setup: dict Runcard settings. Raises ------ AssertionError If model input and output shapes do not match. """ # tuple containing induction and collection inference arguments msetup = setup["model"]["cnn"] # load dummy dataset dummy_dataset = torch.rand(msetup["batch_size"], 1, *setup["dataset"]["crop_size"]) # load cnn model model = load_and_compile_gcnn_network("collection", msetup) model.eval() # forward pass output = model(dummy_dataset) # check that input and output have the same shape try: assert dummy_dataset.shape == output.shape except AssertionError as err: logger.critical( "Assertion fail: CNN model input and output shapes do not match" ) raise err
[docs]def run_test_gcnn(setup: dict): """Run GCNN network test. Parameters ---------- setup: dict Runcard settings. Raises ------ AssertionError If model input and output shapes do not match. """ # tuple containing induction and collection inference arguments msetup = setup["model"]["gcnn"] # load dummy dataset dummy_dataset = torch.rand(msetup["batch_size"], 1, *setup["dataset"]["crop_size"]) # load cnn model model = load_and_compile_gcnn_network("collection", msetup) model.eval() # forward pass output = model(dummy_dataset) # check that input and output have the same shape try: assert dummy_dataset.shape == output.shape except AssertionError as err: logger.critical( "Assertion fail: GCNN model input and output shapes do not match" ) raise err
[docs]def run_test(modeltype: str): """ Run the appropriate test for the supported model. Parameters ---------- modeltype: str Available options uscg | cnn | gcnn. """ setup = load_runcard(Path("runcards/default.yaml")) logger.info("Running forward-pass test on %s model", modeltype) if modeltype == "cnn": run_test_cnn(setup) elif modeltype == "gcnn": run_test_gcnn(setup) elif modeltype == "uscg": run_test_uscg(setup) else: raise NotImplementedError(f"Modeltype not implemented, got {modeltype}")
[docs]def test_networks(): """Test wrapper function.""" for modeltype in get_supported_models(): run_test(modeltype)
if __name__ == "__main__": test_networks()