Source code for astir.models.abstract

from typing import Tuple, List, Dict
import warnings

import torch
import numpy as np

from astir.data import SCDataset


[docs]class AstirModel: """Abstract class to perform statistical inference to assign. This module is the super class of `CellTypeModel` and `CellStateModel` and is not supposed to be instantiated. """ def __init__(self, dset: SCDataset, random_seed: int, dtype: torch.dtype) -> None: if not isinstance(random_seed, int): raise NotClassifiableError("Random seed is expected to be an integer.") torch.manual_seed(random_seed) torch.cuda.manual_seed_all(random_seed) torch.cuda.manual_seed(random_seed) np.random.seed(random_seed) if dtype != torch.float32 and dtype != torch.float64: raise NotClassifiableError( "dtype must be one of torch.float32 and torch.float64." ) elif dtype != dset.get_dtype(): raise NotClassifiableError("dtype must be the same as `dset`.") self._dtype = dtype self._data = None self._variables = None self._losses = None self._dset = dset self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self._is_converged = False
[docs] def get_losses(self) -> float: """ Getter for losses. :return: self.losses :rtype: float """ if self._losses is None: raise Exception("The model has not been trained yet") return self._losses
[docs] def get_scdataset(self) -> SCDataset: """Getter for the `SCDataset`. :return: `self._dset` :rtype: SCDataset """ return self._dset
[docs] def get_data(self): return self._data
[docs] def get_variables(self): """ Returns all variables :return: self._variables """ return self._variables
[docs] def is_converged(self) -> bool: """ Returns True if the model converged :return: self._is_converged """ return self._is_converged
def _param_init(self) -> None: raise NotImplementedError("AbstractModel is not supposed to be instantiated.") def _forward( self, Y: torch.Tensor, X: torch.Tensor, design: torch.Tensor ) -> torch.Tensor: raise NotImplementedError("AbstractModel is not supposed to be instantiated.")
[docs] def fit( self, max_epochs: int, learning_rate: float, batch_size: int, delta_loss: float, msg: str, ) -> None: raise NotImplementedError("AbstractModel is not supposed to be instantiated.")
class NotClassifiableError(RuntimeError): """ Raised when the input data is not classifiable. """ pass