├── .DS_Store ├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── depr ├── scripts │ ├── .gitkeep │ ├── veres-fate.sh │ ├── weinreb-fate.sh │ └── weinreb-interpolate.sh ├── veres.py └── weinreb.py ├── docs ├── .DS_Store ├── .bundle │ └── config ├── .gitignore ├── Gemfile ├── Gemfile.lock ├── _config.yml ├── _data │ ├── examples.json │ ├── navigation.yml │ └── tools.json ├── _includes │ ├── head.html │ ├── header.html │ └── navbar.html ├── _layouts │ ├── cli.html │ ├── core.html │ ├── document.html │ └── home.html ├── _site │ ├── file_formats │ │ └── index.html │ └── index.md ├── about.html ├── assets │ ├── .DS_Store │ ├── css │ │ └── styles.scss │ ├── gifs │ │ └── trajectories.gif │ └── img │ │ ├── .DS_Store │ │ ├── implementation.png │ │ └── model.png ├── documentation.html ├── file_formats.md ├── index.md ├── notebooks.html └── quickstart.html ├── notebooks ├── .gitkeep └── estimate-growth-rates.ipynb ├── prescient ├── .gitkeep ├── __init__.py ├── __main__.py ├── __pycache__ │ ├── train.cpython-37.pyc │ └── veres.cpython-37.pyc ├── commands │ ├── __init__.py │ ├── perturbation_analysis.py │ ├── process_data.py │ ├── simulate_trajectories.py │ └── train_model.py ├── perturb │ ├── __init__.py │ └── pert.py ├── simulate │ ├── __init__.py │ └── sim.py ├── train │ ├── __init__.py │ ├── model.py │ ├── run.py │ └── util.py └── utils.py ├── pyproject.toml ├── setup.cfg └── setup.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gifford-lab/prescient/50971c7d495e8763eaa60f83af91f51555ed7ece/.DS_Store -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-detectable=false 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Jupyter Notebook 55 | .ipynb_checkpoints 56 | 57 | # Environments 58 | .env 59 | .venv 60 | env/ 61 | venv/ 62 | ENV/ 63 | env.bak/ 64 | venv.bak/ 65 | 66 | # Visual Studio Code 67 | *.code-workspace 68 | .vscode 69 | 70 | tmp/ 71 | .DS_Store 72 | 73 | experiments/ 74 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Gifford Lab, MIT CSAIL 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![PyPI version](https://badge.fury.io/py/prescient.svg)](https://badge.fury.io/py/prescient) 2 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 3 | [![DOI](https://zenodo.org/badge/290219865.svg)](https://zenodo.org/badge/latestdoi/290219865) 4 | 5 | # prescient 6 | Software for PRESCIENT (Potential eneRgy undErlying Single Cell gradIENTs), a generative model for modeling single-cell time-series. 7 | + Current paper version: https://www.biorxiv.org/content/10.1101/2020.08.26.269332v1 8 | + For paper pre-processing scripts, training bash scripts, pre-trained models, and visualization notebooks please visit https://github.com/gifford-lab/prescient-analysis. 9 | 10 | ## Documentation 11 | Documentation is available at https://cgs.csail.mit.edu/prescient. 12 | 13 | 14 | 15 | ## Requirements 16 | 17 | + pytorch 1.4.0 18 | + geomloss 0.2.3, pykeops 1.3 19 | + numpy, scipy, pandas, sklearn, tqdm, annoy 20 | + scanpy, pyreadr, anndata 21 | + Recommended: An Nvidia GPU with CUDA support for GPU acceleration (see paper for more details on computational resources) 22 | 23 | ## Bugs & Suggestions 24 | 25 | Please report any bugs, problems, suggestions or requests as a [Github issue](https://github.com/gifford-lab/prescient/issues) 26 | -------------------------------------------------------------------------------- /depr/scripts/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gifford-lab/prescient/50971c7d495e8763eaa60f83af91f51555ed7ece/depr/scripts/.gitkeep -------------------------------------------------------------------------------- /depr/scripts/veres-fate.sh: -------------------------------------------------------------------------------- 1 | # example script for training 5 seeds of 2 layer 400 unit models 2 | # on Veres et al. dataset with estimated proliferation rates 3 | 4 | python=/data/gl/g5/yhtgrace/env/miniconda3/envs/cuda10.1/bin/python 5 | train_dt=0.1 6 | train_sd=0.1 7 | train_tau=1e-6 8 | train_batch=0.1 9 | train_clip=0.1 10 | train_lr=0.001 11 | pretrain_epochs=500 12 | train_epochs=2500 13 | k_dim=400 14 | layers=2 15 | save=100 16 | 17 | data_dir=./data/Veres2019 18 | out_dir=./experiments/veres-fate 19 | weight_path=./data/Veres2019_growth-kegg.pt 20 | 21 | device=0 22 | seed=1 23 | screen -dm bash -c "cd ../../; $python src/veres.py --task fate --train --data_dir $data_dir --out_dir $out_dir --train_batch $train_batch --train_clip $train_clip --train_lr $train_lr --train_sd $train_sd --train_dt $train_dt --train_tau $train_tau --pretrain_epochs $pretrain_epochs --train_epochs $train_epochs --k_dim $k_dim --layers $layers --save $save --seed $seed --device $device --weight_path $weight_path" 24 | 25 | device=1 26 | seed=2 27 | screen -dm bash -c "cd ../../; $python src/veres.py --task fate --train --data_dir $data_dir --out_dir $out_dir --train_batch $train_batch --train_clip $train_clip --train_lr $train_lr --train_sd $train_sd --train_dt $train_dt --train_tau $train_tau --pretrain_epochs $pretrain_epochs --train_epochs $train_epochs --k_dim $k_dim --layers $layers --save $save --seed $seed --device $device --weight_path $weight_path" 28 | 29 | device=2 30 | seed=3 31 | screen -dm bash -c "cd ../../; $python src/veres.py --task fate --train --data_dir $data_dir --out_dir $out_dir --train_batch $train_batch --train_clip $train_clip --train_lr $train_lr --train_sd $train_sd --train_dt $train_dt --train_tau $train_tau --pretrain_epochs $pretrain_epochs --train_epochs $train_epochs --k_dim $k_dim --layers $layers --save $save --seed $seed --device $device --weight_path $weight_path" 32 | 33 | device=3 34 | seed=4 35 | screen -dm bash -c "cd ../../; $python src/veres.py --task fate --train --data_dir $data_dir --out_dir $out_dir --train_batch $train_batch --train_clip $train_clip --train_lr $train_lr --train_sd $train_sd --train_dt $train_dt --train_tau $train_tau --pretrain_epochs $pretrain_epochs --train_epochs $train_epochs --k_dim $k_dim --layers $layers --save $save --seed $seed --device $device --weight_path $weight_path" 36 | 37 | device=4 38 | seed=5 39 | screen -dm bash -c "cd ../../; $python src/veres.py --task fate --train --data_dir $data_dir --out_dir $out_dir --train_batch $train_batch --train_clip $train_clip --train_lr $train_lr --train_sd $train_sd --train_dt $train_dt --train_tau $train_tau --pretrain_epochs $pretrain_epochs --train_epochs $train_epochs --k_dim $k_dim --layers $layers --save $save --seed $seed --device $device --weight_path $weight_path" 40 | -------------------------------------------------------------------------------- /depr/scripts/weinreb-fate.sh: -------------------------------------------------------------------------------- 1 | # example script for training 5 seeds of 2 layer 400 unit models 2 | # on Weinreb et al. dataset on all data with estimated proliferation rates 3 | # for clonal fate bias prediction 4 | 5 | python=/data/gl/g5/yhtgrace/env/miniconda3/envs/cuda10.1/bin/python 6 | train_dt=0.1 7 | train_sd=0.1 8 | train_tau=1e-6 9 | train_batch=0.1 10 | train_clip=0.1 11 | train_lr=0.005 12 | train_epochs=2500 13 | pretrain_epochs=500 14 | k_dim=400 15 | layers=2 16 | save=100 17 | 18 | data_dir=./data/Weinreb2020_fate 19 | weight_path=./data/Weinreb2020_growth-all_kegg.pt # also specifies the training mask 20 | out_dir=./experiments/weinreb-fate 21 | 22 | device=0 23 | seed=1 24 | #echo "cd ../../; $python src/weinreb.py --task fate --train --data_dir $data_dir --weight_path $weight_path --out_dir $out_dir --train_batch $train_batch --train_clip $train_clip --train_lr $train_lr --train_sd $train_sd --train_dt $train_dt --train_tau $train_tau --pretrain_epochs $pretrain_epochs --train_epochs $train_epochs --k_dim $k_dim --layers $layers --save $save --seed $seed --device $device" 25 | screen -dm bash -c "cd ../../; $python src/weinreb.py --task fate --train --data_dir $data_dir --weight_path $weight_path --out_dir $out_dir --train_batch $train_batch --train_clip $train_clip --train_lr $train_lr --train_sd $train_sd --train_dt $train_dt --train_tau $train_tau --pretrain_epochs $pretrain_epochs --train_epochs $train_epochs --k_dim $k_dim --layers $layers --save $save --seed $seed --device $device" 26 | 27 | device=1 28 | seed=2 29 | screen -dm bash -c "cd ../../; $python src/weinreb.py --task fate --train --data_dir $data_dir --weight_path $weight_path --out_dir $out_dir --train_batch $train_batch --train_clip $train_clip --train_lr $train_lr --train_sd $train_sd --train_dt $train_dt --train_tau $train_tau --pretrain_epochs $pretrain_epochs --train_epochs $train_epochs --k_dim $k_dim --layers $layers --save $save --seed $seed --device $device" 30 | 31 | device=2 32 | seed=3 33 | screen -dm bash -c "cd ../../; $python src/weinreb.py --task fate --train --data_dir $data_dir --weight_path $weight_path --out_dir $out_dir --train_batch $train_batch --train_clip $train_clip --train_lr $train_lr --train_sd $train_sd --train_dt $train_dt --train_tau $train_tau --pretrain_epochs $pretrain_epochs --train_epochs $train_epochs --k_dim $k_dim --layers $layers --save $save --seed $seed --device $device" 34 | 35 | device=3 36 | seed=4 37 | screen -dm bash -c "cd ../../; $python src/weinreb.py --task fate --train --data_dir $data_dir --weight_path $weight_path --out_dir $out_dir --train_batch $train_batch --train_clip $train_clip --train_lr $train_lr --train_sd $train_sd --train_dt $train_dt --train_tau $train_tau --pretrain_epochs $pretrain_epochs --train_epochs $train_epochs --k_dim $k_dim --layers $layers --save $save --seed $seed --device $device" 38 | 39 | device=4 40 | seed=5 41 | screen -dm bash -c "cd ../../; $python src/weinreb.py --task fate --train --data_dir $data_dir --weight_path $weight_path --out_dir $out_dir --train_batch $train_batch --train_clip $train_clip --train_lr $train_lr --train_sd $train_sd --train_dt $train_dt --train_tau $train_tau --pretrain_epochs $pretrain_epochs --train_epochs $train_epochs --k_dim $k_dim --layers $layers --save $save --seed $seed --device $device" 42 | 43 | -------------------------------------------------------------------------------- /depr/scripts/weinreb-interpolate.sh: -------------------------------------------------------------------------------- 1 | # example script for training 5 seeds of 2 layer 400 unit models 2 | # on Weinreb et al. dataset on only lineage tracing data with ground truth proliferation rates 3 | # for interpolation task 4 | 5 | python=/data/gl/g5/yhtgrace/env/miniconda3/envs/cuda10.1/bin/python 6 | train_dt=0.1 7 | train_sd=0.1 8 | train_tau=1e-6 9 | train_batch=0.1 10 | train_clip=0.1 11 | train_lr=0.005 12 | train_epochs=2500 13 | pretrain_epochs=500 14 | k_dim=400 15 | layers=2 16 | save=100 17 | 18 | data_path=./data/Weinreb2020_impute.pt 19 | weight_path=./data/Weinreb2020_growth-d26.pt 20 | out_dir=./experiments/weinreb-interpolate 21 | 22 | device=0 23 | seed=1 24 | screen -dm bash -c "cd ../../; $python src/weinreb.py --task interpolate --train --data_path $data_path --weight_path $weight_path --out_dir $out_dir --train_batch $train_batch --train_clip $train_clip --train_lr $train_lr --train_dt $train_dt --train_sd $train_sd --train_tau $train_tau --pretrain_epochs $pretrain_epochs --train_epochs $train_epochs --k_dim $k_dim --layers $layers --save $save --seed $seed --device $device" 25 | 26 | device=1 27 | seed=2 28 | screen -dm bash -c "cd ../../; $python src/weinreb.py --task interpolate --train --data_path $data_path --weight_path $weight_path --out_dir $out_dir --train_batch $train_batch --train_clip $train_clip --train_lr $train_lr --train_dt $train_dt --train_sd $train_sd --train_tau $train_tau --pretrain_epochs $pretrain_epochs --train_epochs $train_epochs --k_dim $k_dim --layers $layers --save $save --seed $seed --device $device" 29 | 30 | device=2 31 | seed=3 32 | screen -dm bash -c "cd ../../; $python src/weinreb.py --task interpolate --train --data_path $data_path --weight_path $weight_path --out_dir $out_dir --train_batch $train_batch --train_clip $train_clip --train_lr $train_lr --train_dt $train_dt --train_sd $train_sd --train_tau $train_tau --pretrain_epochs $pretrain_epochs --train_epochs $train_epochs --k_dim $k_dim --layers $layers --save $save --seed $seed --device $device" 33 | 34 | device=3 35 | seed=4 36 | screen -dm bash -c "cd ../../; $python src/weinreb.py --task interpolate --train --data_path $data_path --weight_path $weight_path --out_dir $out_dir --train_batch $train_batch --train_clip $train_clip --train_lr $train_lr --train_dt $train_dt --train_sd $train_sd --train_tau $train_tau --pretrain_epochs $pretrain_epochs --train_epochs $train_epochs --k_dim $k_dim --layers $layers --save $save --seed $seed --device $device" 37 | 38 | device=4 39 | seed=5 40 | screen -dm bash -c "cd ../../; $python src/weinreb.py --task interpolate --train --data_path $data_path --weight_path $weight_path --out_dir $out_dir --train_batch $train_batch --train_clip $train_clip --train_lr $train_lr --train_dt $train_dt --train_sd $train_sd --train_tau $train_tau --pretrain_epochs $pretrain_epochs --train_epochs $train_epochs --k_dim $k_dim --layers $layers --save $save --seed $seed --device $device" 41 | 42 | -------------------------------------------------------------------------------- /depr/veres.py: -------------------------------------------------------------------------------- 1 | # training on veres et al. dataset 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, optim 6 | 7 | import annoy 8 | import tqdm 9 | 10 | from geomloss import SamplesLoss 11 | 12 | import numpy as np 13 | import pandas as pd 14 | import scipy.stats 15 | 16 | from collections import OrderedDict, Counter 17 | from joblib import Parallel, delayed 18 | from types import SimpleNamespace 19 | from time import strftime, localtime 20 | 21 | import argparse 22 | import copy 23 | import glob 24 | import itertools 25 | import json 26 | import os 27 | import sys 28 | 29 | import train 30 | 31 | def init_config(args): 32 | 33 | config = SimpleNamespace( 34 | 35 | seed = args.seed, 36 | timestamp = strftime("%a, %d %b %Y %H:%M:%S", localtime()), 37 | 38 | # data parameters 39 | data_dir = args.data_dir, 40 | data_path = args.data_path, 41 | weight_path = args.weight_path, 42 | weight = args.weight, 43 | 44 | # model parameters 45 | activation = args.activation, 46 | layers = args.layers, 47 | k_dim = args.k_dim, 48 | 49 | # pretraining parameters 50 | pretrain_burnin = 50, 51 | pretrain_sd = 0.1, 52 | pretrain_lr = 1e-9, 53 | pretrain_epochs = args.pretrain_epochs, 54 | 55 | # training parameters 56 | train_dt = args.train_dt, 57 | train_sd = args.train_sd, 58 | train_batch_size = args.train_batch, 59 | ns = 2000, 60 | train_burnin = 100, 61 | train_tau = args.train_tau, 62 | train_epochs = args.train_epochs, 63 | train_lr = args.train_lr, 64 | train_clip = args.train_clip, 65 | save = args.save, 66 | 67 | # loss parameters 68 | sinkhorn_scaling = 0.7, 69 | sinkhorn_blur = 0.1, 70 | 71 | # file parameters 72 | out_dir = args.out_dir, 73 | out_name = args.out_dir.split('/')[-1], 74 | pretrain_pt = os.path.join(args.out_dir, 'pretrain.pt'), 75 | train_pt = os.path.join(args.out_dir, 'train.{}.pt'), 76 | train_log = os.path.join(args.out_dir, 'train.log'), 77 | done_log = os.path.join(args.out_dir, 'done.log'), 78 | config_pt = os.path.join(args.out_dir, 'config.pt'), 79 | ) 80 | 81 | config.train_t = [] 82 | config.test_t = [] 83 | 84 | if not os.path.exists(args.out_dir): 85 | print('Making directory at {}'.format(args.out_dir)) 86 | os.makedirs(args.out_dir) 87 | else: 88 | print('Directory exists at {}'.format(args.out_dir)) 89 | 90 | return config 91 | 92 | def load_data(config, base_dir = "."): 93 | 94 | data_pt = torch.load(os.path.join(base_dir, config.data_path)) 95 | x = data_pt['xp'] 96 | y = data_pt['y'] 97 | 98 | config.x_dim = x[0].shape[-1] 99 | config.t = y[-1] - y[0] 100 | 101 | y_start = y[config.start_t] 102 | y_ = [y_ for y_ in y if y_ > y_start] 103 | 104 | weight_pt = torch.load(os.path.join(base_dir, config.weight_path)) 105 | 106 | w_ = weight_pt['w'][config.start_t] 107 | w = {(y_start, yy): torch.from_numpy(np.exp((yy - y_start)*w_)) for yy in y_} 108 | 109 | return x, y, w 110 | 111 | def train_fate(args): 112 | 113 | a = copy.copy(args) 114 | 115 | # data 116 | 117 | a.data_path = os.path.join(a.data_dir, 'fate_train.pt') 118 | 119 | weight = os.path.basename(a.weight_path) 120 | weight = weight.split('.')[0].split('-')[-1] 121 | a.weight = weight 122 | 123 | # out directory 124 | 125 | name = ( 126 | "{weight}-" 127 | "{activation}_{layers}_{k_dim}-" 128 | "{train_tau}" 129 | ).format(**a.__dict__) 130 | 131 | a.out_dir = os.path.join(args.out_dir, name, 'seed_{}'.format(a.seed)) 132 | config = init_config(a) 133 | 134 | config.start_t = 0 135 | config.train_t = [1, 2, 3, 4, 5, 6, 7] 136 | 137 | x, y, w = load_data(config) 138 | 139 | return x, y, w, config 140 | 141 | def evaluate_fit(args, config): 142 | 143 | log_path = os.path.join(config.out_dir, 'interpolate.log') 144 | if os.path.exists(log_path): 145 | print(log_path, 'exists. Skipping.') 146 | return 147 | 148 | x, y, w = load_data(config) 149 | 150 | # -- initialize 151 | device, kwargs = train.init(args) 152 | model = train.AutoGenerator(config) 153 | 154 | ot_solver = SamplesLoss("sinkhorn", p = 2, blur = config.sinkhorn_blur, 155 | scaling = config.sinkhorn_scaling) 156 | 157 | losses_xy = [] 158 | train_pts = sorted(glob.glob(config.train_pt.format('*'))) 159 | for train_pt in train_pts: 160 | 161 | checkpoint = torch.load(train_pt) 162 | print('Loading model from {}'.format(train_pt)) 163 | model.load_state_dict(checkpoint['model_state_dict']) 164 | model.to(device) 165 | print(model) 166 | 167 | name = os.path.basename(train_pt).split('.')[1] 168 | 169 | # -- evaluate 170 | torch.manual_seed(0) 171 | np.random.seed(0) 172 | 173 | for t_cur in config.train_t: 174 | 175 | t_prev = config.start_t 176 | y_prev = int(y[t_prev]) 177 | y_cur = int(y[t_cur]) 178 | 179 | time_elapsed = y_cur - y_prev 180 | num_steps = int(np.round(time_elapsed / config.train_dt)) 181 | 182 | dat_prev = x[t_prev].to(device) 183 | w_prev = train.get_weight(w[(y_prev, y_cur)], time_elapsed).cpu().numpy() 184 | 185 | x_s = [] 186 | x_i_ = train.weighted_samp(dat_prev, args.evaluate_n, w_prev) 187 | 188 | for i in range(int(args.evaluate_n / config.ns)): 189 | 190 | x_i = x_i_[i*config.ns:(i+1)*config.ns,] 191 | 192 | for _ in range(num_steps): 193 | z = torch.randn(x_i.shape[0], x_i.shape[1]) * config.train_sd 194 | z = z.to(device) 195 | x_i = model._step(x_i, dt = config.train_dt, z = z) 196 | 197 | x_s.append(x_i.detach()) 198 | 199 | x_s = torch.cat(x_s) 200 | 201 | loss_xy = ([name, t_cur] + 202 | [ot_solver(x_s, x[t_].to(device)).item() for t_ in range(len(x))]) 203 | losses_xy.append(loss_xy) 204 | 205 | losses_xy = pd.DataFrame(losses_xy, columns = ['epoch', 't_cur'] + y) 206 | losses_xy.to_csv(log_path, sep = '\t', index = False) 207 | print('Wrote results to', log_path) 208 | 209 | 210 | def main(): 211 | 212 | parser = argparse.ArgumentParser() 213 | parser.add_argument('-s', '--seed', type = int, default = 0) 214 | parser.add_argument('--no-cuda', action = 'store_true') 215 | parser.add_argument('--device', default = 7, type = int) 216 | parser.add_argument('--out_dir', default = './experiments') 217 | # -- data options 218 | parser.add_argument('--data_path') 219 | parser.add_argument('--data_dir') 220 | parser.add_argument('--weight_path', default = None) 221 | # -- model options 222 | parser.add_argument('--loss', default = 'euclidean') 223 | parser.add_argument('--k_dim', default = 500, type = int) 224 | parser.add_argument('--activation', default = 'softplus') 225 | parser.add_argument('--layers', default = 1, type = int) 226 | # -- pretrain options 227 | parser.add_argument('--pretrain_lr', default = 1e-9, type = float) 228 | parser.add_argument('--pretrain_epochs', default = 500, type = int) 229 | # -- train options 230 | parser.add_argument('--train_epochs', default = 5000, type = int) 231 | parser.add_argument('--train_lr', default = 0.01, type = float) 232 | parser.add_argument('--train_dt', default = 0.1, type = float) 233 | parser.add_argument('--train_sd', default = 0.5, type = float) 234 | parser.add_argument('--train_tau', default = 0, type = float) 235 | parser.add_argument('--train_batch', default = 0.1, type = float) 236 | parser.add_argument('--train_clip', default = 0.25, type = float) 237 | parser.add_argument('--save', default = 100, type = int) 238 | # -- test options 239 | parser.add_argument('--evaluate_n', default = 10000, type = int) 240 | parser.add_argument('--evaluate_data') 241 | parser.add_argument('--evaluate-baseline', action = 'store_true') 242 | # -- run options 243 | parser.add_argument('--task', default = 'fate') 244 | parser.add_argument('--train', action = 'store_true') 245 | parser.add_argument('--evaluate') 246 | parser.add_argument('--config') 247 | args = parser.parse_args() 248 | 249 | if args.task == 'fate': 250 | 251 | if args.train: 252 | 253 | args.pretrain = True 254 | args.train = True 255 | 256 | train.run(args, train_fate) 257 | 258 | if args.evaluate == 'fit': 259 | 260 | if args.config: 261 | config = SimpleNamespace(**torch.load(args.config)) 262 | evaluate_fit(args, config) 263 | else: 264 | print('Please provide a config file') 265 | 266 | 267 | 268 | 269 | 270 | if __name__ == '__main__': 271 | main() 272 | -------------------------------------------------------------------------------- /depr/weinreb.py: -------------------------------------------------------------------------------- 1 | # training and evaluation for interpolation and fate prediction tasks 2 | # on weinreb et al. dataset 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn, optim 7 | 8 | import annoy 9 | import tqdm 10 | 11 | from geomloss import SamplesLoss 12 | 13 | import numpy as np 14 | import pandas as pd 15 | import scipy.stats 16 | 17 | from collections import OrderedDict, Counter 18 | from joblib import Parallel, delayed 19 | from types import SimpleNamespace 20 | from time import strftime, localtime 21 | 22 | import argparse 23 | import copy 24 | import glob 25 | import itertools 26 | import json 27 | import os 28 | import sys 29 | 30 | import train 31 | 32 | FATE_DIR = "data/Klein2020_fate" 33 | FATE_TRAIN_PATH= os.path.join(FATE_DIR, "fate_train.pt") 34 | FATE_ANN = os.path.join(FATE_DIR, "50_20_10") 35 | FATE_TEST_PATH = os.path.join(FATE_DIR, "fate_test.pt") 36 | IMPUTE_DATA_PATH = "data/Klein2020_impute.pt" 37 | WEIGHT_DIR = 'data/Klein2020_weights' 38 | 39 | def init_config(args): 40 | 41 | config = SimpleNamespace( 42 | 43 | seed = args.seed, 44 | timestamp = strftime("%a, %d %b %Y %H:%M:%S", localtime()), 45 | 46 | # data parameters 47 | data_dir = args.data_dir, 48 | data_path = args.data_path, 49 | weight_path = args.weight_path, 50 | weight = args.weight, 51 | 52 | # model parameters 53 | activation = args.activation, 54 | layers = args.layers, 55 | k_dim = args.k_dim, 56 | 57 | # pretraining parameters 58 | pretrain_burnin = 50, 59 | pretrain_sd = 0.1, 60 | pretrain_lr = 1e-9, 61 | pretrain_epochs = args.pretrain_epochs, 62 | 63 | # training parameters 64 | train_dt = args.train_dt, 65 | train_sd = args.train_sd, 66 | train_batch_size = args.train_batch, 67 | ns = 2000, 68 | train_burnin = 100, 69 | train_tau = args.train_tau, 70 | train_epochs = args.train_epochs, 71 | train_lr = args.train_lr, 72 | train_clip = args.train_clip, 73 | save = args.save, 74 | 75 | # loss parameters 76 | sinkhorn_scaling = 0.7, 77 | sinkhorn_blur = 0.1, 78 | 79 | # file parameters 80 | out_dir = args.out_dir, 81 | out_name = args.out_dir.split('/')[-1], 82 | pretrain_pt = os.path.join(args.out_dir, 'pretrain.pt'), 83 | train_pt = os.path.join(args.out_dir, 'train.{}.pt'), 84 | train_log = os.path.join(args.out_dir, 'train.log'), 85 | done_log = os.path.join(args.out_dir, 'done.log'), 86 | config_pt = os.path.join(args.out_dir, 'config.pt'), 87 | ) 88 | 89 | config.train_t = [] 90 | config.test_t = [] 91 | 92 | if not os.path.exists(args.out_dir): 93 | print('Making directory at {}'.format(args.out_dir)) 94 | os.makedirs(args.out_dir) 95 | else: 96 | print('Directory exists at {}'.format(args.out_dir)) 97 | 98 | return config 99 | 100 | def load_data(config, base_dir = "."): 101 | 102 | data_pt = torch.load(os.path.join(base_dir, config.data_path)) 103 | x = data_pt['xp'] 104 | y = data_pt['y'] 105 | 106 | config.x_dim = x[0].shape[-1] 107 | config.t = y[-1] - y[0] 108 | 109 | y_start = y[config.start_t] 110 | y_ = [y_ for y_ in y if y_ > y_start] 111 | 112 | weight_pt = torch.load(os.path.join(base_dir, config.weight_path)) 113 | x = [xx[m] for xx, m in zip(x, weight_pt['m'])] 114 | 115 | w_ = weight_pt['w'][config.start_t] 116 | w = {(y_start, yy): torch.from_numpy(np.exp((yy - y_start)*w_)) for yy in y_} 117 | 118 | return x, y, w 119 | 120 | def train_fate(args): 121 | 122 | a = copy.copy(args) 123 | 124 | # data 125 | 126 | a.data_path = os.path.join(a.data_dir, 'fate_train.pt') 127 | 128 | weight = os.path.basename(a.weight_path) 129 | weight = weight.split('.')[0].split('-')[-1] 130 | a.weight = weight 131 | 132 | # out directory 133 | 134 | name = ( 135 | "{weight}-" 136 | "{activation}_{layers}_{k_dim}-" 137 | "{train_tau}" 138 | ).format(**a.__dict__) 139 | 140 | a.out_dir = os.path.join(args.out_dir, name, 'seed_{}'.format(a.seed)) 141 | config = init_config(a) 142 | 143 | config.start_t = 0 144 | config.train_t = [1, 2] 145 | 146 | x, y, w = load_data(config) 147 | 148 | return x, y, w, config 149 | 150 | def evaluate_fate(args, config): 151 | 152 | # -- load data 153 | 154 | data_pt = torch.load(os.path.join(config.data_dir, 'fate_test.pt')) 155 | x = data_pt['x'] 156 | y = data_pt['y'] 157 | t = data_pt['t'] 158 | 159 | ay_path = os.path.join(config.data_dir, '50_20_10') 160 | ay = annoy.AnnoyIndex(config.x_dim, 'euclidean') 161 | ay.load(ay_path + '.ann') 162 | with open(ay_path + '.txt', 'r') as f: 163 | cy = np.array([line.strip() for line in f]) 164 | 165 | # -- initialize 166 | 167 | device, kwargs = train.init(args) 168 | 169 | # -- model 170 | 171 | model = train.AutoGenerator(config) 172 | 173 | log_str = '{} {:.5f} {:.3e} {:.5f} {:.3e} {:d}' 174 | log_handle = open(os.path.join(config.out_dir, 'fate.log'), 'w') 175 | 176 | names_ = [] 177 | scores_ = [] 178 | masks_ = [] 179 | 180 | train_pts = sorted(glob.glob(config.train_pt.format('*'))) 181 | for train_pt in train_pts: 182 | 183 | name = os.path.basename(train_pt).split('.')[1] 184 | 185 | checkpoint = torch.load(train_pt) 186 | print('Loading model from {}'.format(train_pt)) 187 | model.load_state_dict(checkpoint['model_state_dict']) 188 | model.to(device) 189 | print(model) 190 | 191 | # -- evaluate 192 | torch.manual_seed(0) 193 | 194 | time_elapsed = config.t 195 | num_steps = int(np.round(time_elapsed / config.train_dt)) 196 | 197 | scores = [] 198 | mask = [] 199 | pbar = tqdm.tqdm(range(len(x)), desc = "[fate:{}]".format(name)) 200 | for i in pbar: 201 | 202 | # expand data point 203 | x_i = x[i].expand(config.ns, -1).to(device) 204 | 205 | # simulate forward 206 | for _ in range(num_steps): 207 | z = torch.randn(x_i.shape[0], x_i.shape[1]) * config.train_sd 208 | z = z.to(device) 209 | x_i = model._step(x_i, dt = config.train_dt, z = z) 210 | x_i_ = x_i.detach().cpu().numpy() 211 | 212 | # predict 213 | yp = [] 214 | for j in range(x_i_.shape[0]): 215 | nn = cy[ay.get_nns_by_vector(x_i_[j], 20)] 216 | nn = Counter(nn).most_common(2) 217 | label, num = nn[0] 218 | if len(nn) > 1: 219 | _, num2 = nn[1] 220 | if num == num2: # deal with ties by setting it to the default class 221 | label = 'Other' 222 | yp.append(label) 223 | yp = Counter(yp) 224 | 225 | # may want to save yp instead 226 | num_neu = yp['Neutrophil'] + 1 # use pseudocounts for scoring 227 | num_total = yp['Neutrophil'] + yp['Monocyte'] + 2 228 | score = num_neu / num_total 229 | scores.append(score) 230 | num_total = yp['Neutrophil'] + yp['Monocyte'] 231 | mask.append(num_total > 0) 232 | 233 | scores = np.array(scores) 234 | mask = np.array(mask) 235 | 236 | r, pval = scipy.stats.pearsonr(y, scores) 237 | r_masked, pval_masked = scipy.stats.pearsonr(y[mask], scores[mask]) 238 | 239 | log = log_str.format(name, r, pval, r_masked, pval_masked, mask.sum()) 240 | log_handle.write(log + '\n') 241 | print(log) 242 | 243 | names_.append(name) 244 | scores_.append(scores) 245 | masks_.append(mask) 246 | 247 | log_handle.close() 248 | 249 | torch.save({ 250 | 'scores': scores_, 251 | 'mask': masks_, 252 | 'names': names_ 253 | }, os.path.join(config.out_dir, 'fate.pt')) 254 | 255 | def train_interpolate(args, data_path = IMPUTE_DATA_PATH): 256 | 257 | a = copy.copy(args) 258 | 259 | weight = os.path.basename(a.weight_path) 260 | weight = weight.split('.')[0].split('-')[-1] 261 | a.weight = weight 262 | 263 | name = ( 264 | "{weight}-" 265 | "{activation}_{layers}_{k_dim}-" 266 | "{train_dt}_{train_sd}_{train_tau}-" 267 | "{train_batch}_{train_clip}_{train_lr}" 268 | ).format(**a.__dict__) 269 | 270 | a.out_dir = os.path.join(args.out_dir, name, 'seed_{}'.format(a.seed)) 271 | config = init_config(a) 272 | 273 | config.start_t = 0 274 | config.train_t = [2] 275 | config.test_t = [1] 276 | 277 | x, y, w = load_data(config) 278 | 279 | return x, y, w, config 280 | 281 | def evaluate_interpolate_model(args, config): 282 | 283 | if not os.path.exists(config.done_log): 284 | print(config.done_log, 'does not exist. Skipping.') 285 | return 286 | 287 | log_path = os.path.join(config.out_dir, 'interpolate.log') 288 | if os.path.exists(log_path): 289 | print(log_path, 'exists. Skipping.') 290 | return 291 | 292 | x, y, w = load_data(config) 293 | 294 | # -- initialize 295 | device, kwargs = train.init(args) 296 | model = train.AutoGenerator(config) 297 | 298 | ot_solver = SamplesLoss("sinkhorn", p = 2, blur = config.sinkhorn_blur, 299 | scaling = config.sinkhorn_scaling) 300 | 301 | losses_xy = [] 302 | train_pts = sorted(glob.glob(config.train_pt.format('*'))) 303 | for train_pt in train_pts: 304 | 305 | checkpoint = torch.load(train_pt) 306 | print('Loading model from {}'.format(train_pt)) 307 | model.load_state_dict(checkpoint['model_state_dict']) 308 | model.to(device) 309 | print(model) 310 | 311 | name = os.path.basename(train_pt).split('.')[1] 312 | 313 | # -- evaluate 314 | 315 | def _evaluate_impute_model(t_cur): 316 | 317 | torch.manual_seed(0) 318 | np.random.seed(0) 319 | 320 | t_prev = config.start_t 321 | y_prev = int(y[t_prev]) 322 | y_cur = int(y[t_cur]) 323 | time_elapsed = y_cur - y_prev 324 | num_steps = int(np.round(time_elapsed / config.train_dt)) 325 | 326 | dat_prev = x[t_prev].to(device) 327 | dat_cur = x[t_cur].to(device) 328 | w_prev = train.get_weight(w[(y_prev, y_cur)], time_elapsed).cpu().numpy() 329 | 330 | x_s = [] 331 | x_i_ = train.weighted_samp(dat_prev, args.evaluate_n, w_prev) 332 | 333 | for i in range(int(args.evaluate_n / config.ns)): 334 | 335 | x_i = x_i_[i*config.ns:(i+1)*config.ns,] 336 | 337 | for _ in range(num_steps): 338 | z = torch.randn(x_i.shape[0], x_i.shape[1]) * config.train_sd 339 | z = z.to(device) 340 | x_i = model._step(x_i, dt = config.train_dt, z = z) 341 | 342 | x_s.append(x_i.detach()) 343 | 344 | x_s = torch.cat(x_s) 345 | 346 | loss_xy = ot_solver(x_s, dat_cur) 347 | return loss_xy 348 | 349 | for t in config.train_t: 350 | y_ = y[t] 351 | loss_xy = _evaluate_impute_model(t).item() 352 | losses_xy.append((name, 'train', y_, loss_xy)) 353 | try: 354 | for t in config.test_t: 355 | y_ = y[t] 356 | loss_xy = _evaluate_impute_model(t).item() 357 | losses_xy.append((name, 'test', y_, loss_xy)) 358 | except AttributeError: 359 | continue 360 | 361 | losses_xy = pd.DataFrame(losses_xy, columns = ['epoch', 'eval', 't', 'loss']) 362 | losses_xy.to_csv(log_path, sep = '\t', index = False) 363 | print('Wrote results to', log_path) 364 | 365 | def evaluate_interpolate_model_baseline(args, config): 366 | 367 | if not os.path.exists(config.done_log): 368 | print(config.done_log, 'does not exist. Skipping.') 369 | return 370 | 371 | log_path = os.path.join(config.out_dir, 'baseline.log') 372 | if os.path.exists(log_path): 373 | print(log_path, 'exists. Skipping.') 374 | return 375 | 376 | x, y, w = load_data(config) 377 | 378 | # -- initialize 379 | device, kwargs = train.init(args) 380 | model = train.AutoGenerator(config) 381 | 382 | ot_solver = SamplesLoss("sinkhorn", p = 2, blur = config.sinkhorn_blur, 383 | scaling = config.sinkhorn_scaling) 384 | 385 | losses_xy = [] 386 | train_pts = sorted(glob.glob(config.train_pt.format('*'))) 387 | for train_pt in train_pts: 388 | 389 | checkpoint = torch.load(train_pt) 390 | print('Loading model from {}'.format(train_pt)) 391 | model.load_state_dict(checkpoint['model_state_dict']) 392 | model.to(device) 393 | print(model) 394 | 395 | name = os.path.basename(train_pt).split('.')[1] 396 | 397 | # -- evaluate 398 | torch.manual_seed(0) 399 | np.random.seed(0) 400 | 401 | t_cur = 1 402 | 403 | t_prev = config.start_t 404 | y_prev = int(y[t_prev]) 405 | y_cur = int(y[t_cur]) 406 | time_elapsed = y_cur - y_prev 407 | num_steps = int(np.round(time_elapsed / config.train_dt)) 408 | 409 | dat_prev = x[t_prev].to(device) 410 | w_prev = train.get_weight(w[(y_prev, y_cur)], time_elapsed).cpu().numpy() 411 | 412 | x_s = [] 413 | x_i_ = train.weighted_samp(dat_prev, args.evaluate_n, w_prev) 414 | 415 | for i in range(int(args.evaluate_n / config.ns)): 416 | 417 | x_i = x_i_[i*config.ns:(i+1)*config.ns,] 418 | 419 | for _ in range(num_steps): 420 | z = torch.randn(x_i.shape[0], x_i.shape[1]) * config.train_sd 421 | z = z.to(device) 422 | x_i = model._step(x_i, dt = config.train_dt, z = z) 423 | 424 | x_s.append(x_i.detach()) 425 | 426 | x_s = torch.cat(x_s) 427 | 428 | loss_xy = [name] + [ot_solver(x_s, x[t_].to(device)).item() for t_ in range(len(x))] 429 | losses_xy.append(loss_xy) 430 | 431 | losses_xy = pd.DataFrame(losses_xy, columns = ['epoch'] + y) 432 | losses_xy.to_csv(log_path, sep = '\t', index = False) 433 | print('Wrote results to', log_path) 434 | 435 | def evaluate_interpolate_data(args, config): 436 | 437 | x, y, w = load_data(config) 438 | 439 | device, kwargs = train.init(args) 440 | 441 | pt = torch.load(args.evaluate_data) 442 | x_i = torch.from_numpy(pt['sim_xp']).float().to(device) 443 | y_j = x[1].to(device) 444 | 445 | ot_solver = SamplesLoss("sinkhorn", p = 2, blur = config.sinkhorn_blur, 446 | scaling = config.sinkhorn_scaling) 447 | loss_xy = ot_solver(x_i, y_j) 448 | 449 | import pdb; pdb.set_trace() 450 | 451 | def main(): 452 | 453 | parser = argparse.ArgumentParser() 454 | parser.add_argument('-s', '--seed', type = int, default = 0) 455 | parser.add_argument('--no-cuda', action = 'store_true') 456 | parser.add_argument('--device', default = 7, type = int) 457 | parser.add_argument('--out_dir', default = './experiments') 458 | # -- data options 459 | parser.add_argument('--data_path') 460 | parser.add_argument('--data_dir') 461 | parser.add_argument('--weight_path', default = None) 462 | # -- model options 463 | parser.add_argument('--loss', default = 'euclidean') 464 | parser.add_argument('--k_dim', default = 500, type = int) 465 | parser.add_argument('--activation', default = 'softplus') 466 | parser.add_argument('--layers', default = 1, type = int) 467 | # -- pretrain options 468 | parser.add_argument('--pretrain_lr', default = 1e-9, type = float) 469 | parser.add_argument('--pretrain_epochs', default = 500, type = int) 470 | # -- train options 471 | parser.add_argument('--train_epochs', default = 5000, type = int) 472 | parser.add_argument('--train_lr', default = 0.01, type = float) 473 | parser.add_argument('--train_dt', default = 0.1, type = float) 474 | parser.add_argument('--train_sd', default = 0.5, type = float) 475 | parser.add_argument('--train_tau', default = 0, type = float) 476 | parser.add_argument('--train_batch', default = 0.1, type = float) 477 | parser.add_argument('--train_clip', default = 0.25, type = float) 478 | parser.add_argument('--save', default = 100, type = int) 479 | # -- test options 480 | parser.add_argument('--evaluate_n', default = 10000, type = int) 481 | parser.add_argument('--evaluate_data') 482 | parser.add_argument('--evaluate-baseline', action = 'store_true') 483 | # -- run options 484 | parser.add_argument('--task', default = 'fate') 485 | parser.add_argument('--train', action = 'store_true') 486 | parser.add_argument('--evaluate') 487 | parser.add_argument('--config') 488 | args = parser.parse_args() 489 | 490 | if args.task == 'fate': 491 | 492 | if args.train: 493 | 494 | args.pretrain = True 495 | args.train = True 496 | 497 | train.run(args, train_fate) 498 | 499 | if args.evaluate == 'model': 500 | 501 | if args.config: 502 | config = SimpleNamespace(**torch.load(args.config)) 503 | evaluate_fate(args, config) 504 | else: 505 | print('Please provide a config file') 506 | 507 | elif args.task == 'interpolate': 508 | 509 | if args.train: 510 | 511 | args.pretrain = True 512 | args.train = True 513 | 514 | config = train.run(args, train_interpolate) 515 | 516 | elif args.evaluate: 517 | 518 | if args.evaluate == 'model': 519 | evaluate = evaluate_interpolate_model 520 | elif args.evaluate == 'data': 521 | evaluate = evaluate_interpolate_data 522 | elif args.evaluate == 'baseline': 523 | evaluate = evaluate_interpolate_model_baseline 524 | else: 525 | raise NotImplementedError 526 | 527 | if args.config: 528 | config = SimpleNamespace(**torch.load(args.config)) 529 | else: 530 | print("Please provide a config file") 531 | 532 | evaluate(args, config) 533 | 534 | 535 | 536 | 537 | if __name__ == '__main__': 538 | main() 539 | -------------------------------------------------------------------------------- /docs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gifford-lab/prescient/50971c7d495e8763eaa60f83af91f51555ed7ece/docs/.DS_Store -------------------------------------------------------------------------------- /docs/.bundle/config: -------------------------------------------------------------------------------- 1 | --- 2 | BUNDLE_PATH: "vendor/bundle" 3 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | /_site/ 2 | Gemfile 3 | Gemfile.lock 4 | /vendor/ 5 | -------------------------------------------------------------------------------- /docs/Gemfile: -------------------------------------------------------------------------------- 1 | source 'https://rubygems.org' 2 | gem 'github-pages', "~> 214" 3 | -------------------------------------------------------------------------------- /docs/Gemfile.lock: -------------------------------------------------------------------------------- 1 | GEM 2 | remote: https://rubygems.org/ 3 | specs: 4 | activesupport (6.0.3.6) 5 | concurrent-ruby (~> 1.0, >= 1.0.2) 6 | i18n (>= 0.7, < 2) 7 | minitest (~> 5.1) 8 | tzinfo (~> 1.1) 9 | zeitwerk (~> 2.2, >= 2.2.2) 10 | addressable (2.7.0) 11 | public_suffix (>= 2.0.2, < 5.0) 12 | coffee-script (2.4.1) 13 | coffee-script-source 14 | execjs 15 | coffee-script-source (1.11.1) 16 | colorator (1.1.0) 17 | commonmarker (0.17.13) 18 | ruby-enum (~> 0.5) 19 | concurrent-ruby (1.1.8) 20 | dnsruby (1.61.5) 21 | simpleidn (~> 0.1) 22 | em-websocket (0.5.2) 23 | eventmachine (>= 0.12.9) 24 | http_parser.rb (~> 0.6.0) 25 | ethon (0.12.0) 26 | ffi (>= 1.3.0) 27 | eventmachine (1.2.7) 28 | execjs (2.7.0) 29 | faraday (1.3.0) 30 | faraday-net_http (~> 1.0) 31 | multipart-post (>= 1.2, < 3) 32 | ruby2_keywords 33 | faraday-net_http (1.0.1) 34 | ffi (1.15.0) 35 | forwardable-extended (2.6.0) 36 | gemoji (3.0.1) 37 | github-pages (214) 38 | github-pages-health-check (= 1.17.0) 39 | jekyll (= 3.9.0) 40 | jekyll-avatar (= 0.7.0) 41 | jekyll-coffeescript (= 1.1.1) 42 | jekyll-commonmark-ghpages (= 0.1.6) 43 | jekyll-default-layout (= 0.1.4) 44 | jekyll-feed (= 0.15.1) 45 | jekyll-gist (= 1.5.0) 46 | jekyll-github-metadata (= 2.13.0) 47 | jekyll-mentions (= 1.6.0) 48 | jekyll-optional-front-matter (= 0.3.2) 49 | jekyll-paginate (= 1.1.0) 50 | jekyll-readme-index (= 0.3.0) 51 | jekyll-redirect-from (= 0.16.0) 52 | jekyll-relative-links (= 0.6.1) 53 | jekyll-remote-theme (= 0.4.3) 54 | jekyll-sass-converter (= 1.5.2) 55 | jekyll-seo-tag (= 2.7.1) 56 | jekyll-sitemap (= 1.4.0) 57 | jekyll-swiss (= 1.0.0) 58 | jekyll-theme-architect (= 0.1.1) 59 | jekyll-theme-cayman (= 0.1.1) 60 | jekyll-theme-dinky (= 0.1.1) 61 | jekyll-theme-hacker (= 0.1.2) 62 | jekyll-theme-leap-day (= 0.1.1) 63 | jekyll-theme-merlot (= 0.1.1) 64 | jekyll-theme-midnight (= 0.1.1) 65 | jekyll-theme-minimal (= 0.1.1) 66 | jekyll-theme-modernist (= 0.1.1) 67 | jekyll-theme-primer (= 0.5.4) 68 | jekyll-theme-slate (= 0.1.1) 69 | jekyll-theme-tactile (= 0.1.1) 70 | jekyll-theme-time-machine (= 0.1.1) 71 | jekyll-titles-from-headings (= 0.5.3) 72 | jemoji (= 0.12.0) 73 | kramdown (= 2.3.1) 74 | kramdown-parser-gfm (= 1.1.0) 75 | liquid (= 4.0.3) 76 | mercenary (~> 0.3) 77 | minima (= 2.5.1) 78 | nokogiri (>= 1.10.4, < 2.0) 79 | rouge (= 3.26.0) 80 | terminal-table (~> 1.4) 81 | github-pages-health-check (1.17.0) 82 | addressable (~> 2.3) 83 | dnsruby (~> 1.60) 84 | octokit (~> 4.0) 85 | public_suffix (>= 2.0.2, < 5.0) 86 | typhoeus (~> 1.3) 87 | html-pipeline (2.14.0) 88 | activesupport (>= 2) 89 | nokogiri (>= 1.4) 90 | http_parser.rb (0.6.0) 91 | i18n (0.9.5) 92 | concurrent-ruby (~> 1.0) 93 | jekyll (3.9.0) 94 | addressable (~> 2.4) 95 | colorator (~> 1.0) 96 | em-websocket (~> 0.5) 97 | i18n (~> 0.7) 98 | jekyll-sass-converter (~> 1.0) 99 | jekyll-watch (~> 2.0) 100 | kramdown (>= 1.17, < 3) 101 | liquid (~> 4.0) 102 | mercenary (~> 0.3.3) 103 | pathutil (~> 0.9) 104 | rouge (>= 1.7, < 4) 105 | safe_yaml (~> 1.0) 106 | jekyll-avatar (0.7.0) 107 | jekyll (>= 3.0, < 5.0) 108 | jekyll-coffeescript (1.1.1) 109 | coffee-script (~> 2.2) 110 | coffee-script-source (~> 1.11.1) 111 | jekyll-commonmark (1.3.1) 112 | commonmarker (~> 0.14) 113 | jekyll (>= 3.7, < 5.0) 114 | jekyll-commonmark-ghpages (0.1.6) 115 | commonmarker (~> 0.17.6) 116 | jekyll-commonmark (~> 1.2) 117 | rouge (>= 2.0, < 4.0) 118 | jekyll-default-layout (0.1.4) 119 | jekyll (~> 3.0) 120 | jekyll-feed (0.15.1) 121 | jekyll (>= 3.7, < 5.0) 122 | jekyll-gist (1.5.0) 123 | octokit (~> 4.2) 124 | jekyll-github-metadata (2.13.0) 125 | jekyll (>= 3.4, < 5.0) 126 | octokit (~> 4.0, != 4.4.0) 127 | jekyll-mentions (1.6.0) 128 | html-pipeline (~> 2.3) 129 | jekyll (>= 3.7, < 5.0) 130 | jekyll-optional-front-matter (0.3.2) 131 | jekyll (>= 3.0, < 5.0) 132 | jekyll-paginate (1.1.0) 133 | jekyll-readme-index (0.3.0) 134 | jekyll (>= 3.0, < 5.0) 135 | jekyll-redirect-from (0.16.0) 136 | jekyll (>= 3.3, < 5.0) 137 | jekyll-relative-links (0.6.1) 138 | jekyll (>= 3.3, < 5.0) 139 | jekyll-remote-theme (0.4.3) 140 | addressable (~> 2.0) 141 | jekyll (>= 3.5, < 5.0) 142 | jekyll-sass-converter (>= 1.0, <= 3.0.0, != 2.0.0) 143 | rubyzip (>= 1.3.0, < 3.0) 144 | jekyll-sass-converter (1.5.2) 145 | sass (~> 3.4) 146 | jekyll-seo-tag (2.7.1) 147 | jekyll (>= 3.8, < 5.0) 148 | jekyll-sitemap (1.4.0) 149 | jekyll (>= 3.7, < 5.0) 150 | jekyll-swiss (1.0.0) 151 | jekyll-theme-architect (0.1.1) 152 | jekyll (~> 3.5) 153 | jekyll-seo-tag (~> 2.0) 154 | jekyll-theme-cayman (0.1.1) 155 | jekyll (~> 3.5) 156 | jekyll-seo-tag (~> 2.0) 157 | jekyll-theme-dinky (0.1.1) 158 | jekyll (~> 3.5) 159 | jekyll-seo-tag (~> 2.0) 160 | jekyll-theme-hacker (0.1.2) 161 | jekyll (> 3.5, < 5.0) 162 | jekyll-seo-tag (~> 2.0) 163 | jekyll-theme-leap-day (0.1.1) 164 | jekyll (~> 3.5) 165 | jekyll-seo-tag (~> 2.0) 166 | jekyll-theme-merlot (0.1.1) 167 | jekyll (~> 3.5) 168 | jekyll-seo-tag (~> 2.0) 169 | jekyll-theme-midnight (0.1.1) 170 | jekyll (~> 3.5) 171 | jekyll-seo-tag (~> 2.0) 172 | jekyll-theme-minimal (0.1.1) 173 | jekyll (~> 3.5) 174 | jekyll-seo-tag (~> 2.0) 175 | jekyll-theme-modernist (0.1.1) 176 | jekyll (~> 3.5) 177 | jekyll-seo-tag (~> 2.0) 178 | jekyll-theme-primer (0.5.4) 179 | jekyll (> 3.5, < 5.0) 180 | jekyll-github-metadata (~> 2.9) 181 | jekyll-seo-tag (~> 2.0) 182 | jekyll-theme-slate (0.1.1) 183 | jekyll (~> 3.5) 184 | jekyll-seo-tag (~> 2.0) 185 | jekyll-theme-tactile (0.1.1) 186 | jekyll (~> 3.5) 187 | jekyll-seo-tag (~> 2.0) 188 | jekyll-theme-time-machine (0.1.1) 189 | jekyll (~> 3.5) 190 | jekyll-seo-tag (~> 2.0) 191 | jekyll-titles-from-headings (0.5.3) 192 | jekyll (>= 3.3, < 5.0) 193 | jekyll-watch (2.2.1) 194 | listen (~> 3.0) 195 | jemoji (0.12.0) 196 | gemoji (~> 3.0) 197 | html-pipeline (~> 2.2) 198 | jekyll (>= 3.0, < 5.0) 199 | kramdown (2.3.1) 200 | rexml 201 | kramdown-parser-gfm (1.1.0) 202 | kramdown (~> 2.0) 203 | liquid (4.0.3) 204 | listen (3.5.1) 205 | rb-fsevent (~> 0.10, >= 0.10.3) 206 | rb-inotify (~> 0.9, >= 0.9.10) 207 | mercenary (0.3.6) 208 | minima (2.5.1) 209 | jekyll (>= 3.5, < 5.0) 210 | jekyll-feed (~> 0.9) 211 | jekyll-seo-tag (~> 2.1) 212 | minitest (5.14.4) 213 | multipart-post (2.1.1) 214 | nokogiri (1.11.2-arm64-darwin) 215 | racc (~> 1.4) 216 | nokogiri (1.11.2-x86_64-darwin) 217 | racc (~> 1.4) 218 | octokit (4.20.0) 219 | faraday (>= 0.9) 220 | sawyer (~> 0.8.0, >= 0.5.3) 221 | pathutil (0.16.2) 222 | forwardable-extended (~> 2.6) 223 | public_suffix (4.0.6) 224 | racc (1.5.2) 225 | rb-fsevent (0.10.4) 226 | rb-inotify (0.10.1) 227 | ffi (~> 1.0) 228 | rexml (3.2.5) 229 | rouge (3.26.0) 230 | ruby-enum (0.9.0) 231 | i18n 232 | ruby2_keywords (0.0.4) 233 | rubyzip (2.3.0) 234 | safe_yaml (1.0.5) 235 | sass (3.7.4) 236 | sass-listen (~> 4.0.0) 237 | sass-listen (4.0.0) 238 | rb-fsevent (~> 0.9, >= 0.9.4) 239 | rb-inotify (~> 0.9, >= 0.9.7) 240 | sawyer (0.8.2) 241 | addressable (>= 2.3.5) 242 | faraday (> 0.8, < 2.0) 243 | simpleidn (0.2.1) 244 | unf (~> 0.1.4) 245 | terminal-table (1.8.0) 246 | unicode-display_width (~> 1.1, >= 1.1.1) 247 | thread_safe (0.3.6) 248 | typhoeus (1.4.0) 249 | ethon (>= 0.9.0) 250 | tzinfo (1.2.9) 251 | thread_safe (~> 0.1) 252 | unf (0.1.4) 253 | unf_ext 254 | unf_ext (0.0.7.7) 255 | unicode-display_width (1.7.0) 256 | zeitwerk (2.4.2) 257 | 258 | PLATFORMS 259 | universal-darwin-19 260 | 261 | DEPENDENCIES 262 | github-pages (~> 214) 263 | 264 | BUNDLED WITH 265 | 2.2.4 266 | -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | remote_theme: yous/whiteglass 2 | title: prescient 3 | description: Simulation of single cell differentiation trajectories using longitudinal scRNA-seq. 4 | repository: gifford-lab/prescient 5 | show_downloads: true 6 | markdown: kramdown 7 | -------------------------------------------------------------------------------- /docs/_data/examples.json: -------------------------------------------------------------------------------- 1 | { 2 | "process_data": { 3 | "text": "This command takes a normalized expression CSV, metadata CSV, and pre-computed weight torch file as input and produces a PRESCIENT training torch object.", 4 | "code": "-d data/Veres2019/Stage_5.Seurat.csv -m data/Veres2019/GSE114412_Stage_5.all.cell_metadata.csv --growth_path data/Veres2019/Veres2019_growth-kegg.pt -o './' --tp_col 'CellWeek' --celltype_col 'Assigned_cluster'" 5 | }, 6 | "train_model": { 7 | "text": "This command trains a PRESCIENT model using a PRESCIENT training torch object.", 8 | "code": "-i data.pt --out_dir /experiments/ --weight_name 'kegg-growth' --seed 3 --layers 2 --k_dim 200 --train_tau 1e-06" 9 | }, 10 | "simulate_trajectories": { 11 | "text": "This command generates simulated trajectories from randomly initialized cells using a PRESCIENT model and training torch object.", 12 | "code": "-i data.pt --model_path /experiments/kegg-growth-softplus_2_200-1e-06/ --num_steps 10 -o experiments/ --seed 2" 13 | }, 14 | "perturbation_analysis": { 15 | "text": "This command runs forward simulations of unperturbed cells and cells with perturbations of selected genes.", 16 | "code": "-i ../Downloads/data.pt -p 'GENE1,GENE2,GENE3' -z 5 --model_path /experiments/kegg-softplus_2_200-1e-06/ --num_steps 10 --seed 2 -o experiments/" 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /docs/_data/navigation.yml: -------------------------------------------------------------------------------- 1 | main: 2 | - title: "Quickstart" 3 | url: /quickstart/ 4 | - title: "Inputs" 5 | url: /file_formats/ 6 | - title: "Documentation" 7 | url: /documentation/ 8 | - title: "Notebooks" 9 | url: /notebooks/ 10 | -------------------------------------------------------------------------------- /docs/_data/tools.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "desc": "Process normalized expression dataframe into compatible PRESCIENT file format.", 4 | "name": "process_data", 5 | "epilog": null, 6 | "params": [ 7 | { 8 | "name": "data_path", 9 | "default": null, 10 | "choices": null, 11 | "help": "Path to normalized expression CSV.", 12 | "required": true 13 | }, 14 | { 15 | "name": "out_dir", 16 | "default": null, 17 | "choices": null, 18 | "help": "Directory to store PRESCIENT torch object.", 19 | "required": true 20 | }, 21 | { 22 | "name": "meta_path", 23 | "default": null, 24 | "choices": null, 25 | "help": "Path to metadata CSV containing timepoint and celltype annotation data.", 26 | "required": true 27 | }, 28 | { 29 | "name": "tp_col", 30 | "default": null, 31 | "choices": null, 32 | "help": "Column name of timepoint feature in metadata provided as string.", 33 | "required": false 34 | }, 35 | { 36 | "name": "celltype_col", 37 | "default": null, 38 | "choices": null, 39 | "help": "Column name of timepoint feature in metadata provided as string.", 40 | "required": false 41 | }, 42 | { 43 | "name": "num_pcs", 44 | "default": 50, 45 | "choices": null, 46 | "help": "Define number of PCs to compute for input to training.", 47 | "required": false 48 | }, 49 | { 50 | "name": "num_neighbors_umap", 51 | "default": 10, 52 | "choices": null, 53 | "help": "Define number of neighbors for UMAP trasformation (UMAP used only for visualization.)", 54 | "required": false 55 | }, 56 | { 57 | "name": "growth_path", 58 | "default": null, 59 | "choices": null, 60 | "help": "Path to torch pt file containg pre-computed growth weights. See vignette notebooks for generating growth rate vector.", 61 | "required": false 62 | } 63 | 64 | ] 65 | }, 66 | { 67 | "desc": "Train a PRESCIENT model using a PRESCIENT data object as input.", 68 | "name": "train_model", 69 | "epilog": null, 70 | "params": [ 71 | { 72 | "name": "data_path", 73 | "default": null, 74 | "choices": null, 75 | "help": "Path to PRESCIENT data torch object produced by process_data.", 76 | "required": true 77 | }, 78 | { 79 | "name": "weight_name", 80 | "default": null, 81 | "choices": null, 82 | "help": "Descriptive name of weight vector being used provided as string for model filename.", 83 | "required": true 84 | }, 85 | { 86 | "name": "loss", 87 | "default": "euclidean", 88 | "choices": null, 89 | "help": "Designate distance function for loss.", 90 | "required": false 91 | }, 92 | { 93 | "name": "k_dim", 94 | "default": 500, 95 | "choices": null, 96 | "help": "Designate activation function for layers of NN.", 97 | "required": false 98 | }, 99 | { 100 | "name": "activation", 101 | "default": "softplus", 102 | "choices": null, 103 | "help": "Designate hidden units of fully connected layers in model.", 104 | "required": false 105 | }, 106 | { 107 | "name": "layers", 108 | "default": 2, 109 | "choices": null, 110 | "help": "Number of layers for neural network parameterizing the potential function.", 111 | "required": false 112 | }, 113 | { 114 | "name": "pretrain_lr", 115 | "default": 1e-9, 116 | "choices": null, 117 | "help": "Learning rate for Adam optimizer during pretraining.", 118 | "required": false 119 | }, 120 | { 121 | "name": "pretrain_epochs", 122 | "default": 500, 123 | "choices": null, 124 | "help": "Number of epochs for pretraining with contrastive divergence.", 125 | "required": false 126 | }, 127 | { 128 | "name": "train_epochs", 129 | "default": 2500, 130 | "choices": null, 131 | "help": "Number of epochs for training.", 132 | "required": false 133 | }, 134 | { 135 | "name": "train_lr", 136 | "default": 0.01, 137 | "choices": null, 138 | "help": "Learning rate for Adam optimizer during training.", 139 | "required": false 140 | }, 141 | { 142 | "name": "train_dt", 143 | "default": 0.1, 144 | "choices": null, 145 | "help": "Timestep for simulations during training.", 146 | "required": false 147 | }, 148 | { 149 | "name": "train_sd", 150 | "default": 0.5, 151 | "choices": null, 152 | "help": "Standard deviation of Gaussian noise for simulation steps.", 153 | "required": false 154 | }, 155 | { 156 | "name": "train_tau", 157 | "default": 1e-6, 158 | "choices": null, 159 | "help": "Tau hyperparameter of PRESCIENT.", 160 | "required": false 161 | }, 162 | { 163 | "name": "train_batch", 164 | "default": 0.1, 165 | "choices": null, 166 | "help": "Batch size (fraction) for training.", 167 | "required": false 168 | }, 169 | { 170 | "name": "train_clip", 171 | "default": 0.25, 172 | "choices": null, 173 | "help": "Gradient clipping threshold for training.", 174 | "required": false 175 | }, 176 | { 177 | "name": "save", 178 | "default": 100, 179 | "choices": null, 180 | "help": "Save model every n epochs as torch dict.", 181 | "required": false 182 | } 183 | ] 184 | }, 185 | { 186 | "desc": "Simulate cellular trajectories using a trained PRESCIENT model and a PRESCIENT data object.", 187 | "name": "simulate_trajectories", 188 | "epilog": null, 189 | "params": [ 190 | { 191 | "name": "data_path", 192 | "default": null, 193 | "choices": null, 194 | "help": "Path to PRESCIENT training file (stored in out_dir of process_data command).", 195 | "required": true 196 | }, 197 | { 198 | "name": "model_path", 199 | "default": null, 200 | "choices": null, 201 | "help": "Path to directory containing PRESCIENT model for simulation.", 202 | "required": true 203 | }, 204 | { 205 | "name": "out_path", 206 | "default": null, 207 | "choices": null, 208 | "help": "Path to directory for storing output.", 209 | "required": true 210 | }, 211 | { 212 | "name": "num_sims", 213 | "default": 10, 214 | "choices": null, 215 | "help": "Number of simulations (random initializations of n cells) to run.", 216 | "required": false 217 | }, 218 | { 219 | "name": "num_cells", 220 | "default": 200, 221 | "choices": null, 222 | "help": "Number of cells per simulation.", 223 | "required": false 224 | }, 225 | { 226 | "name": "num_steps", 227 | "default": null, 228 | "choices": null, 229 | "help": "Number of steps forward in time. If not provided, steps will be calculated based on start and end point + train dt.", 230 | "required": false 231 | }, 232 | { 233 | "name": "seed", 234 | "default": 1, 235 | "choices": null, 236 | "help": "Choose the seed of the trained model to use for simulations.", 237 | "required": false 238 | }, 239 | { 240 | "name": "epoch", 241 | "default": "002500", 242 | "choices": null, 243 | "help": "Choose which epoch of the chosen model to use for simulations. Provide this value as str.", 244 | "required": false 245 | }, 246 | { 247 | "name": "gpu", 248 | "default": null, 249 | "choices": null, 250 | "help": "If available, assign GPU device number (requires CUDA). Provide as int.", 251 | "required": false 252 | }, 253 | { 254 | "name": "celltype_subset", 255 | "default": null, 256 | "choices": null, 257 | "help": "Randomly sample initial cells from a particular celltype defined in metadata. Provide celltype as str as appears in metadata.", 258 | "required": false 259 | }, 260 | { 261 | "name": "tp_subset", 262 | "default": null, 263 | "choices": null, 264 | "help": "Randomly sample initial cells from a particular timepoint. Provide timepoint as int or as appears in metadata.", 265 | "required": false 266 | } 267 | ] 268 | }, 269 | { 270 | "desc": "Simulate unperturbed and perturbed simulations of cells using a trained PRESCIENT model and a PRESCIENT data object.", 271 | "name": "perturbation_analysis", 272 | "epilog": null, 273 | "params": [ 274 | { 275 | "name": "perturb_genes", 276 | "default": null, 277 | "choices": null, 278 | "help": "Provide a gene or list of genes to be perturbed as a string (commas, no spaces). Must be in the feature set used to train models.", 279 | "required": true 280 | }, 281 | { 282 | "name": "z_score", 283 | "default": 5.0, 284 | "choices": null, 285 | "help": "Set magnitude of z_score perturbation.", 286 | "required": true 287 | }, 288 | { 289 | "name": "data_path", 290 | "default": null, 291 | "choices": null, 292 | "help": "Path to PRESCIENT training file (stored in out_dir of process_data command).", 293 | "required": true 294 | }, 295 | { 296 | "name": "model_path", 297 | "default": null, 298 | "choices": null, 299 | "help": "Path to directory containing PRESCIENT model for simulation.", 300 | "required": true 301 | }, 302 | { 303 | "name": "out_path", 304 | "default": null, 305 | "choices": null, 306 | "help": "Path to directory for storing output.", 307 | "required": true 308 | }, 309 | { 310 | "name": "num_sims", 311 | "default": 10, 312 | "choices": null, 313 | "help": "Number of simulations (random initializations of n cells) to run.", 314 | "required": false 315 | }, 316 | { 317 | "name": "num_cells", 318 | "default": 200, 319 | "choices": null, 320 | "help": "Number of cells per simulation.", 321 | "required": false 322 | }, 323 | { 324 | "name": "num_steps", 325 | "default": nulls, 326 | "choices": null, 327 | "help": "Number of steps forward in time. If not provided, steps will be calculated based on start and end point + train dt.", 328 | "required": false 329 | }, 330 | { 331 | "name": "seed", 332 | "default": 1, 333 | "choices": null, 334 | "help": "Choose the seed of the trained model to use for simulations.", 335 | "required": false 336 | }, 337 | { 338 | "name": "epoch", 339 | "default": 002500, 340 | "choices": null, 341 | "help": "Choose which epoch of the chosen model to use for simulations.", 342 | "required": false 343 | }, 344 | { 345 | "name": "gpu", 346 | "default": null, 347 | "choices": null, 348 | "help": "If available, assign GPU device number (requires CUDA). Provide as int.", 349 | "required": false 350 | }, 351 | { 352 | "name": "celltype_subset", 353 | "default": null, 354 | "choices": null, 355 | "help": "Randomly sample initial cells from a particular celltype defined in metadata. Provide celltype as str as appears in metadata.", 356 | "required": false 357 | }, 358 | { 359 | "name": "tp_subset", 360 | "default": null, 361 | "choices": null, 362 | "help": "Randomly sample initial cells from a particular timepoint. Provide timepoint as int or as appears in metadata.", 363 | "required": false 364 | } 365 | ] 366 | } 367 | 368 | ] 369 | -------------------------------------------------------------------------------- /docs/_includes/head.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 16 | 17 | 18 | 20 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /docs/_includes/header.html: -------------------------------------------------------------------------------- 1 | 15 | 16 |
17 |

