├── .gitignore ├── LICENSE ├── README.md ├── drug_response_prediction ├── README.md ├── drp.py ├── environment.yml ├── model.py ├── plot.ipynb └── run_all.sh ├── environment.yml ├── model ├── __init__.py ├── assets │ ├── args.json │ └── vocab.json ├── cancergpt.py ├── data_collator.py ├── dataset.py ├── embedding.py ├── utils.py └── vocab.py ├── training_data └── preprocessing.ipynb ├── tutorial ├── README.md ├── __init__.py └── embeddings_tutorial.ipynb └── zero_shot_batch_integration ├── README.md ├── generate_embedding.py └── plot.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Conda environment files 7 | *.env 8 | 9 | # Jupyter Notebook checkpoints 10 | .ipynb_checkpoints/ 11 | 12 | # VSCode settings 13 | .vscode/ 14 | training_data/.cache 15 | training_data/data 16 | training_data/downloads 17 | 18 | save 19 | example 20 | 21 | *.png 22 | *.h5ad 23 | *.csv 24 | *.pth 25 | *.tar -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Boeva Lab, ETH Zurich 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CancerFoundation: A single-cell RNA sequencing foundation model to decipher drug resistance in cancer 2 | 3 | [![Preprint](https://img.shields.io/badge/preprint-available-brightgreen)](https://www.biorxiv.org/content/10.1101/2024.11.01.621087v1)   4 | [![License](https://img.shields.io/badge/license-MIT-blue)](https://github.com/BoevaLab/CancerFoundation/blob/main/LICENSE) 5 | 6 | We present **CancerFoundation**, a novel single-cell RNA-seq foundation model (scFM) trained exclusively on malignant cells. Despite being trained on only one million total cells, a fraction of the data used by existing models, CancerFoundation outperforms other scFMs in key tasks such as zero-shot batch integration and drug response prediction. During training, we employ tissue and technology-aware oversampling and domain-invariant training to enhance performance on underrepresented cancer types and sequencing technologies. We propose survival prediction as a new downstream task to evaluate the generalizability of single-cell foundation models to bulk RNA data and their applicability to patient stratification. CancerFoundation demonstrates superior batch integration performance and shows significant improvements in predicting drug responses for both unseen cell lines and drugs. These results highlight the potential of focused, smaller foundation models in advancing drug discovery and our understanding of cancer biology. 7 | 8 | ## Installation 9 | 10 | ### Prerequisites 11 | 12 | Make sure you have [Conda](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html) installed on your machine. 13 | 14 | ### Step-by-Step Guide 15 | 16 | 1. **Clone the repository**: 17 | 18 | ```bash 19 | git clone https://github.com/BoevaLab/CancerFoundation.git 20 | cd CancerFoundation 21 | ``` 22 | 2. **Create the Conda environment**: 23 | ```bash 24 | conda env create -f environment.yml 25 | ``` 26 | 3. **Activate environment**: 27 | ```bash 28 | conda activate cancerfoundation 29 | ``` 30 | 4. **Download pretrained model weights**: 31 | 32 | Please download the pretrained model from [this link](https://polybox.ethz.ch/index.php/s/pZR9VH7uEHwO5CL), unzip it, and place it in the following directory: ```model/assets```. 33 | 34 | ## Generate embeddings 35 | Please consult ```tutorial/embeddings_tutorial.ipynb``` for a tutorial on how to generate embeddings with CancerFoundation for your scRNA-seq data. 36 | 37 | ## Drug response prediction 38 | Refer to ```drug_response_prediction/README.md``` for instructions on performing drug response prediction. 39 | 40 | ## Zero-shot batch integration 41 | Refer to ```zero_shot_batch_integration/README.md``` for instructions on performing and evaluating zero-shot batch integration. 42 | -------------------------------------------------------------------------------- /drug_response_prediction/README.md: -------------------------------------------------------------------------------- 1 | ## Drug response prediction 2 | 3 | ### Installation 4 | 5 | #### Prerequisites 6 | 7 | Make sure you have [Conda](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html) installed on your machine. 8 | 9 | #### Step-by-step guide 10 | 1. **Create the Conda environment**: 11 | ```bash 12 | conda env create -f environment.yml 13 | ``` 14 | 2. **Activate environment**: 15 | ```bash 16 | conda activate cancergpt_drp 17 | ``` 18 | 19 | 3. **Download data**: 20 | 21 | Please download the pretrained model from [this link](https://polybox.ethz.ch/index.php/s/UxANnzU9q3WGlNA), unzip it, and place the three folders (cell_line, drug, embedding) in this directory. 22 | 23 | 4. **(Optional) Download results for hold-out cell lines**: 24 | Download the results for hold-out cell lines from [this link](https://polybox.ethz.ch/index.php/s/ir0KNi2QXQnhrN1). 25 | 26 | ### Hold-out cell line 27 | 28 | Run the following command for drug response prediction on hold-out cell lines for **CancerGPT** embeddings. 29 | ```bash 30 | python drp.py --embedding_path "./data/embedding/CancerGPT_embedding.csv" --gpu_id 0 31 | ``` 32 | For **scFoundation** embeddings, use: 33 | ```bash 34 | python drp.py --embedding_path "./data/embedding/scFoundation_embedding.csv" --gpu_id 0 35 | ``` 36 | For **raw gene expression** data (DeepCDR), use: 37 | ```bash 38 | python drp.py --embedding_path "./data/embedding/gene_expression.csv" --gpu_id 0 39 | ``` 40 | 41 | Alternatively, download the results from the link specified in Step 4. Corresponding plots can be generated using the `plot.ipynb`notebook. 42 | 43 | ### Hold-out drug 44 | 45 | In order to derive results for all drugs and all embeddings, run the following: 46 | ```bash 47 | bash run_all.sh 48 | ``` 49 | 50 | Results will be saved in the `results` folder. -------------------------------------------------------------------------------- /drug_response_prediction/drp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import os 4 | import numpy as np 5 | import csv 6 | import pandas as pd 7 | from model import MLP, UGCNN, CombinedMLP 8 | from scipy.stats import pearsonr 9 | 10 | import hickle as hkl 11 | import argparse 12 | from tqdm import tqdm 13 | import torch 14 | from torch_geometric.data import Data 15 | from torch_geometric.data import Data, DataLoader 16 | import torch.optim as optim 17 | import copy 18 | import warnings 19 | warnings.filterwarnings("ignore", category=UserWarning) 20 | 21 | parser = argparse.ArgumentParser(description='Drug response prediction.') 22 | parser.add_argument('--gpu_id', dest='gpu_id', type=int, 23 | default='0', help='GPU devices. Use -1 for CPU.') 24 | parser.add_argument('--embedding_path', type=str, default=None, 25 | help='Path to the gene expression embeddings.') 26 | parser.add_argument('--test_drug', type=str, 27 | default=None, help='Hold-out drug.') 28 | parser.add_argument('--val_drug', type=str, default=None, 29 | help='Drug used for early stopping.') 30 | 31 | args = parser.parse_args() 32 | 33 | TCGA_label_set = ["ALL", "BLCA", "BRCA", "CESC", "DLBC", "LIHC", "LUAD", 34 | "ESCA", "GBM", "HNSC", "KIRC", "LAML", "LCML", "LGG", 35 | "LUSC", "MESO", "MM", "NB", "OV", "PAAD", "SCLC", "SKCM", 36 | "STAD", "THCA", 'COAD/READ'] 37 | DPATH = './data' 38 | Drug_info_file = f"./{DPATH}/drug/metadata.csv" 39 | # '%s/CCLE/Cell_lines_annotations_20181226.txt' % DPATH 40 | Cell_line_info_file = f"./{DPATH}/cell_line/metadata.txt" 41 | # '%s/GDSC/drug_graph_feat' % DPATH 42 | Drug_feature_file = f"./{DPATH}/drug/graph_feature" 43 | # '%s/CCLE/GDSC_IC50.csv' % DPATH 44 | Cancer_response_exp_file = f"./{DPATH}/ground_truth.csv" 45 | Max_atoms = 100 46 | 47 | 48 | assert (args.test_drug is None and args.val_drug is None) or ( 49 | args.test_drug != args.val_drug), "test_drug and val_drug must either both be None, or refer to different drugs." 50 | 51 | 52 | def MetadataGenerate(Drug_info_file, Cell_line_info_file, Drug_feature_file, Gene_expression_file, filtered): 53 | # drug_id --> pubchem_id 54 | reader = csv.reader(open(Drug_info_file, 'r')) 55 | rows = [item for item in reader] 56 | drugid2pubchemid = {item[0]: item[5] for item in rows if item[5].isdigit()} 57 | 58 | # map cellline --> cancer type 59 | cellline2cancertype = {} 60 | for line in open(Cell_line_info_file).readlines()[1:]: 61 | cellline_id = line.split('\t')[1] 62 | TCGA_label = line.strip().split('\t')[-1] 63 | cellline2cancertype[cellline_id] = TCGA_label 64 | 65 | # load drug features 66 | drug_pubchem_id_set = [] 67 | drug_feature = {} 68 | for each in os.listdir(Drug_feature_file): 69 | drug_pubchem_id_set.append(each.split('.')[0]) 70 | feat_mat, adj_list, degree_list = hkl.load( 71 | '%s/%s' % (Drug_feature_file, each)) 72 | drug_feature[each.split('.')[0]] = [feat_mat, adj_list, degree_list] 73 | assert len(drug_pubchem_id_set) == len(drug_feature.values()) 74 | 75 | # load gene expression faetures 76 | gexpr_feature = pd.read_csv( 77 | Gene_expression_file, sep=',', header=0, index_col=[0]) 78 | 79 | experiment_data = pd.read_csv( 80 | Cancer_response_exp_file, sep=',', header=0, index_col=[0]) 81 | # filter experiment data 82 | drug_match_list = [item for item in experiment_data.index if item.split(':')[ 83 | 1] in drugid2pubchemid.keys()] 84 | experiment_data_filtered = experiment_data.loc[drug_match_list] 85 | 86 | data_idx = [] 87 | for each_drug in experiment_data_filtered.index: 88 | for each_cellline in experiment_data_filtered.columns: 89 | pubchem_id = drugid2pubchemid[each_drug.split(':')[-1]] 90 | if str(pubchem_id) in drug_pubchem_id_set and each_cellline in gexpr_feature.index: 91 | if not np.isnan(experiment_data_filtered.loc[each_drug, each_cellline]) and each_cellline in cellline2cancertype.keys(): 92 | ln_IC50 = float( 93 | experiment_data_filtered.loc[each_drug, each_cellline]) 94 | data_idx.append( 95 | (each_cellline, pubchem_id, ln_IC50, cellline2cancertype[each_cellline])) 96 | nb_celllines = len(set([item[0] for item in data_idx])) 97 | nb_drugs = len(set([item[1] for item in data_idx])) 98 | print('%d instances across %d cell lines and %d drugs were generated.' % 99 | (len(data_idx), nb_celllines, nb_drugs)) 100 | return drug_feature, gexpr_feature, data_idx 101 | 102 | 103 | def DataSplit(data_idx, ratio=0.95): 104 | data_train_idx, data_test_idx = [], [] 105 | for each_type in TCGA_label_set: 106 | data_subtype_idx = [item for item in data_idx if item[-1] == each_type] 107 | train_list = random.sample( 108 | data_subtype_idx, int(ratio*len(data_subtype_idx))) 109 | test_list = [ 110 | item for item in data_subtype_idx if item not in train_list] 111 | data_train_idx += train_list 112 | data_test_idx += test_list 113 | return data_train_idx, data_test_idx 114 | 115 | 116 | def DrugSplit(data_idx, drugtype): 117 | data_train_idx, data_test_idx = [], [] 118 | data_test_idx = [item for item in data_idx if item[1] == drugtype] 119 | data_train_idx = [item for item in data_idx if item[1] != drugtype] 120 | return data_train_idx, data_test_idx 121 | 122 | 123 | def CalculateGraphFeat(feat_mat, adj_list): 124 | edge_index = [] 125 | 126 | # Convert adjacency list to edge_index format 127 | for node, neighbors in enumerate(adj_list): 128 | for neighbor in neighbors: 129 | edge_index.append([node, neighbor]) 130 | edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() 131 | 132 | # Create a data object 133 | data = Data(x=torch.tensor(feat_mat).float(), 134 | edge_index=edge_index) 135 | 136 | return data 137 | 138 | 139 | def FeatureExtract(data_idx, drug_feature, gexpr_feature): 140 | cancer_type_list = [] 141 | nb_instance = len(data_idx) 142 | nb_gexpr_features = gexpr_feature.shape[1] 143 | drug_data = [[] for item in range(nb_instance)] 144 | gexpr_data = torch.zeros((nb_instance, nb_gexpr_features)).float() 145 | target = torch.zeros(nb_instance).float() 146 | for idx in tqdm(range(nb_instance)): 147 | cell_line_id, pubchem_id, ln_IC50, cancer_type = data_idx[idx] 148 | 149 | feat_mat, adj_list, _ = drug_feature[str(pubchem_id)] 150 | 151 | drug_data[idx] = CalculateGraphFeat(feat_mat, adj_list) 152 | 153 | gexpr_data[idx, :] = torch.tensor( 154 | gexpr_feature.loc[cell_line_id].values) 155 | target[idx] = ln_IC50 156 | cancer_type_list.append([cancer_type, cell_line_id, pubchem_id]) 157 | return drug_data, gexpr_data, target, cancer_type_list 158 | 159 | 160 | def place_on_device(gpu_id=0, *tensors_or_models): 161 | # Determine the device based on the gpu_id 162 | if gpu_id == -1: 163 | device = torch.device('cpu') 164 | elif gpu_id >= torch.cuda.device_count(): 165 | raise ValueError(f"GPU with ID {gpu_id} is not available.") 166 | else: 167 | device = torch.device(f'cuda:{gpu_id}') 168 | 169 | # Move each tensor or model to the specified device and return them as a list 170 | return [tensor_or_model.to(device) for tensor_or_model in tensors_or_models] 171 | 172 | 173 | def main(): 174 | random.seed(0) 175 | drug_feature, gexpr_feature, data_idx = MetadataGenerate( 176 | Drug_info_file, Cell_line_info_file, Drug_feature_file, args.embedding_path, False) 177 | 178 | gexpr_dim = gexpr_feature.shape[-1] 179 | 180 | if args.test_drug is None and args.val_drug is None: 181 | data_train_idx, data_real_idx = DataSplit(data_idx, ratio=0.95) 182 | data_train_idx, data_test_idx = DataSplit( 183 | data_train_idx, ratio=1-0.05/0.95) 184 | else: 185 | data_train_idx, data_real_idx = DrugSplit(data_idx, args.test_drug) 186 | assert len(data_real_idx) > 0, "Test drug doesn't exist." 187 | data_train_idx, data_test_idx = DrugSplit( 188 | data_train_idx, args.val_drug) 189 | assert len(data_real_idx) > 0, "Validation drug doesn't exist." 190 | 191 | # Extract features for training and test 192 | X_drug_data_train, X_gexpr_data_train, Y_train, _ = FeatureExtract( 193 | data_train_idx, drug_feature, gexpr_feature) 194 | 195 | batch_size = 32 196 | train_loader = DataLoader(list(zip(X_drug_data_train, X_gexpr_data_train, Y_train)), 197 | batch_size=batch_size, shuffle=True, num_workers=8) 198 | 199 | X_drug_data_test, X_gexpr_data_test, Y_test, _ = FeatureExtract( 200 | data_test_idx, drug_feature, gexpr_feature) 201 | 202 | test_loader = DataLoader(list(zip(X_drug_data_test, X_gexpr_data_test, Y_test)), 203 | batch_size=batch_size, shuffle=False, num_workers=8) 204 | 205 | X_drug_data_real, X_gexpr_data_real, Y_real, _ = FeatureExtract( 206 | data_real_idx, drug_feature, gexpr_feature) 207 | 208 | real_loader = DataLoader(list(zip(X_drug_data_real, X_gexpr_data_real, Y_real)), 209 | batch_size=batch_size, shuffle=False, num_workers=8) 210 | 211 | drug_gcnn = UGCNN(input_dim=75, hidden_dims=[ 212 | 256, 256, 256], out_channels=100) 213 | gexpr_mlp = MLP(out_dim=100, input_dim=gexpr_dim) 214 | comb_mlp = CombinedMLP(input_dim=200) 215 | 216 | drug_gcnn, gexpr_mlp, comb_mlp = place_on_device( 217 | args.gpu_id, drug_gcnn, gexpr_mlp, comb_mlp) 218 | 219 | parameters = list(drug_gcnn.parameters()) + \ 220 | list(gexpr_mlp.parameters()) + list(comb_mlp.parameters()) 221 | 222 | optimizer = optim.Adam(parameters, lr=0.001, betas=(0.9, 0.999), 223 | eps=1e-08, weight_decay=0, amsgrad=False) 224 | criterion = torch.nn.MSELoss() 225 | 226 | def compute_pcc(y_true, y_pred): 227 | y_true = y_true.flatten() 228 | y_pred = y_pred.flatten() 229 | return pearsonr(y_true, y_pred)[0] 230 | 231 | # Training loop 232 | epochs = 500 233 | patience = 10 234 | best_val_loss = -float('inf') 235 | best_drug_gcnn = None 236 | best_gexpr_mlp = None 237 | best_comb_mlp = None 238 | patience_counter = 0 239 | for epoch in range(epochs): 240 | running_loss = 0.0 241 | progress_bar = tqdm( 242 | train_loader, desc=f'Epoch {epoch+1}/{epochs}', leave=False) 243 | 244 | drug_gcnn.train() 245 | gexpr_mlp.train() 246 | comb_mlp.train() 247 | for i, (drug, gexpr, target) in enumerate(progress_bar): 248 | target, drug, gexpr = place_on_device( 249 | args.gpu_id, target, drug, gexpr) 250 | optimizer.zero_grad() 251 | drug = drug_gcnn(drug) 252 | gexpr = gexpr_mlp(gexpr) 253 | out = comb_mlp(torch.cat((drug, gexpr), dim=1)) 254 | 255 | loss = criterion(out.view(-1), target.view(-1)) 256 | loss.backward() 257 | optimizer.step() 258 | running_loss += loss.item() 259 | 260 | # Update progress bar 261 | progress_bar.set_postfix(loss=running_loss/(i+1)) 262 | 263 | drug_gcnn.eval() 264 | gexpr_mlp.eval() 265 | comb_mlp.eval() 266 | val_loss = 0.0 267 | all_targets = [] 268 | all_outputs = [] 269 | with torch.no_grad(): 270 | for drug, gexpr, target in test_loader: 271 | target, drug, gexpr = place_on_device( 272 | args.gpu_id, target, drug, gexpr) 273 | drug = drug_gcnn(drug) 274 | gexpr = gexpr_mlp(gexpr) 275 | out = comb_mlp(torch.cat((drug, gexpr), dim=1)) 276 | loss = criterion(out.view(-1), target.view(-1)) 277 | val_loss += loss.item() 278 | all_targets.append(target.view(-1).cpu().numpy()) 279 | all_outputs.append(out.view(-1).cpu().numpy()) 280 | 281 | val_loss /= len(test_loader) 282 | 283 | all_targets = np.concatenate(all_targets) 284 | all_outputs = np.concatenate(all_outputs) 285 | pcc = compute_pcc(all_targets, all_outputs) 286 | 287 | print( 288 | f'Epoch {epoch+1}, Validation Loss: {val_loss:.4f}, PCC: {pcc:.4f}') 289 | 290 | if pcc > best_val_loss: 291 | best_drug_gcnn = copy.deepcopy(drug_gcnn) 292 | best_gexpr_mlp = copy.deepcopy(gexpr_mlp) 293 | best_comb_mlp = copy.deepcopy(comb_mlp) 294 | best_val_loss = pcc 295 | patience_counter = 0 296 | else: 297 | patience_counter += 1 298 | 299 | if patience_counter >= patience: 300 | print('Early stopping') 301 | break 302 | 303 | drug_gcnn.eval() 304 | gexpr_mlp.eval() 305 | comb_mlp.eval() 306 | all_targets = [] 307 | all_outputs = [] 308 | with torch.no_grad(): 309 | for drug, gexpr, target in real_loader: 310 | target, drug, gexpr = place_on_device( 311 | args.gpu_id, target, drug, gexpr) 312 | drug = best_drug_gcnn(drug) 313 | gexpr = best_gexpr_mlp(gexpr) 314 | out = best_comb_mlp(torch.cat((drug, gexpr), dim=1)) 315 | loss = criterion(out.view(-1), target.view(-1)) 316 | val_loss += loss.item() 317 | all_targets.append(target.view(-1).cpu().numpy()) 318 | all_outputs.append(out.view(-1).cpu().numpy()) 319 | 320 | pcc_real = compute_pcc(np.concatenate(all_targets), 321 | np.concatenate(all_outputs)) 322 | print(f"PCC on hold-out data: {pcc_real}.") 323 | 324 | embedding_identifier = args.embedding_path[args.embedding_path.rfind( 325 | "/")+1: args.embedding_path.rfind(".csv")] 326 | 327 | os.makedirs("./results", exist_ok=True) 328 | 329 | task = "cell_line" if ( 330 | args.val_drug is None and args.test_drug is None) else f"drug_{args.test_drug}" 331 | 332 | data = { 333 | 'cell_line': [entry[0] for entry in data_real_idx], 334 | 'pubchem_cid': [entry[1] for entry in data_real_idx], 335 | 'cancer_type': [entry[3] for entry in data_real_idx], 336 | 'prediction': np.concatenate(all_outputs).tolist(), 337 | 'ground_truth': np.concatenate(all_targets).tolist(), 338 | } 339 | 340 | df = pd.DataFrame(data) 341 | df.to_csv( 342 | f"./results/hold_out_{task}_{embedding_identifier}.csv", index=False) 343 | print( 344 | f"Results saved to ./results/hold_out_{task}_{embedding_identifier}.csv") 345 | 346 | 347 | if __name__ == '__main__': 348 | main() 349 | -------------------------------------------------------------------------------- /drug_response_prediction/environment.yml: -------------------------------------------------------------------------------- 1 | name: cancergpt_drp 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=5.1=1_gnu 10 | - asttokens=2.4.1=pyhd8ed1ab_0 11 | - backcall=0.2.0=pyh9f0ad1d_0 12 | - blas=1.0=mkl 13 | - brotli-python=1.0.9=py39h5a03fae_7 14 | - bzip2=1.0.8=h7f98852_4 15 | - ca-certificates=2024.8.30=hbcca054_0 16 | - certifi=2024.8.30=pyhd8ed1ab_0 17 | - cffi=1.14.6=py39he32792d_0 18 | - charset-normalizer=3.3.2=pyhd8ed1ab_0 19 | - comm=0.2.2=pyhd8ed1ab_0 20 | - cuda-cudart=11.7.99=0 21 | - cuda-cupti=11.7.101=0 22 | - cuda-libraries=11.7.1=0 23 | - cuda-nvrtc=11.7.99=0 24 | - cuda-nvtx=11.7.91=0 25 | - cuda-runtime=11.7.1=0 26 | - cuda-version=12.6=3 27 | - debugpy=1.6.7=py39h6a678d5_0 28 | - decorator=5.1.1=pyhd8ed1ab_0 29 | - entrypoints=0.4=pyhd8ed1ab_0 30 | - executing=2.1.0=pyhd8ed1ab_0 31 | - ffmpeg=4.3=hf484d3e_0 32 | - filelock=3.13.1=pyhd8ed1ab_0 33 | - freetype=2.12.1=h4a9f257_0 34 | - giflib=5.2.2=h5eee18b_0 35 | - gmp=6.2.1=h58526e2_0 36 | - gmpy2=2.1.2=py39heeb90bb_0 37 | - gnutls=3.6.15=he1e5248_0 38 | - h2=4.1.0=py39hf3d152e_0 39 | - hpack=4.0.0=pyh9f0ad1d_0 40 | - hyperframe=6.0.1=pyhd8ed1ab_0 41 | - idna=3.7=pyhd8ed1ab_0 42 | - intel-openmp=2021.4.0=h06a4308_3561 43 | - ipykernel=6.29.5=pyh3099207_0 44 | - ipython=8.12.0=pyh41d4057_0 45 | - jedi=0.19.1=pyhd8ed1ab_0 46 | - jinja2=3.1.4=pyhd8ed1ab_0 47 | - jpeg=9e=h166bdaf_1 48 | - jupyter_client=7.3.4=pyhd8ed1ab_0 49 | - jupyter_core=5.7.2=pyh31011fe_1 50 | - lame=3.100=h7f98852_1001 51 | - lcms2=2.12=h3be6417_0 52 | - ld_impl_linux-64=2.40=hf3520f5_7 53 | - lerc=3.0=h9c3ff4c_0 54 | - libcublas=11.10.3.66=0 55 | - libcufft=10.7.2.124=h4fbf590_0 56 | - libcufile=1.11.1.6=0 57 | - libcurand=10.3.7.77=0 58 | - libcusolver=11.4.0.1=0 59 | - libcusparse=11.7.4.91=0 60 | - libdeflate=1.17=h5eee18b_1 61 | - libffi=3.3=h58526e2_2 62 | - libgcc-ng=11.2.0=h1234567_1 63 | - libgomp=11.2.0=h1234567_1 64 | - libiconv=1.16=h516909a_0 65 | - libidn2=2.3.4=h5eee18b_0 66 | - libjpeg-turbo=2.0.0=h9bf148f_0 67 | - libnpp=11.7.4.75=0 68 | - libnvjpeg=11.8.0.2=0 69 | - libpng=1.6.39=h5eee18b_0 70 | - libsodium=1.0.18=h36c2ea0_1 71 | - libstdcxx-ng=11.2.0=he4da1e4_16 72 | - libtasn1=4.19.0=h5eee18b_0 73 | - libtiff=4.5.1=h6a678d5_0 74 | - libunistring=0.9.10=h7f98852_0 75 | - libwebp=1.3.2=h11a3e52_0 76 | - libwebp-base=1.3.2=h5eee18b_1 77 | - llvm-openmp=14.0.6=h9e868ea_0 78 | - lz4-c=1.9.4=h6a678d5_1 79 | - markupsafe=2.1.3=py39h5eee18b_0 80 | - matplotlib-inline=0.1.7=pyhd8ed1ab_0 81 | - mkl=2021.4.0=h06a4308_640 82 | - mkl-service=2.4.0=py39h7e14d7c_0 83 | - mkl_fft=1.3.1=py39h0c7bc48_1 84 | - mkl_random=1.2.2=py39hde0f152_0 85 | - mpc=1.1.0=h04dde30_1009 86 | - mpfr=4.0.2=he80fd80_1 87 | - mpmath=1.3.0=pyhd8ed1ab_0 88 | - ncurses=6.4=h6a678d5_0 89 | - nest-asyncio=1.6.0=pyhd8ed1ab_0 90 | - nettle=3.7.3=hbbd107a_1 91 | - networkx=3.2.1=pyhd8ed1ab_0 92 | - numpy=1.24.3=py39h14f4228_0 93 | - numpy-base=1.24.3=py39h31eccc5_0 94 | - openh264=2.1.1=h780b84a_0 95 | - openjpeg=2.5.2=he7f1fd0_0 96 | - openssl=1.1.1w=h7f8727e_0 97 | - packaging=24.1=pyhd8ed1ab_0 98 | - parso=0.8.4=pyhd8ed1ab_0 99 | - pexpect=4.9.0=pyhd8ed1ab_0 100 | - pickleshare=0.7.5=py_1003 101 | - pillow=10.4.0=py39h5eee18b_0 102 | - pip=24.2=pyh8b19718_1 103 | - platformdirs=4.3.6=pyhd8ed1ab_0 104 | - prompt-toolkit=3.0.48=pyha770c72_0 105 | - prompt_toolkit=3.0.48=hd8ed1ab_0 106 | - psutil=5.9.0=py39hb9d737c_1 107 | - ptyprocess=0.7.0=pyhd3deb0d_0 108 | - pure_eval=0.2.3=pyhd8ed1ab_0 109 | - pycparser=2.22=pyhd8ed1ab_0 110 | - pygments=2.18.0=pyhd8ed1ab_0 111 | - pysocks=1.7.1=pyha2e5f31_6 112 | - python=3.9.12=h12debd9_1 113 | - python_abi=3.9=2_cp39 114 | - pytorch-cuda=11.7=h778d358_5 115 | - pytorch-mutex=1.0=cpu 116 | - pyyaml=6.0.2=py39h5eee18b_0 117 | - pyzmq=25.1.2=py39h6a678d5_0 118 | - readline=8.2=h5eee18b_0 119 | - requests=2.32.3=pyhd8ed1ab_0 120 | - setuptools=75.1.0=pyhd8ed1ab_0 121 | - six=1.16.0=pyh6c4a22f_0 122 | - sqlite=3.45.3=h5eee18b_0 123 | - stack_data=0.6.2=pyhd8ed1ab_0 124 | - sympy=1.13.2=pyh04b8f61_3 125 | - tk=8.6.14=h39e8969_0 126 | - torchaudio=2.5.0=py39_cpu 127 | - torchvision=0.20.0=py39_cpu 128 | - tornado=6.1=py39hb9d737c_3 129 | - traitlets=5.14.3=pyhd8ed1ab_0 130 | - typing_extensions=4.11.0=pyha770c72_0 131 | - urllib3=2.2.3=pyhd8ed1ab_0 132 | - wcwidth=0.2.13=pyhd8ed1ab_0 133 | - wheel=0.44.0=pyhd8ed1ab_0 134 | - xz=5.4.6=h5eee18b_1 135 | - yaml=0.2.5=h7f98852_2 136 | - zeromq=4.3.5=h6a678d5_0 137 | - zlib=1.2.13=h5eee18b_1 138 | - zstandard=0.23.0=py39h2c38b39_0 139 | - zstd=1.5.6=hc292b87_0 140 | - pip: 141 | - aiohappyeyeballs==2.4.3 142 | - aiohttp==3.10.10 143 | - aiosignal==1.3.1 144 | - anndata==0.10.9 145 | - array-api-compat==1.9 146 | - async-timeout==4.0.3 147 | - attrs==24.2.0 148 | - contourpy==1.3.0 149 | - cycler==0.12.1 150 | - dill==0.3.9 151 | - exceptiongroup==1.2.2 152 | - fonttools==4.54.1 153 | - frozenlist==1.5.0 154 | - fsspec==2024.9.0 155 | - get-annotations==0.1.2 156 | - h5py==3.12.1 157 | - hickle==5.0.3 158 | - importlib-resources==6.4.5 159 | - joblib==1.4.2 160 | - kiwisolver==1.4.7 161 | - legacy-api-wrap==1.4 162 | - llvmlite==0.43.0 163 | - matplotlib==3.9.2 164 | - multidict==6.1.0 165 | - natsort==8.4.0 166 | - numba==0.60.0 167 | - nvidia-cublas-cu12==12.1.3.1 168 | - nvidia-cuda-cupti-cu12==12.1.105 169 | - nvidia-cuda-nvrtc-cu12==12.1.105 170 | - nvidia-cuda-runtime-cu12==12.1.105 171 | - nvidia-cudnn-cu12==8.9.2.26 172 | - nvidia-cufft-cu12==11.0.2.54 173 | - nvidia-curand-cu12==10.3.2.106 174 | - nvidia-cusolver-cu12==11.4.5.107 175 | - nvidia-cusparse-cu12==12.1.0.106 176 | - nvidia-nccl-cu12==2.19.3 177 | - nvidia-nvjitlink-cu12==12.6.77 178 | - nvidia-nvtx-cu12==12.1.105 179 | - pandas==2.2.3 180 | - patsy==0.5.6 181 | - propcache==0.2.0 182 | - pynndescent==0.5.13 183 | - pyparsing==3.2.0 184 | - python-dateutil==2.9.0.post0 185 | - pytz==2024.2 186 | - scanpy==1.10.3 187 | - scikit-learn==1.5.2 188 | - scipy==1.13.1 189 | - seaborn==0.13.2 190 | - session-info==1.0.0 191 | - statsmodels==0.14.4 192 | - stdlib-list==0.11.0 193 | - threadpoolctl==3.5.0 194 | - torch==2.2.1 195 | - torch-geometric==2.6.1 196 | - torchdata==0.7.1 197 | - torchtext==0.17.1 198 | - tqdm==4.66.5 199 | - triton==2.2.0 200 | - tzdata==2024.2 201 | - umap-learn==0.5.6 202 | - yarl==1.17.0 203 | - zipp==3.20.2 204 | 205 | -------------------------------------------------------------------------------- /drug_response_prediction/model.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_geometric.nn import GCNConv, global_mean_pool, BatchNorm 5 | import torch 6 | 7 | 8 | class UGCNN(nn.Module): 9 | def __init__(self, hidden_dims: List[int], input_dim: int, out_channels: int): 10 | super(UGCNN, self).__init__() 11 | dims = [input_dim] + hidden_dims + [out_channels] 12 | self.conv_list = nn.ModuleList([]) 13 | for i in range(len(dims)-1): 14 | self.conv_list.extend( 15 | [ 16 | GCNConv(dims[i], dims[i+1]) 17 | ] 18 | ) 19 | self.norm_list = nn.ModuleList([]) 20 | for i in range(len(dims)-1): 21 | self.norm_list.append( 22 | nn.Sequential( 23 | BatchNorm(dims[i+1]), 24 | nn.ReLU(dims[i+1]), 25 | nn.Dropout(p=0.1) 26 | ) 27 | ) 28 | 29 | def forward(self, data): 30 | x, edge_index, batch = data.x, data.edge_index, data.batch 31 | 32 | for (conv, norm) in zip(self.conv_list, self.norm_list): 33 | x = norm(conv(x, edge_index)) 34 | 35 | # Use global mean pooling for the entire graph 36 | x = global_mean_pool(x, batch) 37 | 38 | return x 39 | # return F.log_softmax(x, dim=1) 40 | 41 | 42 | class MLP(nn.Module): 43 | def __init__(self, out_dim: int, input_dim: int, hidden_dim: int = 256): 44 | super(MLP, self).__init__() 45 | self.fc1 = nn.Linear(in_features=input_dim, out_features=hidden_dim) 46 | self.fc2 = nn.Linear(in_features=hidden_dim, out_features=out_dim) 47 | self.bn1 = nn.BatchNorm1d(hidden_dim) 48 | self.dropout = nn.Dropout(p=0.1) 49 | 50 | def forward(self, x: torch.Tensor): 51 | x = self.fc1(x) 52 | x = F.tanh(x) 53 | x = self.bn1(x) 54 | x = self.dropout(x) 55 | x = self.fc2(x) 56 | return F.relu(x) 57 | 58 | 59 | class CombinedMLP(nn.Module): 60 | def __init__(self, input_dim: int): 61 | super(CombinedMLP, self).__init__() 62 | self.combined_fc1 = nn.Linear(input_dim, 300) 63 | self.combined_conv1 = nn.Conv2d(1, 30, kernel_size=(1, 150)) 64 | self.combined_conv2 = nn.Conv2d(30, 10, kernel_size=(1, 5)) 65 | self.combined_conv3 = nn.Conv2d(10, 5, kernel_size=(1, 5)) 66 | self.combined_pool1 = nn.MaxPool2d(kernel_size=(1, 2)) 67 | self.combined_pool2 = nn.MaxPool2d(kernel_size=(1, 3)) 68 | self.combined_pool3 = nn.MaxPool2d(kernel_size=(1, 3)) 69 | self.final_fc = nn.Linear(30, 1) 70 | self.dropout = nn.Dropout(p=0.1) 71 | self.final_dropout = nn.Dropout(0.2) 72 | 73 | def forward(self, x: torch.Tensor): 74 | x = self.combined_fc1(x) 75 | x = F.tanh(x) 76 | x = self.dropout(x) 77 | x = x.unsqueeze(1).unsqueeze(2) 78 | 79 | x = self.combined_conv1(x) 80 | 81 | x = F.relu(x) 82 | x = self.combined_pool1(x) 83 | x = self.combined_conv2(x) 84 | x = F.relu(x) 85 | x = self.combined_pool2(x) 86 | x = self.combined_conv3(x) 87 | x = F.relu(x) 88 | x = self.combined_pool3(x) 89 | x = self.dropout(x) 90 | x = x.view(x.size(0), -1) 91 | x = self.final_dropout(x) 92 | 93 | return self.final_fc(x) 94 | -------------------------------------------------------------------------------- /drug_response_prediction/run_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | items=(123631 20635522 10341154 11640390 6253 9956222 82146 46931012 637858 24825971 1401 57339144 44462760 10451420 25167777 24889392 11667893 160355 3463933 42642645 5328940 10127622 9813758 5746 5459322 11485656 3218 10184653 10200390 44819241 5311510 159324 5289247 6710780 9952773 3005532 44632017 9967941 46844147 6445562 6918638 11316960 176167 25222038 24756910 6918289 7251185 16722836 16760646 9549184 53340664 2733526 31703 11404337 46930998 10390396 176870 76044 46843772 5329102 16038120 24894414 54676905 24776445 300471 46943432 44182295 9810884 10074640 5330286 49806720 54685215 11617559 9910224 6753378 4993 11626927 11626560 11634725 46885626 5494449 9863776 3062316 65110 10459196 56962337 9826528 44143370 11455910 9911830 16747388 6852167 16725726 5384616 44551660 53302361 10172943 11713159 5208 46907787 448013 11647372 148124 5113032 126565 17755052 5280757 560326 46191454 25022668 104842 216239 49836027 11373846 446378 56965967 11676786 44450571 462382 11609586 46224516 23725625 10461815 66577006 644241 5460769 11338033 46883536 85668777 644215 10196499 216326 8249 9884685 156422 444795 5311 11152667 9858940 24978538 9956119 3385 11707110 6505803 6450551 6445533 5311497 6918848 5394 84691 9907093 36314 4263900 9931953 10113978 9868037 16095342 25124816 11977753 11960529 25195352 44137675 11844351 521106 5327091 71271629 2375 3796 51371303 10302451 24826799 16720766 10109823 10384072 208908 387447 11433190 6914657 5717801 10027278 42640 11493598 176158 24180719 6450816 11625818 60750 11624601 25126798 49821040 9903786 9874913 78243717 447912 53394750 44224160 126941 24951314 2314623 25102847 9943465 11327430 704473 24772860 2726824 25257557 9826308 11364421 4261 16663089 5278396 36462 6918454 9938202 9914412 10096043 5291 24785538 11754511 11282283 25262965 10077147 11178236) 4 | 5 | # Get the number of items in the list 6 | num_items=${#items[@]} 7 | 8 | # Iterate through each item 9 | for ((i = 0; i < num_items; i++)); do 10 | current_item=${items[$i]} 11 | echo "Current drug: $current_item" 12 | 13 | # Create a temporary array excluding the current item 14 | temp_items=("${items[@]:0:i}" "${items[@]:((i + 1))}") 15 | 16 | # Get the number of items in the temporary array 17 | temp_num_items=${#temp_items[@]} 18 | 19 | # Generate a random index for the temporary array 20 | random_index=$((RANDOM % temp_num_items)) 21 | 22 | # Select a random item from the temporary array 23 | random_item=${temp_items[$random_index]} 24 | 25 | python drp.py --embedding_path "./data/embedding/CancerGPT_embedding.csv" --test_drug "$current_item" --val_drug "$random_item" --gpu_id 0 26 | python drp.py --embedding_path "./data/embedding/gene_expression.csv" --test_drug "$current_item" --val_drug "$random_item" --gpu_id 0 27 | python drp.py --embedding_path "./data/embedding/scFoundation_embedding.csv" --test_drug "$current_item" --val_drug "$random_item" --gpu_id 0 28 | done -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: cancerfoundation 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=5.1=1_gnu 10 | - asttokens=2.4.1=pyhd8ed1ab_0 11 | - backcall=0.2.0=pyh9f0ad1d_0 12 | - blas=1.0=mkl 13 | - brotli-python=1.0.9=py39h5a03fae_7 14 | - bzip2=1.0.8=h7f98852_4 15 | - ca-certificates=2024.8.30=hbcca054_0 16 | - certifi=2024.8.30=pyhd8ed1ab_0 17 | - cffi=1.14.6=py39he32792d_0 18 | - charset-normalizer=3.3.2=pyhd8ed1ab_0 19 | - comm=0.2.2=pyhd8ed1ab_0 20 | - cuda-cudart=11.7.99=0 21 | - cuda-cupti=11.7.101=0 22 | - cuda-libraries=11.7.1=0 23 | - cuda-nvrtc=11.7.99=0 24 | - cuda-nvtx=11.7.91=0 25 | - cuda-runtime=11.7.1=0 26 | - cuda-version=12.6=3 27 | - debugpy=1.6.7=py39h6a678d5_0 28 | - decorator=5.1.1=pyhd8ed1ab_0 29 | - entrypoints=0.4=pyhd8ed1ab_0 30 | - executing=2.1.0=pyhd8ed1ab_0 31 | - ffmpeg=4.3=hf484d3e_0 32 | - filelock=3.13.1=pyhd8ed1ab_0 33 | - freetype=2.12.1=h4a9f257_0 34 | - giflib=5.2.2=h5eee18b_0 35 | - gmp=6.2.1=h58526e2_0 36 | - gmpy2=2.1.2=py39heeb90bb_0 37 | - gnutls=3.6.15=he1e5248_0 38 | - h2=4.1.0=py39hf3d152e_0 39 | - hpack=4.0.0=pyh9f0ad1d_0 40 | - hyperframe=6.0.1=pyhd8ed1ab_0 41 | - idna=3.7=pyhd8ed1ab_0 42 | - intel-openmp=2021.4.0=h06a4308_3561 43 | - ipykernel=6.29.5=pyh3099207_0 44 | - ipython=8.12.0=pyh41d4057_0 45 | - jedi=0.19.1=pyhd8ed1ab_0 46 | - jinja2=3.1.4=pyhd8ed1ab_0 47 | - jpeg=9e=h166bdaf_1 48 | - jupyter_client=7.3.4=pyhd8ed1ab_0 49 | - jupyter_core=5.7.2=pyh31011fe_1 50 | - lame=3.100=h7f98852_1001 51 | - lcms2=2.12=h3be6417_0 52 | - ld_impl_linux-64=2.40=hf3520f5_7 53 | - lerc=3.0=h9c3ff4c_0 54 | - libcublas=11.10.3.66=0 55 | - libcufft=10.7.2.124=h4fbf590_0 56 | - libcufile=1.11.1.6=0 57 | - libcurand=10.3.7.77=0 58 | - libcusolver=11.4.0.1=0 59 | - libcusparse=11.7.4.91=0 60 | - libdeflate=1.17=h5eee18b_1 61 | - libffi=3.3=h58526e2_2 62 | - libgcc-ng=11.2.0=h1234567_1 63 | - libgomp=11.2.0=h1234567_1 64 | - libiconv=1.16=h516909a_0 65 | - libidn2=2.3.4=h5eee18b_0 66 | - libjpeg-turbo=2.0.0=h9bf148f_0 67 | - libnpp=11.7.4.75=0 68 | - libnvjpeg=11.8.0.2=0 69 | - libpng=1.6.39=h5eee18b_0 70 | - libsodium=1.0.18=h36c2ea0_1 71 | - libstdcxx-ng=11.2.0=he4da1e4_16 72 | - libtasn1=4.19.0=h5eee18b_0 73 | - libtiff=4.5.1=h6a678d5_0 74 | - libunistring=0.9.10=h7f98852_0 75 | - libwebp=1.3.2=h11a3e52_0 76 | - libwebp-base=1.3.2=h5eee18b_1 77 | - llvm-openmp=14.0.6=h9e868ea_0 78 | - lz4-c=1.9.4=h6a678d5_1 79 | - markupsafe=2.1.3=py39h5eee18b_0 80 | - matplotlib-inline=0.1.7=pyhd8ed1ab_0 81 | - mkl=2021.4.0=h06a4308_640 82 | - mkl-service=2.4.0=py39h7e14d7c_0 83 | - mkl_fft=1.3.1=py39h0c7bc48_1 84 | - mkl_random=1.2.2=py39hde0f152_0 85 | - mpc=1.1.0=h04dde30_1009 86 | - mpfr=4.0.2=he80fd80_1 87 | - mpmath=1.3.0=pyhd8ed1ab_0 88 | - ncurses=6.4=h6a678d5_0 89 | - nest-asyncio=1.6.0=pyhd8ed1ab_0 90 | - nettle=3.7.3=hbbd107a_1 91 | - networkx=3.2.1=pyhd8ed1ab_0 92 | - numpy=1.24.3=py39h14f4228_0 93 | - numpy-base=1.24.3=py39h31eccc5_0 94 | - openh264=2.1.1=h780b84a_0 95 | - openjpeg=2.5.2=he7f1fd0_0 96 | - openssl=1.1.1w=h7f8727e_0 97 | - packaging=24.1=pyhd8ed1ab_0 98 | - parso=0.8.4=pyhd8ed1ab_0 99 | - pexpect=4.9.0=pyhd8ed1ab_0 100 | - pickleshare=0.7.5=py_1003 101 | - pillow=10.4.0=py39h5eee18b_0 102 | - pip=24.2=pyh8b19718_1 103 | - platformdirs=4.3.6=pyhd8ed1ab_0 104 | - prompt-toolkit=3.0.48=pyha770c72_0 105 | - prompt_toolkit=3.0.48=hd8ed1ab_0 106 | - psutil=5.9.0=py39hb9d737c_1 107 | - ptyprocess=0.7.0=pyhd3deb0d_0 108 | - pure_eval=0.2.3=pyhd8ed1ab_0 109 | - pycparser=2.22=pyhd8ed1ab_0 110 | - pygments=2.18.0=pyhd8ed1ab_0 111 | - pysocks=1.7.1=pyha2e5f31_6 112 | - python=3.9.12=h12debd9_1 113 | - python_abi=3.9=2_cp39 114 | - pytorch-cuda=11.7=h778d358_5 115 | - pytorch-mutex=1.0=cpu 116 | - pyyaml=6.0.2=py39h5eee18b_0 117 | - pyzmq=25.1.2=py39h6a678d5_0 118 | - readline=8.2=h5eee18b_0 119 | - requests=2.32.3=pyhd8ed1ab_0 120 | - setuptools=75.1.0=pyhd8ed1ab_0 121 | - six=1.16.0=pyh6c4a22f_0 122 | - sqlite=3.45.3=h5eee18b_0 123 | - stack_data=0.6.2=pyhd8ed1ab_0 124 | - sympy=1.13.2=pyh04b8f61_3 125 | - tk=8.6.14=h39e8969_0 126 | - torchaudio=2.5.0=py39_cpu 127 | - torchvision=0.20.0=py39_cpu 128 | - tornado=6.1=py39hb9d737c_3 129 | - traitlets=5.14.3=pyhd8ed1ab_0 130 | - typing_extensions=4.11.0=pyha770c72_0 131 | - urllib3=2.2.3=pyhd8ed1ab_0 132 | - wcwidth=0.2.13=pyhd8ed1ab_0 133 | - wheel=0.44.0=pyhd8ed1ab_0 134 | - xz=5.4.6=h5eee18b_1 135 | - yaml=0.2.5=h7f98852_2 136 | - zeromq=4.3.5=h6a678d5_0 137 | - zlib=1.2.13=h5eee18b_1 138 | - zstandard=0.23.0=py39h2c38b39_0 139 | - zstd=1.5.6=hc292b87_0 140 | - pip: 141 | - absl-py==2.1.0 142 | - anndata==0.10.9 143 | - array-api-compat==1.9 144 | - chex==0.1.87 145 | - contourpy==1.3.0 146 | - cycler==0.12.1 147 | - exceptiongroup==1.2.2 148 | - fonttools==4.54.1 149 | - fsspec==2024.9.0 150 | - get-annotations==0.1.2 151 | - h5py==3.12.1 152 | - igraph==0.11.8 153 | - importlib-metadata==8.5.0 154 | - importlib-resources==6.4.5 155 | - jax==0.4.30 156 | - jaxlib==0.4.30 157 | - joblib==1.4.2 158 | - kiwisolver==1.4.7 159 | - legacy-api-wrap==1.4 160 | - llvmlite==0.43.0 161 | - markdown-it-py==3.0.0 162 | - matplotlib==3.9.2 163 | - mdurl==0.1.2 164 | - ml-dtypes==0.5.0 165 | - natsort==8.4.0 166 | - numba==0.60.0 167 | - nvidia-cublas-cu12==12.1.3.1 168 | - nvidia-cuda-cupti-cu12==12.1.105 169 | - nvidia-cuda-nvrtc-cu12==12.1.105 170 | - nvidia-cuda-runtime-cu12==12.1.105 171 | - nvidia-cudnn-cu12==8.9.2.26 172 | - nvidia-cufft-cu12==11.0.2.54 173 | - nvidia-curand-cu12==10.3.2.106 174 | - nvidia-cusolver-cu12==11.4.5.107 175 | - nvidia-cusparse-cu12==12.1.0.106 176 | - nvidia-nccl-cu12==2.19.3 177 | - nvidia-nvjitlink-cu12==12.6.77 178 | - nvidia-nvtx-cu12==12.1.105 179 | - opt-einsum==3.4.0 180 | - pandas==2.2.3 181 | - patsy==0.5.6 182 | - plottable==0.1.5 183 | - pynndescent==0.5.13 184 | - pyparsing==3.2.0 185 | - python-dateutil==2.9.0.post0 186 | - pytz==2024.2 187 | - rich==13.9.3 188 | - scanpy==1.10.3 189 | - scib-metrics==0.5.1 190 | - scikit-learn==1.5.2 191 | - scipy==1.13.1 192 | - seaborn==0.13.2 193 | - session-info==1.0.0 194 | - statsmodels==0.14.4 195 | - stdlib-list==0.11.0 196 | - texttable==1.7.0 197 | - threadpoolctl==3.5.0 198 | - toolz==1.0.0 199 | - torch==2.2.1 200 | - torchdata==0.7.1 201 | - torchtext==0.17.1 202 | - tqdm==4.66.5 203 | - triton==2.2.0 204 | - tzdata==2024.2 205 | - umap-learn==0.5.6 206 | - zipp==3.20.2 207 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from . import embedding, cancergpt 2 | -------------------------------------------------------------------------------- /model/assets/args.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_bins": 51, 3 | "max_seq_len": 1200, 4 | "trunc_by_sample": true, 5 | "nlayers": 6, 6 | "nheads": 8, 7 | "embsize": 256, 8 | "d_hid": 512, 9 | "dropout": 0.2, 10 | "n_layers_cls": 3 11 | } -------------------------------------------------------------------------------- /model/cancergpt.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import math 3 | from typing import Dict, List, Mapping, Optional, Tuple, Any, Union 4 | import warnings 5 | 6 | import torch 7 | from torch import nn, Tensor 8 | from torch.nn import TransformerEncoder, TransformerEncoderLayer 9 | 10 | 11 | class CancerGPT(nn.Module): 12 | def __init__( 13 | self, 14 | ntoken: int, 15 | d_model: int, 16 | nhead: int, 17 | d_hid: int, 18 | nlayers: int, 19 | vocab: Any = None, 20 | dropout: float = 0.5, 21 | pad_token: str = "", 22 | n_input_bins: Optional[int] = None, 23 | ): 24 | super().__init__() 25 | self.d_model = d_model 26 | self.n_input_bins = n_input_bins 27 | 28 | self.gene_encoder = GeneEncoder( 29 | ntoken, d_model, padding_idx=vocab[pad_token]) 30 | self.value_encoder = ContinuousValueEncoder(d_model, dropout) 31 | 32 | encoder_layers = TransformerEncoderLayer( 33 | d_model, nhead, d_hid, dropout, batch_first=True 34 | ) 35 | self.encoder = TransformerEncoder( 36 | encoder_layers, nlayers) 37 | 38 | def encode( 39 | self, 40 | src: Tensor, 41 | values: Tensor, 42 | src_key_padding_mask: Tensor, 43 | ) -> Tensor: 44 | src = self.gene_encoder(src) 45 | 46 | values = self.value_encoder(values) 47 | 48 | total_embs = src + values 49 | 50 | output = self.encoder( 51 | total_embs, src_key_padding_mask=src_key_padding_mask 52 | ) 53 | 54 | return output 55 | 56 | 57 | class GeneEncoder(nn.Module): 58 | def __init__( 59 | self, 60 | num_embeddings: int, 61 | embedding_dim: int, 62 | padding_idx: Optional[int] = None, 63 | ): 64 | super().__init__() 65 | self.embedding = nn.Embedding( 66 | num_embeddings, embedding_dim, padding_idx=padding_idx 67 | ) 68 | self.enc_norm = nn.LayerNorm(embedding_dim) 69 | 70 | def forward(self, x: Tensor) -> Tensor: 71 | x = self.embedding(x) # (batch, seq_len, embsize) 72 | x = self.enc_norm(x) 73 | return x 74 | 75 | 76 | class ContinuousValueEncoder(nn.Module): 77 | """ 78 | Encode real number values to a vector using neural nets projection. 79 | """ 80 | 81 | def __init__(self, d_model: int, dropout: float = 0.1, max_value: int = 512): 82 | super().__init__() 83 | self.dropout = nn.Dropout(p=dropout) 84 | self.linear1 = nn.Linear(1, d_model) 85 | self.activation = nn.ReLU() 86 | self.linear2 = nn.Linear(d_model, d_model) 87 | self.norm = nn.LayerNorm(d_model) 88 | self.max_value = max_value 89 | 90 | def forward(self, x: Tensor) -> Tensor: 91 | """ 92 | Args: 93 | x: Tensor, shape [batch_size, seq_len] 94 | """ 95 | # expand last dimension 96 | x = x.unsqueeze(-1) 97 | # clip x to [-inf, max_value] 98 | x = torch.clamp(x, max=self.max_value) 99 | x = self.activation(self.linear1(x)) 100 | x = self.linear2(x) 101 | x = self.norm(x) 102 | return self.dropout(x) 103 | -------------------------------------------------------------------------------- /model/data_collator.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Dict, List, Mapping, Optional, Tuple, Union 3 | 4 | import torch 5 | import numpy as np 6 | 7 | 8 | def _digitize(x: np.ndarray, bins: np.ndarray, side="both") -> np.ndarray: 9 | """ 10 | Digitize the data into bins. This method spreads data uniformly when bins 11 | have same values. 12 | 13 | Args: 14 | 15 | x (:class:`np.ndarray`): 16 | The data to digitize. 17 | bins (:class:`np.ndarray`): 18 | The bins to use for digitization, in increasing order. 19 | side (:class:`str`, optional): 20 | The side to use for digitization. If "one", the left side is used. If 21 | "both", the left and right side are used. Default to "one". 22 | 23 | Returns: 24 | 25 | :class:`np.ndarray`: 26 | The digitized data. 27 | """ 28 | assert x.ndim == 1 and bins.ndim == 1 29 | 30 | left_digits = np.digitize(x, bins) 31 | if side == "one": 32 | return left_digits 33 | 34 | right_difits = np.digitize(x, bins, right=True) 35 | 36 | rands = np.random.rand(len(x)) # uniform random numbers 37 | 38 | digits = rands * (right_difits - left_digits) + left_digits 39 | digits = np.ceil(digits).astype(np.int64) 40 | return digits 41 | 42 | 43 | def binning( 44 | row: Union[np.ndarray, torch.Tensor], n_bins: int 45 | ) -> Union[np.ndarray, torch.Tensor]: 46 | """Binning the row into n_bins.""" 47 | dtype = row.dtype 48 | return_np = False if isinstance(row, torch.Tensor) else True 49 | row = row.cpu().numpy() if isinstance(row, torch.Tensor) else row 50 | # TODO: use torch.quantile and torch.bucketize 51 | 52 | if row.max() == 0: 53 | logger.warning( 54 | "The input data contains row of zeros. Please make sure this is expected." 55 | ) 56 | return ( 57 | np.zeros_like(row, dtype=dtype) 58 | if return_np 59 | else torch.zeros_like(row, dtype=dtype) 60 | ) 61 | 62 | if row.min() <= 0: 63 | non_zero_ids = row.nonzero() 64 | non_zero_row = row[non_zero_ids] 65 | bins = np.quantile(non_zero_row, np.linspace(0, 1, n_bins - 1)) 66 | non_zero_digits = _digitize(non_zero_row, bins) 67 | binned_row = np.zeros_like(row, dtype=np.int64) 68 | binned_row[non_zero_ids] = non_zero_digits 69 | else: 70 | bins = np.quantile(row, np.linspace(0, 1, n_bins - 1)) 71 | binned_row = _digitize(row, bins) 72 | return torch.from_numpy(binned_row) if not return_np else binned_row.astype(dtype) 73 | 74 | 75 | @dataclass 76 | class DataCollator: 77 | """ 78 | Data collator for the mask value learning task. It pads the sequences to 79 | the maximum length in the batch and masks the gene expression values. 80 | 81 | Args: 82 | do_padding (:obj:`bool`): whether to pad the sequences to the max length. 83 | pad_token_id (:obj:`int`, optional): the token id to use for padding. 84 | This is required if do_padding is True. 85 | pad_value (:obj:`int`): the value to use for padding the expression 86 | values to the max length. 87 | do_mlm (:obj:`bool`): whether to do masking with MLM. 88 | do_binning (:obj:`bool`): whether to bin the expression values. 89 | mlm_probability (:obj:`float`): the probability of masking with MLM. 90 | mask_value (:obj:`int`): the value to fill at the expression postions 91 | that are masked. 92 | max_length (:obj:`int`, optional): the maximum length of the sequences. 93 | This is required if do_padding is True. 94 | sampling (:obj:`bool`): whether to do sampling instead of truncation if 95 | length > max_length. 96 | keep_first_n_tokens (:obj:`int`): the number of tokens in the beginning 97 | of the sequence to keep unchanged from sampling. This is useful when 98 | special tokens have been added to the beginning of the sequence. 99 | Default to 1. 100 | """ 101 | 102 | do_padding: bool = True 103 | pad_token_id: Optional[int] = None 104 | pad_value: int = 0 105 | do_mlm: bool = True 106 | do_binning: bool = True 107 | mlm_probability: float = 0.15 108 | mask_value: int = -1 109 | max_length: Optional[int] = None 110 | sampling: bool = True 111 | keep_first_n_tokens: int = 1 112 | 113 | def __post_init__(self): 114 | if self.do_padding: 115 | if self.pad_token_id is None: 116 | raise ValueError("`pad_token_id` is required if `do_padding`.") 117 | if self.max_length is None: 118 | raise ValueError("`max_length` is required if `do_padding`.") 119 | 120 | if self.mlm_probability <= 0 or self.mlm_probability >= 1: 121 | raise ValueError("`mlm_probability` must be between 0 and 1.") 122 | 123 | if self.keep_first_n_tokens < 0 or self.keep_first_n_tokens > self.max_length: 124 | raise ValueError( 125 | "`keep_first_n_tokens` must be between 0 and `max_length` " 126 | f"({self.max_length})." 127 | ) 128 | 129 | def __call__( 130 | self, examples: List[Dict[str, torch.Tensor]] 131 | ) -> Dict[str, torch.Tensor]: 132 | """ 133 | Each example is like: 134 | {'id': tensor(184117), 135 | 'genes': tensor([36572, 17868, ..., 17072]), 136 | 'expressions': tensor([ 0., 2., ..., 18.])} 137 | """ 138 | if not isinstance(examples[0], Mapping): 139 | return NotImplementedError 140 | 141 | device = examples[0]["genes"].device 142 | 143 | max_ori_len = max(len(example["genes"]) for example in examples) 144 | _max_length = self.max_length if max_ori_len >= self.max_length else max_ori_len 145 | 146 | # pad and truncate 147 | padded_genes = [] 148 | padded_expressions = [] 149 | for i in range(len(examples)): 150 | genes = examples[i]["genes"] 151 | expressions = examples[i]["expressions"] 152 | if self.do_binning: 153 | expressions[self.keep_first_n_tokens:] = binning( 154 | row=expressions[self.keep_first_n_tokens:], 155 | n_bins=51, 156 | ) 157 | genes, expressions = self._sample_or_truncate_plus_pad( 158 | genes, expressions, _max_length 159 | ) # torch tensors of length _max_length 160 | padded_genes.append(genes) 161 | padded_expressions.append(expressions) 162 | 163 | padded_genes = torch.stack(padded_genes, dim=0).to(device) 164 | padded_expressions = torch.stack(padded_expressions, dim=0).to(device) 165 | 166 | data_dict = { 167 | "gene": padded_genes, 168 | "expr": padded_expressions, 169 | } 170 | 171 | # mask 172 | if self.do_mlm: 173 | masked_expressions = self._mask(padded_expressions) 174 | else: 175 | masked_expressions = padded_expressions 176 | data_dict["masked_expr"] = masked_expressions 177 | 178 | return data_dict 179 | 180 | def _mask(self, expressions: torch.Tensor) -> torch.Tensor: 181 | """ 182 | Mask the expression values with MLM. 183 | """ 184 | device = expressions.device 185 | shape = expressions.shape 186 | 187 | probability_matrix = torch.full(shape, self.mlm_probability) 188 | # set padded postion probability to 0 189 | probability_matrix[expressions.eq(self.pad_value)] = 0 190 | if self.keep_first_n_tokens > 0: 191 | probability_matrix[:, : self.keep_first_n_tokens] = 0 192 | 193 | mask = torch.bernoulli(probability_matrix).bool() 194 | mask = mask.to(device) 195 | 196 | masked_expressions = expressions.masked_fill(mask, self.mask_value) 197 | return masked_expressions 198 | 199 | def _sample_or_truncate_plus_pad( 200 | self, 201 | genes: torch.LongTensor, 202 | expressions: torch.Tensor, 203 | max_length: int, 204 | ) -> Tuple[torch.LongTensor, torch.Tensor]: 205 | assert len(genes) == len(expressions) 206 | if len(genes) == max_length: 207 | return genes, expressions 208 | if len(genes) > max_length: # sample or truncate 209 | if self.sampling: 210 | return self._sample(genes, expressions, max_length) 211 | else: 212 | return genes[:max_length], expressions[:max_length] 213 | else: # pad 214 | return self._pad(genes, expressions, max_length) 215 | 216 | def _sample( 217 | self, 218 | genes: torch.LongTensor, 219 | expressions: torch.Tensor, 220 | max_length: int, 221 | ) -> Tuple[torch.LongTensor, torch.Tensor]: 222 | # NOTE: the fastest way to sample in torch has been benchmarked here 223 | # https://discuss.pytorch.org/t/torch-equivalent-of-numpy-random-choice/16146/19 224 | # it shows the randperm on gpu is the fastest. 225 | # NOTE: also, the current implementation permute the orders of the genes 226 | # and expressions, although it is probably a nice argmentation. 227 | device = genes.device 228 | if self.keep_first_n_tokens == 0: 229 | indices = torch.randperm(len(genes), device=device)[:max_length] 230 | return genes[indices], expressions[indices] 231 | 232 | # keep the first n tokens unchanged 233 | _n = self.keep_first_n_tokens 234 | indices = torch.randperm( 235 | len(genes) - _n, device=device)[: max_length - _n] 236 | indices = torch.cat([torch.arange(_n), indices + _n], dim=0) 237 | return genes[indices], expressions[indices] 238 | 239 | def _pad( 240 | self, 241 | genes: torch.LongTensor, 242 | expressions: torch.Tensor, 243 | max_length: int, 244 | ): 245 | device = genes.device 246 | genes = torch.cat( 247 | [ 248 | genes, 249 | torch.full( 250 | (max_length - len(genes),), 251 | self.pad_token_id, 252 | dtype=genes.dtype, 253 | device=device, 254 | ), 255 | ] 256 | ) 257 | expressions = torch.cat( 258 | [ 259 | expressions, 260 | torch.full( 261 | (max_length - len(expressions),), 262 | self.pad_value, 263 | dtype=expressions.dtype, 264 | device=device, 265 | ), 266 | ] 267 | ) 268 | return genes, expressions 269 | -------------------------------------------------------------------------------- /model/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class Dataset(torch.utils.data.Dataset): 6 | def __init__(self, count_matrix, gene_ids, vocab, pad_value, batch_ids=None): 7 | self.count_matrix = count_matrix 8 | self.gene_ids = gene_ids 9 | self.batch_ids = batch_ids 10 | self.vocab = vocab 11 | self.pad_value = pad_value 12 | 13 | def __len__(self): 14 | return len(self.count_matrix) 15 | 16 | def __getitem__(self, idx): 17 | row = self.count_matrix[idx] 18 | values = row 19 | genes = self.gene_ids 20 | genes = np.insert(genes, 0, self.vocab[""]) 21 | values = np.insert(values, 0, self.pad_value) 22 | genes = torch.from_numpy(genes).long() 23 | values = torch.from_numpy(values).float() 24 | output = { 25 | "id": idx, 26 | "genes": genes, 27 | "expressions": values, 28 | } 29 | if self.batch_ids is not None: 30 | output["batch_labels"] = self.batch_ids[idx] 31 | return output 32 | -------------------------------------------------------------------------------- /model/embedding.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from pathlib import Path 3 | from typing import Optional, Union 4 | from anndata import AnnData 5 | import os 6 | from model.cancergpt import CancerGPT 7 | from model.data_collator import DataCollator 8 | from model.dataset import Dataset 9 | 10 | from .utils import load_pretrained 11 | from model.vocab import GeneVocab 12 | import scanpy as sc 13 | import torch 14 | import numpy as np 15 | import json 16 | from torch.utils.data import DataLoader, SequentialSampler 17 | PathLike = Union[str, os.PathLike] 18 | 19 | 20 | def embed( 21 | adata_or_file: Union[AnnData, PathLike], 22 | model_dir: PathLike, 23 | batch_key: Optional[str] = None, 24 | max_length: int = 1200, 25 | batch_size: int = 64, 26 | obs_to_save: Optional[list] = None, 27 | device: Union[str, torch.device] = "cuda", 28 | normalize: bool = True, 29 | ) -> AnnData: 30 | if isinstance(adata_or_file, AnnData): 31 | adata = adata_or_file 32 | else: 33 | adata = sc.read_h5ad(adata_or_file) 34 | 35 | if isinstance(obs_to_save, str): 36 | assert obs_to_save in adata.obs, f"obs_to_save {obs_to_save} not in adata.obs" 37 | obs_to_save = [obs_to_save] 38 | 39 | if device == "cuda": 40 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 41 | if not torch.cuda.is_available(): 42 | print("WARNING: CUDA is not available. Using CPU instead.") 43 | 44 | # LOAD MODEL 45 | model_dir = Path(model_dir) 46 | vocab_file = model_dir / "vocab.json" 47 | model_config_file = model_dir / "args.json" 48 | model_file = model_dir / "model.pth" 49 | pad_token = "" 50 | pad_value = -2 51 | 52 | # vocabulary 53 | vocab = GeneVocab.from_file(vocab_file) 54 | 55 | adata.var["genes"] = adata.var.index 56 | 57 | adata.var["id_in_vocab"] = [ 58 | vocab[gene] if gene in vocab else -1 for gene in adata.var["genes"] 59 | ] 60 | 61 | adata = adata[:, adata.var["id_in_vocab"] >= 0] 62 | 63 | sc.pp.highly_variable_genes( 64 | adata, n_top_genes=max_length-1, flavor='cell_ranger', batch_key=batch_key) 65 | adata = adata[:, adata.var['highly_variable']] 66 | adata.var["genes"] = adata.var.index 67 | 68 | with open(model_config_file, "r") as f: 69 | model_configs = json.load(f) 70 | 71 | vocab.set_default_index(vocab[""]) 72 | genes = adata.var["genes"].tolist() 73 | gene_ids = np.array(vocab(genes), dtype=int) 74 | 75 | model = CancerGPT( 76 | ntoken=len(vocab), 77 | d_model=model_configs["embsize"], 78 | nhead=model_configs["nheads"], 79 | d_hid=model_configs["d_hid"], 80 | nlayers=model_configs["nlayers"], 81 | vocab=vocab, 82 | dropout=model_configs["dropout"], 83 | pad_token=pad_token, 84 | ) 85 | 86 | model = load_pretrained(model, torch.load( 87 | model_file, map_location=device), verbose=False) 88 | 89 | model.to(device) 90 | model.eval() 91 | 92 | count_matrix = adata.X 93 | count_matrix = ( 94 | count_matrix if isinstance( 95 | count_matrix, np.ndarray) else count_matrix.A 96 | ) 97 | 98 | dataset = Dataset( 99 | count_matrix, gene_ids, vocab, pad_value=pad_value 100 | ) 101 | collator = DataCollator( 102 | do_padding=True, 103 | pad_token_id=vocab[pad_token], 104 | pad_value=pad_value, 105 | do_mlm=False, 106 | do_binning=True, 107 | max_length=max_length, 108 | sampling=True, 109 | keep_first_n_tokens=1, 110 | ) 111 | data_loader = DataLoader( 112 | dataset, 113 | batch_size=batch_size, 114 | sampler=SequentialSampler(dataset), 115 | collate_fn=collator, 116 | drop_last=False, 117 | num_workers=min(len(os.sched_getaffinity(0)), batch_size), 118 | pin_memory=True, 119 | ) 120 | 121 | device = next(model.parameters()).device 122 | cell_embeddings = np.zeros( 123 | (len(dataset), model_configs["embsize"]), dtype=np.float32 124 | ) 125 | with torch.no_grad(), torch.cuda.amp.autocast(enabled=True): 126 | count = 0 127 | for data_dict in tqdm(data_loader, desc="Embedding cells"): 128 | input_gene_ids = data_dict["gene"].to(device) 129 | src_key_padding_mask = input_gene_ids.eq( 130 | vocab[pad_token] 131 | ) 132 | 133 | embeddings = model.encode( 134 | input_gene_ids, 135 | data_dict["expr"].to(device), 136 | src_key_padding_mask=src_key_padding_mask, 137 | ) 138 | 139 | # get the position embedding 140 | embeddings = embeddings[:, 0, :] 141 | embeddings = embeddings.cpu().numpy() 142 | cell_embeddings[count: count + len(embeddings)] = embeddings 143 | count += len(embeddings) 144 | 145 | if normalize: 146 | cell_embeddings = cell_embeddings / np.linalg.norm( 147 | cell_embeddings, axis=1, keepdims=True 148 | ) 149 | 150 | adata.obsm["CancerGPT"] = cell_embeddings 151 | return adata 152 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Mapping, Optional 2 | import torch 3 | import re 4 | 5 | 6 | def load_pretrained( 7 | model: torch.nn.Module, 8 | pretrained_params: Mapping[str, torch.Tensor], 9 | strict: bool = False, 10 | prefix: Optional[List[str]] = None, 11 | verbose: bool = True, 12 | ) -> torch.nn.Module: 13 | """ 14 | Load pretrained weights to the model. 15 | 16 | Args: 17 | model (torch.nn.Module): The model to load weights to. 18 | pretrained_params (Mapping[str, torch.Tensor]): The pretrained parameters. 19 | strict (bool): Whether to strictly enforce that the keys in :attr:`pretrained_params` 20 | match the keys returned by this module's :meth:`Module.state_dict`. Default to False. 21 | prefix (List[str]): The list of prefix strings to match with the keys in 22 | :attr:`pretrained_params`. The matched keys will be loaded. Default to None. 23 | 24 | Returns: 25 | torch.nn.Module: The model with pretrained weights. 26 | """ 27 | 28 | model_new_params = model.state_dict() 29 | 30 | not_identical = [] 31 | updated = [] 32 | for key in pretrained_params.keys(): 33 | if key in model_new_params.keys() and model_new_params[key].shape == pretrained_params[key].shape: 34 | model_new_params[key] = pretrained_params[key] 35 | updated.append(key) 36 | else: 37 | if key in model_new_params.keys(): 38 | if verbose: 39 | print( 40 | f"Same key but not same shape: {model_new_params[key].shape} : {pretrained_params[key].shape}") 41 | not_identical.append(key) 42 | 43 | def modify_string(input_string): 44 | # Pattern to check if the string has the required format 45 | check_pattern = r"^encoder\.layers\.\d+\.self_attn\.self_attn\..*" 46 | 47 | # Check if the input string matches the pattern 48 | if re.match(check_pattern, input_string): 49 | # Pattern to perform the replacement 50 | modify_pattern = r"(encoder\.layers\.\d+\.self_attn\.)self_attn\.(.*)" 51 | 52 | # Replace the matched pattern with the required modification 53 | modified_string = re.sub(modify_pattern, r"\1\2", input_string) 54 | return modified_string 55 | else: 56 | return input_string 57 | 58 | for key in not_identical[:]: 59 | key_new = modify_string(key) 60 | if key_new in model_new_params.keys() and model_new_params[key_new].shape == pretrained_params[key].shape: 61 | model_new_params[key_new] = pretrained_params[key] 62 | updated.append(key_new) 63 | not_identical.remove(key) 64 | 65 | if verbose and len(not_identical) > 0: 66 | print("Couldn't load following keys: ", not_identical) 67 | 68 | model.load_state_dict(model_new_params) 69 | return model 70 | -------------------------------------------------------------------------------- /model/vocab.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | from pathlib import Path 4 | from collections import Counter, OrderedDict 5 | from typing import Dict, Iterable, List, Optional, Tuple, Union 6 | from typing_extensions import Self 7 | 8 | import torchtext.vocab as torch_vocab 9 | from torchtext.vocab import Vocab 10 | 11 | 12 | class GeneVocab(Vocab): 13 | """ 14 | Vocabulary for genes. 15 | """ 16 | 17 | def __init__( 18 | self, 19 | gene_list_or_vocab: Union[List[str], Vocab], 20 | specials: Optional[List[str]] = None, 21 | special_first: bool = True, 22 | default_token: Optional[str] = "", 23 | ) -> None: 24 | """ 25 | Initialize the vocabulary. 26 | Note: add specials only works when init from a gene list. 27 | 28 | Args: 29 | gene_list_or_vocab (List[str] or Vocab): List of gene names or a 30 | Vocab object. 31 | specials (List[str]): List of special tokens. 32 | special_first (bool): Whether to add special tokens to the beginning 33 | of the vocabulary. 34 | default_token (str): Default token, by default will set to "", 35 | if "" is in the vocabulary. 36 | """ 37 | if isinstance(gene_list_or_vocab, Vocab): 38 | _vocab = gene_list_or_vocab 39 | if specials is not None: 40 | raise ValueError( 41 | "receive non-empty specials when init from a Vocab object." 42 | ) 43 | elif isinstance(gene_list_or_vocab, list): 44 | _vocab = self._build_vocab_from_iterator( 45 | gene_list_or_vocab, 46 | specials=specials, 47 | special_first=special_first, 48 | ) 49 | else: 50 | raise ValueError( 51 | "gene_list_or_vocab must be a list of gene names or a Vocab object." 52 | ) 53 | super().__init__(_vocab.vocab) 54 | if default_token is not None and default_token in self: 55 | self.set_default_token(default_token) 56 | 57 | @classmethod 58 | def from_file(cls, file_path: Union[Path, str]) -> Self: 59 | """ 60 | Load the vocabulary from a file. The file should be either a pickle or a 61 | json file of token to index mapping. 62 | """ 63 | if isinstance(file_path, str): 64 | file_path = Path(file_path) 65 | if file_path.suffix == ".pkl": 66 | with file_path.open("rb") as f: 67 | vocab = pickle.load(f) 68 | return cls(vocab) 69 | elif file_path.suffix == ".json": 70 | with file_path.open("r") as f: 71 | token2idx = json.load(f) 72 | return cls.from_dict(token2idx) 73 | else: 74 | raise ValueError( 75 | f"{file_path} is not a valid file type. " 76 | "Only .pkl and .json are supported." 77 | ) 78 | 79 | @classmethod 80 | def from_dict( 81 | cls, 82 | token2idx: Dict[str, int], 83 | default_token: Optional[str] = "", 84 | ) -> Self: 85 | """ 86 | Load the vocabulary from a dictionary. 87 | 88 | Args: 89 | token2idx (Dict[str, int]): Dictionary mapping tokens to indices. 90 | """ 91 | # initiate an empty vocabulary first 92 | _vocab = cls([]) 93 | 94 | # add the tokens to the vocabulary, GeneVocab requires consecutive indices 95 | for t, i in sorted(token2idx.items(), key=lambda x: x[1]): 96 | _vocab.insert_token(t, i) 97 | 98 | if default_token is not None and default_token in _vocab: 99 | _vocab.set_default_token(default_token) 100 | 101 | return _vocab 102 | 103 | def _build_vocab_from_iterator( 104 | self, 105 | iterator: Iterable, 106 | min_freq: int = 1, 107 | specials: Optional[List[str]] = None, 108 | special_first: bool = True, 109 | ) -> Vocab: 110 | """ 111 | Build a Vocab from an iterator. This function is modified from 112 | torchtext.vocab.build_vocab_from_iterator. The original function always 113 | splits tokens into characters, which is not what we want. 114 | 115 | Args: 116 | iterator (Iterable): Iterator used to build Vocab. Must yield list 117 | or iterator of tokens. 118 | min_freq (int): The minimum frequency needed to include a token in 119 | the vocabulary. 120 | specials (List[str]): Special symbols to add. The order of supplied 121 | tokens will be preserved. 122 | special_first (bool): Whether to add special tokens to the beginning 123 | 124 | Returns: 125 | torchtext.vocab.Vocab: A `Vocab` object 126 | """ 127 | 128 | counter = Counter() 129 | counter.update(iterator) 130 | 131 | if specials is not None: 132 | for tok in specials: 133 | del counter[tok] 134 | 135 | sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[0]) 136 | sorted_by_freq_tuples.sort(key=lambda x: x[1], reverse=True) 137 | ordered_dict = OrderedDict(sorted_by_freq_tuples) 138 | 139 | if specials is not None: 140 | if special_first: 141 | specials = specials[::-1] 142 | for symbol in specials: 143 | ordered_dict.update({symbol: min_freq}) 144 | ordered_dict.move_to_end(symbol, last=not special_first) 145 | 146 | word_vocab = torch_vocab.vocab(ordered_dict, min_freq=min_freq) 147 | return word_vocab 148 | 149 | @property 150 | def pad_token(self) -> Optional[str]: 151 | """ 152 | Get the pad token. 153 | """ 154 | if getattr(self, "_pad_token", None) is None: 155 | self._pad_token = None 156 | return self._pad_token 157 | 158 | @pad_token.setter 159 | def pad_token(self, pad_token: str) -> None: 160 | """ 161 | Set the pad token. Will not add the pad token to the vocabulary. 162 | 163 | Args: 164 | pad_token (str): Pad token, should be in the vocabulary. 165 | """ 166 | if pad_token not in self: 167 | raise ValueError(f"{pad_token} is not in the vocabulary.") 168 | self._pad_token = pad_token 169 | 170 | def save_json(self, file_path: Union[Path, str]) -> None: 171 | """ 172 | Save the vocabulary to a json file. 173 | """ 174 | if isinstance(file_path, str): 175 | file_path = Path(file_path) 176 | with file_path.open("w") as f: 177 | json.dump(self.get_stoi(), f, indent=2) 178 | 179 | def set_default_token(self, default_token: str) -> None: 180 | """ 181 | Set the default token. 182 | 183 | Args: 184 | default_token (str): Default token. 185 | """ 186 | if default_token not in self: 187 | raise ValueError(f"{default_token} is not in the vocabulary.") 188 | self.set_default_index(self[default_token]) 189 | -------------------------------------------------------------------------------- /training_data/preprocessing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 6, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pathlib\n", 10 | "import scanpy as sc\n", 11 | "import pandas as pd" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 11, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "metadata = pd.read_csv(\"study_overview.csv\", index_col=0)" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 21, 26 | "metadata": {}, 27 | "outputs": [ 28 | { 29 | "data": { 30 | "text/html": [ 31 | "
\n", 32 | "\n", 45 | "\n", 46 | " \n", 47 | " \n", 48 | " \n", 49 | " \n", 50 | " \n", 51 | " \n", 52 | " \n", 53 | " \n", 54 | " \n", 55 | " \n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | "
TissueArticleDiseaseTechnologyN_samplesSourceauthor_last_nameyearProject IDIncludeReasonindex
Spalte 1
Alchahin_2022_kidneykidneyhttps://pubmed.ncbi.nlm.nih.gov/36180422\\n3828...RCC10x15.0CancerSCEMAlchahin2022ccRCC-061YesNaNAlchahin_kidney
Aubin_2022_brainbrainhttps://pubmed.ncbi.nlm.nih.gov/35803925/Pediatric ependymomaDrop-seq8.0CancerSCEMAubin2022EPN-102YesNaNAubin_brain
Aynaud_2020_othermodelsothermodelshttps://www.sciencedirect.com/science/article/...Ewing sarcoma10x, Fluidigm C114.03caAynaud2020CCCA-32YesNaNAynaud_othermodels
Azizi_2018_breastbreasthttps://www.sciencedirect.com/science/article/...Breast cancer10x, InDrop61.03caAzizi2018CCCA-65YesNaNAzizi_breast
Barrett_2024_brainbrainhttps://pubmed.ncbi.nlm.nih.gov/38216553/Sporadic Vestibular Schwannomas10x15.0CancerSCEMBarrett2024VS-105YesNaNBarrett_brain
.......................................
Travaglini_2020_lunglunghttps://doi.org/10.1016/j.ccell.2022.10.021Lung CancerNaNNaNCellxGene(Salcher)NaN2020CG-43YesNaNTravaglini_lung
Vieira_2019_lunglunghttps://doi.org/10.1016/j.ccell.2022.10.022Lung CancerNaNNaNCellxGene(Salcher)NaN2019CG-44YesNaNVieira_lung
Wu_2021_lunglunghttps://doi.org/10.1016/j.ccell.2022.10.023Lung CancerGEXSCOPE technologyNaNCellxGene(Salcher)NaN2021CG-45YesNaNWu_lung
Zilionis_2019_lunglunghttps://doi.org/10.1016/j.ccell.2022.10.024Lung CancerNaNNaNCellxGene(Salcher)NaN2019CG-46NoDublicated 3caZilionis_lung
Leader_2021_lunglunghttps://doi.org/10.1016/j.ccell.2022.10.025Lung CancerNaNNaNCellxGene(Salcher)NaN2021CG-47YesNaNLeader_lung
\n", 246 | "

