├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── data
├── VISp_PERSIST_metadata.csv
└── VISp_markers.toml
├── notebooks
├── 00_data_proc.ipynb
├── 01_persist_supervised.ipynb
├── 02_persist_unsupervised.ipynb
├── 03_persist_pbmc3k_scanpy.ipynb
└── demo.ipynb
├── persist
├── __init__.py
├── data.py
├── layers.py
├── models.py
├── selection.py
└── utils.py
├── setup.py
└── tests
└── test_expression_dataset.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.ipynb linguist-language=Python
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | *.pyc
3 | .DS_Store
4 | .ipynb_checkpoints
5 | .eggs
6 | *.egg-info
7 | build
8 | dist
9 | propose.egg-info
10 | /*.ipynb
11 | *.pkl
12 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Ian Covert
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PERSIST
2 |
3 | **Predictive and robust gene selection for spatial transcriptomics (PERSIST)** is a computational approach to select target genes for FISH studies. PERSIST relies on a reference scRNA-seq dataset, and it uses deep learning to identify genes that are predictive of the genome-wide expression profile or any other target of interest (e.g., transcriptomic cell types). PERSIST also binarizes gene expression levels during the selection process, which helps account for the measurement shift between expression counts obtained by scRNA-seq and FISH.
4 |
5 | See the related [publication](https://www.nature.com/articles/s41467-023-37392-1) for details.
6 |
7 | To cite:
8 | ```bib
9 | @article{covert2023predictive,
10 | title={Predictive and robust gene selection for spatial transcriptomics},
11 | author={Covert, Ian and Gala, Rohan and Wang, Tim and Svoboda, Karel and S{\"u}mb{\"u}l, Uygar and Lee, Su-In},
12 | journal={Nature Communications},
13 | volume={14},
14 | number={1},
15 | pages={2091},
16 | year={2023},
17 | publisher={Nature Publishing Group UK London}
18 | }
19 | ```
20 |
21 | ## Installation
22 |
23 | You can install the package by cloning the repository and pip installing it as follows:
24 |
25 | ```bash
26 | pip install -e .
27 | ```
28 |
29 | This will automatically install any missing dependendencies, which may take a couple minutes. Please be sure to install a version of [PyTorch](https://pytorch.org/get-started/locally/) that is compatible with your GPU, as we highly recommend using a GPU to accelerate training.
30 |
31 | ## Usage
32 |
33 | PERSIST is designed to offer flexibility while requiring minimal tuning. For a demonstration of how it's used, from data preparation through gene selection, please see the following Jupyter notebooks:
34 |
35 | - [00_data_proc.ipynb](https://github.com/iancovert/persist/blob/main/notebooks/00_data_proc.ipynb) shows how to download and pre-process one of the datasets used in our paper (the VISp SmartSeq v4 dataset from [Tasic et al., 2018](https://www.nature.com/articles/s41586-018-0654-5))
36 | - [01_persist_supervised.ipynb](https://github.com/iancovert/persist/blob/main/notebooks/01_persist_supervised.ipynb) shows how to use PERSIST to select genes that are maximally predictive of cell type labels (the supervised case)
37 | - [02_persist_unsupervised.ipynb](https://github.com/iancovert/persist/blob/main/notebooks/02_persist_unsupervised.ipynb) shows how to use PERSIST to select genes that are maximally predictive of the genome-wide expression profile (the unsupervised case)
38 | - [03_persist_pbmc3k_scanpy.ipynb](https://github.com/iancovert/persist/blob/main/notebooks/03_persist_pbmc3k_scanpy.ipynb) supervised gene selection for a pbmc dataset downloaded directly from scanpy.
39 |
--------------------------------------------------------------------------------
/data/VISp_markers.toml:
--------------------------------------------------------------------------------
1 | # Marker gene list was obtained from Tasic et al. 2018 as displayed in Figures. 4c, 5e, and 5f:
2 | markers = [ "Slc30a3", "Cux2", "Rorb", "Deptor", "Scnn1a", "Rspo1", "Hsd11b1", "Batf3", "Oprk1", "Osr1", "Car3", "Fam84b", "Chrna6", "Pvalb", "Pappa2", "Foxp2", "Slc17a8", "Trhr", "Tshz2", "Rapgef3", "Trh", "Gpr139", "Nxph4", "Rprm", "Crym", "Sst", "Chodl", "Nos1", "Mme", "Tac1", "Tacr3", "Calb2", "Nr2f2", "Myh8", "Tac2", "Hpse", "Crhr2", "Crh", "Esm1", "Rxfp1", "Nts", "Gabrg1", "Th", "Calb1", "Akr1c18", "Sema3e", "Gpr149", "Reln", "Tpbg", "Cpne5", "Vipr2", "Nkx2-1", "Lamp5", "Ndnf", "Krt73", "Fam19a1", "Pax6", "Ntn1", "Plch2", "Lsp1", "Lhx6", "Vip", "Sncg", "Nptx2", "Gpr50", "Itih5", "Serpinf1", "Igfbp6", "Gpc3", "Lmo1", "Ptprt", "Rspo4", "Chat", "Crispld2", "Col15a1", "Pde1a",]
3 |
--------------------------------------------------------------------------------
/notebooks/00_data_proc.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "758da1a7-f1fa-4001-8f35-9a3ad32560af",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import numpy as np\n",
11 | "import pandas as pd\n",
12 | "import anndata as ad\n",
13 | "import scanpy as sc\n",
14 | "from scipy.sparse import csr_matrix"
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "id": "406f7632",
20 | "metadata": {},
21 | "source": [
22 | "**Count matrix and metadata for VISp dataset**\n",
23 | " - Download the count data from Allen Institute portal\n",
24 | " - Convert to AnnData format - see [this getting started with AnnData tutorial](https://anndata-tutorials.readthedocs.io/en/latest/getting-started.html)\n",
25 | " - Save the resulting object for later use"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": 2,
31 | "id": "a3a398cd",
32 | "metadata": {},
33 | "outputs": [
34 | {
35 | "name": "stdout",
36 | "output_type": "stream",
37 | "text": [
38 | " % Total % Received % Xferd Average Speed Time Time Time Current\n",
39 | " Dload Upload Total Spent Left Speed\n",
40 | "100 291M 100 291M 0 0 51.1M 0 0:00:05 0:00:05 --:--:-- 53.1M\n",
41 | "Archive: ../data/VISp.zip\n",
42 | " inflating: ../data/VISp/mouse_VISp_2018-06-14_exon-matrix.csv \n",
43 | " inflating: ../data/VISp/mouse_VISp_2018-06-14_genes-rows.csv \n",
44 | " inflating: ../data/VISp/mouse_VISp_2018-06-14_intron-matrix.csv \n",
45 | " inflating: ../data/VISp/mouse_VISp_2018-06-14_readme.txt \n",
46 | " inflating: ../data/VISp/mouse_VISp_2018-06-14_samples-columns.csv \n"
47 | ]
48 | }
49 | ],
50 | "source": [
51 | "# Download count matrices from https://portal.brain-map.org/atlases-and-data/rnaseq/mouse-v1-and-alm-smart-seq or use the shell commands below.\n",
52 | "!curl -o ../data/VISp.zip https://celltypes.brain-map.org/api/v2/well_known_file_download/694413985\n",
53 | "!unzip -d ../data/VISp ../data/VISp.zip\n",
54 | "!rm ../data/VISp.zip"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 3,
60 | "id": "c577e965",
61 | "metadata": {},
62 | "outputs": [
63 | {
64 | "data": {
65 | "text/html": [
66 | "
\n",
67 | "\n",
80 | "
\n",
81 | " \n",
82 | " \n",
83 | " | \n",
84 | " seq_name | \n",
85 | " class | \n",
86 | " subclass | \n",
87 | " cluster | \n",
88 | "
\n",
89 | " \n",
90 | " sample_id | \n",
91 | " | \n",
92 | " | \n",
93 | " | \n",
94 | " | \n",
95 | "
\n",
96 | " \n",
97 | " \n",
98 | " \n",
99 | " F1S4_160108_001_A01 | \n",
100 | " LS-15006_S09_E1-50 | \n",
101 | " GABAergic | \n",
102 | " Vip | \n",
103 | " Vip Arhgap36 Hmcn1 | \n",
104 | "
\n",
105 | " \n",
106 | " F1S4_160108_001_B01 | \n",
107 | " LS-15006_S10_E1-50 | \n",
108 | " GABAergic | \n",
109 | " Lamp5 | \n",
110 | " Lamp5 Lsp1 | \n",
111 | "
\n",
112 | " \n",
113 | " F1S4_160108_001_C01 | \n",
114 | " LS-15006_S11_E1-50 | \n",
115 | " GABAergic | \n",
116 | " Lamp5 | \n",
117 | " Lamp5 Lsp1 | \n",
118 | "
\n",
119 | " \n",
120 | " F1S4_160108_001_D01 | \n",
121 | " LS-15006_S12_E1-50 | \n",
122 | " GABAergic | \n",
123 | " Vip | \n",
124 | " Vip Crispld2 Htr2c | \n",
125 | "
\n",
126 | " \n",
127 | " F1S4_160108_001_E01 | \n",
128 | " LS-15006_S13_E1-50 | \n",
129 | " GABAergic | \n",
130 | " Lamp5 | \n",
131 | " Lamp5 Plch2 Dock5 | \n",
132 | "
\n",
133 | " \n",
134 | "
\n",
135 | "
"
136 | ],
137 | "text/plain": [
138 | " seq_name class subclass \\\n",
139 | "sample_id \n",
140 | "F1S4_160108_001_A01 LS-15006_S09_E1-50 GABAergic Vip \n",
141 | "F1S4_160108_001_B01 LS-15006_S10_E1-50 GABAergic Lamp5 \n",
142 | "F1S4_160108_001_C01 LS-15006_S11_E1-50 GABAergic Lamp5 \n",
143 | "F1S4_160108_001_D01 LS-15006_S12_E1-50 GABAergic Vip \n",
144 | "F1S4_160108_001_E01 LS-15006_S13_E1-50 GABAergic Lamp5 \n",
145 | "\n",
146 | " cluster \n",
147 | "sample_id \n",
148 | "F1S4_160108_001_A01 Vip Arhgap36 Hmcn1 \n",
149 | "F1S4_160108_001_B01 Lamp5 Lsp1 \n",
150 | "F1S4_160108_001_C01 Lamp5 Lsp1 \n",
151 | "F1S4_160108_001_D01 Vip Crispld2 Htr2c \n",
152 | "F1S4_160108_001_E01 Lamp5 Plch2 Dock5 "
153 | ]
154 | },
155 | "execution_count": 3,
156 | "metadata": {},
157 | "output_type": "execute_result"
158 | }
159 | ],
160 | "source": [
161 | "# Load VISp dataset\n",
162 | "filename = '../data/VISp/mouse_VISp_2018-06-14_exon-matrix.csv'\n",
163 | "expr_df = pd.read_csv(filename, header=0, index_col=0, delimiter=',').transpose()\n",
164 | "expr = expr_df.values\n",
165 | "\n",
166 | "# Find gene names\n",
167 | "filename = '../data/VISp/mouse_VISp_2018-06-14_genes-rows.csv'\n",
168 | "genes_df = pd.read_csv(filename, header=0, index_col=0, delimiter=',')\n",
169 | "gene_symbol = genes_df.index.values\n",
170 | "gene_ids = genes_df['gene_entrez_id'].values\n",
171 | "gene_names = np.array([gene_symbol[np.where(gene_ids == name)[0][0]] for name in expr_df.columns])\n",
172 | "\n",
173 | "# Get metadata and save restrict to relevant fields\n",
174 | "filename = '../data/VISp/mouse_VISp_2018-06-14_samples-columns.csv'\n",
175 | "obs = pd.read_csv(filename, header=0, index_col=0, delimiter=',', encoding='iso-8859-1')\n",
176 | "\n",
177 | "obs = obs.reset_index()\n",
178 | "obs = obs[['sample_name','seq_name','class','subclass','cluster']]\n",
179 | "obs = obs.rename(columns={'sample_name':'sample_id'})\n",
180 | "obs = obs.set_index('sample_id')\n",
181 | "obs.head()"
182 | ]
183 | },
184 | {
185 | "cell_type": "code",
186 | "execution_count": 4,
187 | "id": "2443ffbd",
188 | "metadata": {},
189 | "outputs": [],
190 | "source": [
191 | "# compose and store anndata object for efficient read/write\n",
192 | "adata = ad.AnnData(X=csr_matrix(expr))\n",
193 | "adata.var_names = gene_names\n",
194 | "adata.var.index.set_names('genes', inplace=True)\n",
195 | "adata.obs = obs\n",
196 | "adata.write('../data/VISp.h5ad')"
197 | ]
198 | },
199 | {
200 | "cell_type": "markdown",
201 | "id": "c5a11dd0",
202 | "metadata": {},
203 | "source": [
204 | "**Filtering samples**\n",
205 | "\n",
206 | "The next code block is optional, and requires `VISp_PERSIST_metadata.csv` which contains:\n",
207 | "- cell type labels at different resolutions of the taxonomy (see manuscript for details)\n",
208 | "- sample ids to filter out non-neuronal cells\n",
209 | "\n",
210 | "In the following, we will\n",
211 | "1. restrict cells only to those samples specified in `VISp_PERSIST_metadata.csv`\n",
212 | "2. append metadata from `VISp_PERSIST_metadata.csv` to the AnnData object\n",
213 | "3. normalize counts, determine highly variable genes using scanpy functions\n",
214 | "3. save a filtered AnnData object into a .h5ad file for subsequent use"
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": 5,
220 | "id": "10b09db4",
221 | "metadata": {},
222 | "outputs": [
223 | {
224 | "name": "stdout",
225 | "output_type": "stream",
226 | "text": [
227 | " seq_name cell_types_98 cell_types_50 cell_types_25\n",
228 | "0 LS-15006_S09_E1-50 Vip Arhgap36 Hmcn1 n70 n66\n",
229 | "1 LS-15006_S10_E1-50 Lamp5 Lsp1 Lamp5 Lsp1 n78\n",
230 | "\n",
231 | "old shape: (15413, 45768)\n",
232 | "new shape: (13349, 45768)\n"
233 | ]
234 | }
235 | ],
236 | "source": [
237 | "adata = ad.read_h5ad('../data/VISp.h5ad')\n",
238 | "persist_df = pd.read_csv('../data/VISp_PERSIST_metadata.csv')\n",
239 | "print(persist_df.head(2))\n",
240 | "print(f'\\nold shape: {adata.shape}')\n",
241 | "\n",
242 | "adata = adata[adata.obs['seq_name'].isin(persist_df['seq_name']), :]\n",
243 | "obs = adata.obs.copy().reset_index()\n",
244 | "obs = obs.merge(right=persist_df, how='left', left_on='seq_name', right_on='seq_name')\n",
245 | "obs = obs.set_index('sample_id')\n",
246 | "\n",
247 | "adata = ad.AnnData(X=adata.X, obs=obs, var=adata.var)\n",
248 | "print(f'new shape: {adata.shape}')"
249 | ]
250 | },
251 | {
252 | "cell_type": "markdown",
253 | "id": "3b1398f1",
254 | "metadata": {},
255 | "source": [
256 | "**Normalization and preliminary gene selection**"
257 | ]
258 | },
259 | {
260 | "cell_type": "code",
261 | "execution_count": 6,
262 | "id": "11401b0c",
263 | "metadata": {},
264 | "outputs": [],
265 | "source": [
266 | "# transforms data in adata.X\n",
267 | "adata.layers['log1pcpm'] = sc.pp.normalize_total(adata, target_sum=1e6, inplace=False)['X']\n",
268 | "\n",
269 | "# transforms data in layers['lognorm'] inplace\n",
270 | "sc.pp.log1p(adata, layer='log1pcpm')\n"
271 | ]
272 | },
273 | {
274 | "cell_type": "code",
275 | "execution_count": 7,
276 | "id": "2b8cd61f",
277 | "metadata": {},
278 | "outputs": [],
279 | "source": [
280 | "import toml\n",
281 | "dat = toml.load('../data/VISp_markers.toml')"
282 | ]
283 | },
284 | {
285 | "cell_type": "code",
286 | "execution_count": 8,
287 | "id": "c42e06e6",
288 | "metadata": {},
289 | "outputs": [],
290 | "source": [
291 | "# introduces \"highly_variable\" column to adata.var\n",
292 | "sc.pp.highly_variable_genes(adata, \n",
293 | " layer='log1pcpm', \n",
294 | " flavor='cell_ranger',\n",
295 | " n_top_genes=10000, \n",
296 | " inplace=True)"
297 | ]
298 | },
299 | {
300 | "cell_type": "code",
301 | "execution_count": 9,
302 | "id": "a5cdf670",
303 | "metadata": {},
304 | "outputs": [
305 | {
306 | "data": {
307 | "text/plain": [
308 | "AnnData object with n_obs × n_vars = 13349 × 45768\n",
309 | " obs: 'seq_name', 'class', 'subclass', 'cluster', 'cell_types_98', 'cell_types_50', 'cell_types_25'\n",
310 | " var: 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'markers'\n",
311 | " uns: 'log1p', 'hvg'\n",
312 | " layers: 'log1pcpm'"
313 | ]
314 | },
315 | "execution_count": 9,
316 | "metadata": {},
317 | "output_type": "execute_result"
318 | }
319 | ],
320 | "source": [
321 | "# Create new field with marker genes\n",
322 | "adata.var['markers'] = np.isin(adata.var.index.values,dat['markers'])\n",
323 | "adata"
324 | ]
325 | },
326 | {
327 | "cell_type": "code",
328 | "execution_count": 10,
329 | "id": "bb0bb3c2",
330 | "metadata": {},
331 | "outputs": [
332 | {
333 | "data": {
334 | "text/plain": [
335 | ""
337 | ]
338 | },
339 | "execution_count": 10,
340 | "metadata": {},
341 | "output_type": "execute_result"
342 | }
343 | ],
344 | "source": [
345 | "# This is a sparse matrix\n",
346 | "adata.layers['log1pcpm']"
347 | ]
348 | },
349 | {
350 | "cell_type": "code",
351 | "execution_count": 11,
352 | "id": "dc70f25f",
353 | "metadata": {},
354 | "outputs": [
355 | {
356 | "data": {
357 | "text/plain": [
358 | "array([[0. , 0. , 3.84258223, 4.40539198, 0. ],\n",
359 | " [0. , 0. , 4.16454502, 4.52873471, 0.42111822],\n",
360 | " [0. , 0. , 3.82509693, 3.56268307, 0. ],\n",
361 | " [0. , 0. , 3.93543299, 0. , 0. ],\n",
362 | " [0. , 0. , 5.40677164, 4.62215864, 0. ]])"
363 | ]
364 | },
365 | "execution_count": 11,
366 | "metadata": {},
367 | "output_type": "execute_result"
368 | }
369 | ],
370 | "source": [
371 | "# For sparse matrix `M`, `M.toarray()` to convert it to dense array\n",
372 | "adata.layers['log1pcpm'][:5,:5].toarray()"
373 | ]
374 | },
375 | {
376 | "cell_type": "markdown",
377 | "id": "f8b21c66",
378 | "metadata": {},
379 | "source": [
380 | "**Parting notes:**\n",
381 | "\n",
382 | "The anndata object created in this way has a few different fields that we will end up using with PERSIST\n",
383 | "1. The raw counts are in `adata.X`\n",
384 | "2. The normalized counts (log1p of CPM values) are in `adata.layers['log1pcpm']`\n",
385 | "3. All metadata (cell type labels etc. for supervised mode in PERSIST) is in `adata.obs`\n",
386 | "4. A coarse selection of genes is in `adata.var['highly_variable']`\n",
387 | "5. Marker genes defined by Tasic et al. are indicated by `adata.var['markers']`"
388 | ]
389 | },
390 | {
391 | "cell_type": "code",
392 | "execution_count": 12,
393 | "id": "53129c25",
394 | "metadata": {},
395 | "outputs": [],
396 | "source": [
397 | "# adata_hvg is a view. We'll convert it to a new AnnData object and write it out. \n",
398 | "adata_hvg = ad.AnnData(X=adata.X,\n",
399 | " obs=adata.obs, \n",
400 | " var=adata.var[['highly_variable']],\n",
401 | " layers=adata.layers, uns=adata.uns)\n",
402 | "adata_hvg.write('../data/VISp_filtered_cells.h5ad')"
403 | ]
404 | }
405 | ],
406 | "metadata": {
407 | "kernelspec": {
408 | "display_name": "Python 3.8.15 ('persist')",
409 | "language": "python",
410 | "name": "python3"
411 | },
412 | "language_info": {
413 | "codemirror_mode": {
414 | "name": "ipython",
415 | "version": 3
416 | },
417 | "file_extension": ".py",
418 | "mimetype": "text/x-python",
419 | "name": "python",
420 | "nbconvert_exporter": "python",
421 | "pygments_lexer": "ipython3",
422 | "version": "3.11.0"
423 | },
424 | "vscode": {
425 | "interpreter": {
426 | "hash": "660691d0bb1e24e4a68343475da76b12317d18d7509dab1fd0158534dd4eebe4"
427 | }
428 | }
429 | },
430 | "nbformat": 4,
431 | "nbformat_minor": 5
432 | }
433 |
--------------------------------------------------------------------------------
/persist/__init__.py:
--------------------------------------------------------------------------------
1 | from . import data
2 | from . import utils
3 | from . import layers
4 | from . import models
5 | from . import selection
6 | from persist.selection import PERSIST
7 | from persist.data import GeneSet, load_set
8 | from persist.data import ExpressionDataset, HDF5ExpressionDataset
9 | from persist.utils import HurdleLoss, MSELoss, Accuracy
10 |
--------------------------------------------------------------------------------
/persist/data.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import pickle
3 | import numpy as np
4 | from torch.utils.data import Dataset
5 | from sklearn.model_selection import train_test_split
6 | from scipy.sparse import issparse
7 |
8 |
9 | class ExpressionDataset(Dataset):
10 | '''
11 | Gene expression dataset capable of using subsets of inputs.
12 |
13 | Args:
14 | data: array of inputs with size (samples, dim).
15 | labels: array of labels with size (samples,) or (samples, dim).
16 | '''
17 |
18 | def __init__(self, data, output):
19 | self.input_size = data.shape[1]
20 | self._data = data.astype(np.float32)
21 | if len(output.shape) == 1:
22 | # Output is classification labels.
23 | self.output_size = len(np.unique(output))
24 | self._output = output.astype(np.int_)
25 | else:
26 | # Output is regression values.
27 | self.output_size = output.shape[1]
28 | self._output = output.astype(np.float32)
29 | self.set_inds(None)
30 | self.set_output_inds(None)
31 |
32 | # patch to work with sparse data without converting to dense beforehand
33 | self.data_issparse = issparse(data)
34 | self.output_issparse = issparse(output)
35 |
36 | def set_inds(self, inds, delete_remaining=False):
37 | '''
38 | Set input indices to be returned.
39 |
40 | Args:
41 | inds: list/array of selected indices.
42 | delete_remaining: whether to permanently delete the indices that are
43 | not selected.
44 | '''
45 | if inds is None:
46 | assert not delete_remaining
47 | inds = np.arange(self._data.shape[1])
48 |
49 | # Set input and inds.
50 | self.inds = inds
51 | self.data = self._data
52 | self.input_size = len(inds)
53 | else:
54 | # Verify inds.
55 | inds = np.sort(inds)
56 | assert len(inds) > 0
57 | assert (inds[0] >= 0) and (inds[-1] < self._data.shape[1])
58 | assert np.all(np.unique(inds, return_counts=True)[1] == 1)
59 | if len(inds) == self.input_size:
60 | assert not delete_remaining
61 |
62 | self.input_size = len(inds)
63 | if delete_remaining:
64 | # Reset data and input size.
65 | self._data = self._data[:, inds]
66 | self.data = self._data
67 | self.inds = np.arange(len(inds))
68 | else:
69 | # Set input and inds.
70 | self.inds = inds
71 | self.data = self._data[:, inds]
72 |
73 | def set_output_inds(self, inds, delete_remaining=False):
74 | '''
75 | Set output inds to be returned. Only for use with multivariate outputs.
76 |
77 | Args:
78 | inds: list/array of selected indices.
79 | delete_remaining: whether to permanently delete the indices that are
80 | not selected.
81 | '''
82 | if inds is None:
83 | assert not delete_remaining
84 |
85 | # Set output and inds.
86 | self.output = self._output
87 | if len(self._output.shape) == 1:
88 | # Classification labels.
89 | self.output_size = len(np.unique(self._output))
90 | self.output_inds = None
91 | else:
92 | # Regression labels.
93 | self.output_size = self._output.shape[1]
94 | self.output_inds = np.arange(self.output_size)
95 | else:
96 | # Verify that there are multiple output inds.
97 | assert len(self._output.shape) == 2
98 |
99 | # Verify inds.
100 | inds = np.sort(inds)
101 | assert len(inds) > 0
102 | assert (inds[0] >= 0) and (inds[-1] < self._output.shape[1])
103 | assert np.all(np.unique(inds, return_counts=True)[1] == 1)
104 | if len(inds) == self.output_size:
105 | assert not delete_remaining
106 |
107 | self.output_size = len(inds)
108 | if delete_remaining:
109 | # Reset data and input size.
110 | self._output = self._output[:, inds]
111 | self.output = self._output
112 | self.output_inds = np.arange(len(inds))
113 | else:
114 | # Set output and inds.
115 | self.output_inds = inds
116 | self.output = self._output[:, inds]
117 |
118 | @property
119 | def max_input_size(self):
120 | return self._data.shape[1]
121 |
122 | def __len__(self):
123 | return self._data.shape[0]
124 |
125 | def __getitem__(self, index):
126 | if self.data_issparse:
127 | data = self.data[index].toarray().flatten()
128 | else:
129 | data = self.data[index]
130 |
131 | if self.output_issparse:
132 | output = self.output[index].toarray().flatten()
133 | else:
134 | output = self.output[index]
135 | return data, output
136 |
137 |
138 | class HDF5ExpressionDataset(Dataset):
139 | '''
140 | Dataset wrapper, capable of using subsets of inputs.
141 |
142 | Args:
143 | filename: HDF5 filename.
144 | data_name: key for data array.
145 | label_name: key for labels.
146 | sample_inds: list of indices for rows to be sampled.
147 | initialize: whether to initialize by opening the HDF5 file. This should
148 | only be done when using a data loader with no worker threads.
149 | '''
150 |
151 | def __init__(self, filename, data_name, label_name,
152 | sample_inds=None, initialize=False):
153 | # Set up data variables.
154 | self.filename = filename
155 | self.data_name = data_name
156 | self.label_name = label_name
157 |
158 | # Set sample inds.
159 | hf = h5py.File(filename, 'r')
160 | data = hf[self.data_name]
161 | labels = hf[self.label_name]
162 | if sample_inds is None:
163 | sample_inds = np.arange(len(data))
164 | self.sample_inds = sample_inds
165 |
166 | # Set input, output size.
167 | self.input_size = data.shape[1]
168 | if labels.ndim == 1:
169 | # Classification labels.
170 | self.output_size = len(np.unique(labels))
171 | self.multiple_outputs = False
172 | else:
173 | # Regression labels.
174 | self.output_size = labels.shape[1]
175 | self.multiple_outputs = True
176 | hf.close()
177 |
178 | # Set input inds.
179 | self.all_inds = np.arange(self.input_size)
180 | self.set_inds(None)
181 |
182 | # Set output inds.
183 | if self.multiple_outputs:
184 | self.all_output_inds = np.arange(self.output_size)
185 | else:
186 | self.all_output_inds = None
187 | self.set_output_inds(None)
188 |
189 | # Initialize.
190 | if initialize:
191 | self.init_worker(0)
192 |
193 | def set_inds(self, inds, delete_remaining=False):
194 | '''
195 | Set input indices to be returned.
196 |
197 | Args:
198 | inds: list/array of selected indices.
199 | delete_remaining: whether to permanently delete the indices that are
200 | not selected.
201 | '''
202 | if inds is None:
203 | assert not delete_remaining
204 | self.inds = np.arange(len(self.all_inds))
205 | self.relative_inds = self.all_inds
206 | self.input_size = len(self.all_inds)
207 | else:
208 | # Verify inds.
209 | inds = np.sort(inds)
210 | assert len(inds) > 0
211 | assert (inds[0] >= 0) and (inds[-1] < len(self.all_inds))
212 | assert np.all(np.unique(inds, return_counts=True)[1] == 1)
213 |
214 | self.input_size = len(inds)
215 | if delete_remaining:
216 | self.inds = np.arange(len(inds))
217 | self.all_inds = self.all_inds[inds]
218 | self.relative_inds = self.all_inds
219 | else:
220 | self.inds = inds
221 | self.relative_inds = self.all_inds[inds]
222 |
223 | def set_output_inds(self, inds, delete_remaining=False):
224 | '''
225 | Set output inds to be returned. Only for use with multivariate outputs.
226 |
227 | Args:
228 | inds: list/array of selected indices.
229 | delete_remaining: whether to permanently delete the indices that are
230 | not selected.
231 | '''
232 | if inds is None:
233 | assert not delete_remaining
234 | if self.multiple_outputs:
235 | self.output_inds = np.arange(len(self.all_output_inds))
236 | self.output_relative_inds = self.all_output_inds
237 | self.output_size = len(self.all_output_inds)
238 | else:
239 | self.output_inds = None
240 | self.output_relative_inds = None
241 | else:
242 | # Verify that there are multiple output inds.
243 | assert self.multiple_outputs
244 |
245 | # Verify inds.
246 | inds = np.sort(inds)
247 | assert len(inds) > 0
248 | assert (inds[0] >= 0) and (inds[-1] < len(self.all_output_inds))
249 | assert np.all(np.unique(inds, return_counts=True)[1] == 1)
250 | if len(inds) == self.output_size:
251 | assert not delete_remaining
252 |
253 | self.output_size = len(inds)
254 | if delete_remaining:
255 | self.output_inds = np.arange(len(inds))
256 | self.all_output_inds = self.all_output_inds[inds]
257 | self.output_relative_inds = self.all_output_inds
258 | else:
259 | self.output_inds = inds
260 | self.output_relative_inds = self.all_output_inds[inds]
261 |
262 | def init_worker(self, worker_id):
263 | '''Initialize worker in data loader thread.'''
264 | self.h5 = h5py.File(self.filename, 'r', swmr=True)
265 |
266 | @property
267 | def max_input_size(self):
268 | return len(self.all_inds)
269 |
270 | def __len__(self):
271 | return len(self.sample_inds)
272 |
273 | def __getitem__(self, index):
274 | # Possibly initialize worker to open HDF5 file.
275 | if not hasattr(self, 'h5'):
276 | self.init_worker(0)
277 |
278 | index = self.sample_inds[index]
279 | data = self.h5[self.data_name][index][self.relative_inds]
280 | labels = self.h5[self.label_name][index]
281 | if self.output_relative_inds is not None:
282 | labels = labels[self.output_relative_inds]
283 | return data, labels
284 |
285 |
286 | def split_data(data, seed=123, val_portion=0.1, test_portion=0.1):
287 | '''Split data into train, val, test.'''
288 | N = data.shape[0]
289 | N_val = int(val_portion * N)
290 | N_test = int(test_portion * N)
291 | train, test = train_test_split(data, test_size=N_test, random_state=seed)
292 | train, val = train_test_split(train, test_size=N_val, random_state=seed+1)
293 | return train, val, test
294 |
295 |
296 | def bootstrapped_dataset(dataset, seed=None):
297 | '''Sample a bootstrapped dataset.'''
298 | if isinstance(dataset, ExpressionDataset):
299 | data = dataset.data
300 | labels = dataset.output
301 | if seed:
302 | np.random.seed(seed)
303 | N = len(data)
304 | inds = np.random.choice(N, size=N, replace=True)
305 | return ExpressionDataset(data[inds], labels[inds])
306 | elif isinstance(dataset, HDF5ExpressionDataset):
307 | inds = dataset.sample_inds
308 | inds = np.random.choice(inds, size=len(inds), replace=True)
309 | return HDF5ExpressionDataset(dataset.filename, dataset.data_name,
310 | dataset.label_name, sample_inds=inds)
311 | else:
312 | raise ValueError('dataset must be ExpressionDataset or '
313 | 'HDF5ExpressionDataset')
314 |
315 |
316 | class GeneSet:
317 | '''
318 | Set of genes, represented by their indices and names.
319 |
320 | Args:
321 | inds: gene indices.
322 | names: gene names.
323 | '''
324 |
325 | def __init__(self, inds, names):
326 | self.inds = inds
327 | self.names = names
328 |
329 | def save(self, filename):
330 | '''Save object by pickling.'''
331 | with open(filename, 'wb') as f:
332 | pickle.dump(self, f)
333 |
334 |
335 | def load_set(filename):
336 | '''Load GeneSet object.'''
337 | with open(filename, 'rb') as f:
338 | subset = pickle.load(f)
339 | if isinstance(subset, GeneSet):
340 | return subset
341 | else:
342 | raise ValueError('object is not GeneSet')
343 |
--------------------------------------------------------------------------------
/persist/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def clamp_probs(probs):
7 | '''Clamp probabilities to ensure stable logs.'''
8 | eps = torch.finfo(probs.dtype).eps
9 | return torch.clamp(probs, min=eps, max=1-eps)
10 |
11 |
12 | def concrete_sample(logits, temperature, shape=torch.Size([])):
13 | '''
14 | Sampling for Concrete distribution (see eq. 10 of Maddison et al., 2017).
15 |
16 | Args:
17 | logits: Concrete logits parameters.
18 | temperature: Concrete temperature parameter.
19 | shape: sample shape.
20 | '''
21 | uniform_shape = torch.Size(shape) + logits.shape
22 | u = clamp_probs(torch.rand(uniform_shape, dtype=torch.float32,
23 | device=logits.device))
24 | gumbels = - torch.log(- torch.log(u))
25 | scores = (logits + gumbels) / temperature
26 | return scores.softmax(dim=-1)
27 |
28 |
29 | def bernoulli_concrete_sample(logits, temperature, shape=torch.Size([])):
30 | '''
31 | Sampling for BinConcrete distribution (see PyTorch source code, differs
32 | slightly from eq. 16 of Maddison et al., 2017).
33 |
34 | Args:
35 | logits: tensor of BinConcrete logits parameters.
36 | temperature: BinConcrete temperature parameter.
37 | shape: sample shape.
38 | '''
39 | uniform_shape = torch.Size(shape) + logits.shape
40 | u = clamp_probs(torch.rand(uniform_shape, dtype=torch.float32,
41 | device=logits.device))
42 | return torch.sigmoid((F.logsigmoid(logits) - F.logsigmoid(-logits)
43 | + torch.log(u) - torch.log(1 - u)) / temperature)
44 |
45 |
46 | class BinaryMask(nn.Module):
47 | '''
48 | Input layer that selects features by learning a k-hot mask.
49 |
50 | Args:
51 | input_size: number of inputs.
52 | num_selections: number of features to select.
53 | temperature: temperature for Concrete samples.
54 | gamma: used to map learned parameters to logits (helps convergence speed).
55 | '''
56 |
57 | def __init__(self,
58 | input_size,
59 | num_selections,
60 | temperature=10.0,
61 | gamma=1/3):
62 | super().__init__()
63 | self._logits = nn.Parameter(
64 | torch.zeros(num_selections, input_size, dtype=torch.float32,
65 | requires_grad=True))
66 | self.input_size = input_size
67 | self.num_selections = num_selections
68 | self.output_size = input_size
69 | self.temperature = temperature
70 | self.gamma = gamma
71 |
72 | @property
73 | def logits(self):
74 | return self._logits / self.gamma
75 |
76 | @property
77 | def probs(self):
78 | return (self.logits).softmax(dim=1)
79 |
80 | def sample(self, n_samples):
81 | '''Sample approximate k-hot vectors.'''
82 | samples = concrete_sample(
83 | self.logits, self.temperature, torch.Size([n_samples]))
84 | return torch.max(samples, dim=-2).values
85 |
86 | def forward(self, x):
87 | '''Sample and apply mask.'''
88 | m = self.sample(len(x))
89 | x = x * m
90 | return x, m
91 |
92 | def get_inds(self):
93 | '''Get selected indices.'''
94 | inds = torch.argmax(self.logits, dim=1)
95 | return torch.sort(inds)[0].cpu().data.numpy()
96 |
97 | def extra_repr(self):
98 | return (f'input_size={self.input_size}, temperature={self.temperature},'
99 | f' num_selections={self.num_selections}')
100 |
101 |
102 | class BinaryGates(nn.Module):
103 | '''
104 | Input layer that selects features by learning binary gates for each feature,
105 | similar to [1].
106 |
107 | [1] Dropout Feature Ranking for Deep Learning Models (Chang et al., 2017)
108 |
109 | Args:
110 | input_size: number of inputs.
111 | temperature: temperature for BinConcrete samples.
112 | init: initial value for each gate's probability of being 1.
113 | gamma: used to map learned parameters to logits (helps convergence speed).
114 | '''
115 |
116 | def __init__(self,
117 | input_size,
118 | temperature=0.1,
119 | init=0.99,
120 | gamma=1/2):
121 | super().__init__()
122 | init_logit = - torch.log(1 / torch.tensor(init) - 1) * gamma
123 | self._logits = nn.Parameter(torch.full(
124 | (input_size,), init_logit, dtype=torch.float32, requires_grad=True))
125 | self.input_size = input_size
126 | self.output_size = input_size
127 | self.temperature = temperature
128 | self.gamma = gamma
129 |
130 | @property
131 | def logits(self):
132 | return self._logits / self.gamma
133 |
134 | @property
135 | def probs(self):
136 | return torch.sigmoid(self.logits)
137 |
138 | def sample(self, n_samples):
139 | '''Sample approximate binary masks.'''
140 | return bernoulli_concrete_sample(
141 | self.logits, self.temperature, torch.Size([n_samples]))
142 |
143 | def forward(self, x):
144 | '''Sample and apply mask.'''
145 | m = self.sample(len(x))
146 | x = x * m
147 | return x, m
148 |
149 | def get_inds(self, num_features=None, threshold=None):
150 | '''
151 | Get selected indices.
152 |
153 | Args:
154 | num_features: number of top features to return.
155 | threshold: probability threshold for determining selected features.
156 | '''
157 | if (num_features is None) == (threshold is None):
158 | raise ValueError('exactly one of num_features and threshold must be'
159 | ' specified')
160 |
161 | if num_features:
162 | inds = torch.argsort(self.probs)[-num_features:]
163 | elif threshold:
164 | inds = (self.probs > threshold).nonzero()[:, 0]
165 | return torch.sort(inds)[0].cpu().data.numpy()
166 |
167 | def extra_repr(self):
168 | return f'input_size={self.input_size}, temperature={self.temperature}'
169 |
170 |
171 | class ConcreteSelector(nn.Module):
172 | '''
173 | Input layer that selects features by learning a binary matrix, based on [2].
174 |
175 | [2] Concrete Autoencoders for Differentiable Feature Selection and
176 | Reconstruction (Balin et al., 2019)
177 |
178 | Args:
179 | input_size: number of inputs.
180 | num_selections: number of features to select.
181 | temperature: temperature for Concrete samples.
182 | gamma: used to map learned parameters to logits (helps convergence speed).
183 | '''
184 |
185 | def __init__(self,
186 | input_size,
187 | num_selections,
188 | temperature=10.0,
189 | gamma=1/3):
190 | super().__init__()
191 | self._logits = nn.Parameter(
192 | torch.zeros(num_selections, input_size, dtype=torch.float32,
193 | requires_grad=True))
194 | self.input_size = input_size
195 | self.num_selections = num_selections
196 | self.output_size = num_selections
197 | self.temperature = temperature
198 | self.gamma = gamma
199 |
200 | @property
201 | def logits(self):
202 | return self._logits / self.gamma
203 |
204 | @property
205 | def probs(self):
206 | return self.logits.softmax(dim=1)
207 |
208 | def sample(self, n_samples):
209 | '''Sample approximate binary matrices.'''
210 | return concrete_sample(
211 | self.logits, self.temperature, torch.Size([n_samples]))
212 |
213 | def forward(self, x):
214 | '''Sample and apply selector matrix.'''
215 | M = self.sample(len(x))
216 | x = torch.matmul(M, x.unsqueeze(2)).squeeze(2)
217 | return x, M
218 |
219 | def get_inds(self):
220 | '''Get selected indices.'''
221 | inds = torch.argmax(self.logits, dim=1)
222 | return torch.sort(inds)[0].cpu().data.numpy()
223 |
224 | def extra_repr(self):
225 | return (f'input_size={self.input_size}, temperature={self.temperature},'
226 | f' num_selections={self.num_selections}')
227 |
--------------------------------------------------------------------------------
/persist/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import numpy as np
5 | from persist import layers
6 | from copy import deepcopy
7 | from tqdm.auto import tqdm
8 | from torch.utils.data import DataLoader
9 |
10 |
11 | def restore_parameters(model, best_model):
12 | '''Copy parameter values from best_model to model.'''
13 | for params, best_params in zip(model.parameters(), best_model.parameters()):
14 | params.data = best_params
15 |
16 |
17 | def input_layer_penalty(input_layer, m):
18 | if isinstance(input_layer, layers.BinaryGates):
19 | return torch.mean(torch.sum(m, dim=1))
20 | else:
21 | raise ValueError('only BinaryGates layer has penalty')
22 |
23 |
24 | def input_layer_fix(input_layer):
25 | '''Fix collisions in the input layer.'''
26 | required_fix = False
27 |
28 | if isinstance(input_layer, (layers.BinaryMask, layers.ConcreteSelector)):
29 | # Extract logits.
30 | logits = input_layer._logits
31 | argmax = torch.argmax(logits, dim=1).cpu().data.numpy()
32 |
33 | # Locate collisions and reinitialize.
34 | for i in range(len(argmax) - 1):
35 | if argmax[i] in argmax[i+1:]:
36 | required_fix = True
37 | logits.data[i] = torch.randn(
38 | logits[i].shape, dtype=logits.dtype, device=logits.device)
39 | return required_fix
40 |
41 | return required_fix
42 |
43 |
44 | def input_layer_summary(input_layer, n_samples=256):
45 | '''Generate summary string for input layer's convergence.'''
46 | with torch.no_grad():
47 | if isinstance(input_layer, layers.BinaryMask):
48 | m = input_layer.sample(n_samples)
49 | mean = torch.mean(m, dim=0)
50 | sorted_mean = torch.sort(mean, descending=True).values
51 | relevant = sorted_mean[:input_layer.num_selections]
52 | return 'Max = {:.2f}, Mean = {:.2f}, Min = {:.2f}'.format(
53 | relevant[0].item(), torch.mean(relevant).item(),
54 | relevant[-1].item())
55 |
56 | elif isinstance(input_layer, layers.ConcreteSelector):
57 | M = input_layer.sample(n_samples)
58 | mean = torch.mean(M, dim=0)
59 | relevant = torch.max(mean, dim=1).values
60 | return 'Max = {:.2f}, Mean = {:.2f}, Min = {:.2f}'.format(
61 | torch.max(relevant).item(), torch.mean(relevant).item(),
62 | torch.min(relevant).item())
63 |
64 | elif isinstance(input_layer, layers.BinaryGates):
65 | m = input_layer.sample(n_samples)
66 | mean = torch.mean(m, dim=0)
67 | dist = torch.min(mean, 1 - mean)
68 | return 'Mean dist = {:.2f}, Max dist = {:.2f}, Num sel = {}'.format(
69 | torch.mean(dist).item(),
70 | torch.max(dist).item(),
71 | int(torch.sum((mean > 0.5).float()).item()))
72 |
73 |
74 | def input_layer_converged(input_layer, tol=1e-2, n_samples=256):
75 | '''Determine whether the input layer has converged.'''
76 | with torch.no_grad():
77 | if isinstance(input_layer, layers.BinaryMask):
78 | m = input_layer.sample(n_samples)
79 | mean = torch.mean(m, dim=0)
80 | return (
81 | torch.sort(mean).values[-input_layer.num_selections].item()
82 | > 1 - tol)
83 |
84 | elif isinstance(input_layer, layers.BinaryGates):
85 | m = input_layer.sample(n_samples)
86 | mean = torch.mean(m, dim=0)
87 | return torch.max(torch.min(mean, 1 - mean)).item() < tol
88 |
89 | elif isinstance(input_layer, layers.ConcreteSelector):
90 | M = input_layer.sample(n_samples)
91 | mean = torch.mean(M, dim=0)
92 | return torch.min(torch.max(mean, dim=1).values).item() > 1 - tol
93 |
94 |
95 | def warmstart_model(model, inds):
96 | '''
97 | Create model for subset of features by removing parameters.
98 |
99 | Args:
100 | model: model to copy.
101 | inds: indices for features to retain.
102 | '''
103 | sub_model = deepcopy(model.mlp)
104 | device = next(model.parameters()).device
105 |
106 | # Resize input layer.
107 | layer = sub_model.fc[0]
108 | new_layer = nn.Linear(len(inds), layer.out_features).to(device)
109 | new_layer.weight.data = layer.weight[:, inds]
110 | new_layer.bias.data = layer.bias
111 | sub_model.fc[0] = new_layer
112 |
113 | return sub_model
114 |
115 |
116 | class MLP(nn.Module):
117 | '''
118 | Multilayer perceptron (MLP) model.
119 |
120 | Args:
121 | input_size: number of inputs.
122 | output_size: number of outputs.
123 | hidden: list of hidden layer widths.
124 | activation: nonlinearity between layers.
125 | '''
126 |
127 | def __init__(self,
128 | input_size,
129 | output_size,
130 | hidden,
131 | activation=nn.ReLU()):
132 | super().__init__()
133 |
134 | # Fully connected layers.
135 | self.input_size = input_size
136 | self.output_size = output_size
137 | fc_layers = [nn.Linear(d_in, d_out) for d_in, d_out in
138 | zip([input_size] + hidden, hidden + [output_size])]
139 | self.fc = nn.ModuleList(fc_layers)
140 |
141 | # Activation function.
142 | self.activation = activation
143 |
144 | def forward(self, x):
145 | for fc in self.fc[:-1]:
146 | x = fc(x)
147 | x = self.activation(x)
148 |
149 | return self.fc[-1](x)
150 |
151 | def fit(self,
152 | train_dataset,
153 | val_dataset,
154 | mbsize,
155 | max_nepochs,
156 | loss_fn,
157 | lr=1e-3,
158 | min_lr=1e-5,
159 | lr_factor=0.5,
160 | optimizer='Adam',
161 | lookback=10,
162 | bar=False,
163 | verbose=True):
164 | '''
165 | Train the model.
166 |
167 | Args:
168 | train_dataset: training dataset.
169 | val_dataset: validation dataset.
170 | mbsize: minibatch size.
171 | max_nepochs: maximum number of epochs.
172 | loss_fn: loss function.
173 | lr: learning rate.
174 | min_lr: minimum learning rate.
175 | lr_factor: learning rate decrease factor.
176 | optimizer: optimizer type.
177 | lookback: number of epochs to wait for improvement before stopping.
178 | bar: whether to display tqdm progress bar.
179 | verbose: verbosity.
180 | '''
181 | # Set up optimizer.
182 | optimizer = optim.Adam(self.parameters(), lr=lr)
183 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
184 | optimizer, factor=lr_factor, patience=lookback // 2, min_lr=min_lr,
185 | verbose=verbose)
186 |
187 | # Set up data loaders.
188 | has_init = hasattr(train_dataset, 'init_worker')
189 | if has_init:
190 | train_init = train_dataset.init_worker
191 | val_init = val_dataset.init_worker
192 | else:
193 | train_init = None
194 | val_init = None
195 | train_loader = DataLoader(
196 | train_dataset, batch_size=mbsize, shuffle=True, drop_last=True,
197 | worker_init_fn=train_init, num_workers=4)
198 | val_loader = DataLoader(
199 | val_dataset, batch_size=mbsize, worker_init_fn=val_init,
200 | num_workers=4)
201 |
202 | # Determine device.
203 | device = next(self.parameters()).device
204 |
205 | # For tracking loss.
206 | self.train_loss = []
207 | self.val_loss = []
208 | best_model = None
209 | best_loss = np.inf
210 | best_epoch = None
211 |
212 | # Bar setup.
213 | if bar:
214 | tqdm_bar = tqdm(
215 | total=max_nepochs, desc='Training epochs', leave=True)
216 |
217 | # Begin training.
218 | for epoch in range(max_nepochs):
219 | # For tracking mean train loss.
220 | train_loss = 0
221 | N = 0
222 |
223 | for x, y in train_loader:
224 | # Move to device.
225 | x = x.to(device)
226 | y = y.to(device)
227 |
228 | # Forward pass.
229 | pred = self.forward(x)
230 |
231 | # Calculate loss.
232 | loss = loss_fn(pred, y)
233 |
234 | # Update mean train loss.
235 | train_loss = (
236 | (N * train_loss + mbsize * loss.item()) / (N + mbsize))
237 | N += mbsize
238 |
239 | # Gradient step.
240 | loss.backward()
241 | optimizer.step()
242 | self.zero_grad()
243 |
244 | # Check progress.
245 | with torch.no_grad():
246 | # Calculate loss.
247 | self.eval()
248 | val_loss = self.validate(val_loader, loss_fn).item()
249 | self.train()
250 |
251 | # Update learning rate.
252 | scheduler.step(val_loss)
253 |
254 | # Record loss.
255 | self.train_loss.append(train_loss)
256 | self.val_loss.append(val_loss)
257 |
258 | if verbose:
259 | print(f'{"-" * 8}Epoch = {epoch + 1}{"-" * 8}')
260 | print(f'Train loss = {train_loss:.4f}')
261 | print(f'Val loss = {val_loss:.4f}')
262 |
263 | # Update bar.
264 | if bar:
265 | tqdm_bar.update(1)
266 |
267 | # Check for early stopping.
268 | if val_loss < best_loss:
269 | best_loss = val_loss
270 | best_model = deepcopy(self)
271 | best_epoch = epoch
272 | elif (epoch - best_epoch) == lookback:
273 | # Skip bar to end.
274 | if bar:
275 | tqdm_bar.n = max_nepochs
276 |
277 | if verbose:
278 | print('Stopping early')
279 | break
280 |
281 | # Restore model parameters.
282 | restore_parameters(self, best_model)
283 |
284 | def validate(self, loader, loss_fn):
285 | '''Calculate average loss.'''
286 | device = next(self.parameters()).device
287 | mean_loss = 0
288 | N = 0
289 |
290 | with torch.no_grad():
291 | for x, y in loader:
292 | # Move to GPU.
293 | x = x.to(device=device)
294 | y = y.to(device=device)
295 | n = len(x)
296 |
297 | # Calculate loss.
298 | pred = self.forward(x)
299 | loss = loss_fn(pred, y)
300 | mean_loss = (N * mean_loss + n * loss) / (N + n)
301 | N += n
302 |
303 | return mean_loss
304 |
305 |
306 | class SelectorMLP(nn.Module):
307 | '''
308 | MLP model with embedded selector layer.
309 |
310 | Args:
311 | input_layer: selection layer type (e.g., BinaryMask).
312 | input_size: number of inputs.
313 | output_size: number of outputs.
314 | hidden: list of hidden layer widths.
315 | activation: nonlinearity between layers.
316 | preselected_inds: feature indices that are already selected.
317 | num_selections: number of features to select (for BinaryMask and
318 | ConcreteSelector layers).
319 | kwargs: additional arguments for input layers.
320 | '''
321 |
322 | def __init__(self,
323 | input_layer,
324 | input_size,
325 | output_size,
326 | hidden,
327 | activation=nn.ReLU(),
328 | preselected_inds=[],
329 | num_selections=None,
330 | **kwargs):
331 | # Verify arguments.
332 | super().__init__()
333 | if num_selections is None:
334 | if input_layer in ('binary_mask', 'concrete_selector'):
335 | raise ValueError(
336 | f'must specify num_selections for {input_layer} layer')
337 | else:
338 | if input_layer in ('binary_gates'):
339 | raise ValueError('num_selections cannot be specified for '
340 | f'{input_layer} layer')
341 |
342 | # Set up for pre-selected features.
343 | preselected_inds = np.sort(preselected_inds)
344 | assert len(preselected_inds) < input_size
345 | self.preselected = np.array(
346 | [i in preselected_inds for i in range(input_size)])
347 | preselected_size = len(preselected_inds)
348 | self.has_preselected = preselected_size > 0
349 |
350 | # Set up input layer.
351 | if input_layer == 'binary_mask':
352 | mlp_input_size = input_size
353 | self.input_layer = layers.BinaryMask(
354 | input_size - preselected_size, num_selections, **kwargs)
355 | elif input_layer == 'binary_gates':
356 | mlp_input_size = input_size
357 | self.input_layer = layers.BinaryGates(
358 | input_size - preselected_size, **kwargs)
359 | elif input_layer == 'concrete_selector':
360 | mlp_input_size = num_selections + preselected_size
361 | self.input_layer = layers.ConcreteSelector(
362 | input_size - preselected_size, num_selections, **kwargs)
363 | else:
364 | raise ValueError('unsupported input layer: {}'.format(input_layer))
365 |
366 | # Create MLP.
367 | self.mlp = MLP(mlp_input_size, output_size, hidden, activation)
368 |
369 | def forward(self, x):
370 | '''Apply input layer and return MLP output.'''
371 | if self.has_preselected:
372 | pre = x[:, self.preselected]
373 | x, m = self.input_layer(x[:, ~self.preselected])
374 | x = torch.cat([pre, x], dim=1)
375 | else:
376 | x, m = self.input_layer(x)
377 | pred = self.mlp(x)
378 | return pred, x, m
379 |
380 | def fit(self,
381 | train_dataset,
382 | val_dataset,
383 | lr,
384 | mbsize,
385 | max_nepochs,
386 | start_temperature,
387 | end_temperature,
388 | loss_fn,
389 | eta=0,
390 | lam=0,
391 | optimizer='Adam',
392 | lookback=10,
393 | bar=False,
394 | verbose=True):
395 | '''
396 | Train the model.
397 |
398 | Args:
399 | train_dataset: training dataset.
400 | val_dataset: validation dataset.
401 | lr: learning rate.
402 | mbsize: minibatch size.
403 | max_nepochs: maximum number of epochs.
404 | start_temperature:
405 | end_temperature:
406 | loss_fn: loss function.
407 | eta: penalty parameter for number of expressed genes.
408 | lam: penalty parameter.
409 | optimizer: optimizer type.
410 | lookback: number of epochs to wait for improvement before stopping.
411 | bar: whether to display tqdm progress bar.
412 | verbose: verbosity.
413 | '''
414 | # Verify arguments.
415 | if lam != 0:
416 | if not isinstance(self.input_layer, layers.BinaryGates):
417 | raise ValueError('lam should only be specified when using '
418 | 'BinaryGates layer')
419 | else:
420 | if isinstance(self.input_layer, layers.BinaryGates):
421 | raise ValueError('lam must be specified when using '
422 | 'BinaryGates layer')
423 | if eta > 0:
424 | if isinstance(self.input_layer, layers.BinaryGates):
425 | raise ValueError('lam cannot be specified when using '
426 | 'BinaryGates layer')
427 |
428 | if end_temperature > start_temperature:
429 | raise ValueError('temperature should be annealed downwards, must '
430 | 'have end_temperature <= start_temperature')
431 | elif end_temperature == start_temperature:
432 | loss_early_stopping = True
433 | else:
434 | loss_early_stopping = False
435 |
436 | # Set up optimizer.
437 | optimizer = optimizer = optim.Adam(self.parameters(), lr=lr)
438 |
439 | # Set up data loaders.
440 | has_init = hasattr(train_dataset, 'init_worker')
441 | if has_init:
442 | train_init = train_dataset.init_worker
443 | val_init = val_dataset.init_worker
444 | else:
445 | train_init = None
446 | val_init = None
447 | train_loader = DataLoader(train_dataset, batch_size=mbsize,
448 | shuffle=True, drop_last=True,
449 | worker_init_fn=train_init, num_workers=4)
450 | val_loader = DataLoader(val_dataset, batch_size=mbsize,
451 | worker_init_fn=val_init, num_workers=4)
452 |
453 | # Determine device.
454 | device = next(self.parameters()).device
455 |
456 | # Set temperature and determine rate for decreasing.
457 | self.input_layer.temperature = start_temperature
458 | r = np.power(end_temperature / start_temperature,
459 | 1 / ((len(train_dataset) // mbsize) * max_nepochs))
460 |
461 | # For tracking loss.
462 | self.train_loss = []
463 | self.val_loss = []
464 | best_loss = np.inf
465 | best_epoch = -1
466 |
467 | # Bar setup.
468 | if bar:
469 | tqdm_bar = tqdm(
470 | total=max_nepochs, desc='Training epochs', leave=True)
471 |
472 | # Begin training.
473 | for epoch in range(max_nepochs):
474 | # For tracking mean train loss.
475 | train_loss = 0
476 | N = 0
477 |
478 | for x, y in train_loader:
479 | # Move to device.
480 | x = x.to(device)
481 | y = y.to(device)
482 |
483 | # Calculate loss.
484 | pred, x, m = self.forward(x)
485 | loss = loss_fn(pred, y)
486 |
487 | # Calculate penalty if necessary.
488 | if lam > 0:
489 | penalty = input_layer_penalty(self.input_layer, m)
490 | loss = loss + lam * penalty
491 |
492 | # Add expression penalty if necessary.
493 | if eta > 0:
494 | expressed = torch.mean(torch.sum(x, dim=1))
495 | loss = loss + eta * expressed
496 |
497 | # Update mean train loss.
498 | train_loss = (
499 | (N * train_loss + mbsize * loss.item()) / (N + mbsize))
500 | N += mbsize
501 |
502 | # Gradient step.
503 | loss.backward()
504 | optimizer.step()
505 | self.zero_grad()
506 |
507 | # Adjust temperature.
508 | self.input_layer.temperature *= r
509 |
510 | # Check progress.
511 | with torch.no_grad():
512 | # Calculate loss.
513 | self.eval()
514 | val_loss, val_expressed = self.validate(
515 | val_loader, loss_fn, lam, eta)
516 | val_loss, val_expressed = val_loss.item(), val_expressed.item()
517 | self.train()
518 |
519 | # Record loss.
520 | self.train_loss.append(train_loss)
521 | self.val_loss.append(val_loss)
522 |
523 | if verbose:
524 | print(f'{"-" * 8}Epoch = {epoch + 1}{"-" * 8}')
525 | print(f'Train loss = {train_loss:.4f}')
526 | print(f'Val loss = {val_loss:.4f}')
527 | if eta > 0:
528 | print(f'Mean expressed genes = {val_expressed:.4f}')
529 | print(input_layer_summary(self.input_layer))
530 |
531 | # Update bar.
532 | if bar:
533 | tqdm_bar.update(1)
534 |
535 | # Fix input layer if necessary.
536 | required_fix = input_layer_fix(self.input_layer)
537 |
538 | if not required_fix:
539 | # Stop early if input layer is converged.
540 | if input_layer_converged(self.input_layer, n_samples=mbsize):
541 | if verbose:
542 | print('Stopping early: input layer converged')
543 | break
544 |
545 | # Stop early if loss converged.
546 | if loss_early_stopping:
547 | if val_loss < best_loss:
548 | best_loss = val_loss
549 | best_epoch = epoch
550 | elif (epoch - best_epoch) == lookback:
551 | # Skip bar to end.
552 | if bar:
553 | tqdm_bar.n = max_nepochs
554 |
555 | if verbose:
556 | print('Stopping early: loss converged')
557 | break
558 |
559 | def validate(self, loader, loss_fn, lam, eta):
560 | '''Calculate average loss.'''
561 | device = next(self.parameters()).device
562 | mean_loss = 0
563 | mean_expressed = 0
564 | N = 0
565 | with torch.no_grad():
566 | for x, y in loader:
567 | # Move to GPU.
568 | x = x.to(device=device)
569 | y = y.to(device=device)
570 | n = len(x)
571 |
572 | # Calculate loss.
573 | pred, x, m = self.forward(x)
574 | loss = loss_fn(pred, y)
575 |
576 | # Add penalty term.
577 | if lam > 0:
578 | penalty = input_layer_penalty(self.input_layer, m)
579 | loss = loss + lam * penalty
580 |
581 | # Add expression penalty term.
582 | expressed = torch.mean(torch.sum(x, dim=1))
583 | if eta > 0:
584 | loss = loss + eta * expressed
585 |
586 | mean_loss = (N * mean_loss + n * loss) / (N + n)
587 | mean_expressed = (N * mean_expressed + n * expressed) / (N + n)
588 | N += n
589 |
590 | return mean_loss, mean_expressed
591 |
--------------------------------------------------------------------------------
/persist/selection.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn as nn
4 | from persist import models, utils
5 |
6 |
7 | class PERSIST:
8 | '''
9 | Class for using the predictive and robust gene selection for spatial
10 | transcriptomics (PERSIST) method.
11 |
12 | Args:
13 | train_dataset: dataset of training examples (ExpressionDataset).
14 | val_dataset: dataset of validation examples.
15 | loss_fn: loss function, such as HurdleLoss(), nn.MSELoss() or
16 | nn.CrossEntropyLoss().
17 | device: torch device, such as torch.device('cuda', 0).
18 | preselected_inds: list of indices that must be selected.
19 | hidden: number of hidden units per layer (list of ints).
20 | activation: activation function between layers.
21 | '''
22 | def __init__(self,
23 | train_dataset,
24 | val_dataset,
25 | loss_fn,
26 | device,
27 | preselected_inds=[],
28 | hidden=[128, 128],
29 | activation=nn.ReLU()):
30 | # TODO add verification for dataset type.
31 | self.train = train_dataset
32 | self.val = val_dataset
33 | self.loss_fn = loss_fn
34 |
35 | # Architecture parameters.
36 | self.hidden = hidden
37 | self.activation = activation
38 |
39 | # Set device.
40 | assert isinstance(device, torch.device)
41 | self.device = device
42 |
43 | # Set preselected genes.
44 | self.preselected = np.sort(preselected_inds).astype(int)
45 |
46 | # Initialize candidate genes.
47 | self.set_genes()
48 |
49 | def get_genes(self):
50 | '''Get currently selected genes, not including preselected genes.'''
51 | return self.candidates
52 |
53 | def set_genes(self, candidates=None):
54 | '''Restrict the subset of genes.'''
55 | if candidates is None:
56 | # All genes but pre-selected ones.
57 | candidates = np.array(
58 | [i for i in range(self.train.max_input_size)
59 | if i not in self.preselected])
60 |
61 | else:
62 | # Ensure that candidates do not overlap with pre-selected genes.
63 | assert len(np.intersect1d(candidates, self.preselected)) == 0
64 | self.candidates = candidates
65 |
66 | # Set genes in datasets.
67 | included = np.sort(np.concatenate([candidates, self.preselected]))
68 | self.train.set_inds(included)
69 | self.val.set_inds(included)
70 |
71 | # Set relative indices for pre-selected genes.
72 | self.preselected_relative = np.array(
73 | [np.where(included == ind)[0][0] for ind in self.preselected])
74 |
75 | def eliminate(self,
76 | target,
77 | lam_init=None,
78 | mbsize=64,
79 | max_nepochs=250,
80 | lr=1e-3,
81 | tol=0.2,
82 | start_temperature=10.0,
83 | end_temperature=0.01,
84 | optimizer='Adam',
85 | lookback=10,
86 | max_trials=10,
87 | bar=True,
88 | verbose=False):
89 | '''
90 | Narrow the set of candidate genes: train a model with the BinaryGates
91 | layer and annealed penalty to eliminate a large portion of the inputs.
92 |
93 | Args:
94 | target: target number of genes to select (in addition to pre-selected
95 | genes).
96 | lam_init: initial lambda value.
97 | mbsize: minibatch size.
98 | max_nepochs: maximum number of epochs.
99 | lr: learning rate.
100 | tol: tolerance around gene target number.
101 | start_temperature: starting temperature for BinConcrete samples.
102 | end_temperature: final temperature value.
103 | optimizer: optimization algorithm.
104 | lookback: number of epochs to wait for improvement before stopping.
105 | max_trials: maximum number of training rounds before returning an
106 | error.
107 | bar: whether to display tqdm progress bar for each round of training.
108 | verbose: verbosity.
109 | '''
110 | # Reset candidate genes.
111 | all_inds = np.arange(self.train.max_input_size)
112 | all_candidates = np.array_equal(
113 | self.candidates, np.setdiff1d(all_inds, self.preselected))
114 | all_train_inds = np.array_equal(self.train.inds, all_inds)
115 | all_val_inds = np.array_equal(self.val.inds, all_inds)
116 | if not (all_candidates and all_train_inds and all_val_inds):
117 | print('resetting candidate genes')
118 | self.set_genes()
119 |
120 | # Initialize architecture.
121 | if isinstance(self.loss_fn, utils.HurdleLoss):
122 | output_size = 2 * self.train.output_size
123 | elif isinstance(self.loss_fn, (nn.CrossEntropyLoss, nn.MSELoss)):
124 | output_size = self.train.output_size
125 | else:
126 | output_size = self.train.output_size
127 | print(f'Unknown loss function, assuming {self.loss_fn} requires '
128 | f'{self.train.output_size} outputs')
129 |
130 | model = models.SelectorMLP(input_layer='binary_gates',
131 | input_size=self.train.input_size,
132 | output_size=output_size,
133 | hidden=self.hidden,
134 | activation=self.activation,
135 | preselected_inds=self.preselected_relative)
136 | model = model.to(self.device)
137 |
138 | # Determine lam_init, if necessary.
139 | if lam_init is None:
140 | if isinstance(self.loss_fn, utils.HurdleLoss):
141 | print('using HurdleLoss, starting with lam = 0.01')
142 | lam_init = 0.01
143 | elif isinstance(self.loss_fn, nn.MSELoss):
144 | print('using MSELoss, starting with lam = 0.01')
145 | lam_init = 0.01
146 | elif isinstance(self.loss_fn, nn.CrossEntropyLoss):
147 | print('using CrossEntropyLoss, starting with lam = 0.0001')
148 | lam_init = 0.0001
149 | else:
150 | print('unknown loss function, starting with lam = 0.0001')
151 | lam_init = 0.0001
152 | else:
153 | print(f'trying lam = {lam_init:.6f}')
154 |
155 | # Prepare for training and lambda search.
156 | assert 0 < target < self.train.input_size
157 | assert 0.1 <= tol < 0.5
158 | assert lam_init > 0
159 | lam_list = [0]
160 | num_remaining = self.train.input_size
161 | num_remaining_list = [num_remaining]
162 | lam = lam_init
163 | trials = 0
164 |
165 | # Iterate until num_remaining is near the target value.
166 | while np.abs(num_remaining - target) > target * tol:
167 | # Ensure not done.
168 | if trials == max_trials:
169 | raise ValueError(
170 | 'reached maximum number of trials without selecting the '
171 | 'desired number of genes! The results may have large '
172 | 'variance due to small dataset size, or the initial lam '
173 | 'value may be bad')
174 | trials += 1
175 |
176 | # Train.
177 | model.fit(self.train,
178 | self.val,
179 | lr,
180 | mbsize,
181 | max_nepochs,
182 | start_temperature=start_temperature,
183 | end_temperature=end_temperature,
184 | loss_fn=self.loss_fn,
185 | lam=lam,
186 | optimizer=optimizer,
187 | lookback=lookback,
188 | bar=bar,
189 | verbose=verbose)
190 |
191 | # Extract inds.
192 | inds = model.input_layer.get_inds(threshold=0.5)
193 | num_remaining = len(inds)
194 | print(f'lam = {lam:.6f} yielded {num_remaining} genes')
195 |
196 | if np.abs(num_remaining - target) <= target * tol:
197 | print(f'done, lam = {lam:.6f} yielded {num_remaining} genes')
198 |
199 | else:
200 | # Guess next lam value.
201 | next_lam = modified_secant_method(
202 | lam, 1 / (1 + num_remaining), 1 / (1 + target),
203 | np.array(lam_list), 1 / (1 + np.array(num_remaining_list)))
204 |
205 | # Clip lam value for stability.
206 | next_lam = np.clip(next_lam, a_min=0.1 * lam, a_max=10 * lam)
207 |
208 | # Possibly reinitialize model.
209 | if num_remaining < target * (1 - tol):
210 | # BinaryGates layer is not great at allowing features
211 | # back in after inducing too much sparsity.
212 | print('Reinitializing model for next iteration')
213 | model = models.SelectorMLP(
214 | input_layer='binary_gates',
215 | input_size=self.train.input_size,
216 | output_size=output_size,
217 | hidden=self.hidden,
218 | activation=self.activation,
219 | preselected_inds=self.preselected_relative)
220 | model = model.to(self.device)
221 | else:
222 | print('Warm starting model for next iteration')
223 |
224 | # Prepare for next iteration.
225 | lam_list.append(lam)
226 | num_remaining_list.append(num_remaining)
227 | lam = next_lam
228 | print(f'next attempt is lam = {lam:.6f}')
229 |
230 | # Set eligible genes.
231 | true_inds = self.candidates[inds]
232 | self.set_genes(true_inds)
233 | return true_inds, model
234 |
235 | def select(self,
236 | num_genes,
237 | mbsize=64,
238 | max_nepochs=250,
239 | lr=1e-3,
240 | start_temperature=10.0,
241 | end_temperature=0.01,
242 | optimizer='Adam',
243 | bar=True,
244 | verbose=False):
245 | '''
246 | Select genetic probes: train a model with BinaryMask layer to select
247 | a precise number of model inputs.
248 |
249 | Args:
250 | num_genes: number of genes to select (in addition to pre-selected
251 | genes).
252 | mbsize: minibatch size.
253 | max_nepochs: maximum number of epochs.
254 | lr: learning rate.
255 | start_temperature: starting temperature value for Concrete samples.
256 | end_temperature: final temperature value.
257 | optimizer: optimization algorithm.
258 | bar: whether to display tqdm progress bar.
259 | verbose: verbosity.
260 | '''
261 | # Possibly reset candidate genes.
262 | included_inds = np.sort(
263 | np.concatenate([self.candidates, self.preselected]))
264 | candidate_train_inds = np.array_equal(self.train.inds, included_inds)
265 | candidate_val_inds = np.array_equal(self.val.inds, included_inds)
266 | if not (candidate_train_inds and candidate_val_inds):
267 | print('setting candidate genes in datasets')
268 | self.set_genes(self.candidates)
269 |
270 | # Initialize architecture.
271 | if isinstance(self.loss_fn, utils.HurdleLoss):
272 | output_size = 2 * self.train.output_size
273 | elif isinstance(self.loss_fn, (nn.CrossEntropyLoss, nn.MSELoss)):
274 | output_size = self.train.output_size
275 | else:
276 | output_size = self.train.output_size
277 | print(f'assuming loss function {self.loss_fn} requires '
278 | f'{self.train.output_size} outputs')
279 |
280 | input_size = len(self.candidates) + len(self.preselected)
281 | model = models.SelectorMLP(input_layer='binary_mask',
282 | input_size=input_size,
283 | output_size=output_size,
284 | hidden=self.hidden,
285 | activation=self.activation,
286 | preselected_inds=self.preselected_relative,
287 | num_selections=num_genes).to(self.device)
288 |
289 | # Train.
290 | model.fit(self.train,
291 | self.val,
292 | lr,
293 | mbsize,
294 | max_nepochs,
295 | start_temperature,
296 | end_temperature,
297 | loss_fn=self.loss_fn,
298 | optimizer=optimizer,
299 | bar=bar,
300 | verbose=verbose)
301 |
302 | # Return genes.
303 | inds = model.input_layer.get_inds()
304 | true_inds = self.candidates[inds]
305 | print(f'done, selected {len(inds)} genes')
306 | return true_inds, model
307 |
308 |
309 | def modified_secant_method(x0, y0, y1, x, y):
310 | '''
311 | A modified version of secant method, used here to determine the correct lam
312 | value. Note that we use x = lam and y = 1 / (1 + num_remaining) rather than
313 | y = num_remaining, because this gives better results.
314 |
315 | The standard secant method uses the two previous points to calculate a
316 | finite difference rather than an exact derivative (as in Newton's method).
317 | Here, we used a robustified derivative estimator: we find the curve,
318 | which passes through the most recent point (x0, y0), that minimizes a
319 | weighted least squares loss for all previous points (x, y). This improves
320 | robustness to nearby guesses (small |x - x'|) and noisy evaluations.
321 |
322 | Args:
323 | x0: most recent x.
324 | y0: most recent y.
325 | y1: target y value.
326 | x: all previous xs.
327 | y: all previous ys.
328 | '''
329 | # Get robust slope estimate.
330 | weights = 1 / np.abs(x - x0)
331 | slope = (
332 | np.sum(weights * (x - x0) * (y - y0)) /
333 | np.sum(weights * (x - x0) ** 2))
334 |
335 | # Clip slope to minimum value.
336 | slope = np.clip(slope, a_min=1e-6, a_max=None)
337 |
338 | # Guess x1.
339 | x1 = x0 + (y1 - y0) / slope
340 | return x1
341 |
--------------------------------------------------------------------------------
/persist/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from copy import deepcopy
4 |
5 |
6 | class HurdleLoss(nn.BCEWithLogitsLoss):
7 | '''
8 | Hurdle loss that incorporates ZCELoss for each output, as well as MSE for
9 | each output that surpasses the threshold value. This can be understood as
10 | the negative log-likelihood of a hurdle distribution.
11 |
12 | Args:
13 | lam: weight for the ZCELoss term (the hurdle).
14 | thresh: threshold that an output must surpass to be considered turned on.
15 | '''
16 | def __init__(self, lam=10.0, thresh=0):
17 | super().__init__()
18 | self.lam = lam
19 | self.thresh = thresh
20 |
21 | def forward(self, pred, target):
22 | # Verify prediction shape.
23 | if pred.shape[1] != 2 * target.shape[1]:
24 | raise ValueError(
25 | 'Predictions have incorrect shape! For HurdleLoss, the'
26 | ' predictions must have twice the dimensionality of targets'
27 | ' ({})'.format(target.shape[1] * 2))
28 |
29 | # Reshape predictions, get distributional.
30 | pred = pred.reshape(*pred.shape[:-1], -1, 2)
31 | pred = pred.permute(-1, *torch.arange(len(pred.shape))[:-1])
32 | mu = pred[0]
33 | p_logit = pred[1]
34 |
35 | # Calculate loss.
36 | zero_target = (target <= self.thresh).float().detach()
37 | hurdle_loss = super().forward(p_logit, zero_target)
38 | mse = (1 - zero_target) * (target - mu) ** 2
39 |
40 | loss = self.lam * hurdle_loss + mse
41 | return torch.mean(torch.sum(loss, dim=-1))
42 |
43 |
44 | class ZCELoss(nn.BCEWithLogitsLoss):
45 | '''
46 | Binary classification loss on whether outputs surpass a threshold. Expects
47 | logits.
48 |
49 | Args:
50 | thresh: threshold that an output must surpass to be considered on.
51 | '''
52 | def __init__(self, thresh=0):
53 | super().__init__(reduction='none')
54 |
55 | def forward(self, pred, target):
56 | zero_target = (target == 0).float()
57 | loss = super().forward(pred, zero_target)
58 | return torch.mean(torch.sum(loss, dim=-1))
59 |
60 |
61 | class ZeroAccuracy(nn.Module):
62 | '''
63 | Classification accuracy on whether outputs surpass a threshold. Expects
64 | logits.
65 |
66 | Args:
67 | thresh: threshold that an output must surpass to be considered turned on.
68 | '''
69 | def __init__(self, thresh=0):
70 | super().__init__()
71 | self.thresh = thresh
72 |
73 | def forward(self, pred, target):
74 | zero_pred = (pred > 0).float()
75 | zero_target = (target <= self.thresh).float()
76 | acc = (zero_pred == zero_target).float()
77 | return torch.mean(acc)
78 |
79 |
80 | class MSELoss(nn.Module):
81 | '''MSE loss that sums over output dimensions.'''
82 | def __init__(self):
83 | super().__init__()
84 |
85 | def forward(self, pred, target):
86 | loss = torch.sum((pred - target) ** 2, dim=-1)
87 | return torch.mean(loss)
88 |
89 |
90 | class Accuracy(nn.Module):
91 | '''0-1 classification loss.'''
92 | def __init__(self):
93 | super().__init__()
94 |
95 | def forward(self, pred, target):
96 | return (torch.argmax(pred, dim=1) == target).float().mean()
97 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | requirements = [
4 | 'numpy',
5 | 'torch',
6 | 'tqdm',
7 | 'scikit-learn',
8 | 'h5py',
9 | 'anndata',
10 | 'scanpy',
11 | 'matplotlib',
12 | 'scipy',
13 | 'pandas',
14 | 'toml'
15 | ]
16 |
17 | setuptools.setup(
18 | name='persist',
19 | version='0.0.1',
20 | author='Ian Covert',
21 | author_email='icovert@cs.washington.edu',
22 | description='PERSIST: predictive and robust gene selection for spatial transcriptomics',
23 | long_description='''
24 | Predictive and robust gene selection for spatial transcriptomics
25 | (PERSIST) is a computational approach to select a small number of
26 | target genes for FISH experiments. PERSIST relies on a reference
27 | scRNA-seq dataset, and it uses deep learning to identify genes that are
28 | predictive of the genome-wide expression profile or any other target of
29 | interest (e.g., transcriptomic cell types). PERSIST binarizes gene
30 | expression levels during the selection process to account for the
31 | measurement shift between scRNA-seq and FISH expression counts.
32 | ''',
33 | long_description_content_type='text/markdown',
34 | url='https://github.com/iancovert/persist',
35 | packages=['persist'],
36 | install_requires=requirements,
37 | classifiers=[
38 | 'Programming Language :: Python :: 3',
39 | 'License :: OSI Approved :: MIT License',
40 | 'Operating System :: OS Independent',
41 | 'Intended Audience :: Science/Research',
42 | 'Topic :: Scientific/Engineering'
43 | ],
44 | python_requires='>=3.6',
45 | )
46 |
--------------------------------------------------------------------------------
/tests/test_expression_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.sparse import csr_matrix
3 | from persist import ExpressionDataset
4 |
5 |
6 | def test_expression_dataset():
7 | """Tests ExpressionDataset for sparse inputs."""
8 | np.random.seed(0)
9 | labels = np.array([1, 2, 3]).astype(np.int_)
10 | x = np.random.rand(3, 6)
11 | x[x < 0.5] = 0
12 | sparse_x = csr_matrix(x)
13 |
14 | # check dense entries can be recovered
15 | for i in range(x.shape[0]):
16 | assert np.all(x[i] == sparse_x[i].toarray()), f"Row {i} is not the same"
17 |
18 | dense_dataset = ExpressionDataset(x, labels)
19 | sparse_dataset = ExpressionDataset(sparse_x, labels)
20 |
21 | # ExpressionDataset flags:
22 | assert not (dense_dataset.data_issparse), "expecting non-sparse data"
23 | assert sparse_dataset.data_issparse, "expecting sparse data"
24 |
25 | # ExpressionDataset emits X and label; we want identical results in both cases.
26 | for (x1, l1), (x2, l2) in zip(dense_dataset, sparse_dataset):
27 | assert l1 == l2, "Labels returned by ExpressionDataset are not the same"
28 | assert np.all(x1 == x2), "Data returned by ExpressionDataset is not the same"
29 |
30 | return
31 |
--------------------------------------------------------------------------------