{{default: site.github.repository_name-site.title}}

18 | 32 |
33 | -------------------------------------------------------------------------------- /docs/_includes/navbar.html: -------------------------------------------------------------------------------- 1 | 21 | -------------------------------------------------------------------------------- /docs/_layouts/cli.html: -------------------------------------------------------------------------------- 1 | --- 2 | layout: core 3 | noheader: false 4 | --- 5 |
6 |
7 |
8 | 9 |
10 | 11 |
12 |
13 | {{ content }} 14 |
15 |
16 |
17 |
18 | -------------------------------------------------------------------------------- /docs/_layouts/core.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | {% include head.html %} 4 | 11 | 12 | 14 | 15 | 16 | 17 | 18 | {% unless page.noheader %} 19 | {% include header.html %} 20 | {% endunless %} 21 | 22 | {{ content }} 23 | 24 | 25 | -------------------------------------------------------------------------------- /docs/_layouts/document.html: -------------------------------------------------------------------------------- 1 | --- 2 | layout: core 3 | noheader: false 4 | --- 5 |
6 |
7 |
8 |
9 | {{ content }} 10 |
11 |
12 |
13 | 14 | -------------------------------------------------------------------------------- /docs/_layouts/home.html: -------------------------------------------------------------------------------- 1 | --- 2 | layout: core 3 | title: Home 4 | --- 5 |
6 |
7 |
8 |
9 |
10 |

