├── .gitignore ├── LICENSE ├── LangCell-CE-annotation ├── fewshot.py ├── finetune.py └── utils.py ├── LangCell-annotation-fewshot ├── fewshot.py └── utils.py ├── LangCell-annotation-zeroshot ├── utils.py └── zero-shot.ipynb ├── README.md ├── assets └── image.png ├── data_preprocess ├── preprocess.py └── utils.py ├── geneformer_001 ├── MANIFEST.in ├── build │ └── lib │ │ └── geneformer │ │ ├── __init__.py │ │ ├── collator_for_classification.py │ │ ├── emb_extractor.py │ │ ├── gene_median_dictionary.pkl │ │ ├── gene_name_id_dict.pkl │ │ ├── in_silico_perturber.py │ │ ├── in_silico_perturber_stats.py │ │ ├── pretrainer.py │ │ ├── token_dictionary.pkl │ │ └── tokenizer.py ├── geneformer.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ ├── requires.txt │ └── top_level.txt ├── geneformer │ ├── __init__.py │ ├── collator_for_classification.py │ ├── emb_extractor.py │ ├── gene_median_dictionary.pkl │ ├── gene_name_id_dict.pkl │ ├── in_silico_perturber.py │ ├── in_silico_perturber_stats.py │ ├── pretrainer.py │ ├── token_dictionary.pkl │ └── tokenizer.py └── setup.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | */__pycache__ 2 | */ckpt 3 | */ckpts 4 | */data 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Suyuan Zhao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LangCell-CE-annotation/fewshot.py: -------------------------------------------------------------------------------- 1 | # %% 2 | # imports 3 | from collections import Counter 4 | import pickle 5 | import subprocess 6 | 7 | from datasets import load_from_disk 8 | from sklearn.metrics import accuracy_score, f1_score 9 | from transformers import BertForSequenceClassification 10 | from transformers import Trainer 11 | from transformers.training_args import TrainingArguments 12 | import argparse 13 | from utils import LangCellDataCollatorForCellClassification as DataCollatorForCellClassification 14 | import os 15 | 16 | # %% 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--model_path", type=str, default='/path/to/') 19 | parser.add_argument("--data_path", type=str, default='/path/to/') 20 | parser.add_argument("--output_path", type=str, default=None) 21 | parser.add_argument("--epochs", type=int, default=20) 22 | parser.add_argument("--nshot", type=int, default=1) 23 | parser.add_argument("--device", type=int, default=0) 24 | args = parser.parse_args() 25 | model_path = args.model_path 26 | data_path = args.data_path 27 | output_path = './output/' + model_path.split('/')[-1] + '/fewshot/' + data_path.split('/')[-1].split('.')[0] + '/' + str(args.nshot) + '/' 28 | output_path = args.output_path if args.output_path else output_path 29 | subprocess.call(f'mkdir {output_path}', shell=True) 30 | epochs = args.epochs 31 | nshot = args.nshot 32 | 33 | GPU_NUMBER = [args.device] 34 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER]) 35 | os.environ["NCCL_DEBUG"] = "INFO" 36 | # %% 37 | dataset=load_from_disk(data_path) 38 | 39 | trainset_organ_shuffled = dataset.shuffle(seed=1) 40 | for label_name in ["celltype", "cell_type", "str_labels", "labels"]: 41 | if label_name in trainset_organ_shuffled.column_names: 42 | break 43 | trainset_organ_shuffled = trainset_organ_shuffled.rename_column(label_name,"label") 44 | target_names = list(Counter(trainset_organ_shuffled["label"]).keys()) 45 | target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))])) 46 | 47 | # change labels to numerical ids 48 | def classes_to_ids(example): 49 | example["label"] = target_name_id_dict[example["label"]] 50 | return example 51 | labeled_trainset = trainset_organ_shuffled.map(classes_to_ids, num_proc=16) 52 | 53 | # n-shot trainset 54 | label_num = len(target_name_id_dict.keys()) 55 | type2trainlist = {} 56 | for i in range(label_num): 57 | type2trainlist[i] = [] 58 | if nshot >= 1: 59 | for i, l in enumerate(labeled_trainset["label"]): 60 | if len(type2trainlist[l]) < nshot: 61 | type2trainlist[l].append(i) 62 | br = True 63 | for k in type2trainlist.keys(): 64 | if len(type2trainlist[k]) < nshot: 65 | br = False 66 | break 67 | if br: 68 | break 69 | train_idx = [] 70 | for k in type2trainlist.keys(): 71 | train_idx += type2trainlist[k] 72 | test_idx = list(set(range(len(labeled_trainset))) - set(train_idx)) 73 | 74 | labeled_train_split = labeled_trainset.select(train_idx).shuffle(42) 75 | labeled_eval_split = labeled_trainset.select(test_idx).shuffle(42) 76 | labeled_eval_split_subset = labeled_eval_split 77 | 78 | train_set = labeled_train_split 79 | eval_set = labeled_eval_split_subset 80 | label_name_id = target_name_id_dict 81 | # %% 82 | def compute_metrics(pred): 83 | labels = pred.label_ids 84 | preds = pred.predictions.argmax(-1) 85 | # calculate accuracy and macro f1 using sklearn's function 86 | acc = accuracy_score(labels, preds) 87 | macro_f1 = f1_score(labels, preds, average='macro') 88 | return { 89 | 'accuracy': acc, 90 | 'macro_f1': macro_f1 91 | } 92 | 93 | # %% 94 | # set model parameters 95 | # max input size 96 | max_input_size = 2 ** 11 # 2048 97 | max_lr = 5e-3 98 | freeze_layers = 0 99 | num_gpus = 1 100 | num_proc = 16 101 | geneformer_batch_size = 16 102 | epochs = epochs 103 | optimizer = "adamw" 104 | 105 | # %% 106 | import torch.nn as nn 107 | organ_trainset = train_set 108 | organ_evalset = eval_set 109 | organ_label_dict = label_name_id 110 | 111 | # set logging steps 112 | steps_per_epoch = round(len(organ_trainset)/geneformer_batch_size) 113 | if steps_per_epoch == 0: 114 | steps_per_epoch = 1 115 | logging_steps = steps_per_epoch * 5 116 | 117 | # reload pretrained model 118 | model = BertForSequenceClassification.from_pretrained(model_path, 119 | num_labels=len(organ_label_dict.keys()), 120 | output_attentions = False, 121 | output_hidden_states = False).cuda() 122 | # 冻结模型 123 | for name, param in model.named_parameters(): 124 | if "classifier" not in name: 125 | param.requires_grad = False 126 | 127 | from geneformer import TranscriptomeTokenizer 128 | tk = TranscriptomeTokenizer() 129 | config = model.bert.config 130 | if config.vocab_size == len(tk.gene_token_dict) - 1: 131 | embedding_layer = nn.Embedding(config.vocab_size + 1, config.hidden_size, padding_idx=config.pad_token_id) 132 | for param, param_pretrain in zip(embedding_layer.parameters(), model.bert.embeddings.word_embeddings.parameters()): 133 | param.data[:-1] = param_pretrain.data 134 | model.bert.embeddings.word_embeddings = embedding_layer 135 | elif config.vocab_size != len(tk.gene_token_dict): 136 | raise Exception("Vocab size does not match.") 137 | 138 | 139 | # define output directory path 140 | output_dir = output_path 141 | 142 | # ensure not overwriting previously saved model 143 | saved_model_test = os.path.join(output_dir, f"pytorch_model.bin") 144 | if os.path.isfile(saved_model_test) == True: 145 | raise Exception("Model already saved to this directory.") 146 | 147 | # make output directory 148 | subprocess.call(f'mkdir {output_dir}', shell=True) 149 | 150 | # set training arguments 151 | training_args = { 152 | "learning_rate": max_lr, 153 | "do_train": True, 154 | "do_eval": True, 155 | "evaluation_strategy": "steps", 156 | "eval_steps": logging_steps, 157 | "save_total_limit": 1, 158 | "logging_steps": logging_steps, 159 | "group_by_length": True, 160 | "length_column_name": "length", 161 | "disable_tqdm": False, 162 | # "lr_scheduler_type": lr_schedule_fn, 163 | # "warmup_steps": warmup_steps, 164 | "weight_decay": 0.001, 165 | "per_device_train_batch_size": geneformer_batch_size, 166 | "per_device_eval_batch_size": geneformer_batch_size, 167 | "num_train_epochs": epochs, 168 | "load_best_model_at_end": False, 169 | "output_dir": output_dir, 170 | } 171 | 172 | training_args_init = TrainingArguments(**training_args) 173 | 174 | # create the trainer 175 | trainer = Trainer( 176 | model=model, 177 | args=training_args_init, 178 | data_collator=DataCollatorForCellClassification(), 179 | train_dataset=organ_trainset, 180 | eval_dataset=organ_evalset, 181 | compute_metrics=compute_metrics 182 | ) 183 | # train the cell type classifier 184 | 185 | # %% 186 | trainer.train() 187 | predictions = trainer.predict(organ_evalset) 188 | with open(f"{output_dir}predictions.pickle", "wb") as fp: 189 | pickle.dump(predictions, fp) 190 | trainer.save_metrics("eval",predictions.metrics) 191 | trainer.save_model(output_dir) 192 | 193 | 194 | -------------------------------------------------------------------------------- /LangCell-CE-annotation/finetune.py: -------------------------------------------------------------------------------- 1 | # %% 2 | # imports 3 | from collections import Counter 4 | import pickle 5 | import subprocess 6 | 7 | from datasets import load_from_disk 8 | from sklearn.metrics import accuracy_score, f1_score 9 | from transformers import BertForSequenceClassification 10 | from transformers import Trainer 11 | from transformers.training_args import TrainingArguments 12 | import argparse 13 | from utils import LangCellDataCollatorForCellClassification as DataCollatorForCellClassification 14 | import os 15 | 16 | # %% 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--model_path", type=str, default='/path/to/') 19 | parser.add_argument("--data_path", type=str, default='/path/to/') 20 | parser.add_argument("--output_path", type=str, default=None) 21 | parser.add_argument("--epochs", type=int, default=20) 22 | parser.add_argument("--test_size", type=float, default=0.33) 23 | parser.add_argument("--device", type=int, default=0) 24 | args = parser.parse_args() 25 | model_path = args.model_path 26 | data_path = args.data_path 27 | output_path = './output/' + model_path.split('/')[-1] + '/' + data_path.split('/')[-1].split('.')[0] + '/' 28 | output_path = args.output_path if args.output_path else output_path 29 | epochs = args.epochs 30 | test_size = args.test_size 31 | 32 | GPU_NUMBER = [args.device] 33 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER]) 34 | os.environ["NCCL_DEBUG"] = "INFO" 35 | # %% 36 | dataset=load_from_disk(data_path) 37 | 38 | trainset_organ_shuffled = dataset.shuffle(seed=1) 39 | for label_name in ["celltype", "cell_type", "str_labels", "labels"]: 40 | if label_name in trainset_organ_shuffled.column_names: 41 | break 42 | trainset_organ_shuffled = trainset_organ_shuffled.rename_column(label_name,"label") 43 | target_names = list(Counter(trainset_organ_shuffled["label"]).keys()) 44 | target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))])) 45 | 46 | # change labels to numerical ids 47 | def classes_to_ids(example): 48 | example["label"] = target_name_id_dict[example["label"]] 49 | return example 50 | labeled_trainset = trainset_organ_shuffled.map(classes_to_ids, num_proc=16) 51 | train_size = round(len(labeled_trainset)*(1-test_size)) 52 | labeled_train_split = labeled_trainset.select([i for i in range(0, train_size)]) 53 | labeled_eval_split = labeled_trainset.select([i for i in range(train_size, len(labeled_trainset))]) 54 | trained_labels = list(Counter(labeled_train_split["label"]).keys()) 55 | 56 | def if_trained_label(example): 57 | return example["label"] in trained_labels 58 | labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=16) 59 | 60 | train_set = labeled_train_split 61 | eval_set = labeled_eval_split_subset 62 | label_name_id = target_name_id_dict 63 | # %% 64 | def compute_metrics(pred): 65 | labels = pred.label_ids 66 | preds = pred.predictions.argmax(-1) 67 | # calculate accuracy and macro f1 using sklearn's function 68 | acc = accuracy_score(labels, preds) 69 | macro_f1 = f1_score(labels, preds, average='macro') 70 | return { 71 | 'accuracy': acc, 72 | 'macro_f1': macro_f1 73 | } 74 | 75 | # %% 76 | # set model parameters 77 | # max input size 78 | max_input_size = 2 ** 11 # 2048 79 | max_lr = 5e-5 80 | freeze_layers = 0 81 | num_gpus = 1 82 | num_proc = 16 83 | geneformer_batch_size = 16 84 | lr_schedule_fn = "linear" 85 | warmup_steps = 500 86 | epochs = epochs 87 | optimizer = "adamw" 88 | 89 | # %% 90 | import torch.nn as nn 91 | organ_trainset = train_set 92 | organ_evalset = eval_set 93 | organ_label_dict = label_name_id 94 | 95 | # set logging steps 96 | logging_steps = round(len(organ_trainset)/geneformer_batch_size/2) 97 | 98 | # reload pretrained model 99 | model = BertForSequenceClassification.from_pretrained(model_path, 100 | num_labels=len(organ_label_dict.keys()), 101 | output_attentions = False, 102 | output_hidden_states = False).cuda() 103 | 104 | from geneformer import TranscriptomeTokenizer 105 | tk = TranscriptomeTokenizer() 106 | config = model.bert.config 107 | if config.vocab_size == len(tk.gene_token_dict) - 1: 108 | embedding_layer = nn.Embedding(config.vocab_size + 1, config.hidden_size, padding_idx=config.pad_token_id) 109 | for param, param_pretrain in zip(embedding_layer.parameters(), model.bert.embeddings.word_embeddings.parameters()): 110 | param.data[:-1] = param_pretrain.data 111 | model.bert.embeddings.word_embeddings = embedding_layer 112 | elif config.vocab_size != len(tk.gene_token_dict): 113 | raise Exception("Vocab size does not match.") 114 | 115 | 116 | # define output directory path 117 | output_dir = output_path 118 | 119 | # ensure not overwriting previously saved model 120 | saved_model_test = os.path.join(output_dir, f"pytorch_model.bin") 121 | if os.path.isfile(saved_model_test) == True: 122 | raise Exception("Model already saved to this directory.") 123 | 124 | # make output directory 125 | subprocess.call(f'mkdir {output_dir}', shell=True) 126 | 127 | # set training arguments 128 | training_args = { 129 | "learning_rate": max_lr, 130 | "do_train": True, 131 | "do_eval": True, 132 | "evaluation_strategy": "epoch", 133 | "save_total_limit": 1, 134 | "logging_steps": logging_steps, 135 | "group_by_length": True, 136 | "length_column_name": "length", 137 | "disable_tqdm": False, 138 | "lr_scheduler_type": lr_schedule_fn, 139 | "warmup_steps": warmup_steps, 140 | "weight_decay": 0.001, 141 | "per_device_train_batch_size": geneformer_batch_size, 142 | "per_device_eval_batch_size": geneformer_batch_size, 143 | "num_train_epochs": epochs, 144 | "load_best_model_at_end": False, 145 | "output_dir": output_dir, 146 | } 147 | 148 | training_args_init = TrainingArguments(**training_args) 149 | 150 | # create the trainer 151 | trainer = Trainer( 152 | model=model, 153 | args=training_args_init, 154 | data_collator=DataCollatorForCellClassification(), 155 | train_dataset=organ_trainset, 156 | eval_dataset=organ_evalset, 157 | compute_metrics=compute_metrics 158 | ) 159 | # train the cell type classifier 160 | 161 | # %% 162 | trainer.train() 163 | predictions = trainer.predict(organ_evalset) 164 | with open(f"{output_dir}predictions.pickle", "wb") as fp: 165 | pickle.dump(predictions, fp) 166 | trainer.save_metrics("eval",predictions.metrics) 167 | trainer.save_model(output_dir) 168 | 169 | 170 | -------------------------------------------------------------------------------- /LangCell-annotation-fewshot/fewshot.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | from datasets import load_from_disk 4 | import torch.nn as nn, torch.nn.functional as F 5 | import torch, json 6 | from transformers import BertTokenizer, BertModel 7 | from utils import BertModel as MedBertModel 8 | from utils import LangCellDataCollatorForCellClassification as DataCollatorForCellClassification 9 | from tqdm import tqdm 10 | from torch.utils.data import DataLoader 11 | import argparse 12 | import subprocess 13 | from sklearn.metrics import accuracy_score, f1_score 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--data_path", type=str, default="/path/to/") 17 | parser.add_argument("--output_path", type=str, default=None) 18 | parser.add_argument("--epochs", type=int, default=10) 19 | parser.add_argument("--train_batchsize", type=int, default=12) 20 | parser.add_argument("--test_batchsize", type=int, default=64) 21 | parser.add_argument("--nshot", type=int, default=1) 22 | parser.add_argument("--seed", type=int, default=2024) 23 | parser.add_argument("--device", type=int, default=0) 24 | args = parser.parse_args() 25 | # model_path = args.model_path 26 | data_path = args.data_path 27 | epochs = args.epochs 28 | train_batchsize = args.train_batchsize 29 | test_batchsize = args.test_batchsize 30 | seed = args.seed 31 | nshot = args.nshot 32 | output_path = 'output/ctm_' + str(nshot) + '-shot/' 33 | output_path = args.output_path if args.output_path else output_path 34 | subprocess.call(f'mkdir {output_path}', shell=True) 35 | 36 | GPU_NUMBER = [args.device] 37 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER]) 38 | 39 | # %% 40 | class Pooler(nn.Module): 41 | def __init__(self, config, pretrained_proj, proj_dim): 42 | super().__init__() 43 | self.proj = nn.Linear(config.hidden_size, proj_dim) 44 | self.proj.load_state_dict(torch.load(pretrained_proj)) 45 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 46 | pooled_output = hidden_states[:, 0] 47 | pooled_output = F.normalize(self.proj(pooled_output), dim=-1) 48 | return pooled_output 49 | 50 | model = BertModel.from_pretrained('/path/to/') 51 | model.pooler = Pooler(model.config, pretrained_proj='/path/to/', proj_dim=256) 52 | proj = model.pooler.proj 53 | # model = model.module 54 | model = model.to("cuda") 55 | 56 | text_pretrained_model = 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext' 57 | tokenizer = BertTokenizer.from_pretrained(text_pretrained_model) 58 | tokenizer.add_special_tokens({'bos_token':'[DEC]'}) 59 | tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']}) 60 | tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] 61 | text_encoder = MedBertModel.from_pretrained('/path/to/', add_pooling_layer=True) 62 | text_encoder.pooler = Pooler(text_encoder.config, pretrained_proj='/path/to/', proj_dim=256) 63 | text_encoder = text_encoder.to("cuda") 64 | 65 | ctm_head = nn.Linear(text_encoder.config.hidden_size, 2) 66 | ctm_head.load_state_dict(torch.load('/path/to/')) 67 | ctm_head = ctm_head.to("cuda") 68 | 69 | def text_encode(text): 70 | text = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt').to('cuda') 71 | text = text_encoder(**text).pooler_output 72 | # text = F.normalize(model.text_projector(text)) 73 | return text 74 | 75 | def cell_encode(cell_input_ids, cell_atts): 76 | cell = model(cell_input_ids.to("cuda"), cell_atts.to("cuda")) 77 | cell_last_h = cell.last_hidden_state 78 | cell_pooler = cell.pooler_output 79 | return cell_last_h, cell_pooler 80 | 81 | def ctm(text, cell_emb, cell_atts): 82 | # n texts, n cells -> n scores 83 | text = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt').to('cuda') 84 | output = text_encoder(**text, 85 | encoder_hidden_states = cell_emb.to("cuda"), 86 | encoder_attention_mask = cell_atts.to("cuda"), 87 | return_dict = True, 88 | mode = 'multimodal', 89 | ) 90 | logits = ctm_head(output.last_hidden_state[:, 0, :]) 91 | logits = F.softmax(logits, dim=-1)[..., 1] # [n] 92 | return logits 93 | 94 | # %% 95 | dataset = load_from_disk(data_path) 96 | dataset_sub = dataset.shuffle(seed)#.select(range(300)) 97 | for label_name in ["celltype", "cell_type", "str_labels", "labels"]: 98 | if label_name in dataset_sub.column_names: 99 | break 100 | if label_name != "celltype": 101 | dataset_sub = dataset_sub.rename_column(label_name,"celltype") 102 | dataset_sub = dataset_sub.filter(lambda example: example['celltype'] != 'Other') 103 | 104 | ontology = json.load(open('/path/to/')) 105 | name2id = {ontology[id]['name']: id for id in ontology} 106 | def gettextfromname(name, discription=True): 107 | ontology_name = name.lower() 108 | id = name2id[ontology_name.lower()] 109 | s = "cell type: " 110 | s += ontology[id]['name'] + '. ' 111 | if discription: 112 | if ontology[id]['def'] != []: 113 | s += ontology[id]['def'] + '; ' 114 | return s 115 | 116 | types = list(set(dataset_sub['celltype'])) 117 | texts = [gettextfromname(typename) for typename in types] 118 | type2num = dict([(type, i) for i, type in enumerate(types)]) 119 | 120 | # %% 121 | def classes_to_ids(example): 122 | example["label"] = type2num[example["celltype"]] 123 | return example 124 | dataset = dataset_sub.map(classes_to_ids, num_proc=16) 125 | dataset = dataset.remove_columns(['celltype', 'length']) 126 | 127 | # split 128 | label_num = len(type2num.keys()) 129 | type2trainlist = {} 130 | for i in range(label_num): 131 | type2trainlist[i] = [] 132 | if nshot >= 1: 133 | for i, l in enumerate(dataset["label"]): 134 | if len(type2trainlist[l]) < nshot: 135 | type2trainlist[l].append(i) 136 | br = True 137 | for k in type2trainlist.keys(): 138 | if len(type2trainlist[k]) < nshot: 139 | br = False 140 | break 141 | if br: 142 | break 143 | train_idx = [] 144 | for k in type2trainlist.keys(): 145 | train_idx += type2trainlist[k] 146 | test_idx = list(set(range(len(dataset))) - set(train_idx)) 147 | 148 | traindataset = dataset.select(train_idx).shuffle(seed) 149 | testdataset = dataset.select(test_idx).shuffle(seed) 150 | # traindataset, train_ind = extract_data_based_on_class(dataset, train_cls) 151 | # testdataset, test_ind = extract_data_based_on_class(dataset, test_cls) 152 | 153 | # train_batchsize, test_batchsize = 4, 64 154 | eval_num = 500 155 | train_loader = DataLoader(traindataset, batch_size=train_batchsize, 156 | collate_fn=DataCollatorForCellClassification(), shuffle=False) 157 | test_loader = DataLoader(testdataset, batch_size=test_batchsize, 158 | collate_fn=DataCollatorForCellClassification(), shuffle=False) 159 | eval_loader = DataLoader(testdataset.select(range(eval_num)), batch_size=test_batchsize, 160 | collate_fn=DataCollatorForCellClassification(), shuffle=False) 161 | 162 | # %% 163 | model.train() 164 | text_encoder.train() 165 | loss_fn = nn.CrossEntropyLoss() 166 | optimizer1 = torch.optim.Adam(model.parameters(), lr=1e-5) 167 | optimizer2 = torch.optim.Adam(text_encoder.parameters(), lr=1e-5) 168 | 169 | for epoch in range(epochs): 170 | print('epoch:', epoch) 171 | for i, d in tqdm(enumerate(train_loader)): 172 | model.train() 173 | text_encoder.train() 174 | text_embs = torch.cat([text_encode(text) for text in texts], 0).T.cuda() 175 | cell_last_h, cellemb = cell_encode(d['input_ids'], d['attention_mask']) # batchsize * 256 176 | # text_embs: 256 * class_num 177 | sim = (cellemb @ text_embs) / 0.05 # batchsize * class_num 178 | loss_sim = loss_fn(sim, d['labels'].cuda()) 179 | 180 | ctm_logit = torch.zeros_like(sim) 181 | for text_idx, text in enumerate(texts): 182 | text_list = [text] * sim.shape[0] 183 | ctm_logit[:, text_idx] = ctm(text_list, cell_last_h, d['attention_mask']) 184 | loss_ctm = loss_fn(ctm_logit, d['labels'].cuda()) 185 | 186 | loss = loss_sim + loss_ctm 187 | optimizer1.zero_grad() 188 | optimizer2.zero_grad() 189 | loss.backward() 190 | optimizer1.step() 191 | optimizer2.step() 192 | 193 | 194 | # %% 195 | cell_embs = torch.zeros(len(testdataset), 256) 196 | model.eval() 197 | text_encoder.eval() 198 | preds = torch.zeros(len(testdataset)) 199 | sim_logits = torch.zeros(len(testdataset), text_embs.shape[-1]) 200 | ctm_logits = torch.zeros(len(testdataset), text_embs.shape[-1]) 201 | logits = torch.zeros(len(testdataset), text_embs.shape[-1]) 202 | labels = torch.tensor(testdataset['label']) 203 | text_embs = torch.cat([text_encode(text) for text in texts], 0).T.cuda() 204 | with torch.no_grad(): 205 | for i, d in tqdm(enumerate(test_loader)): 206 | cell_last_h, cellemb = cell_encode(d['input_ids'], d['attention_mask']) # batchsize * 256 207 | sim = (cellemb @ text_embs) / 0.05 # batchsize * class_num 208 | sim_logit = F.softmax(sim, dim=-1) 209 | 210 | # ctm 211 | ctm_logit = torch.zeros_like(sim_logit) 212 | for text_idx, text in enumerate(texts): 213 | text_list = [text] * sim_logit.shape[0] 214 | ctm_logit[:, text_idx] = ctm(text_list, cell_last_h, d['attention_mask']) 215 | ctm_logit = F.softmax(ctm_logit, dim=-1) 216 | 217 | sim_logits[i * test_batchsize: (i + 1) * test_batchsize] = sim_logit.cpu() 218 | ctm_logits[i * test_batchsize: (i + 1) * test_batchsize] = ctm_logit.cpu() 219 | logit = (sim_logit + ctm_logit) / 2 220 | pred = logit.argmax(dim=-1) 221 | logits[i * test_batchsize: (i + 1) * test_batchsize] = logit.cpu() 222 | cell_embs[i * test_batchsize: (i + 1) * test_batchsize] = cellemb.cpu() 223 | preds[i * test_batchsize: (i + 1) * test_batchsize] = pred.cpu() 224 | 225 | torch.save({'cell_embs': cell_embs, 'text_embs': text_embs, 226 | 'sim_logits': sim_logits, 'ctm_logits': ctm_logits, 227 | 'preds': preds, 'labels': labels, 'logits': logits}, 228 | output_path + 'result.pt') 229 | 230 | # %% 231 | 232 | from sklearn.metrics import f1_score, accuracy_score 233 | 234 | for k in [0, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]: 235 | preds_k = (k * sim_logits + (1 - k) * ctm_logits).argmax(dim=-1) 236 | print(k, '\n', accuracy_score(labels, preds_k), f1_score(labels, preds_k, average='macro')) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LangCell: Language-Cell Pre-training for Cell Identity Understanding 2 | 3 | 4 | Cell identity encompasses various semantic aspects of a cell, including cell type, pathway information, disease information, and more, which are essential for biologists to gain insights into its biological characteristics. Understanding cell identity from the transcriptomic data, such as annotating cell types, have become an important task in bioinformatics. 5 | As these semantic aspects are determined by human experts, it is impossible for AI models to effectively carry out cell identity understanding tasks without the supervision signals provided by single-cell and label pairs. 6 | The single-cell pre-trained language models (PLMs) currently used for this task are trained only on a single modality, transcriptomics data, lack an understanding of cell identity knowledge. As a result, they have to be fine-tuned for downstream tasks and struggle when lacking labeled data with the desired semantic labels. 7 | To address this issue, we propose an innovative solution by constructing a unified representation of single-cell data and natural language during the pre-training phase, allowing the model to directly incorporate insights related to cell identity. 8 | More specifically, we introduce **LangCell**, the first **Lang**uage-**Cell** pre-training framework. 9 | LangCell utilizes texts enriched with cell identity information to gain a profound comprehension of cross-modal knowledge. 10 | Results from experiments conducted on different benchmarks show that LangCell is the only single-cell PLM that can work effectively in zero-shot cell identity understanding scenarios, and also significantly outperforms existing models in few-shot and fine-tuning cell identity understanding scenarios. 11 | 12 | More information can be found at [https://arxiv.org/abs/2405.06708](https://arxiv.org/abs/2405.06708). 13 | 14 | LangCell will soon be added to the OpenBioMed toolkit: [https://github.com/PharMolix/OpenBioMed](https://github.com/PharMolix/OpenBioMed). 15 | 16 | ![LangCell](assets/image.png) 17 | 18 | # News 19 | - [2024/12/30] Released pre-training dataset [**scLibrary**](https://huggingface.co/datasets/Toycat/scLibrary/tree/main) 20 | 21 | # Install 22 | 23 | [![python >3.9.18](https://img.shields.io/badge/python-3.9.18-brightgreen)](https://www.python.org/) 24 | ``` 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | # Checkpoint 29 | 30 | The model's checkpoint is divided into five modules: text_bert, cell_bert, text_proj, cell_proj, and ctm_head. Users can select and load the necessary modules according to the downstream task requirements. Among them, cell_bert is the standard Huggingface BertModel; text_bert is a multifunctional encoder provided in utils.py; cell_proj and text_proj are linear layers that map the model outputs corresponding to the [CLS] position in cells and text to a unified feature space; and ctm_head is a linear layer that maps the output of text_bert to matching scores when performing Cell-Text Matching. For specific loading methods, please refer to the usage in `LangCell-annotation-zeroshot/zero-shot.ipynb`. 31 | 32 | [Download checkpoint](https://drive.google.com/drive/folders/1cuhVG9v0YoAnjW-t_WMpQQguajumCBTp) 33 | 34 | # Pre-training Dataset 35 | 36 | We constructed a cell-text dataset, **scLibrary**, containing 27.5 million scRNA-seq entries along with their descriptions. Specifically, we obtained raw scRNA-seq data and corresponding metadata from CELLxGENE. We selected eight critical aspects of cell identity that could contain essential insights, including cell type, developmental stage and disease information, to obtain as comprehensive descriptions as possible from the Open Biological and Biomedical Ontology Foundry (OBO Foundry). 37 | 38 | [Download scLibrary](https://huggingface.co/datasets/Toycat/scLibrary/tree/main) 39 | 40 | # Usage 41 | 42 | - **Data preprocess** 43 | Similar to the example in `data_preprocess/preprocess.py`, you can use `scanpy` to read any single-cell data and process it into a format accepted by the model. The processing method is similar to `Geneformer`. For more detailed instructions, please refer to [Geneformer's tokenizing scRNAseq data example](https://huggingface.co/ctheodoris/Geneformer/blob/main/examples/tokenizing_scRNAseq_data.ipynb). 44 | 45 | 46 | - **LangCell zero-shot cell type annotation** 47 | We strongly recommend that users unfamiliar with LangCell start by experiencing this core task to quickly understand the features and usage of LangCell. We have prepared a [demo dataset](https://drive.google.com/drive/folders/1cuhVG9v0YoAnjW-t_WMpQQguajumCBTp?usp=sharing) for this task; you just need to download the dataset and run `LangCell-annotation-zeroshot/zero-shot.ipynb`. 48 | 49 | - **LangCell few-shot cell type annotation** 50 | LangCell's performance can be further enhanced by performing few-shot training on a very small amount of data. You can run the code using the following commands: 51 | ``` 52 | cd LangCell-annotation-fewshot/ 53 | 54 | python fewshot.py --data_path [data_path] --model_path [model_path] --nshot [nshot] --device [device] 55 | ``` 56 | 57 | - **LangCell-CE cell type annotation** 58 | Experiments have proven that fine-tuning using only LangCell's Cell Encoder (LangCell-CE) can also achieve excellent performance on downstream tasks. You can run fine-tuning and few-shot experiments with LangCell-CE using the following commands: 59 | ``` 60 | cd LangCell-CE-annotation/ 61 | 62 | python finetune.py --data_path [data_path] --model_path [model_path] --device [device] 63 | 64 | python fewshot.py --data_path [data_path] --model_path [model_path] --nshot [nshot] --device [device] 65 | ``` 66 | 67 | - **Textual descriptions of cell identities** 68 | We have uploaded the OBO Foundry file "obo.json" [here](https://drive.google.com/drive/folders/1cuhVG9v0YoAnjW-t_WMpQQguajumCBTp), which contains textual descriptions of common cell identities. You can use these as examples to write textual descriptions for new cell types. 69 | 70 | - **We will update more experimental code for LangCell in the future.** 71 | 72 | 73 | # Citation 74 | If you find LangCell helpful to your research, please consider giving this repository a 🌟star and 📎citing the following article. Thank you for your support! 75 | ``` 76 | @misc{zhao2024langcell, 77 | title={LangCell: Language-Cell Pre-training for Cell Identity Understanding}, 78 | author={Suyuan Zhao and Jiahuan Zhang and Yizhen Luo and Yushuai Wu and Zaiqing Nie}, 79 | year={2024}, 80 | eprint={2405.06708}, 81 | archivePrefix={arXiv}, 82 | primaryClass={q-bio.GN} 83 | } 84 | ``` 85 | -------------------------------------------------------------------------------- /assets/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PharMolix/LangCell/69e41ef2fae485b67703294b50f3150bb1d5bb9b/assets/image.png -------------------------------------------------------------------------------- /data_preprocess/preprocess.py: -------------------------------------------------------------------------------- 1 | from utils import LangCellTranscriptomeTokenizer 2 | import scanpy as sc 3 | 4 | data = sc.read_h5ad('/path/to/adata.h5ad') 5 | data.obs['n_counts'] = data.X.sum(axis=1) 6 | data.var['ensembl_id'] = data.var['feature_id'] 7 | 8 | tk = LangCellTranscriptomeTokenizer(dict([(k, k) for k in data.obs.keys()]), nproc=4) 9 | tokenized_cells, cell_metadata = tk.tokenize_anndata(data) 10 | tokenized_dataset = tk.create_dataset(tokenized_cells, cell_metadata) 11 | 12 | tokenized_dataset.save_to_disk('/path/to/tokenized_dataset') -------------------------------------------------------------------------------- /geneformer_001/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include geneformer/gene_median_dictionary.pkl 2 | include geneformer/token_dictionary.pkl 3 | include geneformer/gene_name_id_dict.pkl 4 | -------------------------------------------------------------------------------- /geneformer_001/build/lib/geneformer/__init__.py: -------------------------------------------------------------------------------- 1 | from . import tokenizer 2 | from . import pretrainer 3 | from . import collator_for_classification 4 | from . import in_silico_perturber 5 | from . import in_silico_perturber_stats 6 | from .tokenizer import TranscriptomeTokenizer 7 | from .pretrainer import GeneformerPretrainer 8 | from .collator_for_classification import DataCollatorForGeneClassification 9 | from .collator_for_classification import DataCollatorForCellClassification 10 | from .emb_extractor import EmbExtractor 11 | from .in_silico_perturber import InSilicoPerturber 12 | from .in_silico_perturber_stats import InSilicoPerturberStats -------------------------------------------------------------------------------- /geneformer_001/build/lib/geneformer/collator_for_classification.py: -------------------------------------------------------------------------------- 1 | """ 2 | Geneformer collator for gene and cell classification. 3 | 4 | Huggingface data collator modified to accommodate single-cell transcriptomics data for gene and cell classification. 5 | """ 6 | import numpy as np 7 | import torch 8 | import warnings 9 | from enum import Enum 10 | from typing import Dict, List, Optional, Union 11 | 12 | from transformers import ( 13 | DataCollatorForTokenClassification, 14 | SpecialTokensMixin, 15 | BatchEncoding, 16 | ) 17 | from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj 18 | from transformers.utils.generic import _is_tensorflow, _is_torch 19 | 20 | from .pretrainer import token_dictionary 21 | 22 | EncodedInput = List[int] 23 | logger = logging.get_logger(__name__) 24 | VERY_LARGE_INTEGER = int( 25 | 1e30 26 | ) # This is used to set the max input length for a model with infinite size input 27 | LARGE_INTEGER = int( 28 | 1e20 29 | ) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER 30 | 31 | # precollator functions 32 | 33 | class ExplicitEnum(Enum): 34 | """ 35 | Enum with more explicit error message for missing values. 36 | """ 37 | 38 | @classmethod 39 | def _missing_(cls, value): 40 | raise ValueError( 41 | "%r is not a valid %s, please select one of %s" 42 | % (value, cls.__name__, str(list(cls._value2member_map_.keys()))) 43 | ) 44 | 45 | class TruncationStrategy(ExplicitEnum): 46 | """ 47 | Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for 48 | tab-completion in an IDE. 49 | """ 50 | 51 | ONLY_FIRST = "only_first" 52 | ONLY_SECOND = "only_second" 53 | LONGEST_FIRST = "longest_first" 54 | DO_NOT_TRUNCATE = "do_not_truncate" 55 | 56 | 57 | 58 | class PaddingStrategy(ExplicitEnum): 59 | """ 60 | Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion 61 | in an IDE. 62 | """ 63 | 64 | LONGEST = "longest" 65 | MAX_LENGTH = "max_length" 66 | DO_NOT_PAD = "do_not_pad" 67 | 68 | 69 | 70 | class TensorType(ExplicitEnum): 71 | """ 72 | Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for 73 | tab-completion in an IDE. 74 | """ 75 | 76 | PYTORCH = "pt" 77 | TENSORFLOW = "tf" 78 | NUMPY = "np" 79 | JAX = "jax" 80 | 81 | 82 | class PrecollatorForGeneAndCellClassification(SpecialTokensMixin): 83 | mask_token = "" 84 | mask_token_id = token_dictionary.get("") 85 | pad_token = "" 86 | pad_token_id = token_dictionary.get("") 87 | padding_side = "right" 88 | all_special_ids = [ 89 | token_dictionary.get(""), 90 | token_dictionary.get("") 91 | ] 92 | model_input_names = ["input_ids"] 93 | 94 | def _get_padding_truncation_strategies( 95 | self, padding=True, truncation=False, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs 96 | ): 97 | """ 98 | Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy 99 | and pad_to_max_length) and behaviors. 100 | """ 101 | old_truncation_strategy = kwargs.pop("truncation_strategy", "do_not_truncate") 102 | old_pad_to_max_length = kwargs.pop("pad_to_max_length", False) 103 | 104 | # Backward compatibility for previous behavior, maybe we should deprecate it: 105 | # If you only set max_length, it activates truncation for max_length 106 | if max_length is not None and padding is False and truncation is False: 107 | if verbose: 108 | if not self.deprecation_warnings.get("Truncation-not-explicitly-activated", False): 109 | logger.warning( 110 | "Truncation was not explicitly activated but `max_length` is provided a specific value, " 111 | "please use `truncation=True` to explicitly truncate examples to max length. " 112 | "Defaulting to 'longest_first' truncation strategy. " 113 | "If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy " 114 | "more precisely by providing a specific strategy to `truncation`." 115 | ) 116 | self.deprecation_warnings["Truncation-not-explicitly-activated"] = True 117 | truncation = "longest_first" 118 | 119 | # Get padding strategy 120 | if padding is False and old_pad_to_max_length: 121 | if verbose: 122 | warnings.warn( 123 | "The `pad_to_max_length` argument is deprecated and will be removed in a future version, " 124 | "use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or " 125 | "use `padding='max_length'` to pad to a max length. In this case, you can give a specific " 126 | "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the " 127 | "maximal input size of the model (e.g. 512 for Bert).", 128 | FutureWarning, 129 | ) 130 | if max_length is None: 131 | padding_strategy = PaddingStrategy.LONGEST 132 | else: 133 | padding_strategy = PaddingStrategy.MAX_LENGTH 134 | elif padding is not False: 135 | if padding is True: 136 | padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch 137 | elif not isinstance(padding, PaddingStrategy): 138 | padding_strategy = PaddingStrategy(padding) 139 | elif isinstance(padding, PaddingStrategy): 140 | padding_strategy = padding 141 | else: 142 | padding_strategy = PaddingStrategy.DO_NOT_PAD 143 | 144 | # Get truncation strategy 145 | if truncation is False and old_truncation_strategy != "do_not_truncate": 146 | if verbose: 147 | warnings.warn( 148 | "The `truncation_strategy` argument is deprecated and will be removed in a future version, " 149 | "use `truncation=True` to truncate examples to a max length. You can give a specific " 150 | "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the " 151 | "maximal input size of the model (e.g. 512 for Bert). " 152 | " If you have pairs of inputs, you can give a specific truncation strategy selected among " 153 | "`truncation='only_first'` (will only truncate the first sentence in the pairs) " 154 | "`truncation='only_second'` (will only truncate the second sentence in the pairs) " 155 | "or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence in the pairs).", 156 | FutureWarning, 157 | ) 158 | truncation_strategy = TruncationStrategy(old_truncation_strategy) 159 | elif truncation is not False: 160 | if truncation is True: 161 | truncation_strategy = ( 162 | TruncationStrategy.LONGEST_FIRST 163 | ) # Default to truncate the longest sequences in pairs of inputs 164 | elif not isinstance(truncation, TruncationStrategy): 165 | truncation_strategy = TruncationStrategy(truncation) 166 | elif isinstance(truncation, TruncationStrategy): 167 | truncation_strategy = truncation 168 | else: 169 | truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE 170 | 171 | # Set max length if needed 172 | if max_length is None: 173 | if padding_strategy == PaddingStrategy.MAX_LENGTH: 174 | if self.model_max_length > LARGE_INTEGER: 175 | if verbose: 176 | if not self.deprecation_warnings.get("Asking-to-pad-to-max_length", False): 177 | logger.warning( 178 | "Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. " 179 | "Default to no padding." 180 | ) 181 | self.deprecation_warnings["Asking-to-pad-to-max_length"] = True 182 | padding_strategy = PaddingStrategy.DO_NOT_PAD 183 | else: 184 | max_length = self.model_max_length 185 | 186 | if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE: 187 | if self.model_max_length > LARGE_INTEGER: 188 | if verbose: 189 | if not self.deprecation_warnings.get("Asking-to-truncate-to-max_length", False): 190 | logger.warning( 191 | "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. " 192 | "Default to no truncation." 193 | ) 194 | self.deprecation_warnings["Asking-to-truncate-to-max_length"] = True 195 | truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE 196 | else: 197 | max_length = self.model_max_length 198 | 199 | # Test if we have a padding token 200 | if padding_strategy != PaddingStrategy.DO_NOT_PAD and (not self.pad_token or self.pad_token_id < 0): 201 | raise ValueError( 202 | "Asking to pad but the tokenizer does not have a padding token. " 203 | "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` " 204 | "or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`." 205 | ) 206 | 207 | # Check that we will truncate to a multiple of pad_to_multiple_of if both are provided 208 | if ( 209 | truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE 210 | and padding_strategy != PaddingStrategy.DO_NOT_PAD 211 | and pad_to_multiple_of is not None 212 | and max_length is not None 213 | and (max_length % pad_to_multiple_of != 0) 214 | ): 215 | raise ValueError( 216 | f"Truncation and padding are both activated but " 217 | f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})." 218 | ) 219 | 220 | return padding_strategy, truncation_strategy, max_length, kwargs 221 | 222 | def pad( 223 | self, 224 | encoded_inputs: Union[ 225 | BatchEncoding, 226 | List[BatchEncoding], 227 | Dict[str, EncodedInput], 228 | Dict[str, List[EncodedInput]], 229 | List[Dict[str, EncodedInput]], 230 | ], 231 | class_type, # options: "gene" or "cell" 232 | padding: Union[bool, str, PaddingStrategy] = True, 233 | max_length: Optional[int] = None, 234 | pad_to_multiple_of: Optional[int] = None, 235 | return_attention_mask: Optional[bool] = True, 236 | return_tensors: Optional[Union[str, TensorType]] = None, 237 | verbose: bool = True, 238 | ) -> BatchEncoding: 239 | """ 240 | Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length 241 | in the batch. 242 | 243 | Padding side (left/right) padding token ids are defined at the tokenizer level (with ``self.padding_side``, 244 | ``self.pad_token_id`` and ``self.pad_token_type_id``) 245 | 246 | .. note:: 247 | 248 | If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the 249 | result will use the same type unless you provide a different tensor type with ``return_tensors``. In the 250 | case of PyTorch tensors, you will lose the specific device of your tensors however. 251 | 252 | Args: 253 | encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`): 254 | Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str, 255 | List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str, 256 | List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as 257 | well as in a PyTorch Dataloader collate function. 258 | 259 | Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors), 260 | see the note above for the return type. 261 | padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): 262 | Select a strategy to pad the returned sequences (according to the model's padding side and padding 263 | index) among: 264 | 265 | * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a 266 | single sequence if provided). 267 | * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the 268 | maximum acceptable input length for the model if that argument is not provided. 269 | * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of 270 | different lengths). 271 | max_length (:obj:`int`, `optional`): 272 | Maximum length of the returned list and optionally padding length (see above). 273 | pad_to_multiple_of (:obj:`int`, `optional`): 274 | If set will pad the sequence to a multiple of the provided value. 275 | 276 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability 277 | >= 7.5 (Volta). 278 | return_attention_mask (:obj:`bool`, `optional`): 279 | Whether to return the attention mask. If left to the default, will return the attention mask according 280 | to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute. 281 | 282 | `What are attention masks? <../glossary.html#attention-mask>`__ 283 | return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`): 284 | If set, will return tensors instead of list of python integers. Acceptable values are: 285 | 286 | * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. 287 | * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. 288 | * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects. 289 | verbose (:obj:`bool`, `optional`, defaults to :obj:`True`): 290 | Whether or not to print more information and warnings. 291 | """ 292 | # If we have a list of dicts, let's convert it in a dict of lists 293 | # We do this to allow using this method as a collate_fn function in PyTorch Dataloader 294 | if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)): 295 | encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()} 296 | 297 | # The model's main input name, usually `input_ids`, has be passed for padding 298 | if self.model_input_names[0] not in encoded_inputs: 299 | raise ValueError( 300 | "You should supply an encoding or a list of encodings to this method" 301 | f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}" 302 | ) 303 | 304 | required_input = encoded_inputs[self.model_input_names[0]] 305 | 306 | if not required_input: 307 | if return_attention_mask: 308 | encoded_inputs["attention_mask"] = [] 309 | return encoded_inputs 310 | 311 | # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects 312 | # and rebuild them afterwards if no return_tensors is specified 313 | # Note that we lose the specific device the tensor may be on for PyTorch 314 | 315 | first_element = required_input[0] 316 | if isinstance(first_element, (list, tuple)): 317 | # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element. 318 | index = 0 319 | while len(required_input[index]) == 0: 320 | index += 1 321 | if index < len(required_input): 322 | first_element = required_input[index][0] 323 | # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do. 324 | if not isinstance(first_element, (int, list, tuple)): 325 | if is_tf_available() and _is_tensorflow(first_element): 326 | return_tensors = "tf" if return_tensors is None else return_tensors 327 | elif is_torch_available() and _is_torch(first_element): 328 | return_tensors = "pt" if return_tensors is None else return_tensors 329 | elif isinstance(first_element, np.ndarray): 330 | return_tensors = "np" if return_tensors is None else return_tensors 331 | else: 332 | raise ValueError( 333 | f"type of {first_element} unknown: {type(first_element)}. " 334 | f"Should be one of a python, numpy, pytorch or tensorflow object." 335 | ) 336 | 337 | for key, value in encoded_inputs.items(): 338 | encoded_inputs[key] = to_py_obj(value) 339 | 340 | # Convert padding_strategy in PaddingStrategy 341 | padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies( 342 | padding=padding, max_length=max_length, verbose=verbose 343 | ) 344 | 345 | required_input = encoded_inputs[self.model_input_names[0]] 346 | if required_input and not isinstance(required_input[0], (list, tuple)): 347 | encoded_inputs = self._pad( 348 | encoded_inputs, 349 | class_type=class_type, 350 | max_length=max_length, 351 | padding_strategy=padding_strategy, 352 | pad_to_multiple_of=pad_to_multiple_of, 353 | return_attention_mask=return_attention_mask, 354 | ) 355 | return BatchEncoding(encoded_inputs, tensor_type=return_tensors) 356 | 357 | batch_size = len(required_input) 358 | assert all( 359 | len(v) == batch_size for v in encoded_inputs.values() 360 | ), "Some items in the output dictionary have a different batch size than others." 361 | 362 | if padding_strategy == PaddingStrategy.LONGEST: 363 | max_length = max(len(inputs) for inputs in required_input) 364 | padding_strategy = PaddingStrategy.MAX_LENGTH 365 | 366 | batch_outputs = {} 367 | for i in range(batch_size): 368 | inputs = dict((k, v[i]) for k, v in encoded_inputs.items()) 369 | outputs = self._pad( 370 | inputs, 371 | class_type=class_type, 372 | max_length=max_length, 373 | padding_strategy=padding_strategy, 374 | pad_to_multiple_of=pad_to_multiple_of, 375 | return_attention_mask=return_attention_mask, 376 | ) 377 | 378 | for key, value in outputs.items(): 379 | if key not in batch_outputs: 380 | batch_outputs[key] = [] 381 | batch_outputs[key].append(value) 382 | if class_type == "cell": 383 | del batch_outputs["label"] 384 | return BatchEncoding(batch_outputs, tensor_type=return_tensors) 385 | 386 | def _pad( 387 | self, 388 | encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], 389 | class_type, # options: "gene" or "cell" 390 | max_length: Optional[int] = None, 391 | padding_strategy: PaddingStrategy = PaddingStrategy.LONGEST, 392 | pad_to_multiple_of: Optional[int] = None, 393 | return_attention_mask: Optional[bool] = True, 394 | ) -> dict: 395 | """ 396 | Pad encoded inputs (on left/right and up to predefined length or max length in the batch) 397 | 398 | Args: 399 | encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). 400 | max_length: maximum length of the returned list and optionally padding length (see below). 401 | Will truncate by taking into account the special tokens. 402 | padding_strategy: PaddingStrategy to use for padding. 403 | 404 | - PaddingStrategy.LONGEST Pad to the longest sequence in the batch 405 | - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) 406 | - PaddingStrategy.DO_NOT_PAD: Do not pad 407 | The tokenizer padding sides are defined in self.padding_side: 408 | 409 | - 'left': pads on the left of the sequences 410 | - 'right': pads on the right of the sequences 411 | pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. 412 | This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability 413 | >= 7.5 (Volta). 414 | return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics) 415 | """ 416 | # Load from model defaults 417 | if return_attention_mask is None: 418 | return_attention_mask = "attention_mask" in self.model_input_names 419 | 420 | required_input = encoded_inputs[self.model_input_names[0]] 421 | 422 | if padding_strategy == PaddingStrategy.LONGEST: 423 | max_length = len(required_input) 424 | 425 | if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): 426 | max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of 427 | 428 | needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length 429 | 430 | if needs_to_be_padded: 431 | difference = max_length - len(required_input) 432 | if self.padding_side == "right": 433 | if return_attention_mask: 434 | encoded_inputs["attention_mask"] = [1] * len(required_input) + [0] * difference 435 | if "token_type_ids" in encoded_inputs: 436 | encoded_inputs["token_type_ids"] = ( 437 | encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference 438 | ) 439 | if "special_tokens_mask" in encoded_inputs: 440 | encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference 441 | encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference 442 | if class_type == "gene": 443 | encoded_inputs["labels"] = encoded_inputs["labels"] + [-100] * difference 444 | elif self.padding_side == "left": 445 | if return_attention_mask: 446 | encoded_inputs["attention_mask"] = [0] * difference + [1] * len(required_input) 447 | if "token_type_ids" in encoded_inputs: 448 | encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ 449 | "token_type_ids" 450 | ] 451 | if "special_tokens_mask" in encoded_inputs: 452 | encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] 453 | encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input 454 | if class_type == "gene": 455 | encoded_inputs["labels"] = [-100] * difference + encoded_inputs["labels"] 456 | else: 457 | raise ValueError("Invalid padding strategy:" + str(self.padding_side)) 458 | elif return_attention_mask and "attention_mask" not in encoded_inputs: 459 | encoded_inputs["attention_mask"] = [1] * len(required_input) 460 | 461 | return encoded_inputs 462 | 463 | def get_special_tokens_mask( 464 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False 465 | ) -> List[int]: 466 | """ 467 | Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding 468 | special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. 469 | Args: 470 | token_ids_0 (:obj:`List[int]`): 471 | List of ids of the first sequence. 472 | token_ids_1 (:obj:`List[int]`, `optional`): 473 | List of ids of the second sequence. 474 | already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): 475 | Whether or not the token list is already formatted with special tokens for the model. 476 | Returns: 477 | A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. 478 | """ 479 | assert already_has_special_tokens and token_ids_1 is None, ( 480 | "You cannot use ``already_has_special_tokens=False`` with this tokenizer. " 481 | "Please use a slow (full python) tokenizer to activate this argument." 482 | "Or set `return_special_tokens_mask=True` when calling the encoding method " 483 | "to get the special tokens mask in any tokenizer. " 484 | ) 485 | 486 | all_special_ids = self.all_special_ids # cache the property 487 | 488 | special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0] 489 | 490 | return special_tokens_mask 491 | 492 | def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]: 493 | """ 494 | Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the 495 | vocabulary. 496 | Args: 497 | tokens (:obj:`str` or :obj:`List[str]`): One or several token(s) to convert to token id(s). 498 | Returns: 499 | :obj:`int` or :obj:`List[int]`: The token id or list of token ids. 500 | """ 501 | if tokens is None: 502 | return None 503 | 504 | if isinstance(tokens, str): 505 | return self._convert_token_to_id_with_added_voc(tokens) 506 | 507 | ids = [] 508 | for token in tokens: 509 | ids.append(self._convert_token_to_id_with_added_voc(token)) 510 | return ids 511 | 512 | def _convert_token_to_id_with_added_voc(self, token): 513 | if token is None: 514 | return None 515 | 516 | return token_dictionary.get(token) 517 | 518 | def __len__(self): 519 | return len(token_dictionary) 520 | 521 | 522 | # collator functions 523 | 524 | class DataCollatorForGeneClassification(DataCollatorForTokenClassification): 525 | """ 526 | Data collator that will dynamically pad the inputs received, as well as the labels. 527 | Args: 528 | tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`): 529 | The tokenizer used for encoding the data. 530 | padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): 531 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 532 | among: 533 | * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 534 | sequence if provided). 535 | * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the 536 | maximum acceptable input length for the model if that argument is not provided. 537 | * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of 538 | different lengths). 539 | max_length (:obj:`int`, `optional`): 540 | Maximum length of the returned list and optionally padding length (see above). 541 | pad_to_multiple_of (:obj:`int`, `optional`): 542 | If set will pad the sequence to a multiple of the provided value. 543 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 544 | 7.5 (Volta). 545 | label_pad_token_id (:obj:`int`, `optional`, defaults to -100): 546 | The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions). 547 | """ 548 | 549 | tokenizer = PrecollatorForGeneAndCellClassification() 550 | class_type = "gene" 551 | padding: Union[bool, str, PaddingStrategy] = True 552 | max_length: Optional[int] = None 553 | pad_to_multiple_of: Optional[int] = None 554 | label_pad_token_id: int = -100 555 | 556 | def __init__(self, *args, **kwargs) -> None: 557 | super().__init__( 558 | tokenizer=self.tokenizer, 559 | padding=self.padding, 560 | max_length=self.max_length, 561 | pad_to_multiple_of=self.pad_to_multiple_of, 562 | label_pad_token_id=self.label_pad_token_id, 563 | *args, **kwargs) 564 | 565 | def _prepare_batch(self, features): 566 | label_name = "label" if "label" in features[0].keys() else "labels" 567 | labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None 568 | batch = self.tokenizer.pad( 569 | features, 570 | class_type=self.class_type, 571 | padding=self.padding, 572 | max_length=self.max_length, 573 | pad_to_multiple_of=self.pad_to_multiple_of, 574 | return_tensors="pt", 575 | ) 576 | return batch 577 | 578 | def __call__(self, features): 579 | batch = self._prepare_batch(features) 580 | 581 | batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()} 582 | return batch 583 | 584 | 585 | class DataCollatorForCellClassification(DataCollatorForGeneClassification): 586 | 587 | class_type = "cell" 588 | 589 | def _prepare_batch(self, features): 590 | 591 | batch = super()._prepare_batch(features) 592 | 593 | # Special handling for labels. 594 | # Ensure that tensor is created with the correct type 595 | # (it should be automatically the case, but let's make sure of it.) 596 | first = features[0] 597 | if "label" in first and first["label"] is not None: 598 | label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"] 599 | dtype = torch.long if isinstance(label, int) else torch.float 600 | batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype) 601 | 602 | return batch 603 | -------------------------------------------------------------------------------- /geneformer_001/build/lib/geneformer/emb_extractor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Geneformer embedding extractor. 3 | 4 | Usage: 5 | from geneformer import EmbExtractor 6 | embex = EmbExtractor(model_type="CellClassifier", 7 | num_classes=3, 8 | emb_mode="cell", 9 | cell_emb_style="mean_pool", 10 | filter_data={"cell_type":["cardiomyocyte"]}, 11 | max_ncells=1000, 12 | max_ncells_to_plot=1000, 13 | emb_layer=-1, 14 | emb_label=["disease","cell_type"], 15 | labels_to_plot=["disease","cell_type"], 16 | forward_batch_size=100, 17 | nproc=16) 18 | embs = embex.extract_embs("path/to/model", 19 | "path/to/input_data", 20 | "path/to/output_directory", 21 | "output_prefix") 22 | embex.plot_embs(embs=embs, 23 | plot_style="heatmap", 24 | output_directory="path/to/output_directory", 25 | output_prefix="output_prefix") 26 | 27 | """ 28 | 29 | # imports 30 | import logging 31 | import anndata 32 | import matplotlib.pyplot as plt 33 | import numpy as np 34 | import pandas as pd 35 | import pickle 36 | import scanpy as sc 37 | import seaborn as sns 38 | import torch 39 | from collections import Counter 40 | from pathlib import Path 41 | from tqdm.notebook import trange 42 | from transformers import BertForMaskedLM, BertForTokenClassification, BertForSequenceClassification 43 | 44 | from .tokenizer import TOKEN_DICTIONARY_FILE 45 | 46 | from .in_silico_perturber import downsample_and_sort, \ 47 | gen_attention_mask, \ 48 | get_model_input_size, \ 49 | load_and_filter, \ 50 | load_model, \ 51 | mean_nonpadding_embs, \ 52 | pad_tensor_list, \ 53 | quant_layers 54 | 55 | logger = logging.getLogger(__name__) 56 | 57 | # average embedding position of goal cell states 58 | def get_embs(model, 59 | filtered_input_data, 60 | emb_mode, 61 | layer_to_quant, 62 | pad_token_id, 63 | forward_batch_size): 64 | 65 | model_input_size = get_model_input_size(model) 66 | total_batch_length = len(filtered_input_data) 67 | if ((total_batch_length-1)/forward_batch_size).is_integer(): 68 | forward_batch_size = forward_batch_size-1 69 | 70 | embs_list = [] 71 | for i in trange(0, total_batch_length, forward_batch_size): 72 | max_range = min(i+forward_batch_size, total_batch_length) 73 | 74 | minibatch = filtered_input_data.select([i for i in range(i, max_range)]) 75 | max_len = max(minibatch["length"]) 76 | original_lens = torch.tensor(minibatch["length"]).to("cuda") 77 | minibatch.set_format(type="torch") 78 | 79 | input_data_minibatch = minibatch["input_ids"] 80 | input_data_minibatch = pad_tensor_list(input_data_minibatch, 81 | max_len, 82 | pad_token_id, 83 | model_input_size) 84 | 85 | with torch.no_grad(): 86 | outputs = model( 87 | input_ids = input_data_minibatch.to("cuda"), 88 | attention_mask = gen_attention_mask(minibatch) 89 | ) 90 | 91 | embs_i = outputs.hidden_states[layer_to_quant] 92 | 93 | if emb_mode == "cell": 94 | mean_embs = mean_nonpadding_embs(embs_i, original_lens) 95 | embs_list += [mean_embs] 96 | 97 | del outputs 98 | del minibatch 99 | del input_data_minibatch 100 | del embs_i 101 | del mean_embs 102 | torch.cuda.empty_cache() 103 | 104 | embs_stack = torch.cat(embs_list) 105 | return embs_stack 106 | 107 | def label_embs(embs, downsampled_data, emb_labels): 108 | embs_df = pd.DataFrame(embs.cpu()) 109 | if emb_labels is not None: 110 | for label in emb_labels: 111 | emb_label = downsampled_data[label] 112 | embs_df[label] = emb_label 113 | return embs_df 114 | 115 | def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict): 116 | only_embs_df = embs_df.iloc[:,:emb_dims] 117 | only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str) 118 | only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(str) 119 | vars_dict = {"embs": only_embs_df.columns} 120 | obs_dict = {"cell_id": list(only_embs_df.index), 121 | f"{label}": list(embs_df[label])} 122 | adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict) 123 | sc.tl.pca(adata, svd_solver='arpack') 124 | sc.pp.neighbors(adata) 125 | sc.tl.umap(adata) 126 | sns.set(rc={'figure.figsize':(10,10)}, font_scale=2.3) 127 | sns.set_style("white") 128 | default_kwargs_dict = {"palette":"Set2", "size":200} 129 | if kwargs_dict is not None: 130 | default_kwargs_dict.update(kwargs_dict) 131 | 132 | sc.pl.umap(adata, color=label, save=output_file, **default_kwargs_dict) 133 | 134 | 135 | def gen_heatmap_class_colors(labels, df): 136 | pal = sns.cubehelix_palette(len(Counter(labels).keys()), light=0.9, dark=0.1, hue=1, reverse=True, start=1, rot=-2) 137 | lut = dict(zip(map(str, Counter(labels).keys()), pal)) 138 | colors = pd.Series(labels, index=df.index).map(lut) 139 | return colors 140 | 141 | def gen_heatmap_class_dict(classes, label_colors_series): 142 | class_color_dict_df = pd.DataFrame({"classes": classes, "color": label_colors_series}) 143 | class_color_dict_df = class_color_dict_df.drop_duplicates(subset=["classes"]) 144 | return dict(zip(class_color_dict_df["classes"],class_color_dict_df["color"])) 145 | 146 | def make_colorbar(embs_df, label): 147 | 148 | labels = list(embs_df[label]) 149 | 150 | cell_type_colors = gen_heatmap_class_colors(labels, embs_df) 151 | label_colors = pd.DataFrame(cell_type_colors, columns=[label]) 152 | 153 | for i,row in label_colors.iterrows(): 154 | colors=row[0] 155 | if len(colors)!=3 or any(np.isnan(colors)): 156 | print(i,colors) 157 | 158 | label_colors.isna().sum() 159 | 160 | # create dictionary for colors and classes 161 | label_color_dict = gen_heatmap_class_dict(labels, label_colors[label]) 162 | return label_colors, label_color_dict 163 | 164 | def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict): 165 | sns.set_style("white") 166 | sns.set(font_scale=2) 167 | plt.figure(figsize=(15, 15), dpi=150) 168 | label_colors, label_color_dict = make_colorbar(embs_df, label) 169 | 170 | default_kwargs_dict = {"row_cluster": True, 171 | "col_cluster": True, 172 | "row_colors": label_colors, 173 | "standard_scale": 1, 174 | "linewidths": 0, 175 | "xticklabels": False, 176 | "yticklabels": False, 177 | "figsize": (15,15), 178 | "center": 0, 179 | "cmap": "magma"} 180 | 181 | if kwargs_dict is not None: 182 | default_kwargs_dict.update(kwargs_dict) 183 | g = sns.clustermap(embs_df.iloc[:,0:emb_dims].apply(pd.to_numeric), **default_kwargs_dict) 184 | 185 | plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right") 186 | 187 | for label_color in list(label_color_dict.keys()): 188 | g.ax_col_dendrogram.bar(0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0) 189 | 190 | l1 = g.ax_col_dendrogram.legend(title=f"{label}", 191 | loc="lower center", 192 | ncol=4, 193 | bbox_to_anchor=(0.5, 1), 194 | facecolor="white") 195 | 196 | plt.savefig(output_file, bbox_inches='tight') 197 | 198 | class EmbExtractor: 199 | valid_option_dict = { 200 | "model_type": {"Pretrained","GeneClassifier","CellClassifier"}, 201 | "num_classes": {int}, 202 | "emb_mode": {"cell","gene"}, 203 | "cell_emb_style": {"mean_pool"}, 204 | "filter_data": {None, dict}, 205 | "max_ncells": {None, int}, 206 | "emb_layer": {-1, 0}, 207 | "emb_label": {None, list}, 208 | "labels_to_plot": {None, list}, 209 | "forward_batch_size": {int}, 210 | "nproc": {int}, 211 | } 212 | def __init__( 213 | self, 214 | model_type="Pretrained", 215 | num_classes=0, 216 | emb_mode="cell", 217 | cell_emb_style="mean_pool", 218 | filter_data=None, 219 | max_ncells=1000, 220 | emb_layer=-1, 221 | emb_label=None, 222 | labels_to_plot=None, 223 | forward_batch_size=100, 224 | nproc=4, 225 | token_dictionary_file=TOKEN_DICTIONARY_FILE, 226 | ): 227 | """ 228 | Initialize embedding extractor. 229 | 230 | Parameters 231 | ---------- 232 | model_type : {"Pretrained","GeneClassifier","CellClassifier"} 233 | Whether model is the pretrained Geneformer or a fine-tuned gene or cell classifier. 234 | num_classes : int 235 | If model is a gene or cell classifier, specify number of classes it was trained to classify. 236 | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier. 237 | emb_mode : {"cell","gene"} 238 | Whether to output cell or gene embeddings. 239 | cell_emb_style : "mean_pool" 240 | Method for summarizing cell embeddings. 241 | Currently only option is mean pooling of gene embeddings for given cell. 242 | filter_data : None, dict 243 | Default is to extract embeddings from all input data. 244 | Otherwise, dictionary specifying .dataset column name and list of values to filter by. 245 | max_ncells : None, int 246 | Maximum number of cells to extract embeddings from. 247 | Default is 1000 cells randomly sampled from input data. 248 | If None, will extract embeddings from all cells. 249 | emb_layer : {-1, 0} 250 | Embedding layer to extract. 251 | The last layer is most specifically weighted to optimize the given learning objective. 252 | Generally, it is best to extract the 2nd to last layer to get a more general representation. 253 | -1: 2nd to last layer 254 | 0: last layer 255 | emb_label : None, list 256 | List of column name(s) in .dataset to add as labels to embedding output. 257 | labels_to_plot : None, list 258 | Cell labels to plot. 259 | Shown as color bar in heatmap. 260 | Shown as cell color in umap. 261 | Plotting umap requires labels to plot. 262 | forward_batch_size : int 263 | Batch size for forward pass. 264 | nproc : int 265 | Number of CPU processes to use. 266 | token_dictionary_file : Path 267 | Path to pickle file containing token dictionary (Ensembl ID:token). 268 | """ 269 | 270 | self.model_type = model_type 271 | self.num_classes = num_classes 272 | self.emb_mode = emb_mode 273 | self.cell_emb_style = cell_emb_style 274 | self.filter_data = filter_data 275 | self.max_ncells = max_ncells 276 | self.emb_layer = emb_layer 277 | self.emb_label = emb_label 278 | self.labels_to_plot = labels_to_plot 279 | self.forward_batch_size = forward_batch_size 280 | self.nproc = nproc 281 | 282 | self.validate_options() 283 | 284 | # load token dictionary (Ensembl IDs:token) 285 | with open(token_dictionary_file, "rb") as f: 286 | self.gene_token_dict = pickle.load(f) 287 | 288 | self.pad_token_id = self.gene_token_dict.get("") 289 | 290 | 291 | def validate_options(self): 292 | # first disallow options under development 293 | if self.emb_mode == "gene": 294 | logger.error( 295 | "Extraction and plotting of gene-level embeddings currently under development. " \ 296 | "Current valid option for 'emb_mode': 'cell'" 297 | ) 298 | raise 299 | 300 | # confirm arguments are within valid options and compatible with each other 301 | for attr_name,valid_options in self.valid_option_dict.items(): 302 | attr_value = self.__dict__[attr_name] 303 | if type(attr_value) not in {list, dict}: 304 | if attr_value in valid_options: 305 | continue 306 | valid_type = False 307 | for option in valid_options: 308 | if (option in [int,list,dict]) and isinstance(attr_value, option): 309 | valid_type = True 310 | break 311 | if valid_type: 312 | continue 313 | logger.error( 314 | f"Invalid option for {attr_name}. " \ 315 | f"Valid options for {attr_name}: {valid_options}" 316 | ) 317 | raise 318 | 319 | if self.filter_data is not None: 320 | for key,value in self.filter_data.items(): 321 | if type(value) != list: 322 | self.filter_data[key] = [value] 323 | logger.warning( 324 | "Values in filter_data dict must be lists. " \ 325 | f"Changing {key} value to list ([{value}]).") 326 | 327 | def extract_embs(self, 328 | model_directory, 329 | input_data_file, 330 | output_directory, 331 | output_prefix): 332 | """ 333 | Extract embeddings from input data and save as results in output_directory. 334 | 335 | Parameters 336 | ---------- 337 | model_directory : Path 338 | Path to directory containing model 339 | input_data_file : Path 340 | Path to directory containing .dataset inputs 341 | output_directory : Path 342 | Path to directory where embedding data will be saved as csv 343 | output_prefix : str 344 | Prefix for output file 345 | """ 346 | 347 | filtered_input_data = load_and_filter(self.filter_data, self.nproc, input_data_file) 348 | downsampled_data = downsample_and_sort(filtered_input_data, self.max_ncells) 349 | model = load_model(self.model_type, self.num_classes, model_directory) 350 | layer_to_quant = quant_layers(model)+self.emb_layer 351 | embs = get_embs(model, 352 | downsampled_data, 353 | self.emb_mode, 354 | layer_to_quant, 355 | self.pad_token_id, 356 | self.forward_batch_size) 357 | embs_df = label_embs(embs, downsampled_data, self.emb_label) 358 | 359 | # save embeddings to output_path 360 | output_path = (Path(output_directory) / output_prefix).with_suffix(".csv") 361 | embs_df.to_csv(output_path) 362 | 363 | return embs_df 364 | 365 | def plot_embs(self, 366 | embs, 367 | plot_style, 368 | output_directory, 369 | output_prefix, 370 | max_ncells_to_plot=1000, 371 | kwargs_dict=None): 372 | 373 | """ 374 | Plot embeddings, coloring by provided labels. 375 | 376 | Parameters 377 | ---------- 378 | embs : pandas.core.frame.DataFrame 379 | Pandas dataframe containing embeddings output from extract_embs 380 | plot_style : str 381 | Style of plot: "heatmap" or "umap" 382 | output_directory : Path 383 | Path to directory where plots will be saved as pdf 384 | output_prefix : str 385 | Prefix for output file 386 | max_ncells_to_plot : None, int 387 | Maximum number of cells to plot. 388 | Default is 1000 cells randomly sampled from embeddings. 389 | If None, will plot embeddings from all cells. 390 | kwargs_dict : dict 391 | Dictionary of kwargs to pass to plotting function. 392 | """ 393 | 394 | if plot_style not in ["heatmap","umap"]: 395 | logger.error( 396 | "Invalid option for 'plot_style'. " \ 397 | "Valid options: {'heatmap','umap'}" 398 | ) 399 | raise 400 | 401 | if (plot_style == "umap") and (self.labels_to_plot is None): 402 | logger.error( 403 | "Plotting UMAP requires 'labels_to_plot'. " 404 | ) 405 | raise 406 | 407 | if max_ncells_to_plot > self.max_ncells: 408 | max_ncells_to_plot = self.max_ncells 409 | logger.warning( 410 | "max_ncells_to_plot must be <= max_ncells. " \ 411 | f"Changing max_ncells_to_plot to {self.max_ncells}.") 412 | 413 | if (max_ncells_to_plot is not None) \ 414 | and (max_ncells_to_plot < self.max_ncells): 415 | embs = embs.sample(max_ncells_to_plot, axis=0) 416 | 417 | if self.emb_label is None: 418 | label_len = 0 419 | else: 420 | label_len = len(self.emb_label) 421 | 422 | emb_dims = embs.shape[1] - label_len 423 | 424 | if self.emb_label is None: 425 | emb_labels = None 426 | else: 427 | emb_labels = embs.columns[emb_dims:] 428 | 429 | if plot_style == "umap": 430 | for label in self.labels_to_plot: 431 | if label not in emb_labels: 432 | logger.warning( 433 | f"Label {label} from labels_to_plot " \ 434 | f"not present in provided embeddings dataframe.") 435 | continue 436 | output_prefix_label = "_" + output_prefix + f"_umap_{label}" 437 | output_file = (Path(output_directory) / output_prefix_label).with_suffix(".pdf") 438 | plot_umap(embs, emb_dims, label, output_prefix_label, kwargs_dict) 439 | 440 | if plot_style == "heatmap": 441 | for label in self.labels_to_plot: 442 | if label not in emb_labels: 443 | logger.warning( 444 | f"Label {label} from labels_to_plot " \ 445 | f"not present in provided embeddings dataframe.") 446 | continue 447 | output_prefix_label = output_prefix + f"_heatmap_{label}" 448 | output_file = (Path(output_directory) / output_prefix_label).with_suffix(".pdf") 449 | plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict) 450 | -------------------------------------------------------------------------------- /geneformer_001/build/lib/geneformer/gene_median_dictionary.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PharMolix/LangCell/69e41ef2fae485b67703294b50f3150bb1d5bb9b/geneformer_001/build/lib/geneformer/gene_median_dictionary.pkl -------------------------------------------------------------------------------- /geneformer_001/build/lib/geneformer/gene_name_id_dict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PharMolix/LangCell/69e41ef2fae485b67703294b50f3150bb1d5bb9b/geneformer_001/build/lib/geneformer/gene_name_id_dict.pkl -------------------------------------------------------------------------------- /geneformer_001/build/lib/geneformer/token_dictionary.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PharMolix/LangCell/69e41ef2fae485b67703294b50f3150bb1d5bb9b/geneformer_001/build/lib/geneformer/token_dictionary.pkl -------------------------------------------------------------------------------- /geneformer_001/build/lib/geneformer/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Geneformer tokenizer. 3 | 4 | Input data: 5 | Required format: raw counts scRNAseq data without feature selection as .loom file 6 | Required row (gene) attribute: "ensembl_id"; Ensembl ID for each gene 7 | Required col (cell) attribute: "n_counts"; total read counts in that cell 8 | Optional col (cell) attribute: "filter_pass"; binary indicator of whether cell should be tokenized based on user-defined filtering criteria 9 | Optional col (cell) attributes: any other cell metadata can be passed on to the tokenized dataset as a custom attribute dictionary as shown below 10 | 11 | Usage: 12 | from geneformer import TranscriptomeTokenizer 13 | tk = TranscriptomeTokenizer({"cell_type": "cell_type", "organ_major": "organ_major"}, nproc=4) 14 | tk.tokenize_data("loom_data_directory", "output_directory", "output_prefix") 15 | """ 16 | 17 | import pickle 18 | from pathlib import Path 19 | 20 | import logging 21 | 22 | import warnings 23 | warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") 24 | 25 | import loompy as lp 26 | import numpy as np 27 | from datasets import Dataset 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl" 32 | TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl" 33 | 34 | 35 | def tokenize_cell(gene_vector, gene_tokens): 36 | """ 37 | Convert normalized gene expression vector to tokenized rank value encoding. 38 | """ 39 | # create array of gene vector with token indices 40 | # mask undetected genes 41 | nonzero_mask = np.nonzero(gene_vector)[0] 42 | # sort by median-scaled gene values 43 | sorted_indices = np.argsort(-gene_vector[nonzero_mask]) 44 | # tokenize 45 | sentence_tokens = gene_tokens[nonzero_mask][sorted_indices] 46 | return sentence_tokens 47 | 48 | 49 | class TranscriptomeTokenizer: 50 | def __init__( 51 | self, 52 | custom_attr_name_dict=None, 53 | nproc=1, 54 | gene_median_file=GENE_MEDIAN_FILE, 55 | token_dictionary_file=TOKEN_DICTIONARY_FILE, 56 | ): 57 | """ 58 | Initialize tokenizer. 59 | 60 | Parameters 61 | ---------- 62 | custom_attr_name_dict : None, dict 63 | Dictionary of custom attributes to be added to the dataset. 64 | Keys are the names of the attributes in the loom file. 65 | Values are the names of the attributes in the dataset. 66 | nproc : int 67 | Number of processes to use for dataset mapping. 68 | gene_median_file : Path 69 | Path to pickle file containing dictionary of non-zero median 70 | gene expression values across Genecorpus-30M. 71 | token_dictionary_file : Path 72 | Path to pickle file containing token dictionary (Ensembl IDs:token). 73 | """ 74 | # dictionary of custom attributes {output dataset column name: input .loom column name} 75 | self.custom_attr_name_dict = custom_attr_name_dict 76 | 77 | # number of processes for dataset mapping 78 | self.nproc = nproc 79 | 80 | # load dictionary of gene normalization factors 81 | # (non-zero median value of expression across Genecorpus-30M) 82 | with open(gene_median_file, "rb") as f: 83 | self.gene_median_dict = pickle.load(f) 84 | 85 | # load token dictionary (Ensembl IDs:token) 86 | with open(token_dictionary_file, "rb") as f: 87 | self.gene_token_dict = pickle.load(f) 88 | 89 | # gene keys for full vocabulary 90 | self.gene_keys = list(self.gene_median_dict.keys()) 91 | 92 | # protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization 93 | self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys))) 94 | 95 | def tokenize_data(self, loom_data_directory, output_directory, output_prefix): 96 | """ 97 | Tokenize .loom files in loom_data_directory and save as tokenized .dataset in output_directory. 98 | 99 | Parameters 100 | ---------- 101 | loom_data_directory : Path 102 | Path to directory containing loom files 103 | output_directory : Path 104 | Path to directory where tokenized data will be saved as .dataset 105 | output_prefix : str 106 | Prefix for output .dataset 107 | """ 108 | tokenized_cells, cell_metadata = self.tokenize_files(Path(loom_data_directory)) 109 | tokenized_dataset = self.create_dataset(tokenized_cells, cell_metadata) 110 | 111 | output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset") 112 | tokenized_dataset.save_to_disk(output_path) 113 | 114 | def tokenize_files(self, loom_data_directory): 115 | tokenized_cells = [] 116 | if self.custom_attr_name_dict is not None: 117 | loom_cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()] 118 | cell_metadata = {attr_key: [] for attr_key in self.custom_attr_name_dict.values()} 119 | 120 | # loops through directories to tokenize .loom files 121 | file_found = 0 122 | for loom_file_path in loom_data_directory.glob("*.loom"): 123 | file_found = 1 124 | print(f"Tokenizing {loom_file_path}") 125 | file_tokenized_cells, file_cell_metadata = self.tokenize_file( 126 | loom_file_path 127 | ) 128 | tokenized_cells += file_tokenized_cells 129 | if self.custom_attr_name_dict is not None: 130 | for k in loom_cell_attr: 131 | cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[k] 132 | else: 133 | cell_metadata = None 134 | 135 | if file_found == 0: 136 | logger.error( 137 | f"No .loom files found in directory {loom_data_directory}.") 138 | raise 139 | return tokenized_cells, cell_metadata 140 | 141 | def tokenize_file(self, loom_file_path): 142 | if self.custom_attr_name_dict is not None: 143 | file_cell_metadata = { 144 | attr_key: [] for attr_key in self.custom_attr_name_dict.keys() 145 | } 146 | 147 | with lp.connect(str(loom_file_path)) as data: 148 | # define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors 149 | coding_miRNA_loc = np.where( 150 | [self.genelist_dict.get(i, False) for i in data.ra["ensembl_id"]] 151 | )[0] 152 | norm_factor_vector = np.array( 153 | [ 154 | self.gene_median_dict[i] 155 | for i in data.ra["ensembl_id"][coding_miRNA_loc] 156 | ] 157 | ) 158 | coding_miRNA_ids = data.ra["ensembl_id"][coding_miRNA_loc] 159 | coding_miRNA_tokens = np.array( 160 | [self.gene_token_dict[i] for i in coding_miRNA_ids] 161 | ) 162 | 163 | # define coordinates of cells passing filters for inclusion (e.g. QC) 164 | try: 165 | data.ca["filter_pass"] 166 | except AttributeError: 167 | var_exists = False 168 | else: 169 | var_exists = True 170 | 171 | if var_exists is True: 172 | filter_pass_loc = np.where( 173 | [True if i == 1 else False for i in data.ca["filter_pass"]] 174 | )[0] 175 | elif var_exists is False: 176 | print( 177 | f"{loom_file_path} has no column attribute 'filter_pass'; tokenizing all cells." 178 | ) 179 | filter_pass_loc = np.array([i for i in range(data.shape[1])]) 180 | 181 | # scan through .loom files and tokenize cells 182 | tokenized_cells = [] 183 | for (_ix, _selection, view) in data.scan(items=filter_pass_loc, axis=1): 184 | # select subview with protein-coding and miRNA genes 185 | subview = view.view[coding_miRNA_loc, :] 186 | 187 | # normalize by total counts per cell and multiply by 10,000 to allocate bits to precision 188 | # and normalize by gene normalization factors 189 | subview_norm_array = ( 190 | subview[:, :] 191 | / subview.ca.n_counts 192 | * 10_000 193 | / norm_factor_vector[:, None] 194 | ) 195 | # tokenize subview gene vectors 196 | tokenized_cells += [ 197 | tokenize_cell(subview_norm_array[:, i], coding_miRNA_tokens) 198 | for i in range(subview_norm_array.shape[1]) 199 | ] 200 | 201 | # add custom attributes for subview to dict 202 | if self.custom_attr_name_dict is not None: 203 | for k in file_cell_metadata.keys(): 204 | file_cell_metadata[k] += subview.ca[k].tolist() 205 | else: 206 | file_cell_metadata = None 207 | 208 | return tokenized_cells, file_cell_metadata 209 | 210 | def create_dataset(self, tokenized_cells, cell_metadata): 211 | # create dict for dataset creation 212 | dataset_dict = {"input_ids": tokenized_cells} 213 | if self.custom_attr_name_dict is not None: 214 | dataset_dict.update(cell_metadata) 215 | 216 | # create dataset 217 | output_dataset = Dataset.from_dict(dataset_dict) 218 | 219 | # truncate dataset 220 | def truncate(example): 221 | example["input_ids"] = example["input_ids"][0:2048] 222 | return example 223 | 224 | output_dataset_truncated = output_dataset.map(truncate, num_proc=self.nproc) 225 | 226 | # measure lengths of dataset 227 | def measure_length(example): 228 | example["length"] = len(example["input_ids"]) 229 | return example 230 | 231 | output_dataset_truncated_w_length = output_dataset_truncated.map( 232 | measure_length, num_proc=self.nproc 233 | ) 234 | 235 | return output_dataset_truncated_w_length 236 | -------------------------------------------------------------------------------- /geneformer_001/geneformer.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: geneformer 3 | Version: 0.0.1 4 | Summary: Geneformer is a transformer model pretrained on a large-scale corpus of ~30 million single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology. 5 | Author: Christina Theodoris 6 | Author-email: christina.theodoris@gladstone.ucsf.edu 7 | -------------------------------------------------------------------------------- /geneformer_001/geneformer.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | MANIFEST.in 2 | setup.py 3 | geneformer/__init__.py 4 | geneformer/collator_for_classification.py 5 | geneformer/emb_extractor.py 6 | geneformer/gene_median_dictionary.pkl 7 | geneformer/gene_name_id_dict.pkl 8 | geneformer/in_silico_perturber.py 9 | geneformer/in_silico_perturber_stats.py 10 | geneformer/pretrainer.py 11 | geneformer/token_dictionary.pkl 12 | geneformer/tokenizer.py 13 | geneformer.egg-info/PKG-INFO 14 | geneformer.egg-info/SOURCES.txt 15 | geneformer.egg-info/dependency_links.txt 16 | geneformer.egg-info/requires.txt 17 | geneformer.egg-info/top_level.txt -------------------------------------------------------------------------------- /geneformer_001/geneformer.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /geneformer_001/geneformer.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | datasets 2 | loompy 3 | numpy 4 | transformers 5 | -------------------------------------------------------------------------------- /geneformer_001/geneformer.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | geneformer 2 | -------------------------------------------------------------------------------- /geneformer_001/geneformer/__init__.py: -------------------------------------------------------------------------------- 1 | from . import tokenizer 2 | from . import pretrainer 3 | from . import collator_for_classification 4 | from . import in_silico_perturber 5 | from . import in_silico_perturber_stats 6 | from .tokenizer import TranscriptomeTokenizer 7 | from .pretrainer import GeneformerPretrainer 8 | from .collator_for_classification import DataCollatorForGeneClassification 9 | from .collator_for_classification import DataCollatorForCellClassification 10 | from .emb_extractor import EmbExtractor 11 | from .in_silico_perturber import InSilicoPerturber 12 | from .in_silico_perturber_stats import InSilicoPerturberStats -------------------------------------------------------------------------------- /geneformer_001/geneformer/collator_for_classification.py: -------------------------------------------------------------------------------- 1 | """ 2 | Geneformer collator for gene and cell classification. 3 | 4 | Huggingface data collator modified to accommodate single-cell transcriptomics data for gene and cell classification. 5 | """ 6 | import numpy as np 7 | import torch 8 | import warnings 9 | from enum import Enum 10 | from typing import Dict, List, Optional, Union 11 | 12 | from transformers import ( 13 | DataCollatorForTokenClassification, 14 | SpecialTokensMixin, 15 | BatchEncoding, 16 | ) 17 | from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj 18 | from transformers.utils.generic import _is_tensorflow, _is_torch 19 | 20 | from .pretrainer import token_dictionary 21 | 22 | EncodedInput = List[int] 23 | logger = logging.get_logger(__name__) 24 | VERY_LARGE_INTEGER = int( 25 | 1e30 26 | ) # This is used to set the max input length for a model with infinite size input 27 | LARGE_INTEGER = int( 28 | 1e20 29 | ) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER 30 | 31 | # precollator functions 32 | 33 | class ExplicitEnum(Enum): 34 | """ 35 | Enum with more explicit error message for missing values. 36 | """ 37 | 38 | @classmethod 39 | def _missing_(cls, value): 40 | raise ValueError( 41 | "%r is not a valid %s, please select one of %s" 42 | % (value, cls.__name__, str(list(cls._value2member_map_.keys()))) 43 | ) 44 | 45 | class TruncationStrategy(ExplicitEnum): 46 | """ 47 | Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for 48 | tab-completion in an IDE. 49 | """ 50 | 51 | ONLY_FIRST = "only_first" 52 | ONLY_SECOND = "only_second" 53 | LONGEST_FIRST = "longest_first" 54 | DO_NOT_TRUNCATE = "do_not_truncate" 55 | 56 | 57 | 58 | class PaddingStrategy(ExplicitEnum): 59 | """ 60 | Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion 61 | in an IDE. 62 | """ 63 | 64 | LONGEST = "longest" 65 | MAX_LENGTH = "max_length" 66 | DO_NOT_PAD = "do_not_pad" 67 | 68 | 69 | 70 | class TensorType(ExplicitEnum): 71 | """ 72 | Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for 73 | tab-completion in an IDE. 74 | """ 75 | 76 | PYTORCH = "pt" 77 | TENSORFLOW = "tf" 78 | NUMPY = "np" 79 | JAX = "jax" 80 | 81 | 82 | class PrecollatorForGeneAndCellClassification(SpecialTokensMixin): 83 | mask_token = "" 84 | mask_token_id = token_dictionary.get("") 85 | pad_token = "" 86 | pad_token_id = token_dictionary.get("") 87 | padding_side = "right" 88 | all_special_ids = [ 89 | token_dictionary.get(""), 90 | token_dictionary.get("") 91 | ] 92 | model_input_names = ["input_ids"] 93 | 94 | def _get_padding_truncation_strategies( 95 | self, padding=True, truncation=False, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs 96 | ): 97 | """ 98 | Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy 99 | and pad_to_max_length) and behaviors. 100 | """ 101 | old_truncation_strategy = kwargs.pop("truncation_strategy", "do_not_truncate") 102 | old_pad_to_max_length = kwargs.pop("pad_to_max_length", False) 103 | 104 | # Backward compatibility for previous behavior, maybe we should deprecate it: 105 | # If you only set max_length, it activates truncation for max_length 106 | if max_length is not None and padding is False and truncation is False: 107 | if verbose: 108 | if not self.deprecation_warnings.get("Truncation-not-explicitly-activated", False): 109 | logger.warning( 110 | "Truncation was not explicitly activated but `max_length` is provided a specific value, " 111 | "please use `truncation=True` to explicitly truncate examples to max length. " 112 | "Defaulting to 'longest_first' truncation strategy. " 113 | "If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy " 114 | "more precisely by providing a specific strategy to `truncation`." 115 | ) 116 | self.deprecation_warnings["Truncation-not-explicitly-activated"] = True 117 | truncation = "longest_first" 118 | 119 | # Get padding strategy 120 | if padding is False and old_pad_to_max_length: 121 | if verbose: 122 | warnings.warn( 123 | "The `pad_to_max_length` argument is deprecated and will be removed in a future version, " 124 | "use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or " 125 | "use `padding='max_length'` to pad to a max length. In this case, you can give a specific " 126 | "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the " 127 | "maximal input size of the model (e.g. 512 for Bert).", 128 | FutureWarning, 129 | ) 130 | if max_length is None: 131 | padding_strategy = PaddingStrategy.LONGEST 132 | else: 133 | padding_strategy = PaddingStrategy.MAX_LENGTH 134 | elif padding is not False: 135 | if padding is True: 136 | padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch 137 | elif not isinstance(padding, PaddingStrategy): 138 | padding_strategy = PaddingStrategy(padding) 139 | elif isinstance(padding, PaddingStrategy): 140 | padding_strategy = padding 141 | else: 142 | padding_strategy = PaddingStrategy.DO_NOT_PAD 143 | 144 | # Get truncation strategy 145 | if truncation is False and old_truncation_strategy != "do_not_truncate": 146 | if verbose: 147 | warnings.warn( 148 | "The `truncation_strategy` argument is deprecated and will be removed in a future version, " 149 | "use `truncation=True` to truncate examples to a max length. You can give a specific " 150 | "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the " 151 | "maximal input size of the model (e.g. 512 for Bert). " 152 | " If you have pairs of inputs, you can give a specific truncation strategy selected among " 153 | "`truncation='only_first'` (will only truncate the first sentence in the pairs) " 154 | "`truncation='only_second'` (will only truncate the second sentence in the pairs) " 155 | "or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence in the pairs).", 156 | FutureWarning, 157 | ) 158 | truncation_strategy = TruncationStrategy(old_truncation_strategy) 159 | elif truncation is not False: 160 | if truncation is True: 161 | truncation_strategy = ( 162 | TruncationStrategy.LONGEST_FIRST 163 | ) # Default to truncate the longest sequences in pairs of inputs 164 | elif not isinstance(truncation, TruncationStrategy): 165 | truncation_strategy = TruncationStrategy(truncation) 166 | elif isinstance(truncation, TruncationStrategy): 167 | truncation_strategy = truncation 168 | else: 169 | truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE 170 | 171 | # Set max length if needed 172 | if max_length is None: 173 | if padding_strategy == PaddingStrategy.MAX_LENGTH: 174 | if self.model_max_length > LARGE_INTEGER: 175 | if verbose: 176 | if not self.deprecation_warnings.get("Asking-to-pad-to-max_length", False): 177 | logger.warning( 178 | "Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. " 179 | "Default to no padding." 180 | ) 181 | self.deprecation_warnings["Asking-to-pad-to-max_length"] = True 182 | padding_strategy = PaddingStrategy.DO_NOT_PAD 183 | else: 184 | max_length = self.model_max_length 185 | 186 | if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE: 187 | if self.model_max_length > LARGE_INTEGER: 188 | if verbose: 189 | if not self.deprecation_warnings.get("Asking-to-truncate-to-max_length", False): 190 | logger.warning( 191 | "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. " 192 | "Default to no truncation." 193 | ) 194 | self.deprecation_warnings["Asking-to-truncate-to-max_length"] = True 195 | truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE 196 | else: 197 | max_length = self.model_max_length 198 | 199 | # Test if we have a padding token 200 | if padding_strategy != PaddingStrategy.DO_NOT_PAD and (not self.pad_token or self.pad_token_id < 0): 201 | raise ValueError( 202 | "Asking to pad but the tokenizer does not have a padding token. " 203 | "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` " 204 | "or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`." 205 | ) 206 | 207 | # Check that we will truncate to a multiple of pad_to_multiple_of if both are provided 208 | if ( 209 | truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE 210 | and padding_strategy != PaddingStrategy.DO_NOT_PAD 211 | and pad_to_multiple_of is not None 212 | and max_length is not None 213 | and (max_length % pad_to_multiple_of != 0) 214 | ): 215 | raise ValueError( 216 | f"Truncation and padding are both activated but " 217 | f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})." 218 | ) 219 | 220 | return padding_strategy, truncation_strategy, max_length, kwargs 221 | 222 | def pad( 223 | self, 224 | encoded_inputs: Union[ 225 | BatchEncoding, 226 | List[BatchEncoding], 227 | Dict[str, EncodedInput], 228 | Dict[str, List[EncodedInput]], 229 | List[Dict[str, EncodedInput]], 230 | ], 231 | class_type, # options: "gene" or "cell" 232 | padding: Union[bool, str, PaddingStrategy] = True, 233 | max_length: Optional[int] = None, 234 | pad_to_multiple_of: Optional[int] = None, 235 | return_attention_mask: Optional[bool] = True, 236 | return_tensors: Optional[Union[str, TensorType]] = None, 237 | verbose: bool = True, 238 | ) -> BatchEncoding: 239 | """ 240 | Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length 241 | in the batch. 242 | 243 | Padding side (left/right) padding token ids are defined at the tokenizer level (with ``self.padding_side``, 244 | ``self.pad_token_id`` and ``self.pad_token_type_id``) 245 | 246 | .. note:: 247 | 248 | If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the 249 | result will use the same type unless you provide a different tensor type with ``return_tensors``. In the 250 | case of PyTorch tensors, you will lose the specific device of your tensors however. 251 | 252 | Args: 253 | encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`): 254 | Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str, 255 | List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str, 256 | List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as 257 | well as in a PyTorch Dataloader collate function. 258 | 259 | Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors), 260 | see the note above for the return type. 261 | padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): 262 | Select a strategy to pad the returned sequences (according to the model's padding side and padding 263 | index) among: 264 | 265 | * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a 266 | single sequence if provided). 267 | * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the 268 | maximum acceptable input length for the model if that argument is not provided. 269 | * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of 270 | different lengths). 271 | max_length (:obj:`int`, `optional`): 272 | Maximum length of the returned list and optionally padding length (see above). 273 | pad_to_multiple_of (:obj:`int`, `optional`): 274 | If set will pad the sequence to a multiple of the provided value. 275 | 276 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability 277 | >= 7.5 (Volta). 278 | return_attention_mask (:obj:`bool`, `optional`): 279 | Whether to return the attention mask. If left to the default, will return the attention mask according 280 | to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute. 281 | 282 | `What are attention masks? <../glossary.html#attention-mask>`__ 283 | return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`): 284 | If set, will return tensors instead of list of python integers. Acceptable values are: 285 | 286 | * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. 287 | * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. 288 | * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects. 289 | verbose (:obj:`bool`, `optional`, defaults to :obj:`True`): 290 | Whether or not to print more information and warnings. 291 | """ 292 | # If we have a list of dicts, let's convert it in a dict of lists 293 | # We do this to allow using this method as a collate_fn function in PyTorch Dataloader 294 | if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)): 295 | encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()} 296 | 297 | # The model's main input name, usually `input_ids`, has be passed for padding 298 | if self.model_input_names[0] not in encoded_inputs: 299 | raise ValueError( 300 | "You should supply an encoding or a list of encodings to this method" 301 | f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}" 302 | ) 303 | 304 | required_input = encoded_inputs[self.model_input_names[0]] 305 | 306 | if not required_input: 307 | if return_attention_mask: 308 | encoded_inputs["attention_mask"] = [] 309 | return encoded_inputs 310 | 311 | # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects 312 | # and rebuild them afterwards if no return_tensors is specified 313 | # Note that we lose the specific device the tensor may be on for PyTorch 314 | 315 | first_element = required_input[0] 316 | if isinstance(first_element, (list, tuple)): 317 | # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element. 318 | index = 0 319 | while len(required_input[index]) == 0: 320 | index += 1 321 | if index < len(required_input): 322 | first_element = required_input[index][0] 323 | # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do. 324 | if not isinstance(first_element, (int, list, tuple)): 325 | if is_tf_available() and _is_tensorflow(first_element): 326 | return_tensors = "tf" if return_tensors is None else return_tensors 327 | elif is_torch_available() and _is_torch(first_element): 328 | return_tensors = "pt" if return_tensors is None else return_tensors 329 | elif isinstance(first_element, np.ndarray): 330 | return_tensors = "np" if return_tensors is None else return_tensors 331 | else: 332 | raise ValueError( 333 | f"type of {first_element} unknown: {type(first_element)}. " 334 | f"Should be one of a python, numpy, pytorch or tensorflow object." 335 | ) 336 | 337 | for key, value in encoded_inputs.items(): 338 | encoded_inputs[key] = to_py_obj(value) 339 | 340 | # Convert padding_strategy in PaddingStrategy 341 | padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies( 342 | padding=padding, max_length=max_length, verbose=verbose 343 | ) 344 | 345 | required_input = encoded_inputs[self.model_input_names[0]] 346 | if required_input and not isinstance(required_input[0], (list, tuple)): 347 | encoded_inputs = self._pad( 348 | encoded_inputs, 349 | class_type=class_type, 350 | max_length=max_length, 351 | padding_strategy=padding_strategy, 352 | pad_to_multiple_of=pad_to_multiple_of, 353 | return_attention_mask=return_attention_mask, 354 | ) 355 | return BatchEncoding(encoded_inputs, tensor_type=return_tensors) 356 | 357 | batch_size = len(required_input) 358 | assert all( 359 | len(v) == batch_size for v in encoded_inputs.values() 360 | ), "Some items in the output dictionary have a different batch size than others." 361 | 362 | if padding_strategy == PaddingStrategy.LONGEST: 363 | max_length = max(len(inputs) for inputs in required_input) 364 | padding_strategy = PaddingStrategy.MAX_LENGTH 365 | 366 | batch_outputs = {} 367 | for i in range(batch_size): 368 | inputs = dict((k, v[i]) for k, v in encoded_inputs.items()) 369 | outputs = self._pad( 370 | inputs, 371 | class_type=class_type, 372 | max_length=max_length, 373 | padding_strategy=padding_strategy, 374 | pad_to_multiple_of=pad_to_multiple_of, 375 | return_attention_mask=return_attention_mask, 376 | ) 377 | 378 | for key, value in outputs.items(): 379 | if key not in batch_outputs: 380 | batch_outputs[key] = [] 381 | batch_outputs[key].append(value) 382 | if class_type == "cell": 383 | del batch_outputs["label"] 384 | return BatchEncoding(batch_outputs, tensor_type=return_tensors) 385 | 386 | def _pad( 387 | self, 388 | encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], 389 | class_type, # options: "gene" or "cell" 390 | max_length: Optional[int] = None, 391 | padding_strategy: PaddingStrategy = PaddingStrategy.LONGEST, 392 | pad_to_multiple_of: Optional[int] = None, 393 | return_attention_mask: Optional[bool] = True, 394 | ) -> dict: 395 | """ 396 | Pad encoded inputs (on left/right and up to predefined length or max length in the batch) 397 | 398 | Args: 399 | encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). 400 | max_length: maximum length of the returned list and optionally padding length (see below). 401 | Will truncate by taking into account the special tokens. 402 | padding_strategy: PaddingStrategy to use for padding. 403 | 404 | - PaddingStrategy.LONGEST Pad to the longest sequence in the batch 405 | - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) 406 | - PaddingStrategy.DO_NOT_PAD: Do not pad 407 | The tokenizer padding sides are defined in self.padding_side: 408 | 409 | - 'left': pads on the left of the sequences 410 | - 'right': pads on the right of the sequences 411 | pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. 412 | This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability 413 | >= 7.5 (Volta). 414 | return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics) 415 | """ 416 | # Load from model defaults 417 | if return_attention_mask is None: 418 | return_attention_mask = "attention_mask" in self.model_input_names 419 | 420 | required_input = encoded_inputs[self.model_input_names[0]] 421 | 422 | if padding_strategy == PaddingStrategy.LONGEST: 423 | max_length = len(required_input) 424 | 425 | if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): 426 | max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of 427 | 428 | needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length 429 | 430 | if needs_to_be_padded: 431 | difference = max_length - len(required_input) 432 | if self.padding_side == "right": 433 | if return_attention_mask: 434 | encoded_inputs["attention_mask"] = [1] * len(required_input) + [0] * difference 435 | if "token_type_ids" in encoded_inputs: 436 | encoded_inputs["token_type_ids"] = ( 437 | encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference 438 | ) 439 | if "special_tokens_mask" in encoded_inputs: 440 | encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference 441 | encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference 442 | if class_type == "gene": 443 | encoded_inputs["labels"] = encoded_inputs["labels"] + [-100] * difference 444 | elif self.padding_side == "left": 445 | if return_attention_mask: 446 | encoded_inputs["attention_mask"] = [0] * difference + [1] * len(required_input) 447 | if "token_type_ids" in encoded_inputs: 448 | encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ 449 | "token_type_ids" 450 | ] 451 | if "special_tokens_mask" in encoded_inputs: 452 | encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] 453 | encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input 454 | if class_type == "gene": 455 | encoded_inputs["labels"] = [-100] * difference + encoded_inputs["labels"] 456 | else: 457 | raise ValueError("Invalid padding strategy:" + str(self.padding_side)) 458 | elif return_attention_mask and "attention_mask" not in encoded_inputs: 459 | encoded_inputs["attention_mask"] = [1] * len(required_input) 460 | 461 | return encoded_inputs 462 | 463 | def get_special_tokens_mask( 464 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False 465 | ) -> List[int]: 466 | """ 467 | Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding 468 | special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. 469 | Args: 470 | token_ids_0 (:obj:`List[int]`): 471 | List of ids of the first sequence. 472 | token_ids_1 (:obj:`List[int]`, `optional`): 473 | List of ids of the second sequence. 474 | already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): 475 | Whether or not the token list is already formatted with special tokens for the model. 476 | Returns: 477 | A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. 478 | """ 479 | assert already_has_special_tokens and token_ids_1 is None, ( 480 | "You cannot use ``already_has_special_tokens=False`` with this tokenizer. " 481 | "Please use a slow (full python) tokenizer to activate this argument." 482 | "Or set `return_special_tokens_mask=True` when calling the encoding method " 483 | "to get the special tokens mask in any tokenizer. " 484 | ) 485 | 486 | all_special_ids = self.all_special_ids # cache the property 487 | 488 | special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0] 489 | 490 | return special_tokens_mask 491 | 492 | def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]: 493 | """ 494 | Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the 495 | vocabulary. 496 | Args: 497 | tokens (:obj:`str` or :obj:`List[str]`): One or several token(s) to convert to token id(s). 498 | Returns: 499 | :obj:`int` or :obj:`List[int]`: The token id or list of token ids. 500 | """ 501 | if tokens is None: 502 | return None 503 | 504 | if isinstance(tokens, str): 505 | return self._convert_token_to_id_with_added_voc(tokens) 506 | 507 | ids = [] 508 | for token in tokens: 509 | ids.append(self._convert_token_to_id_with_added_voc(token)) 510 | return ids 511 | 512 | def _convert_token_to_id_with_added_voc(self, token): 513 | if token is None: 514 | return None 515 | 516 | return token_dictionary.get(token) 517 | 518 | def __len__(self): 519 | return len(token_dictionary) 520 | 521 | 522 | # collator functions 523 | 524 | class DataCollatorForGeneClassification(DataCollatorForTokenClassification): 525 | """ 526 | Data collator that will dynamically pad the inputs received, as well as the labels. 527 | Args: 528 | tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`): 529 | The tokenizer used for encoding the data. 530 | padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): 531 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 532 | among: 533 | * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 534 | sequence if provided). 535 | * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the 536 | maximum acceptable input length for the model if that argument is not provided. 537 | * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of 538 | different lengths). 539 | max_length (:obj:`int`, `optional`): 540 | Maximum length of the returned list and optionally padding length (see above). 541 | pad_to_multiple_of (:obj:`int`, `optional`): 542 | If set will pad the sequence to a multiple of the provided value. 543 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 544 | 7.5 (Volta). 545 | label_pad_token_id (:obj:`int`, `optional`, defaults to -100): 546 | The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions). 547 | """ 548 | 549 | tokenizer = PrecollatorForGeneAndCellClassification() 550 | class_type = "gene" 551 | padding: Union[bool, str, PaddingStrategy] = True 552 | max_length: Optional[int] = None 553 | pad_to_multiple_of: Optional[int] = None 554 | label_pad_token_id: int = -100 555 | 556 | def __init__(self, *args, **kwargs) -> None: 557 | super().__init__( 558 | tokenizer=self.tokenizer, 559 | padding=self.padding, 560 | max_length=self.max_length, 561 | pad_to_multiple_of=self.pad_to_multiple_of, 562 | label_pad_token_id=self.label_pad_token_id, 563 | *args, **kwargs) 564 | 565 | def _prepare_batch(self, features): 566 | label_name = "label" if "label" in features[0].keys() else "labels" 567 | labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None 568 | batch = self.tokenizer.pad( 569 | features, 570 | class_type=self.class_type, 571 | padding=self.padding, 572 | max_length=self.max_length, 573 | pad_to_multiple_of=self.pad_to_multiple_of, 574 | return_tensors="pt", 575 | ) 576 | return batch 577 | 578 | def __call__(self, features): 579 | batch = self._prepare_batch(features) 580 | 581 | batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()} 582 | return batch 583 | 584 | 585 | class DataCollatorForCellClassification(DataCollatorForGeneClassification): 586 | 587 | class_type = "cell" 588 | 589 | def _prepare_batch(self, features): 590 | 591 | batch = super()._prepare_batch(features) 592 | 593 | # Special handling for labels. 594 | # Ensure that tensor is created with the correct type 595 | # (it should be automatically the case, but let's make sure of it.) 596 | first = features[0] 597 | if "label" in first and first["label"] is not None: 598 | label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"] 599 | dtype = torch.long if isinstance(label, int) else torch.float 600 | batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype) 601 | 602 | return batch 603 | -------------------------------------------------------------------------------- /geneformer_001/geneformer/emb_extractor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Geneformer embedding extractor. 3 | 4 | Usage: 5 | from geneformer import EmbExtractor 6 | embex = EmbExtractor(model_type="CellClassifier", 7 | num_classes=3, 8 | emb_mode="cell", 9 | cell_emb_style="mean_pool", 10 | filter_data={"cell_type":["cardiomyocyte"]}, 11 | max_ncells=1000, 12 | max_ncells_to_plot=1000, 13 | emb_layer=-1, 14 | emb_label=["disease","cell_type"], 15 | labels_to_plot=["disease","cell_type"], 16 | forward_batch_size=100, 17 | nproc=16) 18 | embs = embex.extract_embs("path/to/model", 19 | "path/to/input_data", 20 | "path/to/output_directory", 21 | "output_prefix") 22 | embex.plot_embs(embs=embs, 23 | plot_style="heatmap", 24 | output_directory="path/to/output_directory", 25 | output_prefix="output_prefix") 26 | 27 | """ 28 | 29 | # imports 30 | import logging 31 | import anndata 32 | import matplotlib.pyplot as plt 33 | import numpy as np 34 | import pandas as pd 35 | import pickle 36 | import scanpy as sc 37 | import seaborn as sns 38 | import torch 39 | from collections import Counter 40 | from pathlib import Path 41 | from tqdm.notebook import trange 42 | from transformers import BertForMaskedLM, BertForTokenClassification, BertForSequenceClassification 43 | 44 | from .tokenizer import TOKEN_DICTIONARY_FILE 45 | 46 | from .in_silico_perturber import downsample_and_sort, \ 47 | gen_attention_mask, \ 48 | get_model_input_size, \ 49 | load_and_filter, \ 50 | load_model, \ 51 | mean_nonpadding_embs, \ 52 | pad_tensor_list, \ 53 | quant_layers 54 | 55 | logger = logging.getLogger(__name__) 56 | 57 | # average embedding position of goal cell states 58 | def get_embs(model, 59 | filtered_input_data, 60 | emb_mode, 61 | layer_to_quant, 62 | pad_token_id, 63 | forward_batch_size): 64 | 65 | model_input_size = get_model_input_size(model) 66 | total_batch_length = len(filtered_input_data) 67 | if ((total_batch_length-1)/forward_batch_size).is_integer(): 68 | forward_batch_size = forward_batch_size-1 69 | 70 | embs_list = [] 71 | for i in trange(0, total_batch_length, forward_batch_size): 72 | max_range = min(i+forward_batch_size, total_batch_length) 73 | 74 | minibatch = filtered_input_data.select([i for i in range(i, max_range)]) 75 | max_len = max(minibatch["length"]) 76 | original_lens = torch.tensor(minibatch["length"]).to("cuda") 77 | minibatch.set_format(type="torch") 78 | 79 | input_data_minibatch = minibatch["input_ids"] 80 | input_data_minibatch = pad_tensor_list(input_data_minibatch, 81 | max_len, 82 | pad_token_id, 83 | model_input_size) 84 | 85 | with torch.no_grad(): 86 | outputs = model( 87 | input_ids = input_data_minibatch.to("cuda"), 88 | attention_mask = gen_attention_mask(minibatch) 89 | ) 90 | 91 | embs_i = outputs.hidden_states[layer_to_quant] 92 | 93 | if emb_mode == "cell": 94 | mean_embs = mean_nonpadding_embs(embs_i, original_lens) 95 | embs_list += [mean_embs] 96 | 97 | del outputs 98 | del minibatch 99 | del input_data_minibatch 100 | del embs_i 101 | del mean_embs 102 | torch.cuda.empty_cache() 103 | 104 | embs_stack = torch.cat(embs_list) 105 | return embs_stack 106 | 107 | def label_embs(embs, downsampled_data, emb_labels): 108 | embs_df = pd.DataFrame(embs.cpu()) 109 | if emb_labels is not None: 110 | for label in emb_labels: 111 | emb_label = downsampled_data[label] 112 | embs_df[label] = emb_label 113 | return embs_df 114 | 115 | def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict): 116 | only_embs_df = embs_df.iloc[:,:emb_dims] 117 | only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str) 118 | only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(str) 119 | vars_dict = {"embs": only_embs_df.columns} 120 | obs_dict = {"cell_id": list(only_embs_df.index), 121 | f"{label}": list(embs_df[label])} 122 | adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict) 123 | sc.tl.pca(adata, svd_solver='arpack') 124 | sc.pp.neighbors(adata) 125 | sc.tl.umap(adata) 126 | sns.set(rc={'figure.figsize':(10,10)}, font_scale=2.3) 127 | sns.set_style("white") 128 | default_kwargs_dict = {"palette":"Set2", "size":200} 129 | if kwargs_dict is not None: 130 | default_kwargs_dict.update(kwargs_dict) 131 | 132 | sc.pl.umap(adata, color=label, save=output_file, **default_kwargs_dict) 133 | 134 | 135 | def gen_heatmap_class_colors(labels, df): 136 | pal = sns.cubehelix_palette(len(Counter(labels).keys()), light=0.9, dark=0.1, hue=1, reverse=True, start=1, rot=-2) 137 | lut = dict(zip(map(str, Counter(labels).keys()), pal)) 138 | colors = pd.Series(labels, index=df.index).map(lut) 139 | return colors 140 | 141 | def gen_heatmap_class_dict(classes, label_colors_series): 142 | class_color_dict_df = pd.DataFrame({"classes": classes, "color": label_colors_series}) 143 | class_color_dict_df = class_color_dict_df.drop_duplicates(subset=["classes"]) 144 | return dict(zip(class_color_dict_df["classes"],class_color_dict_df["color"])) 145 | 146 | def make_colorbar(embs_df, label): 147 | 148 | labels = list(embs_df[label]) 149 | 150 | cell_type_colors = gen_heatmap_class_colors(labels, embs_df) 151 | label_colors = pd.DataFrame(cell_type_colors, columns=[label]) 152 | 153 | for i,row in label_colors.iterrows(): 154 | colors=row[0] 155 | if len(colors)!=3 or any(np.isnan(colors)): 156 | print(i,colors) 157 | 158 | label_colors.isna().sum() 159 | 160 | # create dictionary for colors and classes 161 | label_color_dict = gen_heatmap_class_dict(labels, label_colors[label]) 162 | return label_colors, label_color_dict 163 | 164 | def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict): 165 | sns.set_style("white") 166 | sns.set(font_scale=2) 167 | plt.figure(figsize=(15, 15), dpi=150) 168 | label_colors, label_color_dict = make_colorbar(embs_df, label) 169 | 170 | default_kwargs_dict = {"row_cluster": True, 171 | "col_cluster": True, 172 | "row_colors": label_colors, 173 | "standard_scale": 1, 174 | "linewidths": 0, 175 | "xticklabels": False, 176 | "yticklabels": False, 177 | "figsize": (15,15), 178 | "center": 0, 179 | "cmap": "magma"} 180 | 181 | if kwargs_dict is not None: 182 | default_kwargs_dict.update(kwargs_dict) 183 | g = sns.clustermap(embs_df.iloc[:,0:emb_dims].apply(pd.to_numeric), **default_kwargs_dict) 184 | 185 | plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right") 186 | 187 | for label_color in list(label_color_dict.keys()): 188 | g.ax_col_dendrogram.bar(0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0) 189 | 190 | l1 = g.ax_col_dendrogram.legend(title=f"{label}", 191 | loc="lower center", 192 | ncol=4, 193 | bbox_to_anchor=(0.5, 1), 194 | facecolor="white") 195 | 196 | plt.savefig(output_file, bbox_inches='tight') 197 | 198 | class EmbExtractor: 199 | valid_option_dict = { 200 | "model_type": {"Pretrained","GeneClassifier","CellClassifier"}, 201 | "num_classes": {int}, 202 | "emb_mode": {"cell","gene"}, 203 | "cell_emb_style": {"mean_pool"}, 204 | "filter_data": {None, dict}, 205 | "max_ncells": {None, int}, 206 | "emb_layer": {-1, 0}, 207 | "emb_label": {None, list}, 208 | "labels_to_plot": {None, list}, 209 | "forward_batch_size": {int}, 210 | "nproc": {int}, 211 | } 212 | def __init__( 213 | self, 214 | model_type="Pretrained", 215 | num_classes=0, 216 | emb_mode="cell", 217 | cell_emb_style="mean_pool", 218 | filter_data=None, 219 | max_ncells=1000, 220 | emb_layer=-1, 221 | emb_label=None, 222 | labels_to_plot=None, 223 | forward_batch_size=100, 224 | nproc=4, 225 | token_dictionary_file=TOKEN_DICTIONARY_FILE, 226 | ): 227 | """ 228 | Initialize embedding extractor. 229 | 230 | Parameters 231 | ---------- 232 | model_type : {"Pretrained","GeneClassifier","CellClassifier"} 233 | Whether model is the pretrained Geneformer or a fine-tuned gene or cell classifier. 234 | num_classes : int 235 | If model is a gene or cell classifier, specify number of classes it was trained to classify. 236 | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier. 237 | emb_mode : {"cell","gene"} 238 | Whether to output cell or gene embeddings. 239 | cell_emb_style : "mean_pool" 240 | Method for summarizing cell embeddings. 241 | Currently only option is mean pooling of gene embeddings for given cell. 242 | filter_data : None, dict 243 | Default is to extract embeddings from all input data. 244 | Otherwise, dictionary specifying .dataset column name and list of values to filter by. 245 | max_ncells : None, int 246 | Maximum number of cells to extract embeddings from. 247 | Default is 1000 cells randomly sampled from input data. 248 | If None, will extract embeddings from all cells. 249 | emb_layer : {-1, 0} 250 | Embedding layer to extract. 251 | The last layer is most specifically weighted to optimize the given learning objective. 252 | Generally, it is best to extract the 2nd to last layer to get a more general representation. 253 | -1: 2nd to last layer 254 | 0: last layer 255 | emb_label : None, list 256 | List of column name(s) in .dataset to add as labels to embedding output. 257 | labels_to_plot : None, list 258 | Cell labels to plot. 259 | Shown as color bar in heatmap. 260 | Shown as cell color in umap. 261 | Plotting umap requires labels to plot. 262 | forward_batch_size : int 263 | Batch size for forward pass. 264 | nproc : int 265 | Number of CPU processes to use. 266 | token_dictionary_file : Path 267 | Path to pickle file containing token dictionary (Ensembl ID:token). 268 | """ 269 | 270 | self.model_type = model_type 271 | self.num_classes = num_classes 272 | self.emb_mode = emb_mode 273 | self.cell_emb_style = cell_emb_style 274 | self.filter_data = filter_data 275 | self.max_ncells = max_ncells 276 | self.emb_layer = emb_layer 277 | self.emb_label = emb_label 278 | self.labels_to_plot = labels_to_plot 279 | self.forward_batch_size = forward_batch_size 280 | self.nproc = nproc 281 | 282 | self.validate_options() 283 | 284 | # load token dictionary (Ensembl IDs:token) 285 | with open(token_dictionary_file, "rb") as f: 286 | self.gene_token_dict = pickle.load(f) 287 | 288 | self.pad_token_id = self.gene_token_dict.get("") 289 | 290 | 291 | def validate_options(self): 292 | # first disallow options under development 293 | if self.emb_mode == "gene": 294 | logger.error( 295 | "Extraction and plotting of gene-level embeddings currently under development. " \ 296 | "Current valid option for 'emb_mode': 'cell'" 297 | ) 298 | raise 299 | 300 | # confirm arguments are within valid options and compatible with each other 301 | for attr_name,valid_options in self.valid_option_dict.items(): 302 | attr_value = self.__dict__[attr_name] 303 | if type(attr_value) not in {list, dict}: 304 | if attr_value in valid_options: 305 | continue 306 | valid_type = False 307 | for option in valid_options: 308 | if (option in [int,list,dict]) and isinstance(attr_value, option): 309 | valid_type = True 310 | break 311 | if valid_type: 312 | continue 313 | logger.error( 314 | f"Invalid option for {attr_name}. " \ 315 | f"Valid options for {attr_name}: {valid_options}" 316 | ) 317 | raise 318 | 319 | if self.filter_data is not None: 320 | for key,value in self.filter_data.items(): 321 | if type(value) != list: 322 | self.filter_data[key] = [value] 323 | logger.warning( 324 | "Values in filter_data dict must be lists. " \ 325 | f"Changing {key} value to list ([{value}]).") 326 | 327 | def extract_embs(self, 328 | model_directory, 329 | input_data_file, 330 | output_directory, 331 | output_prefix): 332 | """ 333 | Extract embeddings from input data and save as results in output_directory. 334 | 335 | Parameters 336 | ---------- 337 | model_directory : Path 338 | Path to directory containing model 339 | input_data_file : Path 340 | Path to directory containing .dataset inputs 341 | output_directory : Path 342 | Path to directory where embedding data will be saved as csv 343 | output_prefix : str 344 | Prefix for output file 345 | """ 346 | 347 | filtered_input_data = load_and_filter(self.filter_data, self.nproc, input_data_file) 348 | downsampled_data = downsample_and_sort(filtered_input_data, self.max_ncells) 349 | model = load_model(self.model_type, self.num_classes, model_directory) 350 | layer_to_quant = quant_layers(model)+self.emb_layer 351 | embs = get_embs(model, 352 | downsampled_data, 353 | self.emb_mode, 354 | layer_to_quant, 355 | self.pad_token_id, 356 | self.forward_batch_size) 357 | embs_df = label_embs(embs, downsampled_data, self.emb_label) 358 | 359 | # save embeddings to output_path 360 | output_path = (Path(output_directory) / output_prefix).with_suffix(".csv") 361 | embs_df.to_csv(output_path) 362 | 363 | return embs_df 364 | 365 | def plot_embs(self, 366 | embs, 367 | plot_style, 368 | output_directory, 369 | output_prefix, 370 | max_ncells_to_plot=1000, 371 | kwargs_dict=None): 372 | 373 | """ 374 | Plot embeddings, coloring by provided labels. 375 | 376 | Parameters 377 | ---------- 378 | embs : pandas.core.frame.DataFrame 379 | Pandas dataframe containing embeddings output from extract_embs 380 | plot_style : str 381 | Style of plot: "heatmap" or "umap" 382 | output_directory : Path 383 | Path to directory where plots will be saved as pdf 384 | output_prefix : str 385 | Prefix for output file 386 | max_ncells_to_plot : None, int 387 | Maximum number of cells to plot. 388 | Default is 1000 cells randomly sampled from embeddings. 389 | If None, will plot embeddings from all cells. 390 | kwargs_dict : dict 391 | Dictionary of kwargs to pass to plotting function. 392 | """ 393 | 394 | if plot_style not in ["heatmap","umap"]: 395 | logger.error( 396 | "Invalid option for 'plot_style'. " \ 397 | "Valid options: {'heatmap','umap'}" 398 | ) 399 | raise 400 | 401 | if (plot_style == "umap") and (self.labels_to_plot is None): 402 | logger.error( 403 | "Plotting UMAP requires 'labels_to_plot'. " 404 | ) 405 | raise 406 | 407 | if max_ncells_to_plot > self.max_ncells: 408 | max_ncells_to_plot = self.max_ncells 409 | logger.warning( 410 | "max_ncells_to_plot must be <= max_ncells. " \ 411 | f"Changing max_ncells_to_plot to {self.max_ncells}.") 412 | 413 | if (max_ncells_to_plot is not None) \ 414 | and (max_ncells_to_plot < self.max_ncells): 415 | embs = embs.sample(max_ncells_to_plot, axis=0) 416 | 417 | if self.emb_label is None: 418 | label_len = 0 419 | else: 420 | label_len = len(self.emb_label) 421 | 422 | emb_dims = embs.shape[1] - label_len 423 | 424 | if self.emb_label is None: 425 | emb_labels = None 426 | else: 427 | emb_labels = embs.columns[emb_dims:] 428 | 429 | if plot_style == "umap": 430 | for label in self.labels_to_plot: 431 | if label not in emb_labels: 432 | logger.warning( 433 | f"Label {label} from labels_to_plot " \ 434 | f"not present in provided embeddings dataframe.") 435 | continue 436 | output_prefix_label = "_" + output_prefix + f"_umap_{label}" 437 | output_file = (Path(output_directory) / output_prefix_label).with_suffix(".pdf") 438 | plot_umap(embs, emb_dims, label, output_prefix_label, kwargs_dict) 439 | 440 | if plot_style == "heatmap": 441 | for label in self.labels_to_plot: 442 | if label not in emb_labels: 443 | logger.warning( 444 | f"Label {label} from labels_to_plot " \ 445 | f"not present in provided embeddings dataframe.") 446 | continue 447 | output_prefix_label = output_prefix + f"_heatmap_{label}" 448 | output_file = (Path(output_directory) / output_prefix_label).with_suffix(".pdf") 449 | plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict) 450 | -------------------------------------------------------------------------------- /geneformer_001/geneformer/gene_median_dictionary.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PharMolix/LangCell/69e41ef2fae485b67703294b50f3150bb1d5bb9b/geneformer_001/geneformer/gene_median_dictionary.pkl -------------------------------------------------------------------------------- /geneformer_001/geneformer/gene_name_id_dict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PharMolix/LangCell/69e41ef2fae485b67703294b50f3150bb1d5bb9b/geneformer_001/geneformer/gene_name_id_dict.pkl -------------------------------------------------------------------------------- /geneformer_001/geneformer/in_silico_perturber_stats.py: -------------------------------------------------------------------------------- 1 | """ 2 | Geneformer in silico perturber stats generator. 3 | 4 | Usage: 5 | from geneformer import InSilicoPerturberStats 6 | ispstats = InSilicoPerturberStats(mode="goal_state_shift", 7 | combos=0, 8 | anchor_gene=None, 9 | cell_states_to_model={"state_key": "disease", 10 | "start_state": "dcm", 11 | "goal_state": "nf", 12 | "alt_states": ["hcm", "other1", "other2"]}) 13 | ispstats.get_stats("path/to/input_data", 14 | None, 15 | "path/to/output_directory", 16 | "output_prefix") 17 | """ 18 | 19 | 20 | import os 21 | import logging 22 | import numpy as np 23 | import pandas as pd 24 | import pickle 25 | import random 26 | import statsmodels.stats.multitest as smt 27 | from pathlib import Path 28 | from scipy.stats import ranksums 29 | from sklearn.mixture import GaussianMixture 30 | from tqdm.notebook import trange, tqdm 31 | 32 | from .in_silico_perturber import flatten_list 33 | 34 | from .tokenizer import TOKEN_DICTIONARY_FILE 35 | 36 | GENE_NAME_ID_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl" 37 | 38 | logger = logging.getLogger(__name__) 39 | 40 | # invert dictionary keys/values 41 | def invert_dict(dictionary): 42 | return {v: k for k, v in dictionary.items()} 43 | 44 | # read raw dictionary files 45 | def read_dictionaries(input_data_directory, cell_or_gene_emb, anchor_token): 46 | file_found = 0 47 | file_path_list = [] 48 | dict_list = [] 49 | for file in os.listdir(input_data_directory): 50 | # process only _raw.pickle files 51 | if file.endswith("_raw.pickle"): 52 | file_found = 1 53 | file_path_list += [f"{input_data_directory}/{file}"] 54 | for file_path in tqdm(file_path_list): 55 | with open(file_path, "rb") as fp: 56 | cos_sims_dict = pickle.load(fp) 57 | if cell_or_gene_emb == "cell": 58 | cell_emb_dict = {k: v for k, 59 | v in cos_sims_dict.items() if v and "cell_emb" in k} 60 | dict_list += [cell_emb_dict] 61 | elif cell_or_gene_emb == "gene": 62 | gene_emb_dict = {k: v for k, 63 | v in cos_sims_dict.items() if v and anchor_token == k[0]} 64 | dict_list += [gene_emb_dict] 65 | if file_found == 0: 66 | logger.error( 67 | "No raw data for processing found within provided directory. " \ 68 | "Please ensure data files end with '_raw.pickle'.") 69 | raise 70 | return dict_list 71 | 72 | # get complete gene list 73 | def get_gene_list(dict_list,mode): 74 | if mode == "cell": 75 | position = 0 76 | elif mode == "gene": 77 | position = 1 78 | gene_set = set() 79 | for dict_i in dict_list: 80 | gene_set.update([k[position] for k, v in dict_i.items() if v]) 81 | gene_list = list(gene_set) 82 | if mode == "gene": 83 | gene_list.remove("cell_emb") 84 | gene_list.sort() 85 | return gene_list 86 | 87 | def token_tuple_to_ensembl_ids(token_tuple, gene_token_id_dict): 88 | return tuple([gene_token_id_dict.get(i, np.nan) for i in token_tuple]) 89 | 90 | def n_detections(token, dict_list, mode, anchor_token): 91 | cos_sim_megalist = [] 92 | for dict_i in dict_list: 93 | if mode == "cell": 94 | cos_sim_megalist += dict_i.get((token, "cell_emb"),[]) 95 | elif mode == "gene": 96 | cos_sim_megalist += dict_i.get((anchor_token, token),[]) 97 | return len(cos_sim_megalist) 98 | 99 | def get_fdr(pvalues): 100 | return list(smt.multipletests(pvalues, alpha=0.05, method="fdr_bh")[1]) 101 | 102 | def get_impact_component(test_value, gaussian_mixture_model): 103 | impact_border = gaussian_mixture_model.means_[0][0] 104 | nonimpact_border = gaussian_mixture_model.means_[1][0] 105 | if test_value > nonimpact_border: 106 | impact_component = 0 107 | elif test_value < impact_border: 108 | impact_component = 1 109 | else: 110 | impact_component_raw = gaussian_mixture_model.predict([[test_value]])[0] 111 | if impact_component_raw == 1: 112 | impact_component = 0 113 | elif impact_component_raw == 0: 114 | impact_component = 1 115 | return impact_component 116 | 117 | # aggregate data for single perturbation in multiple cells 118 | def isp_aggregate_grouped_perturb(cos_sims_df, dict_list): 119 | names=["Cosine_shift"] 120 | cos_sims_full_df = pd.DataFrame(columns=names) 121 | 122 | cos_shift_data = [] 123 | token = cos_sims_df["Gene"][0] 124 | for dict_i in dict_list: 125 | cos_shift_data += dict_i.get((token, "cell_emb"),[]) 126 | cos_sims_full_df["Cosine_shift"] = cos_shift_data 127 | return cos_sims_full_df 128 | 129 | # stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations 130 | def isp_stats_to_goal_state(cos_sims_df, dict_list, cell_states_to_model, genes_perturbed): 131 | cell_state_key = cell_states_to_model["start_state"] 132 | if ("alt_states" not in cell_states_to_model.keys()) \ 133 | or (len(cell_states_to_model["alt_states"]) == 0) \ 134 | or (cell_states_to_model["alt_states"] == [None]): 135 | alt_end_state_exists = False 136 | elif (len(cell_states_to_model["alt_states"]) > 0) and (cell_states_to_model["alt_states"] != [None]): 137 | alt_end_state_exists = True 138 | 139 | # for single perturbation in multiple cells, there are no random perturbations to compare to 140 | if genes_perturbed != "all": 141 | names=["Shift_to_goal_end", 142 | "Shift_to_alt_end"] 143 | if alt_end_state_exists == False: 144 | names.remove("Shift_to_alt_end") 145 | cos_sims_full_df = pd.DataFrame(columns=names) 146 | 147 | cos_shift_data = [] 148 | token = cos_sims_df["Gene"][0] 149 | for dict_i in dict_list: 150 | cos_shift_data += dict_i.get((token, "cell_emb"),[]) 151 | if alt_end_state_exists == False: 152 | cos_sims_full_df["Shift_to_goal_end"] = [goal_end for start_state,goal_end in cos_shift_data] 153 | if alt_end_state_exists == True: 154 | cos_sims_full_df["Shift_to_goal_end"] = [goal_end for start_state,goal_end,alt_end in cos_shift_data] 155 | cos_sims_full_df["Shift_to_alt_end"] = [alt_end for start_state,goal_end,alt_end in cos_shift_data] 156 | 157 | # sort by shift to desired state 158 | cos_sims_full_df = cos_sims_full_df.sort_values(by=["Shift_to_goal_end"], 159 | ascending=[False]) 160 | return cos_sims_full_df 161 | 162 | elif genes_perturbed == "all": 163 | random_tuples = [] 164 | for i in trange(cos_sims_df.shape[0]): 165 | token = cos_sims_df["Gene"][i] 166 | for dict_i in dict_list: 167 | random_tuples += dict_i.get((token, "cell_emb"),[]) 168 | 169 | if alt_end_state_exists == False: 170 | goal_end_random_megalist = [goal_end for start_state,goal_end in random_tuples] 171 | elif alt_end_state_exists == True: 172 | goal_end_random_megalist = [goal_end for start_state,goal_end,alt_end in random_tuples] 173 | alt_end_random_megalist = [alt_end for start_state,goal_end,alt_end in random_tuples] 174 | 175 | # downsample to improve speed of ranksums 176 | if len(goal_end_random_megalist) > 100_000: 177 | random.seed(42) 178 | goal_end_random_megalist = random.sample(goal_end_random_megalist, k=100_000) 179 | if alt_end_state_exists == True: 180 | if len(alt_end_random_megalist) > 100_000: 181 | random.seed(42) 182 | alt_end_random_megalist = random.sample(alt_end_random_megalist, k=100_000) 183 | 184 | names=["Gene", 185 | "Gene_name", 186 | "Ensembl_ID", 187 | "Shift_to_goal_end", 188 | "Shift_to_alt_end", 189 | "Goal_end_vs_random_pval", 190 | "Alt_end_vs_random_pval"] 191 | if alt_end_state_exists == False: 192 | names.remove("Shift_to_alt_end") 193 | names.remove("Alt_end_vs_random_pval") 194 | cos_sims_full_df = pd.DataFrame(columns=names) 195 | 196 | for i in trange(cos_sims_df.shape[0]): 197 | token = cos_sims_df["Gene"][i] 198 | name = cos_sims_df["Gene_name"][i] 199 | ensembl_id = cos_sims_df["Ensembl_ID"][i] 200 | cos_shift_data = [] 201 | 202 | for dict_i in dict_list: 203 | cos_shift_data += dict_i.get((token, "cell_emb"),[]) 204 | 205 | if alt_end_state_exists == False: 206 | goal_end_cos_sim_megalist = [goal_end for start_state,goal_end in cos_shift_data] 207 | elif alt_end_state_exists == True: 208 | goal_end_cos_sim_megalist = [goal_end for start_state,goal_end,alt_end in cos_shift_data] 209 | alt_end_cos_sim_megalist = [alt_end for start_state,goal_end,alt_end in cos_shift_data] 210 | mean_alt_end = np.mean(alt_end_cos_sim_megalist) 211 | pval_alt_end = ranksums(alt_end_random_megalist,alt_end_cos_sim_megalist).pvalue 212 | 213 | mean_goal_end = np.mean(goal_end_cos_sim_megalist) 214 | pval_goal_end = ranksums(goal_end_random_megalist,goal_end_cos_sim_megalist).pvalue 215 | 216 | if alt_end_state_exists == False: 217 | data_i = [token, 218 | name, 219 | ensembl_id, 220 | mean_goal_end, 221 | pval_goal_end] 222 | elif alt_end_state_exists == True: 223 | data_i = [token, 224 | name, 225 | ensembl_id, 226 | mean_goal_end, 227 | mean_alt_end, 228 | pval_goal_end, 229 | pval_alt_end] 230 | 231 | cos_sims_df_i = pd.DataFrame(dict(zip(names,data_i)),index=[i]) 232 | cos_sims_full_df = pd.concat([cos_sims_full_df,cos_sims_df_i]) 233 | 234 | cos_sims_full_df["Goal_end_FDR"] = get_fdr(list(cos_sims_full_df["Goal_end_vs_random_pval"])) 235 | if alt_end_state_exists == True: 236 | cos_sims_full_df["Alt_end_FDR"] = get_fdr(list(cos_sims_full_df["Alt_end_vs_random_pval"])) 237 | 238 | # quantify number of detections of each gene 239 | cos_sims_full_df["N_Detections"] = [n_detections(i, dict_list, "cell", None) for i in cos_sims_full_df["Gene"]] 240 | 241 | # sort by shift to desired state\ 242 | cos_sims_full_df["Sig"] = [1 if fdr<0.05 else 0 for fdr in cos_sims_full_df["Goal_end_FDR"]] 243 | cos_sims_full_df = cos_sims_full_df.sort_values(by=["Sig", 244 | "Shift_to_goal_end", 245 | "Goal_end_FDR"], 246 | ascending=[False,False,True]) 247 | 248 | return cos_sims_full_df 249 | 250 | # stats comparing cos sim shifts of test perturbations vs null distribution 251 | def isp_stats_vs_null(cos_sims_df, dict_list, null_dict_list): 252 | cos_sims_full_df = cos_sims_df.copy() 253 | 254 | cos_sims_full_df["Test_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float) 255 | cos_sims_full_df["Null_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float) 256 | cos_sims_full_df["Test_vs_null_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float) 257 | cos_sims_full_df["Test_vs_null_pval"] = np.zeros(cos_sims_df.shape[0], dtype=float) 258 | cos_sims_full_df["Test_vs_null_FDR"] = np.zeros(cos_sims_df.shape[0], dtype=float) 259 | cos_sims_full_df["N_Detections_test"] = np.zeros(cos_sims_df.shape[0], dtype="uint32") 260 | cos_sims_full_df["N_Detections_null"] = np.zeros(cos_sims_df.shape[0], dtype="uint32") 261 | 262 | for i in trange(cos_sims_df.shape[0]): 263 | token = cos_sims_df["Gene"][i] 264 | test_shifts = [] 265 | null_shifts = [] 266 | 267 | for dict_i in dict_list: 268 | test_shifts += dict_i.get((token, "cell_emb"),[]) 269 | 270 | for dict_i in null_dict_list: 271 | null_shifts += dict_i.get((token, "cell_emb"),[]) 272 | 273 | cos_sims_full_df.loc[i, "Test_avg_shift"] = np.mean(test_shifts) 274 | cos_sims_full_df.loc[i, "Null_avg_shift"] = np.mean(null_shifts) 275 | cos_sims_full_df.loc[i, "Test_vs_null_avg_shift"] = np.mean(test_shifts)-np.mean(null_shifts) 276 | cos_sims_full_df.loc[i, "Test_vs_null_pval"] = ranksums(test_shifts, 277 | null_shifts, nan_policy="omit").pvalue 278 | 279 | cos_sims_full_df.loc[i, "N_Detections_test"] = len(test_shifts) 280 | cos_sims_full_df.loc[i, "N_Detections_null"] = len(null_shifts) 281 | 282 | cos_sims_full_df["Test_vs_null_FDR"] = get_fdr(cos_sims_full_df["Test_vs_null_pval"]) 283 | 284 | cos_sims_full_df["Sig"] = [1 if fdr<0.05 else 0 for fdr in cos_sims_full_df["Test_vs_null_FDR"]] 285 | cos_sims_full_df = cos_sims_full_df.sort_values(by=["Sig", 286 | "Test_vs_null_avg_shift", 287 | "Test_vs_null_FDR"], 288 | ascending=[False,False,True]) 289 | return cos_sims_full_df 290 | 291 | # stats for identifying perturbations with largest effect within a given set of cells 292 | # fits a mixture model to 2 components (impact vs. non-impact) and 293 | # reports the most likely component for each test perturbation 294 | # Note: because assumes given perturbation has a consistent effect in the cells tested, 295 | # we recommend only using the mixture model strategy with uniform cell populations 296 | def isp_stats_mixture_model(cos_sims_df, dict_list, combos, anchor_token): 297 | 298 | names=["Gene", 299 | "Gene_name", 300 | "Ensembl_ID"] 301 | 302 | if combos == 0: 303 | names += ["Test_avg_shift"] 304 | elif combos == 1: 305 | names += ["Anchor_shift", 306 | "Test_token_shift", 307 | "Sum_of_indiv_shifts", 308 | "Combo_shift", 309 | "Combo_minus_sum_shift"] 310 | 311 | names += ["Impact_component", 312 | "Impact_component_percent"] 313 | 314 | cos_sims_full_df = pd.DataFrame(columns=names) 315 | avg_values = [] 316 | gene_names = [] 317 | 318 | for i in trange(cos_sims_df.shape[0]): 319 | token = cos_sims_df["Gene"][i] 320 | name = cos_sims_df["Gene_name"][i] 321 | ensembl_id = cos_sims_df["Ensembl_ID"][i] 322 | cos_shift_data = [] 323 | 324 | for dict_i in dict_list: 325 | if (combos == 0) and (anchor_token is not None): 326 | cos_shift_data += dict_i.get((anchor_token, token),[]) 327 | else: 328 | cos_shift_data += dict_i.get((token, "cell_emb"),[]) 329 | 330 | # Extract values for current gene 331 | if combos == 0: 332 | test_values = cos_shift_data 333 | elif combos == 1: 334 | test_values = [] 335 | for tup in cos_shift_data: 336 | test_values.append(tup[2]) 337 | 338 | if len(test_values) > 0: 339 | avg_value = np.mean(test_values) 340 | avg_values.append(avg_value) 341 | gene_names.append(name) 342 | 343 | # fit Gaussian mixture model to dataset of mean for each gene 344 | avg_values_to_fit = np.array(avg_values).reshape(-1, 1) 345 | gm = GaussianMixture(n_components=2, random_state=0).fit(avg_values_to_fit) 346 | 347 | for i in trange(cos_sims_df.shape[0]): 348 | token = cos_sims_df["Gene"][i] 349 | name = cos_sims_df["Gene_name"][i] 350 | ensembl_id = cos_sims_df["Ensembl_ID"][i] 351 | cos_shift_data = [] 352 | 353 | for dict_i in dict_list: 354 | if (combos == 0) and (anchor_token is not None): 355 | cos_shift_data += dict_i.get((anchor_token, token),[]) 356 | else: 357 | cos_shift_data += dict_i.get((token, "cell_emb"),[]) 358 | 359 | if combos == 0: 360 | mean_test = np.mean(cos_shift_data) 361 | impact_components = [get_impact_component(value,gm) for value in cos_shift_data] 362 | elif combos == 1: 363 | anchor_cos_sim_megalist = [anchor for anchor,token,combo in cos_shift_data] 364 | token_cos_sim_megalist = [token for anchor,token,combo in cos_shift_data] 365 | anchor_plus_token_cos_sim_megalist = [1-((1-anchor)+(1-token)) for anchor,token,combo in cos_shift_data] 366 | combo_anchor_token_cos_sim_megalist = [combo for anchor,token,combo in cos_shift_data] 367 | combo_minus_sum_cos_sim_megalist = [combo-(1-((1-anchor)+(1-token))) for anchor,token,combo in cos_shift_data] 368 | 369 | mean_anchor = np.mean(anchor_cos_sim_megalist) 370 | mean_token = np.mean(token_cos_sim_megalist) 371 | mean_sum = np.mean(anchor_plus_token_cos_sim_megalist) 372 | mean_test = np.mean(combo_anchor_token_cos_sim_megalist) 373 | mean_combo_minus_sum = np.mean(combo_minus_sum_cos_sim_megalist) 374 | 375 | impact_components = [get_impact_component(value,gm) for value in combo_anchor_token_cos_sim_megalist] 376 | 377 | impact_component = get_impact_component(mean_test,gm) 378 | impact_component_percent = np.mean(impact_components)*100 379 | 380 | data_i = [token, 381 | name, 382 | ensembl_id] 383 | if combos == 0: 384 | data_i += [mean_test] 385 | elif combos == 1: 386 | data_i += [mean_anchor, 387 | mean_token, 388 | mean_sum, 389 | mean_test, 390 | mean_combo_minus_sum] 391 | data_i += [impact_component, 392 | impact_component_percent] 393 | 394 | cos_sims_df_i = pd.DataFrame(dict(zip(names,data_i)),index=[i]) 395 | cos_sims_full_df = pd.concat([cos_sims_full_df,cos_sims_df_i]) 396 | 397 | # quantify number of detections of each gene 398 | cos_sims_full_df["N_Detections"] = [n_detections(i, 399 | dict_list, 400 | "gene", 401 | anchor_token) for i in cos_sims_full_df["Gene"]] 402 | 403 | if combos == 0: 404 | cos_sims_full_df = cos_sims_full_df.sort_values(by=["Impact_component", 405 | "Test_avg_shift"], 406 | ascending=[False,True]) 407 | elif combos == 1: 408 | cos_sims_full_df = cos_sims_full_df.sort_values(by=["Impact_component", 409 | "Combo_minus_sum_shift"], 410 | ascending=[False,True]) 411 | return cos_sims_full_df 412 | 413 | class InSilicoPerturberStats: 414 | valid_option_dict = { 415 | "mode": {"goal_state_shift","vs_null","mixture_model","aggregate_data"}, 416 | "combos": {0,1}, 417 | "anchor_gene": {None, str}, 418 | "cell_states_to_model": {None, dict}, 419 | } 420 | def __init__( 421 | self, 422 | mode="mixture_model", 423 | genes_perturbed="all", 424 | combos=0, 425 | anchor_gene=None, 426 | cell_states_to_model=None, 427 | token_dictionary_file=TOKEN_DICTIONARY_FILE, 428 | gene_name_id_dictionary_file=GENE_NAME_ID_DICTIONARY_FILE, 429 | ): 430 | """ 431 | Initialize in silico perturber stats generator. 432 | 433 | Parameters 434 | ---------- 435 | mode : {"goal_state_shift","vs_null","mixture_model","aggregate_data"} 436 | Type of stats. 437 | "goal_state_shift": perturbation vs. random for desired cell state shift 438 | "vs_null": perturbation vs. null from provided null distribution dataset 439 | "mixture_model": perturbation in impact vs. no impact component of mixture model (no goal direction) 440 | "aggregate_data": aggregates cosine shifts for single perturbation in multiple cells 441 | genes_perturbed : "all", list 442 | Genes perturbed in isp experiment. 443 | Default is assuming genes_to_perturb in isp experiment was "all" (each gene in each cell). 444 | Otherwise, may provide a list of ENSEMBL IDs of genes perturbed as a group all together. 445 | combos : {0,1,2} 446 | Whether to perturb genes individually (0), in pairs (1), or in triplets (2). 447 | anchor_gene : None, str 448 | ENSEMBL ID of gene to use as anchor in combination perturbations or in testing effect on downstream genes. 449 | For example, if combos=1 and anchor_gene="ENSG00000136574": 450 | analyzes data for anchor gene perturbed in combination with each other gene. 451 | However, if combos=0 and anchor_gene="ENSG00000136574": 452 | analyzes data for the effect of anchor gene's perturbation on the embedding of each other gene. 453 | cell_states_to_model: None, dict 454 | Cell states to model if testing perturbations that achieve goal state change. 455 | Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states 456 | state_key: key specifying name of column in .dataset that defines the start/goal states 457 | start_state: value in the state_key column that specifies the start state 458 | goal_state: value in the state_key column taht specifies the goal end state 459 | alt_states: list of values in the state_key column that specify the alternate end states 460 | For example: {"state_key": "disease", 461 | "start_state": "dcm", 462 | "goal_state": "nf", 463 | "alt_states": ["hcm", "other1", "other2"]} 464 | token_dictionary_file : Path 465 | Path to pickle file containing token dictionary (Ensembl ID:token). 466 | gene_name_id_dictionary_file : Path 467 | Path to pickle file containing gene name to ID dictionary (gene name:Ensembl ID). 468 | """ 469 | 470 | self.mode = mode 471 | self.genes_perturbed = genes_perturbed 472 | self.combos = combos 473 | self.anchor_gene = anchor_gene 474 | self.cell_states_to_model = cell_states_to_model 475 | 476 | self.validate_options() 477 | 478 | # load token dictionary (Ensembl IDs:token) 479 | with open(token_dictionary_file, "rb") as f: 480 | self.gene_token_dict = pickle.load(f) 481 | 482 | # load gene name dictionary (gene name:Ensembl ID) 483 | with open(gene_name_id_dictionary_file, "rb") as f: 484 | self.gene_name_id_dict = pickle.load(f) 485 | 486 | if anchor_gene is None: 487 | self.anchor_token = None 488 | else: 489 | self.anchor_token = self.gene_token_dict[self.anchor_gene] 490 | 491 | def validate_options(self): 492 | for attr_name,valid_options in self.valid_option_dict.items(): 493 | attr_value = self.__dict__[attr_name] 494 | if type(attr_value) not in {list, dict}: 495 | if attr_name in {"anchor_gene"}: 496 | continue 497 | elif attr_value in valid_options: 498 | continue 499 | valid_type = False 500 | for option in valid_options: 501 | if (option in [int,list,dict]) and isinstance(attr_value, option): 502 | valid_type = True 503 | break 504 | if valid_type: 505 | continue 506 | logger.error( 507 | f"Invalid option for {attr_name}. " \ 508 | f"Valid options for {attr_name}: {valid_options}" 509 | ) 510 | raise 511 | 512 | if self.cell_states_to_model is not None: 513 | if len(self.cell_states_to_model.items()) == 1: 514 | logger.warning( 515 | "The single value dictionary for cell_states_to_model will be " \ 516 | "replaced with a dictionary with named keys for start, goal, and alternate states. " \ 517 | "Please specify state_key, start_state, goal_state, and alt_states " \ 518 | "in the cell_states_to_model dictionary for future use. " \ 519 | "For example, cell_states_to_model={" \ 520 | "'state_key': 'disease', " \ 521 | "'start_state': 'dcm', " \ 522 | "'goal_state': 'nf', " \ 523 | "'alt_states': ['hcm', 'other1', 'other2']}" 524 | ) 525 | for key,value in self.cell_states_to_model.items(): 526 | if (len(value) == 3) and isinstance(value, tuple): 527 | if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list): 528 | if len(value[0]) == 1 and len(value[1]) == 1: 529 | all_values = value[0]+value[1]+value[2] 530 | if len(all_values) == len(set(all_values)): 531 | continue 532 | # reformat to the new named key format 533 | state_values = flatten_list(list(self.cell_states_to_model.values())) 534 | self.cell_states_to_model = { 535 | "state_key": list(self.cell_states_to_model.keys())[0], 536 | "start_state": state_values[0][0], 537 | "goal_state": state_values[1][0], 538 | "alt_states": state_values[2:][0] 539 | } 540 | elif set(self.cell_states_to_model.keys()) == {"state_key", "start_state", "goal_state", "alt_states"}: 541 | if (self.cell_states_to_model["state_key"] is None) \ 542 | or (self.cell_states_to_model["start_state"] is None) \ 543 | or (self.cell_states_to_model["goal_state"] is None): 544 | logger.error( 545 | "Please specify 'state_key', 'start_state', and 'goal_state' in cell_states_to_model.") 546 | raise 547 | 548 | if self.cell_states_to_model["start_state"] == self.cell_states_to_model["goal_state"]: 549 | logger.error( 550 | "All states must be unique.") 551 | raise 552 | 553 | if self.cell_states_to_model["alt_states"] is not None: 554 | if type(self.cell_states_to_model["alt_states"]) is not list: 555 | logger.error( 556 | "self.cell_states_to_model['alt_states'] must be a list (even if it is one element)." 557 | ) 558 | raise 559 | if len(self.cell_states_to_model["alt_states"])!= len(set(self.cell_states_to_model["alt_states"])): 560 | logger.error( 561 | "All states must be unique.") 562 | raise 563 | 564 | else: 565 | logger.error( 566 | "cell_states_to_model must only have the following four keys: " \ 567 | "'state_key', 'start_state', 'goal_state', 'alt_states'." \ 568 | "For example, cell_states_to_model={" \ 569 | "'state_key': 'disease', " \ 570 | "'start_state': 'dcm', " \ 571 | "'goal_state': 'nf', " \ 572 | "'alt_states': ['hcm', 'other1', 'other2']}" 573 | ) 574 | raise 575 | 576 | if self.anchor_gene is not None: 577 | self.anchor_gene = None 578 | logger.warning( 579 | "anchor_gene set to None. " \ 580 | "Currently, anchor gene not available " \ 581 | "when modeling multiple cell states.") 582 | 583 | if self.combos > 0: 584 | if self.anchor_gene is None: 585 | logger.error( 586 | "Currently, stats are only supported for combination " \ 587 | "in silico perturbation run with anchor gene. Please add " \ 588 | "anchor gene when using with combos > 0. ") 589 | raise 590 | 591 | if (self.mode == "mixture_model") and (self.genes_perturbed != "all"): 592 | logger.error( 593 | "Mixture model mode requires multiple gene perturbations to fit model " \ 594 | "so is incompatible with a single grouped perturbation.") 595 | raise 596 | if (self.mode == "aggregate_data") and (self.genes_perturbed == "all"): 597 | logger.error( 598 | "Simple data aggregation mode is for single perturbation in multiple cells " \ 599 | "so is incompatible with a genes_perturbed being 'all'.") 600 | raise 601 | 602 | def get_stats(self, 603 | input_data_directory, 604 | null_dist_data_directory, 605 | output_directory, 606 | output_prefix): 607 | """ 608 | Get stats for in silico perturbation data and save as results in output_directory. 609 | 610 | Parameters 611 | ---------- 612 | input_data_directory : Path 613 | Path to directory containing cos_sim dictionary inputs 614 | null_dist_data_directory : Path 615 | Path to directory containing null distribution cos_sim dictionary inputs 616 | output_directory : Path 617 | Path to directory where perturbation data will be saved as .csv 618 | output_prefix : str 619 | Prefix for output .csv 620 | 621 | Outputs 622 | ---------- 623 | Definition of possible columns in .csv output file. 624 | 625 | Of note, not all columns will be present in all output files. 626 | Some columns are specific to particular perturbation modes. 627 | 628 | "Gene": gene token 629 | "Gene_name": gene name 630 | "Ensembl_ID": gene Ensembl ID 631 | "N_Detections": number of cells in which each gene or gene combination was detected in the input dataset 632 | "Sig": 1 if FDR<0.05, otherwise 0 633 | 634 | "Shift_to_goal_end": cosine shift from start state towards goal end state in response to given perturbation 635 | "Shift_to_alt_end": cosine shift from start state towards alternate end state in response to given perturbation 636 | "Goal_end_vs_random_pval": pvalue of cosine shift from start state towards goal end state by Wilcoxon 637 | pvalue compares shift caused by perturbing given gene compared to random genes 638 | "Alt_end_vs_random_pval": pvalue of cosine shift from start state towards alternate end state by Wilcoxon 639 | pvalue compares shift caused by perturbing given gene compared to random genes 640 | "Goal_end_FDR": Benjamini-Hochberg correction of "Goal_end_vs_random_pval" 641 | "Alt_end_FDR": Benjamini-Hochberg correction of "Alt_end_vs_random_pval" 642 | 643 | "Test_avg_shift": cosine shift in response to given perturbation in cells from test distribution 644 | "Null_avg_shift": cosine shift in response to given perturbation in cells from null distribution (e.g. random cells) 645 | "Test_vs_null_avg_shift": difference in cosine shift in cells from test vs. null distribution 646 | (i.e. "Test_avg_shift" minus "Null_avg_shift") 647 | "Test_vs_null_pval": pvalue of cosine shift in test vs. null distribution 648 | "Test_vs_null_FDR": Benjamini-Hochberg correction of "Test_vs_null_pval" 649 | "N_Detections_test": "N_Detections" in cells from test distribution 650 | "N_Detections_null": "N_Detections" in cells from null distribution 651 | 652 | "Anchor_shift": cosine shift in response to given perturbation of anchor gene 653 | "Test_token_shift": cosine shift in response to given perturbation of test gene 654 | "Sum_of_indiv_shifts": sum of cosine shifts in response to individually perturbing test and anchor genes 655 | "Combo_shift": cosine shift in response to given perturbation of both anchor and test gene(s) in combination 656 | "Combo_minus_sum_shift": difference of cosine shifts in response combo perturbation vs. sum of individual perturbations 657 | (i.e. "Combo_shift" minus "Sum_of_indiv_shifts") 658 | "Impact_component": whether the given perturbation was modeled to be within the impact component by the mixture model 659 | 1: within impact component; 0: not within impact component 660 | "Impact_component_percent": percent of cells in which given perturbation was modeled to be within impact component 661 | """ 662 | 663 | if self.mode not in ["goal_state_shift", "vs_null", "mixture_model","aggregate_data"]: 664 | logger.error( 665 | "Currently, only modes available are stats for goal_state_shift, " \ 666 | "vs_null (comparing to null distribution), and " \ 667 | "mixture_model (fitting mixture model for perturbations with or without impact.") 668 | raise 669 | 670 | self.gene_token_id_dict = invert_dict(self.gene_token_dict) 671 | self.gene_id_name_dict = invert_dict(self.gene_name_id_dict) 672 | 673 | # obtain total gene list 674 | if (self.combos == 0) and (self.anchor_token is not None): 675 | # cos sim data for effect of gene perturbation on the embedding of each other gene 676 | dict_list = read_dictionaries(input_data_directory, "gene", self.anchor_token) 677 | gene_list = get_gene_list(dict_list, "gene") 678 | else: 679 | # cos sim data for effect of gene perturbation on the embedding of each cell 680 | dict_list = read_dictionaries(input_data_directory, "cell", self.anchor_token) 681 | gene_list = get_gene_list(dict_list, "cell") 682 | 683 | # initiate results dataframe 684 | cos_sims_df_initial = pd.DataFrame({"Gene": gene_list, 685 | "Gene_name": [self.token_to_gene_name(item) \ 686 | for item in gene_list], \ 687 | "Ensembl_ID": [token_tuple_to_ensembl_ids(genes, self.gene_token_id_dict) \ 688 | if self.genes_perturbed != "all" else \ 689 | self.gene_token_id_dict[genes[1]] \ 690 | if isinstance(genes,tuple) else \ 691 | self.gene_token_id_dict[genes] \ 692 | for genes in gene_list]}, \ 693 | index=[i for i in range(len(gene_list))]) 694 | 695 | if self.mode == "goal_state_shift": 696 | cos_sims_df = isp_stats_to_goal_state(cos_sims_df_initial, dict_list, self.cell_states_to_model, self.genes_perturbed) 697 | 698 | elif self.mode == "vs_null": 699 | null_dict_list = read_dictionaries(null_dist_data_directory, "cell", self.anchor_token) 700 | cos_sims_df = isp_stats_vs_null(cos_sims_df_initial, dict_list, null_dict_list) 701 | 702 | elif self.mode == "mixture_model": 703 | cos_sims_df = isp_stats_mixture_model(cos_sims_df_initial, dict_list, self.combos, self.anchor_token) 704 | 705 | elif self.mode == "aggregate_data": 706 | cos_sims_df = isp_aggregate_grouped_perturb(cos_sims_df_initial, dict_list) 707 | 708 | # save perturbation stats to output_path 709 | output_path = (Path(output_directory) / output_prefix).with_suffix(".csv") 710 | cos_sims_df.to_csv(output_path) 711 | 712 | def token_to_gene_name(self, item): 713 | if isinstance(item,int): 714 | return self.gene_id_name_dict.get(self.gene_token_id_dict.get(item, np.nan), np.nan) 715 | if isinstance(item,tuple): 716 | return tuple([self.gene_id_name_dict.get(self.gene_token_id_dict.get(i, np.nan), np.nan) for i in item]) 717 | -------------------------------------------------------------------------------- /geneformer_001/geneformer/token_dictionary.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PharMolix/LangCell/69e41ef2fae485b67703294b50f3150bb1d5bb9b/geneformer_001/geneformer/token_dictionary.pkl -------------------------------------------------------------------------------- /geneformer_001/geneformer/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Geneformer tokenizer. 3 | 4 | Input data: 5 | Required format: raw counts scRNAseq data without feature selection as .loom file 6 | Required row (gene) attribute: "ensembl_id"; Ensembl ID for each gene 7 | Required col (cell) attribute: "n_counts"; total read counts in that cell 8 | Optional col (cell) attribute: "filter_pass"; binary indicator of whether cell should be tokenized based on user-defined filtering criteria 9 | Optional col (cell) attributes: any other cell metadata can be passed on to the tokenized dataset as a custom attribute dictionary as shown below 10 | 11 | Usage: 12 | from geneformer import TranscriptomeTokenizer 13 | tk = TranscriptomeTokenizer({"cell_type": "cell_type", "organ_major": "organ_major"}, nproc=4) 14 | tk.tokenize_data("loom_data_directory", "output_directory", "output_prefix") 15 | """ 16 | 17 | import pickle 18 | from pathlib import Path 19 | 20 | import logging 21 | 22 | import warnings 23 | warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") 24 | 25 | import loompy as lp 26 | import numpy as np 27 | from datasets import Dataset 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl" 32 | TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl" 33 | 34 | 35 | def tokenize_cell(gene_vector, gene_tokens): 36 | """ 37 | Convert normalized gene expression vector to tokenized rank value encoding. 38 | """ 39 | # create array of gene vector with token indices 40 | # mask undetected genes 41 | nonzero_mask = np.nonzero(gene_vector)[0] 42 | # sort by median-scaled gene values 43 | sorted_indices = np.argsort(-gene_vector[nonzero_mask]) 44 | # tokenize 45 | sentence_tokens = gene_tokens[nonzero_mask][sorted_indices] 46 | return sentence_tokens 47 | 48 | 49 | class TranscriptomeTokenizer: 50 | def __init__( 51 | self, 52 | custom_attr_name_dict=None, 53 | nproc=1, 54 | gene_median_file=GENE_MEDIAN_FILE, 55 | token_dictionary_file=TOKEN_DICTIONARY_FILE, 56 | ): 57 | """ 58 | Initialize tokenizer. 59 | 60 | Parameters 61 | ---------- 62 | custom_attr_name_dict : None, dict 63 | Dictionary of custom attributes to be added to the dataset. 64 | Keys are the names of the attributes in the loom file. 65 | Values are the names of the attributes in the dataset. 66 | nproc : int 67 | Number of processes to use for dataset mapping. 68 | gene_median_file : Path 69 | Path to pickle file containing dictionary of non-zero median 70 | gene expression values across Genecorpus-30M. 71 | token_dictionary_file : Path 72 | Path to pickle file containing token dictionary (Ensembl IDs:token). 73 | """ 74 | # dictionary of custom attributes {output dataset column name: input .loom column name} 75 | self.custom_attr_name_dict = custom_attr_name_dict 76 | 77 | # number of processes for dataset mapping 78 | self.nproc = nproc 79 | 80 | # load dictionary of gene normalization factors 81 | # (non-zero median value of expression across Genecorpus-30M) 82 | with open(gene_median_file, "rb") as f: 83 | self.gene_median_dict = pickle.load(f) 84 | 85 | # load token dictionary (Ensembl IDs:token) 86 | with open(token_dictionary_file, "rb") as f: 87 | self.gene_token_dict = pickle.load(f) 88 | 89 | # gene keys for full vocabulary 90 | self.gene_keys = list(self.gene_median_dict.keys()) 91 | 92 | # protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization 93 | self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys))) 94 | 95 | def tokenize_data(self, loom_data_directory, output_directory, output_prefix): 96 | """ 97 | Tokenize .loom files in loom_data_directory and save as tokenized .dataset in output_directory. 98 | 99 | Parameters 100 | ---------- 101 | loom_data_directory : Path 102 | Path to directory containing loom files 103 | output_directory : Path 104 | Path to directory where tokenized data will be saved as .dataset 105 | output_prefix : str 106 | Prefix for output .dataset 107 | """ 108 | tokenized_cells, cell_metadata = self.tokenize_files(Path(loom_data_directory)) 109 | tokenized_dataset = self.create_dataset(tokenized_cells, cell_metadata) 110 | 111 | output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset") 112 | tokenized_dataset.save_to_disk(output_path) 113 | 114 | def tokenize_files(self, loom_data_directory): 115 | tokenized_cells = [] 116 | if self.custom_attr_name_dict is not None: 117 | loom_cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()] 118 | cell_metadata = {attr_key: [] for attr_key in self.custom_attr_name_dict.values()} 119 | 120 | # loops through directories to tokenize .loom files 121 | file_found = 0 122 | for loom_file_path in loom_data_directory.glob("*.loom"): 123 | file_found = 1 124 | print(f"Tokenizing {loom_file_path}") 125 | file_tokenized_cells, file_cell_metadata = self.tokenize_file( 126 | loom_file_path 127 | ) 128 | tokenized_cells += file_tokenized_cells 129 | if self.custom_attr_name_dict is not None: 130 | for k in loom_cell_attr: 131 | cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[k] 132 | else: 133 | cell_metadata = None 134 | 135 | if file_found == 0: 136 | logger.error( 137 | f"No .loom files found in directory {loom_data_directory}.") 138 | raise 139 | return tokenized_cells, cell_metadata 140 | 141 | def tokenize_file(self, loom_file_path): 142 | if self.custom_attr_name_dict is not None: 143 | file_cell_metadata = { 144 | attr_key: [] for attr_key in self.custom_attr_name_dict.keys() 145 | } 146 | 147 | with lp.connect(str(loom_file_path)) as data: 148 | # define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors 149 | coding_miRNA_loc = np.where( 150 | [self.genelist_dict.get(i, False) for i in data.ra["ensembl_id"]] 151 | )[0] 152 | norm_factor_vector = np.array( 153 | [ 154 | self.gene_median_dict[i] 155 | for i in data.ra["ensembl_id"][coding_miRNA_loc] 156 | ] 157 | ) 158 | coding_miRNA_ids = data.ra["ensembl_id"][coding_miRNA_loc] 159 | coding_miRNA_tokens = np.array( 160 | [self.gene_token_dict[i] for i in coding_miRNA_ids] 161 | ) 162 | 163 | # define coordinates of cells passing filters for inclusion (e.g. QC) 164 | try: 165 | data.ca["filter_pass"] 166 | except AttributeError: 167 | var_exists = False 168 | else: 169 | var_exists = True 170 | 171 | if var_exists is True: 172 | filter_pass_loc = np.where( 173 | [True if i == 1 else False for i in data.ca["filter_pass"]] 174 | )[0] 175 | elif var_exists is False: 176 | print( 177 | f"{loom_file_path} has no column attribute 'filter_pass'; tokenizing all cells." 178 | ) 179 | filter_pass_loc = np.array([i for i in range(data.shape[1])]) 180 | 181 | # scan through .loom files and tokenize cells 182 | tokenized_cells = [] 183 | for (_ix, _selection, view) in data.scan(items=filter_pass_loc, axis=1): 184 | # select subview with protein-coding and miRNA genes 185 | subview = view.view[coding_miRNA_loc, :] 186 | 187 | # normalize by total counts per cell and multiply by 10,000 to allocate bits to precision 188 | # and normalize by gene normalization factors 189 | subview_norm_array = ( 190 | subview[:, :] 191 | / subview.ca.n_counts 192 | * 10_000 193 | / norm_factor_vector[:, None] 194 | ) 195 | # tokenize subview gene vectors 196 | tokenized_cells += [ 197 | tokenize_cell(subview_norm_array[:, i], coding_miRNA_tokens) 198 | for i in range(subview_norm_array.shape[1]) 199 | ] 200 | 201 | # add custom attributes for subview to dict 202 | if self.custom_attr_name_dict is not None: 203 | for k in file_cell_metadata.keys(): 204 | file_cell_metadata[k] += subview.ca[k].tolist() 205 | else: 206 | file_cell_metadata = None 207 | 208 | return tokenized_cells, file_cell_metadata 209 | 210 | def create_dataset(self, tokenized_cells, cell_metadata): 211 | # create dict for dataset creation 212 | dataset_dict = {"input_ids": tokenized_cells} 213 | if self.custom_attr_name_dict is not None: 214 | dataset_dict.update(cell_metadata) 215 | 216 | # create dataset 217 | output_dataset = Dataset.from_dict(dataset_dict) 218 | 219 | # truncate dataset 220 | def truncate(example): 221 | example["input_ids"] = example["input_ids"][0:2048] 222 | return example 223 | 224 | output_dataset_truncated = output_dataset.map(truncate, num_proc=self.nproc) 225 | 226 | # measure lengths of dataset 227 | def measure_length(example): 228 | example["length"] = len(example["input_ids"]) 229 | return example 230 | 231 | output_dataset_truncated_w_length = output_dataset_truncated.map( 232 | measure_length, num_proc=self.nproc 233 | ) 234 | 235 | return output_dataset_truncated_w_length 236 | -------------------------------------------------------------------------------- /geneformer_001/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="geneformer", 5 | version="0.0.1", 6 | author="Christina Theodoris", 7 | author_email="christina.theodoris@gladstone.ucsf.edu", 8 | description="Geneformer is a transformer model pretrained \ 9 | on a large-scale corpus of ~30 million single \ 10 | cell transcriptomes to enable context-aware \ 11 | predictions in settings with limited data in \ 12 | network biology.", 13 | packages=["geneformer"], 14 | include_package_data=True, 15 | install_requires=[ 16 | "datasets", 17 | "loompy", 18 | "numpy", 19 | "transformers", 20 | ], 21 | ) 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.1 2 | geneformer==0.0.1 3 | transformers==4.39.1 4 | datasets==2.14.0 5 | numpy==1.23.4 6 | pandas==2.2.1 7 | scanpy==1.9.3 8 | scikit-learn==1.3.0 9 | seaborn==0.12.2 10 | matplotlib==3.8.2 --------------------------------------------------------------------------------