276 rows × 12 columns

\n", 247 | "
" 248 | ], 249 | "text/plain": [ 250 | " Tissue \\\n", 251 | "Spalte 1 \n", 252 | "Alchahin_2022_kidney kidney \n", 253 | "Aubin_2022_brain brain \n", 254 | "Aynaud_2020_othermodels othermodels \n", 255 | "Azizi_2018_breast breast \n", 256 | "Barrett_2024_brain brain \n", 257 | "... ... \n", 258 | "Travaglini_2020_lung lung \n", 259 | "Vieira_2019_lung lung \n", 260 | "Wu_2021_lung lung \n", 261 | "Zilionis_2019_lung lung \n", 262 | "Leader_2021_lung lung \n", 263 | "\n", 264 | " Article \\\n", 265 | "Spalte 1 \n", 266 | "Alchahin_2022_kidney https://pubmed.ncbi.nlm.nih.gov/36180422\\n3828... \n", 267 | "Aubin_2022_brain https://pubmed.ncbi.nlm.nih.gov/35803925/ \n", 268 | "Aynaud_2020_othermodels https://www.sciencedirect.com/science/article/... \n", 269 | "Azizi_2018_breast https://www.sciencedirect.com/science/article/... \n", 270 | "Barrett_2024_brain https://pubmed.ncbi.nlm.nih.gov/38216553/ \n", 271 | "... ... \n", 272 | "Travaglini_2020_lung https://doi.org/10.1016/j.ccell.2022.10.021 \n", 273 | "Vieira_2019_lung https://doi.org/10.1016/j.ccell.2022.10.022 \n", 274 | "Wu_2021_lung https://doi.org/10.1016/j.ccell.2022.10.023 \n", 275 | "Zilionis_2019_lung https://doi.org/10.1016/j.ccell.2022.10.024 \n", 276 | "Leader_2021_lung https://doi.org/10.1016/j.ccell.2022.10.025 \n", 277 | "\n", 278 | " Disease Technology \\\n", 279 | "Spalte 1 \n", 280 | "Alchahin_2022_kidney RCC 10x \n", 281 | "Aubin_2022_brain Pediatric ependymoma Drop-seq \n", 282 | "Aynaud_2020_othermodels Ewing sarcoma 10x, Fluidigm C1 \n", 283 | "Azizi_2018_breast Breast cancer 10x, InDrop \n", 284 | "Barrett_2024_brain Sporadic Vestibular Schwannomas 10x \n", 285 | "... ... ... \n", 286 | "Travaglini_2020_lung Lung Cancer NaN \n", 287 | "Vieira_2019_lung Lung Cancer NaN \n", 288 | "Wu_2021_lung Lung Cancer GEXSCOPE technology \n", 289 | "Zilionis_2019_lung Lung Cancer NaN \n", 290 | "Leader_2021_lung Lung Cancer NaN \n", 291 | "\n", 292 | " N_samples Source author_last_name year \\\n", 293 | "Spalte 1 \n", 294 | "Alchahin_2022_kidney 15.0 CancerSCEM Alchahin 2022 \n", 295 | "Aubin_2022_brain 8.0 CancerSCEM Aubin 2022 \n", 296 | "Aynaud_2020_othermodels 14.0 3ca Aynaud 2020 \n", 297 | "Azizi_2018_breast 61.0 3ca Azizi 2018 \n", 298 | "Barrett_2024_brain 15.0 CancerSCEM Barrett 2024 \n", 299 | "... ... ... ... ... \n", 300 | "Travaglini_2020_lung NaN CellxGene(Salcher) NaN 2020 \n", 301 | "Vieira_2019_lung NaN CellxGene(Salcher) NaN 2019 \n", 302 | "Wu_2021_lung NaN CellxGene(Salcher) NaN 2021 \n", 303 | "Zilionis_2019_lung NaN CellxGene(Salcher) NaN 2019 \n", 304 | "Leader_2021_lung NaN CellxGene(Salcher) NaN 2021 \n", 305 | "\n", 306 | " Project ID Include Reason index \n", 307 | "Spalte 1 \n", 308 | "Alchahin_2022_kidney ccRCC-061 Yes NaN Alchahin_kidney \n", 309 | "Aubin_2022_brain EPN-102 Yes NaN Aubin_brain \n", 310 | "Aynaud_2020_othermodels CCCA-32 Yes NaN Aynaud_othermodels \n", 311 | "Azizi_2018_breast CCCA-65 Yes NaN Azizi_breast \n", 312 | "Barrett_2024_brain VS-105 Yes NaN Barrett_brain \n", 313 | "... ... ... ... ... \n", 314 | "Travaglini_2020_lung CG-43 Yes NaN Travaglini_lung \n", 315 | "Vieira_2019_lung CG-44 Yes NaN Vieira_lung \n", 316 | "Wu_2021_lung CG-45 Yes NaN Wu_lung \n", 317 | "Zilionis_2019_lung CG-46 No Dublicated 3ca Zilionis_lung \n", 318 | "Leader_2021_lung CG-47 Yes NaN Leader_lung \n", 319 | "\n", 320 | "[276 rows x 12 columns]" 321 | ] 322 | }, 323 | "execution_count": 21, 324 | "metadata": {}, 325 | "output_type": "execute_result" 326 | } 327 | ], 328 | "source": [ 329 | "metadata" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": 22, 335 | "metadata": {}, 336 | "outputs": [], 337 | "source": [ 338 | "metadata[\"index\"] = metadata.index.str.split(\"_\").str[0] + \"_\" + metadata.index.str.split(\"_\").str[2]\n", 339 | "metadata = metadata[(metadata[\"Source\"]==\"3ca\") & (metadata[\"Include\"]==\"Yes\")]" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": 30, 345 | "metadata": {}, 346 | "outputs": [ 347 | { 348 | "name": "stdout", 349 | "output_type": "stream", 350 | "text": [ 351 | "Aynaud_2020_othermodels\n", 352 | "Azizi_2018_breast\n", 353 | "Chen_2021_colorectal\n", 354 | "Cillo_2020_head-and-neck\n", 355 | "Cui_2021_othermodels\n", 356 | "Dost_2020_othermodels\n", 357 | "Durante_2020_othermodels\n", 358 | "Ebinger_2016_othermodels\n", 359 | "Ferrari de Andrade_2019_skin\n", 360 | "Franses_2017_othermodels\n", 361 | "Franses_2020_othermodels\n", 362 | "Gaiti_2019_hematologic\n", 363 | "Gaydosik_2019_hematologic\n", 364 | "Gonzalez_2022_othermodels\n", 365 | "Gulati_2020_breast\n", 366 | "Guo_2018_lung\n", 367 | "Hollern_2019_othermodels\n", 368 | "Hovestadt_2019_brain\n", 369 | "Hwang_2022_pancreas\n", 370 | "Ireland_2020_lung\n", 371 | "Jansky_2021_othermodels\n", 372 | "Jerby-Arnon_2021_sarcoma\n", 373 | "Jordan_2016_breast\n", 374 | "Jordan_2016_othermodels\n", 375 | "Kumar_2022_othermodels\n", 376 | "Li_2019_skin\n", 377 | "Ligorio_2019_othermodels\n", 378 | "Mahuron_2020_skin\n", 379 | "Massalha_2020_othermodels\n", 380 | "Mercatelli_2021_othermodels\n", 381 | "Miyamoto_2015_othermodels\n", 382 | "Pal_2021_breast\n", 383 | "Pelka_2021_colorectal\n", 384 | "Raghavan_2021_othermodels\n", 385 | "Rendeiro_2020_hematologic\n", 386 | "Riether_2020_hematologic\n", 387 | "Sade-Feldman_2018_skin\n", 388 | "Savas_2018_breast\n", 389 | "Sharma_2020_liverbiliary\n", 390 | "Sun_2021_othermodels\n", 391 | "Tang-Huau_2018_ovarian\n", 392 | "Yao_2020_othermodels\n", 393 | "Zhang_2018_colorectal\n", 394 | "Zhang_2019_liverbiliary\n", 395 | "Zhang_2019_othermodels\n", 396 | "Zheng_2017_liverbiliary\n" 397 | ] 398 | } 399 | ], 400 | "source": [ 401 | "for idx, row in metadata.iterrows():\n", 402 | " if list(pathlib.Path(\"data/3ca\").glob(f\"{row['index']}*\")):\n", 403 | " pass \n", 404 | " else:\n", 405 | " print(idx)\n", 406 | " " 407 | ] 408 | }, 409 | { 410 | "cell_type": "code", 411 | "execution_count": 27, 412 | "metadata": {}, 413 | "outputs": [ 414 | { 415 | "data": { 416 | "text/plain": [ 417 | "[]" 418 | ] 419 | }, 420 | "execution_count": 27, 421 | "metadata": {}, 422 | "output_type": "execute_result" 423 | } 424 | ], 425 | "source": [] 426 | }, 427 | { 428 | "cell_type": "code", 429 | "execution_count": null, 430 | "metadata": {}, 431 | "outputs": [], 432 | "source": [] 433 | } 434 | ], 435 | "metadata": { 436 | "kernelspec": { 437 | "display_name": "cancerfoundation", 438 | "language": "python", 439 | "name": "python3" 440 | }, 441 | "language_info": { 442 | "codemirror_mode": { 443 | "name": "ipython", 444 | "version": 3 445 | }, 446 | "file_extension": ".py", 447 | "mimetype": "text/x-python", 448 | "name": "python", 449 | "nbconvert_exporter": "python", 450 | "pygments_lexer": "ipython3", 451 | "version": "3.12.8" 452 | } 453 | }, 454 | "nbformat": 4, 455 | "nbformat_minor": 2 456 | } 457 | -------------------------------------------------------------------------------- /tutorial/README.md: -------------------------------------------------------------------------------- 1 | ## Tutorial on generating embeddings 2 | Download a scRNA-seq glioblastoma dataset [1] from [this link](https://polybox.ethz.ch/index.php/s/Q9hNm5oTgHh5tJW) to exemplify the generation of embeddings as outlined in ```embeddings_tutorial.ipynb```. 3 | 4 | 5 | ## References 6 | [1] Neftel, Cyril, Julie Laffy, Mariella G. Filbin, Toshiro Hara, Marni E. Shore, Gilbert J. Rahme, Alyssa R. Richman, et al. "An integrative model of cellular states, plasticity, and genetics for glioblastoma." _Cell_ 178, no. 4 (2019): 835–849. Elsevier. -------------------------------------------------------------------------------- /tutorial/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoevaLab/CancerFoundation/130375b3d0f6de09279df0f76cd2c0e4df9b89cf/tutorial/__init__.py -------------------------------------------------------------------------------- /tutorial/embeddings_tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "sys.path.insert(0, \"../\")\n", 11 | "import scanpy as sc\n", 12 | "from model.embedding import embed" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 3, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "model_dir = \"../model/assets\"\n", 22 | "adata_path = \"../data/neftel_ss2.h5ad\" # INSERT the path to your anndata object here\n", 23 | "\n", 24 | "adata = sc.read_h5ad(adata_path)\n", 25 | "batch_key = \"sample\" # The batch identity is used for highly variable gene selection\n", 26 | "bio_key = \"subtype\"" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "embed_adata = embed(\n", 36 | " adata_or_file=adata,\n", 37 | " model_dir=model_dir,\n", 38 | " batch_key=batch_key,\n", 39 | " batch_size=64,\n", 40 | ")\n" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "sc.pp.neighbors(embed_adata, use_rep=\"CancerGPT\")\n", 50 | "sc.tl.umap(embed_adata)\n", 51 | "fig = sc.pl.umap(embed_adata,\n", 52 | " color=[bio_key],\n", 53 | " frameon=False,\n", 54 | " palette=sc.pl.palettes.default_20,\n", 55 | " legend_loc=None,\n", 56 | " return_fig=True,\n", 57 | " title=[\"Subtype\"])" 58 | ] 59 | } 60 | ], 61 | "metadata": { 62 | "kernelspec": { 63 | "display_name": "cancergpt2", 64 | "language": "python", 65 | "name": "python3" 66 | }, 67 | "language_info": { 68 | "codemirror_mode": { 69 | "name": "ipython", 70 | "version": 3 71 | }, 72 | "file_extension": ".py", 73 | "mimetype": "text/x-python", 74 | "name": "python", 75 | "nbconvert_exporter": "python", 76 | "pygments_lexer": "ipython3", 77 | "version": "3.9.12" 78 | } 79 | }, 80 | "nbformat": 4, 81 | "nbformat_minor": 2 82 | } 83 | -------------------------------------------------------------------------------- /zero_shot_batch_integration/README.md: -------------------------------------------------------------------------------- 1 | ## Zero-shot batch integration 2 | 3 | ### Installation 4 | Refer to the instructions in the parent directory in regards to the conda environment, i.e., `../README.md`. 5 | 6 | #### Step-by-step guide 7 | 8 | 1. **Download baseline data**: 9 | 10 | Please download the scRNA-seq glioblastoma dataset [1], as well as the baseline embeddings from [this link](https://polybox.ethz.ch/index.php/s/ddfx3WfKLYkK712), and unzip it. Alternatively, you can generate the embeddings yourself by following the instructions in the scGPT github repository [here](https://github.com/bowang-lab/scGPT/blob/main/tutorials/zero-shot/Tutorial_ZeroShot_Integration.ipynb). 11 | 12 | 2. **Generate CancerGPT embeddings**: 13 | 14 | Run the following command to generate CancerGPT embeddings: 15 | ```bash 16 | python generate_embedding.py 17 | ``` 18 | 19 | 3. **Plot the results**: 20 | 21 | Refer to the `plot.ipynb` notebook to generate plots for the different embeddings. 22 | 23 | ## References 24 | [1] Neftel, Cyril, Julie Laffy, Mariella G. Filbin, Toshiro Hara, Marni E. Shore, Gilbert J. Rahme, Alyssa R. Richman, et al. "An integrative model of cellular states, plasticity, and genetics for glioblastoma." _Cell_ 178, no. 4 (2019): 835–849. Elsevier. -------------------------------------------------------------------------------- /zero_shot_batch_integration/generate_embedding.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, "../") 3 | from model.embedding import embed 4 | import scanpy as sc 5 | import os 6 | 7 | model_dir = "../model/assets" 8 | adata_path = "./data/neftel_ss2.h5ad" 9 | 10 | adata = sc.read_h5ad(adata_path) 11 | batch_key = "sample" # The batch identity is used for highly variable gene selection 12 | bio_key = "subtype" 13 | 14 | embed_adata = embed( 15 | adata_or_file=adata, 16 | model_dir=model_dir, 17 | batch_key=batch_key, 18 | batch_size=64, 19 | ) 20 | os.makedirs("./data", exist_ok=True) 21 | embed_adata.write_h5ad("./data/CancerGPT_neftel_ss2.h5ad") 22 | --------------------------------------------------------------------------------