Introduction

11 |

12 | PRESCIENT (Potential eneRgy undeRlying Single-Cell gradIENTs) is a tool for simulating cellular 13 | differentiation trajectories with arbitrary cell state intializations. 14 | PRESCIENT frames differentiation as a diffusion process given by a stochastic ODE with a drift parameter 15 | given by a generative neural network. PRESCIENT models can simulate cellular differentiation trajectories 16 | for out-of-sample (i.e. not seen during training) cells, enabling robust fate prediction and perturbational analysis. 17 | Here, we package PRESCIENT as a command-line tool and PyPI package. 18 |

19 |
20 |
21 | 22 |
23 | 24 |
25 | 26 |
27 |
28 |
29 |

Installation

30 | PyPI version 31 |

32 | We recommend using pip to install PRESCIENT. For a stable version:
33 | 34 | pip install prescient 35 | 36 |

37 |

38 | For the latest version:
39 | pip install git+https://github.com/gifford-lab/prescient.git 40 |

41 |

42 | Source code available at:
43 | github.com/gifford-lab/prescient 44 |

45 | 46 |
47 |
48 |
49 | 50 |
51 |
52 |
53 |

Citing

54 | 55 |

56 | Generative modeling of single-cell time series with PRESCIENT enables prediction of cell trajectories with interventions
57 | by Grace Hui-Ting Yeo, Sachit D. Saksena, and David K. Gifford 58 |

