├── .gitignore
├── DGD_schematic.png
├── LICENSE
├── README.md
├── examples
├── scDGD_training_mousebrain5k.ipynb
└── scDGD_training_mousebrain5k_predicting_prob.ipynb
├── pretrained_models
├── PBMC_Zheng2017
│ ├── data_obs.csv
│ ├── data_var.csv
│ ├── load_PBMC_Zheng2017.ipynb
│ └── model
│ │ ├── L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422_decoder.pt
│ │ ├── L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422_gmm.pt
│ │ ├── L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422_heldoutTestRepresentation.pt
│ │ ├── L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422_representation.pt
│ │ ├── L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422_testRepresentation.pt
│ │ └── dgd_hyperparameters.json
└── mousebrain_1M
│ ├── data_info.zip
│ ├── load_mousebrain1M.ipynb
│ └── model
│ ├── dgd_hyperparameters.json
│ ├── dgd_mousebrain_1m_a1_rs0_decoder.pt
│ └── dgd_mousebrain_1m_a1_rs0_gmm.pt
├── scDGD
├── __init__.py
├── classes
│ ├── __init__.py
│ ├── data.py
│ ├── output_distributions.py
│ ├── prior.py
│ └── representation.py
├── functions
│ ├── __init__.py
│ ├── analysis.py
│ ├── data.py
│ └── train.py
└── models
│ ├── __init__.py
│ └── models.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | # mac related
163 | *.DS_Store
--------------------------------------------------------------------------------
/DGD_schematic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Center-for-Health-Data-Science/scDGD/1314903f66b6d290d041f0c2cd73b258ea2c3d6b/DGD_schematic.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Center-for-Health-Data-Science
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # scDGD
2 |
3 | scDGD is an application of our encoder-less generative model, the Deep Generative Decoder (DGD), to single-cell transcriptomics data.
4 |
5 | It learns low-dimensional representations of full transcriptomics matrices without feature selection. The low-dimensional embeddings are of higher quality than comparable methods such as scVI and the data reconstruction is highly data efficient, outperforming scVI and scVAE, especially on very small data sets.
6 |
7 | For more information about the underlying method and our results, check out the [paper](https://academic.oup.com/bioinformatics/article/39/9/btad497/7241685).
8 |
9 |
10 |
11 | ## Installation
12 |
13 | You can install the package via
14 | ```
15 | pip install git+https://github.com/Center-for-Health-Data-Science/scDGD
16 | ```
17 |
18 | ## How to use it
19 |
20 | From our experience, scDGD can be applied to data sets with as few as 500 cells and as many as one million.
21 |
22 | Check out the notebook showing an example of how to use scDGD:
23 | [](https://colab.research.google.com/github/Center-for-Health-Data-Science/scDGD/blob/HEAD/examples/scDGD_training_mousebrain5k.ipynb)
24 |
25 | We have also uploaded pre-trained models for the PBMC (Zheng et al. 2017) and the 10X 1 million mouse brain data sets, along with notebooks showing how to load them [in pretrained_models](https://github.com/Center-for-Health-Data-Science/scDGD/tree/main/pretrained_models).
26 |
27 | ## Reference
28 |
29 | If you use scDGD in your research, please consider citing
30 |
31 | ```
32 | @article{schuster_deep_2023,
33 | title = {The Deep Generative Decoder: MAP estimation of representations improves modelling of single-cell RNA data},
34 | volume = {39},
35 | issn = {1367-4811},
36 | shorttitle = {The Deep Generative Decoder},
37 | url = {https://doi.org/10.1093/bioinformatics/btad497},
38 | doi = {10.1093/bioinformatics/btad497},
39 | number = {9},
40 | journal = {Bioinformatics},
41 | author = {Schuster, Viktoria and Krogh, Anders},
42 | month = sep,
43 | year = {2023}
44 | }
45 | ```
--------------------------------------------------------------------------------
/pretrained_models/PBMC_Zheng2017/model/L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422_decoder.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Center-for-Health-Data-Science/scDGD/1314903f66b6d290d041f0c2cd73b258ea2c3d6b/pretrained_models/PBMC_Zheng2017/model/L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422_decoder.pt
--------------------------------------------------------------------------------
/pretrained_models/PBMC_Zheng2017/model/L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422_gmm.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Center-for-Health-Data-Science/scDGD/1314903f66b6d290d041f0c2cd73b258ea2c3d6b/pretrained_models/PBMC_Zheng2017/model/L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422_gmm.pt
--------------------------------------------------------------------------------
/pretrained_models/PBMC_Zheng2017/model/L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422_heldoutTestRepresentation.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Center-for-Health-Data-Science/scDGD/1314903f66b6d290d041f0c2cd73b258ea2c3d6b/pretrained_models/PBMC_Zheng2017/model/L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422_heldoutTestRepresentation.pt
--------------------------------------------------------------------------------
/pretrained_models/PBMC_Zheng2017/model/L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422_representation.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Center-for-Health-Data-Science/scDGD/1314903f66b6d290d041f0c2cd73b258ea2c3d6b/pretrained_models/PBMC_Zheng2017/model/L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422_representation.pt
--------------------------------------------------------------------------------
/pretrained_models/PBMC_Zheng2017/model/L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422_testRepresentation.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Center-for-Health-Data-Science/scDGD/1314903f66b6d290d041f0c2cd73b258ea2c3d6b/pretrained_models/PBMC_Zheng2017/model/L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422_testRepresentation.pt
--------------------------------------------------------------------------------
/pretrained_models/PBMC_Zheng2017/model/dgd_hyperparameters.json:
--------------------------------------------------------------------------------
1 | {"latent": 20,
2 | "hidden": [100, 100, 100],
3 | "learning_rates": [[1e-3, 1e-2, 1e-2], [1e-4, 1e-2, 1e-2]],
4 | "lr_name": "322",
5 | "optim_beta": [0.5, 0.7],
6 | "weight_decay": 1e-4,
7 | "dropout": 0.1,
8 | "gmm_type": "diagonal",
9 | "n_components": 18,
10 | "r_init": 2,
11 | "scaling_type": "max",
12 | "output": "mean",
13 | "activation": "softplus",
14 | "mp_scale": 1,
15 | "hardness": 1,
16 | "sd_mean": 0.01,
17 | "sd_sd": 1,
18 | "dirichlet_a": 1,
19 | "n_genes": 32738,
20 | "name": "L20G18_mp1-1_sd0.01_h100-100-100_diri1_step1_genes32738_lr322_max_bs512_rs0_nbRplus1_continued_lr422"}
--------------------------------------------------------------------------------
/pretrained_models/mousebrain_1M/data_info.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Center-for-Health-Data-Science/scDGD/1314903f66b6d290d041f0c2cd73b258ea2c3d6b/pretrained_models/mousebrain_1M/data_info.zip
--------------------------------------------------------------------------------
/pretrained_models/mousebrain_1M/model/dgd_hyperparameters.json:
--------------------------------------------------------------------------------
1 | {"latent": 20,
2 | "hidden": [100, 100, 100],
3 | "learning_rates": [[1e-3, 1e-2, 1e-2], [1e-4, 1e-2, 1e-2]],
4 | "lr_name": "322",
5 | "lr_schedule": [0,500],
6 | "optim_beta": [0.5, 0.7],
7 | "weight_decay": 1e-4,
8 | "dropout": 0.1,
9 | "gmm_type": "diagonal",
10 | "n_components": 27,
11 | "r_init": 2,
12 | "scaling_type": "max",
13 | "output": "mean",
14 | "activation": "softplus",
15 | "mp_scale": 1,
16 | "hardness": 1,
17 | "sd_mean": 0.007,
18 | "sd_sd": 1,
19 | "dirichlet_a": 1,
20 | "n_genes": 27998,
21 | "name": "dgd_mousebrain_1m_a1_rs0"}
--------------------------------------------------------------------------------
/pretrained_models/mousebrain_1M/model/dgd_mousebrain_1m_a1_rs0_decoder.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Center-for-Health-Data-Science/scDGD/1314903f66b6d290d041f0c2cd73b258ea2c3d6b/pretrained_models/mousebrain_1M/model/dgd_mousebrain_1m_a1_rs0_decoder.pt
--------------------------------------------------------------------------------
/pretrained_models/mousebrain_1M/model/dgd_mousebrain_1m_a1_rs0_gmm.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Center-for-Health-Data-Science/scDGD/1314903f66b6d290d041f0c2cd73b258ea2c3d6b/pretrained_models/mousebrain_1M/model/dgd_mousebrain_1m_a1_rs0_gmm.pt
--------------------------------------------------------------------------------
/scDGD/__init__.py:
--------------------------------------------------------------------------------
1 | from . import classes, models, functions
--------------------------------------------------------------------------------
/scDGD/classes/__init__.py:
--------------------------------------------------------------------------------
1 | from scDGD.classes.data import scDataset
2 | from scDGD.classes.prior import GaussianMixture
--------------------------------------------------------------------------------
/scDGD/classes/data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch.utils.data import Dataset
4 | import scipy.sparse
5 |
6 | class scDataset(Dataset):
7 | """
8 | """
9 | def __init__(self, sparse_mtrx, meta_data, label_type, scaling_type='max', gene_selection=None, subset=None, sparse=False):
10 | """
11 | Args:
12 | This is a custom data set for single cell transcriptomics data.
13 | It takes a sparse matrix of gene expression data and a pandas dataframe of metadata.
14 |
15 | sparse_mtrx:
16 | a scipy.sparse matrix of gene expression data with rows representing cells and columns representing transcripts
17 | meta_data:
18 | a pandas dataframe of metadata with rows representing cells and columns representing metadata
19 | scaling_type:
20 | a string specifying the type of scaling to use for the data. Options are 'mean' and 'max'
21 | this will either scale the data by the mean or max of each cell
22 | gene_selection (optional):
23 | a list of indices specifying which genes to use from the sparse matrix if feature selection is to be performed
24 | subset (optional):
25 | a list of indices specifying which cells to use from the sparse matrix if subsampling is to be performed
26 | label_type (optional):
27 | the label type of the characteristic that one wants to observe in clustering. It is usually used for cell type.
28 | It indicates the column name of the meta data provided.
29 | sparse (optional):
30 | a boolean indicating whether the data should be kept in sparse format or converted to a dense tensor.
31 | For small data sets, it is recommended to transform the data in dense format for faster training.
32 | For large ones, this helps keep the memory used in check.
33 | """
34 |
35 | self.scaling_type = scaling_type
36 | self.meta = meta_data
37 |
38 | if gene_selection is not None:
39 | sparse_mtrx = sparse_mtrx.tocsc()[:,gene_selection].tocoo()
40 | if subset is not None:
41 | sparse_mtrx = sparse_mtrx.tocsr()[subset]
42 | if sparse:
43 | self.sparse = True
44 | sparse_mtrx = sparse_mtrx.tocsr()
45 | self.data = sparse_mtrx
46 | if self.scaling_type == 'mean':
47 | self.library = torch.tensor(sparse_mtrx.mean(axis=-1).toarray())
48 | elif self.scaling_type == 'max':
49 | self.library = torch.tensor(sparse_mtrx.max(axis=-1).toarray())
50 | elif self.scaling_type == 'sum':
51 | self.library = torch.tensor(sparse_mtrx.sum(axis=-1).toarray())
52 |
53 | else:
54 | self.sparse = False
55 | self.data = torch.Tensor(sparse_mtrx.todense())
56 | if self.scaling_type == 'mean':
57 | self.library = torch.mean(self.data, dim=-1).unsqueeze(1)
58 | elif self.scaling_type == 'max':
59 | self.library = torch.max(self.data, dim=-1).values.unsqueeze(1)
60 | elif self.scaling_type == 'sum':
61 | self.library = torch.sum(self.data, dim=-1).unsqueeze(1)
62 |
63 | self.n_genes = self.data.shape[1]
64 |
65 | self.label_type = label_type
66 |
67 | def __len__(self):
68 | return(self.data.shape[0])
69 |
70 | def __getitem__(self, idx=None):
71 | if idx is not None:
72 | if self.sparse:
73 | expression = self.data[idx]
74 | else:
75 | expression = self.data[idx]
76 | else:
77 | expression = self.data
78 | idx = torch.arange(self.data.shape[0])
79 | lib = self.library[idx]
80 | return expression, lib, idx
81 |
82 | def get_labels(self, idx=None):
83 | if torch.is_tensor(idx):
84 | idx = idx.tolist()
85 | if idx is None:
86 | idx = np.arange(self.__len__())
87 | return np.asarray(np.array(self.meta[self.label_type])[idx])
88 |
89 | def get_labels_numerical(self, idx=None):
90 | if torch.is_tensor(idx):
91 | idx = idx.tolist()
92 | if idx is None:
93 | idx = np.arange(self.__len__())
94 | label_ids = np.argmax(np.expand_dims(np.asarray(self.meta[self.label_type]),0)>=np.expand_dims(idx,1),axis=1)
95 | return label_ids
96 |
97 |
98 | ###
99 | # functions used for the sparse option
100 | # here the data will have to be transformed to dense format when a batch is called
101 | ###
102 |
103 | def sparse_coo_to_tensor(mtrx):
104 | return torch.FloatTensor(mtrx.todense())
105 |
106 | def collate_sparse_batches(batch):
107 | data_batch, library_batch, idx_batch = zip(*batch)
108 | data_batch = scipy.sparse.vstack(list(data_batch))
109 | data_batch = sparse_coo_to_tensor(data_batch)
110 | library_batch = torch.stack(list(library_batch), dim=0)
111 | idx_batch = list(idx_batch)
112 | return data_batch, library_batch, idx_batch
113 |
--------------------------------------------------------------------------------
/scDGD/classes/output_distributions.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.distributions as D
5 |
6 |
7 | def logNBdensity(k, m, r):
8 | """
9 | Negative Binomial NB(k;m,r), where m is the mean and r is "number of failures"
10 | r can be real number (and so can k)
11 | k, and m are tensors of same shape
12 | r is tensor of shape (1, n_genes)
13 | Returns the log NB in same shape as k
14 | """
15 | # remember that gamma(n+1)=n!
16 | eps = 1.0e-10
17 | x = torch.lgamma(k + r)
18 | x -= torch.lgamma(r)
19 | x -= torch.lgamma(k + 1)
20 | x += k * torch.log(m * (r + m + eps) ** (-1) + eps)
21 | x += r * torch.log(r * (r + m + eps) ** (-1))
22 | return x
23 |
24 |
25 | class NBLayer(nn.Module):
26 | """
27 | A negative binomial of scaled values of m and learned parameters for r.
28 | mhat = m/M, where M is the scaling factor
29 |
30 | The scaled value mhat is typically the output value from the NN,
31 | but we are expanding to the option of modelling the probability.
32 |
33 | If rhat=None, it is assumed that the last half of mhat contains rhat.
34 |
35 | m = M*mhat
36 | """
37 |
38 | # def __init__(self, out_dim, r_init, scaling_type='max', output='mean', activation='sigmoid', reduction='none'):
39 | def __init__(
40 | self, out_dim, r_init, output="mean", activation="sigmoid", reduction="none"
41 | ):
42 | super(NBLayer, self).__init__()
43 |
44 | # initialize parameter for r
45 | # real-valued positive parameters are usually used as their log equivalent
46 | self.log_r = torch.nn.Parameter(
47 | torch.full(fill_value=math.log(r_init - 1), size=(1, out_dim)),
48 | requires_grad=True,
49 | )
50 | # self.scaling_type = scaling_type
51 | self.output = output
52 | self.activation = activation
53 | self.reduction = reduction
54 |
55 | # self.activ_layer = nn.ModuleList()
56 | if self.activation == "sigmoid":
57 | self.activ_layer = nn.Sigmoid()
58 | elif self.activation == "softmax":
59 | self.activ_layer = nn.Softmax(dim=-1)
60 | elif self.activation == "softplus":
61 | self.activ_layer = nn.Softplus()
62 | else:
63 | raise ValueError("Activation function not recognized")
64 |
65 | @property
66 | def r(self):
67 | return torch.exp(self.log_r) + 1
68 |
69 | def forward(self, x):
70 | if self.output == "mean":
71 | return self.activ_layer(x)
72 | elif self.output == "prob":
73 | p = self.activ_layer(x)
74 | mean = self.r * (1 - p) / p
75 | return mean
76 |
77 | # Convert to m from scaled variables
78 | def rescale(self, M, mhat):
79 | return M * mhat
80 |
81 | def loss(self, x, M, mhat, gene_id=None):
82 | # k has shape (nbatch,dim), M has shape (nbatch,1)
83 | # mhat has dim (nbatch,dim)
84 | # r has dim (1,dim)
85 | if gene_id is not None:
86 | loss = -logNBdensity(x, self.rescale(M, mhat), self.r[0, gene_id])
87 | else:
88 | loss = -logNBdensity(x, self.rescale(M, mhat), self.r)
89 | if self.reduction == "none":
90 | return loss
91 | elif self.reduction == "sum":
92 | return loss.sum()
93 |
94 | # The logprob of the tensor
95 | def logprob(self, x, M, mhat):
96 | return logNBdensity(x, self.rescale(M, mhat), self.r)
97 |
98 | def sample(self, nsample, M, mhat):
99 | # Note that torch.distributions.NegativeBinomial returns FLOAT and not int
100 | with torch.no_grad():
101 | m = self.rescale(M, mhat)
102 | # probs = m/(m+torch.exp(self.log_r))
103 | probs = self.r / (m + self.r)
104 | nb = torch.distributions.NegativeBinomial(self.r, probs=probs)
105 | return nb.sample([nsample]).squeeze()
106 |
--------------------------------------------------------------------------------
/scDGD/classes/prior.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.distributions as D
5 | import numpy as np
6 | import json
7 |
8 |
9 | class gaussian:
10 | """
11 | This is a simple Gaussian prior used for initializing mixture model means
12 | """
13 |
14 | def __init__(self, dim, mean, stddev):
15 | self.dim = dim
16 | self.mean = mean
17 | self.stddev = stddev
18 | self.g = torch.distributions.normal.Normal(mean, stddev)
19 |
20 | def sample(self, n):
21 | return self.g.sample((n, self.dim))
22 |
23 | def log_prob(self, x):
24 | return self.g.log_prob(x)
25 |
26 |
27 | class softball:
28 | """
29 | Almost uniform prior for the m-dimensional ball.
30 | Logistic function makes a soft (differentiable) boundary.
31 | Returns a prior function and a sample function.
32 | The prior takes a tensor with a batch of z
33 | vectors (last dim) and returns a tensor of prior log-probabilities.
34 | The sample function returns n samples from the prior (approximate
35 | samples uniform from the m-ball). NOTE: APPROXIMATE SAMPLING.
36 | """
37 |
38 | def __init__(self, dim, radius, a=1):
39 | self.dim = dim
40 | self.radius = radius
41 | self.a = a
42 | self.norm = math.lgamma(1 + dim * 0.5) - dim * (
43 | math.log(radius) + 0.5 * math.log(math.pi)
44 | )
45 |
46 | def sample(self, n):
47 | # Return n random samples
48 | # Approximate: We sample uniformly from n-ball
49 | with torch.no_grad():
50 | # Gaussian sample
51 | sample = torch.randn((n, self.dim))
52 | # n random directions
53 | sample.div_(sample.norm(dim=-1, keepdim=True))
54 | # n random lengths
55 | local_len = self.radius * torch.pow(torch.rand((n, 1)), 1.0 / self.dim)
56 | sample.mul_(local_len.expand(-1, self.dim))
57 | return sample
58 |
59 | def log_prob(self, z):
60 | # Return log probabilities of elements of tensor (last dim assumed to be z vectors)
61 | return self.norm - torch.log(
62 | 1 + torch.exp(self.a * (z.norm(dim=-1) / self.radius - 1))
63 | )
64 |
65 |
66 | class GaussianMixture(nn.Module):
67 | def __init__(
68 | self,
69 | Nmix,
70 | dim,
71 | type="diagonal",
72 | alpha=1,
73 | mean_init=(1.0, 1.0),
74 | sd_init=[1.0, 1.0],
75 | ):
76 | """
77 | A mixture of multi-variate Gaussians
78 |
79 | Nmix is the number of components in the mixture
80 | dim is the dimension of the space
81 | type can be "fixed", "isotropic" or "diagonal", which refers to the covariance matrices
82 | mean_prior is a prior class with a log_prob and sample function
83 | - Standard normal if not specified.
84 | - Other option is ('softball',,)
85 | If there is no mean_prior specified, a default Gaussian will be chosen with
86 | - mean_init[0] as mean and mean_init[1] as standard deviation
87 | logbeta_prior is a prior class for the negative log variance of the mixture components
88 | - logbeta = log (1/sigma^2)
89 | - If it is not specified, we make this prior a Gaussian from sd_init parameters
90 | - For the sake of interpretability, the sd_init parameters represent the desired mean and (approximately) sd of the standard deviation
91 | - the difference btw giving a prior beforehand and giving only init values is that with a given prior, the logbetas will be sampled from it, otherwise they will be initialized the same
92 | alpha determines the Dirichlet prior on mixture coefficients
93 | Mixture coefficients are initialized uniformly
94 | Other parameters are sampled from prior
95 | """
96 | super(GaussianMixture, self).__init__()
97 | self.dim = dim
98 | self.Nmix = Nmix
99 | # self.init = init
100 |
101 | # Means with shape: Nmix,dim
102 | self.mean = nn.Parameter(torch.empty(Nmix, dim), requires_grad=True)
103 | self.mean_prior = softball(self.dim, mean_init[0], mean_init[1])
104 |
105 | # Dirichlet prior on mixture
106 | self.alpha = alpha
107 | self.dirichlet_constant = math.lgamma(Nmix * alpha) - Nmix * math.lgamma(alpha)
108 |
109 | # Log inverse variance with shape (Nmix,dim) or (Nmix,1)
110 | self.sd_init = sd_init
111 | self.sd_init[0] = 0.2 * (mean_init[0] / self.Nmix)
112 | self.betafactor = dim * 0.5 # rename this!
113 | self.bdim = 1 # If 'diagonal' the dimension of lobbeta is = dim
114 | if type == "fixed":
115 | # No gradient needed for training
116 | # This is a column vector to be correctly broadcastet in std dev tensor
117 | self.logbeta = nn.Parameter(
118 | torch.empty(Nmix, self.bdim), requires_grad=False
119 | )
120 | self.logbeta_prior = None
121 | else:
122 | if type == "diagonal":
123 | self.betafactor = 0.5
124 | self.bdim = dim
125 | elif type != "isotropic":
126 | raise ValueError(
127 | "type must be 'isotropic' (default), 'diagonal', or 'fixed'"
128 | )
129 |
130 | self.logbeta = nn.Parameter(
131 | torch.empty(Nmix, self.bdim), requires_grad=True
132 | )
133 |
134 | # Mixture coefficients. These are weights for softmax
135 | self.weight = nn.Parameter(torch.empty(Nmix), requires_grad=True)
136 | self.init_params()
137 |
138 | # -dim*0.5*log(2pi)
139 | self.pi_term = -0.5 * self.dim * math.log(2 * math.pi)
140 |
141 | def init_params(self):
142 | with torch.no_grad():
143 | # Means are sampled from the prior
144 | self.mean.copy_(self.mean_prior.sample(self.Nmix))
145 | self.logbeta.fill_(-2 * math.log(self.sd_init[0]))
146 | self.logbeta_prior = gaussian(
147 | self.bdim, -2 * math.log(self.sd_init[0]), self.sd_init[1]
148 | )
149 |
150 | # Weights are initialized to 1, corresponding to uniform mixture coeffs
151 | self.weight.fill_(1)
152 |
153 | def forward(self, x, label=None):
154 | # The beta values are obtained from logbeta
155 | halfbeta = 0.5 * torch.exp(self.logbeta)
156 |
157 | # y = logp = - 0.5*log (2pi) -0.5*beta(x-mean[i])^2 + 0.5*log(beta)
158 | # sum terms for each component (sum is over last dimension)
159 | # y is one-dim with length Nmix
160 | # x is unsqueezed to (nsample,1,dim), so broadcasting of mean (Nmix,dim) works
161 | y = (
162 | self.pi_term
163 | - (x.unsqueeze(-2) - self.mean).square().mul(halfbeta).sum(-1)
164 | + self.betafactor * self.logbeta.sum(-1)
165 | )
166 | # For each component multiply by mixture probs
167 | y += torch.log_softmax(self.weight, dim=0)
168 | y = torch.logsumexp(y, dim=-1)
169 | y = y + self.prior() # += gives cuda error
170 |
171 | return y
172 |
173 | def log_prob(self, x): # Add label?
174 | self.forward(x)
175 |
176 | def mixture_probs(self):
177 | return torch.softmax(self.weight, dim=-1)
178 |
179 | def covariance(self):
180 | return torch.exp(-self.logbeta)
181 |
182 | def prior(self):
183 | """Calculate log prob of prior on mean, logbeta, and mixture coeff"""
184 | # Mixture
185 | p = self.dirichlet_constant # /self.Nmix
186 | if self.alpha != 1:
187 | p = p + (self.alpha - 1.0) * (
188 | self.mixture_probs().log().sum()
189 | ) # /self.Nmix
190 | # Means
191 | p = p + self.mean_prior.log_prob(self.mean).sum() # /self.Nmix
192 | # logbeta
193 | if self.logbeta_prior is not None:
194 | p = (
195 | p + self.logbeta_prior.log_prob(self.logbeta).sum()
196 | ) # /(self.Nmix*self.dim)
197 | return p
198 |
199 | def Distribution(self):
200 | with torch.no_grad():
201 | mix = D.Categorical(probs=torch.softmax(self.weight, dim=-1))
202 | comp = D.Independent(D.Normal(self.mean, torch.exp(-0.5 * self.logbeta)), 1)
203 | return D.MixtureSameFamily(mix, comp)
204 |
205 | def sample(self, nsample):
206 | with torch.no_grad():
207 | gmm = self.Distribution()
208 | return gmm.sample(torch.tensor([nsample]))
209 |
210 | def component_sample(self, nsample):
211 | """Returns a sample from each component. Tensor shape (nsample,nmix,dim)"""
212 | with torch.no_grad():
213 | comp = D.Independent(D.Normal(self.mean, torch.exp(-0.5 * self.logbeta)), 1)
214 | return comp.sample(torch.tensor([nsample]))
215 |
216 | def sample_probs(self, x):
217 | halfbeta = 0.5 * torch.exp(self.logbeta)
218 | y = (
219 | self.pi_term
220 | - (x.unsqueeze(-2) - self.mean).square().mul(halfbeta).sum(-1)
221 | + self.betafactor * self.logbeta.sum(-1)
222 | )
223 | y += torch.log_softmax(self.weight, dim=0)
224 | return torch.exp(y)
225 |
226 | """This section is for learning new data points"""
227 |
228 | def sample_new_points(self, n_points, option="random", n_new=1):
229 | """
230 | Generates samples for each new data point
231 | - n_points defines the number of new data points to learn
232 | - option defines which of 2 schemes to use
233 | - random: sample n_new vectors from each component
234 | -> Nmix * n_new values per new point
235 | - mean: take the mean of each component as initial representation values
236 | -> Nmix values per new point
237 | The order of repetition in both options is [a,a,a, b,b,b, c,c,c] on data point ID.
238 | """
239 | self.new_samples = n_new
240 | multiplier = self.Nmix
241 | if option == "random":
242 | out = self.component_sample(n_points * n_new)
243 | multiplier *= n_new
244 |
245 | elif option == "mean":
246 | with torch.no_grad():
247 | out = torch.repeat_interleave(
248 | self.mean.clone().cpu().detach().unsqueeze(0), n_points, dim=0
249 | )
250 | else:
251 | print(
252 | "Please specify how to initialize new representations correctly \nThe options are 'random' and 'mean'."
253 | )
254 | return out.view(n_points * self.new_samples * self.Nmix, self.dim)
255 |
256 | def reshape_targets(self, y, y_type="true"):
257 | """
258 | Since we have multiple representations for the same new data point,
259 | we need to reshape the output a little to calculate the losses
260 | Depending on the y_type, y can be
261 | - the true targets (y_type: 'true')
262 | - the model predictions (y_type: 'predicted') (can also be used for rep.z in dataloader loop)
263 | - the 4-dimensional representation or the loss (y_type: 'reverse')
264 | """
265 |
266 | if y_type == "true":
267 | if len(y.shape) > 2:
268 | raise ValueError(
269 | "Unexpected shape in input to function reshape_targets. Expected 2 dimensions, got "
270 | + str(len(y.shape))
271 | )
272 | return (
273 | y.unsqueeze(1).unsqueeze(1).expand(-1, self.new_samples, self.Nmix, -1)
274 | )
275 | elif y_type == "predicted":
276 | if len(y.shape) > 2:
277 | raise ValueError(
278 | "Unexpected shape in input to function reshape_targets. Expected 2 dimensions, got "
279 | + str(len(y.shape))
280 | )
281 | n_points = int(
282 | torch.numel(y) / (self.new_samples * self.Nmix * y.shape[-1])
283 | )
284 | return y.view(n_points, self.new_samples, self.Nmix, y.shape[-1])
285 | elif "reverse":
286 | if len(y.shape) < 4:
287 | # this case is for when the losses are of shape (n_points,self.new_samples,self.Nmix)
288 | return y.view(y.shape[0] * self.new_samples * self.Nmix)
289 | else:
290 | return y.view(y.shape[0] * self.new_samples * self.Nmix, y.shape[-1])
291 | else:
292 | raise ValueError(
293 | "The y_type in function reshape_targets was incorrect. Please choose between 'true' and 'predicted'."
294 | )
295 |
296 | def choose_best_representations(self, x, losses):
297 | """
298 | Selects the representation for each new datapoint that maximizes the objective
299 | - x are the newly learned representations
300 | - x and losses have to have the same shape in the first dimension
301 | make sure that the losses are only summed over the output dimension
302 | Outputs new representation values
303 | """
304 | n_points = int(torch.numel(losses) / (self.new_samples * self.Nmix))
305 |
306 | best_sample = torch.argmin(
307 | losses.view(-1, self.new_samples * self.Nmix), dim=1
308 | ).squeeze(-1)
309 | best_rep = x.view(n_points, self.new_samples * self.Nmix, self.dim)[
310 | range(n_points), best_sample
311 | ]
312 | # best_rep = torch.diagonal(x.view(n_points,self.new_samples*self.Nmix,self.dim)[:,best_sample],dim1=0,dim2=1).transpose(0,1)
313 |
314 | return best_rep
315 |
316 | def choose_old_or_new(self, z_new, loss_new, z_old, loss_old):
317 | if (len(z_new.shape) == 2) and (len(z_old.shape) == 2):
318 | z_conc = torch.cat((z_new.unsqueeze(1), z_old.unsqueeze(1)), dim=1)
319 | else:
320 | raise ValueError(
321 | "Unexpected shape in input to function choose_old_or_new. Expected 2 dimensions for z_new and z_old, got "
322 | + str(len(z_new.shape))
323 | + " and "
324 | + str(len(z_old.shape))
325 | )
326 |
327 | len_loss_new = len(loss_new.shape)
328 | for l in range(3 - len_loss_new):
329 | loss_new = loss_new.unsqueeze(1)
330 | len_loss_old = len(loss_old.shape)
331 | for l in range(3 - len_loss_old):
332 | loss_old = loss_old.unsqueeze(1)
333 | losses = torch.cat((loss_new, loss_old), dim=1)
334 |
335 | best_sample = torch.argmin(losses, dim=1).squeeze(-1)
336 | # print(str(best_sample.sum().item())+' out of '+str(z_new.shape[0])+' samples were resampled.')
337 |
338 | best_rep = z_conc[range(z_conc.shape[0]), best_sample]
339 |
340 | return best_rep, round((best_sample.sum().item() / z_new.shape[0]) * 100, 2)
341 |
342 | def clustering(self, z):
343 | """compute the cluster assignment (as int) for each sample"""
344 | return torch.argmax(self.sample_probs(torch.tensor(z)), dim=-1).to(torch.int16)
345 |
346 | @classmethod
347 | def load(cls, save_dir="./"):
348 | # get saved hyper-parameters
349 | with open(save_dir + "dgd_hyperparameters.json", "r") as fp:
350 | param_dict = json.load(fp)
351 |
352 | gmm = cls(
353 | Nmix=param_dict["n_components"],
354 | dim=param_dict["latent"],
355 | type="diagonal",
356 | alpha=param_dict["dirichlet_a"],
357 | mean_init=(param_dict["mp_scale"], param_dict["hardness"]),
358 | sd_init=[param_dict["sd_mean"], param_dict["sd_sd"]],
359 | )
360 |
361 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
362 | gmm.load_state_dict(
363 | torch.load(save_dir + param_dict["name"] + "_gmm.pt", map_location=device)
364 | )
365 | return gmm
366 |
367 |
368 | class GaussianMixtureSupervised(GaussianMixture):
369 | def __init__(
370 | self,
371 | Nclass,
372 | dim,
373 | Ncpc=1,
374 | type="diagonal",
375 | alpha=1,
376 | mean_init=(1.0, 1.0),
377 | sd_init=(1.0, 1.0),
378 | ):
379 | super(GaussianMixtureSupervised, self).__init__(
380 | Ncpc * Nclass,
381 | dim,
382 | type=type,
383 | alpha=alpha,
384 | mean_init=mean_init,
385 | sd_init=sd_init,
386 | )
387 | self.dim = dim
388 |
389 | self.Nclass = Nclass
390 | self.Ncpc = Ncpc
391 |
392 | def forward(self, x, label=None):
393 | if label is None:
394 | y = super().forward(x)
395 | return y
396 |
397 | if 999 in label:
398 | # first get normal loss
399 | idx_unsup = [i for i in range(len(label)) if label[i] == 999]
400 | y_unsup = super().forward(x[idx_unsup])
401 | # Otherwise use the component corresponding to the label
402 | idx_sup = [i for i in range(len(label)) if label[i] != 999]
403 | halfbeta = 0.5 * torch.exp(self.logbeta)
404 | # Pick only the Nclc components belonging class
405 | # y_sup = self.pi_term - (x.unsqueeze(-2).unsqueeze(-2) - self.mean.view(self.Nclass,self.Ncpc,-1)).square().mul(halfbeta.unsqueeze(-2))[label[idx_sup]].sum(-1).sum(-1) + self.betafactor*self.logbeta.view(self.Nclass,self.Ncpc,-1).sum(-1).sum(-1)
406 | y_sup = (
407 | self.pi_term
408 | - (
409 | x.unsqueeze(-2).unsqueeze(-2)
410 | - self.mean.view(self.Nclass, self.Ncpc, -1)
411 | )
412 | .square()
413 | .mul(halfbeta.unsqueeze(-2))
414 | .sum(-1)
415 | + (self.betafactor * self.logbeta.view(self.Nclass, self.Ncpc, -1)).sum(
416 | -1
417 | )
418 | )
419 | y_sup += torch.log_softmax(self.weight.view(self.Nclass, self.Ncpc), dim=-1)
420 | y_sup = y_sup.sum(-1)
421 | y_sup = torch.abs(
422 | y_sup[(idx_sup, label[idx_sup])] * self.Nclass
423 | ) # this is replacement for logsumexp of supervised samples
424 | # put together y
425 | y = torch.empty((x.shape[0]), dtype=torch.float32).to(x.device)
426 | y[idx_unsup] = y_unsup
427 | y[idx_sup] = y_sup
428 | else:
429 | halfbeta = 0.5 * torch.exp(self.logbeta)
430 | # Pick only the Nclc components belonging class
431 | y = (
432 | self.pi_term
433 | - (
434 | x.unsqueeze(-2).unsqueeze(-2)
435 | - self.mean.view(self.Nclass, self.Ncpc, -1)
436 | )
437 | .square()
438 | .mul(halfbeta.unsqueeze(-2))
439 | .sum(-1)
440 | + (self.betafactor * self.logbeta.view(self.Nclass, self.Ncpc, -1)).sum(
441 | -1
442 | )
443 | )
444 | y += torch.log_softmax(self.weight.view(self.Nclass, self.Ncpc), dim=-1)
445 | y = y.sum(-1)
446 | # y = torch.abs(y[(np.arange(y.shape[0]),label)] * self.Nclass) # this is replacement for logsumexp of supervised samples
447 | y = (
448 | y[(np.arange(y.shape[0]), label)] * self.Nclass
449 | ) # this is replacement for logsumexp of supervised samples
450 |
451 | y = y + super().prior()
452 | return y
453 |
454 | def log_prob(self, x, label=None):
455 | self.forward(x, label=label)
456 |
457 | def label_mixture_probs(self, label):
458 | return torch.softmax(self.weight[label], dim=-1)
459 |
--------------------------------------------------------------------------------
/scDGD/classes/representation.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.distributions as D
5 |
6 | class RepresentationLayer(torch.nn.Module):
7 | '''
8 | Implements a representation layer, that accumulates pytorch gradients.
9 |
10 | Representations are vectors in nrep-dimensional real space. By default
11 | they will be initialized as a tensor of dimension nsample x nrep from a
12 | normal distribution (mean and variance given by init).
13 |
14 | One can also supply a tensor to initialize the representations (values=tensor).
15 | The representations will then have the same dimension and will assumes that
16 | the first dimension is nsample (and the last is nrep).
17 |
18 | forward() takes a sample index and returns the representation.
19 |
20 | Representations are "linear", so a representation layer may be followed
21 | by an activation function.
22 |
23 | To update representations, the pytorch optimizers do not always work well,
24 | so the module comes with it's own SGD update (self.update(lr,mom,...)).
25 |
26 | If the loss has reduction="sum", things work well. If it is ="mean", the
27 | gradients become very small and the learning rate needs to be rescaled
28 | accordingly (batchsize*output_dim).
29 |
30 | Do not forget to call the zero_grad() before each epoch (not inside the loop
31 | like with the weights).
32 |
33 | '''
34 | def __init__(self,
35 | nrep, # Dimension of representation
36 | nsample, # Number of training samples
37 | init=(0.,1.),# Normal distribution mean and stddev for
38 | # initializing representations
39 | values=None # If values is given, the other parameters are ignored
40 | ):
41 | super(RepresentationLayer, self).__init__()
42 | self.dz = None
43 | if values is None:
44 | self.nrep=nrep
45 | self.nsample=nsample
46 | self.mean, self.stddev = init[0],init[1]
47 | self.init_random(self.mean,self.stddev)
48 | else:
49 | # Initialize representations from a tensor with values
50 | self.nrep = values.shape[-1]
51 | self.nsample = values.shape[0]
52 | self.mean, self.stddev = None, None
53 | # Is this the way to copy values to a parameter?
54 | self.z = torch.nn.Parameter(torch.zeros_like(values), requires_grad=True)
55 | with torch.no_grad():
56 | self.z.copy_(values)
57 |
58 | def init_random(self,mean,stddev):
59 | # Generate random representations
60 | self.z = torch.nn.Parameter(torch.normal(mean,stddev,size=(self.nsample,self.nrep), requires_grad=True))
61 |
62 | def forward(self, idx=None):
63 | if idx is None:
64 | return self.z
65 | else:
66 | return self.z[idx]
67 |
68 | # Index can be whatever it can be for a torch.tensor (e.g. tensor of idxs)
69 | def __getitem__(self,idx):
70 | return self.z[idx]
71 |
72 | def fix(self):
73 | self.z.requires_grad = False
74 |
75 | def unfix(self):
76 | self.z.requires_grad = True
77 |
78 | def zero_grad(self): # Used only if the update function is used
79 | if self.z.grad is not None:
80 | self.z.grad.detach_()
81 | self.z.grad.zero_()
82 |
83 | def update(self,idx=None,lr=0.001,mom=0.9,wd=None):
84 | if self.dz is None:
85 | self.dz = torch.zeros(self.z.size()).to(self.z.device)
86 | with torch.no_grad():
87 | # Update z
88 | # dz(k,j) = sum_i grad(k,i) w(i,j) step(z(j))
89 | self.dz[idx] = self.dz[idx].mul(mom) - self.z.grad[idx].mul(lr)
90 | if wd is not None:
91 | self.dz[idx] -= wd*self.z[idx]
92 | self.z[idx] += self.dz[idx]
93 |
94 | def rescale(self):
95 | z_flat = torch.flatten(self.z.cpu().detach())
96 | sd, m = torch.std_mean(z_flat)
97 | with torch.no_grad():
98 | self.z -= m
99 | self.z /= sd
--------------------------------------------------------------------------------
/scDGD/functions/__init__.py:
--------------------------------------------------------------------------------
1 | from scDGD.functions.data import prepate_data
2 | from scDGD.functions.train import dgd_train
--------------------------------------------------------------------------------
/scDGD/functions/analysis.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | from sklearn import preprocessing
4 | import torch
5 | from sklearn.metrics import confusion_matrix
6 | import operator
7 | import sklearn.metrics
8 | import json
9 | from scDGD.classes import GaussianMixture
10 | import seaborn as sns
11 | import matplotlib.pyplot as plt
12 | import umap
13 |
14 |
15 | def order_matrix_by_max_per_class(mtrx, class_labels, comp_order=None):
16 | if comp_order is not None:
17 | temp_mtrx = np.zeros(mtrx.shape)
18 | for i in range(mtrx.shape[1]):
19 | temp_mtrx[:, i] = mtrx[:, comp_order[i]]
20 | mtrx = temp_mtrx
21 | max_id_per_class = np.argmax(mtrx, axis=1)
22 | max_coordinates = list(zip(np.arange(mtrx.shape[0]), max_id_per_class))
23 | max_coordinates.sort(key=operator.itemgetter(1))
24 | new_class_order = [x[0] for x in max_coordinates]
25 | new_mtrx = np.zeros(mtrx.shape)
26 | # reindexing mtrx worked on test but not in application, reverting to stupid safe for-loop
27 | for i in range(mtrx.shape[0]):
28 | new_mtrx[i, :] = mtrx[new_class_order[i], :]
29 | # mtrx = mtrx[new_class_order,:]
30 | return new_mtrx, [class_labels[i] for i in new_class_order]
31 |
32 |
33 | def gmm_clustering(r, gmm, labels):
34 | # transform categorical labels into numerical
35 | le = preprocessing.LabelEncoder()
36 | le.fit(labels)
37 | true_labels = le.transform(labels)
38 | # compute probabilities per sample and component (n_sample,n_mix_comp)
39 | probs_per_sample_and_component = gmm.sample_probs(torch.tensor(r))
40 | # get index (i.e. component id) of max prob per sample
41 | cluster_labels = (
42 | torch.max(probs_per_sample_and_component, dim=-1).indices.cpu().detach()
43 | )
44 | return cluster_labels
45 |
46 |
47 | def compute_distances(mtrx):
48 | distances = sklearn.metrics.pairwise.euclidean_distances(mtrx)
49 | return distances
50 |
51 |
52 | def get_connectivity_from_threshold(mtrx, threshold):
53 | connectivity_mtrx = np.zeros(mtrx.shape)
54 | idx = np.where(mtrx <= threshold)
55 | connectivity_mtrx[idx[0], idx[1]] = 1
56 | np.fill_diagonal(connectivity_mtrx, 0)
57 | return connectivity_mtrx
58 |
59 |
60 | def rank_distances(mtrx):
61 | ranks = np.argsort(mtrx, axis=-1)
62 | # testing advanced indexing for ranking
63 | m, n = mtrx.shape
64 | # Initialize output array
65 | out = np.empty((m, n), dtype=int)
66 | # Use sidx as column indices, while a range array for the row indices
67 | # to select one element per row. Since sidx is a 2D array of indices
68 | # we need to use a 2D extended range array for the row indices
69 | out[np.arange(m)[:, None], ranks] = np.arange(n)
70 | return out
71 | # return ranks
72 |
73 |
74 | def get_node_degrees(mtrx):
75 | return np.sum(mtrx, 1)
76 |
77 |
78 | def get_secondary_degrees(mtrx):
79 | out = np.zeros(mtrx.shape[0])
80 | for i in range(mtrx.shape[0]):
81 | direct_neighbors = np.where(mtrx[i] == 1)[0]
82 | out[i] = mtrx[direct_neighbors, :].sum()
83 | return out
84 |
85 |
86 | def find_start_node(d1, d2):
87 | minimum_first_degree = np.where(d1 == d1.min())[0]
88 | if len(minimum_first_degree) > 1:
89 | minimum_second_degree_subset = np.where(
90 | d2[minimum_first_degree] == d2[minimum_first_degree].min()
91 | )[0][0]
92 | return minimum_first_degree[minimum_second_degree_subset]
93 | else:
94 | return minimum_first_degree[0]
95 |
96 |
97 | def find_next_node(c, r, i):
98 | connected_nodes = np.where(c[i, :] == 1)[0]
99 | if len(connected_nodes) > 0:
100 | connected_nodes_paired = list(zip(connected_nodes, r[i, connected_nodes]))
101 | connected_nodes_paired.sort(key=operator.itemgetter(1))
102 | connected_nodes = [
103 | connected_nodes_paired[x][0] for x in range(len(connected_nodes))
104 | ]
105 | return connected_nodes
106 |
107 |
108 | def traverse_through_graph(connectiv_mtrx, ranks, first_degrees, second_degrees):
109 | # create 2 lists of node ids
110 | # the first one keeps track of the nodes we have used already (and stores them in the desired order)
111 | # the second one keeps track of the nodes we still have to sort
112 | node_order = []
113 | nodes_to_be_distributed = list(np.arange(connectiv_mtrx.shape[0]))
114 |
115 | start_node = find_start_node(first_degrees, second_degrees)
116 | node_order.append(start_node)
117 | nodes_to_be_distributed.remove(start_node)
118 |
119 | count_turns = 0
120 | while len(nodes_to_be_distributed) > 0:
121 | next_nodes = find_next_node(connectiv_mtrx, ranks, node_order[-1])
122 | next_nodes = list(set(next_nodes).difference(set(node_order)))
123 | if len(next_nodes) < 1:
124 | next_nodes = [
125 | nodes_to_be_distributed[
126 | find_start_node(
127 | first_degrees[nodes_to_be_distributed],
128 | second_degrees[nodes_to_be_distributed],
129 | )
130 | ]
131 | ]
132 | for n in next_nodes:
133 | if n not in node_order:
134 | node_order.append(n)
135 | nodes_to_be_distributed.remove(n)
136 | count_turns = 0
137 | count_turns += 1
138 | if count_turns >= 10:
139 | break
140 |
141 | return node_order
142 |
143 |
144 | def order_components_as_graph_traversal(gmm):
145 | distance_mtrx = compute_distances(gmm.mean.detach().cpu().numpy())
146 | threshold = round(np.percentile(distance_mtrx.flatten(), 30), 2)
147 |
148 | connectivity_mtrx = get_connectivity_from_threshold(distance_mtrx, threshold)
149 | rank_mtrx = rank_distances(distance_mtrx)
150 |
151 | node_degrees = get_node_degrees(connectivity_mtrx)
152 | secondary_node_degrees = get_secondary_degrees(connectivity_mtrx)
153 |
154 | new_node_order = traverse_through_graph(
155 | connectivity_mtrx, rank_mtrx, node_degrees, secondary_node_degrees
156 | )
157 | return new_node_order
158 |
159 |
160 | def clustering_matrix(gmm, rep, labels, norm=True):
161 | classes = list(np.unique(labels))
162 | true_labels = np.asarray([classes.index(i) for i in labels])
163 | cluster_labels = gmm_clustering(rep, gmm, labels)
164 | # get absolute confusion matrix
165 | cm1 = confusion_matrix(true_labels, cluster_labels)
166 |
167 | class_counts = [np.where(true_labels == i)[0].shape[0] for i in range(len(classes))]
168 | cm2 = cm1.astype(np.float64)
169 | for i in range(len(class_counts)):
170 | # percent_sum = 0
171 | for j in range(gmm.Nmix):
172 | if norm:
173 | cm2[i, j] = cm2[i, j] * 100 / class_counts[i]
174 | else:
175 | cm2[i, j] = cm2[i, j]
176 | cm2 = cm2.round()
177 |
178 | # get an order of components based on connectivity graph
179 | component_order = order_components_as_graph_traversal(gmm)
180 |
181 | # take the non-empty entries
182 | cm2 = cm2[: len(classes), : gmm.Nmix]
183 |
184 | cm3, classes_reordered = order_matrix_by_max_per_class(
185 | cm2, classes, component_order
186 | )
187 | out = pd.DataFrame(data=cm3, index=classes_reordered, columns=component_order)
188 | return out
189 |
190 |
191 | def load_embedding(save_dir="./"):
192 | # get parameter dict
193 | with open(save_dir + "dgd_hyperparameters.json", "r") as fp:
194 | param_dict = json.load(fp)
195 |
196 | embedding = torch.load(
197 | save_dir + param_dict["name"] + "_representation.pt",
198 | map_location=torch.device("cpu"),
199 | )
200 | return embedding["z"].detach().cpu().numpy()
201 |
202 |
203 | def load_labels(save_dir="./", split="train"):
204 | # load train-val-test split
205 | obs = pd.read_csv(save_dir + "data_obs.csv", index_col=0)
206 | return obs[obs["train_val_test"] == split]["cell_type"].values
207 |
208 |
209 | def plot_cluster_heatmap(save_dir="./"):
210 | gmm = GaussianMixture.load(save_dir=save_dir)
211 | embedding = load_embedding(save_dir=save_dir)
212 | labels = load_labels()
213 |
214 | df_relative_clustering = clustering_matrix(gmm, embedding, labels)
215 | df_clustering = clustering_matrix(gmm, embedding, labels, norm=False)
216 |
217 | annotations = df_relative_clustering.to_numpy(dtype=np.float64).copy()
218 | annotations[annotations < 1] = None
219 | df_relative_clustering = df_relative_clustering.fillna(0)
220 | fig, ax = plt.subplots(figsize=(0.5*annotations.shape[0], 0.5*annotations.shape[1]))
221 | cmap = sns.color_palette("GnBu", as_cmap=True)
222 | sns.heatmap(
223 | df_relative_clustering,
224 | annot=annotations,
225 | cmap=cmap,
226 | annot_kws={"size": 6},
227 | cbar_kws={"shrink": 0.5, "location": "bottom"},
228 | xticklabels=True,
229 | yticklabels=True,
230 | mask=np.isnan(annotations),
231 | alpha=0.8,
232 | )
233 | ylabels = [
234 | df_clustering.index[x] + " (" + str(int(df_clustering.sum(axis=1)[x])) + ")"
235 | for x in range(df_clustering.shape[0])
236 | ]
237 | plt.yticks(
238 | ticks=np.arange(len(ylabels)) + 0.5, labels=ylabels, rotation=0, fontsize=8
239 | )
240 | plt.tick_params(axis="x", rotation=0, labelsize=8)
241 | plt.ylabel("Cell type")
242 | plt.xlabel("GMM component ID")
243 | plt.title("percentage of cell type in GMM cluster")
244 | plt.show()
245 |
246 |
247 | def plot_latent_umap(save_dir="./", n_neighbors=15, min_dist=0.5):
248 | gmm = GaussianMixture.load(save_dir=save_dir)
249 | embedding = load_embedding(save_dir=save_dir)
250 | labels = load_labels()
251 |
252 | # make umap
253 | reducer = umap.UMAP(n_neighbors=n_neighbors, n_components=2, min_dist=min_dist)
254 | projected = reducer.fit_transform(embedding)
255 | plot_data = pd.DataFrame(projected, columns=["UMAP1", "UMAP2"])
256 | plot_data["cell type"] = labels
257 | plot_data["cell type"] = plot_data["cell type"].astype("category")
258 | plot_data["cluster"] = (
259 | gmm.clustering(embedding).cpu().detach().numpy()
260 | ) # .astype(str)
261 | plot_data["cluster"] = plot_data["cluster"].astype("category")
262 |
263 | # make a plot with two subplots
264 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
265 | # adjust spacing between plots
266 | fig.subplots_adjust(wspace=1.0)
267 | # make text size smaller
268 | plt.rcParams.update({"font.size": 6})
269 | sns.scatterplot(data=plot_data, x="UMAP1", y="UMAP2", hue="cell type", ax=ax1, s=1)
270 | sns.scatterplot(data=plot_data, x="UMAP1", y="UMAP2", hue="cluster", ax=ax2, s=1)
271 | ax1.set_title("cell type")
272 | ax1.legend(bbox_to_anchor=(1.02, 1), loc=2, borderaxespad=0.0, frameon=False)
273 | ax2.legend(bbox_to_anchor=(1.02, 1), loc=2, borderaxespad=0.0, frameon=False)
274 | ax2.set_title("cluster")
275 | plt.show()
276 |
--------------------------------------------------------------------------------
/scDGD/functions/data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from sklearn.model_selection import train_test_split
4 | from scDGD.classes import scDataset
5 |
6 | def prepate_data(adata, label_column, train_fraction=0.8, include_test=True, scaling_type='max', batch_size=256, num_w=0):
7 | '''
8 | Prepares the pytorch data sets and loaders for training and testing
9 |
10 | For integrating a new data set, set the train_fraction to 1. Otherwise there should always be something left for validation.
11 | If include_test is True, the split will also include a held-out test set. Otherwise, it will only be train and validation.
12 | '''
13 |
14 | ###
15 | # first create a data split
16 | ###
17 | labels = adata.obs[label_column]
18 |
19 | train_mode = True
20 | if 'train_val_test' not in adata.obs.keys():
21 | if train_fraction < 1.0:
22 | if include_test:
23 | train_indices, test_indices = train_test_split(np.arange(len(labels)), test_size=(1.0-train_fraction)/2, stratify=labels)
24 | train_indices, val_indices = train_test_split(train_indices, test_size=(((1.0-train_fraction)/2)/(1.0-(1.0-train_fraction)/2)), stratify=labels[train_indices])
25 | # add the split to the anndata object
26 | train_val_test = [''] * len(labels)
27 | train_val_test = ['train' if i in train_indices else train_val_test[i] for i in range(len(labels))]
28 | train_val_test = ['validation' if i in val_indices else train_val_test[i] for i in range(len(labels))]
29 | train_val_test = ['test' if i in test_indices else train_val_test[i] for i in range(len(labels))]
30 | else:
31 | train_indices, val_indices = train_test_split(np.arange(len(labels)), test_size=(1.0-train_fraction), stratify=labels)
32 | train_val_test = [''] * len(labels)
33 | train_val_test = ['train' if i in train_indices else train_val_test[i] for i in range(len(labels))]
34 | train_val_test = ['validation' if i in val_indices else train_val_test[i] for i in range(len(labels))]
35 | else:
36 | train_mode = False
37 | train_val_test = 'test'
38 | adata.obs['train_val_test'] = train_val_test
39 | else:
40 | if len(set(adata.obs['train_val_test'])) == 1:
41 | train_mode = False
42 | adata.obs['label'] = labels
43 | # make sure to afterwards also return the adata object so that the data split can be re-used
44 |
45 | ###
46 | # then create the data sets and loaders
47 | ###
48 |
49 | if train_mode:
50 | trainset = scDataset(
51 | adata.X,
52 | adata.obs,
53 | scaling_type=scaling_type,
54 | subset=np.where(adata.obs['train_val_test']=='train')[0],
55 | label_type='label'
56 | )
57 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_w)
58 | validationset = scDataset(
59 | adata.X,
60 | adata.obs,
61 | scaling_type=scaling_type,
62 | subset=np.where(adata.obs['train_val_test']=='validation')[0],
63 | label_type='label'
64 | )
65 | validationloader = torch.utils.data.DataLoader(validationset, batch_size=batch_size, shuffle=True, num_workers=num_w)
66 | if len(set(adata.obs['train_val_test'])) == 3:
67 | testset = scDataset(
68 | adata.X,
69 | adata.obs,
70 | scaling_type=scaling_type,
71 | subset=np.where(adata.obs['train_val_test']=='test')[0],
72 | label_type='label'
73 | )
74 | testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=num_w)
75 | else:
76 | testset, testloader = None, None
77 | else:
78 | testset = scDataset(
79 | adata.X,
80 | adata.obs,
81 | scaling_type=scaling_type,
82 | subset=np.where(adata.obs['train_val_test']=='test')[0],
83 | label_type='label'
84 | )
85 | testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=num_w)
86 |
87 | return adata, trainloader, validationloader, testloader
88 |
89 |
90 |
91 |
92 |
--------------------------------------------------------------------------------
/scDGD/functions/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import pandas as pd
5 |
6 | from sklearn.metrics import confusion_matrix
7 | from scipy.optimize import linear_sum_assignment as linear_assignment
8 | from sklearn import preprocessing
9 |
10 | from scDGD.classes.representation import RepresentationLayer
11 |
12 | def set_random_seed(seed):
13 | torch.manual_seed(seed)
14 | np.random.seed(seed)
15 | if torch.cuda.is_available():
16 | torch.cuda.manual_seed_all(seed)
17 |
18 | def reshape_scaling_factor(x, out_dim):
19 | start_dim = len(x.shape)
20 | for i in range(out_dim - start_dim):
21 | x = x.unsqueeze(1)
22 | return x
23 |
24 | def _make_cost_m(cm):
25 | s = np.max(cm)
26 | return (- cm + s)
27 |
28 | def gmm_clustering(gmm, rep):
29 | halfbeta = 0.5*torch.exp(gmm.logbeta)
30 | y = gmm.pi_term - (rep.unsqueeze(-2)-gmm.mean).square().mul(halfbeta).sum(-1) + gmm.betafactor*gmm.logbeta.sum(-1)
31 | # For each component multiply by mixture probs
32 | y += torch.log_softmax(gmm.weight,dim=0)
33 | return torch.exp(y)
34 |
35 | def gmm_cluster_acc(r, gmm, labels):
36 | le = preprocessing.LabelEncoder()
37 | le.fit(labels)
38 | true_labels = le.transform(labels)
39 | clustering = gmm_clustering(gmm, r.z.detach())
40 | cluster_labels = torch.max(clustering, dim=-1).indices.cpu().detach()
41 | cm = confusion_matrix(true_labels, cluster_labels)
42 | indexes = linear_assignment(_make_cost_m(cm))
43 | cm2 = cm[:,indexes[1]]
44 | acc2 = np.trace(cm2) / np.sum(cm2)
45 | return acc2
46 |
47 | def dgd_train(model, gmm, train_loader, validation_loader, n_epochs=500,
48 | export_dir='./', export_name='scDGD',
49 | lr_schedule_epochs=[0,300],lr_schedule=[[1e-3,1e-2,1e-2],[1e-4,1e-2,1e-2]], optim_betas=[0.5,0.7], wd=1e-4,
50 | acc_save_threshold=0.5,supervision_labels=None, wandb_logging=False):
51 |
52 | if wandb_logging:
53 | import wandb
54 |
55 | # prepare for saving the model
56 | if export_name is not None:
57 | if not os.path.exists(export_dir+export_name):
58 | os.makedirs(export_dir+export_name)
59 |
60 | # get some info from the data
61 | nsample = len(train_loader.dataset)
62 | nsample_test = len(validation_loader.dataset)
63 | out_dim = train_loader.dataset.n_genes
64 | latent = gmm.dim
65 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
66 |
67 | model=model.to(device)
68 | gmm=gmm.to(device)
69 |
70 | ###
71 | # set up representations and optimizers
72 | ###
73 |
74 | if lr_schedule_epochs is None:
75 | lr = lr_schedule[0]
76 | lr_rep = lr_schedule[1]
77 | lr_gmm = lr_schedule[2]
78 | else:
79 | lr = lr_schedule[0][0]
80 | lr_rep = lr_schedule[0][1]
81 | lr_gmm = lr_schedule[0][2]
82 |
83 | rep = RepresentationLayer(nrep=latent,nsample=nsample,values=torch.zeros(size=(nsample,latent))).to(device)
84 | test_rep = RepresentationLayer(nrep=latent,nsample=nsample_test,values=torch.zeros(size=(nsample_test,latent))).to(device)
85 | rep_optimizer = torch.optim.Adam(rep.parameters(), lr=lr_rep, weight_decay=wd,betas=(optim_betas[0],optim_betas[1]))
86 | testrep_optimizer = torch.optim.Adam(test_rep.parameters(), lr=lr_rep, weight_decay=wd,betas=(optim_betas[0],optim_betas[1]))
87 |
88 | if gmm is not None:
89 | gmm_optimizer = torch.optim.Adam(gmm.parameters(), lr=lr_gmm, weight_decay=wd,betas=(optim_betas[0],optim_betas[1]))
90 | model_optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd,betas=(optim_betas[0],optim_betas[1]))
91 |
92 | ###
93 | # start training
94 | ###
95 |
96 | # keep track of the losses and other metrics
97 | train_avg = []
98 | recon_avg = []
99 | dist_avg = []
100 | test_avg = []
101 | recon_test_avg = []
102 | dist_test_avg = []
103 | cluster_accuracies = []
104 | best_gmm_cluster = 0
105 |
106 | for epoch in range(n_epochs):
107 |
108 | # in case there is a scheduled change in learning rates, change them if the specified epoch is reached
109 | if lr_schedule_epochs is not None:
110 | if epoch in lr_schedule_epochs:
111 | lr_idx = [x for x in range(len(lr_schedule_epochs)) if lr_schedule_epochs[x] == epoch][0]
112 | lr_decoder = lr_schedule[lr_idx][0]
113 | model_optimizer = torch.optim.Adam(model.parameters(), lr=lr_decoder, weight_decay=wd, betas=(optim_betas[0],optim_betas[1]))
114 | lr_rep = lr_schedule[lr_idx][1]
115 | lr_gmm = lr_schedule[lr_idx][2]
116 | if gmm is not None:
117 | gmm_optimizer = torch.optim.Adam(gmm.parameters(), lr=lr_gmm, weight_decay=wd, betas=(optim_betas[0],optim_betas[1]))
118 |
119 | # collect losses and other metrics for the epoch
120 | train_avg.append(0)
121 | recon_avg.append(0)
122 | dist_avg.append(0)
123 | test_avg.append(0)
124 | recon_test_avg.append(0)
125 | dist_test_avg.append(0)
126 |
127 | # train
128 | model.train()
129 | rep_optimizer.zero_grad()
130 | # standard mini batching
131 | for x,lib,i in train_loader:
132 | gmm_optimizer.zero_grad()
133 | model_optimizer.zero_grad()
134 |
135 | x = x.to(device)
136 | lib = lib.to(device)
137 | z = rep(i)
138 | y = model(z)
139 |
140 | # compute losses
141 | recon_loss_x = model.nb.loss(x, lib, y).sum()
142 | if supervision_labels is not None:
143 | sup_i = supervision_labels[i]
144 | gmm_error = -gmm(z,sup_i).sum()
145 | else:
146 | gmm_error = -gmm(z).sum()
147 | loss = recon_loss_x.clone() + gmm_error.clone()
148 |
149 | # backpropagate and update
150 | loss.backward()
151 | gmm_optimizer.step()
152 | model_optimizer.step()
153 |
154 | # log losses
155 | train_avg[-1] += loss.item()/(nsample*out_dim)
156 | recon_avg[-1] += recon_loss_x.item()/(nsample*out_dim)
157 | dist_avg[-1] += gmm_error.item()/(nsample*latent)
158 |
159 | # update representations
160 | rep_optimizer.step()
161 |
162 | # validation run
163 | model.eval()
164 | testrep_optimizer.zero_grad()
165 | # same as above, but without updates for gmm and model
166 | for x,lib,i in validation_loader:
167 | x = x.to(device)
168 | lib = lib.to(device)
169 | z = test_rep(i)
170 | y = model(z)
171 | recon_loss_x = model.nb.loss(x, lib, y).sum()
172 | gmm_error = -gmm(z).sum()
173 | loss = recon_loss_x.clone() + gmm_error.clone()
174 | loss.backward()
175 |
176 | test_avg[-1] += loss.item()/(nsample_test*out_dim)
177 | recon_test_avg[-1] += recon_loss_x.item()/(nsample_test*out_dim)
178 | dist_test_avg[-1] += gmm_error.item()/(nsample_test*latent)
179 | testrep_optimizer.step()
180 |
181 | cluster_accuracies.append(gmm_cluster_acc(rep, gmm, train_loader.dataset.get_labels()))
182 |
183 | save_here = False
184 | if best_gmm_cluster < cluster_accuracies[-1]:
185 | best_gmm_cluster = cluster_accuracies[-1]
186 | best_gmm_epoch = epoch
187 | if best_gmm_cluster >= acc_save_threshold:
188 | save_here = True
189 | elif epoch == n_epochs-1:
190 | save_here = True
191 |
192 | if wandb_logging:
193 | wandb.log({"loss_train": train_avg[-1],
194 | "loss_test": test_avg[-1],
195 | "loss_recon_train": recon_avg[-1],
196 | "loss_recon_test": recon_test_avg[-1],
197 | "loss_gmm_train": dist_avg[-1],
198 | "loss_gmm_test": dist_test_avg[-1],
199 | "cluster_accuracy": cluster_accuracies[-1],
200 | "epoch": epoch})
201 | else:
202 | # print progress every 10 epochs
203 | if epoch % 10 == 0:
204 | print("epoch "+str(epoch)+": train loss "+str(train_avg[-1])+", validation loss "+str(test_avg[-1])+", cluster accuracy "+str(cluster_accuracies[-1]))
205 |
206 | if export_name is not None:
207 | if save_here:
208 | if epoch == n_epochs-1:
209 | print("model saved at epoch "+str(epoch)+" for having reached the end of training")
210 | else:
211 | print("model saved at epoch "+str(epoch)+" for having so far highest accuracy of "+str(cluster_accuracies[-1]))
212 | torch.save(model.state_dict(), export_dir+export_name+'/'+export_name+'_decoder.pt')
213 | torch.save(rep.state_dict(), export_dir+export_name+'/'+export_name+'_representation.pt')
214 | torch.save(test_rep.state_dict(), export_dir+export_name+'/'+export_name+'_valRepresentation.pt')
215 | torch.save(gmm.state_dict(), export_dir+export_name+'/'+export_name+'_gmm.pt')
216 |
217 | if wandb_logging:
218 | wandb.run.summary["best_gmm_cluster"] = best_gmm_cluster
219 | wandb.run.summary["best_gmm_epoch"] = best_gmm_epoch
220 | # create a history that is returned after training
221 | history = pd.DataFrame(
222 | {'train_loss': train_avg,
223 | 'test_loss': test_avg,
224 | 'train_recon_loss': recon_avg,
225 | 'test_recon_loss': recon_test_avg,
226 | 'train_gmm_loss': dist_avg,
227 | 'test_gmm_loss': dist_test_avg,
228 | 'cluster_accuracy': cluster_accuracies,
229 | 'epoch': np.arange(1,n_epochs+1)
230 | })
231 | return model, rep, test_rep, gmm, history
--------------------------------------------------------------------------------
/scDGD/models/__init__.py:
--------------------------------------------------------------------------------
1 | from scDGD.models.models import DGD
--------------------------------------------------------------------------------
/scDGD/models/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from scDGD.classes.output_distributions import NBLayer
4 | import json
5 |
6 |
7 |
8 | class DGD(nn.Module):
9 | def __init__(
10 | self,
11 | out,
12 | latent=20,
13 | hidden=[100, 100, 100],
14 | r_init=2,
15 | output_prediction_type="mean",
16 | output_activation="softplus"
17 | ):
18 | super(DGD, self).__init__()
19 |
20 | self.main = nn.ModuleList()
21 |
22 | if type(hidden) is not int:
23 | n_hidden = len(hidden)
24 | self.main.append(nn.Linear(latent, hidden[0]))
25 | self.main.append(nn.ReLU(True))
26 | for i in range(n_hidden - 1):
27 | self.main.append(nn.Linear(hidden[i], hidden[i + 1]))
28 | self.main.append(nn.ReLU(True))
29 | self.main.append(nn.Linear(hidden[-1], out))
30 | else:
31 | self.main.append(nn.Linear(latent, hidden))
32 | self.main.append(nn.ReLU(True))
33 | self.main.append(nn.Linear(hidden, out))
34 |
35 | self.nb = NBLayer(
36 | out,
37 | r_init=r_init,
38 | output=output_prediction_type,
39 | activation=output_activation
40 | )
41 |
42 | def forward(self, z):
43 | for i in range(len(self.main)):
44 | z = self.main[i](z)
45 | return self.nb(z)
46 |
47 | @classmethod
48 | def load(cls, save_dir="./"):
49 | # get saved hyper-parameters
50 | with open(save_dir + "dgd_hyperparameters.json", "r") as fp:
51 | param_dict = json.load(fp)
52 |
53 | model = cls(
54 | out=param_dict["n_genes"],
55 | latent=param_dict["latent"],
56 | hidden=param_dict["hidden"],
57 | r_init=param_dict["r_init"],
58 | scaling_type=param_dict["scaling_type"],
59 | output=param_dict["output"],
60 | activation=param_dict["activation"],
61 | )
62 |
63 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
64 | model.load_state_dict(
65 | torch.load(
66 | save_dir + param_dict["name"] + "_decoder.pt", map_location=device
67 | )
68 | )
69 | return model
70 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | #import
3 | setup(name='scDGD',
4 | version="0.2",
5 | #description='Classes for internal representations',
6 | author='Viktoria Schuster',
7 | #url='https://github.com/viktoriaschuster/DGD_paper_experiments',
8 | packages=['scDGD','scDGD.classes','scDGD.models','scDGD.functions']
9 | #packages=find_packages()#,
10 | #package_dir = {'dgdExp': 'src'}
11 | )
12 |
13 | # install by going to directory and calling
14 | # python3 setup.py build
15 | # python3 setup.py install --user
16 | # --user needs to be used on this computer because I do not have sudo rights here
--------------------------------------------------------------------------------