"""
This module contains the GCNN Net building blocks.
"""
import torch
from torch import nn
from dunedn.networks.gcnn.gcnn_net_utils import (
pairwise_dist,
batched_index_select,
local_mask,
)
[docs]class ROI(nn.Module):
"""U-net style binary segmentation"""
def __init__(self, kernel_size, ic, hc, getgraph_fn, model):
super(ROI, self).__init__()
self.getgraph_fn = getgraph_fn
self.pre_process_block = PreProcessBlock(
kernel_size, ic, hc, getgraph_fn, model
)
self.gcs = nn.ModuleList([choose_conv(model, hc, hc) for i in range(8)])
self.gc_final = choose_conv(model, hc, 1)
self.activ = nn.LeakyReLU(0.05)
self.act = nn.Sigmoid()
[docs] def forward(self, x):
x = self.pre_process_block(x)
for i, gc in enumerate(self.gcs):
if i % 3 == 0:
graph = self.getgraph_fn(x)
x = self.activ(gc(x, graph))
return self.act(self.gc_final(x, graph))
[docs]class PreProcessBlock(nn.Module):
def __init__(self, kernel_size, ic, oc, getgraph_fn, model):
ks = kernel_size
kso2 = kernel_size // 2
super(PreProcessBlock, self).__init__()
self.getgraph_fn = getgraph_fn
self.activ = nn.LeakyReLU(0.05)
self.convs = nn.Sequential(
nn.Conv2d(ic, oc, ks, padding=(kso2, kso2)),
self.activ,
nn.Conv2d(oc, oc, ks, padding=(kso2, kso2)),
self.activ,
nn.Conv2d(oc, oc, ks, padding=(kso2, kso2)),
self.activ,
)
self.bn = nn.BatchNorm2d(oc)
self.gc = choose_conv(model, oc, oc)
[docs] def forward(self, x):
x = self.convs(x)
graph = self.getgraph_fn(x)
return self.activ(self.gc(x, graph))
[docs]class HPF(nn.Module):
"""High Pass Filter"""
def __init__(self, ic, oc, getgraph_fn, model):
super(HPF, self).__init__()
self.getgraph_fn = getgraph_fn
self.conv = nn.Sequential(
nn.Conv2d(ic, ic, 3, padding=1), nn.BatchNorm2d(ic), nn.LeakyReLU(0.05)
)
self.gcs = nn.ModuleList(
[
choose_conv(model, ic, ic),
choose_conv(model, ic, oc),
choose_conv(model, oc, oc),
]
)
self.act = nn.LeakyReLU(0.05)
[docs] def forward(self, x):
x = self.conv(x)
graph = self.getgraph_fn(x)
for gc in self.gcs:
x = self.act(gc(x, graph))
return x
[docs]class LPF(nn.Module):
"""Low Pass Filter"""
def __init__(self, ic, oc, getgraph_fn, model):
super(LPF, self).__init__()
self.getgraph_fn = getgraph_fn
self.conv = nn.Sequential(
nn.Conv2d(ic, ic, 5, padding=2), nn.BatchNorm2d(ic), nn.LeakyReLU(0.05)
)
self.gcs = nn.ModuleList(
[
choose_conv(model, ic, ic),
choose_conv(model, ic, oc),
choose_conv(model, oc, oc),
]
)
self.bns = nn.ModuleList(
[nn.BatchNorm2d(ic), nn.BatchNorm2d(oc), nn.BatchNorm2d(oc)]
)
self.act = nn.LeakyReLU(0.05)
[docs] def forward(self, x):
y = self.conv(x)
graph = self.getgraph_fn(y)
for bn, gc in zip(self.bns, self.gcs):
y = self.act(bn(gc(y, graph)))
return x + y
[docs]class PostProcessBlock(nn.Module):
def __init__(self, ic, hc, getgraph_fn, model):
super(PostProcessBlock, self).__init__()
self.getgraph_fn = getgraph_fn
self.gcs = nn.ModuleList(
[
choose_conv(model, hc * 3 + 1, hc * 2),
choose_conv(model, hc * 2, hc),
choose_conv(model, hc, ic),
]
)
self.bns = nn.ModuleList(
[nn.BatchNorm2d(hc * 2), nn.BatchNorm2d(hc), nn.Identity()]
)
self.acts = nn.ModuleList(
[nn.LeakyReLU(0.05), nn.LeakyReLU(0.05), nn.Identity()]
)
[docs] def forward(self, x):
for act, bn, gc in zip(self.acts, self.bns, self.gcs):
graph = self.getgraph_fn(x)
x = act(bn(gc(x, graph)))
return x
[docs]class NonLocalGraph:
"""Non-local graph layer."""
def __init__(self, k, crop_size):
"""
Parameters
----------
- k: int, nearest neighbor number.
- crop_size: tuple, (edge_h, edge_w)
"""
self.k = k
self.local_mask = local_mask(crop_size)
def __call__(self, arr):
"""
Parameters
----------
- arr: torch.Tensor, input tensor of shape=(N,C,H,W)
Returns
-------
- torch.Tensor, output tensor of shape=(N,H*W*K,C)
"""
arr = arr.data.permute(0, 2, 3, 1)
b, h, w, f = arr.shape
arr = arr.view(b, h * w, f)
hw = h * w
dists = pairwise_dist(arr, self.k, self.local_mask)
selected = batched_index_select(arr, 1, dists.view(dists.shape[0], -1)).view(
b, hw, self.k, f
)
diff = arr.unsqueeze(2) - selected
return diff
# ==============================================================================
# functions and classes to be called within this module only
[docs]def choose_conv(model, ic, oc):
"""
Utility function to retrieve GConv or Conv layer from its name.
Parameters
----------
- model: str, available options cnn | cnn
- ic: int, input channel dimension size
- oc: int, output channel dimension size
Returns
-------
- torch.nn.Module, the layer instance
Raises
------
- NotImplementedError if op is not in ['gcnn', 'cnn']
"""
if model == "gcnn":
return GConv(ic, oc)
elif model == "cnn":
return Conv(ic, oc)
else:
raise NotImplementedError("Operation not implemented")
[docs]class GConv(nn.Module):
"""GConv layer."""
def __init__(self, ic, oc):
"""
Parameters
----------
- ic: int, input channel dimension size
- oc: int, output channel dimension size
"""
super(GConv, self).__init__()
self.conv1 = nn.Conv2d(ic, oc, 3, padding=1)
self.nla = NonLocalAggregator(ic, oc)
[docs] def forward(self, x, graph):
return torch.mean(torch.stack([self.conv1(x), self.nla(x, graph)]), dim=0)
[docs]class Conv(nn.Module):
"""GConv layer."""
def __init__(self, ic, oc):
"""
Parameters
----------
- ic: int, input channel dimension size
- oc: int, output channel dimension size
"""
super(Conv, self).__init__()
self.conv1 = nn.Conv2d(ic, oc, 3, padding=1)
self.conv2 = nn.Conv2d(ic, oc, 5, padding=2)
[docs] def forward(self, x, graph):
return torch.mean(torch.stack([self.conv1(x), self.conv2(x)]), dim=0)
[docs]class NonLocalAggregator(nn.Module):
"""NonLocalAggregator layer."""
def __init__(self, input_channels, out_channels):
"""
Parameters
----------
- input_channels: int, input channel dimension size
- output_channels: int, output channel dimension size
"""
super(NonLocalAggregator, self).__init__()
self.diff_fc = nn.Linear(input_channels, out_channels)
self.w_self = nn.Linear(input_channels, out_channels)
# self.bias = nn.Parameter(torch.randn(out_channels), requires_grad=True)
[docs] def forward(self, x, graph):
"""
Parameters
----------
- x: torch.Tensor, of shape=(N,C,H,W)
- graph: torch.Tensor
Returns
-------
- torch.Tensor of shape=(N,C,H,W)
"""
x = x.permute(0, 2, 3, 1)
b, h, w, f = x.shape
x = x.view(b, h * w, f)
# closest_graph = get_graph(x, self.k, local_mask) #this builds the graph
agg_weights = self.diff_fc(graph) # look closer
agg_self = self.w_self(x)
x_new = torch.mean(agg_weights, dim=-2) + agg_self # + self.bias
return x_new.view(b, h, w, x_new.shape[-1]).permute(0, 3, 1, 2)