"""
State Recognition Neural Network Model
"""
import torch
import torch.nn.functional as F
import torch.nn as nn
import math
from typing import Tuple
# The recognition net
[docs]class StateRecognitionNet(nn.Module):
""" State Recognition Neural Network to get mean of z and standard
deviation of z. The neural network architecture looks like this: G ->
const * C -> const * C -> G (for mu) or -> G (for std). With batch
normal layers after each activation output layers and dropout
activation units
:param C: number of classes
:param G: number of proteins
:param const: the size of the hidden layers are const times proportional
to C
:param dropout_rate: the dropout rate
:param batch_norm: apply batch normal layers if True
"""
def __init__(self, C: int, G: int, const: int=2,
dropout_rate: float=0, batch_norm: bool=False) -> None:
super(StateRecognitionNet, self).__init__()
self.batch_norm = batch_norm
hidden_layer_size = math.ceil(const * C)
# First hidden layer
self.linear1 = nn.Linear(G, hidden_layer_size).float()
self.dropout1 = nn.Dropout(dropout_rate)
# Second hidden layer
self.linear2 = nn.Linear(hidden_layer_size, hidden_layer_size).float()
self.dropout2 = nn.Dropout(dropout_rate)
# Output layer for mu
self.linear3_mu = nn.Linear(hidden_layer_size, C).float()
self.dropout_mu = nn.Dropout(dropout_rate)
# Output layer for std
self.linear3_std = nn.Linear(hidden_layer_size, C).float()
self.dropout_std = nn.Dropout(dropout_rate)
# Batch normal layers
if self.batch_norm:
self.bn1 = nn.BatchNorm1d(num_features=hidden_layer_size).float()
self.bn2 = nn.BatchNorm1d(num_features=hidden_layer_size).float()
self.bn_out_mu = nn.BatchNorm1d(num_features=C).float()
self.bn_out_std = nn.BatchNorm1d(num_features=C).float()
[docs] def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
""" One forward pass of the StateRecognitionNet
"""
# Input --linear1--> Hidden1
x = self.linear1(x)
if self.batch_norm:
x = self.bn1(x)
x = F.relu(x)
# x = self.dropout1(x)
# Hidden1 --linear2--> Hidden2
x = self.linear2(x)
if self.batch_norm:
x = self.bn2(x)
x = F.relu(x)
# x = self.dropout2(x)
# Hidden2 --linear3_mu--> mu
mu_z = self.linear3_mu(x)
if self.batch_norm:
mu_z = self.bn_out_mu(mu_z)
# mu_z = self.dropout_mu(mu_z)
# Hidden2 --linear3_std--> std
std_z = self.linear3_std(x)
if self.batch_norm:
std_z = self.bn_out_std(std_z)
# std_z = self.dropout_std(std_z)
return mu_z, std_z