Source code for astir.models.cellstate

"""
Cell State Model
"""
from typing import Tuple, List, Dict, Union
import warnings
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import yaml
from .abstract import AstirModel
from astir.data import SCDataset
from .cellstate_recognet import StateRecognitionNet
from tqdm import trange
from torch.autograd import Variable
from torch.utils.data import DataLoader


[docs]class CellStateModel(AstirModel): """Class to perform statistical inference to on the activation of states (pathways) across cells :param df_gex: the input gene expression dataframe :param marker_dict: the gene marker dictionary :param random_seed: seed number to reproduce results, defaults to 1234 :param dtype: torch datatype to use in the model """ def __init__( self, dset: SCDataset, const: int = 2, dropout_rate: float = 0, batch_norm: bool = False, random_seed: int = 42, dtype: torch.dtype = torch.float64, ) -> None: super().__init__(dset, random_seed, dtype) # Setting random seeds self.random_seed = random_seed torch.manual_seed(self.random_seed) torch.cuda.manual_seed_all(self.random_seed) torch.cuda.manual_seed(self.random_seed) np.random.seed(self.random_seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self._dset = dset self._optimizer = None self._losses = torch.empty(0, dtype=self._dtype) self._param_init(const, dropout_rate, batch_norm) # Convergence flag self._is_converged = False def _param_init(self, const, dropout_rate, batch_norm) -> None: """ Initializes sets of parameters """ N = len(self._dset) C = self._dset.get_n_classes() G = self._dset.get_n_features() initializations = { "log_sigma": torch.log(self._dset.get_sigma().mean()), "mu": torch.reshape(self._dset.get_mu(), (1, -1)), } # Include beta or not d = torch.distributions.Uniform( torch.tensor(0.0, dtype=self._dtype), torch.tensor(1.5, dtype=self._dtype) ) initializations["log_w"] = torch.log(d.sample((C, self._dset.get_n_features()))) self._variables = { n: i.to(self._device).detach().clone().requires_grad_() for (n, i) in initializations.items() } self._data = { "rho": self._dset.get_marker_mat().T.to(self._device), } self._recog = StateRecognitionNet( C, G, const=const, dropout_rate=dropout_rate, batch_norm=batch_norm ).to(device=self._device, dtype=self._dtype) def _loss_fn( self, mu_z: torch.Tensor, std_z: torch.Tensor, z_sample: torch.Tensor, y_in: torch.Tensor, ) -> torch.Tensor: """ Returns the calculated loss :param mu_z: the predicted mean of z :param std_z: the predicted standard deviation of z :param z_sample: the sampled z values :param y_in: the input data :return: the loss """ S = y_in.shape[0] # log posterior q(z) approx p(z|y) q_z_dist = torch.distributions.Normal(loc=mu_z, scale=torch.exp(std_z)) log_q_z = q_z_dist.log_prob(z_sample) # log likelihood p(y|z) rho_w = torch.mul(self._data["rho"], torch.exp(self._variables["log_w"])) mean = self._variables["mu"] + torch.matmul(z_sample, rho_w) std = torch.exp(self._variables["log_sigma"]).reshape(1, -1) p_y_given_z_dist = torch.distributions.Normal(loc=mean, scale=std) log_p_y_given_z = p_y_given_z_dist.log_prob(y_in) # log prior p(z) p_z_dist = torch.distributions.Normal(0, 1) log_p_z = p_z_dist.log_prob(z_sample) loss = (1 / S) * ( torch.sum(log_q_z) - torch.sum(log_p_y_given_z) - torch.sum(log_p_z) ) return loss def _forward( self, Y: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ One forward pass :param Y: dataset to do forward pass on :return: mu_z, std_z, z_sample """ mu_z, std_z = self._recog(Y) std = torch.exp(std_z) eps = torch.randn_like(std) z_sample = eps * std + mu_z return mu_z, std_z, z_sample # @profile
[docs] def fit( self, max_epochs: int = 50, learning_rate: float = 1e-3, batch_size: int = 128, delta_loss: float = 1e-3, delta_loss_batch: int = 10, msg: str = "", ) -> List[float]: """ Runs train loops until the convergence reaches delta_loss for\ delta_loss_batch sizes or for max_epochs number of times :param max_epochs: number of train loop iterations, defaults to 50 :param learning_rate: the learning rate, defaults to 0.01 :param batch_size: the batch size, defaults to 128 :param delta_loss: stops iteration once the loss rate reaches\ delta_loss, defaults to 0.001 :param delta_loss_batch: the batch size to consider delta loss,\ defaults to 10 :param msg: iterator bar message, defaults to empty string """ losses = [] # Returns early if the model has already converged if self._is_converged: return losses if delta_loss_batch >= max_epochs: warnings.warn("Delta loss batch size is greater than the number of epochs") # Create an optimizer if there is no optimizer if self._optimizer is None: opt_params = list(self._recog.parameters()) + list(self._variables.values()) self._optimizer = torch.optim.Adam(opt_params, lr=learning_rate) if self._losses.shape[0] >= delta_loss_batch: prev_mean = torch.mean(self._losses[-delta_loss_batch:]) else: prev_mean = None delta_cond_met = False iterator = trange( max_epochs, desc="training restart" + msg, unit="epochs", bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{rate_fmt}{postfix}]", ) train_iterator = DataLoader( self._dset, batch_size=min(batch_size, len(self._dset)) ) for ep in iterator: # for ep in range(max_epochs): for i, (y_in, x_in, _) in enumerate(train_iterator): self._optimizer.zero_grad() mu_z, std_z, z_samples = self._forward(x_in) loss = self._loss_fn(mu_z, std_z, z_samples, x_in) loss.backward() self._optimizer.step() losses.append(loss.cpu().detach().item()) start_index = ep - delta_loss_batch + 1 end_index = start_index + delta_loss_batch if start_index >= 0: curr_mean = sum(losses[start_index:end_index]) / len( losses[start_index:end_index] ) elif self._losses.shape[0] >= -start_index: last_ten_losses = torch.cat( ( self._losses[start_index:], torch.tensor(losses[:end_index], dtype=torch.float64), ) ) curr_mean = torch.mean(last_ten_losses).item() else: curr_mean = None if prev_mean is not None: curr_delta_loss = (prev_mean - curr_mean) / prev_mean delta_cond_met = 0 <= curr_delta_loss < delta_loss # iterator.set_postfix_str("current loss: " + str(round(losses[ep], 1))) yield round(losses[ep], 1) prev_mean = curr_mean if delta_cond_met: losses = losses[0 : ep + 1] self._is_converged = True # iterator.close() break if self._losses is None: self._losses = torch.tensor(losses, dtype=self._dtype) else: self._losses = torch.cat( (self._losses, torch.tensor(losses, dtype=self._dtype)) )
# return losses
[docs] def get_recognet(self) -> StateRecognitionNet: """ Getter for the recognition net :return: the trained recognition net """ return self._recog
[docs] def get_final_mu_z(self, new_dset: SCDataset = None) -> torch.Tensor: """ Returns the mean of the predicted z values for each core :param new_dset: returns the predicted z values of this dataset on the existing model. If None, it predicts using the existing dataset :return: the mean of the predicted z values for each core """ if new_dset is None: _, x_in, _ = self._dset[:] # should be the scaled one else: _, x_in, _ = new_dset[:] final_mu_z, _, _ = self._forward(x_in) return final_mu_z
[docs] def get_correlations(self) -> np.array: state_assignment = self.get_final_mu_z().detach().cpu().numpy() y_in = self._dset.get_exprs() feature_names = self._dset.get_features() state_names = self._dset.get_classes() G = self._dset.get_n_features() C = self._dset.get_n_classes() corr_mat = np.zeros((C, G)) # Make a matrix of correlations between all states and proteins for c, state in enumerate(state_names): for g, feature in enumerate(feature_names): states = state_assignment[:, c] protein = y_in[:, g].cpu() corr_mat[c, g] = np.corrcoef(protein, states)[0, 1] return corr_mat
[docs] def diagnostics(self) -> pd.DataFrame: """ Run diagnostics on cell state assignments :return: diagnostics """ feature_names = self._dset.get_features() state_names = self._dset.get_classes() corr_mat = self.get_correlations() # Correlation values of all marker proteins marker_mat = self._dset.get_marker_mat().T.cpu().numpy() marker_corr = marker_mat * corr_mat marker_corr[marker_mat == 0] = np.inf # Smallest correlation values for each pathway min_marker_corr = np.min(marker_corr, axis=1).reshape(-1, 1) min_marker_proteins = np.take(feature_names, np.argmin(marker_corr, axis=1)) # Correlation values of all non marker proteins non_marker_mat = 1 - self._dset.get_marker_mat().T.cpu().numpy() non_marker_corr = non_marker_mat * corr_mat non_marker_corr[non_marker_mat == 0] = -np.inf # Any correlation values where non marker proteins is greater than # the smallest correlation values of marker proteins bad_corr_marker = np.array(non_marker_corr > min_marker_corr, dtype=np.int32) # Problem summary indices = np.argwhere(bad_corr_marker > 0) col_names = [ "pathway", "protein A", "correlation of protein A", "protein B", "correlation of protein B", "note", ] problems = [] for index in indices: state_index = index[0] protein_index = index[1] state = state_names[index[0]] marker_protein = min_marker_proteins[state_index] non_marker_protein = feature_names[protein_index] problem = { "pathway": state, "marker_protein": marker_protein, "corr_of_marker_protein": min_marker_corr[state_index][0], "non_marker_protein": non_marker_protein, "corr_of_non_marker_protein": non_marker_corr[ state_index, protein_index ], "msg": "{} is marker for {} but {} isn't".format( marker_protein, state, non_marker_protein ), } problems.append(problem) if len(problems) > 0: df_issues = pd.DataFrame(problems) df_issues.columns = col_names else: df_issues = pd.DataFrame(columns=col_names) return df_issues
[docs] def get_losses(self) -> np.array: """ Getter for losses :return: a torch tensor of losses for each training iteration the model runs """ if self._losses is None: raise Exception("The state model has not been trained yet") return self._losses
[docs] def get_scdataset(self) -> SCDataset: """ Returns the input dataset :return: self._dset """ return self._dset
[docs] def is_converged(self) -> bool: """ Returns True if the model converged :return: self._is_converged """ return self._is_converged
[docs] def get_data(self) -> Dict[str, torch.Tensor]: """ Returns data parameter :return: self._data """ return self._data
[docs] def get_variables(self) -> Dict[str, torch.Tensor]: """ Returns all variables :return: self._variables """ return self._variables
class NotClassifiableError(RuntimeError): """ Raised when the input data is not classifiable. """ pass