├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── data ├── VISp_PERSIST_metadata.csv └── VISp_markers.toml ├── notebooks ├── 00_data_proc.ipynb ├── 01_persist_supervised.ipynb ├── 02_persist_unsupervised.ipynb ├── 03_persist_pbmc3k_scanpy.ipynb └── demo.ipynb ├── persist ├── __init__.py ├── data.py ├── layers.py ├── models.py ├── selection.py └── utils.py ├── setup.py └── tests └── test_expression_dataset.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-language=Python 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | .DS_Store 4 | .ipynb_checkpoints 5 | .eggs 6 | *.egg-info 7 | build 8 | dist 9 | propose.egg-info 10 | /*.ipynb 11 | *.pkl 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Ian Covert 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PERSIST 2 | 3 | **Predictive and robust gene selection for spatial transcriptomics (PERSIST)** is a computational approach to select target genes for FISH studies. PERSIST relies on a reference scRNA-seq dataset, and it uses deep learning to identify genes that are predictive of the genome-wide expression profile or any other target of interest (e.g., transcriptomic cell types). PERSIST also binarizes gene expression levels during the selection process, which helps account for the measurement shift between expression counts obtained by scRNA-seq and FISH. 4 | 5 | See the related [publication](https://www.nature.com/articles/s41467-023-37392-1) for details. 6 | 7 | To cite: 8 | ```bib 9 | @article{covert2023predictive, 10 | title={Predictive and robust gene selection for spatial transcriptomics}, 11 | author={Covert, Ian and Gala, Rohan and Wang, Tim and Svoboda, Karel and S{\"u}mb{\"u}l, Uygar and Lee, Su-In}, 12 | journal={Nature Communications}, 13 | volume={14}, 14 | number={1}, 15 | pages={2091}, 16 | year={2023}, 17 | publisher={Nature Publishing Group UK London} 18 | } 19 | ``` 20 | 21 | ## Installation 22 | 23 | You can install the package by cloning the repository and pip installing it as follows: 24 | 25 | ```bash 26 | pip install -e . 27 | ``` 28 | 29 | This will automatically install any missing dependendencies, which may take a couple minutes. Please be sure to install a version of [PyTorch](https://pytorch.org/get-started/locally/) that is compatible with your GPU, as we highly recommend using a GPU to accelerate training. 30 | 31 | ## Usage 32 | 33 | PERSIST is designed to offer flexibility while requiring minimal tuning. For a demonstration of how it's used, from data preparation through gene selection, please see the following Jupyter notebooks: 34 | 35 | - [00_data_proc.ipynb](https://github.com/iancovert/persist/blob/main/notebooks/00_data_proc.ipynb) shows how to download and pre-process one of the datasets used in our paper (the VISp SmartSeq v4 dataset from [Tasic et al., 2018](https://www.nature.com/articles/s41586-018-0654-5)) 36 | - [01_persist_supervised.ipynb](https://github.com/iancovert/persist/blob/main/notebooks/01_persist_supervised.ipynb) shows how to use PERSIST to select genes that are maximally predictive of cell type labels (the supervised case) 37 | - [02_persist_unsupervised.ipynb](https://github.com/iancovert/persist/blob/main/notebooks/02_persist_unsupervised.ipynb) shows how to use PERSIST to select genes that are maximally predictive of the genome-wide expression profile (the unsupervised case) 38 | - [03_persist_pbmc3k_scanpy.ipynb](https://github.com/iancovert/persist/blob/main/notebooks/03_persist_pbmc3k_scanpy.ipynb) supervised gene selection for a pbmc dataset downloaded directly from scanpy. 39 | -------------------------------------------------------------------------------- /data/VISp_markers.toml: -------------------------------------------------------------------------------- 1 | # Marker gene list was obtained from Tasic et al. 2018 as displayed in Figures. 4c, 5e, and 5f: 2 | markers = [ "Slc30a3", "Cux2", "Rorb", "Deptor", "Scnn1a", "Rspo1", "Hsd11b1", "Batf3", "Oprk1", "Osr1", "Car3", "Fam84b", "Chrna6", "Pvalb", "Pappa2", "Foxp2", "Slc17a8", "Trhr", "Tshz2", "Rapgef3", "Trh", "Gpr139", "Nxph4", "Rprm", "Crym", "Sst", "Chodl", "Nos1", "Mme", "Tac1", "Tacr3", "Calb2", "Nr2f2", "Myh8", "Tac2", "Hpse", "Crhr2", "Crh", "Esm1", "Rxfp1", "Nts", "Gabrg1", "Th", "Calb1", "Akr1c18", "Sema3e", "Gpr149", "Reln", "Tpbg", "Cpne5", "Vipr2", "Nkx2-1", "Lamp5", "Ndnf", "Krt73", "Fam19a1", "Pax6", "Ntn1", "Plch2", "Lsp1", "Lhx6", "Vip", "Sncg", "Nptx2", "Gpr50", "Itih5", "Serpinf1", "Igfbp6", "Gpc3", "Lmo1", "Ptprt", "Rspo4", "Chat", "Crispld2", "Col15a1", "Pde1a",] 3 | -------------------------------------------------------------------------------- /notebooks/00_data_proc.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "758da1a7-f1fa-4001-8f35-9a3ad32560af", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "import pandas as pd\n", 12 | "import anndata as ad\n", 13 | "import scanpy as sc\n", 14 | "from scipy.sparse import csr_matrix" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "id": "406f7632", 20 | "metadata": {}, 21 | "source": [ 22 | "**Count matrix and metadata for VISp dataset**\n", 23 | " - Download the count data from Allen Institute portal\n", 24 | " - Convert to AnnData format - see [this getting started with AnnData tutorial](https://anndata-tutorials.readthedocs.io/en/latest/getting-started.html)\n", 25 | " - Save the resulting object for later use" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "id": "a3a398cd", 32 | "metadata": {}, 33 | "outputs": [ 34 | { 35 | "name": "stdout", 36 | "output_type": "stream", 37 | "text": [ 38 | " % Total % Received % Xferd Average Speed Time Time Time Current\n", 39 | " Dload Upload Total Spent Left Speed\n", 40 | "100 291M 100 291M 0 0 51.1M 0 0:00:05 0:00:05 --:--:-- 53.1M\n", 41 | "Archive: ../data/VISp.zip\n", 42 | " inflating: ../data/VISp/mouse_VISp_2018-06-14_exon-matrix.csv \n", 43 | " inflating: ../data/VISp/mouse_VISp_2018-06-14_genes-rows.csv \n", 44 | " inflating: ../data/VISp/mouse_VISp_2018-06-14_intron-matrix.csv \n", 45 | " inflating: ../data/VISp/mouse_VISp_2018-06-14_readme.txt \n", 46 | " inflating: ../data/VISp/mouse_VISp_2018-06-14_samples-columns.csv \n" 47 | ] 48 | } 49 | ], 50 | "source": [ 51 | "# Download count matrices from https://portal.brain-map.org/atlases-and-data/rnaseq/mouse-v1-and-alm-smart-seq or use the shell commands below.\n", 52 | "!curl -o ../data/VISp.zip https://celltypes.brain-map.org/api/v2/well_known_file_download/694413985\n", 53 | "!unzip -d ../data/VISp ../data/VISp.zip\n", 54 | "!rm ../data/VISp.zip" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 3, 60 | "id": "c577e965", 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "data": { 65 | "text/html": [ 66 | "
\n", 67 | "\n", 80 | "\n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | "
seq_nameclasssubclasscluster
sample_id
F1S4_160108_001_A01LS-15006_S09_E1-50GABAergicVipVip Arhgap36 Hmcn1
F1S4_160108_001_B01LS-15006_S10_E1-50GABAergicLamp5Lamp5 Lsp1
F1S4_160108_001_C01LS-15006_S11_E1-50GABAergicLamp5Lamp5 Lsp1
F1S4_160108_001_D01LS-15006_S12_E1-50GABAergicVipVip Crispld2 Htr2c
F1S4_160108_001_E01LS-15006_S13_E1-50GABAergicLamp5Lamp5 Plch2 Dock5
\n", 135 | "
" 136 | ], 137 | "text/plain": [ 138 | " seq_name class subclass \\\n", 139 | "sample_id \n", 140 | "F1S4_160108_001_A01 LS-15006_S09_E1-50 GABAergic Vip \n", 141 | "F1S4_160108_001_B01 LS-15006_S10_E1-50 GABAergic Lamp5 \n", 142 | "F1S4_160108_001_C01 LS-15006_S11_E1-50 GABAergic Lamp5 \n", 143 | "F1S4_160108_001_D01 LS-15006_S12_E1-50 GABAergic Vip \n", 144 | "F1S4_160108_001_E01 LS-15006_S13_E1-50 GABAergic Lamp5 \n", 145 | "\n", 146 | " cluster \n", 147 | "sample_id \n", 148 | "F1S4_160108_001_A01 Vip Arhgap36 Hmcn1 \n", 149 | "F1S4_160108_001_B01 Lamp5 Lsp1 \n", 150 | "F1S4_160108_001_C01 Lamp5 Lsp1 \n", 151 | "F1S4_160108_001_D01 Vip Crispld2 Htr2c \n", 152 | "F1S4_160108_001_E01 Lamp5 Plch2 Dock5 " 153 | ] 154 | }, 155 | "execution_count": 3, 156 | "metadata": {}, 157 | "output_type": "execute_result" 158 | } 159 | ], 160 | "source": [ 161 | "# Load VISp dataset\n", 162 | "filename = '../data/VISp/mouse_VISp_2018-06-14_exon-matrix.csv'\n", 163 | "expr_df = pd.read_csv(filename, header=0, index_col=0, delimiter=',').transpose()\n", 164 | "expr = expr_df.values\n", 165 | "\n", 166 | "# Find gene names\n", 167 | "filename = '../data/VISp/mouse_VISp_2018-06-14_genes-rows.csv'\n", 168 | "genes_df = pd.read_csv(filename, header=0, index_col=0, delimiter=',')\n", 169 | "gene_symbol = genes_df.index.values\n", 170 | "gene_ids = genes_df['gene_entrez_id'].values\n", 171 | "gene_names = np.array([gene_symbol[np.where(gene_ids == name)[0][0]] for name in expr_df.columns])\n", 172 | "\n", 173 | "# Get metadata and save restrict to relevant fields\n", 174 | "filename = '../data/VISp/mouse_VISp_2018-06-14_samples-columns.csv'\n", 175 | "obs = pd.read_csv(filename, header=0, index_col=0, delimiter=',', encoding='iso-8859-1')\n", 176 | "\n", 177 | "obs = obs.reset_index()\n", 178 | "obs = obs[['sample_name','seq_name','class','subclass','cluster']]\n", 179 | "obs = obs.rename(columns={'sample_name':'sample_id'})\n", 180 | "obs = obs.set_index('sample_id')\n", 181 | "obs.head()" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 4, 187 | "id": "2443ffbd", 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "# compose and store anndata object for efficient read/write\n", 192 | "adata = ad.AnnData(X=csr_matrix(expr))\n", 193 | "adata.var_names = gene_names\n", 194 | "adata.var.index.set_names('genes', inplace=True)\n", 195 | "adata.obs = obs\n", 196 | "adata.write('../data/VISp.h5ad')" 197 | ] 198 | }, 199 | { 200 | "cell_type": "markdown", 201 | "id": "c5a11dd0", 202 | "metadata": {}, 203 | "source": [ 204 | "**Filtering samples**\n", 205 | "\n", 206 | "The next code block is optional, and requires `VISp_PERSIST_metadata.csv` which contains:\n", 207 | "- cell type labels at different resolutions of the taxonomy (see manuscript for details)\n", 208 | "- sample ids to filter out non-neuronal cells\n", 209 | "\n", 210 | "In the following, we will\n", 211 | "1. restrict cells only to those samples specified in `VISp_PERSIST_metadata.csv`\n", 212 | "2. append metadata from `VISp_PERSIST_metadata.csv` to the AnnData object\n", 213 | "3. normalize counts, determine highly variable genes using scanpy functions\n", 214 | "3. save a filtered AnnData object into a .h5ad file for subsequent use" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 5, 220 | "id": "10b09db4", 221 | "metadata": {}, 222 | "outputs": [ 223 | { 224 | "name": "stdout", 225 | "output_type": "stream", 226 | "text": [ 227 | " seq_name cell_types_98 cell_types_50 cell_types_25\n", 228 | "0 LS-15006_S09_E1-50 Vip Arhgap36 Hmcn1 n70 n66\n", 229 | "1 LS-15006_S10_E1-50 Lamp5 Lsp1 Lamp5 Lsp1 n78\n", 230 | "\n", 231 | "old shape: (15413, 45768)\n", 232 | "new shape: (13349, 45768)\n" 233 | ] 234 | } 235 | ], 236 | "source": [ 237 | "adata = ad.read_h5ad('../data/VISp.h5ad')\n", 238 | "persist_df = pd.read_csv('../data/VISp_PERSIST_metadata.csv')\n", 239 | "print(persist_df.head(2))\n", 240 | "print(f'\\nold shape: {adata.shape}')\n", 241 | "\n", 242 | "adata = adata[adata.obs['seq_name'].isin(persist_df['seq_name']), :]\n", 243 | "obs = adata.obs.copy().reset_index()\n", 244 | "obs = obs.merge(right=persist_df, how='left', left_on='seq_name', right_on='seq_name')\n", 245 | "obs = obs.set_index('sample_id')\n", 246 | "\n", 247 | "adata = ad.AnnData(X=adata.X, obs=obs, var=adata.var)\n", 248 | "print(f'new shape: {adata.shape}')" 249 | ] 250 | }, 251 | { 252 | "cell_type": "markdown", 253 | "id": "3b1398f1", 254 | "metadata": {}, 255 | "source": [ 256 | "**Normalization and preliminary gene selection**" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 6, 262 | "id": "11401b0c", 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "# transforms data in adata.X\n", 267 | "adata.layers['log1pcpm'] = sc.pp.normalize_total(adata, target_sum=1e6, inplace=False)['X']\n", 268 | "\n", 269 | "# transforms data in layers['lognorm'] inplace\n", 270 | "sc.pp.log1p(adata, layer='log1pcpm')\n" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 7, 276 | "id": "2b8cd61f", 277 | "metadata": {}, 278 | "outputs": [], 279 | "source": [ 280 | "import toml\n", 281 | "dat = toml.load('../data/VISp_markers.toml')" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": 8, 287 | "id": "c42e06e6", 288 | "metadata": {}, 289 | "outputs": [], 290 | "source": [ 291 | "# introduces \"highly_variable\" column to adata.var\n", 292 | "sc.pp.highly_variable_genes(adata, \n", 293 | " layer='log1pcpm', \n", 294 | " flavor='cell_ranger',\n", 295 | " n_top_genes=10000, \n", 296 | " inplace=True)" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": 9, 302 | "id": "a5cdf670", 303 | "metadata": {}, 304 | "outputs": [ 305 | { 306 | "data": { 307 | "text/plain": [ 308 | "AnnData object with n_obs × n_vars = 13349 × 45768\n", 309 | " obs: 'seq_name', 'class', 'subclass', 'cluster', 'cell_types_98', 'cell_types_50', 'cell_types_25'\n", 310 | " var: 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'markers'\n", 311 | " uns: 'log1p', 'hvg'\n", 312 | " layers: 'log1pcpm'" 313 | ] 314 | }, 315 | "execution_count": 9, 316 | "metadata": {}, 317 | "output_type": "execute_result" 318 | } 319 | ], 320 | "source": [ 321 | "# Create new field with marker genes\n", 322 | "adata.var['markers'] = np.isin(adata.var.index.values,dat['markers'])\n", 323 | "adata" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 10, 329 | "id": "bb0bb3c2", 330 | "metadata": {}, 331 | "outputs": [ 332 | { 333 | "data": { 334 | "text/plain": [ 335 | "" 337 | ] 338 | }, 339 | "execution_count": 10, 340 | "metadata": {}, 341 | "output_type": "execute_result" 342 | } 343 | ], 344 | "source": [ 345 | "# This is a sparse matrix\n", 346 | "adata.layers['log1pcpm']" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": 11, 352 | "id": "dc70f25f", 353 | "metadata": {}, 354 | "outputs": [ 355 | { 356 | "data": { 357 | "text/plain": [ 358 | "array([[0. , 0. , 3.84258223, 4.40539198, 0. ],\n", 359 | " [0. , 0. , 4.16454502, 4.52873471, 0.42111822],\n", 360 | " [0. , 0. , 3.82509693, 3.56268307, 0. ],\n", 361 | " [0. , 0. , 3.93543299, 0. , 0. ],\n", 362 | " [0. , 0. , 5.40677164, 4.62215864, 0. ]])" 363 | ] 364 | }, 365 | "execution_count": 11, 366 | "metadata": {}, 367 | "output_type": "execute_result" 368 | } 369 | ], 370 | "source": [ 371 | "# For sparse matrix `M`, `M.toarray()` to convert it to dense array\n", 372 | "adata.layers['log1pcpm'][:5,:5].toarray()" 373 | ] 374 | }, 375 | { 376 | "cell_type": "markdown", 377 | "id": "f8b21c66", 378 | "metadata": {}, 379 | "source": [ 380 | "**Parting notes:**\n", 381 | "\n", 382 | "The anndata object created in this way has a few different fields that we will end up using with PERSIST\n", 383 | "1. The raw counts are in `adata.X`\n", 384 | "2. The normalized counts (log1p of CPM values) are in `adata.layers['log1pcpm']`\n", 385 | "3. All metadata (cell type labels etc. for supervised mode in PERSIST) is in `adata.obs`\n", 386 | "4. A coarse selection of genes is in `adata.var['highly_variable']`\n", 387 | "5. Marker genes defined by Tasic et al. are indicated by `adata.var['markers']`" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": 12, 393 | "id": "53129c25", 394 | "metadata": {}, 395 | "outputs": [], 396 | "source": [ 397 | "# adata_hvg is a view. We'll convert it to a new AnnData object and write it out. \n", 398 | "adata_hvg = ad.AnnData(X=adata.X,\n", 399 | " obs=adata.obs, \n", 400 | " var=adata.var[['highly_variable']],\n", 401 | " layers=adata.layers, uns=adata.uns)\n", 402 | "adata_hvg.write('../data/VISp_filtered_cells.h5ad')" 403 | ] 404 | } 405 | ], 406 | "metadata": { 407 | "kernelspec": { 408 | "display_name": "Python 3.8.15 ('persist')", 409 | "language": "python", 410 | "name": "python3" 411 | }, 412 | "language_info": { 413 | "codemirror_mode": { 414 | "name": "ipython", 415 | "version": 3 416 | }, 417 | "file_extension": ".py", 418 | "mimetype": "text/x-python", 419 | "name": "python", 420 | "nbconvert_exporter": "python", 421 | "pygments_lexer": "ipython3", 422 | "version": "3.11.0" 423 | }, 424 | "vscode": { 425 | "interpreter": { 426 | "hash": "660691d0bb1e24e4a68343475da76b12317d18d7509dab1fd0158534dd4eebe4" 427 | } 428 | } 429 | }, 430 | "nbformat": 4, 431 | "nbformat_minor": 5 432 | } 433 | -------------------------------------------------------------------------------- /persist/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data 2 | from . import utils 3 | from . import layers 4 | from . import models 5 | from . import selection 6 | from persist.selection import PERSIST 7 | from persist.data import GeneSet, load_set 8 | from persist.data import ExpressionDataset, HDF5ExpressionDataset 9 | from persist.utils import HurdleLoss, MSELoss, Accuracy 10 | -------------------------------------------------------------------------------- /persist/data.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import pickle 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | from sklearn.model_selection import train_test_split 6 | from scipy.sparse import issparse 7 | 8 | 9 | class ExpressionDataset(Dataset): 10 | ''' 11 | Gene expression dataset capable of using subsets of inputs. 12 | 13 | Args: 14 | data: array of inputs with size (samples, dim). 15 | labels: array of labels with size (samples,) or (samples, dim). 16 | ''' 17 | 18 | def __init__(self, data, output): 19 | self.input_size = data.shape[1] 20 | self._data = data.astype(np.float32) 21 | if len(output.shape) == 1: 22 | # Output is classification labels. 23 | self.output_size = len(np.unique(output)) 24 | self._output = output.astype(np.int_) 25 | else: 26 | # Output is regression values. 27 | self.output_size = output.shape[1] 28 | self._output = output.astype(np.float32) 29 | self.set_inds(None) 30 | self.set_output_inds(None) 31 | 32 | # patch to work with sparse data without converting to dense beforehand 33 | self.data_issparse = issparse(data) 34 | self.output_issparse = issparse(output) 35 | 36 | def set_inds(self, inds, delete_remaining=False): 37 | ''' 38 | Set input indices to be returned. 39 | 40 | Args: 41 | inds: list/array of selected indices. 42 | delete_remaining: whether to permanently delete the indices that are 43 | not selected. 44 | ''' 45 | if inds is None: 46 | assert not delete_remaining 47 | inds = np.arange(self._data.shape[1]) 48 | 49 | # Set input and inds. 50 | self.inds = inds 51 | self.data = self._data 52 | self.input_size = len(inds) 53 | else: 54 | # Verify inds. 55 | inds = np.sort(inds) 56 | assert len(inds) > 0 57 | assert (inds[0] >= 0) and (inds[-1] < self._data.shape[1]) 58 | assert np.all(np.unique(inds, return_counts=True)[1] == 1) 59 | if len(inds) == self.input_size: 60 | assert not delete_remaining 61 | 62 | self.input_size = len(inds) 63 | if delete_remaining: 64 | # Reset data and input size. 65 | self._data = self._data[:, inds] 66 | self.data = self._data 67 | self.inds = np.arange(len(inds)) 68 | else: 69 | # Set input and inds. 70 | self.inds = inds 71 | self.data = self._data[:, inds] 72 | 73 | def set_output_inds(self, inds, delete_remaining=False): 74 | ''' 75 | Set output inds to be returned. Only for use with multivariate outputs. 76 | 77 | Args: 78 | inds: list/array of selected indices. 79 | delete_remaining: whether to permanently delete the indices that are 80 | not selected. 81 | ''' 82 | if inds is None: 83 | assert not delete_remaining 84 | 85 | # Set output and inds. 86 | self.output = self._output 87 | if len(self._output.shape) == 1: 88 | # Classification labels. 89 | self.output_size = len(np.unique(self._output)) 90 | self.output_inds = None 91 | else: 92 | # Regression labels. 93 | self.output_size = self._output.shape[1] 94 | self.output_inds = np.arange(self.output_size) 95 | else: 96 | # Verify that there are multiple output inds. 97 | assert len(self._output.shape) == 2 98 | 99 | # Verify inds. 100 | inds = np.sort(inds) 101 | assert len(inds) > 0 102 | assert (inds[0] >= 0) and (inds[-1] < self._output.shape[1]) 103 | assert np.all(np.unique(inds, return_counts=True)[1] == 1) 104 | if len(inds) == self.output_size: 105 | assert not delete_remaining 106 | 107 | self.output_size = len(inds) 108 | if delete_remaining: 109 | # Reset data and input size. 110 | self._output = self._output[:, inds] 111 | self.output = self._output 112 | self.output_inds = np.arange(len(inds)) 113 | else: 114 | # Set output and inds. 115 | self.output_inds = inds 116 | self.output = self._output[:, inds] 117 | 118 | @property 119 | def max_input_size(self): 120 | return self._data.shape[1] 121 | 122 | def __len__(self): 123 | return self._data.shape[0] 124 | 125 | def __getitem__(self, index): 126 | if self.data_issparse: 127 | data = self.data[index].toarray().flatten() 128 | else: 129 | data = self.data[index] 130 | 131 | if self.output_issparse: 132 | output = self.output[index].toarray().flatten() 133 | else: 134 | output = self.output[index] 135 | return data, output 136 | 137 | 138 | class HDF5ExpressionDataset(Dataset): 139 | ''' 140 | Dataset wrapper, capable of using subsets of inputs. 141 | 142 | Args: 143 | filename: HDF5 filename. 144 | data_name: key for data array. 145 | label_name: key for labels. 146 | sample_inds: list of indices for rows to be sampled. 147 | initialize: whether to initialize by opening the HDF5 file. This should 148 | only be done when using a data loader with no worker threads. 149 | ''' 150 | 151 | def __init__(self, filename, data_name, label_name, 152 | sample_inds=None, initialize=False): 153 | # Set up data variables. 154 | self.filename = filename 155 | self.data_name = data_name 156 | self.label_name = label_name 157 | 158 | # Set sample inds. 159 | hf = h5py.File(filename, 'r') 160 | data = hf[self.data_name] 161 | labels = hf[self.label_name] 162 | if sample_inds is None: 163 | sample_inds = np.arange(len(data)) 164 | self.sample_inds = sample_inds 165 | 166 | # Set input, output size. 167 | self.input_size = data.shape[1] 168 | if labels.ndim == 1: 169 | # Classification labels. 170 | self.output_size = len(np.unique(labels)) 171 | self.multiple_outputs = False 172 | else: 173 | # Regression labels. 174 | self.output_size = labels.shape[1] 175 | self.multiple_outputs = True 176 | hf.close() 177 | 178 | # Set input inds. 179 | self.all_inds = np.arange(self.input_size) 180 | self.set_inds(None) 181 | 182 | # Set output inds. 183 | if self.multiple_outputs: 184 | self.all_output_inds = np.arange(self.output_size) 185 | else: 186 | self.all_output_inds = None 187 | self.set_output_inds(None) 188 | 189 | # Initialize. 190 | if initialize: 191 | self.init_worker(0) 192 | 193 | def set_inds(self, inds, delete_remaining=False): 194 | ''' 195 | Set input indices to be returned. 196 | 197 | Args: 198 | inds: list/array of selected indices. 199 | delete_remaining: whether to permanently delete the indices that are 200 | not selected. 201 | ''' 202 | if inds is None: 203 | assert not delete_remaining 204 | self.inds = np.arange(len(self.all_inds)) 205 | self.relative_inds = self.all_inds 206 | self.input_size = len(self.all_inds) 207 | else: 208 | # Verify inds. 209 | inds = np.sort(inds) 210 | assert len(inds) > 0 211 | assert (inds[0] >= 0) and (inds[-1] < len(self.all_inds)) 212 | assert np.all(np.unique(inds, return_counts=True)[1] == 1) 213 | 214 | self.input_size = len(inds) 215 | if delete_remaining: 216 | self.inds = np.arange(len(inds)) 217 | self.all_inds = self.all_inds[inds] 218 | self.relative_inds = self.all_inds 219 | else: 220 | self.inds = inds 221 | self.relative_inds = self.all_inds[inds] 222 | 223 | def set_output_inds(self, inds, delete_remaining=False): 224 | ''' 225 | Set output inds to be returned. Only for use with multivariate outputs. 226 | 227 | Args: 228 | inds: list/array of selected indices. 229 | delete_remaining: whether to permanently delete the indices that are 230 | not selected. 231 | ''' 232 | if inds is None: 233 | assert not delete_remaining 234 | if self.multiple_outputs: 235 | self.output_inds = np.arange(len(self.all_output_inds)) 236 | self.output_relative_inds = self.all_output_inds 237 | self.output_size = len(self.all_output_inds) 238 | else: 239 | self.output_inds = None 240 | self.output_relative_inds = None 241 | else: 242 | # Verify that there are multiple output inds. 243 | assert self.multiple_outputs 244 | 245 | # Verify inds. 246 | inds = np.sort(inds) 247 | assert len(inds) > 0 248 | assert (inds[0] >= 0) and (inds[-1] < len(self.all_output_inds)) 249 | assert np.all(np.unique(inds, return_counts=True)[1] == 1) 250 | if len(inds) == self.output_size: 251 | assert not delete_remaining 252 | 253 | self.output_size = len(inds) 254 | if delete_remaining: 255 | self.output_inds = np.arange(len(inds)) 256 | self.all_output_inds = self.all_output_inds[inds] 257 | self.output_relative_inds = self.all_output_inds 258 | else: 259 | self.output_inds = inds 260 | self.output_relative_inds = self.all_output_inds[inds] 261 | 262 | def init_worker(self, worker_id): 263 | '''Initialize worker in data loader thread.''' 264 | self.h5 = h5py.File(self.filename, 'r', swmr=True) 265 | 266 | @property 267 | def max_input_size(self): 268 | return len(self.all_inds) 269 | 270 | def __len__(self): 271 | return len(self.sample_inds) 272 | 273 | def __getitem__(self, index): 274 | # Possibly initialize worker to open HDF5 file. 275 | if not hasattr(self, 'h5'): 276 | self.init_worker(0) 277 | 278 | index = self.sample_inds[index] 279 | data = self.h5[self.data_name][index][self.relative_inds] 280 | labels = self.h5[self.label_name][index] 281 | if self.output_relative_inds is not None: 282 | labels = labels[self.output_relative_inds] 283 | return data, labels 284 | 285 | 286 | def split_data(data, seed=123, val_portion=0.1, test_portion=0.1): 287 | '''Split data into train, val, test.''' 288 | N = data.shape[0] 289 | N_val = int(val_portion * N) 290 | N_test = int(test_portion * N) 291 | train, test = train_test_split(data, test_size=N_test, random_state=seed) 292 | train, val = train_test_split(train, test_size=N_val, random_state=seed+1) 293 | return train, val, test 294 | 295 | 296 | def bootstrapped_dataset(dataset, seed=None): 297 | '''Sample a bootstrapped dataset.''' 298 | if isinstance(dataset, ExpressionDataset): 299 | data = dataset.data 300 | labels = dataset.output 301 | if seed: 302 | np.random.seed(seed) 303 | N = len(data) 304 | inds = np.random.choice(N, size=N, replace=True) 305 | return ExpressionDataset(data[inds], labels[inds]) 306 | elif isinstance(dataset, HDF5ExpressionDataset): 307 | inds = dataset.sample_inds 308 | inds = np.random.choice(inds, size=len(inds), replace=True) 309 | return HDF5ExpressionDataset(dataset.filename, dataset.data_name, 310 | dataset.label_name, sample_inds=inds) 311 | else: 312 | raise ValueError('dataset must be ExpressionDataset or ' 313 | 'HDF5ExpressionDataset') 314 | 315 | 316 | class GeneSet: 317 | ''' 318 | Set of genes, represented by their indices and names. 319 | 320 | Args: 321 | inds: gene indices. 322 | names: gene names. 323 | ''' 324 | 325 | def __init__(self, inds, names): 326 | self.inds = inds 327 | self.names = names 328 | 329 | def save(self, filename): 330 | '''Save object by pickling.''' 331 | with open(filename, 'wb') as f: 332 | pickle.dump(self, f) 333 | 334 | 335 | def load_set(filename): 336 | '''Load GeneSet object.''' 337 | with open(filename, 'rb') as f: 338 | subset = pickle.load(f) 339 | if isinstance(subset, GeneSet): 340 | return subset 341 | else: 342 | raise ValueError('object is not GeneSet') 343 | -------------------------------------------------------------------------------- /persist/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def clamp_probs(probs): 7 | '''Clamp probabilities to ensure stable logs.''' 8 | eps = torch.finfo(probs.dtype).eps 9 | return torch.clamp(probs, min=eps, max=1-eps) 10 | 11 | 12 | def concrete_sample(logits, temperature, shape=torch.Size([])): 13 | ''' 14 | Sampling for Concrete distribution (see eq. 10 of Maddison et al., 2017). 15 | 16 | Args: 17 | logits: Concrete logits parameters. 18 | temperature: Concrete temperature parameter. 19 | shape: sample shape. 20 | ''' 21 | uniform_shape = torch.Size(shape) + logits.shape 22 | u = clamp_probs(torch.rand(uniform_shape, dtype=torch.float32, 23 | device=logits.device)) 24 | gumbels = - torch.log(- torch.log(u)) 25 | scores = (logits + gumbels) / temperature 26 | return scores.softmax(dim=-1) 27 | 28 | 29 | def bernoulli_concrete_sample(logits, temperature, shape=torch.Size([])): 30 | ''' 31 | Sampling for BinConcrete distribution (see PyTorch source code, differs 32 | slightly from eq. 16 of Maddison et al., 2017). 33 | 34 | Args: 35 | logits: tensor of BinConcrete logits parameters. 36 | temperature: BinConcrete temperature parameter. 37 | shape: sample shape. 38 | ''' 39 | uniform_shape = torch.Size(shape) + logits.shape 40 | u = clamp_probs(torch.rand(uniform_shape, dtype=torch.float32, 41 | device=logits.device)) 42 | return torch.sigmoid((F.logsigmoid(logits) - F.logsigmoid(-logits) 43 | + torch.log(u) - torch.log(1 - u)) / temperature) 44 | 45 | 46 | class BinaryMask(nn.Module): 47 | ''' 48 | Input layer that selects features by learning a k-hot mask. 49 | 50 | Args: 51 | input_size: number of inputs. 52 | num_selections: number of features to select. 53 | temperature: temperature for Concrete samples. 54 | gamma: used to map learned parameters to logits (helps convergence speed). 55 | ''' 56 | 57 | def __init__(self, 58 | input_size, 59 | num_selections, 60 | temperature=10.0, 61 | gamma=1/3): 62 | super().__init__() 63 | self._logits = nn.Parameter( 64 | torch.zeros(num_selections, input_size, dtype=torch.float32, 65 | requires_grad=True)) 66 | self.input_size = input_size 67 | self.num_selections = num_selections 68 | self.output_size = input_size 69 | self.temperature = temperature 70 | self.gamma = gamma 71 | 72 | @property 73 | def logits(self): 74 | return self._logits / self.gamma 75 | 76 | @property 77 | def probs(self): 78 | return (self.logits).softmax(dim=1) 79 | 80 | def sample(self, n_samples): 81 | '''Sample approximate k-hot vectors.''' 82 | samples = concrete_sample( 83 | self.logits, self.temperature, torch.Size([n_samples])) 84 | return torch.max(samples, dim=-2).values 85 | 86 | def forward(self, x): 87 | '''Sample and apply mask.''' 88 | m = self.sample(len(x)) 89 | x = x * m 90 | return x, m 91 | 92 | def get_inds(self): 93 | '''Get selected indices.''' 94 | inds = torch.argmax(self.logits, dim=1) 95 | return torch.sort(inds)[0].cpu().data.numpy() 96 | 97 | def extra_repr(self): 98 | return (f'input_size={self.input_size}, temperature={self.temperature},' 99 | f' num_selections={self.num_selections}') 100 | 101 | 102 | class BinaryGates(nn.Module): 103 | ''' 104 | Input layer that selects features by learning binary gates for each feature, 105 | similar to [1]. 106 | 107 | [1] Dropout Feature Ranking for Deep Learning Models (Chang et al., 2017) 108 | 109 | Args: 110 | input_size: number of inputs. 111 | temperature: temperature for BinConcrete samples. 112 | init: initial value for each gate's probability of being 1. 113 | gamma: used to map learned parameters to logits (helps convergence speed). 114 | ''' 115 | 116 | def __init__(self, 117 | input_size, 118 | temperature=0.1, 119 | init=0.99, 120 | gamma=1/2): 121 | super().__init__() 122 | init_logit = - torch.log(1 / torch.tensor(init) - 1) * gamma 123 | self._logits = nn.Parameter(torch.full( 124 | (input_size,), init_logit, dtype=torch.float32, requires_grad=True)) 125 | self.input_size = input_size 126 | self.output_size = input_size 127 | self.temperature = temperature 128 | self.gamma = gamma 129 | 130 | @property 131 | def logits(self): 132 | return self._logits / self.gamma 133 | 134 | @property 135 | def probs(self): 136 | return torch.sigmoid(self.logits) 137 | 138 | def sample(self, n_samples): 139 | '''Sample approximate binary masks.''' 140 | return bernoulli_concrete_sample( 141 | self.logits, self.temperature, torch.Size([n_samples])) 142 | 143 | def forward(self, x): 144 | '''Sample and apply mask.''' 145 | m = self.sample(len(x)) 146 | x = x * m 147 | return x, m 148 | 149 | def get_inds(self, num_features=None, threshold=None): 150 | ''' 151 | Get selected indices. 152 | 153 | Args: 154 | num_features: number of top features to return. 155 | threshold: probability threshold for determining selected features. 156 | ''' 157 | if (num_features is None) == (threshold is None): 158 | raise ValueError('exactly one of num_features and threshold must be' 159 | ' specified') 160 | 161 | if num_features: 162 | inds = torch.argsort(self.probs)[-num_features:] 163 | elif threshold: 164 | inds = (self.probs > threshold).nonzero()[:, 0] 165 | return torch.sort(inds)[0].cpu().data.numpy() 166 | 167 | def extra_repr(self): 168 | return f'input_size={self.input_size}, temperature={self.temperature}' 169 | 170 | 171 | class ConcreteSelector(nn.Module): 172 | ''' 173 | Input layer that selects features by learning a binary matrix, based on [2]. 174 | 175 | [2] Concrete Autoencoders for Differentiable Feature Selection and 176 | Reconstruction (Balin et al., 2019) 177 | 178 | Args: 179 | input_size: number of inputs. 180 | num_selections: number of features to select. 181 | temperature: temperature for Concrete samples. 182 | gamma: used to map learned parameters to logits (helps convergence speed). 183 | ''' 184 | 185 | def __init__(self, 186 | input_size, 187 | num_selections, 188 | temperature=10.0, 189 | gamma=1/3): 190 | super().__init__() 191 | self._logits = nn.Parameter( 192 | torch.zeros(num_selections, input_size, dtype=torch.float32, 193 | requires_grad=True)) 194 | self.input_size = input_size 195 | self.num_selections = num_selections 196 | self.output_size = num_selections 197 | self.temperature = temperature 198 | self.gamma = gamma 199 | 200 | @property 201 | def logits(self): 202 | return self._logits / self.gamma 203 | 204 | @property 205 | def probs(self): 206 | return self.logits.softmax(dim=1) 207 | 208 | def sample(self, n_samples): 209 | '''Sample approximate binary matrices.''' 210 | return concrete_sample( 211 | self.logits, self.temperature, torch.Size([n_samples])) 212 | 213 | def forward(self, x): 214 | '''Sample and apply selector matrix.''' 215 | M = self.sample(len(x)) 216 | x = torch.matmul(M, x.unsqueeze(2)).squeeze(2) 217 | return x, M 218 | 219 | def get_inds(self): 220 | '''Get selected indices.''' 221 | inds = torch.argmax(self.logits, dim=1) 222 | return torch.sort(inds)[0].cpu().data.numpy() 223 | 224 | def extra_repr(self): 225 | return (f'input_size={self.input_size}, temperature={self.temperature},' 226 | f' num_selections={self.num_selections}') 227 | -------------------------------------------------------------------------------- /persist/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import numpy as np 5 | from persist import layers 6 | from copy import deepcopy 7 | from tqdm.auto import tqdm 8 | from torch.utils.data import DataLoader 9 | 10 | 11 | def restore_parameters(model, best_model): 12 | '''Copy parameter values from best_model to model.''' 13 | for params, best_params in zip(model.parameters(), best_model.parameters()): 14 | params.data = best_params 15 | 16 | 17 | def input_layer_penalty(input_layer, m): 18 | if isinstance(input_layer, layers.BinaryGates): 19 | return torch.mean(torch.sum(m, dim=1)) 20 | else: 21 | raise ValueError('only BinaryGates layer has penalty') 22 | 23 | 24 | def input_layer_fix(input_layer): 25 | '''Fix collisions in the input layer.''' 26 | required_fix = False 27 | 28 | if isinstance(input_layer, (layers.BinaryMask, layers.ConcreteSelector)): 29 | # Extract logits. 30 | logits = input_layer._logits 31 | argmax = torch.argmax(logits, dim=1).cpu().data.numpy() 32 | 33 | # Locate collisions and reinitialize. 34 | for i in range(len(argmax) - 1): 35 | if argmax[i] in argmax[i+1:]: 36 | required_fix = True 37 | logits.data[i] = torch.randn( 38 | logits[i].shape, dtype=logits.dtype, device=logits.device) 39 | return required_fix 40 | 41 | return required_fix 42 | 43 | 44 | def input_layer_summary(input_layer, n_samples=256): 45 | '''Generate summary string for input layer's convergence.''' 46 | with torch.no_grad(): 47 | if isinstance(input_layer, layers.BinaryMask): 48 | m = input_layer.sample(n_samples) 49 | mean = torch.mean(m, dim=0) 50 | sorted_mean = torch.sort(mean, descending=True).values 51 | relevant = sorted_mean[:input_layer.num_selections] 52 | return 'Max = {:.2f}, Mean = {:.2f}, Min = {:.2f}'.format( 53 | relevant[0].item(), torch.mean(relevant).item(), 54 | relevant[-1].item()) 55 | 56 | elif isinstance(input_layer, layers.ConcreteSelector): 57 | M = input_layer.sample(n_samples) 58 | mean = torch.mean(M, dim=0) 59 | relevant = torch.max(mean, dim=1).values 60 | return 'Max = {:.2f}, Mean = {:.2f}, Min = {:.2f}'.format( 61 | torch.max(relevant).item(), torch.mean(relevant).item(), 62 | torch.min(relevant).item()) 63 | 64 | elif isinstance(input_layer, layers.BinaryGates): 65 | m = input_layer.sample(n_samples) 66 | mean = torch.mean(m, dim=0) 67 | dist = torch.min(mean, 1 - mean) 68 | return 'Mean dist = {:.2f}, Max dist = {:.2f}, Num sel = {}'.format( 69 | torch.mean(dist).item(), 70 | torch.max(dist).item(), 71 | int(torch.sum((mean > 0.5).float()).item())) 72 | 73 | 74 | def input_layer_converged(input_layer, tol=1e-2, n_samples=256): 75 | '''Determine whether the input layer has converged.''' 76 | with torch.no_grad(): 77 | if isinstance(input_layer, layers.BinaryMask): 78 | m = input_layer.sample(n_samples) 79 | mean = torch.mean(m, dim=0) 80 | return ( 81 | torch.sort(mean).values[-input_layer.num_selections].item() 82 | > 1 - tol) 83 | 84 | elif isinstance(input_layer, layers.BinaryGates): 85 | m = input_layer.sample(n_samples) 86 | mean = torch.mean(m, dim=0) 87 | return torch.max(torch.min(mean, 1 - mean)).item() < tol 88 | 89 | elif isinstance(input_layer, layers.ConcreteSelector): 90 | M = input_layer.sample(n_samples) 91 | mean = torch.mean(M, dim=0) 92 | return torch.min(torch.max(mean, dim=1).values).item() > 1 - tol 93 | 94 | 95 | def warmstart_model(model, inds): 96 | ''' 97 | Create model for subset of features by removing parameters. 98 | 99 | Args: 100 | model: model to copy. 101 | inds: indices for features to retain. 102 | ''' 103 | sub_model = deepcopy(model.mlp) 104 | device = next(model.parameters()).device 105 | 106 | # Resize input layer. 107 | layer = sub_model.fc[0] 108 | new_layer = nn.Linear(len(inds), layer.out_features).to(device) 109 | new_layer.weight.data = layer.weight[:, inds] 110 | new_layer.bias.data = layer.bias 111 | sub_model.fc[0] = new_layer 112 | 113 | return sub_model 114 | 115 | 116 | class MLP(nn.Module): 117 | ''' 118 | Multilayer perceptron (MLP) model. 119 | 120 | Args: 121 | input_size: number of inputs. 122 | output_size: number of outputs. 123 | hidden: list of hidden layer widths. 124 | activation: nonlinearity between layers. 125 | ''' 126 | 127 | def __init__(self, 128 | input_size, 129 | output_size, 130 | hidden, 131 | activation=nn.ReLU()): 132 | super().__init__() 133 | 134 | # Fully connected layers. 135 | self.input_size = input_size 136 | self.output_size = output_size 137 | fc_layers = [nn.Linear(d_in, d_out) for d_in, d_out in 138 | zip([input_size] + hidden, hidden + [output_size])] 139 | self.fc = nn.ModuleList(fc_layers) 140 | 141 | # Activation function. 142 | self.activation = activation 143 | 144 | def forward(self, x): 145 | for fc in self.fc[:-1]: 146 | x = fc(x) 147 | x = self.activation(x) 148 | 149 | return self.fc[-1](x) 150 | 151 | def fit(self, 152 | train_dataset, 153 | val_dataset, 154 | mbsize, 155 | max_nepochs, 156 | loss_fn, 157 | lr=1e-3, 158 | min_lr=1e-5, 159 | lr_factor=0.5, 160 | optimizer='Adam', 161 | lookback=10, 162 | bar=False, 163 | verbose=True): 164 | ''' 165 | Train the model. 166 | 167 | Args: 168 | train_dataset: training dataset. 169 | val_dataset: validation dataset. 170 | mbsize: minibatch size. 171 | max_nepochs: maximum number of epochs. 172 | loss_fn: loss function. 173 | lr: learning rate. 174 | min_lr: minimum learning rate. 175 | lr_factor: learning rate decrease factor. 176 | optimizer: optimizer type. 177 | lookback: number of epochs to wait for improvement before stopping. 178 | bar: whether to display tqdm progress bar. 179 | verbose: verbosity. 180 | ''' 181 | # Set up optimizer. 182 | optimizer = optim.Adam(self.parameters(), lr=lr) 183 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 184 | optimizer, factor=lr_factor, patience=lookback // 2, min_lr=min_lr, 185 | verbose=verbose) 186 | 187 | # Set up data loaders. 188 | has_init = hasattr(train_dataset, 'init_worker') 189 | if has_init: 190 | train_init = train_dataset.init_worker 191 | val_init = val_dataset.init_worker 192 | else: 193 | train_init = None 194 | val_init = None 195 | train_loader = DataLoader( 196 | train_dataset, batch_size=mbsize, shuffle=True, drop_last=True, 197 | worker_init_fn=train_init, num_workers=4) 198 | val_loader = DataLoader( 199 | val_dataset, batch_size=mbsize, worker_init_fn=val_init, 200 | num_workers=4) 201 | 202 | # Determine device. 203 | device = next(self.parameters()).device 204 | 205 | # For tracking loss. 206 | self.train_loss = [] 207 | self.val_loss = [] 208 | best_model = None 209 | best_loss = np.inf 210 | best_epoch = None 211 | 212 | # Bar setup. 213 | if bar: 214 | tqdm_bar = tqdm( 215 | total=max_nepochs, desc='Training epochs', leave=True) 216 | 217 | # Begin training. 218 | for epoch in range(max_nepochs): 219 | # For tracking mean train loss. 220 | train_loss = 0 221 | N = 0 222 | 223 | for x, y in train_loader: 224 | # Move to device. 225 | x = x.to(device) 226 | y = y.to(device) 227 | 228 | # Forward pass. 229 | pred = self.forward(x) 230 | 231 | # Calculate loss. 232 | loss = loss_fn(pred, y) 233 | 234 | # Update mean train loss. 235 | train_loss = ( 236 | (N * train_loss + mbsize * loss.item()) / (N + mbsize)) 237 | N += mbsize 238 | 239 | # Gradient step. 240 | loss.backward() 241 | optimizer.step() 242 | self.zero_grad() 243 | 244 | # Check progress. 245 | with torch.no_grad(): 246 | # Calculate loss. 247 | self.eval() 248 | val_loss = self.validate(val_loader, loss_fn).item() 249 | self.train() 250 | 251 | # Update learning rate. 252 | scheduler.step(val_loss) 253 | 254 | # Record loss. 255 | self.train_loss.append(train_loss) 256 | self.val_loss.append(val_loss) 257 | 258 | if verbose: 259 | print(f'{"-" * 8}Epoch = {epoch + 1}{"-" * 8}') 260 | print(f'Train loss = {train_loss:.4f}') 261 | print(f'Val loss = {val_loss:.4f}') 262 | 263 | # Update bar. 264 | if bar: 265 | tqdm_bar.update(1) 266 | 267 | # Check for early stopping. 268 | if val_loss < best_loss: 269 | best_loss = val_loss 270 | best_model = deepcopy(self) 271 | best_epoch = epoch 272 | elif (epoch - best_epoch) == lookback: 273 | # Skip bar to end. 274 | if bar: 275 | tqdm_bar.n = max_nepochs 276 | 277 | if verbose: 278 | print('Stopping early') 279 | break 280 | 281 | # Restore model parameters. 282 | restore_parameters(self, best_model) 283 | 284 | def validate(self, loader, loss_fn): 285 | '''Calculate average loss.''' 286 | device = next(self.parameters()).device 287 | mean_loss = 0 288 | N = 0 289 | 290 | with torch.no_grad(): 291 | for x, y in loader: 292 | # Move to GPU. 293 | x = x.to(device=device) 294 | y = y.to(device=device) 295 | n = len(x) 296 | 297 | # Calculate loss. 298 | pred = self.forward(x) 299 | loss = loss_fn(pred, y) 300 | mean_loss = (N * mean_loss + n * loss) / (N + n) 301 | N += n 302 | 303 | return mean_loss 304 | 305 | 306 | class SelectorMLP(nn.Module): 307 | ''' 308 | MLP model with embedded selector layer. 309 | 310 | Args: 311 | input_layer: selection layer type (e.g., BinaryMask). 312 | input_size: number of inputs. 313 | output_size: number of outputs. 314 | hidden: list of hidden layer widths. 315 | activation: nonlinearity between layers. 316 | preselected_inds: feature indices that are already selected. 317 | num_selections: number of features to select (for BinaryMask and 318 | ConcreteSelector layers). 319 | kwargs: additional arguments for input layers. 320 | ''' 321 | 322 | def __init__(self, 323 | input_layer, 324 | input_size, 325 | output_size, 326 | hidden, 327 | activation=nn.ReLU(), 328 | preselected_inds=[], 329 | num_selections=None, 330 | **kwargs): 331 | # Verify arguments. 332 | super().__init__() 333 | if num_selections is None: 334 | if input_layer in ('binary_mask', 'concrete_selector'): 335 | raise ValueError( 336 | f'must specify num_selections for {input_layer} layer') 337 | else: 338 | if input_layer in ('binary_gates'): 339 | raise ValueError('num_selections cannot be specified for ' 340 | f'{input_layer} layer') 341 | 342 | # Set up for pre-selected features. 343 | preselected_inds = np.sort(preselected_inds) 344 | assert len(preselected_inds) < input_size 345 | self.preselected = np.array( 346 | [i in preselected_inds for i in range(input_size)]) 347 | preselected_size = len(preselected_inds) 348 | self.has_preselected = preselected_size > 0 349 | 350 | # Set up input layer. 351 | if input_layer == 'binary_mask': 352 | mlp_input_size = input_size 353 | self.input_layer = layers.BinaryMask( 354 | input_size - preselected_size, num_selections, **kwargs) 355 | elif input_layer == 'binary_gates': 356 | mlp_input_size = input_size 357 | self.input_layer = layers.BinaryGates( 358 | input_size - preselected_size, **kwargs) 359 | elif input_layer == 'concrete_selector': 360 | mlp_input_size = num_selections + preselected_size 361 | self.input_layer = layers.ConcreteSelector( 362 | input_size - preselected_size, num_selections, **kwargs) 363 | else: 364 | raise ValueError('unsupported input layer: {}'.format(input_layer)) 365 | 366 | # Create MLP. 367 | self.mlp = MLP(mlp_input_size, output_size, hidden, activation) 368 | 369 | def forward(self, x): 370 | '''Apply input layer and return MLP output.''' 371 | if self.has_preselected: 372 | pre = x[:, self.preselected] 373 | x, m = self.input_layer(x[:, ~self.preselected]) 374 | x = torch.cat([pre, x], dim=1) 375 | else: 376 | x, m = self.input_layer(x) 377 | pred = self.mlp(x) 378 | return pred, x, m 379 | 380 | def fit(self, 381 | train_dataset, 382 | val_dataset, 383 | lr, 384 | mbsize, 385 | max_nepochs, 386 | start_temperature, 387 | end_temperature, 388 | loss_fn, 389 | eta=0, 390 | lam=0, 391 | optimizer='Adam', 392 | lookback=10, 393 | bar=False, 394 | verbose=True): 395 | ''' 396 | Train the model. 397 | 398 | Args: 399 | train_dataset: training dataset. 400 | val_dataset: validation dataset. 401 | lr: learning rate. 402 | mbsize: minibatch size. 403 | max_nepochs: maximum number of epochs. 404 | start_temperature: 405 | end_temperature: 406 | loss_fn: loss function. 407 | eta: penalty parameter for number of expressed genes. 408 | lam: penalty parameter. 409 | optimizer: optimizer type. 410 | lookback: number of epochs to wait for improvement before stopping. 411 | bar: whether to display tqdm progress bar. 412 | verbose: verbosity. 413 | ''' 414 | # Verify arguments. 415 | if lam != 0: 416 | if not isinstance(self.input_layer, layers.BinaryGates): 417 | raise ValueError('lam should only be specified when using ' 418 | 'BinaryGates layer') 419 | else: 420 | if isinstance(self.input_layer, layers.BinaryGates): 421 | raise ValueError('lam must be specified when using ' 422 | 'BinaryGates layer') 423 | if eta > 0: 424 | if isinstance(self.input_layer, layers.BinaryGates): 425 | raise ValueError('lam cannot be specified when using ' 426 | 'BinaryGates layer') 427 | 428 | if end_temperature > start_temperature: 429 | raise ValueError('temperature should be annealed downwards, must ' 430 | 'have end_temperature <= start_temperature') 431 | elif end_temperature == start_temperature: 432 | loss_early_stopping = True 433 | else: 434 | loss_early_stopping = False 435 | 436 | # Set up optimizer. 437 | optimizer = optimizer = optim.Adam(self.parameters(), lr=lr) 438 | 439 | # Set up data loaders. 440 | has_init = hasattr(train_dataset, 'init_worker') 441 | if has_init: 442 | train_init = train_dataset.init_worker 443 | val_init = val_dataset.init_worker 444 | else: 445 | train_init = None 446 | val_init = None 447 | train_loader = DataLoader(train_dataset, batch_size=mbsize, 448 | shuffle=True, drop_last=True, 449 | worker_init_fn=train_init, num_workers=4) 450 | val_loader = DataLoader(val_dataset, batch_size=mbsize, 451 | worker_init_fn=val_init, num_workers=4) 452 | 453 | # Determine device. 454 | device = next(self.parameters()).device 455 | 456 | # Set temperature and determine rate for decreasing. 457 | self.input_layer.temperature = start_temperature 458 | r = np.power(end_temperature / start_temperature, 459 | 1 / ((len(train_dataset) // mbsize) * max_nepochs)) 460 | 461 | # For tracking loss. 462 | self.train_loss = [] 463 | self.val_loss = [] 464 | best_loss = np.inf 465 | best_epoch = -1 466 | 467 | # Bar setup. 468 | if bar: 469 | tqdm_bar = tqdm( 470 | total=max_nepochs, desc='Training epochs', leave=True) 471 | 472 | # Begin training. 473 | for epoch in range(max_nepochs): 474 | # For tracking mean train loss. 475 | train_loss = 0 476 | N = 0 477 | 478 | for x, y in train_loader: 479 | # Move to device. 480 | x = x.to(device) 481 | y = y.to(device) 482 | 483 | # Calculate loss. 484 | pred, x, m = self.forward(x) 485 | loss = loss_fn(pred, y) 486 | 487 | # Calculate penalty if necessary. 488 | if lam > 0: 489 | penalty = input_layer_penalty(self.input_layer, m) 490 | loss = loss + lam * penalty 491 | 492 | # Add expression penalty if necessary. 493 | if eta > 0: 494 | expressed = torch.mean(torch.sum(x, dim=1)) 495 | loss = loss + eta * expressed 496 | 497 | # Update mean train loss. 498 | train_loss = ( 499 | (N * train_loss + mbsize * loss.item()) / (N + mbsize)) 500 | N += mbsize 501 | 502 | # Gradient step. 503 | loss.backward() 504 | optimizer.step() 505 | self.zero_grad() 506 | 507 | # Adjust temperature. 508 | self.input_layer.temperature *= r 509 | 510 | # Check progress. 511 | with torch.no_grad(): 512 | # Calculate loss. 513 | self.eval() 514 | val_loss, val_expressed = self.validate( 515 | val_loader, loss_fn, lam, eta) 516 | val_loss, val_expressed = val_loss.item(), val_expressed.item() 517 | self.train() 518 | 519 | # Record loss. 520 | self.train_loss.append(train_loss) 521 | self.val_loss.append(val_loss) 522 | 523 | if verbose: 524 | print(f'{"-" * 8}Epoch = {epoch + 1}{"-" * 8}') 525 | print(f'Train loss = {train_loss:.4f}') 526 | print(f'Val loss = {val_loss:.4f}') 527 | if eta > 0: 528 | print(f'Mean expressed genes = {val_expressed:.4f}') 529 | print(input_layer_summary(self.input_layer)) 530 | 531 | # Update bar. 532 | if bar: 533 | tqdm_bar.update(1) 534 | 535 | # Fix input layer if necessary. 536 | required_fix = input_layer_fix(self.input_layer) 537 | 538 | if not required_fix: 539 | # Stop early if input layer is converged. 540 | if input_layer_converged(self.input_layer, n_samples=mbsize): 541 | if verbose: 542 | print('Stopping early: input layer converged') 543 | break 544 | 545 | # Stop early if loss converged. 546 | if loss_early_stopping: 547 | if val_loss < best_loss: 548 | best_loss = val_loss 549 | best_epoch = epoch 550 | elif (epoch - best_epoch) == lookback: 551 | # Skip bar to end. 552 | if bar: 553 | tqdm_bar.n = max_nepochs 554 | 555 | if verbose: 556 | print('Stopping early: loss converged') 557 | break 558 | 559 | def validate(self, loader, loss_fn, lam, eta): 560 | '''Calculate average loss.''' 561 | device = next(self.parameters()).device 562 | mean_loss = 0 563 | mean_expressed = 0 564 | N = 0 565 | with torch.no_grad(): 566 | for x, y in loader: 567 | # Move to GPU. 568 | x = x.to(device=device) 569 | y = y.to(device=device) 570 | n = len(x) 571 | 572 | # Calculate loss. 573 | pred, x, m = self.forward(x) 574 | loss = loss_fn(pred, y) 575 | 576 | # Add penalty term. 577 | if lam > 0: 578 | penalty = input_layer_penalty(self.input_layer, m) 579 | loss = loss + lam * penalty 580 | 581 | # Add expression penalty term. 582 | expressed = torch.mean(torch.sum(x, dim=1)) 583 | if eta > 0: 584 | loss = loss + eta * expressed 585 | 586 | mean_loss = (N * mean_loss + n * loss) / (N + n) 587 | mean_expressed = (N * mean_expressed + n * expressed) / (N + n) 588 | N += n 589 | 590 | return mean_loss, mean_expressed 591 | -------------------------------------------------------------------------------- /persist/selection.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | from persist import models, utils 5 | 6 | 7 | class PERSIST: 8 | ''' 9 | Class for using the predictive and robust gene selection for spatial 10 | transcriptomics (PERSIST) method. 11 | 12 | Args: 13 | train_dataset: dataset of training examples (ExpressionDataset). 14 | val_dataset: dataset of validation examples. 15 | loss_fn: loss function, such as HurdleLoss(), nn.MSELoss() or 16 | nn.CrossEntropyLoss(). 17 | device: torch device, such as torch.device('cuda', 0). 18 | preselected_inds: list of indices that must be selected. 19 | hidden: number of hidden units per layer (list of ints). 20 | activation: activation function between layers. 21 | ''' 22 | def __init__(self, 23 | train_dataset, 24 | val_dataset, 25 | loss_fn, 26 | device, 27 | preselected_inds=[], 28 | hidden=[128, 128], 29 | activation=nn.ReLU()): 30 | # TODO add verification for dataset type. 31 | self.train = train_dataset 32 | self.val = val_dataset 33 | self.loss_fn = loss_fn 34 | 35 | # Architecture parameters. 36 | self.hidden = hidden 37 | self.activation = activation 38 | 39 | # Set device. 40 | assert isinstance(device, torch.device) 41 | self.device = device 42 | 43 | # Set preselected genes. 44 | self.preselected = np.sort(preselected_inds).astype(int) 45 | 46 | # Initialize candidate genes. 47 | self.set_genes() 48 | 49 | def get_genes(self): 50 | '''Get currently selected genes, not including preselected genes.''' 51 | return self.candidates 52 | 53 | def set_genes(self, candidates=None): 54 | '''Restrict the subset of genes.''' 55 | if candidates is None: 56 | # All genes but pre-selected ones. 57 | candidates = np.array( 58 | [i for i in range(self.train.max_input_size) 59 | if i not in self.preselected]) 60 | 61 | else: 62 | # Ensure that candidates do not overlap with pre-selected genes. 63 | assert len(np.intersect1d(candidates, self.preselected)) == 0 64 | self.candidates = candidates 65 | 66 | # Set genes in datasets. 67 | included = np.sort(np.concatenate([candidates, self.preselected])) 68 | self.train.set_inds(included) 69 | self.val.set_inds(included) 70 | 71 | # Set relative indices for pre-selected genes. 72 | self.preselected_relative = np.array( 73 | [np.where(included == ind)[0][0] for ind in self.preselected]) 74 | 75 | def eliminate(self, 76 | target, 77 | lam_init=None, 78 | mbsize=64, 79 | max_nepochs=250, 80 | lr=1e-3, 81 | tol=0.2, 82 | start_temperature=10.0, 83 | end_temperature=0.01, 84 | optimizer='Adam', 85 | lookback=10, 86 | max_trials=10, 87 | bar=True, 88 | verbose=False): 89 | ''' 90 | Narrow the set of candidate genes: train a model with the BinaryGates 91 | layer and annealed penalty to eliminate a large portion of the inputs. 92 | 93 | Args: 94 | target: target number of genes to select (in addition to pre-selected 95 | genes). 96 | lam_init: initial lambda value. 97 | mbsize: minibatch size. 98 | max_nepochs: maximum number of epochs. 99 | lr: learning rate. 100 | tol: tolerance around gene target number. 101 | start_temperature: starting temperature for BinConcrete samples. 102 | end_temperature: final temperature value. 103 | optimizer: optimization algorithm. 104 | lookback: number of epochs to wait for improvement before stopping. 105 | max_trials: maximum number of training rounds before returning an 106 | error. 107 | bar: whether to display tqdm progress bar for each round of training. 108 | verbose: verbosity. 109 | ''' 110 | # Reset candidate genes. 111 | all_inds = np.arange(self.train.max_input_size) 112 | all_candidates = np.array_equal( 113 | self.candidates, np.setdiff1d(all_inds, self.preselected)) 114 | all_train_inds = np.array_equal(self.train.inds, all_inds) 115 | all_val_inds = np.array_equal(self.val.inds, all_inds) 116 | if not (all_candidates and all_train_inds and all_val_inds): 117 | print('resetting candidate genes') 118 | self.set_genes() 119 | 120 | # Initialize architecture. 121 | if isinstance(self.loss_fn, utils.HurdleLoss): 122 | output_size = 2 * self.train.output_size 123 | elif isinstance(self.loss_fn, (nn.CrossEntropyLoss, nn.MSELoss)): 124 | output_size = self.train.output_size 125 | else: 126 | output_size = self.train.output_size 127 | print(f'Unknown loss function, assuming {self.loss_fn} requires ' 128 | f'{self.train.output_size} outputs') 129 | 130 | model = models.SelectorMLP(input_layer='binary_gates', 131 | input_size=self.train.input_size, 132 | output_size=output_size, 133 | hidden=self.hidden, 134 | activation=self.activation, 135 | preselected_inds=self.preselected_relative) 136 | model = model.to(self.device) 137 | 138 | # Determine lam_init, if necessary. 139 | if lam_init is None: 140 | if isinstance(self.loss_fn, utils.HurdleLoss): 141 | print('using HurdleLoss, starting with lam = 0.01') 142 | lam_init = 0.01 143 | elif isinstance(self.loss_fn, nn.MSELoss): 144 | print('using MSELoss, starting with lam = 0.01') 145 | lam_init = 0.01 146 | elif isinstance(self.loss_fn, nn.CrossEntropyLoss): 147 | print('using CrossEntropyLoss, starting with lam = 0.0001') 148 | lam_init = 0.0001 149 | else: 150 | print('unknown loss function, starting with lam = 0.0001') 151 | lam_init = 0.0001 152 | else: 153 | print(f'trying lam = {lam_init:.6f}') 154 | 155 | # Prepare for training and lambda search. 156 | assert 0 < target < self.train.input_size 157 | assert 0.1 <= tol < 0.5 158 | assert lam_init > 0 159 | lam_list = [0] 160 | num_remaining = self.train.input_size 161 | num_remaining_list = [num_remaining] 162 | lam = lam_init 163 | trials = 0 164 | 165 | # Iterate until num_remaining is near the target value. 166 | while np.abs(num_remaining - target) > target * tol: 167 | # Ensure not done. 168 | if trials == max_trials: 169 | raise ValueError( 170 | 'reached maximum number of trials without selecting the ' 171 | 'desired number of genes! The results may have large ' 172 | 'variance due to small dataset size, or the initial lam ' 173 | 'value may be bad') 174 | trials += 1 175 | 176 | # Train. 177 | model.fit(self.train, 178 | self.val, 179 | lr, 180 | mbsize, 181 | max_nepochs, 182 | start_temperature=start_temperature, 183 | end_temperature=end_temperature, 184 | loss_fn=self.loss_fn, 185 | lam=lam, 186 | optimizer=optimizer, 187 | lookback=lookback, 188 | bar=bar, 189 | verbose=verbose) 190 | 191 | # Extract inds. 192 | inds = model.input_layer.get_inds(threshold=0.5) 193 | num_remaining = len(inds) 194 | print(f'lam = {lam:.6f} yielded {num_remaining} genes') 195 | 196 | if np.abs(num_remaining - target) <= target * tol: 197 | print(f'done, lam = {lam:.6f} yielded {num_remaining} genes') 198 | 199 | else: 200 | # Guess next lam value. 201 | next_lam = modified_secant_method( 202 | lam, 1 / (1 + num_remaining), 1 / (1 + target), 203 | np.array(lam_list), 1 / (1 + np.array(num_remaining_list))) 204 | 205 | # Clip lam value for stability. 206 | next_lam = np.clip(next_lam, a_min=0.1 * lam, a_max=10 * lam) 207 | 208 | # Possibly reinitialize model. 209 | if num_remaining < target * (1 - tol): 210 | # BinaryGates layer is not great at allowing features 211 | # back in after inducing too much sparsity. 212 | print('Reinitializing model for next iteration') 213 | model = models.SelectorMLP( 214 | input_layer='binary_gates', 215 | input_size=self.train.input_size, 216 | output_size=output_size, 217 | hidden=self.hidden, 218 | activation=self.activation, 219 | preselected_inds=self.preselected_relative) 220 | model = model.to(self.device) 221 | else: 222 | print('Warm starting model for next iteration') 223 | 224 | # Prepare for next iteration. 225 | lam_list.append(lam) 226 | num_remaining_list.append(num_remaining) 227 | lam = next_lam 228 | print(f'next attempt is lam = {lam:.6f}') 229 | 230 | # Set eligible genes. 231 | true_inds = self.candidates[inds] 232 | self.set_genes(true_inds) 233 | return true_inds, model 234 | 235 | def select(self, 236 | num_genes, 237 | mbsize=64, 238 | max_nepochs=250, 239 | lr=1e-3, 240 | start_temperature=10.0, 241 | end_temperature=0.01, 242 | optimizer='Adam', 243 | bar=True, 244 | verbose=False): 245 | ''' 246 | Select genetic probes: train a model with BinaryMask layer to select 247 | a precise number of model inputs. 248 | 249 | Args: 250 | num_genes: number of genes to select (in addition to pre-selected 251 | genes). 252 | mbsize: minibatch size. 253 | max_nepochs: maximum number of epochs. 254 | lr: learning rate. 255 | start_temperature: starting temperature value for Concrete samples. 256 | end_temperature: final temperature value. 257 | optimizer: optimization algorithm. 258 | bar: whether to display tqdm progress bar. 259 | verbose: verbosity. 260 | ''' 261 | # Possibly reset candidate genes. 262 | included_inds = np.sort( 263 | np.concatenate([self.candidates, self.preselected])) 264 | candidate_train_inds = np.array_equal(self.train.inds, included_inds) 265 | candidate_val_inds = np.array_equal(self.val.inds, included_inds) 266 | if not (candidate_train_inds and candidate_val_inds): 267 | print('setting candidate genes in datasets') 268 | self.set_genes(self.candidates) 269 | 270 | # Initialize architecture. 271 | if isinstance(self.loss_fn, utils.HurdleLoss): 272 | output_size = 2 * self.train.output_size 273 | elif isinstance(self.loss_fn, (nn.CrossEntropyLoss, nn.MSELoss)): 274 | output_size = self.train.output_size 275 | else: 276 | output_size = self.train.output_size 277 | print(f'assuming loss function {self.loss_fn} requires ' 278 | f'{self.train.output_size} outputs') 279 | 280 | input_size = len(self.candidates) + len(self.preselected) 281 | model = models.SelectorMLP(input_layer='binary_mask', 282 | input_size=input_size, 283 | output_size=output_size, 284 | hidden=self.hidden, 285 | activation=self.activation, 286 | preselected_inds=self.preselected_relative, 287 | num_selections=num_genes).to(self.device) 288 | 289 | # Train. 290 | model.fit(self.train, 291 | self.val, 292 | lr, 293 | mbsize, 294 | max_nepochs, 295 | start_temperature, 296 | end_temperature, 297 | loss_fn=self.loss_fn, 298 | optimizer=optimizer, 299 | bar=bar, 300 | verbose=verbose) 301 | 302 | # Return genes. 303 | inds = model.input_layer.get_inds() 304 | true_inds = self.candidates[inds] 305 | print(f'done, selected {len(inds)} genes') 306 | return true_inds, model 307 | 308 | 309 | def modified_secant_method(x0, y0, y1, x, y): 310 | ''' 311 | A modified version of secant method, used here to determine the correct lam 312 | value. Note that we use x = lam and y = 1 / (1 + num_remaining) rather than 313 | y = num_remaining, because this gives better results. 314 | 315 | The standard secant method uses the two previous points to calculate a 316 | finite difference rather than an exact derivative (as in Newton's method). 317 | Here, we used a robustified derivative estimator: we find the curve, 318 | which passes through the most recent point (x0, y0), that minimizes a 319 | weighted least squares loss for all previous points (x, y). This improves 320 | robustness to nearby guesses (small |x - x'|) and noisy evaluations. 321 | 322 | Args: 323 | x0: most recent x. 324 | y0: most recent y. 325 | y1: target y value. 326 | x: all previous xs. 327 | y: all previous ys. 328 | ''' 329 | # Get robust slope estimate. 330 | weights = 1 / np.abs(x - x0) 331 | slope = ( 332 | np.sum(weights * (x - x0) * (y - y0)) / 333 | np.sum(weights * (x - x0) ** 2)) 334 | 335 | # Clip slope to minimum value. 336 | slope = np.clip(slope, a_min=1e-6, a_max=None) 337 | 338 | # Guess x1. 339 | x1 = x0 + (y1 - y0) / slope 340 | return x1 341 | -------------------------------------------------------------------------------- /persist/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from copy import deepcopy 4 | 5 | 6 | class HurdleLoss(nn.BCEWithLogitsLoss): 7 | ''' 8 | Hurdle loss that incorporates ZCELoss for each output, as well as MSE for 9 | each output that surpasses the threshold value. This can be understood as 10 | the negative log-likelihood of a hurdle distribution. 11 | 12 | Args: 13 | lam: weight for the ZCELoss term (the hurdle). 14 | thresh: threshold that an output must surpass to be considered turned on. 15 | ''' 16 | def __init__(self, lam=10.0, thresh=0): 17 | super().__init__() 18 | self.lam = lam 19 | self.thresh = thresh 20 | 21 | def forward(self, pred, target): 22 | # Verify prediction shape. 23 | if pred.shape[1] != 2 * target.shape[1]: 24 | raise ValueError( 25 | 'Predictions have incorrect shape! For HurdleLoss, the' 26 | ' predictions must have twice the dimensionality of targets' 27 | ' ({})'.format(target.shape[1] * 2)) 28 | 29 | # Reshape predictions, get distributional. 30 | pred = pred.reshape(*pred.shape[:-1], -1, 2) 31 | pred = pred.permute(-1, *torch.arange(len(pred.shape))[:-1]) 32 | mu = pred[0] 33 | p_logit = pred[1] 34 | 35 | # Calculate loss. 36 | zero_target = (target <= self.thresh).float().detach() 37 | hurdle_loss = super().forward(p_logit, zero_target) 38 | mse = (1 - zero_target) * (target - mu) ** 2 39 | 40 | loss = self.lam * hurdle_loss + mse 41 | return torch.mean(torch.sum(loss, dim=-1)) 42 | 43 | 44 | class ZCELoss(nn.BCEWithLogitsLoss): 45 | ''' 46 | Binary classification loss on whether outputs surpass a threshold. Expects 47 | logits. 48 | 49 | Args: 50 | thresh: threshold that an output must surpass to be considered on. 51 | ''' 52 | def __init__(self, thresh=0): 53 | super().__init__(reduction='none') 54 | 55 | def forward(self, pred, target): 56 | zero_target = (target == 0).float() 57 | loss = super().forward(pred, zero_target) 58 | return torch.mean(torch.sum(loss, dim=-1)) 59 | 60 | 61 | class ZeroAccuracy(nn.Module): 62 | ''' 63 | Classification accuracy on whether outputs surpass a threshold. Expects 64 | logits. 65 | 66 | Args: 67 | thresh: threshold that an output must surpass to be considered turned on. 68 | ''' 69 | def __init__(self, thresh=0): 70 | super().__init__() 71 | self.thresh = thresh 72 | 73 | def forward(self, pred, target): 74 | zero_pred = (pred > 0).float() 75 | zero_target = (target <= self.thresh).float() 76 | acc = (zero_pred == zero_target).float() 77 | return torch.mean(acc) 78 | 79 | 80 | class MSELoss(nn.Module): 81 | '''MSE loss that sums over output dimensions.''' 82 | def __init__(self): 83 | super().__init__() 84 | 85 | def forward(self, pred, target): 86 | loss = torch.sum((pred - target) ** 2, dim=-1) 87 | return torch.mean(loss) 88 | 89 | 90 | class Accuracy(nn.Module): 91 | '''0-1 classification loss.''' 92 | def __init__(self): 93 | super().__init__() 94 | 95 | def forward(self, pred, target): 96 | return (torch.argmax(pred, dim=1) == target).float().mean() 97 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | requirements = [ 4 | 'numpy', 5 | 'torch', 6 | 'tqdm', 7 | 'scikit-learn', 8 | 'h5py', 9 | 'anndata', 10 | 'scanpy', 11 | 'matplotlib', 12 | 'scipy', 13 | 'pandas', 14 | 'toml' 15 | ] 16 | 17 | setuptools.setup( 18 | name='persist', 19 | version='0.0.1', 20 | author='Ian Covert', 21 | author_email='icovert@cs.washington.edu', 22 | description='PERSIST: predictive and robust gene selection for spatial transcriptomics', 23 | long_description=''' 24 | Predictive and robust gene selection for spatial transcriptomics 25 | (PERSIST) is a computational approach to select a small number of 26 | target genes for FISH experiments. PERSIST relies on a reference 27 | scRNA-seq dataset, and it uses deep learning to identify genes that are 28 | predictive of the genome-wide expression profile or any other target of 29 | interest (e.g., transcriptomic cell types). PERSIST binarizes gene 30 | expression levels during the selection process to account for the 31 | measurement shift between scRNA-seq and FISH expression counts. 32 | ''', 33 | long_description_content_type='text/markdown', 34 | url='https://github.com/iancovert/persist', 35 | packages=['persist'], 36 | install_requires=requirements, 37 | classifiers=[ 38 | 'Programming Language :: Python :: 3', 39 | 'License :: OSI Approved :: MIT License', 40 | 'Operating System :: OS Independent', 41 | 'Intended Audience :: Science/Research', 42 | 'Topic :: Scientific/Engineering' 43 | ], 44 | python_requires='>=3.6', 45 | ) 46 | -------------------------------------------------------------------------------- /tests/test_expression_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.sparse import csr_matrix 3 | from persist import ExpressionDataset 4 | 5 | 6 | def test_expression_dataset(): 7 | """Tests ExpressionDataset for sparse inputs.""" 8 | np.random.seed(0) 9 | labels = np.array([1, 2, 3]).astype(np.int_) 10 | x = np.random.rand(3, 6) 11 | x[x < 0.5] = 0 12 | sparse_x = csr_matrix(x) 13 | 14 | # check dense entries can be recovered 15 | for i in range(x.shape[0]): 16 | assert np.all(x[i] == sparse_x[i].toarray()), f"Row {i} is not the same" 17 | 18 | dense_dataset = ExpressionDataset(x, labels) 19 | sparse_dataset = ExpressionDataset(sparse_x, labels) 20 | 21 | # ExpressionDataset flags: 22 | assert not (dense_dataset.data_issparse), "expecting non-sparse data" 23 | assert sparse_dataset.data_issparse, "expecting sparse data" 24 | 25 | # ExpressionDataset emits X and label; we want identical results in both cases. 26 | for (x1, l1), (x2, l2) in zip(dense_dataset, sparse_dataset): 27 | assert l1 == l2, "Labels returned by ExpressionDataset are not the same" 28 | assert np.all(x1 == x2), "Data returned by ExpressionDataset is not the same" 29 | 30 | return 31 | --------------------------------------------------------------------------------