Source code for astir.models.celltype

""" 
Cell Type Model
"""

from .abstract import AstirModel
from astir.data import SCDataset
from .celltype_recognet import TypeRecognitionNet
import torch
import seaborn as sns
import re
from typing import Tuple, List, Dict
import warnings
from tqdm import trange
from torch.autograd import Variable
from torch.distributions import (
    Normal,
    StudentT,
    MultivariateNormal,
    LowRankMultivariateNormal,
)
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from scipy import stats


[docs]class CellTypeModel(AstirModel): """Class to perform statistical inference to assign cells to cell types. :param dset: the input gene expression dataframe :type dset: SCDataset :param random_seed: the random seed for parameter initialization, defaults to 1234 :type random_seed: int, optional :param dtype: the data type of parameters, should be the same as `dset`, defaults to torch.float64 :type dtype: torch.dtype, optional """ def __init__( self, dset: SCDataset, random_seed: int = 1234, dtype: torch.dtype = torch.float64, ) -> None: super().__init__(dset, random_seed, dtype) self.losses = None # losses after optimization self.cov_mat = None # temporary -- remove self._assignment = None self._recog = TypeRecognitionNet( dset.get_n_classes(), dset.get_n_features() ).to(self._device, dtype=dtype) self._param_init() def _param_init(self) -> None: """ Initializes parameters and design matrices. """ G = self._dset.get_n_features() C = self._dset.get_n_classes() # Establish data self._data = { "log_alpha": torch.log(torch.ones(C + 1, dtype=self._dtype) / (C + 1)).to( self._device ), "rho": self._dset.get_marker_mat().to(self._device), } # Initialize mu, log_delta delta_init_mean = torch.log( torch.log(torch.tensor(3.0, dtype=self._dtype)) ) # the log of the log of this is the multiplier t = torch.distributions.Normal( delta_init_mean.clone().detach().to(self._dtype), torch.tensor(0.1, dtype=self._dtype), ) log_delta_init = t.sample((G, C + 1)) mu_init = torch.log(self._dset.get_mu()).to(self._device) mu_init = mu_init - ( self._data["rho"] * torch.exp(log_delta_init).to(self._device) ).mean(1) mu_init = mu_init.reshape(-1, 1) # Create initialization dictionary initializations = { "mu": mu_init, "log_sigma": torch.log(self._dset.get_sigma()).to(self._device), "log_delta": log_delta_init, "p": torch.zeros((G, C + 1), dtype=self._dtype, device=self._device), } P = self._dset.get_design().shape[1] # Add additional columns of mu for anything in the design matrix initializations["mu"] = torch.cat( [ initializations["mu"], torch.zeros((G, P - 1), dtype=self._dtype, device=self._device), ], 1, ) # Create trainable variables self._variables = {} for (n, v) in initializations.items(): self._variables[n] = Variable(v.clone()).to(self._device) self._variables[n].requires_grad = True # @profile ## Declare pytorch forward fn def _forward( self, Y: torch.Tensor, X: torch.Tensor, design: torch.Tensor ) -> torch.Tensor: """One forward pass. :param Y: a sample from the dataset :type Y: torch.Tensor :param X: normalized sample data :type X: torch.Tensor :param design: the corresponding row of design matrix :type design: torch.Tensor :return: the cost (elbo) of the current pass :rtype: torch.Tensor """ G = self._dset.get_n_features() C = self._dset.get_n_classes() N = Y.shape[0] Y_spread = Y.reshape(-1, G, 1).repeat(1, 1, C + 1) delta_tilde = torch.exp(self._variables["log_delta"]) # + np.log(0.5) mean = delta_tilde * self._data["rho"] mean2 = torch.mm(design, self._variables["mu"].T) ## N x P * P x G mean2 = mean2.reshape(-1, G, 1).repeat(1, 1, C + 1) mean = mean + mean2 # now do the variance modelling p = torch.sigmoid(self._variables["p"]) sigma = torch.exp(self._variables["log_sigma"]) v1 = (self._data["rho"] * p).T * sigma v2 = torch.pow(sigma, 2) * (1 - torch.pow(self._data["rho"] * p, 2)).T v1 = v1.reshape(1, C + 1, G, 1).repeat(N, 1, 1, 1) # extra 1 is the "rank" v2 = v2.reshape(1, C + 1, G).repeat(N, 1, 1) + 1e-6 dist = LowRankMultivariateNormal( loc=torch.exp(mean).permute(0, 2, 1), cov_factor=v1, cov_diag=v2 ) log_p_y_on_c = dist.log_prob(Y_spread.permute(0, 2, 1)) gamma = self._recog.forward(X) elbo = ( gamma * (log_p_y_on_c + self._data["log_alpha"] - torch.log(gamma)) ).sum() return -elbo # @profile
[docs] def fit( self, max_epochs: int = 50, learning_rate: float = 1e-3, batch_size: int = 128, delta_loss: float = 1e-3, msg: str = "", ) -> None: for l in self.fit_yield_loss(max_epochs, learning_rate, batch_size, delta_loss, msg): pass
[docs] def fit_yield_loss( self, max_epochs: int = 50, learning_rate: float = 1e-3, batch_size: int = 128, delta_loss: float = 1e-3, msg: str = "", ) -> None: """ Runs train loops until the convergence reaches delta_loss for\ delta_loss_batch sizes or for max_epochs number of times :param max_epochs: number of train loop iterations, defaults to 50 :param learning_rate: the learning rate, defaults to 0.01 :param batch_size: the batch size, defaults to 128 :param delta_loss: stops iteration once the loss rate reaches\ delta_loss, defaults to 0.001 :param msg: iterator bar message, defaults to empty string """ # Make dataloader dataloader = DataLoader( self._dset, batch_size=min(batch_size, len(self._dset)), shuffle=True ) # Run training loop losses = [] per = 1 # Construct optimizer opt_params = list(self._variables.values()) + list(self._recog.parameters()) optimizer = torch.optim.Adam(opt_params, lr=learning_rate) _, exprs_X, _ = self._dset[:] # calls dset.get_item iterator = trange( max_epochs, desc="training restart" + msg, unit="epochs", bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{rate_fmt}{postfix}]", ) for ep in iterator: # for ep in range(max_epochs): L = None loss = torch.tensor(0.0, dtype=self._dtype) for batch in dataloader: Y, X, design = batch optimizer.zero_grad() L = self._forward(Y, X, design) L.backward() optimizer.step() with torch.no_grad(): loss = loss + L if len(losses) > 0: per = abs((loss - losses[-1]) / losses[-1]) losses.append(loss) # iterator.set_postfix_str("current loss: " + str(round(float(loss), 1))) yield round(float(loss), 1) if per <= delta_loss: self._is_converged = True # iterator.close() break # Save output self._assignment = pd.DataFrame( self._recog.forward(exprs_X).detach().cpu().numpy() ) self._assignment.columns = self._dset.get_classes() + ["Other"] self._assignment.index = self._dset.get_cell_names() if self._losses is None: self._losses = torch.tensor(losses) else: self._losses = torch.cat( (self._losses.view(self._losses.shape[0]), torch.tensor(losses)), dim=0 )
[docs] def predict(self, new_dset: pd.DataFrame) -> np.array: """Feed `new_dset` to the recognition net to get a prediction. :param new_dset: the dataset to be predicted :type new_dset: pd.DataFrame :return: the resulting cell type assignment :rtype: np.array """ _, exprs_X, _ = new_dset[:] g = pd.DataFrame(self._recog.forward(exprs_X).detach().cpu().numpy()) return g
[docs] def get_assignment(self) -> np.array: """Get the final assignment of the dataset. :return: the final assignment of the dataset :rtype: np.array """ if self._assignment is None: raise Exception("The type model has not been trained yet") return self._assignment
[docs] def get_recognet(self) -> TypeRecognitionNet: """ Getter for the recognition net. :return: the trained recognition net """ return self._recog
def _most_likely_celltype( self, row: pd.DataFrame, threshold: float, cell_types: List[str] ) -> str: """Given a row of the assignment matrix, return the most likely cell type :param row: the row of cell assignment matrix to be evaluated :type row: pd.DataFrame :param threshold: the higher bound of the maximun probability to classify a cell as `Unknown` :type threshold: float :param cell_types: the names of cell types, in the same order as the features of the row :type cell_types: List[str] :return: the most likely cell type of this cell :rtype: str """ row = row.values max_prob = np.max(row) if max_prob < threshold: return "Unknown" return cell_types[np.argmax(row)]
[docs] def get_celltypes(self, threshold=0.7) -> pd.DataFrame: """ Get the most likely cell types A cell is assigned to a cell type if the probability is greater than threshold. If no cell types have a probability higher than threshold, then "Unknown" is returned :param threshold: the probability threshold above which a cell is assigned to a cell type :return: a data frame with most likely cell types for each """ type_probability = self.get_assignment() cell_types = list(type_probability.columns) cell_type_assignments = type_probability.apply( self._most_likely_celltype, axis=1, threshold=threshold, cell_types=cell_types, ) cell_type_assignments = pd.DataFrame(cell_type_assignments) cell_type_assignments.columns = ["cell_type"] return cell_type_assignments
def _compare_marker_between_types( self, curr_type, celltype_to_compare, marker, cell_types, alpha: float = 0.05 ): """For a given cell type and two proteins, ensure marker is expressed at higher level using t-test """ current_marker_ind = np.array(self._dset.get_features()) == marker cells_x = np.array(cell_types) == curr_type cells_y = np.array(cell_types) == celltype_to_compare x = self._dset.get_exprs().detach().cpu().numpy()[cells_x, current_marker_ind] y = self._dset.get_exprs().detach().cpu().numpy()[cells_y, current_marker_ind] stat = np.NaN pval = np.Inf note = "Only 1 cell in a type: comparison not possible" if len(x) > 1 and len(y) > 1: tt = stats.ttest_ind(x, y) stat = tt.statistic pval = tt.pvalue note = None if not (stat > 0 and pval < alpha): rdict = { "current_marker": marker, "curr_type": curr_type, "celltype_to_compare": celltype_to_compare, "mean_A": x.mean(), "mean_Y": y.mean(), "p-val": pval, "note": note, } return rdict return None
[docs] def plot_clustermap( self, plot_name: str = "celltype_protein_cluster.png", threshold: float = 0.7, figsize: Tuple[float, float] = (7, 5), ) -> None: """Save the heatmap of protein content in cells with cell types labeled. :param plot_name: name of the plot, extension(e.g. .png or .jpg) is needed, defaults to "celltype_protein_cluster.png" :type plot_name: str, optional :param threshold: the probability threshold above which a cell is assigned to a cell type, defaults to 0.7 :type threshold: float, optional """ expr_df = self._dset.get_exprs_df() scaler = StandardScaler() for feature in expr_df.columns: expr_df[feature] = scaler.fit_transform( expr_df[feature].values.reshape((expr_df[feature].shape[0], 1)) ) expr_df["cell_type"] = self.get_celltypes(threshold=threshold) expr_df = expr_df.sort_values(by=["cell_type"]) types = expr_df.pop("cell_type") types_uni = types.unique() lut = dict(zip(types_uni, sns.color_palette("BrBG", len(types_uni)))) col_colors = pd.DataFrame(types.map(lut)) cm = sns.clustermap( expr_df.T, xticklabels=False, cmap="vlag", col_cluster=False, col_colors=col_colors, figsize=figsize, ) for t in types_uni: cm.ax_col_dendrogram.bar(0, 0, color=lut[t], label=t, linewidth=0) cm.ax_col_dendrogram.legend( title="Cell Types", loc="center", ncol=3, bbox_to_anchor=(0.8, 0.8) ) cm.savefig(plot_name, dpi=150)
[docs] def diagnostics(self, cell_type_assignments: list, alpha: float) -> pd.DataFrame: """Run diagnostics on cell type assignments See :meth:`astir.Astir.diagnostics_celltype` for full documentation """ problems = [] # Want to construct a data frame that models rho with # cell type names on the columns and feature names on the rows g_df = pd.DataFrame(self._data["rho"].detach().cpu().numpy()) g_df.columns = self._dset.get_classes() + ["Other"] g_df.index = self._dset.get_features() for curr_type in self._dset.get_classes(): if not curr_type in cell_type_assignments: continue current_markers = g_df.index[g_df[curr_type] == 1] for current_marker in current_markers: # find all the cell types that shouldn't highly express this marker celltypes_to_compare = g_df.columns[g_df.loc[current_marker] == 0] for celltype_to_compare in celltypes_to_compare: if not celltype_to_compare in cell_type_assignments: continue is_problem = self._compare_marker_between_types( curr_type, celltype_to_compare, current_marker, cell_type_assignments, alpha, ) if is_problem is not None: problems.append(is_problem) col_names = [ "feature", "should be expressed higher in", "than", "mean cell type 1", "mean cell type 2", "p-value", "note", ] df_issues = None if len(problems) > 0: df_issues = pd.DataFrame(problems) df_issues.columns = col_names else: df_issues = pd.DataFrame(columns=col_names) return df_issues