├── .gitignore ├── DGD_schematic.png ├── LICENSE ├── README.md ├── examples ├── scDGD_training_mousebrain5k.ipynb └── scDGD_training_mousebrain5k_predicting_prob.ipynb ├── pretrained_models ├── PBMC_Zheng2017 │ ├── data_obs.csv │ ├── data_var.csv │ ├── load_PBMC_Zheng2017.ipynb │ └── model │ │ ├── L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422_decoder.pt │ │ ├── L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422_gmm.pt │ │ ├── L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422_heldoutTestRepresentation.pt │ │ ├── L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422_representation.pt │ │ ├── L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422_testRepresentation.pt │ │ └── dgd_hyperparameters.json └── mousebrain_1M │ ├── data_info.zip │ ├── load_mousebrain1M.ipynb │ └── model │ ├── dgd_hyperparameters.json │ ├── dgd_mousebrain_1m_a1_rs0_decoder.pt │ └── dgd_mousebrain_1m_a1_rs0_gmm.pt ├── scDGD ├── __init__.py ├── classes │ ├── __init__.py │ ├── data.py │ ├── output_distributions.py │ ├── prior.py │ └── representation.py ├── functions │ ├── __init__.py │ ├── analysis.py │ ├── data.py │ └── train.py └── models │ ├── __init__.py │ └── models.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # mac related 163 | *.DS_Store -------------------------------------------------------------------------------- /DGD_schematic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Center-for-Health-Data-Science/scDGD/1314903f66b6d290d041f0c2cd73b258ea2c3d6b/DGD_schematic.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Center-for-Health-Data-Science 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # scDGD 2 | 3 | scDGD is an application of our encoder-less generative model, the Deep Generative Decoder (DGD), to single-cell transcriptomics data. 4 | 5 | It learns low-dimensional representations of full transcriptomics matrices without feature selection. The low-dimensional embeddings are of higher quality than comparable methods such as scVI and the data reconstruction is highly data efficient, outperforming scVI and scVAE, especially on very small data sets. 6 | 7 | For more information about the underlying method and our results, check out the [paper](https://academic.oup.com/bioinformatics/article/39/9/btad497/7241685). 8 | 9 | 10 | 11 | ## Installation 12 | 13 | You can install the package via 14 | ``` 15 | pip install git+https://github.com/Center-for-Health-Data-Science/scDGD 16 | ``` 17 | 18 | ## How to use it 19 | 20 | From our experience, scDGD can be applied to data sets with as few as 500 cells and as many as one million. 21 | 22 | Check out the notebook showing an example of how to use scDGD: 23 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Center-for-Health-Data-Science/scDGD/blob/HEAD/examples/scDGD_training_mousebrain5k.ipynb) 24 | 25 | We have also uploaded pre-trained models for the PBMC (Zheng et al. 2017) and the 10X 1 million mouse brain data sets, along with notebooks showing how to load them [in pretrained_models](https://github.com/Center-for-Health-Data-Science/scDGD/tree/main/pretrained_models). 26 | 27 | ## Reference 28 | 29 | If you use scDGD in your research, please consider citing 30 | 31 | ``` 32 | @article{schuster_deep_2023, 33 | title = {The Deep Generative Decoder: MAP estimation of representations improves modelling of single-cell RNA data}, 34 | volume = {39}, 35 | issn = {1367-4811}, 36 | shorttitle = {The Deep Generative Decoder}, 37 | url = {https://doi.org/10.1093/bioinformatics/btad497}, 38 | doi = {10.1093/bioinformatics/btad497}, 39 | number = {9}, 40 | journal = {Bioinformatics}, 41 | author = {Schuster, Viktoria and Krogh, Anders}, 42 | month = sep, 43 | year = {2023} 44 | } 45 | ``` -------------------------------------------------------------------------------- /pretrained_models/PBMC_Zheng2017/model/L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422_decoder.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Center-for-Health-Data-Science/scDGD/1314903f66b6d290d041f0c2cd73b258ea2c3d6b/pretrained_models/PBMC_Zheng2017/model/L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422_decoder.pt -------------------------------------------------------------------------------- /pretrained_models/PBMC_Zheng2017/model/L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422_gmm.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Center-for-Health-Data-Science/scDGD/1314903f66b6d290d041f0c2cd73b258ea2c3d6b/pretrained_models/PBMC_Zheng2017/model/L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422_gmm.pt -------------------------------------------------------------------------------- /pretrained_models/PBMC_Zheng2017/model/L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422_heldoutTestRepresentation.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Center-for-Health-Data-Science/scDGD/1314903f66b6d290d041f0c2cd73b258ea2c3d6b/pretrained_models/PBMC_Zheng2017/model/L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422_heldoutTestRepresentation.pt -------------------------------------------------------------------------------- /pretrained_models/PBMC_Zheng2017/model/L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422_representation.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Center-for-Health-Data-Science/scDGD/1314903f66b6d290d041f0c2cd73b258ea2c3d6b/pretrained_models/PBMC_Zheng2017/model/L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422_representation.pt -------------------------------------------------------------------------------- /pretrained_models/PBMC_Zheng2017/model/L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422_testRepresentation.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Center-for-Health-Data-Science/scDGD/1314903f66b6d290d041f0c2cd73b258ea2c3d6b/pretrained_models/PBMC_Zheng2017/model/L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422_testRepresentation.pt -------------------------------------------------------------------------------- /pretrained_models/PBMC_Zheng2017/model/dgd_hyperparameters.json: -------------------------------------------------------------------------------- 1 | {"latent": 20, 2 | "hidden": [100, 100, 100], 3 | "learning_rates": [[1e-3, 1e-2, 1e-2], [1e-4, 1e-2, 1e-2]], 4 | "lr_name": "322", 5 | "optim_beta": [0.5, 0.7], 6 | "weight_decay": 1e-4, 7 | "dropout": 0.1, 8 | "gmm_type": "diagonal", 9 | "n_components": 18, 10 | "r_init": 2, 11 | "scaling_type": "max", 12 | "output": "mean", 13 | "activation": "softplus", 14 | "mp_scale": 1, 15 | "hardness": 1, 16 | "sd_mean": 0.01, 17 | "sd_sd": 1, 18 | "dirichlet_a": 1, 19 | "n_genes": 32738, 20 | "name": "L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422"} -------------------------------------------------------------------------------- /pretrained_models/mousebrain_1M/data_info.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Center-for-Health-Data-Science/scDGD/1314903f66b6d290d041f0c2cd73b258ea2c3d6b/pretrained_models/mousebrain_1M/data_info.zip -------------------------------------------------------------------------------- /pretrained_models/mousebrain_1M/model/dgd_hyperparameters.json: -------------------------------------------------------------------------------- 1 | {"latent": 20, 2 | "hidden": [100, 100, 100], 3 | "learning_rates": [[1e-3, 1e-2, 1e-2], [1e-4, 1e-2, 1e-2]], 4 | "lr_name": "322", 5 | "lr_schedule": [0,500], 6 | "optim_beta": [0.5, 0.7], 7 | "weight_decay": 1e-4, 8 | "dropout": 0.1, 9 | "gmm_type": "diagonal", 10 | "n_components": 27, 11 | "r_init": 2, 12 | "scaling_type": "max", 13 | "output": "mean", 14 | "activation": "softplus", 15 | "mp_scale": 1, 16 | "hardness": 1, 17 | "sd_mean": 0.007, 18 | "sd_sd": 1, 19 | "dirichlet_a": 1, 20 | "n_genes": 27998, 21 | "name": "dgd_mousebrain_1m_a1_rs0"} -------------------------------------------------------------------------------- /pretrained_models/mousebrain_1M/model/dgd_mousebrain_1m_a1_rs0_decoder.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Center-for-Health-Data-Science/scDGD/1314903f66b6d290d041f0c2cd73b258ea2c3d6b/pretrained_models/mousebrain_1M/model/dgd_mousebrain_1m_a1_rs0_decoder.pt -------------------------------------------------------------------------------- /pretrained_models/mousebrain_1M/model/dgd_mousebrain_1m_a1_rs0_gmm.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Center-for-Health-Data-Science/scDGD/1314903f66b6d290d041f0c2cd73b258ea2c3d6b/pretrained_models/mousebrain_1M/model/dgd_mousebrain_1m_a1_rs0_gmm.pt -------------------------------------------------------------------------------- /scDGD/__init__.py: -------------------------------------------------------------------------------- 1 | from . import classes, models, functions -------------------------------------------------------------------------------- /scDGD/classes/__init__.py: -------------------------------------------------------------------------------- 1 | from scDGD.classes.data import scDataset 2 | from scDGD.classes.prior import GaussianMixture -------------------------------------------------------------------------------- /scDGD/classes/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | import scipy.sparse 5 | 6 | class scDataset(Dataset): 7 | """ 8 | """ 9 | def __init__(self, sparse_mtrx, meta_data, label_type, scaling_type='max', gene_selection=None, subset=None, sparse=False): 10 | """ 11 | Args: 12 | This is a custom data set for single cell transcriptomics data. 13 | It takes a sparse matrix of gene expression data and a pandas dataframe of metadata. 14 | 15 | sparse_mtrx: 16 | a scipy.sparse matrix of gene expression data with rows representing cells and columns representing transcripts 17 | meta_data: 18 | a pandas dataframe of metadata with rows representing cells and columns representing metadata 19 | scaling_type: 20 | a string specifying the type of scaling to use for the data. Options are 'mean' and 'max' 21 | this will either scale the data by the mean or max of each cell 22 | gene_selection (optional): 23 | a list of indices specifying which genes to use from the sparse matrix if feature selection is to be performed 24 | subset (optional): 25 | a list of indices specifying which cells to use from the sparse matrix if subsampling is to be performed 26 | label_type (optional): 27 | the label type of the characteristic that one wants to observe in clustering. It is usually used for cell type. 28 | It indicates the column name of the meta data provided. 29 | sparse (optional): 30 | a boolean indicating whether the data should be kept in sparse format or converted to a dense tensor. 31 | For small data sets, it is recommended to transform the data in dense format for faster training. 32 | For large ones, this helps keep the memory used in check. 33 | """ 34 | 35 | self.scaling_type = scaling_type 36 | self.meta = meta_data 37 | 38 | if gene_selection is not None: 39 | sparse_mtrx = sparse_mtrx.tocsc()[:,gene_selection].tocoo() 40 | if subset is not None: 41 | sparse_mtrx = sparse_mtrx.tocsr()[subset] 42 | if sparse: 43 | self.sparse = True 44 | sparse_mtrx = sparse_mtrx.tocsr() 45 | self.data = sparse_mtrx 46 | if self.scaling_type == 'mean': 47 | self.library = torch.tensor(sparse_mtrx.mean(axis=-1).toarray()) 48 | elif self.scaling_type == 'max': 49 | self.library = torch.tensor(sparse_mtrx.max(axis=-1).toarray()) 50 | elif self.scaling_type == 'sum': 51 | self.library = torch.tensor(sparse_mtrx.sum(axis=-1).toarray()) 52 | 53 | else: 54 | self.sparse = False 55 | self.data = torch.Tensor(sparse_mtrx.todense()) 56 | if self.scaling_type == 'mean': 57 | self.library = torch.mean(self.data, dim=-1).unsqueeze(1) 58 | elif self.scaling_type == 'max': 59 | self.library = torch.max(self.data, dim=-1).values.unsqueeze(1) 60 | elif self.scaling_type == 'sum': 61 | self.library = torch.sum(self.data, dim=-1).unsqueeze(1) 62 | 63 | self.n_genes = self.data.shape[1] 64 | 65 | self.label_type = label_type 66 | 67 | def __len__(self): 68 | return(self.data.shape[0]) 69 | 70 | def __getitem__(self, idx=None): 71 | if idx is not None: 72 | if self.sparse: 73 | expression = self.data[idx] 74 | else: 75 | expression = self.data[idx] 76 | else: 77 | expression = self.data 78 | idx = torch.arange(self.data.shape[0]) 79 | lib = self.library[idx] 80 | return expression, lib, idx 81 | 82 | def get_labels(self, idx=None): 83 | if torch.is_tensor(idx): 84 | idx = idx.tolist() 85 | if idx is None: 86 | idx = np.arange(self.__len__()) 87 | return np.asarray(np.array(self.meta[self.label_type])[idx]) 88 | 89 | def get_labels_numerical(self, idx=None): 90 | if torch.is_tensor(idx): 91 | idx = idx.tolist() 92 | if idx is None: 93 | idx = np.arange(self.__len__()) 94 | label_ids = np.argmax(np.expand_dims(np.asarray(self.meta[self.label_type]),0)>=np.expand_dims(idx,1),axis=1) 95 | return label_ids 96 | 97 | 98 | ### 99 | # functions used for the sparse option 100 | # here the data will have to be transformed to dense format when a batch is called 101 | ### 102 | 103 | def sparse_coo_to_tensor(mtrx): 104 | return torch.FloatTensor(mtrx.todense()) 105 | 106 | def collate_sparse_batches(batch): 107 | data_batch, library_batch, idx_batch = zip(*batch) 108 | data_batch = scipy.sparse.vstack(list(data_batch)) 109 | data_batch = sparse_coo_to_tensor(data_batch) 110 | library_batch = torch.stack(list(library_batch), dim=0) 111 | idx_batch = list(idx_batch) 112 | return data_batch, library_batch, idx_batch 113 | -------------------------------------------------------------------------------- /scDGD/classes/output_distributions.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.distributions as D 5 | 6 | 7 | def logNBdensity(k, m, r): 8 | """ 9 | Negative Binomial NB(k;m,r), where m is the mean and r is "number of failures" 10 | r can be real number (and so can k) 11 | k, and m are tensors of same shape 12 | r is tensor of shape (1, n_genes) 13 | Returns the log NB in same shape as k 14 | """ 15 | # remember that gamma(n+1)=n! 16 | eps = 1.0e-10 17 | x = torch.lgamma(k + r) 18 | x -= torch.lgamma(r) 19 | x -= torch.lgamma(k + 1) 20 | x += k * torch.log(m * (r + m + eps) ** (-1) + eps) 21 | x += r * torch.log(r * (r + m + eps) ** (-1)) 22 | return x 23 | 24 | 25 | class NBLayer(nn.Module): 26 | """ 27 | A negative binomial of scaled values of m and learned parameters for r. 28 | mhat = m/M, where M is the scaling factor 29 | 30 | The scaled value mhat is typically the output value from the NN, 31 | but we are expanding to the option of modelling the probability. 32 | 33 | If rhat=None, it is assumed that the last half of mhat contains rhat. 34 | 35 | m = M*mhat 36 | """ 37 | 38 | # def __init__(self, out_dim, r_init, scaling_type='max', output='mean', activation='sigmoid', reduction='none'): 39 | def __init__( 40 | self, out_dim, r_init, output="mean", activation="sigmoid", reduction="none" 41 | ): 42 | super(NBLayer, self).__init__() 43 | 44 | # initialize parameter for r 45 | # real-valued positive parameters are usually used as their log equivalent 46 | self.log_r = torch.nn.Parameter( 47 | torch.full(fill_value=math.log(r_init - 1), size=(1, out_dim)), 48 | requires_grad=True, 49 | ) 50 | # self.scaling_type = scaling_type 51 | self.output = output 52 | self.activation = activation 53 | self.reduction = reduction 54 | 55 | # self.activ_layer = nn.ModuleList() 56 | if self.activation == "sigmoid": 57 | self.activ_layer = nn.Sigmoid() 58 | elif self.activation == "softmax": 59 | self.activ_layer = nn.Softmax(dim=-1) 60 | elif self.activation == "softplus": 61 | self.activ_layer = nn.Softplus() 62 | else: 63 | raise ValueError("Activation function not recognized") 64 | 65 | @property 66 | def r(self): 67 | return torch.exp(self.log_r) + 1 68 | 69 | def forward(self, x): 70 | if self.output == "mean": 71 | return self.activ_layer(x) 72 | elif self.output == "prob": 73 | p = self.activ_layer(x) 74 | mean = self.r * (1 - p) / p 75 | return mean 76 | 77 | # Convert to m from scaled variables 78 | def rescale(self, M, mhat): 79 | return M * mhat 80 | 81 | def loss(self, x, M, mhat, gene_id=None): 82 | # k has shape (nbatch,dim), M has shape (nbatch,1) 83 | # mhat has dim (nbatch,dim) 84 | # r has dim (1,dim) 85 | if gene_id is not None: 86 | loss = -logNBdensity(x, self.rescale(M, mhat), self.r[0, gene_id]) 87 | else: 88 | loss = -logNBdensity(x, self.rescale(M, mhat), self.r) 89 | if self.reduction == "none": 90 | return loss 91 | elif self.reduction == "sum": 92 | return loss.sum() 93 | 94 | # The logprob of the tensor 95 | def logprob(self, x, M, mhat): 96 | return logNBdensity(x, self.rescale(M, mhat), self.r) 97 | 98 | def sample(self, nsample, M, mhat): 99 | # Note that torch.distributions.NegativeBinomial returns FLOAT and not int 100 | with torch.no_grad(): 101 | m = self.rescale(M, mhat) 102 | # probs = m/(m+torch.exp(self.log_r)) 103 | probs = self.r / (m + self.r) 104 | nb = torch.distributions.NegativeBinomial(self.r, probs=probs) 105 | return nb.sample([nsample]).squeeze() 106 | -------------------------------------------------------------------------------- /scDGD/classes/prior.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.distributions as D 5 | import numpy as np 6 | import json 7 | 8 | 9 | class gaussian: 10 | """ 11 | This is a simple Gaussian prior used for initializing mixture model means 12 | """ 13 | 14 | def __init__(self, dim, mean, stddev): 15 | self.dim = dim 16 | self.mean = mean 17 | self.stddev = stddev 18 | self.g = torch.distributions.normal.Normal(mean, stddev) 19 | 20 | def sample(self, n): 21 | return self.g.sample((n, self.dim)) 22 | 23 | def log_prob(self, x): 24 | return self.g.log_prob(x) 25 | 26 | 27 | class softball: 28 | """ 29 | Almost uniform prior for the m-dimensional ball. 30 | Logistic function makes a soft (differentiable) boundary. 31 | Returns a prior function and a sample function. 32 | The prior takes a tensor with a batch of z 33 | vectors (last dim) and returns a tensor of prior log-probabilities. 34 | The sample function returns n samples from the prior (approximate 35 | samples uniform from the m-ball). NOTE: APPROXIMATE SAMPLING. 36 | """ 37 | 38 | def __init__(self, dim, radius, a=1): 39 | self.dim = dim 40 | self.radius = radius 41 | self.a = a 42 | self.norm = math.lgamma(1 + dim * 0.5) - dim * ( 43 | math.log(radius) + 0.5 * math.log(math.pi) 44 | ) 45 | 46 | def sample(self, n): 47 | # Return n random samples 48 | # Approximate: We sample uniformly from n-ball 49 | with torch.no_grad(): 50 | # Gaussian sample 51 | sample = torch.randn((n, self.dim)) 52 | # n random directions 53 | sample.div_(sample.norm(dim=-1, keepdim=True)) 54 | # n random lengths 55 | local_len = self.radius * torch.pow(torch.rand((n, 1)), 1.0 / self.dim) 56 | sample.mul_(local_len.expand(-1, self.dim)) 57 | return sample 58 | 59 | def log_prob(self, z): 60 | # Return log probabilities of elements of tensor (last dim assumed to be z vectors) 61 | return self.norm - torch.log( 62 | 1 + torch.exp(self.a * (z.norm(dim=-1) / self.radius - 1)) 63 | ) 64 | 65 | 66 | class GaussianMixture(nn.Module): 67 | def __init__( 68 | self, 69 | Nmix, 70 | dim, 71 | type="diagonal", 72 | alpha=1, 73 | mean_init=(1.0, 1.0), 74 | sd_init=[1.0, 1.0], 75 | ): 76 | """ 77 | A mixture of multi-variate Gaussians 78 | 79 | Nmix is the number of components in the mixture 80 | dim is the dimension of the space 81 | type can be "fixed", "isotropic" or "diagonal", which refers to the covariance matrices 82 | mean_prior is a prior class with a log_prob and sample function 83 | - Standard normal if not specified. 84 | - Other option is ('softball',,) 85 | If there is no mean_prior specified, a default Gaussian will be chosen with 86 | - mean_init[0] as mean and mean_init[1] as standard deviation 87 | logbeta_prior is a prior class for the negative log variance of the mixture components 88 | - logbeta = log (1/sigma^2) 89 | - If it is not specified, we make this prior a Gaussian from sd_init parameters 90 | - For the sake of interpretability, the sd_init parameters represent the desired mean and (approximately) sd of the standard deviation 91 | - the difference btw giving a prior beforehand and giving only init values is that with a given prior, the logbetas will be sampled from it, otherwise they will be initialized the same 92 | alpha determines the Dirichlet prior on mixture coefficients 93 | Mixture coefficients are initialized uniformly 94 | Other parameters are sampled from prior 95 | """ 96 | super(GaussianMixture, self).__init__() 97 | self.dim = dim 98 | self.Nmix = Nmix 99 | # self.init = init 100 | 101 | # Means with shape: Nmix,dim 102 | self.mean = nn.Parameter(torch.empty(Nmix, dim), requires_grad=True) 103 | self.mean_prior = softball(self.dim, mean_init[0], mean_init[1]) 104 | 105 | # Dirichlet prior on mixture 106 | self.alpha = alpha 107 | self.dirichlet_constant = math.lgamma(Nmix * alpha) - Nmix * math.lgamma(alpha) 108 | 109 | # Log inverse variance with shape (Nmix,dim) or (Nmix,1) 110 | self.sd_init = sd_init 111 | self.sd_init[0] = 0.2 * (mean_init[0] / self.Nmix) 112 | self.betafactor = dim * 0.5 # rename this! 113 | self.bdim = 1 # If 'diagonal' the dimension of lobbeta is = dim 114 | if type == "fixed": 115 | # No gradient needed for training 116 | # This is a column vector to be correctly broadcastet in std dev tensor 117 | self.logbeta = nn.Parameter( 118 | torch.empty(Nmix, self.bdim), requires_grad=False 119 | ) 120 | self.logbeta_prior = None 121 | else: 122 | if type == "diagonal": 123 | self.betafactor = 0.5 124 | self.bdim = dim 125 | elif type != "isotropic": 126 | raise ValueError( 127 | "type must be 'isotropic' (default), 'diagonal', or 'fixed'" 128 | ) 129 | 130 | self.logbeta = nn.Parameter( 131 | torch.empty(Nmix, self.bdim), requires_grad=True 132 | ) 133 | 134 | # Mixture coefficients. These are weights for softmax 135 | self.weight = nn.Parameter(torch.empty(Nmix), requires_grad=True) 136 | self.init_params() 137 | 138 | # -dim*0.5*log(2pi) 139 | self.pi_term = -0.5 * self.dim * math.log(2 * math.pi) 140 | 141 | def init_params(self): 142 | with torch.no_grad(): 143 | # Means are sampled from the prior 144 | self.mean.copy_(self.mean_prior.sample(self.Nmix)) 145 | self.logbeta.fill_(-2 * math.log(self.sd_init[0])) 146 | self.logbeta_prior = gaussian( 147 | self.bdim, -2 * math.log(self.sd_init[0]), self.sd_init[1] 148 | ) 149 | 150 | # Weights are initialized to 1, corresponding to uniform mixture coeffs 151 | self.weight.fill_(1) 152 | 153 | def forward(self, x, label=None): 154 | # The beta values are obtained from logbeta 155 | halfbeta = 0.5 * torch.exp(self.logbeta) 156 | 157 | # y = logp = - 0.5*log (2pi) -0.5*beta(x-mean[i])^2 + 0.5*log(beta) 158 | # sum terms for each component (sum is over last dimension) 159 | # y is one-dim with length Nmix 160 | # x is unsqueezed to (nsample,1,dim), so broadcasting of mean (Nmix,dim) works 161 | y = ( 162 | self.pi_term 163 | - (x.unsqueeze(-2) - self.mean).square().mul(halfbeta).sum(-1) 164 | + self.betafactor * self.logbeta.sum(-1) 165 | ) 166 | # For each component multiply by mixture probs 167 | y += torch.log_softmax(self.weight, dim=0) 168 | y = torch.logsumexp(y, dim=-1) 169 | y = y + self.prior() # += gives cuda error 170 | 171 | return y 172 | 173 | def log_prob(self, x): # Add label? 174 | self.forward(x) 175 | 176 | def mixture_probs(self): 177 | return torch.softmax(self.weight, dim=-1) 178 | 179 | def covariance(self): 180 | return torch.exp(-self.logbeta) 181 | 182 | def prior(self): 183 | """Calculate log prob of prior on mean, logbeta, and mixture coeff""" 184 | # Mixture 185 | p = self.dirichlet_constant # /self.Nmix 186 | if self.alpha != 1: 187 | p = p + (self.alpha - 1.0) * ( 188 | self.mixture_probs().log().sum() 189 | ) # /self.Nmix 190 | # Means 191 | p = p + self.mean_prior.log_prob(self.mean).sum() # /self.Nmix 192 | # logbeta 193 | if self.logbeta_prior is not None: 194 | p = ( 195 | p + self.logbeta_prior.log_prob(self.logbeta).sum() 196 | ) # /(self.Nmix*self.dim) 197 | return p 198 | 199 | def Distribution(self): 200 | with torch.no_grad(): 201 | mix = D.Categorical(probs=torch.softmax(self.weight, dim=-1)) 202 | comp = D.Independent(D.Normal(self.mean, torch.exp(-0.5 * self.logbeta)), 1) 203 | return D.MixtureSameFamily(mix, comp) 204 | 205 | def sample(self, nsample): 206 | with torch.no_grad(): 207 | gmm = self.Distribution() 208 | return gmm.sample(torch.tensor([nsample])) 209 | 210 | def component_sample(self, nsample): 211 | """Returns a sample from each component. Tensor shape (nsample,nmix,dim)""" 212 | with torch.no_grad(): 213 | comp = D.Independent(D.Normal(self.mean, torch.exp(-0.5 * self.logbeta)), 1) 214 | return comp.sample(torch.tensor([nsample])) 215 | 216 | def sample_probs(self, x): 217 | halfbeta = 0.5 * torch.exp(self.logbeta) 218 | y = ( 219 | self.pi_term 220 | - (x.unsqueeze(-2) - self.mean).square().mul(halfbeta).sum(-1) 221 | + self.betafactor * self.logbeta.sum(-1) 222 | ) 223 | y += torch.log_softmax(self.weight, dim=0) 224 | return torch.exp(y) 225 | 226 | """This section is for learning new data points""" 227 | 228 | def sample_new_points(self, n_points, option="random", n_new=1): 229 | """ 230 | Generates samples for each new data point 231 | - n_points defines the number of new data points to learn 232 | - option defines which of 2 schemes to use 233 | - random: sample n_new vectors from each component 234 | -> Nmix * n_new values per new point 235 | - mean: take the mean of each component as initial representation values 236 | -> Nmix values per new point 237 | The order of repetition in both options is [a,a,a, b,b,b, c,c,c] on data point ID. 238 | """ 239 | self.new_samples = n_new 240 | multiplier = self.Nmix 241 | if option == "random": 242 | out = self.component_sample(n_points * n_new) 243 | multiplier *= n_new 244 | 245 | elif option == "mean": 246 | with torch.no_grad(): 247 | out = torch.repeat_interleave( 248 | self.mean.clone().cpu().detach().unsqueeze(0), n_points, dim=0 249 | ) 250 | else: 251 | print( 252 | "Please specify how to initialize new representations correctly \nThe options are 'random' and 'mean'." 253 | ) 254 | return out.view(n_points * self.new_samples * self.Nmix, self.dim) 255 | 256 | def reshape_targets(self, y, y_type="true"): 257 | """ 258 | Since we have multiple representations for the same new data point, 259 | we need to reshape the output a little to calculate the losses 260 | Depending on the y_type, y can be 261 | - the true targets (y_type: 'true') 262 | - the model predictions (y_type: 'predicted') (can also be used for rep.z in dataloader loop) 263 | - the 4-dimensional representation or the loss (y_type: 'reverse') 264 | """ 265 | 266 | if y_type == "true": 267 | if len(y.shape) > 2: 268 | raise ValueError( 269 | "Unexpected shape in input to function reshape_targets. Expected 2 dimensions, got " 270 | + str(len(y.shape)) 271 | ) 272 | return ( 273 | y.unsqueeze(1).unsqueeze(1).expand(-1, self.new_samples, self.Nmix, -1) 274 | ) 275 | elif y_type == "predicted": 276 | if len(y.shape) > 2: 277 | raise ValueError( 278 | "Unexpected shape in input to function reshape_targets. Expected 2 dimensions, got " 279 | + str(len(y.shape)) 280 | ) 281 | n_points = int( 282 | torch.numel(y) / (self.new_samples * self.Nmix * y.shape[-1]) 283 | ) 284 | return y.view(n_points, self.new_samples, self.Nmix, y.shape[-1]) 285 | elif "reverse": 286 | if len(y.shape) < 4: 287 | # this case is for when the losses are of shape (n_points,self.new_samples,self.Nmix) 288 | return y.view(y.shape[0] * self.new_samples * self.Nmix) 289 | else: 290 | return y.view(y.shape[0] * self.new_samples * self.Nmix, y.shape[-1]) 291 | else: 292 | raise ValueError( 293 | "The y_type in function reshape_targets was incorrect. Please choose between 'true' and 'predicted'." 294 | ) 295 | 296 | def choose_best_representations(self, x, losses): 297 | """ 298 | Selects the representation for each new datapoint that maximizes the objective 299 | - x are the newly learned representations 300 | - x and losses have to have the same shape in the first dimension 301 | make sure that the losses are only summed over the output dimension 302 | Outputs new representation values 303 | """ 304 | n_points = int(torch.numel(losses) / (self.new_samples * self.Nmix)) 305 | 306 | best_sample = torch.argmin( 307 | losses.view(-1, self.new_samples * self.Nmix), dim=1 308 | ).squeeze(-1) 309 | best_rep = x.view(n_points, self.new_samples * self.Nmix, self.dim)[ 310 | range(n_points), best_sample 311 | ] 312 | # best_rep = torch.diagonal(x.view(n_points,self.new_samples*self.Nmix,self.dim)[:,best_sample],dim1=0,dim2=1).transpose(0,1) 313 | 314 | return best_rep 315 | 316 | def choose_old_or_new(self, z_new, loss_new, z_old, loss_old): 317 | if (len(z_new.shape) == 2) and (len(z_old.shape) == 2): 318 | z_conc = torch.cat((z_new.unsqueeze(1), z_old.unsqueeze(1)), dim=1) 319 | else: 320 | raise ValueError( 321 | "Unexpected shape in input to function choose_old_or_new. Expected 2 dimensions for z_new and z_old, got " 322 | + str(len(z_new.shape)) 323 | + " and " 324 | + str(len(z_old.shape)) 325 | ) 326 | 327 | len_loss_new = len(loss_new.shape) 328 | for l in range(3 - len_loss_new): 329 | loss_new = loss_new.unsqueeze(1) 330 | len_loss_old = len(loss_old.shape) 331 | for l in range(3 - len_loss_old): 332 | loss_old = loss_old.unsqueeze(1) 333 | losses = torch.cat((loss_new, loss_old), dim=1) 334 | 335 | best_sample = torch.argmin(losses, dim=1).squeeze(-1) 336 | # print(str(best_sample.sum().item())+' out of '+str(z_new.shape[0])+' samples were resampled.') 337 | 338 | best_rep = z_conc[range(z_conc.shape[0]), best_sample] 339 | 340 | return best_rep, round((best_sample.sum().item() / z_new.shape[0]) * 100, 2) 341 | 342 | def clustering(self, z): 343 | """compute the cluster assignment (as int) for each sample""" 344 | return torch.argmax(self.sample_probs(torch.tensor(z)), dim=-1).to(torch.int16) 345 | 346 | @classmethod 347 | def load(cls, save_dir="./"): 348 | # get saved hyper-parameters 349 | with open(save_dir + "dgd_hyperparameters.json", "r") as fp: 350 | param_dict = json.load(fp) 351 | 352 | gmm = cls( 353 | Nmix=param_dict["n_components"], 354 | dim=param_dict["latent"], 355 | type="diagonal", 356 | alpha=param_dict["dirichlet_a"], 357 | mean_init=(param_dict["mp_scale"], param_dict["hardness"]), 358 | sd_init=[param_dict["sd_mean"], param_dict["sd_sd"]], 359 | ) 360 | 361 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 362 | gmm.load_state_dict( 363 | torch.load(save_dir + param_dict["name"] + "_gmm.pt", map_location=device) 364 | ) 365 | return gmm 366 | 367 | 368 | class GaussianMixtureSupervised(GaussianMixture): 369 | def __init__( 370 | self, 371 | Nclass, 372 | dim, 373 | Ncpc=1, 374 | type="diagonal", 375 | alpha=1, 376 | mean_init=(1.0, 1.0), 377 | sd_init=(1.0, 1.0), 378 | ): 379 | super(GaussianMixtureSupervised, self).__init__( 380 | Ncpc * Nclass, 381 | dim, 382 | type=type, 383 | alpha=alpha, 384 | mean_init=mean_init, 385 | sd_init=sd_init, 386 | ) 387 | self.dim = dim 388 | 389 | self.Nclass = Nclass 390 | self.Ncpc = Ncpc 391 | 392 | def forward(self, x, label=None): 393 | if label is None: 394 | y = super().forward(x) 395 | return y 396 | 397 | if 999 in label: 398 | # first get normal loss 399 | idx_unsup = [i for i in range(len(label)) if label[i] == 999] 400 | y_unsup = super().forward(x[idx_unsup]) 401 | # Otherwise use the component corresponding to the label 402 | idx_sup = [i for i in range(len(label)) if label[i] != 999] 403 | halfbeta = 0.5 * torch.exp(self.logbeta) 404 | # Pick only the Nclc components belonging class 405 | # y_sup = self.pi_term - (x.unsqueeze(-2).unsqueeze(-2) - self.mean.view(self.Nclass,self.Ncpc,-1)).square().mul(halfbeta.unsqueeze(-2))[label[idx_sup]].sum(-1).sum(-1) + self.betafactor*self.logbeta.view(self.Nclass,self.Ncpc,-1).sum(-1).sum(-1) 406 | y_sup = ( 407 | self.pi_term 408 | - ( 409 | x.unsqueeze(-2).unsqueeze(-2) 410 | - self.mean.view(self.Nclass, self.Ncpc, -1) 411 | ) 412 | .square() 413 | .mul(halfbeta.unsqueeze(-2)) 414 | .sum(-1) 415 | + (self.betafactor * self.logbeta.view(self.Nclass, self.Ncpc, -1)).sum( 416 | -1 417 | ) 418 | ) 419 | y_sup += torch.log_softmax(self.weight.view(self.Nclass, self.Ncpc), dim=-1) 420 | y_sup = y_sup.sum(-1) 421 | y_sup = torch.abs( 422 | y_sup[(idx_sup, label[idx_sup])] * self.Nclass 423 | ) # this is replacement for logsumexp of supervised samples 424 | # put together y 425 | y = torch.empty((x.shape[0]), dtype=torch.float32).to(x.device) 426 | y[idx_unsup] = y_unsup 427 | y[idx_sup] = y_sup 428 | else: 429 | halfbeta = 0.5 * torch.exp(self.logbeta) 430 | # Pick only the Nclc components belonging class 431 | y = ( 432 | self.pi_term 433 | - ( 434 | x.unsqueeze(-2).unsqueeze(-2) 435 | - self.mean.view(self.Nclass, self.Ncpc, -1) 436 | ) 437 | .square() 438 | .mul(halfbeta.unsqueeze(-2)) 439 | .sum(-1) 440 | + (self.betafactor * self.logbeta.view(self.Nclass, self.Ncpc, -1)).sum( 441 | -1 442 | ) 443 | ) 444 | y += torch.log_softmax(self.weight.view(self.Nclass, self.Ncpc), dim=-1) 445 | y = y.sum(-1) 446 | # y = torch.abs(y[(np.arange(y.shape[0]),label)] * self.Nclass) # this is replacement for logsumexp of supervised samples 447 | y = ( 448 | y[(np.arange(y.shape[0]), label)] * self.Nclass 449 | ) # this is replacement for logsumexp of supervised samples 450 | 451 | y = y + super().prior() 452 | return y 453 | 454 | def log_prob(self, x, label=None): 455 | self.forward(x, label=label) 456 | 457 | def label_mixture_probs(self, label): 458 | return torch.softmax(self.weight[label], dim=-1) 459 | -------------------------------------------------------------------------------- /scDGD/classes/representation.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.distributions as D 5 | 6 | class RepresentationLayer(torch.nn.Module): 7 | ''' 8 | Implements a representation layer, that accumulates pytorch gradients. 9 | 10 | Representations are vectors in nrep-dimensional real space. By default 11 | they will be initialized as a tensor of dimension nsample x nrep from a 12 | normal distribution (mean and variance given by init). 13 | 14 | One can also supply a tensor to initialize the representations (values=tensor). 15 | The representations will then have the same dimension and will assumes that 16 | the first dimension is nsample (and the last is nrep). 17 | 18 | forward() takes a sample index and returns the representation. 19 | 20 | Representations are "linear", so a representation layer may be followed 21 | by an activation function. 22 | 23 | To update representations, the pytorch optimizers do not always work well, 24 | so the module comes with it's own SGD update (self.update(lr,mom,...)). 25 | 26 | If the loss has reduction="sum", things work well. If it is ="mean", the 27 | gradients become very small and the learning rate needs to be rescaled 28 | accordingly (batchsize*output_dim). 29 | 30 | Do not forget to call the zero_grad() before each epoch (not inside the loop 31 | like with the weights). 32 | 33 | ''' 34 | def __init__(self, 35 | nrep, # Dimension of representation 36 | nsample, # Number of training samples 37 | init=(0.,1.),# Normal distribution mean and stddev for 38 | # initializing representations 39 | values=None # If values is given, the other parameters are ignored 40 | ): 41 | super(RepresentationLayer, self).__init__() 42 | self.dz = None 43 | if values is None: 44 | self.nrep=nrep 45 | self.nsample=nsample 46 | self.mean, self.stddev = init[0],init[1] 47 | self.init_random(self.mean,self.stddev) 48 | else: 49 | # Initialize representations from a tensor with values 50 | self.nrep = values.shape[-1] 51 | self.nsample = values.shape[0] 52 | self.mean, self.stddev = None, None 53 | # Is this the way to copy values to a parameter? 54 | self.z = torch.nn.Parameter(torch.zeros_like(values), requires_grad=True) 55 | with torch.no_grad(): 56 | self.z.copy_(values) 57 | 58 | def init_random(self,mean,stddev): 59 | # Generate random representations 60 | self.z = torch.nn.Parameter(torch.normal(mean,stddev,size=(self.nsample,self.nrep), requires_grad=True)) 61 | 62 | def forward(self, idx=None): 63 | if idx is None: 64 | return self.z 65 | else: 66 | return self.z[idx] 67 | 68 | # Index can be whatever it can be for a torch.tensor (e.g. tensor of idxs) 69 | def __getitem__(self,idx): 70 | return self.z[idx] 71 | 72 | def fix(self): 73 | self.z.requires_grad = False 74 | 75 | def unfix(self): 76 | self.z.requires_grad = True 77 | 78 | def zero_grad(self): # Used only if the update function is used 79 | if self.z.grad is not None: 80 | self.z.grad.detach_() 81 | self.z.grad.zero_() 82 | 83 | def update(self,idx=None,lr=0.001,mom=0.9,wd=None): 84 | if self.dz is None: 85 | self.dz = torch.zeros(self.z.size()).to(self.z.device) 86 | with torch.no_grad(): 87 | # Update z 88 | # dz(k,j) = sum_i grad(k,i) w(i,j) step(z(j)) 89 | self.dz[idx] = self.dz[idx].mul(mom) - self.z.grad[idx].mul(lr) 90 | if wd is not None: 91 | self.dz[idx] -= wd*self.z[idx] 92 | self.z[idx] += self.dz[idx] 93 | 94 | def rescale(self): 95 | z_flat = torch.flatten(self.z.cpu().detach()) 96 | sd, m = torch.std_mean(z_flat) 97 | with torch.no_grad(): 98 | self.z -= m 99 | self.z /= sd -------------------------------------------------------------------------------- /scDGD/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from scDGD.functions.data import prepate_data 2 | from scDGD.functions.train import dgd_train -------------------------------------------------------------------------------- /scDGD/functions/analysis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from sklearn import preprocessing 4 | import torch 5 | from sklearn.metrics import confusion_matrix 6 | import operator 7 | import sklearn.metrics 8 | import json 9 | from scDGD.classes import GaussianMixture 10 | import seaborn as sns 11 | import matplotlib.pyplot as plt 12 | import umap 13 | 14 | 15 | def order_matrix_by_max_per_class(mtrx, class_labels, comp_order=None): 16 | if comp_order is not None: 17 | temp_mtrx = np.zeros(mtrx.shape) 18 | for i in range(mtrx.shape[1]): 19 | temp_mtrx[:, i] = mtrx[:, comp_order[i]] 20 | mtrx = temp_mtrx 21 | max_id_per_class = np.argmax(mtrx, axis=1) 22 | max_coordinates = list(zip(np.arange(mtrx.shape[0]), max_id_per_class)) 23 | max_coordinates.sort(key=operator.itemgetter(1)) 24 | new_class_order = [x[0] for x in max_coordinates] 25 | new_mtrx = np.zeros(mtrx.shape) 26 | # reindexing mtrx worked on test but not in application, reverting to stupid safe for-loop 27 | for i in range(mtrx.shape[0]): 28 | new_mtrx[i, :] = mtrx[new_class_order[i], :] 29 | # mtrx = mtrx[new_class_order,:] 30 | return new_mtrx, [class_labels[i] for i in new_class_order] 31 | 32 | 33 | def gmm_clustering(r, gmm, labels): 34 | # transform categorical labels into numerical 35 | le = preprocessing.LabelEncoder() 36 | le.fit(labels) 37 | true_labels = le.transform(labels) 38 | # compute probabilities per sample and component (n_sample,n_mix_comp) 39 | probs_per_sample_and_component = gmm.sample_probs(torch.tensor(r)) 40 | # get index (i.e. component id) of max prob per sample 41 | cluster_labels = ( 42 | torch.max(probs_per_sample_and_component, dim=-1).indices.cpu().detach() 43 | ) 44 | return cluster_labels 45 | 46 | 47 | def compute_distances(mtrx): 48 | distances = sklearn.metrics.pairwise.euclidean_distances(mtrx) 49 | return distances 50 | 51 | 52 | def get_connectivity_from_threshold(mtrx, threshold): 53 | connectivity_mtrx = np.zeros(mtrx.shape) 54 | idx = np.where(mtrx <= threshold) 55 | connectivity_mtrx[idx[0], idx[1]] = 1 56 | np.fill_diagonal(connectivity_mtrx, 0) 57 | return connectivity_mtrx 58 | 59 | 60 | def rank_distances(mtrx): 61 | ranks = np.argsort(mtrx, axis=-1) 62 | # testing advanced indexing for ranking 63 | m, n = mtrx.shape 64 | # Initialize output array 65 | out = np.empty((m, n), dtype=int) 66 | # Use sidx as column indices, while a range array for the row indices 67 | # to select one element per row. Since sidx is a 2D array of indices 68 | # we need to use a 2D extended range array for the row indices 69 | out[np.arange(m)[:, None], ranks] = np.arange(n) 70 | return out 71 | # return ranks 72 | 73 | 74 | def get_node_degrees(mtrx): 75 | return np.sum(mtrx, 1) 76 | 77 | 78 | def get_secondary_degrees(mtrx): 79 | out = np.zeros(mtrx.shape[0]) 80 | for i in range(mtrx.shape[0]): 81 | direct_neighbors = np.where(mtrx[i] == 1)[0] 82 | out[i] = mtrx[direct_neighbors, :].sum() 83 | return out 84 | 85 | 86 | def find_start_node(d1, d2): 87 | minimum_first_degree = np.where(d1 == d1.min())[0] 88 | if len(minimum_first_degree) > 1: 89 | minimum_second_degree_subset = np.where( 90 | d2[minimum_first_degree] == d2[minimum_first_degree].min() 91 | )[0][0] 92 | return minimum_first_degree[minimum_second_degree_subset] 93 | else: 94 | return minimum_first_degree[0] 95 | 96 | 97 | def find_next_node(c, r, i): 98 | connected_nodes = np.where(c[i, :] == 1)[0] 99 | if len(connected_nodes) > 0: 100 | connected_nodes_paired = list(zip(connected_nodes, r[i, connected_nodes])) 101 | connected_nodes_paired.sort(key=operator.itemgetter(1)) 102 | connected_nodes = [ 103 | connected_nodes_paired[x][0] for x in range(len(connected_nodes)) 104 | ] 105 | return connected_nodes 106 | 107 | 108 | def traverse_through_graph(connectiv_mtrx, ranks, first_degrees, second_degrees): 109 | # create 2 lists of node ids 110 | # the first one keeps track of the nodes we have used already (and stores them in the desired order) 111 | # the second one keeps track of the nodes we still have to sort 112 | node_order = [] 113 | nodes_to_be_distributed = list(np.arange(connectiv_mtrx.shape[0])) 114 | 115 | start_node = find_start_node(first_degrees, second_degrees) 116 | node_order.append(start_node) 117 | nodes_to_be_distributed.remove(start_node) 118 | 119 | count_turns = 0 120 | while len(nodes_to_be_distributed) > 0: 121 | next_nodes = find_next_node(connectiv_mtrx, ranks, node_order[-1]) 122 | next_nodes = list(set(next_nodes).difference(set(node_order))) 123 | if len(next_nodes) < 1: 124 | next_nodes = [ 125 | nodes_to_be_distributed[ 126 | find_start_node( 127 | first_degrees[nodes_to_be_distributed], 128 | second_degrees[nodes_to_be_distributed], 129 | ) 130 | ] 131 | ] 132 | for n in next_nodes: 133 | if n not in node_order: 134 | node_order.append(n) 135 | nodes_to_be_distributed.remove(n) 136 | count_turns = 0 137 | count_turns += 1 138 | if count_turns >= 10: 139 | break 140 | 141 | return node_order 142 | 143 | 144 | def order_components_as_graph_traversal(gmm): 145 | distance_mtrx = compute_distances(gmm.mean.detach().cpu().numpy()) 146 | threshold = round(np.percentile(distance_mtrx.flatten(), 30), 2) 147 | 148 | connectivity_mtrx = get_connectivity_from_threshold(distance_mtrx, threshold) 149 | rank_mtrx = rank_distances(distance_mtrx) 150 | 151 | node_degrees = get_node_degrees(connectivity_mtrx) 152 | secondary_node_degrees = get_secondary_degrees(connectivity_mtrx) 153 | 154 | new_node_order = traverse_through_graph( 155 | connectivity_mtrx, rank_mtrx, node_degrees, secondary_node_degrees 156 | ) 157 | return new_node_order 158 | 159 | 160 | def clustering_matrix(gmm, rep, labels, norm=True): 161 | classes = list(np.unique(labels)) 162 | true_labels = np.asarray([classes.index(i) for i in labels]) 163 | cluster_labels = gmm_clustering(rep, gmm, labels) 164 | # get absolute confusion matrix 165 | cm1 = confusion_matrix(true_labels, cluster_labels) 166 | 167 | class_counts = [np.where(true_labels == i)[0].shape[0] for i in range(len(classes))] 168 | cm2 = cm1.astype(np.float64) 169 | for i in range(len(class_counts)): 170 | # percent_sum = 0 171 | for j in range(gmm.Nmix): 172 | if norm: 173 | cm2[i, j] = cm2[i, j] * 100 / class_counts[i] 174 | else: 175 | cm2[i, j] = cm2[i, j] 176 | cm2 = cm2.round() 177 | 178 | # get an order of components based on connectivity graph 179 | component_order = order_components_as_graph_traversal(gmm) 180 | 181 | # take the non-empty entries 182 | cm2 = cm2[: len(classes), : gmm.Nmix] 183 | 184 | cm3, classes_reordered = order_matrix_by_max_per_class( 185 | cm2, classes, component_order 186 | ) 187 | out = pd.DataFrame(data=cm3, index=classes_reordered, columns=component_order) 188 | return out 189 | 190 | 191 | def load_embedding(save_dir="./"): 192 | # get parameter dict 193 | with open(save_dir + "dgd_hyperparameters.json", "r") as fp: 194 | param_dict = json.load(fp) 195 | 196 | embedding = torch.load( 197 | save_dir + param_dict["name"] + "_representation.pt", 198 | map_location=torch.device("cpu"), 199 | ) 200 | return embedding["z"].detach().cpu().numpy() 201 | 202 | 203 | def load_labels(save_dir="./", split="train"): 204 | # load train-val-test split 205 | obs = pd.read_csv(save_dir + "data_obs.csv", index_col=0) 206 | return obs[obs["train_val_test"] == split]["cell_type"].values 207 | 208 | 209 | def plot_cluster_heatmap(save_dir="./"): 210 | gmm = GaussianMixture.load(save_dir=save_dir) 211 | embedding = load_embedding(save_dir=save_dir) 212 | labels = load_labels() 213 | 214 | df_relative_clustering = clustering_matrix(gmm, embedding, labels) 215 | df_clustering = clustering_matrix(gmm, embedding, labels, norm=False) 216 | 217 | annotations = df_relative_clustering.to_numpy(dtype=np.float64).copy() 218 | annotations[annotations < 1] = None 219 | df_relative_clustering = df_relative_clustering.fillna(0) 220 | fig, ax = plt.subplots(figsize=(0.5*annotations.shape[0], 0.5*annotations.shape[1])) 221 | cmap = sns.color_palette("GnBu", as_cmap=True) 222 | sns.heatmap( 223 | df_relative_clustering, 224 | annot=annotations, 225 | cmap=cmap, 226 | annot_kws={"size": 6}, 227 | cbar_kws={"shrink": 0.5, "location": "bottom"}, 228 | xticklabels=True, 229 | yticklabels=True, 230 | mask=np.isnan(annotations), 231 | alpha=0.8, 232 | ) 233 | ylabels = [ 234 | df_clustering.index[x] + " (" + str(int(df_clustering.sum(axis=1)[x])) + ")" 235 | for x in range(df_clustering.shape[0]) 236 | ] 237 | plt.yticks( 238 | ticks=np.arange(len(ylabels)) + 0.5, labels=ylabels, rotation=0, fontsize=8 239 | ) 240 | plt.tick_params(axis="x", rotation=0, labelsize=8) 241 | plt.ylabel("Cell type") 242 | plt.xlabel("GMM component ID") 243 | plt.title("percentage of cell type in GMM cluster") 244 | plt.show() 245 | 246 | 247 | def plot_latent_umap(save_dir="./", n_neighbors=15, min_dist=0.5): 248 | gmm = GaussianMixture.load(save_dir=save_dir) 249 | embedding = load_embedding(save_dir=save_dir) 250 | labels = load_labels() 251 | 252 | # make umap 253 | reducer = umap.UMAP(n_neighbors=n_neighbors, n_components=2, min_dist=min_dist) 254 | projected = reducer.fit_transform(embedding) 255 | plot_data = pd.DataFrame(projected, columns=["UMAP1", "UMAP2"]) 256 | plot_data["cell type"] = labels 257 | plot_data["cell type"] = plot_data["cell type"].astype("category") 258 | plot_data["cluster"] = ( 259 | gmm.clustering(embedding).cpu().detach().numpy() 260 | ) # .astype(str) 261 | plot_data["cluster"] = plot_data["cluster"].astype("category") 262 | 263 | # make a plot with two subplots 264 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) 265 | # adjust spacing between plots 266 | fig.subplots_adjust(wspace=1.0) 267 | # make text size smaller 268 | plt.rcParams.update({"font.size": 6}) 269 | sns.scatterplot(data=plot_data, x="UMAP1", y="UMAP2", hue="cell type", ax=ax1, s=1) 270 | sns.scatterplot(data=plot_data, x="UMAP1", y="UMAP2", hue="cluster", ax=ax2, s=1) 271 | ax1.set_title("cell type") 272 | ax1.legend(bbox_to_anchor=(1.02, 1), loc=2, borderaxespad=0.0, frameon=False) 273 | ax2.legend(bbox_to_anchor=(1.02, 1), loc=2, borderaxespad=0.0, frameon=False) 274 | ax2.set_title("cluster") 275 | plt.show() 276 | -------------------------------------------------------------------------------- /scDGD/functions/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn.model_selection import train_test_split 4 | from scDGD.classes import scDataset 5 | 6 | def prepate_data(adata, label_column, train_fraction=0.8, include_test=True, scaling_type='max', batch_size=256, num_w=0): 7 | ''' 8 | Prepares the pytorch data sets and loaders for training and testing 9 | 10 | For integrating a new data set, set the train_fraction to 1. Otherwise there should always be something left for validation. 11 | If include_test is True, the split will also include a held-out test set. Otherwise, it will only be train and validation. 12 | ''' 13 | 14 | ### 15 | # first create a data split 16 | ### 17 | labels = adata.obs[label_column] 18 | 19 | train_mode = True 20 | if 'train_val_test' not in adata.obs.keys(): 21 | if train_fraction < 1.0: 22 | if include_test: 23 | train_indices, test_indices = train_test_split(np.arange(len(labels)), test_size=(1.0-train_fraction)/2, stratify=labels) 24 | train_indices, val_indices = train_test_split(train_indices, test_size=(((1.0-train_fraction)/2)/(1.0-(1.0-train_fraction)/2)), stratify=labels[train_indices]) 25 | # add the split to the anndata object 26 | train_val_test = [''] * len(labels) 27 | train_val_test = ['train' if i in train_indices else train_val_test[i] for i in range(len(labels))] 28 | train_val_test = ['validation' if i in val_indices else train_val_test[i] for i in range(len(labels))] 29 | train_val_test = ['test' if i in test_indices else train_val_test[i] for i in range(len(labels))] 30 | else: 31 | train_indices, val_indices = train_test_split(np.arange(len(labels)), test_size=(1.0-train_fraction), stratify=labels) 32 | train_val_test = [''] * len(labels) 33 | train_val_test = ['train' if i in train_indices else train_val_test[i] for i in range(len(labels))] 34 | train_val_test = ['validation' if i in val_indices else train_val_test[i] for i in range(len(labels))] 35 | else: 36 | train_mode = False 37 | train_val_test = 'test' 38 | adata.obs['train_val_test'] = train_val_test 39 | else: 40 | if len(set(adata.obs['train_val_test'])) == 1: 41 | train_mode = False 42 | adata.obs['label'] = labels 43 | # make sure to afterwards also return the adata object so that the data split can be re-used 44 | 45 | ### 46 | # then create the data sets and loaders 47 | ### 48 | 49 | if train_mode: 50 | trainset = scDataset( 51 | adata.X, 52 | adata.obs, 53 | scaling_type=scaling_type, 54 | subset=np.where(adata.obs['train_val_test']=='train')[0], 55 | label_type='label' 56 | ) 57 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_w) 58 | validationset = scDataset( 59 | adata.X, 60 | adata.obs, 61 | scaling_type=scaling_type, 62 | subset=np.where(adata.obs['train_val_test']=='validation')[0], 63 | label_type='label' 64 | ) 65 | validationloader = torch.utils.data.DataLoader(validationset, batch_size=batch_size, shuffle=True, num_workers=num_w) 66 | if len(set(adata.obs['train_val_test'])) == 3: 67 | testset = scDataset( 68 | adata.X, 69 | adata.obs, 70 | scaling_type=scaling_type, 71 | subset=np.where(adata.obs['train_val_test']=='test')[0], 72 | label_type='label' 73 | ) 74 | testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=num_w) 75 | else: 76 | testset, testloader = None, None 77 | else: 78 | testset = scDataset( 79 | adata.X, 80 | adata.obs, 81 | scaling_type=scaling_type, 82 | subset=np.where(adata.obs['train_val_test']=='test')[0], 83 | label_type='label' 84 | ) 85 | testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=num_w) 86 | 87 | return adata, trainloader, validationloader, testloader 88 | 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /scDGD/functions/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from sklearn.metrics import confusion_matrix 7 | from scipy.optimize import linear_sum_assignment as linear_assignment 8 | from sklearn import preprocessing 9 | 10 | from scDGD.classes.representation import RepresentationLayer 11 | 12 | def set_random_seed(seed): 13 | torch.manual_seed(seed) 14 | np.random.seed(seed) 15 | if torch.cuda.is_available(): 16 | torch.cuda.manual_seed_all(seed) 17 | 18 | def reshape_scaling_factor(x, out_dim): 19 | start_dim = len(x.shape) 20 | for i in range(out_dim - start_dim): 21 | x = x.unsqueeze(1) 22 | return x 23 | 24 | def _make_cost_m(cm): 25 | s = np.max(cm) 26 | return (- cm + s) 27 | 28 | def gmm_clustering(gmm, rep): 29 | halfbeta = 0.5*torch.exp(gmm.logbeta) 30 | y = gmm.pi_term - (rep.unsqueeze(-2)-gmm.mean).square().mul(halfbeta).sum(-1) + gmm.betafactor*gmm.logbeta.sum(-1) 31 | # For each component multiply by mixture probs 32 | y += torch.log_softmax(gmm.weight,dim=0) 33 | return torch.exp(y) 34 | 35 | def gmm_cluster_acc(r, gmm, labels): 36 | le = preprocessing.LabelEncoder() 37 | le.fit(labels) 38 | true_labels = le.transform(labels) 39 | clustering = gmm_clustering(gmm, r.z.detach()) 40 | cluster_labels = torch.max(clustering, dim=-1).indices.cpu().detach() 41 | cm = confusion_matrix(true_labels, cluster_labels) 42 | indexes = linear_assignment(_make_cost_m(cm)) 43 | cm2 = cm[:,indexes[1]] 44 | acc2 = np.trace(cm2) / np.sum(cm2) 45 | return acc2 46 | 47 | def dgd_train(model, gmm, train_loader, validation_loader, n_epochs=500, 48 | export_dir='./', export_name='scDGD', 49 | lr_schedule_epochs=[0,300],lr_schedule=[[1e-3,1e-2,1e-2],[1e-4,1e-2,1e-2]], optim_betas=[0.5,0.7], wd=1e-4, 50 | acc_save_threshold=0.5,supervision_labels=None, wandb_logging=False): 51 | 52 | if wandb_logging: 53 | import wandb 54 | 55 | # prepare for saving the model 56 | if export_name is not None: 57 | if not os.path.exists(export_dir+export_name): 58 | os.makedirs(export_dir+export_name) 59 | 60 | # get some info from the data 61 | nsample = len(train_loader.dataset) 62 | nsample_test = len(validation_loader.dataset) 63 | out_dim = train_loader.dataset.n_genes 64 | latent = gmm.dim 65 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 66 | 67 | model=model.to(device) 68 | gmm=gmm.to(device) 69 | 70 | ### 71 | # set up representations and optimizers 72 | ### 73 | 74 | if lr_schedule_epochs is None: 75 | lr = lr_schedule[0] 76 | lr_rep = lr_schedule[1] 77 | lr_gmm = lr_schedule[2] 78 | else: 79 | lr = lr_schedule[0][0] 80 | lr_rep = lr_schedule[0][1] 81 | lr_gmm = lr_schedule[0][2] 82 | 83 | rep = RepresentationLayer(nrep=latent,nsample=nsample,values=torch.zeros(size=(nsample,latent))).to(device) 84 | test_rep = RepresentationLayer(nrep=latent,nsample=nsample_test,values=torch.zeros(size=(nsample_test,latent))).to(device) 85 | rep_optimizer = torch.optim.Adam(rep.parameters(), lr=lr_rep, weight_decay=wd,betas=(optim_betas[0],optim_betas[1])) 86 | testrep_optimizer = torch.optim.Adam(test_rep.parameters(), lr=lr_rep, weight_decay=wd,betas=(optim_betas[0],optim_betas[1])) 87 | 88 | if gmm is not None: 89 | gmm_optimizer = torch.optim.Adam(gmm.parameters(), lr=lr_gmm, weight_decay=wd,betas=(optim_betas[0],optim_betas[1])) 90 | model_optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd,betas=(optim_betas[0],optim_betas[1])) 91 | 92 | ### 93 | # start training 94 | ### 95 | 96 | # keep track of the losses and other metrics 97 | train_avg = [] 98 | recon_avg = [] 99 | dist_avg = [] 100 | test_avg = [] 101 | recon_test_avg = [] 102 | dist_test_avg = [] 103 | cluster_accuracies = [] 104 | best_gmm_cluster = 0 105 | 106 | for epoch in range(n_epochs): 107 | 108 | # in case there is a scheduled change in learning rates, change them if the specified epoch is reached 109 | if lr_schedule_epochs is not None: 110 | if epoch in lr_schedule_epochs: 111 | lr_idx = [x for x in range(len(lr_schedule_epochs)) if lr_schedule_epochs[x] == epoch][0] 112 | lr_decoder = lr_schedule[lr_idx][0] 113 | model_optimizer = torch.optim.Adam(model.parameters(), lr=lr_decoder, weight_decay=wd, betas=(optim_betas[0],optim_betas[1])) 114 | lr_rep = lr_schedule[lr_idx][1] 115 | lr_gmm = lr_schedule[lr_idx][2] 116 | if gmm is not None: 117 | gmm_optimizer = torch.optim.Adam(gmm.parameters(), lr=lr_gmm, weight_decay=wd, betas=(optim_betas[0],optim_betas[1])) 118 | 119 | # collect losses and other metrics for the epoch 120 | train_avg.append(0) 121 | recon_avg.append(0) 122 | dist_avg.append(0) 123 | test_avg.append(0) 124 | recon_test_avg.append(0) 125 | dist_test_avg.append(0) 126 | 127 | # train 128 | model.train() 129 | rep_optimizer.zero_grad() 130 | # standard mini batching 131 | for x,lib,i in train_loader: 132 | gmm_optimizer.zero_grad() 133 | model_optimizer.zero_grad() 134 | 135 | x = x.to(device) 136 | lib = lib.to(device) 137 | z = rep(i) 138 | y = model(z) 139 | 140 | # compute losses 141 | recon_loss_x = model.nb.loss(x, lib, y).sum() 142 | if supervision_labels is not None: 143 | sup_i = supervision_labels[i] 144 | gmm_error = -gmm(z,sup_i).sum() 145 | else: 146 | gmm_error = -gmm(z).sum() 147 | loss = recon_loss_x.clone() + gmm_error.clone() 148 | 149 | # backpropagate and update 150 | loss.backward() 151 | gmm_optimizer.step() 152 | model_optimizer.step() 153 | 154 | # log losses 155 | train_avg[-1] += loss.item()/(nsample*out_dim) 156 | recon_avg[-1] += recon_loss_x.item()/(nsample*out_dim) 157 | dist_avg[-1] += gmm_error.item()/(nsample*latent) 158 | 159 | # update representations 160 | rep_optimizer.step() 161 | 162 | # validation run 163 | model.eval() 164 | testrep_optimizer.zero_grad() 165 | # same as above, but without updates for gmm and model 166 | for x,lib,i in validation_loader: 167 | x = x.to(device) 168 | lib = lib.to(device) 169 | z = test_rep(i) 170 | y = model(z) 171 | recon_loss_x = model.nb.loss(x, lib, y).sum() 172 | gmm_error = -gmm(z).sum() 173 | loss = recon_loss_x.clone() + gmm_error.clone() 174 | loss.backward() 175 | 176 | test_avg[-1] += loss.item()/(nsample_test*out_dim) 177 | recon_test_avg[-1] += recon_loss_x.item()/(nsample_test*out_dim) 178 | dist_test_avg[-1] += gmm_error.item()/(nsample_test*latent) 179 | testrep_optimizer.step() 180 | 181 | cluster_accuracies.append(gmm_cluster_acc(rep, gmm, train_loader.dataset.get_labels())) 182 | 183 | save_here = False 184 | if best_gmm_cluster < cluster_accuracies[-1]: 185 | best_gmm_cluster = cluster_accuracies[-1] 186 | best_gmm_epoch = epoch 187 | if best_gmm_cluster >= acc_save_threshold: 188 | save_here = True 189 | elif epoch == n_epochs-1: 190 | save_here = True 191 | 192 | if wandb_logging: 193 | wandb.log({"loss_train": train_avg[-1], 194 | "loss_test": test_avg[-1], 195 | "loss_recon_train": recon_avg[-1], 196 | "loss_recon_test": recon_test_avg[-1], 197 | "loss_gmm_train": dist_avg[-1], 198 | "loss_gmm_test": dist_test_avg[-1], 199 | "cluster_accuracy": cluster_accuracies[-1], 200 | "epoch": epoch}) 201 | else: 202 | # print progress every 10 epochs 203 | if epoch % 10 == 0: 204 | print("epoch "+str(epoch)+": train loss "+str(train_avg[-1])+", validation loss "+str(test_avg[-1])+", cluster accuracy "+str(cluster_accuracies[-1])) 205 | 206 | if export_name is not None: 207 | if save_here: 208 | if epoch == n_epochs-1: 209 | print("model saved at epoch "+str(epoch)+" for having reached the end of training") 210 | else: 211 | print("model saved at epoch "+str(epoch)+" for having so far highest accuracy of "+str(cluster_accuracies[-1])) 212 | torch.save(model.state_dict(), export_dir+export_name+'/'+export_name+'_decoder.pt') 213 | torch.save(rep.state_dict(), export_dir+export_name+'/'+export_name+'_representation.pt') 214 | torch.save(test_rep.state_dict(), export_dir+export_name+'/'+export_name+'_valRepresentation.pt') 215 | torch.save(gmm.state_dict(), export_dir+export_name+'/'+export_name+'_gmm.pt') 216 | 217 | if wandb_logging: 218 | wandb.run.summary["best_gmm_cluster"] = best_gmm_cluster 219 | wandb.run.summary["best_gmm_epoch"] = best_gmm_epoch 220 | # create a history that is returned after training 221 | history = pd.DataFrame( 222 | {'train_loss': train_avg, 223 | 'test_loss': test_avg, 224 | 'train_recon_loss': recon_avg, 225 | 'test_recon_loss': recon_test_avg, 226 | 'train_gmm_loss': dist_avg, 227 | 'test_gmm_loss': dist_test_avg, 228 | 'cluster_accuracy': cluster_accuracies, 229 | 'epoch': np.arange(1,n_epochs+1) 230 | }) 231 | return model, rep, test_rep, gmm, history -------------------------------------------------------------------------------- /scDGD/models/__init__.py: -------------------------------------------------------------------------------- 1 | from scDGD.models.models import DGD -------------------------------------------------------------------------------- /scDGD/models/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from scDGD.classes.output_distributions import NBLayer 4 | import json 5 | 6 | 7 | 8 | class DGD(nn.Module): 9 | def __init__( 10 | self, 11 | out, 12 | latent=20, 13 | hidden=[100, 100, 100], 14 | r_init=2, 15 | output_prediction_type="mean", 16 | output_activation="softplus" 17 | ): 18 | super(DGD, self).__init__() 19 | 20 | self.main = nn.ModuleList() 21 | 22 | if type(hidden) is not int: 23 | n_hidden = len(hidden) 24 | self.main.append(nn.Linear(latent, hidden[0])) 25 | self.main.append(nn.ReLU(True)) 26 | for i in range(n_hidden - 1): 27 | self.main.append(nn.Linear(hidden[i], hidden[i + 1])) 28 | self.main.append(nn.ReLU(True)) 29 | self.main.append(nn.Linear(hidden[-1], out)) 30 | else: 31 | self.main.append(nn.Linear(latent, hidden)) 32 | self.main.append(nn.ReLU(True)) 33 | self.main.append(nn.Linear(hidden, out)) 34 | 35 | self.nb = NBLayer( 36 | out, 37 | r_init=r_init, 38 | output=output_prediction_type, 39 | activation=output_activation 40 | ) 41 | 42 | def forward(self, z): 43 | for i in range(len(self.main)): 44 | z = self.main[i](z) 45 | return self.nb(z) 46 | 47 | @classmethod 48 | def load(cls, save_dir="./"): 49 | # get saved hyper-parameters 50 | with open(save_dir + "dgd_hyperparameters.json", "r") as fp: 51 | param_dict = json.load(fp) 52 | 53 | model = cls( 54 | out=param_dict["n_genes"], 55 | latent=param_dict["latent"], 56 | hidden=param_dict["hidden"], 57 | r_init=param_dict["r_init"], 58 | scaling_type=param_dict["scaling_type"], 59 | output=param_dict["output"], 60 | activation=param_dict["activation"], 61 | ) 62 | 63 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 64 | model.load_state_dict( 65 | torch.load( 66 | save_dir + param_dict["name"] + "_decoder.pt", map_location=device 67 | ) 68 | ) 69 | return model 70 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | #import 3 | setup(name='scDGD', 4 | version="0.2", 5 | #description='Classes for internal representations', 6 | author='Viktoria Schuster', 7 | #url='https://github.com/viktoriaschuster/DGD_paper_experiments', 8 | packages=['scDGD','scDGD.classes','scDGD.models','scDGD.functions'] 9 | #packages=find_packages()#, 10 | #package_dir = {'dgdExp': 'src'} 11 | ) 12 | 13 | # install by going to directory and calling 14 | # python3 setup.py build 15 | # python3 setup.py install --user 16 | # --user needs to be used on this computer because I do not have sudo rights here --------------------------------------------------------------------------------