Source code for dunedn.training.metrics

from logging import Logger
from math import sqrt
import torch
from .losses import get_metric

DN_METRICS = ["ssim", "psnr", "mse", "imae"]

ROI_METRICS = ["xent", "softdice"]


[docs]class MetricsList: """Wrapping class for a list of metrics. Example ------- >>> from dunedn.training.metrics import MetricsList >>> metrics = ["ssim", "psnr"] >>> MetricsList(metrics) """ def __init__(self, metrics: list[str]): """ Parameters ---------- metrics: list[str] The list of metrics names. """ self.metrics = [get_metric(metric)(reduction="none") for metric in metrics] self.names = [metric.name for metric in self.metrics]
[docs] def combine_collection_induction_results(self, ires: dict, cres: dict): """Combine computed metrics from different planes types. Metrics results must be averaged, while standard deviations are summed in quadrature. Parameters ---------- ires: dict Computed metrics on induction planes. cres: dict Computed metrics on collection planes. Returns ------- res: dict The combined metrices. """ res = {} for name in self.names: ivalue = ires.get(name) cvalue = cres.get(name) if ivalue is not None and cvalue is not None: res[name] = (ivalue + cvalue) * 0.5 ivalue_std = ires.get(name + "_std") cvalue_std = cres.get(name + "_std") if ivalue_std is not None and cvalue_std is not None: res[name + "_std"] = (ivalue_std + cvalue_std) * 0.5 return res
[docs] def compute_metrics(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> dict: """Computes values from the stored list of metrics. Parameters ---------- y_pred: torch.tensor Prediction tensor, of shape=(N,C,H,W). y_true: torch.tensor Labels tensor, of shape=(N,C,H,W). Returns ------- res_metrics: dict The computed metrics results in dictionary form. """ results = torch.stack( [metric(y_pred, y_true) for metric in self.metrics], dim=0 ) res_mean = results.mean(-1) sqrtn = sqrt(len(res_mean)) res_std = results.std(-1) / sqrtn res_metrics = {name: mean.item() for name, mean in zip(self.names, res_mean)} res_metrics.update( {f"{name}_std": std.item() for name, std in zip(self.names, res_std)} ) return res_metrics
[docs] def print_metrics(self, logger: Logger, logs: dict): """Log the computed metrics. Parameters ---------- logger: Logger The logging object. logs: dict The computed metrics values to be logged. """ msg = "Prediction metrics:\n" for name in self.names: mean = logs.get(name) std = logs.get(f"{name}_std") msg += f"{name:>10}: {mean:.3f} +/- {std:.3f}\n" logger.info(msg.strip("\n"))