├── LICENSE ├── README.md ├── gene_sets ├── Mouse_TF_targets.txt ├── c2.cp.v7.4.symbols.gmt ├── c2.cp.wikipathways.v7.4.symbols.gmt ├── c5.go.bp.v7.4.symbols.gmt └── c8.all.v7.4.symbols.gmt ├── setup.py ├── tutorails ├── Data_preprocessing.ipynb ├── HuBMAP_datasets_ID.xlsx ├── UNIFAN_cluster_annotations.ipynb ├── UNIFAN_example.ipynb └── getExample.py ├── unifan-main.png ├── unifan-pretrain.png └── unifan ├── __init__.py ├── annocluster.py ├── autoencoder.py ├── classifier.py ├── datasets.py ├── main.py ├── networks.py ├── trainer.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Dora Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | UNIFAN (**Un**supervised S**i**ngle-cell **F**unctional **An**notation) simultaneously clusters and annotates cells with known biological processes (including pathways). For each single cell, UNIFAN first infers gene set activity scores (denoted by r in the figure below) associated with this cell using the input gene sets. 3 | 4 | ![flowchart](./unifan-pretrain.png) 5 | 6 | Next, UNIFAN clusters cells by using the learned gene set activity scores (r) and a reduced dimension representation of the expression of genes in the cell. The gene set activity scores are used by an “annotator” to guide the clustering such that cells sharing similar biological processes are more likely to be grouped together. Such design allows the method to focus on the key processes when clustering cells and so can overcome issues related to noise and dropout while simultaneously selecting marker gene sets which can be used to annotate clusters. To allow the selection of marker genes for each cluster, we also add a subset of the most variable genes selected using Seuratv3 (Stuart et al., [2019](https://doi.org/10.1016/j.cell.2019.05.031)) as features for the annotator. 7 | 8 | ![flowchart](./unifan-main.png) 9 | 10 | ## Table of Contents 11 | - [Get started](#Get-started) 12 | - [Command-line tools](#Command-line) 13 | - [Tutorials](#Tutorials) 14 | - [Updates log](#Updates-log) 15 | - [Learn-more](#Learn-more) 16 | - [Credits](#Credits) 17 | 18 | # Get-started 19 | ## Prerequisites 20 | * Python >= 3.6 21 | * Python side-packages: 22 | -- pytorch >= 1.9.0 23 | -- numpy >= 1.19.2 24 | -- pandas>=1.1.5 25 | -- scanpy >= 1.7.2 26 | -- leidenalg>=0.8.4 27 | -- tqdm >= 4.61.1 28 | -- scikit-learn>=0.24.2 29 | -- umap-learn>=0.5.1 30 | -- matplotlib >= 3.3.4 31 | -- seaborn >= 0.11.0 32 | 33 | ## Installation 34 | 35 | ### Install within a virtual environment 36 | 37 | It is recommended to use a virtural environment/pacakges manager such as [Anaconda](https://www.anaconda.com/). After successfully installing Anaconda/Miniconda, create an environment by the following: 38 | 39 | ```shell 40 | conda create -n myenv python=3.6 41 | ``` 42 | 43 | You can then install and run the package in the virtual environment. Activate the virtural environment by: 44 | 45 | ```shell 46 | conda activate myenv 47 | ``` 48 | 49 | Make sure you have **pip** installed in your environment. You may check by 50 | 51 | ```shell 52 | conda list 53 | ``` 54 | 55 | If not installed, then: 56 | 57 | ```shell 58 | conda install pip 59 | ``` 60 | ### Install Pytorch 61 | 62 | UNIFAN is built based on Pytorch and supporting both CPU or GPU. Make sure you have Pytorch (>= 1.9.0) installed in your virtual environment. If not, please visist [Pytorch](https://pytorch.org/) and install the appropriate version. 63 | 64 | ### Install UNIFAN 65 | 66 | Install by: 67 | 68 | ```shell 69 | pip install git+https://github.com/doraadong/UNIFAN.git 70 | ``` 71 | 72 | If you want to upgrade UNIFAN to the latest version, then first uninstall it by: 73 | 74 | ```shell 75 | pip uninstall unifan 76 | ``` 77 | 78 | And then just run the pip install command again. 79 | 80 | # Command-line 81 | 82 | You may import UNIFAN as an package and use it in your code (See [Tutorials](#Tutorials) for details). Or you may train models using the following command-line tool. 83 | 84 | ## Run UNIFAN 85 | 86 | Run UNIFAN by (arguments are taken for example): 87 | 88 | ```shell 89 | main.py -i ../example/input/Limb_Muscle.h5ad -o ../example/output -p tabula_muris -t Limb_Muscle -l cell_ontology_class -e ../gene_sets/ 90 | ``` 91 | The usage of this command is listed as follows. Note only the first 5 inputs are required: 92 | 93 | ```shell 94 | usage: main.py [-h] -i INPUT -o OUTPUT -p PROJECT -t TISSUE [-e GENESETSPATH] 95 | [-l LABEL] [-v VARIABLE] [-r PRIOR] 96 | [-f {gene_sets,gene,gene_gene_sets}] [-a ALPHA] [-b BETA] 97 | [-g GAMMA] [-u TAU] [-d DIM] [-s BATCH] [-na NANNO] 98 | [-ns NSCORE] [-nu NAUTO] [-nc NCLUSTER] [-nze NZENCO] 99 | [-nzd NZDECO] [-dze DIMZENCO] [-dzd DIMZDECO] [-nre NRENCO] 100 | [-dre DIMRENCO] [-drd DIMRDECO] 101 | [-n {sigmoid,non-negative,gaussian}] [-m SEED] [-c CUDA] 102 | [-w NWORKERS] 103 | 104 | optional arguments: 105 | -h, --help show this help message and exit 106 | -i INPUT, --input INPUT 107 | string, path to the input expression data, default 108 | '../input/data.h5ad' 109 | -o OUTPUT, --output OUTPUT 110 | string, path to the output folder, default 111 | '../output/' 112 | -p PROJECT, --project PROJECT 113 | string, identifier for the project, e.g., tabula_muris 114 | -t TISSUE, --tissue TISSUE 115 | string, tissue where the input data is sampled from 116 | -e GENESETSPATH, --geneSetsPath GENESETSPATH 117 | string, path to the folder where gene sets can be 118 | found, default='../gene_sets/' 119 | -l LABEL, --label LABEL 120 | string, optional, the column / field name of the 121 | ground truth label, if available; used for evaluation 122 | only; default None 123 | -v VARIABLE, --variable VARIABLE 124 | string, optional, the column / field name of the 125 | highly variable genes; default 'highly_variable' 126 | -r PRIOR, --prior PRIOR 127 | string, optional, gene set file names used to learn 128 | the gene set activity scores, use '+' to separate 129 | multiple gene set names, default 130 | c5.go.bp.v7.4.symbols.gmt+c2.cp.v7.4.symbols.gmt+TF- 131 | DNA 132 | -f {gene_sets,gene,gene_gene_sets}, --features {gene_sets,gene,gene_gene_sets} 133 | string, optional, features used for the annotator, any 134 | of 'gene_sets', 'gene' or 'gene_gene_sets', default 135 | 'gene_gene_sets' 136 | -a ALPHA, --alpha ALPHA 137 | float, optional, hyperparameter for the L1 term in the 138 | set cover loss, default 1e-2 139 | -b BETA, --beta BETA float, optional, hyperparameter for the set cover term 140 | in the set cover loss, default 1e-5 141 | -g GAMMA, --gamma GAMMA 142 | float, optional, hyperparameter for the exclusive L1 143 | term, default 1e-3 144 | -u TAU, --tau TAU float, optional, hyperparameter for the annotator 145 | loss, default 10 146 | -d DIM, --dim DIM integer, optional, dimension for the low-dimensional 147 | representation, default 32 148 | -s BATCH, --batch BATCH 149 | integer, optional, batch size for training except for 150 | pretraining annotator (fixed at 32), default 128 151 | -na NANNO, --nanno NANNO 152 | integer, optional, number of epochs to pretrain the 153 | annotator, default 50 154 | -ns NSCORE, --nscore NSCORE 155 | integer, optional, number of epochs to train the gene 156 | set activity model, default 70 157 | -nu NAUTO, --nauto NAUTO 158 | integer, optional, number of epochs to pretrain the 159 | annocluster model, default 50 160 | -nc NCLUSTER, --ncluster NCLUSTER 161 | integer, optional, number of epochs to train the 162 | annocluster model, default 25 163 | -nze NZENCO, --nzenco NZENCO 164 | float, optional, number of hidden layers for encoder 165 | of annocluster, default 3 166 | -nzd NZDECO, --nzdeco NZDECO 167 | float, optional, number of hidden layers for decoder 168 | of annocluster, default 2 169 | -dze DIMZENCO, --dimzenco DIMZENCO 170 | integer, optional, number of nodes for hidden layers 171 | for encoder of annocluster, default 128 172 | -dzd DIMZDECO, --dimzdeco DIMZDECO 173 | integer, optional, number of nodes for hidden layers 174 | for decoder of annocluster, default 128 175 | -nre NRENCO, --nrenco NRENCO 176 | integer, optional, number of hidden layers for the 177 | encoder of gene set activity scores model, default 5 178 | -dre DIMRENCO, --dimrenco DIMRENCO 179 | integer, optional, number of nodes for hidden layers 180 | for encoder of gene set activity scores model, default 181 | 128 182 | -drd DIMRDECO, --dimrdeco DIMRDECO 183 | integer, optional, number of nodes for hidden layers 184 | for decoder of gene set activity scores model, default 185 | 128 186 | -n {sigmoid,non-negative,gaussian}, --network {sigmoid,non-negative,gaussian} 187 | string, optional, the encoder for the gene set 188 | activity model, any of 'sigmoid', 'non-negative' or 189 | 'gaussian', default 'non-negative' 190 | -m SEED, --seed SEED integer, optional, random seed for the initialization, 191 | default 0 192 | -c CUDA, --cuda CUDA boolean, optional, if use GPU for neural network 193 | training, default False 194 | -w NWORKERS, --nworkers NWORKERS 195 | integer, optional, number of workers for dataloader, 196 | default 8 197 | 198 | ``` 199 | 200 | 201 | # Tutorials 202 | 203 | Github rendering disables some functionalities of Jupyter notebooks. We recommend using [nbviewer](https://nbviewer.jupyter.org/) to view the following tutorials. 204 | 205 | ## Run UNIFAN on example data 206 | In [UNIFAN training tutorial](tutorails/UNIFAN_example.ipynb), we illustrate how to run UNIFAN step-by-step on the example data: Limb_Muscle from Tabula Muris. 207 | 208 | ### Download and Preprocess the Input Data 209 | You may download the gene sets in [gene_sets](gene_sets). As default, we use the GO terms for biological processes (c5.go.bp.v7.4.symbols.gmt), canonical pathways (c2.cp.v7.4.symbols.gmt) and the TF-DNA interacitons data (Mouse_TF_targets.txt). 210 | 211 | UNIFAN takes AnnData files as input. See [AnnData](https://anndata.readthedocs.io/en/latest/) for details. To prepare the example data (Limb_Muscle in Tabula Muris), first download the [Tabula Muris senis data](https://figshare.com/ndownloader/files/24351086). Then run the Python script [getExample.py](tutorails/getExample.py) to preprocess the count data using the following command: 212 | 213 | ```shell 214 | python getExample.py -p ./facs.h5ad -i ../example/input -t Limb_Muscle 215 | 216 | ``` 217 | The usage of this command is listed as follows: 218 | 219 | ```shell 220 | usage: getExample.py [-h] -p PATH -i FOLDER -t TISSUE [-k TOPK] 221 | 222 | optional arguments: 223 | -h, --help show this help message and exit 224 | -p PATH, --path PATH string, path to the downloaded data, default 225 | './facs.h5ad' 226 | -i FOLDER, --folder FOLDER 227 | string, path to the folder to save the data, default 228 | '../example/input' 229 | -t TISSUE, --tissue TISSUE 230 | string, specify the output tissue; if using the 231 | default None, then all tissues will be outputted and 232 | saved separately in the folder; default None 233 | -k TOPK, --topk TOPK integer, optional, number of most variable genes, 234 | default 2000 235 | 236 | ``` 237 | 238 | We also provide [Data preprocessing](tutorails/Data_preprocessing.ipynb) showing how we preprocessed the other datasets we used in the manuscript. 239 | 240 | ## Analyze results and annotate clusters 241 | In [cluster annotating tutorial](tutorails/UNIFAN_cluster_annotations.ipynb), we illustrate how to use the coefficients learned by UNIFAN to annotate clusters. Particularly, we show how to select representing gene sets / genes for each cluster, evaluate if selected genes are likely marker genes and visualize the annotations. 242 | 243 | # Updates-log 244 | * 10-11-2022: 245 | -- Add tutorial on preprocessing the datasets used in the manuscript 246 | 247 | # Learn-more 248 | Check our paper at [Genome Research](https://genome.cshlp.org/content/early/2022/06/28/gr.276609.122.long). Link to [preprint](https://www.biorxiv.org/content/10.1101/2021.11.20.469410v2). 249 | 250 | # Credits 251 | The software is an implementation of the method UNIFAN, jointly developed by [Dongshunyi "Dora" Li](https://github.com/doraadong) and Ziv Bar-Joseph from [System Biology Group @ Carnegie Mellon University](http://sb.cs.cmu.edu/) and [Jun Ding](https://github.com/phoenixding) from McGill University. 252 | 253 | # Contacts 254 | * dongshul at andrew.cmu.edu 255 | 256 | # License 257 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details 258 | 259 | 260 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # partly borrowed from from https://github.com/navdeep-G/setup.py/blob/master/setup.py 5 | import io 6 | import os 7 | from setuptools import setup, find_packages 8 | import pathlib 9 | 10 | # Package meta-data. 11 | NAME = 'unifan' 12 | DESCRIPTION = 'Unsupervised cell functional annotation' 13 | URL = "https://github.com/doraadong/UNIFAN" 14 | EMAIL = 'dongshul@andrew.cmu.edu' 15 | AUTHOR = 'Dora Li' 16 | REQUIRES_PYTHON = '>=3.6' 17 | VERSION = '1.0.0' 18 | 19 | # What packages are required for this module to be executed? 20 | REQUIRED = ["torch", "numpy>=1.19.2", "pandas>=1.1.5", "scanpy>=1.7.2", "leidenalg>=0.8.4", "tqdm>=4.61.1", 21 | "scikit-learn>=0.24.2", "umap-learn>=0.5.1", "matplotlib>=3.3.4", "seaborn>=0.11.0"] 22 | 23 | # The rest you shouldn't have to touch too much :) 24 | # ------------------------------------------------ 25 | # Except, perhaps the License and Trove Classifiers! 26 | # If you do change the License, remember to change the Trove Classifier for that! 27 | 28 | here = os.path.abspath(os.path.dirname(__file__)) 29 | 30 | # Import the README and use it as the long-description. 31 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file! 32 | try: 33 | with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f: 34 | long_description = '\n' + f.read() 35 | except FileNotFoundError: 36 | long_description = DESCRIPTION 37 | 38 | # Load the package's __version__.py module as a dictionary. 39 | about = {} 40 | if not VERSION: 41 | project_slug = NAME.lower().replace("-", "_").replace(" ", "_") 42 | with open(os.path.join(here, project_slug, '__version__.py')) as f: 43 | exec(f.read(), about) 44 | else: 45 | about['__version__'] = VERSION 46 | 47 | setup( 48 | name=NAME, 49 | version=about['__version__'], 50 | description=DESCRIPTION, 51 | long_description=long_description, 52 | long_description_content_type='text/markdown', 53 | author=AUTHOR, 54 | author_email=EMAIL, 55 | python_requires=REQUIRES_PYTHON, 56 | url=URL, 57 | packages=['unifan'], 58 | scripts=['unifan/main.py'], 59 | install_requires=REQUIRED, 60 | include_package_data=True, 61 | license='MIT', 62 | classifiers=[ 63 | 'License :: OSI Approved :: MIT License', 64 | 'Programming Language :: Python :: 3.6', 65 | ], 66 | ) 67 | -------------------------------------------------------------------------------- /tutorails/Data_preprocessing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os, sys\n", 10 | "import argparse\n", 11 | "import time\n", 12 | "from os.path import exists\n", 13 | "import collections\n", 14 | "from typing import Iterable\n", 15 | "import pickle \n", 16 | "from collections import Counter\n", 17 | "\n", 18 | "import scanpy as sc\n", 19 | "import pandas as pd\n", 20 | "import numpy as np\n", 21 | "from matplotlib import pyplot as plt" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "# Preprocess input data \n", 29 | "\n", 30 | "Here we show how we preprocessed the datasets we used in our manuscript, except for the Tabula Muris datasets which we already covered in our examples (see [getExample.py](https://github.com/doraadong/UNIFAN/blob/main/tutorails/getExample.py) for details). Other than the pacakges imported above, we also use [mygene](https://pypi.org/project/mygene/) package to convert ENSEMBL IDs to gene symbols for the HuBMAP datasets. \n", 31 | "\n", 32 | "**Table of Content**\n", 33 | "\n", 34 | "1. [Preprocess pbmc28K](#1)\n", 35 | "2. [Preprocess pbmc68K](#2)\n", 36 | "3. [Preprocess Atlas lung](#3)\n", 37 | "4. [Preprocess HuBMAP datasets](#4)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 3, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "cell_cutoffs = {'lung': 500, 'lymph_node':200, 'spleen':200, 'thymus':200, 'pbmc68k':200, 'pbmc28k':200}\n", 47 | "gene_cutoffs = {'lung': 5, 'lymph_node':3, 'spleen':3, 'thymus':3, 'pbmc68k':3, 'pbmc28k':3}" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 4, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "save_raw = False\n", 57 | "save_processed = True\n", 58 | "\n", 59 | "topk = 2000\n", 60 | "select_genes = False\n", 61 | "\n", 62 | "if select_genes: \n", 63 | " condition = f\"_top{topk}\"\n", 64 | "else:\n", 65 | " condition = f\"\"" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "## 1. Preprocess pbmc28k\n", 73 | "" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 6, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "tissue = \"pbmc28k\"\n", 83 | "data_name = \"pbmc\"\n", 84 | "\n", 85 | "parent_folder = f\"../input/{data_name}\"\n", 86 | "labels = pd.read_csv(f\"{parent_folder}/{tissue}/barcodes_to_cell_types.tsv\", sep=\"\\t\")\n", 87 | "\n", 88 | "barcodes = [i.split('_')[0] for i in labels['barcode']]\n", 89 | "lanes = [i.split('_')[1] for i in labels['barcode']]\n", 90 | "labels['code'] = barcodes\n", 91 | "labels['lane'] = lanes" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 7, 97 | "metadata": {}, 98 | "outputs": [ 99 | { 100 | "name": "stdout", 101 | "output_type": "stream", 102 | "text": [ 103 | "AnnData object with n_obs × n_vars = 4309 × 32738\n", 104 | " var: 'gene_ids'\n", 105 | "AnnData object with n_obs × n_vars = 3844 × 32738\n", 106 | " var: 'gene_ids'\n", 107 | "AnnData object with n_obs × n_vars = 4998 × 32738\n", 108 | " var: 'gene_ids'\n", 109 | "AnnData object with n_obs × n_vars = 3392 × 32738\n", 110 | " var: 'gene_ids'\n", 111 | "AnnData object with n_obs × n_vars = 2868 × 32738\n", 112 | " var: 'gene_ids'\n", 113 | "AnnData object with n_obs × n_vars = 3685 × 32738\n", 114 | " var: 'gene_ids'\n", 115 | "AnnData object with n_obs × n_vars = 3198 × 32738\n", 116 | " var: 'gene_ids'\n", 117 | "AnnData object with n_obs × n_vars = 2561 × 32738\n", 118 | " var: 'gene_ids'\n" 119 | ] 120 | } 121 | ], 122 | "source": [ 123 | "full_data = None\n", 124 | "pre_genes = None \n", 125 | "\n", 126 | "for i in range(1, 9):\n", 127 | " cur_lane = f\"lane_{i}\"\n", 128 | " \n", 129 | " parent_folder = f\"../input/{data_name}\"\n", 130 | " adata = sc.read_10x_mtx(f\"{parent_folder}/{tissue}/{cur_lane}/\", var_names='gene_symbols', cache=True) \n", 131 | " adata.var_names_make_unique() \n", 132 | " \n", 133 | " print(adata)\n", 134 | " \n", 135 | " # get observation idx\n", 136 | " _codes = [i.split('-')[0] for i in adata.obs.index.values]\n", 137 | " # all barcodes having labels are in the expression file \n", 138 | " assert len(set(labels[labels['lane'] == f\"lane{i}\"].code.values) - set(_codes)) == 0\n", 139 | "\n", 140 | " _others = [i.split('-')[1] for i in adata.obs.index.values]\n", 141 | " assert (np.array(_others) == '1').all()\n", 142 | " \n", 143 | " # get labels \n", 144 | " cur_labels = labels[labels['lane'] == f\"lane{i}\"].copy().reset_index()\n", 145 | " cur_labels = cur_labels[['code', 'cell_type']]\n", 146 | "\n", 147 | " df_codes = pd.DataFrame(_codes)\n", 148 | " df_codes.columns = ['code']\n", 149 | " df_codes = df_codes.merge(cur_labels, left_on = 'code', right_on = 'code', how = 'left')\n", 150 | "\n", 151 | " # same order as in adata\n", 152 | " assert (df_codes['code'].values == np.array(_codes)).all()\n", 153 | " \n", 154 | " adata.obs['label'] = df_codes['cell_type'].values\n", 155 | " \n", 156 | " if full_data is not None:\n", 157 | " full_data = full_data.concatenate(adata)\n", 158 | " else:\n", 159 | " full_data = adata\n", 160 | " \n", 161 | " # check if gene id same for all data\n", 162 | " cur_genes = adata.var['gene_ids'].values\n", 163 | " if pre_genes is not None:\n", 164 | " assert (pre_genes == cur_genes).all()\n", 165 | " pre_genes = cur_genes" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 8, 171 | "metadata": { 172 | "scrolled": false 173 | }, 174 | "outputs": [], 175 | "source": [ 176 | "# keep only those with labels\n", 177 | "full_data = full_data[~full_data.obs.label.isna()]" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 9, 183 | "metadata": { 184 | "scrolled": true 185 | }, 186 | "outputs": [ 187 | { 188 | "data": { 189 | "text/plain": [ 190 | "View of AnnData object with n_obs × n_vars = 25185 × 32738\n", 191 | " obs: 'label', 'batch'\n", 192 | " var: 'gene_ids'" 193 | ] 194 | }, 195 | "execution_count": 9, 196 | "metadata": {}, 197 | "output_type": "execute_result" 198 | } 199 | ], 200 | "source": [ 201 | "full_data" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 10, 207 | "metadata": {}, 208 | "outputs": [ 209 | { 210 | "name": "stderr", 211 | "output_type": "stream", 212 | "text": [ 213 | "Trying to set attribute `.obs` of view, copying.\n", 214 | "... storing 'label' as categorical\n" 215 | ] 216 | } 217 | ], 218 | "source": [ 219 | "# filtering \n", 220 | "sc.pp.filter_cells(full_data, min_genes=cell_cutoffs[tissue])\n", 221 | "sc.pp.filter_genes(full_data, min_cells=gene_cutoffs[tissue])\n", 222 | "\n", 223 | "if save_raw:\n", 224 | " filename = f\"{tissue}_raw{condition}.h5ad\"\n", 225 | " full_data.write_h5ad(os.path.join(parent_folder, filename))\n", 226 | "\n", 227 | "if save_processed:\n", 228 | " # selecting genes using Seurat v3 method \n", 229 | " sc.pp.highly_variable_genes(full_data, n_top_genes=topk, flavor='seurat_v3')\n", 230 | " if select_genes:\n", 231 | " full_data = full_data[:, full_data.var.highly_variable]\n", 232 | "\n", 233 | " # normalize \n", 234 | " sc.pp.normalize_total(full_data, target_sum=1e4)\n", 235 | " sc.pp.log1p(full_data)\n", 236 | "\n", 237 | " # adata.raw = adata\n", 238 | " sc.pp.scale(full_data, max_value=10)\n", 239 | "\n", 240 | " # save \n", 241 | " try:\n", 242 | " full_data.__dict__['_raw'].__dict__['_var'] = full_data.__dict__['_raw'].__dict__['_var'].rename(columns={'_index': 'features'})\n", 243 | " except AttributeError:\n", 244 | " pass \n", 245 | "\n", 246 | " filename = f\"{tissue}_processed{condition}.h5ad\"\n", 247 | " full_data.write_h5ad(os.path.join(parent_folder, filename))" 248 | ] 249 | }, 250 | { 251 | "cell_type": "markdown", 252 | "metadata": {}, 253 | "source": [ 254 | "## 2. Preprocessing pbmc68k\n", 255 | "" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 6, 261 | "metadata": {}, 262 | "outputs": [ 263 | { 264 | "name": "stderr", 265 | "output_type": "stream", 266 | "text": [ 267 | "... storing 'label' as categorical\n" 268 | ] 269 | } 270 | ], 271 | "source": [ 272 | "tissue = \"pbmc68k\"\n", 273 | "data_name = \"pbmc\"\n", 274 | "\n", 275 | "parent_folder = f\"../input/{data_name}\"\n", 276 | "full_data = sc.read_10x_mtx(f\"{parent_folder}/{tissue}/filtered_matrices_mex/hg19/\", var_names='gene_symbols',cache=True) \n", 277 | "full_data.var_names_make_unique() \n", 278 | "\n", 279 | "# get labels \n", 280 | "labels = pd.read_csv(f\"{parent_folder}/{tissue}/zheng17-cell-labels.txt\", sep = \"\\t\")\n", 281 | "\n", 282 | "# check if barcodes same\n", 283 | "barcodes = pd.read_csv(\"../input/pbmc/pbmc68k/filtered_matrices_mex/hg19/barcodes.tsv\", header = None)\n", 284 | "assert (labels['barcode'].values == barcodes[0].values).all()\n", 285 | "\n", 286 | "full_data.obs['label'] = labels['bulk_labels'].values\n", 287 | "\n", 288 | "# filtering \n", 289 | "sc.pp.filter_cells(full_data, min_genes=cell_cutoffs[tissue]) \n", 290 | "sc.pp.filter_genes(full_data, min_cells=gene_cutoffs[tissue])\n", 291 | "\n", 292 | "if save_raw:\n", 293 | " filename = f\"{tissue}_raw{condition}.h5ad\"\n", 294 | " full_data.write_h5ad(os.path.join(parent_folder, filename))\n", 295 | " \n", 296 | "if save_processed:\n", 297 | " # selecting genes using Seurat v3 method \n", 298 | " sc.pp.highly_variable_genes(full_data, n_top_genes=topk, flavor='seurat_v3')\n", 299 | " if select_genes:\n", 300 | " full_data = full_data[:, full_data.var.highly_variable]\n", 301 | "\n", 302 | " # normalie \n", 303 | " sc.pp.normalize_total(full_data, target_sum=1e4)\n", 304 | " sc.pp.log1p(full_data)\n", 305 | "\n", 306 | " # adata.raw = adata\n", 307 | " sc.pp.scale(full_data, max_value=10)\n", 308 | "\n", 309 | " # save \n", 310 | " try:\n", 311 | " full_data.__dict__['_raw'].__dict__['_var'] = full_data.__dict__['_raw'].__dict__['_var'].rename(columns={'_index': 'features'})\n", 312 | " except AttributeError:\n", 313 | " pass \n", 314 | "\n", 315 | " filename = f\"{tissue}_processed{condition}.h5ad\"\n", 316 | " full_data.write_h5ad(os.path.join(parent_folder, filename))" 317 | ] 318 | }, 319 | { 320 | "cell_type": "markdown", 321 | "metadata": {}, 322 | "source": [ 323 | "## 3. Preprocess Atlas lung\n", 324 | "\n", 325 | "\n", 326 | "Before running the following code, make sure you have done the following: \n", 327 | "\n", 328 | "1. Download the following two files from [GSE136831](https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE136831)\n", 329 | " * GSE136831_AllCells.Samples.CellType.MetadataTable.txt\n", 330 | " * GSE136831_RawCounts_Sparse.mtx \n", 331 | "2. Take only samples where \"Disease_Identity\" == \"Control\"\n", 332 | "3. Make a AnnData (i.e. lung.h5ad) using the raw counts as the expression and the meta data as the observation matrix. Make sure to include the following two columns in the observation matrix:\n", 333 | " * CellType_Category - coarse-grained label \n", 334 | " * Manuscript_Identity - fine-grained label " 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": null, 340 | "metadata": {}, 341 | "outputs": [], 342 | "source": [ 343 | "tissue = \"lung\"\n", 344 | "\n", 345 | "filename = f\"{tissue}.h5ad\"\n", 346 | "parent_folder = \"../input/ipf\"\n", 347 | "\n", 348 | "full_data = sc.read(os.path.join(parent_folder, filename), dtype='float64')\n", 349 | "\n", 350 | "# filtering \n", 351 | "sc.pp.filter_cells(full_data, min_genes=cell_cutoffs[tissue]) \n", 352 | "sc.pp.filter_genes(full_data, min_cells=gene_cutoffs[tissue])\n", 353 | "\n", 354 | "if save_raw:\n", 355 | " # filter using genes in the preprocessed data\n", 356 | " filename = f\"{tissue}_processed{condition}.h5ad\"\n", 357 | " proccessed_data = sc.read(os.path.join(parent_folder, filename), dtype='float64', backed=\"r\")\n", 358 | "\n", 359 | " print(full_data.var.index.is_unique)\n", 360 | " cur_genes = list(full_data.var.index.values)\n", 361 | "\n", 362 | " idx_genes = [cur_genes.index(i) for i in proccessed_data.var.index.values]\n", 363 | " full_data = full_data[:, idx_genes]\n", 364 | "\n", 365 | " assert (full_data.var.index.values == proccessed_data.var.index.values).all()\n", 366 | " assert (full_data.obs.index.values == proccessed_data.obs.index.values).all()\n", 367 | "\n", 368 | " try:\n", 369 | " full_data.__dict__['_raw'].__dict__['_var'] = full_data.__dict__['_raw'].__dict__['_var'].rename(columns={'_index': 'features'})\n", 370 | " except AttributeError:\n", 371 | " pass \n", 372 | "\n", 373 | "\n", 374 | " filename = f\"{tissue}_raw{condition}.h5ad\"\n", 375 | " full_data.write_h5ad(os.path.join(parent_folder, filename))\n", 376 | "\n", 377 | "if save_processed:\n", 378 | "\n", 379 | " # normalie \n", 380 | " sc.pp.normalize_total(full_data, target_sum=1e4)\n", 381 | " sc.pp.log1p(full_data)\n", 382 | "\n", 383 | " # given the large number of genes for this data; filter to keep only highly dispersed genes \n", 384 | " sc.pp.highly_variable_genes(full_data, min_mean=0, max_mean = 1000, min_disp=0.01)\n", 385 | " full_data = full_data[:, full_data.var.highly_variable]\n", 386 | "\n", 387 | " # adata.raw = adata\n", 388 | " sc.pp.scale(full_data, max_value=10)\n", 389 | "\n", 390 | " # save \n", 391 | " try:\n", 392 | " full_data.__dict__['_raw'].__dict__['_var'] = full_data.__dict__['_raw'].__dict__['_var'].rename(columns={'_index': 'features'})\n", 393 | " except AttributeError:\n", 394 | " pass \n", 395 | "\n", 396 | " filename = f\"{tissue}_processed{condition}.h5ad\"\n", 397 | " full_data.write_h5ad(os.path.join(parent_folder, filename))" 398 | ] 399 | }, 400 | { 401 | "cell_type": "markdown", 402 | "metadata": {}, 403 | "source": [ 404 | "##### add highly variables genes selected by Seurat_v3 for lung" 405 | ] 406 | }, 407 | { 408 | "cell_type": "code", 409 | "execution_count": 5, 410 | "metadata": {}, 411 | "outputs": [], 412 | "source": [ 413 | "tissue = \"lung\"\n", 414 | "\n", 415 | "filename = f\"{tissue}.h5ad\"\n", 416 | "parent_folder = \"../input/ipf\"\n", 417 | "\n", 418 | "full_data = sc.read(os.path.join(parent_folder, filename), dtype='float64')\n", 419 | "\n", 420 | "# filtering \n", 421 | "sc.pp.filter_cells(full_data, min_genes=cell_cutoffs[tissue]) \n", 422 | "sc.pp.filter_genes(full_data, min_cells=gene_cutoffs[tissue])\n", 423 | "\n", 424 | "# selecting genes using Seurat v3 method \n", 425 | "sc.pp.highly_variable_genes(full_data, n_top_genes=topk, flavor='seurat_v3') \n", 426 | "genes_variable = full_data.var.index[full_data.var.highly_variable]" 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "execution_count": 33, 432 | "metadata": {}, 433 | "outputs": [], 434 | "source": [ 435 | "filename = f\"{tissue}_processed.h5ad\"\n", 436 | "temp = sc.read(os.path.join(parent_folder, filename), dtype='float64', backed = \"r\")" 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": 17, 442 | "metadata": {}, 443 | "outputs": [ 444 | { 445 | "name": "stdout", 446 | "output_type": "stream", 447 | "text": [ 448 | "Number of Seurat selected genes not in dispersion-selected: 100\n" 449 | ] 450 | } 451 | ], 452 | "source": [ 453 | "print(f\"Number of Seurat selected genes not in dispersion-selected: {len(set(genes_variable) - set(temp.var.index))}\")" 454 | ] 455 | }, 456 | { 457 | "cell_type": "code", 458 | "execution_count": 20, 459 | "metadata": {}, 460 | "outputs": [], 461 | "source": [ 462 | "highly_variable_seurat = np.repeat(0, temp.shape[1])\n", 463 | "\n", 464 | "for i in range(temp.shape[1]):\n", 465 | " g = temp.var.index.values[i]\n", 466 | " if g in genes_variable:\n", 467 | " highly_variable_seurat[i] = 1\n", 468 | " \n", 469 | "temp.var[\"highly_variable_seurat\"] = highly_variable_seurat.astype(bool)" 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": 32, 475 | "metadata": {}, 476 | "outputs": [], 477 | "source": [ 478 | "filename = f\"{tissue}_processed.h5ad\"\n", 479 | "temp.write_h5ad(os.path.join(parent_folder, filename))" 480 | ] 481 | }, 482 | { 483 | "cell_type": "markdown", 484 | "metadata": {}, 485 | "source": [ 486 | "## 4. Preprocess HuBMAP datasets\n", 487 | "\n", 488 | "\n", 489 | "For each tissue, we made the data by concatenating multiple datasets from HuBMAP. Please see [HuBMAP_datasets_ID.xlsx](https://github.com/doraadong/UNIFAN/blob/main/tutorails/HuBMAP_datasets_ID.xlsx) for the dataset IDs for the corresponding tissues. All datasets are preprocessed using standardized [HuBMAP scRNA-seq pipeline](https://github.com/hubmapconsortium/salmon-rnaseq)." 490 | ] 491 | }, 492 | { 493 | "cell_type": "code", 494 | "execution_count": 6, 495 | "metadata": {}, 496 | "outputs": [ 497 | { 498 | "name": "stderr", 499 | "output_type": "stream", 500 | "text": [ 501 | "... storing 'tissue' as categorical\n", 502 | "... storing 'predicted.id' as categorical\n", 503 | "... storing 'celltype' as categorical\n", 504 | "... storing 'tissue' as categorical\n", 505 | "... storing 'predicted.id' as categorical\n", 506 | "... storing 'celltype' as categorical\n" 507 | ] 508 | }, 509 | { 510 | "name": "stdout", 511 | "output_type": "stream", 512 | "text": [ 513 | "True\n" 514 | ] 515 | } 516 | ], 517 | "source": [ 518 | "for tissue in [\"spleen\", \"thymus\", \"lymph_node\"]:\n", 519 | " \n", 520 | " filename = f\"{tissue}.h5ad\"\n", 521 | " parent_folder = \"../input/hubmap\"\n", 522 | "\n", 523 | " full_data = sc.read(os.path.join(parent_folder, filename), dtype='float64')\n", 524 | "\n", 525 | " # filtering \n", 526 | " sc.pp.filter_cells(full_data, min_genes=cell_cutoffs[tissue]) \n", 527 | " sc.pp.filter_genes(full_data, min_cells=gene_cutoffs[tissue])\n", 528 | " \n", 529 | " if save_raw: \n", 530 | " try:\n", 531 | " full_data.__dict__['_raw'].__dict__['_var'] = full_data.__dict__['_raw'].__dict__['_var'].rename(columns={'_index': 'features'})\n", 532 | " except AttributeError:\n", 533 | " pass \n", 534 | " \n", 535 | " filename = f\"{tissue}_raw{condition}.h5ad\"\n", 536 | " full_data.write_h5ad(os.path.join(parent_folder, filename))\n", 537 | " \n", 538 | " if save_processed:\n", 539 | " # selecting genes using Seurat v3 method \n", 540 | " sc.pp.highly_variable_genes(full_data, n_top_genes=topk, flavor='seurat_v3')\n", 541 | "\n", 542 | " if select_genes:\n", 543 | " full_data = full_data[:, full_data.var.highly_variable]\n", 544 | "\n", 545 | " # normalie \n", 546 | " sc.pp.normalize_total(full_data, target_sum=1e4)\n", 547 | " sc.pp.log1p(full_data)\n", 548 | "\n", 549 | " # adata.raw = adata\n", 550 | " sc.pp.scale(full_data, max_value=10)\n", 551 | "\n", 552 | " # convert ensemble to gene symbols \n", 553 | " import mygene\n", 554 | " mg = mygene.MyGeneInfo()\n", 555 | "\n", 556 | " genes = [g.split('.')[0] for g in full_data.var.index.values]\n", 557 | "\n", 558 | " out = mg.querymany(genes, scopes='ensembl.gene', fields='symbol', species='human', returnall=True)\n", 559 | "\n", 560 | "\n", 561 | " # check when not selecting genes \n", 562 | " for i in out['out']:\n", 563 | " if i['query'] == 'ENSG00000229425':\n", 564 | " print(i)\n", 565 | "\n", 566 | " for i in out['out']:\n", 567 | " if i['query'] == 'ENSG00000130723':\n", 568 | " print(i)\n", 569 | "\n", 570 | " en2symbol = {}\n", 571 | " for i in out['out']:\n", 572 | " if 'symbol' in i.keys():\n", 573 | " if i['query'] not in en2symbol.keys():\n", 574 | " en2symbol[i['query']] = [i['symbol']]\n", 575 | " else:\n", 576 | " en2symbol[i['query']].append(i['symbol'])\n", 577 | " else:\n", 578 | " en2symbol[i['query']] = [i['query']]\n", 579 | "\n", 580 | " symbols = []\n", 581 | " for g in genes:\n", 582 | " symbols.append(en2symbol[g][0]) # take the first one \n", 583 | "\n", 584 | " full_data.var.index = symbols\n", 585 | "\n", 586 | "\n", 587 | " # save \n", 588 | " try:\n", 589 | " full_data.__dict__['_raw'].__dict__['_var'] = full_data.__dict__['_raw'].__dict__['_var'].rename(columns={'_index': 'features'})\n", 590 | " except AttributeError:\n", 591 | " pass \n", 592 | "\n", 593 | " filename = f\"{tissue}_processed{condition}.h5ad\"\n", 594 | " full_data.write_h5ad(os.path.join(parent_folder, filename))" 595 | ] 596 | } 597 | ], 598 | "metadata": { 599 | "kernelspec": { 600 | "display_name": "Python [conda env:anno] *", 601 | "language": "python", 602 | "name": "conda-env-anno-py" 603 | }, 604 | "language_info": { 605 | "codemirror_mode": { 606 | "name": "ipython", 607 | "version": 3 608 | }, 609 | "file_extension": ".py", 610 | "mimetype": "text/x-python", 611 | "name": "python", 612 | "nbconvert_exporter": "python", 613 | "pygments_lexer": "ipython3", 614 | "version": "3.6.13" 615 | } 616 | }, 617 | "nbformat": 4, 618 | "nbformat_minor": 2 619 | } 620 | -------------------------------------------------------------------------------- /tutorails/HuBMAP_datasets_ID.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/doraadong/UNIFAN/96e0811d52e3af4b0d40dfa5fa02cfa0d28c9be6/tutorails/HuBMAP_datasets_ID.xlsx -------------------------------------------------------------------------------- /tutorails/getExample.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import scanpy as sc 5 | 6 | 7 | """ 8 | 9 | Process the Tabula Muris senis data. 10 | 11 | Download data from: https://figshare.com/ndownloader/files/24351086 12 | 13 | """ 14 | 15 | def main(): 16 | # parse command-line arguments 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('-p', '--path', required=True, type=str, 19 | default='./facs.h5ad', help="string, path to the downloaded data, " 20 | "default './facs.h5ad'") 21 | parser.add_argument('-i', '--folder', required=True, type=str, 22 | default='../example/input', help="string, path to the folder to save the data, " 23 | "default '../example/input'") 24 | parser.add_argument('-t', '--tissue', required=True, type=str, 25 | default=None, help="string, specify the output tissue; if using the default None, then all " 26 | "tissues will be outputted and saved separately in the folder; default None") 27 | parser.add_argument('-k', '--topk', required=False, default=2000, type=int, 28 | help="integer, optional, number of most variable genes, default 2000") 29 | 30 | args = parser.parse_args() 31 | print(args) 32 | 33 | parent_folder = args.folder 34 | filepath = args.path 35 | tissue = args.tissue 36 | topk = args.topk 37 | 38 | adata = sc.read(filepath, dtype='float64') 39 | 40 | sc.pp.filter_cells(adata, min_counts=5000) 41 | sc.pp.filter_cells(adata, min_genes=500) 42 | sc.pp.filter_genes(adata, min_cells=5) 43 | 44 | # get unnormalized version to infer highly variable genes 45 | full_data_unnorm = adata[adata.obs["age"] == "3m", :].copy() # equivalent to Tabula Muris data (MARS) 46 | 47 | # not include Brain_Myeloid and Marrow (MARS) 48 | full_data_unnorm = full_data_unnorm[full_data_unnorm.obs["tissue"] != "Marrow", :].copy() 49 | full_data_unnorm = full_data_unnorm[full_data_unnorm.obs["tissue"] != "Brain_Myeloid", :].copy() 50 | 51 | sc.pp.normalize_total(adata, target_sum=1e4) 52 | 53 | sc.pp.log1p(adata) 54 | 55 | sc.pp.scale(adata, max_value=10) 56 | 57 | full_data = adata[adata.obs["age"] == "3m", :] # equivalent to Tabula Muris data (MARS) 58 | 59 | # not include Brain_Myeloid and Marrow (MARS) 60 | full_data = full_data[full_data.obs["tissue"] != "Marrow", :] 61 | full_data = full_data[full_data.obs["tissue"] != "Brain_Myeloid", :] 62 | 63 | if tissue is None: 64 | # save each tissue separately 65 | for tissue in set(full_data.obs['tissue'].values): 66 | print(f"Saving output for {tissue}") 67 | 68 | # get most variable genes using unnormalized 69 | subset_unnorm = full_data_unnorm[full_data_unnorm.obs['tissue'] == tissue].copy() 70 | 71 | # selecting genes using Seurat v3 method 72 | sc.pp.highly_variable_genes(subset_unnorm, n_top_genes=topk, flavor='seurat_v3') 73 | 74 | # keep only the current tissue 75 | subset = full_data[full_data.obs['tissue'] == tissue].copy() 76 | subset.var["highly_variable"] = subset_unnorm.var["highly_variable"] 77 | 78 | # write full data 79 | subset.write_h5ad(os.path.join(parent_folder, f"{tissue}_facts_processed_3m.h5ad")) 80 | else: 81 | # get most variable genes using unnormalized 82 | subset_unnorm = full_data_unnorm[full_data_unnorm.obs['tissue'] == tissue].copy() 83 | 84 | # selecting genes using Seurat v3 method 85 | sc.pp.highly_variable_genes(subset_unnorm, n_top_genes=topk, flavor='seurat_v3') 86 | 87 | # keep only the current tissue 88 | subset = full_data[full_data.obs['tissue'] == tissue].copy() 89 | subset.var["highly_variable"] = subset_unnorm.var["highly_variable"] 90 | 91 | # write full data 92 | subset.write_h5ad(os.path.join(parent_folder, f"{tissue}_facts_processed_3m.h5ad")) 93 | 94 | 95 | 96 | if __name__ == '__main__': 97 | main() 98 | -------------------------------------------------------------------------------- /unifan-main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/doraadong/UNIFAN/96e0811d52e3af4b0d40dfa5fa02cfa0d28c9be6/unifan-main.png -------------------------------------------------------------------------------- /unifan-pretrain.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/doraadong/UNIFAN/96e0811d52e3af4b0d40dfa5fa02cfa0d28c9be6/unifan-pretrain.png -------------------------------------------------------------------------------- /unifan/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | dir_path = os.path.dirname(os.path.realpath(__file__)) 5 | sys.path.append(dir_path) 6 | 7 | __all__ = ['main', 'annocluster', 'autoencoder', 'classifier', 'datasets', 'networks', 'trainer', 'utils'] 8 | for i in __all__: 9 | __import__(i) -------------------------------------------------------------------------------- /unifan/annocluster.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | torch.backends.cudnn.benchmark = True 4 | 5 | from unifan.networks import Encoder, Decoder, Set2Gene 6 | 7 | 8 | class AnnoCluster(nn.Module): 9 | 10 | """ 11 | Clustering with annotator. 12 | 13 | Parameters 14 | ---------- 15 | input_dim: integer 16 | number of input features 17 | z_dim: integer 18 | number of low-dimensional features 19 | gene_set_dim: integer 20 | number of gene sets 21 | tau: float 22 | hyperparameter to weight the annotator loss 23 | zeta: float 24 | hyperparameter to weight the reconstruction loss from embeddings (discrete representations) 25 | encoder_dim: integer 26 | dimension of hidden layer for encoders 27 | emission_dim: integer 28 | dimension of hidden layer for decoders 29 | num_layers_encoder: integer 30 | number of hidden layers for encoder 31 | num_layers_decoder: integer 32 | number of hidden layers for decoder 33 | dropout_rate: float 34 | use_t_dist: boolean 35 | if using t distribution kernel to transform the euclidean distances between encodings and centroids 36 | regulating_probability: string 37 | the type of probability to regulating the clustering (by distance) results 38 | centroids: torch.Tensor 39 | embeddings in the low-dimensional space for the cluster centroids 40 | gene_set_table: torch.Tensor 41 | gene set relationship table 42 | 43 | """ 44 | 45 | def __init__(self, input_dim: int = 10000, z_dim: int = 32, gene_set_dim: int = 335, 46 | tau: float = 1.0, zeta: float = 1.0, n_clusters: int = 16, 47 | encoder_dim: int = 128, emission_dim: int = 128, num_layers_encoder: int = 1, 48 | num_layers_decoder: int = 1, dropout_rate: float = 0.1, use_t_dist: bool = True, 49 | reconstruction_network: str = "gaussian", decoding_network: str = "gaussian", 50 | regulating_probability: str = "classifier", centroids: torch.Tensor = None, 51 | gene_set_table: torch.Tensor = None, use_cuda: bool = False): 52 | 53 | super().__init__() 54 | 55 | # initialize parameters 56 | self.z_dim = z_dim 57 | self.reconstruction_network = reconstruction_network 58 | self.decoding_network = decoding_network 59 | self.tau = tau 60 | self.zeta = zeta 61 | self.n_clusters = n_clusters 62 | self.use_t_dist = use_t_dist 63 | self.regulating_probability = regulating_probability 64 | 65 | if regulating_probability not in ["classifier"]: 66 | raise NotImplementedError(f"The current implementation only support 'classifier', " 67 | f" for regulating probability.") 68 | 69 | # initialize centroids embeddings 70 | if centroids is not None: 71 | self.embeddings = nn.Parameter(centroids, requires_grad=True) 72 | else: 73 | self.embeddings = nn.Parameter(torch.randn(self.n_clusters, self.z_dim) * 0.05, requires_grad=True) 74 | 75 | # initialize loss 76 | self.mse_loss = nn.MSELoss() 77 | self.nLL_loss = nn.NLLLoss() 78 | 79 | # instantiate encoder for z 80 | if self.reconstruction_network == "gaussian": 81 | self.encoder = Encoder(input_dim, z_dim, num_layers=num_layers_encoder, hidden_dim=encoder_dim, 82 | dropout_rate=dropout_rate) 83 | else: 84 | raise NotImplementedError(f"The current implementation only support 'gaussian' for encoder.") 85 | 86 | # instantiate decoder for emission 87 | if self.decoding_network == 'gaussian': 88 | self.decoder_e = Decoder(z_dim, input_dim, num_layers=num_layers_decoder, hidden_dim=emission_dim) 89 | self.decoder_q = Decoder(z_dim, input_dim, num_layers=num_layers_decoder,hidden_dim=emission_dim) 90 | elif self.decoding_network == 'geneSet': 91 | self.decoder_e = Set2Gene(gene_set_table) 92 | self.decoder_q = Set2Gene(gene_set_table) 93 | else: 94 | raise NotImplementedError(f"The current implementation only support 'gaussian', " 95 | f"'geneSet' for emission decoder.") 96 | 97 | self.use_cuda = use_cuda 98 | if use_cuda: 99 | self.cuda() 100 | 101 | def forward(self, x): 102 | 103 | # get encoding 104 | z_e, _ = self.encoder(x) 105 | 106 | # get the index of embedding closed to the encoding 107 | k, z_dist, dist_prob = self._get_clusters(z_e) 108 | 109 | # get embeddings (discrete representations) 110 | z_q = self._get_embeddings(k) 111 | 112 | # decode embedding (discrete representation) and encoding 113 | x_q, _ = self.decoder_q(z_q) 114 | x_e, _ = self.decoder_e(z_e) 115 | 116 | return x_e, x_q, z_e, z_q, k, z_dist, dist_prob 117 | 118 | def _get_clusters(self, z_e): 119 | """ 120 | 121 | Assign each sample to a cluster based on euclidean distances. 122 | 123 | Parameters 124 | ---------- 125 | z_e: torch.Tensor 126 | low-dimensional encodings 127 | 128 | Returns 129 | ------- 130 | k: torch.Tensor 131 | cluster assignments 132 | z_dist: torch.Tensor 133 | distances between encodings and centroids 134 | dist_prob: torch.Tensor 135 | probability of closeness of encodings to centroids transformed by t-distribution 136 | 137 | """ 138 | 139 | _z_dist = (z_e.unsqueeze(1) - self.embeddings.unsqueeze(0)) ** 2 140 | z_dist = torch.sum(_z_dist, dim=-1) 141 | 142 | if self.use_t_dist: 143 | dist_prob = self._t_dist_sim(z_dist, df=10) 144 | k = torch.argmax(dist_prob, dim=-1) 145 | else: 146 | k = torch.argmin(z_dist, dim=-1) 147 | dist_prob = None 148 | 149 | return k, z_dist, dist_prob 150 | 151 | def _t_dist_sim(self, z_dist, df=10): 152 | """ 153 | Transform distances using t-distribution kernel. 154 | 155 | Parameters 156 | ---------- 157 | z_dist: torch.Tensor 158 | distances between encodings and centroids 159 | 160 | Returns 161 | ------- 162 | dist_prob: torch.Tensor 163 | probability of closeness of encodings to centroids transformed by t-distribution 164 | 165 | """ 166 | 167 | _factor = - ((df + 1) / 2) 168 | dist_prob = torch.pow((1 + z_dist / df), _factor) 169 | dist_prob = dist_prob / dist_prob.sum(axis=1).unsqueeze(1) 170 | 171 | return dist_prob 172 | 173 | def _get_embeddings(self, k): 174 | """ 175 | 176 | Get the embeddings (discrete representations). 177 | 178 | Parameters 179 | ---------- 180 | k: torch.Tensor 181 | cluster assignments 182 | 183 | Returns 184 | ------- 185 | z_q: torch.Tensor 186 | low-dimensional embeddings (discrete representations) 187 | 188 | """ 189 | 190 | k = k.long() 191 | _z_q = [] 192 | for i in range(len(k)): 193 | _z_q.append(self.embeddings[k[i]]) 194 | 195 | z_q = torch.stack(_z_q) 196 | 197 | return z_q 198 | 199 | 200 | def _loss_reconstruct(self, x, x_e, x_q): 201 | """ 202 | Calculate reconstruction loss. 203 | 204 | Parameters 205 | ----------- 206 | x: torch.Tensor 207 | original observation in full-dimension 208 | x_e: torch.Tensor 209 | reconstructed observation encodings 210 | x_q: torch.Tensor 211 | reconstructed observation from embeddings (discrete representations) 212 | """ 213 | 214 | l_e = self.mse_loss(x, x_e) 215 | l_q = self.mse_loss(x, x_q) 216 | mse_l = l_e + l_q * self.zeta 217 | return mse_l 218 | 219 | def _loss_z_prob(self, z_dist, prior_prob=None): 220 | """ 221 | Calculate annotator loss. 222 | 223 | Parameters 224 | ---------- 225 | z_dist: torch.Tensor 226 | distances between encodings and centroids 227 | prior_prob: torch.Tensor 228 | probability learned from other source (e.g. prior) about cluster assignment 229 | 230 | """ 231 | if self.regulating_probability == "classifier": 232 | 233 | weighted_z_dist_prob = z_dist * prior_prob 234 | prob_z_l = torch.mean(weighted_z_dist_prob) 235 | 236 | else: 237 | raise NotImplementedError(f"The current implementation only support " 238 | f"'classifier' for prob_z_l method.") 239 | 240 | return prob_z_l 241 | 242 | def loss(self, x, x_e, x_q, z_dist, prior_prob=None): 243 | 244 | mse_l = self._loss_reconstruct(x, x_e, x_q) 245 | prob_z_l = self._loss_z_prob(z_dist, prior_prob) 246 | 247 | l = mse_l + self.tau * prob_z_l 248 | 249 | return l 250 | 251 | 252 | -------------------------------------------------------------------------------- /unifan/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | torch.backends.cudnn.benchmark = True 5 | 6 | from unifan.networks import Encoder, Decoder, Set2Gene, LinearCoder, NonNegativeCoder, SigmoidCoder 7 | 8 | 9 | class autoencoder(nn.Module): 10 | """ 11 | 12 | Autoencoder used for pre-training. 13 | 14 | Parameters 15 | ---------- 16 | input_dim: integer 17 | number of input features 18 | z_dim: integer 19 | number of low-dimensional features 20 | gene_set_dim: integer 21 | number of gene sets 22 | encoder_dim: integer 23 | dimension of hidden layer for encoders 24 | emission_dim: integer 25 | dimension of hidden layer for decoders 26 | num_layers_encoder: integer 27 | number of hidden layers for encoder 28 | num_layers_decoder: integer 29 | number of hidden layers for decoder 30 | dropout_rate: float 31 | gene_set_table: torch.Tensor 32 | gene set relationship table 33 | 34 | """ 35 | 36 | def __init__(self, input_dim: int = 10000, z_dim: int = 32, gene_set_dim: int = 335, encoder_dim: int = 128, 37 | emission_dim: int = 128, num_layers_encoder: int = 1, num_layers_decoder: int = 1, 38 | dropout_rate: float = 0.1, reconstruction_network: str = "non-negative", 39 | decoding_network: str = "geneSet", gene_set_table: torch.Tensor = None, use_cuda: bool = False): 40 | 41 | super().__init__() 42 | 43 | # initialize parameters 44 | self.z_dim = z_dim 45 | self.reconstruction_network = reconstruction_network 46 | self.decoding_network = decoding_network 47 | 48 | # initialize loss 49 | self.mse_loss = nn.MSELoss() 50 | 51 | # initialize encoder and decoder 52 | if self.reconstruction_network == 'linear' and self.decoding_network == 'linear': 53 | self.encoder = LinearCoder(input_dim, z_dim) 54 | self.decoder_e = LinearCoder(z_dim, input_dim) 55 | else: 56 | 57 | if self.reconstruction_network == 'non-negative': 58 | # instantiate encoder for z 59 | self.encoder = NonNegativeCoder(input_dim, z_dim, num_layers=num_layers_encoder, hidden_dim=encoder_dim, 60 | dropout_rate=dropout_rate) 61 | elif self.reconstruction_network == 'sigmoid': 62 | # instantiate encoder for z 63 | self.encoder = SigmoidCoder(input_dim, z_dim, num_layers=num_layers_encoder, hidden_dim=encoder_dim, 64 | dropout_rate=dropout_rate) 65 | elif self.reconstruction_network == "gaussian": 66 | # instantiate encoder for z, using standard encoder 67 | self.encoder = Encoder(input_dim, z_dim, num_layers=num_layers_encoder, hidden_dim=encoder_dim, 68 | dropout_rate=dropout_rate) 69 | 70 | else: 71 | raise NotImplementedError(f"The current implementation only support 'gaussian', " 72 | f"'non-negative' or 'sigmoid' for encoder.") 73 | 74 | # instantiate decoder for emission 75 | if self.decoding_network == 'gaussian': 76 | self.decoder_e = Decoder(z_dim, input_dim, num_layers=num_layers_decoder, hidden_dim=emission_dim) 77 | elif self.decoding_network == 'geneSet': 78 | self.decoder_e = Set2Gene(gene_set_table) 79 | else: 80 | raise NotImplementedError(f"The current implementation only support 'gaussian', " 81 | f"'geneSet' for emission decoder.") 82 | 83 | self.use_cuda = use_cuda 84 | if use_cuda: 85 | self.cuda() 86 | 87 | def forward(self, data): 88 | 89 | x = data 90 | 91 | # get encoding 92 | z_e, _ = self.encoder(x) 93 | 94 | # decode encoding 95 | x_e, _ = self.decoder_e(z_e) 96 | 97 | return x_e, z_e 98 | 99 | def _loss_reconstruct(self, x, x_e): 100 | """ 101 | Calculate reconstruction loss. 102 | 103 | Parameters 104 | ---------- 105 | x: torch.Tensor 106 | original data 107 | x_e: torch.Tensor 108 | reconstructed data 109 | 110 | Returns 111 | ------- 112 | mse_l: torch.Tensor 113 | reconstruction loss 114 | 115 | """ 116 | l_e = self.mse_loss(x, x_e) 117 | mse_l = l_e 118 | 119 | return mse_l 120 | 121 | def loss(self, x, x_e): 122 | l = self._loss_reconstruct(x, x_e) 123 | return l 124 | -------------------------------------------------------------------------------- /unifan/classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | torch.backends.cudnn.benchmark = True 5 | 6 | from unifan.networks import Decode2Labels 7 | 8 | 9 | class classifier(nn.Module): 10 | """ 11 | 12 | A classifier. 13 | 14 | Parameters 15 | ---------- 16 | z_dim: integer 17 | number of input features 18 | output_dim: integer 19 | number of output (number of types of labels) 20 | emission_dim: integer 21 | dimension of hidden layer 22 | num_layers: integer 23 | number of hidden layers 24 | 25 | """ 26 | 27 | def __init__(self, z_dim: int = 335, output_dim: int = 10, emission_dim: int = 128, num_layers: int = 1, 28 | use_cuda=False): 29 | super().__init__() 30 | 31 | # initialize loss 32 | self.loss_function = nn.NLLLoss() 33 | 34 | # instantiate decoder for emission 35 | self.decoder = Decode2Labels(z_dim, output_dim) 36 | 37 | self.use_cuda = use_cuda 38 | if use_cuda: 39 | self.cuda() 40 | 41 | def forward(self, x): 42 | y_pre = self.decoder(x) 43 | return y_pre 44 | 45 | def loss(self, y_pre, y_true): 46 | l = self.loss_function(y_pre, y_true) 47 | return l 48 | -------------------------------------------------------------------------------- /unifan/datasets.py: -------------------------------------------------------------------------------- 1 | import scanpy as sc 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class AnnDataset(Dataset): 7 | def __init__(self, filepath: str, label_name: str = None, second_filepath: str = None, 8 | variable_gene_name: str = None): 9 | """ 10 | 11 | Anndata dataset. 12 | 13 | Parameters 14 | ---------- 15 | label_name: string 16 | name of the cell type annotation, default 'label' 17 | second_filepath: string 18 | path to another input file other than the main one; e.g. path to predicted clusters or 19 | side information; only support numpy array 20 | 21 | """ 22 | 23 | super().__init__() 24 | 25 | self.data = sc.read(filepath, dtype='float64', backed="r") 26 | 27 | genes = self.data.var.index.values 28 | self.genes_upper = [g.upper() for g in genes] 29 | if label_name is not None: 30 | self.clusters_true = self.data.obs[label_name].values 31 | else: 32 | self.clusters_true = None 33 | 34 | self.N = self.data.shape[0] 35 | self.G = len(self.genes_upper) 36 | 37 | self.secondary_data = None 38 | if second_filepath is not None: 39 | self.secondary_data = np.load(second_filepath) 40 | assert len(self.secondary_data) == self.N, "The other file have same length as the main" 41 | 42 | if variable_gene_name is not None: 43 | _idx = np.where(self.data.var[variable_gene_name].values)[0] 44 | self.exp_variable_genes = self.data.X[:, _idx] 45 | self.variable_genes_names = self.data.var.index.values[_idx] 46 | 47 | def __len__(self): 48 | return self.N 49 | 50 | def __getitem__(self, idx): 51 | main = self.data[idx].X.flatten() 52 | 53 | if self.secondary_data is not None: 54 | secondary = self.secondary_data[idx].flatten() 55 | return main, secondary 56 | else: 57 | return main 58 | 59 | 60 | class NumpyDataset(Dataset): 61 | def __init__(self, filepath: str, second_filepath: str = None): 62 | """ 63 | 64 | Numpy array dataset. 65 | 66 | Parameters 67 | ---------- 68 | second_filepath: string 69 | path to another input file other than the main one; e.g. path to predicted clusters or 70 | side information; only support numpy array 71 | 72 | """ 73 | super().__init__() 74 | 75 | self.data = np.load(filepath) 76 | self.N = self.data.shape[0] 77 | self.G = self.data.shape[1] 78 | 79 | self.secondary_data = None 80 | if second_filepath is not None: 81 | self.secondary_data = np.load(second_filepath) 82 | assert len(self.secondary_data) == self.N, "The other file have same length as the main" 83 | 84 | def __len__(self): 85 | return self.N 86 | 87 | def __getitem__(self, idx): 88 | main = self.data[idx].flatten() 89 | 90 | if self.secondary_data is not None: 91 | secondary = self.secondary_data[idx].flatten() 92 | return main, secondary 93 | else: 94 | return main 95 | -------------------------------------------------------------------------------- /unifan/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import gc 4 | import itertools 5 | import argparse 6 | 7 | import torch 8 | import scanpy as sc 9 | import pandas as pd 10 | import numpy as np 11 | 12 | from unifan.datasets import AnnDataset, NumpyDataset 13 | from unifan.annocluster import AnnoCluster 14 | from unifan.autoencoder import autoencoder 15 | from unifan.classifier import classifier 16 | from unifan.utils import getGeneSetMatrix, str2bool 17 | from unifan.trainer import Trainer 18 | 19 | 20 | def main(): 21 | # parse command-line arguments 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('-i', '--input', required=True, type=str, 24 | default='../input/data.h5ad', help="string, path to the input expression data, " 25 | "default '../input/data.h5ad'") 26 | parser.add_argument('-o', '--output', required=True, type=str, 27 | default='../output/', help="string, path to the output folder, default '../output/'") 28 | parser.add_argument('-p', '--project', required=True, type=str, 29 | default='data', help="string, identifier for the project, e.g., tabula_muris") 30 | parser.add_argument('-t', '--tissue', required=True, type=str, 31 | default='tissue', help="string, tissue where the input data is sampled from") 32 | parser.add_argument('-e', '--geneSetsPath', required=True, type=str, 33 | default='../gene_sets/', help="string, path to the folder where gene sets can be found, " 34 | "default='../gene_sets/'") 35 | parser.add_argument('-l', '--label', required=False, type=str, 36 | default=None, help="string, optional, the column / field name of the ground truth label, if " 37 | "available; used for evaluation only; default None") 38 | parser.add_argument('-v', '--variable', required=False, type=str, 39 | default='highly_variable', help="string, optional, the column / field name of the highly " 40 | "variable genes; default 'highly_variable'") 41 | parser.add_argument('-r', '--prior', required=False, type=str, 42 | default='c5.go.bp.v7.4.symbols.gmt+c2.cp.v7.4.symbols.gmt+TF-DNA', 43 | help="string, optional, gene set file names used to learn the gene set activity scores, " 44 | "use '+' to separate multiple gene set names, " 45 | "default c5.go.bp.v7.4.symbols.gmt+c2.cp.v7.4.symbols.gmt+TF-DNA") 46 | parser.add_argument('-f', '--features', required=False, default='gene_gene_sets', type=str, 47 | choices=['gene_sets', 'gene', 'gene_gene_sets'], 48 | help="string, optional, features used for the annotator, any of 'gene_sets', 'gene' or " 49 | "'gene_gene_sets', default 'gene_gene_sets'") 50 | parser.add_argument('-a', '--alpha', required=False, default=1e-2, type=float, 51 | help="float, optional, hyperparameter for the L1 term in the set cover loss, default 1e-2") 52 | parser.add_argument('-b', '--beta', required=False, default=1e-5, type=float, 53 | help="float, optional, hyperparameter for the set cover term in the set cover loss, " 54 | "default 1e-5") 55 | parser.add_argument('-g', '--gamma', required=False, default=1e-3, type=float, 56 | help="float, optional, hyperparameter for the exclusive L1 term, default 1e-3") 57 | parser.add_argument('-u', '--tau', required=False, default=10, type=float, 58 | help="float, optional, hyperparameter for the annotator loss, default 10") 59 | parser.add_argument('-d', '--dim', required=False, default=32, type=int, 60 | help="integer, optional, dimension for the low-dimensional representation, default 32") 61 | parser.add_argument('-s', '--batch', required=False, default=128, type=int, 62 | help="integer, optional, batch size for training except for pretraining annotator " 63 | "(fixed at 32), default 128") 64 | parser.add_argument('-na', '--nanno', required=False, default=50, type=int, 65 | help="integer, optional, number of epochs to pretrain the annotator, default 50") 66 | parser.add_argument('-ns', '--nscore', required=False, default=70, type=int, 67 | help="integer, optional, number of epochs to train the gene set activity model, default 70") 68 | parser.add_argument('-nu', '--nauto', required=False, default=50, type=int, 69 | help="integer, optional, number of epochs to pretrain the AnnoCluster model, default 50") 70 | parser.add_argument('-nc', '--ncluster', required=False, default=25, type=int, 71 | help="integer, optional, number of epochs to train the AnnoCluster model, default 25") 72 | parser.add_argument('-nze', '--nzenco', required=False, default=3, type=int, 73 | help="float, optional, number of hidden layers for encoder of AnnoCluster, default 3") 74 | parser.add_argument('-nzd', '--nzdeco', required=False, default=2, type=int, 75 | help="float, optional, number of hidden layers for decoder of AnnoCluster, default 2") 76 | parser.add_argument('-dze', '--dimzenco', required=False, default=128, type=int, 77 | help="integer, optional, number of nodes for hidden layers for encoder of AnnoCluster, " 78 | "default 128") 79 | parser.add_argument('-dzd', '--dimzdeco', required=False, default=128, type=int, 80 | help="integer, optional, number of nodes for hidden layers for decoder of AnnoCluster, " 81 | "default 128") 82 | parser.add_argument('-nre', '--nrenco', required=False, default=5, type=int, 83 | help="integer, optional, number of hidden layers for the encoder of gene set activity scores " 84 | "model, default 5") 85 | parser.add_argument('-dre', '--dimrenco', required=False, default=128, type=int, 86 | help="integer, optional, number of nodes for hidden layers for encoder of gene set activity " 87 | "scores model, default 128") 88 | parser.add_argument('-drd', '--dimrdeco', required=False, default=128, type=int, 89 | help="integer, optional, number of nodes for hidden layers for decoder of gene set activity " 90 | "scores model, default 128") 91 | parser.add_argument('-n', '--network', required=False, choices=['sigmoid', 'non-negative', 'gaussian'], type=str, 92 | default='non-negative', help="string, optional, the encoder for the gene set activity model, " 93 | "any of 'sigmoid', 'non-negative' or 'gaussian', " 94 | "default 'non-negative'") 95 | parser.add_argument('-m', '--seed', required=False, default=0, type=int, 96 | help="integer, optional, random seed for the initialization, default 0") 97 | parser.add_argument('-c', '--cuda', required=False, type=str2bool, 98 | default=False, help="boolean, optional, if use GPU for neural network training, default False") 99 | parser.add_argument('-w', '--nworkers', required=False, default=8, type=int, 100 | help="integer, optional, number of workers for dataloader, default 8") 101 | 102 | 103 | args = parser.parse_args() 104 | print(args) 105 | 106 | data_filepath = args.input 107 | output_path = args.output 108 | gene_sets_path = args.geneSetsPath 109 | project = args.project 110 | tissue = args.tissue 111 | label_name = args.label 112 | variable_gene_name = args.variable 113 | 114 | prior_name = args.prior 115 | features_type = args.features 116 | alpha = args.alpha 117 | beta = args.beta 118 | weight_decay = args.gamma 119 | tau = args.tau 120 | z_dim = args.dim 121 | 122 | batch_size = args.batch 123 | num_epochs_classifier = args.nanno 124 | num_epochs_r = args.nscore 125 | num_epochs_z = args.nauto 126 | r_epoch = num_epochs_r - 1 127 | z_epoch = num_epochs_z - 1 128 | num_epochs_annocluster = args.ncluster 129 | 130 | z_encoder_layers = args.nzenco 131 | z_decoder_layers = args.nzdeco 132 | z_encoder_dim = args.dimzenco 133 | z_decoder_dim = args.dimzdeco 134 | r_encoder_layers = args.nrenco 135 | r_decoder_layers = 1 136 | r_encoder_dim = args.dimrenco 137 | r_decoder_dim = args.dimrdeco 138 | rnetwork = args.network 139 | 140 | random_seed = args.seed 141 | 142 | use_cuda = args.cuda 143 | num_workers = args.nworkers 144 | 145 | # ------ training conditions 146 | device = torch.device("cuda" if use_cuda else "cpu") 147 | if use_cuda: 148 | pin_memory = True 149 | non_blocking = True 150 | else: 151 | pin_memory = False 152 | non_blocking = False 153 | 154 | if '+' in prior_name: 155 | prior_names_list = prior_name.split('+') 156 | 157 | # ------ prepare for output 158 | output_parent_path = os.path.join(output_path, f"{project}/{tissue}/") 159 | 160 | r_folder = f"{output_parent_path}r" 161 | input_r_ae_path = os.path.join(r_folder, f"r_model_{r_epoch}.pickle") 162 | input_r_path = os.path.join(r_folder, f"r_{r_epoch}.npy") 163 | input_r_names_path = os.path.join(r_folder, f"r_names_{r_epoch}.npy") 164 | 165 | pretrain_z_folder = f"{output_parent_path}pretrain_z" 166 | input_z_path = os.path.join(pretrain_z_folder, f"pretrain_z_{z_epoch}.npy") 167 | input_ae_path = os.path.join(pretrain_z_folder, f"pretrain_z_model_{z_epoch}.pickle") 168 | input_cluster_path = os.path.join(pretrain_z_folder, f"cluster_{z_epoch}.npy") 169 | 170 | pretrain_annotator_folder = f"{output_parent_path}pretrain_annotator" 171 | annocluster_folder = f"{output_parent_path}annocluster_{features_type}" 172 | 173 | # ------ load data 174 | if features_type in ["gene", "gene_gene_sets"]: 175 | expression_only = AnnDataset(data_filepath, label_name=label_name, variable_gene_name=variable_gene_name) 176 | exp_variable_genes = expression_only.exp_variable_genes 177 | variable_genes_names = expression_only.variable_genes_names 178 | else: 179 | expression_only = AnnDataset(data_filepath, label_name=label_name) 180 | exp_variable_genes = None 181 | variable_genes_names = None 182 | 183 | genes_upper = expression_only.genes_upper 184 | N = expression_only.N 185 | G = expression_only.G 186 | 187 | # ------ process prior data 188 | # generate gene_set_matrix 189 | if '+' in prior_name: 190 | _matrix_list = [] 191 | _keys_list = [] 192 | for _name in prior_names_list: 193 | _matrix, _keys = getGeneSetMatrix(_name, genes_upper, gene_sets_path) 194 | _matrix_list.append(_matrix) 195 | _keys_list.append(_keys) 196 | 197 | gene_set_matrix = np.concatenate(_matrix_list, axis=0) 198 | keys_all = list(itertools.chain(*_keys_list)) 199 | 200 | del _matrix_list 201 | del _keys_list 202 | gc.collect() 203 | 204 | else: 205 | gene_set_matrix, keys_all = getGeneSetMatrix(prior_name, genes_upper, gene_sets_path) 206 | 207 | # ------ set-up for the set cover loss 208 | if beta != 0: 209 | # get the gene set matrix with only genes covered 210 | genes_covered = np.sum(gene_set_matrix, axis=0) 211 | gene_covered_matrix = gene_set_matrix[:, genes_covered != 0] 212 | gene_covered_matrix = torch.from_numpy(gene_covered_matrix).to(device, non_blocking=non_blocking).float() 213 | beta_list = torch.from_numpy(np.repeat(beta, gene_covered_matrix.shape[1])).to(device, 214 | non_blocking=non_blocking).float() 215 | 216 | del genes_covered 217 | gc.collect() 218 | else: 219 | gene_covered_matrix = None 220 | beta_list = None 221 | 222 | gene_set_dim = gene_set_matrix.shape[0] 223 | gene_set_matrix = torch.from_numpy(gene_set_matrix).to(device, non_blocking=non_blocking) 224 | 225 | # ------ Train gene set activity scores (r) model ------ 226 | if features_type == "gene": 227 | z_gene_set = exp_variable_genes 228 | set_names = list(variable_genes_names) 229 | else: 230 | 231 | model_gene_set = autoencoder(input_dim=G, z_dim=gene_set_dim, gene_set_dim=gene_set_dim, 232 | encoder_dim=r_encoder_dim, emission_dim=r_decoder_dim, 233 | num_layers_encoder=r_encoder_layers, num_layers_decoder=r_decoder_layers, 234 | reconstruction_network=rnetwork, decoding_network='geneSet', 235 | gene_set_table=gene_set_matrix, use_cuda=use_cuda) 236 | 237 | if os.path.isfile(input_r_path): 238 | print(f"Inferred r exists. No need to train the gene set activity scores model.") 239 | z_gene_set = np.load(input_r_path) 240 | else: 241 | if os.path.isfile(input_r_ae_path): 242 | model_gene_set.load_state_dict(torch.load(input_r_ae_path, map_location=device)['state_dict']) 243 | 244 | trainer = Trainer(dataset=expression_only, model=model_gene_set, model_name="r", batch_size=batch_size, 245 | num_epochs=num_epochs_r, save_infer=True, output_folder=r_folder, num_workers=num_workers, 246 | use_cuda=use_cuda) 247 | if os.path.isfile(input_r_ae_path): 248 | print( 249 | f"Inferred r model exists but r does not. Need to infer r and no need to train the gene set activity " 250 | f"scores model.") 251 | z_gene_set = trainer.infer_r(alpha=alpha, beta=beta, beta_list=beta_list, 252 | gene_covered_matrix=gene_covered_matrix) 253 | np.save(input_r_path, z_gene_set) 254 | else: 255 | print(f"Start training the gene set activity scores model ... ") 256 | trainer.train(alpha=alpha, beta=beta, beta_list=beta_list, gene_covered_matrix=gene_covered_matrix) 257 | z_gene_set = np.load(input_r_path) 258 | 259 | z_gene_set = torch.from_numpy(z_gene_set) 260 | 261 | # filter r to keep only non-zero values 262 | idx_non_0_gene_sets = np.where(z_gene_set.numpy().sum(axis=0) != 0)[0] 263 | 264 | # get kepted gene set names 265 | set_names = np.array(keys_all)[idx_non_0_gene_sets] 266 | 267 | z_gene_set = z_gene_set[:, idx_non_0_gene_sets] 268 | print(f"Aftering filtering, we have {z_gene_set.shape[1]} genesets") 269 | 270 | # add also selected genes if using "gene_gene_sets" 271 | if features_type == "gene_gene_sets": 272 | z_gene_set = np.concatenate([z_gene_set, exp_variable_genes], axis=1) 273 | set_names = list(set_names) + list(variable_genes_names) 274 | else: 275 | pass 276 | 277 | print(f"z_gene_set: {features_type}: {z_gene_set.shape}") 278 | print(f"z_gene_set: {features_type}: {len(set_names)}") 279 | 280 | # save feature names 281 | input_r_names_path = f"{input_r_names_path}_filtered_{features_type}.npy" 282 | np.save(input_r_names_path, set_names) 283 | 284 | # save processed features 285 | input_r_path = f"{input_r_path}_filtered_{features_type}.npy" 286 | np.save(input_r_path, z_gene_set) 287 | gene_set_dim = z_gene_set.shape[1] 288 | 289 | try: 290 | z_gene_set = z_gene_set.numpy() 291 | except AttributeError: 292 | pass 293 | 294 | # ------ Pretrain annocluster & initialize clustering ------ 295 | model_autoencoder = autoencoder(input_dim=G, z_dim=z_dim, gene_set_dim=gene_set_dim, 296 | encoder_dim=z_encoder_dim, emission_dim=z_decoder_dim, 297 | num_layers_encoder=z_encoder_layers, num_layers_decoder=z_decoder_layers, 298 | reconstruction_network='gaussian', decoding_network='gaussian', 299 | use_cuda=use_cuda) 300 | 301 | if os.path.isfile(input_z_path) and os.path.isfile(input_ae_path): 302 | print(f"Both pretrained autoencoder and inferred z exist. No need to pretrain the annocluster model.") 303 | z_init = np.load(input_z_path) 304 | model_autoencoder.load_state_dict(torch.load(input_ae_path, map_location=device)['state_dict']) 305 | else: 306 | if os.path.isfile(input_ae_path): 307 | model_autoencoder.load_state_dict(torch.load(input_ae_path, map_location=device)['state_dict']) 308 | 309 | trainer = Trainer(dataset=expression_only, model=model_autoencoder, model_name="pretrain_z", 310 | batch_size=batch_size, 311 | num_epochs=num_epochs_z, save_infer=True, output_folder=pretrain_z_folder, 312 | num_workers=num_workers, 313 | use_cuda=use_cuda) 314 | 315 | if os.path.isfile(input_ae_path): 316 | print(f"Only pretrained autoencoder exists. Need to infer z and no need to pretrain the annocluster model.") 317 | z_init = trainer.infer_z() 318 | np.save(input_z_path, z_init) 319 | else: 320 | print(f"Start training pretrain the annocluster model ... ") 321 | trainer.train() 322 | z_init = np.load(input_z_path) 323 | 324 | z_init = torch.from_numpy(z_init) 325 | 326 | try: 327 | z_init = z_init.numpy() 328 | except AttributeError: 329 | pass 330 | 331 | # initialize using leiden clustering 332 | adata = sc.AnnData(X=z_init) 333 | adata.obsm['X_unifan'] = z_init 334 | sc.pp.neighbors(adata, n_pcs=z_dim, use_rep='X_unifan', random_state=random_seed) 335 | sc.tl.leiden(adata, resolution=1, random_state=random_seed) 336 | clusters_pre = adata.obs['leiden'].astype('int').values # original as string 337 | 338 | # save for the dataset for classifier training 339 | np.save(input_cluster_path, clusters_pre) 340 | 341 | # initialize centroids 342 | try: 343 | df_cluster = pd.DataFrame(z_init.detach().cpu().numpy()) 344 | except AttributeError: 345 | df_cluster = pd.DataFrame(z_init) 346 | 347 | cluster_labels = np.unique(clusters_pre) 348 | M = len(set(cluster_labels)) # set as number of clusters 349 | df_cluster['cluster'] = clusters_pre 350 | 351 | # get centroids 352 | centroids = df_cluster.groupby('cluster').mean().values 353 | centroids_torch = torch.from_numpy(centroids) 354 | 355 | # ------ pretrain annotator (classification) ------ 356 | cls_times = 1 # count how many times of running classification 357 | cls_training_accuracy = 1 # initialize being 1 so that to run at least once 358 | weight_decay_candidates = [50, 20, 10, 5.5, 5, 4.5, 4, 3.5, 3, 2.5, 2, 1, 5e-1, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5] 359 | idx_starting_weight_decay = weight_decay_candidates.index(weight_decay) 360 | 361 | while cls_training_accuracy >= 0.99: 362 | # assign new weight decay (first time running kepted the same) 363 | weight_decay = weight_decay_candidates[idx_starting_weight_decay - cls_times + 1] 364 | 365 | print(f"Run classifier the {cls_times}th time with {weight_decay}") 366 | 367 | prior_cluster = NumpyDataset(input_r_path, input_cluster_path) 368 | 369 | model_classifier = classifier(output_dim=M, z_dim=gene_set_dim, emission_dim=128, use_cuda=use_cuda) 370 | 371 | trainer = Trainer(dataset=prior_cluster, model=model_classifier, model_name="pretrain_annotator", batch_size=32, 372 | num_epochs=num_epochs_classifier, save_infer=False, output_folder=pretrain_annotator_folder, 373 | num_workers=num_workers, use_cuda=use_cuda) 374 | 375 | trainer.train(weight_decay=weight_decay) 376 | clusters_classifier = trainer.infer_annotator() 377 | 378 | cls_training_accuracy = (clusters_classifier == clusters_pre).sum() / N 379 | print(f"Cluster accuracy on training: \n {cls_training_accuracy}") 380 | 381 | cls_times += 1 382 | 383 | # ------ clustering ------ 384 | num_epochs = num_epochs_annocluster 385 | use_pretrain = True 386 | 387 | model_annocluster = AnnoCluster(input_dim=G, z_dim=z_dim, gene_set_dim=gene_set_dim, tau=tau, n_clusters=M, 388 | encoder_dim=z_encoder_dim, emission_dim=z_decoder_dim, 389 | num_layers_encoder=z_encoder_layers, num_layers_decoder=z_decoder_layers, 390 | use_t_dist=True, reconstruction_network='gaussian', decoding_network='gaussian', 391 | centroids=centroids_torch, gene_set_table=gene_set_matrix, use_cuda=use_cuda) 392 | 393 | if use_pretrain: 394 | pretrained_state_dict = model_autoencoder.state_dict() 395 | 396 | # load pretrained AnnoCluster model 397 | state_dict = model_annocluster.state_dict() 398 | for k, v in state_dict.items(): 399 | if k in pretrained_state_dict.keys(): 400 | state_dict[k] = pretrained_state_dict[k] 401 | 402 | model_annocluster.load_state_dict(state_dict) 403 | 404 | # reload dataset, loading gene set activity scores together 405 | expression_prior = AnnDataset(data_filepath, second_filepath=input_r_path, label_name=label_name) 406 | 407 | trainer = Trainer(dataset=expression_prior, model=model_annocluster, model_2nd=model_classifier, 408 | model_name="annocluster", batch_size=batch_size, num_epochs=num_epochs_annocluster, 409 | save_infer=True, output_folder=annocluster_folder, num_workers=num_workers, use_cuda=use_cuda) 410 | trainer.train(weight_decay=weight_decay) 411 | 412 | 413 | if __name__ == '__main__': 414 | main() 415 | -------------------------------------------------------------------------------- /unifan/networks.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class FullyConnectedLayers(nn.Module): 8 | """ 9 | Parameters 10 | ---------- 11 | input_dim: integer 12 | number of input features 13 | output_dim: integer 14 | number of output features 15 | num_layers: integer 16 | number of hidden layers 17 | hidden_dim: integer 18 | dimension of hidden layer 19 | dropout_rate: float 20 | bias: boolean 21 | if apply bias to the linear layers 22 | batch_norm: boolean 23 | if apply batch normalization 24 | 25 | """ 26 | 27 | def __init__(self, input_dim: int, output_dim: int, num_layers: int = 1, hidden_dim: int = 128, 28 | dropout_rate: float = 0.1, bias: bool = True, batch_norm: bool = False): 29 | super().__init__() 30 | layers_dim = [input_dim] + [hidden_dim for i in range(num_layers - 1)] + [output_dim] 31 | 32 | self.all_layers = nn.Sequential(collections.OrderedDict( 33 | [(f"Layer {i}", nn.Sequential( 34 | nn.Linear(input_dim, output_dim, bias=bias), 35 | nn.BatchNorm1d(output_dim) if batch_norm else None, 36 | nn.ReLU(), 37 | nn.Dropout(p=dropout_rate) if dropout_rate > 0 else None)) 38 | for i, (input_dim, output_dim) in enumerate(zip(layers_dim[:-1], layers_dim[1:]))])) 39 | 40 | def forward(self, x: torch.Tensor): 41 | 42 | for layers in self.all_layers: 43 | for layer in layers: 44 | if layer is not None: 45 | x = layer(x) 46 | 47 | return x 48 | 49 | 50 | class Encoder(nn.Module): 51 | """ 52 | 53 | A standard encoder. 54 | 55 | Parameters 56 | ---------- 57 | input_dim: integer 58 | number of input features 59 | output_dim: integer 60 | number of output features 61 | num_layers: integer 62 | number of hidden layers 63 | hidden_dim: integer 64 | dimension of hidden layer 65 | dropout_rate: float 66 | 67 | """ 68 | 69 | def __init__(self, input_dim: int, output_dim: int, num_layers: int = 1, hidden_dim: int = 128, 70 | dropout_rate: float = 0.1): 71 | super().__init__() 72 | 73 | self.encoder = FullyConnectedLayers(input_dim=input_dim, output_dim=hidden_dim, num_layers=num_layers, 74 | hidden_dim=hidden_dim, dropout_rate=dropout_rate) 75 | self.mean_layer = nn.Linear(hidden_dim, output_dim) 76 | self.var_layer = nn.Linear(hidden_dim, output_dim) 77 | 78 | def forward(self, x: torch.Tensor): 79 | """ 80 | 81 | Parameters 82 | ---------- 83 | x: torch.Tensor 84 | 85 | Returns 86 | ------- 87 | q_m: torch.Tensor 88 | estimated mean 89 | q_v: torch.Tensor 90 | estimated variance 91 | 92 | """ 93 | 94 | q = self.encoder(x) 95 | q_m = self.mean_layer(q) 96 | q_v = torch.exp(self.var_layer(q)) 97 | return q_m, q_v 98 | 99 | 100 | class LinearCoder(nn.Module): 101 | 102 | """ 103 | 104 | A single-layer linear encoder. 105 | 106 | Parameters 107 | ---------- 108 | input_dim: integer 109 | number of input features 110 | output_dim: integer 111 | number of output features 112 | """ 113 | 114 | def __init__(self, input_dim: int, output_dim: int): 115 | super().__init__() 116 | 117 | self.encoder = nn.Linear(input_dim, output_dim) 118 | 119 | def forward(self, x: torch.Tensor): 120 | q = self.encoder(x) 121 | return q, None 122 | 123 | 124 | class NonNegativeCoder(nn.Module): 125 | 126 | """ 127 | 128 | A encoder outputting non-negative values (using ReLU for the output layer). 129 | 130 | Parameters 131 | ---------- 132 | input_dim: integer 133 | number of input features 134 | output_dim: integer 135 | number of output features 136 | num_layers: integer 137 | number of hidden layers 138 | hidden_dim: integer 139 | dimension of hidden layer 140 | dropout_rate: float 141 | 142 | """ 143 | 144 | def __init__(self, input_dim: int, output_dim: int, num_layers: int = 1, hidden_dim: int = 128, 145 | dropout_rate: float = 0.1): 146 | super().__init__() 147 | 148 | self.encoder = FullyConnectedLayers(input_dim=input_dim, output_dim=hidden_dim, num_layers=num_layers, 149 | hidden_dim=hidden_dim, dropout_rate=dropout_rate) 150 | self.mean_layer = FullyConnectedLayers(input_dim=hidden_dim, output_dim=output_dim, num_layers=1, 151 | hidden_dim=hidden_dim, dropout_rate=dropout_rate) 152 | 153 | def forward(self, x: torch.Tensor): 154 | q = self.encoder(x) 155 | q = self.mean_layer(q) 156 | return q, None 157 | 158 | 159 | class SigmoidCoder(nn.Module): 160 | """ 161 | 162 | A encoder using sigmoid for the output layer. 163 | 164 | Parameters 165 | ---------- 166 | input_dim: integer 167 | number of input features 168 | output_dim: integer 169 | number of output features 170 | num_layers: integer 171 | number of hidden layers 172 | hidden_dim: integer 173 | dimension of hidden layer 174 | dropout_rate: float 175 | 176 | """ 177 | 178 | def __init__(self, input_dim: int, output_dim: int, num_layers: int = 1, hidden_dim: int = 128, 179 | dropout_rate: float = 0.1): 180 | super().__init__() 181 | 182 | self.encoder = FullyConnectedLayers(input_dim=input_dim, output_dim=hidden_dim, num_layers=num_layers, 183 | hidden_dim=hidden_dim, dropout_rate=dropout_rate) 184 | self.mean_layer = nn.Sequential(nn.Linear(hidden_dim, output_dim), nn.Sigmoid()) 185 | 186 | def forward(self, x: torch.Tensor): 187 | q = self.encoder(x) 188 | q = self.mean_layer(q) 189 | return q, None 190 | 191 | 192 | class Decoder(nn.Module): 193 | """ 194 | 195 | A standard decoder. 196 | 197 | Parameters 198 | ---------- 199 | input_dim: integer 200 | number of input features 201 | output_dim: integer 202 | number of output features 203 | num_layers: integer 204 | number of hidden layers 205 | hidden_dim: integer 206 | dimension of hidden layer 207 | dropout_rate: float 208 | 209 | """ 210 | 211 | def __init__(self, input_dim: int, output_dim: int, num_layers: int = 1, hidden_dim: int = 128, 212 | dropout_rate: float = 0.1): 213 | super().__init__() 214 | self.decoder = FullyConnectedLayers(input_dim=input_dim, output_dim=hidden_dim, num_layers=num_layers, 215 | hidden_dim=hidden_dim, dropout_rate=dropout_rate) 216 | 217 | self.mean_layer = nn.Linear(hidden_dim, output_dim) 218 | self.var_layer = nn.Linear(hidden_dim, output_dim) 219 | 220 | def forward(self, x: torch.Tensor): 221 | """ 222 | 223 | Parameters 224 | ---------- 225 | x: torch.Tensor 226 | 227 | Returns 228 | ------- 229 | p_m: torch.Tensor 230 | estimated mean 231 | p_v: torch.Tensor 232 | estimated variance 233 | """ 234 | 235 | p = self.decoder(x) 236 | p_m = self.mean_layer(p) 237 | p_v = torch.exp(self.var_layer(p)) 238 | return p_m, p_v 239 | 240 | 241 | 242 | class Set2Gene(nn.Module): 243 | 244 | """ 245 | Decode by linear combination of known gene set relationship between gene set (input) and genes (output). 246 | 247 | Parameters 248 | ---------- 249 | tf_gene_table: torch.Tensor 250 | number of genes x number gene sets (equal to the dimension of input) 251 | 252 | """ 253 | 254 | def __init__(self, tf_gene_table: torch.Tensor): 255 | super().__init__() 256 | self.tf_gene_table = tf_gene_table 257 | 258 | def forward(self, x: torch.Tensor): 259 | p_m = torch.mm(x.double(), self.tf_gene_table) 260 | return p_m, None 261 | 262 | 263 | class Decode2Labels(nn.Module): 264 | """ 265 | 266 | A linear classifier (logistic classifier). 267 | 268 | Parameters 269 | ---------- 270 | input_dim: integer 271 | number of input features 272 | output_dim: integer 273 | number of output features 274 | """ 275 | 276 | def __init__(self, input_dim: int, output_dim: int, bias: bool = False): 277 | super().__init__() 278 | self.predictor = nn.Sequential(nn.Linear(input_dim, output_dim, bias=bias), nn.LogSoftmax(dim=-1)) 279 | 280 | def forward(self, x: torch.Tensor): 281 | labels = self.predictor(x) 282 | return labels 283 | 284 | -------------------------------------------------------------------------------- /unifan/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import pickle 4 | 5 | from tqdm import tqdm 6 | import torch 7 | import torch.nn as nn 8 | from sklearn.metrics.cluster import adjusted_rand_score, adjusted_mutual_info_score 9 | from sklearn.model_selection import train_test_split 10 | import seaborn as sns 11 | import numpy as np 12 | import pandas as pd 13 | from matplotlib import pyplot as plt 14 | import umap 15 | 16 | torch.backends.cudnn.benchmark = True 17 | 18 | 19 | class Trainer(nn.Module): 20 | """ 21 | 22 | Train NN models. 23 | 24 | 25 | Parameters 26 | ---------- 27 | dataset: PyTorch Dataset 28 | model: Pytorch NN model 29 | model_2nd: another Pytorch NN model to be trained together 30 | model_name: string 31 | name of the model, any of "r", "pretrain_z", "annocluster", "pretrain_annotator" 32 | percent_training: float 33 | percentage of data used for training, the rest used for validation 34 | checkpoint_freq: integer 35 | frequency of saving models during training 36 | val_freq: integer 37 | frequency of conducting evaluation (both training and validation set will be evaluated) 38 | visualize_freq: integer 39 | frequency of inferring low-dimensional representation and visualize it using UMAP 40 | save_visual: boolean 41 | if conduct visualization and save the figures to the output folder 42 | save_checkpoint: boolean 43 | if saving checkpoint models during training 44 | save_infer: boolean 45 | if conduct inference (low-dimensional representation for autoencoders; clusters for clustering model; 46 | predicted results for classifiers) when the training finished 47 | output_folder: string 48 | folder to save all outputs 49 | 50 | """ 51 | def __init__(self, dataset, model, model_2nd=None, model_name: str = None, batch_size: int = 128, 52 | num_epochs: int = 50, percent_training: float = 1.0, learning_rate: float = 0.0005, 53 | decay_factor: float = 0.9, num_workers: int = 8, use_cuda: bool = False, checkpoint_freq: int = 20, 54 | val_freq: int = 10, visualize_freq: int = 10, save_visual: bool = False, save_checkpoint: bool = False, 55 | save_infer: bool = False, output_folder: str = None): 56 | 57 | super().__init__() 58 | 59 | # device 60 | self.num_workers = num_workers 61 | self.use_cuda = use_cuda 62 | self.device = torch.device("cuda" if use_cuda else "cpu") 63 | self.pin_memory = True if use_cuda else False 64 | self.non_blocking = True if use_cuda else False 65 | 66 | # model 67 | _support_models = ["r", "pretrain_z", "annocluster", "pretrain_annotator"] 68 | if model_name not in _support_models: 69 | raise NotImplementedError(f"The current implementation only support training " 70 | f"for {','.join(_support_models)}.") 71 | self.model_name = model_name 72 | self.model = model.to(self.device) 73 | self.model_2nd = model_2nd.to(self.device) if model_2nd is not None else None 74 | 75 | # optimization 76 | self.batch_size = batch_size 77 | self.num_epochs = num_epochs 78 | self.learning_rate = learning_rate 79 | self.decay_factor = decay_factor 80 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) 81 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, 1000, self.decay_factor) 82 | if model_2nd is not None: 83 | self.optimizer_2nd = torch.optim.Adam(self.model_2nd.parameters(), lr=self.learning_rate) 84 | self.scheduler_2nd = torch.optim.lr_scheduler.StepLR(self.optimizer_2nd, 1000, self.decay_factor) 85 | else: 86 | self.optimizer_2nd, self.scheduler_2nd = None, None 87 | 88 | # evaluation 89 | self.checkpoint_freq = checkpoint_freq 90 | self.val_freq = val_freq 91 | self.visual_frequency = visualize_freq 92 | self.output_folder = output_folder 93 | self.percent_training = percent_training 94 | self.save_checkpoint = save_checkpoint 95 | self.save_visual = save_visual 96 | self.save_infer = save_infer 97 | 98 | if output_folder is not None: 99 | if not os.path.exists(output_folder): 100 | os.makedirs(output_folder) 101 | 102 | # data 103 | self.dataset = dataset 104 | 105 | # prepare for data loader 106 | train_length = int(self.dataset.N * self.percent_training) 107 | val_length = self.dataset.N - train_length 108 | 109 | train_data, val_data = torch.utils.data.random_split(self.dataset, (train_length, val_length)) 110 | self.dataloader_all = torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size, shuffle=False, 111 | num_workers=self.num_workers, pin_memory=self.pin_memory) 112 | self.dataloader_train = torch.utils.data.DataLoader(train_data, batch_size=self.batch_size, shuffle=True, 113 | num_workers=self.num_workers, pin_memory=self.pin_memory) 114 | if self.percent_training != 1: 115 | self.dataloader_val = torch.utils.data.DataLoader(val_data, batch_size=self.batch_size, shuffle=True, 116 | num_workers=self.num_workers, pin_memory=self.pin_memory) 117 | else: 118 | self.dataloader_val = None 119 | 120 | # initialize functions for each model 121 | self.train_functions = {"r": self.train_epoch_r, "pretrain_z": self.train_epoch_z, 122 | "pretrain_annotator": self.train_epoch_annotator, 123 | "annocluster": self.train_epoch_cluster} 124 | self.train_stats_dicts = {"r": {k: [] for k in ['train_loss', 'val_loss', 'train_sparsity', 'train_set_cover']}, 125 | "pretrain_z": {k: [] for k in ['train_loss', 'val_loss']}, 126 | "pretrain_annotator": {k: [] for k in ['train_loss', 'val_loss']}, 127 | "annocluster": {k: [] for k in ['train_loss', 'val_loss', 'train_mse', 'train_mse_e', 128 | 'train_mse_q', 'train_prob_z_l', 'ARI', 'NMI']}} 129 | 130 | self.evaluate_functions = {"r": self.evaluate_r, "pretrain_z": self.evaluate_z, 131 | "pretrain_annotator": self.evaluate_annotator, 132 | "annocluster": self.evaluate_cluster} 133 | 134 | self.infer_functions = {"r": self.infer_r, "pretrain_z": self.infer_z, 135 | "pretrain_annotator": self.infer_annotator, 136 | "annocluster": self.infer_cluster} 137 | 138 | # initialize the parameter list (used for regularization) 139 | if self.model_name in ["pretrain_annotator", "annocluster"]: 140 | self.annotator_param_list = nn.ParameterList() 141 | if self.model_name == "pretrain_annotator": 142 | for p in self.model.named_parameters(): 143 | self.annotator_param_list.append(p[1]) 144 | else: 145 | assert self.model_2nd is not None 146 | for p in self.model_2nd.named_parameters(): 147 | self.annotator_param_list.append(p[1]) 148 | 149 | if self.model_name == "pretrain_annotator" and self.save_visual: 150 | raise NotImplementedError(f"The current implementation only support visualizing " 151 | f"for {','.join(_support_models)[:-1]}.") 152 | 153 | if self.model_name != "annocluster" and self.model_2nd is not None: 154 | raise NotImplementedError(f"The current implementation only support two models training for annocluster") 155 | 156 | def train(self, **kwargs): 157 | 158 | """ 159 | Train the model. Will save the trained model & evaluation results as default. 160 | 161 | Parameters 162 | ---------- 163 | kwargs: keyword arguements specific to each model (e.g. alpha for gene set activity scores model) 164 | 165 | """ 166 | 167 | for epoch in range(self.num_epochs): 168 | self.model.train() 169 | if self.model_2nd is not None: 170 | self.model_2nd.train() 171 | 172 | self.train_functions[self.model_name](**kwargs) 173 | 174 | if epoch % self.val_freq == 0: 175 | with torch.no_grad(): 176 | self.model.eval() 177 | if self.model_2nd is not None: 178 | self.model_2nd.eval() 179 | 180 | self.evaluate_functions[self.model_name](**kwargs) 181 | 182 | if self.save_visual and epoch % self.visual_frequency == 0: 183 | with torch.no_grad(): 184 | self.model.eval() 185 | if self.model_2nd is not None: 186 | self.model_2nd.eval() 187 | if self.model_name == "annocluster": 188 | _X, _clusters = self.infer_functions[self.model_name](**kwargs) 189 | else: 190 | _X = self.infer_functions[self.model_name](**kwargs) 191 | _clusters = None 192 | self.visualize_UMAP(_X, epoch, self.output_folder, clusters_true=self.dataset.clusters_true, 193 | clusters_pre=_clusters) 194 | 195 | # save model & inference 196 | if self.save_checkpoint and epoch % self.checkpoint_freq == 0: 197 | # save model 198 | _state = {'epoch': epoch, 199 | 'state_dict': self.model.state_dict(), 200 | 'optimizer': self.optimizer.state_dict()} 201 | if self.model_2nd is not None: 202 | _state = {'epoch': epoch, 203 | 'state_dict': self.model.state_dict(), 204 | 'optimizer': self.optimizer.state_dict(), 205 | 'state_dict_2': self.model_2nd.state_dict(), 206 | 'optimizer_2': self.optimizer_2nd.state_dict()} 207 | 208 | torch.save(_state, os.path.join(self.output_folder, f"{self.model_name}_model_{epoch}.pickle")) 209 | 210 | del _state 211 | gc.collect() 212 | 213 | # save model 214 | _state = {'epoch': epoch, 215 | 'state_dict': self.model.state_dict(), 216 | 'optimizer': self.optimizer.state_dict()} 217 | if self.model_2nd is not None: 218 | _state = {'epoch': epoch, 219 | 'state_dict': self.model.state_dict(), 220 | 'optimizer': self.optimizer.state_dict(), 221 | 'state_dict_2': self.model_2nd.state_dict(), 222 | 'optimizer_2': self.optimizer_2nd.state_dict()} 223 | 224 | torch.save(_state, os.path.join(self.output_folder, f"{self.model_name}_model_{epoch}.pickle")) 225 | 226 | del _state 227 | gc.collect() 228 | 229 | # inference and save 230 | if self.save_infer: 231 | if self.model_name == "annocluster": 232 | _X, _clusters = self.infer_functions[self.model_name](**kwargs) 233 | np.save(os.path.join(self.output_folder, f"{self.model_name}_clusters_pre_{epoch}.npy"), _clusters) 234 | else: 235 | _X = self.infer_functions[self.model_name](**kwargs) 236 | 237 | np.save(os.path.join(self.output_folder, f"{self.model_name}_{epoch}.npy"), _X) 238 | 239 | # save training stats 240 | with open(os.path.join(self.output_folder, f"stats_{epoch}.pickle"), 'wb') as handle: 241 | pickle.dump(self.train_stats_dicts[self.model_name], handle, protocol=pickle.HIGHEST_PROTOCOL) 242 | 243 | def train_epoch_r(self, **kwargs): 244 | for batch_idx, (X_batch) in enumerate(tqdm(self.dataloader_train)): 245 | self.process_minibatch_r(X_batch=X_batch, **kwargs) 246 | 247 | def train_epoch_z(self, **kwargs): 248 | for batch_idx, (X_batch) in enumerate(tqdm(self.dataloader_train)): 249 | self.process_minibatch_z(X_batch=X_batch, **kwargs) 250 | 251 | def train_epoch_annotator(self, **kwargs): 252 | for batch_idx, (X_batch, y_batch) in enumerate(tqdm(self.dataloader_train)): 253 | y_batch = torch.flatten(y_batch) 254 | self.process_minibatch_annotator(X_batch, y_batch, **kwargs) 255 | 256 | def train_epoch_cluster(self, **kwargs): 257 | for batch_idx, (X_batch, gene_set_batch) in enumerate(tqdm(self.dataloader_train)): 258 | self.process_minibatch_cluster(X_batch, gene_set_batch, **kwargs) 259 | 260 | def process_minibatch_cluster(self, X_batch, gene_set_batch, weight_decay: float = 0): 261 | """ 262 | Process minibatch for the annocluster model. 263 | 264 | Parameters 265 | ---------- 266 | X_batch: torch.Tensor 267 | gene expression 268 | gene_set_batch: torch.Tensor 269 | gene set activity scores 270 | weight_decay: float 271 | hyperparameter gamma regularizing exclusive lasso penalty 272 | 273 | """ 274 | X_batch = X_batch.to(self.device, non_blocking=self.non_blocking).float() 275 | gene_set_batch = gene_set_batch.to(self.device, non_blocking=self.non_blocking).float() 276 | 277 | if self.model.training and self.model_2nd.training: 278 | self.optimizer.zero_grad(set_to_none=True) 279 | self.optimizer_2nd.zero_grad(set_to_none=True) 280 | 281 | x_e, x_q, z_e, z_q, k, z_dist, dist_prob = self.model(X_batch) 282 | y_pre = self.model_2nd(gene_set_batch) 283 | pre_prob = torch.exp(y_pre) 284 | 285 | l = self.model.loss(X_batch, x_e, x_q, z_dist, prior_prob=pre_prob) 286 | l_prob = self.model_2nd.loss(y_pre, k) 287 | 288 | l_prob += weight_decay * torch.sum(torch.square(torch.sum(torch.abs(self.annotator_param_list[0]), dim=0))) 289 | 290 | if self.model.training and self.model_2nd.training: 291 | l.backward(retain_graph=True) 292 | self.optimizer.step() 293 | 294 | l_prob.backward(retain_graph=True) 295 | self.optimizer_2nd.step() 296 | 297 | self.scheduler.step() 298 | self.scheduler_2nd.step() 299 | 300 | return l.detach().cpu().item(), l_prob.detach().cpu().numpy(), k.detach().cpu().numpy(), \ 301 | z_e.detach().cpu().numpy() 302 | 303 | def evaluate_cluster(self, **kwargs): 304 | """ 305 | 306 | Evaluate the total loss and each loss term for the training data and the total loss for the validation set 307 | for the annotcluster model. 308 | 309 | """ 310 | 311 | 312 | train_stats = self.train_stats_dicts[self.model_name] 313 | 314 | _loss = [] 315 | _l_e = [] 316 | _l_q = [] 317 | _mse_l = [] 318 | _prob_z_l = [] 319 | 320 | for batch_idx, (X_batch, gene_set_batch) in enumerate(tqdm(self.dataloader_train)): 321 | X_batch = X_batch.to(self.device, non_blocking=self.non_blocking).float() 322 | gene_set_batch = gene_set_batch.to(self.device, non_blocking=self.non_blocking).float() 323 | 324 | x_e, x_q, z_e, z_q, k, z_dist, dist_prob = self.model(X_batch) 325 | y_pre = self.model_2nd(gene_set_batch) 326 | pre_prob = torch.exp(y_pre) 327 | 328 | l = self.model.loss(X_batch, x_e, x_q, z_dist, prior_prob=pre_prob).detach().cpu().item() 329 | l_prob = self.model_2nd.loss(y_pre, k).detach().cpu().item() 330 | 331 | # calculate each loss 332 | l_e = self.model.mse_loss(X_batch, x_e).detach().cpu().item() 333 | l_q = self.model.mse_loss(X_batch, x_q).detach().cpu().item() 334 | mse_l = self.model._loss_reconstruct(X_batch, x_e, x_q).detach().cpu().item() 335 | prob_z_l = self.model._loss_z_prob(z_dist, prior_prob=pre_prob).detach().cpu().item() 336 | 337 | _loss.append(l) 338 | _l_e.append(l_e) 339 | _l_q.append(l_q) 340 | _mse_l.append(mse_l) 341 | _prob_z_l.append(prob_z_l) 342 | 343 | train_stats['train_loss'].append(np.mean(_loss)) 344 | train_stats['train_mse'].append(np.mean(_mse_l)) 345 | train_stats['train_mse_e'].append(np.mean(_l_e)) 346 | train_stats['train_mse_q'].append(np.mean(_l_q)) 347 | train_stats['train_prob_z_l'].append(np.mean(_prob_z_l)) 348 | 349 | if self.percent_training != 1: 350 | _loss = [] 351 | for batch_idx, (X_batch, gene_set_batch) in enumerate(tqdm(self.dataloader_val)): 352 | l, _, _, _ = self.process_minibatch_cluster(X_batch, gene_set_batch, **kwargs) 353 | _loss.append(l) 354 | 355 | train_stats['val_loss'].append(np.mean(_loss)) 356 | else: 357 | train_stats['val_loss'].append(np.nan) 358 | 359 | def infer_cluster(self, **kwargs): 360 | """ 361 | 362 | Get z_e and clusters from the annocluster model. Also calculate ARI and NMI scores comparing with 363 | the ground truth (if available). 364 | 365 | Returns 366 | ------- 367 | z_annocluster: numpy array 368 | z_e of cells 369 | clusters_pre: numpy array 370 | cluster assignments 371 | 372 | """ 373 | train_stats = self.train_stats_dicts[self.model_name] 374 | 375 | with torch.no_grad(): 376 | self.model.eval() 377 | if self.model_2nd is not None: 378 | self.model_2nd.eval() 379 | 380 | # inference 381 | k_list = [] 382 | z_e_list = [] 383 | 384 | for batch_idx, (X_batch, gene_set_batch) in enumerate(tqdm(self.dataloader_all)): 385 | _, _, k, z_e = self.process_minibatch_cluster(X_batch, gene_set_batch, **kwargs) 386 | k_list.append(k) 387 | z_e_list.append(z_e) 388 | 389 | clusters_pre = np.concatenate(k_list) 390 | z_annocluster = np.concatenate(z_e_list) 391 | 392 | if self.dataset.clusters_true is not None: 393 | if self.dataset.N > 5e4: 394 | idx_stratified, _ = train_test_split(range(self.dataset.N), test_size=0.5, 395 | stratify=self.dataset.clusters_true) 396 | else: 397 | idx_stratified = range(self.dataset.N) 398 | 399 | # metrics 400 | ari_smaller = adjusted_rand_score(clusters_pre[idx_stratified], 401 | self.dataset.clusters_true[idx_stratified]) 402 | nmi_smaller = adjusted_mutual_info_score(clusters_pre, self.dataset.clusters_true) 403 | print(f"annocluster: ARI for smaller cluster: {ari_smaller}") 404 | print(f"annocluster: NMI for smaller cluster: {nmi_smaller}") 405 | else: 406 | ari_smaller = np.nan 407 | nmi_smaller = np.nan 408 | 409 | train_stats["ARI"].append(ari_smaller) 410 | train_stats["NMI"].append(nmi_smaller) 411 | 412 | return z_annocluster, clusters_pre 413 | 414 | 415 | def process_minibatch_annotator(self, X_batch, y_batch, weight_decay: float = 0): 416 | 417 | """ 418 | Process minibatch for the annotator model. 419 | 420 | Parameters 421 | ---------- 422 | X_batch: torch.Tensor 423 | gene expression 424 | y_batch: torch.Tensor 425 | cluster assignment 426 | weight_decay: float 427 | hyperparameter gamma regularizing exclusive lasso penalty 428 | 429 | """ 430 | 431 | X_batch = X_batch.to(self.device, non_blocking=self.non_blocking).float() 432 | y_batch = y_batch.to(self.device, non_blocking=self.non_blocking) 433 | 434 | if self.model.training: 435 | self.optimizer.zero_grad(set_to_none=True) 436 | 437 | y_pre = self.model(X_batch) 438 | l = self.model.loss(y_pre, y_batch) 439 | 440 | l += weight_decay * torch.sum(torch.square(torch.sum(torch.abs(self.annotator_param_list[0]), dim=0))) 441 | 442 | if self.model.training: 443 | l.backward() 444 | self.optimizer.step() 445 | self.scheduler.step() 446 | 447 | return l.detach().cpu().item(), y_pre.detach().cpu().numpy() 448 | 449 | def evaluate_annotator(self, **kwargs): 450 | train_stats = self.train_stats_dicts[self.model_name] 451 | _loss = [] 452 | for batch_idx, (X_batch, y_batch) in enumerate(tqdm(self.dataloader_train)): 453 | y_batch = torch.flatten(y_batch) 454 | l, _ = self.process_minibatch_annotator(X_batch, y_batch, **kwargs) 455 | _loss.append(l) 456 | 457 | train_stats['train_loss'].append(np.mean(_loss)) 458 | 459 | if self.percent_training != 1: 460 | _loss = [] 461 | for batch_idx, (X_batch, y_batch) in enumerate(tqdm(self.dataloader_val)): 462 | y_batch = torch.flatten(y_batch) 463 | l, _ = self.process_minibatch_annotator(X_batch, y_batch, **kwargs) 464 | _loss.append(l) 465 | 466 | train_stats['val_loss'].append(np.mean(_loss)) 467 | else: 468 | train_stats['val_loss'].append(np.nan) 469 | 470 | def infer_annotator(self): 471 | """ 472 | 473 | Get the prediction of labels on the all data. 474 | 475 | Return 476 | ------ 477 | clusters_classifier: numpy array 478 | predicted labels from the trained annotator 479 | 480 | """ 481 | with torch.no_grad(): 482 | self.model.eval() 483 | tf_prob = self.model( 484 | torch.from_numpy(self.dataset.data).to(self.device, non_blocking=self.non_blocking).float()) 485 | 486 | clusters_prob_pre = torch.exp(tf_prob) 487 | clusters_classifier = np.argmax(clusters_prob_pre.detach().cpu().numpy(), axis=1) 488 | 489 | return clusters_classifier 490 | 491 | def process_minibatch_r(self, X_batch, alpha: float = 0, beta: float = 0, beta_list: torch.Tensor = None, 492 | gene_covered_matrix: torch.Tensor = None): 493 | """ 494 | 495 | Process minibatch for gene set activity scores model (named as r). 496 | 497 | Parameters 498 | ---------- 499 | X_batch: torch.Tensor 500 | gene expression 501 | alpha: float 502 | hyperparameter regularizing L1 term in the set cover loss 503 | beta: float 504 | hyperparameter regularizing set loss term in the set cover loss 505 | beta_list: torch.Tensor 506 | beta values for all genes 507 | gene_covered_matrix: torch.Tensor 508 | gene set membership matrix with genes that are at least covered by one of the available sets 509 | 510 | """ 511 | 512 | 513 | X_batch = X_batch.to(self.device, non_blocking=self.non_blocking).float() 514 | 515 | if self.model.training: 516 | self.optimizer.zero_grad(set_to_none=True) 517 | 518 | x_e, z_e = self.model(X_batch) 519 | l = self.model.loss(X_batch.float(), x_e.float()) 520 | 521 | # add sparsity regularization on the output 522 | if alpha != 0: 523 | sparsity_penalty = alpha * torch.mean(torch.abs(z_e)) 524 | l += sparsity_penalty 525 | else: 526 | sparsity_penalty = torch.zeros(1) 527 | 528 | if beta != 0: 529 | cover_penality = - torch.mean(torch.matmul(torch.mm(z_e, gene_covered_matrix), 530 | beta_list)) 531 | l += cover_penality 532 | else: 533 | cover_penality = torch.zeros(1) 534 | 535 | if self.model.training: 536 | l.backward() 537 | self.optimizer.step() 538 | self.scheduler.step() 539 | 540 | return l.detach().cpu().item(), z_e.detach().cpu().numpy(), \ 541 | sparsity_penalty.detach().cpu().numpy(), cover_penality.detach().cpu().numpy() 542 | 543 | def evaluate_r(self, **kwargs): 544 | train_stats = self.train_stats_dicts[self.model_name] 545 | _loss = [] 546 | _sp = [] 547 | _cp = [] 548 | 549 | for batch_idx, (X_batch) in enumerate(tqdm(self.dataloader_train)): 550 | l, _, sp, cp = self.process_minibatch_r(X_batch, **kwargs) 551 | _loss.append(l) 552 | _sp.append(sp) 553 | _cp.append(cp) 554 | 555 | train_stats['train_loss'].append(np.mean(_loss)) 556 | train_stats['train_sparsity'].append(np.mean(_sp)) 557 | train_stats['train_set_cover'].append(np.mean(_cp)) 558 | 559 | if self.percent_training != 1: 560 | _loss = [] 561 | for batch_idx, (X_batch) in enumerate(tqdm(self.dataloader_val)): 562 | l, _, _, _ = self.process_minibatch_r(X_batch, **kwargs) 563 | _loss.append(l) 564 | 565 | train_stats['val_loss'].append(np.mean(_loss)) 566 | else: 567 | train_stats['val_loss'].append(np.nan) 568 | 569 | def infer_r(self, **kwargs): 570 | 571 | """ 572 | 573 | Returns 574 | ------- 575 | z_gene_set: numpy array 576 | gene set activity scores 577 | 578 | """ 579 | with torch.no_grad(): 580 | self.model.eval() 581 | 582 | z_e_list = [] 583 | _loss = [] 584 | print(f"Start inferring ...") 585 | for batch_idx, (X_batch) in enumerate(self.dataloader_all): 586 | l, z_e, _, _ = self.process_minibatch_r(X_batch, **kwargs) 587 | 588 | z_e_list.append(z_e) 589 | _loss.append(l) 590 | 591 | print(f"Finish inferring ...") 592 | z_gene_set = np.concatenate(z_e_list) 593 | 594 | del z_e_list 595 | del _loss 596 | del z_e 597 | 598 | gc.collect() 599 | 600 | return z_gene_set 601 | 602 | def process_minibatch_z(self, X_batch): 603 | 604 | """ 605 | Process minibatch for pretraining the autocluster model (named as pretrain_z) 606 | 607 | """ 608 | 609 | X_batch = X_batch.to(self.device, non_blocking=self.non_blocking).float() 610 | 611 | if self.model.training: 612 | self.optimizer.zero_grad(set_to_none=True) 613 | 614 | x_e, z_e = self.model(X_batch) 615 | 616 | l = self.model.loss(X_batch.float(), x_e.float()) 617 | 618 | if self.model.training: 619 | l.backward() 620 | self.optimizer.step() 621 | self.scheduler.step() 622 | return l.detach().cpu().item(), z_e.detach().cpu().numpy() 623 | 624 | def evaluate_z(self, **kwargs): 625 | train_stats = self.train_stats_dicts[self.model_name] 626 | 627 | _loss = [] 628 | for batch_idx, (X_batch) in enumerate(tqdm(self.dataloader_train)): 629 | l, _ = self.process_minibatch_z(X_batch) 630 | _loss.append(l) 631 | 632 | train_stats['train_loss'].append(np.mean(_loss)) 633 | 634 | if self.percent_training != 1: 635 | _loss = [] 636 | for batch_idx, (X_batch) in enumerate(tqdm(self.dataloader_val)): 637 | l, _ = self.process_minibatch_z(X_batch) 638 | _loss.append(l) 639 | 640 | train_stats['val_loss'].append(np.mean(_loss)) 641 | else: 642 | train_stats['val_loss'].append(np.nan) 643 | 644 | def infer_z(self): 645 | 646 | """ 647 | 648 | Returns 649 | ------- 650 | z_init: numpy array 651 | z_e from the pretrain model 652 | 653 | """ 654 | 655 | with torch.no_grad(): 656 | self.model.eval() 657 | 658 | z_e_list = [] 659 | 660 | for batch_idx, (X_batch) in enumerate(tqdm(self.dataloader_all)): 661 | _, z_e = self.process_minibatch_z(X_batch) 662 | z_e_list.append(z_e) 663 | 664 | z_init = np.concatenate(z_e_list) 665 | 666 | del z_e_list 667 | gc.collect() 668 | 669 | return z_init 670 | 671 | @staticmethod 672 | def visualize_UMAP(X, epoch:int, output_folder: str, clusters_true=None, clusters_pre=None, 673 | color_palette:str = "tab20"): 674 | """ 675 | 676 | Visualize the low-dimensional representations using UMAP and save figures. 677 | 678 | Parameters 679 | ---------- 680 | X: numpy array 681 | low-dimensional representations 682 | epoch: integer 683 | epoch of the model based on which the low-dimensional representations is inferred 684 | clusters_true: numpy array 685 | ground truth labels 686 | clusters_pre: numpy array 687 | cluster assignment 688 | 689 | """ 690 | 691 | print(f"Start visualizing using UMAP...") 692 | umap_original = umap.UMAP().fit_transform(X) 693 | 694 | # color by cluster 695 | hues = {'label': clusters_true, 'cluster': clusters_pre} 696 | 697 | for k, v in hues.items(): 698 | df_plot = pd.DataFrame(umap_original) 699 | if v is None: 700 | df_plot['label'] = np.repeat("Label not available", df_plot.shape[0]) 701 | else: 702 | df_plot['label'] = v 703 | df_plot['label'].astype('str') 704 | df_plot.columns = ['dim_1', 'dim_2', 'label'] 705 | 706 | plt.figure(figsize=(10, 10)) 707 | sns.scatterplot(x='dim_1', y='dim_2', hue='label', data=df_plot, palette=color_palette, 708 | legend=True) 709 | plt.title(f"Encoding (r) colored by {k}") 710 | plt.savefig(os.path.join(output_folder, f"r_{epoch}_{k}.png"), bbox_inches="tight", format="png") 711 | plt.close() 712 | 713 | del X 714 | del umap_original 715 | del df_plot 716 | 717 | gc.collect() 718 | 719 | -------------------------------------------------------------------------------- /unifan/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper classes / functions. 3 | 4 | """ 5 | import os 6 | import argparse 7 | import gc 8 | 9 | import numpy as np 10 | 11 | 12 | def str2bool(v): 13 | """ 14 | Helper to pass boolean arguements. 15 | Extracted from: https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 16 | Author: @Maxim 17 | """ 18 | 19 | if isinstance(v, bool): 20 | return v 21 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 22 | return True 23 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 24 | return False 25 | else: 26 | raise argparse.ArgumentTypeError('Boolean value expected.') 27 | 28 | 29 | def gen_tf_gene_table(genes, tf_list, dTD): 30 | """ 31 | 32 | Adapted from: 33 | Author: Jun Ding 34 | Project: SCDIFF2 35 | Ref: Ding, J., Aronow, B. J., Kaminski, N., Kitzmiller, J., Whitsett, J. A., & Bar-Joseph, Z. 36 | (2018). Reconstructing differentiation networks and their regulation from time series 37 | single-cell expression data. Genome research, 28(3), 383-395. 38 | 39 | """ 40 | gene_names = [g.upper() for g in genes] 41 | TF_names = [g.upper() for g in tf_list] 42 | tf_gene_table = dict.fromkeys(tf_list) 43 | 44 | for i, tf in enumerate(tf_list): 45 | tf_gene_table[tf] = np.zeros(len(gene_names)) 46 | _genes = dTD[tf] 47 | 48 | _existed_targets = list(set(_genes).intersection(gene_names)) 49 | _idx_targets = map(lambda x: gene_names.index(x), _existed_targets) 50 | 51 | for _g in _idx_targets: 52 | tf_gene_table[tf][_g] = 1 53 | 54 | del gene_names 55 | del TF_names 56 | del _genes 57 | del _existed_targets 58 | del _idx_targets 59 | 60 | gc.collect() 61 | 62 | return tf_gene_table 63 | 64 | 65 | def getGeneSetMatrix(_name, genes_upper, gene_sets_path): 66 | """ 67 | 68 | Adapted from: 69 | Author: Jun Ding 70 | Project: SCDIFF2 71 | Ref: Ding, J., Aronow, B. J., Kaminski, N., Kitzmiller, J., Whitsett, J. A., & Bar-Joseph, Z. 72 | (2018). Reconstructing differentiation networks and their regulation from time series 73 | single-cell expression data. Genome research, 28(3), 383-395. 74 | 75 | """ 76 | if _name[-3:] == 'gmt': 77 | print(f"GMT file {_name} loading ... ") 78 | filename = _name 79 | filepath = os.path.join(gene_sets_path, f"{filename}") 80 | 81 | with open(filepath) as genesets: 82 | pathway2gene = {line.strip().split("\t")[0]: line.strip().split("\t")[2:] 83 | for line in genesets.readlines()} 84 | 85 | print(len(pathway2gene)) 86 | 87 | gs = [] 88 | for k, v in pathway2gene.items(): 89 | gs += v 90 | 91 | print(f"Number of genes in {_name} {len(set(gs).intersection(genes_upper))}") 92 | 93 | pathway_list = pathway2gene.keys() 94 | pathway_gene_table = gen_tf_gene_table(genes_upper, pathway_list, pathway2gene) 95 | gene_set_matrix = np.array(list(pathway_gene_table.values())) 96 | keys = pathway_gene_table.keys() 97 | 98 | del pathway2gene 99 | del gs 100 | del pathway_list 101 | del pathway_gene_table 102 | 103 | gc.collect() 104 | 105 | 106 | elif _name == 'TF-DNA': 107 | 108 | # get TF-DNA dictionary 109 | # TF->DNA 110 | def getdTD(tfDNA): 111 | dTD = {} 112 | with open(tfDNA, 'r') as f: 113 | tfRows = f.readlines() 114 | tfRows = [item.strip().split() for item in tfRows] 115 | for row in tfRows: 116 | itf = row[0].upper() 117 | itarget = row[1].upper() 118 | if itf not in dTD: 119 | dTD[itf] = [itarget] 120 | else: 121 | dTD[itf].append(itarget) 122 | 123 | del tfRows 124 | del itf 125 | del itarget 126 | gc.collect() 127 | 128 | return dTD 129 | 130 | from collections import defaultdict 131 | 132 | def getdDT(dTD): 133 | gene_tf_dict = defaultdict(lambda: []) 134 | for key, val in dTD.items(): 135 | for v in val: 136 | gene_tf_dict[v.upper()] += [key.upper()] 137 | 138 | return gene_tf_dict 139 | 140 | tfDNA_file = os.path.join(gene_sets_path, f"Mouse_TF_targets.txt") 141 | dTD = getdTD(tfDNA_file) 142 | dDT = getdDT(dTD) 143 | 144 | tf_list = list(sorted(dTD.keys())) 145 | tf_list.remove('TF') 146 | 147 | tf_gene_table = gen_tf_gene_table(genes_upper, tf_list, dTD) 148 | gene_set_matrix = np.array(list(tf_gene_table.values())) 149 | keys = tf_gene_table.keys() 150 | 151 | del dTD 152 | del dDT 153 | del tf_list 154 | del tf_gene_table 155 | 156 | gc.collect() 157 | 158 | else: 159 | gene_set_matrix = None 160 | 161 | return gene_set_matrix, keys 162 | 163 | 164 | 165 | --------------------------------------------------------------------------------