.
675 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Protein Design by Directed Evolution guided by Large Language Models
4 |
5 |
6 | Contributors:
7 | * Tran Van Trong Thanh
8 | * Truong-Son Hy (PI / Correspondent)
9 |
10 | Publication:
11 |
12 | https://doi.org/10.1109/TEVC.2024.3439690
13 |
14 | Preprint:
15 |
16 | https://doi.org/10.1101/2023.11.28.568945
17 |
18 | ## Table of Contents:
19 |
20 | - [Introduction](#introduction)
21 | - [Structure Description](#structure-description)
22 | - [Installation](#installation)
23 | - [Usage](#usage)
24 | - [Training](#training)
25 | - [Inference](#inference)
26 | - [Citation](#citation)
27 | - [License](#license)
28 |
29 | ## Introduction
30 | This is the official implementation of the paper [Protein Design by Directed Evolution guided by Large Language Models](https://www.biorxiv.org/content/10.1101/2023.11.28.568945v2).
31 |
32 | |  |
33 | |:--:|
34 | | ***(A)** Workflow of traditional directed evolution. **(B)** We train the in-silico oracle as the “ground-truth” evaluator to predict the fitness score of each generated sequence. **(C)** Our proposed MLDE framework.* |
35 |
36 | ## Structure description
37 |
38 | Our repository is structured as follows:
39 | ```python
40 | .
41 | |-assets
42 | |-README.md
43 | |-LICENSE
44 | |-preprocessed_data # training data
45 | |-requirements.txt
46 | |-scripts
47 | | |-train_decoder.py # trains oracle
48 | | |-run_de.sh # Shell file to run
49 | | |-run_discrete_de.py # Python file to run
50 | | |-preprocess # contains codes to preprocess data
51 | |-exps
52 | | |-results # results stored here
53 | | |-logs # logs stored here
54 | | |-checkpoints # checkpoints stored here
55 | |-setup.py
56 | |-de # contains main source code
57 | ```
58 |
59 | ## Installation
60 |
61 | You should have Python 3.10 or higher. I highly recommend creating a virtual environment like conda. If so, run the below commands to install:
62 |
63 | ```shell
64 | git clone https://github.com/HySonLab/Directed_Evolution.git
65 | cd Directed_Evolution
66 |
67 | conda create -n mlde python=3.10 -y
68 | conda activate mlde
69 |
70 | pip install -e .
71 | ```
72 |
73 | ## Usage
74 |
75 | ### Training
76 |
77 | To train the oracle (i.e., Attention1D) on certain dataset (e.g., AAV), simply run:
78 | ```shell
79 | python train_decoder.py \
80 | --data_file /path/to/AAV.csv \
81 | --dataset_name AAV \
82 | --pretrained_encoder facebook/esm_t12_35M_UR5D \
83 | --dec_hidden_dim 1280 \
84 | --batch_size 256 \
85 | --ckpt_path /path/to/ckpt_to_continue_from \
86 | --devices 0 \
87 | --grad_accum_steps 1 \
88 | --lr 5e-5 \
89 | --num_epochs 50 \
90 | --num_ckpts 2 \
91 | ```
92 | If you want to train the model without using WandB, just prepending `WANDB_DISABLED=True` to the command like below
93 |
94 | ```shell
95 | WANDB_DISABLED=True python train_decoder.py ...
96 | ```
97 |
98 | Arguments list:
99 | ```shell
100 | options:
101 | -h, --help show this help message and exit
102 | --data_file DATA_FILE
103 | Path to data directory.
104 | --dataset_name DATASET_NAME
105 | Name of trained dataset.
106 | --pretrained_encoder PRETRAINED_ENCODER
107 | Path to pretrained encoder.
108 | --dec_hidden_dim DEC_HIDDEN_DIM
109 | Hidden dim of decoder.
110 | --batch_size BATCH_SIZE
111 | Batch size.
112 | --ckpt_path CKPT_PATH
113 | Checkpoint of model.
114 | --devices DEVICES Training devices separated by comma.
115 | --output_dir OUTPUT_DIR
116 | Path to output directory.
117 | --grad_accum_steps GRAD_ACCUM_STEPS
118 | No. updates steps to accumulate the gradient.
119 | --lr LR Learning rate.
120 | --num_epochs NUM_EPOCHS
121 | Number of epochs.
122 | --wandb_project WANDB_PROJECT
123 | WandB project's name.
124 | --seed SEED Random seed for reproducibility.
125 | --set_seed_only Whether to not set deterministic flag.
126 | --num_workers NUM_WORKERS
127 | No. workers.
128 | --num_ckpts NUM_CKPTS
129 | Maximum no. checkpoints can be saved.
130 | --log_interval LOG_INTERVAL
131 | How often to log within steps.
132 | --precision {highest,high,medium}
133 | Internal precision of float32 matrix multiplications.
134 | ```
135 |
136 | ### Inference
137 |
138 | After having oracle's checkpoint corresponding to a dataset (e.g., AAV), you can generate novel proteins by running:
139 | ```shell
140 | python run_discrete_de.py \
141 | --wt DEEEIRTTNPVATEQYGSVSTNLQRGNR
142 | --wt_fitness -100 \
143 | --n_steps 60 \
144 | --population 128 \
145 | --num_proposes_per_var 4 \
146 | --k 1 \
147 | --rm_dups \
148 | --population_ratio_per_mask 0.6 0.4 \
149 | --pretrained_mutation_name facebook/esm2_t12_35M_UR50D \
150 | --dec_hidden_size 1280 \
151 | --predictor_ckpt_path /path/to/ckpt \
152 | --verbose \
153 | --devices 0 \
154 | ```
155 |
156 | Arguments list:
157 | ```shell
158 | options:
159 | -h, --help show this help message and exit
160 | --data_file DATA_FILE
161 | Path to data file.
162 | --wt WT Amino acid sequence.
163 | --wt_fitness WT_FITNESS
164 | Wild-type sequence's fitness.
165 | --n_steps N_STEPS No. steps to run directed evolution.
166 | --population POPULATION
167 | No. population per step.
168 | --num_proposes_per_var NUM_PROPOSES_PER_VAR
169 | Number of proposed mutations for each variant in the pool.
170 | --k K Split sequence into multiple tokens with length `k`.
171 | --rm_dups Whether to remove duplications in the proposed candidate pool.
172 | --population_ratio_per_mask POPULATION_RATIO_PER_MASK [POPULATION_RATIO_PER_MASK ...]
173 | Population ratio to run per masker.
174 | --pretrained_mutation_name PRETRAINED_MUTATION_NAME
175 | Pretrained model name or path for mutation checkpoint.
176 | --dec_hidden_size DEC_HIDDEN_SIZE
177 | Decoder hidden size (for conditional task).
178 | --predictor_ckpt_path PREDICTOR_CKPT_PATH
179 | Path to fitness predictor checkpoints.
180 | --num_masked_tokens NUM_MASKED_TOKENS
181 | No. masked tokens to predict.
182 | --mask_high_importance
183 | Whether to mask high-importance token in the sequence.
184 | --verbose Whether to display output.
185 | --seed SEED Random seed.
186 | --set_seed_only Whether to enable full determinism or set random seed only.
187 | --result_dir RESULT_DIR
188 | Directory to save result csv file.
189 | --save_name SAVE_NAME
190 | Filename of the result csv file.
191 | --devices DEVICES Devices, separated by commas.
192 | ```
193 |
194 | ## Citation
195 | If our paper aids your work, please kindly cite our paper using the following bibtex
196 | ```bibtex
197 | @ARTICLE{10628050,
198 | author={Tran, Thanh V. T. and Hy, Truong Son},
199 | journal={IEEE Transactions on Evolutionary Computation},
200 | title={Protein Design by Directed Evolution Guided by Large Language Models},
201 | year={2025},
202 | volume={29},
203 | number={2},
204 | pages={418-428},
205 | keywords={Proteins;Evolution (biology);Large language models;Optimization;Transformers;Protein engineering;Task analysis;Directed evolution;large language models (LLMs);machine learning (ML);protein engineering},
206 | doi={10.1109/TEVC.2024.3439690}}
207 | ```
208 |
209 | ```bibtex
210 | @article {Tran2023.11.28.568945,
211 | author = {Trong Thanh Tran and Truong Son Hy},
212 | title = {Protein Design by Directed Evolution Guided by Large Language Models},
213 | elocation-id = {2023.11.28.568945},
214 | year = {2023},
215 | doi = {10.1101/2023.11.28.568945},
216 | publisher = {Cold Spring Harbor Laboratory},
217 | abstract = {Directed evolution, a strategy for protein engineering, optimizes protein properties (i.e., fitness) by a rigorous and resource-intensive process of screening or selecting among a vast range of mutations. By conducting an in-silico screening of sequence properties, machine learning-guided directed evolution (MLDE) can expedite the optimization process and alleviate the experimental workload. In this work, we propose a general MLDE framework in which we apply recent advancements of Deep Learning in protein representation learning and protein property prediction to accelerate the searching and optimization processes. In particular, we introduce an optimization pipeline that utilizes Large Language Models (LLMs) to pinpoint the mutation hotspots in the sequence and then suggest replacements to improve the overall fitness. Our experiments have shown the superior efficiency and efficacy of our proposed framework in the conditional protein generation, in comparision with traditional searching algorithms, diffusion models, and other generative models. We expect this work will shed a new light on not only protein engineering but also on solving combinatorial problems using data-driven methods. Our implementation is publicly available at https://github.com/HySonLab/Directed_EvolutionCompeting Interest StatementThe authors have declared no competing interest.},
218 | URL = {https://www.biorxiv.org/content/early/2023/11/29/2023.11.28.568945},
219 | eprint = {https://www.biorxiv.org/content/early/2023/11/29/2023.11.28.568945.full.pdf},
220 | journal = {bioRxiv}
221 | }
222 | ```
223 |
224 | ## License
225 |
226 | [GPL-3.0 License](./LICENSE)
227 |
--------------------------------------------------------------------------------
/assets/main.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HySonLab/Directed_Evolution/a9ef68497dd722cfc933eb76dd8a5a75424d97f9/assets/main.png
--------------------------------------------------------------------------------
/de/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HySonLab/Directed_Evolution/a9ef68497dd722cfc933eb76dd8a5a75424d97f9/de/__init__.py
--------------------------------------------------------------------------------
/de/common/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HySonLab/Directed_Evolution/a9ef68497dd722cfc933eb76dd8a5a75424d97f9/de/common/__init__.py
--------------------------------------------------------------------------------
/de/common/constants.py:
--------------------------------------------------------------------------------
1 | from itertools import product
2 |
3 |
4 | CANONICAL_ALPHABET = [
5 | 'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R',
6 | 'S', 'T', 'V', 'W', 'Y'
7 | ]
8 |
9 |
10 | def all_possible_kmers(k: int):
11 | kmers = [''.join(comb) for comb in product(CANONICAL_ALPHABET, repeat=k)]
12 | return kmers
13 |
--------------------------------------------------------------------------------
/de/common/io_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import List, Dict
3 | from .constants import CANONICAL_ALPHABET
4 |
5 |
6 | def read_fasta(
7 | filepath: str,
8 | do_filter: bool = True,
9 | max_seq_length: int = 1024,
10 | accepted_residues: List[str] = CANONICAL_ALPHABET,
11 | ) -> Dict[str, str]:
12 | """ Read a fasta file
13 |
14 | Args:
15 | filepath (str): path to fasta file
16 |
17 | Returns:
18 | sequences (dict): map multiple sequence ids to corresponding sequences."""
19 | sequences = {}
20 | with open(filepath, 'r') as file:
21 | sequence_id = None
22 | sequence = ''
23 | for line in file:
24 | line = line.strip()
25 | if line.startswith(">"):
26 | if sequence_id:
27 | sequences[sequence_id] = sequence.upper()
28 | sequence_id = line[1:]
29 | sequence = ''
30 | else:
31 | sequence += line.strip()
32 | if sequence_id:
33 | sequences[sequence_id] = sequence.upper()
34 |
35 | if do_filter:
36 | sequences = filter_seqs(sequences, max_seq_length, accepted_residues)
37 |
38 | return sequences
39 |
40 |
41 | def filter_seqs(
42 | sequences: List[str],
43 | max_length: int = 1024,
44 | accepted_residues: List[str] = CANONICAL_ALPHABET
45 | ) -> List[str]:
46 | valid_residues = "".join(accepted_residues)
47 |
48 | def contains_invalid_chars(input):
49 | pattern = f"[^{re.escape(valid_residues)}]"
50 | return bool(re.search(pattern, input))
51 |
52 | new_seqs = {}
53 | for id, seq in sequences.items():
54 | if max_length > 0 and len(seq) > max_length:
55 | continue
56 | if contains_invalid_chars(seq):
57 | continue
58 | new_seqs[id] = seq
59 | return new_seqs
60 |
--------------------------------------------------------------------------------
/de/common/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import time
4 | import torch
5 | import random
6 | from datetime import datetime
7 | from functools import wraps
8 | from itertools import combinations
9 | from polyleven import levenshtein
10 | from typing import List
11 | from .constants import CANONICAL_ALPHABET
12 |
13 |
14 | def edit_distance(seq1, seq2):
15 | return levenshtein(seq1, seq2)
16 |
17 |
18 | def measure_diversity(seqs: List[str]):
19 | dists = []
20 | for pair in combinations(seqs, 2):
21 | dists.append(edit_distance(*pair))
22 | return np.mean(dists)
23 |
24 |
25 | def measure_distwt(seqs: List[str], wt: str):
26 | dists = []
27 | for seq in seqs:
28 | dists.append(edit_distance(seq, wt))
29 | return np.mean(dists)
30 |
31 |
32 | def measure_novelty(seqs: List[str], train_seqs: List[str]):
33 | all_novelty = []
34 | for seq in seqs:
35 | min_dist = 1e9
36 | for known in train_seqs:
37 | dist = edit_distance(seq, known)
38 | if dist == 0:
39 | all_novelty.append(dist)
40 | break
41 | elif dist < min_dist:
42 | min_dist = dist
43 | all_novelty.append(min_dist)
44 | return np.mean(all_novelty)
45 |
46 |
47 | def remove_duplicates(seqs: List[str], scores: List[float], return_idx: bool = False):
48 | new_seqs = []
49 | new_scores = []
50 | ids = []
51 | for idx, (seq, score) in enumerate(zip(seqs, scores)):
52 | if seq in new_seqs:
53 | continue
54 | else:
55 | new_seqs.append(seq)
56 | new_scores.append(score)
57 | ids.append(idx)
58 | return new_seqs, new_scores, ids if return_idx else None
59 |
60 |
61 | def get_mutated_sequence(focus_seq: str,
62 | mutant: str,
63 | start_idx: int = 1,
64 | AA_vocab: str = ''.join(CANONICAL_ALPHABET)) -> str:
65 | """Mutates an input sequence (focus_seq) via an input mutation triplet (substitutions only).
66 |
67 | Args:
68 | focus_seq (str): Input sequence.
69 | mutant (str): list of mutants applied to input sequence (e.g., "B12F:A83M").
70 | start_idx (int): Index to start indexing.
71 | AA_vocab (str): Amino acids.
72 |
73 | Returns:
74 | (str): mutated sequence.
75 | """
76 | if mutant == "":
77 | return focus_seq
78 | mutated_seq = list(focus_seq)
79 | for mutation in mutant.split(":"):
80 | try:
81 | from_AA, position, to_AA = mutation[0], int(
82 | mutation[1:-1]), mutation[-1]
83 | except ValueError:
84 | print("Issue with mutant: " + str(mutation))
85 | relative_position = position - start_idx
86 | assert from_AA == focus_seq[relative_position], \
87 | f"Invalid from_AA or mutant position: {str(mutation)} from_AA {str(str(from_AA))} " \
88 | f"relative pos: {str(relative_position)} focus_seq: {str(focus_seq)}"
89 | assert to_AA in AA_vocab, f"Mutant to_AA is invalid: {str(mutation)}"
90 | mutated_seq[relative_position] = to_AA
91 | return "".join(mutated_seq)
92 |
93 |
94 | def get_mutants(wt_seq: str, variant: str, offset_idx: int = 1):
95 | assert len(wt_seq) == len(variant), "Length must be the same."
96 | mutant = []
97 | for i in range(len(wt_seq)):
98 | if wt_seq[i] != variant[i]:
99 | mutant.append(f"{wt_seq[i]}{i + offset_idx}{variant[i]}")
100 |
101 | return ":".join(mutant)
102 |
103 |
104 | def split_kmers2(seqs: List[str], k: int = 3) -> List[List[str]]:
105 | return [[seq[i:i + k] for i in range(len(seq) - k + 1)] for seq in seqs]
106 |
107 |
108 | def set_seed(seed: int):
109 | """Set random seed for reproducibility.
110 |
111 | Args:
112 | seed (int): seed number.
113 | """
114 | random.seed(seed)
115 | np.random.seed(seed)
116 | torch.manual_seed(seed)
117 | torch.cuda.manual_seed_all(seed)
118 |
119 |
120 | def enable_full_deterministic(seed: int):
121 | """Helper function for reproducible behavior during distributed training
122 | See: https://pytorch.org/docs/stable/notes/randomness.html
123 | """
124 | set_seed(seed)
125 |
126 | # Enable PyTorch deterministic mode. This potentially requires either the environment
127 | # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
128 | # depending on the CUDA version, so we set them both here
129 | os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
130 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
131 | torch.use_deterministic_algorithms(True, warn_only=False)
132 | # Enable CuDNN deterministic mode
133 | torch.backends.cudnn.deterministic = True
134 | torch.backends.cudnn.benchmark = False
135 |
136 |
137 | def print_variant_in_color(seq: str,
138 | wt: str,
139 | ignore_gaps: bool = True) -> None:
140 | """Print a variant in color."""
141 | for j in range(len(wt)):
142 | if seq[j] != wt[j]:
143 | if ignore_gaps and (seq[j] == '-' or seq[j] == 'X'):
144 | continue
145 | print(f'\033[91m{seq[j]}', end='')
146 | else:
147 | print(f'\033[0m{seq[j]}', end='')
148 | print('\033[0m')
149 |
150 |
151 | def timer(func):
152 | @wraps(func)
153 | def timeit_wrapper(*args, **kwargs):
154 | start_time = time.perf_counter()
155 | result = func(*args, **kwargs)
156 | end_time = time.perf_counter()
157 | total_time = end_time - start_time
158 | now = datetime.now().strftime("%d/%m/%Y %H:%M:%S")
159 | print(f'{now}: Function {func.__name__} took {total_time:.4f} seconds')
160 | return result
161 | return timeit_wrapper
162 |
--------------------------------------------------------------------------------
/de/dataio/proteins.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import torch
3 | from lightning import LightningDataModule
4 | from torch.utils.data import Dataset, DataLoader, random_split
5 | from transformers import PreTrainedTokenizer
6 | from typing import Dict, Tuple
7 |
8 |
9 | class ProteinDataset(Dataset):
10 |
11 | def __init__(self, csv_file: str, tokenizer: PreTrainedTokenizer, max_length: int = None):
12 | """
13 | Args:
14 | csv_file (str): Path to the csv file.
15 | """
16 | self.data = pd.read_csv(csv_file)
17 | self.tokenizer = tokenizer
18 | self.max_length = max_length or max(self.data["sequence"].apply(lambda x: len(x)).to_list())
19 |
20 | def __len__(self):
21 | return len(self.data)
22 |
23 | def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
24 | if torch.is_tensor(idx):
25 | idx = idx.tolist()
26 |
27 | sequences = self.data.iloc[idx, 0]
28 | fitnesses = self.data.iloc[idx, 1]
29 | if isinstance(sequences, pd.Series):
30 | sequences = sequences.tolist()
31 | fitnesses = fitnesses.tolist()
32 | input_ids = self.tokenizer(sequences,
33 | add_special_tokens=True,
34 | truncation=True,
35 | padding="max_length",
36 | max_length=self.max_length)["input_ids"]
37 | return {"input_ids": torch.tensor(input_ids, dtype=torch.long),
38 | "fitness": torch.tensor(fitnesses, dtype=torch.float32)}
39 |
40 |
41 | class ProteinsDataModule(LightningDataModule):
42 |
43 | def __init__(self,
44 | csv_file: str,
45 | tokenizer: PreTrainedTokenizer,
46 | max_length: int = None,
47 | train_val_split: Tuple[float, float] = (0.9, 0.1),
48 | train_batch_size: int = 32,
49 | valid_batch_size: int = 32,
50 | num_workers: int = 64,
51 | seed: int = 0):
52 | super().__init__()
53 |
54 | self.save_hyperparameters(logger=False)
55 |
56 | self.train_dataset = None
57 | self.valid_dataset = None
58 |
59 | def setup(self, stage):
60 | datasets = ProteinDataset(self.hparams.csv_file,
61 | self.hparams.tokenizer,
62 | self.hparams.max_length)
63 | self.train_dataset, self.valid_dataset = random_split(
64 | dataset=datasets,
65 | lengths=self.hparams.train_val_split,
66 | generator=torch.Generator().manual_seed(self.hparams.seed)
67 | )
68 |
69 | def train_dataloader(self) -> DataLoader:
70 | return DataLoader(
71 | self.train_dataset,
72 | batch_size=self.hparams.train_batch_size,
73 | num_workers=self.hparams.num_workers,
74 | shuffle=True,
75 | )
76 |
77 | def val_dataloader(self) -> DataLoader:
78 | return DataLoader(
79 | self.valid_dataset,
80 | batch_size=self.hparams.valid_batch_size,
81 | num_workers=self.hparams.num_workers,
82 | shuffle=False,
83 | )
84 |
--------------------------------------------------------------------------------
/de/directed_evolution.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import logging
3 | import numpy as np
4 | import torch
5 | from copy import deepcopy
6 | from datetime import datetime
7 | from operator import itemgetter
8 | from typing import List, Tuple, Union
9 | from transformers import PreTrainedTokenizer
10 | from de.common.utils import timer
11 | from de.samplers.maskers import BaseMasker
12 |
13 |
14 | class DiscreteDirectedEvolution2:
15 | def __init__(self,
16 | n_steps: int,
17 | population: int,
18 | maskers: List[BaseMasker],
19 | mutation_model: torch.nn.Module,
20 | mutation_tokenizer: PreTrainedTokenizer,
21 | fitness_predictor: Union[torch.nn.Module, object],
22 | remove_duplications: bool = False,
23 | k: int = 3,
24 | population_ratio_per_mask: List[float] = None,
25 | num_propose_mutation_per_variant: int = 5,
26 | verbose: bool = False,
27 | num_workers: int = 16,
28 | mutation_device: Union[torch.device, str] = "cpu",
29 | log_dir: str = "./logs/",
30 | seed: int = 0,):
31 | """Main class for Discrete-space Directed Evolution
32 |
33 | Args:
34 | n_steps (int): No. steps to run directed evolution
35 | population (int): No. population per run
36 | verbose (bool): Whether to print output
37 | """
38 | self.n_steps = n_steps
39 | self.population = population
40 | self.maskers = maskers
41 | self.mutation_model = mutation_model
42 | self.mutation_tokenizer = mutation_tokenizer
43 | self.fitness_predictor = fitness_predictor
44 | self.rm_dups = remove_duplications
45 | self.k = k
46 | self.num_propose_mutation_per_variant = num_propose_mutation_per_variant
47 | self.num_workers = num_workers
48 | self.verbose = verbose
49 | self.mutation_device = mutation_device
50 | self.seed = seed
51 | self.population_ratio_per_mask = population_ratio_per_mask
52 | if population_ratio_per_mask is None:
53 | self.population_ratio_per_mask = [1 / len(maskers) for _ in range(len(maskers))]
54 |
55 | # Logging and caching variables
56 | self.mutation_logger = None
57 | self.prev_fitness = None
58 | self.prev_mutants = None
59 | self.prev_variants = None
60 | # Checks
61 | if self.n_steps < 1:
62 | raise ValueError("`n_steps` must be >= 1")
63 | if self.k < 1:
64 | raise ValueError("`k` must be >= 1")
65 |
66 | filename = f"{log_dir}/log_mask={'-'.join([str(msk) for msk in self.population_ratio_per_mask])}_k={k}_beam={num_propose_mutation_per_variant}_{self.seed}.log"
67 | logging.basicConfig(
68 | filename=filename,
69 | level=logging.INFO,
70 | format='%(asctime)s - %(levelname)s - %(message)s',
71 | filemode='w'
72 | )
73 |
74 | @timer
75 | def mask_sequences(
76 | self,
77 | variants: List[str],
78 | ids: List[int]
79 | ) -> Tuple[List[str], List[List[int]]]:
80 | """First step in Directed Evolution
81 | Args:
82 | variants (List[str]): List of sequences to be masked.
83 | ids (List[int]): Corresponding indices of `variants` w.r.t original list.
84 |
85 | Returns:
86 | masked_variants (List[str]): Masked sequences
87 | masked_poses (List[List[int]]): Masked positions.
88 | """
89 | num_variant = len(variants)
90 | if self.verbose:
91 | now = datetime.now().strftime("%d/%m/%Y %H:%M:%S")
92 | print(f"\n{now}: ====== MASK VARIANTS ======")
93 | print(f"{now}: Start masking {num_variant} variants.")
94 |
95 | masked_variants, masked_positions = [], []
96 | begin_idx = 0
97 | for population_ratio, masker in zip(self.population_ratio_per_mask, self.maskers):
98 | sub_population = int(num_variant * population_ratio)
99 | sub_variants = variants[begin_idx:begin_idx + sub_population]
100 | sub_ids = ids[begin_idx:begin_idx + sub_population]
101 | begin_idx += sub_population
102 |
103 | if len(sub_variants) == 0:
104 | continue
105 | masked_vars, masked_pos = masker.run(sub_variants, sub_ids)
106 | masked_variants.extend(masked_vars)
107 | masked_positions.extend(masked_pos)
108 |
109 | return masked_variants, masked_positions
110 |
111 | @timer
112 | def mutate_masked_sequences(
113 | self,
114 | wt_seq: str,
115 | masked_variants: List[str],
116 | masked_positions: List[List[int]]
117 | ) -> Tuple[List[str], List[str]]:
118 | """Second step of Directed Evolution
119 | Args:
120 | wt_seq (str): wild-type sequence.
121 | masked_variants (List[str]): Masked sequences (each has been splitted into k-mers).
122 | masked_positions (List[List[int]]): Masked positions.
123 |
124 | Returns:
125 | mutated_seqs (List[str]): Mutated sequences
126 | mutants (List[str]): List of strings indicates the mutations in each sequence.
127 | """
128 | if self.verbose:
129 | now = datetime.now().strftime("%d/%m/%Y %H:%M:%S")
130 | print(f"\n{now}: ====== MUTATE MASKED POSITION ======")
131 |
132 | # token position
133 | eos_id = self.mutation_tokenizer.eos_token_id
134 | masked_inputs = self.mutation_model.tokenize(masked_variants)
135 | # move to device
136 | masked_inputs.to(self.mutation_device)
137 | with torch.inference_mode():
138 | masked_outputs = self.mutation_model(masked_inputs)
139 | logits = masked_outputs.logits
140 | state = masked_outputs.hidden_states[-1]
141 | # return to cpu
142 | masked_inputs = masked_inputs.to(torch.device("cpu"))
143 | logits = logits.to(torch.device("cpu"))
144 | state = state.to(torch.device("cpu"))
145 |
146 | log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # [N, seq_len, vocab_size]
147 | # actual seq_len are similar => hard fix to prevent prediction at the end of seq.
148 | log_probs[:, -2, eos_id] = -1e6
149 | predicted_toks = torch.argmax(log_probs, dim=-1) # [N, seq_len] (N ~ num_variant)
150 | # get masked positions ( added to the beginning)
151 | masked_positions_tensor = torch.tensor(masked_positions, dtype=torch.int64) + 1
152 | # get mutations
153 | mutations = torch.gather(predicted_toks, dim=1, index=masked_positions_tensor)
154 | mutated_toks = masked_inputs["input_ids"].scatter_(1, masked_positions_tensor, mutations)
155 | mutated_seqs = self.mutation_tokenizer.batch_decode(mutated_toks, skip_special_tokens=True)
156 | mutated_seqs = [seq.replace(" ", "") for seq in mutated_seqs]
157 |
158 | mutants = []
159 | for idx, (posis, seq) in enumerate(zip(masked_positions, mutated_seqs)):
160 | for i in posis:
161 | self.mutation_logger[idx][str(i + 1)] = [wt_seq[i], seq[i]]
162 | mutants = self.logger2mutants(len(mutated_seqs))
163 |
164 | return mutated_seqs, mutants, state
165 |
166 | @timer
167 | def predict_fitness(self,
168 | inputs: Union[str, torch.Tensor],
169 | wt_fitness: float,
170 | mutated_seqs: List[str],
171 | mutants: List[str],
172 | wt_seq: str = None) -> Union[List[str], List[float]]:
173 | """Third step of Directed Evolution
174 | Args:
175 | inputs (str | torch.Tensor): wild-type sequence or sequence representation shape of
176 | (batch, sequence_len, dim).
177 | mutated_seqs (List[str]): Mutated sequences
178 | mutants (List[str]): List of strings indicates the mutations in each sequence.
179 |
180 | Returns:
181 | top_variants (List[str]): List of mutated sequences sorted by fitness score.
182 | top_fitness_score (List[float]): List of fitness score sorted in descending order.
183 | """
184 | if self.verbose:
185 | now = datetime.now().strftime("%d/%m/%Y %H:%M:%S")
186 | print(f"\n{now}: ====== FITNESS PREDICTION ======")
187 |
188 | inputs = inputs.to(self.mutation_device)
189 |
190 | # (batch, 1)
191 | # fitness = self.fitness_predictor.infer_fitness(inputs).detach().cpu()
192 | # fitness = torch.concat([fitness, self.prev_fitness], dim=0)
193 | fitness = torch.tensor(self.fitness_predictor.infer_fitness(mutated_seqs),
194 | dtype=torch.float32)
195 | fitness = fitness.unsqueeze(1) if fitness.ndim == 1 else fitness
196 | fitness = torch.concat([fitness, self.prev_fitness], dim=0)
197 | mutants = mutants + self.prev_mutants
198 | mutated_seqs = mutated_seqs + self.prev_variants
199 |
200 | # Get topk fitness score
201 | k = self.population if len(mutants) >= self.population else len(mutants)
202 | topk_fitness, topk_indices = torch.topk(fitness, k, dim=0)
203 | top_fitness_score = topk_fitness.squeeze(1).numpy().tolist()
204 | top_indices = topk_indices.squeeze(1).numpy().tolist()
205 |
206 | # Fill pool to fit pool size (if needed)
207 | n = 0
208 | if len(top_fitness_score) < self.population:
209 | n = self.population - len(top_fitness_score)
210 | top_fitness_score = [top_fitness_score[0] for _ in range(n)] + top_fitness_score
211 | top_indices = [top_indices[0] for _ in range(n)] + top_indices
212 |
213 | # Get top variants
214 | retriever = itemgetter(*top_indices)
215 | top_variants = list(retriever(mutated_seqs))
216 | top_mutants = list(retriever(mutants))
217 |
218 | # update self.mutation_logger according to saved mutant
219 | self.mutation_logger = self.mutants2logger(top_mutants)
220 | self.prev_fitness = topk_fitness
221 | self.prev_mutants = top_mutants[n:]
222 | self.prev_variants = top_variants[n:]
223 |
224 | return top_variants, top_fitness_score
225 |
226 | def __call__(self, wt_seq: str, wt_fitness: float):
227 | """Run the discrete-space directed evolution
228 |
229 | Args:
230 | wt_seq (str): wild-type protein sequence
231 |
232 | Returns:
233 | variants (List[str]): list of protein sequences
234 | scores (torch.Tensor): scores for the variants
235 | """
236 | if self.verbose:
237 | now = datetime.now().strftime("%d/%m/%Y %H:%M:%S")
238 | print(f"{now}: Wild-type sequence: {wt_seq}")
239 |
240 | # Initialize
241 | variants = [wt_seq for _ in range(self.population)]
242 | self.mutation_logger = [{} for _ in range(self.population)]
243 | self.prev_fitness = torch.tensor([[wt_fitness]], dtype=torch.float32)
244 | self.prev_mutants = [""]
245 | self.prev_variants = [wt_seq]
246 |
247 | for step in range(self.n_steps):
248 | # ============================ #
249 | # ====== PRE-PROCESSING ====== #
250 | # ============================ #
251 | if self.verbose:
252 | now = datetime.now().strftime("%d/%m/%Y %H:%M:%S")
253 | print(f"\n{now}: ====== Step {step + 1} ======")
254 |
255 | variants = list(itertools.chain.from_iterable(
256 | list(deepcopy(i) for _ in range(self.num_propose_mutation_per_variant))
257 | for i in variants
258 | ))
259 | self.mutation_logger = list(itertools.chain.from_iterable(
260 | list(deepcopy(i) for _ in range(self.num_propose_mutation_per_variant))
261 | for i in self.mutation_logger
262 | ))
263 | shuffled_ids = np.random.permutation(len(variants)).tolist()
264 | retriever = itemgetter(*shuffled_ids)
265 | shuffled_variants = list(retriever(variants))
266 | if step != 0:
267 | self.mutation_logger = list(retriever(self.mutation_logger))
268 | del retriever
269 |
270 | # =========================== #
271 | # ====== MASK VARIANTS ====== #
272 | # =========================== #
273 | masked_variants, masked_positions = self.mask_sequences(shuffled_variants, shuffled_ids)
274 |
275 | # ==================================== #
276 | # ====== MUTATE MASKED POSITION ====== #
277 | # ==================================== #
278 | mutated_seqs, mutants, enc_out = self.mutate_masked_sequences(wt_seq,
279 | masked_variants,
280 | masked_positions)
281 |
282 | # Remove duplications if needed
283 | mutated_seqs, mutants, enc_out = self.remove_dups(enc_out, mutated_seqs, mutants)
284 |
285 | # ================================ #
286 | # ====== FITNESS PREDICTION ====== #
287 | # ================================ #
288 | inputs = enc_out
289 | variants, score = self.predict_fitness(
290 | inputs, wt_fitness, mutated_seqs, mutants, wt_seq
291 | )
292 |
293 | logging.info(f"\n-------- STEP {step} --------")
294 | for i, (var, mut, s) in enumerate(zip(variants, self.prev_mutants, score)):
295 | logging.info(f"{i}:\t{s}\t{mut}\t{var}")
296 |
297 | return self.prev_mutants, self.prev_fitness, variants
298 |
299 | def remove_dups(self, enc_out, mutated_seqs, mutants):
300 | candidate_array = np.array(mutated_seqs)
301 | unique_cand, indices = np.unique(candidate_array, return_index=True)
302 | unique_mutated_seqs = unique_cand.tolist()
303 | unique_indices = indices.tolist()
304 |
305 | # Retrieve unique elements based on indices
306 | unique_enc_out = enc_out[unique_indices]
307 | retriever = itemgetter(*unique_indices)
308 | unique_mutants = list(retriever(mutants))
309 | self.mutation_logger = list(retriever(self.mutation_logger))
310 |
311 | return unique_mutated_seqs, unique_mutants, unique_enc_out
312 |
313 | def logger2mutants(self, num2convert: int):
314 | mutants = []
315 | for i in range(num2convert):
316 | mutant = ''
317 | for k, v in self.mutation_logger[i].items():
318 | mutant += v[0] + k + v[1] + ":"
319 | mutants.append(mutant[:-1])
320 | return mutants
321 |
322 | def mutants2logger(self, mutants: List[str]):
323 | logger = [{} for _ in range(len(mutants))]
324 | for idx, mutant in enumerate(mutants):
325 | if len(mutant) == 0:
326 | continue
327 | for m in mutant.split(":"):
328 | before, pos, after = m[0], m[1:-1], m[-1]
329 | logger[idx][pos] = [before, after]
330 | return logger
331 |
--------------------------------------------------------------------------------
/de/predictors/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HySonLab/Directed_Evolution/a9ef68497dd722cfc933eb76dd8a5a75424d97f9/de/predictors/__init__.py
--------------------------------------------------------------------------------
/de/predictors/attention/decoder.py:
--------------------------------------------------------------------------------
1 | """ Code adopted from: `https://github.com/microsoft/protein-sequence-models` """
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class MaskedConv1d(nn.Conv1d):
9 | """ A masked 1-dimensional convolution layer.
10 |
11 | Takes the same arguments as torch.nn.Conv1D, except that the padding is set automatically.
12 |
13 | Shape:
14 | Input: (N, L, in_channels)
15 | input_mask: (N, L, 1), optional
16 | Output: (N, L, out_channels)
17 | """
18 |
19 | def __init__(self,
20 | in_channels: int,
21 | out_channels: int,
22 | kernel_size: int,
23 | stride: int = 1,
24 | dilation: int = 1,
25 | groups: int = 1,
26 | bias: bool = True):
27 | """
28 | Args:
29 | in_channels (int): input channels
30 | out_channels (int): output channels
31 | kernel_size (int): the kernel width
32 | stride (int): filter shift
33 | dilation (int): dilation factor
34 | groups (int): perform depth-wise convolutions
35 | bias (bool): adds learnable bias to output
36 | """
37 | padding = dilation * (kernel_size - 1) // 2
38 | super().__init__(in_channels,
39 | out_channels,
40 | kernel_size,
41 | stride=stride,
42 | dilation=dilation,
43 | groups=groups,
44 | bias=bias,
45 | padding=padding)
46 |
47 | def forward(self, x, input_mask=None):
48 | if input_mask is not None:
49 | x = x * input_mask
50 | return super().forward(x.transpose(1, 2)).transpose(1, 2)
51 |
52 |
53 | class Attention1d(nn.Module):
54 |
55 | def __init__(self, in_dim: int):
56 | super().__init__()
57 | self.layer = MaskedConv1d(in_dim, 1, 1)
58 |
59 | def forward(self, x, input_mask=None):
60 | n, ell, _ = x.shape
61 | attn = self.layer(x)
62 | attn = attn.view(n, -1)
63 | if input_mask is not None:
64 | attn = attn.masked_fill_(~input_mask.view(n, -1).bool(),
65 | float('-inf'))
66 | attn = F.softmax(attn, dim=-1).view(n, -1, 1)
67 | out = (attn * x).sum(dim=1)
68 | return out
69 |
70 |
71 | class Decoder(nn.Module):
72 |
73 | def __init__(self, input_dim: int, hidden_dim: int):
74 | super().__init__()
75 | self.dense_1 = nn.Linear(input_dim, hidden_dim)
76 | self.dense_2 = nn.Linear(hidden_dim, hidden_dim)
77 | self.attention1d = Attention1d(in_dim=hidden_dim)
78 | self.dense_3 = nn.Linear(hidden_dim, hidden_dim)
79 | self.dense_4 = nn.Linear(hidden_dim, 1)
80 |
81 | def forward(self, x):
82 | x = torch.relu(self.dense_1(x))
83 | x = torch.relu(self.dense_2(x))
84 | x = self.attention1d(x)
85 | x = torch.relu(self.dense_3(x))
86 | x = self.dense_4(x)
87 | return x
88 |
--------------------------------------------------------------------------------
/de/predictors/attention/module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from lightning import LightningModule
4 | from torchmetrics import MinMetric, MeanMetric
5 | from torchmetrics.regression.mse import MeanSquaredError
6 | from torchmetrics.regression.mae import MeanAbsoluteError
7 | from typing import Any, List
8 | from .decoder import Decoder
9 | from transformers import EsmModel, AutoTokenizer
10 |
11 |
12 | class ESM2_Attention(nn.Module):
13 | def __init__(self,
14 | pretrained_model_name_or_path: str = "facebook/esm2_t12_35M_UR50D",
15 | hidden_dim: int = 512):
16 | super().__init__()
17 | self.esm = EsmModel.from_pretrained(pretrained_model_name_or_path)
18 | self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
19 | input_dim = self.esm.config.hidden_size
20 | self.decoder = Decoder(input_dim, hidden_dim)
21 |
22 | def freeze_encoder(self):
23 | for param in self.esm.parameters():
24 | param.requires_grad = False
25 |
26 | def forward(self, x):
27 | enc_out = self.esm(x).last_hidden_state
28 | output = self.decoder(enc_out)
29 | return output
30 |
31 |
32 | class ESM2DecoderModule(LightningModule):
33 | def __init__(self,
34 | net: nn.Module,
35 | optimizer: torch.optim.Optimizer):
36 | super().__init__()
37 |
38 | # this line allows to access init params with 'self.hparams' attribute
39 | # also ensures init params will be stored in ckpt
40 | self.save_hyperparameters(ignore=["net"])
41 | self.net = net
42 | # loss function
43 | self.criterion = torch.nn.MSELoss()
44 |
45 | # metric objects for calculating and averaging error
46 | self.train_mae = MeanAbsoluteError()
47 | self.valid_mae = MeanAbsoluteError()
48 | self.valid_mse = MeanSquaredError()
49 |
50 | # averaging loss across batches
51 | self.train_loss = MeanMetric()
52 | self.val_loss = MeanMetric()
53 |
54 | # for tracking best so far
55 | self.val_mae_best = MinMetric()
56 | self.val_mse_best = MinMetric()
57 |
58 | def forward(self, x):
59 | return self.net(x)
60 |
61 | def on_train_start(self):
62 | self.val_loss.reset()
63 | self.valid_mae.reset()
64 | self.valid_mse.reset()
65 | self.val_mse_best.reset()
66 | self.val_mae_best.reset()
67 |
68 | def model_step(self, batch):
69 | x, y = batch["input_ids"], batch["fitness"]
70 | y = y.unsqueeze(1)
71 | pred = self.forward(x)
72 | loss = self.criterion(pred, y)
73 | return loss, pred, y
74 |
75 | def training_step(self, batch, batch_idx):
76 | loss, preds, targets = self.model_step(batch)
77 |
78 | # update and log metrics
79 | self.train_loss(loss)
80 | self.train_mae(preds, targets)
81 | self.log("train_loss", self.train_loss, on_step=True, on_epoch=True, prog_bar=True)
82 | self.log("train_mae", self.train_mae, on_step=True, on_epoch=True, prog_bar=True)
83 |
84 | # return loss
85 | return loss
86 |
87 | def validation_step(self, batch, batch_idx):
88 | loss, preds, targets = self.model_step(batch)
89 |
90 | # update and log metrics
91 | self.val_loss(loss)
92 | self.valid_mae(preds, targets)
93 | self.valid_mse(preds, targets)
94 | self.log("val_loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
95 | self.log("val_mae", self.valid_mae, on_step=False, on_epoch=True, prog_bar=True)
96 |
97 | def on_validation_epoch_end(self) -> None:
98 | mae = self.valid_mae.compute() # get current mae
99 | mse = self.valid_mse.compute() # get current mse
100 | self.val_mae_best(mae)
101 | self.val_mse_best(mse)
102 | self.log("val_mae_best", self.val_mae_best.compute(), sync_dist=True, prog_bar=True)
103 | self.log("val_mse_best", self.val_mse_best.compute(), sync_dist=True, prog_bar=True)
104 |
105 | def configure_optimizers(self) -> Any:
106 | optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
107 | return {"optimizer": optimizer}
108 |
109 | def predict_fitness(self, representation: torch.Tensor):
110 | fitness = self.net.decoder(representation)
111 | return fitness
112 |
113 | def infer_fitness(self, seqs: List[str]):
114 | with torch.inference_mode():
115 | inputs = self.net.tokenizer(seqs, return_tensors="pt").to(self.device)
116 | repr = self.net.esm(**inputs).last_hidden_state
117 | outputs = self.predict_fitness(repr)
118 | return outputs.cpu()
119 |
--------------------------------------------------------------------------------
/de/predictors/oracle.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import torch.nn as nn
5 | from transformers import AutoTokenizer, EsmModel
6 | from typing import List, Union
7 | # from .attention.decoder import Decoder
8 | from de.common.utils import get_mutants
9 | from de.predictors.attention.decoder import Decoder
10 |
11 |
12 | class ESM1b_Attention1d(nn.Module):
13 |
14 | def __init__(self):
15 | super(ESM1b_Attention1d, self).__init__()
16 | self.encoder = EsmModel.from_pretrained("facebook/esm1b_t33_650M_UR50S")
17 | self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S")
18 | self.decoder = Decoder(input_dim=1280, hidden_dim=512)
19 |
20 | def forward(self, inputs):
21 | x = self.encoder(**inputs).last_hidden_state
22 | x = self.decoder(x)
23 | return x
24 |
25 |
26 | class ESM1b_Landscape:
27 | """
28 | An ESM-based oracle model to simulate protein fitness landscape.
29 | """
30 |
31 | def __init__(self, task: str, device: Union[str, torch.device]):
32 | task_dir_path = os.path.join('./landscape_params/esm1b_landscape', task)
33 | task_dir_path = os.path.abspath(task_dir_path)
34 | assert os.path.exists(os.path.join(task_dir_path, 'decoder.pt'))
35 | self.model = ESM1b_Attention1d()
36 | self.model.decoder.load_state_dict(
37 | torch.load(os.path.join(task_dir_path, 'decoder.pt'))
38 | )
39 | with open(os.path.join(task_dir_path, 'starting_sequence.json')) as f:
40 | self.starting_sequence = json.load(f)
41 |
42 | self.tokenizer = self.model.tokenizer
43 | self.device = device
44 | self.model.to(self.device)
45 |
46 | def infer_fitness(self, sequences: List[str], batch_size: int = 16, device=None):
47 | # Input: - sequences: [query_batch_size, sequence_length]
48 | # Output: - fitness_scores: [query_batch_size]
49 |
50 | self.model.eval()
51 | fitness_scores = []
52 | seqs = [sequences[i:i + batch_size] for i in range(0, len(sequences), batch_size)]
53 | for seq in seqs:
54 | inputs = self.tokenizer(seq, return_tensors="pt").to(self.device)
55 | fitness = self.model(inputs).cpu().tolist()
56 | fitness_scores.extend(fitness)
57 | # fitness_scores.append(self.model(inputs).item())
58 | return fitness_scores
59 |
60 |
61 | class ESM1v:
62 |
63 | def __init__(self, model_name: str, device, method: str, offset_idx: int):
64 | self.tokenizer = AutoTokenizer.from_pretrained(f"facebook/{model_name}")
65 | self.model = EsmModel.from_pretrained(f"facebook/{model_name}")
66 | self.model.eval()
67 | self.model = self.model.to(device)
68 | self.device = device
69 | self.method = method
70 | self.offset_idx = offset_idx
71 |
72 | def compute_pppl(self, variants: List[str]):
73 | log_probs = []
74 | mask_id = self.tokenizer._token_to_id[""]
75 | inputs = self.tokenizer(variants, return_tensors="pt").to(self.device)
76 | input_ids, attention_mask = inputs.input_ids, inputs.attention_mask
77 |
78 | for i in range(1, len(variants[0]) - 1):
79 | token_ids = input_ids[:, i].unsqueeze(1)
80 | batch_token_masked = input_ids.clone()
81 | batch_token_masked[:, i] = mask_id
82 |
83 | with torch.inference_mode():
84 | logits = self.model(batch_token_masked, attention_mask).last_hidden_state
85 | token_probs = torch.log_softmax(logits, dim=-1)[:, i]
86 | token_probs = torch.gather(token_probs, dim=1, index=token_ids)
87 |
88 | log_probs.append(token_probs)
89 |
90 | return torch.sum(torch.concat(log_probs, dim=1), dim=1).cpu().tolist()
91 |
92 | def compute_masked_marginals(self, wt_seq: str, mutants: List[str]):
93 | all_token_probs = []
94 | mask_id = self.tokenizer._token_to_id[""]
95 | inputs = self.tokenizer(wt_seq, return_tensors="pt").to(self.device)
96 | input_ids, attention_mask = inputs.input_ids, inputs.attention_mask
97 | for i in range(input_ids.size(1)):
98 | batch_token_masked = input_ids.clone()
99 | batch_token_masked[:, i] = mask_id
100 |
101 | with torch.inference_mode():
102 | logits = self.model(batch_token_masked, attention_mask).last_hidden_state
103 | token_probs = torch.log_softmax(logits, dim=-1)[:, i]
104 |
105 | all_token_probs.append(token_probs)
106 |
107 | token_probs = torch.cat(all_token_probs, dim=0)
108 | scores = []
109 | for mutant in mutants:
110 | ms = mutant.split(":")
111 | score = 0
112 | for row in ms:
113 | if len(row) == 0:
114 | continue
115 | wt, idx, mt = row[0], int(row[1:-1]) - self.offset_idx, row[-1]
116 | assert wt_seq[idx] == wt
117 |
118 | wt_encoded, mt_encoded = self.tokenizer._token_to_id[wt], self.tokenizer._token_to_id[mt]
119 | mt_score = token_probs[1 + idx, mt_encoded] - token_probs[1 + idx, wt_encoded]
120 | score = score + mt_score.item()
121 |
122 | scores.append(score)
123 |
124 | return scores
125 |
126 | def infer_fitness(self, sequences: List[str], wt_seq: str = None, device=None):
127 | if self.method == "pseudo":
128 | scores = self.compute_pppl(sequences)
129 | elif self.method == "masked":
130 | assert wt_seq is not None, "wt_seq must be provided when using masked marginal."
131 | mutants = [get_mutants(wt_seq, seq, self.offset_idx) for seq in sequences]
132 | scores = self.compute_masked_marginals(wt_seq, mutants)
133 | else:
134 | raise ValueError("method is not supported")
135 | return scores
136 |
137 |
138 | if __name__ == "__main__":
139 | import sys
140 | import pandas as pd
141 | from de.common.utils import get_mutated_sequence
142 |
143 | csv_file = sys.argv[1]
144 |
145 | device = torch.device("cuda:0")
146 | landscape = ESM1b_Landscape("AAV", device)
147 |
148 | df = pd.read_csv(csv_file)
149 | df["mutated"] = df.apply(lambda x: get_mutated_sequence(x["WT"], x.mutants), axis=1)
150 | opt_score = df["score"].tolist()
151 | mutated_seqs = df["mutated"].tolist()
152 |
153 | scores = landscape.infer_fitness(mutated_seqs)
154 | results = {"mutated": mutated_seqs, "opt_score": opt_score, "eval_score": scores}
155 | df = pd.DataFrame.from_dict(results)
156 | target_path = os.path.join(os.path.dirname(csv_file), "tmp.csv")
157 | df.to_csv(target_path)
158 |
--------------------------------------------------------------------------------
/de/samplers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HySonLab/Directed_Evolution/a9ef68497dd722cfc933eb76dd8a5a75424d97f9/de/samplers/__init__.py
--------------------------------------------------------------------------------
/de/samplers/maskers/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import BaseMasker
2 | from .random import RandomMasker2
3 | from .importance import ImportanceMasker2
4 |
5 |
6 | __all__ = [
7 | "BaseMasker",
8 | "RandomMasker2",
9 | "ImportanceMasker2"
10 | ]
11 |
--------------------------------------------------------------------------------
/de/samplers/maskers/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import List
3 |
4 |
5 | class BaseMasker(ABC):
6 | """Base class for maskers."""
7 | @abstractmethod
8 | def run(self,
9 | population: List[str],
10 | indices: List[int] = None):
11 | """
12 | Args:
13 | population (List[str]): List of sequences to be masked
14 | indices (List[int]): List of indices of each sequence in original population.
15 | Returns:
16 | masked_population (List[str]): List of masked sequence
17 | masked_poses (List[List[int]]): List of masked positions for each sequence.
18 | """
19 | raise NotImplementedError
20 |
--------------------------------------------------------------------------------
/de/samplers/maskers/importance.py:
--------------------------------------------------------------------------------
1 | import math
2 | import itertools
3 | from sklearn.feature_extraction.text import TfidfVectorizer
4 | from typing import Dict, List
5 | from .base import BaseMasker
6 | from ...common.utils import split_kmers2
7 |
8 |
9 | class ImportanceMasker2(BaseMasker):
10 | def __init__(self,
11 | k: int = 3,
12 | max_subs: int = 5,
13 | mask_token: str = "",
14 | low_importance_mask: bool = True):
15 | # TODO: mask by assigning weight by the importance?
16 | self.k = k
17 | self.max_subs = max_subs
18 | self.mask_token = mask_token
19 | self.low_importance_mask = low_importance_mask
20 | # calculate the importance
21 | self.importances = None
22 | # cache importance of kmer (as we do not alter every kmer)
23 | self.cache = None
24 | # TF-IDF does not filter out stand-alone amino acid.
25 | self.tfidf = TfidfVectorizer(lowercase=False, token_pattern=r"(?u)\b\w+\b")
26 | self.actual_vocabs = None
27 |
28 | def _measure_importance(self, sequences: List[List[str]]):
29 | """Inspired by paper
30 | `A Cheaper and Better Diffusion Language Model with Soft-Masked Noise`
31 | """
32 | merge_seqs = [' '.join(seq) for seq in sequences]
33 | # Run TF-IDF
34 | tfidfs = self.tfidf.fit_transform(merge_seqs)
35 | self.actual_vocabs = {
36 | name: idx for idx, name in enumerate(self.tfidf.get_feature_names_out())
37 | }
38 | # Get entropy
39 | kmer2entropy = self._get_entropy_of_unique_tokens(sequences)
40 |
41 | # Measure importance
42 | importances = []
43 | for seq_idx, seq in enumerate(sequences):
44 | kmer2imp = dict()
45 | setseq = list(set(seq))
46 | seq_tfidf = tfidfs[seq_idx].sum()
47 | seq_entropy = 0
48 | seq_tfidfs = []
49 | for kmer in setseq:
50 | # Temporary
51 | try:
52 | kmer_idx = self.actual_vocabs[kmer]
53 | except KeyError:
54 | self.actual_vocabs[kmer] = len(self.actual_vocabs)
55 | kmer_idx = self.actual_vocabs[kmer]
56 |
57 | tfidf = tfidfs[seq_idx, kmer_idx]
58 | seq_tfidfs.append(tfidf)
59 | seq_entropy += kmer2entropy[kmer]
60 |
61 | for kmer, tfidf in zip(setseq, seq_tfidfs):
62 | try:
63 | kmer2imp[kmer] = tfidf / seq_tfidf + kmer2entropy[kmer] / seq_entropy
64 | except ZeroDivisionError:
65 | kmer2imp[kmer] = tfidf / seq_tfidf
66 |
67 | importances.append(kmer2imp)
68 |
69 | return importances
70 |
71 | def _get_entropy_of_unique_tokens(self, seqs: List[List[str]]):
72 | bag_of_toks = list(itertools.chain.from_iterable(seqs))
73 | set_toks = set(bag_of_toks)
74 | count = {tok: bag_of_toks.count(tok) for tok in set_toks}
75 |
76 | entropy = {}
77 | for k, v in count.items():
78 | prob = v / len(bag_of_toks)
79 | entropy[k] = -1.0 * prob * math.log(prob)
80 |
81 | return entropy
82 |
83 | def mask_sequence(self,
84 | org_seq: str,
85 | kmer_seq: List[str],
86 | kmer2imp: Dict):
87 | """Mask sequence based on kmer's importance.
88 | Default is to mask kmers with low importances.
89 |
90 | Args:
91 | org_seq (str): Protein sequence.
92 | kmer_seq (List[str]): List of overlapping k-mers.
93 | kmer2imp (Dict): A dictionary map kmer with its importance in the sequence.
94 |
95 | Returns:
96 | seq (str): Masked protein sequence.
97 | pos_to_mutate (List[int]): Masked positions.
98 | """
99 | if self.k > 1:
100 | assert self.max_subs == 1, "Only substitute 1 k-mer at a time for k > 1."
101 |
102 | if self.low_importance_mask:
103 | sorted_kmers_by_imps = sorted(kmer2imp.items(), key=lambda x: x[1])
104 | else:
105 | sorted_kmers_by_imps = sorted(kmer2imp.items(), key=lambda x: x[1], reverse=True)
106 | sorted_kmers_by_imps = dict(sorted_kmers_by_imps)
107 |
108 | positions = []
109 | curr_idx, start_pos = 0, 0
110 | lseq = list(org_seq)
111 | for _ in range(self.max_subs):
112 | try:
113 | pos = kmer_seq.index(list(sorted_kmers_by_imps.keys())[curr_idx], start_pos)
114 | except ValueError:
115 | curr_idx += 1
116 | start_pos = 0
117 | pos = kmer_seq.index(list(sorted_kmers_by_imps.keys())[curr_idx], start_pos)
118 | finally:
119 | lseq[pos:pos + self.k] = [self.mask_token] * self.k
120 | positions.append(pos)
121 | start_pos = pos + 1
122 |
123 | if self.k == 1:
124 | return ''.join(lseq), positions
125 | else:
126 | return ''.join(lseq), list(range(positions[0], positions[0] + self.k))
127 |
128 | def run(self,
129 | population: List[str],
130 | indices: List[int] = None):
131 | kmer_population = split_kmers2(population, k=self.k)
132 | importances = self._measure_importance(kmer_population)
133 |
134 | masked_population = []
135 | masked_positions = []
136 | for kmer2imp, seq, pop in zip(importances, kmer_population, population):
137 | new_seq, masked_pos = self.mask_sequence(pop, seq, kmer2imp)
138 | masked_population.append(new_seq)
139 | masked_positions.append(masked_pos)
140 | return masked_population, masked_positions
141 |
--------------------------------------------------------------------------------
/de/samplers/maskers/random.py:
--------------------------------------------------------------------------------
1 | import random
2 | from typing import List
3 | from .base import BaseMasker
4 |
5 |
6 | class RandomMasker2(BaseMasker):
7 | def __init__(self, k: int = 1, max_subs: int = 5, mask_token: str = ""):
8 | self.k = k
9 | self.mask_token = mask_token
10 | self.max_subs = max_subs
11 |
12 | def mask_random_pos(self, seq: str):
13 | """Mask random positions in the protein sequence
14 |
15 | Args:
16 | seq (List[str]): Protein sequence.
17 |
18 | Returns:
19 | seq (str): Masked protein sequence.
20 | pos_to_mutate (List[int]): Masked positions.
21 | """
22 | if self.k > 1:
23 | assert self.max_subs == 1, "Only substitute 1 k-mer at a time for k > 1."
24 |
25 | lseq = list(seq)
26 | min_pos = 0
27 | max_pos = len(lseq) - self.k + 1
28 |
29 | candidate_masked_pos = list(range(min_pos, max_pos))
30 | random.shuffle(candidate_masked_pos)
31 | pos_to_mutate = candidate_masked_pos[:self.max_subs]
32 |
33 | for i in range(self.max_subs):
34 | pos = pos_to_mutate[i]
35 | lseq[pos:pos + self.k] = [self.mask_token] * self.k
36 |
37 | if self.k == 1:
38 | return ''.join(lseq), pos_to_mutate
39 | else:
40 | return ''.join(lseq), list(range(pos_to_mutate[0], pos_to_mutate[0] + self.k))
41 |
42 | def run(self,
43 | population: List[str],
44 | indices: List[int] = None):
45 | masked_population = []
46 | masked_positions = []
47 | for seq in population:
48 | new_seq, masked_pos = self.mask_random_pos(seq)
49 | masked_population.append(new_seq)
50 | masked_positions.append(masked_pos)
51 | return masked_population, masked_positions
52 |
--------------------------------------------------------------------------------
/de/samplers/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HySonLab/Directed_Evolution/a9ef68497dd722cfc933eb76dd8a5a75424d97f9/de/samplers/models/__init__.py
--------------------------------------------------------------------------------
/de/samplers/models/esm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import AutoTokenizer, EsmForMaskedLM, BatchEncoding
3 | from typing import List
4 |
5 |
6 | class ESM2(torch.nn.Module):
7 | def __init__(self, pretrained_model_name_or_path: str = "facebook/esm2_t12_35M_UR50D"):
8 | """
9 | Args:
10 | pretrained_model_name_or_path (str): Pre-trained model to load.
11 | """
12 | super(ESM2, self).__init__()
13 | assert pretrained_model_name_or_path is not None
14 | self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
15 | self.model = EsmForMaskedLM.from_pretrained(pretrained_model_name_or_path)
16 |
17 | def tokenize(self, inputs: List[str]) -> BatchEncoding:
18 | """Convert inputs to a format suitable for the model.
19 |
20 | Args:
21 | inputs (List[str]): A list of protein sequence strings of len [population].
22 |
23 | Returns:
24 | encoded_inputs (BatchEncoding): a BatchEncoding object.
25 | """
26 | encoded_inputs = self.tokenizer(inputs,
27 | add_special_tokens=True,
28 | return_tensors="pt",
29 | padding=True)
30 | return encoded_inputs
31 |
32 | def decode(self, tokens: torch.Tensor) -> List[str]:
33 | """Decode predicted tokens into alphabet characters
34 |
35 | Args:
36 | tokens (torch.Tensor): Predicted tokens of shape [batch, sequence_length]
37 |
38 | Returns:
39 | (List[str]): Predicted characters.
40 | """
41 | return self.tokenizer.batch_decode(tokens, skip_special_tokens=True)
42 |
43 | def forward(self, inputs: BatchEncoding) -> torch.Tensor:
44 | """Forward pass of ESM2 model
45 |
46 | Args:
47 | inputs (BatchEncoding): Output of tokenizer.
48 |
49 | Returns:
50 | logits (torch.Tensor): Logits.
51 | """
52 | results = self.model(output_hidden_states=True, **inputs)
53 | return results
54 |
--------------------------------------------------------------------------------
/de/version.py:
--------------------------------------------------------------------------------
1 | """This file is auto-generated by setup.py, please do not alter."""
2 | __version__ = "1.0.0"
3 |
--------------------------------------------------------------------------------
/exps/checkpoints/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HySonLab/Directed_Evolution/a9ef68497dd722cfc933eb76dd8a5a75424d97f9/exps/checkpoints/.gitkeep
--------------------------------------------------------------------------------
/exps/logs/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HySonLab/Directed_Evolution/a9ef68497dd722cfc933eb76dd8a5a75424d97f9/exps/logs/.gitkeep
--------------------------------------------------------------------------------
/exps/results/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HySonLab/Directed_Evolution/a9ef68497dd722cfc933eb76dd8a5a75424d97f9/exps/results/.gitkeep
--------------------------------------------------------------------------------
/preprocessed_data/AAV/AAV_reference_sequence.txt:
--------------------------------------------------------------------------------
1 | DEEEIRTTNPVATEQYGSVSTNLQRGNR
2 | -2.731
--------------------------------------------------------------------------------
/preprocessed_data/AMIE/AMIE_reference_sequence.txt:
--------------------------------------------------------------------------------
1 | MRHGDISSSNDTVGVAVVNYKMPRLHTAAEVLDNARKIAEMIVGMKQGLPGMDLVVFPEYSLQGIMYDPAEMMETAVAIPGEETEIFSRACRKANVWGVFSLTGERHEEHPRKAPYNTLVLIDNNGEIVQKYRKIIPWCPIEGWYPGGQTYVSEGPKGMKISLIICDDGNYPEIWRDCAMKGAELIVRCQGYMYPAKDQQVMMAKAMAWANNCYVAVANAAGFDGVYSYFGHSAIIGFDGRTLGECGEEEMGIQYAQLSLSQIRDARANDQSQNHRFKILHRGYSGLQASGDGDRGLAECPFEFYRTWVTDAEKARENVERLTRSTTGVAQCPVGRLPYEG
2 | -2.789
--------------------------------------------------------------------------------
/preprocessed_data/E4B/E4B_reference_sequence.txt:
--------------------------------------------------------------------------------
1 | IEKFKLLAEKVEEIVAKNARAEIDYSDAPDEFRDPLMDTLMTDPVRLPSGVTMDRSIILRHLLNSPTDPFNRQMLTESMLEPVPELKEQIQAWMREKQSSDH
2 | 0.774
--------------------------------------------------------------------------------
/preprocessed_data/LGK/LGK_reference_sequence.txt:
--------------------------------------------------------------------------------
1 | MPIATSTGDNVLDFTVLGLNSGTSMDGIDCALCHFYQKTPDAPMEFELLEYGEVPLAQPIKQRVMRMILEDTTSPSELSEVNVILGEHFADAVRQPAAERNVDLSTIDAIASHGQTIWLLSMPEEGQVKSALTMAEGAIIAARTGITSITDFRISDQAAGRQGAPLIAFFDALLLHHPTKLRACQNIGGIANVCFIPPDVDGRRTDEYYDFDTGPGNVFIDAVVRHFTNGEQEYDKDGAMGKRGKVDQELVDDFLKMPYFQLDPPKTTGREVFRDTLAHDLIRRAEAKGLSPDDIVATTTRITAQAIVDHYRRYAPSQEIDEIFMCGGGAYNPNIVEFIQQSYPNTKIMMLDEAGVPAGAKEAITFAWQGMECLVGRSIPVPTRVETRQHYVLGKVSPGLNYRSVMKKGMAFGGDAQQLPWVSEMIVKKKGKVITNNWA
2 | -1.260
--------------------------------------------------------------------------------
/preprocessed_data/Pab1/Pab1_reference_sequence.txt:
--------------------------------------------------------------------------------
1 | GNIFIKNLHPDIDNKALYDTFSVFGDILSSKIAPDENGKSKGFGFVPFEEEGAAKEAIDALNGMLLNGQEIYVAP
2 | 0.014
--------------------------------------------------------------------------------
/preprocessed_data/TEM/TEM_reference_sequence.txt:
--------------------------------------------------------------------------------
1 | MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDRWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW
2 | 1.084
--------------------------------------------------------------------------------
/preprocessed_data/UBE2I/UBE2I_reference_sequence.txt:
--------------------------------------------------------------------------------
1 | MSGIALSRLAQERKAWRKDHPFGFVAVPTKNPDGTMNLMNWECAIPGKKGTPWEGGLFKLRMLFKDDYPSSPPKCKFEPPLFHPNVYPSGTVCLSILEEDKDWRPAITIKQILLGIQELLNEPNIQDPAQAEAYTIYCQNRVEYEKRVRAQAKKFAPSY
2 | 0.766
--------------------------------------------------------------------------------
/preprocessed_data/avGFP/avGFP_reference_sequence.txt:
--------------------------------------------------------------------------------
1 | SKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTLSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK
2 | 3.677
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.13.1
2 | lightning==2.0.8
3 | transformers[torch]==4.31.0
4 | scikit-learn==1.3.0
5 | scipy==1.11.2
6 | biopython
7 | polyleven
8 | wandb
9 | matplotlib
10 | seaborn
11 | tqdm
--------------------------------------------------------------------------------
/scripts/preprocess/preprocess_AAV.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import numpy as np
3 | import pandas as pd
4 | import os
5 |
6 |
7 | def get_aa_sequence(filepath):
8 | with open(filepath, "r") as f:
9 | seq = f.readlines()[0].strip()
10 | return seq
11 |
12 |
13 | def generate_data(data_file):
14 | df = pd.read_csv(data_file)
15 |
16 | # preprocess data
17 | df.replace([np.inf, -np.inf], np.nan, inplace=True)
18 | df.dropna(inplace=True)
19 | sequences = [seq.upper() for seq in df["sequence"].to_list()]
20 | fitnesses = df["viral_selection"].to_list()
21 |
22 | return {"sequence": sequences, "fitness": fitnesses}
23 |
24 |
25 | if __name__ == "__main__":
26 | # Files
27 | data_dir = sys.argv[1]
28 | seq_file = os.path.join(data_dir, "AAV_reference_sequence.txt")
29 | data_file = os.path.join(data_dir, "allseqs_20191230.csv")
30 | out_file = os.path.join(data_dir, "AAV.csv")
31 |
32 | # Generate data
33 | seq2fit = generate_data(data_file)
34 | mut_df = pd.DataFrame.from_dict(seq2fit)
35 | # Drop duplications
36 | mut_df.drop_duplicates(subset="sequence", inplace=True, ignore_index=True)
37 | mut_df.to_csv(out_file, index=False)
38 |
--------------------------------------------------------------------------------
/scripts/preprocess/preprocess_AMIE.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import pandas as pd
3 | import os
4 | from de.common.utils import get_mutated_sequence
5 |
6 |
7 | def get_substrate(type):
8 | if type is None or type == "A":
9 | return "Acetamide"
10 | elif type == "I":
11 | return "Isobutyramide"
12 | elif type == "P":
13 | return "Propionamide"
14 | else:
15 | raise ValueError(f"Substrate type {type} is not supported. Choices are 'A', 'I', and 'P'")
16 |
17 |
18 | def get_aa_sequence(filepath):
19 | with open(filepath, "r") as f:
20 | seq = f.readlines()[0].strip()
21 | return seq
22 |
23 |
24 | def generate_data(wt_seq, data_file):
25 | df = pd.read_csv(data_file, sep="\t")
26 | # preprocess df
27 | df = df[df["mutation"] != "*"]
28 | df = df[df["normalized_fitness"] != "NS"]
29 |
30 | sequences = []
31 | fitnesses = []
32 |
33 | for i in range(len(df)):
34 | # get vars
35 | loc = df["location"].iloc[i]
36 | wt_aa = wt_seq[loc - 1]
37 | new_aa = df["mutation"].iloc[i]
38 | fitness = float(df["normalized_fitness"].iloc[i])
39 |
40 | mut = wt_aa + str(loc) + new_aa
41 | mut_seq = get_mutated_sequence(wt_seq, mut)
42 |
43 | sequences.append(mut_seq)
44 | fitnesses.append(fitness)
45 |
46 | return {"sequence": sequences, "fitness": fitnesses}
47 |
48 |
49 | if __name__ == "__main__":
50 | # Files
51 | data_dir = sys.argv[1]
52 | substrate = get_substrate(sys.argv[2])
53 | seq_file = os.path.join(data_dir, "amiE_reference_sequence.txt")
54 | data_file = os.path.join(data_dir, f"amiESelectionFitnessData_{substrate}.txt")
55 | out_file = os.path.join(data_dir, f"amiE_{substrate}.csv")
56 |
57 | # Get protein sequence
58 | wt_seq = get_aa_sequence(seq_file)
59 |
60 | # Generate data
61 | seq2fit = generate_data(wt_seq, data_file)
62 | mut_df = pd.DataFrame.from_dict(seq2fit)
63 | # Drop duplications
64 | mut_df.drop_duplicates(subset="sequence", inplace=True, ignore_index=True)
65 | mut_df.to_csv(out_file, index=False)
66 |
--------------------------------------------------------------------------------
/scripts/preprocess/preprocess_E4B.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import pandas as pd
3 | import os
4 | from de.common.utils import get_mutated_sequence
5 |
6 |
7 | def get_aa_sequence(filepath):
8 | with open(filepath, "r") as f:
9 | seq = f.readlines()[0].strip()
10 | return seq
11 |
12 |
13 | def generate_data(wt_seq, data_file):
14 | df = pd.read_csv(data_file, sep="\t")
15 | df.dropna(subset="log2_ratio", inplace=True)
16 |
17 | sequences = []
18 | fitnesses = []
19 |
20 | def convert2mutant(mutations):
21 | context = mutations.split("-")
22 | locs = [int(loc) for loc in context[0].split(",")]
23 | aas = context[1].split(",")
24 | if "*" in aas:
25 | return None
26 | mutants = ""
27 | for loc, aa in zip(locs, aas):
28 | mutants = mutants + f"{wt_seq[loc]}{loc + 1}{aa}" + ":"
29 | return mutants[:-1]
30 |
31 | for i in range(len(df)):
32 | mutations = df["seqID"].iloc[i]
33 | mutant = convert2mutant(mutations)
34 | if mutant is None:
35 | continue
36 | seq = get_mutated_sequence(wt_seq, mutant)
37 | sequences.append(seq)
38 | fitnesses.append(df["log2_ratio"].iloc[i])
39 |
40 | return {"sequence": sequences, "fitness": fitnesses}
41 |
42 |
43 | if __name__ == "__main__":
44 | # Files
45 | data_dir = sys.argv[1]
46 | seq_file = os.path.join(data_dir, "E4B_reference_sequence.txt")
47 | data_file = os.path.join(data_dir, "1303309110_sd01.tsv")
48 | out_file = os.path.join(data_dir, "E4B.csv")
49 |
50 | # Get protein sequence
51 | wt_seq = get_aa_sequence(seq_file)
52 |
53 | # Generate data
54 | seq2fit = generate_data(wt_seq, data_file)
55 | mut_df = pd.DataFrame.from_dict(seq2fit)
56 | # Drop duplications
57 | mut_df.drop_duplicates(subset="sequence", inplace=True, ignore_index=True)
58 | mut_df.to_csv(out_file, index=False)
59 |
--------------------------------------------------------------------------------
/scripts/preprocess/preprocess_LGK.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import pandas as pd
3 | import os
4 | from de.predictors.tranception.utils.scoring_utils import get_mutated_sequence
5 |
6 |
7 | def get_aa_sequence(filepath):
8 | with open(filepath, "r") as f:
9 | seq = f.readlines()[0].strip()
10 | return seq
11 |
12 |
13 | def generate_data(wt_seq, data_file):
14 | df = pd.read_csv(data_file)
15 | # preprocess df
16 | df = df[df["Mutation"] != "*"]
17 | df = df[df["Normalized_ER"] != "NS"]
18 |
19 | sequences = []
20 | fitnesses = []
21 |
22 | for i in range(len(df)):
23 | # get vars
24 | loc = df["Location"].iloc[i]
25 | wt_aa = wt_seq[loc]
26 | new_aa = df["Mutation"].iloc[i]
27 | fitness = float(df["Normalized_ER"].iloc[i])
28 |
29 | mut = wt_aa + str(loc + 1) + new_aa
30 | mut_seq = get_mutated_sequence(wt_seq, mut)
31 |
32 | sequences.append(mut_seq)
33 | fitnesses.append(fitness)
34 |
35 | return {"sequence": sequences, "fitness": fitnesses}
36 |
37 |
38 | if __name__ == "__main__":
39 | # Files
40 | data_dir = sys.argv[1]
41 | seq_file = os.path.join(data_dir, "LGK_reference_sequence.txt")
42 | data_file = os.path.join(data_dir, "raw.csv")
43 | out_file = os.path.join(data_dir, "LGK.csv")
44 |
45 | # Get protein sequence
46 | wt_seq = get_aa_sequence(seq_file)
47 |
48 | # Generate data
49 | seq2fit = generate_data(wt_seq, data_file)
50 | mut_df = pd.DataFrame.from_dict(seq2fit)
51 | # Drop duplications
52 | mut_df.drop_duplicates(subset="sequence", inplace=True, ignore_index=True)
53 | mut_df.to_csv(out_file, index=False)
54 |
--------------------------------------------------------------------------------
/scripts/preprocess/preprocess_Pab1.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import pandas as pd
3 | import os
4 | from de.predictors.tranception.utils.scoring_utils import get_mutated_sequence
5 |
6 |
7 | def get_aa_sequence(filepath):
8 | with open(filepath, "r") as f:
9 | seq = f.readlines()[0].strip()
10 | return seq
11 |
12 |
13 | def generate_data(wt_seq, data_file):
14 | xlsx = pd.ExcelFile(data_file)
15 | df = pd.read_excel(xlsx, "All_Epistasis")
16 |
17 | sequences = []
18 | fitnesses = []
19 |
20 | def convert2mutant(mutations):
21 | context = mutations.split("-")
22 | locs = [int(loc) - 126 for loc in context[0].split(",")]
23 | aas = context[1].split(",")
24 | if "*" in aas:
25 | return None
26 | mutants = ""
27 | for loc, aa in zip(locs, aas):
28 | mutants = mutants + f"{wt_seq[loc]}{loc + 1}{aa}" + ":"
29 | return mutants[:-1]
30 |
31 | for i in range(len(df)):
32 | mutations = df["seqID_XY"].iloc[i]
33 | mutant = convert2mutant(mutations)
34 | if mutant is None:
35 | continue
36 | seq = get_mutated_sequence(wt_seq, mutant)
37 | sequences.append(seq)
38 | fitnesses.append(df["Epistasis_score"].iloc[i])
39 |
40 | return {"sequence": sequences, "fitness": fitnesses}
41 |
42 |
43 | if __name__ == "__main__":
44 | # Files
45 | data_dir = sys.argv[1]
46 | seq_file = os.path.join(data_dir, "Pab1_reference_sequence.txt")
47 | data_file = os.path.join(data_dir, "Supplementary_Table_5.xlsx")
48 | out_file = os.path.join(data_dir, "Pab1.csv")
49 |
50 | # Get protein sequence
51 | wt_seq = get_aa_sequence(seq_file)
52 |
53 | # Generate data
54 | seq2fit = generate_data(wt_seq, data_file)
55 | mut_df = pd.DataFrame.from_dict(seq2fit)
56 | # Drop duplications
57 | mut_df.drop_duplicates(subset="sequence", inplace=True, ignore_index=True)
58 | mut_df.to_csv(out_file, index=False)
59 |
--------------------------------------------------------------------------------
/scripts/preprocess/preprocess_TEM.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import numpy as np
3 | import pandas as pd
4 | import os
5 | from de.common.utils import get_mutated_sequence
6 |
7 |
8 | def get_aa_sequence(filepath):
9 | with open(filepath, "r") as f:
10 | seq = f.readlines()[0].strip()
11 | return seq
12 |
13 |
14 | def generate_data(wt_seq, data_file):
15 | df = pd.read_csv(data_file, sep="\t")
16 | # df.dropna(inplace=True, ignore_index=True)
17 | df["real_loc"] = df["location"].apply(lambda x: x - 3)
18 |
19 | sequences = []
20 | fitnesses = []
21 |
22 | for i in range(len(df)):
23 | # get vars
24 | loc = df["real_loc"].iloc[i]
25 | wt_aa = df["wt_aa"].iloc[i]
26 | new_aa = df["new_aa"].iloc[i]
27 | fitness = df["fitness"].iloc[i]
28 |
29 | if np.isnan(fitness):
30 | continue
31 |
32 | if wt_seq[loc] != wt_aa:
33 | print(f"i = {i}")
34 | print(f"loc = {loc}")
35 | print(df.iloc[i])
36 | print(f"wt_seq[{loc}] = {wt_seq[loc]}")
37 | print(f"wt_aa = {wt_aa}")
38 | raise ValueError(f"Position {loc + 1} of WT sequence is {wt_seq[loc]}, not {wt_aa}")
39 | mut = wt_aa + str(loc + 1) + new_aa
40 | mut_seq = get_mutated_sequence(wt_seq, mut)
41 |
42 | sequences.append(mut_seq)
43 | fitnesses.append(fitness)
44 |
45 | return {"sequence": sequences, "fitness": fitnesses}
46 |
47 |
48 | if __name__ == "__main__":
49 | # Files
50 | data_dir = sys.argv[1]
51 | seq_file = os.path.join(data_dir, "TEM_reference_sequence.txt")
52 | data_file = os.path.join(data_dir, "TEM_mutation.tsv")
53 | out_file = os.path.join(data_dir, "TEM.csv")
54 |
55 | # Get protein sequence
56 | wt_seq = get_aa_sequence(seq_file)
57 |
58 | # Generate data
59 | seq2fit = generate_data(wt_seq, data_file)
60 | mut_df = pd.DataFrame.from_dict(seq2fit)
61 | # Drop duplications
62 | mut_df.drop_duplicates(subset="sequence", inplace=True, ignore_index=True)
63 | mut_df.to_csv(out_file, index=False)
64 |
--------------------------------------------------------------------------------
/scripts/preprocess/preprocess_UBE2I.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import numpy as np
3 | import pandas as pd
4 | import os
5 | from de.common.utils import get_mutated_sequence
6 |
7 |
8 | def get_aa_sequence(filepath):
9 | with open(filepath, "r") as f:
10 | seq = f.readlines()[0].strip()
11 | return seq
12 |
13 |
14 | def generate_data(wt_seq, data_file1, data_file2):
15 | df1 = pd.read_csv(data_file1)
16 | df2 = pd.read_csv(data_file2)
17 |
18 | sequences = []
19 | fitnesses = []
20 |
21 | for i in range(len(df1)):
22 | mut = df1["mut"].iloc[i]
23 | mut_seq = get_mutated_sequence(wt_seq, mut)
24 | if np.isnan(df1["screen.score"].iloc[i]) and np.isnan(df2["screen.score"].iloc[i]):
25 | fitness = (df1["joint.score"].iloc[i] + df2["joint.score"].iloc[2]) / 2
26 | else:
27 | fitness = df1["screen.score"].iloc[i] or df2["screen.score"].iloc[i]
28 |
29 | sequences.append(mut_seq)
30 | fitnesses.append(fitness)
31 |
32 | return {"sequence": sequences, "fitness": fitnesses}
33 |
34 |
35 | if __name__ == "__main__":
36 | # Files
37 | data_dir = sys.argv[1]
38 | seq_file = os.path.join(data_dir, "UBE2I_reference_sequence.txt")
39 | data_file1 = os.path.join(data_dir, "UBE2I_scores.csv")
40 | data_file2 = os.path.join(data_dir, "UBE2I_flipped_scores.csv")
41 | out_file = os.path.join(data_dir, "UBE2I.csv")
42 |
43 | # Get protein sequence
44 | wt_seq = get_aa_sequence(seq_file)
45 |
46 | # Generate data
47 | seq2fit = generate_data(wt_seq, data_file1, data_file2)
48 | mut_df = pd.DataFrame.from_dict(seq2fit)
49 | # Drop duplications
50 | mut_df.drop_duplicates(subset="sequence", inplace=True, ignore_index=True)
51 | mut_df.to_csv(out_file, index=False)
52 |
--------------------------------------------------------------------------------
/scripts/preprocess/preprocess_avGFP.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import sys
3 | import os
4 | from Bio.Seq import translate
5 |
6 |
7 | def get_aa_sequence(filepath: str):
8 | with open(filepath, "r") as f:
9 | content = f.readlines()
10 | dna_seq = content[-1]
11 | prot_seq = translate(dna_seq, to_stop=True)
12 | return prot_seq
13 |
14 |
15 | def mutant2seq(wt_seq: str, mutant: str):
16 | if mutant == "":
17 | return wt_seq
18 | elif "*" in mutant:
19 | return None
20 | else:
21 | seq = list(wt_seq)
22 | muts = mutant.split(":")
23 | for mut in muts:
24 | aa_org, pos, aa_new = mut[1], int(mut[2:-1]), mut[-1]
25 | if aa_org != wt_seq[pos]:
26 | raise ValueError(f"{aa_org} is different from wt_seq[{pos}].")
27 | seq[pos] = aa_new
28 |
29 | return "".join(seq)
30 |
31 |
32 | def generate_data(wt_seq: str, df: pd.DataFrame):
33 | df["aaMutations"].fillna("", inplace=True)
34 | mutants = df["aaMutations"].tolist()
35 | fitness = df["medianBrightness"].tolist()
36 | variants = []
37 | fitnesses = []
38 | for mut, fit in zip(mutants, fitness):
39 | variant = mutant2seq(wt_seq, mut)
40 | if variant is not None:
41 | variants.append(variant)
42 | fitnesses.append(fit)
43 |
44 | return {"sequence": variants, "fitness": fitnesses}
45 |
46 |
47 | if __name__ == "__main__":
48 | # Files
49 | data_dir = sys.argv[1]
50 | seq_file = os.path.join(data_dir, "avGFP_reference_sequence.fa")
51 | data_file = os.path.join(data_dir, "amino_acid_genotypes_to_brightness.tsv")
52 | out_file = os.path.join(data_dir, "avGFP.csv")
53 |
54 | # Convert DNA to protein sequence
55 | wt_seq = get_aa_sequence(seq_file)
56 |
57 | # Generate data
58 | df = pd.read_csv(data_file, sep="\t")
59 | seq2fit = generate_data(wt_seq, df)
60 | mut_df = pd.DataFrame.from_dict(seq2fit)
61 | mut_df.to_csv(out_file, index=False)
62 |
--------------------------------------------------------------------------------
/scripts/run_de.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Parse command line arguments or set default values
4 | dataset="$1"
5 | n_steps="${2}"
6 | seed="${3:-0}"
7 | devices="${4:-0}"
8 | k="${5:-1}"
9 | num_proposes_per_var="${6:-4}"
10 | num_chunk="${7:-1}"
11 | ckpt_path=$8
12 | population=128
13 | num_toks=1
14 | pretrained_mutation_name="facebook/esm2_t12_35M_UR50D"
15 | model_name="esm2-35M"
16 |
17 | python scripts/run_discrete_de.py --task "$dataset" --n_steps "$((n_steps))" --population "$((population))" \
18 | --num_proposes_per_var "$((num_proposes_per_var))" --seed "$seed" --rm_dups \
19 | --save_name results_${dataset}_model=${model_name}_steps${n_steps}_pop${population}_pros${num_proposes_per_var}_seed${seed}_k${k}_num${num_toks}_imp.csv \
20 | --k "$((k))" --num_masked_tokens "$((num_toks))" --verbose --devices "$devices" --predictor_ckpt_path "$ckpt_path" \
21 | --population_ratio_per_mask 0.1 0.9 --pretrained_mutation_name "$pretrained_mutation_name"
--------------------------------------------------------------------------------
/scripts/run_discrete_de.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import os
4 | import pandas as pd
5 | import torch
6 | from typing import List, Union, Tuple
7 | from de.common.utils import set_seed, enable_full_deterministic
8 | from de.directed_evolution import DiscreteDirectedEvolution2
9 | from de.samplers.maskers import RandomMasker2, ImportanceMasker2
10 | from de.samplers.models.esm import ESM2
11 | from de.predictors.attention.module import ESM2DecoderModule, ESM2_Attention
12 | from de.predictors.oracle import ESM1b_Landscape, ESM1v
13 |
14 |
15 | def parse_args():
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument("--task",
18 | type=str,
19 | choices=["AAV", "avGFP", "TEM", "E4B", "UBE2I", "LGK", "Pab1", "AMIE"],
20 | help="Benchmark task.")
21 | parser.add_argument("--n_steps",
22 | type=int,
23 | default=100,
24 | help="No. steps to run directed evolution.")
25 | parser.add_argument("--population",
26 | type=int,
27 | default=128,
28 | help="No. population per step.")
29 | parser.add_argument("--num_proposes_per_var",
30 | type=int,
31 | default=4,
32 | help="Number of proposed mutations for each variant in the pool.")
33 | parser.add_argument("--k",
34 | type=int,
35 | default=1,
36 | help="Split sequence into multiple tokens with length `k`.")
37 | parser.add_argument("--rm_dups",
38 | action="store_true",
39 | help="Whether to remove duplications in the proposed candidate pool.")
40 | parser.add_argument("--population_ratio_per_mask",
41 | nargs="+",
42 | type=float,
43 | help="Population ratio to run per masker.")
44 | parser.add_argument("--pretrained_mutation_name",
45 | type=str,
46 | default="facebook/esm2_t12_35M_UR50D",
47 | help="Pretrained model name or path for mutation checkpoint.")
48 | parser.add_argument("--dec_hidden_size",
49 | type=int,
50 | default=512,
51 | help="Decoder hidden size (for conditional task).")
52 | parser.add_argument("--predictor_ckpt_path",
53 | type=str,
54 | help="Path to fitness predictor checkpoints.")
55 | parser.add_argument("--num_masked_tokens",
56 | type=int,
57 | default=1,
58 | help="No. masked tokens to predict.")
59 | parser.add_argument("--mask_high_importance",
60 | action="store_true",
61 | help="Whether to mask high-importance token in the sequence.")
62 | parser.add_argument("--verbose",
63 | action="store_true",
64 | help="Whether to display output.")
65 | parser.add_argument("--seed",
66 | type=int,
67 | default=0,
68 | help="Random seed.")
69 | parser.add_argument("--set_seed_only",
70 | action="store_true",
71 | help="Whether to enable full determinism or set random seed only.")
72 | parser.add_argument("--result_dir",
73 | type=str,
74 | default=os.path.abspath("./exps/results"),
75 | help="Directory to save result csv file.")
76 | parser.add_argument("--log_dir",
77 | type=str,
78 | default=os.path.abspath("./exps/logs"),
79 | help="Directory to save logfile")
80 | parser.add_argument("--save_name",
81 | type=str,
82 | help="Filename of the result csv file.")
83 | parser.add_argument("--devices",
84 | type=str,
85 | default="-1",
86 | help="Devices, separated by commas.")
87 | parser.add_argument("--esm1v_seed",
88 | type=int,
89 | choices=[1, 2, 3, 4, 5])
90 | parser.add_argument("--predictor_ckpt_path", type=str)
91 | args = parser.parse_args()
92 | return args
93 |
94 |
95 | def extract_from_csv(csv_file: str, top_k: int = -1) -> Tuple[List[str], np.ndarray]:
96 | df = pd.read_csv(csv_file)
97 | if top_k != -1:
98 | df = df.nlargest(top_k, columns="fitness")
99 | targets = df["fitness"].to_list()
100 | seqs = df.sequence.tolist()
101 | return seqs, targets
102 |
103 |
104 | def initialize_mutation_model(args, device: torch.device):
105 | model = ESM2(pretrained_model_name_or_path=args.pretrained_mutation_name)
106 | tokenizer = model.tokenizer
107 | model.to(device)
108 | model.eval()
109 | return model, tokenizer
110 |
111 |
112 | def initialize_maskers(args):
113 | imp_masker = ImportanceMasker2(args.k,
114 | max_subs=args.num_masked_tokens,
115 | low_importance_mask=not args.mask_high_importance)
116 | rand_masker = RandomMasker2(args.k, max_subs=args.num_masked_tokens)
117 |
118 | return [rand_masker, imp_masker]
119 |
120 |
121 | def initialize_oracle(args, device: Union[str, torch.device]):
122 | landscape = ESM1b_Landscape(args.task, device)
123 | return landscape
124 |
125 |
126 | def initialize_oracle2(args, device):
127 | model = ESM1v(f"esm1v_t33_650M_UR90S_{args.esm1v_seed}", device, "masked", 1)
128 | return model
129 |
130 |
131 | def initialize_fitness_predictor(args, device: Union[str, torch.device]):
132 | tmp_name = "facebook/esm2_t33_650M_UR50D"
133 | # decoder = ESM2_Attention(args.pretrained_mutation_name, hidden_dim=args.dec_hidden_size)
134 | decoder = ESM2_Attention(tmp_name, hidden_dim=args.dec_hidden_size)
135 | model = ESM2DecoderModule.load_from_checkpoint(
136 | args.predictor_ckpt_path, map_location=device, net=decoder
137 | )
138 | model.eval()
139 |
140 | return model
141 |
142 |
143 | def save_results(wt_seqs: List[str], mutants, score, valid_score, output_path: str):
144 | output_dir = os.path.dirname(output_path)
145 | os.makedirs(output_dir, exist_ok=True)
146 | df = pd.DataFrame.from_dict({"WT": wt_seqs,
147 | "mutants": mutants,
148 | "score": score,
149 | "orc. score": valid_score})
150 | df.sort_values(by=["orc. score"], ascending=False, inplace=True, ignore_index=True)
151 | df.to_csv(output_path, index=False)
152 |
153 |
154 | def main(args):
155 | # Init env stuffs
156 | set_seed(args.seed) if args.set_seed_only else enable_full_deterministic(args.seed)
157 | os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = 'true'
158 | device = torch.device("cpu" if args.devices == "-1" else f"cuda:{args.devices}")
159 |
160 | # Init models
161 | mutation_model, mutation_tokenizer = initialize_mutation_model(args, device)
162 | fitness_predictor = initialize_fitness_predictor(args, device)
163 | # Init oracle
164 | oracle = initialize_oracle(args, device)
165 | # oracle2 = initialize_oracle2(args, device)
166 | # Init masker
167 | maskers = initialize_maskers(args)
168 | # Create folder
169 | result_dir = os.path.join(args.result_dir, args.task)
170 | log_dir = os.path.join(args.log_dir, args.task)
171 | os.makedirs(result_dir, exist_ok=True)
172 | os.makedirs(log_dir, exist_ok=True)
173 |
174 | # Init procedure
175 | direct_evo = DiscreteDirectedEvolution2(
176 | n_steps=args.n_steps,
177 | population=args.population,
178 | maskers=maskers,
179 | mutation_model=mutation_model,
180 | mutation_tokenizer=mutation_tokenizer,
181 | fitness_predictor=fitness_predictor,
182 | remove_duplications=args.rm_dups,
183 | k=args.k,
184 | population_ratio_per_mask=args.population_ratio_per_mask,
185 | num_propose_mutation_per_variant=args.num_proposes_per_var,
186 | verbose=args.verbose,
187 | mutation_device=device,
188 | log_dir=log_dir,
189 | seed=args.seed,
190 | )
191 |
192 | lines = open(f"./preprocessed_data/{args.task}/{args.task}_reference_sequence.txt").readlines()
193 | wt_seq, wt_fitness = lines[0].strip(), float(lines[1].strip())
194 | mutants, pred_fitness, variants = direct_evo(wt_seq, wt_fitness)
195 | pred_fitness = pred_fitness.squeeze(1).numpy().tolist()
196 |
197 | valid_fitness = oracle.infer_fitness(variants)
198 |
199 | filepath = os.path.join(result_dir, args.save_name)
200 | save_results([wt_seq] * len(mutants), mutants, pred_fitness, valid_fitness, filepath)
201 |
202 |
203 | if __name__ == "__main__":
204 | args = parse_args()
205 | main(args)
206 |
--------------------------------------------------------------------------------
/scripts/train.sh:
--------------------------------------------------------------------------------
1 | dataset=$1
2 | devices=$2
3 | batch_size=${3:-128}
4 | ckpt_path=${4:-''}
5 |
6 | data_file="/home/thanhtvt1/workspace/Latent-Based-Directed-Evolution/preprocessed_data/${dataset}/${dataset}.csv"
7 | pretrained_encoder="facebook/esm2_t12_35M_UR50D"
8 | dec_hidden_dim=1280
9 | lr=0.0002
10 | num_epochs=100
11 | num_ckpts=3
12 | precision="highest"
13 |
14 | python train_decoder.py --data_file $data_file --dataset_name $dataset \
15 | --pretrained_encoder $pretrained_encoder --dec_hidden_dim $dec_hidden_dim \
16 | --batch_size $batch_size --devices $devices \
17 | --lr $lr --num_epochs $num_epochs --num_ckpts $num_ckpts \
18 | --precision $precision #--ckpt_path=$ckpt_path
--------------------------------------------------------------------------------
/scripts/train_decoder.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | from functools import partial
5 | from lightning import Trainer, seed_everything
6 | from lightning.pytorch import loggers, callbacks
7 | from torch.optim import Adam
8 | from de.dataio.proteins import ProteinsDataModule
9 | from de.predictors.attention.module import ESM2_Attention, ESM2DecoderModule
10 |
11 |
12 | def parse_args():
13 | parser = argparse.ArgumentParser(description="Train decoder.")
14 | parser.add_argument("--data_file",
15 | type=str,
16 | help="Path to data directory.")
17 | parser.add_argument("--dataset_name",
18 | type=str,
19 | help="Name of trained dataset.")
20 | parser.add_argument("--pretrained_encoder",
21 | type=str,
22 | default="facebook/esm2_t12_35M_UR50D",
23 | help="Path to pretrained encoder.")
24 | parser.add_argument("--dec_hidden_dim",
25 | type=int,
26 | default=1280,
27 | help="Hidden dim of decoder.")
28 | parser.add_argument("--batch_size",
29 | type=int,
30 | default=128,
31 | help="Batch size.")
32 | parser.add_argument("--ckpt_path",
33 | type=str,
34 | help="Checkpoint of model.")
35 | parser.add_argument("--devices",
36 | type=str,
37 | default="-1",
38 | help="Training devices separated by comma.")
39 | parser.add_argument("--output_dir",
40 | type=str,
41 | default="./exps",
42 | help="Path to output directory.")
43 | parser.add_argument("--grad_accum_steps",
44 | type=int,
45 | default=1,
46 | help="No. updates steps to accumulate the gradient.")
47 | parser.add_argument("--lr",
48 | type=float,
49 | default=1e-4,
50 | help="Learning rate.")
51 | parser.add_argument("--num_epochs",
52 | type=int,
53 | default=30,
54 | help="Number of epochs.")
55 | parser.add_argument("--wandb_project",
56 | type=str,
57 | default="directed_evolution",
58 | help="WandB project's name.")
59 | parser.add_argument("--seed",
60 | type=int,
61 | default=0,
62 | help="Random seed for reproducibility.")
63 | parser.add_argument("--set_seed_only",
64 | action="store_true",
65 | help="Whether to not set deterministic flag.")
66 | parser.add_argument("--num_workers",
67 | type=int,
68 | default=64,
69 | help="No. workers.")
70 | parser.add_argument("--num_ckpts",
71 | type=int,
72 | default=5,
73 | help="Maximum no. checkpoints can be saved.")
74 | parser.add_argument("--log_interval",
75 | type=int,
76 | default=100,
77 | help="How often to log within steps.")
78 | parser.add_argument("--precision",
79 | type=str,
80 | choices=["highest", "high", "medium"],
81 | default="highest",
82 | help="Internal precision of float32 matrix multiplications.")
83 | args = parser.parse_args()
84 | return args
85 |
86 |
87 | def init_model(pretrained_encoder, hidden_dim):
88 | model = ESM2_Attention(pretrained_encoder, hidden_dim)
89 | tokenizer = model.tokenizer
90 | model.freeze_encoder()
91 | return model, tokenizer
92 |
93 |
94 | def train(args):
95 | seed_everything(args.seed, workers=True)
96 | torch.set_float32_matmul_precision(args.precision)
97 | accelerator = "cpu" if args.devices == "-1" else "gpu"
98 |
99 | # Load model
100 | model, tokenizer = init_model(args.pretrained_encoder, args.dec_hidden_dim)
101 | # Init optimizer
102 | optim = partial(Adam, lr=args.lr)
103 |
104 | # ================== #
105 | # ====== Data ====== #
106 | # ================== #
107 | datamodule = ProteinsDataModule(
108 | csv_file=args.data_file,
109 | tokenizer=tokenizer,
110 | train_batch_size=args.batch_size,
111 | valid_batch_size=args.batch_size,
112 | num_workers=args.num_workers,
113 | seed=args.seed,
114 | )
115 |
116 | # ==================== #
117 | # ====== Model ====== #
118 | # ==================== #
119 | module = ESM2DecoderModule(model, optim)
120 |
121 | # ====================== #
122 | # ====== Training ====== #
123 | # ====================== #
124 | logger_list = [
125 | loggers.CSVLogger(args.output_dir),
126 | loggers.WandbLogger(save_dir=args.output_dir,
127 | project=args.wandb_project)
128 | ]
129 | prefix = args.pretrained_encoder.split("/")[-1] + f"-dec_{args.dec_hidden_dim}"
130 | callback_list = [
131 | callbacks.RichModelSummary(),
132 | callbacks.RichProgressBar(),
133 | callbacks.ModelCheckpoint(
134 | dirpath=os.path.join(args.output_dir, "checkpoints"),
135 | filename=f"{prefix}-{args.dataset_name}_" +
136 | "{epoch:02d}-{train_loss:.3f}-{val_loss:.3f}",
137 | monitor="val_loss",
138 | verbose=True,
139 | save_top_k=args.num_ckpts,
140 | save_weights_only=False,
141 | every_n_epochs=1,
142 | )
143 | ]
144 |
145 | trainer = Trainer(
146 | accelerator=accelerator,
147 | devices=[int(d) for d in args.devices.split(",")],
148 | max_epochs=args.num_epochs,
149 | log_every_n_steps=args.log_interval,
150 | accumulate_grad_batches=args.grad_accum_steps,
151 | deterministic=not args.set_seed_only,
152 | default_root_dir=args.output_dir,
153 | logger=logger_list,
154 | callbacks=callback_list,
155 | )
156 |
157 | trainer.fit(module, datamodule=datamodule, ckpt_path=args.ckpt_path)
158 |
159 |
160 | if __name__ == "__main__":
161 | args = parse_args()
162 | train(args)
163 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | from setuptools import find_packages, setup
3 |
4 |
5 | with open("README.md", "r") as f:
6 | long_description = f.read()
7 |
8 | with open("requirements.txt", "r") as f:
9 | install_requires = f.read().splitlines()
10 |
11 | version = "1.0.0"
12 |
13 | with open(os.path.join("de", "version.py"), "w") as f:
14 | f.writelines([
15 | '"""This file is auto-generated by setup.py, please do not alter."""\n',
16 | f'__version__ = "{version}"\n',
17 | "",
18 | ])
19 |
20 |
21 | setup(
22 | name="de",
23 | version=version,
24 | description="Protein design by Directed Evolution guided by Large Language Models.",
25 | long_description=long_description,
26 | long_description_content_type="text/markdown",
27 | url="https://github.com/HySonLab/Directed_Evolution",
28 | packages=find_packages(),
29 | install_requires=install_requires,
30 | python_requires=">=3.10",
31 | license="GNU",
32 | keywords=["directed evolution", "protein engineering", "large language model"]
33 | )
34 |
--------------------------------------------------------------------------------