59 |

60 | Code and notebooks for paper analyses and figures avalable at:
61 | github.com/gifford-lab/prescient-analysis 62 |

63 |
64 |
65 |
66 |
67 |
68 | -------------------------------------------------------------------------------- /docs/_site/file_formats/index.html: -------------------------------------------------------------------------------- 1 |

File formats

2 | 3 |

PRESCIENT takes as input longitudinal scRNA-seq data. For training, all that is needed is normalized gene expression, time-point labels, and cell type annotations. These inputs are used to generate a PRESCIENT data file using data.py. Below, we describe accepted formatting for inputs. For pre-processing, we recommend using Seurat or scanpy. PRESCIENT accepts the following formats: .csv, .tsv, .txt, .h5ad of a scanpy anndata object, or an .rds file of a Seurat object.

4 | 5 |

CSV, TSV, TXT

6 |

A post-processed gene expression file in .csv, .tsv, or .txt in the following format will work to create a PRESCIENT data object: 7 | | id | gene_1 | gene_2 | gene_3 | … | gene_n | 8 | |——– |——– |——– |——– |—– |——– | 9 | | cell_1 | 0.0 | 0.121 | 0.0 | | 0.0 | 10 | | cell_2 | 0.234 | 0.0 | 0.0 | | 0.0 | 11 | | cell_3 | 0.0 | 0.0 | 0.0 | | 1.2 |

