├── LICENSE
├── README.md
├── gene_sets
├── Mouse_TF_targets.txt
├── c2.cp.v7.4.symbols.gmt
├── c2.cp.wikipathways.v7.4.symbols.gmt
├── c5.go.bp.v7.4.symbols.gmt
└── c8.all.v7.4.symbols.gmt
├── setup.py
├── tutorails
├── Data_preprocessing.ipynb
├── HuBMAP_datasets_ID.xlsx
├── UNIFAN_cluster_annotations.ipynb
├── UNIFAN_example.ipynb
└── getExample.py
├── unifan-main.png
├── unifan-pretrain.png
└── unifan
├── __init__.py
├── annocluster.py
├── autoencoder.py
├── classifier.py
├── datasets.py
├── main.py
├── networks.py
├── trainer.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Dora Li
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Introduction
2 | UNIFAN (**Un**supervised S**i**ngle-cell **F**unctional **An**notation) simultaneously clusters and annotates cells with known biological processes (including pathways). For each single cell, UNIFAN first infers gene set activity scores (denoted by r in the figure below) associated with this cell using the input gene sets.
3 |
4 | 
5 |
6 | Next, UNIFAN clusters cells by using the learned gene set activity scores (r) and a reduced dimension representation of the expression of genes in the cell. The gene set activity scores are used by an “annotator” to guide the clustering such that cells sharing similar biological processes are more likely to be grouped together. Such design allows the method to focus on the key processes when clustering cells and so can overcome issues related to noise and dropout while simultaneously selecting marker gene sets which can be used to annotate clusters. To allow the selection of marker genes for each cluster, we also add a subset of the most variable genes selected using Seuratv3 (Stuart et al., [2019](https://doi.org/10.1016/j.cell.2019.05.031)) as features for the annotator.
7 |
8 | 
9 |
10 | ## Table of Contents
11 | - [Get started](#Get-started)
12 | - [Command-line tools](#Command-line)
13 | - [Tutorials](#Tutorials)
14 | - [Updates log](#Updates-log)
15 | - [Learn-more](#Learn-more)
16 | - [Credits](#Credits)
17 |
18 | # Get-started
19 | ## Prerequisites
20 | * Python >= 3.6
21 | * Python side-packages:
22 | -- pytorch >= 1.9.0
23 | -- numpy >= 1.19.2
24 | -- pandas>=1.1.5
25 | -- scanpy >= 1.7.2
26 | -- leidenalg>=0.8.4
27 | -- tqdm >= 4.61.1
28 | -- scikit-learn>=0.24.2
29 | -- umap-learn>=0.5.1
30 | -- matplotlib >= 3.3.4
31 | -- seaborn >= 0.11.0
32 |
33 | ## Installation
34 |
35 | ### Install within a virtual environment
36 |
37 | It is recommended to use a virtural environment/pacakges manager such as [Anaconda](https://www.anaconda.com/). After successfully installing Anaconda/Miniconda, create an environment by the following:
38 |
39 | ```shell
40 | conda create -n myenv python=3.6
41 | ```
42 |
43 | You can then install and run the package in the virtual environment. Activate the virtural environment by:
44 |
45 | ```shell
46 | conda activate myenv
47 | ```
48 |
49 | Make sure you have **pip** installed in your environment. You may check by
50 |
51 | ```shell
52 | conda list
53 | ```
54 |
55 | If not installed, then:
56 |
57 | ```shell
58 | conda install pip
59 | ```
60 | ### Install Pytorch
61 |
62 | UNIFAN is built based on Pytorch and supporting both CPU or GPU. Make sure you have Pytorch (>= 1.9.0) installed in your virtual environment. If not, please visist [Pytorch](https://pytorch.org/) and install the appropriate version.
63 |
64 | ### Install UNIFAN
65 |
66 | Install by:
67 |
68 | ```shell
69 | pip install git+https://github.com/doraadong/UNIFAN.git
70 | ```
71 |
72 | If you want to upgrade UNIFAN to the latest version, then first uninstall it by:
73 |
74 | ```shell
75 | pip uninstall unifan
76 | ```
77 |
78 | And then just run the pip install command again.
79 |
80 | # Command-line
81 |
82 | You may import UNIFAN as an package and use it in your code (See [Tutorials](#Tutorials) for details). Or you may train models using the following command-line tool.
83 |
84 | ## Run UNIFAN
85 |
86 | Run UNIFAN by (arguments are taken for example):
87 |
88 | ```shell
89 | main.py -i ../example/input/Limb_Muscle.h5ad -o ../example/output -p tabula_muris -t Limb_Muscle -l cell_ontology_class -e ../gene_sets/
90 | ```
91 | The usage of this command is listed as follows. Note only the first 5 inputs are required:
92 |
93 | ```shell
94 | usage: main.py [-h] -i INPUT -o OUTPUT -p PROJECT -t TISSUE [-e GENESETSPATH]
95 | [-l LABEL] [-v VARIABLE] [-r PRIOR]
96 | [-f {gene_sets,gene,gene_gene_sets}] [-a ALPHA] [-b BETA]
97 | [-g GAMMA] [-u TAU] [-d DIM] [-s BATCH] [-na NANNO]
98 | [-ns NSCORE] [-nu NAUTO] [-nc NCLUSTER] [-nze NZENCO]
99 | [-nzd NZDECO] [-dze DIMZENCO] [-dzd DIMZDECO] [-nre NRENCO]
100 | [-dre DIMRENCO] [-drd DIMRDECO]
101 | [-n {sigmoid,non-negative,gaussian}] [-m SEED] [-c CUDA]
102 | [-w NWORKERS]
103 |
104 | optional arguments:
105 | -h, --help show this help message and exit
106 | -i INPUT, --input INPUT
107 | string, path to the input expression data, default
108 | '../input/data.h5ad'
109 | -o OUTPUT, --output OUTPUT
110 | string, path to the output folder, default
111 | '../output/'
112 | -p PROJECT, --project PROJECT
113 | string, identifier for the project, e.g., tabula_muris
114 | -t TISSUE, --tissue TISSUE
115 | string, tissue where the input data is sampled from
116 | -e GENESETSPATH, --geneSetsPath GENESETSPATH
117 | string, path to the folder where gene sets can be
118 | found, default='../gene_sets/'
119 | -l LABEL, --label LABEL
120 | string, optional, the column / field name of the
121 | ground truth label, if available; used for evaluation
122 | only; default None
123 | -v VARIABLE, --variable VARIABLE
124 | string, optional, the column / field name of the
125 | highly variable genes; default 'highly_variable'
126 | -r PRIOR, --prior PRIOR
127 | string, optional, gene set file names used to learn
128 | the gene set activity scores, use '+' to separate
129 | multiple gene set names, default
130 | c5.go.bp.v7.4.symbols.gmt+c2.cp.v7.4.symbols.gmt+TF-
131 | DNA
132 | -f {gene_sets,gene,gene_gene_sets}, --features {gene_sets,gene,gene_gene_sets}
133 | string, optional, features used for the annotator, any
134 | of 'gene_sets', 'gene' or 'gene_gene_sets', default
135 | 'gene_gene_sets'
136 | -a ALPHA, --alpha ALPHA
137 | float, optional, hyperparameter for the L1 term in the
138 | set cover loss, default 1e-2
139 | -b BETA, --beta BETA float, optional, hyperparameter for the set cover term
140 | in the set cover loss, default 1e-5
141 | -g GAMMA, --gamma GAMMA
142 | float, optional, hyperparameter for the exclusive L1
143 | term, default 1e-3
144 | -u TAU, --tau TAU float, optional, hyperparameter for the annotator
145 | loss, default 10
146 | -d DIM, --dim DIM integer, optional, dimension for the low-dimensional
147 | representation, default 32
148 | -s BATCH, --batch BATCH
149 | integer, optional, batch size for training except for
150 | pretraining annotator (fixed at 32), default 128
151 | -na NANNO, --nanno NANNO
152 | integer, optional, number of epochs to pretrain the
153 | annotator, default 50
154 | -ns NSCORE, --nscore NSCORE
155 | integer, optional, number of epochs to train the gene
156 | set activity model, default 70
157 | -nu NAUTO, --nauto NAUTO
158 | integer, optional, number of epochs to pretrain the
159 | annocluster model, default 50
160 | -nc NCLUSTER, --ncluster NCLUSTER
161 | integer, optional, number of epochs to train the
162 | annocluster model, default 25
163 | -nze NZENCO, --nzenco NZENCO
164 | float, optional, number of hidden layers for encoder
165 | of annocluster, default 3
166 | -nzd NZDECO, --nzdeco NZDECO
167 | float, optional, number of hidden layers for decoder
168 | of annocluster, default 2
169 | -dze DIMZENCO, --dimzenco DIMZENCO
170 | integer, optional, number of nodes for hidden layers
171 | for encoder of annocluster, default 128
172 | -dzd DIMZDECO, --dimzdeco DIMZDECO
173 | integer, optional, number of nodes for hidden layers
174 | for decoder of annocluster, default 128
175 | -nre NRENCO, --nrenco NRENCO
176 | integer, optional, number of hidden layers for the
177 | encoder of gene set activity scores model, default 5
178 | -dre DIMRENCO, --dimrenco DIMRENCO
179 | integer, optional, number of nodes for hidden layers
180 | for encoder of gene set activity scores model, default
181 | 128
182 | -drd DIMRDECO, --dimrdeco DIMRDECO
183 | integer, optional, number of nodes for hidden layers
184 | for decoder of gene set activity scores model, default
185 | 128
186 | -n {sigmoid,non-negative,gaussian}, --network {sigmoid,non-negative,gaussian}
187 | string, optional, the encoder for the gene set
188 | activity model, any of 'sigmoid', 'non-negative' or
189 | 'gaussian', default 'non-negative'
190 | -m SEED, --seed SEED integer, optional, random seed for the initialization,
191 | default 0
192 | -c CUDA, --cuda CUDA boolean, optional, if use GPU for neural network
193 | training, default False
194 | -w NWORKERS, --nworkers NWORKERS
195 | integer, optional, number of workers for dataloader,
196 | default 8
197 |
198 | ```
199 |
200 |
201 | # Tutorials
202 |
203 | Github rendering disables some functionalities of Jupyter notebooks. We recommend using [nbviewer](https://nbviewer.jupyter.org/) to view the following tutorials.
204 |
205 | ## Run UNIFAN on example data
206 | In [UNIFAN training tutorial](tutorails/UNIFAN_example.ipynb), we illustrate how to run UNIFAN step-by-step on the example data: Limb_Muscle from Tabula Muris.
207 |
208 | ### Download and Preprocess the Input Data
209 | You may download the gene sets in [gene_sets](gene_sets). As default, we use the GO terms for biological processes (c5.go.bp.v7.4.symbols.gmt), canonical pathways (c2.cp.v7.4.symbols.gmt) and the TF-DNA interacitons data (Mouse_TF_targets.txt).
210 |
211 | UNIFAN takes AnnData files as input. See [AnnData](https://anndata.readthedocs.io/en/latest/) for details. To prepare the example data (Limb_Muscle in Tabula Muris), first download the [Tabula Muris senis data](https://figshare.com/ndownloader/files/24351086). Then run the Python script [getExample.py](tutorails/getExample.py) to preprocess the count data using the following command:
212 |
213 | ```shell
214 | python getExample.py -p ./facs.h5ad -i ../example/input -t Limb_Muscle
215 |
216 | ```
217 | The usage of this command is listed as follows:
218 |
219 | ```shell
220 | usage: getExample.py [-h] -p PATH -i FOLDER -t TISSUE [-k TOPK]
221 |
222 | optional arguments:
223 | -h, --help show this help message and exit
224 | -p PATH, --path PATH string, path to the downloaded data, default
225 | './facs.h5ad'
226 | -i FOLDER, --folder FOLDER
227 | string, path to the folder to save the data, default
228 | '../example/input'
229 | -t TISSUE, --tissue TISSUE
230 | string, specify the output tissue; if using the
231 | default None, then all tissues will be outputted and
232 | saved separately in the folder; default None
233 | -k TOPK, --topk TOPK integer, optional, number of most variable genes,
234 | default 2000
235 |
236 | ```
237 |
238 | We also provide [Data preprocessing](tutorails/Data_preprocessing.ipynb) showing how we preprocessed the other datasets we used in the manuscript.
239 |
240 | ## Analyze results and annotate clusters
241 | In [cluster annotating tutorial](tutorails/UNIFAN_cluster_annotations.ipynb), we illustrate how to use the coefficients learned by UNIFAN to annotate clusters. Particularly, we show how to select representing gene sets / genes for each cluster, evaluate if selected genes are likely marker genes and visualize the annotations.
242 |
243 | # Updates-log
244 | * 10-11-2022:
245 | -- Add tutorial on preprocessing the datasets used in the manuscript
246 |
247 | # Learn-more
248 | Check our paper at [Genome Research](https://genome.cshlp.org/content/early/2022/06/28/gr.276609.122.long). Link to [preprint](https://www.biorxiv.org/content/10.1101/2021.11.20.469410v2).
249 |
250 | # Credits
251 | The software is an implementation of the method UNIFAN, jointly developed by [Dongshunyi "Dora" Li](https://github.com/doraadong) and Ziv Bar-Joseph from [System Biology Group @ Carnegie Mellon University](http://sb.cs.cmu.edu/) and [Jun Ding](https://github.com/phoenixding) from McGill University.
252 |
253 | # Contacts
254 | * dongshul at andrew.cmu.edu
255 |
256 | # License
257 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details
258 |
259 |
260 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # partly borrowed from from https://github.com/navdeep-G/setup.py/blob/master/setup.py
5 | import io
6 | import os
7 | from setuptools import setup, find_packages
8 | import pathlib
9 |
10 | # Package meta-data.
11 | NAME = 'unifan'
12 | DESCRIPTION = 'Unsupervised cell functional annotation'
13 | URL = "https://github.com/doraadong/UNIFAN"
14 | EMAIL = 'dongshul@andrew.cmu.edu'
15 | AUTHOR = 'Dora Li'
16 | REQUIRES_PYTHON = '>=3.6'
17 | VERSION = '1.0.0'
18 |
19 | # What packages are required for this module to be executed?
20 | REQUIRED = ["torch", "numpy>=1.19.2", "pandas>=1.1.5", "scanpy>=1.7.2", "leidenalg>=0.8.4", "tqdm>=4.61.1",
21 | "scikit-learn>=0.24.2", "umap-learn>=0.5.1", "matplotlib>=3.3.4", "seaborn>=0.11.0"]
22 |
23 | # The rest you shouldn't have to touch too much :)
24 | # ------------------------------------------------
25 | # Except, perhaps the License and Trove Classifiers!
26 | # If you do change the License, remember to change the Trove Classifier for that!
27 |
28 | here = os.path.abspath(os.path.dirname(__file__))
29 |
30 | # Import the README and use it as the long-description.
31 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file!
32 | try:
33 | with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f:
34 | long_description = '\n' + f.read()
35 | except FileNotFoundError:
36 | long_description = DESCRIPTION
37 |
38 | # Load the package's __version__.py module as a dictionary.
39 | about = {}
40 | if not VERSION:
41 | project_slug = NAME.lower().replace("-", "_").replace(" ", "_")
42 | with open(os.path.join(here, project_slug, '__version__.py')) as f:
43 | exec(f.read(), about)
44 | else:
45 | about['__version__'] = VERSION
46 |
47 | setup(
48 | name=NAME,
49 | version=about['__version__'],
50 | description=DESCRIPTION,
51 | long_description=long_description,
52 | long_description_content_type='text/markdown',
53 | author=AUTHOR,
54 | author_email=EMAIL,
55 | python_requires=REQUIRES_PYTHON,
56 | url=URL,
57 | packages=['unifan'],
58 | scripts=['unifan/main.py'],
59 | install_requires=REQUIRED,
60 | include_package_data=True,
61 | license='MIT',
62 | classifiers=[
63 | 'License :: OSI Approved :: MIT License',
64 | 'Programming Language :: Python :: 3.6',
65 | ],
66 | )
67 |
--------------------------------------------------------------------------------
/tutorails/Data_preprocessing.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os, sys\n",
10 | "import argparse\n",
11 | "import time\n",
12 | "from os.path import exists\n",
13 | "import collections\n",
14 | "from typing import Iterable\n",
15 | "import pickle \n",
16 | "from collections import Counter\n",
17 | "\n",
18 | "import scanpy as sc\n",
19 | "import pandas as pd\n",
20 | "import numpy as np\n",
21 | "from matplotlib import pyplot as plt"
22 | ]
23 | },
24 | {
25 | "cell_type": "markdown",
26 | "metadata": {},
27 | "source": [
28 | "# Preprocess input data \n",
29 | "\n",
30 | "Here we show how we preprocessed the datasets we used in our manuscript, except for the Tabula Muris datasets which we already covered in our examples (see [getExample.py](https://github.com/doraadong/UNIFAN/blob/main/tutorails/getExample.py) for details). Other than the pacakges imported above, we also use [mygene](https://pypi.org/project/mygene/) package to convert ENSEMBL IDs to gene symbols for the HuBMAP datasets. \n",
31 | "\n",
32 | "**Table of Content**\n",
33 | "\n",
34 | "1. [Preprocess pbmc28K](#1)\n",
35 | "2. [Preprocess pbmc68K](#2)\n",
36 | "3. [Preprocess Atlas lung](#3)\n",
37 | "4. [Preprocess HuBMAP datasets](#4)"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 3,
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "cell_cutoffs = {'lung': 500, 'lymph_node':200, 'spleen':200, 'thymus':200, 'pbmc68k':200, 'pbmc28k':200}\n",
47 | "gene_cutoffs = {'lung': 5, 'lymph_node':3, 'spleen':3, 'thymus':3, 'pbmc68k':3, 'pbmc28k':3}"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": 4,
53 | "metadata": {},
54 | "outputs": [],
55 | "source": [
56 | "save_raw = False\n",
57 | "save_processed = True\n",
58 | "\n",
59 | "topk = 2000\n",
60 | "select_genes = False\n",
61 | "\n",
62 | "if select_genes: \n",
63 | " condition = f\"_top{topk}\"\n",
64 | "else:\n",
65 | " condition = f\"\""
66 | ]
67 | },
68 | {
69 | "cell_type": "markdown",
70 | "metadata": {},
71 | "source": [
72 | "## 1. Preprocess pbmc28k\n",
73 | ""
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": 6,
79 | "metadata": {},
80 | "outputs": [],
81 | "source": [
82 | "tissue = \"pbmc28k\"\n",
83 | "data_name = \"pbmc\"\n",
84 | "\n",
85 | "parent_folder = f\"../input/{data_name}\"\n",
86 | "labels = pd.read_csv(f\"{parent_folder}/{tissue}/barcodes_to_cell_types.tsv\", sep=\"\\t\")\n",
87 | "\n",
88 | "barcodes = [i.split('_')[0] for i in labels['barcode']]\n",
89 | "lanes = [i.split('_')[1] for i in labels['barcode']]\n",
90 | "labels['code'] = barcodes\n",
91 | "labels['lane'] = lanes"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": 7,
97 | "metadata": {},
98 | "outputs": [
99 | {
100 | "name": "stdout",
101 | "output_type": "stream",
102 | "text": [
103 | "AnnData object with n_obs × n_vars = 4309 × 32738\n",
104 | " var: 'gene_ids'\n",
105 | "AnnData object with n_obs × n_vars = 3844 × 32738\n",
106 | " var: 'gene_ids'\n",
107 | "AnnData object with n_obs × n_vars = 4998 × 32738\n",
108 | " var: 'gene_ids'\n",
109 | "AnnData object with n_obs × n_vars = 3392 × 32738\n",
110 | " var: 'gene_ids'\n",
111 | "AnnData object with n_obs × n_vars = 2868 × 32738\n",
112 | " var: 'gene_ids'\n",
113 | "AnnData object with n_obs × n_vars = 3685 × 32738\n",
114 | " var: 'gene_ids'\n",
115 | "AnnData object with n_obs × n_vars = 3198 × 32738\n",
116 | " var: 'gene_ids'\n",
117 | "AnnData object with n_obs × n_vars = 2561 × 32738\n",
118 | " var: 'gene_ids'\n"
119 | ]
120 | }
121 | ],
122 | "source": [
123 | "full_data = None\n",
124 | "pre_genes = None \n",
125 | "\n",
126 | "for i in range(1, 9):\n",
127 | " cur_lane = f\"lane_{i}\"\n",
128 | " \n",
129 | " parent_folder = f\"../input/{data_name}\"\n",
130 | " adata = sc.read_10x_mtx(f\"{parent_folder}/{tissue}/{cur_lane}/\", var_names='gene_symbols', cache=True) \n",
131 | " adata.var_names_make_unique() \n",
132 | " \n",
133 | " print(adata)\n",
134 | " \n",
135 | " # get observation idx\n",
136 | " _codes = [i.split('-')[0] for i in adata.obs.index.values]\n",
137 | " # all barcodes having labels are in the expression file \n",
138 | " assert len(set(labels[labels['lane'] == f\"lane{i}\"].code.values) - set(_codes)) == 0\n",
139 | "\n",
140 | " _others = [i.split('-')[1] for i in adata.obs.index.values]\n",
141 | " assert (np.array(_others) == '1').all()\n",
142 | " \n",
143 | " # get labels \n",
144 | " cur_labels = labels[labels['lane'] == f\"lane{i}\"].copy().reset_index()\n",
145 | " cur_labels = cur_labels[['code', 'cell_type']]\n",
146 | "\n",
147 | " df_codes = pd.DataFrame(_codes)\n",
148 | " df_codes.columns = ['code']\n",
149 | " df_codes = df_codes.merge(cur_labels, left_on = 'code', right_on = 'code', how = 'left')\n",
150 | "\n",
151 | " # same order as in adata\n",
152 | " assert (df_codes['code'].values == np.array(_codes)).all()\n",
153 | " \n",
154 | " adata.obs['label'] = df_codes['cell_type'].values\n",
155 | " \n",
156 | " if full_data is not None:\n",
157 | " full_data = full_data.concatenate(adata)\n",
158 | " else:\n",
159 | " full_data = adata\n",
160 | " \n",
161 | " # check if gene id same for all data\n",
162 | " cur_genes = adata.var['gene_ids'].values\n",
163 | " if pre_genes is not None:\n",
164 | " assert (pre_genes == cur_genes).all()\n",
165 | " pre_genes = cur_genes"
166 | ]
167 | },
168 | {
169 | "cell_type": "code",
170 | "execution_count": 8,
171 | "metadata": {
172 | "scrolled": false
173 | },
174 | "outputs": [],
175 | "source": [
176 | "# keep only those with labels\n",
177 | "full_data = full_data[~full_data.obs.label.isna()]"
178 | ]
179 | },
180 | {
181 | "cell_type": "code",
182 | "execution_count": 9,
183 | "metadata": {
184 | "scrolled": true
185 | },
186 | "outputs": [
187 | {
188 | "data": {
189 | "text/plain": [
190 | "View of AnnData object with n_obs × n_vars = 25185 × 32738\n",
191 | " obs: 'label', 'batch'\n",
192 | " var: 'gene_ids'"
193 | ]
194 | },
195 | "execution_count": 9,
196 | "metadata": {},
197 | "output_type": "execute_result"
198 | }
199 | ],
200 | "source": [
201 | "full_data"
202 | ]
203 | },
204 | {
205 | "cell_type": "code",
206 | "execution_count": 10,
207 | "metadata": {},
208 | "outputs": [
209 | {
210 | "name": "stderr",
211 | "output_type": "stream",
212 | "text": [
213 | "Trying to set attribute `.obs` of view, copying.\n",
214 | "... storing 'label' as categorical\n"
215 | ]
216 | }
217 | ],
218 | "source": [
219 | "# filtering \n",
220 | "sc.pp.filter_cells(full_data, min_genes=cell_cutoffs[tissue])\n",
221 | "sc.pp.filter_genes(full_data, min_cells=gene_cutoffs[tissue])\n",
222 | "\n",
223 | "if save_raw:\n",
224 | " filename = f\"{tissue}_raw{condition}.h5ad\"\n",
225 | " full_data.write_h5ad(os.path.join(parent_folder, filename))\n",
226 | "\n",
227 | "if save_processed:\n",
228 | " # selecting genes using Seurat v3 method \n",
229 | " sc.pp.highly_variable_genes(full_data, n_top_genes=topk, flavor='seurat_v3')\n",
230 | " if select_genes:\n",
231 | " full_data = full_data[:, full_data.var.highly_variable]\n",
232 | "\n",
233 | " # normalize \n",
234 | " sc.pp.normalize_total(full_data, target_sum=1e4)\n",
235 | " sc.pp.log1p(full_data)\n",
236 | "\n",
237 | " # adata.raw = adata\n",
238 | " sc.pp.scale(full_data, max_value=10)\n",
239 | "\n",
240 | " # save \n",
241 | " try:\n",
242 | " full_data.__dict__['_raw'].__dict__['_var'] = full_data.__dict__['_raw'].__dict__['_var'].rename(columns={'_index': 'features'})\n",
243 | " except AttributeError:\n",
244 | " pass \n",
245 | "\n",
246 | " filename = f\"{tissue}_processed{condition}.h5ad\"\n",
247 | " full_data.write_h5ad(os.path.join(parent_folder, filename))"
248 | ]
249 | },
250 | {
251 | "cell_type": "markdown",
252 | "metadata": {},
253 | "source": [
254 | "## 2. Preprocessing pbmc68k\n",
255 | ""
256 | ]
257 | },
258 | {
259 | "cell_type": "code",
260 | "execution_count": 6,
261 | "metadata": {},
262 | "outputs": [
263 | {
264 | "name": "stderr",
265 | "output_type": "stream",
266 | "text": [
267 | "... storing 'label' as categorical\n"
268 | ]
269 | }
270 | ],
271 | "source": [
272 | "tissue = \"pbmc68k\"\n",
273 | "data_name = \"pbmc\"\n",
274 | "\n",
275 | "parent_folder = f\"../input/{data_name}\"\n",
276 | "full_data = sc.read_10x_mtx(f\"{parent_folder}/{tissue}/filtered_matrices_mex/hg19/\", var_names='gene_symbols',cache=True) \n",
277 | "full_data.var_names_make_unique() \n",
278 | "\n",
279 | "# get labels \n",
280 | "labels = pd.read_csv(f\"{parent_folder}/{tissue}/zheng17-cell-labels.txt\", sep = \"\\t\")\n",
281 | "\n",
282 | "# check if barcodes same\n",
283 | "barcodes = pd.read_csv(\"../input/pbmc/pbmc68k/filtered_matrices_mex/hg19/barcodes.tsv\", header = None)\n",
284 | "assert (labels['barcode'].values == barcodes[0].values).all()\n",
285 | "\n",
286 | "full_data.obs['label'] = labels['bulk_labels'].values\n",
287 | "\n",
288 | "# filtering \n",
289 | "sc.pp.filter_cells(full_data, min_genes=cell_cutoffs[tissue]) \n",
290 | "sc.pp.filter_genes(full_data, min_cells=gene_cutoffs[tissue])\n",
291 | "\n",
292 | "if save_raw:\n",
293 | " filename = f\"{tissue}_raw{condition}.h5ad\"\n",
294 | " full_data.write_h5ad(os.path.join(parent_folder, filename))\n",
295 | " \n",
296 | "if save_processed:\n",
297 | " # selecting genes using Seurat v3 method \n",
298 | " sc.pp.highly_variable_genes(full_data, n_top_genes=topk, flavor='seurat_v3')\n",
299 | " if select_genes:\n",
300 | " full_data = full_data[:, full_data.var.highly_variable]\n",
301 | "\n",
302 | " # normalie \n",
303 | " sc.pp.normalize_total(full_data, target_sum=1e4)\n",
304 | " sc.pp.log1p(full_data)\n",
305 | "\n",
306 | " # adata.raw = adata\n",
307 | " sc.pp.scale(full_data, max_value=10)\n",
308 | "\n",
309 | " # save \n",
310 | " try:\n",
311 | " full_data.__dict__['_raw'].__dict__['_var'] = full_data.__dict__['_raw'].__dict__['_var'].rename(columns={'_index': 'features'})\n",
312 | " except AttributeError:\n",
313 | " pass \n",
314 | "\n",
315 | " filename = f\"{tissue}_processed{condition}.h5ad\"\n",
316 | " full_data.write_h5ad(os.path.join(parent_folder, filename))"
317 | ]
318 | },
319 | {
320 | "cell_type": "markdown",
321 | "metadata": {},
322 | "source": [
323 | "## 3. Preprocess Atlas lung\n",
324 | "\n",
325 | "\n",
326 | "Before running the following code, make sure you have done the following: \n",
327 | "\n",
328 | "1. Download the following two files from [GSE136831](https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE136831)\n",
329 | " * GSE136831_AllCells.Samples.CellType.MetadataTable.txt\n",
330 | " * GSE136831_RawCounts_Sparse.mtx \n",
331 | "2. Take only samples where \"Disease_Identity\" == \"Control\"\n",
332 | "3. Make a AnnData (i.e. lung.h5ad) using the raw counts as the expression and the meta data as the observation matrix. Make sure to include the following two columns in the observation matrix:\n",
333 | " * CellType_Category - coarse-grained label \n",
334 | " * Manuscript_Identity - fine-grained label "
335 | ]
336 | },
337 | {
338 | "cell_type": "code",
339 | "execution_count": null,
340 | "metadata": {},
341 | "outputs": [],
342 | "source": [
343 | "tissue = \"lung\"\n",
344 | "\n",
345 | "filename = f\"{tissue}.h5ad\"\n",
346 | "parent_folder = \"../input/ipf\"\n",
347 | "\n",
348 | "full_data = sc.read(os.path.join(parent_folder, filename), dtype='float64')\n",
349 | "\n",
350 | "# filtering \n",
351 | "sc.pp.filter_cells(full_data, min_genes=cell_cutoffs[tissue]) \n",
352 | "sc.pp.filter_genes(full_data, min_cells=gene_cutoffs[tissue])\n",
353 | "\n",
354 | "if save_raw:\n",
355 | " # filter using genes in the preprocessed data\n",
356 | " filename = f\"{tissue}_processed{condition}.h5ad\"\n",
357 | " proccessed_data = sc.read(os.path.join(parent_folder, filename), dtype='float64', backed=\"r\")\n",
358 | "\n",
359 | " print(full_data.var.index.is_unique)\n",
360 | " cur_genes = list(full_data.var.index.values)\n",
361 | "\n",
362 | " idx_genes = [cur_genes.index(i) for i in proccessed_data.var.index.values]\n",
363 | " full_data = full_data[:, idx_genes]\n",
364 | "\n",
365 | " assert (full_data.var.index.values == proccessed_data.var.index.values).all()\n",
366 | " assert (full_data.obs.index.values == proccessed_data.obs.index.values).all()\n",
367 | "\n",
368 | " try:\n",
369 | " full_data.__dict__['_raw'].__dict__['_var'] = full_data.__dict__['_raw'].__dict__['_var'].rename(columns={'_index': 'features'})\n",
370 | " except AttributeError:\n",
371 | " pass \n",
372 | "\n",
373 | "\n",
374 | " filename = f\"{tissue}_raw{condition}.h5ad\"\n",
375 | " full_data.write_h5ad(os.path.join(parent_folder, filename))\n",
376 | "\n",
377 | "if save_processed:\n",
378 | "\n",
379 | " # normalie \n",
380 | " sc.pp.normalize_total(full_data, target_sum=1e4)\n",
381 | " sc.pp.log1p(full_data)\n",
382 | "\n",
383 | " # given the large number of genes for this data; filter to keep only highly dispersed genes \n",
384 | " sc.pp.highly_variable_genes(full_data, min_mean=0, max_mean = 1000, min_disp=0.01)\n",
385 | " full_data = full_data[:, full_data.var.highly_variable]\n",
386 | "\n",
387 | " # adata.raw = adata\n",
388 | " sc.pp.scale(full_data, max_value=10)\n",
389 | "\n",
390 | " # save \n",
391 | " try:\n",
392 | " full_data.__dict__['_raw'].__dict__['_var'] = full_data.__dict__['_raw'].__dict__['_var'].rename(columns={'_index': 'features'})\n",
393 | " except AttributeError:\n",
394 | " pass \n",
395 | "\n",
396 | " filename = f\"{tissue}_processed{condition}.h5ad\"\n",
397 | " full_data.write_h5ad(os.path.join(parent_folder, filename))"
398 | ]
399 | },
400 | {
401 | "cell_type": "markdown",
402 | "metadata": {},
403 | "source": [
404 | "##### add highly variables genes selected by Seurat_v3 for lung"
405 | ]
406 | },
407 | {
408 | "cell_type": "code",
409 | "execution_count": 5,
410 | "metadata": {},
411 | "outputs": [],
412 | "source": [
413 | "tissue = \"lung\"\n",
414 | "\n",
415 | "filename = f\"{tissue}.h5ad\"\n",
416 | "parent_folder = \"../input/ipf\"\n",
417 | "\n",
418 | "full_data = sc.read(os.path.join(parent_folder, filename), dtype='float64')\n",
419 | "\n",
420 | "# filtering \n",
421 | "sc.pp.filter_cells(full_data, min_genes=cell_cutoffs[tissue]) \n",
422 | "sc.pp.filter_genes(full_data, min_cells=gene_cutoffs[tissue])\n",
423 | "\n",
424 | "# selecting genes using Seurat v3 method \n",
425 | "sc.pp.highly_variable_genes(full_data, n_top_genes=topk, flavor='seurat_v3') \n",
426 | "genes_variable = full_data.var.index[full_data.var.highly_variable]"
427 | ]
428 | },
429 | {
430 | "cell_type": "code",
431 | "execution_count": 33,
432 | "metadata": {},
433 | "outputs": [],
434 | "source": [
435 | "filename = f\"{tissue}_processed.h5ad\"\n",
436 | "temp = sc.read(os.path.join(parent_folder, filename), dtype='float64', backed = \"r\")"
437 | ]
438 | },
439 | {
440 | "cell_type": "code",
441 | "execution_count": 17,
442 | "metadata": {},
443 | "outputs": [
444 | {
445 | "name": "stdout",
446 | "output_type": "stream",
447 | "text": [
448 | "Number of Seurat selected genes not in dispersion-selected: 100\n"
449 | ]
450 | }
451 | ],
452 | "source": [
453 | "print(f\"Number of Seurat selected genes not in dispersion-selected: {len(set(genes_variable) - set(temp.var.index))}\")"
454 | ]
455 | },
456 | {
457 | "cell_type": "code",
458 | "execution_count": 20,
459 | "metadata": {},
460 | "outputs": [],
461 | "source": [
462 | "highly_variable_seurat = np.repeat(0, temp.shape[1])\n",
463 | "\n",
464 | "for i in range(temp.shape[1]):\n",
465 | " g = temp.var.index.values[i]\n",
466 | " if g in genes_variable:\n",
467 | " highly_variable_seurat[i] = 1\n",
468 | " \n",
469 | "temp.var[\"highly_variable_seurat\"] = highly_variable_seurat.astype(bool)"
470 | ]
471 | },
472 | {
473 | "cell_type": "code",
474 | "execution_count": 32,
475 | "metadata": {},
476 | "outputs": [],
477 | "source": [
478 | "filename = f\"{tissue}_processed.h5ad\"\n",
479 | "temp.write_h5ad(os.path.join(parent_folder, filename))"
480 | ]
481 | },
482 | {
483 | "cell_type": "markdown",
484 | "metadata": {},
485 | "source": [
486 | "## 4. Preprocess HuBMAP datasets\n",
487 | "\n",
488 | "\n",
489 | "For each tissue, we made the data by concatenating multiple datasets from HuBMAP. Please see [HuBMAP_datasets_ID.xlsx](https://github.com/doraadong/UNIFAN/blob/main/tutorails/HuBMAP_datasets_ID.xlsx) for the dataset IDs for the corresponding tissues. All datasets are preprocessed using standardized [HuBMAP scRNA-seq pipeline](https://github.com/hubmapconsortium/salmon-rnaseq)."
490 | ]
491 | },
492 | {
493 | "cell_type": "code",
494 | "execution_count": 6,
495 | "metadata": {},
496 | "outputs": [
497 | {
498 | "name": "stderr",
499 | "output_type": "stream",
500 | "text": [
501 | "... storing 'tissue' as categorical\n",
502 | "... storing 'predicted.id' as categorical\n",
503 | "... storing 'celltype' as categorical\n",
504 | "... storing 'tissue' as categorical\n",
505 | "... storing 'predicted.id' as categorical\n",
506 | "... storing 'celltype' as categorical\n"
507 | ]
508 | },
509 | {
510 | "name": "stdout",
511 | "output_type": "stream",
512 | "text": [
513 | "True\n"
514 | ]
515 | }
516 | ],
517 | "source": [
518 | "for tissue in [\"spleen\", \"thymus\", \"lymph_node\"]:\n",
519 | " \n",
520 | " filename = f\"{tissue}.h5ad\"\n",
521 | " parent_folder = \"../input/hubmap\"\n",
522 | "\n",
523 | " full_data = sc.read(os.path.join(parent_folder, filename), dtype='float64')\n",
524 | "\n",
525 | " # filtering \n",
526 | " sc.pp.filter_cells(full_data, min_genes=cell_cutoffs[tissue]) \n",
527 | " sc.pp.filter_genes(full_data, min_cells=gene_cutoffs[tissue])\n",
528 | " \n",
529 | " if save_raw: \n",
530 | " try:\n",
531 | " full_data.__dict__['_raw'].__dict__['_var'] = full_data.__dict__['_raw'].__dict__['_var'].rename(columns={'_index': 'features'})\n",
532 | " except AttributeError:\n",
533 | " pass \n",
534 | " \n",
535 | " filename = f\"{tissue}_raw{condition}.h5ad\"\n",
536 | " full_data.write_h5ad(os.path.join(parent_folder, filename))\n",
537 | " \n",
538 | " if save_processed:\n",
539 | " # selecting genes using Seurat v3 method \n",
540 | " sc.pp.highly_variable_genes(full_data, n_top_genes=topk, flavor='seurat_v3')\n",
541 | "\n",
542 | " if select_genes:\n",
543 | " full_data = full_data[:, full_data.var.highly_variable]\n",
544 | "\n",
545 | " # normalie \n",
546 | " sc.pp.normalize_total(full_data, target_sum=1e4)\n",
547 | " sc.pp.log1p(full_data)\n",
548 | "\n",
549 | " # adata.raw = adata\n",
550 | " sc.pp.scale(full_data, max_value=10)\n",
551 | "\n",
552 | " # convert ensemble to gene symbols \n",
553 | " import mygene\n",
554 | " mg = mygene.MyGeneInfo()\n",
555 | "\n",
556 | " genes = [g.split('.')[0] for g in full_data.var.index.values]\n",
557 | "\n",
558 | " out = mg.querymany(genes, scopes='ensembl.gene', fields='symbol', species='human', returnall=True)\n",
559 | "\n",
560 | "\n",
561 | " # check when not selecting genes \n",
562 | " for i in out['out']:\n",
563 | " if i['query'] == 'ENSG00000229425':\n",
564 | " print(i)\n",
565 | "\n",
566 | " for i in out['out']:\n",
567 | " if i['query'] == 'ENSG00000130723':\n",
568 | " print(i)\n",
569 | "\n",
570 | " en2symbol = {}\n",
571 | " for i in out['out']:\n",
572 | " if 'symbol' in i.keys():\n",
573 | " if i['query'] not in en2symbol.keys():\n",
574 | " en2symbol[i['query']] = [i['symbol']]\n",
575 | " else:\n",
576 | " en2symbol[i['query']].append(i['symbol'])\n",
577 | " else:\n",
578 | " en2symbol[i['query']] = [i['query']]\n",
579 | "\n",
580 | " symbols = []\n",
581 | " for g in genes:\n",
582 | " symbols.append(en2symbol[g][0]) # take the first one \n",
583 | "\n",
584 | " full_data.var.index = symbols\n",
585 | "\n",
586 | "\n",
587 | " # save \n",
588 | " try:\n",
589 | " full_data.__dict__['_raw'].__dict__['_var'] = full_data.__dict__['_raw'].__dict__['_var'].rename(columns={'_index': 'features'})\n",
590 | " except AttributeError:\n",
591 | " pass \n",
592 | "\n",
593 | " filename = f\"{tissue}_processed{condition}.h5ad\"\n",
594 | " full_data.write_h5ad(os.path.join(parent_folder, filename))"
595 | ]
596 | }
597 | ],
598 | "metadata": {
599 | "kernelspec": {
600 | "display_name": "Python [conda env:anno] *",
601 | "language": "python",
602 | "name": "conda-env-anno-py"
603 | },
604 | "language_info": {
605 | "codemirror_mode": {
606 | "name": "ipython",
607 | "version": 3
608 | },
609 | "file_extension": ".py",
610 | "mimetype": "text/x-python",
611 | "name": "python",
612 | "nbconvert_exporter": "python",
613 | "pygments_lexer": "ipython3",
614 | "version": "3.6.13"
615 | }
616 | },
617 | "nbformat": 4,
618 | "nbformat_minor": 2
619 | }
620 |
--------------------------------------------------------------------------------
/tutorails/HuBMAP_datasets_ID.xlsx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/doraadong/UNIFAN/96e0811d52e3af4b0d40dfa5fa02cfa0d28c9be6/tutorails/HuBMAP_datasets_ID.xlsx
--------------------------------------------------------------------------------
/tutorails/getExample.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | import scanpy as sc
5 |
6 |
7 | """
8 |
9 | Process the Tabula Muris senis data.
10 |
11 | Download data from: https://figshare.com/ndownloader/files/24351086
12 |
13 | """
14 |
15 | def main():
16 | # parse command-line arguments
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('-p', '--path', required=True, type=str,
19 | default='./facs.h5ad', help="string, path to the downloaded data, "
20 | "default './facs.h5ad'")
21 | parser.add_argument('-i', '--folder', required=True, type=str,
22 | default='../example/input', help="string, path to the folder to save the data, "
23 | "default '../example/input'")
24 | parser.add_argument('-t', '--tissue', required=True, type=str,
25 | default=None, help="string, specify the output tissue; if using the default None, then all "
26 | "tissues will be outputted and saved separately in the folder; default None")
27 | parser.add_argument('-k', '--topk', required=False, default=2000, type=int,
28 | help="integer, optional, number of most variable genes, default 2000")
29 |
30 | args = parser.parse_args()
31 | print(args)
32 |
33 | parent_folder = args.folder
34 | filepath = args.path
35 | tissue = args.tissue
36 | topk = args.topk
37 |
38 | adata = sc.read(filepath, dtype='float64')
39 |
40 | sc.pp.filter_cells(adata, min_counts=5000)
41 | sc.pp.filter_cells(adata, min_genes=500)
42 | sc.pp.filter_genes(adata, min_cells=5)
43 |
44 | # get unnormalized version to infer highly variable genes
45 | full_data_unnorm = adata[adata.obs["age"] == "3m", :].copy() # equivalent to Tabula Muris data (MARS)
46 |
47 | # not include Brain_Myeloid and Marrow (MARS)
48 | full_data_unnorm = full_data_unnorm[full_data_unnorm.obs["tissue"] != "Marrow", :].copy()
49 | full_data_unnorm = full_data_unnorm[full_data_unnorm.obs["tissue"] != "Brain_Myeloid", :].copy()
50 |
51 | sc.pp.normalize_total(adata, target_sum=1e4)
52 |
53 | sc.pp.log1p(adata)
54 |
55 | sc.pp.scale(adata, max_value=10)
56 |
57 | full_data = adata[adata.obs["age"] == "3m", :] # equivalent to Tabula Muris data (MARS)
58 |
59 | # not include Brain_Myeloid and Marrow (MARS)
60 | full_data = full_data[full_data.obs["tissue"] != "Marrow", :]
61 | full_data = full_data[full_data.obs["tissue"] != "Brain_Myeloid", :]
62 |
63 | if tissue is None:
64 | # save each tissue separately
65 | for tissue in set(full_data.obs['tissue'].values):
66 | print(f"Saving output for {tissue}")
67 |
68 | # get most variable genes using unnormalized
69 | subset_unnorm = full_data_unnorm[full_data_unnorm.obs['tissue'] == tissue].copy()
70 |
71 | # selecting genes using Seurat v3 method
72 | sc.pp.highly_variable_genes(subset_unnorm, n_top_genes=topk, flavor='seurat_v3')
73 |
74 | # keep only the current tissue
75 | subset = full_data[full_data.obs['tissue'] == tissue].copy()
76 | subset.var["highly_variable"] = subset_unnorm.var["highly_variable"]
77 |
78 | # write full data
79 | subset.write_h5ad(os.path.join(parent_folder, f"{tissue}_facts_processed_3m.h5ad"))
80 | else:
81 | # get most variable genes using unnormalized
82 | subset_unnorm = full_data_unnorm[full_data_unnorm.obs['tissue'] == tissue].copy()
83 |
84 | # selecting genes using Seurat v3 method
85 | sc.pp.highly_variable_genes(subset_unnorm, n_top_genes=topk, flavor='seurat_v3')
86 |
87 | # keep only the current tissue
88 | subset = full_data[full_data.obs['tissue'] == tissue].copy()
89 | subset.var["highly_variable"] = subset_unnorm.var["highly_variable"]
90 |
91 | # write full data
92 | subset.write_h5ad(os.path.join(parent_folder, f"{tissue}_facts_processed_3m.h5ad"))
93 |
94 |
95 |
96 | if __name__ == '__main__':
97 | main()
98 |
--------------------------------------------------------------------------------
/unifan-main.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/doraadong/UNIFAN/96e0811d52e3af4b0d40dfa5fa02cfa0d28c9be6/unifan-main.png
--------------------------------------------------------------------------------
/unifan-pretrain.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/doraadong/UNIFAN/96e0811d52e3af4b0d40dfa5fa02cfa0d28c9be6/unifan-pretrain.png
--------------------------------------------------------------------------------
/unifan/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | dir_path = os.path.dirname(os.path.realpath(__file__))
5 | sys.path.append(dir_path)
6 |
7 | __all__ = ['main', 'annocluster', 'autoencoder', 'classifier', 'datasets', 'networks', 'trainer', 'utils']
8 | for i in __all__:
9 | __import__(i)
--------------------------------------------------------------------------------
/unifan/annocluster.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | torch.backends.cudnn.benchmark = True
4 |
5 | from unifan.networks import Encoder, Decoder, Set2Gene
6 |
7 |
8 | class AnnoCluster(nn.Module):
9 |
10 | """
11 | Clustering with annotator.
12 |
13 | Parameters
14 | ----------
15 | input_dim: integer
16 | number of input features
17 | z_dim: integer
18 | number of low-dimensional features
19 | gene_set_dim: integer
20 | number of gene sets
21 | tau: float
22 | hyperparameter to weight the annotator loss
23 | zeta: float
24 | hyperparameter to weight the reconstruction loss from embeddings (discrete representations)
25 | encoder_dim: integer
26 | dimension of hidden layer for encoders
27 | emission_dim: integer
28 | dimension of hidden layer for decoders
29 | num_layers_encoder: integer
30 | number of hidden layers for encoder
31 | num_layers_decoder: integer
32 | number of hidden layers for decoder
33 | dropout_rate: float
34 | use_t_dist: boolean
35 | if using t distribution kernel to transform the euclidean distances between encodings and centroids
36 | regulating_probability: string
37 | the type of probability to regulating the clustering (by distance) results
38 | centroids: torch.Tensor
39 | embeddings in the low-dimensional space for the cluster centroids
40 | gene_set_table: torch.Tensor
41 | gene set relationship table
42 |
43 | """
44 |
45 | def __init__(self, input_dim: int = 10000, z_dim: int = 32, gene_set_dim: int = 335,
46 | tau: float = 1.0, zeta: float = 1.0, n_clusters: int = 16,
47 | encoder_dim: int = 128, emission_dim: int = 128, num_layers_encoder: int = 1,
48 | num_layers_decoder: int = 1, dropout_rate: float = 0.1, use_t_dist: bool = True,
49 | reconstruction_network: str = "gaussian", decoding_network: str = "gaussian",
50 | regulating_probability: str = "classifier", centroids: torch.Tensor = None,
51 | gene_set_table: torch.Tensor = None, use_cuda: bool = False):
52 |
53 | super().__init__()
54 |
55 | # initialize parameters
56 | self.z_dim = z_dim
57 | self.reconstruction_network = reconstruction_network
58 | self.decoding_network = decoding_network
59 | self.tau = tau
60 | self.zeta = zeta
61 | self.n_clusters = n_clusters
62 | self.use_t_dist = use_t_dist
63 | self.regulating_probability = regulating_probability
64 |
65 | if regulating_probability not in ["classifier"]:
66 | raise NotImplementedError(f"The current implementation only support 'classifier', "
67 | f" for regulating probability.")
68 |
69 | # initialize centroids embeddings
70 | if centroids is not None:
71 | self.embeddings = nn.Parameter(centroids, requires_grad=True)
72 | else:
73 | self.embeddings = nn.Parameter(torch.randn(self.n_clusters, self.z_dim) * 0.05, requires_grad=True)
74 |
75 | # initialize loss
76 | self.mse_loss = nn.MSELoss()
77 | self.nLL_loss = nn.NLLLoss()
78 |
79 | # instantiate encoder for z
80 | if self.reconstruction_network == "gaussian":
81 | self.encoder = Encoder(input_dim, z_dim, num_layers=num_layers_encoder, hidden_dim=encoder_dim,
82 | dropout_rate=dropout_rate)
83 | else:
84 | raise NotImplementedError(f"The current implementation only support 'gaussian' for encoder.")
85 |
86 | # instantiate decoder for emission
87 | if self.decoding_network == 'gaussian':
88 | self.decoder_e = Decoder(z_dim, input_dim, num_layers=num_layers_decoder, hidden_dim=emission_dim)
89 | self.decoder_q = Decoder(z_dim, input_dim, num_layers=num_layers_decoder,hidden_dim=emission_dim)
90 | elif self.decoding_network == 'geneSet':
91 | self.decoder_e = Set2Gene(gene_set_table)
92 | self.decoder_q = Set2Gene(gene_set_table)
93 | else:
94 | raise NotImplementedError(f"The current implementation only support 'gaussian', "
95 | f"'geneSet' for emission decoder.")
96 |
97 | self.use_cuda = use_cuda
98 | if use_cuda:
99 | self.cuda()
100 |
101 | def forward(self, x):
102 |
103 | # get encoding
104 | z_e, _ = self.encoder(x)
105 |
106 | # get the index of embedding closed to the encoding
107 | k, z_dist, dist_prob = self._get_clusters(z_e)
108 |
109 | # get embeddings (discrete representations)
110 | z_q = self._get_embeddings(k)
111 |
112 | # decode embedding (discrete representation) and encoding
113 | x_q, _ = self.decoder_q(z_q)
114 | x_e, _ = self.decoder_e(z_e)
115 |
116 | return x_e, x_q, z_e, z_q, k, z_dist, dist_prob
117 |
118 | def _get_clusters(self, z_e):
119 | """
120 |
121 | Assign each sample to a cluster based on euclidean distances.
122 |
123 | Parameters
124 | ----------
125 | z_e: torch.Tensor
126 | low-dimensional encodings
127 |
128 | Returns
129 | -------
130 | k: torch.Tensor
131 | cluster assignments
132 | z_dist: torch.Tensor
133 | distances between encodings and centroids
134 | dist_prob: torch.Tensor
135 | probability of closeness of encodings to centroids transformed by t-distribution
136 |
137 | """
138 |
139 | _z_dist = (z_e.unsqueeze(1) - self.embeddings.unsqueeze(0)) ** 2
140 | z_dist = torch.sum(_z_dist, dim=-1)
141 |
142 | if self.use_t_dist:
143 | dist_prob = self._t_dist_sim(z_dist, df=10)
144 | k = torch.argmax(dist_prob, dim=-1)
145 | else:
146 | k = torch.argmin(z_dist, dim=-1)
147 | dist_prob = None
148 |
149 | return k, z_dist, dist_prob
150 |
151 | def _t_dist_sim(self, z_dist, df=10):
152 | """
153 | Transform distances using t-distribution kernel.
154 |
155 | Parameters
156 | ----------
157 | z_dist: torch.Tensor
158 | distances between encodings and centroids
159 |
160 | Returns
161 | -------
162 | dist_prob: torch.Tensor
163 | probability of closeness of encodings to centroids transformed by t-distribution
164 |
165 | """
166 |
167 | _factor = - ((df + 1) / 2)
168 | dist_prob = torch.pow((1 + z_dist / df), _factor)
169 | dist_prob = dist_prob / dist_prob.sum(axis=1).unsqueeze(1)
170 |
171 | return dist_prob
172 |
173 | def _get_embeddings(self, k):
174 | """
175 |
176 | Get the embeddings (discrete representations).
177 |
178 | Parameters
179 | ----------
180 | k: torch.Tensor
181 | cluster assignments
182 |
183 | Returns
184 | -------
185 | z_q: torch.Tensor
186 | low-dimensional embeddings (discrete representations)
187 |
188 | """
189 |
190 | k = k.long()
191 | _z_q = []
192 | for i in range(len(k)):
193 | _z_q.append(self.embeddings[k[i]])
194 |
195 | z_q = torch.stack(_z_q)
196 |
197 | return z_q
198 |
199 |
200 | def _loss_reconstruct(self, x, x_e, x_q):
201 | """
202 | Calculate reconstruction loss.
203 |
204 | Parameters
205 | -----------
206 | x: torch.Tensor
207 | original observation in full-dimension
208 | x_e: torch.Tensor
209 | reconstructed observation encodings
210 | x_q: torch.Tensor
211 | reconstructed observation from embeddings (discrete representations)
212 | """
213 |
214 | l_e = self.mse_loss(x, x_e)
215 | l_q = self.mse_loss(x, x_q)
216 | mse_l = l_e + l_q * self.zeta
217 | return mse_l
218 |
219 | def _loss_z_prob(self, z_dist, prior_prob=None):
220 | """
221 | Calculate annotator loss.
222 |
223 | Parameters
224 | ----------
225 | z_dist: torch.Tensor
226 | distances between encodings and centroids
227 | prior_prob: torch.Tensor
228 | probability learned from other source (e.g. prior) about cluster assignment
229 |
230 | """
231 | if self.regulating_probability == "classifier":
232 |
233 | weighted_z_dist_prob = z_dist * prior_prob
234 | prob_z_l = torch.mean(weighted_z_dist_prob)
235 |
236 | else:
237 | raise NotImplementedError(f"The current implementation only support "
238 | f"'classifier' for prob_z_l method.")
239 |
240 | return prob_z_l
241 |
242 | def loss(self, x, x_e, x_q, z_dist, prior_prob=None):
243 |
244 | mse_l = self._loss_reconstruct(x, x_e, x_q)
245 | prob_z_l = self._loss_z_prob(z_dist, prior_prob)
246 |
247 | l = mse_l + self.tau * prob_z_l
248 |
249 | return l
250 |
251 |
252 |
--------------------------------------------------------------------------------
/unifan/autoencoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | torch.backends.cudnn.benchmark = True
5 |
6 | from unifan.networks import Encoder, Decoder, Set2Gene, LinearCoder, NonNegativeCoder, SigmoidCoder
7 |
8 |
9 | class autoencoder(nn.Module):
10 | """
11 |
12 | Autoencoder used for pre-training.
13 |
14 | Parameters
15 | ----------
16 | input_dim: integer
17 | number of input features
18 | z_dim: integer
19 | number of low-dimensional features
20 | gene_set_dim: integer
21 | number of gene sets
22 | encoder_dim: integer
23 | dimension of hidden layer for encoders
24 | emission_dim: integer
25 | dimension of hidden layer for decoders
26 | num_layers_encoder: integer
27 | number of hidden layers for encoder
28 | num_layers_decoder: integer
29 | number of hidden layers for decoder
30 | dropout_rate: float
31 | gene_set_table: torch.Tensor
32 | gene set relationship table
33 |
34 | """
35 |
36 | def __init__(self, input_dim: int = 10000, z_dim: int = 32, gene_set_dim: int = 335, encoder_dim: int = 128,
37 | emission_dim: int = 128, num_layers_encoder: int = 1, num_layers_decoder: int = 1,
38 | dropout_rate: float = 0.1, reconstruction_network: str = "non-negative",
39 | decoding_network: str = "geneSet", gene_set_table: torch.Tensor = None, use_cuda: bool = False):
40 |
41 | super().__init__()
42 |
43 | # initialize parameters
44 | self.z_dim = z_dim
45 | self.reconstruction_network = reconstruction_network
46 | self.decoding_network = decoding_network
47 |
48 | # initialize loss
49 | self.mse_loss = nn.MSELoss()
50 |
51 | # initialize encoder and decoder
52 | if self.reconstruction_network == 'linear' and self.decoding_network == 'linear':
53 | self.encoder = LinearCoder(input_dim, z_dim)
54 | self.decoder_e = LinearCoder(z_dim, input_dim)
55 | else:
56 |
57 | if self.reconstruction_network == 'non-negative':
58 | # instantiate encoder for z
59 | self.encoder = NonNegativeCoder(input_dim, z_dim, num_layers=num_layers_encoder, hidden_dim=encoder_dim,
60 | dropout_rate=dropout_rate)
61 | elif self.reconstruction_network == 'sigmoid':
62 | # instantiate encoder for z
63 | self.encoder = SigmoidCoder(input_dim, z_dim, num_layers=num_layers_encoder, hidden_dim=encoder_dim,
64 | dropout_rate=dropout_rate)
65 | elif self.reconstruction_network == "gaussian":
66 | # instantiate encoder for z, using standard encoder
67 | self.encoder = Encoder(input_dim, z_dim, num_layers=num_layers_encoder, hidden_dim=encoder_dim,
68 | dropout_rate=dropout_rate)
69 |
70 | else:
71 | raise NotImplementedError(f"The current implementation only support 'gaussian', "
72 | f"'non-negative' or 'sigmoid' for encoder.")
73 |
74 | # instantiate decoder for emission
75 | if self.decoding_network == 'gaussian':
76 | self.decoder_e = Decoder(z_dim, input_dim, num_layers=num_layers_decoder, hidden_dim=emission_dim)
77 | elif self.decoding_network == 'geneSet':
78 | self.decoder_e = Set2Gene(gene_set_table)
79 | else:
80 | raise NotImplementedError(f"The current implementation only support 'gaussian', "
81 | f"'geneSet' for emission decoder.")
82 |
83 | self.use_cuda = use_cuda
84 | if use_cuda:
85 | self.cuda()
86 |
87 | def forward(self, data):
88 |
89 | x = data
90 |
91 | # get encoding
92 | z_e, _ = self.encoder(x)
93 |
94 | # decode encoding
95 | x_e, _ = self.decoder_e(z_e)
96 |
97 | return x_e, z_e
98 |
99 | def _loss_reconstruct(self, x, x_e):
100 | """
101 | Calculate reconstruction loss.
102 |
103 | Parameters
104 | ----------
105 | x: torch.Tensor
106 | original data
107 | x_e: torch.Tensor
108 | reconstructed data
109 |
110 | Returns
111 | -------
112 | mse_l: torch.Tensor
113 | reconstruction loss
114 |
115 | """
116 | l_e = self.mse_loss(x, x_e)
117 | mse_l = l_e
118 |
119 | return mse_l
120 |
121 | def loss(self, x, x_e):
122 | l = self._loss_reconstruct(x, x_e)
123 | return l
124 |
--------------------------------------------------------------------------------
/unifan/classifier.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | torch.backends.cudnn.benchmark = True
5 |
6 | from unifan.networks import Decode2Labels
7 |
8 |
9 | class classifier(nn.Module):
10 | """
11 |
12 | A classifier.
13 |
14 | Parameters
15 | ----------
16 | z_dim: integer
17 | number of input features
18 | output_dim: integer
19 | number of output (number of types of labels)
20 | emission_dim: integer
21 | dimension of hidden layer
22 | num_layers: integer
23 | number of hidden layers
24 |
25 | """
26 |
27 | def __init__(self, z_dim: int = 335, output_dim: int = 10, emission_dim: int = 128, num_layers: int = 1,
28 | use_cuda=False):
29 | super().__init__()
30 |
31 | # initialize loss
32 | self.loss_function = nn.NLLLoss()
33 |
34 | # instantiate decoder for emission
35 | self.decoder = Decode2Labels(z_dim, output_dim)
36 |
37 | self.use_cuda = use_cuda
38 | if use_cuda:
39 | self.cuda()
40 |
41 | def forward(self, x):
42 | y_pre = self.decoder(x)
43 | return y_pre
44 |
45 | def loss(self, y_pre, y_true):
46 | l = self.loss_function(y_pre, y_true)
47 | return l
48 |
--------------------------------------------------------------------------------
/unifan/datasets.py:
--------------------------------------------------------------------------------
1 | import scanpy as sc
2 | import numpy as np
3 | from torch.utils.data import Dataset
4 |
5 |
6 | class AnnDataset(Dataset):
7 | def __init__(self, filepath: str, label_name: str = None, second_filepath: str = None,
8 | variable_gene_name: str = None):
9 | """
10 |
11 | Anndata dataset.
12 |
13 | Parameters
14 | ----------
15 | label_name: string
16 | name of the cell type annotation, default 'label'
17 | second_filepath: string
18 | path to another input file other than the main one; e.g. path to predicted clusters or
19 | side information; only support numpy array
20 |
21 | """
22 |
23 | super().__init__()
24 |
25 | self.data = sc.read(filepath, dtype='float64', backed="r")
26 |
27 | genes = self.data.var.index.values
28 | self.genes_upper = [g.upper() for g in genes]
29 | if label_name is not None:
30 | self.clusters_true = self.data.obs[label_name].values
31 | else:
32 | self.clusters_true = None
33 |
34 | self.N = self.data.shape[0]
35 | self.G = len(self.genes_upper)
36 |
37 | self.secondary_data = None
38 | if second_filepath is not None:
39 | self.secondary_data = np.load(second_filepath)
40 | assert len(self.secondary_data) == self.N, "The other file have same length as the main"
41 |
42 | if variable_gene_name is not None:
43 | _idx = np.where(self.data.var[variable_gene_name].values)[0]
44 | self.exp_variable_genes = self.data.X[:, _idx]
45 | self.variable_genes_names = self.data.var.index.values[_idx]
46 |
47 | def __len__(self):
48 | return self.N
49 |
50 | def __getitem__(self, idx):
51 | main = self.data[idx].X.flatten()
52 |
53 | if self.secondary_data is not None:
54 | secondary = self.secondary_data[idx].flatten()
55 | return main, secondary
56 | else:
57 | return main
58 |
59 |
60 | class NumpyDataset(Dataset):
61 | def __init__(self, filepath: str, second_filepath: str = None):
62 | """
63 |
64 | Numpy array dataset.
65 |
66 | Parameters
67 | ----------
68 | second_filepath: string
69 | path to another input file other than the main one; e.g. path to predicted clusters or
70 | side information; only support numpy array
71 |
72 | """
73 | super().__init__()
74 |
75 | self.data = np.load(filepath)
76 | self.N = self.data.shape[0]
77 | self.G = self.data.shape[1]
78 |
79 | self.secondary_data = None
80 | if second_filepath is not None:
81 | self.secondary_data = np.load(second_filepath)
82 | assert len(self.secondary_data) == self.N, "The other file have same length as the main"
83 |
84 | def __len__(self):
85 | return self.N
86 |
87 | def __getitem__(self, idx):
88 | main = self.data[idx].flatten()
89 |
90 | if self.secondary_data is not None:
91 | secondary = self.secondary_data[idx].flatten()
92 | return main, secondary
93 | else:
94 | return main
95 |
--------------------------------------------------------------------------------
/unifan/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import os
3 | import gc
4 | import itertools
5 | import argparse
6 |
7 | import torch
8 | import scanpy as sc
9 | import pandas as pd
10 | import numpy as np
11 |
12 | from unifan.datasets import AnnDataset, NumpyDataset
13 | from unifan.annocluster import AnnoCluster
14 | from unifan.autoencoder import autoencoder
15 | from unifan.classifier import classifier
16 | from unifan.utils import getGeneSetMatrix, str2bool
17 | from unifan.trainer import Trainer
18 |
19 |
20 | def main():
21 | # parse command-line arguments
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument('-i', '--input', required=True, type=str,
24 | default='../input/data.h5ad', help="string, path to the input expression data, "
25 | "default '../input/data.h5ad'")
26 | parser.add_argument('-o', '--output', required=True, type=str,
27 | default='../output/', help="string, path to the output folder, default '../output/'")
28 | parser.add_argument('-p', '--project', required=True, type=str,
29 | default='data', help="string, identifier for the project, e.g., tabula_muris")
30 | parser.add_argument('-t', '--tissue', required=True, type=str,
31 | default='tissue', help="string, tissue where the input data is sampled from")
32 | parser.add_argument('-e', '--geneSetsPath', required=True, type=str,
33 | default='../gene_sets/', help="string, path to the folder where gene sets can be found, "
34 | "default='../gene_sets/'")
35 | parser.add_argument('-l', '--label', required=False, type=str,
36 | default=None, help="string, optional, the column / field name of the ground truth label, if "
37 | "available; used for evaluation only; default None")
38 | parser.add_argument('-v', '--variable', required=False, type=str,
39 | default='highly_variable', help="string, optional, the column / field name of the highly "
40 | "variable genes; default 'highly_variable'")
41 | parser.add_argument('-r', '--prior', required=False, type=str,
42 | default='c5.go.bp.v7.4.symbols.gmt+c2.cp.v7.4.symbols.gmt+TF-DNA',
43 | help="string, optional, gene set file names used to learn the gene set activity scores, "
44 | "use '+' to separate multiple gene set names, "
45 | "default c5.go.bp.v7.4.symbols.gmt+c2.cp.v7.4.symbols.gmt+TF-DNA")
46 | parser.add_argument('-f', '--features', required=False, default='gene_gene_sets', type=str,
47 | choices=['gene_sets', 'gene', 'gene_gene_sets'],
48 | help="string, optional, features used for the annotator, any of 'gene_sets', 'gene' or "
49 | "'gene_gene_sets', default 'gene_gene_sets'")
50 | parser.add_argument('-a', '--alpha', required=False, default=1e-2, type=float,
51 | help="float, optional, hyperparameter for the L1 term in the set cover loss, default 1e-2")
52 | parser.add_argument('-b', '--beta', required=False, default=1e-5, type=float,
53 | help="float, optional, hyperparameter for the set cover term in the set cover loss, "
54 | "default 1e-5")
55 | parser.add_argument('-g', '--gamma', required=False, default=1e-3, type=float,
56 | help="float, optional, hyperparameter for the exclusive L1 term, default 1e-3")
57 | parser.add_argument('-u', '--tau', required=False, default=10, type=float,
58 | help="float, optional, hyperparameter for the annotator loss, default 10")
59 | parser.add_argument('-d', '--dim', required=False, default=32, type=int,
60 | help="integer, optional, dimension for the low-dimensional representation, default 32")
61 | parser.add_argument('-s', '--batch', required=False, default=128, type=int,
62 | help="integer, optional, batch size for training except for pretraining annotator "
63 | "(fixed at 32), default 128")
64 | parser.add_argument('-na', '--nanno', required=False, default=50, type=int,
65 | help="integer, optional, number of epochs to pretrain the annotator, default 50")
66 | parser.add_argument('-ns', '--nscore', required=False, default=70, type=int,
67 | help="integer, optional, number of epochs to train the gene set activity model, default 70")
68 | parser.add_argument('-nu', '--nauto', required=False, default=50, type=int,
69 | help="integer, optional, number of epochs to pretrain the AnnoCluster model, default 50")
70 | parser.add_argument('-nc', '--ncluster', required=False, default=25, type=int,
71 | help="integer, optional, number of epochs to train the AnnoCluster model, default 25")
72 | parser.add_argument('-nze', '--nzenco', required=False, default=3, type=int,
73 | help="float, optional, number of hidden layers for encoder of AnnoCluster, default 3")
74 | parser.add_argument('-nzd', '--nzdeco', required=False, default=2, type=int,
75 | help="float, optional, number of hidden layers for decoder of AnnoCluster, default 2")
76 | parser.add_argument('-dze', '--dimzenco', required=False, default=128, type=int,
77 | help="integer, optional, number of nodes for hidden layers for encoder of AnnoCluster, "
78 | "default 128")
79 | parser.add_argument('-dzd', '--dimzdeco', required=False, default=128, type=int,
80 | help="integer, optional, number of nodes for hidden layers for decoder of AnnoCluster, "
81 | "default 128")
82 | parser.add_argument('-nre', '--nrenco', required=False, default=5, type=int,
83 | help="integer, optional, number of hidden layers for the encoder of gene set activity scores "
84 | "model, default 5")
85 | parser.add_argument('-dre', '--dimrenco', required=False, default=128, type=int,
86 | help="integer, optional, number of nodes for hidden layers for encoder of gene set activity "
87 | "scores model, default 128")
88 | parser.add_argument('-drd', '--dimrdeco', required=False, default=128, type=int,
89 | help="integer, optional, number of nodes for hidden layers for decoder of gene set activity "
90 | "scores model, default 128")
91 | parser.add_argument('-n', '--network', required=False, choices=['sigmoid', 'non-negative', 'gaussian'], type=str,
92 | default='non-negative', help="string, optional, the encoder for the gene set activity model, "
93 | "any of 'sigmoid', 'non-negative' or 'gaussian', "
94 | "default 'non-negative'")
95 | parser.add_argument('-m', '--seed', required=False, default=0, type=int,
96 | help="integer, optional, random seed for the initialization, default 0")
97 | parser.add_argument('-c', '--cuda', required=False, type=str2bool,
98 | default=False, help="boolean, optional, if use GPU for neural network training, default False")
99 | parser.add_argument('-w', '--nworkers', required=False, default=8, type=int,
100 | help="integer, optional, number of workers for dataloader, default 8")
101 |
102 |
103 | args = parser.parse_args()
104 | print(args)
105 |
106 | data_filepath = args.input
107 | output_path = args.output
108 | gene_sets_path = args.geneSetsPath
109 | project = args.project
110 | tissue = args.tissue
111 | label_name = args.label
112 | variable_gene_name = args.variable
113 |
114 | prior_name = args.prior
115 | features_type = args.features
116 | alpha = args.alpha
117 | beta = args.beta
118 | weight_decay = args.gamma
119 | tau = args.tau
120 | z_dim = args.dim
121 |
122 | batch_size = args.batch
123 | num_epochs_classifier = args.nanno
124 | num_epochs_r = args.nscore
125 | num_epochs_z = args.nauto
126 | r_epoch = num_epochs_r - 1
127 | z_epoch = num_epochs_z - 1
128 | num_epochs_annocluster = args.ncluster
129 |
130 | z_encoder_layers = args.nzenco
131 | z_decoder_layers = args.nzdeco
132 | z_encoder_dim = args.dimzenco
133 | z_decoder_dim = args.dimzdeco
134 | r_encoder_layers = args.nrenco
135 | r_decoder_layers = 1
136 | r_encoder_dim = args.dimrenco
137 | r_decoder_dim = args.dimrdeco
138 | rnetwork = args.network
139 |
140 | random_seed = args.seed
141 |
142 | use_cuda = args.cuda
143 | num_workers = args.nworkers
144 |
145 | # ------ training conditions
146 | device = torch.device("cuda" if use_cuda else "cpu")
147 | if use_cuda:
148 | pin_memory = True
149 | non_blocking = True
150 | else:
151 | pin_memory = False
152 | non_blocking = False
153 |
154 | if '+' in prior_name:
155 | prior_names_list = prior_name.split('+')
156 |
157 | # ------ prepare for output
158 | output_parent_path = os.path.join(output_path, f"{project}/{tissue}/")
159 |
160 | r_folder = f"{output_parent_path}r"
161 | input_r_ae_path = os.path.join(r_folder, f"r_model_{r_epoch}.pickle")
162 | input_r_path = os.path.join(r_folder, f"r_{r_epoch}.npy")
163 | input_r_names_path = os.path.join(r_folder, f"r_names_{r_epoch}.npy")
164 |
165 | pretrain_z_folder = f"{output_parent_path}pretrain_z"
166 | input_z_path = os.path.join(pretrain_z_folder, f"pretrain_z_{z_epoch}.npy")
167 | input_ae_path = os.path.join(pretrain_z_folder, f"pretrain_z_model_{z_epoch}.pickle")
168 | input_cluster_path = os.path.join(pretrain_z_folder, f"cluster_{z_epoch}.npy")
169 |
170 | pretrain_annotator_folder = f"{output_parent_path}pretrain_annotator"
171 | annocluster_folder = f"{output_parent_path}annocluster_{features_type}"
172 |
173 | # ------ load data
174 | if features_type in ["gene", "gene_gene_sets"]:
175 | expression_only = AnnDataset(data_filepath, label_name=label_name, variable_gene_name=variable_gene_name)
176 | exp_variable_genes = expression_only.exp_variable_genes
177 | variable_genes_names = expression_only.variable_genes_names
178 | else:
179 | expression_only = AnnDataset(data_filepath, label_name=label_name)
180 | exp_variable_genes = None
181 | variable_genes_names = None
182 |
183 | genes_upper = expression_only.genes_upper
184 | N = expression_only.N
185 | G = expression_only.G
186 |
187 | # ------ process prior data
188 | # generate gene_set_matrix
189 | if '+' in prior_name:
190 | _matrix_list = []
191 | _keys_list = []
192 | for _name in prior_names_list:
193 | _matrix, _keys = getGeneSetMatrix(_name, genes_upper, gene_sets_path)
194 | _matrix_list.append(_matrix)
195 | _keys_list.append(_keys)
196 |
197 | gene_set_matrix = np.concatenate(_matrix_list, axis=0)
198 | keys_all = list(itertools.chain(*_keys_list))
199 |
200 | del _matrix_list
201 | del _keys_list
202 | gc.collect()
203 |
204 | else:
205 | gene_set_matrix, keys_all = getGeneSetMatrix(prior_name, genes_upper, gene_sets_path)
206 |
207 | # ------ set-up for the set cover loss
208 | if beta != 0:
209 | # get the gene set matrix with only genes covered
210 | genes_covered = np.sum(gene_set_matrix, axis=0)
211 | gene_covered_matrix = gene_set_matrix[:, genes_covered != 0]
212 | gene_covered_matrix = torch.from_numpy(gene_covered_matrix).to(device, non_blocking=non_blocking).float()
213 | beta_list = torch.from_numpy(np.repeat(beta, gene_covered_matrix.shape[1])).to(device,
214 | non_blocking=non_blocking).float()
215 |
216 | del genes_covered
217 | gc.collect()
218 | else:
219 | gene_covered_matrix = None
220 | beta_list = None
221 |
222 | gene_set_dim = gene_set_matrix.shape[0]
223 | gene_set_matrix = torch.from_numpy(gene_set_matrix).to(device, non_blocking=non_blocking)
224 |
225 | # ------ Train gene set activity scores (r) model ------
226 | if features_type == "gene":
227 | z_gene_set = exp_variable_genes
228 | set_names = list(variable_genes_names)
229 | else:
230 |
231 | model_gene_set = autoencoder(input_dim=G, z_dim=gene_set_dim, gene_set_dim=gene_set_dim,
232 | encoder_dim=r_encoder_dim, emission_dim=r_decoder_dim,
233 | num_layers_encoder=r_encoder_layers, num_layers_decoder=r_decoder_layers,
234 | reconstruction_network=rnetwork, decoding_network='geneSet',
235 | gene_set_table=gene_set_matrix, use_cuda=use_cuda)
236 |
237 | if os.path.isfile(input_r_path):
238 | print(f"Inferred r exists. No need to train the gene set activity scores model.")
239 | z_gene_set = np.load(input_r_path)
240 | else:
241 | if os.path.isfile(input_r_ae_path):
242 | model_gene_set.load_state_dict(torch.load(input_r_ae_path, map_location=device)['state_dict'])
243 |
244 | trainer = Trainer(dataset=expression_only, model=model_gene_set, model_name="r", batch_size=batch_size,
245 | num_epochs=num_epochs_r, save_infer=True, output_folder=r_folder, num_workers=num_workers,
246 | use_cuda=use_cuda)
247 | if os.path.isfile(input_r_ae_path):
248 | print(
249 | f"Inferred r model exists but r does not. Need to infer r and no need to train the gene set activity "
250 | f"scores model.")
251 | z_gene_set = trainer.infer_r(alpha=alpha, beta=beta, beta_list=beta_list,
252 | gene_covered_matrix=gene_covered_matrix)
253 | np.save(input_r_path, z_gene_set)
254 | else:
255 | print(f"Start training the gene set activity scores model ... ")
256 | trainer.train(alpha=alpha, beta=beta, beta_list=beta_list, gene_covered_matrix=gene_covered_matrix)
257 | z_gene_set = np.load(input_r_path)
258 |
259 | z_gene_set = torch.from_numpy(z_gene_set)
260 |
261 | # filter r to keep only non-zero values
262 | idx_non_0_gene_sets = np.where(z_gene_set.numpy().sum(axis=0) != 0)[0]
263 |
264 | # get kepted gene set names
265 | set_names = np.array(keys_all)[idx_non_0_gene_sets]
266 |
267 | z_gene_set = z_gene_set[:, idx_non_0_gene_sets]
268 | print(f"Aftering filtering, we have {z_gene_set.shape[1]} genesets")
269 |
270 | # add also selected genes if using "gene_gene_sets"
271 | if features_type == "gene_gene_sets":
272 | z_gene_set = np.concatenate([z_gene_set, exp_variable_genes], axis=1)
273 | set_names = list(set_names) + list(variable_genes_names)
274 | else:
275 | pass
276 |
277 | print(f"z_gene_set: {features_type}: {z_gene_set.shape}")
278 | print(f"z_gene_set: {features_type}: {len(set_names)}")
279 |
280 | # save feature names
281 | input_r_names_path = f"{input_r_names_path}_filtered_{features_type}.npy"
282 | np.save(input_r_names_path, set_names)
283 |
284 | # save processed features
285 | input_r_path = f"{input_r_path}_filtered_{features_type}.npy"
286 | np.save(input_r_path, z_gene_set)
287 | gene_set_dim = z_gene_set.shape[1]
288 |
289 | try:
290 | z_gene_set = z_gene_set.numpy()
291 | except AttributeError:
292 | pass
293 |
294 | # ------ Pretrain annocluster & initialize clustering ------
295 | model_autoencoder = autoencoder(input_dim=G, z_dim=z_dim, gene_set_dim=gene_set_dim,
296 | encoder_dim=z_encoder_dim, emission_dim=z_decoder_dim,
297 | num_layers_encoder=z_encoder_layers, num_layers_decoder=z_decoder_layers,
298 | reconstruction_network='gaussian', decoding_network='gaussian',
299 | use_cuda=use_cuda)
300 |
301 | if os.path.isfile(input_z_path) and os.path.isfile(input_ae_path):
302 | print(f"Both pretrained autoencoder and inferred z exist. No need to pretrain the annocluster model.")
303 | z_init = np.load(input_z_path)
304 | model_autoencoder.load_state_dict(torch.load(input_ae_path, map_location=device)['state_dict'])
305 | else:
306 | if os.path.isfile(input_ae_path):
307 | model_autoencoder.load_state_dict(torch.load(input_ae_path, map_location=device)['state_dict'])
308 |
309 | trainer = Trainer(dataset=expression_only, model=model_autoencoder, model_name="pretrain_z",
310 | batch_size=batch_size,
311 | num_epochs=num_epochs_z, save_infer=True, output_folder=pretrain_z_folder,
312 | num_workers=num_workers,
313 | use_cuda=use_cuda)
314 |
315 | if os.path.isfile(input_ae_path):
316 | print(f"Only pretrained autoencoder exists. Need to infer z and no need to pretrain the annocluster model.")
317 | z_init = trainer.infer_z()
318 | np.save(input_z_path, z_init)
319 | else:
320 | print(f"Start training pretrain the annocluster model ... ")
321 | trainer.train()
322 | z_init = np.load(input_z_path)
323 |
324 | z_init = torch.from_numpy(z_init)
325 |
326 | try:
327 | z_init = z_init.numpy()
328 | except AttributeError:
329 | pass
330 |
331 | # initialize using leiden clustering
332 | adata = sc.AnnData(X=z_init)
333 | adata.obsm['X_unifan'] = z_init
334 | sc.pp.neighbors(adata, n_pcs=z_dim, use_rep='X_unifan', random_state=random_seed)
335 | sc.tl.leiden(adata, resolution=1, random_state=random_seed)
336 | clusters_pre = adata.obs['leiden'].astype('int').values # original as string
337 |
338 | # save for the dataset for classifier training
339 | np.save(input_cluster_path, clusters_pre)
340 |
341 | # initialize centroids
342 | try:
343 | df_cluster = pd.DataFrame(z_init.detach().cpu().numpy())
344 | except AttributeError:
345 | df_cluster = pd.DataFrame(z_init)
346 |
347 | cluster_labels = np.unique(clusters_pre)
348 | M = len(set(cluster_labels)) # set as number of clusters
349 | df_cluster['cluster'] = clusters_pre
350 |
351 | # get centroids
352 | centroids = df_cluster.groupby('cluster').mean().values
353 | centroids_torch = torch.from_numpy(centroids)
354 |
355 | # ------ pretrain annotator (classification) ------
356 | cls_times = 1 # count how many times of running classification
357 | cls_training_accuracy = 1 # initialize being 1 so that to run at least once
358 | weight_decay_candidates = [50, 20, 10, 5.5, 5, 4.5, 4, 3.5, 3, 2.5, 2, 1, 5e-1, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5]
359 | idx_starting_weight_decay = weight_decay_candidates.index(weight_decay)
360 |
361 | while cls_training_accuracy >= 0.99:
362 | # assign new weight decay (first time running kepted the same)
363 | weight_decay = weight_decay_candidates[idx_starting_weight_decay - cls_times + 1]
364 |
365 | print(f"Run classifier the {cls_times}th time with {weight_decay}")
366 |
367 | prior_cluster = NumpyDataset(input_r_path, input_cluster_path)
368 |
369 | model_classifier = classifier(output_dim=M, z_dim=gene_set_dim, emission_dim=128, use_cuda=use_cuda)
370 |
371 | trainer = Trainer(dataset=prior_cluster, model=model_classifier, model_name="pretrain_annotator", batch_size=32,
372 | num_epochs=num_epochs_classifier, save_infer=False, output_folder=pretrain_annotator_folder,
373 | num_workers=num_workers, use_cuda=use_cuda)
374 |
375 | trainer.train(weight_decay=weight_decay)
376 | clusters_classifier = trainer.infer_annotator()
377 |
378 | cls_training_accuracy = (clusters_classifier == clusters_pre).sum() / N
379 | print(f"Cluster accuracy on training: \n {cls_training_accuracy}")
380 |
381 | cls_times += 1
382 |
383 | # ------ clustering ------
384 | num_epochs = num_epochs_annocluster
385 | use_pretrain = True
386 |
387 | model_annocluster = AnnoCluster(input_dim=G, z_dim=z_dim, gene_set_dim=gene_set_dim, tau=tau, n_clusters=M,
388 | encoder_dim=z_encoder_dim, emission_dim=z_decoder_dim,
389 | num_layers_encoder=z_encoder_layers, num_layers_decoder=z_decoder_layers,
390 | use_t_dist=True, reconstruction_network='gaussian', decoding_network='gaussian',
391 | centroids=centroids_torch, gene_set_table=gene_set_matrix, use_cuda=use_cuda)
392 |
393 | if use_pretrain:
394 | pretrained_state_dict = model_autoencoder.state_dict()
395 |
396 | # load pretrained AnnoCluster model
397 | state_dict = model_annocluster.state_dict()
398 | for k, v in state_dict.items():
399 | if k in pretrained_state_dict.keys():
400 | state_dict[k] = pretrained_state_dict[k]
401 |
402 | model_annocluster.load_state_dict(state_dict)
403 |
404 | # reload dataset, loading gene set activity scores together
405 | expression_prior = AnnDataset(data_filepath, second_filepath=input_r_path, label_name=label_name)
406 |
407 | trainer = Trainer(dataset=expression_prior, model=model_annocluster, model_2nd=model_classifier,
408 | model_name="annocluster", batch_size=batch_size, num_epochs=num_epochs_annocluster,
409 | save_infer=True, output_folder=annocluster_folder, num_workers=num_workers, use_cuda=use_cuda)
410 | trainer.train(weight_decay=weight_decay)
411 |
412 |
413 | if __name__ == '__main__':
414 | main()
415 |
--------------------------------------------------------------------------------
/unifan/networks.py:
--------------------------------------------------------------------------------
1 | import collections
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class FullyConnectedLayers(nn.Module):
8 | """
9 | Parameters
10 | ----------
11 | input_dim: integer
12 | number of input features
13 | output_dim: integer
14 | number of output features
15 | num_layers: integer
16 | number of hidden layers
17 | hidden_dim: integer
18 | dimension of hidden layer
19 | dropout_rate: float
20 | bias: boolean
21 | if apply bias to the linear layers
22 | batch_norm: boolean
23 | if apply batch normalization
24 |
25 | """
26 |
27 | def __init__(self, input_dim: int, output_dim: int, num_layers: int = 1, hidden_dim: int = 128,
28 | dropout_rate: float = 0.1, bias: bool = True, batch_norm: bool = False):
29 | super().__init__()
30 | layers_dim = [input_dim] + [hidden_dim for i in range(num_layers - 1)] + [output_dim]
31 |
32 | self.all_layers = nn.Sequential(collections.OrderedDict(
33 | [(f"Layer {i}", nn.Sequential(
34 | nn.Linear(input_dim, output_dim, bias=bias),
35 | nn.BatchNorm1d(output_dim) if batch_norm else None,
36 | nn.ReLU(),
37 | nn.Dropout(p=dropout_rate) if dropout_rate > 0 else None))
38 | for i, (input_dim, output_dim) in enumerate(zip(layers_dim[:-1], layers_dim[1:]))]))
39 |
40 | def forward(self, x: torch.Tensor):
41 |
42 | for layers in self.all_layers:
43 | for layer in layers:
44 | if layer is not None:
45 | x = layer(x)
46 |
47 | return x
48 |
49 |
50 | class Encoder(nn.Module):
51 | """
52 |
53 | A standard encoder.
54 |
55 | Parameters
56 | ----------
57 | input_dim: integer
58 | number of input features
59 | output_dim: integer
60 | number of output features
61 | num_layers: integer
62 | number of hidden layers
63 | hidden_dim: integer
64 | dimension of hidden layer
65 | dropout_rate: float
66 |
67 | """
68 |
69 | def __init__(self, input_dim: int, output_dim: int, num_layers: int = 1, hidden_dim: int = 128,
70 | dropout_rate: float = 0.1):
71 | super().__init__()
72 |
73 | self.encoder = FullyConnectedLayers(input_dim=input_dim, output_dim=hidden_dim, num_layers=num_layers,
74 | hidden_dim=hidden_dim, dropout_rate=dropout_rate)
75 | self.mean_layer = nn.Linear(hidden_dim, output_dim)
76 | self.var_layer = nn.Linear(hidden_dim, output_dim)
77 |
78 | def forward(self, x: torch.Tensor):
79 | """
80 |
81 | Parameters
82 | ----------
83 | x: torch.Tensor
84 |
85 | Returns
86 | -------
87 | q_m: torch.Tensor
88 | estimated mean
89 | q_v: torch.Tensor
90 | estimated variance
91 |
92 | """
93 |
94 | q = self.encoder(x)
95 | q_m = self.mean_layer(q)
96 | q_v = torch.exp(self.var_layer(q))
97 | return q_m, q_v
98 |
99 |
100 | class LinearCoder(nn.Module):
101 |
102 | """
103 |
104 | A single-layer linear encoder.
105 |
106 | Parameters
107 | ----------
108 | input_dim: integer
109 | number of input features
110 | output_dim: integer
111 | number of output features
112 | """
113 |
114 | def __init__(self, input_dim: int, output_dim: int):
115 | super().__init__()
116 |
117 | self.encoder = nn.Linear(input_dim, output_dim)
118 |
119 | def forward(self, x: torch.Tensor):
120 | q = self.encoder(x)
121 | return q, None
122 |
123 |
124 | class NonNegativeCoder(nn.Module):
125 |
126 | """
127 |
128 | A encoder outputting non-negative values (using ReLU for the output layer).
129 |
130 | Parameters
131 | ----------
132 | input_dim: integer
133 | number of input features
134 | output_dim: integer
135 | number of output features
136 | num_layers: integer
137 | number of hidden layers
138 | hidden_dim: integer
139 | dimension of hidden layer
140 | dropout_rate: float
141 |
142 | """
143 |
144 | def __init__(self, input_dim: int, output_dim: int, num_layers: int = 1, hidden_dim: int = 128,
145 | dropout_rate: float = 0.1):
146 | super().__init__()
147 |
148 | self.encoder = FullyConnectedLayers(input_dim=input_dim, output_dim=hidden_dim, num_layers=num_layers,
149 | hidden_dim=hidden_dim, dropout_rate=dropout_rate)
150 | self.mean_layer = FullyConnectedLayers(input_dim=hidden_dim, output_dim=output_dim, num_layers=1,
151 | hidden_dim=hidden_dim, dropout_rate=dropout_rate)
152 |
153 | def forward(self, x: torch.Tensor):
154 | q = self.encoder(x)
155 | q = self.mean_layer(q)
156 | return q, None
157 |
158 |
159 | class SigmoidCoder(nn.Module):
160 | """
161 |
162 | A encoder using sigmoid for the output layer.
163 |
164 | Parameters
165 | ----------
166 | input_dim: integer
167 | number of input features
168 | output_dim: integer
169 | number of output features
170 | num_layers: integer
171 | number of hidden layers
172 | hidden_dim: integer
173 | dimension of hidden layer
174 | dropout_rate: float
175 |
176 | """
177 |
178 | def __init__(self, input_dim: int, output_dim: int, num_layers: int = 1, hidden_dim: int = 128,
179 | dropout_rate: float = 0.1):
180 | super().__init__()
181 |
182 | self.encoder = FullyConnectedLayers(input_dim=input_dim, output_dim=hidden_dim, num_layers=num_layers,
183 | hidden_dim=hidden_dim, dropout_rate=dropout_rate)
184 | self.mean_layer = nn.Sequential(nn.Linear(hidden_dim, output_dim), nn.Sigmoid())
185 |
186 | def forward(self, x: torch.Tensor):
187 | q = self.encoder(x)
188 | q = self.mean_layer(q)
189 | return q, None
190 |
191 |
192 | class Decoder(nn.Module):
193 | """
194 |
195 | A standard decoder.
196 |
197 | Parameters
198 | ----------
199 | input_dim: integer
200 | number of input features
201 | output_dim: integer
202 | number of output features
203 | num_layers: integer
204 | number of hidden layers
205 | hidden_dim: integer
206 | dimension of hidden layer
207 | dropout_rate: float
208 |
209 | """
210 |
211 | def __init__(self, input_dim: int, output_dim: int, num_layers: int = 1, hidden_dim: int = 128,
212 | dropout_rate: float = 0.1):
213 | super().__init__()
214 | self.decoder = FullyConnectedLayers(input_dim=input_dim, output_dim=hidden_dim, num_layers=num_layers,
215 | hidden_dim=hidden_dim, dropout_rate=dropout_rate)
216 |
217 | self.mean_layer = nn.Linear(hidden_dim, output_dim)
218 | self.var_layer = nn.Linear(hidden_dim, output_dim)
219 |
220 | def forward(self, x: torch.Tensor):
221 | """
222 |
223 | Parameters
224 | ----------
225 | x: torch.Tensor
226 |
227 | Returns
228 | -------
229 | p_m: torch.Tensor
230 | estimated mean
231 | p_v: torch.Tensor
232 | estimated variance
233 | """
234 |
235 | p = self.decoder(x)
236 | p_m = self.mean_layer(p)
237 | p_v = torch.exp(self.var_layer(p))
238 | return p_m, p_v
239 |
240 |
241 |
242 | class Set2Gene(nn.Module):
243 |
244 | """
245 | Decode by linear combination of known gene set relationship between gene set (input) and genes (output).
246 |
247 | Parameters
248 | ----------
249 | tf_gene_table: torch.Tensor
250 | number of genes x number gene sets (equal to the dimension of input)
251 |
252 | """
253 |
254 | def __init__(self, tf_gene_table: torch.Tensor):
255 | super().__init__()
256 | self.tf_gene_table = tf_gene_table
257 |
258 | def forward(self, x: torch.Tensor):
259 | p_m = torch.mm(x.double(), self.tf_gene_table)
260 | return p_m, None
261 |
262 |
263 | class Decode2Labels(nn.Module):
264 | """
265 |
266 | A linear classifier (logistic classifier).
267 |
268 | Parameters
269 | ----------
270 | input_dim: integer
271 | number of input features
272 | output_dim: integer
273 | number of output features
274 | """
275 |
276 | def __init__(self, input_dim: int, output_dim: int, bias: bool = False):
277 | super().__init__()
278 | self.predictor = nn.Sequential(nn.Linear(input_dim, output_dim, bias=bias), nn.LogSoftmax(dim=-1))
279 |
280 | def forward(self, x: torch.Tensor):
281 | labels = self.predictor(x)
282 | return labels
283 |
284 |
--------------------------------------------------------------------------------
/unifan/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gc
3 | import pickle
4 |
5 | from tqdm import tqdm
6 | import torch
7 | import torch.nn as nn
8 | from sklearn.metrics.cluster import adjusted_rand_score, adjusted_mutual_info_score
9 | from sklearn.model_selection import train_test_split
10 | import seaborn as sns
11 | import numpy as np
12 | import pandas as pd
13 | from matplotlib import pyplot as plt
14 | import umap
15 |
16 | torch.backends.cudnn.benchmark = True
17 |
18 |
19 | class Trainer(nn.Module):
20 | """
21 |
22 | Train NN models.
23 |
24 |
25 | Parameters
26 | ----------
27 | dataset: PyTorch Dataset
28 | model: Pytorch NN model
29 | model_2nd: another Pytorch NN model to be trained together
30 | model_name: string
31 | name of the model, any of "r", "pretrain_z", "annocluster", "pretrain_annotator"
32 | percent_training: float
33 | percentage of data used for training, the rest used for validation
34 | checkpoint_freq: integer
35 | frequency of saving models during training
36 | val_freq: integer
37 | frequency of conducting evaluation (both training and validation set will be evaluated)
38 | visualize_freq: integer
39 | frequency of inferring low-dimensional representation and visualize it using UMAP
40 | save_visual: boolean
41 | if conduct visualization and save the figures to the output folder
42 | save_checkpoint: boolean
43 | if saving checkpoint models during training
44 | save_infer: boolean
45 | if conduct inference (low-dimensional representation for autoencoders; clusters for clustering model;
46 | predicted results for classifiers) when the training finished
47 | output_folder: string
48 | folder to save all outputs
49 |
50 | """
51 | def __init__(self, dataset, model, model_2nd=None, model_name: str = None, batch_size: int = 128,
52 | num_epochs: int = 50, percent_training: float = 1.0, learning_rate: float = 0.0005,
53 | decay_factor: float = 0.9, num_workers: int = 8, use_cuda: bool = False, checkpoint_freq: int = 20,
54 | val_freq: int = 10, visualize_freq: int = 10, save_visual: bool = False, save_checkpoint: bool = False,
55 | save_infer: bool = False, output_folder: str = None):
56 |
57 | super().__init__()
58 |
59 | # device
60 | self.num_workers = num_workers
61 | self.use_cuda = use_cuda
62 | self.device = torch.device("cuda" if use_cuda else "cpu")
63 | self.pin_memory = True if use_cuda else False
64 | self.non_blocking = True if use_cuda else False
65 |
66 | # model
67 | _support_models = ["r", "pretrain_z", "annocluster", "pretrain_annotator"]
68 | if model_name not in _support_models:
69 | raise NotImplementedError(f"The current implementation only support training "
70 | f"for {','.join(_support_models)}.")
71 | self.model_name = model_name
72 | self.model = model.to(self.device)
73 | self.model_2nd = model_2nd.to(self.device) if model_2nd is not None else None
74 |
75 | # optimization
76 | self.batch_size = batch_size
77 | self.num_epochs = num_epochs
78 | self.learning_rate = learning_rate
79 | self.decay_factor = decay_factor
80 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
81 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, 1000, self.decay_factor)
82 | if model_2nd is not None:
83 | self.optimizer_2nd = torch.optim.Adam(self.model_2nd.parameters(), lr=self.learning_rate)
84 | self.scheduler_2nd = torch.optim.lr_scheduler.StepLR(self.optimizer_2nd, 1000, self.decay_factor)
85 | else:
86 | self.optimizer_2nd, self.scheduler_2nd = None, None
87 |
88 | # evaluation
89 | self.checkpoint_freq = checkpoint_freq
90 | self.val_freq = val_freq
91 | self.visual_frequency = visualize_freq
92 | self.output_folder = output_folder
93 | self.percent_training = percent_training
94 | self.save_checkpoint = save_checkpoint
95 | self.save_visual = save_visual
96 | self.save_infer = save_infer
97 |
98 | if output_folder is not None:
99 | if not os.path.exists(output_folder):
100 | os.makedirs(output_folder)
101 |
102 | # data
103 | self.dataset = dataset
104 |
105 | # prepare for data loader
106 | train_length = int(self.dataset.N * self.percent_training)
107 | val_length = self.dataset.N - train_length
108 |
109 | train_data, val_data = torch.utils.data.random_split(self.dataset, (train_length, val_length))
110 | self.dataloader_all = torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size, shuffle=False,
111 | num_workers=self.num_workers, pin_memory=self.pin_memory)
112 | self.dataloader_train = torch.utils.data.DataLoader(train_data, batch_size=self.batch_size, shuffle=True,
113 | num_workers=self.num_workers, pin_memory=self.pin_memory)
114 | if self.percent_training != 1:
115 | self.dataloader_val = torch.utils.data.DataLoader(val_data, batch_size=self.batch_size, shuffle=True,
116 | num_workers=self.num_workers, pin_memory=self.pin_memory)
117 | else:
118 | self.dataloader_val = None
119 |
120 | # initialize functions for each model
121 | self.train_functions = {"r": self.train_epoch_r, "pretrain_z": self.train_epoch_z,
122 | "pretrain_annotator": self.train_epoch_annotator,
123 | "annocluster": self.train_epoch_cluster}
124 | self.train_stats_dicts = {"r": {k: [] for k in ['train_loss', 'val_loss', 'train_sparsity', 'train_set_cover']},
125 | "pretrain_z": {k: [] for k in ['train_loss', 'val_loss']},
126 | "pretrain_annotator": {k: [] for k in ['train_loss', 'val_loss']},
127 | "annocluster": {k: [] for k in ['train_loss', 'val_loss', 'train_mse', 'train_mse_e',
128 | 'train_mse_q', 'train_prob_z_l', 'ARI', 'NMI']}}
129 |
130 | self.evaluate_functions = {"r": self.evaluate_r, "pretrain_z": self.evaluate_z,
131 | "pretrain_annotator": self.evaluate_annotator,
132 | "annocluster": self.evaluate_cluster}
133 |
134 | self.infer_functions = {"r": self.infer_r, "pretrain_z": self.infer_z,
135 | "pretrain_annotator": self.infer_annotator,
136 | "annocluster": self.infer_cluster}
137 |
138 | # initialize the parameter list (used for regularization)
139 | if self.model_name in ["pretrain_annotator", "annocluster"]:
140 | self.annotator_param_list = nn.ParameterList()
141 | if self.model_name == "pretrain_annotator":
142 | for p in self.model.named_parameters():
143 | self.annotator_param_list.append(p[1])
144 | else:
145 | assert self.model_2nd is not None
146 | for p in self.model_2nd.named_parameters():
147 | self.annotator_param_list.append(p[1])
148 |
149 | if self.model_name == "pretrain_annotator" and self.save_visual:
150 | raise NotImplementedError(f"The current implementation only support visualizing "
151 | f"for {','.join(_support_models)[:-1]}.")
152 |
153 | if self.model_name != "annocluster" and self.model_2nd is not None:
154 | raise NotImplementedError(f"The current implementation only support two models training for annocluster")
155 |
156 | def train(self, **kwargs):
157 |
158 | """
159 | Train the model. Will save the trained model & evaluation results as default.
160 |
161 | Parameters
162 | ----------
163 | kwargs: keyword arguements specific to each model (e.g. alpha for gene set activity scores model)
164 |
165 | """
166 |
167 | for epoch in range(self.num_epochs):
168 | self.model.train()
169 | if self.model_2nd is not None:
170 | self.model_2nd.train()
171 |
172 | self.train_functions[self.model_name](**kwargs)
173 |
174 | if epoch % self.val_freq == 0:
175 | with torch.no_grad():
176 | self.model.eval()
177 | if self.model_2nd is not None:
178 | self.model_2nd.eval()
179 |
180 | self.evaluate_functions[self.model_name](**kwargs)
181 |
182 | if self.save_visual and epoch % self.visual_frequency == 0:
183 | with torch.no_grad():
184 | self.model.eval()
185 | if self.model_2nd is not None:
186 | self.model_2nd.eval()
187 | if self.model_name == "annocluster":
188 | _X, _clusters = self.infer_functions[self.model_name](**kwargs)
189 | else:
190 | _X = self.infer_functions[self.model_name](**kwargs)
191 | _clusters = None
192 | self.visualize_UMAP(_X, epoch, self.output_folder, clusters_true=self.dataset.clusters_true,
193 | clusters_pre=_clusters)
194 |
195 | # save model & inference
196 | if self.save_checkpoint and epoch % self.checkpoint_freq == 0:
197 | # save model
198 | _state = {'epoch': epoch,
199 | 'state_dict': self.model.state_dict(),
200 | 'optimizer': self.optimizer.state_dict()}
201 | if self.model_2nd is not None:
202 | _state = {'epoch': epoch,
203 | 'state_dict': self.model.state_dict(),
204 | 'optimizer': self.optimizer.state_dict(),
205 | 'state_dict_2': self.model_2nd.state_dict(),
206 | 'optimizer_2': self.optimizer_2nd.state_dict()}
207 |
208 | torch.save(_state, os.path.join(self.output_folder, f"{self.model_name}_model_{epoch}.pickle"))
209 |
210 | del _state
211 | gc.collect()
212 |
213 | # save model
214 | _state = {'epoch': epoch,
215 | 'state_dict': self.model.state_dict(),
216 | 'optimizer': self.optimizer.state_dict()}
217 | if self.model_2nd is not None:
218 | _state = {'epoch': epoch,
219 | 'state_dict': self.model.state_dict(),
220 | 'optimizer': self.optimizer.state_dict(),
221 | 'state_dict_2': self.model_2nd.state_dict(),
222 | 'optimizer_2': self.optimizer_2nd.state_dict()}
223 |
224 | torch.save(_state, os.path.join(self.output_folder, f"{self.model_name}_model_{epoch}.pickle"))
225 |
226 | del _state
227 | gc.collect()
228 |
229 | # inference and save
230 | if self.save_infer:
231 | if self.model_name == "annocluster":
232 | _X, _clusters = self.infer_functions[self.model_name](**kwargs)
233 | np.save(os.path.join(self.output_folder, f"{self.model_name}_clusters_pre_{epoch}.npy"), _clusters)
234 | else:
235 | _X = self.infer_functions[self.model_name](**kwargs)
236 |
237 | np.save(os.path.join(self.output_folder, f"{self.model_name}_{epoch}.npy"), _X)
238 |
239 | # save training stats
240 | with open(os.path.join(self.output_folder, f"stats_{epoch}.pickle"), 'wb') as handle:
241 | pickle.dump(self.train_stats_dicts[self.model_name], handle, protocol=pickle.HIGHEST_PROTOCOL)
242 |
243 | def train_epoch_r(self, **kwargs):
244 | for batch_idx, (X_batch) in enumerate(tqdm(self.dataloader_train)):
245 | self.process_minibatch_r(X_batch=X_batch, **kwargs)
246 |
247 | def train_epoch_z(self, **kwargs):
248 | for batch_idx, (X_batch) in enumerate(tqdm(self.dataloader_train)):
249 | self.process_minibatch_z(X_batch=X_batch, **kwargs)
250 |
251 | def train_epoch_annotator(self, **kwargs):
252 | for batch_idx, (X_batch, y_batch) in enumerate(tqdm(self.dataloader_train)):
253 | y_batch = torch.flatten(y_batch)
254 | self.process_minibatch_annotator(X_batch, y_batch, **kwargs)
255 |
256 | def train_epoch_cluster(self, **kwargs):
257 | for batch_idx, (X_batch, gene_set_batch) in enumerate(tqdm(self.dataloader_train)):
258 | self.process_minibatch_cluster(X_batch, gene_set_batch, **kwargs)
259 |
260 | def process_minibatch_cluster(self, X_batch, gene_set_batch, weight_decay: float = 0):
261 | """
262 | Process minibatch for the annocluster model.
263 |
264 | Parameters
265 | ----------
266 | X_batch: torch.Tensor
267 | gene expression
268 | gene_set_batch: torch.Tensor
269 | gene set activity scores
270 | weight_decay: float
271 | hyperparameter gamma regularizing exclusive lasso penalty
272 |
273 | """
274 | X_batch = X_batch.to(self.device, non_blocking=self.non_blocking).float()
275 | gene_set_batch = gene_set_batch.to(self.device, non_blocking=self.non_blocking).float()
276 |
277 | if self.model.training and self.model_2nd.training:
278 | self.optimizer.zero_grad(set_to_none=True)
279 | self.optimizer_2nd.zero_grad(set_to_none=True)
280 |
281 | x_e, x_q, z_e, z_q, k, z_dist, dist_prob = self.model(X_batch)
282 | y_pre = self.model_2nd(gene_set_batch)
283 | pre_prob = torch.exp(y_pre)
284 |
285 | l = self.model.loss(X_batch, x_e, x_q, z_dist, prior_prob=pre_prob)
286 | l_prob = self.model_2nd.loss(y_pre, k)
287 |
288 | l_prob += weight_decay * torch.sum(torch.square(torch.sum(torch.abs(self.annotator_param_list[0]), dim=0)))
289 |
290 | if self.model.training and self.model_2nd.training:
291 | l.backward(retain_graph=True)
292 | self.optimizer.step()
293 |
294 | l_prob.backward(retain_graph=True)
295 | self.optimizer_2nd.step()
296 |
297 | self.scheduler.step()
298 | self.scheduler_2nd.step()
299 |
300 | return l.detach().cpu().item(), l_prob.detach().cpu().numpy(), k.detach().cpu().numpy(), \
301 | z_e.detach().cpu().numpy()
302 |
303 | def evaluate_cluster(self, **kwargs):
304 | """
305 |
306 | Evaluate the total loss and each loss term for the training data and the total loss for the validation set
307 | for the annotcluster model.
308 |
309 | """
310 |
311 |
312 | train_stats = self.train_stats_dicts[self.model_name]
313 |
314 | _loss = []
315 | _l_e = []
316 | _l_q = []
317 | _mse_l = []
318 | _prob_z_l = []
319 |
320 | for batch_idx, (X_batch, gene_set_batch) in enumerate(tqdm(self.dataloader_train)):
321 | X_batch = X_batch.to(self.device, non_blocking=self.non_blocking).float()
322 | gene_set_batch = gene_set_batch.to(self.device, non_blocking=self.non_blocking).float()
323 |
324 | x_e, x_q, z_e, z_q, k, z_dist, dist_prob = self.model(X_batch)
325 | y_pre = self.model_2nd(gene_set_batch)
326 | pre_prob = torch.exp(y_pre)
327 |
328 | l = self.model.loss(X_batch, x_e, x_q, z_dist, prior_prob=pre_prob).detach().cpu().item()
329 | l_prob = self.model_2nd.loss(y_pre, k).detach().cpu().item()
330 |
331 | # calculate each loss
332 | l_e = self.model.mse_loss(X_batch, x_e).detach().cpu().item()
333 | l_q = self.model.mse_loss(X_batch, x_q).detach().cpu().item()
334 | mse_l = self.model._loss_reconstruct(X_batch, x_e, x_q).detach().cpu().item()
335 | prob_z_l = self.model._loss_z_prob(z_dist, prior_prob=pre_prob).detach().cpu().item()
336 |
337 | _loss.append(l)
338 | _l_e.append(l_e)
339 | _l_q.append(l_q)
340 | _mse_l.append(mse_l)
341 | _prob_z_l.append(prob_z_l)
342 |
343 | train_stats['train_loss'].append(np.mean(_loss))
344 | train_stats['train_mse'].append(np.mean(_mse_l))
345 | train_stats['train_mse_e'].append(np.mean(_l_e))
346 | train_stats['train_mse_q'].append(np.mean(_l_q))
347 | train_stats['train_prob_z_l'].append(np.mean(_prob_z_l))
348 |
349 | if self.percent_training != 1:
350 | _loss = []
351 | for batch_idx, (X_batch, gene_set_batch) in enumerate(tqdm(self.dataloader_val)):
352 | l, _, _, _ = self.process_minibatch_cluster(X_batch, gene_set_batch, **kwargs)
353 | _loss.append(l)
354 |
355 | train_stats['val_loss'].append(np.mean(_loss))
356 | else:
357 | train_stats['val_loss'].append(np.nan)
358 |
359 | def infer_cluster(self, **kwargs):
360 | """
361 |
362 | Get z_e and clusters from the annocluster model. Also calculate ARI and NMI scores comparing with
363 | the ground truth (if available).
364 |
365 | Returns
366 | -------
367 | z_annocluster: numpy array
368 | z_e of cells
369 | clusters_pre: numpy array
370 | cluster assignments
371 |
372 | """
373 | train_stats = self.train_stats_dicts[self.model_name]
374 |
375 | with torch.no_grad():
376 | self.model.eval()
377 | if self.model_2nd is not None:
378 | self.model_2nd.eval()
379 |
380 | # inference
381 | k_list = []
382 | z_e_list = []
383 |
384 | for batch_idx, (X_batch, gene_set_batch) in enumerate(tqdm(self.dataloader_all)):
385 | _, _, k, z_e = self.process_minibatch_cluster(X_batch, gene_set_batch, **kwargs)
386 | k_list.append(k)
387 | z_e_list.append(z_e)
388 |
389 | clusters_pre = np.concatenate(k_list)
390 | z_annocluster = np.concatenate(z_e_list)
391 |
392 | if self.dataset.clusters_true is not None:
393 | if self.dataset.N > 5e4:
394 | idx_stratified, _ = train_test_split(range(self.dataset.N), test_size=0.5,
395 | stratify=self.dataset.clusters_true)
396 | else:
397 | idx_stratified = range(self.dataset.N)
398 |
399 | # metrics
400 | ari_smaller = adjusted_rand_score(clusters_pre[idx_stratified],
401 | self.dataset.clusters_true[idx_stratified])
402 | nmi_smaller = adjusted_mutual_info_score(clusters_pre, self.dataset.clusters_true)
403 | print(f"annocluster: ARI for smaller cluster: {ari_smaller}")
404 | print(f"annocluster: NMI for smaller cluster: {nmi_smaller}")
405 | else:
406 | ari_smaller = np.nan
407 | nmi_smaller = np.nan
408 |
409 | train_stats["ARI"].append(ari_smaller)
410 | train_stats["NMI"].append(nmi_smaller)
411 |
412 | return z_annocluster, clusters_pre
413 |
414 |
415 | def process_minibatch_annotator(self, X_batch, y_batch, weight_decay: float = 0):
416 |
417 | """
418 | Process minibatch for the annotator model.
419 |
420 | Parameters
421 | ----------
422 | X_batch: torch.Tensor
423 | gene expression
424 | y_batch: torch.Tensor
425 | cluster assignment
426 | weight_decay: float
427 | hyperparameter gamma regularizing exclusive lasso penalty
428 |
429 | """
430 |
431 | X_batch = X_batch.to(self.device, non_blocking=self.non_blocking).float()
432 | y_batch = y_batch.to(self.device, non_blocking=self.non_blocking)
433 |
434 | if self.model.training:
435 | self.optimizer.zero_grad(set_to_none=True)
436 |
437 | y_pre = self.model(X_batch)
438 | l = self.model.loss(y_pre, y_batch)
439 |
440 | l += weight_decay * torch.sum(torch.square(torch.sum(torch.abs(self.annotator_param_list[0]), dim=0)))
441 |
442 | if self.model.training:
443 | l.backward()
444 | self.optimizer.step()
445 | self.scheduler.step()
446 |
447 | return l.detach().cpu().item(), y_pre.detach().cpu().numpy()
448 |
449 | def evaluate_annotator(self, **kwargs):
450 | train_stats = self.train_stats_dicts[self.model_name]
451 | _loss = []
452 | for batch_idx, (X_batch, y_batch) in enumerate(tqdm(self.dataloader_train)):
453 | y_batch = torch.flatten(y_batch)
454 | l, _ = self.process_minibatch_annotator(X_batch, y_batch, **kwargs)
455 | _loss.append(l)
456 |
457 | train_stats['train_loss'].append(np.mean(_loss))
458 |
459 | if self.percent_training != 1:
460 | _loss = []
461 | for batch_idx, (X_batch, y_batch) in enumerate(tqdm(self.dataloader_val)):
462 | y_batch = torch.flatten(y_batch)
463 | l, _ = self.process_minibatch_annotator(X_batch, y_batch, **kwargs)
464 | _loss.append(l)
465 |
466 | train_stats['val_loss'].append(np.mean(_loss))
467 | else:
468 | train_stats['val_loss'].append(np.nan)
469 |
470 | def infer_annotator(self):
471 | """
472 |
473 | Get the prediction of labels on the all data.
474 |
475 | Return
476 | ------
477 | clusters_classifier: numpy array
478 | predicted labels from the trained annotator
479 |
480 | """
481 | with torch.no_grad():
482 | self.model.eval()
483 | tf_prob = self.model(
484 | torch.from_numpy(self.dataset.data).to(self.device, non_blocking=self.non_blocking).float())
485 |
486 | clusters_prob_pre = torch.exp(tf_prob)
487 | clusters_classifier = np.argmax(clusters_prob_pre.detach().cpu().numpy(), axis=1)
488 |
489 | return clusters_classifier
490 |
491 | def process_minibatch_r(self, X_batch, alpha: float = 0, beta: float = 0, beta_list: torch.Tensor = None,
492 | gene_covered_matrix: torch.Tensor = None):
493 | """
494 |
495 | Process minibatch for gene set activity scores model (named as r).
496 |
497 | Parameters
498 | ----------
499 | X_batch: torch.Tensor
500 | gene expression
501 | alpha: float
502 | hyperparameter regularizing L1 term in the set cover loss
503 | beta: float
504 | hyperparameter regularizing set loss term in the set cover loss
505 | beta_list: torch.Tensor
506 | beta values for all genes
507 | gene_covered_matrix: torch.Tensor
508 | gene set membership matrix with genes that are at least covered by one of the available sets
509 |
510 | """
511 |
512 |
513 | X_batch = X_batch.to(self.device, non_blocking=self.non_blocking).float()
514 |
515 | if self.model.training:
516 | self.optimizer.zero_grad(set_to_none=True)
517 |
518 | x_e, z_e = self.model(X_batch)
519 | l = self.model.loss(X_batch.float(), x_e.float())
520 |
521 | # add sparsity regularization on the output
522 | if alpha != 0:
523 | sparsity_penalty = alpha * torch.mean(torch.abs(z_e))
524 | l += sparsity_penalty
525 | else:
526 | sparsity_penalty = torch.zeros(1)
527 |
528 | if beta != 0:
529 | cover_penality = - torch.mean(torch.matmul(torch.mm(z_e, gene_covered_matrix),
530 | beta_list))
531 | l += cover_penality
532 | else:
533 | cover_penality = torch.zeros(1)
534 |
535 | if self.model.training:
536 | l.backward()
537 | self.optimizer.step()
538 | self.scheduler.step()
539 |
540 | return l.detach().cpu().item(), z_e.detach().cpu().numpy(), \
541 | sparsity_penalty.detach().cpu().numpy(), cover_penality.detach().cpu().numpy()
542 |
543 | def evaluate_r(self, **kwargs):
544 | train_stats = self.train_stats_dicts[self.model_name]
545 | _loss = []
546 | _sp = []
547 | _cp = []
548 |
549 | for batch_idx, (X_batch) in enumerate(tqdm(self.dataloader_train)):
550 | l, _, sp, cp = self.process_minibatch_r(X_batch, **kwargs)
551 | _loss.append(l)
552 | _sp.append(sp)
553 | _cp.append(cp)
554 |
555 | train_stats['train_loss'].append(np.mean(_loss))
556 | train_stats['train_sparsity'].append(np.mean(_sp))
557 | train_stats['train_set_cover'].append(np.mean(_cp))
558 |
559 | if self.percent_training != 1:
560 | _loss = []
561 | for batch_idx, (X_batch) in enumerate(tqdm(self.dataloader_val)):
562 | l, _, _, _ = self.process_minibatch_r(X_batch, **kwargs)
563 | _loss.append(l)
564 |
565 | train_stats['val_loss'].append(np.mean(_loss))
566 | else:
567 | train_stats['val_loss'].append(np.nan)
568 |
569 | def infer_r(self, **kwargs):
570 |
571 | """
572 |
573 | Returns
574 | -------
575 | z_gene_set: numpy array
576 | gene set activity scores
577 |
578 | """
579 | with torch.no_grad():
580 | self.model.eval()
581 |
582 | z_e_list = []
583 | _loss = []
584 | print(f"Start inferring ...")
585 | for batch_idx, (X_batch) in enumerate(self.dataloader_all):
586 | l, z_e, _, _ = self.process_minibatch_r(X_batch, **kwargs)
587 |
588 | z_e_list.append(z_e)
589 | _loss.append(l)
590 |
591 | print(f"Finish inferring ...")
592 | z_gene_set = np.concatenate(z_e_list)
593 |
594 | del z_e_list
595 | del _loss
596 | del z_e
597 |
598 | gc.collect()
599 |
600 | return z_gene_set
601 |
602 | def process_minibatch_z(self, X_batch):
603 |
604 | """
605 | Process minibatch for pretraining the autocluster model (named as pretrain_z)
606 |
607 | """
608 |
609 | X_batch = X_batch.to(self.device, non_blocking=self.non_blocking).float()
610 |
611 | if self.model.training:
612 | self.optimizer.zero_grad(set_to_none=True)
613 |
614 | x_e, z_e = self.model(X_batch)
615 |
616 | l = self.model.loss(X_batch.float(), x_e.float())
617 |
618 | if self.model.training:
619 | l.backward()
620 | self.optimizer.step()
621 | self.scheduler.step()
622 | return l.detach().cpu().item(), z_e.detach().cpu().numpy()
623 |
624 | def evaluate_z(self, **kwargs):
625 | train_stats = self.train_stats_dicts[self.model_name]
626 |
627 | _loss = []
628 | for batch_idx, (X_batch) in enumerate(tqdm(self.dataloader_train)):
629 | l, _ = self.process_minibatch_z(X_batch)
630 | _loss.append(l)
631 |
632 | train_stats['train_loss'].append(np.mean(_loss))
633 |
634 | if self.percent_training != 1:
635 | _loss = []
636 | for batch_idx, (X_batch) in enumerate(tqdm(self.dataloader_val)):
637 | l, _ = self.process_minibatch_z(X_batch)
638 | _loss.append(l)
639 |
640 | train_stats['val_loss'].append(np.mean(_loss))
641 | else:
642 | train_stats['val_loss'].append(np.nan)
643 |
644 | def infer_z(self):
645 |
646 | """
647 |
648 | Returns
649 | -------
650 | z_init: numpy array
651 | z_e from the pretrain model
652 |
653 | """
654 |
655 | with torch.no_grad():
656 | self.model.eval()
657 |
658 | z_e_list = []
659 |
660 | for batch_idx, (X_batch) in enumerate(tqdm(self.dataloader_all)):
661 | _, z_e = self.process_minibatch_z(X_batch)
662 | z_e_list.append(z_e)
663 |
664 | z_init = np.concatenate(z_e_list)
665 |
666 | del z_e_list
667 | gc.collect()
668 |
669 | return z_init
670 |
671 | @staticmethod
672 | def visualize_UMAP(X, epoch:int, output_folder: str, clusters_true=None, clusters_pre=None,
673 | color_palette:str = "tab20"):
674 | """
675 |
676 | Visualize the low-dimensional representations using UMAP and save figures.
677 |
678 | Parameters
679 | ----------
680 | X: numpy array
681 | low-dimensional representations
682 | epoch: integer
683 | epoch of the model based on which the low-dimensional representations is inferred
684 | clusters_true: numpy array
685 | ground truth labels
686 | clusters_pre: numpy array
687 | cluster assignment
688 |
689 | """
690 |
691 | print(f"Start visualizing using UMAP...")
692 | umap_original = umap.UMAP().fit_transform(X)
693 |
694 | # color by cluster
695 | hues = {'label': clusters_true, 'cluster': clusters_pre}
696 |
697 | for k, v in hues.items():
698 | df_plot = pd.DataFrame(umap_original)
699 | if v is None:
700 | df_plot['label'] = np.repeat("Label not available", df_plot.shape[0])
701 | else:
702 | df_plot['label'] = v
703 | df_plot['label'].astype('str')
704 | df_plot.columns = ['dim_1', 'dim_2', 'label']
705 |
706 | plt.figure(figsize=(10, 10))
707 | sns.scatterplot(x='dim_1', y='dim_2', hue='label', data=df_plot, palette=color_palette,
708 | legend=True)
709 | plt.title(f"Encoding (r) colored by {k}")
710 | plt.savefig(os.path.join(output_folder, f"r_{epoch}_{k}.png"), bbox_inches="tight", format="png")
711 | plt.close()
712 |
713 | del X
714 | del umap_original
715 | del df_plot
716 |
717 | gc.collect()
718 |
719 |
--------------------------------------------------------------------------------
/unifan/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Helper classes / functions.
3 |
4 | """
5 | import os
6 | import argparse
7 | import gc
8 |
9 | import numpy as np
10 |
11 |
12 | def str2bool(v):
13 | """
14 | Helper to pass boolean arguements.
15 | Extracted from: https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
16 | Author: @Maxim
17 | """
18 |
19 | if isinstance(v, bool):
20 | return v
21 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
22 | return True
23 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
24 | return False
25 | else:
26 | raise argparse.ArgumentTypeError('Boolean value expected.')
27 |
28 |
29 | def gen_tf_gene_table(genes, tf_list, dTD):
30 | """
31 |
32 | Adapted from:
33 | Author: Jun Ding
34 | Project: SCDIFF2
35 | Ref: Ding, J., Aronow, B. J., Kaminski, N., Kitzmiller, J., Whitsett, J. A., & Bar-Joseph, Z.
36 | (2018). Reconstructing differentiation networks and their regulation from time series
37 | single-cell expression data. Genome research, 28(3), 383-395.
38 |
39 | """
40 | gene_names = [g.upper() for g in genes]
41 | TF_names = [g.upper() for g in tf_list]
42 | tf_gene_table = dict.fromkeys(tf_list)
43 |
44 | for i, tf in enumerate(tf_list):
45 | tf_gene_table[tf] = np.zeros(len(gene_names))
46 | _genes = dTD[tf]
47 |
48 | _existed_targets = list(set(_genes).intersection(gene_names))
49 | _idx_targets = map(lambda x: gene_names.index(x), _existed_targets)
50 |
51 | for _g in _idx_targets:
52 | tf_gene_table[tf][_g] = 1
53 |
54 | del gene_names
55 | del TF_names
56 | del _genes
57 | del _existed_targets
58 | del _idx_targets
59 |
60 | gc.collect()
61 |
62 | return tf_gene_table
63 |
64 |
65 | def getGeneSetMatrix(_name, genes_upper, gene_sets_path):
66 | """
67 |
68 | Adapted from:
69 | Author: Jun Ding
70 | Project: SCDIFF2
71 | Ref: Ding, J., Aronow, B. J., Kaminski, N., Kitzmiller, J., Whitsett, J. A., & Bar-Joseph, Z.
72 | (2018). Reconstructing differentiation networks and their regulation from time series
73 | single-cell expression data. Genome research, 28(3), 383-395.
74 |
75 | """
76 | if _name[-3:] == 'gmt':
77 | print(f"GMT file {_name} loading ... ")
78 | filename = _name
79 | filepath = os.path.join(gene_sets_path, f"{filename}")
80 |
81 | with open(filepath) as genesets:
82 | pathway2gene = {line.strip().split("\t")[0]: line.strip().split("\t")[2:]
83 | for line in genesets.readlines()}
84 |
85 | print(len(pathway2gene))
86 |
87 | gs = []
88 | for k, v in pathway2gene.items():
89 | gs += v
90 |
91 | print(f"Number of genes in {_name} {len(set(gs).intersection(genes_upper))}")
92 |
93 | pathway_list = pathway2gene.keys()
94 | pathway_gene_table = gen_tf_gene_table(genes_upper, pathway_list, pathway2gene)
95 | gene_set_matrix = np.array(list(pathway_gene_table.values()))
96 | keys = pathway_gene_table.keys()
97 |
98 | del pathway2gene
99 | del gs
100 | del pathway_list
101 | del pathway_gene_table
102 |
103 | gc.collect()
104 |
105 |
106 | elif _name == 'TF-DNA':
107 |
108 | # get TF-DNA dictionary
109 | # TF->DNA
110 | def getdTD(tfDNA):
111 | dTD = {}
112 | with open(tfDNA, 'r') as f:
113 | tfRows = f.readlines()
114 | tfRows = [item.strip().split() for item in tfRows]
115 | for row in tfRows:
116 | itf = row[0].upper()
117 | itarget = row[1].upper()
118 | if itf not in dTD:
119 | dTD[itf] = [itarget]
120 | else:
121 | dTD[itf].append(itarget)
122 |
123 | del tfRows
124 | del itf
125 | del itarget
126 | gc.collect()
127 |
128 | return dTD
129 |
130 | from collections import defaultdict
131 |
132 | def getdDT(dTD):
133 | gene_tf_dict = defaultdict(lambda: [])
134 | for key, val in dTD.items():
135 | for v in val:
136 | gene_tf_dict[v.upper()] += [key.upper()]
137 |
138 | return gene_tf_dict
139 |
140 | tfDNA_file = os.path.join(gene_sets_path, f"Mouse_TF_targets.txt")
141 | dTD = getdTD(tfDNA_file)
142 | dDT = getdDT(dTD)
143 |
144 | tf_list = list(sorted(dTD.keys()))
145 | tf_list.remove('TF')
146 |
147 | tf_gene_table = gen_tf_gene_table(genes_upper, tf_list, dTD)
148 | gene_set_matrix = np.array(list(tf_gene_table.values()))
149 | keys = tf_gene_table.keys()
150 |
151 | del dTD
152 | del dDT
153 | del tf_list
154 | del tf_gene_table
155 |
156 | gc.collect()
157 |
158 | else:
159 | gene_set_matrix = None
160 |
161 | return gene_set_matrix, keys
162 |
163 |
164 |
165 |
--------------------------------------------------------------------------------