Source code for dunedn.networks.uscg.uscg_net_utils

"""
    This module contains the utility functions for USCG network.
"""
from torch import nn


[docs]def weight_xavier_init(*models): for model in models: for module in model.modules(): if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): # nn.init.xavier_normal_(module.weight) nn.init.orthogonal_(module.weight) # nn.init.kaiming_normal_(module.weight) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.BatchNorm2d): module.weight.data.fill_(1) module.bias.data.zero_()