12 | 13 |

Scanpy AnnData

14 |

If pre-processing is done with Scanpy, you can directly provide the AnnData object to the PRESCIENT data.py command line function. The AnnData object should contain the following information:

15 | 33 | 34 |

Seurat object

35 |

If pre-processing is done with Seurat, you can directly provide the Seurat object as .rds or convert it to a .csv and provide the file as directed in CSV 36 | ‘’

37 | -------------------------------------------------------------------------------- /docs/_site/index.md: -------------------------------------------------------------------------------- 1 | # prescient 2 | -------------------------------------------------------------------------------- /docs/about.html: -------------------------------------------------------------------------------- 1 | --- 2 | title: About 3 | noheader: false 4 | permalink: about/ 5 | layout: document 6 | location: About 7 | --- 8 |

Potential EneRgy undErlying Single- 9 | Cell gradIENTs (PRESCIENT)

10 |
11 |

12 | 13 | 14 |
15 | 16 |
17 | 18 |
19 | 20 |
21 | -------------------------------------------------------------------------------- /docs/assets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gifford-lab/prescient/50971c7d495e8763eaa60f83af91f51555ed7ece/docs/assets/.DS_Store -------------------------------------------------------------------------------- /docs/assets/css/styles.scss: -------------------------------------------------------------------------------- 1 | --- 2 | --- 3 | @import "{{ site.theme }}"; 4 | 5 | .navbar-right { 6 | float: center; 7 | } 8 | 9 | a { color: #d64e8a; } /* CSS link color */ 10 | 11 | /* code { background-color: lightgrey; } */ 12 | -------------------------------------------------------------------------------- /docs/assets/gifs/trajectories.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gifford-lab/prescient/50971c7d495e8763eaa60f83af91f51555ed7ece/docs/assets/gifs/trajectories.gif -------------------------------------------------------------------------------- /docs/assets/img/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gifford-lab/prescient/50971c7d495e8763eaa60f83af91f51555ed7ece/docs/assets/img/.DS_Store -------------------------------------------------------------------------------- /docs/assets/img/implementation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gifford-lab/prescient/50971c7d495e8763eaa60f83af91f51555ed7ece/docs/assets/img/implementation.png -------------------------------------------------------------------------------- /docs/assets/img/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gifford-lab/prescient/50971c7d495e8763eaa60f83af91f51555ed7ece/docs/assets/img/model.png -------------------------------------------------------------------------------- /docs/documentation.html: -------------------------------------------------------------------------------- 1 | --- 2 | title: Documentation 3 | noheader: false 4 | permalink: documentation/ 5 | layout: document 6 | location: documentation 7 | --- 8 |

Command line interface documentation

9 |

10 | PRESCIENT is primarily implemented as a command line tool. Access manual for each command in the command-line using the syntax prescient commands -h . 11 | Run the following commands via prescient command [params]. 12 |

13 | 14 |
15 | {% for tool in site.data.tools %} 16 |
17 |

{{ tool.name }}

18 | 19 | 23 |
24 | 25 | 26 | 27 | 28 | 29 | {% for param in tool.params %} 30 | 31 | 32 | 34 | 37 | 38 | {% endfor %} 39 |
ParameterDescription
{% if param.required %} {% endif %} {{param.name}} {% if param.required %} {% endif %} 33 | {{param.help}}{% if param.choices %}
Choices: {{param.choices | join: ", "}} {% endif %}{% if 35 | param.default %}
Default: {{param.default}} {% endif %} 36 |
40 | {% if site.data.examples[tool.name] %} 41 | Example Usage:
42 | prescient {{tool.name }} {{ site.data.examples[tool.name].code }}
43 | {{ site.data.examples[tool.name].text }} 44 | {% endif %} 45 |
46 |
47 | {% endfor %} 48 | 49 |
50 |

Links to resources for running CLI with Google Cloud SDK

51 |

52 | If you do not have access to GPUs and want to use them for 53 | training and simulations (alternatively, you can use CPUs) from the command line, 54 | we recommend using any cloud computing service that provides 55 | NVIDIA GPUs with CUDA support. For an easier approach, we have provided a short demo in the notebooks tab 56 | for using free cloud GPUs in a notebook via Google Colab. We recommend this approach, as the setup process for Google Cloud SDKs 57 | can be intensive. That being said, we provide a list of Google Cloud web tutorials for setting up a Google Cloud account, 58 | Google Cloud SDKs command-line interface, creating a GPU instance, and running an interactive shell: 59 |

60 |
    61 |
  1. Setting up account and billing for buying GPU compute time
  2. 62 |
  3. Downloading Google Cloud SDK (gcloud command)
  4. 63 |
  5. Creating a virtual machine (VM) with mounted GPU
  6. 64 |
  7. Using gcloud interactive shell
  8. 65 |
66 | -------------------------------------------------------------------------------- /docs/file_formats.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: File Formats 3 | permalink: file_formats/ 4 | layout: document 5 | location: file_formats 6 | --- 7 | 8 | ## File formats 9 | 10 | PRESCIENT takes as input longitudinal scRNA-seq data. For training, all that is needed is normalized gene expression, time-point labels, and cell type annotations. These inputs are used to generate a **PRESCIENT torch object** using the `prescient process_data` (see below). Below, we describe accepted formatting for inputs. For pre-processing, we recommend using Seurat or scanpy. PRESCIENT accepts the following formats: **.csv**, **.tsv**, **.txt**, **.h5ad** of a scanpy anndata object, or an **.rds** file of a Seurat object. 11 | 12 | ## Normalized expression 13 | A post-processed gene expression file in .csv, .tsv, or .txt in the following format will work to create a PRESCIENT data object: 14 | 15 | | id | gene_1 | gene_2 | gene_3 | ... | gene_n | 16 | |:-------- |:-------- |:-------- |:-------- |:----- |:-------- | 17 | | cell_1 | 0.0 | 0.121 | 0.0 | | 0.0 | 18 | | cell_2 | 0.234 | 0.0 | 0.0 | | 0.0 | 19 | | cell_3 | 0.0 | 0.0 | 0.0 | | 1.2 | 20 | 21 | ## Metadata 22 | 23 | | id | timepoint | cell_type | 24 | |-------- |----------- |------------------ | 25 | | cell_1 | 0 | undifferentiated | 26 | | cell_2 | 1 | neutrophil | 27 | | cell_3 | 2 | monocyte | 28 | 29 | 50 | 51 | ## PRESCIENT torch object 52 | The `prescient process_data` command will generate a torch pt file `data_pt` (serialized dictionary) that contains all the necessary information for downstream training, simulations, and perturbations. It will contain the following information: 53 | 54 | - data_pt["data"]: Numpy ndarray of normalized expression. 55 | - data_pt["celltype"]: List of celltype labels for each cell. 56 | - data_pt["genes"]: List of gene features. 57 | - data_pt["tps"]: Timepoint assignment for each cell in dataset from metadata. 58 | - data_pt["x"]: Torch tensors of normalied expression split by timepoint. 59 | - data_pt["xp"]: Torch tensors of cell PCs split by timepoint. 60 | - data_pt["xu"]: Torch tensors of cell UMAPs split by timepoint. 61 | - data_pt["pca"]: sklearn.decomposition.PCA object fit to normalized expression and used to produce PCs. 62 | - data_pt["um"]: umap.UMAP object fit to PCs used to produce UMAP dims. 63 | - data_pt["y"]: List of timepoints. 64 | - data_pt["w"]: Torch tensors of pre-computed growth weights split by timepoint. 65 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: home 3 | location: Home 4 | --- 5 | -------------------------------------------------------------------------------- /docs/notebooks.html: -------------------------------------------------------------------------------- 1 | --- 2 | title: Notebooks 3 | noheader: false 4 | permalink: notebooks/ 5 | layout: document 6 | location: Notebooks 7 | --- 8 |

Running PRESCIENT using cloud GPU resources

9 |
10 |

11 | Training PRESCIENT models using CPU can, in some cases, be time intensive. 12 | For this reason, here we provide the path of least resistance to running the 13 | whole training pipeline using cloud computing services.This Google Colab notebook provides a free demo of 14 | how PRESCIENT can be run with GPU acceleration. PRESCIENT is also available via PyPI for installation 15 | and usage on a local workstation, on a compute server, or on other cloud services, including 16 | other cloud notebook servers such as 17 | Gcloud AI platform and 18 | AWS with jupyter. 19 | Our example uses Google Colab which offers free access to public GPUs hosted by Google. 20 | 21 |

22 | 25 | 26 | 32 |
33 | 34 |

Usage and analysis notebooks

35 |

36 | Here, we present different examples of various individual pre-processing and analyses steps using PRESCIENT. This will be a growing list of notebooks as we add functionality and update PRESCIENT. 37 |
38 | To demonstrate, we use longitudinal scRNA-seq data from Veres et al. 2019 Stage 5 pancreatic 39 | beta-cell differentiation as an example dataset. We present only a demo version of this analysis here for simplicity, please refer to the paper analysis github repo, 40 | for in-depth analyses used in the paper. 41 | To run the following notebooks with the example data, download the following directories containing raw data and trained PRESCIENT models: 42 |

46 | 47 |

48 | 49 |

Estimating growth rates

50 |
51 |

52 | PRESCIENT models can incorporate proliferation during training via computing a "growth weight" per cell in a scRNA-seq dataset. 53 | This was shown to greatly improve model performance, and we conclude that growth is important for model performance. When lineage tracing is available, empirical growth rates can be computed, however since the primary 54 | use case for PRESCIENT is in the absence of lineage tracing, we provide a function for estimating growth from proliferative/death gene signatures. 55 | Here, these gene sets are provided in the downloaded data folder and used in the notebook, but you 56 | can use any gene set (gst) from common sources like MSigDB. 57 |

58 | 61 |
62 | 63 | 82 | -------------------------------------------------------------------------------- /docs/quickstart.html: -------------------------------------------------------------------------------- 1 | --- 2 | title: Quickstart 3 | noheader: false 4 | permalink: quickstart/ 5 | layout: document 6 | location: Quickstart 7 | --- 8 | 9 |

Quickstart

10 |

Here, we provide the path of least resistance (the command-line interface) 11 | to training a PRESCIENT model and running perturbational analyses. To install PRESCIENT refer to 12 | the homepage.

13 | 14 |

Create PRESCIENT torch object

15 |
16 |

17 | First, we recommend looking at how to prepare inputs for PRESCIENT 18 | and bring your scRNA-seq to an acceptable format for PRESCIENT. For estimating growth weights, please refer to the notebooks tab. 19 |

20 | Run the following to estimate growth rates and create a PRESCIENT training pyTorch object: 21 |
22 | prescient process_data -d /path/to/your_data.csv -o /path/for/output/ -m /path/to/metadata.csv --tp_col "timepoint colname" --celltype_col "annotation colname" --growth_path /path/to/growth_weights.pt 23 |
24 |

25 |
26 |
27 |

Train PRESCIENT model

28 |
29 |

To train a PRESCIENT model, it is beneficial to use GPU acceleration with CUDA support. PRESCIENT models can be trained on CPUs but will take longer to train. 30 | For a demo on runining PRESCIENT with free GPU cloud resources on Google Colab, please refer to the notebooks tab. 31 |
32 |

33 | Next, train a basic PRESCIENT model with default parameters with the following command and the data.pt file from the process_data command:
34 | prescient train_model -i /path/to/data.pt --out_dir /experiments/ --weight_name 'kegg-growth' 35 |

36 |

37 | For more options to control model architecture and hyperparameters, 38 | please refer to CLI documentation.

39 |
40 | 41 |
42 |
43 | 44 | 45 |

Simulate trajectories

46 |
47 |

Now, with a trained PRESCIENT model and the original PRESCIENT data object, you can simulate trajectories of cells with arbitrary intializations. 48 | To do so, run the simulate command line functions.

49 |

50 | In the following example, the function will randomly sample 50 cells at 51 | the first provided timepoint and simulate forward to the final timepoint: 52 | prescient simulate_trajectories -i /path/to/data.pt --model_path /path/to/trained/model_directory -o /path/to/output_dir --seed 2 53 |

54 |

This will produce a PRESCIENT simulation object containing the following:

55 | 58 |

59 | For more control over choosing cells, number of steps, etc. please refer to CLI documentation. 60 |

61 |
62 | 63 |
64 |

Run perturbation simulations

65 | 66 |
67 |

68 | One of the advantages of training a PRESCIENT model is the ability to simulate the trajectory of out-of-sample 69 | or perturbed initial cells. To do this, individual or sets of genes are perturbed by setting the value(s) to a z-score in scaled 70 | expression space. The following function induces perturbations and generates simulated trajectories of both unperturbed and perturbed cells 71 | for comparison. 72 |

73 |

74 | In the following example GENE1, GENE2, and GENE3 are perturbed in 10 random samples of 200 cells with a z-score of 5 and simulated forward to the final timepoint with a trained PRESCIENT model:
75 | prescient perturbation_analysis -i /path/to/data.pt -p 'GENE1,GENE2,GENE3' -z 5 --model_path /path/to/trained/model_directory --seed 2 -o /path/to/output_dir 76 |

77 | 78 |

This will produce a PRESCIENT simulation object containing the following:

79 | 84 |

85 | For more control over choosing cells, number of steps, etc. please refer to CLI documentation. 86 |

87 |
88 | -------------------------------------------------------------------------------- /notebooks/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gifford-lab/prescient/50971c7d495e8763eaa60f83af91f51555ed7ece/notebooks/.gitkeep -------------------------------------------------------------------------------- /prescient/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gifford-lab/prescient/50971c7d495e8763eaa60f83af91f51555ed7ece/prescient/.gitkeep -------------------------------------------------------------------------------- /prescient/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | import prescient.train 3 | import prescient.simulate 4 | import prescient.perturb 5 | -------------------------------------------------------------------------------- /prescient/__main__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from prescient.commands import * 4 | 5 | def main(): 6 | command_list = [process_data, train_model, simulate_trajectories, perturbation_analysis] 7 | tool_parser = argparse.ArgumentParser(description='Run a prescient command.') 8 | command_list_strings = list(map(lambda x: x.__name__[len('prescient.commands.'):], command_list)) 9 | tool_parser.add_argument('command', help='prescient command', choices=command_list_strings) 10 | tool_parser.add_argument('command_args', help='command arguments', nargs=argparse.REMAINDER) 11 | prescient_args = tool_parser.parse_args() 12 | command_name = prescient_args.command 13 | command_args = prescient_args.command_args 14 | cmd = command_list[command_list_strings.index(command_name)] 15 | sys.argv[0] = cmd.__file__ 16 | parser = cmd.create_parser() 17 | args = parser.parse_args(command_args) 18 | cmd.main(args) 19 | 20 | 21 | 22 | if __name__ == '__main__': 23 | args = sys.argv 24 | # if "--help" in args or len(args) == 1: 25 | # print("CVE") 26 | main() 27 | -------------------------------------------------------------------------------- /prescient/__pycache__/train.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gifford-lab/prescient/50971c7d495e8763eaa60f83af91f51555ed7ece/prescient/__pycache__/train.cpython-37.pyc -------------------------------------------------------------------------------- /prescient/__pycache__/veres.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gifford-lab/prescient/50971c7d495e8763eaa60f83af91f51555ed7ece/prescient/__pycache__/veres.cpython-37.pyc -------------------------------------------------------------------------------- /prescient/commands/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from .process_data import * 3 | from .train_model import * 4 | from .simulate_trajectories import * 5 | from .perturbation_analysis import * 6 | -------------------------------------------------------------------------------- /prescient/commands/perturbation_analysis.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import prescient.simulate as traj 3 | from prescient.train.model import * 4 | import prescient.perturb as pert 5 | import prescient.simulate as traj 6 | 7 | 8 | def create_parser(): 9 | 10 | parser = argparse.ArgumentParser() 11 | 12 | # perturbation parameters 13 | parser.add_argument("-p", "--perturb_genes", required=True, help="Provide a gene or list of genes to be perturbed as a string (commas, no spaces).") 14 | parser.add_argument("-z", "--z_score", default=5.0, required=True, help="Set magnitude of perturbation as z-score.") 15 | 16 | # simulation parameters 17 | parser.add_argument("-i", "--data_path", required=True, help="Path to PRESCIENT data file stored as a torch pt.") 18 | parser.add_argument("--model_path", required=True, help="Path to model directory.") 19 | parser.add_argument("--seed", default=1, required=True, help="Choose the seed of the trained model to use for simulations.") 20 | parser.add_argument("--epoch", default="002500", type=str, required=False, help="Choose which epoch of the model to use for simulations.") 21 | parser.add_argument("--num_sims", default=10, help="Number of simulations to run.") 22 | parser.add_argument("--num_cells", default=200, help="Number of cells per simulation.") 23 | parser.add_argument("--num_steps", default=None, required=False, help="Define number of forward steps of size dt to take.") 24 | parser.add_argument("--num_pcs", required=True, help="Number of PC's. Must match the number given to process_data.") 25 | parser.add_argument("--gpu", default=None, required=False) 26 | parser.add_argument("--celltype_subset", default=None, required=False, help="Randomly sample initial cells from a particular celltype defined in metadata.") 27 | parser.add_argument("--tp_subset", default=None, required=False, help="Randomly sample initial cells from a particular timepoint.") 28 | parser.add_argument("-o", "--out_path", required=True, default=None, help="Path to output directory.") 29 | return parser 30 | 31 | 32 | def main(args): 33 | 34 | # load data 35 | data_pt = torch.load(args.data_path) 36 | genes = data_pt["genes"].tolist() 37 | expr = data_pt["data"] 38 | pca = data_pt["pca"] 39 | xp = pca.transform(expr) 40 | xp = xp[:,0:int(args.num_pcs)] 41 | 42 | # generate perturbations PRESCIENT data file 43 | xp_perturb = pert.z_score_perturbation(genes, args.perturb_genes, expr, pca, args.z_score) 44 | xp_perturb = xp_perturb[:,0:int(args.num_pcs)] 45 | # torch device 46 | if args.gpu != None: 47 | device = torch.device('cuda:{}'.format(args.gpu)) 48 | else: 49 | device = torch.device('cpu') 50 | 51 | # load model 52 | config_path = os.path.join(str(args.model_path), 'seed_{}/config.pt'.format(args.seed)) 53 | config = SimpleNamespace(**torch.load(config_path)) 54 | net = AutoGenerator(config) 55 | 56 | train_pt = os.path.join(args.model_path, 'seed_{}/train.epoch_{}.pt'.format(args.seed, args.epoch)) 57 | checkpoint = torch.load(train_pt, map_location=device) 58 | net.load_state_dict(checkpoint['model_state_dict']) 59 | net.to(device) 60 | 61 | # Either use assigned number of steps or calculate steps, both using the stepsize used for training 62 | if args.num_steps == None: 63 | t = data_pt["y"][-1]-data_pt["y"][0] 64 | num_steps = int(np.round(t / config.train_dt)) 65 | else: 66 | num_steps = int(args.num_steps) 67 | 68 | # simulate forward 69 | std_out = traj.simulate(xp, data_pt["tps"], data_pt["celltype"], data_pt["w"], net, config, int(args.num_sims), int(args.num_cells), num_steps, device, args.tp_subset, args.celltype_subset) 70 | 71 | perturbed_out = traj.simulate(xp_perturb, data_pt["tps"], data_pt["celltype"], data_pt["w"], net, config, int(args.num_sims), int(args.num_cells), num_steps, device, args.tp_subset, args.celltype_subset) 72 | 73 | out_path = os.path.join(args.out_path, args.model_path.split("/")[-1], 'seed_{}_train.epoch_{}_num.sims_{}_num.cells_{}_num.steps_{}_subsets_{}_{}_perturb_simulation.pt'.format(args.seed, args.epoch, args.num_sims, args.num_cells, num_steps, args.tp_subset, args.celltype_subset)) 74 | # save PRESCIENT perturbation file 75 | torch.save({"perturbed_genes": args.perturb_genes, 76 | "unperturbed_sim": std_out, 77 | "perturbed_sim": perturbed_out}, 78 | out_path) 79 | 80 | 81 | if __name__=="__main__": 82 | main() 83 | -------------------------------------------------------------------------------- /prescient/commands/process_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import argparse 5 | import pyreadr 6 | import scanpy as sc 7 | import anndata 8 | import sklearn 9 | import umap 10 | 11 | import annoy 12 | import torch 13 | 14 | 15 | def read_data(args): 16 | """ 17 | - Load csv preprocessed with scanpy or Seurat. 18 | - Must be a csv with format n_cells x n_genes with normalized (not scaled!) expression. 19 | - Must have meta data with n_cells x n_metadata and include timepoints and assigned cell type labels. 20 | 21 | Inputs: 22 | ------- 23 | path: path to csv or rds file of processed scRNA-seq dataset. 24 | meta: path to metadata csv. 25 | """ 26 | ext = os.path.splitext(args.data_path)[1] 27 | # load in expression dataframe 28 | if ext == ".csv" or ext == ".txt" or ext == ".tsv": 29 | if args.meta_path == None: 30 | raise ValueError("If csv/tsv/txt provided, you must provide a path to csv metadata with timepoint and cell_type") 31 | if args.tp_col == None or args.celltype_col == None: 32 | raise ValueError("If csv/tsv/txt provided, you must provide --tp_col and --celltype_col.") 33 | 34 | expr = pd.read_csv(args.data_path, index_col=0) 35 | meta = pd.read_csv(args.meta_path) 36 | genes = expr.columns 37 | expr = expr.to_numpy() 38 | tps = meta[args.tp_col].values.astype(int) 39 | celltype = meta[args.celltype_col].values 40 | 41 | if ext == ".h5ad": 42 | adata = sc.read_h5ad(args.data_path) 43 | expr = adata.X 44 | meta = adata.obs.copy() 45 | try: 46 | expr = expr.toarray() # In case it is in a sparse format; I do not know if this would be an obstacle downstream. 47 | except: 48 | pass 49 | if args.tp_col == None or args.celltype_col == None: 50 | raise ValueError("If h5ad input is provided, you must provide --tp_col and --celltype_col.") 51 | assert args.tp_col in adata.obs.columns, f"Expected a timepoint column in .obs called {args.tp_col}, but did not find it. Update the --tp_col arg?" 52 | assert args.celltype_col in adata.obs.columns, f"Expected a cell_type column in .obs called {args.celltype_col}, but did not find it. Update the --celltype_col arg?" 53 | tps = meta[args.tp_col].values.astype(int) 54 | celltype = meta[args.celltype_col].values 55 | genes = adata.var_names 56 | 57 | # todo: implement Seurat object functionality 58 | if ext == ".rds": 59 | raise NotImplementedError 60 | 61 | if args.fix_non_consecutive: 62 | converter = {orig:new for orig, new in zip(np.sort(np.unique(tps)), np.arange(0, len(np.unique(tps))))} 63 | tps = np.array([converter[orig] for orig in tps]) 64 | assert np.all(np.sort(np.unique(tps)) == np.arange(0, len(np.unique(tps)))), "Timepoints must be labeled 0, 1, 2, ... T consecutively; no gaps are allowed" 65 | 66 | # transformations 67 | scaler = sklearn.preprocessing.StandardScaler() 68 | pca = sklearn.decomposition.PCA(n_components = args.num_pcs) 69 | um = umap.UMAP(n_components = 2, metric = 'euclidean', n_neighbors = args.num_neighbors_umap) 70 | 71 | x = scaler.fit_transform(expr) 72 | xp = pca.fit_transform(x) 73 | xu = um.fit_transform(xp) 74 | 75 | y = list(np.sort(np.unique(tps))) 76 | 77 | x_ = [torch.from_numpy(x[(tps == d),:]).float() for d in y] 78 | xp_ = [torch.from_numpy(xp[(tps == d),:]).float() for d in y] 79 | xu_ = [torch.from_numpy(xu[(tps == d),:]).float() for d in y] 80 | 81 | return expr, x_, xp_, xu_, y, pca, um, tps, celltype, genes 82 | 83 | def create_parser(): 84 | parser = argparse.ArgumentParser() 85 | 86 | # file I/0 87 | parser.add_argument('-d', '--data_path', type=str, required=True, 88 | help="Path to dataframe of expression values.") 89 | parser.add_argument('-o', '--out_dir', type=str, required=True, 90 | help="Path to output directory to store final PRESCIENT data file.") 91 | parser.add_argument('-m', '--meta_path', type=str, required=False, 92 | help="Path to metadata containing timepoint and celltype annotation data.") 93 | 94 | # column names 95 | parser.add_argument('--tp_col', type=str, required=False, 96 | help="Column name of timepoint feature in metadate provided as string.") 97 | parser.add_argument('--celltype_col', type=str, required=False, 98 | help="Column name of celltype feature in metadata provided as string.") 99 | 100 | # option to fix non-consecutive timepoint labels 101 | parser.add_argument('--fix_non_consecutive', action="store_true", default=False, 102 | help="If provided, quantitative timepoints will be overwritted, e.g. 1, 4, 10 becomes 0,1,2.") 103 | 104 | # dimensionality reduction growth_parameters 105 | parser.add_argument('--num_pcs', type=int, default=50, required=False, 106 | help="Define number of PCs to compute for input to training.") 107 | parser.add_argument('--num_neighbors_umap', type=int, default=10, required=False, 108 | help="Define number of neighbors for UMAP trasformation (UMAP used only for visualization.)") 109 | 110 | # proliferation scores 111 | parser.add_argument('--growth_path', type=str, 112 | help="Path to torch pt file containg pre-computed growth weights. See vignette notebooks for generating growth rate vector.") 113 | return parser 114 | 115 | def main(args): 116 | """ 117 | Outputs: 118 | -------- 119 | Saves a PRESCIENT file to out_path. Does not output file. 120 | data.pt: 121 | |- x: scaled expression 122 | |- xp: n PC space 123 | |- xu: UMAP space 124 | |- pca: sklearn pca object for pca tranformation 125 | |- um: umap object for umap transformation 126 | |- y: timepoints 127 | |- genes: features 128 | |- w: growth weights 129 | |- celltype: vector of celltype labels 130 | """ 131 | expr, x, xp, xu, y, pca, um, tps, celltype, genes = read_data(args) 132 | 133 | w_pt = torch.load(args.growth_path) 134 | w = w_pt["w"] 135 | 136 | 137 | # write as a torch object 138 | torch.save({ 139 | "data": expr, 140 | "genes": genes, 141 | "celltype": celltype, 142 | "tps": tps, 143 | "x":x, 144 | "xp":xp, 145 | "xu": xu, 146 | "y": y, 147 | "pca": pca, 148 | "um":um, 149 | "w":w 150 | }, args.out_dir+"data.pt") 151 | 152 | if __name__ == '__main__': 153 | main() 154 | -------------------------------------------------------------------------------- /prescient/commands/simulate_trajectories.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import prescient.simulate as traj 4 | from prescient.train.model import * 5 | 6 | def create_parser(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("-i", "--data_path", required=True, help="Path to PRESCIENT data file stored as a torch pt.") 9 | parser.add_argument("--model_path", required=True, help="Path to directory containing PRESCIENT model for simulation.") 10 | parser.add_argument("--seed", default=1, required=True, help="Choose the seed of the trained model to use for simulations.") 11 | parser.add_argument("--epoch", default="002500", type=str, required=False, help="Choose which epoch of the model to use for simulations.") 12 | parser.add_argument("--num_sims", default=10, help="Number of simulations to run.") 13 | parser.add_argument("--num_cells", default=200, help="Number of cells per simulation.") 14 | parser.add_argument("--num_steps", default=None, required=False, help="Define number of forward steps of size dt to take.") 15 | parser.add_argument("--gpu", default=None, required=False, help="If available, assign GPU device number.") 16 | parser.add_argument("--celltype_subset", default=None, required=False, help="Randomly sample initial cells from a particular celltype defined in metadata.") 17 | parser.add_argument("--tp_subset", type=int, default=None, required=False, help="Randomly sample initial cells from a particular timepoint.") 18 | parser.add_argument("-o", "--out_path", required=True, default=None, help="Path to output directory.") 19 | return parser 20 | 21 | def main(args): 22 | 23 | # load data 24 | data_pt = torch.load(args.data_path) 25 | expr = data_pt["data"] 26 | pca = data_pt["pca"] 27 | xp = pca.transform(expr) 28 | 29 | # torch device 30 | if args.gpu != None: 31 | device = torch.device('cuda:{}'.format(args.gpu)) 32 | else: 33 | device = torch.device('cpu') 34 | 35 | # load model 36 | config_path = os.path.join(str(args.model_path), 'seed_{}/config.pt'.format(args.seed)) 37 | config = SimpleNamespace(**torch.load(config_path)) 38 | net = AutoGenerator(config) 39 | 40 | train_pt = os.path.join(args.model_path, 'seed_{}/train.epoch_{}.pt'.format(args.seed, args.epoch)) 41 | checkpoint = torch.load(train_pt, map_location=device) 42 | net.load_state_dict(checkpoint['model_state_dict']) 43 | 44 | net.to(device) 45 | 46 | # Either use assigned number of steps or calculate steps, both using the stepsize used for training 47 | if args.num_steps == None: 48 | num_steps = int(np.round(data_pt["y"] / config.train_dt)) 49 | else: 50 | num_steps = int(args.num_steps) 51 | 52 | # simulate forward 53 | out = traj.simulate(xp, data_pt["tps"], data_pt["celltype"], data_pt["w"], net, config, args.num_sims, int(args.num_cells), num_steps, device, args.tp_subset, args.celltype_subset) 54 | 55 | # write simulation data to file 56 | out_path = os.path.join(args.out_path, args.model_path.split("/")[-1], 'seed_{}_train.epoch_{}_num.sims_{}_num.cells_{}_num.steps_{}_subsets_{}_{}_simulation.pt'.format(args.seed, args.epoch, args.num_sims, args.num_cells, num_steps, args.tp_subset, args.celltype_subset)) 57 | torch.save({ 58 | "sims": out 59 | }, out_path) 60 | 61 | if __name__=="__main__": 62 | main() 63 | -------------------------------------------------------------------------------- /prescient/commands/train_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import copy 5 | import numpy as np 6 | import torch 7 | import itertools 8 | import json 9 | import sklearn.decomposition 10 | 11 | from geomloss import SamplesLoss 12 | from collections import OrderedDict 13 | from types import SimpleNamespace 14 | from time import strftime, localtime 15 | 16 | import prescient.train as train 17 | 18 | 19 | def create_parser(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--no-cuda', action = 'store_true') 22 | parser.add_argument('--gpu', default = 7, type = int, help="Designate GPU number as an integer (compatible with CUDA).") 23 | parser.add_argument('--out_dir', default = './experiments', help="Directory for storing training output.") 24 | parser.add_argument('--seed', type = int, default = 2, help="Set seed for training process.") 25 | # -- data options 26 | parser.add_argument('-i', '--data_path', required=True, help="Input PRESCIENT data torch file.") 27 | parser.add_argument('--weight_name', default = None, help="Designate descriptive name of growth parameters for filename.") 28 | # -- model options 29 | parser.add_argument('--loss', default = 'euclidean', help="Designate distance function for loss.") 30 | parser.add_argument('--k_dim', default = 500, type = int, help="Designate hidden units of NN.") 31 | parser.add_argument('--activation', default = 'softplus', help="Designate activation function for layers of NN.") 32 | parser.add_argument('--layers', default = 1, type = int, help="Choose number of layers for neural network parameterizing the potential function.") 33 | # -- pretrain options 34 | parser.add_argument('--pretrain_epochs', default = 500, type = int, help="Number of epochs for pretraining with contrastive divergence.") 35 | # -- train options 36 | parser.add_argument('--train_epochs', default = 2500, type = int, help="Number of epochs for training.") 37 | parser.add_argument('--train_lr', default = 0.01, type = float, help="Learning rate for Adam optimizer during training.") 38 | parser.add_argument('--train_dt', default = 0.1, type = float, help="Timestep for simulations during training.") 39 | parser.add_argument('--train_sd', default = 0.5, type = float, help="Standard deviation of Gaussian noise for simulation steps.") 40 | parser.add_argument('--train_tau', default = 1e-6, type = float, help="Tau hyperparameter of PRESCIENT.") 41 | parser.add_argument('--train_batch', default = 0.1, type = float, help="Batch size for training.") 42 | parser.add_argument('--train_clip', default = 0.25, type = float, help="Gradient clipping threshold for training.") 43 | parser.add_argument('--save', default = 100, type = int, help="Save model every n epochs.") 44 | # -- run options 45 | parser.add_argument('--pretrain', type=bool, default=True, help="If True, pretraining will run.") 46 | parser.add_argument('--train', type=bool, default=True, help="If True, training will run with existing pretraining torch file.") 47 | parser.add_argument('--config') 48 | return parser 49 | 50 | 51 | def init_config(args): 52 | 53 | config = SimpleNamespace( 54 | 55 | seed = args.seed, 56 | timestamp = strftime("%a, %d %b %Y %H:%M:%S", localtime()), 57 | 58 | # data parameters 59 | data_path = args.data_path, 60 | weight = args.weight, 61 | 62 | # model parameters 63 | activation = args.activation, 64 | layers = args.layers, 65 | k_dim = args.k_dim, 66 | 67 | # pretraining parameters 68 | pretrain_burnin = 50, 69 | pretrain_sd = 0.1, 70 | pretrain_lr = 1e-9, 71 | pretrain_epochs = args.pretrain_epochs, 72 | 73 | # training parameters 74 | train_dt = args.train_dt, 75 | train_sd = args.train_sd, 76 | train_batch_size = args.train_batch, 77 | ns = 2000, 78 | train_burnin = 100, 79 | train_tau = args.train_tau, 80 | train_epochs = args.train_epochs, 81 | train_lr = args.train_lr, 82 | train_clip = args.train_clip, 83 | save = args.save, 84 | 85 | # loss parameters 86 | sinkhorn_scaling = 0.7, 87 | sinkhorn_blur = 0.1, 88 | 89 | # file parameters 90 | out_dir = args.out_dir, 91 | out_name = args.out_dir.split('/')[-1], 92 | pretrain_pt = os.path.join(args.out_dir, 'pretrain.pt'), 93 | train_pt = os.path.join(args.out_dir, 'train.{}.pt'), 94 | train_log = os.path.join(args.out_dir, 'train.log'), 95 | done_log = os.path.join(args.out_dir, 'done.log'), 96 | config_pt = os.path.join(args.out_dir, 'config.pt'), 97 | ) 98 | 99 | config.train_t = [] 100 | config.test_t = [] 101 | 102 | if not os.path.exists(args.out_dir): 103 | print('Making directory at {}'.format(args.out_dir)) 104 | os.makedirs(args.out_dir) 105 | else: 106 | print('Directory exists at {}'.format(args.out_dir)) 107 | return config 108 | 109 | def load_data(args): 110 | return torch.load(args.data_path) 111 | 112 | def train_init(args): 113 | 114 | a = copy.copy(args) 115 | 116 | # data 117 | data_pt = load_data(args) 118 | x = data_pt["xp"] 119 | y = data_pt["y"] 120 | weight = data_pt["w"] 121 | if args.weight_name != None: 122 | a.weight = args.weight_name 123 | # weight = os.path.basename(a.weight_path) 124 | # weight = weight.split('.')[0].split('-')[-1] 125 | 126 | 127 | # out directory 128 | name = ( 129 | "{weight}-" 130 | "{activation}_{layers}_{k_dim}-" 131 | "{train_tau}" 132 | ).format(**a.__dict__) 133 | 134 | a.out_dir = os.path.join(args.out_dir, name, 'seed_{}'.format(a.seed)) 135 | config = init_config(a) 136 | 137 | config.x_dim = x[0].shape[-1] 138 | config.t = y[-1] - y[0] 139 | 140 | config.start_t = y[0] 141 | config.train_t = y[1:] 142 | y_start = y[config.start_t] 143 | y_ = [y_ for y_ in y if y_ > y_start] 144 | 145 | w_ = weight[config.start_t] 146 | w = {(y_start, yy): torch.from_numpy(np.exp((yy - y_start)*w_)) for yy in y_} 147 | 148 | return x, y, w, config 149 | 150 | def main(args): 151 | train.run(args, train_init) 152 | 153 | if __name__=="__main__": 154 | main() 155 | -------------------------------------------------------------------------------- /prescient/perturb/__init__.py: -------------------------------------------------------------------------------- 1 | from .pert import * 2 | -------------------------------------------------------------------------------- /prescient/perturb/pert.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import sklearn 4 | 5 | def z_score_perturbation(genes, perturb_genes, x, pca, z_score): 6 | scaler=sklearn.preprocessing.StandardScaler() 7 | x=scaler.fit_transform(x) 8 | # x_ is perturbed expression profile 9 | x_ = x 10 | perturb_genes=perturb_genes.split(",") 11 | idx=[] 12 | # perturb genes that appear in highly variable gene list 13 | for elt in perturb_genes: 14 | if (elt in genes): 15 | idx.append(genes.index(elt)) 16 | for elt in idx: 17 | x_[:,elt]= z_score 18 | xp = pca.transform(x_) 19 | return xp 20 | -------------------------------------------------------------------------------- /prescient/simulate/__init__.py: -------------------------------------------------------------------------------- 1 | from .sim import * 2 | -------------------------------------------------------------------------------- /prescient/simulate/sim.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings('ignore') 3 | import os 4 | import sys 5 | import argparse 6 | import random 7 | import joblib 8 | import json 9 | import tqdm 10 | import torch 11 | 12 | import numpy as np 13 | import pandas as pd 14 | import sklearn 15 | 16 | from types import SimpleNamespace 17 | from collections import Counter 18 | 19 | def simulate(xp, tps, celltype_annotations, w, model, config, num_sims, num_cells, num_steps, device, tp_subset, celltype_subset): 20 | """ 21 | Use trained PRESCIENT model to simulate cell trajectories with arbitrary initializations. 22 | """ 23 | # load data 24 | xp = torch.from_numpy(xp) 25 | 26 | # make meta dataframe 27 | # TO-DO implement weight sampling strategy 28 | dict = {"tp": tps, "celltype": celltype_annotations} 29 | meta = pd.DataFrame(dict) 30 | 31 | 32 | all_sims = [] 33 | pbar = tqdm.tqdm(range(num_sims)) 34 | for s in pbar: 35 | # sample cells based on timepoint or celltype or both 36 | try: 37 | if tp_subset is not None and celltype_subset is not None: 38 | idx = pd.DataFrame(meta[(meta["tp"]==tp_subset) & (meta["celltype"]==celltype_subset)]).sample(int(num_cells)).index 39 | elif tp_subset is not None: 40 | idx = pd.DataFrame(meta[meta["tp"]==tp_subset]).sample(int(num_cells)).index 41 | elif celltype_subset is not None: 42 | idx = pd.DataFrame(meta[meta["celltype"]==celltype_subset]).sample(int(num_cells)).index 43 | else: 44 | idx = meta.sample(num_cells).index 45 | except ValueError: 46 | print(meta.head(), flush=True) 47 | raise ValueError(f"Those cells were not found in the metadata. Wrong 'tp_subset' or 'celltype_subset'? Timepoint given was {tp_subset} and celltype given was {celltype_subset}. Metadata examples should be printed to stdout.") 48 | # map tensor to device 49 | xp_i = xp[idx].to(device) 50 | 51 | # store inital value 52 | xp_i_ = xp_i.detach().cpu().numpy() 53 | xps_i = [xp_i_] # n 54 | 55 | # simulate all cells forward through time 56 | for _ in range(num_steps): 57 | # initialize latent vector 58 | z = torch.randn(xp_i.shape[0], xp_i.shape[1]) * config.train_sd 59 | z = z.to(device) 60 | 61 | # step forward with trained model 62 | xp_i = model._step(xp_i.float(), dt=config.train_dt, z=z) 63 | 64 | # store next step 65 | xp_i_ = xp_i.detach().cpu().numpy() 66 | xps_i.append(xp_i_) 67 | 68 | # group timepoints 69 | xps = np.stack(xps_i) #[n_cells x n_steps] 70 | all_sims.append(xps) #[n_sims x n_cells x n_steps] 71 | 72 | pbar.set_description('[simulate] {}'.format(s)) 73 | return all_sims 74 | -------------------------------------------------------------------------------- /prescient/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import * 2 | from .run import * 3 | from .util import * 4 | -------------------------------------------------------------------------------- /prescient/train/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn, optim 4 | 5 | import numpy as np 6 | 7 | from geomloss import SamplesLoss 8 | 9 | import tqdm 10 | 11 | from collections import OrderedDict 12 | from types import SimpleNamespace 13 | from time import strftime, localtime 14 | 15 | import argparse 16 | import itertools 17 | import json 18 | import os 19 | import sys 20 | 21 | import sklearn.decomposition 22 | 23 | from .util import * 24 | 25 | # ---- PRESCIENT 26 | 27 | class IntReLU(nn.Module): 28 | 29 | def __init__(self, input_dim): 30 | super(IntReLU, self).__init__() 31 | 32 | def forward(self, x): 33 | return torch.max(torch.zeros_like(x), 0.5 * (x**2)) # + self.c) 34 | 35 | 36 | class AutoGenerator(nn.Module): 37 | 38 | def __init__(self, config): 39 | super(AutoGenerator, self).__init__() 40 | 41 | self.x_dim = config.x_dim 42 | self.k_dim = config.k_dim 43 | self.layers = config.layers 44 | 45 | self.activation = config.activation 46 | if self.activation == 'relu': 47 | self.act = nn.LeakyReLU 48 | elif self.activation == 'softplus': 49 | self.act = nn.Softplus 50 | elif self.activation == 'intrelu': # broken, wip 51 | raise NotImplementedError 52 | elif self.activation == 'none': 53 | self.act = None 54 | else: 55 | raise NotImplementedError 56 | 57 | self.net_ = [] 58 | for i in range(self.layers): 59 | # add linear layer 60 | if i == 0: 61 | self.net_.append(('linear{}'.format(i+1), nn.Linear(self.x_dim, self.k_dim))) 62 | else: 63 | self.net_.append(('linear{}'.format(i+1), nn.Linear(self.k_dim, self.k_dim))) 64 | # add activation 65 | if self.activation == 'intrelu': 66 | raise NotImplementedError 67 | elif self.activation == 'none': 68 | pass 69 | else: 70 | self.net_.append(('{}{}'.format(self.activation, i+1), self.act())) 71 | self.net_.append(('linear', nn.Linear(self.k_dim, 1, bias = False))) 72 | self.net_ = OrderedDict(self.net_) 73 | self.net = nn.Sequential(self.net_) 74 | 75 | net_params = list(self.net.parameters()) 76 | net_params[-1].data = torch.zeros(net_params[-1].data.shape) # initialize 77 | 78 | def _step(self, x, dt, z): 79 | sqrtdt = np.sqrt(dt) 80 | return x + self._drift(x) * dt + z * sqrtdt 81 | 82 | def _pot(self, x): 83 | return self.net(x) 84 | 85 | def _drift(self, x): 86 | x_ = x.requires_grad_() 87 | pot = self._pot(x_) 88 | 89 | drift = torch.autograd.grad(pot, x_, torch.ones_like(pot), 90 | create_graph = True)[0] 91 | return drift 92 | 93 | # ---- loss 94 | 95 | class OTLoss(): 96 | 97 | def __init__(self, config, device): 98 | 99 | self.ot_solver = SamplesLoss("sinkhorn", p = 2, blur = config.sinkhorn_blur, 100 | scaling = config.sinkhorn_scaling, debias = True) 101 | self.device = device 102 | 103 | def __call__(self, a_i, x_i, b_j, y_j, requires_grad = True): 104 | 105 | a_i = a_i.to(self.device) 106 | x_i = x_i.to(self.device) 107 | b_j = b_j.to(self.device) 108 | y_j = y_j.to(self.device) 109 | 110 | if requires_grad: 111 | a_i.requires_grad_() 112 | x_i.requires_grad_() 113 | b_j.requires_grad_() 114 | 115 | loss_xy = self.ot_solver(a_i, x_i, b_j, y_j) 116 | return loss_xy 117 | -------------------------------------------------------------------------------- /prescient/train/run.py: -------------------------------------------------------------------------------- 1 | # shared functions and classes, including the model and `run` 2 | # which implements the main pre-training and training loop 3 | import os 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn, optim 8 | 9 | import numpy as np 10 | 11 | import tqdm 12 | from time import strftime, localtime 13 | 14 | from .model import * 15 | from .util import * 16 | 17 | def run(args, init_task): 18 | 19 | # ---- initialize 20 | 21 | device, kwargs = init(args) 22 | torch.manual_seed(args.seed) 23 | np.random.seed(args.seed) 24 | 25 | x, y, w, config = init_task(args) 26 | 27 | # ---- model 28 | 29 | model = AutoGenerator(config) 30 | print(model) 31 | model.zero_grad() 32 | 33 | # ---- loss 34 | 35 | if args.loss == 'euclidean': 36 | loss = OTLoss(config, device) 37 | else: 38 | raise NotImplementedError 39 | 40 | torch.save(config.__dict__, config.config_pt) 41 | 42 | if args.pretrain: 43 | 44 | if os.path.exists(config.done_log): 45 | 46 | print(config.done_log, ' exists. Skipping.') 47 | 48 | else: 49 | 50 | model.to(device) 51 | x_last = x[config.train_t[-1]].to(device) # use the last available training point 52 | optimizer = optim.SGD(list(model.parameters()), lr = config.pretrain_lr) 53 | 54 | pbar = tqdm.tqdm(range(config.pretrain_epochs)) 55 | for epoch in pbar: 56 | 57 | pp, _ = p_samp(x_last, config.ns) 58 | 59 | dt = config.t / config.pretrain_burnin 60 | pp, pos_fv, neg_fv = fit_regularizer(x_last, pp, 61 | config.pretrain_burnin, dt, config.pretrain_sd, 62 | model, device) 63 | fv_tot = pos_fv + neg_fv 64 | 65 | fv_tot.backward() 66 | optimizer.step() 67 | model.zero_grad() 68 | 69 | pbar.set_description('[{}|pretrain] {} {:.3f}'.format( 70 | config.out_name, epoch, fv_tot.item())) 71 | 72 | torch.save({ 73 | 'model_state_dict': model.state_dict(), 74 | }, config.pretrain_pt) 75 | 76 | if args.train: 77 | 78 | if os.path.exists(config.done_log): 79 | 80 | print(config.done_log, ' exists. Skipping.') 81 | 82 | else: 83 | 84 | checkpoint = torch.load(config.pretrain_pt) 85 | model.load_state_dict(checkpoint['model_state_dict']) 86 | model.to(device) 87 | 88 | optimizer = optim.Adam(list(model.parameters()), lr = config.train_lr) 89 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 100, gamma = 0.9) 90 | optimizer.zero_grad() 91 | 92 | pbar = tqdm.tqdm(range(config.train_epochs)) 93 | x_last = x[config.train_t[-1]].to(device) # use the last available training point 94 | # fit on time points 95 | 96 | best_train_loss_xy = np.inf 97 | log_handle = open(config.train_log, 'w') 98 | 99 | for epoch in pbar: 100 | 101 | losses_xy = [] 102 | config.train_epoch = epoch 103 | 104 | for j in config.train_t: 105 | 106 | t_cur = j 107 | t_prev = config.start_t 108 | dat_cur = x[t_cur] 109 | dat_prev = x[t_prev] 110 | y_cur = y[t_cur] 111 | y_prev = y[t_prev] 112 | time_elapsed = y_cur - y_prev 113 | 114 | w_prev = get_weight(w[(y_prev, y_cur)], time_elapsed) 115 | 116 | x_i, a_i = p_samp(dat_prev, int(dat_prev.shape[0] * args.train_batch), 117 | w_prev) 118 | x_i = x_i.to(device) 119 | num_steps = int(np.round(time_elapsed / config.train_dt)) 120 | for _ in range(num_steps): 121 | z = torch.randn(x_i.shape[0], x_i.shape[1]) * config.train_sd 122 | z = z.to(device) 123 | x_i = model._step(x_i, dt = config.train_dt, z = z) 124 | 125 | y_j, b_j = p_samp(dat_cur, int(dat_cur.shape[0] * args.train_batch)) 126 | 127 | loss_xy = loss(a_i, x_i, b_j, y_j) 128 | losses_xy.append(loss_xy.item()) 129 | 130 | #[F_i, G_j, dx_i] = torch.autograd.grad( Loss_xy, [a_i, b_j, x_i] ) 131 | 132 | loss_xy.backward() 133 | 134 | train_loss_xy = np.mean(losses_xy) 135 | 136 | # fit regularizer 137 | 138 | if config.train_tau > 0: 139 | 140 | pp, _ = p_samp(x_last, config.ns) 141 | 142 | dt = config.t / config.train_burnin 143 | pp, pos_fv, neg_fv = fit_regularizer(x_last, pp, 144 | config.train_burnin, dt, config.train_sd, 145 | model, device) 146 | fv_tot = pos_fv + neg_fv 147 | fv_tot *= config.train_tau 148 | fv_tot.backward() 149 | 150 | # step 151 | 152 | if config.train_clip > 0: 153 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.train_clip) 154 | optimizer.step() 155 | scheduler.step() 156 | model.zero_grad() 157 | 158 | # report 159 | 160 | desc = "[{}|train] {}".format(config.out_name, epoch + 1) 161 | if len(losses_xy) < 10: 162 | for l_xy in losses_xy: 163 | desc += " {:.6f}".format(l_xy) 164 | desc += " {:.6f}".format(train_loss_xy) 165 | desc += " {:.6f}".format(best_train_loss_xy) 166 | pbar.set_description(desc) 167 | log_handle.write(desc + '\n') 168 | log_handle.flush() 169 | 170 | if train_loss_xy < best_train_loss_xy: 171 | best_train_loss_xy = train_loss_xy 172 | 173 | torch.save({ 174 | 'model_state_dict': model.state_dict(), 175 | 'epoch': config.train_epoch + 1, 176 | }, config.train_pt.format('best')) 177 | 178 | # save model every x epochs 179 | 180 | if (config.train_epoch + 1) % config.save == 0: 181 | epoch_ = str(config.train_epoch + 1).rjust(6, '0') 182 | torch.save({ 183 | 'model_state_dict': model.state_dict(), 184 | 'epoch': config.train_epoch + 1, 185 | }, config.train_pt.format('epoch_{}'.format(epoch_))) 186 | 187 | log_handle.close() 188 | 189 | log_handle = open(config.done_log, 'w') 190 | timestamp = strftime("%a, %d %b %Y %H:%M:%S", localtime()) 191 | log_handle.write(config.timestamp + '\n') 192 | log_handle.close() 193 | -------------------------------------------------------------------------------- /prescient/train/util.py: -------------------------------------------------------------------------------- 1 | # shared functions and classes, including the model and `run` 2 | # which implements the main pre-training and training loop 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn, optim 7 | 8 | import numpy as np 9 | 10 | from collections import OrderedDict 11 | from types import SimpleNamespace 12 | from time import strftime, localtime 13 | 14 | import argparse 15 | import itertools 16 | import json 17 | import os 18 | import sys 19 | 20 | import sklearn.decomposition 21 | 22 | # ---- convenience functions 23 | 24 | def p_samp(p, num_samp, w = None): 25 | repflag = p.shape[0] < num_samp 26 | p_sub = np.random.choice(p.shape[0], size = num_samp, replace = repflag) 27 | if w is None: 28 | w_ = torch.ones(len(p_sub)) 29 | else: 30 | w_ = w[p_sub].clone() 31 | w_ = w_ / w_.sum() 32 | 33 | return p[p_sub,:].clone(), w_ 34 | 35 | def fit_regularizer(samples, pp, burnin, dt, sd, model, device): 36 | 37 | factor = samples.shape[0] / pp.shape[0] 38 | 39 | z = torch.randn(burnin, pp.shape[0], pp.shape[1]) * sd 40 | z = z.to(device) 41 | 42 | for i in range(burnin): 43 | pp = model._step(pp, dt, z = z[i,:,:]) 44 | 45 | pos_fv = -1 * model._pot(samples).sum() 46 | neg_fv = factor * model._pot(pp.detach()).sum() 47 | 48 | return pp, pos_fv, neg_fv 49 | 50 | def pca_transform(x): 51 | 52 | pca = sklearn.decomposition.PCA(n_components = 50) 53 | # keep track of how to break up the array after concat 54 | x_ = torch.cat(x) 55 | x_ = pca.fit_transform(x_) 56 | x_breaks = np.append([0], np.cumsum([len(x_) for x_ in x])) 57 | x_tmp = [] 58 | for i in range(len(x_breaks) - 1): 59 | ii = x_breaks[i] 60 | jj = x_breaks[i+1] 61 | x_tmp.append(torch.from_numpy(x_[ii:jj]).float()) 62 | x = x_tmp 63 | 64 | return x 65 | 66 | def get_weight(w, time_elapsed): 67 | return w 68 | 69 | def init(args): 70 | 71 | args.cuda = torch.cuda.is_available() 72 | device = torch.device('cuda:{}'.format(args.gpu) if args.cuda else 'cpu') 73 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 74 | 75 | return device, kwargs 76 | 77 | def weighted_samp(p, num_samp, w): 78 | ix = list(torch.utils.data.WeightedRandomSampler(w, num_samp)) 79 | return p[ix,:].clone() 80 | -------------------------------------------------------------------------------- /prescient/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import pandas as pd 5 | from annoy import AnnoyIndex 6 | import torch 7 | 8 | def plot_interpolation(): # TO-DO 9 | pass 10 | 11 | def plot_trajectory(): # TO-DO 12 | pass 13 | 14 | def make_trajectory_animation(): # TO-DO 15 | pass 16 | 17 | def plot_fate_streamplot(): # TO-DO 18 | pass 19 | 20 | def train_ann(data_pt): 21 | yc = data_pt['celltype'] 22 | xtr = data_pt["xp"] 23 | n_trees = 10 24 | n_neighbors = 20 25 | t = AnnoyIndex(xtr.shape[1], 'euclidean') 26 | for i in range(xtr.shape[0]): 27 | t.add_item(i, xtr[i]) 28 | t.build(n_trees) 29 | return t 30 | 31 | def classify_cells(args, data_pt, all_sims_timepoints, ann_dir): 32 | n_neighbors=10 33 | meta = data_pt["meta"] 34 | yc = data_pt["celltype"] 35 | xp_df = pd.DataFrame(data_pt["xp"], yc) 36 | u = AnnoyIndex(all_sims_timepoints[0][0].shape[1], 'euclidean') # all_sims_timepoints[0][0][0].shape[1], 'euclidean') 37 | u.load(ann_dir) 38 | yp_all=[] 39 | for timepoint in all_sims_timepoints: 40 | yp=[] 41 | for i in range(len(timepoint)): 42 | yt=[] 43 | for j in range(len(timepoint[0])): 44 | nn = xp_df.iloc[u.get_nns_by_vector(timepoint[i][j], n_neighbors)] 45 | nn = Counter(nn.index).most_common(2) 46 | label, num = nn[0] 47 | yt.append(label) 48 | yp.append(yt) 49 | yp_all.append(yp) 50 | return yp_all 51 | 52 | def compute_growth(L0, L, k, birth_score, death_score, birth_smoothed_score, death_smoothed_score): 53 | L = float(L) 54 | L0 = float(L0) 55 | k = float(k) 56 | 57 | kb = np.log(k) / np.min(birth_score) 58 | kd = np.log(k) / np.min(death_score) 59 | 60 | b = birth_smoothed_score 61 | d = death_smoothed_score 62 | 63 | b = L0 + L / (1 + np.exp(-kb * b)) 64 | d = L0 + L / (1 + np.exp(-kd * d)) 65 | g = b - d 66 | return g 67 | 68 | def get_growth_weights(x, xp, metadata, tp_col, genes, birth_gst, death_gst, outfile, 69 | n_neighbors=20, beta=0.1, L0=0.3, L=1.1, k=0.001): 70 | """ 71 | Estimate growth using KEGG gene annotations. Implements smoothing procedure. 72 | 73 | Inputs: 74 | ------- 75 | x: numpy ndarray of scaled gene expression. 76 | xp: numpy ndarray of PCs. 77 | genes: list or numpy array of highly variable gene symbols. 78 | birth_gst: path to csv of birth signature annotations. 79 | death_gst: path to csv of death signature annotations. 80 | outfile: provide name of outfile pt. 81 | 82 | Outputs: 83 | -------- 84 | weights: growth rates vector. 85 | """ 86 | gst = pd.read_csv(birth_gst, index_col=0) 87 | birth_gst = [g for g in gst['gene_symbol'].unique() if g in genes] 88 | gst = pd.read_csv(death_gst, index_col = 0) 89 | death_gst = [g for g in gst['gene_symbol'].unique() if g in genes] 90 | 91 | birth_gst = [g for g in birth_gst if g not in death_gst] 92 | death_gst = [g for g in death_gst if g not in birth_gst] 93 | 94 | # smoothing procedure for growth 95 | ay = AnnoyIndex(xp.shape[1], 'euclidean') 96 | for i in range(xp.shape[0]): 97 | ay.add_item(i, xp[i]) 98 | ay.build(10) 99 | 100 | prev_score = x[birth_gst].mean(axis = 1).values 101 | cur_score = np.zeros(prev_score.shape) 102 | 103 | for _ in range(5): 104 | for i in range(len(prev_score)): 105 | xn = prev_score[ay.get_nns_by_item(i, 20)] 106 | cur_score[i] = (beta * xn[0]) + ((1 - beta) * xn[1:].mean(axis = 0)) 107 | prev_score = cur_score 108 | 109 | birth_score = x[birth_gst].mean(axis = 1).values 110 | birth_smoothed_score = cur_score 111 | 112 | # smooth death score 113 | 114 | prev_score = x[death_gst].mean(axis = 1).values 115 | cur_score = np.zeros(prev_score.shape) 116 | 117 | for _ in range(5): 118 | for i in range(len(prev_score)): 119 | xn = prev_score[ay.get_nns_by_item(i, 20)] 120 | cur_score[i] = (beta * xn[0]) + ((1 - beta) * xn[1:].mean(axis = 0)) 121 | prev_score = cur_score 122 | 123 | death_score = x[death_gst].mean(axis = 1).values 124 | death_smoothed_score = cur_score 125 | 126 | # compute growth 127 | g = compute_growth(L0, L, k, 128 | birth_score, death_score, 129 | birth_smoothed_score, death_smoothed_score) 130 | y_l = sorted(metadata[tp_col].unique()) 131 | g_l = [g[(metadata[tp_col] == y_).values] for y_ in y_l] 132 | 133 | # write growth 134 | torch.save({ 135 | "w": g_l 136 | }, outfile) 137 | 138 | return g, g_l 139 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "wheel" 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | license_files = LICENSE.txt 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="prescient", 8 | version="0.1.0", 9 | 10 | 11 | author="Sachit Saksena; Grace Hui-Ting Yeo", 12 | author_email="sachit@mit.edu", 13 | license="MIT License", 14 | description="Method for simulating single cells using longitudinal scRNA-seq.", 15 | long_description=long_description, 16 | long_description_content_type="text/markdown", 17 | url="https://github.com/gifford-lab/prescient", 18 | packages=setuptools.find_packages(), 19 | classifiers=[ 20 | "Intended Audience :: Science/Research", 21 | "License :: OSI Approved :: MIT License", 22 | "Operating System :: OS Independent", 23 | "Programming Language :: Python :: 3" 24 | ], 25 | include_package_data=True, 26 | install_requires=[ 27 | "scanpy>=1.7", 28 | "pyreadr>=0.0", 29 | "matplotlib>=3.3", 30 | "annoy>=1.17.0", 31 | "numpy>=1.14", 32 | "pandas>=0.25", 33 | "scikit-learn>=0.21", 34 | "scipy>=1.3", 35 | "setuptools>=41.6", 36 | "torch>=1.5", 37 | "torchvision>=0.7", 38 | "geomloss==0.2.3", 39 | "pykeops>=1.3" 40 | ], 41 | python_requires=">=3.4", 42 | entry_points={ 43 | "console_scripts": ["prescient=prescient.__main__:main"] 44 | } 45 | ) 46 | --------------------------------------------------------------------------------