Source code for dunedn.networks.uscg.uscg_dataloading

"""This module implements dataset loading for the USCG network."""
import logging
from typing import Tuple
import numpy as np
import torch
from ..utils import get_hits_from_clear_images
from dunedn import PACKAGE
from dunedn.utils.utils import median_subtraction

logger = logging.getLogger(PACKAGE + ".gcnn")


[docs]class BaseUscgDataset(torch.utils.data.Dataset): """Loads the dataset for CNN and GCNN networks.""" def __init__( self, dataset_type: str, task: str, channel: str, dsetup: dict, batch_size: int, ): """ Parameters ---------- dataset_type: str Available options train | val | test task: str Available options dn | roi. channel: str Available options induction | collection dsetup: dict The dataset settings dictionary. batch_size: int The number of examples to be batched. """ self.dataset_type = dataset_type self.task = task self.channel = channel self.dsetup = dsetup self.batch_size = int(batch_size) self.training = self.dataset_type == "train" self.threshold = dsetup["threshold"] def __len__(self): return len(self.noisy)
[docs]class UscgDataset(BaseUscgDataset): """Loads the dataset for CNN and GCNN networks.""" def __init__( self, dataset_type: str, task: str, channel: str, dsetup: dict, batch_size: int, ): """ Parameters ---------- dataset_type: str Available options train | val | test task: str Available options dn | roi. channel: str Available options induction | collection dsetup: dict The dataset settings dictionary. batch_size: int The number of examples to be batched. """ super().__init__( dataset_type, task, channel, dsetup, batch_size, ) self.data_folder = self.dsetup["data_folder"] / dataset_type self.planes_folder = self.data_folder / "planes" self.noisy, self.clear = self.get_planes_from_setup()
[docs] def get_planes_from_setup(self) -> Tuple[torch.Tensor, torch.Tensor]: """Get planes from folder pointed by `dsetup`. Returns ------- noisy: torch.Tensor The noisy planes, of shape=(N,1,H,W). clear: torch.Tensor The clear planes, of shape=(N,1,H,W). """ fname = self.planes_folder / f"{self.channel}_noisy.npy" noisy = np.load(fname) noisy = median_subtraction(noisy) fname = self.planes_folder / f"{self.channel}_clear.npy" clear = np.load(fname) if self.task == "roi": clear = get_hits_from_clear_images(clear, self.threshold) self.balance_ratio = np.count_nonzero(clear) / clear.size() noisy = torch.Tensor(noisy) clear = torch.Tensor(clear) return noisy, clear
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: """ Returns ------- noisy: torch.Tensor A single noisy example, of shape=(1,H,W). clear: torch.Tensor A single clear example, of shape=(1,H,W). """ return self.noisy[index], self.clear[index]
[docs]class UscgPlanesDataset(BaseUscgDataset): """Loads planes in dataset form for GcnnNet network inference.""" def __init__( self, noisy: np.ndarray, task: str, channel: str, dsetup: dict, batch_size: int, ): """ Parameters ---------- noisy: np.ndarray The noisy planes for inference. task: str Available options dn | roi. channel: str Available options induction | collection dsetup: dict The dataset settings dictionary. batch_size: int The number of examples to be batched. """ noisy = noisy.astype(np.float32) self.noisy = median_subtraction(noisy) super().__init__( "test", task, channel, dsetup, batch_size, ) def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]: """ Returns ------- noisy: torch.Tensor A single noisy example, of shape=(1,H,W). None dummy output for labels. """ return self.noisy[index], 0