├── .github ├── CODE_OF_CONDUCT.md └── CONTRIBUTING.md ├── LICENSE ├── README.md ├── embeddings ├── cell_line_embedding_full_ccle_300_scaled.csv └── phenotypes.csv ├── prophet ├── Prophet.py ├── __init__.py ├── callbacks.py ├── config.py ├── dataloader.py ├── dataset.py ├── model.py ├── train.py └── train_model.py ├── setup.py ├── test └── test_dataloader.py └── tutorials ├── config_file_finetuning.yaml ├── finetuning.ipynb └── insilico_screening.ipynb /.github/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | 2 | --- 3 | 4 | ### 📄 `CODE_OF_CONDUCT.md` 5 | 6 | ```markdown 7 | # Code of Conduct 8 | 9 | ## Our Pledge 10 | 11 | We are committed to fostering a welcoming and inclusive environment for everyone. We pledge to make participation in our project a harassment-free experience for all, regardless of background or identity. 12 | 13 | ## Our Standards 14 | 15 | Examples of behavior that contributes to a positive environment: 16 | 17 | - Using welcoming and inclusive language 18 | - Being respectful of differing viewpoints and experiences 19 | - Accepting constructive criticism gracefully 20 | - Focusing on what is best for the community 21 | 22 | Examples of unacceptable behavior: 23 | 24 | - Harassment, discrimination, or exclusionary behavior 25 | - Trolling, insulting or derogatory comments 26 | - Personal or political attacks 27 | - Public or private harassment 28 | 29 | ## Our Responsibilities 30 | 31 | Maintainers are responsible for clarifying standards of acceptable behavior and will take appropriate action in response to any unacceptable behavior. 32 | 33 | ## Scope 34 | 35 | This Code of Conduct applies within all project spaces and applies when representing the project in any public space. 36 | 37 | ## Enforcement 38 | 39 | Violations may result in warnings, temporary bans, or permanent exclusion from the project. 40 | 41 | If you experience or witness unacceptable behavior, please report it by contacting the maintainers directly or via a dedicated email (if applicable). 42 | 43 | ## Attribution 44 | 45 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 2.1. 46 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to This Project 2 | 3 | Thank you for your interest in contributing! We welcome all contributions, including bug reports, feature requests, documentation improvements, and code changes. 4 | 5 | ## Getting Started 6 | 7 | 1. **Fork the repository** and clone your fork locally. 8 | 2. Create a new branch for your changes: 9 | ```bash 10 | git checkout -b my-feature-branch 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024, Alejandro Tejada Lapuerta, Yuge Ji. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![License: MIT][mit-shield]][mit] 2 | # Prophet 3 | 4 | Prophet is a transformer-based regression model that predicts cellular responses by decomposing experiments into cell state, treatment, and functional readout, leveraging extensive screening datasets and scalability to significantly reduce the number of required experiments and identify effective treatments. 5 | 6 | ## Model Overview 7 | 8 | Prophet decomposes biological experiments into three key components: 9 | 1. **Cell state** - represented by cell line embeddings derived from gene expression profiles 10 | 2. **Treatment** - represented by intervention embeddings (e.g., small molecules, genetic perturbations) 11 | 3. **Functional readout** - the phenotypic measurement being predicted (e.g., viability, IC50) 12 | 13 | The model uses a transformer architecture to learn complex interactions between these components and predict experimental outcomes without requiring the experiments to be performed. 14 | 15 | ### Embeddings 16 | 17 | Prophet uses three types of embeddings: 18 | - **Cell line embeddings**: 300-dimensional vectors derived from CCLE gene expression data 19 | - **Intervention embeddings**: 500-dimensional vectors representing small molecules or genetic perturbations 20 | - **Phenotype embeddings**: Representations of different readout types (optional) 21 | 22 | These embeddings capture the biological properties of each component and allow the model to generalize across different experimental conditions. 23 | 24 | ## Training 25 | 26 | Prophet was trained on a large dataset of cellular response measurements, including: 27 | - Drug sensitivity screens (GDSC, PRISM, CTRP) 28 | - Genetic perturbation screens (DepMap, Achilles) 29 | - Combinatorial perturbation experiments 30 | 31 | The model was trained using a masked attention mechanism to handle variable numbers of perturbations and a cosine learning rate schedule with warmup. Training was performed on NVIDIA A100 GPUs with early stopping based on validation loss. 32 | 33 | ## Installation 34 | ``` 35 | mamba create -n prophet_env python=3.10 36 | mamba activate prophet_env 37 | 38 | git clone https://github.com/theislab/prophet.git 39 | cd prophet 40 | pip install -e . 41 | ``` 42 | 43 | ## Usage 44 | 45 | ### Downloading Resources 46 | 47 | Model checkpoints and input embeddings can be downloaded [here](https://huggingface.co/datasets/aletlvl/Prophet_v1/tree/main) and [here](https://data.mendeley.com/datasets/g7z3pw3bfw). 48 | 49 | If you have used our work in your research, please cite our [preprint](https://www.biorxiv.org/content/10.1101/2024.08.12.607533v2). 50 | 51 | [mit]: https://opensource.org/licenses/MIT 52 | [mit-image]: https://img.shields.io/badge/License-MIT-yellow.svg 53 | [mit-shield]: https://img.shields.io/badge/License-MIT-yellow.svg 54 | -------------------------------------------------------------------------------- /embeddings/phenotypes.csv: -------------------------------------------------------------------------------- 1 | ,0 2 | 0,SCORE 3 | 1,Horlbeck 4 | 2,UBE2A 5 | 3,GNAI1 6 | 4,MYBL2 7 | 5,BIRC5 8 | 6,NFIL3 9 | 7,TGFB3 10 | 8,P4HTM 11 | 9,RPN1 12 | 10,PDIA5 13 | 11,BAMBI 14 | 12,TATDN2 15 | 13,FAM57A 16 | 14,FGFR4 17 | 15,MOK 18 | 16,CCNE2 19 | 17,KIAA0753 20 | 18,KDELR2 21 | 19,GTF2A2 22 | 20,DCTD 23 | 21,USP22 24 | 22,CANT1 25 | 23,NENF 26 | 24,RRP12 27 | 25,DSG2 28 | 26,SACM1L 29 | 27,FBXO11 30 | 28,AKAP8 31 | 29,BAG3 32 | 30,MBNL2 33 | 31,C2CD2 34 | 32,YKT6 35 | 33,MVP 36 | 34,NOL3 37 | 35,CSRP1 38 | 36,DMTF1 39 | 37,ALDOC 40 | 38,SYNGR3 41 | 39,TCEAL4 42 | 40,PLEKHM1 43 | 41,IGF1R 44 | 42,MMP1 45 | 43,ETV1 46 | 44,CHERP 47 | 45,PARP2 48 | 46,TIMM17B 49 | 47,NUP133 50 | 48,ENOSF1 51 | 49,RALA 52 | 50,CNOT4 53 | 51,FBXL12 54 | 52,GDSC 55 | 53,GDSCcomb 56 | 54,PRISM 57 | 55,inhouse 58 | 56,CTRP 59 | -------------------------------------------------------------------------------- /prophet/Prophet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor 4 | import numpy as np 5 | import pandas as pd 6 | import warnings 7 | from tqdm import tqdm 8 | from typing import List, Union, Optional 9 | from prophet.callbacks import R2ScoreCallback 10 | import functools 11 | from joblib import load 12 | from sklearn.ensemble import RandomForestRegressor 13 | from .dataloader import ( 14 | dataloader_phenotypes, 15 | process_priors, 16 | remove_nonexistent_cat, 17 | ) 18 | from .model import load_models_config, TransformerPredictor 19 | from pytorch_lightning.callbacks import TQDMProgressBar 20 | 21 | def inherit_docs_and_signature(from_method): 22 | def decorator(to_method): 23 | @functools.wraps(from_method) 24 | def wrapper(self, *args, **kwargs): 25 | return to_method(self, *args, **kwargs) 26 | wrapper.__doc__ = from_method.__doc__ 27 | wrapper.__signature__ = from_method.__signature__ 28 | return wrapper 29 | return decorator 30 | 31 | class Prophet: 32 | def __init__( 33 | self, 34 | iv_emb_path: Union[str, List[str]] = None, 35 | cl_emb_path: Union[str, List[str]] = None, 36 | ph_emb_path: Union[str, List[str]] = None, 37 | model_pth=None, 38 | architecture="Transformer", 39 | ): 40 | """Initialize the Prophet model. 41 | 42 | Args: 43 | iv_emb_path (Union[str, List[str]], optional): The path to the gene embeddings. Defaults to None. 44 | cl_emb_path (Union[str, List[str]], optional): The path to the cell line embeddings. Defaults to None. 45 | ph_emb_path (Union[str, List[str]], optional): The path to the phenotype embeddings. Defaults to None. 46 | model_pth ([type], optional): The path to the trained model. Defaults to None. 47 | architecture (str, optional): The architecture of the model. Defaults to "Transformer". 48 | """ 49 | 50 | self.architecture = architecture 51 | self.iv_emb_path = iv_emb_path 52 | self.cl_emb_path = cl_emb_path 53 | self.ph_emb_path = ph_emb_path 54 | # set phenotypes (must be in the same order regardless of what is passed in predict) 55 | self.phenotypes = None 56 | self.column_map = None 57 | self.pert_len = None 58 | 59 | if model_pth and architecture == "RandomForest": 60 | self.model = load(model_pth) 61 | else: 62 | self.model_pth = model_pth 63 | self.model = self._build_model(architecture) 64 | self.phenotypes = self.model.hparams["phenotypes"] 65 | self.iv_embedding, self.cl_embedding, self.ph_embedding = process_priors(self.iv_emb_path, self.cl_emb_path, self.ph_emb_path) 66 | if self.model.hparams.explicit_phenotype and self.ph_embedding is None: 67 | raise ValueError('model was run with explicit phenotype! must pass a ph_emb_path') 68 | 69 | def _build_model(self, arch): 70 | if arch == "RandomForest": 71 | self.torch_dataset = False 72 | return RandomForestRegressor() 73 | elif arch == "Transformer": 74 | self.torch_dataset = True 75 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 76 | print('returning trained model!') 77 | model = TransformerPredictor.load_from_checkpoint(checkpoint_path=self.model_pth, map_location=torch.device('cpu')) #map_location must be cpu to load from checkpoint if you used ddp-notebook 78 | model.eval() 79 | # working backwards from config 80 | if model.hparams.simpler: 81 | self.pert_len = model.hparams.ctx_len - 1 82 | else: 83 | self.pert_len = model.hparams.ctx_len - 3 84 | 85 | return model 86 | else: 87 | raise ValueError(arch, " is not a valid model architecture.") 88 | 89 | def _remove_nonexistent_cat( 90 | self, 91 | data_label: Optional[pd.DataFrame] = None, 92 | verbose=True, 93 | ): 94 | embeddings = [self.iv_embedding, self.cl_embedding, self.ph_embedding] 95 | cols = [self.iv_cols, self.cl_col, self.ph_col] 96 | for i, embedding in enumerate(embeddings): 97 | if embedding is None: # phenotype embedding can be None 98 | continue 99 | data_label = remove_nonexistent_cat(data_label, embedding, cols[i], verbose) 100 | data_label = data_label.reset_index(drop=True) 101 | 102 | if len(data_label) == 0 and not verbose: 103 | self._remove_nonexistent_cat(data_label=data_label, verbose=True) 104 | raise ValueError('labels did not match embeddings passed!') 105 | return data_label 106 | 107 | def _init_input( 108 | self, 109 | iv_col: Union[List[str], str] = ['iv1', 'iv2'], 110 | cl_col: str = "cell_line", 111 | ph_col: str = "phenotype", 112 | readout_col: str = "value", 113 | ): 114 | """Sets some state variables in the model, but is always overwritten by 115 | either train or predict. 116 | """ 117 | if isinstance(iv_col, str): 118 | iv_col = [iv_col] 119 | intervention_mapping = {col: f"iv{i+1}" for i, col in enumerate(iv_col)} 120 | self.iv_cols = list(intervention_mapping.values()) 121 | 122 | # store the columns used for training for reference 123 | self.cl_col = cl_col 124 | self.ph_col = ph_col 125 | self.readout_col = readout_col 126 | 127 | # create the mapping to the internal variables used 128 | self.column_map = { 129 | self.cl_col: "cell_line", 130 | self.ph_col: "phenotype", 131 | self.readout_col: "value", 132 | **intervention_mapping 133 | } 134 | if self.pert_len is None: 135 | self.pert_len = len(self.iv_cols) 136 | else: 137 | if self.pert_len != len(self.iv_cols): 138 | raise ValueError(f"Are you sure you passed the right number of intervention columns? Currently receiving {self.iv_cols}") 139 | 140 | def train( 141 | self, 142 | df: pd.DataFrame, 143 | iv_col: Union[List[str], str] = ['iv1', 'iv2'], 144 | cl_col: str = "cell_line", 145 | ph_col: str = "phenotype", 146 | readout_col: str = "value", 147 | model_config: dict = None, 148 | ): 149 | """Train the Prophet model on the provided DataFrame. 150 | 151 | This function reformats the DataFrame according to the specified settings and intervention columns, 152 | then trains the model using the reformatted data. 153 | 154 | Args: 155 | df (pd.DataFrame): The DataFrame containing the experimental data. 156 | iv_col (Union[List[str], str]): The names of the intervention columns in the DataFrame. Can be a single column name or a list of names. 157 | cl_col (str, optional): The name of the column in df that contains the setting labels. Defaults to "cell_line". 158 | ph_col (str, optional): The name of the column in df that contains the phenotype labels. Defaults to "phenotype". 159 | readout_col (str, optional): The name of the column in df that contains the readout data. Defaults to "value". 160 | model_config (dict, optional): Configuration yaml, e.g. config_file_finetuning 161 | """ 162 | self._init_input(iv_col, cl_col, ph_col, readout_col) 163 | # user-friendly check that the columns were passed in correctly 164 | for _, (old_name, new_name) in enumerate(self.column_map.items()): 165 | if old_name not in df.columns: 166 | raise ValueError(f"{old_name} not in df columns.") 167 | 168 | df = df.rename(columns=self.column_map).copy() 169 | 170 | # Formatting 171 | df = df.drop_duplicates() 172 | df = df.reset_index(drop=True) 173 | 174 | ## generate training dataloader 175 | df = self._remove_nonexistent_cat(data_label=df, verbose=True) 176 | split = dataloader_phenotypes( 177 | gene_embedding=self.iv_embedding, 178 | cell_lines_embedding=self.cl_embedding, 179 | phenotype_embedding=self.ph_embedding if self.ph_embedding is not None else None, 180 | data_label=df, 181 | label_name="value", 182 | index=( 183 | np.array(df.index), 184 | [], 185 | [], 186 | "", 187 | ), # (train, test, val, descr) 188 | torch_dataset=self.torch_dataset, 189 | pert_len=len(self.iv_cols) 190 | ) 191 | 192 | print("Fitting model.") 193 | if not self.torch_dataset: 194 | X_train, y_train = split[2] 195 | self.model.fit(X_train, y_train) 196 | else: 197 | print("pytorch model, finetuning") 198 | # automatically take 10% of the data as validation set 199 | train_indices = np.array(df.index)[np.random.choice(len(df.index), int(len(df.index) * 0.9), replace=False)] 200 | val_indices = np.array(df.index)[~np.isin(df.index, train_indices)] 201 | split = dataloader_phenotypes( 202 | gene_embedding=self.iv_embedding, 203 | cell_lines_embedding=self.cl_embedding, 204 | phenotype_embedding=self.ph_embedding if self.ph_embedding is not None else None, 205 | data_label=df, 206 | label_name="value", 207 | index=( 208 | train_indices, 209 | val_indices, 210 | [], 211 | "", 212 | ), # (train, val, test, descr) 213 | torch_dataset=self.torch_dataset, 214 | pert_len=len(self.iv_cols), 215 | valid_set=True 216 | ) 217 | 218 | model_config.ohe_dim = 0 219 | 220 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 221 | model, model_config = load_models_config(model_config, seed=42, phenotypes=None) 222 | 223 | lr_monitor = LearningRateMonitor(logging_interval='step') 224 | dirpath = model_config.dirpath 225 | model_checkpointer = ModelCheckpoint(dirpath=dirpath, save_top_k=1, every_n_epochs=1, monitor='R2_train', mode='max') 226 | r2_callback = R2ScoreCallback(device=model.device, average=False) 227 | early_stopping = EarlyStopping(monitor="R2_train", mode="max", patience=model_config.patience, min_delta=0.0) 228 | 229 | tqdm_progress_bar = TQDMProgressBar(refresh_rate=1) 230 | callbacks = [r2_callback, model_checkpointer, lr_monitor, early_stopping,tqdm_progress_bar] 231 | 232 | print(f"Running with early stopping: {model_config.early_stopping}") 233 | if model_config.early_stopping: 234 | print(f"Early stopping patience: {model_config.patience}") 235 | 236 | trainer = pl.Trainer( 237 | min_epochs=1, 238 | #max_steps=100, 239 | max_steps=model_config.max_steps, 240 | max_epochs=2, 241 | accelerator='gpu', 242 | check_val_every_n_epoch=1, 243 | callbacks=callbacks, 244 | strategy="auto", #choose a notebook-compatible strategy: `Trainer(strategy='ddp_notebook')` 245 | #precision="16-mixed", 246 | enable_progress_bar=True, 247 | gradient_clip_val=1, 248 | log_every_n_steps = 1, 249 | deterministic=True) 250 | 251 | trainer.fit(model=model, train_dataloaders=split[0], val_dataloaders=split[1]) 252 | 253 | def _generate_predict_df(self, 254 | run_index: int, 255 | num_iterations: int, 256 | target_ivs: List[str], 257 | target_cls: List[str], 258 | target_phs: List[str] = ['_'], 259 | ): 260 | subset_cl = pd.DataFrame(target_cls, columns=["cell_line"]) 261 | subset_iv = pd.DataFrame(target_ivs, columns=["iv"]) 262 | if target_phs is None: 263 | target_phs = ['_'] 264 | subset_ph = pd.DataFrame(target_phs, columns=["phenotype"]) 265 | if len(self.iv_cols) > 2: 266 | raise NotImplementedError("Only support 1 or 2 interventions if you input a list of interventions. Please create the data label dataframe yourself and input to predict()!") 267 | 268 | batch_size = int(len(subset_iv) // num_iterations) 269 | start_idx = run_index * batch_size 270 | end_idx = ( 271 | (start_idx + batch_size) 272 | if (run_index < num_iterations - 1) 273 | else len(subset_iv["iv"]) 274 | ) 275 | 276 | data_label = pd.merge(subset_iv[["iv"]][start_idx:end_idx], subset_cl, how="cross") 277 | data_label = pd.merge(data_label, subset_ph, how="cross") 278 | 279 | if len(self.iv_cols) == 1: 280 | data_label.rename(columns={"iv": "iv1"}, inplace=True) 281 | else: 282 | data_label = pd.merge(subset_iv[["iv"]], data_label, how="cross", suffixes=("1", "2")) 283 | # A+B and B+A should be the same, so we remove all duplicates in favor of A+B (was pretty sure this shouldn't exist in the implementation @John) 284 | data_label['iv1+iv2'] = ['+'.join(sorted([row['iv1'], row['iv2']])) for _, row in data_label.iterrows()] 285 | data_label = data_label.drop_duplicates(subset=['iv1+iv2', 'cell_line', 'phenotype']) 286 | 287 | data_label['value'] = '_' 288 | 289 | return data_label 290 | 291 | def _decide_iteration_num( 292 | self, 293 | total_size: int, 294 | single_run_size: int = None, 295 | memory_size: int = None, 296 | ): 297 | if total_size <= single_run_size: 298 | num_iterations = 1 299 | else: 300 | num_iterations = total_size // single_run_size 301 | 302 | return int(num_iterations) 303 | 304 | def predict( 305 | self, 306 | df: pd.DataFrame = None, 307 | target_ivs: Union[str, List[str]] = None, 308 | target_cls: Union[str, List[str]] = None, 309 | target_phs: Union[str, List[str]] = None, 310 | iv_col: Union[List[str], str] = ['iv1', 'iv2'], 311 | cl_col: str = "cell_line", 312 | ph_col: str = "phenotype", 313 | num_iterations: int = None, 314 | save: bool = True, 315 | filename: str = "Prophet_prediction", 316 | ): 317 | """Predict outcomes using the trained Prophet model. 318 | 319 | This function can take either a DataFrame or a combination of gene and cell line lists to make predictions. 320 | If a dataframe is passed, which columns correspond to which inputs must also be passed. If lists are passed, 321 | all combinations within are taken (not including combinations). 322 | 323 | Args: 324 | df (pd.DataFrame, optional): The DataFrame containing the data for prediction. If None, predictions will be made for all combinations of provided genes and cell lines. 325 | target_ivs (Union[str, List[str]], optional): The intervention or list of interventions for prediction. If df and target_ivs are both None, predictions will be made for all available treatments. 326 | target_cls (Union[str, List[str]], optional): The cell line or list of cell lines for prediction. If df and target_cls are both None, predictions will be made for all available cell lines. 327 | target_phs (Union[str, List[str]], optional): The phenotype for prediction. If None, it will be set to "_". 328 | num_iterations (int, optional): The number of iterations to run the prediction. If None, it will be calculated automatically. 329 | save (bool, optional): Whether to save the prediction results to a file. If False, return the prediction result as dataframe. Defaults to True. 330 | filename (str, optional): The filename to save the prediction results. If None, a default name will be used. 331 | """ 332 | if save: 333 | print(f"Saving results to {filename}.parquet") 334 | 335 | # If only pass target_cls and target_ivs 336 | if df is None: 337 | 338 | if isinstance(target_cls, str): 339 | target_cls = [target_cls] 340 | if isinstance(target_phs, str): 341 | target_phs = [target_phs] 342 | if isinstance(target_ivs, str): 343 | target_ivs = [target_ivs] 344 | 345 | if target_cls is None: 346 | target_cls = list(self.cl_embedding.index) 347 | warnings.warn(f"Trying to predict for all cell lines that are from {self.cl_embedding.index}!") 348 | print(f"There are {len(target_cls)} cell lines in total!") 349 | if target_ivs is None: 350 | target_ivs = list(self.iv_embedding.drop_duplicates().index) # because there compounds with the same embedding but different 351 | warnings.warn(f"Trying to predict for all gene combinations that are from {self.iv_embedding.index}!") 352 | print(f"There are {len(target_ivs)} genes in total!") 353 | 354 | # in case the user wants to specify more than one intervention 355 | if iv_col is None: 356 | iv_col = ['iv1'] 357 | 358 | total_size = len(target_ivs) * len(target_cls) * len(target_phs) 359 | self._init_input(iv_col, 'cell_line', 'phenotype', 'value') 360 | # or pass a dataframe has similiar format with in train 361 | else: 362 | self._init_input(iv_col, cl_col, ph_col, 'value') 363 | # user-friendly column name check 364 | for _, (old_name, new_name) in enumerate(self.column_map.items()): 365 | if new_name == 'value': # prediction does not need a readout col 366 | continue 367 | if old_name not in df.columns: 368 | raise ValueError(f"{old_name} not in df columns.") 369 | 370 | df = df.rename(columns=self.column_map).copy() 371 | 372 | # same formatting as in train 373 | df = df.drop_duplicates() 374 | df = df.reset_index(drop=True) 375 | df = self._remove_nonexistent_cat(data_label=df, verbose=False) 376 | total_size = len(df) 377 | 378 | # Divide work into partitions. This allows users to run a large amount of inference 379 | # in one shot without memory problems. 380 | if num_iterations: 381 | if num_iterations < 1: 382 | raise KeyError("num_iterations must be larger than 0.") 383 | else: 384 | if num_iterations - self._decide_iteration_num(total_size=total_size, single_run_size=1e08) > 5: 385 | warnings.warn("The num_iterations you passed might be too small.") 386 | else: 387 | num_iterations = self._decide_iteration_num(total_size=total_size, single_run_size=1e08) 388 | 389 | if num_iterations % 2 == 0: 390 | num_iterations = num_iterations + 1 391 | 392 | print(f"There are {num_iterations} iterations") 393 | data_label_list = [] 394 | for run_index in tqdm(range(num_iterations)): 395 | 396 | # If the user input is df 397 | if isinstance(df, pd.DataFrame): 398 | batch_size = int(total_size // num_iterations) 399 | start_idx = run_index * batch_size 400 | end_idx = ( 401 | start_idx + batch_size 402 | if (run_index < num_iterations - 1) 403 | else len(df) 404 | ) 405 | data_label = df.iloc[start_idx:end_idx] 406 | 407 | # If the user input is gene list and cl list 408 | else: 409 | data_label = self._generate_predict_df(run_index=run_index,num_iterations=num_iterations,target_ivs=target_ivs,target_cls=target_cls, target_phs=target_phs) 410 | 411 | data_label = data_label.drop_duplicates() 412 | 413 | data_label = self._remove_nonexistent_cat(data_label=data_label, verbose=not isinstance(df, pd.DataFrame)) 414 | 415 | if self.torch_dataset: # format for pytorch dataloading 416 | # must have a value column 417 | data_label['_'] = 0 418 | 419 | split = dataloader_phenotypes( 420 | gene_embedding=self.iv_embedding, 421 | cell_lines_embedding=self.cl_embedding, 422 | phenotype_embedding=self.ph_embedding, 423 | data_label=data_label, 424 | label_name='_', 425 | index=( 426 | np.array(data_label.index), 427 | [], 428 | np.array(data_label.index).tolist(), # test is not shuffled 429 | "", 430 | ), # (train, test, val, descr) 431 | torch_dataset=self.torch_dataset, 432 | pert_len=self.pert_len 433 | ) 434 | 435 | if self.torch_dataset: 436 | X = split[2] # take test is not shuffled 437 | else: 438 | X, _ = split[2] 439 | 440 | train_dataloader, valid_dataloader, test_dataloader, train_indices, test_indices, descriptor = split 441 | trainer = pl.Trainer(devices=1) 442 | predictions = trainer.predict(self.model, test_dataloader) 443 | predictions = [t[0] for t in predictions] 444 | predictions = torch.cat(predictions, dim=0) 445 | data_label["pred"] = predictions 446 | data_label.drop(columns=['_'], inplace=True) 447 | 448 | if save: 449 | data_label.to_parquet( 450 | f"{filename}_{run_index}.parquet", 451 | # engine="fastparquet" 452 | ) 453 | else: 454 | data_label_list.append(data_label) 455 | if not save: 456 | concatenated_df = pd.concat(data_label_list) 457 | return concatenated_df 458 | -------------------------------------------------------------------------------- /prophet/__init__.py: -------------------------------------------------------------------------------- 1 | from .Prophet import Prophet 2 | from .config import set_config 3 | 4 | __all__ = ['Prophet', 'set_config'] 5 | -------------------------------------------------------------------------------- /prophet/callbacks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | from torch import optim 4 | import numpy as np 5 | from sklearn.metrics import r2_score 6 | from scipy.stats import spearmanr 7 | 8 | 9 | class CosineWarmupScheduler(optim.lr_scheduler._LRScheduler): 10 | 11 | def __init__(self, optimizer, warmup, max_iters): 12 | self.warmup = warmup 13 | self.max_num_iters = max_iters 14 | super().__init__(optimizer) 15 | 16 | def get_lr(self): 17 | lr_factor = self.get_lr_factor(epoch=self.last_epoch) 18 | return [base_lr * lr_factor for base_lr in self.base_lrs] 19 | 20 | def get_lr_factor(self, epoch): 21 | lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters)) 22 | if epoch <= self.warmup: 23 | lr_factor *= epoch * 1.0 / self.warmup 24 | return lr_factor 25 | 26 | 27 | class R2ScoreCallback(pl.Callback): 28 | def __init__(self, device: torch.device = 'cpu', average = False): 29 | super().__init__() 30 | self.predictions = [] 31 | self.targets = [] 32 | 33 | 34 | self.prediction_train = [] 35 | self.prediction_test = [] 36 | 37 | self.targets_train = [] 38 | self.targets_test = [] 39 | 40 | self.phenotype_validation = [] 41 | 42 | self.device = device 43 | self.average = average 44 | 45 | self.table_val = None 46 | self.table_train = None 47 | 48 | print("R2 average: ", self.average) 49 | 50 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): 51 | y_pred, y_true = outputs['y_pred'].detach(), outputs['y_true'].detach() 52 | self.prediction_train.append(y_pred) 53 | self.targets_train.append(y_true) 54 | 55 | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): 56 | y_pred, y_true, phenotype = outputs['y_pred'], outputs['y_true'], outputs['phenotype'] 57 | self.predictions.append(y_pred) 58 | self.targets.append(y_true) 59 | self.phenotype_validation.append(phenotype) 60 | 61 | def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): 62 | y_pred, y_true = outputs['y_pred'], outputs['y_true'] 63 | self.prediction_test.append(y_pred) 64 | self.targets_test.append(y_true) 65 | 66 | def on_validation_epoch_end(self, trainer, pl_module): 67 | predictions = torch.cat(self.predictions, dim=0)#.cpu().numpy() 68 | targets = torch.cat(self.targets, dim=0)#.cpu().numpy() 69 | phenotypes = torch.cat(self.phenotype_validation, dim=0) 70 | 71 | 72 | predictions = predictions.cpu().numpy() 73 | targets = targets.cpu().numpy() 74 | phenotypes = phenotypes.cpu().numpy() 75 | 76 | if self.average: 77 | r2_scores = 0 78 | spearman_scores = 0 79 | unique_phe = np.unique(phenotypes) 80 | for phe in unique_phe: 81 | indices = np.nonzero(phenotypes == phe) 82 | phe_predictions = predictions[indices] 83 | phe_targets = targets[indices] 84 | 85 | r2 = r2_score(phe_targets, phe_predictions) 86 | r2_scores += r2 87 | 88 | spearman = spearmanr(phe_predictions, phe_targets).statistic 89 | spearman_scores += spearman 90 | 91 | r2_total = r2_scores / len(unique_phe) 92 | spearman_total = spearman_scores / len(unique_phe) 93 | 94 | self.log("R2", r2_total, sync_dist=True, batch_size=predictions.shape[0]) 95 | self.log("Spearman", spearman_total, sync_dist=True, batch_size=predictions.shape[0]) 96 | 97 | else: 98 | 99 | r2 = r2_score(targets, predictions) 100 | self.log("R2", r2, sync_dist=True, batch_size=predictions.shape[0]) 101 | 102 | spearman = spearmanr(predictions, targets).statistic 103 | self.log("Spearman", spearman, sync_dist=True, batch_size=predictions.shape[0]) 104 | 105 | self.predictions = [] 106 | self.targets = [] 107 | self.phenotype_validation = [] 108 | 109 | def on_train_epoch_end(self, trainer, pl_module): 110 | predictions = torch.cat(self.prediction_train, dim=0)#.cpu().numpy() 111 | targets = torch.cat(self.targets_train, dim=0)#.cpu().numpy() 112 | 113 | targets_mean = targets.mean(0) 114 | targets_mean = targets_mean.repeat(targets.shape[0]) 115 | 116 | targets_mean = targets_mean.cpu().numpy() 117 | predictions = predictions.cpu().numpy() 118 | targets = targets.cpu().numpy() 119 | 120 | r2 = r2_score(targets, predictions) 121 | self.log("R2_train", r2, sync_dist=True, batch_size=predictions.shape[0]) 122 | 123 | spearman = spearmanr(predictions, targets).statistic 124 | self.log("Spearman_train", spearman, sync_dist=True, batch_size=predictions.shape[0]) 125 | 126 | self.prediction_train = [] 127 | self.targets_train = [] 128 | 129 | def on_test_epoch_end(self, trainer, pl_module): 130 | predictions = torch.cat(self.prediction_test, dim=0)#.cpu().numpy() 131 | targets = torch.cat(self.targets_test, dim=0)#.cpu().numpy() 132 | 133 | targets_mean = targets.mean(0) 134 | targets_mean = targets_mean.repeat(targets.shape[0]) 135 | 136 | targets_mean = targets_mean.cpu().numpy() 137 | predictions = predictions.cpu().numpy() 138 | targets = targets.cpu().numpy() 139 | 140 | r2 = r2_score(targets, predictions) 141 | self.log("R2_test", r2, sync_dist=True, batch_size=predictions.shape[0]) 142 | 143 | spearman = spearmanr(predictions, targets).statistic 144 | self.log("Spearman_test", spearman, sync_dist=True, batch_size=predictions.shape[0]) 145 | 146 | self.prediction_test = [] 147 | self.targets_test = [] -------------------------------------------------------------------------------- /prophet/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import List, Dict 3 | 4 | @dataclass 5 | class TransformerConfig: 6 | dim_cl: int = 300 7 | dim_iv: int = 800 8 | dim_phe: int = 300 9 | model_dim: int = 128 10 | num_heads: int = 1 11 | num_layers: int = 2 12 | iv_dropout: float = 0.2 13 | cl_dropout: float = 0.2 14 | ph_dropout: float = 0.2 15 | regressor_dropout: float = 0.2 16 | lr: float = 0.0001 17 | weight_decay: float = 0.01 18 | warmup: int = 10000 19 | max_iters: int = 70000 20 | dropout: float = 0.2 21 | exclude_cl_embedding: bool = False 22 | pool: str = "cls" 23 | simpler: bool = True 24 | mask: bool = True 25 | sum: bool = False 26 | explicit_phenotype: bool = False 27 | linear_predictor: bool = False 28 | tokenizer_layers: int = 2 29 | 30 | def update_from_dict(self, updates: Dict): 31 | for key, value in updates.items(): 32 | if hasattr(self, key): 33 | setattr(self, key, value) 34 | 35 | @dataclass 36 | class Config: 37 | setting: str 38 | leaveout_method: str 39 | dirpath: str = './ckpts/' 40 | ckpt_path: str = None 41 | project_name: str = "Prophet_hparams" 42 | cell_lines_prior: List[str] = field(default_factory=lambda: ["./embeddings/cell_line_embedding_full_ccle_300_scaled.csv"]) 43 | genes_prior: List[str] = field(default_factory=lambda: ["./embeddings/ccle_T_pca_300_enformer_full_gene_mean_PCA_500_scaled.csv"] ) 44 | phenotype_prior: List[str] = None 45 | unbalanced: bool = False 46 | pert_len: int = 2 47 | max_steps: int = 140000 48 | batch_size: int = 2048 49 | early_stopping: bool = True 50 | patience: int = 20 51 | ckpt_path = None 52 | fine_tune = False 53 | transformer: TransformerConfig = field(default_factory=TransformerConfig) 54 | 55 | def update_from_dict(self, updates: Dict): 56 | for key, value in updates.items(): 57 | if key == 'Transformer' and isinstance(value, dict): 58 | self.transformer.update_from_dict(value) 59 | elif hasattr(self, key): 60 | setattr(self, key, value) 61 | 62 | def set_config(models_config): 63 | config = Config( 64 | setting=models_config['setting'], 65 | leaveout_method=models_config['leaveout_method'], 66 | ) 67 | config.update_from_dict(models_config) 68 | 69 | if config.transformer.simpler: 70 | config.ctx_len = config.pert_len + 1 71 | else: 72 | config.ctx_len = config.pert_len + 3 73 | return config 74 | -------------------------------------------------------------------------------- /prophet/dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from typing import List, Tuple 3 | from torch.utils.data import WeightedRandomSampler 4 | import numpy as np 5 | import pandas as pd 6 | from functools import reduce 7 | from .dataset import PhenotypeDataset 8 | 9 | SEED = 42 # the true, baseline seed (that sets test splits) 10 | 11 | def _choose(a, size, seed): 12 | """Guaranteed deterministic choosing.""" 13 | np.random.seed(seed) # reset the generator 14 | return np.random.choice(a, size=size, replace=False) 15 | 16 | def dataloader_phenotypes( 17 | gene_embedding: List[pd.DataFrame], 18 | cell_lines_embedding: List[pd.DataFrame], 19 | phenotype_embedding: List[pd.DataFrame], 20 | data_label, 21 | index, 22 | batch_size = 2048, 23 | label_name=None, 24 | unbalanced: bool = False, 25 | torch_dataset: bool = True, 26 | pert_len: int = 2, 27 | valid_set: bool = True, 28 | test_set: bool = True, 29 | phenotypes: list = None, 30 | ) -> List[Tuple[DataLoader, DataLoader, DataLoader, np.array, np.array]] : 31 | """Dataloader for multiple sources of information 32 | 33 | Args: 34 | gene_embedding (pd.DataFrame): index needs to be gene or gRNA 35 | cell_lines_embedding (pd.DataFrame): index needs to be cancer cell line name 36 | phenotype_embedding (pd.DataFrame): index needs to be phenotype name 37 | data_label (_type_): experimental data with columns that match indices in gene_embedding dataframes and cell line embedings dataframes 38 | label_name (_type_): label to predict 39 | indices (_type_, optional): _description_. Defaults to None. 40 | batch_size (_type_, optional): -1 to indicate work with sklearn 41 | embedding (Optional[List[bool]], optional): whether gene and cell_line should send index for embedding of vectors. [F, F] send vectors, [F, T] sends vector for genes and index for cell_line 42 | pert_len (int): number of perturbations to give to the model 43 | phenotypes (list): Pass a list of phenotypes to force indexing to occur correctly at prediction time for a pytorch model. If None, automatically determined from data_label. 44 | Returns: 45 | _type_: _description_ 46 | """ 47 | train_indices, valid_indices, test_indices, cl_holdout = index 48 | if len(valid_indices) == 0: 49 | valid_set = False 50 | if len(test_indices) == 0: 51 | test_set = False 52 | 53 | # accounting for multiple datasets 54 | test_dict = None 55 | if type(test_indices) == dict: 56 | test_dict = test_indices.copy() 57 | test_indices = test_indices['all'] 58 | 59 | # value checks 60 | if 'type' not in gene_embedding.columns: 61 | raise ValueError("No column 'type' in gene_embedding") 62 | 63 | if not torch_dataset: 64 | # create input dataframes 65 | ge = gene_embedding.dropna() 66 | ce = cell_lines_embedding.dropna() 67 | ge = ge.drop(columns=['type']) 68 | 69 | if phenotype_embedding is not None: 70 | pe = phenotype_embedding.dropna() 71 | X_phenotype = pe.loc[data_label.phenotype].to_numpy() 72 | else: 73 | X_phenotype = pd.get_dummies(data_label['phenotype']).astype(float).values # 1he 74 | X_iv = np.concatenate( 75 | [ge.loc[data_label[f'iv{i}']].to_numpy() for i in range(1, pert_len + 1)] 76 | , axis=1) 77 | X_cellline = ce.loc[data_label.cell_line].to_numpy() 78 | if label_name is None: 79 | return [ 80 | (np.concatenate([X_phenotype[idxs], X_cellline[idxs], X_iv[idxs]], axis=1), 81 | None) \ 82 | for idxs in [train_indices, valid_indices, test_indices]] 83 | else: 84 | y_label = data_label[label_name].to_numpy() 85 | 86 | return [ 87 | (np.concatenate([X_phenotype[idxs], X_cellline[idxs], X_iv[idxs]], axis=1), 88 | y_label[idxs]) \ 89 | for idxs in [train_indices, valid_indices, test_indices]] 90 | 91 | data = data_label.copy() 92 | if phenotypes is None: 93 | phenotypes = sorted(list(data_label.phenotype.unique())) 94 | 95 | train_set = PhenotypeDataset( 96 | experimental_data = data.loc[train_indices], 97 | label_key = label_name, 98 | iv_embeddings = gene_embedding, 99 | cell_line_embeddings = cell_lines_embedding, 100 | phenotype_embeddings = phenotype_embedding if phenotype_embedding is not None else None, # if it's None, None[0] will return error 101 | phenotypes = phenotypes, 102 | pert_len=pert_len 103 | ) 104 | if valid_set: 105 | valid_set = PhenotypeDataset( 106 | experimental_data = data.loc[valid_indices], 107 | label_key = label_name, 108 | iv_embeddings = gene_embedding, 109 | cell_line_embeddings = cell_lines_embedding, 110 | phenotype_embeddings = phenotype_embedding if phenotype_embedding is not None else None, 111 | phenotypes = phenotypes, 112 | pert_len=pert_len 113 | ) 114 | valid_dataloader = DataLoader(valid_set, batch_size=batch_size, shuffle=False, num_workers=4) 115 | else: 116 | valid_dataloader = None 117 | 118 | if test_set: 119 | test_set = PhenotypeDataset( 120 | experimental_data = data.loc[test_indices], 121 | label_key = label_name, 122 | iv_embeddings = gene_embedding, 123 | cell_line_embeddings = cell_lines_embedding, 124 | phenotype_embeddings = phenotype_embedding if phenotype_embedding is not None else None, 125 | phenotypes = phenotypes, 126 | pert_len=pert_len 127 | ) 128 | test_dataloader = DataLoader(test_set, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4) 129 | # convert to dict if there are multiple test sets 130 | if test_dict is not None: 131 | test_dl_dict = {'all': test_dataloader} 132 | for k, test_indices in test_dict.items(): 133 | if k == 'all': # already loaded in test_dataloader 134 | pass 135 | test_set = PhenotypeDataset( 136 | experimental_data = data.loc[test_indices], 137 | label_key = label_name, 138 | iv_embeddings = gene_embedding, 139 | cell_line_embeddings = cell_lines_embedding, 140 | phenotype_embeddings = phenotype_embedding if phenotype_embedding is not None else None, 141 | phenotypes = phenotypes, 142 | pert_len=pert_len 143 | ) 144 | test_dl_dict[k] = DataLoader(test_set, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4) 145 | 146 | else: 147 | test_dataloader = None 148 | 149 | if unbalanced: # unbalanced means that one phenotype is way more measured that the other 150 | 151 | ## Count the occurrences of each class 152 | key = 'phenotype' 153 | if 'dataset' in data.columns: 154 | key = 'dataset' 155 | class_counts = data.loc[train_indices][key].value_counts() 156 | num_samples = len(data) 157 | class_weights = [num_samples / class_counts.values[i] for i in range(len(class_counts))] 158 | scaling_factor = 1 / min(class_weights) 159 | class_weights = [x * scaling_factor for x in class_weights] 160 | class_weight_dict = dict(zip(class_counts.index, class_weights)) 161 | 162 | # Create a custom WeightedRandomSampler to oversample the minority class 163 | weights = [class_weight_dict[c] for c in data.loc[train_indices][key].values] 164 | train_sampler = WeightedRandomSampler(weights, len(weights), replacement=True) 165 | train_dataloader = DataLoader(train_set, batch_size=batch_size, sampler=train_sampler, num_workers=4) 166 | else: 167 | train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4) 168 | 169 | if test_dict: 170 | return (train_dataloader, valid_dataloader, test_dl_dict, train_indices, test_indices, cl_holdout) 171 | else: 172 | return (train_dataloader, valid_dataloader, test_dataloader, train_indices, test_indices, cl_holdout) 173 | 174 | def read_in_priors(prior_files): 175 | """Convert list of filenames to dfs. When presented with multiple files, assumes they contain different 176 | indices and adds additional rows. Assumes columns which contain the same feature are named accordingly.""" 177 | prior = [] 178 | if isinstance(prior_files, str): 179 | prior_files = [prior_files] 180 | for file in range(len(prior_files)): 181 | emb = pd.read_csv(prior_files[file], index_col=0) 182 | emb.index = emb.index.astype(str) 183 | prior.append(emb) 184 | 185 | return pd.concat(prior).fillna(0) if len(prior) > 0 else None 186 | 187 | def process_priors(genes_prior, cell_lines_prior, phenotype_prior): 188 | gene_prior = read_in_priors(genes_prior) 189 | try: 190 | gene_prior = gene_prior.set_index('smiles') 191 | except KeyError: # still works even if there's no smiles column, like for genetic interventions 192 | pass 193 | gene_prior.index = [str(x).lower() for x in gene_prior.index] # allow translatability across organisms and drugs 194 | cl_prior = read_in_priors(cell_lines_prior) 195 | phe_prior = None 196 | if phenotype_prior is not None: 197 | phe_prior = read_in_priors(phenotype_prior) 198 | 199 | if gene_prior is not None and "type" not in gene_prior.columns: 200 | raise ValueError("type not in iv_embedding columns") 201 | 202 | # add 0 for control 203 | gene_prior.loc['negative_gene'] = 0 204 | gene_prior.loc['negative_drug'] = 0 205 | gene_prior.loc['negative_gene', 'type'] = 'gene' 206 | gene_prior.loc['negative_drug', 'type'] = 'drug' 207 | 208 | return gene_prior, cl_prior, phe_prior 209 | 210 | 211 | def remove_nonexistent_cat(data_label, prior, columns, verbose=True): 212 | """Takes in a cell line or gene prior embedding dataframe and removes rows from 213 | data_label where the embedding doesn't exist. 214 | 215 | Parameters 216 | ---------- 217 | data_label : pandas.DataFrame 218 | A dataframe with phenotype, cell context, and iv columns. This dataframe is modified by the function. 219 | prior : pandas.DataFrame 220 | A dataframe representing prior embeddings. The index of this dataframe should contain the categories to be filtered on. 221 | columns : list[str] 222 | A list of column names in `data_label` where the values are checked against the categories in `prior`. 223 | 224 | Returns 225 | ------- 226 | pandas.DataFrame 227 | """ 228 | if isinstance(columns, str): 229 | columns = [columns] 230 | data_label_cats = reduce(lambda x, y: np.union1d(x, y), [data_label[col].astype(str).values for col in columns]) 231 | emb_cats = prior.index.to_list() 232 | strings_to_remove = list(np.setdiff1d(data_label_cats, emb_cats)) 233 | for col in columns: 234 | data_label = data_label[~data_label[col].isin(strings_to_remove)] 235 | if verbose: 236 | print(f"Removing {len(strings_to_remove)} such as {strings_to_remove[:5]} from {columns}. {data_label.shape[0]} rows remaining.", flush=True) 237 | return data_label 238 | 239 | def check_iv_emb(emb): 240 | if "type" not in emb.columns: 241 | raise KeyError("Intervention embedding has no `type` in columns.") 242 | 243 | def check_data(data_label): 244 | cols = set(data_label.columns) 245 | needed = set(["phenotype", "cell_line", "iv1"]) 246 | if not needed.issubset(cols): 247 | raise KeyError(f"Cols is missing {cols-needed}") 248 | 249 | def universal_processing(data_label): 250 | 251 | if 'phenotype' not in data_label.columns: 252 | data_label['phenotype'] = 'none' 253 | 254 | data_label['value'] = data_label['value'].astype('f4') 255 | data_label = data_label.reset_index(drop=True) 256 | 257 | 258 | data_label['iv1'] = [x.lower() for x in data_label.iv1.values] # allow translatability across organisms and drugs 259 | data_label['iv2'] = [x.lower() for x in data_label.iv2.values] 260 | check_valid(data_label) 261 | 262 | data_label_flipped = data_label.rename( 263 | columns={'iv1': 'iv2', 'iv2': 'iv1'}) 264 | data_label = pd.concat([data_label, data_label_flipped], axis=0, ignore_index=True) 265 | 266 | return data_label 267 | 268 | def check_valid(df): 269 | if 'iv1' not in df.columns: 270 | raise ValueError("Dataset must have at least one perturbation in a columns named iv1, iv2, etc.") 271 | if 'cell_line' not in df.columns: 272 | raise ValueError("Dataset must have a cellular context in a column labeled `cell_line`.") 273 | if 'negative' in df.iv1.values: 274 | raise ValueError("Dataset still contains the negative label, please specify negative_gene or negative_drug.") -------------------------------------------------------------------------------- /prophet/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from torch.utils.data import Dataset 4 | 5 | class PhenotypeDataset(Dataset): 6 | """ 7 | Dataset that gathers multiple phenotypes. The splits are done before. Experimental_data is the dataframe with the train, test or validation data. 8 | Each train, test or validation data is a different PhenotypeDataset dataset with different 'experimental_data' according to the splits 9 | """ 10 | def __init__( 11 | self, 12 | experimental_data: pd.DataFrame, 13 | label_key: str, 14 | iv_embeddings: pd.DataFrame, 15 | cell_line_embeddings: pd.DataFrame, 16 | phenotype_embeddings: pd.DataFrame = None, 17 | phenotypes: list = None, 18 | cl_embedding: bool = False, 19 | pert_len: int = 2 20 | ): 21 | """ 22 | Args: 23 | experimental_data (pd.DataFrame): experimental data, contains the label and the training data (combinations of gRNA) 24 | label_key (str): key that identifies the label in the experimental data 25 | iv_embeddings (pd.DataFrame): pandas dataframe with the embeddings of the perturbations 26 | cell_line_embeddings (pd.DataFrame): pandas dataframe with the embedding of cell lines 27 | cl_embedding (bool): if True, use predfined embedding; if False, retrieve just index cause it will be learn 28 | phenotypes (list): if phenotype embeddings are not provided, then a sorted list of phenotypes must be provided. 29 | pert_len (int): number of perturbations to provide to the model, context length will be pert_len + 2, which comes from phenotype + cell_type 30 | """ 31 | # precompute the attention mask 32 | self.attn_mask = [[False]*experimental_data.shape[0]] # always pay attention to CLS, which is first token 33 | for i in range(1, pert_len + 1): 34 | col = f'iv{i}' 35 | mask_values = experimental_data[col].isin(['negative_drug', 'negative_gene']).values # mask if negative 36 | self.attn_mask.append(mask_values) 37 | self.attn_mask.append([False]*experimental_data.shape[0]) # once for cell_line 38 | self.attn_mask.append([False]*experimental_data.shape[0]) # once for phenotype 39 | self.attn_mask = np.array(self.attn_mask).T 40 | 41 | columns = ['cell_line', 'phenotype'] + [f'iv{i}' for i in range(1, pert_len + 1)] 42 | self.experimental_data = experimental_data[columns].values # ordered 43 | self.labels = experimental_data[label_key].values 44 | self.iv = iv_embeddings.iloc[:, 1:].values 45 | self.iv_embs_types = iv_embeddings.iloc[:, 0].values 46 | self.cell_line = cell_line_embeddings.values 47 | self.iv_to_index = dict(zip(iv_embeddings.index, range(iv_embeddings.shape[0]))) 48 | self.cl_to_index = dict(zip(cell_line_embeddings.index, range(cell_line_embeddings.shape[0]))) 49 | 50 | # special handling for phenotypes 51 | self.ph_to_index = dict(zip(phenotypes, range(len(phenotypes)))) 52 | if phenotype_embeddings is not None: 53 | phenotype_embeddings = phenotype_embeddings.T[phenotypes].T # reorder the embedding so that ph_to_index matches 54 | self.phenotype_embeddings = phenotype_embeddings.values 55 | else: 56 | self.phenotype_embeddings = None 57 | 58 | self.pert_len = pert_len 59 | 60 | # print("Interventions: ", self.iv.shape) 61 | # print("Cell line: ", self.cell_line.shape) 62 | # print("Order of phenotypes: ", phenotypes) 63 | # print("Don't using explicit phenotype embeddings") if self.phenotype_embeddings is None else print(f"Explicit phenotype {self.phenotype_embeddings.shape} was passed") 64 | 65 | def __len__(self): 66 | return len(self.experimental_data) 67 | 68 | def __getitem__(self, idx): 69 | """ 70 | Returns a dictionary with: 71 | phenotype: index to query 72 | cell_line: embedding 73 | label: scalar to predict 74 | names: names of the perturbations 75 | idx: index of the observation 76 | pert_type: list that says whether perturbations are genes or drugs 77 | **iv_values_dict: dictionary with keys iv1, iv2 etc. and respective embedding values. Same size as pert_len. 78 | """ 79 | 80 | item = self.experimental_data[idx] 81 | cell_line = item[0] 82 | phenotype = item[1] 83 | 84 | iv_values_dict = {} 85 | iv_type = [] 86 | for i in range(2, self.pert_len + 2): 87 | name = item[i] # perturbation name 88 | emb_entry = self.iv[self.iv_to_index[name]] 89 | iv_type.append(emb_entry[0]) # first item of the embedding is 'gene' or 'drug' 90 | iv_values_dict[f'iv{i-1}'] = emb_entry.astype('float64') # use all dimensions but the first one 91 | 92 | # if there isn't phenotype embedding 93 | if self.phenotype_embeddings is None: 94 | context = self.ph_to_index[phenotype] # retrieve index 95 | context = context + 1 # CLS is 0 96 | else: # if there's embedding 97 | context = self.phenotype_embeddings[self.ph_to_index[phenotype]] # retrieve embedding 98 | 99 | # Gene = 0, Drug = 1 100 | iv_types = [0 if item == 'gene' else (1 if item == 'drug' else item) for item in iv_type] 101 | iv_types = np.array(iv_types) 102 | 103 | return {'phenotype': context, # sometimes an int, sometimes an embedidng 104 | 'cell_line': self.cell_line[self.cl_to_index[cell_line]], 105 | 'label': self.labels[idx], 106 | 'attn_mask': self.attn_mask[idx], 107 | 'idx': idx, 108 | 'pert_type': iv_types, 109 | **iv_values_dict 110 | } -------------------------------------------------------------------------------- /prophet/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pytorch_lightning as pl 4 | from torch import optim 5 | import torch.nn.init as init 6 | from prophet.callbacks import CosineWarmupScheduler 7 | import logging 8 | 9 | 10 | class TransformerPredictor(pl.LightningModule): 11 | 12 | def __init__(self, 13 | dim_cl: int, 14 | dim_iv: int, 15 | dim_phe: int, 16 | model_dim: int, 17 | num_heads: int, 18 | num_layers: int, 19 | iv_dropout: float, 20 | cl_dropout: float, 21 | ph_dropout: float, 22 | regressor_dropout: float, 23 | lr: float, 24 | warmup: int, 25 | weight_decay: float, 26 | max_iters: int, 27 | batch_size: int = 0, 28 | dropout = 0.0, 29 | pool: str = 'cls', 30 | simpler: bool = False, 31 | ctx_len: int = 4, 32 | mask: bool = True, 33 | sum: bool = False, 34 | explicit_phenotype: bool = False, 35 | linear_predictor: bool = False, 36 | tokenizer_layers: int = 2, 37 | seed=42, 38 | phenotypes=None): 39 | """ 40 | Inputs: 41 | dim_cl - Number of dimensions to take from the cell lines 42 | dim_iv - Number of dimensions of the interventional embeddings 43 | model_dim - Hidden dimensionality to use inside the Transformer 44 | num_heads - Number of heads to use in the Multi-Head Attention blocks 45 | num_layers - Number of encoder blocks to use. 46 | iv_dropout - dropout in iv layers 47 | cl_dropout - dropout in cl layers 48 | regressor_dropout - dropout in regressor 49 | lr - Learning rate in the optimizer 50 | warmup - Number of warmup steps. Usually between 50 and 300 51 | max_iters - Number of maximum iterations the model is trained for. This is needed for the CosineWarmup scheduler 52 | dropout - Dropout to apply inside the model 53 | pool - 'mean', 'cls', or 'pool'. Mean takes the mean, CLS predicts just with the CLS, pool takes the max value 54 | simpler - uses the transformer just for the set of perturbations, then it concatenates the result with the other representations 55 | ctx_len - context length 56 | mask - if True, mask attention, otherwise don't do it 57 | sum - if True, don't use Transformer, just sum embeddings 58 | explicit_phenotype - if True, the user passes an embedding as phenotype directly 59 | linear_predictor 60 | tokenizer_layers 61 | """ 62 | super().__init__() 63 | 64 | ## Process phenotypes to save them in the checkpoint 65 | if phenotypes is not None: 66 | self.ph_to_index = dict(zip(phenotypes, range(len(phenotypes)))) 67 | logging.info(f"Phenotypes: {self.ph_to_index}") 68 | 69 | self.save_hyperparameters() 70 | self._create_model() 71 | 72 | # Initialize the weights 73 | self.initialize_weights() 74 | 75 | def _create_model(self): 76 | self.learnable_embedding = torch.nn.Embedding(num_embeddings=1000, 77 | embedding_dim=self.hparams.model_dim, 78 | max_norm=0.5, 79 | ) 80 | self.embedding_dropout = nn.Dropout(self.hparams.ph_dropout) 81 | 82 | # Tokenizer layer strong enough to non-linearly transform the data 83 | self.gene_net = nn.Sequential( 84 | nn.Linear(self.hparams.dim_iv, self.hparams.model_dim), 85 | nn.GELU(), 86 | nn.Dropout(self.hparams.iv_dropout), 87 | nn.Linear(self.hparams.model_dim, self.hparams.model_dim) 88 | ) 89 | 90 | self.drug_net = nn.Sequential( 91 | nn.Linear(self.hparams.dim_iv, self.hparams.model_dim), 92 | nn.GELU(), 93 | nn.Dropout(self.hparams.iv_dropout), 94 | nn.Linear(self.hparams.model_dim, self.hparams.model_dim) 95 | ) 96 | 97 | self.cl_net = nn.Sequential( 98 | nn.Linear(self.hparams.dim_cl, self.hparams.model_dim), 99 | nn.GELU(), 100 | nn.Dropout(self.hparams.cl_dropout), 101 | nn.Linear(self.hparams.model_dim, self.hparams.model_dim) 102 | ) 103 | 104 | if self.hparams.tokenizer_layers == 1: 105 | self.gene_net = nn.Sequential( 106 | nn.Linear(self.hparams.dim_iv, self.hparams.model_dim), 107 | nn.Dropout(self.hparams.iv_dropout), 108 | ) 109 | self.drug_net = nn.Sequential( 110 | nn.Linear(self.hparams.dim_iv, self.hparams.model_dim), 111 | nn.Dropout(self.hparams.iv_dropout), 112 | ) 113 | self.cl_net = nn.Sequential( 114 | nn.Linear(self.hparams.dim_cl, self.hparams.model_dim), 115 | nn.Dropout(self.hparams.cl_dropout), 116 | ) 117 | 118 | 119 | if self.hparams.explicit_phenotype: 120 | self.phenotype_net = nn.Sequential( 121 | nn.Linear(self.hparams.dim_phe, self.hparams.dim_phe), 122 | nn.GELU(), 123 | nn.Dropout(self.hparams.ph_dropout), 124 | nn.Linear(self.hparams.dim_phe, self.hparams.model_dim) 125 | ) 126 | 127 | # Transformer 128 | layer = nn.TransformerEncoderLayer(d_model=self.hparams.model_dim, 129 | nhead=self.hparams.num_heads, 130 | dim_feedforward=2*self.hparams.model_dim, 131 | dropout=self.hparams.dropout, 132 | batch_first=True, 133 | activation="gelu") 134 | 135 | self.transformer = nn.TransformerEncoder(encoder_layer=layer, 136 | num_layers=self.hparams.num_layers) 137 | 138 | # 2 layers regressor 139 | dim_regressor_input = 2*self.hparams.model_dim + self.learnable_embedding.embedding_dim # CLS (model_dim) | CellLine (model_dim) | phenotype (varies) 140 | if not self.hparams.simpler: 141 | dim_regressor_input = self.hparams.model_dim 142 | if self.hparams.sum: 143 | dim_regressor_input = dim_regressor_input+self.hparams.model_dim+self.hparams.model_dim 144 | 145 | self.output_net = nn.Sequential( 146 | nn.Linear(dim_regressor_input, self.hparams.model_dim), 147 | nn.GELU(), 148 | nn.Dropout(self.hparams.regressor_dropout), 149 | nn.Linear(self.hparams.model_dim, self.hparams.model_dim), 150 | nn.GELU(), 151 | nn.Linear(self.hparams.model_dim, 1), 152 | ) 153 | if self.hparams.linear_predictor: 154 | self.output_net = nn.Sequential( 155 | nn.Linear(dim_regressor_input, 1) 156 | ) 157 | 158 | print('Gene net: ', self.gene_net, flush=True) 159 | print('Cell line net: ', self.cl_net, flush=True) 160 | print('Regressor: ', self.output_net, flush=True) 161 | if self.hparams.explicit_phenotype: 162 | print("Using explicit phenotype") 163 | if self.hparams.linear_predictor: 164 | print("Using linear predictor") 165 | 166 | def forward(self, phenotype, cl, perturbations, perturbations_type, attn_mask): 167 | """ 168 | Inputs: 169 | x - Input features of shape [Batch, SeqLen, 1] 170 | """ 171 | cl = cl[:,:self.hparams.dim_cl] 172 | perturbations = [pert[:, :self.hparams.dim_iv] for pert in perturbations] 173 | attn_mask = attn_mask[:, :self.hparams.ctx_len] 174 | 175 | if self.hparams.explicit_phenotype: 176 | phenotype_emb = self.phenotype_net(phenotype[:,:self.hparams.dim_phe]) 177 | else: 178 | phenotype_emb = self.learnable_embedding(phenotype) # Phenotype 179 | 180 | phenotype_emb = self.embedding_dropout(phenotype_emb) # dropout 181 | # shape is (batch_size x n_dim) 182 | 183 | # Drugs to drug network and genes to gene network 184 | # We mask the attention to the negative perturbations, so it's like not using the networks 185 | drug_perturbations = [self.drug_net(tensor).unsqueeze(1) for tensor in perturbations] # all perts to drug 186 | gene_perturbations = [self.gene_net(tensor).unsqueeze(1) for tensor in perturbations] # all perts to gene 187 | 188 | drug_perturbations = torch.cat(drug_perturbations, dim=1) # bs x n x dim 189 | gene_perturbations = torch.cat(gene_perturbations, dim=1) # bs x n x dim 190 | 191 | perturbations = torch.where(perturbations_type.unsqueeze(2) == 0, gene_perturbations, drug_perturbations) 192 | 193 | cl_embedding = self.cl_net(cl).unsqueeze(1) # unsqueeze just useful if not simpler 194 | phenotype_emb = phenotype_emb.unsqueeze(1) 195 | # bs x n x dim 196 | 197 | # Regression token (CLS) stored in index 0 198 | cls = torch.zeros(size=(phenotype_emb.shape[0], 1), device=phenotype_emb.device, dtype=torch.int32) 199 | cls = self.learnable_embedding(cls) # we get the embedding, shape (bs x 1 x dim) 200 | 201 | if self.hparams.pool == 'cls': 202 | if self.hparams.simpler: # if simpler, just perturbations and CLS to transformer 203 | embeddings = torch.cat((cls, perturbations), dim=1) 204 | else: # otherwise everuthing to transformer 205 | embeddings = torch.cat((cls, perturbations, cl_embedding, phenotype_emb), dim=1) 206 | else: # if not CLS, we don't need it 207 | if self.hparams.simpler: 208 | embeddings = perturbations 209 | else: 210 | embeddings = torch.cat((perturbations, cl_embedding, phenotype_emb), dim=1) 211 | 212 | # Run Transformer Layer 213 | if not self.hparams.sum: 214 | if self.hparams.mask: 215 | x = self.transformer(embeddings, mask=None, src_key_padding_mask=attn_mask) 216 | else: 217 | x = self.transformer(embeddings, mask=None) 218 | 219 | if self.hparams.pool == 'cls': 220 | x = x[:, 0, :] # use just the regressor token for regression 221 | elif self.hparams.pool == 'mean': 222 | x = torch.mean(x, dim=1) # mean-pool 223 | else: 224 | x = torch.max(x, dim=1) # max-pool 225 | 226 | # If sum, forget about everything else 227 | if self.hparams.sum: 228 | x = torch.reshape(embeddings, (embeddings.shape[0], -1)) 229 | 230 | if self.hparams.simpler: 231 | x = torch.cat((x, cl_embedding.squeeze(1), phenotype_emb.squeeze(1)), dim=-1) 232 | 233 | 234 | x = self.output_net(x) 235 | 236 | return x 237 | 238 | def embedding(self, phenotype, cl, perturbations, perturbations_type, attn_mask): 239 | """ 240 | Inputs: 241 | x - Input features of shape [Batch, SeqLen, 1] 242 | """ 243 | 244 | # Cut CL and perturbations to number of selected dimensions 245 | cl = cl[:,:self.hparams.dim_cl] 246 | perturbations = [pert[:, :self.hparams.dim_iv] for pert in perturbations] 247 | attn_mask = attn_mask[:, :self.hparams.ctx_len] 248 | 249 | if self.hparams.explicit_phenotype: 250 | # if explicit_phenotype, we take the phenotype that the user input. The user must make it dim_model-dimensional 251 | phenotype_emb = self.phenotype_net(phenotype[:,:self.hparams.model_dim]) 252 | else: 253 | phenotype_emb = self.learnable_embedding(phenotype) # Phenotype 254 | 255 | phenotype_emb = self.embedding_dropout(phenotype_emb) # dropout 256 | 257 | # Drugs to drug network and genes to gene network 258 | # We mask the attention to the negative perturbations, so it's like not using the networks 259 | drug_perturbations = [self.drug_net(tensor).unsqueeze(1) for tensor in perturbations] # all perts to drug 260 | gene_perturbations = [self.gene_net(tensor).unsqueeze(1) for tensor in perturbations] # all perts to gene 261 | 262 | drug_perturbations = torch.cat(drug_perturbations, dim=1) # bs x n x dim 263 | gene_perturbations = torch.cat(gene_perturbations, dim=1) # bs x n x dim 264 | 265 | perturbations = torch.where(perturbations_type.unsqueeze(2) == 0, gene_perturbations, drug_perturbations) 266 | 267 | cl_embedding = self.cl_net(cl).unsqueeze(1) # unsqueeze just useful if not simpler 268 | phenotype_emb = phenotype_emb.unsqueeze(1) 269 | # bs x n x dim 270 | 271 | # Regression token (CLS) stored in index 0 272 | cls = torch.zeros(size=(phenotype_emb.shape[0], 1), device=phenotype_emb.device, dtype=torch.int32) 273 | cls = self.learnable_embedding(cls) # we get the embedding, shape (bs x 1 x dim) 274 | 275 | if self.hparams.pool == 'cls': 276 | if self.hparams.simpler: # if simpler, just perturbations and CLS to transformer 277 | embeddings = torch.cat((cls, perturbations), dim=1) 278 | else: # otherwise everuthing to transformer 279 | embeddings = torch.cat((cls, perturbations, cl_embedding, phenotype_emb), dim=1) 280 | else: # if not CLS, we don't need it 281 | if self.hparams.simpler: 282 | embeddings = perturbations 283 | else: 284 | embeddings = torch.cat((perturbations, cl_embedding, phenotype_emb), dim=1) 285 | 286 | # Run Transformer Layer 287 | if not self.hparams.sum: 288 | if self.hparams.mask: 289 | x = self.transformer(embeddings, mask=None, src_key_padding_mask=attn_mask) 290 | else: 291 | x = self.transformer(embeddings, mask=None) 292 | 293 | transformer_output = x 294 | 295 | if self.hparams.pool == 'cls': 296 | pert_emb = x[:, 0, :] # use just the regressor token for regression 297 | elif self.hparams.pool == 'mean': 298 | pert_emb = torch.mean(x, dim=1) # mean-pool 299 | else: 300 | pert_emb = torch.max(x, dim=1) # max-pool 301 | 302 | # If sum, forget about everything else 303 | if self.hparams.sum: 304 | #x = torch.sum(embeddings, dim=1) 305 | x = torch.reshape(embeddings, (embeddings.shape[0], -1)) 306 | 307 | if self.hparams.simpler: 308 | x = torch.cat((pert_emb, cl_embedding.squeeze(1), phenotype_emb.squeeze(1)), dim=-1) 309 | 310 | x = self.output_net(x) 311 | 312 | return {'pert_emb': pert_emb, # output of transformer: CLS or mean 313 | 'output': x, 314 | 'perturbations_after_transformer': transformer_output[:, 1:, :], 315 | 'perturbations': perturbations, # tokens 316 | 'cl_embedding': cl_embedding, # tokens 317 | 'phenotype': phenotype, 318 | 'pert_type': perturbations_type, 319 | } 320 | 321 | def configure_optimizers(self): 322 | optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay) 323 | 324 | # Apply lr scheduler per step 325 | lr_scheduler = CosineWarmupScheduler(optimizer, 326 | warmup=self.hparams.warmup, 327 | max_iters=self.hparams.max_iters) 328 | return [optimizer], [{'scheduler': lr_scheduler, 'interval': 'step'}] 329 | 330 | def training_step(self, batch, batch_idx): 331 | 332 | # complete_masking(batch, self.hparams.ctx_len) 333 | attn_mask = batch['attn_mask'] 334 | attn_mask = attn_mask[:, :self.hparams.ctx_len] 335 | 336 | phenotype = batch['phenotype'] 337 | cl = batch['cell_line'] 338 | y = batch['label'] 339 | perturbations_type = batch['pert_type'] 340 | 341 | perturbations = [] 342 | for pert in range(1, self.hparams.ctx_len - (2 if not self.hparams.simpler else 0)): 343 | perturbations.append(batch[f'iv{pert}'].to(torch.float32)) 344 | 345 | if self.hparams.explicit_phenotype: 346 | phenotype, cl = phenotype.to(torch.float32), cl.to(torch.float32) 347 | else: 348 | phenotype, cl = phenotype.to(torch.int32), cl.to(torch.float32) 349 | 350 | y_hat = self(phenotype, cl, perturbations, perturbations_type, attn_mask) 351 | 352 | y = y.unsqueeze(1) 353 | 354 | loss = torch.nn.functional.mse_loss(y, y_hat) 355 | self.log("train_loss", loss, prog_bar=True, sync_dist=True, batch_size=phenotype.shape[0]) 356 | 357 | return {'loss': loss, 'y_pred': y_hat, 'y_true': y} 358 | 359 | def validation_step(self, batch, batch_idx): 360 | 361 | attn_mask = batch['attn_mask'] 362 | attn_mask = attn_mask[:, :self.hparams.ctx_len] 363 | 364 | phenotype = batch['phenotype'] 365 | cl = batch['cell_line'] 366 | y = batch['label'] 367 | perturbations_type = batch['pert_type'] 368 | 369 | perturbations = [] 370 | for pert in range(1, self.hparams.ctx_len - (2 if not self.hparams.simpler else 0)): 371 | perturbations.append(batch[f'iv{pert}'].to(torch.float32)) 372 | 373 | if self.hparams.explicit_phenotype: 374 | phenotype, cl = phenotype.to(torch.float32), cl.to(torch.float32) 375 | else: 376 | phenotype, cl = phenotype.to(torch.int32), cl.to(torch.float32) 377 | 378 | y_hat = self(phenotype, cl, perturbations, perturbations_type, attn_mask) 379 | 380 | y = y.unsqueeze(1) 381 | loss = torch.nn.functional.mse_loss(y, y_hat) 382 | 383 | self.log("validation_loss", loss, sync_dist=True, batch_size=phenotype.shape[0]) 384 | 385 | return {'y_pred': y_hat, 'y_true': y, 'phenotype': phenotype} 386 | 387 | def test_step(self, batch, batch_idx): 388 | 389 | attn_mask = batch['attn_mask'] 390 | attn_mask = attn_mask[:, :self.hparams.ctx_len] 391 | 392 | phenotype = batch['phenotype'] 393 | cl = batch['cell_line'] 394 | y = batch['label'] 395 | perturbations_type = batch['pert_type'] 396 | 397 | perturbations = [] 398 | for pert in range(1, self.hparams.ctx_len - (2 if not self.hparams.simpler else 0)): 399 | perturbations.append(batch[f'iv{pert}'].to(torch.float32)) 400 | 401 | if self.hparams.explicit_phenotype: 402 | phenotype, cl = phenotype.to(torch.float32), cl.to(torch.float32) 403 | else: 404 | phenotype, cl = phenotype.to(torch.int32), cl.to(torch.float32) 405 | 406 | y_hat = self(phenotype, cl, perturbations, perturbations_type, attn_mask) 407 | 408 | y = y.unsqueeze(1) 409 | loss = torch.nn.functional.mse_loss(y, y_hat) 410 | 411 | self.log("test_loss", loss, sync_dist=True, batch_size=phenotype.shape[0]) 412 | 413 | return {'y_pred': y_hat, 'y_true': y} 414 | 415 | def get_embeddings(self, batch): 416 | 417 | attn_mask = batch['attn_mask'] 418 | attn_mask = attn_mask[:, :self.hparams.ctx_len] 419 | 420 | x = batch['phenotype'] 421 | cl = batch['cell_line'] 422 | names = batch['names'] 423 | perturbations_type = batch['pert_type'] 424 | 425 | perturbations = [] 426 | for pert in range(1, self.hparams.ctx_len - (2 if not self.hparams.simpler else 0)): 427 | perturbations.append(batch[f'iv{pert}'].to(torch.float32)) 428 | 429 | x, cl = x.to(torch.int32), cl.to(torch.float32) 430 | 431 | emb_dict = self.embedding(x, cl, perturbations, perturbations_type, attn_mask) 432 | 433 | return {'pert_emb': emb_dict['pert_emb'], 434 | 'output': emb_dict['output'], 435 | 'perturbations_after_transformer': emb_dict['perturbations_after_transformer'], 436 | 'perturbations': emb_dict['perturbations'], 437 | 'cell_line': emb_dict['cl_embedding'], 438 | 'phenotype': emb_dict['phenotype'], 439 | 'pert_type': emb_dict['pert_type'], 440 | 'names': names} 441 | 442 | def predict_step(self, batch, batch_idx): 443 | 444 | attn_mask = batch['attn_mask'] 445 | attn_mask = attn_mask[:, :self.hparams.ctx_len] 446 | 447 | phenotype = batch['phenotype'] 448 | cl = batch['cell_line'] 449 | y = batch['label'] 450 | perturbations_type = batch['pert_type'] 451 | 452 | perturbations = [] 453 | for pert in range(1, self.hparams.ctx_len - (2 if not self.hparams.simpler else 0)): 454 | perturbations.append(batch[f'iv{pert}'].to(torch.float32)) 455 | 456 | if self.hparams.explicit_phenotype: 457 | phenotype, cl = phenotype.to(torch.float32), cl.to(torch.float32) 458 | else: 459 | phenotype, cl = phenotype.to(torch.int32), cl.to(torch.float32) 460 | 461 | y_hat = self(phenotype, cl, perturbations, perturbations_type, attn_mask) 462 | 463 | return y_hat, y 464 | 465 | def on_before_optimizer_step(self, optimizer) -> None: 466 | 467 | total_norm = 0 468 | for p in self.parameters(): 469 | param_norm = p.grad.data.norm(2) 470 | total_norm += param_norm.item() ** 2 471 | total_norm = total_norm ** 0.5 472 | self.log('grad_norm', total_norm, sync_dist=True) 473 | 474 | return super().on_before_optimizer_step(optimizer) 475 | 476 | def initialize_weights(self): 477 | for m in self.modules(): 478 | if isinstance(m, nn.Linear): 479 | # You can choose a different initialization method 480 | init.xavier_normal_(m.weight) 481 | init.zeros_(m.bias) 482 | 483 | 484 | def load_models_config(models_config, seed, hparams=False, trial=None, phenotypes=None): 485 | # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 486 | 487 | if hparams: 488 | if trial is None: 489 | raise ValueError("Must pass trial when hparams is True.") 490 | 491 | transformer = TransformerPredictor(dim_cl=models_config.transformer.dim_cl, dim_iv=models_config.transformer.dim_iv, dim_phe=models_config.transformer.dim_phe, 492 | model_dim=models_config.transformer.model_dim, num_heads=models_config.transformer.num_heads, 493 | num_layers=models_config.transformer.num_layers, iv_dropout=models_config.transformer.iv_dropout, 494 | cl_dropout=models_config.transformer.cl_dropout, ph_dropout=models_config.transformer.ph_dropout, regressor_dropout=models_config.transformer.regressor_dropout, 495 | lr=models_config.transformer.lr, weight_decay=models_config.transformer.weight_decay, warmup=models_config.transformer.warmup, batch_size=models_config.batch_size, 496 | max_iters=models_config.transformer.max_iters, dropout=models_config.transformer.dropout, 497 | pool=models_config.transformer.pool, simpler=models_config.transformer.simpler, 498 | ctx_len=models_config.ctx_len, mask=models_config.transformer.mask, sum=models_config.transformer.sum, 499 | explicit_phenotype=models_config.transformer.explicit_phenotype, linear_predictor=models_config.transformer.linear_predictor, tokenizer_layers=models_config.transformer.tokenizer_layers, 500 | seed=seed,phenotypes=None) 501 | 502 | return transformer, models_config 503 | -------------------------------------------------------------------------------- /prophet/train.py: -------------------------------------------------------------------------------- 1 | from model import TransformerPredictor 2 | from prophet.callbacks import R2ScoreCallback 3 | import pytorch_lightning as pl 4 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor 5 | from pytorch_lightning.loggers import WandbLogger 6 | from torchmetrics.regression import R2Score 7 | import torch 8 | import wandb 9 | import os 10 | 11 | 12 | def train_transformer(data, model, config, name, seed): 13 | train_dataloader, valid_dataloader, test_dataloader, train_indices, test_indices, descriptor = data 14 | 15 | wandb_config = { 16 | 'model': 'Transformer', 17 | 'descr':descriptor, 18 | 'seed': seed, 19 | 'leaveout_method':config.leaveout_method, 20 | 'setting':config.setting, 21 | 'path': config.path, 22 | 'pooling': config.transformer.pool, 23 | 'unbalanced':config.unbalanced, 24 | 'n_heads':config.transformer.num_heads, 25 | 'n_layers':config.transformer.num_layers, 26 | 'gene_prior': os.path.basename(config.genes_prior[0]), 27 | 'cell_lines_prior': os.path.basename(config.cell_lines_prior[0]), 28 | 'batch_size': config.batch_size, 29 | 'early_stopping': config.early_stopping, 30 | 'patience': config.patience, 31 | 'ckpt_path': config.ckpt_path, 32 | 'fine_tune': config.fine_tune, 33 | 'max_steps': config.max_steps, 34 | } 35 | sub_descr = {f'descr{i}':x for i, x in enumerate(descriptor.split('_'))} 36 | 37 | wandb_logger = WandbLogger( 38 | project=config.project_name, 39 | name=f'{descriptor}_{config.setting}_nheads_{config.transformer.num_heads}_nlayers_{config.transformer.num_layers}_{config.transformer.simpler}simpler_{config.transformer.mask}mask_{config.transformer.lr}lr_{config.transformer.warmup}warmup_{config.transformer.max_iters}max_iters', 40 | config={**wandb_config, **sub_descr}, 41 | ) 42 | 43 | lr_monitor = LearningRateMonitor(logging_interval='step') 44 | dirpath = f"./pretrained_prophet/{config.setting}/{descriptor}_{config.leaveout_method[6:]}_{config.transformer.dim_cl}cl_{config.transformer.dim_iv}iv_{config.transformer.model_dim}model_{config.transformer.num_layers}layers_{config.transformer.simpler}simpler_{config.transformer.mask}mask_{config.transformer.lr}lr_{config.transformer.explicit_phenotype}explicitphenotype_{config.transformer.warmup}warmup_{config.transformer.max_iters}max_iters_{config.unbalanced}unbalanced_{config.transformer.weight_decay}wd_{config.batch_size}bs_{config.fine_tune}ft/{name}_seed_{seed}" 45 | model_checkpointer = ModelCheckpoint(dirpath=dirpath, save_top_k=1, every_n_epochs=1, monitor='R2', mode='max') 46 | r2_callback = R2ScoreCallback(device=model.device, average=True if config.setting == 'everything' else False) 47 | early_stopping = EarlyStopping(monitor="R2", mode="max", patience=config.patience, min_delta=0.0) 48 | 49 | callbacks = [r2_callback, model_checkpointer, lr_monitor, early_stopping] 50 | if not config.early_stopping: 51 | callbacks = [r2_callback, model_checkpointer, lr_monitor] 52 | 53 | test_mode = False # manual toggle for debugging 54 | 55 | trainer = pl.Trainer( 56 | min_epochs=1, 57 | max_steps=config.max_steps, 58 | accelerator='gpu', 59 | devices=-1, 60 | check_val_every_n_epoch=1, 61 | callbacks=callbacks, 62 | logger=wandb_logger, 63 | strategy="ddp", 64 | gradient_clip_val=1, 65 | deterministic=True) 66 | 67 | if config.ckpt_path is not None: 68 | print("name: ", name) 69 | axis_name = 'gene' if 'gene' in config.leaveout_method else 'cl' 70 | if axis_name == 'gene': 71 | axis = 'iv' 72 | else: 73 | axis = 'cl' 74 | split_num = descriptor.split('_')[1] 75 | ckpt_path = f"{config.ckpt_path}/{axis_name}_{split_num}_seed_{seed}/" 76 | if config.fine_tune: 77 | ckpt_path = f"{config.ckpt_path.replace('iv_0_iv', f'{axis}_{split_num}_{axis}')}/{axis_name}_{split_num}_seed_{seed}/" 78 | files = os.listdir(ckpt_path) # ckpt_path is the path to a folder 79 | full_paths = [os.path.join(ckpt_path, file) for file in files] 80 | ckpt_file = max(full_paths, key=os.path.getmtime) # ckpt_file is the .ckpt file 81 | print(f"Resume training from {ckpt_file}") 82 | if config.fine_tune: 83 | print("Fine tuning") 84 | model = TransformerPredictor.load_from_checkpoint(ckpt_file, warmup=config.transformer.warmup) 85 | trainer.fit(model=model, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader) # train from scratch 86 | else: 87 | trainer.fit(model=model, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader, ckpt_path=config.ckpt_path) 88 | else: 89 | trainer.fit(model=model, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader, ckpt_path=config.ckpt_path) 90 | 91 | if not test_mode: 92 | 93 | # Get most recent checkpoint 94 | files = os.listdir(dirpath) 95 | full_paths = [os.path.join(dirpath, file) for file in files] 96 | ckpt_file = max(full_paths, key=os.path.getmtime) 97 | 98 | print(f"Testing model from {ckpt_file}") 99 | # Load best model in terms of R2 100 | model = TransformerPredictor.load_from_checkpoint(checkpoint_path=ckpt_file) 101 | print(type(test_dataloader)) 102 | if not isinstance(test_dataloader, dict): 103 | trainer.test(model, test_dataloader) 104 | else: 105 | for id_dataset, dataset in test_dataloader.items(): 106 | if id_dataset == 'all': # logs as R2_test as normal 107 | trainer.test(model, dataset) 108 | continue 109 | 110 | if id_dataset == 'all': 111 | trainer.test(model, dataset) 112 | continue 113 | 114 | predictions_and_targets = trainer.predict(model, dataset) 115 | 116 | predictions = [t[0] for t in predictions_and_targets] 117 | targets = [t[1] for t in predictions_and_targets] 118 | 119 | predictions = torch.cat(predictions, dim=0) 120 | targets = torch.cat(targets, dim=0) 121 | 122 | r2score = R2Score().to(model.device) 123 | wandb.log({f"R2_test_{id_dataset}": r2score(predictions, targets.unsqueeze(-1))}) 124 | 125 | 126 | wandb.finish() 127 | -------------------------------------------------------------------------------- /prophet/train_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import pprint 4 | 5 | from dataloader import dataloader_phenotypes, get_split_indices, get_data_by_setting 6 | from train import train_transformer 7 | from model import load_models_config 8 | import pytorch_lightning as pl 9 | from prophet.config import set_config 10 | import os 11 | import yaml 12 | 13 | # Add arguments 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--setting", type=str, 16 | default="Rad", required=False) 17 | parser.add_argument("--leaveout_method", type=str, 18 | default="leave_one_cl_out", required=False) 19 | parser.add_argument("--config_file", type=str, required=True) # config file with info regarding the architecture 20 | parser.add_argument("--fine_tune", action='store_true', default=False, required=False) 21 | 22 | args = parser.parse_args() 23 | config_file = args.config_file.split('-')[1] 24 | seed = int(args.config_file.split('-')[0]) 25 | 26 | def get_global_rank(): 27 | return int(os.getenv('SLURM_PROCID', '0')) 28 | 29 | if __name__ == "__main__": 30 | 31 | # Check the number of available GPUs 32 | num_gpus = torch.cuda.device_count() 33 | print("Number of gpus: ", num_gpus) 34 | 35 | global_rank = get_global_rank() 36 | print(f"Global Rank: {global_rank}") 37 | 38 | with open(config_file, 'r') as f: 39 | models_config = yaml.safe_load(f) 40 | 41 | models_config = set_config(models_config) 42 | 43 | # override config to the default fine tuning model parameters during fine-tuning 44 | if args.fine_tune: 45 | with open('config_files/config_file_finetuning.yaml', 'r') as f: 46 | ft_config = set_config(yaml.safe_load(f)) 47 | ft_config.setting = models_config.setting 48 | ft_config.leaveout_method = models_config.leaveout_method 49 | ft_config.dirpath = models_config.dirpath 50 | models_config = ft_config 51 | 52 | pl.seed_everything(seed, workers=True) 53 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 54 | os.environ['TORCH_USE_CUDA_DSA'] = '1' 55 | 56 | data_label, gene_prior, cl_prior, phe_prior, path = get_data_by_setting(models_config.setting, models_config.genes_prior, models_config.cell_lines_prior, models_config.phenotype_prior) 57 | 58 | models_config.path = path # to print the path 59 | 60 | pprint.pprint(models_config) 61 | 62 | # Get the indices of the DataFrame 63 | indices = data_label.index 64 | indices = get_split_indices(data_label, models_config.leaveout_method, seed) 65 | 66 | for index in indices: 67 | 68 | data = dataloader_phenotypes( 69 | gene_embedding = gene_prior, 70 | cell_lines_embedding = cl_prior, 71 | phenotype_embedding = phe_prior, 72 | data_label = data_label, 73 | label_name = "value", 74 | index = index, 75 | batch_size = models_config.batch_size, 76 | unbalanced = models_config.unbalanced, 77 | pert_len = models_config.pert_len, 78 | ) 79 | 80 | models_config.ohe_dim = 0 81 | 82 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 83 | model, models_config = load_models_config(models_config, seed) 84 | 85 | train_transformer(data=data, model=model, config=models_config, name=index[-1], seed=seed) 86 | del data 87 | del model 88 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='prophet', 5 | version='0.1.0', 6 | url="https://github.com/theislab/prophet", 7 | license='CC-BY-NC 4.0', 8 | description='Scalable and universal prediction of cellular phenotypes', 9 | author='Alejandro Tejada-Lapuerta, Yuge Ji', 10 | author_email='alejandro.tejada@helmholtz-munich.de, yuge.ji@helmholtz-munich.de', 11 | packages=find_packages(), 12 | python_requires='>=3.10', 13 | install_requires=[ 14 | 'joblib==1.4.2', 15 | 'numpy==2.2.3', 16 | 'pandas==2.2.3', 17 | 'pytorch_lightning==2.1.0', 18 | 'PyYAML==6.0.2', 19 | 'scikit_learn==1.5.1', 20 | 'scipy==1.14.0', 21 | 'torch==2.3.0', 22 | 'torchmetrics==1.4.0.post0', 23 | 'tqdm==4.66.4', 24 | 'wandb==0.17.6', 25 | 'jupyterlab==4.1.5', 26 | 'pytest==8.3.5' 27 | ], 28 | ) -------------------------------------------------------------------------------- /test/test_dataloader.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pandas as pd 3 | import numpy as np 4 | from torch.utils.data import DataLoader 5 | 6 | from prophet.dataloader import ( 7 | dataloader_phenotypes, 8 | process_priors, 9 | remove_nonexistent_cat, 10 | universal_processing, 11 | ) 12 | 13 | # Mock data 14 | @pytest.fixture 15 | def mock_data(): 16 | iv_embedding = pd.DataFrame({ 17 | 'type': ['gene', 'gene', 'drug', 'gene', 'drug'], 18 | 'feature1': [0.1, 0.2, 0.3, 0.0, 0.0], 19 | 'feature2': [0.4, 0.5, 0.6, 0.0, 0.0] 20 | }, index=['gene1', 'gene2', 'drug1', 'negative_gene', 'negative_drug']) # assume embedding for drug2 is not exist 21 | 22 | 23 | cell_lines_embedding = pd.DataFrame({ 24 | 'feature1': [0.7, 0.8], 25 | 'feature2': [0.9, 1.0] 26 | }, index=['cell_line_1', 'cell_line_2']) 27 | 28 | phenotype_embedding = None # usually None 29 | 30 | data_label = pd.DataFrame({ 31 | 'phenotype': ['phenotype1', 'phenotype2','phenotype1', 'phenotype2'], 32 | 'cell_line': ['cell_line_1', 'cell_line_2', 'cell_line_1', 'cell_line_2'], 33 | 'iv1': ['gene1', 'gene2', 'drug1', 'drug2'], 34 | 'iv2': ['negative_gene', 'negative_gene', 'negative_drug', 'negative_drug'], 35 | 'value': [0.5, 0.6, 0.9, 0.3] 36 | }) 37 | 38 | index = ( 39 | np.array([0,2]), # train_indices 40 | np.array([1]), # validation_indices 41 | [], # test_indices 42 | None # cl_holdout 43 | ) 44 | 45 | return iv_embedding, cell_lines_embedding, phenotype_embedding, data_label, index 46 | 47 | def test_dataloader_phenotypes(mock_data): 48 | 49 | iv_embedding, cell_lines_embedding, phenotype_embedding, data_label, index = mock_data 50 | 51 | # Just filter out rows where iv1 or iv2 values that dont have embeddings 52 | valid_iv1 = data_label['iv1'].isin(iv_embedding.index) 53 | valid_iv2 = data_label['iv2'].isin(iv_embedding.index) 54 | data_label = data_label[valid_iv1 & valid_iv2] 55 | 56 | train_dataloader, valid_dataloader, test_dataloader, train_indices, test_indices, cl_holdout = dataloader_phenotypes( 57 | gene_embedding=iv_embedding, 58 | cell_lines_embedding=cell_lines_embedding, 59 | phenotype_embedding=phenotype_embedding, 60 | data_label=data_label, 61 | index=index, 62 | batch_size=2, 63 | label_name='value', 64 | unbalanced=False, 65 | torch_dataset=True, 66 | valid_set=True 67 | ) 68 | 69 | assert isinstance(train_dataloader, DataLoader) 70 | assert len(train_dataloader.dataset) == 2 # indices: 0,2 71 | 72 | assert isinstance(valid_dataloader, DataLoader) 73 | assert len(valid_dataloader.dataset) == 1 # indices: 1 74 | 75 | assert test_dataloader is None # no test samples given 76 | 77 | def test_process_priors(mock_data): 78 | iv_embedding, cell_lines_embedding, phenotype_embedding, _, _ = mock_data 79 | 80 | # setup temp paths 81 | iv_path = "./iv_embedding.csv" 82 | cl_path = "./cell_lines_embedding.csv" 83 | ph_path = "./phenotype_embedding.csv" 84 | 85 | iv_embedding.to_csv(iv_path) 86 | cell_lines_embedding.to_csv(cl_path) 87 | if phenotype_embedding is not None: #create dummy files for testing of process_priors 88 | pd.DataFrame().to_csv(ph_path) 89 | else: 90 | ph_path = None 91 | 92 | iv_prior, cl_prior, phe_prior = process_priors( 93 | genes_prior=[str(iv_path)], 94 | cell_lines_prior=[str(cl_path)], 95 | phenotype_prior=[str(ph_path)] if ph_path else None 96 | ) 97 | 98 | assert 'negative_gene' in iv_prior.index 99 | assert 'negative_drug' in iv_prior.index 100 | assert iv_prior.loc['negative_gene', 'type'] == 'gene' 101 | 102 | assert cl_prior.shape == cell_lines_embedding.shape 103 | 104 | assert phe_prior == None 105 | 106 | def test_remove_nonexistent_cat(mock_data): 107 | iv_embedding, cell_lines_embedding, phenotype_embedding, data_label, _ = mock_data 108 | embeddings = [iv_embedding, cell_lines_embedding, phenotype_embedding] 109 | 110 | iv_cols = ['iv1', 'iv2'] 111 | cl_col = "cell_line" 112 | ph_col = "phenotype" 113 | cols = [iv_cols, cl_col, ph_col] 114 | 115 | prior = pd.DataFrame(index=['gene1', 'gene2', 'drug1']) 116 | print("embeddings:", embeddings) 117 | 118 | for i, embedding in enumerate(embeddings): 119 | if embedding is None: # phenotype embedding can be None 120 | continue 121 | data_label = remove_nonexistent_cat(data_label, embedding, cols[i], verbose=True) 122 | 123 | assert len(data_label) == 3 # 4-1= 3 According to mock embeddings and datalabel, one row contains 124 | # "drug2" should be removed since its embedding does not exist! 125 | assert 'gene1' in data_label['iv1'].values 126 | assert 'gene2' in data_label['iv1'].values 127 | assert 'drug1' in data_label['iv1'].values 128 | assert 'drug2' not in data_label['iv1'].values # removed 129 | 130 | def test_universal_processing(mock_data): 131 | _, _, _, data_label, _ = mock_data 132 | 133 | processed_data = universal_processing(data_label) 134 | 135 | assert 'phenotype' in processed_data.columns 136 | assert 'value' in processed_data.columns 137 | assert 'iv1' in processed_data.columns 138 | assert 'iv2' in processed_data.columns 139 | assert len(processed_data) == 8 # original (4) + flipped rows (4) -------------------------------------------------------------------------------- /tutorials/config_file_finetuning.yaml: -------------------------------------------------------------------------------- 1 | setting: NULL 2 | leaveout_method: NULL 3 | 4 | project_name: FT_Prophet 5 | pert_len: 2 6 | ckpt_path: /lustre/groups/ml01/projects/super_rad_project/pretrained_prophet/everything/iv_0_iv_out_multitest_300cl_1219iv_512model_8layers_Falsesimpler_Truemask_0.0001lr_Falseexplicitphenotype_10000warmup_150001max_iters_Falseunbalanced_0.01wd_4096bs_Falseft/iv_0_seed_2024/epoch=29.ckpt 7 | fine_tune: True 8 | 9 | cell_lines_prior: ['../embs/cell_line_embedding_full_ccle_300_scaled.csv'] 10 | genes_prior: [ '../embs/global_iv_scaledv3.csv', 11 | '../embs/CTRP_with_smiles_simscaled.csv', 12 | '../embs/Hadian_plates_NEW_simscaled.csv'] 13 | max_steps: 1000 14 | batch_size: 256 15 | 16 | Transformer: 17 | simpler: False 18 | num_layers: 8 19 | num_heads: 8 20 | model_dim: 512 21 | max_iters: 5000 22 | dim_iv: 1219 23 | warmup: 5000 24 | iv_dropout: 0.1 25 | cl_dropout: 0.1 26 | ph_dropout: 0.1 -------------------------------------------------------------------------------- /tutorials/finetuning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "f95e67be", 6 | "metadata": {}, 7 | "source": [ 8 | "# fine-tuning using your in-house assay\n", 9 | "\n", 10 | "This tutorial is for users who have already run a functional assay and would like to answer the question \"what if I had run my assay with other drugs, other CRISPR treatments, or in other cell lines?\"\n", 11 | "\n", 12 | "For example, [Tieu et al (2024)](https://doi.org/10.1016/j.cell.2024.01.035) run a combinatorial CAR-T transduction with 24 guides for a total of 576 pairwise combinations. We can fine-tune Prophet on this dataset and make predictions for genes spanning the entire genome and additional combinations therein." 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 1, 18 | "id": "66127741", 19 | "metadata": { 20 | "scrolled": false 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "import pandas as pd\n", 25 | "import yaml\n", 26 | "from prophet import Prophet\n", 27 | "from prophet.config import set_config\n", 28 | "from prophet.dataloader import universal_processing" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "id": "164fa87f", 34 | "metadata": {}, 35 | "source": [ 36 | "We load in a config file to get the file paths for the embeddings which should be used. These can be downloaded from the links on Github." 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 2, 42 | "id": "8823f063", 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "with open('config_file_finetuning.yaml', 'r') as f:\n", 47 | " config = set_config(yaml.safe_load(f))" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "id": "6d1aec6c", 53 | "metadata": {}, 54 | "source": [ 55 | "We've randomly select one of the pretrained model checkpoints (noted in the config file above) to train on here. For more robust results, we recommend training with several checkpoints and taking the ensemble prediction." 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 3, 61 | "id": "004e340c", 62 | "metadata": {}, 63 | "outputs": [ 64 | { 65 | "data": { 66 | "text/plain": [ 67 | "'/lustre/groups/ml01/projects/super_rad_project/pretrained_prophet/everything/iv_0_iv_out_multitest_300cl_1219iv_512model_8layers_Falsesimpler_Truemask_0.0001lr_Falseexplicitphenotype_10000warmup_150001max_iters_Falseunbalanced_0.01wd_4096bs_Falseft/iv_0_seed_2024/epoch=29.ckpt'" 68 | ] 69 | }, 70 | "execution_count": 3, 71 | "metadata": {}, 72 | "output_type": "execute_result" 73 | } 74 | ], 75 | "source": [ 76 | "config.ckpt_path" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "id": "54fd2dca", 82 | "metadata": {}, 83 | "source": [ 84 | "Load in the pretrained model." 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 4, 90 | "id": "740983bb", 91 | "metadata": {}, 92 | "outputs": [ 93 | { 94 | "name": "stdout", 95 | "output_type": "stream", 96 | "text": [ 97 | "returning trained model!\n", 98 | "Gene net: Sequential(\n", 99 | " (0): Linear(in_features=1219, out_features=512, bias=True)\n", 100 | " (1): GELU(approximate='none')\n", 101 | " (2): Dropout(p=0.2, inplace=False)\n", 102 | " (3): Linear(in_features=512, out_features=512, bias=True)\n", 103 | ")\n", 104 | "Cell line net: Sequential(\n", 105 | " (0): Linear(in_features=300, out_features=512, bias=True)\n", 106 | " (1): GELU(approximate='none')\n", 107 | " (2): Dropout(p=0.2, inplace=False)\n", 108 | " (3): Linear(in_features=512, out_features=512, bias=True)\n", 109 | ")\n", 110 | "Regressor: Sequential(\n", 111 | " (0): Linear(in_features=512, out_features=512, bias=True)\n", 112 | " (1): GELU(approximate='none')\n", 113 | " (2): Dropout(p=0.2, inplace=False)\n", 114 | " (3): Linear(in_features=512, out_features=512, bias=True)\n", 115 | " (4): GELU(approximate='none')\n", 116 | " (5): Linear(in_features=512, out_features=1, bias=True)\n", 117 | ")\n" 118 | ] 119 | } 120 | ], 121 | "source": [ 122 | "model = Prophet(\n", 123 | " iv_emb_path=config.genes_prior,\n", 124 | " cl_emb_path=config.cell_lines_prior,\n", 125 | " ph_emb_path=None,\n", 126 | " model_pth=config.ckpt_path,\n", 127 | ")" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "id": "c4cdd598", 133 | "metadata": {}, 134 | "source": [ 135 | "Here we've provided two examples for finetuning to demonstrate the variety of datasets on which you can finetune Prophet. We recommend that all datasets have values minmaxed to between 0 and 1 before finetuning." 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "id": "05e05905", 141 | "metadata": {}, 142 | "source": [ 143 | "### GDSC2\n", 144 | "\n", 145 | "GDSC2 (https://www.cancerrxgene.org/) is a cancer-screening dataset with titrated IC50s.\n", 146 | "\n", 147 | " - cell states: cancer cell lines\n", 148 | " - interventions: small molecule singletons\n", 149 | " - readout: IC50 of viability as measured using CellTitreGlo at 72hrs, fitted over multiple concentrations" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 5, 155 | "id": "b45cad0d", 156 | "metadata": {}, 157 | "outputs": [ 158 | { 159 | "data": { 160 | "text/html": [ 161 | "
\n", 162 | "\n", 175 | "\n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | "
cell_lineiv1valueiv_namephenotypeiv2
0PFSK1cc[c@@]1(o)c(=o)occ2c1cc1-c3nc4ccccc4cc3cn1c2=o0.323078CamptothecinGDSCnegative_drug
1A673cc[c@@]1(o)c(=o)occ2c1cc1-c3nc4ccccc4cc3cn1c2=o0.172422CamptothecinGDSCnegative_drug
2ES5cc[c@@]1(o)c(=o)occ2c1cc1-c3nc4ccccc4cc3cn1c2=o0.239133CamptothecinGDSCnegative_drug
3ES7cc[c@@]1(o)c(=o)occ2c1cc1-c3nc4ccccc4cc3cn1c2=o0.164659CamptothecinGDSCnegative_drug
4EW11cc[c@@]1(o)c(=o)occ2c1cc1-c3nc4ccccc4cc3cn1c2=o0.222290CamptothecinGDSCnegative_drug
\n", 235 | "
" 236 | ], 237 | "text/plain": [ 238 | " cell_line iv1 value \\\n", 239 | "0 PFSK1 cc[c@@]1(o)c(=o)occ2c1cc1-c3nc4ccccc4cc3cn1c2=o 0.323078 \n", 240 | "1 A673 cc[c@@]1(o)c(=o)occ2c1cc1-c3nc4ccccc4cc3cn1c2=o 0.172422 \n", 241 | "2 ES5 cc[c@@]1(o)c(=o)occ2c1cc1-c3nc4ccccc4cc3cn1c2=o 0.239133 \n", 242 | "3 ES7 cc[c@@]1(o)c(=o)occ2c1cc1-c3nc4ccccc4cc3cn1c2=o 0.164659 \n", 243 | "4 EW11 cc[c@@]1(o)c(=o)occ2c1cc1-c3nc4ccccc4cc3cn1c2=o 0.222290 \n", 244 | "\n", 245 | " iv_name phenotype iv2 \n", 246 | "0 Camptothecin GDSC negative_drug \n", 247 | "1 Camptothecin GDSC negative_drug \n", 248 | "2 Camptothecin GDSC negative_drug \n", 249 | "3 Camptothecin GDSC negative_drug \n", 250 | "4 Camptothecin GDSC negative_drug " 251 | ] 252 | }, 253 | "execution_count": 5, 254 | "metadata": {}, 255 | "output_type": "execute_result" 256 | } 257 | ], 258 | "source": [ 259 | "gdsc_data_path = \"/lustre/groups/ml01/projects/super_rad_project/data/GDSC_notscaled_minmax.csv\"\n", 260 | "data_label = pd.read_csv(gdsc_data_path, index_col=0)\n", 261 | "\n", 262 | "data_label['iv2'] = 'negative_drug' # there is no second compound so we specify negative_drug so the token is masked\n", 263 | "data_label['phenotype'] = 'GDSC' # this is the label to use for this readout when you run inference later\n", 264 | "\n", 265 | "data_label = universal_processing(data_label)\n", 266 | "data_label.head()" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": 6, 272 | "id": "bbf88222", 273 | "metadata": { 274 | "scrolled": true 275 | }, 276 | "outputs": [ 277 | { 278 | "name": "stdout", 279 | "output_type": "stream", 280 | "text": [ 281 | "Removing 61 such as ['123138', '123829', '150412', '50869', '615590'] from ['iv1', 'iv2']. 393646 rows remaining.\n", 282 | "Removing 285 such as ['451LU', '7860', 'ALLPO', 'ARH77', 'ATN1'] from ['cell_line']. 280202 rows remaining.\n", 283 | "Fitting model.\n", 284 | "pytorch model, finetuning\n", 285 | "Gene net: Sequential(\n", 286 | " (0): Linear(in_features=1219, out_features=512, bias=True)\n", 287 | " (1): GELU(approximate='none')\n", 288 | " (2): Dropout(p=0.1, inplace=False)\n", 289 | " (3): Linear(in_features=512, out_features=512, bias=True)\n", 290 | ")\n", 291 | "Cell line net: Sequential(\n", 292 | " (0): Linear(in_features=300, out_features=512, bias=True)\n", 293 | " (1): GELU(approximate='none')\n", 294 | " (2): Dropout(p=0.1, inplace=False)\n", 295 | " (3): Linear(in_features=512, out_features=512, bias=True)\n", 296 | ")\n", 297 | "Regressor: Sequential(\n", 298 | " (0): Linear(in_features=512, out_features=512, bias=True)\n", 299 | " (1): GELU(approximate='none')\n", 300 | " (2): Dropout(p=0.2, inplace=False)\n", 301 | " (3): Linear(in_features=512, out_features=512, bias=True)\n", 302 | " (4): GELU(approximate='none')\n", 303 | " (5): Linear(in_features=512, out_features=1, bias=True)\n", 304 | ")\n" 305 | ] 306 | }, 307 | { 308 | "name": "stderr", 309 | "output_type": "stream", 310 | "text": [ 311 | "GPU available: True (cuda), used: True\n", 312 | "TPU available: False, using: 0 TPU cores\n", 313 | "IPU available: False, using: 0 IPUs\n", 314 | "HPU available: False, using: 0 HPUs\n" 315 | ] 316 | }, 317 | { 318 | "name": "stdout", 319 | "output_type": "stream", 320 | "text": [ 321 | "R2 average: False\n", 322 | "Running with early stopping: True\n", 323 | "Early stopping patience: 20\n" 324 | ] 325 | }, 326 | { 327 | "name": "stderr", 328 | "output_type": "stream", 329 | "text": [ 330 | "You are using a CUDA device ('NVIDIA A100-PCIE-40GB MIG 3g.20gb') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", 331 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [MIG-c51c82f2-7e04-56a8-8b74-4211c9821715]\n", 332 | "\n", 333 | " | Name | Type | Params\n", 334 | "-----------------------------------------------------------\n", 335 | "0 | learnable_embedding | Embedding | 512 K \n", 336 | "1 | embedding_dropout | Dropout | 0 \n", 337 | "2 | gene_net | Sequential | 887 K \n", 338 | "3 | drug_net | Sequential | 887 K \n", 339 | "4 | cl_net | Sequential | 416 K \n", 340 | "5 | transformer | TransformerEncoder | 16.8 M\n", 341 | "6 | output_net | Sequential | 525 K \n", 342 | "-----------------------------------------------------------\n", 343 | "20.1 M Trainable params\n", 344 | "0 Non-trainable params\n", 345 | "20.1 M Total params\n", 346 | "80.206 Total estimated model params size (MB)\n" 347 | ] 348 | }, 349 | { 350 | "name": "stdout", 351 | "output_type": "stream", 352 | "text": [ 353 | "Epoch 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 124/124 [00:37<00:00, 3.26it/s, v_num=8, train_loss=0.119]\n", 354 | "Validation: | | 0/? [00:00\n", 442 | "\n", 455 | "\n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | "
iv1iv2valuecell_linephenotype
0batf3batf30.633443JURKATT-cell_viability
1batf3cblb0.544929JURKATT-cell_viability
2batf3ctla40.483662JURKATT-cell_viability
3batf3dhx370.693173JURKATT-cell_viability
4batf3fas0.594995JURKATT-cell_viability
..................
1245tigitzc3h12a0.604678JURKATT-cell_viability
1246toxzc3h12a0.596647JURKATT-cell_viability
1247tox2zc3h12a0.418890JURKATT-cell_viability
1248traczc3h12a0.625894JURKATT-cell_viability
1249zc3h12azc3h12a0.490973JURKATT-cell_viability
\n", 557 | "

1250 rows × 5 columns

\n", 558 | "" 559 | ], 560 | "text/plain": [ 561 | " iv1 iv2 value cell_line phenotype\n", 562 | "0 batf3 batf3 0.633443 JURKAT T-cell_viability\n", 563 | "1 batf3 cblb 0.544929 JURKAT T-cell_viability\n", 564 | "2 batf3 ctla4 0.483662 JURKAT T-cell_viability\n", 565 | "3 batf3 dhx37 0.693173 JURKAT T-cell_viability\n", 566 | "4 batf3 fas 0.594995 JURKAT T-cell_viability\n", 567 | "... ... ... ... ... ...\n", 568 | "1245 tigit zc3h12a 0.604678 JURKAT T-cell_viability\n", 569 | "1246 tox zc3h12a 0.596647 JURKAT T-cell_viability\n", 570 | "1247 tox2 zc3h12a 0.418890 JURKAT T-cell_viability\n", 571 | "1248 trac zc3h12a 0.625894 JURKAT T-cell_viability\n", 572 | "1249 zc3h12a zc3h12a 0.490973 JURKAT T-cell_viability\n", 573 | "\n", 574 | "[1250 rows x 5 columns]" 575 | ] 576 | }, 577 | "execution_count": 7, 578 | "metadata": {}, 579 | "output_type": "execute_result" 580 | } 581 | ], 582 | "source": [ 583 | "tieu_data_path = \"/ictstr01/home/icb/yuge.ji/projects/super_rad_project/data/TieuQi_lfc_day11_minmax.csv\"\n", 584 | "data_label = pd.read_csv(tieu_data_path, index_col=0)\n", 585 | "\n", 586 | "data_label['phenotype'] = 'T-cell_viability' # this is the label to use for this readout when you run inference later\n", 587 | "\n", 588 | "data_label = universal_processing(data_label)\n", 589 | "data_label" 590 | ] 591 | }, 592 | { 593 | "cell_type": "code", 594 | "execution_count": 8, 595 | "id": "1b5c884c", 596 | "metadata": {}, 597 | "outputs": [ 598 | { 599 | "name": "stdout", 600 | "output_type": "stream", 601 | "text": [ 602 | "Removing 4 such as ['tceb2', 'tet2', 'tigit', 'trac'] from ['iv1', 'iv2']. 861 rows remaining.\n", 603 | "Removing 0 such as [] from ['cell_line']. 861 rows remaining.\n", 604 | "Fitting model.\n", 605 | "pytorch model, finetuning\n", 606 | "Gene net: Sequential(\n", 607 | " (0): Linear(in_features=1219, out_features=512, bias=True)\n", 608 | " (1): GELU(approximate='none')\n", 609 | " (2): Dropout(p=0.1, inplace=False)\n", 610 | " (3): Linear(in_features=512, out_features=512, bias=True)\n", 611 | ")\n", 612 | "Cell line net: Sequential(\n", 613 | " (0): Linear(in_features=300, out_features=512, bias=True)\n", 614 | " (1): GELU(approximate='none')\n", 615 | " (2): Dropout(p=0.1, inplace=False)\n", 616 | " (3): Linear(in_features=512, out_features=512, bias=True)\n", 617 | ")\n", 618 | "Regressor: Sequential(\n", 619 | " (0): Linear(in_features=512, out_features=512, bias=True)\n", 620 | " (1): GELU(approximate='none')\n", 621 | " (2): Dropout(p=0.2, inplace=False)\n", 622 | " (3): Linear(in_features=512, out_features=512, bias=True)\n", 623 | " (4): GELU(approximate='none')\n", 624 | " (5): Linear(in_features=512, out_features=1, bias=True)\n", 625 | ")\n" 626 | ] 627 | }, 628 | { 629 | "name": "stderr", 630 | "output_type": "stream", 631 | "text": [ 632 | "GPU available: True (cuda), used: True\n", 633 | "TPU available: False, using: 0 TPU cores\n", 634 | "IPU available: False, using: 0 IPUs\n", 635 | "HPU available: False, using: 0 HPUs\n", 636 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [MIG-c51c82f2-7e04-56a8-8b74-4211c9821715]\n", 637 | "\n", 638 | " | Name | Type | Params\n", 639 | "-----------------------------------------------------------\n", 640 | "0 | learnable_embedding | Embedding | 512 K \n", 641 | "1 | embedding_dropout | Dropout | 0 \n", 642 | "2 | gene_net | Sequential | 887 K \n", 643 | "3 | drug_net | Sequential | 887 K \n", 644 | "4 | cl_net | Sequential | 416 K \n", 645 | "5 | transformer | TransformerEncoder | 16.8 M\n", 646 | "6 | output_net | Sequential | 525 K \n", 647 | "-----------------------------------------------------------\n", 648 | "20.1 M Trainable params\n", 649 | "0 Non-trainable params\n", 650 | "20.1 M Total params\n", 651 | "80.206 Total estimated model params size (MB)\n" 652 | ] 653 | }, 654 | { 655 | "name": "stdout", 656 | "output_type": "stream", 657 | "text": [ 658 | "R2 average: False\n", 659 | "Running with early stopping: True\n", 660 | "Early stopping patience: 20\n", 661 | "Epoch 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 2.07it/s, v_num=9, train_loss=0.638]\n", 662 | "Validation: | | 0/? [00:00\n", 185 | "\n", 198 | "\n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | "
iv1iv2cell_linephenotypeiv1+iv2valuepred
0oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...A375GDSCoc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc..._0.480913
1oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...UACC62GDSCoc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc..._0.484669
2oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...WM983BGDSCoc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc..._0.491997
3oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...MALME3MGDSCoc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc..._0.473723
4oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...A2058GDSCoc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc..._0.489981
........................
135cs(=o)ccs(=o)cSKMEL1GDSCcs(=o)c+cs(=o)c_0.513998
136cs(=o)ccs(=o)cHMCBGDSCcs(=o)c+cs(=o)c_0.540128
137cs(=o)ccs(=o)cMDAMB435SGDSCcs(=o)c+cs(=o)c_0.538795
138cs(=o)ccs(=o)cWM1799GDSCcs(=o)c+cs(=o)c_0.526665
139cs(=o)ccs(=o)cLOXIMVIGDSCcs(=o)c+cs(=o)c_0.528534
\n", 324 | "

140 rows × 7 columns

\n", 325 | "" 326 | ], 327 | "text/plain": [ 328 | " iv1 \\\n", 329 | "0 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... \n", 330 | "1 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... \n", 331 | "2 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... \n", 332 | "3 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... \n", 333 | "4 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... \n", 334 | ".. ... \n", 335 | "135 cs(=o)c \n", 336 | "136 cs(=o)c \n", 337 | "137 cs(=o)c \n", 338 | "138 cs(=o)c \n", 339 | "139 cs(=o)c \n", 340 | "\n", 341 | " iv2 cell_line phenotype \\\n", 342 | "0 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... A375 GDSC \n", 343 | "1 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... UACC62 GDSC \n", 344 | "2 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... WM983B GDSC \n", 345 | "3 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... MALME3M GDSC \n", 346 | "4 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... A2058 GDSC \n", 347 | ".. ... ... ... \n", 348 | "135 cs(=o)c SKMEL1 GDSC \n", 349 | "136 cs(=o)c HMCB GDSC \n", 350 | "137 cs(=o)c MDAMB435S GDSC \n", 351 | "138 cs(=o)c WM1799 GDSC \n", 352 | "139 cs(=o)c LOXIMVI GDSC \n", 353 | "\n", 354 | " iv1+iv2 value pred \n", 355 | "0 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... _ 0.480913 \n", 356 | "1 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... _ 0.484669 \n", 357 | "2 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... _ 0.491997 \n", 358 | "3 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... _ 0.473723 \n", 359 | "4 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... _ 0.489981 \n", 360 | ".. ... ... ... \n", 361 | "135 cs(=o)c+cs(=o)c _ 0.513998 \n", 362 | "136 cs(=o)c+cs(=o)c _ 0.540128 \n", 363 | "137 cs(=o)c+cs(=o)c _ 0.538795 \n", 364 | "138 cs(=o)c+cs(=o)c _ 0.526665 \n", 365 | "139 cs(=o)c+cs(=o)c _ 0.528534 \n", 366 | "\n", 367 | "[140 rows x 7 columns]" 368 | ] 369 | }, 370 | "execution_count": 5, 371 | "metadata": {}, 372 | "output_type": "execute_result" 373 | } 374 | ], 375 | "source": [ 376 | "# predict with lists of treatments and cell lines\n", 377 | "df = model.predict(\n", 378 | " target_ivs=iv_list,\n", 379 | " target_cls=cl_list,\n", 380 | " target_phs=ph_list,\n", 381 | " iv_col=['iv1', 'iv2'], # pass to turn on combinatorial predictions\n", 382 | " num_iterations=1, save=False,\n", 383 | ")\n", 384 | "df" 385 | ] 386 | }, 387 | { 388 | "cell_type": "markdown", 389 | "id": "32b4e1ad", 390 | "metadata": {}, 391 | "source": [ 392 | "### Making predictions for a specific set of treatments, cell lines, and phenotypes\n", 393 | "\n", 394 | "If we're interested in only a subset of the experimental matrix, we can also pass in a custom dataframe. (This is the recommended usage, as users understand exactly the list being predicted.)" 395 | ] 396 | }, 397 | { 398 | "cell_type": "code", 399 | "execution_count": 6, 400 | "id": "8b0baefd", 401 | "metadata": { 402 | "scrolled": true 403 | }, 404 | "outputs": [ 405 | { 406 | "data": { 407 | "text/html": [ 408 | "
\n", 409 | "\n", 422 | "\n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | " \n", 759 | " \n", 760 | " \n", 761 | " \n", 762 | " \n", 763 | " \n", 764 | " \n", 765 | " \n", 766 | " \n", 767 | " \n", 768 | " \n", 769 | " \n", 770 | " \n", 771 | " \n", 772 | " \n", 773 | " \n", 774 | " \n", 775 | " \n", 776 | " \n", 777 | " \n", 778 | " \n", 779 | " \n", 780 | " \n", 781 | " \n", 782 | " \n", 783 | " \n", 784 | " \n", 785 | " \n", 786 | " \n", 787 | " \n", 788 | " \n", 789 | " \n", 790 | " \n", 791 | " \n", 792 | " \n", 793 | " \n", 794 | " \n", 795 | " \n", 796 | " \n", 797 | " \n", 798 | " \n", 799 | " \n", 800 | " \n", 801 | " \n", 802 | " \n", 803 | " \n", 804 | " \n", 805 | " \n", 806 | " \n", 807 | " \n", 808 | " \n", 809 | " \n", 810 | " \n", 811 | " \n", 812 | " \n", 813 | " \n", 814 | " \n", 815 | " \n", 816 | " \n", 817 | " \n", 818 | " \n", 819 | " \n", 820 | " \n", 821 | " \n", 822 | " \n", 823 | " \n", 824 | " \n", 825 | " \n", 826 | "
iv1cell_lineiv2phenotype
0oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...A375cs(=o)cGDSC
1oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...UACC62cs(=o)cGDSC
2oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...WM983Bcs(=o)cGDSC
3oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...MALME3Mcs(=o)cGDSC
4oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...A2058cs(=o)cGDSC
5oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...WM793cs(=o)cGDSC
6oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...HT144cs(=o)cGDSC
7oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...RPMI7951cs(=o)cGDSC
8oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...SKMEL2cs(=o)cGDSC
9oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...SKMEL1cs(=o)cGDSC
10oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...HMCBcs(=o)cGDSC
11oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...MDAMB435Scs(=o)cGDSC
12oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...WM1799cs(=o)cGDSC
13oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...LOXIMVIcs(=o)cGDSC
14cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc...A375cs(=o)cGDSC
15cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc...UACC62cs(=o)cGDSC
16cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc...WM983Bcs(=o)cGDSC
17cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc...MALME3Mcs(=o)cGDSC
18cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc...A2058cs(=o)cGDSC
19cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc...WM793cs(=o)cGDSC
20cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc...HT144cs(=o)cGDSC
21cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc...RPMI7951cs(=o)cGDSC
22cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc...SKMEL2cs(=o)cGDSC
23cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc...SKMEL1cs(=o)cGDSC
24cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc...HMCBcs(=o)cGDSC
25cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc...MDAMB435Scs(=o)cGDSC
26cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc...WM1799cs(=o)cGDSC
27cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc...LOXIMVIcs(=o)cGDSC
28fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c...A375cs(=o)cGDSC
29fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c...UACC62cs(=o)cGDSC
30fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c...WM983Bcs(=o)cGDSC
31fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c...MALME3Mcs(=o)cGDSC
32fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c...A2058cs(=o)cGDSC
33fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c...WM793cs(=o)cGDSC
34fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c...HT144cs(=o)cGDSC
35fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c...RPMI7951cs(=o)cGDSC
36fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c...SKMEL2cs(=o)cGDSC
37fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c...SKMEL1cs(=o)cGDSC
38fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c...HMCBcs(=o)cGDSC
39fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c...MDAMB435Scs(=o)cGDSC
40fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c...WM1799cs(=o)cGDSC
41fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c...LOXIMVIcs(=o)cGDSC
42cs(=o)cA375cs(=o)cGDSC
43cs(=o)cUACC62cs(=o)cGDSC
44cs(=o)cWM983Bcs(=o)cGDSC
45cs(=o)cMALME3Mcs(=o)cGDSC
46cs(=o)cA2058cs(=o)cGDSC
47cs(=o)cWM793cs(=o)cGDSC
48cs(=o)cHT144cs(=o)cGDSC
49cs(=o)cRPMI7951cs(=o)cGDSC
50cs(=o)cSKMEL2cs(=o)cGDSC
51cs(=o)cSKMEL1cs(=o)cGDSC
52cs(=o)cHMCBcs(=o)cGDSC
53cs(=o)cMDAMB435Scs(=o)cGDSC
54cs(=o)cWM1799cs(=o)cGDSC
55cs(=o)cLOXIMVIcs(=o)cGDSC
\n", 827 | "
" 828 | ], 829 | "text/plain": [ 830 | " iv1 cell_line iv2 \\\n", 831 | "0 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... A375 cs(=o)c \n", 832 | "1 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... UACC62 cs(=o)c \n", 833 | "2 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... WM983B cs(=o)c \n", 834 | "3 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... MALME3M cs(=o)c \n", 835 | "4 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... A2058 cs(=o)c \n", 836 | "5 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... WM793 cs(=o)c \n", 837 | "6 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... HT144 cs(=o)c \n", 838 | "7 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... RPMI7951 cs(=o)c \n", 839 | "8 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... SKMEL2 cs(=o)c \n", 840 | "9 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... SKMEL1 cs(=o)c \n", 841 | "10 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... HMCB cs(=o)c \n", 842 | "11 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... MDAMB435S cs(=o)c \n", 843 | "12 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... WM1799 cs(=o)c \n", 844 | "13 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... LOXIMVI cs(=o)c \n", 845 | "14 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... A375 cs(=o)c \n", 846 | "15 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... UACC62 cs(=o)c \n", 847 | "16 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... WM983B cs(=o)c \n", 848 | "17 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... MALME3M cs(=o)c \n", 849 | "18 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... A2058 cs(=o)c \n", 850 | "19 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... WM793 cs(=o)c \n", 851 | "20 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... HT144 cs(=o)c \n", 852 | "21 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... RPMI7951 cs(=o)c \n", 853 | "22 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... SKMEL2 cs(=o)c \n", 854 | "23 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... SKMEL1 cs(=o)c \n", 855 | "24 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... HMCB cs(=o)c \n", 856 | "25 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... MDAMB435S cs(=o)c \n", 857 | "26 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... WM1799 cs(=o)c \n", 858 | "27 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... LOXIMVI cs(=o)c \n", 859 | "28 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... A375 cs(=o)c \n", 860 | "29 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... UACC62 cs(=o)c \n", 861 | "30 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... WM983B cs(=o)c \n", 862 | "31 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... MALME3M cs(=o)c \n", 863 | "32 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... A2058 cs(=o)c \n", 864 | "33 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... WM793 cs(=o)c \n", 865 | "34 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... HT144 cs(=o)c \n", 866 | "35 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... RPMI7951 cs(=o)c \n", 867 | "36 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... SKMEL2 cs(=o)c \n", 868 | "37 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... SKMEL1 cs(=o)c \n", 869 | "38 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... HMCB cs(=o)c \n", 870 | "39 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... MDAMB435S cs(=o)c \n", 871 | "40 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... WM1799 cs(=o)c \n", 872 | "41 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... LOXIMVI cs(=o)c \n", 873 | "42 cs(=o)c A375 cs(=o)c \n", 874 | "43 cs(=o)c UACC62 cs(=o)c \n", 875 | "44 cs(=o)c WM983B cs(=o)c \n", 876 | "45 cs(=o)c MALME3M cs(=o)c \n", 877 | "46 cs(=o)c A2058 cs(=o)c \n", 878 | "47 cs(=o)c WM793 cs(=o)c \n", 879 | "48 cs(=o)c HT144 cs(=o)c \n", 880 | "49 cs(=o)c RPMI7951 cs(=o)c \n", 881 | "50 cs(=o)c SKMEL2 cs(=o)c \n", 882 | "51 cs(=o)c SKMEL1 cs(=o)c \n", 883 | "52 cs(=o)c HMCB cs(=o)c \n", 884 | "53 cs(=o)c MDAMB435S cs(=o)c \n", 885 | "54 cs(=o)c WM1799 cs(=o)c \n", 886 | "55 cs(=o)c LOXIMVI cs(=o)c \n", 887 | "\n", 888 | " phenotype \n", 889 | "0 GDSC \n", 890 | "1 GDSC \n", 891 | "2 GDSC \n", 892 | "3 GDSC \n", 893 | "4 GDSC \n", 894 | "5 GDSC \n", 895 | "6 GDSC \n", 896 | "7 GDSC \n", 897 | "8 GDSC \n", 898 | "9 GDSC \n", 899 | "10 GDSC \n", 900 | "11 GDSC \n", 901 | "12 GDSC \n", 902 | "13 GDSC \n", 903 | "14 GDSC \n", 904 | "15 GDSC \n", 905 | "16 GDSC \n", 906 | "17 GDSC \n", 907 | "18 GDSC \n", 908 | "19 GDSC \n", 909 | "20 GDSC \n", 910 | "21 GDSC \n", 911 | "22 GDSC \n", 912 | "23 GDSC \n", 913 | "24 GDSC \n", 914 | "25 GDSC \n", 915 | "26 GDSC \n", 916 | "27 GDSC \n", 917 | "28 GDSC \n", 918 | "29 GDSC \n", 919 | "30 GDSC \n", 920 | "31 GDSC \n", 921 | "32 GDSC \n", 922 | "33 GDSC \n", 923 | "34 GDSC \n", 924 | "35 GDSC \n", 925 | "36 GDSC \n", 926 | "37 GDSC \n", 927 | "38 GDSC \n", 928 | "39 GDSC \n", 929 | "40 GDSC \n", 930 | "41 GDSC \n", 931 | "42 GDSC \n", 932 | "43 GDSC \n", 933 | "44 GDSC \n", 934 | "45 GDSC \n", 935 | "46 GDSC \n", 936 | "47 GDSC \n", 937 | "48 GDSC \n", 938 | "49 GDSC \n", 939 | "50 GDSC \n", 940 | "51 GDSC \n", 941 | "52 GDSC \n", 942 | "53 GDSC \n", 943 | "54 GDSC \n", 944 | "55 GDSC " 945 | ] 946 | }, 947 | "execution_count": 6, 948 | "metadata": {}, 949 | "output_type": "execute_result" 950 | } 951 | ], 952 | "source": [ 953 | "# Construct a dataframe containing the experiments we want to run. In practice, the user would load\n", 954 | "# a premade dataframe here.\n", 955 | "input_df = pd.MultiIndex.from_product([\n", 956 | " iv_list,\n", 957 | " cl_list,\n", 958 | "], names=['iv1', 'cell_line'])\n", 959 | "input_df = input_df.to_frame(index=False).reset_index(drop=True)\n", 960 | "input_df['iv2'] = 'cs(=o)c' # DMSO\n", 961 | "input_df['phenotype'] = 'GDSC'\n", 962 | "input_df" 963 | ] 964 | }, 965 | { 966 | "cell_type": "code", 967 | "execution_count": 7, 968 | "id": "65e53302", 969 | "metadata": {}, 970 | "outputs": [ 971 | { 972 | "name": "stdout", 973 | "output_type": "stream", 974 | "text": [ 975 | "There are 1 iterations\n" 976 | ] 977 | }, 978 | { 979 | "name": "stderr", 980 | "output_type": "stream", 981 | "text": [ 982 | " 0%| | 0/1 [00:00\n", 1007 | "\n", 1020 | "\n", 1021 | " \n", 1022 | " \n", 1023 | " \n", 1024 | " \n", 1025 | " \n", 1026 | " \n", 1027 | " \n", 1028 | " \n", 1029 | " \n", 1030 | " \n", 1031 | " \n", 1032 | " \n", 1033 | " \n", 1034 | " \n", 1035 | " \n", 1036 | " \n", 1037 | " \n", 1038 | " \n", 1039 | " \n", 1040 | " \n", 1041 | " \n", 1042 | " \n", 1043 | " \n", 1044 | " \n", 1045 | " \n", 1046 | " \n", 1047 | " \n", 1048 | " \n", 1049 | " \n", 1050 | " \n", 1051 | " \n", 1052 | " \n", 1053 | " \n", 1054 | " \n", 1055 | " \n", 1056 | " \n", 1057 | " \n", 1058 | " \n", 1059 | " \n", 1060 | " \n", 1061 | " \n", 1062 | " \n", 1063 | " \n", 1064 | " \n", 1065 | " \n", 1066 | " \n", 1067 | " \n", 1068 | " \n", 1069 | " \n", 1070 | " \n", 1071 | " \n", 1072 | " \n", 1073 | " \n", 1074 | " \n", 1075 | " \n", 1076 | " \n", 1077 | " \n", 1078 | " \n", 1079 | " \n", 1080 | " \n", 1081 | " \n", 1082 | " \n", 1083 | " \n", 1084 | " \n", 1085 | " \n", 1086 | " \n", 1087 | " \n", 1088 | " \n", 1089 | " \n", 1090 | " \n", 1091 | " \n", 1092 | " \n", 1093 | " \n", 1094 | " \n", 1095 | " \n", 1096 | " \n", 1097 | " \n", 1098 | " \n", 1099 | " \n", 1100 | " \n", 1101 | " \n", 1102 | " \n", 1103 | " \n", 1104 | " \n", 1105 | " \n", 1106 | " \n", 1107 | " \n", 1108 | " \n", 1109 | " \n", 1110 | " \n", 1111 | " \n", 1112 | " \n", 1113 | " \n", 1114 | " \n", 1115 | " \n", 1116 | " \n", 1117 | " \n", 1118 | " \n", 1119 | " \n", 1120 | " \n", 1121 | " \n", 1122 | " \n", 1123 | " \n", 1124 | " \n", 1125 | " \n", 1126 | " \n", 1127 | " \n", 1128 | " \n", 1129 | " \n", 1130 | " \n", 1131 | " \n", 1132 | " \n", 1133 | " \n", 1134 | " \n", 1135 | " \n", 1136 | " \n", 1137 | " \n", 1138 | " \n", 1139 | " \n", 1140 | " \n", 1141 | " \n", 1142 | " \n", 1143 | " \n", 1144 | " \n", 1145 | " \n", 1146 | " \n", 1147 | " \n", 1148 | " \n", 1149 | " \n", 1150 | " \n", 1151 | " \n", 1152 | " \n", 1153 | " \n", 1154 | " \n", 1155 | " \n", 1156 | " \n", 1157 | " \n", 1158 | " \n", 1159 | " \n", 1160 | " \n", 1161 | " \n", 1162 | " \n", 1163 | " \n", 1164 | " \n", 1165 | " \n", 1166 | " \n", 1167 | " \n", 1168 | " \n", 1169 | " \n", 1170 | " \n", 1171 | " \n", 1172 | " \n", 1173 | " \n", 1174 | " \n", 1175 | " \n", 1176 | " \n", 1177 | " \n", 1178 | " \n", 1179 | " \n", 1180 | " \n", 1181 | " \n", 1182 | " \n", 1183 | " \n", 1184 | " \n", 1185 | " \n", 1186 | " \n", 1187 | " \n", 1188 | " \n", 1189 | " \n", 1190 | " \n", 1191 | " \n", 1192 | " \n", 1193 | " \n", 1194 | " \n", 1195 | " \n", 1196 | " \n", 1197 | " \n", 1198 | " \n", 1199 | " \n", 1200 | " \n", 1201 | " \n", 1202 | " \n", 1203 | " \n", 1204 | " \n", 1205 | " \n", 1206 | " \n", 1207 | " \n", 1208 | " \n", 1209 | " \n", 1210 | " \n", 1211 | " \n", 1212 | " \n", 1213 | " \n", 1214 | " \n", 1215 | " \n", 1216 | " \n", 1217 | " \n", 1218 | " \n", 1219 | " \n", 1220 | " \n", 1221 | " \n", 1222 | " \n", 1223 | " \n", 1224 | " \n", 1225 | " \n", 1226 | " \n", 1227 | " \n", 1228 | " \n", 1229 | " \n", 1230 | " \n", 1231 | " \n", 1232 | " \n", 1233 | " \n", 1234 | " \n", 1235 | " \n", 1236 | " \n", 1237 | " \n", 1238 | " \n", 1239 | " \n", 1240 | " \n", 1241 | " \n", 1242 | " \n", 1243 | " \n", 1244 | " \n", 1245 | " \n", 1246 | " \n", 1247 | " \n", 1248 | " \n", 1249 | " \n", 1250 | " \n", 1251 | " \n", 1252 | " \n", 1253 | " \n", 1254 | " \n", 1255 | " \n", 1256 | " \n", 1257 | " \n", 1258 | " \n", 1259 | " \n", 1260 | " \n", 1261 | " \n", 1262 | " \n", 1263 | " \n", 1264 | " \n", 1265 | " \n", 1266 | " \n", 1267 | " \n", 1268 | " \n", 1269 | " \n", 1270 | " \n", 1271 | " \n", 1272 | " \n", 1273 | " \n", 1274 | " \n", 1275 | " \n", 1276 | " \n", 1277 | " \n", 1278 | " \n", 1279 | " \n", 1280 | " \n", 1281 | " \n", 1282 | " \n", 1283 | " \n", 1284 | " \n", 1285 | " \n", 1286 | " \n", 1287 | " \n", 1288 | " \n", 1289 | " \n", 1290 | " \n", 1291 | " \n", 1292 | " \n", 1293 | " \n", 1294 | " \n", 1295 | " \n", 1296 | " \n", 1297 | " \n", 1298 | " \n", 1299 | " \n", 1300 | " \n", 1301 | " \n", 1302 | " \n", 1303 | " \n", 1304 | " \n", 1305 | " \n", 1306 | " \n", 1307 | " \n", 1308 | " \n", 1309 | " \n", 1310 | " \n", 1311 | " \n", 1312 | " \n", 1313 | " \n", 1314 | " \n", 1315 | " \n", 1316 | " \n", 1317 | " \n", 1318 | " \n", 1319 | " \n", 1320 | " \n", 1321 | " \n", 1322 | " \n", 1323 | " \n", 1324 | " \n", 1325 | " \n", 1326 | " \n", 1327 | " \n", 1328 | " \n", 1329 | " \n", 1330 | " \n", 1331 | " \n", 1332 | " \n", 1333 | " \n", 1334 | " \n", 1335 | " \n", 1336 | " \n", 1337 | " \n", 1338 | " \n", 1339 | " \n", 1340 | " \n", 1341 | " \n", 1342 | " \n", 1343 | " \n", 1344 | " \n", 1345 | " \n", 1346 | " \n", 1347 | " \n", 1348 | " \n", 1349 | " \n", 1350 | " \n", 1351 | " \n", 1352 | " \n", 1353 | " \n", 1354 | " \n", 1355 | " \n", 1356 | " \n", 1357 | " \n", 1358 | " \n", 1359 | " \n", 1360 | " \n", 1361 | " \n", 1362 | " \n", 1363 | " \n", 1364 | " \n", 1365 | " \n", 1366 | " \n", 1367 | " \n", 1368 | " \n", 1369 | " \n", 1370 | " \n", 1371 | " \n", 1372 | " \n", 1373 | " \n", 1374 | " \n", 1375 | " \n", 1376 | " \n", 1377 | " \n", 1378 | " \n", 1379 | " \n", 1380 | " \n", 1381 | " \n", 1382 | " \n", 1383 | " \n", 1384 | " \n", 1385 | " \n", 1386 | " \n", 1387 | " \n", 1388 | " \n", 1389 | " \n", 1390 | " \n", 1391 | " \n", 1392 | " \n", 1393 | " \n", 1394 | " \n", 1395 | " \n", 1396 | " \n", 1397 | " \n", 1398 | " \n", 1399 | " \n", 1400 | " \n", 1401 | " \n", 1402 | " \n", 1403 | " \n", 1404 | " \n", 1405 | " \n", 1406 | " \n", 1407 | " \n", 1408 | " \n", 1409 | " \n", 1410 | " \n", 1411 | " \n", 1412 | " \n", 1413 | " \n", 1414 | " \n", 1415 | " \n", 1416 | " \n", 1417 | " \n", 1418 | " \n", 1419 | " \n", 1420 | " \n", 1421 | " \n", 1422 | " \n", 1423 | " \n", 1424 | " \n", 1425 | " \n", 1426 | " \n", 1427 | " \n", 1428 | " \n", 1429 | " \n", 1430 | " \n", 1431 | " \n", 1432 | " \n", 1433 | " \n", 1434 | " \n", 1435 | " \n", 1436 | " \n", 1437 | " \n", 1438 | " \n", 1439 | " \n", 1440 | " \n", 1441 | " \n", 1442 | " \n", 1443 | " \n", 1444 | " \n", 1445 | " \n", 1446 | " \n", 1447 | " \n", 1448 | " \n", 1449 | " \n", 1450 | " \n", 1451 | " \n", 1452 | " \n", 1453 | " \n", 1454 | " \n", 1455 | " \n", 1456 | " \n", 1457 | " \n", 1458 | " \n", 1459 | " \n", 1460 | " \n", 1461 | " \n", 1462 | " \n", 1463 | " \n", 1464 | " \n", 1465 | " \n", 1466 | " \n", 1467 | " \n", 1468 | " \n", 1469 | " \n", 1470 | " \n", 1471 | " \n", 1472 | " \n", 1473 | " \n", 1474 | " \n", 1475 | " \n", 1476 | " \n", 1477 | " \n", 1478 | " \n", 1479 | " \n", 1480 | " \n", 1481 | "
iv1cell_lineiv2phenotypepred
0oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...A375cs(=o)cGDSC0.489043
1oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...UACC62cs(=o)cGDSC0.495842
2oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...WM983Bcs(=o)cGDSC0.501370
3oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...MALME3Mcs(=o)cGDSC0.484853
4oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...A2058cs(=o)cGDSC0.499624
5oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...WM793cs(=o)cGDSC0.493993
6oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...HT144cs(=o)cGDSC0.495368
7oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...RPMI7951cs(=o)cGDSC0.479674
8oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...SKMEL2cs(=o)cGDSC0.506403
9oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...SKMEL1cs(=o)cGDSC0.471883
10oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...HMCBcs(=o)cGDSC0.505215
11oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...MDAMB435Scs(=o)cGDSC0.500283
12oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...WM1799cs(=o)cGDSC0.498168
13oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...LOXIMVIcs(=o)cGDSC0.495914
14cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc...A375cs(=o)cGDSC0.479828
15cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc...UACC62cs(=o)cGDSC0.496269
16cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc...WM983Bcs(=o)cGDSC0.481987
17cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc...MALME3Mcs(=o)cGDSC0.470828
18cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc...A2058cs(=o)cGDSC0.486716
19cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc...WM793cs(=o)cGDSC0.492139
20cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc...HT144cs(=o)cGDSC0.486285
21cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc...RPMI7951cs(=o)cGDSC0.473088
22cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc...SKMEL2cs(=o)cGDSC0.487896
23cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc...SKMEL1cs(=o)cGDSC0.469014
24cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc...HMCBcs(=o)cGDSC0.505564
25cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc...MDAMB435Scs(=o)cGDSC0.490048
26cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc...WM1799cs(=o)cGDSC0.486165
27cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc...LOXIMVIcs(=o)cGDSC0.495294
28fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c...A375cs(=o)cGDSC0.502929
29fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c...UACC62cs(=o)cGDSC0.514967
30fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c...WM983Bcs(=o)cGDSC0.490298
31fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c...MALME3Mcs(=o)cGDSC0.488163
32fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c...A2058cs(=o)cGDSC0.506753
33fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c...WM793cs(=o)cGDSC0.511536
34fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c...HT144cs(=o)cGDSC0.504222
35fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c...RPMI7951cs(=o)cGDSC0.481429
36fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c...SKMEL2cs(=o)cGDSC0.514762
37fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c...SKMEL1cs(=o)cGDSC0.490700
38fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c...HMCBcs(=o)cGDSC0.521368
39fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c...MDAMB435Scs(=o)cGDSC0.498836
40fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c...WM1799cs(=o)cGDSC0.499848
41fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c...LOXIMVIcs(=o)cGDSC0.512768
42cs(=o)cA375cs(=o)cGDSC0.523384
43cs(=o)cUACC62cs(=o)cGDSC0.526230
44cs(=o)cWM983Bcs(=o)cGDSC0.529374
45cs(=o)cMALME3Mcs(=o)cGDSC0.517211
46cs(=o)cA2058cs(=o)cGDSC0.541236
47cs(=o)cWM793cs(=o)cGDSC0.524480
48cs(=o)cHT144cs(=o)cGDSC0.531894
49cs(=o)cRPMI7951cs(=o)cGDSC0.505558
50cs(=o)cSKMEL2cs(=o)cGDSC0.529134
51cs(=o)cSKMEL1cs(=o)cGDSC0.513998
52cs(=o)cHMCBcs(=o)cGDSC0.540129
53cs(=o)cMDAMB435Scs(=o)cGDSC0.538795
54cs(=o)cWM1799cs(=o)cGDSC0.526665
55cs(=o)cLOXIMVIcs(=o)cGDSC0.528534
\n", 1482 | "" 1483 | ], 1484 | "text/plain": [ 1485 | " iv1 cell_line iv2 \\\n", 1486 | "0 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... A375 cs(=o)c \n", 1487 | "1 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... UACC62 cs(=o)c \n", 1488 | "2 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... WM983B cs(=o)c \n", 1489 | "3 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... MALME3M cs(=o)c \n", 1490 | "4 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... A2058 cs(=o)c \n", 1491 | "5 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... WM793 cs(=o)c \n", 1492 | "6 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... HT144 cs(=o)c \n", 1493 | "7 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... RPMI7951 cs(=o)c \n", 1494 | "8 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... SKMEL2 cs(=o)c \n", 1495 | "9 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... SKMEL1 cs(=o)c \n", 1496 | "10 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... HMCB cs(=o)c \n", 1497 | "11 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... MDAMB435S cs(=o)c \n", 1498 | "12 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... WM1799 cs(=o)c \n", 1499 | "13 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... LOXIMVI cs(=o)c \n", 1500 | "14 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... A375 cs(=o)c \n", 1501 | "15 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... UACC62 cs(=o)c \n", 1502 | "16 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... WM983B cs(=o)c \n", 1503 | "17 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... MALME3M cs(=o)c \n", 1504 | "18 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... A2058 cs(=o)c \n", 1505 | "19 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... WM793 cs(=o)c \n", 1506 | "20 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... HT144 cs(=o)c \n", 1507 | "21 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... RPMI7951 cs(=o)c \n", 1508 | "22 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... SKMEL2 cs(=o)c \n", 1509 | "23 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... SKMEL1 cs(=o)c \n", 1510 | "24 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... HMCB cs(=o)c \n", 1511 | "25 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... MDAMB435S cs(=o)c \n", 1512 | "26 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... WM1799 cs(=o)c \n", 1513 | "27 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... LOXIMVI cs(=o)c \n", 1514 | "28 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... A375 cs(=o)c \n", 1515 | "29 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... UACC62 cs(=o)c \n", 1516 | "30 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... WM983B cs(=o)c \n", 1517 | "31 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... MALME3M cs(=o)c \n", 1518 | "32 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... A2058 cs(=o)c \n", 1519 | "33 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... WM793 cs(=o)c \n", 1520 | "34 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... HT144 cs(=o)c \n", 1521 | "35 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... RPMI7951 cs(=o)c \n", 1522 | "36 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... SKMEL2 cs(=o)c \n", 1523 | "37 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... SKMEL1 cs(=o)c \n", 1524 | "38 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... HMCB cs(=o)c \n", 1525 | "39 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... MDAMB435S cs(=o)c \n", 1526 | "40 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... WM1799 cs(=o)c \n", 1527 | "41 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... LOXIMVI cs(=o)c \n", 1528 | "42 cs(=o)c A375 cs(=o)c \n", 1529 | "43 cs(=o)c UACC62 cs(=o)c \n", 1530 | "44 cs(=o)c WM983B cs(=o)c \n", 1531 | "45 cs(=o)c MALME3M cs(=o)c \n", 1532 | "46 cs(=o)c A2058 cs(=o)c \n", 1533 | "47 cs(=o)c WM793 cs(=o)c \n", 1534 | "48 cs(=o)c HT144 cs(=o)c \n", 1535 | "49 cs(=o)c RPMI7951 cs(=o)c \n", 1536 | "50 cs(=o)c SKMEL2 cs(=o)c \n", 1537 | "51 cs(=o)c SKMEL1 cs(=o)c \n", 1538 | "52 cs(=o)c HMCB cs(=o)c \n", 1539 | "53 cs(=o)c MDAMB435S cs(=o)c \n", 1540 | "54 cs(=o)c WM1799 cs(=o)c \n", 1541 | "55 cs(=o)c LOXIMVI cs(=o)c \n", 1542 | "\n", 1543 | " phenotype pred \n", 1544 | "0 GDSC 0.489043 \n", 1545 | "1 GDSC 0.495842 \n", 1546 | "2 GDSC 0.501370 \n", 1547 | "3 GDSC 0.484853 \n", 1548 | "4 GDSC 0.499624 \n", 1549 | "5 GDSC 0.493993 \n", 1550 | "6 GDSC 0.495368 \n", 1551 | "7 GDSC 0.479674 \n", 1552 | "8 GDSC 0.506403 \n", 1553 | "9 GDSC 0.471883 \n", 1554 | "10 GDSC 0.505215 \n", 1555 | "11 GDSC 0.500283 \n", 1556 | "12 GDSC 0.498168 \n", 1557 | "13 GDSC 0.495914 \n", 1558 | "14 GDSC 0.479828 \n", 1559 | "15 GDSC 0.496269 \n", 1560 | "16 GDSC 0.481987 \n", 1561 | "17 GDSC 0.470828 \n", 1562 | "18 GDSC 0.486716 \n", 1563 | "19 GDSC 0.492139 \n", 1564 | "20 GDSC 0.486285 \n", 1565 | "21 GDSC 0.473088 \n", 1566 | "22 GDSC 0.487896 \n", 1567 | "23 GDSC 0.469014 \n", 1568 | "24 GDSC 0.505564 \n", 1569 | "25 GDSC 0.490048 \n", 1570 | "26 GDSC 0.486165 \n", 1571 | "27 GDSC 0.495294 \n", 1572 | "28 GDSC 0.502929 \n", 1573 | "29 GDSC 0.514967 \n", 1574 | "30 GDSC 0.490298 \n", 1575 | "31 GDSC 0.488163 \n", 1576 | "32 GDSC 0.506753 \n", 1577 | "33 GDSC 0.511536 \n", 1578 | "34 GDSC 0.504222 \n", 1579 | "35 GDSC 0.481429 \n", 1580 | "36 GDSC 0.514762 \n", 1581 | "37 GDSC 0.490700 \n", 1582 | "38 GDSC 0.521368 \n", 1583 | "39 GDSC 0.498836 \n", 1584 | "40 GDSC 0.499848 \n", 1585 | "41 GDSC 0.512768 \n", 1586 | "42 GDSC 0.523384 \n", 1587 | "43 GDSC 0.526230 \n", 1588 | "44 GDSC 0.529374 \n", 1589 | "45 GDSC 0.517211 \n", 1590 | "46 GDSC 0.541236 \n", 1591 | "47 GDSC 0.524480 \n", 1592 | "48 GDSC 0.531894 \n", 1593 | "49 GDSC 0.505558 \n", 1594 | "50 GDSC 0.529134 \n", 1595 | "51 GDSC 0.513998 \n", 1596 | "52 GDSC 0.540129 \n", 1597 | "53 GDSC 0.538795 \n", 1598 | "54 GDSC 0.526665 \n", 1599 | "55 GDSC 0.528534 " 1600 | ] 1601 | }, 1602 | "execution_count": 7, 1603 | "metadata": {}, 1604 | "output_type": "execute_result" 1605 | } 1606 | ], 1607 | "source": [ 1608 | "df = model.predict(input_df, num_iterations=1, save=False)\n", 1609 | "df" 1610 | ] 1611 | } 1612 | ], 1613 | "metadata": { 1614 | "kernelspec": { 1615 | "display_name": "Python [conda env:prophet_api]", 1616 | "language": "python", 1617 | "name": "conda-env-prophet_api-py" 1618 | }, 1619 | "language_info": { 1620 | "codemirror_mode": { 1621 | "name": "ipython", 1622 | "version": 3 1623 | }, 1624 | "file_extension": ".py", 1625 | "mimetype": "text/x-python", 1626 | "name": "python", 1627 | "nbconvert_exporter": "python", 1628 | "pygments_lexer": "ipython3", 1629 | "version": "3.12.4" 1630 | } 1631 | }, 1632 | "nbformat": 4, 1633 | "nbformat_minor": 5 1634 | } 1635 | --------------------------------------------------------------------------------