├── .github
├── CODE_OF_CONDUCT.md
└── CONTRIBUTING.md
├── LICENSE
├── README.md
├── embeddings
├── cell_line_embedding_full_ccle_300_scaled.csv
└── phenotypes.csv
├── prophet
├── Prophet.py
├── __init__.py
├── callbacks.py
├── config.py
├── dataloader.py
├── dataset.py
├── model.py
├── train.py
└── train_model.py
├── setup.py
├── test
└── test_dataloader.py
└── tutorials
├── config_file_finetuning.yaml
├── finetuning.ipynb
└── insilico_screening.ipynb
/.github/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 |
2 | ---
3 |
4 | ### 📄 `CODE_OF_CONDUCT.md`
5 |
6 | ```markdown
7 | # Code of Conduct
8 |
9 | ## Our Pledge
10 |
11 | We are committed to fostering a welcoming and inclusive environment for everyone. We pledge to make participation in our project a harassment-free experience for all, regardless of background or identity.
12 |
13 | ## Our Standards
14 |
15 | Examples of behavior that contributes to a positive environment:
16 |
17 | - Using welcoming and inclusive language
18 | - Being respectful of differing viewpoints and experiences
19 | - Accepting constructive criticism gracefully
20 | - Focusing on what is best for the community
21 |
22 | Examples of unacceptable behavior:
23 |
24 | - Harassment, discrimination, or exclusionary behavior
25 | - Trolling, insulting or derogatory comments
26 | - Personal or political attacks
27 | - Public or private harassment
28 |
29 | ## Our Responsibilities
30 |
31 | Maintainers are responsible for clarifying standards of acceptable behavior and will take appropriate action in response to any unacceptable behavior.
32 |
33 | ## Scope
34 |
35 | This Code of Conduct applies within all project spaces and applies when representing the project in any public space.
36 |
37 | ## Enforcement
38 |
39 | Violations may result in warnings, temporary bans, or permanent exclusion from the project.
40 |
41 | If you experience or witness unacceptable behavior, please report it by contacting the maintainers directly or via a dedicated email (if applicable).
42 |
43 | ## Attribution
44 |
45 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 2.1.
46 |
--------------------------------------------------------------------------------
/.github/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to This Project
2 |
3 | Thank you for your interest in contributing! We welcome all contributions, including bug reports, feature requests, documentation improvements, and code changes.
4 |
5 | ## Getting Started
6 |
7 | 1. **Fork the repository** and clone your fork locally.
8 | 2. Create a new branch for your changes:
9 | ```bash
10 | git checkout -b my-feature-branch
11 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024, Alejandro Tejada Lapuerta, Yuge Ji.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [![License: MIT][mit-shield]][mit]
2 | # Prophet
3 |
4 | Prophet is a transformer-based regression model that predicts cellular responses by decomposing experiments into cell state, treatment, and functional readout, leveraging extensive screening datasets and scalability to significantly reduce the number of required experiments and identify effective treatments.
5 |
6 | ## Model Overview
7 |
8 | Prophet decomposes biological experiments into three key components:
9 | 1. **Cell state** - represented by cell line embeddings derived from gene expression profiles
10 | 2. **Treatment** - represented by intervention embeddings (e.g., small molecules, genetic perturbations)
11 | 3. **Functional readout** - the phenotypic measurement being predicted (e.g., viability, IC50)
12 |
13 | The model uses a transformer architecture to learn complex interactions between these components and predict experimental outcomes without requiring the experiments to be performed.
14 |
15 | ### Embeddings
16 |
17 | Prophet uses three types of embeddings:
18 | - **Cell line embeddings**: 300-dimensional vectors derived from CCLE gene expression data
19 | - **Intervention embeddings**: 500-dimensional vectors representing small molecules or genetic perturbations
20 | - **Phenotype embeddings**: Representations of different readout types (optional)
21 |
22 | These embeddings capture the biological properties of each component and allow the model to generalize across different experimental conditions.
23 |
24 | ## Training
25 |
26 | Prophet was trained on a large dataset of cellular response measurements, including:
27 | - Drug sensitivity screens (GDSC, PRISM, CTRP)
28 | - Genetic perturbation screens (DepMap, Achilles)
29 | - Combinatorial perturbation experiments
30 |
31 | The model was trained using a masked attention mechanism to handle variable numbers of perturbations and a cosine learning rate schedule with warmup. Training was performed on NVIDIA A100 GPUs with early stopping based on validation loss.
32 |
33 | ## Installation
34 | ```
35 | mamba create -n prophet_env python=3.10
36 | mamba activate prophet_env
37 |
38 | git clone https://github.com/theislab/prophet.git
39 | cd prophet
40 | pip install -e .
41 | ```
42 |
43 | ## Usage
44 |
45 | ### Downloading Resources
46 |
47 | Model checkpoints and input embeddings can be downloaded [here](https://huggingface.co/datasets/aletlvl/Prophet_v1/tree/main) and [here](https://data.mendeley.com/datasets/g7z3pw3bfw).
48 |
49 | If you have used our work in your research, please cite our [preprint](https://www.biorxiv.org/content/10.1101/2024.08.12.607533v2).
50 |
51 | [mit]: https://opensource.org/licenses/MIT
52 | [mit-image]: https://img.shields.io/badge/License-MIT-yellow.svg
53 | [mit-shield]: https://img.shields.io/badge/License-MIT-yellow.svg
54 |
--------------------------------------------------------------------------------
/embeddings/phenotypes.csv:
--------------------------------------------------------------------------------
1 | ,0
2 | 0,SCORE
3 | 1,Horlbeck
4 | 2,UBE2A
5 | 3,GNAI1
6 | 4,MYBL2
7 | 5,BIRC5
8 | 6,NFIL3
9 | 7,TGFB3
10 | 8,P4HTM
11 | 9,RPN1
12 | 10,PDIA5
13 | 11,BAMBI
14 | 12,TATDN2
15 | 13,FAM57A
16 | 14,FGFR4
17 | 15,MOK
18 | 16,CCNE2
19 | 17,KIAA0753
20 | 18,KDELR2
21 | 19,GTF2A2
22 | 20,DCTD
23 | 21,USP22
24 | 22,CANT1
25 | 23,NENF
26 | 24,RRP12
27 | 25,DSG2
28 | 26,SACM1L
29 | 27,FBXO11
30 | 28,AKAP8
31 | 29,BAG3
32 | 30,MBNL2
33 | 31,C2CD2
34 | 32,YKT6
35 | 33,MVP
36 | 34,NOL3
37 | 35,CSRP1
38 | 36,DMTF1
39 | 37,ALDOC
40 | 38,SYNGR3
41 | 39,TCEAL4
42 | 40,PLEKHM1
43 | 41,IGF1R
44 | 42,MMP1
45 | 43,ETV1
46 | 44,CHERP
47 | 45,PARP2
48 | 46,TIMM17B
49 | 47,NUP133
50 | 48,ENOSF1
51 | 49,RALA
52 | 50,CNOT4
53 | 51,FBXL12
54 | 52,GDSC
55 | 53,GDSCcomb
56 | 54,PRISM
57 | 55,inhouse
58 | 56,CTRP
59 |
--------------------------------------------------------------------------------
/prophet/Prophet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytorch_lightning as pl
3 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
4 | import numpy as np
5 | import pandas as pd
6 | import warnings
7 | from tqdm import tqdm
8 | from typing import List, Union, Optional
9 | from prophet.callbacks import R2ScoreCallback
10 | import functools
11 | from joblib import load
12 | from sklearn.ensemble import RandomForestRegressor
13 | from .dataloader import (
14 | dataloader_phenotypes,
15 | process_priors,
16 | remove_nonexistent_cat,
17 | )
18 | from .model import load_models_config, TransformerPredictor
19 | from pytorch_lightning.callbacks import TQDMProgressBar
20 |
21 | def inherit_docs_and_signature(from_method):
22 | def decorator(to_method):
23 | @functools.wraps(from_method)
24 | def wrapper(self, *args, **kwargs):
25 | return to_method(self, *args, **kwargs)
26 | wrapper.__doc__ = from_method.__doc__
27 | wrapper.__signature__ = from_method.__signature__
28 | return wrapper
29 | return decorator
30 |
31 | class Prophet:
32 | def __init__(
33 | self,
34 | iv_emb_path: Union[str, List[str]] = None,
35 | cl_emb_path: Union[str, List[str]] = None,
36 | ph_emb_path: Union[str, List[str]] = None,
37 | model_pth=None,
38 | architecture="Transformer",
39 | ):
40 | """Initialize the Prophet model.
41 |
42 | Args:
43 | iv_emb_path (Union[str, List[str]], optional): The path to the gene embeddings. Defaults to None.
44 | cl_emb_path (Union[str, List[str]], optional): The path to the cell line embeddings. Defaults to None.
45 | ph_emb_path (Union[str, List[str]], optional): The path to the phenotype embeddings. Defaults to None.
46 | model_pth ([type], optional): The path to the trained model. Defaults to None.
47 | architecture (str, optional): The architecture of the model. Defaults to "Transformer".
48 | """
49 |
50 | self.architecture = architecture
51 | self.iv_emb_path = iv_emb_path
52 | self.cl_emb_path = cl_emb_path
53 | self.ph_emb_path = ph_emb_path
54 | # set phenotypes (must be in the same order regardless of what is passed in predict)
55 | self.phenotypes = None
56 | self.column_map = None
57 | self.pert_len = None
58 |
59 | if model_pth and architecture == "RandomForest":
60 | self.model = load(model_pth)
61 | else:
62 | self.model_pth = model_pth
63 | self.model = self._build_model(architecture)
64 | self.phenotypes = self.model.hparams["phenotypes"]
65 | self.iv_embedding, self.cl_embedding, self.ph_embedding = process_priors(self.iv_emb_path, self.cl_emb_path, self.ph_emb_path)
66 | if self.model.hparams.explicit_phenotype and self.ph_embedding is None:
67 | raise ValueError('model was run with explicit phenotype! must pass a ph_emb_path')
68 |
69 | def _build_model(self, arch):
70 | if arch == "RandomForest":
71 | self.torch_dataset = False
72 | return RandomForestRegressor()
73 | elif arch == "Transformer":
74 | self.torch_dataset = True
75 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
76 | print('returning trained model!')
77 | model = TransformerPredictor.load_from_checkpoint(checkpoint_path=self.model_pth, map_location=torch.device('cpu')) #map_location must be cpu to load from checkpoint if you used ddp-notebook
78 | model.eval()
79 | # working backwards from config
80 | if model.hparams.simpler:
81 | self.pert_len = model.hparams.ctx_len - 1
82 | else:
83 | self.pert_len = model.hparams.ctx_len - 3
84 |
85 | return model
86 | else:
87 | raise ValueError(arch, " is not a valid model architecture.")
88 |
89 | def _remove_nonexistent_cat(
90 | self,
91 | data_label: Optional[pd.DataFrame] = None,
92 | verbose=True,
93 | ):
94 | embeddings = [self.iv_embedding, self.cl_embedding, self.ph_embedding]
95 | cols = [self.iv_cols, self.cl_col, self.ph_col]
96 | for i, embedding in enumerate(embeddings):
97 | if embedding is None: # phenotype embedding can be None
98 | continue
99 | data_label = remove_nonexistent_cat(data_label, embedding, cols[i], verbose)
100 | data_label = data_label.reset_index(drop=True)
101 |
102 | if len(data_label) == 0 and not verbose:
103 | self._remove_nonexistent_cat(data_label=data_label, verbose=True)
104 | raise ValueError('labels did not match embeddings passed!')
105 | return data_label
106 |
107 | def _init_input(
108 | self,
109 | iv_col: Union[List[str], str] = ['iv1', 'iv2'],
110 | cl_col: str = "cell_line",
111 | ph_col: str = "phenotype",
112 | readout_col: str = "value",
113 | ):
114 | """Sets some state variables in the model, but is always overwritten by
115 | either train or predict.
116 | """
117 | if isinstance(iv_col, str):
118 | iv_col = [iv_col]
119 | intervention_mapping = {col: f"iv{i+1}" for i, col in enumerate(iv_col)}
120 | self.iv_cols = list(intervention_mapping.values())
121 |
122 | # store the columns used for training for reference
123 | self.cl_col = cl_col
124 | self.ph_col = ph_col
125 | self.readout_col = readout_col
126 |
127 | # create the mapping to the internal variables used
128 | self.column_map = {
129 | self.cl_col: "cell_line",
130 | self.ph_col: "phenotype",
131 | self.readout_col: "value",
132 | **intervention_mapping
133 | }
134 | if self.pert_len is None:
135 | self.pert_len = len(self.iv_cols)
136 | else:
137 | if self.pert_len != len(self.iv_cols):
138 | raise ValueError(f"Are you sure you passed the right number of intervention columns? Currently receiving {self.iv_cols}")
139 |
140 | def train(
141 | self,
142 | df: pd.DataFrame,
143 | iv_col: Union[List[str], str] = ['iv1', 'iv2'],
144 | cl_col: str = "cell_line",
145 | ph_col: str = "phenotype",
146 | readout_col: str = "value",
147 | model_config: dict = None,
148 | ):
149 | """Train the Prophet model on the provided DataFrame.
150 |
151 | This function reformats the DataFrame according to the specified settings and intervention columns,
152 | then trains the model using the reformatted data.
153 |
154 | Args:
155 | df (pd.DataFrame): The DataFrame containing the experimental data.
156 | iv_col (Union[List[str], str]): The names of the intervention columns in the DataFrame. Can be a single column name or a list of names.
157 | cl_col (str, optional): The name of the column in df that contains the setting labels. Defaults to "cell_line".
158 | ph_col (str, optional): The name of the column in df that contains the phenotype labels. Defaults to "phenotype".
159 | readout_col (str, optional): The name of the column in df that contains the readout data. Defaults to "value".
160 | model_config (dict, optional): Configuration yaml, e.g. config_file_finetuning
161 | """
162 | self._init_input(iv_col, cl_col, ph_col, readout_col)
163 | # user-friendly check that the columns were passed in correctly
164 | for _, (old_name, new_name) in enumerate(self.column_map.items()):
165 | if old_name not in df.columns:
166 | raise ValueError(f"{old_name} not in df columns.")
167 |
168 | df = df.rename(columns=self.column_map).copy()
169 |
170 | # Formatting
171 | df = df.drop_duplicates()
172 | df = df.reset_index(drop=True)
173 |
174 | ## generate training dataloader
175 | df = self._remove_nonexistent_cat(data_label=df, verbose=True)
176 | split = dataloader_phenotypes(
177 | gene_embedding=self.iv_embedding,
178 | cell_lines_embedding=self.cl_embedding,
179 | phenotype_embedding=self.ph_embedding if self.ph_embedding is not None else None,
180 | data_label=df,
181 | label_name="value",
182 | index=(
183 | np.array(df.index),
184 | [],
185 | [],
186 | "",
187 | ), # (train, test, val, descr)
188 | torch_dataset=self.torch_dataset,
189 | pert_len=len(self.iv_cols)
190 | )
191 |
192 | print("Fitting model.")
193 | if not self.torch_dataset:
194 | X_train, y_train = split[2]
195 | self.model.fit(X_train, y_train)
196 | else:
197 | print("pytorch model, finetuning")
198 | # automatically take 10% of the data as validation set
199 | train_indices = np.array(df.index)[np.random.choice(len(df.index), int(len(df.index) * 0.9), replace=False)]
200 | val_indices = np.array(df.index)[~np.isin(df.index, train_indices)]
201 | split = dataloader_phenotypes(
202 | gene_embedding=self.iv_embedding,
203 | cell_lines_embedding=self.cl_embedding,
204 | phenotype_embedding=self.ph_embedding if self.ph_embedding is not None else None,
205 | data_label=df,
206 | label_name="value",
207 | index=(
208 | train_indices,
209 | val_indices,
210 | [],
211 | "",
212 | ), # (train, val, test, descr)
213 | torch_dataset=self.torch_dataset,
214 | pert_len=len(self.iv_cols),
215 | valid_set=True
216 | )
217 |
218 | model_config.ohe_dim = 0
219 |
220 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
221 | model, model_config = load_models_config(model_config, seed=42, phenotypes=None)
222 |
223 | lr_monitor = LearningRateMonitor(logging_interval='step')
224 | dirpath = model_config.dirpath
225 | model_checkpointer = ModelCheckpoint(dirpath=dirpath, save_top_k=1, every_n_epochs=1, monitor='R2_train', mode='max')
226 | r2_callback = R2ScoreCallback(device=model.device, average=False)
227 | early_stopping = EarlyStopping(monitor="R2_train", mode="max", patience=model_config.patience, min_delta=0.0)
228 |
229 | tqdm_progress_bar = TQDMProgressBar(refresh_rate=1)
230 | callbacks = [r2_callback, model_checkpointer, lr_monitor, early_stopping,tqdm_progress_bar]
231 |
232 | print(f"Running with early stopping: {model_config.early_stopping}")
233 | if model_config.early_stopping:
234 | print(f"Early stopping patience: {model_config.patience}")
235 |
236 | trainer = pl.Trainer(
237 | min_epochs=1,
238 | #max_steps=100,
239 | max_steps=model_config.max_steps,
240 | max_epochs=2,
241 | accelerator='gpu',
242 | check_val_every_n_epoch=1,
243 | callbacks=callbacks,
244 | strategy="auto", #choose a notebook-compatible strategy: `Trainer(strategy='ddp_notebook')`
245 | #precision="16-mixed",
246 | enable_progress_bar=True,
247 | gradient_clip_val=1,
248 | log_every_n_steps = 1,
249 | deterministic=True)
250 |
251 | trainer.fit(model=model, train_dataloaders=split[0], val_dataloaders=split[1])
252 |
253 | def _generate_predict_df(self,
254 | run_index: int,
255 | num_iterations: int,
256 | target_ivs: List[str],
257 | target_cls: List[str],
258 | target_phs: List[str] = ['_'],
259 | ):
260 | subset_cl = pd.DataFrame(target_cls, columns=["cell_line"])
261 | subset_iv = pd.DataFrame(target_ivs, columns=["iv"])
262 | if target_phs is None:
263 | target_phs = ['_']
264 | subset_ph = pd.DataFrame(target_phs, columns=["phenotype"])
265 | if len(self.iv_cols) > 2:
266 | raise NotImplementedError("Only support 1 or 2 interventions if you input a list of interventions. Please create the data label dataframe yourself and input to predict()!")
267 |
268 | batch_size = int(len(subset_iv) // num_iterations)
269 | start_idx = run_index * batch_size
270 | end_idx = (
271 | (start_idx + batch_size)
272 | if (run_index < num_iterations - 1)
273 | else len(subset_iv["iv"])
274 | )
275 |
276 | data_label = pd.merge(subset_iv[["iv"]][start_idx:end_idx], subset_cl, how="cross")
277 | data_label = pd.merge(data_label, subset_ph, how="cross")
278 |
279 | if len(self.iv_cols) == 1:
280 | data_label.rename(columns={"iv": "iv1"}, inplace=True)
281 | else:
282 | data_label = pd.merge(subset_iv[["iv"]], data_label, how="cross", suffixes=("1", "2"))
283 | # A+B and B+A should be the same, so we remove all duplicates in favor of A+B (was pretty sure this shouldn't exist in the implementation @John)
284 | data_label['iv1+iv2'] = ['+'.join(sorted([row['iv1'], row['iv2']])) for _, row in data_label.iterrows()]
285 | data_label = data_label.drop_duplicates(subset=['iv1+iv2', 'cell_line', 'phenotype'])
286 |
287 | data_label['value'] = '_'
288 |
289 | return data_label
290 |
291 | def _decide_iteration_num(
292 | self,
293 | total_size: int,
294 | single_run_size: int = None,
295 | memory_size: int = None,
296 | ):
297 | if total_size <= single_run_size:
298 | num_iterations = 1
299 | else:
300 | num_iterations = total_size // single_run_size
301 |
302 | return int(num_iterations)
303 |
304 | def predict(
305 | self,
306 | df: pd.DataFrame = None,
307 | target_ivs: Union[str, List[str]] = None,
308 | target_cls: Union[str, List[str]] = None,
309 | target_phs: Union[str, List[str]] = None,
310 | iv_col: Union[List[str], str] = ['iv1', 'iv2'],
311 | cl_col: str = "cell_line",
312 | ph_col: str = "phenotype",
313 | num_iterations: int = None,
314 | save: bool = True,
315 | filename: str = "Prophet_prediction",
316 | ):
317 | """Predict outcomes using the trained Prophet model.
318 |
319 | This function can take either a DataFrame or a combination of gene and cell line lists to make predictions.
320 | If a dataframe is passed, which columns correspond to which inputs must also be passed. If lists are passed,
321 | all combinations within are taken (not including combinations).
322 |
323 | Args:
324 | df (pd.DataFrame, optional): The DataFrame containing the data for prediction. If None, predictions will be made for all combinations of provided genes and cell lines.
325 | target_ivs (Union[str, List[str]], optional): The intervention or list of interventions for prediction. If df and target_ivs are both None, predictions will be made for all available treatments.
326 | target_cls (Union[str, List[str]], optional): The cell line or list of cell lines for prediction. If df and target_cls are both None, predictions will be made for all available cell lines.
327 | target_phs (Union[str, List[str]], optional): The phenotype for prediction. If None, it will be set to "_".
328 | num_iterations (int, optional): The number of iterations to run the prediction. If None, it will be calculated automatically.
329 | save (bool, optional): Whether to save the prediction results to a file. If False, return the prediction result as dataframe. Defaults to True.
330 | filename (str, optional): The filename to save the prediction results. If None, a default name will be used.
331 | """
332 | if save:
333 | print(f"Saving results to {filename}.parquet")
334 |
335 | # If only pass target_cls and target_ivs
336 | if df is None:
337 |
338 | if isinstance(target_cls, str):
339 | target_cls = [target_cls]
340 | if isinstance(target_phs, str):
341 | target_phs = [target_phs]
342 | if isinstance(target_ivs, str):
343 | target_ivs = [target_ivs]
344 |
345 | if target_cls is None:
346 | target_cls = list(self.cl_embedding.index)
347 | warnings.warn(f"Trying to predict for all cell lines that are from {self.cl_embedding.index}!")
348 | print(f"There are {len(target_cls)} cell lines in total!")
349 | if target_ivs is None:
350 | target_ivs = list(self.iv_embedding.drop_duplicates().index) # because there compounds with the same embedding but different
351 | warnings.warn(f"Trying to predict for all gene combinations that are from {self.iv_embedding.index}!")
352 | print(f"There are {len(target_ivs)} genes in total!")
353 |
354 | # in case the user wants to specify more than one intervention
355 | if iv_col is None:
356 | iv_col = ['iv1']
357 |
358 | total_size = len(target_ivs) * len(target_cls) * len(target_phs)
359 | self._init_input(iv_col, 'cell_line', 'phenotype', 'value')
360 | # or pass a dataframe has similiar format with in train
361 | else:
362 | self._init_input(iv_col, cl_col, ph_col, 'value')
363 | # user-friendly column name check
364 | for _, (old_name, new_name) in enumerate(self.column_map.items()):
365 | if new_name == 'value': # prediction does not need a readout col
366 | continue
367 | if old_name not in df.columns:
368 | raise ValueError(f"{old_name} not in df columns.")
369 |
370 | df = df.rename(columns=self.column_map).copy()
371 |
372 | # same formatting as in train
373 | df = df.drop_duplicates()
374 | df = df.reset_index(drop=True)
375 | df = self._remove_nonexistent_cat(data_label=df, verbose=False)
376 | total_size = len(df)
377 |
378 | # Divide work into partitions. This allows users to run a large amount of inference
379 | # in one shot without memory problems.
380 | if num_iterations:
381 | if num_iterations < 1:
382 | raise KeyError("num_iterations must be larger than 0.")
383 | else:
384 | if num_iterations - self._decide_iteration_num(total_size=total_size, single_run_size=1e08) > 5:
385 | warnings.warn("The num_iterations you passed might be too small.")
386 | else:
387 | num_iterations = self._decide_iteration_num(total_size=total_size, single_run_size=1e08)
388 |
389 | if num_iterations % 2 == 0:
390 | num_iterations = num_iterations + 1
391 |
392 | print(f"There are {num_iterations} iterations")
393 | data_label_list = []
394 | for run_index in tqdm(range(num_iterations)):
395 |
396 | # If the user input is df
397 | if isinstance(df, pd.DataFrame):
398 | batch_size = int(total_size // num_iterations)
399 | start_idx = run_index * batch_size
400 | end_idx = (
401 | start_idx + batch_size
402 | if (run_index < num_iterations - 1)
403 | else len(df)
404 | )
405 | data_label = df.iloc[start_idx:end_idx]
406 |
407 | # If the user input is gene list and cl list
408 | else:
409 | data_label = self._generate_predict_df(run_index=run_index,num_iterations=num_iterations,target_ivs=target_ivs,target_cls=target_cls, target_phs=target_phs)
410 |
411 | data_label = data_label.drop_duplicates()
412 |
413 | data_label = self._remove_nonexistent_cat(data_label=data_label, verbose=not isinstance(df, pd.DataFrame))
414 |
415 | if self.torch_dataset: # format for pytorch dataloading
416 | # must have a value column
417 | data_label['_'] = 0
418 |
419 | split = dataloader_phenotypes(
420 | gene_embedding=self.iv_embedding,
421 | cell_lines_embedding=self.cl_embedding,
422 | phenotype_embedding=self.ph_embedding,
423 | data_label=data_label,
424 | label_name='_',
425 | index=(
426 | np.array(data_label.index),
427 | [],
428 | np.array(data_label.index).tolist(), # test is not shuffled
429 | "",
430 | ), # (train, test, val, descr)
431 | torch_dataset=self.torch_dataset,
432 | pert_len=self.pert_len
433 | )
434 |
435 | if self.torch_dataset:
436 | X = split[2] # take test is not shuffled
437 | else:
438 | X, _ = split[2]
439 |
440 | train_dataloader, valid_dataloader, test_dataloader, train_indices, test_indices, descriptor = split
441 | trainer = pl.Trainer(devices=1)
442 | predictions = trainer.predict(self.model, test_dataloader)
443 | predictions = [t[0] for t in predictions]
444 | predictions = torch.cat(predictions, dim=0)
445 | data_label["pred"] = predictions
446 | data_label.drop(columns=['_'], inplace=True)
447 |
448 | if save:
449 | data_label.to_parquet(
450 | f"{filename}_{run_index}.parquet",
451 | # engine="fastparquet"
452 | )
453 | else:
454 | data_label_list.append(data_label)
455 | if not save:
456 | concatenated_df = pd.concat(data_label_list)
457 | return concatenated_df
458 |
--------------------------------------------------------------------------------
/prophet/__init__.py:
--------------------------------------------------------------------------------
1 | from .Prophet import Prophet
2 | from .config import set_config
3 |
4 | __all__ = ['Prophet', 'set_config']
5 |
--------------------------------------------------------------------------------
/prophet/callbacks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytorch_lightning as pl
3 | from torch import optim
4 | import numpy as np
5 | from sklearn.metrics import r2_score
6 | from scipy.stats import spearmanr
7 |
8 |
9 | class CosineWarmupScheduler(optim.lr_scheduler._LRScheduler):
10 |
11 | def __init__(self, optimizer, warmup, max_iters):
12 | self.warmup = warmup
13 | self.max_num_iters = max_iters
14 | super().__init__(optimizer)
15 |
16 | def get_lr(self):
17 | lr_factor = self.get_lr_factor(epoch=self.last_epoch)
18 | return [base_lr * lr_factor for base_lr in self.base_lrs]
19 |
20 | def get_lr_factor(self, epoch):
21 | lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters))
22 | if epoch <= self.warmup:
23 | lr_factor *= epoch * 1.0 / self.warmup
24 | return lr_factor
25 |
26 |
27 | class R2ScoreCallback(pl.Callback):
28 | def __init__(self, device: torch.device = 'cpu', average = False):
29 | super().__init__()
30 | self.predictions = []
31 | self.targets = []
32 |
33 |
34 | self.prediction_train = []
35 | self.prediction_test = []
36 |
37 | self.targets_train = []
38 | self.targets_test = []
39 |
40 | self.phenotype_validation = []
41 |
42 | self.device = device
43 | self.average = average
44 |
45 | self.table_val = None
46 | self.table_train = None
47 |
48 | print("R2 average: ", self.average)
49 |
50 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
51 | y_pred, y_true = outputs['y_pred'].detach(), outputs['y_true'].detach()
52 | self.prediction_train.append(y_pred)
53 | self.targets_train.append(y_true)
54 |
55 | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
56 | y_pred, y_true, phenotype = outputs['y_pred'], outputs['y_true'], outputs['phenotype']
57 | self.predictions.append(y_pred)
58 | self.targets.append(y_true)
59 | self.phenotype_validation.append(phenotype)
60 |
61 | def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
62 | y_pred, y_true = outputs['y_pred'], outputs['y_true']
63 | self.prediction_test.append(y_pred)
64 | self.targets_test.append(y_true)
65 |
66 | def on_validation_epoch_end(self, trainer, pl_module):
67 | predictions = torch.cat(self.predictions, dim=0)#.cpu().numpy()
68 | targets = torch.cat(self.targets, dim=0)#.cpu().numpy()
69 | phenotypes = torch.cat(self.phenotype_validation, dim=0)
70 |
71 |
72 | predictions = predictions.cpu().numpy()
73 | targets = targets.cpu().numpy()
74 | phenotypes = phenotypes.cpu().numpy()
75 |
76 | if self.average:
77 | r2_scores = 0
78 | spearman_scores = 0
79 | unique_phe = np.unique(phenotypes)
80 | for phe in unique_phe:
81 | indices = np.nonzero(phenotypes == phe)
82 | phe_predictions = predictions[indices]
83 | phe_targets = targets[indices]
84 |
85 | r2 = r2_score(phe_targets, phe_predictions)
86 | r2_scores += r2
87 |
88 | spearman = spearmanr(phe_predictions, phe_targets).statistic
89 | spearman_scores += spearman
90 |
91 | r2_total = r2_scores / len(unique_phe)
92 | spearman_total = spearman_scores / len(unique_phe)
93 |
94 | self.log("R2", r2_total, sync_dist=True, batch_size=predictions.shape[0])
95 | self.log("Spearman", spearman_total, sync_dist=True, batch_size=predictions.shape[0])
96 |
97 | else:
98 |
99 | r2 = r2_score(targets, predictions)
100 | self.log("R2", r2, sync_dist=True, batch_size=predictions.shape[0])
101 |
102 | spearman = spearmanr(predictions, targets).statistic
103 | self.log("Spearman", spearman, sync_dist=True, batch_size=predictions.shape[0])
104 |
105 | self.predictions = []
106 | self.targets = []
107 | self.phenotype_validation = []
108 |
109 | def on_train_epoch_end(self, trainer, pl_module):
110 | predictions = torch.cat(self.prediction_train, dim=0)#.cpu().numpy()
111 | targets = torch.cat(self.targets_train, dim=0)#.cpu().numpy()
112 |
113 | targets_mean = targets.mean(0)
114 | targets_mean = targets_mean.repeat(targets.shape[0])
115 |
116 | targets_mean = targets_mean.cpu().numpy()
117 | predictions = predictions.cpu().numpy()
118 | targets = targets.cpu().numpy()
119 |
120 | r2 = r2_score(targets, predictions)
121 | self.log("R2_train", r2, sync_dist=True, batch_size=predictions.shape[0])
122 |
123 | spearman = spearmanr(predictions, targets).statistic
124 | self.log("Spearman_train", spearman, sync_dist=True, batch_size=predictions.shape[0])
125 |
126 | self.prediction_train = []
127 | self.targets_train = []
128 |
129 | def on_test_epoch_end(self, trainer, pl_module):
130 | predictions = torch.cat(self.prediction_test, dim=0)#.cpu().numpy()
131 | targets = torch.cat(self.targets_test, dim=0)#.cpu().numpy()
132 |
133 | targets_mean = targets.mean(0)
134 | targets_mean = targets_mean.repeat(targets.shape[0])
135 |
136 | targets_mean = targets_mean.cpu().numpy()
137 | predictions = predictions.cpu().numpy()
138 | targets = targets.cpu().numpy()
139 |
140 | r2 = r2_score(targets, predictions)
141 | self.log("R2_test", r2, sync_dist=True, batch_size=predictions.shape[0])
142 |
143 | spearman = spearmanr(predictions, targets).statistic
144 | self.log("Spearman_test", spearman, sync_dist=True, batch_size=predictions.shape[0])
145 |
146 | self.prediction_test = []
147 | self.targets_test = []
--------------------------------------------------------------------------------
/prophet/config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import List, Dict
3 |
4 | @dataclass
5 | class TransformerConfig:
6 | dim_cl: int = 300
7 | dim_iv: int = 800
8 | dim_phe: int = 300
9 | model_dim: int = 128
10 | num_heads: int = 1
11 | num_layers: int = 2
12 | iv_dropout: float = 0.2
13 | cl_dropout: float = 0.2
14 | ph_dropout: float = 0.2
15 | regressor_dropout: float = 0.2
16 | lr: float = 0.0001
17 | weight_decay: float = 0.01
18 | warmup: int = 10000
19 | max_iters: int = 70000
20 | dropout: float = 0.2
21 | exclude_cl_embedding: bool = False
22 | pool: str = "cls"
23 | simpler: bool = True
24 | mask: bool = True
25 | sum: bool = False
26 | explicit_phenotype: bool = False
27 | linear_predictor: bool = False
28 | tokenizer_layers: int = 2
29 |
30 | def update_from_dict(self, updates: Dict):
31 | for key, value in updates.items():
32 | if hasattr(self, key):
33 | setattr(self, key, value)
34 |
35 | @dataclass
36 | class Config:
37 | setting: str
38 | leaveout_method: str
39 | dirpath: str = './ckpts/'
40 | ckpt_path: str = None
41 | project_name: str = "Prophet_hparams"
42 | cell_lines_prior: List[str] = field(default_factory=lambda: ["./embeddings/cell_line_embedding_full_ccle_300_scaled.csv"])
43 | genes_prior: List[str] = field(default_factory=lambda: ["./embeddings/ccle_T_pca_300_enformer_full_gene_mean_PCA_500_scaled.csv"] )
44 | phenotype_prior: List[str] = None
45 | unbalanced: bool = False
46 | pert_len: int = 2
47 | max_steps: int = 140000
48 | batch_size: int = 2048
49 | early_stopping: bool = True
50 | patience: int = 20
51 | ckpt_path = None
52 | fine_tune = False
53 | transformer: TransformerConfig = field(default_factory=TransformerConfig)
54 |
55 | def update_from_dict(self, updates: Dict):
56 | for key, value in updates.items():
57 | if key == 'Transformer' and isinstance(value, dict):
58 | self.transformer.update_from_dict(value)
59 | elif hasattr(self, key):
60 | setattr(self, key, value)
61 |
62 | def set_config(models_config):
63 | config = Config(
64 | setting=models_config['setting'],
65 | leaveout_method=models_config['leaveout_method'],
66 | )
67 | config.update_from_dict(models_config)
68 |
69 | if config.transformer.simpler:
70 | config.ctx_len = config.pert_len + 1
71 | else:
72 | config.ctx_len = config.pert_len + 3
73 | return config
74 |
--------------------------------------------------------------------------------
/prophet/dataloader.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader
2 | from typing import List, Tuple
3 | from torch.utils.data import WeightedRandomSampler
4 | import numpy as np
5 | import pandas as pd
6 | from functools import reduce
7 | from .dataset import PhenotypeDataset
8 |
9 | SEED = 42 # the true, baseline seed (that sets test splits)
10 |
11 | def _choose(a, size, seed):
12 | """Guaranteed deterministic choosing."""
13 | np.random.seed(seed) # reset the generator
14 | return np.random.choice(a, size=size, replace=False)
15 |
16 | def dataloader_phenotypes(
17 | gene_embedding: List[pd.DataFrame],
18 | cell_lines_embedding: List[pd.DataFrame],
19 | phenotype_embedding: List[pd.DataFrame],
20 | data_label,
21 | index,
22 | batch_size = 2048,
23 | label_name=None,
24 | unbalanced: bool = False,
25 | torch_dataset: bool = True,
26 | pert_len: int = 2,
27 | valid_set: bool = True,
28 | test_set: bool = True,
29 | phenotypes: list = None,
30 | ) -> List[Tuple[DataLoader, DataLoader, DataLoader, np.array, np.array]] :
31 | """Dataloader for multiple sources of information
32 |
33 | Args:
34 | gene_embedding (pd.DataFrame): index needs to be gene or gRNA
35 | cell_lines_embedding (pd.DataFrame): index needs to be cancer cell line name
36 | phenotype_embedding (pd.DataFrame): index needs to be phenotype name
37 | data_label (_type_): experimental data with columns that match indices in gene_embedding dataframes and cell line embedings dataframes
38 | label_name (_type_): label to predict
39 | indices (_type_, optional): _description_. Defaults to None.
40 | batch_size (_type_, optional): -1 to indicate work with sklearn
41 | embedding (Optional[List[bool]], optional): whether gene and cell_line should send index for embedding of vectors. [F, F] send vectors, [F, T] sends vector for genes and index for cell_line
42 | pert_len (int): number of perturbations to give to the model
43 | phenotypes (list): Pass a list of phenotypes to force indexing to occur correctly at prediction time for a pytorch model. If None, automatically determined from data_label.
44 | Returns:
45 | _type_: _description_
46 | """
47 | train_indices, valid_indices, test_indices, cl_holdout = index
48 | if len(valid_indices) == 0:
49 | valid_set = False
50 | if len(test_indices) == 0:
51 | test_set = False
52 |
53 | # accounting for multiple datasets
54 | test_dict = None
55 | if type(test_indices) == dict:
56 | test_dict = test_indices.copy()
57 | test_indices = test_indices['all']
58 |
59 | # value checks
60 | if 'type' not in gene_embedding.columns:
61 | raise ValueError("No column 'type' in gene_embedding")
62 |
63 | if not torch_dataset:
64 | # create input dataframes
65 | ge = gene_embedding.dropna()
66 | ce = cell_lines_embedding.dropna()
67 | ge = ge.drop(columns=['type'])
68 |
69 | if phenotype_embedding is not None:
70 | pe = phenotype_embedding.dropna()
71 | X_phenotype = pe.loc[data_label.phenotype].to_numpy()
72 | else:
73 | X_phenotype = pd.get_dummies(data_label['phenotype']).astype(float).values # 1he
74 | X_iv = np.concatenate(
75 | [ge.loc[data_label[f'iv{i}']].to_numpy() for i in range(1, pert_len + 1)]
76 | , axis=1)
77 | X_cellline = ce.loc[data_label.cell_line].to_numpy()
78 | if label_name is None:
79 | return [
80 | (np.concatenate([X_phenotype[idxs], X_cellline[idxs], X_iv[idxs]], axis=1),
81 | None) \
82 | for idxs in [train_indices, valid_indices, test_indices]]
83 | else:
84 | y_label = data_label[label_name].to_numpy()
85 |
86 | return [
87 | (np.concatenate([X_phenotype[idxs], X_cellline[idxs], X_iv[idxs]], axis=1),
88 | y_label[idxs]) \
89 | for idxs in [train_indices, valid_indices, test_indices]]
90 |
91 | data = data_label.copy()
92 | if phenotypes is None:
93 | phenotypes = sorted(list(data_label.phenotype.unique()))
94 |
95 | train_set = PhenotypeDataset(
96 | experimental_data = data.loc[train_indices],
97 | label_key = label_name,
98 | iv_embeddings = gene_embedding,
99 | cell_line_embeddings = cell_lines_embedding,
100 | phenotype_embeddings = phenotype_embedding if phenotype_embedding is not None else None, # if it's None, None[0] will return error
101 | phenotypes = phenotypes,
102 | pert_len=pert_len
103 | )
104 | if valid_set:
105 | valid_set = PhenotypeDataset(
106 | experimental_data = data.loc[valid_indices],
107 | label_key = label_name,
108 | iv_embeddings = gene_embedding,
109 | cell_line_embeddings = cell_lines_embedding,
110 | phenotype_embeddings = phenotype_embedding if phenotype_embedding is not None else None,
111 | phenotypes = phenotypes,
112 | pert_len=pert_len
113 | )
114 | valid_dataloader = DataLoader(valid_set, batch_size=batch_size, shuffle=False, num_workers=4)
115 | else:
116 | valid_dataloader = None
117 |
118 | if test_set:
119 | test_set = PhenotypeDataset(
120 | experimental_data = data.loc[test_indices],
121 | label_key = label_name,
122 | iv_embeddings = gene_embedding,
123 | cell_line_embeddings = cell_lines_embedding,
124 | phenotype_embeddings = phenotype_embedding if phenotype_embedding is not None else None,
125 | phenotypes = phenotypes,
126 | pert_len=pert_len
127 | )
128 | test_dataloader = DataLoader(test_set, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4)
129 | # convert to dict if there are multiple test sets
130 | if test_dict is not None:
131 | test_dl_dict = {'all': test_dataloader}
132 | for k, test_indices in test_dict.items():
133 | if k == 'all': # already loaded in test_dataloader
134 | pass
135 | test_set = PhenotypeDataset(
136 | experimental_data = data.loc[test_indices],
137 | label_key = label_name,
138 | iv_embeddings = gene_embedding,
139 | cell_line_embeddings = cell_lines_embedding,
140 | phenotype_embeddings = phenotype_embedding if phenotype_embedding is not None else None,
141 | phenotypes = phenotypes,
142 | pert_len=pert_len
143 | )
144 | test_dl_dict[k] = DataLoader(test_set, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4)
145 |
146 | else:
147 | test_dataloader = None
148 |
149 | if unbalanced: # unbalanced means that one phenotype is way more measured that the other
150 |
151 | ## Count the occurrences of each class
152 | key = 'phenotype'
153 | if 'dataset' in data.columns:
154 | key = 'dataset'
155 | class_counts = data.loc[train_indices][key].value_counts()
156 | num_samples = len(data)
157 | class_weights = [num_samples / class_counts.values[i] for i in range(len(class_counts))]
158 | scaling_factor = 1 / min(class_weights)
159 | class_weights = [x * scaling_factor for x in class_weights]
160 | class_weight_dict = dict(zip(class_counts.index, class_weights))
161 |
162 | # Create a custom WeightedRandomSampler to oversample the minority class
163 | weights = [class_weight_dict[c] for c in data.loc[train_indices][key].values]
164 | train_sampler = WeightedRandomSampler(weights, len(weights), replacement=True)
165 | train_dataloader = DataLoader(train_set, batch_size=batch_size, sampler=train_sampler, num_workers=4)
166 | else:
167 | train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4)
168 |
169 | if test_dict:
170 | return (train_dataloader, valid_dataloader, test_dl_dict, train_indices, test_indices, cl_holdout)
171 | else:
172 | return (train_dataloader, valid_dataloader, test_dataloader, train_indices, test_indices, cl_holdout)
173 |
174 | def read_in_priors(prior_files):
175 | """Convert list of filenames to dfs. When presented with multiple files, assumes they contain different
176 | indices and adds additional rows. Assumes columns which contain the same feature are named accordingly."""
177 | prior = []
178 | if isinstance(prior_files, str):
179 | prior_files = [prior_files]
180 | for file in range(len(prior_files)):
181 | emb = pd.read_csv(prior_files[file], index_col=0)
182 | emb.index = emb.index.astype(str)
183 | prior.append(emb)
184 |
185 | return pd.concat(prior).fillna(0) if len(prior) > 0 else None
186 |
187 | def process_priors(genes_prior, cell_lines_prior, phenotype_prior):
188 | gene_prior = read_in_priors(genes_prior)
189 | try:
190 | gene_prior = gene_prior.set_index('smiles')
191 | except KeyError: # still works even if there's no smiles column, like for genetic interventions
192 | pass
193 | gene_prior.index = [str(x).lower() for x in gene_prior.index] # allow translatability across organisms and drugs
194 | cl_prior = read_in_priors(cell_lines_prior)
195 | phe_prior = None
196 | if phenotype_prior is not None:
197 | phe_prior = read_in_priors(phenotype_prior)
198 |
199 | if gene_prior is not None and "type" not in gene_prior.columns:
200 | raise ValueError("type not in iv_embedding columns")
201 |
202 | # add 0 for control
203 | gene_prior.loc['negative_gene'] = 0
204 | gene_prior.loc['negative_drug'] = 0
205 | gene_prior.loc['negative_gene', 'type'] = 'gene'
206 | gene_prior.loc['negative_drug', 'type'] = 'drug'
207 |
208 | return gene_prior, cl_prior, phe_prior
209 |
210 |
211 | def remove_nonexistent_cat(data_label, prior, columns, verbose=True):
212 | """Takes in a cell line or gene prior embedding dataframe and removes rows from
213 | data_label where the embedding doesn't exist.
214 |
215 | Parameters
216 | ----------
217 | data_label : pandas.DataFrame
218 | A dataframe with phenotype, cell context, and iv columns. This dataframe is modified by the function.
219 | prior : pandas.DataFrame
220 | A dataframe representing prior embeddings. The index of this dataframe should contain the categories to be filtered on.
221 | columns : list[str]
222 | A list of column names in `data_label` where the values are checked against the categories in `prior`.
223 |
224 | Returns
225 | -------
226 | pandas.DataFrame
227 | """
228 | if isinstance(columns, str):
229 | columns = [columns]
230 | data_label_cats = reduce(lambda x, y: np.union1d(x, y), [data_label[col].astype(str).values for col in columns])
231 | emb_cats = prior.index.to_list()
232 | strings_to_remove = list(np.setdiff1d(data_label_cats, emb_cats))
233 | for col in columns:
234 | data_label = data_label[~data_label[col].isin(strings_to_remove)]
235 | if verbose:
236 | print(f"Removing {len(strings_to_remove)} such as {strings_to_remove[:5]} from {columns}. {data_label.shape[0]} rows remaining.", flush=True)
237 | return data_label
238 |
239 | def check_iv_emb(emb):
240 | if "type" not in emb.columns:
241 | raise KeyError("Intervention embedding has no `type` in columns.")
242 |
243 | def check_data(data_label):
244 | cols = set(data_label.columns)
245 | needed = set(["phenotype", "cell_line", "iv1"])
246 | if not needed.issubset(cols):
247 | raise KeyError(f"Cols is missing {cols-needed}")
248 |
249 | def universal_processing(data_label):
250 |
251 | if 'phenotype' not in data_label.columns:
252 | data_label['phenotype'] = 'none'
253 |
254 | data_label['value'] = data_label['value'].astype('f4')
255 | data_label = data_label.reset_index(drop=True)
256 |
257 |
258 | data_label['iv1'] = [x.lower() for x in data_label.iv1.values] # allow translatability across organisms and drugs
259 | data_label['iv2'] = [x.lower() for x in data_label.iv2.values]
260 | check_valid(data_label)
261 |
262 | data_label_flipped = data_label.rename(
263 | columns={'iv1': 'iv2', 'iv2': 'iv1'})
264 | data_label = pd.concat([data_label, data_label_flipped], axis=0, ignore_index=True)
265 |
266 | return data_label
267 |
268 | def check_valid(df):
269 | if 'iv1' not in df.columns:
270 | raise ValueError("Dataset must have at least one perturbation in a columns named iv1, iv2, etc.")
271 | if 'cell_line' not in df.columns:
272 | raise ValueError("Dataset must have a cellular context in a column labeled `cell_line`.")
273 | if 'negative' in df.iv1.values:
274 | raise ValueError("Dataset still contains the negative label, please specify negative_gene or negative_drug.")
--------------------------------------------------------------------------------
/prophet/dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | from torch.utils.data import Dataset
4 |
5 | class PhenotypeDataset(Dataset):
6 | """
7 | Dataset that gathers multiple phenotypes. The splits are done before. Experimental_data is the dataframe with the train, test or validation data.
8 | Each train, test or validation data is a different PhenotypeDataset dataset with different 'experimental_data' according to the splits
9 | """
10 | def __init__(
11 | self,
12 | experimental_data: pd.DataFrame,
13 | label_key: str,
14 | iv_embeddings: pd.DataFrame,
15 | cell_line_embeddings: pd.DataFrame,
16 | phenotype_embeddings: pd.DataFrame = None,
17 | phenotypes: list = None,
18 | cl_embedding: bool = False,
19 | pert_len: int = 2
20 | ):
21 | """
22 | Args:
23 | experimental_data (pd.DataFrame): experimental data, contains the label and the training data (combinations of gRNA)
24 | label_key (str): key that identifies the label in the experimental data
25 | iv_embeddings (pd.DataFrame): pandas dataframe with the embeddings of the perturbations
26 | cell_line_embeddings (pd.DataFrame): pandas dataframe with the embedding of cell lines
27 | cl_embedding (bool): if True, use predfined embedding; if False, retrieve just index cause it will be learn
28 | phenotypes (list): if phenotype embeddings are not provided, then a sorted list of phenotypes must be provided.
29 | pert_len (int): number of perturbations to provide to the model, context length will be pert_len + 2, which comes from phenotype + cell_type
30 | """
31 | # precompute the attention mask
32 | self.attn_mask = [[False]*experimental_data.shape[0]] # always pay attention to CLS, which is first token
33 | for i in range(1, pert_len + 1):
34 | col = f'iv{i}'
35 | mask_values = experimental_data[col].isin(['negative_drug', 'negative_gene']).values # mask if negative
36 | self.attn_mask.append(mask_values)
37 | self.attn_mask.append([False]*experimental_data.shape[0]) # once for cell_line
38 | self.attn_mask.append([False]*experimental_data.shape[0]) # once for phenotype
39 | self.attn_mask = np.array(self.attn_mask).T
40 |
41 | columns = ['cell_line', 'phenotype'] + [f'iv{i}' for i in range(1, pert_len + 1)]
42 | self.experimental_data = experimental_data[columns].values # ordered
43 | self.labels = experimental_data[label_key].values
44 | self.iv = iv_embeddings.iloc[:, 1:].values
45 | self.iv_embs_types = iv_embeddings.iloc[:, 0].values
46 | self.cell_line = cell_line_embeddings.values
47 | self.iv_to_index = dict(zip(iv_embeddings.index, range(iv_embeddings.shape[0])))
48 | self.cl_to_index = dict(zip(cell_line_embeddings.index, range(cell_line_embeddings.shape[0])))
49 |
50 | # special handling for phenotypes
51 | self.ph_to_index = dict(zip(phenotypes, range(len(phenotypes))))
52 | if phenotype_embeddings is not None:
53 | phenotype_embeddings = phenotype_embeddings.T[phenotypes].T # reorder the embedding so that ph_to_index matches
54 | self.phenotype_embeddings = phenotype_embeddings.values
55 | else:
56 | self.phenotype_embeddings = None
57 |
58 | self.pert_len = pert_len
59 |
60 | # print("Interventions: ", self.iv.shape)
61 | # print("Cell line: ", self.cell_line.shape)
62 | # print("Order of phenotypes: ", phenotypes)
63 | # print("Don't using explicit phenotype embeddings") if self.phenotype_embeddings is None else print(f"Explicit phenotype {self.phenotype_embeddings.shape} was passed")
64 |
65 | def __len__(self):
66 | return len(self.experimental_data)
67 |
68 | def __getitem__(self, idx):
69 | """
70 | Returns a dictionary with:
71 | phenotype: index to query
72 | cell_line: embedding
73 | label: scalar to predict
74 | names: names of the perturbations
75 | idx: index of the observation
76 | pert_type: list that says whether perturbations are genes or drugs
77 | **iv_values_dict: dictionary with keys iv1, iv2 etc. and respective embedding values. Same size as pert_len.
78 | """
79 |
80 | item = self.experimental_data[idx]
81 | cell_line = item[0]
82 | phenotype = item[1]
83 |
84 | iv_values_dict = {}
85 | iv_type = []
86 | for i in range(2, self.pert_len + 2):
87 | name = item[i] # perturbation name
88 | emb_entry = self.iv[self.iv_to_index[name]]
89 | iv_type.append(emb_entry[0]) # first item of the embedding is 'gene' or 'drug'
90 | iv_values_dict[f'iv{i-1}'] = emb_entry.astype('float64') # use all dimensions but the first one
91 |
92 | # if there isn't phenotype embedding
93 | if self.phenotype_embeddings is None:
94 | context = self.ph_to_index[phenotype] # retrieve index
95 | context = context + 1 # CLS is 0
96 | else: # if there's embedding
97 | context = self.phenotype_embeddings[self.ph_to_index[phenotype]] # retrieve embedding
98 |
99 | # Gene = 0, Drug = 1
100 | iv_types = [0 if item == 'gene' else (1 if item == 'drug' else item) for item in iv_type]
101 | iv_types = np.array(iv_types)
102 |
103 | return {'phenotype': context, # sometimes an int, sometimes an embedidng
104 | 'cell_line': self.cell_line[self.cl_to_index[cell_line]],
105 | 'label': self.labels[idx],
106 | 'attn_mask': self.attn_mask[idx],
107 | 'idx': idx,
108 | 'pert_type': iv_types,
109 | **iv_values_dict
110 | }
--------------------------------------------------------------------------------
/prophet/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import pytorch_lightning as pl
4 | from torch import optim
5 | import torch.nn.init as init
6 | from prophet.callbacks import CosineWarmupScheduler
7 | import logging
8 |
9 |
10 | class TransformerPredictor(pl.LightningModule):
11 |
12 | def __init__(self,
13 | dim_cl: int,
14 | dim_iv: int,
15 | dim_phe: int,
16 | model_dim: int,
17 | num_heads: int,
18 | num_layers: int,
19 | iv_dropout: float,
20 | cl_dropout: float,
21 | ph_dropout: float,
22 | regressor_dropout: float,
23 | lr: float,
24 | warmup: int,
25 | weight_decay: float,
26 | max_iters: int,
27 | batch_size: int = 0,
28 | dropout = 0.0,
29 | pool: str = 'cls',
30 | simpler: bool = False,
31 | ctx_len: int = 4,
32 | mask: bool = True,
33 | sum: bool = False,
34 | explicit_phenotype: bool = False,
35 | linear_predictor: bool = False,
36 | tokenizer_layers: int = 2,
37 | seed=42,
38 | phenotypes=None):
39 | """
40 | Inputs:
41 | dim_cl - Number of dimensions to take from the cell lines
42 | dim_iv - Number of dimensions of the interventional embeddings
43 | model_dim - Hidden dimensionality to use inside the Transformer
44 | num_heads - Number of heads to use in the Multi-Head Attention blocks
45 | num_layers - Number of encoder blocks to use.
46 | iv_dropout - dropout in iv layers
47 | cl_dropout - dropout in cl layers
48 | regressor_dropout - dropout in regressor
49 | lr - Learning rate in the optimizer
50 | warmup - Number of warmup steps. Usually between 50 and 300
51 | max_iters - Number of maximum iterations the model is trained for. This is needed for the CosineWarmup scheduler
52 | dropout - Dropout to apply inside the model
53 | pool - 'mean', 'cls', or 'pool'. Mean takes the mean, CLS predicts just with the CLS, pool takes the max value
54 | simpler - uses the transformer just for the set of perturbations, then it concatenates the result with the other representations
55 | ctx_len - context length
56 | mask - if True, mask attention, otherwise don't do it
57 | sum - if True, don't use Transformer, just sum embeddings
58 | explicit_phenotype - if True, the user passes an embedding as phenotype directly
59 | linear_predictor
60 | tokenizer_layers
61 | """
62 | super().__init__()
63 |
64 | ## Process phenotypes to save them in the checkpoint
65 | if phenotypes is not None:
66 | self.ph_to_index = dict(zip(phenotypes, range(len(phenotypes))))
67 | logging.info(f"Phenotypes: {self.ph_to_index}")
68 |
69 | self.save_hyperparameters()
70 | self._create_model()
71 |
72 | # Initialize the weights
73 | self.initialize_weights()
74 |
75 | def _create_model(self):
76 | self.learnable_embedding = torch.nn.Embedding(num_embeddings=1000,
77 | embedding_dim=self.hparams.model_dim,
78 | max_norm=0.5,
79 | )
80 | self.embedding_dropout = nn.Dropout(self.hparams.ph_dropout)
81 |
82 | # Tokenizer layer strong enough to non-linearly transform the data
83 | self.gene_net = nn.Sequential(
84 | nn.Linear(self.hparams.dim_iv, self.hparams.model_dim),
85 | nn.GELU(),
86 | nn.Dropout(self.hparams.iv_dropout),
87 | nn.Linear(self.hparams.model_dim, self.hparams.model_dim)
88 | )
89 |
90 | self.drug_net = nn.Sequential(
91 | nn.Linear(self.hparams.dim_iv, self.hparams.model_dim),
92 | nn.GELU(),
93 | nn.Dropout(self.hparams.iv_dropout),
94 | nn.Linear(self.hparams.model_dim, self.hparams.model_dim)
95 | )
96 |
97 | self.cl_net = nn.Sequential(
98 | nn.Linear(self.hparams.dim_cl, self.hparams.model_dim),
99 | nn.GELU(),
100 | nn.Dropout(self.hparams.cl_dropout),
101 | nn.Linear(self.hparams.model_dim, self.hparams.model_dim)
102 | )
103 |
104 | if self.hparams.tokenizer_layers == 1:
105 | self.gene_net = nn.Sequential(
106 | nn.Linear(self.hparams.dim_iv, self.hparams.model_dim),
107 | nn.Dropout(self.hparams.iv_dropout),
108 | )
109 | self.drug_net = nn.Sequential(
110 | nn.Linear(self.hparams.dim_iv, self.hparams.model_dim),
111 | nn.Dropout(self.hparams.iv_dropout),
112 | )
113 | self.cl_net = nn.Sequential(
114 | nn.Linear(self.hparams.dim_cl, self.hparams.model_dim),
115 | nn.Dropout(self.hparams.cl_dropout),
116 | )
117 |
118 |
119 | if self.hparams.explicit_phenotype:
120 | self.phenotype_net = nn.Sequential(
121 | nn.Linear(self.hparams.dim_phe, self.hparams.dim_phe),
122 | nn.GELU(),
123 | nn.Dropout(self.hparams.ph_dropout),
124 | nn.Linear(self.hparams.dim_phe, self.hparams.model_dim)
125 | )
126 |
127 | # Transformer
128 | layer = nn.TransformerEncoderLayer(d_model=self.hparams.model_dim,
129 | nhead=self.hparams.num_heads,
130 | dim_feedforward=2*self.hparams.model_dim,
131 | dropout=self.hparams.dropout,
132 | batch_first=True,
133 | activation="gelu")
134 |
135 | self.transformer = nn.TransformerEncoder(encoder_layer=layer,
136 | num_layers=self.hparams.num_layers)
137 |
138 | # 2 layers regressor
139 | dim_regressor_input = 2*self.hparams.model_dim + self.learnable_embedding.embedding_dim # CLS (model_dim) | CellLine (model_dim) | phenotype (varies)
140 | if not self.hparams.simpler:
141 | dim_regressor_input = self.hparams.model_dim
142 | if self.hparams.sum:
143 | dim_regressor_input = dim_regressor_input+self.hparams.model_dim+self.hparams.model_dim
144 |
145 | self.output_net = nn.Sequential(
146 | nn.Linear(dim_regressor_input, self.hparams.model_dim),
147 | nn.GELU(),
148 | nn.Dropout(self.hparams.regressor_dropout),
149 | nn.Linear(self.hparams.model_dim, self.hparams.model_dim),
150 | nn.GELU(),
151 | nn.Linear(self.hparams.model_dim, 1),
152 | )
153 | if self.hparams.linear_predictor:
154 | self.output_net = nn.Sequential(
155 | nn.Linear(dim_regressor_input, 1)
156 | )
157 |
158 | print('Gene net: ', self.gene_net, flush=True)
159 | print('Cell line net: ', self.cl_net, flush=True)
160 | print('Regressor: ', self.output_net, flush=True)
161 | if self.hparams.explicit_phenotype:
162 | print("Using explicit phenotype")
163 | if self.hparams.linear_predictor:
164 | print("Using linear predictor")
165 |
166 | def forward(self, phenotype, cl, perturbations, perturbations_type, attn_mask):
167 | """
168 | Inputs:
169 | x - Input features of shape [Batch, SeqLen, 1]
170 | """
171 | cl = cl[:,:self.hparams.dim_cl]
172 | perturbations = [pert[:, :self.hparams.dim_iv] for pert in perturbations]
173 | attn_mask = attn_mask[:, :self.hparams.ctx_len]
174 |
175 | if self.hparams.explicit_phenotype:
176 | phenotype_emb = self.phenotype_net(phenotype[:,:self.hparams.dim_phe])
177 | else:
178 | phenotype_emb = self.learnable_embedding(phenotype) # Phenotype
179 |
180 | phenotype_emb = self.embedding_dropout(phenotype_emb) # dropout
181 | # shape is (batch_size x n_dim)
182 |
183 | # Drugs to drug network and genes to gene network
184 | # We mask the attention to the negative perturbations, so it's like not using the networks
185 | drug_perturbations = [self.drug_net(tensor).unsqueeze(1) for tensor in perturbations] # all perts to drug
186 | gene_perturbations = [self.gene_net(tensor).unsqueeze(1) for tensor in perturbations] # all perts to gene
187 |
188 | drug_perturbations = torch.cat(drug_perturbations, dim=1) # bs x n x dim
189 | gene_perturbations = torch.cat(gene_perturbations, dim=1) # bs x n x dim
190 |
191 | perturbations = torch.where(perturbations_type.unsqueeze(2) == 0, gene_perturbations, drug_perturbations)
192 |
193 | cl_embedding = self.cl_net(cl).unsqueeze(1) # unsqueeze just useful if not simpler
194 | phenotype_emb = phenotype_emb.unsqueeze(1)
195 | # bs x n x dim
196 |
197 | # Regression token (CLS) stored in index 0
198 | cls = torch.zeros(size=(phenotype_emb.shape[0], 1), device=phenotype_emb.device, dtype=torch.int32)
199 | cls = self.learnable_embedding(cls) # we get the embedding, shape (bs x 1 x dim)
200 |
201 | if self.hparams.pool == 'cls':
202 | if self.hparams.simpler: # if simpler, just perturbations and CLS to transformer
203 | embeddings = torch.cat((cls, perturbations), dim=1)
204 | else: # otherwise everuthing to transformer
205 | embeddings = torch.cat((cls, perturbations, cl_embedding, phenotype_emb), dim=1)
206 | else: # if not CLS, we don't need it
207 | if self.hparams.simpler:
208 | embeddings = perturbations
209 | else:
210 | embeddings = torch.cat((perturbations, cl_embedding, phenotype_emb), dim=1)
211 |
212 | # Run Transformer Layer
213 | if not self.hparams.sum:
214 | if self.hparams.mask:
215 | x = self.transformer(embeddings, mask=None, src_key_padding_mask=attn_mask)
216 | else:
217 | x = self.transformer(embeddings, mask=None)
218 |
219 | if self.hparams.pool == 'cls':
220 | x = x[:, 0, :] # use just the regressor token for regression
221 | elif self.hparams.pool == 'mean':
222 | x = torch.mean(x, dim=1) # mean-pool
223 | else:
224 | x = torch.max(x, dim=1) # max-pool
225 |
226 | # If sum, forget about everything else
227 | if self.hparams.sum:
228 | x = torch.reshape(embeddings, (embeddings.shape[0], -1))
229 |
230 | if self.hparams.simpler:
231 | x = torch.cat((x, cl_embedding.squeeze(1), phenotype_emb.squeeze(1)), dim=-1)
232 |
233 |
234 | x = self.output_net(x)
235 |
236 | return x
237 |
238 | def embedding(self, phenotype, cl, perturbations, perturbations_type, attn_mask):
239 | """
240 | Inputs:
241 | x - Input features of shape [Batch, SeqLen, 1]
242 | """
243 |
244 | # Cut CL and perturbations to number of selected dimensions
245 | cl = cl[:,:self.hparams.dim_cl]
246 | perturbations = [pert[:, :self.hparams.dim_iv] for pert in perturbations]
247 | attn_mask = attn_mask[:, :self.hparams.ctx_len]
248 |
249 | if self.hparams.explicit_phenotype:
250 | # if explicit_phenotype, we take the phenotype that the user input. The user must make it dim_model-dimensional
251 | phenotype_emb = self.phenotype_net(phenotype[:,:self.hparams.model_dim])
252 | else:
253 | phenotype_emb = self.learnable_embedding(phenotype) # Phenotype
254 |
255 | phenotype_emb = self.embedding_dropout(phenotype_emb) # dropout
256 |
257 | # Drugs to drug network and genes to gene network
258 | # We mask the attention to the negative perturbations, so it's like not using the networks
259 | drug_perturbations = [self.drug_net(tensor).unsqueeze(1) for tensor in perturbations] # all perts to drug
260 | gene_perturbations = [self.gene_net(tensor).unsqueeze(1) for tensor in perturbations] # all perts to gene
261 |
262 | drug_perturbations = torch.cat(drug_perturbations, dim=1) # bs x n x dim
263 | gene_perturbations = torch.cat(gene_perturbations, dim=1) # bs x n x dim
264 |
265 | perturbations = torch.where(perturbations_type.unsqueeze(2) == 0, gene_perturbations, drug_perturbations)
266 |
267 | cl_embedding = self.cl_net(cl).unsqueeze(1) # unsqueeze just useful if not simpler
268 | phenotype_emb = phenotype_emb.unsqueeze(1)
269 | # bs x n x dim
270 |
271 | # Regression token (CLS) stored in index 0
272 | cls = torch.zeros(size=(phenotype_emb.shape[0], 1), device=phenotype_emb.device, dtype=torch.int32)
273 | cls = self.learnable_embedding(cls) # we get the embedding, shape (bs x 1 x dim)
274 |
275 | if self.hparams.pool == 'cls':
276 | if self.hparams.simpler: # if simpler, just perturbations and CLS to transformer
277 | embeddings = torch.cat((cls, perturbations), dim=1)
278 | else: # otherwise everuthing to transformer
279 | embeddings = torch.cat((cls, perturbations, cl_embedding, phenotype_emb), dim=1)
280 | else: # if not CLS, we don't need it
281 | if self.hparams.simpler:
282 | embeddings = perturbations
283 | else:
284 | embeddings = torch.cat((perturbations, cl_embedding, phenotype_emb), dim=1)
285 |
286 | # Run Transformer Layer
287 | if not self.hparams.sum:
288 | if self.hparams.mask:
289 | x = self.transformer(embeddings, mask=None, src_key_padding_mask=attn_mask)
290 | else:
291 | x = self.transformer(embeddings, mask=None)
292 |
293 | transformer_output = x
294 |
295 | if self.hparams.pool == 'cls':
296 | pert_emb = x[:, 0, :] # use just the regressor token for regression
297 | elif self.hparams.pool == 'mean':
298 | pert_emb = torch.mean(x, dim=1) # mean-pool
299 | else:
300 | pert_emb = torch.max(x, dim=1) # max-pool
301 |
302 | # If sum, forget about everything else
303 | if self.hparams.sum:
304 | #x = torch.sum(embeddings, dim=1)
305 | x = torch.reshape(embeddings, (embeddings.shape[0], -1))
306 |
307 | if self.hparams.simpler:
308 | x = torch.cat((pert_emb, cl_embedding.squeeze(1), phenotype_emb.squeeze(1)), dim=-1)
309 |
310 | x = self.output_net(x)
311 |
312 | return {'pert_emb': pert_emb, # output of transformer: CLS or mean
313 | 'output': x,
314 | 'perturbations_after_transformer': transformer_output[:, 1:, :],
315 | 'perturbations': perturbations, # tokens
316 | 'cl_embedding': cl_embedding, # tokens
317 | 'phenotype': phenotype,
318 | 'pert_type': perturbations_type,
319 | }
320 |
321 | def configure_optimizers(self):
322 | optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
323 |
324 | # Apply lr scheduler per step
325 | lr_scheduler = CosineWarmupScheduler(optimizer,
326 | warmup=self.hparams.warmup,
327 | max_iters=self.hparams.max_iters)
328 | return [optimizer], [{'scheduler': lr_scheduler, 'interval': 'step'}]
329 |
330 | def training_step(self, batch, batch_idx):
331 |
332 | # complete_masking(batch, self.hparams.ctx_len)
333 | attn_mask = batch['attn_mask']
334 | attn_mask = attn_mask[:, :self.hparams.ctx_len]
335 |
336 | phenotype = batch['phenotype']
337 | cl = batch['cell_line']
338 | y = batch['label']
339 | perturbations_type = batch['pert_type']
340 |
341 | perturbations = []
342 | for pert in range(1, self.hparams.ctx_len - (2 if not self.hparams.simpler else 0)):
343 | perturbations.append(batch[f'iv{pert}'].to(torch.float32))
344 |
345 | if self.hparams.explicit_phenotype:
346 | phenotype, cl = phenotype.to(torch.float32), cl.to(torch.float32)
347 | else:
348 | phenotype, cl = phenotype.to(torch.int32), cl.to(torch.float32)
349 |
350 | y_hat = self(phenotype, cl, perturbations, perturbations_type, attn_mask)
351 |
352 | y = y.unsqueeze(1)
353 |
354 | loss = torch.nn.functional.mse_loss(y, y_hat)
355 | self.log("train_loss", loss, prog_bar=True, sync_dist=True, batch_size=phenotype.shape[0])
356 |
357 | return {'loss': loss, 'y_pred': y_hat, 'y_true': y}
358 |
359 | def validation_step(self, batch, batch_idx):
360 |
361 | attn_mask = batch['attn_mask']
362 | attn_mask = attn_mask[:, :self.hparams.ctx_len]
363 |
364 | phenotype = batch['phenotype']
365 | cl = batch['cell_line']
366 | y = batch['label']
367 | perturbations_type = batch['pert_type']
368 |
369 | perturbations = []
370 | for pert in range(1, self.hparams.ctx_len - (2 if not self.hparams.simpler else 0)):
371 | perturbations.append(batch[f'iv{pert}'].to(torch.float32))
372 |
373 | if self.hparams.explicit_phenotype:
374 | phenotype, cl = phenotype.to(torch.float32), cl.to(torch.float32)
375 | else:
376 | phenotype, cl = phenotype.to(torch.int32), cl.to(torch.float32)
377 |
378 | y_hat = self(phenotype, cl, perturbations, perturbations_type, attn_mask)
379 |
380 | y = y.unsqueeze(1)
381 | loss = torch.nn.functional.mse_loss(y, y_hat)
382 |
383 | self.log("validation_loss", loss, sync_dist=True, batch_size=phenotype.shape[0])
384 |
385 | return {'y_pred': y_hat, 'y_true': y, 'phenotype': phenotype}
386 |
387 | def test_step(self, batch, batch_idx):
388 |
389 | attn_mask = batch['attn_mask']
390 | attn_mask = attn_mask[:, :self.hparams.ctx_len]
391 |
392 | phenotype = batch['phenotype']
393 | cl = batch['cell_line']
394 | y = batch['label']
395 | perturbations_type = batch['pert_type']
396 |
397 | perturbations = []
398 | for pert in range(1, self.hparams.ctx_len - (2 if not self.hparams.simpler else 0)):
399 | perturbations.append(batch[f'iv{pert}'].to(torch.float32))
400 |
401 | if self.hparams.explicit_phenotype:
402 | phenotype, cl = phenotype.to(torch.float32), cl.to(torch.float32)
403 | else:
404 | phenotype, cl = phenotype.to(torch.int32), cl.to(torch.float32)
405 |
406 | y_hat = self(phenotype, cl, perturbations, perturbations_type, attn_mask)
407 |
408 | y = y.unsqueeze(1)
409 | loss = torch.nn.functional.mse_loss(y, y_hat)
410 |
411 | self.log("test_loss", loss, sync_dist=True, batch_size=phenotype.shape[0])
412 |
413 | return {'y_pred': y_hat, 'y_true': y}
414 |
415 | def get_embeddings(self, batch):
416 |
417 | attn_mask = batch['attn_mask']
418 | attn_mask = attn_mask[:, :self.hparams.ctx_len]
419 |
420 | x = batch['phenotype']
421 | cl = batch['cell_line']
422 | names = batch['names']
423 | perturbations_type = batch['pert_type']
424 |
425 | perturbations = []
426 | for pert in range(1, self.hparams.ctx_len - (2 if not self.hparams.simpler else 0)):
427 | perturbations.append(batch[f'iv{pert}'].to(torch.float32))
428 |
429 | x, cl = x.to(torch.int32), cl.to(torch.float32)
430 |
431 | emb_dict = self.embedding(x, cl, perturbations, perturbations_type, attn_mask)
432 |
433 | return {'pert_emb': emb_dict['pert_emb'],
434 | 'output': emb_dict['output'],
435 | 'perturbations_after_transformer': emb_dict['perturbations_after_transformer'],
436 | 'perturbations': emb_dict['perturbations'],
437 | 'cell_line': emb_dict['cl_embedding'],
438 | 'phenotype': emb_dict['phenotype'],
439 | 'pert_type': emb_dict['pert_type'],
440 | 'names': names}
441 |
442 | def predict_step(self, batch, batch_idx):
443 |
444 | attn_mask = batch['attn_mask']
445 | attn_mask = attn_mask[:, :self.hparams.ctx_len]
446 |
447 | phenotype = batch['phenotype']
448 | cl = batch['cell_line']
449 | y = batch['label']
450 | perturbations_type = batch['pert_type']
451 |
452 | perturbations = []
453 | for pert in range(1, self.hparams.ctx_len - (2 if not self.hparams.simpler else 0)):
454 | perturbations.append(batch[f'iv{pert}'].to(torch.float32))
455 |
456 | if self.hparams.explicit_phenotype:
457 | phenotype, cl = phenotype.to(torch.float32), cl.to(torch.float32)
458 | else:
459 | phenotype, cl = phenotype.to(torch.int32), cl.to(torch.float32)
460 |
461 | y_hat = self(phenotype, cl, perturbations, perturbations_type, attn_mask)
462 |
463 | return y_hat, y
464 |
465 | def on_before_optimizer_step(self, optimizer) -> None:
466 |
467 | total_norm = 0
468 | for p in self.parameters():
469 | param_norm = p.grad.data.norm(2)
470 | total_norm += param_norm.item() ** 2
471 | total_norm = total_norm ** 0.5
472 | self.log('grad_norm', total_norm, sync_dist=True)
473 |
474 | return super().on_before_optimizer_step(optimizer)
475 |
476 | def initialize_weights(self):
477 | for m in self.modules():
478 | if isinstance(m, nn.Linear):
479 | # You can choose a different initialization method
480 | init.xavier_normal_(m.weight)
481 | init.zeros_(m.bias)
482 |
483 |
484 | def load_models_config(models_config, seed, hparams=False, trial=None, phenotypes=None):
485 | # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
486 |
487 | if hparams:
488 | if trial is None:
489 | raise ValueError("Must pass trial when hparams is True.")
490 |
491 | transformer = TransformerPredictor(dim_cl=models_config.transformer.dim_cl, dim_iv=models_config.transformer.dim_iv, dim_phe=models_config.transformer.dim_phe,
492 | model_dim=models_config.transformer.model_dim, num_heads=models_config.transformer.num_heads,
493 | num_layers=models_config.transformer.num_layers, iv_dropout=models_config.transformer.iv_dropout,
494 | cl_dropout=models_config.transformer.cl_dropout, ph_dropout=models_config.transformer.ph_dropout, regressor_dropout=models_config.transformer.regressor_dropout,
495 | lr=models_config.transformer.lr, weight_decay=models_config.transformer.weight_decay, warmup=models_config.transformer.warmup, batch_size=models_config.batch_size,
496 | max_iters=models_config.transformer.max_iters, dropout=models_config.transformer.dropout,
497 | pool=models_config.transformer.pool, simpler=models_config.transformer.simpler,
498 | ctx_len=models_config.ctx_len, mask=models_config.transformer.mask, sum=models_config.transformer.sum,
499 | explicit_phenotype=models_config.transformer.explicit_phenotype, linear_predictor=models_config.transformer.linear_predictor, tokenizer_layers=models_config.transformer.tokenizer_layers,
500 | seed=seed,phenotypes=None)
501 |
502 | return transformer, models_config
503 |
--------------------------------------------------------------------------------
/prophet/train.py:
--------------------------------------------------------------------------------
1 | from model import TransformerPredictor
2 | from prophet.callbacks import R2ScoreCallback
3 | import pytorch_lightning as pl
4 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
5 | from pytorch_lightning.loggers import WandbLogger
6 | from torchmetrics.regression import R2Score
7 | import torch
8 | import wandb
9 | import os
10 |
11 |
12 | def train_transformer(data, model, config, name, seed):
13 | train_dataloader, valid_dataloader, test_dataloader, train_indices, test_indices, descriptor = data
14 |
15 | wandb_config = {
16 | 'model': 'Transformer',
17 | 'descr':descriptor,
18 | 'seed': seed,
19 | 'leaveout_method':config.leaveout_method,
20 | 'setting':config.setting,
21 | 'path': config.path,
22 | 'pooling': config.transformer.pool,
23 | 'unbalanced':config.unbalanced,
24 | 'n_heads':config.transformer.num_heads,
25 | 'n_layers':config.transformer.num_layers,
26 | 'gene_prior': os.path.basename(config.genes_prior[0]),
27 | 'cell_lines_prior': os.path.basename(config.cell_lines_prior[0]),
28 | 'batch_size': config.batch_size,
29 | 'early_stopping': config.early_stopping,
30 | 'patience': config.patience,
31 | 'ckpt_path': config.ckpt_path,
32 | 'fine_tune': config.fine_tune,
33 | 'max_steps': config.max_steps,
34 | }
35 | sub_descr = {f'descr{i}':x for i, x in enumerate(descriptor.split('_'))}
36 |
37 | wandb_logger = WandbLogger(
38 | project=config.project_name,
39 | name=f'{descriptor}_{config.setting}_nheads_{config.transformer.num_heads}_nlayers_{config.transformer.num_layers}_{config.transformer.simpler}simpler_{config.transformer.mask}mask_{config.transformer.lr}lr_{config.transformer.warmup}warmup_{config.transformer.max_iters}max_iters',
40 | config={**wandb_config, **sub_descr},
41 | )
42 |
43 | lr_monitor = LearningRateMonitor(logging_interval='step')
44 | dirpath = f"./pretrained_prophet/{config.setting}/{descriptor}_{config.leaveout_method[6:]}_{config.transformer.dim_cl}cl_{config.transformer.dim_iv}iv_{config.transformer.model_dim}model_{config.transformer.num_layers}layers_{config.transformer.simpler}simpler_{config.transformer.mask}mask_{config.transformer.lr}lr_{config.transformer.explicit_phenotype}explicitphenotype_{config.transformer.warmup}warmup_{config.transformer.max_iters}max_iters_{config.unbalanced}unbalanced_{config.transformer.weight_decay}wd_{config.batch_size}bs_{config.fine_tune}ft/{name}_seed_{seed}"
45 | model_checkpointer = ModelCheckpoint(dirpath=dirpath, save_top_k=1, every_n_epochs=1, monitor='R2', mode='max')
46 | r2_callback = R2ScoreCallback(device=model.device, average=True if config.setting == 'everything' else False)
47 | early_stopping = EarlyStopping(monitor="R2", mode="max", patience=config.patience, min_delta=0.0)
48 |
49 | callbacks = [r2_callback, model_checkpointer, lr_monitor, early_stopping]
50 | if not config.early_stopping:
51 | callbacks = [r2_callback, model_checkpointer, lr_monitor]
52 |
53 | test_mode = False # manual toggle for debugging
54 |
55 | trainer = pl.Trainer(
56 | min_epochs=1,
57 | max_steps=config.max_steps,
58 | accelerator='gpu',
59 | devices=-1,
60 | check_val_every_n_epoch=1,
61 | callbacks=callbacks,
62 | logger=wandb_logger,
63 | strategy="ddp",
64 | gradient_clip_val=1,
65 | deterministic=True)
66 |
67 | if config.ckpt_path is not None:
68 | print("name: ", name)
69 | axis_name = 'gene' if 'gene' in config.leaveout_method else 'cl'
70 | if axis_name == 'gene':
71 | axis = 'iv'
72 | else:
73 | axis = 'cl'
74 | split_num = descriptor.split('_')[1]
75 | ckpt_path = f"{config.ckpt_path}/{axis_name}_{split_num}_seed_{seed}/"
76 | if config.fine_tune:
77 | ckpt_path = f"{config.ckpt_path.replace('iv_0_iv', f'{axis}_{split_num}_{axis}')}/{axis_name}_{split_num}_seed_{seed}/"
78 | files = os.listdir(ckpt_path) # ckpt_path is the path to a folder
79 | full_paths = [os.path.join(ckpt_path, file) for file in files]
80 | ckpt_file = max(full_paths, key=os.path.getmtime) # ckpt_file is the .ckpt file
81 | print(f"Resume training from {ckpt_file}")
82 | if config.fine_tune:
83 | print("Fine tuning")
84 | model = TransformerPredictor.load_from_checkpoint(ckpt_file, warmup=config.transformer.warmup)
85 | trainer.fit(model=model, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader) # train from scratch
86 | else:
87 | trainer.fit(model=model, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader, ckpt_path=config.ckpt_path)
88 | else:
89 | trainer.fit(model=model, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader, ckpt_path=config.ckpt_path)
90 |
91 | if not test_mode:
92 |
93 | # Get most recent checkpoint
94 | files = os.listdir(dirpath)
95 | full_paths = [os.path.join(dirpath, file) for file in files]
96 | ckpt_file = max(full_paths, key=os.path.getmtime)
97 |
98 | print(f"Testing model from {ckpt_file}")
99 | # Load best model in terms of R2
100 | model = TransformerPredictor.load_from_checkpoint(checkpoint_path=ckpt_file)
101 | print(type(test_dataloader))
102 | if not isinstance(test_dataloader, dict):
103 | trainer.test(model, test_dataloader)
104 | else:
105 | for id_dataset, dataset in test_dataloader.items():
106 | if id_dataset == 'all': # logs as R2_test as normal
107 | trainer.test(model, dataset)
108 | continue
109 |
110 | if id_dataset == 'all':
111 | trainer.test(model, dataset)
112 | continue
113 |
114 | predictions_and_targets = trainer.predict(model, dataset)
115 |
116 | predictions = [t[0] for t in predictions_and_targets]
117 | targets = [t[1] for t in predictions_and_targets]
118 |
119 | predictions = torch.cat(predictions, dim=0)
120 | targets = torch.cat(targets, dim=0)
121 |
122 | r2score = R2Score().to(model.device)
123 | wandb.log({f"R2_test_{id_dataset}": r2score(predictions, targets.unsqueeze(-1))})
124 |
125 |
126 | wandb.finish()
127 |
--------------------------------------------------------------------------------
/prophet/train_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import pprint
4 |
5 | from dataloader import dataloader_phenotypes, get_split_indices, get_data_by_setting
6 | from train import train_transformer
7 | from model import load_models_config
8 | import pytorch_lightning as pl
9 | from prophet.config import set_config
10 | import os
11 | import yaml
12 |
13 | # Add arguments
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument("--setting", type=str,
16 | default="Rad", required=False)
17 | parser.add_argument("--leaveout_method", type=str,
18 | default="leave_one_cl_out", required=False)
19 | parser.add_argument("--config_file", type=str, required=True) # config file with info regarding the architecture
20 | parser.add_argument("--fine_tune", action='store_true', default=False, required=False)
21 |
22 | args = parser.parse_args()
23 | config_file = args.config_file.split('-')[1]
24 | seed = int(args.config_file.split('-')[0])
25 |
26 | def get_global_rank():
27 | return int(os.getenv('SLURM_PROCID', '0'))
28 |
29 | if __name__ == "__main__":
30 |
31 | # Check the number of available GPUs
32 | num_gpus = torch.cuda.device_count()
33 | print("Number of gpus: ", num_gpus)
34 |
35 | global_rank = get_global_rank()
36 | print(f"Global Rank: {global_rank}")
37 |
38 | with open(config_file, 'r') as f:
39 | models_config = yaml.safe_load(f)
40 |
41 | models_config = set_config(models_config)
42 |
43 | # override config to the default fine tuning model parameters during fine-tuning
44 | if args.fine_tune:
45 | with open('config_files/config_file_finetuning.yaml', 'r') as f:
46 | ft_config = set_config(yaml.safe_load(f))
47 | ft_config.setting = models_config.setting
48 | ft_config.leaveout_method = models_config.leaveout_method
49 | ft_config.dirpath = models_config.dirpath
50 | models_config = ft_config
51 |
52 | pl.seed_everything(seed, workers=True)
53 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
54 | os.environ['TORCH_USE_CUDA_DSA'] = '1'
55 |
56 | data_label, gene_prior, cl_prior, phe_prior, path = get_data_by_setting(models_config.setting, models_config.genes_prior, models_config.cell_lines_prior, models_config.phenotype_prior)
57 |
58 | models_config.path = path # to print the path
59 |
60 | pprint.pprint(models_config)
61 |
62 | # Get the indices of the DataFrame
63 | indices = data_label.index
64 | indices = get_split_indices(data_label, models_config.leaveout_method, seed)
65 |
66 | for index in indices:
67 |
68 | data = dataloader_phenotypes(
69 | gene_embedding = gene_prior,
70 | cell_lines_embedding = cl_prior,
71 | phenotype_embedding = phe_prior,
72 | data_label = data_label,
73 | label_name = "value",
74 | index = index,
75 | batch_size = models_config.batch_size,
76 | unbalanced = models_config.unbalanced,
77 | pert_len = models_config.pert_len,
78 | )
79 |
80 | models_config.ohe_dim = 0
81 |
82 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
83 | model, models_config = load_models_config(models_config, seed)
84 |
85 | train_transformer(data=data, model=model, config=models_config, name=index[-1], seed=seed)
86 | del data
87 | del model
88 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='prophet',
5 | version='0.1.0',
6 | url="https://github.com/theislab/prophet",
7 | license='CC-BY-NC 4.0',
8 | description='Scalable and universal prediction of cellular phenotypes',
9 | author='Alejandro Tejada-Lapuerta, Yuge Ji',
10 | author_email='alejandro.tejada@helmholtz-munich.de, yuge.ji@helmholtz-munich.de',
11 | packages=find_packages(),
12 | python_requires='>=3.10',
13 | install_requires=[
14 | 'joblib==1.4.2',
15 | 'numpy==2.2.3',
16 | 'pandas==2.2.3',
17 | 'pytorch_lightning==2.1.0',
18 | 'PyYAML==6.0.2',
19 | 'scikit_learn==1.5.1',
20 | 'scipy==1.14.0',
21 | 'torch==2.3.0',
22 | 'torchmetrics==1.4.0.post0',
23 | 'tqdm==4.66.4',
24 | 'wandb==0.17.6',
25 | 'jupyterlab==4.1.5',
26 | 'pytest==8.3.5'
27 | ],
28 | )
--------------------------------------------------------------------------------
/test/test_dataloader.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import pandas as pd
3 | import numpy as np
4 | from torch.utils.data import DataLoader
5 |
6 | from prophet.dataloader import (
7 | dataloader_phenotypes,
8 | process_priors,
9 | remove_nonexistent_cat,
10 | universal_processing,
11 | )
12 |
13 | # Mock data
14 | @pytest.fixture
15 | def mock_data():
16 | iv_embedding = pd.DataFrame({
17 | 'type': ['gene', 'gene', 'drug', 'gene', 'drug'],
18 | 'feature1': [0.1, 0.2, 0.3, 0.0, 0.0],
19 | 'feature2': [0.4, 0.5, 0.6, 0.0, 0.0]
20 | }, index=['gene1', 'gene2', 'drug1', 'negative_gene', 'negative_drug']) # assume embedding for drug2 is not exist
21 |
22 |
23 | cell_lines_embedding = pd.DataFrame({
24 | 'feature1': [0.7, 0.8],
25 | 'feature2': [0.9, 1.0]
26 | }, index=['cell_line_1', 'cell_line_2'])
27 |
28 | phenotype_embedding = None # usually None
29 |
30 | data_label = pd.DataFrame({
31 | 'phenotype': ['phenotype1', 'phenotype2','phenotype1', 'phenotype2'],
32 | 'cell_line': ['cell_line_1', 'cell_line_2', 'cell_line_1', 'cell_line_2'],
33 | 'iv1': ['gene1', 'gene2', 'drug1', 'drug2'],
34 | 'iv2': ['negative_gene', 'negative_gene', 'negative_drug', 'negative_drug'],
35 | 'value': [0.5, 0.6, 0.9, 0.3]
36 | })
37 |
38 | index = (
39 | np.array([0,2]), # train_indices
40 | np.array([1]), # validation_indices
41 | [], # test_indices
42 | None # cl_holdout
43 | )
44 |
45 | return iv_embedding, cell_lines_embedding, phenotype_embedding, data_label, index
46 |
47 | def test_dataloader_phenotypes(mock_data):
48 |
49 | iv_embedding, cell_lines_embedding, phenotype_embedding, data_label, index = mock_data
50 |
51 | # Just filter out rows where iv1 or iv2 values that dont have embeddings
52 | valid_iv1 = data_label['iv1'].isin(iv_embedding.index)
53 | valid_iv2 = data_label['iv2'].isin(iv_embedding.index)
54 | data_label = data_label[valid_iv1 & valid_iv2]
55 |
56 | train_dataloader, valid_dataloader, test_dataloader, train_indices, test_indices, cl_holdout = dataloader_phenotypes(
57 | gene_embedding=iv_embedding,
58 | cell_lines_embedding=cell_lines_embedding,
59 | phenotype_embedding=phenotype_embedding,
60 | data_label=data_label,
61 | index=index,
62 | batch_size=2,
63 | label_name='value',
64 | unbalanced=False,
65 | torch_dataset=True,
66 | valid_set=True
67 | )
68 |
69 | assert isinstance(train_dataloader, DataLoader)
70 | assert len(train_dataloader.dataset) == 2 # indices: 0,2
71 |
72 | assert isinstance(valid_dataloader, DataLoader)
73 | assert len(valid_dataloader.dataset) == 1 # indices: 1
74 |
75 | assert test_dataloader is None # no test samples given
76 |
77 | def test_process_priors(mock_data):
78 | iv_embedding, cell_lines_embedding, phenotype_embedding, _, _ = mock_data
79 |
80 | # setup temp paths
81 | iv_path = "./iv_embedding.csv"
82 | cl_path = "./cell_lines_embedding.csv"
83 | ph_path = "./phenotype_embedding.csv"
84 |
85 | iv_embedding.to_csv(iv_path)
86 | cell_lines_embedding.to_csv(cl_path)
87 | if phenotype_embedding is not None: #create dummy files for testing of process_priors
88 | pd.DataFrame().to_csv(ph_path)
89 | else:
90 | ph_path = None
91 |
92 | iv_prior, cl_prior, phe_prior = process_priors(
93 | genes_prior=[str(iv_path)],
94 | cell_lines_prior=[str(cl_path)],
95 | phenotype_prior=[str(ph_path)] if ph_path else None
96 | )
97 |
98 | assert 'negative_gene' in iv_prior.index
99 | assert 'negative_drug' in iv_prior.index
100 | assert iv_prior.loc['negative_gene', 'type'] == 'gene'
101 |
102 | assert cl_prior.shape == cell_lines_embedding.shape
103 |
104 | assert phe_prior == None
105 |
106 | def test_remove_nonexistent_cat(mock_data):
107 | iv_embedding, cell_lines_embedding, phenotype_embedding, data_label, _ = mock_data
108 | embeddings = [iv_embedding, cell_lines_embedding, phenotype_embedding]
109 |
110 | iv_cols = ['iv1', 'iv2']
111 | cl_col = "cell_line"
112 | ph_col = "phenotype"
113 | cols = [iv_cols, cl_col, ph_col]
114 |
115 | prior = pd.DataFrame(index=['gene1', 'gene2', 'drug1'])
116 | print("embeddings:", embeddings)
117 |
118 | for i, embedding in enumerate(embeddings):
119 | if embedding is None: # phenotype embedding can be None
120 | continue
121 | data_label = remove_nonexistent_cat(data_label, embedding, cols[i], verbose=True)
122 |
123 | assert len(data_label) == 3 # 4-1= 3 According to mock embeddings and datalabel, one row contains
124 | # "drug2" should be removed since its embedding does not exist!
125 | assert 'gene1' in data_label['iv1'].values
126 | assert 'gene2' in data_label['iv1'].values
127 | assert 'drug1' in data_label['iv1'].values
128 | assert 'drug2' not in data_label['iv1'].values # removed
129 |
130 | def test_universal_processing(mock_data):
131 | _, _, _, data_label, _ = mock_data
132 |
133 | processed_data = universal_processing(data_label)
134 |
135 | assert 'phenotype' in processed_data.columns
136 | assert 'value' in processed_data.columns
137 | assert 'iv1' in processed_data.columns
138 | assert 'iv2' in processed_data.columns
139 | assert len(processed_data) == 8 # original (4) + flipped rows (4)
--------------------------------------------------------------------------------
/tutorials/config_file_finetuning.yaml:
--------------------------------------------------------------------------------
1 | setting: NULL
2 | leaveout_method: NULL
3 |
4 | project_name: FT_Prophet
5 | pert_len: 2
6 | ckpt_path: /lustre/groups/ml01/projects/super_rad_project/pretrained_prophet/everything/iv_0_iv_out_multitest_300cl_1219iv_512model_8layers_Falsesimpler_Truemask_0.0001lr_Falseexplicitphenotype_10000warmup_150001max_iters_Falseunbalanced_0.01wd_4096bs_Falseft/iv_0_seed_2024/epoch=29.ckpt
7 | fine_tune: True
8 |
9 | cell_lines_prior: ['../embs/cell_line_embedding_full_ccle_300_scaled.csv']
10 | genes_prior: [ '../embs/global_iv_scaledv3.csv',
11 | '../embs/CTRP_with_smiles_simscaled.csv',
12 | '../embs/Hadian_plates_NEW_simscaled.csv']
13 | max_steps: 1000
14 | batch_size: 256
15 |
16 | Transformer:
17 | simpler: False
18 | num_layers: 8
19 | num_heads: 8
20 | model_dim: 512
21 | max_iters: 5000
22 | dim_iv: 1219
23 | warmup: 5000
24 | iv_dropout: 0.1
25 | cl_dropout: 0.1
26 | ph_dropout: 0.1
--------------------------------------------------------------------------------
/tutorials/finetuning.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "f95e67be",
6 | "metadata": {},
7 | "source": [
8 | "# fine-tuning using your in-house assay\n",
9 | "\n",
10 | "This tutorial is for users who have already run a functional assay and would like to answer the question \"what if I had run my assay with other drugs, other CRISPR treatments, or in other cell lines?\"\n",
11 | "\n",
12 | "For example, [Tieu et al (2024)](https://doi.org/10.1016/j.cell.2024.01.035) run a combinatorial CAR-T transduction with 24 guides for a total of 576 pairwise combinations. We can fine-tune Prophet on this dataset and make predictions for genes spanning the entire genome and additional combinations therein."
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 1,
18 | "id": "66127741",
19 | "metadata": {
20 | "scrolled": false
21 | },
22 | "outputs": [],
23 | "source": [
24 | "import pandas as pd\n",
25 | "import yaml\n",
26 | "from prophet import Prophet\n",
27 | "from prophet.config import set_config\n",
28 | "from prophet.dataloader import universal_processing"
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "id": "164fa87f",
34 | "metadata": {},
35 | "source": [
36 | "We load in a config file to get the file paths for the embeddings which should be used. These can be downloaded from the links on Github."
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": 2,
42 | "id": "8823f063",
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "with open('config_file_finetuning.yaml', 'r') as f:\n",
47 | " config = set_config(yaml.safe_load(f))"
48 | ]
49 | },
50 | {
51 | "cell_type": "markdown",
52 | "id": "6d1aec6c",
53 | "metadata": {},
54 | "source": [
55 | "We've randomly select one of the pretrained model checkpoints (noted in the config file above) to train on here. For more robust results, we recommend training with several checkpoints and taking the ensemble prediction."
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": 3,
61 | "id": "004e340c",
62 | "metadata": {},
63 | "outputs": [
64 | {
65 | "data": {
66 | "text/plain": [
67 | "'/lustre/groups/ml01/projects/super_rad_project/pretrained_prophet/everything/iv_0_iv_out_multitest_300cl_1219iv_512model_8layers_Falsesimpler_Truemask_0.0001lr_Falseexplicitphenotype_10000warmup_150001max_iters_Falseunbalanced_0.01wd_4096bs_Falseft/iv_0_seed_2024/epoch=29.ckpt'"
68 | ]
69 | },
70 | "execution_count": 3,
71 | "metadata": {},
72 | "output_type": "execute_result"
73 | }
74 | ],
75 | "source": [
76 | "config.ckpt_path"
77 | ]
78 | },
79 | {
80 | "cell_type": "markdown",
81 | "id": "54fd2dca",
82 | "metadata": {},
83 | "source": [
84 | "Load in the pretrained model."
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": 4,
90 | "id": "740983bb",
91 | "metadata": {},
92 | "outputs": [
93 | {
94 | "name": "stdout",
95 | "output_type": "stream",
96 | "text": [
97 | "returning trained model!\n",
98 | "Gene net: Sequential(\n",
99 | " (0): Linear(in_features=1219, out_features=512, bias=True)\n",
100 | " (1): GELU(approximate='none')\n",
101 | " (2): Dropout(p=0.2, inplace=False)\n",
102 | " (3): Linear(in_features=512, out_features=512, bias=True)\n",
103 | ")\n",
104 | "Cell line net: Sequential(\n",
105 | " (0): Linear(in_features=300, out_features=512, bias=True)\n",
106 | " (1): GELU(approximate='none')\n",
107 | " (2): Dropout(p=0.2, inplace=False)\n",
108 | " (3): Linear(in_features=512, out_features=512, bias=True)\n",
109 | ")\n",
110 | "Regressor: Sequential(\n",
111 | " (0): Linear(in_features=512, out_features=512, bias=True)\n",
112 | " (1): GELU(approximate='none')\n",
113 | " (2): Dropout(p=0.2, inplace=False)\n",
114 | " (3): Linear(in_features=512, out_features=512, bias=True)\n",
115 | " (4): GELU(approximate='none')\n",
116 | " (5): Linear(in_features=512, out_features=1, bias=True)\n",
117 | ")\n"
118 | ]
119 | }
120 | ],
121 | "source": [
122 | "model = Prophet(\n",
123 | " iv_emb_path=config.genes_prior,\n",
124 | " cl_emb_path=config.cell_lines_prior,\n",
125 | " ph_emb_path=None,\n",
126 | " model_pth=config.ckpt_path,\n",
127 | ")"
128 | ]
129 | },
130 | {
131 | "cell_type": "markdown",
132 | "id": "c4cdd598",
133 | "metadata": {},
134 | "source": [
135 | "Here we've provided two examples for finetuning to demonstrate the variety of datasets on which you can finetune Prophet. We recommend that all datasets have values minmaxed to between 0 and 1 before finetuning."
136 | ]
137 | },
138 | {
139 | "cell_type": "markdown",
140 | "id": "05e05905",
141 | "metadata": {},
142 | "source": [
143 | "### GDSC2\n",
144 | "\n",
145 | "GDSC2 (https://www.cancerrxgene.org/) is a cancer-screening dataset with titrated IC50s.\n",
146 | "\n",
147 | " - cell states: cancer cell lines\n",
148 | " - interventions: small molecule singletons\n",
149 | " - readout: IC50 of viability as measured using CellTitreGlo at 72hrs, fitted over multiple concentrations"
150 | ]
151 | },
152 | {
153 | "cell_type": "code",
154 | "execution_count": 5,
155 | "id": "b45cad0d",
156 | "metadata": {},
157 | "outputs": [
158 | {
159 | "data": {
160 | "text/html": [
161 | "
\n",
162 | "\n",
175 | "
\n",
176 | " \n",
177 | " \n",
178 | " | \n",
179 | " cell_line | \n",
180 | " iv1 | \n",
181 | " value | \n",
182 | " iv_name | \n",
183 | " phenotype | \n",
184 | " iv2 | \n",
185 | "
\n",
186 | " \n",
187 | " \n",
188 | " \n",
189 | " 0 | \n",
190 | " PFSK1 | \n",
191 | " cc[c@@]1(o)c(=o)occ2c1cc1-c3nc4ccccc4cc3cn1c2=o | \n",
192 | " 0.323078 | \n",
193 | " Camptothecin | \n",
194 | " GDSC | \n",
195 | " negative_drug | \n",
196 | "
\n",
197 | " \n",
198 | " 1 | \n",
199 | " A673 | \n",
200 | " cc[c@@]1(o)c(=o)occ2c1cc1-c3nc4ccccc4cc3cn1c2=o | \n",
201 | " 0.172422 | \n",
202 | " Camptothecin | \n",
203 | " GDSC | \n",
204 | " negative_drug | \n",
205 | "
\n",
206 | " \n",
207 | " 2 | \n",
208 | " ES5 | \n",
209 | " cc[c@@]1(o)c(=o)occ2c1cc1-c3nc4ccccc4cc3cn1c2=o | \n",
210 | " 0.239133 | \n",
211 | " Camptothecin | \n",
212 | " GDSC | \n",
213 | " negative_drug | \n",
214 | "
\n",
215 | " \n",
216 | " 3 | \n",
217 | " ES7 | \n",
218 | " cc[c@@]1(o)c(=o)occ2c1cc1-c3nc4ccccc4cc3cn1c2=o | \n",
219 | " 0.164659 | \n",
220 | " Camptothecin | \n",
221 | " GDSC | \n",
222 | " negative_drug | \n",
223 | "
\n",
224 | " \n",
225 | " 4 | \n",
226 | " EW11 | \n",
227 | " cc[c@@]1(o)c(=o)occ2c1cc1-c3nc4ccccc4cc3cn1c2=o | \n",
228 | " 0.222290 | \n",
229 | " Camptothecin | \n",
230 | " GDSC | \n",
231 | " negative_drug | \n",
232 | "
\n",
233 | " \n",
234 | "
\n",
235 | "
"
236 | ],
237 | "text/plain": [
238 | " cell_line iv1 value \\\n",
239 | "0 PFSK1 cc[c@@]1(o)c(=o)occ2c1cc1-c3nc4ccccc4cc3cn1c2=o 0.323078 \n",
240 | "1 A673 cc[c@@]1(o)c(=o)occ2c1cc1-c3nc4ccccc4cc3cn1c2=o 0.172422 \n",
241 | "2 ES5 cc[c@@]1(o)c(=o)occ2c1cc1-c3nc4ccccc4cc3cn1c2=o 0.239133 \n",
242 | "3 ES7 cc[c@@]1(o)c(=o)occ2c1cc1-c3nc4ccccc4cc3cn1c2=o 0.164659 \n",
243 | "4 EW11 cc[c@@]1(o)c(=o)occ2c1cc1-c3nc4ccccc4cc3cn1c2=o 0.222290 \n",
244 | "\n",
245 | " iv_name phenotype iv2 \n",
246 | "0 Camptothecin GDSC negative_drug \n",
247 | "1 Camptothecin GDSC negative_drug \n",
248 | "2 Camptothecin GDSC negative_drug \n",
249 | "3 Camptothecin GDSC negative_drug \n",
250 | "4 Camptothecin GDSC negative_drug "
251 | ]
252 | },
253 | "execution_count": 5,
254 | "metadata": {},
255 | "output_type": "execute_result"
256 | }
257 | ],
258 | "source": [
259 | "gdsc_data_path = \"/lustre/groups/ml01/projects/super_rad_project/data/GDSC_notscaled_minmax.csv\"\n",
260 | "data_label = pd.read_csv(gdsc_data_path, index_col=0)\n",
261 | "\n",
262 | "data_label['iv2'] = 'negative_drug' # there is no second compound so we specify negative_drug so the token is masked\n",
263 | "data_label['phenotype'] = 'GDSC' # this is the label to use for this readout when you run inference later\n",
264 | "\n",
265 | "data_label = universal_processing(data_label)\n",
266 | "data_label.head()"
267 | ]
268 | },
269 | {
270 | "cell_type": "code",
271 | "execution_count": 6,
272 | "id": "bbf88222",
273 | "metadata": {
274 | "scrolled": true
275 | },
276 | "outputs": [
277 | {
278 | "name": "stdout",
279 | "output_type": "stream",
280 | "text": [
281 | "Removing 61 such as ['123138', '123829', '150412', '50869', '615590'] from ['iv1', 'iv2']. 393646 rows remaining.\n",
282 | "Removing 285 such as ['451LU', '7860', 'ALLPO', 'ARH77', 'ATN1'] from ['cell_line']. 280202 rows remaining.\n",
283 | "Fitting model.\n",
284 | "pytorch model, finetuning\n",
285 | "Gene net: Sequential(\n",
286 | " (0): Linear(in_features=1219, out_features=512, bias=True)\n",
287 | " (1): GELU(approximate='none')\n",
288 | " (2): Dropout(p=0.1, inplace=False)\n",
289 | " (3): Linear(in_features=512, out_features=512, bias=True)\n",
290 | ")\n",
291 | "Cell line net: Sequential(\n",
292 | " (0): Linear(in_features=300, out_features=512, bias=True)\n",
293 | " (1): GELU(approximate='none')\n",
294 | " (2): Dropout(p=0.1, inplace=False)\n",
295 | " (3): Linear(in_features=512, out_features=512, bias=True)\n",
296 | ")\n",
297 | "Regressor: Sequential(\n",
298 | " (0): Linear(in_features=512, out_features=512, bias=True)\n",
299 | " (1): GELU(approximate='none')\n",
300 | " (2): Dropout(p=0.2, inplace=False)\n",
301 | " (3): Linear(in_features=512, out_features=512, bias=True)\n",
302 | " (4): GELU(approximate='none')\n",
303 | " (5): Linear(in_features=512, out_features=1, bias=True)\n",
304 | ")\n"
305 | ]
306 | },
307 | {
308 | "name": "stderr",
309 | "output_type": "stream",
310 | "text": [
311 | "GPU available: True (cuda), used: True\n",
312 | "TPU available: False, using: 0 TPU cores\n",
313 | "IPU available: False, using: 0 IPUs\n",
314 | "HPU available: False, using: 0 HPUs\n"
315 | ]
316 | },
317 | {
318 | "name": "stdout",
319 | "output_type": "stream",
320 | "text": [
321 | "R2 average: False\n",
322 | "Running with early stopping: True\n",
323 | "Early stopping patience: 20\n"
324 | ]
325 | },
326 | {
327 | "name": "stderr",
328 | "output_type": "stream",
329 | "text": [
330 | "You are using a CUDA device ('NVIDIA A100-PCIE-40GB MIG 3g.20gb') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
331 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [MIG-c51c82f2-7e04-56a8-8b74-4211c9821715]\n",
332 | "\n",
333 | " | Name | Type | Params\n",
334 | "-----------------------------------------------------------\n",
335 | "0 | learnable_embedding | Embedding | 512 K \n",
336 | "1 | embedding_dropout | Dropout | 0 \n",
337 | "2 | gene_net | Sequential | 887 K \n",
338 | "3 | drug_net | Sequential | 887 K \n",
339 | "4 | cl_net | Sequential | 416 K \n",
340 | "5 | transformer | TransformerEncoder | 16.8 M\n",
341 | "6 | output_net | Sequential | 525 K \n",
342 | "-----------------------------------------------------------\n",
343 | "20.1 M Trainable params\n",
344 | "0 Non-trainable params\n",
345 | "20.1 M Total params\n",
346 | "80.206 Total estimated model params size (MB)\n"
347 | ]
348 | },
349 | {
350 | "name": "stdout",
351 | "output_type": "stream",
352 | "text": [
353 | "Epoch 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 124/124 [00:37<00:00, 3.26it/s, v_num=8, train_loss=0.119]\n",
354 | "Validation: | | 0/? [00:00, ?it/s]\u001b[A\n",
355 | "Validation: 0%| | 0/14 [00:00, ?it/s]\u001b[A\n",
356 | "Validation DataLoader 0: 0%| | 0/14 [00:00, ?it/s]\u001b[A\n",
357 | "Validation DataLoader 0: 7%|████████▋ | 1/14 [00:00<00:00, 50.76it/s]\u001b[A\n",
358 | "Validation DataLoader 0: 14%|█████████████████▎ | 2/14 [00:00<00:00, 22.06it/s]\u001b[A\n",
359 | "Validation DataLoader 0: 21%|█████████████████████████▉ | 3/14 [00:00<00:00, 18.37it/s]\u001b[A\n",
360 | "Validation DataLoader 0: 29%|██████████████████████████████████▌ | 4/14 [00:00<00:00, 16.87it/s]\u001b[A\n",
361 | "Validation DataLoader 0: 36%|███████████████████████████████████████████▏ | 5/14 [00:00<00:00, 16.11it/s]\u001b[A\n",
362 | "Validation DataLoader 0: 43%|███████████████████████████████████████████████████▊ | 6/14 [00:00<00:00, 15.55it/s]\u001b[A\n",
363 | "Validation DataLoader 0: 50%|████████████████████████████████████████████████████████████▌ | 7/14 [00:00<00:00, 15.30it/s]\u001b[A\n",
364 | "Validation DataLoader 0: 57%|█████████████████████████████████████████████████████████████████████▏ | 8/14 [00:00<00:00, 15.13it/s]\u001b[A\n",
365 | "Validation DataLoader 0: 64%|█████████████████████████████████████████████████████████████████████████████▊ | 9/14 [00:00<00:00, 14.98it/s]\u001b[A\n",
366 | "Validation DataLoader 0: 71%|█████████████████████████████████████████████████████████████████████████████████████▋ | 10/14 [00:00<00:00, 14.87it/s]\u001b[A\n",
367 | "Validation DataLoader 0: 79%|██████████████████████████████████████████████████████████████████████████████████████████████▎ | 11/14 [00:00<00:00, 14.78it/s]\u001b[A\n",
368 | "Validation DataLoader 0: 86%|██████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 12/14 [00:00<00:00, 14.71it/s]\u001b[A\n",
369 | "Validation DataLoader 0: 93%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 13/14 [00:00<00:00, 14.66it/s]\u001b[A\n",
370 | "Validation DataLoader 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 14.64it/s]\u001b[A\n",
371 | "Epoch 1: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 124/124 [00:38<00:00, 3.25it/s, v_num=8, train_loss=0.0325]\u001b[A\n",
372 | "Validation: | | 0/? [00:00, ?it/s]\u001b[A\n",
373 | "Validation: 0%| | 0/14 [00:00, ?it/s]\u001b[A\n",
374 | "Validation DataLoader 0: 0%| | 0/14 [00:00, ?it/s]\u001b[A\n",
375 | "Validation DataLoader 0: 7%|████████▋ | 1/14 [00:00<00:00, 23.63it/s]\u001b[A\n",
376 | "Validation DataLoader 0: 14%|█████████████████▎ | 2/14 [00:00<00:00, 20.51it/s]\u001b[A\n",
377 | "Validation DataLoader 0: 21%|█████████████████████████▉ | 3/14 [00:00<00:00, 16.10it/s]\u001b[A\n",
378 | "Validation DataLoader 0: 29%|██████████████████████████████████▌ | 4/14 [00:00<00:00, 15.03it/s]\u001b[A\n",
379 | "Validation DataLoader 0: 36%|███████████████████████████████████████████▏ | 5/14 [00:00<00:00, 14.41it/s]\u001b[A\n",
380 | "Validation DataLoader 0: 43%|███████████████████████████████████████████████████▊ | 6/14 [00:00<00:00, 10.26it/s]\u001b[A\n",
381 | "Validation DataLoader 0: 50%|████████████████████████████████████████████████████████████▌ | 7/14 [00:00<00:00, 10.51it/s]\u001b[A\n",
382 | "Validation DataLoader 0: 57%|█████████████████████████████████████████████████████████████████████▏ | 8/14 [00:00<00:00, 10.71it/s]\u001b[A\n",
383 | "Validation DataLoader 0: 64%|█████████████████████████████████████████████████████████████████████████████▊ | 9/14 [00:00<00:00, 10.95it/s]\u001b[A\n",
384 | "Validation DataLoader 0: 71%|█████████████████████████████████████████████████████████████████████████████████████▋ | 10/14 [00:01<00:00, 9.85it/s]\u001b[A\n",
385 | "Validation DataLoader 0: 79%|██████████████████████████████████████████████████████████████████████████████████████████████▎ | 11/14 [00:01<00:00, 10.07it/s]\u001b[A\n",
386 | "Validation DataLoader 0: 86%|██████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 12/14 [00:01<00:00, 10.24it/s]\u001b[A\n",
387 | "Validation DataLoader 0: 93%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 13/14 [00:01<00:00, 10.38it/s]\u001b[A\n",
388 | "Validation DataLoader 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:01<00:00, 10.53it/s]\u001b[A\n",
389 | "Epoch 1: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 124/124 [00:40<00:00, 3.10it/s, v_num=8, train_loss=0.0325]\u001b[A"
390 | ]
391 | },
392 | {
393 | "name": "stderr",
394 | "output_type": "stream",
395 | "text": [
396 | "`Trainer.fit` stopped: `max_epochs=2` reached.\n"
397 | ]
398 | },
399 | {
400 | "name": "stdout",
401 | "output_type": "stream",
402 | "text": [
403 | "Epoch 1: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 124/124 [00:41<00:00, 3.00it/s, v_num=8, train_loss=0.0325]\n"
404 | ]
405 | }
406 | ],
407 | "source": [
408 | "config.dirpath = './ckpts_GDSC/' # set the save directory here; you can also set it in the config file\n",
409 | "model.train(\n",
410 | " data_label,\n",
411 | " iv_col = ['iv1', 'iv2'],\n",
412 | " cl_col = 'cell_line',\n",
413 | " ph_col = 'phenotype',\n",
414 | " model_config = config\n",
415 | ")"
416 | ]
417 | },
418 | {
419 | "cell_type": "markdown",
420 | "id": "b0d4c109",
421 | "metadata": {},
422 | "source": [
423 | "### TieuQi\n",
424 | "\n",
425 | "This is a T-cell proliferation dataset from https://doi.org/10.1016/j.cell.2024.01.035.\n",
426 | "\n",
427 | "- cell state: T-cell proliferation\n",
428 | "- intervention: combinatorial CRISPRi\n",
429 | "- readout: Log2FC of CD8+ T-cells vs. control pDNA"
430 | ]
431 | },
432 | {
433 | "cell_type": "code",
434 | "execution_count": 7,
435 | "id": "8a968d5d",
436 | "metadata": {},
437 | "outputs": [
438 | {
439 | "data": {
440 | "text/html": [
441 | "\n",
442 | "\n",
455 | "
\n",
456 | " \n",
457 | " \n",
458 | " | \n",
459 | " iv1 | \n",
460 | " iv2 | \n",
461 | " value | \n",
462 | " cell_line | \n",
463 | " phenotype | \n",
464 | "
\n",
465 | " \n",
466 | " \n",
467 | " \n",
468 | " 0 | \n",
469 | " batf3 | \n",
470 | " batf3 | \n",
471 | " 0.633443 | \n",
472 | " JURKAT | \n",
473 | " T-cell_viability | \n",
474 | "
\n",
475 | " \n",
476 | " 1 | \n",
477 | " batf3 | \n",
478 | " cblb | \n",
479 | " 0.544929 | \n",
480 | " JURKAT | \n",
481 | " T-cell_viability | \n",
482 | "
\n",
483 | " \n",
484 | " 2 | \n",
485 | " batf3 | \n",
486 | " ctla4 | \n",
487 | " 0.483662 | \n",
488 | " JURKAT | \n",
489 | " T-cell_viability | \n",
490 | "
\n",
491 | " \n",
492 | " 3 | \n",
493 | " batf3 | \n",
494 | " dhx37 | \n",
495 | " 0.693173 | \n",
496 | " JURKAT | \n",
497 | " T-cell_viability | \n",
498 | "
\n",
499 | " \n",
500 | " 4 | \n",
501 | " batf3 | \n",
502 | " fas | \n",
503 | " 0.594995 | \n",
504 | " JURKAT | \n",
505 | " T-cell_viability | \n",
506 | "
\n",
507 | " \n",
508 | " ... | \n",
509 | " ... | \n",
510 | " ... | \n",
511 | " ... | \n",
512 | " ... | \n",
513 | " ... | \n",
514 | "
\n",
515 | " \n",
516 | " 1245 | \n",
517 | " tigit | \n",
518 | " zc3h12a | \n",
519 | " 0.604678 | \n",
520 | " JURKAT | \n",
521 | " T-cell_viability | \n",
522 | "
\n",
523 | " \n",
524 | " 1246 | \n",
525 | " tox | \n",
526 | " zc3h12a | \n",
527 | " 0.596647 | \n",
528 | " JURKAT | \n",
529 | " T-cell_viability | \n",
530 | "
\n",
531 | " \n",
532 | " 1247 | \n",
533 | " tox2 | \n",
534 | " zc3h12a | \n",
535 | " 0.418890 | \n",
536 | " JURKAT | \n",
537 | " T-cell_viability | \n",
538 | "
\n",
539 | " \n",
540 | " 1248 | \n",
541 | " trac | \n",
542 | " zc3h12a | \n",
543 | " 0.625894 | \n",
544 | " JURKAT | \n",
545 | " T-cell_viability | \n",
546 | "
\n",
547 | " \n",
548 | " 1249 | \n",
549 | " zc3h12a | \n",
550 | " zc3h12a | \n",
551 | " 0.490973 | \n",
552 | " JURKAT | \n",
553 | " T-cell_viability | \n",
554 | "
\n",
555 | " \n",
556 | "
\n",
557 | "
1250 rows × 5 columns
\n",
558 | "
"
559 | ],
560 | "text/plain": [
561 | " iv1 iv2 value cell_line phenotype\n",
562 | "0 batf3 batf3 0.633443 JURKAT T-cell_viability\n",
563 | "1 batf3 cblb 0.544929 JURKAT T-cell_viability\n",
564 | "2 batf3 ctla4 0.483662 JURKAT T-cell_viability\n",
565 | "3 batf3 dhx37 0.693173 JURKAT T-cell_viability\n",
566 | "4 batf3 fas 0.594995 JURKAT T-cell_viability\n",
567 | "... ... ... ... ... ...\n",
568 | "1245 tigit zc3h12a 0.604678 JURKAT T-cell_viability\n",
569 | "1246 tox zc3h12a 0.596647 JURKAT T-cell_viability\n",
570 | "1247 tox2 zc3h12a 0.418890 JURKAT T-cell_viability\n",
571 | "1248 trac zc3h12a 0.625894 JURKAT T-cell_viability\n",
572 | "1249 zc3h12a zc3h12a 0.490973 JURKAT T-cell_viability\n",
573 | "\n",
574 | "[1250 rows x 5 columns]"
575 | ]
576 | },
577 | "execution_count": 7,
578 | "metadata": {},
579 | "output_type": "execute_result"
580 | }
581 | ],
582 | "source": [
583 | "tieu_data_path = \"/ictstr01/home/icb/yuge.ji/projects/super_rad_project/data/TieuQi_lfc_day11_minmax.csv\"\n",
584 | "data_label = pd.read_csv(tieu_data_path, index_col=0)\n",
585 | "\n",
586 | "data_label['phenotype'] = 'T-cell_viability' # this is the label to use for this readout when you run inference later\n",
587 | "\n",
588 | "data_label = universal_processing(data_label)\n",
589 | "data_label"
590 | ]
591 | },
592 | {
593 | "cell_type": "code",
594 | "execution_count": 8,
595 | "id": "1b5c884c",
596 | "metadata": {},
597 | "outputs": [
598 | {
599 | "name": "stdout",
600 | "output_type": "stream",
601 | "text": [
602 | "Removing 4 such as ['tceb2', 'tet2', 'tigit', 'trac'] from ['iv1', 'iv2']. 861 rows remaining.\n",
603 | "Removing 0 such as [] from ['cell_line']. 861 rows remaining.\n",
604 | "Fitting model.\n",
605 | "pytorch model, finetuning\n",
606 | "Gene net: Sequential(\n",
607 | " (0): Linear(in_features=1219, out_features=512, bias=True)\n",
608 | " (1): GELU(approximate='none')\n",
609 | " (2): Dropout(p=0.1, inplace=False)\n",
610 | " (3): Linear(in_features=512, out_features=512, bias=True)\n",
611 | ")\n",
612 | "Cell line net: Sequential(\n",
613 | " (0): Linear(in_features=300, out_features=512, bias=True)\n",
614 | " (1): GELU(approximate='none')\n",
615 | " (2): Dropout(p=0.1, inplace=False)\n",
616 | " (3): Linear(in_features=512, out_features=512, bias=True)\n",
617 | ")\n",
618 | "Regressor: Sequential(\n",
619 | " (0): Linear(in_features=512, out_features=512, bias=True)\n",
620 | " (1): GELU(approximate='none')\n",
621 | " (2): Dropout(p=0.2, inplace=False)\n",
622 | " (3): Linear(in_features=512, out_features=512, bias=True)\n",
623 | " (4): GELU(approximate='none')\n",
624 | " (5): Linear(in_features=512, out_features=1, bias=True)\n",
625 | ")\n"
626 | ]
627 | },
628 | {
629 | "name": "stderr",
630 | "output_type": "stream",
631 | "text": [
632 | "GPU available: True (cuda), used: True\n",
633 | "TPU available: False, using: 0 TPU cores\n",
634 | "IPU available: False, using: 0 IPUs\n",
635 | "HPU available: False, using: 0 HPUs\n",
636 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [MIG-c51c82f2-7e04-56a8-8b74-4211c9821715]\n",
637 | "\n",
638 | " | Name | Type | Params\n",
639 | "-----------------------------------------------------------\n",
640 | "0 | learnable_embedding | Embedding | 512 K \n",
641 | "1 | embedding_dropout | Dropout | 0 \n",
642 | "2 | gene_net | Sequential | 887 K \n",
643 | "3 | drug_net | Sequential | 887 K \n",
644 | "4 | cl_net | Sequential | 416 K \n",
645 | "5 | transformer | TransformerEncoder | 16.8 M\n",
646 | "6 | output_net | Sequential | 525 K \n",
647 | "-----------------------------------------------------------\n",
648 | "20.1 M Trainable params\n",
649 | "0 Non-trainable params\n",
650 | "20.1 M Total params\n",
651 | "80.206 Total estimated model params size (MB)\n"
652 | ]
653 | },
654 | {
655 | "name": "stdout",
656 | "output_type": "stream",
657 | "text": [
658 | "R2 average: False\n",
659 | "Running with early stopping: True\n",
660 | "Early stopping patience: 20\n",
661 | "Epoch 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 2.07it/s, v_num=9, train_loss=0.638]\n",
662 | "Validation: | | 0/? [00:00, ?it/s]\u001b[A\n",
663 | "Validation: 0%| | 0/1 [00:00, ?it/s]\u001b[A\n",
664 | "Validation DataLoader 0: 0%| | 0/1 [00:00, ?it/s]\u001b[A\n",
665 | "Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 41.25it/s]\u001b[A\n",
666 | "Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.01it/s, v_num=9, train_loss=0.592]\u001b[A\n",
667 | "Validation: | | 0/? [00:00, ?it/s]\u001b[A\n",
668 | "Validation: 0%| | 0/1 [00:00, ?it/s]\u001b[A\n",
669 | "Validation DataLoader 0: 0%| | 0/1 [00:00, ?it/s]\u001b[A\n",
670 | "Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 40.94it/s]\u001b[A\n",
671 | "Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00, 0.65it/s, v_num=9, train_loss=0.592]\u001b[A"
672 | ]
673 | },
674 | {
675 | "name": "stderr",
676 | "output_type": "stream",
677 | "text": [
678 | "`Trainer.fit` stopped: `max_epochs=2` reached.\n"
679 | ]
680 | },
681 | {
682 | "name": "stdout",
683 | "output_type": "stream",
684 | "text": [
685 | "Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00, 0.31it/s, v_num=9, train_loss=0.592]\n"
686 | ]
687 | }
688 | ],
689 | "source": [
690 | "config.dirpath = './ckpts_TieuQi/'\n",
691 | "model.train(\n",
692 | " data_label,\n",
693 | " iv_col = ['iv1', 'iv2'],\n",
694 | " cl_col = 'cell_line',\n",
695 | " ph_col = 'phenotype',\n",
696 | " model_config = config\n",
697 | ")"
698 | ]
699 | },
700 | {
701 | "cell_type": "markdown",
702 | "id": "d89e3040",
703 | "metadata": {},
704 | "source": [
705 | "## Loading in a model checkpoint"
706 | ]
707 | },
708 | {
709 | "cell_type": "code",
710 | "execution_count": 9,
711 | "id": "42e32a24",
712 | "metadata": {},
713 | "outputs": [
714 | {
715 | "name": "stdout",
716 | "output_type": "stream",
717 | "text": [
718 | "returning trained model!\n",
719 | "Gene net: Sequential(\n",
720 | " (0): Linear(in_features=1219, out_features=512, bias=True)\n",
721 | " (1): GELU(approximate='none')\n",
722 | " (2): Dropout(p=0.1, inplace=False)\n",
723 | " (3): Linear(in_features=512, out_features=512, bias=True)\n",
724 | ")\n",
725 | "Cell line net: Sequential(\n",
726 | " (0): Linear(in_features=300, out_features=512, bias=True)\n",
727 | " (1): GELU(approximate='none')\n",
728 | " (2): Dropout(p=0.1, inplace=False)\n",
729 | " (3): Linear(in_features=512, out_features=512, bias=True)\n",
730 | ")\n",
731 | "Regressor: Sequential(\n",
732 | " (0): Linear(in_features=512, out_features=512, bias=True)\n",
733 | " (1): GELU(approximate='none')\n",
734 | " (2): Dropout(p=0.2, inplace=False)\n",
735 | " (3): Linear(in_features=512, out_features=512, bias=True)\n",
736 | " (4): GELU(approximate='none')\n",
737 | " (5): Linear(in_features=512, out_features=1, bias=True)\n",
738 | ")\n"
739 | ]
740 | }
741 | ],
742 | "source": [
743 | "pretrained_checkpoint_path = \"./ckpts_TieuQi/epoch=0-step=1.ckpt\"\n",
744 | "model = Prophet(\n",
745 | " iv_emb_path=config.genes_prior,\n",
746 | " cl_emb_path=config.cell_lines_prior,\n",
747 | " ph_emb_path=None,\n",
748 | " model_pth=pretrained_checkpoint_path,\n",
749 | ")"
750 | ]
751 | },
752 | {
753 | "cell_type": "markdown",
754 | "id": "058a1a2e",
755 | "metadata": {},
756 | "source": [
757 | "As you can see, this checkpoint can now be loaded for inference! See other notebooks in `tutorials` for how to perform inference, or copy this notebook for more finetuning."
758 | ]
759 | }
760 | ],
761 | "metadata": {
762 | "kernelspec": {
763 | "display_name": "Python [conda env:prophet_api]",
764 | "language": "python",
765 | "name": "conda-env-prophet_api-py"
766 | },
767 | "language_info": {
768 | "codemirror_mode": {
769 | "name": "ipython",
770 | "version": 3
771 | },
772 | "file_extension": ".py",
773 | "mimetype": "text/x-python",
774 | "name": "python",
775 | "nbconvert_exporter": "python",
776 | "pygments_lexer": "ipython3",
777 | "version": "3.12.4"
778 | }
779 | },
780 | "nbformat": 4,
781 | "nbformat_minor": 5
782 | }
783 |
--------------------------------------------------------------------------------
/tutorials/insilico_screening.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "f95e67be",
6 | "metadata": {},
7 | "source": [
8 | "# massive in silico screening with Prophet\n",
9 | "\n",
10 | "This notebook demonstrates how to make predictions with Prophet with any of the checkpoints we have made available."
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 1,
16 | "id": "66127741",
17 | "metadata": {
18 | "scrolled": false
19 | },
20 | "outputs": [],
21 | "source": [
22 | "import pandas as pd\n",
23 | "import yaml\n",
24 | "from prophet import Prophet\n",
25 | "from prophet.config import set_config"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "id": "222c3eeb",
31 | "metadata": {},
32 | "source": [
33 | "Load in the same config file that was used for finetuning, for the embedding files."
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 2,
39 | "id": "0e0b21b1",
40 | "metadata": {},
41 | "outputs": [],
42 | "source": [
43 | "with open('config_file_finetuning.yaml', 'r') as f:\n",
44 | " config = set_config(yaml.safe_load(f))"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": 3,
50 | "id": "42e32a24",
51 | "metadata": {},
52 | "outputs": [
53 | {
54 | "name": "stdout",
55 | "output_type": "stream",
56 | "text": [
57 | "returning trained model!\n",
58 | "Gene net: Sequential(\n",
59 | " (0): Linear(in_features=1219, out_features=512, bias=True)\n",
60 | " (1): GELU(approximate='none')\n",
61 | " (2): Dropout(p=0.1, inplace=False)\n",
62 | " (3): Linear(in_features=512, out_features=512, bias=True)\n",
63 | ")\n",
64 | "Cell line net: Sequential(\n",
65 | " (0): Linear(in_features=300, out_features=512, bias=True)\n",
66 | " (1): GELU(approximate='none')\n",
67 | " (2): Dropout(p=0.1, inplace=False)\n",
68 | " (3): Linear(in_features=512, out_features=512, bias=True)\n",
69 | ")\n",
70 | "Regressor: Sequential(\n",
71 | " (0): Linear(in_features=512, out_features=512, bias=True)\n",
72 | " (1): GELU(approximate='none')\n",
73 | " (2): Dropout(p=0.2, inplace=False)\n",
74 | " (3): Linear(in_features=512, out_features=512, bias=True)\n",
75 | " (4): GELU(approximate='none')\n",
76 | " (5): Linear(in_features=512, out_features=1, bias=True)\n",
77 | ")\n"
78 | ]
79 | }
80 | ],
81 | "source": [
82 | "pretrained_checkpoint_path = \"./ckpts/epoch=1-step=248.ckpt\"\n",
83 | "model = Prophet(\n",
84 | " iv_emb_path=config.genes_prior,\n",
85 | " cl_emb_path=config.cell_lines_prior,\n",
86 | " ph_emb_path=None,\n",
87 | " model_pth=pretrained_checkpoint_path,\n",
88 | ")"
89 | ]
90 | },
91 | {
92 | "cell_type": "markdown",
93 | "id": "ff7ca95e",
94 | "metadata": {},
95 | "source": [
96 | "Suppose we have some small molecules, some cell lines we would like to test them in, and we're interested in measuring their relative IC50. We can pass in lists of these inputs, and Prophet will return predictions for all combinations:"
97 | ]
98 | },
99 | {
100 | "cell_type": "markdown",
101 | "id": "e5d1e7cc",
102 | "metadata": {},
103 | "source": [
104 | "### Making predictions by passing all treatments, cell lines, and phenotypes you want to run\n",
105 | "\n",
106 | "This format can be useful when running large combinatorial screens in silico, as it splits the experiments up into batches to help prevent memory errors."
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": 4,
112 | "id": "49b65b30",
113 | "metadata": {},
114 | "outputs": [],
115 | "source": [
116 | "iv_list = [\n",
117 | " 'oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc4=c(c=c(c=c4)i)f)=o',\n",
118 | " 'cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc5=cc=c(c=c5f)i)=c(c4=o)c)=c1)=o',\n",
119 | " 'fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c=c4)cl)=o)ns(ccc)(=o)=o',\n",
120 | " 'cs(=o)c' # DMSO\n",
121 | "]\n",
122 | "cl_list = ['A375', 'UACC62', 'WM983B', 'MALME3M', 'A2058', 'WM793', 'HT144', 'RPMI7951', 'SKMEL2', 'SKMEL1', 'HMCB', 'MDAMB435S', 'WM1799', 'LOXIMVI']\n",
123 | "ph_list = ['GDSC']"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": 5,
129 | "id": "b56b1e46",
130 | "metadata": {},
131 | "outputs": [
132 | {
133 | "name": "stdout",
134 | "output_type": "stream",
135 | "text": [
136 | "There are 1 iterations\n"
137 | ]
138 | },
139 | {
140 | "name": "stderr",
141 | "output_type": "stream",
142 | "text": [
143 | "\r",
144 | " 0%| | 0/1 [00:00, ?it/s]"
145 | ]
146 | },
147 | {
148 | "name": "stdout",
149 | "output_type": "stream",
150 | "text": [
151 | "Removing 0 such as [] from ['iv1', 'iv2']. 140 rows remaining.\n",
152 | "Removing 0 such as [] from ['cell_line']. 140 rows remaining.\n"
153 | ]
154 | },
155 | {
156 | "name": "stderr",
157 | "output_type": "stream",
158 | "text": [
159 | "GPU available: True (cuda), used: True\n",
160 | "TPU available: False, using: 0 TPU cores\n",
161 | "IPU available: False, using: 0 IPUs\n",
162 | "HPU available: False, using: 0 HPUs\n",
163 | "You are using a CUDA device ('NVIDIA A100-PCIE-40GB MIG 3g.20gb') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
164 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [MIG-c51c82f2-7e04-56a8-8b74-4211c9821715]\n"
165 | ]
166 | },
167 | {
168 | "name": "stdout",
169 | "output_type": "stream",
170 | "text": [
171 | "Predicting DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00, 0.94it/s]\n"
172 | ]
173 | },
174 | {
175 | "name": "stderr",
176 | "output_type": "stream",
177 | "text": [
178 | "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:06<00:00, 6.88s/it]\n"
179 | ]
180 | },
181 | {
182 | "data": {
183 | "text/html": [
184 | "\n",
185 | "\n",
198 | "
\n",
199 | " \n",
200 | " \n",
201 | " | \n",
202 | " iv1 | \n",
203 | " iv2 | \n",
204 | " cell_line | \n",
205 | " phenotype | \n",
206 | " iv1+iv2 | \n",
207 | " value | \n",
208 | " pred | \n",
209 | "
\n",
210 | " \n",
211 | " \n",
212 | " \n",
213 | " 0 | \n",
214 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
215 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
216 | " A375 | \n",
217 | " GDSC | \n",
218 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
219 | " _ | \n",
220 | " 0.480913 | \n",
221 | "
\n",
222 | " \n",
223 | " 1 | \n",
224 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
225 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
226 | " UACC62 | \n",
227 | " GDSC | \n",
228 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
229 | " _ | \n",
230 | " 0.484669 | \n",
231 | "
\n",
232 | " \n",
233 | " 2 | \n",
234 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
235 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
236 | " WM983B | \n",
237 | " GDSC | \n",
238 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
239 | " _ | \n",
240 | " 0.491997 | \n",
241 | "
\n",
242 | " \n",
243 | " 3 | \n",
244 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
245 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
246 | " MALME3M | \n",
247 | " GDSC | \n",
248 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
249 | " _ | \n",
250 | " 0.473723 | \n",
251 | "
\n",
252 | " \n",
253 | " 4 | \n",
254 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
255 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
256 | " A2058 | \n",
257 | " GDSC | \n",
258 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
259 | " _ | \n",
260 | " 0.489981 | \n",
261 | "
\n",
262 | " \n",
263 | " ... | \n",
264 | " ... | \n",
265 | " ... | \n",
266 | " ... | \n",
267 | " ... | \n",
268 | " ... | \n",
269 | " ... | \n",
270 | " ... | \n",
271 | "
\n",
272 | " \n",
273 | " 135 | \n",
274 | " cs(=o)c | \n",
275 | " cs(=o)c | \n",
276 | " SKMEL1 | \n",
277 | " GDSC | \n",
278 | " cs(=o)c+cs(=o)c | \n",
279 | " _ | \n",
280 | " 0.513998 | \n",
281 | "
\n",
282 | " \n",
283 | " 136 | \n",
284 | " cs(=o)c | \n",
285 | " cs(=o)c | \n",
286 | " HMCB | \n",
287 | " GDSC | \n",
288 | " cs(=o)c+cs(=o)c | \n",
289 | " _ | \n",
290 | " 0.540128 | \n",
291 | "
\n",
292 | " \n",
293 | " 137 | \n",
294 | " cs(=o)c | \n",
295 | " cs(=o)c | \n",
296 | " MDAMB435S | \n",
297 | " GDSC | \n",
298 | " cs(=o)c+cs(=o)c | \n",
299 | " _ | \n",
300 | " 0.538795 | \n",
301 | "
\n",
302 | " \n",
303 | " 138 | \n",
304 | " cs(=o)c | \n",
305 | " cs(=o)c | \n",
306 | " WM1799 | \n",
307 | " GDSC | \n",
308 | " cs(=o)c+cs(=o)c | \n",
309 | " _ | \n",
310 | " 0.526665 | \n",
311 | "
\n",
312 | " \n",
313 | " 139 | \n",
314 | " cs(=o)c | \n",
315 | " cs(=o)c | \n",
316 | " LOXIMVI | \n",
317 | " GDSC | \n",
318 | " cs(=o)c+cs(=o)c | \n",
319 | " _ | \n",
320 | " 0.528534 | \n",
321 | "
\n",
322 | " \n",
323 | "
\n",
324 | "
140 rows × 7 columns
\n",
325 | "
"
326 | ],
327 | "text/plain": [
328 | " iv1 \\\n",
329 | "0 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... \n",
330 | "1 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... \n",
331 | "2 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... \n",
332 | "3 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... \n",
333 | "4 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... \n",
334 | ".. ... \n",
335 | "135 cs(=o)c \n",
336 | "136 cs(=o)c \n",
337 | "137 cs(=o)c \n",
338 | "138 cs(=o)c \n",
339 | "139 cs(=o)c \n",
340 | "\n",
341 | " iv2 cell_line phenotype \\\n",
342 | "0 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... A375 GDSC \n",
343 | "1 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... UACC62 GDSC \n",
344 | "2 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... WM983B GDSC \n",
345 | "3 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... MALME3M GDSC \n",
346 | "4 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... A2058 GDSC \n",
347 | ".. ... ... ... \n",
348 | "135 cs(=o)c SKMEL1 GDSC \n",
349 | "136 cs(=o)c HMCB GDSC \n",
350 | "137 cs(=o)c MDAMB435S GDSC \n",
351 | "138 cs(=o)c WM1799 GDSC \n",
352 | "139 cs(=o)c LOXIMVI GDSC \n",
353 | "\n",
354 | " iv1+iv2 value pred \n",
355 | "0 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... _ 0.480913 \n",
356 | "1 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... _ 0.484669 \n",
357 | "2 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... _ 0.491997 \n",
358 | "3 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... _ 0.473723 \n",
359 | "4 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... _ 0.489981 \n",
360 | ".. ... ... ... \n",
361 | "135 cs(=o)c+cs(=o)c _ 0.513998 \n",
362 | "136 cs(=o)c+cs(=o)c _ 0.540128 \n",
363 | "137 cs(=o)c+cs(=o)c _ 0.538795 \n",
364 | "138 cs(=o)c+cs(=o)c _ 0.526665 \n",
365 | "139 cs(=o)c+cs(=o)c _ 0.528534 \n",
366 | "\n",
367 | "[140 rows x 7 columns]"
368 | ]
369 | },
370 | "execution_count": 5,
371 | "metadata": {},
372 | "output_type": "execute_result"
373 | }
374 | ],
375 | "source": [
376 | "# predict with lists of treatments and cell lines\n",
377 | "df = model.predict(\n",
378 | " target_ivs=iv_list,\n",
379 | " target_cls=cl_list,\n",
380 | " target_phs=ph_list,\n",
381 | " iv_col=['iv1', 'iv2'], # pass to turn on combinatorial predictions\n",
382 | " num_iterations=1, save=False,\n",
383 | ")\n",
384 | "df"
385 | ]
386 | },
387 | {
388 | "cell_type": "markdown",
389 | "id": "32b4e1ad",
390 | "metadata": {},
391 | "source": [
392 | "### Making predictions for a specific set of treatments, cell lines, and phenotypes\n",
393 | "\n",
394 | "If we're interested in only a subset of the experimental matrix, we can also pass in a custom dataframe. (This is the recommended usage, as users understand exactly the list being predicted.)"
395 | ]
396 | },
397 | {
398 | "cell_type": "code",
399 | "execution_count": 6,
400 | "id": "8b0baefd",
401 | "metadata": {
402 | "scrolled": true
403 | },
404 | "outputs": [
405 | {
406 | "data": {
407 | "text/html": [
408 | "\n",
409 | "\n",
422 | "
\n",
423 | " \n",
424 | " \n",
425 | " | \n",
426 | " iv1 | \n",
427 | " cell_line | \n",
428 | " iv2 | \n",
429 | " phenotype | \n",
430 | "
\n",
431 | " \n",
432 | " \n",
433 | " \n",
434 | " 0 | \n",
435 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
436 | " A375 | \n",
437 | " cs(=o)c | \n",
438 | " GDSC | \n",
439 | "
\n",
440 | " \n",
441 | " 1 | \n",
442 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
443 | " UACC62 | \n",
444 | " cs(=o)c | \n",
445 | " GDSC | \n",
446 | "
\n",
447 | " \n",
448 | " 2 | \n",
449 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
450 | " WM983B | \n",
451 | " cs(=o)c | \n",
452 | " GDSC | \n",
453 | "
\n",
454 | " \n",
455 | " 3 | \n",
456 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
457 | " MALME3M | \n",
458 | " cs(=o)c | \n",
459 | " GDSC | \n",
460 | "
\n",
461 | " \n",
462 | " 4 | \n",
463 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
464 | " A2058 | \n",
465 | " cs(=o)c | \n",
466 | " GDSC | \n",
467 | "
\n",
468 | " \n",
469 | " 5 | \n",
470 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
471 | " WM793 | \n",
472 | " cs(=o)c | \n",
473 | " GDSC | \n",
474 | "
\n",
475 | " \n",
476 | " 6 | \n",
477 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
478 | " HT144 | \n",
479 | " cs(=o)c | \n",
480 | " GDSC | \n",
481 | "
\n",
482 | " \n",
483 | " 7 | \n",
484 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
485 | " RPMI7951 | \n",
486 | " cs(=o)c | \n",
487 | " GDSC | \n",
488 | "
\n",
489 | " \n",
490 | " 8 | \n",
491 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
492 | " SKMEL2 | \n",
493 | " cs(=o)c | \n",
494 | " GDSC | \n",
495 | "
\n",
496 | " \n",
497 | " 9 | \n",
498 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
499 | " SKMEL1 | \n",
500 | " cs(=o)c | \n",
501 | " GDSC | \n",
502 | "
\n",
503 | " \n",
504 | " 10 | \n",
505 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
506 | " HMCB | \n",
507 | " cs(=o)c | \n",
508 | " GDSC | \n",
509 | "
\n",
510 | " \n",
511 | " 11 | \n",
512 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
513 | " MDAMB435S | \n",
514 | " cs(=o)c | \n",
515 | " GDSC | \n",
516 | "
\n",
517 | " \n",
518 | " 12 | \n",
519 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
520 | " WM1799 | \n",
521 | " cs(=o)c | \n",
522 | " GDSC | \n",
523 | "
\n",
524 | " \n",
525 | " 13 | \n",
526 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
527 | " LOXIMVI | \n",
528 | " cs(=o)c | \n",
529 | " GDSC | \n",
530 | "
\n",
531 | " \n",
532 | " 14 | \n",
533 | " cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... | \n",
534 | " A375 | \n",
535 | " cs(=o)c | \n",
536 | " GDSC | \n",
537 | "
\n",
538 | " \n",
539 | " 15 | \n",
540 | " cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... | \n",
541 | " UACC62 | \n",
542 | " cs(=o)c | \n",
543 | " GDSC | \n",
544 | "
\n",
545 | " \n",
546 | " 16 | \n",
547 | " cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... | \n",
548 | " WM983B | \n",
549 | " cs(=o)c | \n",
550 | " GDSC | \n",
551 | "
\n",
552 | " \n",
553 | " 17 | \n",
554 | " cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... | \n",
555 | " MALME3M | \n",
556 | " cs(=o)c | \n",
557 | " GDSC | \n",
558 | "
\n",
559 | " \n",
560 | " 18 | \n",
561 | " cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... | \n",
562 | " A2058 | \n",
563 | " cs(=o)c | \n",
564 | " GDSC | \n",
565 | "
\n",
566 | " \n",
567 | " 19 | \n",
568 | " cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... | \n",
569 | " WM793 | \n",
570 | " cs(=o)c | \n",
571 | " GDSC | \n",
572 | "
\n",
573 | " \n",
574 | " 20 | \n",
575 | " cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... | \n",
576 | " HT144 | \n",
577 | " cs(=o)c | \n",
578 | " GDSC | \n",
579 | "
\n",
580 | " \n",
581 | " 21 | \n",
582 | " cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... | \n",
583 | " RPMI7951 | \n",
584 | " cs(=o)c | \n",
585 | " GDSC | \n",
586 | "
\n",
587 | " \n",
588 | " 22 | \n",
589 | " cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... | \n",
590 | " SKMEL2 | \n",
591 | " cs(=o)c | \n",
592 | " GDSC | \n",
593 | "
\n",
594 | " \n",
595 | " 23 | \n",
596 | " cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... | \n",
597 | " SKMEL1 | \n",
598 | " cs(=o)c | \n",
599 | " GDSC | \n",
600 | "
\n",
601 | " \n",
602 | " 24 | \n",
603 | " cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... | \n",
604 | " HMCB | \n",
605 | " cs(=o)c | \n",
606 | " GDSC | \n",
607 | "
\n",
608 | " \n",
609 | " 25 | \n",
610 | " cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... | \n",
611 | " MDAMB435S | \n",
612 | " cs(=o)c | \n",
613 | " GDSC | \n",
614 | "
\n",
615 | " \n",
616 | " 26 | \n",
617 | " cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... | \n",
618 | " WM1799 | \n",
619 | " cs(=o)c | \n",
620 | " GDSC | \n",
621 | "
\n",
622 | " \n",
623 | " 27 | \n",
624 | " cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... | \n",
625 | " LOXIMVI | \n",
626 | " cs(=o)c | \n",
627 | " GDSC | \n",
628 | "
\n",
629 | " \n",
630 | " 28 | \n",
631 | " fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... | \n",
632 | " A375 | \n",
633 | " cs(=o)c | \n",
634 | " GDSC | \n",
635 | "
\n",
636 | " \n",
637 | " 29 | \n",
638 | " fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... | \n",
639 | " UACC62 | \n",
640 | " cs(=o)c | \n",
641 | " GDSC | \n",
642 | "
\n",
643 | " \n",
644 | " 30 | \n",
645 | " fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... | \n",
646 | " WM983B | \n",
647 | " cs(=o)c | \n",
648 | " GDSC | \n",
649 | "
\n",
650 | " \n",
651 | " 31 | \n",
652 | " fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... | \n",
653 | " MALME3M | \n",
654 | " cs(=o)c | \n",
655 | " GDSC | \n",
656 | "
\n",
657 | " \n",
658 | " 32 | \n",
659 | " fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... | \n",
660 | " A2058 | \n",
661 | " cs(=o)c | \n",
662 | " GDSC | \n",
663 | "
\n",
664 | " \n",
665 | " 33 | \n",
666 | " fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... | \n",
667 | " WM793 | \n",
668 | " cs(=o)c | \n",
669 | " GDSC | \n",
670 | "
\n",
671 | " \n",
672 | " 34 | \n",
673 | " fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... | \n",
674 | " HT144 | \n",
675 | " cs(=o)c | \n",
676 | " GDSC | \n",
677 | "
\n",
678 | " \n",
679 | " 35 | \n",
680 | " fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... | \n",
681 | " RPMI7951 | \n",
682 | " cs(=o)c | \n",
683 | " GDSC | \n",
684 | "
\n",
685 | " \n",
686 | " 36 | \n",
687 | " fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... | \n",
688 | " SKMEL2 | \n",
689 | " cs(=o)c | \n",
690 | " GDSC | \n",
691 | "
\n",
692 | " \n",
693 | " 37 | \n",
694 | " fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... | \n",
695 | " SKMEL1 | \n",
696 | " cs(=o)c | \n",
697 | " GDSC | \n",
698 | "
\n",
699 | " \n",
700 | " 38 | \n",
701 | " fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... | \n",
702 | " HMCB | \n",
703 | " cs(=o)c | \n",
704 | " GDSC | \n",
705 | "
\n",
706 | " \n",
707 | " 39 | \n",
708 | " fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... | \n",
709 | " MDAMB435S | \n",
710 | " cs(=o)c | \n",
711 | " GDSC | \n",
712 | "
\n",
713 | " \n",
714 | " 40 | \n",
715 | " fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... | \n",
716 | " WM1799 | \n",
717 | " cs(=o)c | \n",
718 | " GDSC | \n",
719 | "
\n",
720 | " \n",
721 | " 41 | \n",
722 | " fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... | \n",
723 | " LOXIMVI | \n",
724 | " cs(=o)c | \n",
725 | " GDSC | \n",
726 | "
\n",
727 | " \n",
728 | " 42 | \n",
729 | " cs(=o)c | \n",
730 | " A375 | \n",
731 | " cs(=o)c | \n",
732 | " GDSC | \n",
733 | "
\n",
734 | " \n",
735 | " 43 | \n",
736 | " cs(=o)c | \n",
737 | " UACC62 | \n",
738 | " cs(=o)c | \n",
739 | " GDSC | \n",
740 | "
\n",
741 | " \n",
742 | " 44 | \n",
743 | " cs(=o)c | \n",
744 | " WM983B | \n",
745 | " cs(=o)c | \n",
746 | " GDSC | \n",
747 | "
\n",
748 | " \n",
749 | " 45 | \n",
750 | " cs(=o)c | \n",
751 | " MALME3M | \n",
752 | " cs(=o)c | \n",
753 | " GDSC | \n",
754 | "
\n",
755 | " \n",
756 | " 46 | \n",
757 | " cs(=o)c | \n",
758 | " A2058 | \n",
759 | " cs(=o)c | \n",
760 | " GDSC | \n",
761 | "
\n",
762 | " \n",
763 | " 47 | \n",
764 | " cs(=o)c | \n",
765 | " WM793 | \n",
766 | " cs(=o)c | \n",
767 | " GDSC | \n",
768 | "
\n",
769 | " \n",
770 | " 48 | \n",
771 | " cs(=o)c | \n",
772 | " HT144 | \n",
773 | " cs(=o)c | \n",
774 | " GDSC | \n",
775 | "
\n",
776 | " \n",
777 | " 49 | \n",
778 | " cs(=o)c | \n",
779 | " RPMI7951 | \n",
780 | " cs(=o)c | \n",
781 | " GDSC | \n",
782 | "
\n",
783 | " \n",
784 | " 50 | \n",
785 | " cs(=o)c | \n",
786 | " SKMEL2 | \n",
787 | " cs(=o)c | \n",
788 | " GDSC | \n",
789 | "
\n",
790 | " \n",
791 | " 51 | \n",
792 | " cs(=o)c | \n",
793 | " SKMEL1 | \n",
794 | " cs(=o)c | \n",
795 | " GDSC | \n",
796 | "
\n",
797 | " \n",
798 | " 52 | \n",
799 | " cs(=o)c | \n",
800 | " HMCB | \n",
801 | " cs(=o)c | \n",
802 | " GDSC | \n",
803 | "
\n",
804 | " \n",
805 | " 53 | \n",
806 | " cs(=o)c | \n",
807 | " MDAMB435S | \n",
808 | " cs(=o)c | \n",
809 | " GDSC | \n",
810 | "
\n",
811 | " \n",
812 | " 54 | \n",
813 | " cs(=o)c | \n",
814 | " WM1799 | \n",
815 | " cs(=o)c | \n",
816 | " GDSC | \n",
817 | "
\n",
818 | " \n",
819 | " 55 | \n",
820 | " cs(=o)c | \n",
821 | " LOXIMVI | \n",
822 | " cs(=o)c | \n",
823 | " GDSC | \n",
824 | "
\n",
825 | " \n",
826 | "
\n",
827 | "
"
828 | ],
829 | "text/plain": [
830 | " iv1 cell_line iv2 \\\n",
831 | "0 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... A375 cs(=o)c \n",
832 | "1 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... UACC62 cs(=o)c \n",
833 | "2 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... WM983B cs(=o)c \n",
834 | "3 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... MALME3M cs(=o)c \n",
835 | "4 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... A2058 cs(=o)c \n",
836 | "5 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... WM793 cs(=o)c \n",
837 | "6 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... HT144 cs(=o)c \n",
838 | "7 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... RPMI7951 cs(=o)c \n",
839 | "8 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... SKMEL2 cs(=o)c \n",
840 | "9 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... SKMEL1 cs(=o)c \n",
841 | "10 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... HMCB cs(=o)c \n",
842 | "11 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... MDAMB435S cs(=o)c \n",
843 | "12 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... WM1799 cs(=o)c \n",
844 | "13 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... LOXIMVI cs(=o)c \n",
845 | "14 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... A375 cs(=o)c \n",
846 | "15 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... UACC62 cs(=o)c \n",
847 | "16 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... WM983B cs(=o)c \n",
848 | "17 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... MALME3M cs(=o)c \n",
849 | "18 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... A2058 cs(=o)c \n",
850 | "19 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... WM793 cs(=o)c \n",
851 | "20 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... HT144 cs(=o)c \n",
852 | "21 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... RPMI7951 cs(=o)c \n",
853 | "22 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... SKMEL2 cs(=o)c \n",
854 | "23 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... SKMEL1 cs(=o)c \n",
855 | "24 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... HMCB cs(=o)c \n",
856 | "25 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... MDAMB435S cs(=o)c \n",
857 | "26 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... WM1799 cs(=o)c \n",
858 | "27 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... LOXIMVI cs(=o)c \n",
859 | "28 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... A375 cs(=o)c \n",
860 | "29 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... UACC62 cs(=o)c \n",
861 | "30 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... WM983B cs(=o)c \n",
862 | "31 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... MALME3M cs(=o)c \n",
863 | "32 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... A2058 cs(=o)c \n",
864 | "33 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... WM793 cs(=o)c \n",
865 | "34 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... HT144 cs(=o)c \n",
866 | "35 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... RPMI7951 cs(=o)c \n",
867 | "36 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... SKMEL2 cs(=o)c \n",
868 | "37 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... SKMEL1 cs(=o)c \n",
869 | "38 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... HMCB cs(=o)c \n",
870 | "39 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... MDAMB435S cs(=o)c \n",
871 | "40 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... WM1799 cs(=o)c \n",
872 | "41 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... LOXIMVI cs(=o)c \n",
873 | "42 cs(=o)c A375 cs(=o)c \n",
874 | "43 cs(=o)c UACC62 cs(=o)c \n",
875 | "44 cs(=o)c WM983B cs(=o)c \n",
876 | "45 cs(=o)c MALME3M cs(=o)c \n",
877 | "46 cs(=o)c A2058 cs(=o)c \n",
878 | "47 cs(=o)c WM793 cs(=o)c \n",
879 | "48 cs(=o)c HT144 cs(=o)c \n",
880 | "49 cs(=o)c RPMI7951 cs(=o)c \n",
881 | "50 cs(=o)c SKMEL2 cs(=o)c \n",
882 | "51 cs(=o)c SKMEL1 cs(=o)c \n",
883 | "52 cs(=o)c HMCB cs(=o)c \n",
884 | "53 cs(=o)c MDAMB435S cs(=o)c \n",
885 | "54 cs(=o)c WM1799 cs(=o)c \n",
886 | "55 cs(=o)c LOXIMVI cs(=o)c \n",
887 | "\n",
888 | " phenotype \n",
889 | "0 GDSC \n",
890 | "1 GDSC \n",
891 | "2 GDSC \n",
892 | "3 GDSC \n",
893 | "4 GDSC \n",
894 | "5 GDSC \n",
895 | "6 GDSC \n",
896 | "7 GDSC \n",
897 | "8 GDSC \n",
898 | "9 GDSC \n",
899 | "10 GDSC \n",
900 | "11 GDSC \n",
901 | "12 GDSC \n",
902 | "13 GDSC \n",
903 | "14 GDSC \n",
904 | "15 GDSC \n",
905 | "16 GDSC \n",
906 | "17 GDSC \n",
907 | "18 GDSC \n",
908 | "19 GDSC \n",
909 | "20 GDSC \n",
910 | "21 GDSC \n",
911 | "22 GDSC \n",
912 | "23 GDSC \n",
913 | "24 GDSC \n",
914 | "25 GDSC \n",
915 | "26 GDSC \n",
916 | "27 GDSC \n",
917 | "28 GDSC \n",
918 | "29 GDSC \n",
919 | "30 GDSC \n",
920 | "31 GDSC \n",
921 | "32 GDSC \n",
922 | "33 GDSC \n",
923 | "34 GDSC \n",
924 | "35 GDSC \n",
925 | "36 GDSC \n",
926 | "37 GDSC \n",
927 | "38 GDSC \n",
928 | "39 GDSC \n",
929 | "40 GDSC \n",
930 | "41 GDSC \n",
931 | "42 GDSC \n",
932 | "43 GDSC \n",
933 | "44 GDSC \n",
934 | "45 GDSC \n",
935 | "46 GDSC \n",
936 | "47 GDSC \n",
937 | "48 GDSC \n",
938 | "49 GDSC \n",
939 | "50 GDSC \n",
940 | "51 GDSC \n",
941 | "52 GDSC \n",
942 | "53 GDSC \n",
943 | "54 GDSC \n",
944 | "55 GDSC "
945 | ]
946 | },
947 | "execution_count": 6,
948 | "metadata": {},
949 | "output_type": "execute_result"
950 | }
951 | ],
952 | "source": [
953 | "# Construct a dataframe containing the experiments we want to run. In practice, the user would load\n",
954 | "# a premade dataframe here.\n",
955 | "input_df = pd.MultiIndex.from_product([\n",
956 | " iv_list,\n",
957 | " cl_list,\n",
958 | "], names=['iv1', 'cell_line'])\n",
959 | "input_df = input_df.to_frame(index=False).reset_index(drop=True)\n",
960 | "input_df['iv2'] = 'cs(=o)c' # DMSO\n",
961 | "input_df['phenotype'] = 'GDSC'\n",
962 | "input_df"
963 | ]
964 | },
965 | {
966 | "cell_type": "code",
967 | "execution_count": 7,
968 | "id": "65e53302",
969 | "metadata": {},
970 | "outputs": [
971 | {
972 | "name": "stdout",
973 | "output_type": "stream",
974 | "text": [
975 | "There are 1 iterations\n"
976 | ]
977 | },
978 | {
979 | "name": "stderr",
980 | "output_type": "stream",
981 | "text": [
982 | " 0%| | 0/1 [00:00, ?it/s]GPU available: True (cuda), used: True\n",
983 | "TPU available: False, using: 0 TPU cores\n",
984 | "IPU available: False, using: 0 IPUs\n",
985 | "HPU available: False, using: 0 HPUs\n",
986 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [MIG-c51c82f2-7e04-56a8-8b74-4211c9821715]\n"
987 | ]
988 | },
989 | {
990 | "name": "stdout",
991 | "output_type": "stream",
992 | "text": [
993 | "Predicting DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 28.32it/s]\n"
994 | ]
995 | },
996 | {
997 | "name": "stderr",
998 | "output_type": "stream",
999 | "text": [
1000 | "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00, 2.28s/it]\n"
1001 | ]
1002 | },
1003 | {
1004 | "data": {
1005 | "text/html": [
1006 | "\n",
1007 | "\n",
1020 | "
\n",
1021 | " \n",
1022 | " \n",
1023 | " | \n",
1024 | " iv1 | \n",
1025 | " cell_line | \n",
1026 | " iv2 | \n",
1027 | " phenotype | \n",
1028 | " pred | \n",
1029 | "
\n",
1030 | " \n",
1031 | " \n",
1032 | " \n",
1033 | " 0 | \n",
1034 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
1035 | " A375 | \n",
1036 | " cs(=o)c | \n",
1037 | " GDSC | \n",
1038 | " 0.489043 | \n",
1039 | "
\n",
1040 | " \n",
1041 | " 1 | \n",
1042 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
1043 | " UACC62 | \n",
1044 | " cs(=o)c | \n",
1045 | " GDSC | \n",
1046 | " 0.495842 | \n",
1047 | "
\n",
1048 | " \n",
1049 | " 2 | \n",
1050 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
1051 | " WM983B | \n",
1052 | " cs(=o)c | \n",
1053 | " GDSC | \n",
1054 | " 0.501370 | \n",
1055 | "
\n",
1056 | " \n",
1057 | " 3 | \n",
1058 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
1059 | " MALME3M | \n",
1060 | " cs(=o)c | \n",
1061 | " GDSC | \n",
1062 | " 0.484853 | \n",
1063 | "
\n",
1064 | " \n",
1065 | " 4 | \n",
1066 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
1067 | " A2058 | \n",
1068 | " cs(=o)c | \n",
1069 | " GDSC | \n",
1070 | " 0.499624 | \n",
1071 | "
\n",
1072 | " \n",
1073 | " 5 | \n",
1074 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
1075 | " WM793 | \n",
1076 | " cs(=o)c | \n",
1077 | " GDSC | \n",
1078 | " 0.493993 | \n",
1079 | "
\n",
1080 | " \n",
1081 | " 6 | \n",
1082 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
1083 | " HT144 | \n",
1084 | " cs(=o)c | \n",
1085 | " GDSC | \n",
1086 | " 0.495368 | \n",
1087 | "
\n",
1088 | " \n",
1089 | " 7 | \n",
1090 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
1091 | " RPMI7951 | \n",
1092 | " cs(=o)c | \n",
1093 | " GDSC | \n",
1094 | " 0.479674 | \n",
1095 | "
\n",
1096 | " \n",
1097 | " 8 | \n",
1098 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
1099 | " SKMEL2 | \n",
1100 | " cs(=o)c | \n",
1101 | " GDSC | \n",
1102 | " 0.506403 | \n",
1103 | "
\n",
1104 | " \n",
1105 | " 9 | \n",
1106 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
1107 | " SKMEL1 | \n",
1108 | " cs(=o)c | \n",
1109 | " GDSC | \n",
1110 | " 0.471883 | \n",
1111 | "
\n",
1112 | " \n",
1113 | " 10 | \n",
1114 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
1115 | " HMCB | \n",
1116 | " cs(=o)c | \n",
1117 | " GDSC | \n",
1118 | " 0.505215 | \n",
1119 | "
\n",
1120 | " \n",
1121 | " 11 | \n",
1122 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
1123 | " MDAMB435S | \n",
1124 | " cs(=o)c | \n",
1125 | " GDSC | \n",
1126 | " 0.500283 | \n",
1127 | "
\n",
1128 | " \n",
1129 | " 12 | \n",
1130 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
1131 | " WM1799 | \n",
1132 | " cs(=o)c | \n",
1133 | " GDSC | \n",
1134 | " 0.498168 | \n",
1135 | "
\n",
1136 | " \n",
1137 | " 13 | \n",
1138 | " oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... | \n",
1139 | " LOXIMVI | \n",
1140 | " cs(=o)c | \n",
1141 | " GDSC | \n",
1142 | " 0.495914 | \n",
1143 | "
\n",
1144 | " \n",
1145 | " 14 | \n",
1146 | " cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... | \n",
1147 | " A375 | \n",
1148 | " cs(=o)c | \n",
1149 | " GDSC | \n",
1150 | " 0.479828 | \n",
1151 | "
\n",
1152 | " \n",
1153 | " 15 | \n",
1154 | " cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... | \n",
1155 | " UACC62 | \n",
1156 | " cs(=o)c | \n",
1157 | " GDSC | \n",
1158 | " 0.496269 | \n",
1159 | "
\n",
1160 | " \n",
1161 | " 16 | \n",
1162 | " cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... | \n",
1163 | " WM983B | \n",
1164 | " cs(=o)c | \n",
1165 | " GDSC | \n",
1166 | " 0.481987 | \n",
1167 | "
\n",
1168 | " \n",
1169 | " 17 | \n",
1170 | " cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... | \n",
1171 | " MALME3M | \n",
1172 | " cs(=o)c | \n",
1173 | " GDSC | \n",
1174 | " 0.470828 | \n",
1175 | "
\n",
1176 | " \n",
1177 | " 18 | \n",
1178 | " cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... | \n",
1179 | " A2058 | \n",
1180 | " cs(=o)c | \n",
1181 | " GDSC | \n",
1182 | " 0.486716 | \n",
1183 | "
\n",
1184 | " \n",
1185 | " 19 | \n",
1186 | " cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... | \n",
1187 | " WM793 | \n",
1188 | " cs(=o)c | \n",
1189 | " GDSC | \n",
1190 | " 0.492139 | \n",
1191 | "
\n",
1192 | " \n",
1193 | " 20 | \n",
1194 | " cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... | \n",
1195 | " HT144 | \n",
1196 | " cs(=o)c | \n",
1197 | " GDSC | \n",
1198 | " 0.486285 | \n",
1199 | "
\n",
1200 | " \n",
1201 | " 21 | \n",
1202 | " cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... | \n",
1203 | " RPMI7951 | \n",
1204 | " cs(=o)c | \n",
1205 | " GDSC | \n",
1206 | " 0.473088 | \n",
1207 | "
\n",
1208 | " \n",
1209 | " 22 | \n",
1210 | " cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... | \n",
1211 | " SKMEL2 | \n",
1212 | " cs(=o)c | \n",
1213 | " GDSC | \n",
1214 | " 0.487896 | \n",
1215 | "
\n",
1216 | " \n",
1217 | " 23 | \n",
1218 | " cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... | \n",
1219 | " SKMEL1 | \n",
1220 | " cs(=o)c | \n",
1221 | " GDSC | \n",
1222 | " 0.469014 | \n",
1223 | "
\n",
1224 | " \n",
1225 | " 24 | \n",
1226 | " cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... | \n",
1227 | " HMCB | \n",
1228 | " cs(=o)c | \n",
1229 | " GDSC | \n",
1230 | " 0.505564 | \n",
1231 | "
\n",
1232 | " \n",
1233 | " 25 | \n",
1234 | " cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... | \n",
1235 | " MDAMB435S | \n",
1236 | " cs(=o)c | \n",
1237 | " GDSC | \n",
1238 | " 0.490048 | \n",
1239 | "
\n",
1240 | " \n",
1241 | " 26 | \n",
1242 | " cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... | \n",
1243 | " WM1799 | \n",
1244 | " cs(=o)c | \n",
1245 | " GDSC | \n",
1246 | " 0.486165 | \n",
1247 | "
\n",
1248 | " \n",
1249 | " 27 | \n",
1250 | " cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... | \n",
1251 | " LOXIMVI | \n",
1252 | " cs(=o)c | \n",
1253 | " GDSC | \n",
1254 | " 0.495294 | \n",
1255 | "
\n",
1256 | " \n",
1257 | " 28 | \n",
1258 | " fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... | \n",
1259 | " A375 | \n",
1260 | " cs(=o)c | \n",
1261 | " GDSC | \n",
1262 | " 0.502929 | \n",
1263 | "
\n",
1264 | " \n",
1265 | " 29 | \n",
1266 | " fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... | \n",
1267 | " UACC62 | \n",
1268 | " cs(=o)c | \n",
1269 | " GDSC | \n",
1270 | " 0.514967 | \n",
1271 | "
\n",
1272 | " \n",
1273 | " 30 | \n",
1274 | " fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... | \n",
1275 | " WM983B | \n",
1276 | " cs(=o)c | \n",
1277 | " GDSC | \n",
1278 | " 0.490298 | \n",
1279 | "
\n",
1280 | " \n",
1281 | " 31 | \n",
1282 | " fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... | \n",
1283 | " MALME3M | \n",
1284 | " cs(=o)c | \n",
1285 | " GDSC | \n",
1286 | " 0.488163 | \n",
1287 | "
\n",
1288 | " \n",
1289 | " 32 | \n",
1290 | " fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... | \n",
1291 | " A2058 | \n",
1292 | " cs(=o)c | \n",
1293 | " GDSC | \n",
1294 | " 0.506753 | \n",
1295 | "
\n",
1296 | " \n",
1297 | " 33 | \n",
1298 | " fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... | \n",
1299 | " WM793 | \n",
1300 | " cs(=o)c | \n",
1301 | " GDSC | \n",
1302 | " 0.511536 | \n",
1303 | "
\n",
1304 | " \n",
1305 | " 34 | \n",
1306 | " fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... | \n",
1307 | " HT144 | \n",
1308 | " cs(=o)c | \n",
1309 | " GDSC | \n",
1310 | " 0.504222 | \n",
1311 | "
\n",
1312 | " \n",
1313 | " 35 | \n",
1314 | " fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... | \n",
1315 | " RPMI7951 | \n",
1316 | " cs(=o)c | \n",
1317 | " GDSC | \n",
1318 | " 0.481429 | \n",
1319 | "
\n",
1320 | " \n",
1321 | " 36 | \n",
1322 | " fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... | \n",
1323 | " SKMEL2 | \n",
1324 | " cs(=o)c | \n",
1325 | " GDSC | \n",
1326 | " 0.514762 | \n",
1327 | "
\n",
1328 | " \n",
1329 | " 37 | \n",
1330 | " fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... | \n",
1331 | " SKMEL1 | \n",
1332 | " cs(=o)c | \n",
1333 | " GDSC | \n",
1334 | " 0.490700 | \n",
1335 | "
\n",
1336 | " \n",
1337 | " 38 | \n",
1338 | " fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... | \n",
1339 | " HMCB | \n",
1340 | " cs(=o)c | \n",
1341 | " GDSC | \n",
1342 | " 0.521368 | \n",
1343 | "
\n",
1344 | " \n",
1345 | " 39 | \n",
1346 | " fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... | \n",
1347 | " MDAMB435S | \n",
1348 | " cs(=o)c | \n",
1349 | " GDSC | \n",
1350 | " 0.498836 | \n",
1351 | "
\n",
1352 | " \n",
1353 | " 40 | \n",
1354 | " fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... | \n",
1355 | " WM1799 | \n",
1356 | " cs(=o)c | \n",
1357 | " GDSC | \n",
1358 | " 0.499848 | \n",
1359 | "
\n",
1360 | " \n",
1361 | " 41 | \n",
1362 | " fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... | \n",
1363 | " LOXIMVI | \n",
1364 | " cs(=o)c | \n",
1365 | " GDSC | \n",
1366 | " 0.512768 | \n",
1367 | "
\n",
1368 | " \n",
1369 | " 42 | \n",
1370 | " cs(=o)c | \n",
1371 | " A375 | \n",
1372 | " cs(=o)c | \n",
1373 | " GDSC | \n",
1374 | " 0.523384 | \n",
1375 | "
\n",
1376 | " \n",
1377 | " 43 | \n",
1378 | " cs(=o)c | \n",
1379 | " UACC62 | \n",
1380 | " cs(=o)c | \n",
1381 | " GDSC | \n",
1382 | " 0.526230 | \n",
1383 | "
\n",
1384 | " \n",
1385 | " 44 | \n",
1386 | " cs(=o)c | \n",
1387 | " WM983B | \n",
1388 | " cs(=o)c | \n",
1389 | " GDSC | \n",
1390 | " 0.529374 | \n",
1391 | "
\n",
1392 | " \n",
1393 | " 45 | \n",
1394 | " cs(=o)c | \n",
1395 | " MALME3M | \n",
1396 | " cs(=o)c | \n",
1397 | " GDSC | \n",
1398 | " 0.517211 | \n",
1399 | "
\n",
1400 | " \n",
1401 | " 46 | \n",
1402 | " cs(=o)c | \n",
1403 | " A2058 | \n",
1404 | " cs(=o)c | \n",
1405 | " GDSC | \n",
1406 | " 0.541236 | \n",
1407 | "
\n",
1408 | " \n",
1409 | " 47 | \n",
1410 | " cs(=o)c | \n",
1411 | " WM793 | \n",
1412 | " cs(=o)c | \n",
1413 | " GDSC | \n",
1414 | " 0.524480 | \n",
1415 | "
\n",
1416 | " \n",
1417 | " 48 | \n",
1418 | " cs(=o)c | \n",
1419 | " HT144 | \n",
1420 | " cs(=o)c | \n",
1421 | " GDSC | \n",
1422 | " 0.531894 | \n",
1423 | "
\n",
1424 | " \n",
1425 | " 49 | \n",
1426 | " cs(=o)c | \n",
1427 | " RPMI7951 | \n",
1428 | " cs(=o)c | \n",
1429 | " GDSC | \n",
1430 | " 0.505558 | \n",
1431 | "
\n",
1432 | " \n",
1433 | " 50 | \n",
1434 | " cs(=o)c | \n",
1435 | " SKMEL2 | \n",
1436 | " cs(=o)c | \n",
1437 | " GDSC | \n",
1438 | " 0.529134 | \n",
1439 | "
\n",
1440 | " \n",
1441 | " 51 | \n",
1442 | " cs(=o)c | \n",
1443 | " SKMEL1 | \n",
1444 | " cs(=o)c | \n",
1445 | " GDSC | \n",
1446 | " 0.513998 | \n",
1447 | "
\n",
1448 | " \n",
1449 | " 52 | \n",
1450 | " cs(=o)c | \n",
1451 | " HMCB | \n",
1452 | " cs(=o)c | \n",
1453 | " GDSC | \n",
1454 | " 0.540129 | \n",
1455 | "
\n",
1456 | " \n",
1457 | " 53 | \n",
1458 | " cs(=o)c | \n",
1459 | " MDAMB435S | \n",
1460 | " cs(=o)c | \n",
1461 | " GDSC | \n",
1462 | " 0.538795 | \n",
1463 | "
\n",
1464 | " \n",
1465 | " 54 | \n",
1466 | " cs(=o)c | \n",
1467 | " WM1799 | \n",
1468 | " cs(=o)c | \n",
1469 | " GDSC | \n",
1470 | " 0.526665 | \n",
1471 | "
\n",
1472 | " \n",
1473 | " 55 | \n",
1474 | " cs(=o)c | \n",
1475 | " LOXIMVI | \n",
1476 | " cs(=o)c | \n",
1477 | " GDSC | \n",
1478 | " 0.528534 | \n",
1479 | "
\n",
1480 | " \n",
1481 | "
\n",
1482 | "
"
1483 | ],
1484 | "text/plain": [
1485 | " iv1 cell_line iv2 \\\n",
1486 | "0 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... A375 cs(=o)c \n",
1487 | "1 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... UACC62 cs(=o)c \n",
1488 | "2 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... WM983B cs(=o)c \n",
1489 | "3 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... MALME3M cs(=o)c \n",
1490 | "4 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... A2058 cs(=o)c \n",
1491 | "5 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... WM793 cs(=o)c \n",
1492 | "6 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... HT144 cs(=o)c \n",
1493 | "7 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... RPMI7951 cs(=o)c \n",
1494 | "8 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... SKMEL2 cs(=o)c \n",
1495 | "9 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... SKMEL1 cs(=o)c \n",
1496 | "10 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... HMCB cs(=o)c \n",
1497 | "11 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... MDAMB435S cs(=o)c \n",
1498 | "12 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... WM1799 cs(=o)c \n",
1499 | "13 oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc... LOXIMVI cs(=o)c \n",
1500 | "14 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... A375 cs(=o)c \n",
1501 | "15 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... UACC62 cs(=o)c \n",
1502 | "16 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... WM983B cs(=o)c \n",
1503 | "17 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... MALME3M cs(=o)c \n",
1504 | "18 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... A2058 cs(=o)c \n",
1505 | "19 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... WM793 cs(=o)c \n",
1506 | "20 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... HT144 cs(=o)c \n",
1507 | "21 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... RPMI7951 cs(=o)c \n",
1508 | "22 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... SKMEL2 cs(=o)c \n",
1509 | "23 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... SKMEL1 cs(=o)c \n",
1510 | "24 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... HMCB cs(=o)c \n",
1511 | "25 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... MDAMB435S cs(=o)c \n",
1512 | "26 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... WM1799 cs(=o)c \n",
1513 | "27 cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc... LOXIMVI cs(=o)c \n",
1514 | "28 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... A375 cs(=o)c \n",
1515 | "29 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... UACC62 cs(=o)c \n",
1516 | "30 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... WM983B cs(=o)c \n",
1517 | "31 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... MALME3M cs(=o)c \n",
1518 | "32 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... A2058 cs(=o)c \n",
1519 | "33 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... WM793 cs(=o)c \n",
1520 | "34 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... HT144 cs(=o)c \n",
1521 | "35 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... RPMI7951 cs(=o)c \n",
1522 | "36 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... SKMEL2 cs(=o)c \n",
1523 | "37 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... SKMEL1 cs(=o)c \n",
1524 | "38 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... HMCB cs(=o)c \n",
1525 | "39 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... MDAMB435S cs(=o)c \n",
1526 | "40 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... WM1799 cs(=o)c \n",
1527 | "41 fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c... LOXIMVI cs(=o)c \n",
1528 | "42 cs(=o)c A375 cs(=o)c \n",
1529 | "43 cs(=o)c UACC62 cs(=o)c \n",
1530 | "44 cs(=o)c WM983B cs(=o)c \n",
1531 | "45 cs(=o)c MALME3M cs(=o)c \n",
1532 | "46 cs(=o)c A2058 cs(=o)c \n",
1533 | "47 cs(=o)c WM793 cs(=o)c \n",
1534 | "48 cs(=o)c HT144 cs(=o)c \n",
1535 | "49 cs(=o)c RPMI7951 cs(=o)c \n",
1536 | "50 cs(=o)c SKMEL2 cs(=o)c \n",
1537 | "51 cs(=o)c SKMEL1 cs(=o)c \n",
1538 | "52 cs(=o)c HMCB cs(=o)c \n",
1539 | "53 cs(=o)c MDAMB435S cs(=o)c \n",
1540 | "54 cs(=o)c WM1799 cs(=o)c \n",
1541 | "55 cs(=o)c LOXIMVI cs(=o)c \n",
1542 | "\n",
1543 | " phenotype pred \n",
1544 | "0 GDSC 0.489043 \n",
1545 | "1 GDSC 0.495842 \n",
1546 | "2 GDSC 0.501370 \n",
1547 | "3 GDSC 0.484853 \n",
1548 | "4 GDSC 0.499624 \n",
1549 | "5 GDSC 0.493993 \n",
1550 | "6 GDSC 0.495368 \n",
1551 | "7 GDSC 0.479674 \n",
1552 | "8 GDSC 0.506403 \n",
1553 | "9 GDSC 0.471883 \n",
1554 | "10 GDSC 0.505215 \n",
1555 | "11 GDSC 0.500283 \n",
1556 | "12 GDSC 0.498168 \n",
1557 | "13 GDSC 0.495914 \n",
1558 | "14 GDSC 0.479828 \n",
1559 | "15 GDSC 0.496269 \n",
1560 | "16 GDSC 0.481987 \n",
1561 | "17 GDSC 0.470828 \n",
1562 | "18 GDSC 0.486716 \n",
1563 | "19 GDSC 0.492139 \n",
1564 | "20 GDSC 0.486285 \n",
1565 | "21 GDSC 0.473088 \n",
1566 | "22 GDSC 0.487896 \n",
1567 | "23 GDSC 0.469014 \n",
1568 | "24 GDSC 0.505564 \n",
1569 | "25 GDSC 0.490048 \n",
1570 | "26 GDSC 0.486165 \n",
1571 | "27 GDSC 0.495294 \n",
1572 | "28 GDSC 0.502929 \n",
1573 | "29 GDSC 0.514967 \n",
1574 | "30 GDSC 0.490298 \n",
1575 | "31 GDSC 0.488163 \n",
1576 | "32 GDSC 0.506753 \n",
1577 | "33 GDSC 0.511536 \n",
1578 | "34 GDSC 0.504222 \n",
1579 | "35 GDSC 0.481429 \n",
1580 | "36 GDSC 0.514762 \n",
1581 | "37 GDSC 0.490700 \n",
1582 | "38 GDSC 0.521368 \n",
1583 | "39 GDSC 0.498836 \n",
1584 | "40 GDSC 0.499848 \n",
1585 | "41 GDSC 0.512768 \n",
1586 | "42 GDSC 0.523384 \n",
1587 | "43 GDSC 0.526230 \n",
1588 | "44 GDSC 0.529374 \n",
1589 | "45 GDSC 0.517211 \n",
1590 | "46 GDSC 0.541236 \n",
1591 | "47 GDSC 0.524480 \n",
1592 | "48 GDSC 0.531894 \n",
1593 | "49 GDSC 0.505558 \n",
1594 | "50 GDSC 0.529134 \n",
1595 | "51 GDSC 0.513998 \n",
1596 | "52 GDSC 0.540129 \n",
1597 | "53 GDSC 0.538795 \n",
1598 | "54 GDSC 0.526665 \n",
1599 | "55 GDSC 0.528534 "
1600 | ]
1601 | },
1602 | "execution_count": 7,
1603 | "metadata": {},
1604 | "output_type": "execute_result"
1605 | }
1606 | ],
1607 | "source": [
1608 | "df = model.predict(input_df, num_iterations=1, save=False)\n",
1609 | "df"
1610 | ]
1611 | }
1612 | ],
1613 | "metadata": {
1614 | "kernelspec": {
1615 | "display_name": "Python [conda env:prophet_api]",
1616 | "language": "python",
1617 | "name": "conda-env-prophet_api-py"
1618 | },
1619 | "language_info": {
1620 | "codemirror_mode": {
1621 | "name": "ipython",
1622 | "version": 3
1623 | },
1624 | "file_extension": ".py",
1625 | "mimetype": "text/x-python",
1626 | "name": "python",
1627 | "nbconvert_exporter": "python",
1628 | "pygments_lexer": "ipython3",
1629 | "version": "3.12.4"
1630 | }
1631 | },
1632 | "nbformat": 4,
1633 | "nbformat_minor": 5
1634 | }
1635 |
--------------------------------------------------------------------------------