Source code for astir.models.cellstate_recognet

"""
State Recognition Neural Network Model
"""

import torch
import torch.nn.functional as F
import torch.nn as nn
import math

from typing import Tuple


# The recognition net
[docs]class StateRecognitionNet(nn.Module): """ State Recognition Neural Network to get mean of z and standard deviation of z. The neural network architecture looks like this: G -> const * C -> const * C -> G (for mu) or -> G (for std). With batch normal layers after each activation output layers and dropout activation units :param C: number of classes :param G: number of proteins :param const: the size of the hidden layers are const times proportional to C :param dropout_rate: the dropout rate :param batch_norm: apply batch normal layers if True """ def __init__(self, C: int, G: int, const: int=2, dropout_rate: float=0, batch_norm: bool=False) -> None: super(StateRecognitionNet, self).__init__() self.batch_norm = batch_norm hidden_layer_size = math.ceil(const * C) # First hidden layer self.linear1 = nn.Linear(G, hidden_layer_size).float() self.dropout1 = nn.Dropout(dropout_rate) # Second hidden layer self.linear2 = nn.Linear(hidden_layer_size, hidden_layer_size).float() self.dropout2 = nn.Dropout(dropout_rate) # Output layer for mu self.linear3_mu = nn.Linear(hidden_layer_size, C).float() self.dropout_mu = nn.Dropout(dropout_rate) # Output layer for std self.linear3_std = nn.Linear(hidden_layer_size, C).float() self.dropout_std = nn.Dropout(dropout_rate) # Batch normal layers if self.batch_norm: self.bn1 = nn.BatchNorm1d(num_features=hidden_layer_size).float() self.bn2 = nn.BatchNorm1d(num_features=hidden_layer_size).float() self.bn_out_mu = nn.BatchNorm1d(num_features=C).float() self.bn_out_std = nn.BatchNorm1d(num_features=C).float()
[docs] def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ One forward pass of the StateRecognitionNet """ # Input --linear1--> Hidden1 x = self.linear1(x) if self.batch_norm: x = self.bn1(x) x = F.relu(x) # x = self.dropout1(x) # Hidden1 --linear2--> Hidden2 x = self.linear2(x) if self.batch_norm: x = self.bn2(x) x = F.relu(x) # x = self.dropout2(x) # Hidden2 --linear3_mu--> mu mu_z = self.linear3_mu(x) if self.batch_norm: mu_z = self.bn_out_mu(mu_z) # mu_z = self.dropout_mu(mu_z) # Hidden2 --linear3_std--> std std_z = self.linear3_std(x) if self.batch_norm: std_z = self.bn_out_std(std_z) # std_z = self.dropout_std(std_z) return mu_z, std_z