├── .gitignore ├── LICENSE ├── README.md ├── abgpt ├── __init__.py └── generate.py ├── pyproject.toml └── setup.cfg /.gitignore: -------------------------------------------------------------------------------- 1 | abgpt.egg-info/ 2 | dist/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Desmond Kuan, Amir Barati Farimani 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AbGPT 2 | Official repository for AbGPT: [De Novo B-Cell Receptor Design via Generative Language Modeling](https://www.biorxiv.org/xxxxxxxxxx). 3 | 4 | ## Setup 5 | To use AbGPT, install via pip: 6 | ```bash 7 | pip install abgpt 8 | ``` 9 | 10 | 15 | 16 | ## Command line usage 17 | 18 | ### Full sequence generation 19 | To generate 1000 light chain sequences starting with "QLQL": 20 | ```bash 21 | abgpt_generate --chain_type light --starting_residue QLQL --num_seqs 1000 22 | ``` 23 | 24 | To generate a BCR library with 1000 sequences for a number of starting residue (e.g., "QVQL", "EVQL", "VQLV") in the heavy chain: 25 | ```bash 26 | abgpt_generate --chain_type heavy --starting_residue QVQL,EVQL,VQLV --num_seqs_each_starting_residue 1000 27 | ``` 28 | 29 | To generate a BCR library with 1000 sequences for a number of starting residue (e.g., "EIVL", "EIVM", "DIQM") in the light chain: 30 | ```bash 31 | abgpt_generate --chain_type light --starting_residue EIVL,EIVM,DIQM --num_seqs_each_starting_residue 1000 32 | ``` -------------------------------------------------------------------------------- /abgpt/__init__.py: -------------------------------------------------------------------------------- 1 | from .generate import generate_specific_sequences, generate_bcr_library -------------------------------------------------------------------------------- /abgpt/generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import random 5 | import math 6 | from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel 7 | from tqdm import tqdm 8 | 9 | CHAIN_LENGTHS = { 10 | "light": (100, 120), 11 | "heavy": (110, 140) 12 | } 13 | MODEL_REPO = "deskk/AbGPT" 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser(description='BCR Sequence Generation') 17 | parser.add_argument('--chain_type', type=str, choices=['light', 'heavy'], required=True, help='Chain type: light or heavy') 18 | parser.add_argument('--starting_residue', type=str, help='Starting residue for sequence generation') 19 | parser.add_argument('--num_seqs', type=int, help='Number of sequences to generate') 20 | parser.add_argument('--num_seqs_each_starting_residue', type=int, help='Number of sequences to generate for each starting residue') 21 | args = parser.parse_args() 22 | return args 23 | 24 | def generate_bcr_sequences(num_sequences=100, chain_type="light", starting_residue=""): 25 | device = "cuda" if torch.cuda.is_available() else "cpu" 26 | min_length = 20 if chain_type == "light" else 28 27 | starting_prompt = f"<|endoftext|>{starting_residue}" 28 | abgpt_pipeline = pipeline('text-generation', model=MODEL_REPO, tokenizer=MODEL_REPO, device=0 if device == "cuda" else -1) 29 | generated_sequences = abgpt_pipeline( 30 | starting_prompt, 31 | min_length=min_length, 32 | do_sample=True, 33 | top_k=950, 34 | repetition_penalty=1.2, 35 | num_return_sequences=num_sequences, 36 | eos_token_id=0 37 | ) 38 | return [seq['generated_text'].replace("<|endoftext|>", "").strip() for seq in generated_sequences] 39 | 40 | def save_sequences(filename, sequences): 41 | with open(filename, 'w') as f: 42 | for seq in sequences: 43 | f.write(f"{seq}\n\n") 44 | 45 | def calculate_perplexity(sequence, model, tokenizer, device): 46 | input_ids = torch.tensor(tokenizer.encode(sequence)).unsqueeze(0).to(device) 47 | with torch.no_grad(): 48 | loss = model(input_ids, labels=input_ids).loss 49 | return math.exp(loss.item()) 50 | 51 | def preprocess_sequence(sequence): 52 | return "<|endoftext|>" + '\n'.join([sequence[i:i+60] for i in range(0, len(sequence), 60)]) + "<|endoftext|>" 53 | 54 | def save_best_sequences(filename, sequences): 55 | with open(filename, 'w') as file: 56 | for seq, ppl in sequences: 57 | cleaned_seq = seq.replace("<|endoftext|>", "").strip() 58 | formatted_seq = "\n".join([cleaned_seq[i:i+60] for i in range(0, len(cleaned_seq), 60)]) 59 | file.write(f"Sequence: {formatted_seq}\nPerplexity: {ppl}\n\n") 60 | 61 | def process_sequences_from_file(file_path, model, tokenizer, device): 62 | with open(file_path, 'r') as file: 63 | sequences = file.read().split("\n\n") 64 | filtered_sequences = [] 65 | for seq_block in tqdm(sequences, desc=f"Calculating sequences for {file_path}"): 66 | seq = seq_block.strip() 67 | if seq: 68 | concatenated_seq = ''.join(seq.splitlines()) 69 | preprocessed_seq = preprocess_sequence(concatenated_seq) 70 | ppl = calculate_perplexity(preprocessed_seq, model, tokenizer, device) 71 | if ppl < 13.0: 72 | filtered_sequences.append((concatenated_seq, ppl)) 73 | return sorted(filtered_sequences, key=lambda x: x[1]) 74 | 75 | def read_and_filter_sequences(file_path): 76 | with open(file_path, 'r') as file: 77 | text = file.read() 78 | chunks = text.split('Sequence: ')[1:] 79 | sequences = [] 80 | for chunk in chunks: 81 | sequence = chunk.split('Perplexity:')[0].strip().replace('\n', '') 82 | if 'X' not in sequence and 'B' not in sequence: 83 | sequences.append(sequence) 84 | return sequences 85 | 86 | def format_sequence(sequence): 87 | return '\n'.join(sequence[i:i+60] for i in range(0, len(sequence), 60)) 88 | 89 | def save_to_fasta(sequences, output_file, sequence_counter, seen_sequences): 90 | with open(output_file, 'a') as f: 91 | for sequence in sequences: 92 | if sequence not in seen_sequences: 93 | seen_sequences.add(sequence) 94 | sequence_counter += 1 95 | formatted_sequence = format_sequence(sequence) 96 | f.write(f'>Sequence_{sequence_counter}\n') 97 | f.write(f'{formatted_sequence}\n') 98 | return sequence_counter, seen_sequences 99 | 100 | def process_directory(directory_path, output_file): 101 | sequence_counter = 0 102 | seen_sequences = set() 103 | with open(output_file, 'w') as f: 104 | pass 105 | for filename in os.listdir(directory_path): 106 | if filename.endswith('.txt'): 107 | file_path = os.path.join(directory_path, filename) 108 | sequences = read_and_filter_sequences(file_path) 109 | sequence_counter, seen_sequences = save_to_fasta(sequences, output_file, sequence_counter, seen_sequences) 110 | 111 | def generate_specific_sequences(args): 112 | sequences = generate_bcr_sequences( 113 | num_sequences=args.num_seqs, 114 | chain_type=args.chain_type, 115 | starting_residue=args.starting_residue 116 | ) 117 | output_dir = 'bcr_design' 118 | os.makedirs(output_dir, exist_ok=True) 119 | filename = os.path.join(output_dir, f'{args.chain_type}_{args.starting_residue}.txt') 120 | save_sequences(filename, sequences) 121 | 122 | def generate_bcr_library(args): 123 | output_dir = 'bcr_library' 124 | os.makedirs(output_dir, exist_ok=True) 125 | combined_filename = os.path.join(output_dir, f'{args.chain_type}_BCR_library.txt') 126 | all_sequences = [] 127 | residues = args.starting_residue.split(',') 128 | for starting_residue in residues: 129 | sequences = generate_bcr_sequences( 130 | num_sequences=args.num_seqs_each_starting_residue, 131 | chain_type=args.chain_type, 132 | starting_residue=starting_residue 133 | ) 134 | all_sequences.extend(sequences) 135 | save_sequences(combined_filename, all_sequences) 136 | 137 | def main(): 138 | args = parse_args() 139 | if args.starting_residue and args.num_seqs: 140 | generate_specific_sequences(args) 141 | elif args.starting_residue and args.num_seqs_each_starting_residue: 142 | generate_bcr_library(args) 143 | else: 144 | print("Please specify either --starting_residue and --num_seqs or --starting_residue and --num_seqs_each_starting_residue for library generation.") 145 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = abgpt 3 | version = 0.1.2 4 | author = Desmond Kuan, Amir Barati Farimani 5 | author_email = barati@cmu.edu 6 | description = AbGPT: De Novo B-Cell Receptor Design via Generative Language Modeling 7 | long_description = file: README.md 8 | long_description_content_type = text/markdown 9 | url = https://github.com/deskk/AbGPT 10 | license = MIT 11 | classifiers = 12 | Programming Language :: Python :: 3 13 | License :: OSI Approved :: MIT License 14 | Operating System :: OS Independent 15 | 16 | [options] 17 | packages = find: 18 | install_requires = 19 | numpy>=1.21.2 20 | tokenizers>=0.14.0 21 | torch>=2.0.1 22 | tqdm>=4.66.1 23 | transformers==4.44.2 24 | # transformers @ git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7 25 | python_requires = >=3.6 26 | 27 | [options.entry_points] 28 | console_scripts = 29 | abgpt_generate = abgpt.generate:main 30 | --------------------------------------------------------------------------------