Source code for dunedn.networks.uscg.uscg_net_blocks

"""
    This module contains the USCG Net building blocks.
"""
import torch
from torch import nn


[docs]class SCG_Block(nn.Module): def __init__( self, in_ch, hidden_ch=6, node_size=(32, 32), add_diag=True, dropout=0.2 ): super(SCG_Block, self).__init__() self.node_size = node_size self.hidden = hidden_ch self.nodes = node_size[0] * node_size[1] self.add_diag = add_diag self.pool = nn.AdaptiveMaxPool2d(node_size) self.mu = nn.Sequential( nn.Conv2d(in_ch, hidden_ch, 3, padding=1, bias=True), nn.Dropout(dropout), ) self.logvar = nn.Sequential( nn.Conv2d(in_ch, hidden_ch, 1, 1, bias=True), nn.Dropout(dropout), )
[docs] def forward(self, gx): B, C, H, W = gx.size() mu, log_var = self.mu(gx), self.logvar(gx) if self.training: std = torch.exp(log_var.reshape(B, self.nodes, self.hidden)) eps = torch.randn_like(std) z = mu.reshape(B, self.nodes, self.hidden) + std * eps else: z = mu.reshape(B, self.nodes, self.hidden) A = torch.matmul(z, z.permute(0, 2, 1)) A = torch.relu(A) Ad = torch.diagonal(A, dim1=1, dim2=2) mean = torch.mean(Ad, dim=1) gamma = torch.sqrt(1 + 1.0 / mean).unsqueeze(-1).unsqueeze(-1) if self.training: dl_loss = ( gamma.mean() * torch.log(Ad[Ad < 1] + 1.0e-7).sum() / (A.size(0) * A.size(1) * A.size(2)) ) kl_loss = ( -0.5 / self.nodes * torch.mean( torch.sum(1 + 2 * log_var - mu.pow(2) - log_var.exp().pow(2), 1) ) ) loss = kl_loss - dl_loss if self.training else None if self.add_diag: diag = [] for i in range(Ad.shape[0]): diag.append(torch.diag(Ad[i, :]).unsqueeze(0)) A = A + gamma * torch.cat(diag, 0) A = self.laplacian_matrix(A, self_loop=True) z_hat = ( gamma.mean() * mu.reshape(B, self.nodes, self.hidden) * (1.0 - log_var.reshape(B, self.nodes, self.hidden)) ) return A, gx, loss, z_hat
[docs] def laplacian_matrix(self, A, self_loop=False): """ Computes normalized Laplacian matrix: A (B, N, N) """ if self_loop: A = A + torch.eye(A.size(1), device=A.device).unsqueeze(0) deg_inv_sqrt = (torch.sum(A, 1) + 1e-5).pow(-0.5) LA = deg_inv_sqrt.unsqueeze(-1) * A * deg_inv_sqrt.unsqueeze(-2) return LA
[docs]class GCN_Layer(nn.Module): def __init__( self, in_features, out_features, bnorm=True, activation=nn.ReLU(), dropout=None ): super(GCN_Layer, self).__init__() self.bnorm = bnorm fc = [nn.Linear(in_features, out_features)] if bnorm: fc.append(BatchNorm_GCN(out_features)) if activation is not None: fc.append(activation) if dropout is not None: fc.append(nn.Dropout(dropout)) self.fc = nn.Sequential(*fc)
[docs] def forward(self, data): x, A = data y = self.fc(torch.bmm(A, x)) return [y, A]
[docs]class Pooling_Block(nn.Module): def __init__(self, c, h, w): """ Parameters: c: int, image channels h: int, output height w: int, output width """ super(Pooling_Block, self).__init__() self.pooling = nn.Sequential( nn.AdaptiveMaxPool2d((h, w)), nn.Conv2d(c, c, 3, padding=1, bias=False), nn.BatchNorm2d(c), nn.ReLU(), )
[docs] def forward(self, x): return self.pooling(x)
[docs]class Recombination_Layer(nn.Module): """ This layer recombines an output with a residual connection through a (1,1) convolution. It first concatenates the two inputs along the channel dimension and then it applies the convolution. """ def __init__(self): super(Recombination_Layer, self).__init__() self.conv = nn.Conv2d(2, 1, 1)
[docs] def forward(self, x, y): return self.conv(torch.cat([x, y], axis=1))
# ============================================================================== # functions and classes to be called within this module only
[docs]class BatchNorm_GCN(nn.BatchNorm1d): """Batch normalization over GCN features""" def __init__(self, num_features): super(BatchNorm_GCN, self).__init__(num_features)
[docs] def forward(self, x): return super(BatchNorm_GCN, self).forward(x.permute(0, 2, 1)).permute(0, 2, 1)