Source code for dunedn.training.losses

"""
    This module implements several losses.
    
    Main option is reduction, which could be either `mean` (default) or `None`.
"""
import torch
import numpy as np
from torch import nn
from abc import ABC, abstractmethod
from dunedn.training.ssim import stat_ssim
from dunedn.utils.utils import confusion_matrix


EPS = torch.Tensor([torch.finfo(torch.float64).eps])


[docs]class Loss(ABC): """Abstract loss function class.""" def __init__(self, a=0.5, data_range=1.0, reduction="mean"): """ Parameters ---------- a: float Relative weight of the loss constributions. data_range: float Data interval. reduction: str Available options mean | none. """ self.a = a self.data_range = data_range self.reduction = reduction @abstractmethod def __call__(self, y_pred, y_true): """Compute the loss function""" pass
[docs]class LossMse(Loss): """Mean squared error loss function. Computes ``L2(y_true, y_pred) = mean((y_true - y_pred)**2)``. """ def __init__(self, a=0.84, data_range=1.0, reduction="mean"): super(LossMse, self).__init__(reduction=reduction) self.loss = nn.MSELoss(reduction="none") self.name = "MSE" def __call__(self, y_pred, y_true): """ Parameters ---------- y_pred: torch.Tensor Predicted tensor, of shape=(N,C,W,H). y_true: torch.Tensor Target tensor, of shape=(N,C,W,H). """ loss = self.loss(y_pred, y_true) if self.reduction == "mean": return loss.mean() return loss.mean([1, 2, 3])
[docs]class LossImae(Loss): """Mean absolute error on integrated charge loss function. Computes ``IMAE(y_true, y_pred) = mean(|y_true.sum(-1) - y_pred.sum(-1)|)``. """ def __init__(self, a=0.84, data_range=1.0, reduction="mean"): super(LossImae, self).__init__(reduction=reduction) self.loss = nn.L1Loss(reduction="none") self.name = "IMAE" def __call__(self, y_pred, y_true): """ Parameters ---------- y_pred: torch.Tensor Predicted tensor, of shape=(N,C,W,H). y_true: torch.Tensor Target tensor, of shape=(N,C,W,H). """ loss = self.loss(y_pred.sum(-1), y_true.sum(-1)) if self.reduction == "mean": return loss.mean() return loss.reshape([loss.shape[0], -1]).mean(-1)
[docs]class LossSsim(Loss): """Statistical structural similarity loss function. Computes ``Lssim(y_true, y_pred) = 1 - stat-ssim(y_true, y_pred)``. """ def __init__(self, a=0.84, data_range=1.0, reduction="mean"): super(LossSsim, self).__init__(a, data_range, reduction) self.name = "Lssim" def __call__(self, x_pred, y_true): """ Parameters ---------- y_pred: torch.Tensor Predicted tensor, of shape=(N,C,W,H). y_true: torch.Tensor Target tensor, of shape=(N,C,W,H). """ return 1 - stat_ssim( x_pred, y_true, data_range=self.data_range, reduction=self.reduction )
[docs]class Ssim(Loss): """Statistical structural similarity function. Computes ``ssim(y_true, y_pred) = stat-ssim(y_true, y_pred)``. """ def __init__(self, a=0.84, data_range=1.0, reduction="mean"): super(Ssim, self).__init__(a, data_range, reduction) self.name = "ssim" def __call__(self, y_pred, y_true): """ Parameters ---------- y_pred: torch.Tensor Predicted tensor, of shape=(N,C,W,H). y_true: torch.Tensor Target tensor, of shape=(N,C,W,H). """ return stat_ssim( y_pred, y_true, data_range=self.data_range, reduction=self.reduction )
[docs]class LossSsimL2(Loss): """Stat ssim + MSE loss function. Computes ``(a * L2 + (1 - a) * 1e-3 * Lssim)(y_true, y_pred)``. """ def __init__(self, a=0.84, data_range=1.0, reduction="mean"): super(LossSsimL2, self).__init__(a, data_range, reduction) self.name = "ssim_l2" def __call__(self, y_pred, y_true): """ Parameters ---------- y_pred: torch.Tensor Predicted tensor, of shape=(N,C,W,H). y_true: torch.Tensor Target tensor, of shape=(N,C,W,H). """ loss1 = nn.MSELoss(reduction=self.reduction)(y_pred, y_true) if self.reduction == "none": loss1 = loss1.mean([1, 2, 3]) loss2 = 1 - stat_ssim( y_pred, y_true, data_range=self.data_range, reduction=self.reduction ) return self.a * loss1 + (1 - self.a) * 1e-3 * loss2
[docs]class LossSsimL1(Loss): """Stat ssim + mean absolute error loss function. Computes ``(a * L1 + (1 - a) * 1e-3 * Lssim)(y_true, y_pred)``. """ def __init__(self, a=0.84, data_range=1.0, reduction="mean"): super(LossSsimL1, self).__init__(a, data_range, reduction) self.name = "ssim_l1" def __call__(self, y_pred, y_true): """ Parameters ---------- y_pred: torch.Tensor Predicted tensor, of shape=(N,C,W,H). y_true: torch.Tensor Target tensor, of shape=(N,C,W,H). """ loss1 = (y_pred - y_true).abs() if self.reduction == "mean": loss1 = loss1.mean() elif self.reduction == "none": loss1 = loss1.mean([1, 2, 3]) loss2 = 1 - stat_ssim( y_pred, y_true, data_range=self.data_range, reduction=self.reduction ) return self.a * loss1 + (1 - self.a) * 1e-3 * loss2
[docs]class LossBce(Loss): """Binary cross entropy loss function. Computes ``Xent(y_true, y_pred)``. """ def __init__(self, ratio=0.5, reduction="mean"): """ Ratio is the number of positive against negative example in training set. It's used for reweighting the cross entropy Parameters ---------- reduction: str Available options mean | sum | none. """ super(LossBce, self).__init__(0, 0, reduction) self.ratio = ratio self.name = "xent" def __call__(self, y_pred, y_true): """ Parameters ---------- y_pred: torch.Tensor Predicted tensor, of shape=(N,C,W,H). y_true: torch.Tensor Target tensor, of shape=(N,C,W,H). """ log = lambda x: torch.log(x + EPS.to(x.device)) loss = -y_true * log(y_pred) / self.ratio - (1 - y_true) * log(1 - y_pred) / ( 1 - self.ratio ) if self.reduction == "mean": return loss.mean() elif self.reduction == "sum": return loss.sum() return loss
[docs]class LossSoftDice(Loss): """Soft dice loss function.""" def __init__(self, reduction="mean"): """ Reduction: str 'mean' | 'none' """ super(LossSoftDice, self).__init__(0, 0, reduction) self.name = "softdice"
[docs] def dice(self, x, y): """ Parameters ---------- x: torch.Tensor Predicted tensor, of shape=(N,C,W,H). y: torch.Tensor Target tensor, of shape=(N,C,W,H). """ eps = EPS.to(x.device) ix = 1 - x iy = 1 - y num1 = (x * y).sum((-1, -2)) + eps den1 = (x * x + y * y).sum((-1, -2)) + eps num2 = (ix * iy).sum((-1, -2)) + eps den2 = (ix * ix + iy * iy).sum((-1, -2)) + eps return num1 / den1 + num2 / den2
def __call__(self, y_pred, y_true): """ Parameters ---------- y_pred: torch.Tensor Predicted tensor, of shape=(N,C,W,H). y_true: torch.Tensor Target tensor, of shape=(N,C,W,H). """ ratio = self.dice(y_pred, y_true) loss = 1 - ratio if self.reduction == "mean": return loss.mean() return loss
[docs]class LossBceDice(Loss): """Binary xent + soft dice loss function.""" def __init__(self, ratio=0.5, reduction="mean"): """ Reduction: str 'mean' | 'none' """ super(LossBceDice, self).__init__(0, 0, reduction) self.bce = LossBce(ratio, reduction="none") self.dice = LossSoftDice(reduction="none") self.name = "xent_softdice" def __call__(self, y_pred, y_true): """ Parameters ---------- y_pred: torch.Tensor Predicted tensor, of shape=(N,C,W,H). y_true: torch.Tensor Target tensor, of shape=(N,C,W,H). """ shape = [y_pred.shape[0], -1] bce = self.bce(y_pred, y_true).reshape(shape).mean(-1) dice = -torch.log(self.dice.dice(y_pred, y_true)) loss = bce + dice if self.reduction == "mean": return loss.mean() return loss
[docs]class LossPsnr(Loss): """Peak signal to noise ration function.""" def __init__(self, reduction="mean"): super(LossPsnr, self).__init__(reduction=reduction) self.mse = nn.MSELoss(reduction="none") self.name = "psnr" def __call__(self, y_noisy, y_clear): """ Parameters ---------- y_pred: torch.Tensor Predicted tensor, of shape=(N,C,W,H). y_true: torch.Tensor Target tensor, of shape=(N,C,W,H). """ nimages = y_clear.shape[0] x1 = y_clear.reshape(nimages, -1) x2 = y_noisy.reshape(nimages, -1) mse = self.mse(x1, x2).mean(-1) m2 = x1.max(-1).values ** 2 zero = torch.Tensor([0.0]).to(x1.device) eps = EPS.to(x1.device) psnr = torch.where(m2 == 0, zero, 10 * torch.log10(m2 / (mse + eps))) if self.reduction == "none": return psnr return psnr.mean()
[docs]class LossCfnm(Loss): """Confusion matrix function.""" def __init__(self, reduction="mean"): self.name = "cfnm" def __call__(self, y_pred, y_true): # compute the confusion matrix from cuda tensors n = len(y_pred) os = y_pred.cpu().numpy().reshape([n, -1]) ts = y_true.cpu().numpy().reshape([n, -1]) cfnm = [] for o, t in zip(os, ts): hit = o[t.astype(bool)] no_hit = o[~t.astype(bool)] cfnm.append(confusion_matrix(hit, no_hit, 0.5)) cfnm = np.stack(cfnm) cfnm = cfnm / cfnm[0, :].sum() tp = [cfnm[:, 0].mean(), cfnm[:, 0].std() / np.sqrt(n)] fp = [cfnm[:, 1].mean(), cfnm[:, 1].std() / np.sqrt(n)] fn = [cfnm[:, 2].mean(), cfnm[:, 2].std() / np.sqrt(n)] tn = [cfnm[:, 3].mean(), cfnm[:, 3].std() / np.sqrt(n)] return tp, fp, fn, tn
[docs]def get_loss(loss): """Utility function to retrieve loss from loss name. Parameters ---------- loss: str Available options: - mse - imae - ssim - ssim_l2 - ssim_l1 - bce - softdice - cfnm Returns ------- Loss The query loss class. Raises ------ NotImplementedError If `loss` is not in: - mse - imae - ssim - ssim_l2 - ssim_l1 - bce - softdice - cfnm """ if loss == "mse": return LossMse elif loss == "imae": return LossImae elif loss == "ssim": return LossSsim elif loss == "ssim_l2": return LossSsimL2 elif loss == "ssim_l1": return LossSsimL1 elif loss == "bce": return LossBce elif loss == "softdice": return LossSoftDice elif loss == "bce_dice": return LossBceDice elif loss == "psnr": return LossPsnr elif loss == "cfnm": return LossCfnm else: raise NotImplementedError("Loss function not implemented")
[docs]def get_metric(metric): """Utility function to retrieve metrics from loss name. Parameters ---------- metric: str Available options: - mse - imae - ssim - psnr - bce - softdice - cfnm Returns ------- Loss The query class. Raises ------ NotImplementedError If `loss` is not in: - mse - imae - ssim - bce - softdice - cfnm """ if metric == "mse": return LossMse elif metric == "imae": return LossImae elif metric == "ssim": return Ssim elif metric == "psnr": return LossPsnr elif metric == "xent": return LossBce elif metric == "softdice": return LossSoftDice elif metric == "cfnm": return LossCfnm else: raise NotImplementedError("Loss function not implemented")