├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── feature_request.md │ └── other.md └── workflows │ └── ci.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── CodonTransformer ├── CodonData.py ├── CodonEvaluation.py ├── CodonJupyter.py ├── CodonPrediction.py ├── CodonUtils.py └── __init__.py ├── CodonTransformerDemo.ipynb ├── LICENSE ├── Makefile ├── README.md ├── demo ├── sample_dataset.csv └── sample_predictions.csv ├── finetune.py ├── pretrain.py ├── pyproject.toml ├── requirements.txt ├── setup.py ├── slurm ├── finetune.sh └── pretrain.sh ├── src ├── CodonTransformerTokenizer.json ├── CodonTransformer_inference_template.xlsx ├── __init__.py ├── banner_final.png └── organism2id.pkl └── tests ├── __init__.py ├── test_CodonData.py ├── test_CodonJupyter.py ├── test_CodonPrediction.py └── test_CodonUtils.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Additional context** 17 | Add any other context or screenshots about the feature request here. 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/other.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Other 3 | about: Any other issue 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe your issue here** 11 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | # .github/workflows/ci.yml 2 | 3 | name: CI 4 | 5 | on: [push, pull_request] 6 | 7 | jobs: 8 | test: 9 | runs-on: ubuntu-latest 10 | 11 | steps: 12 | - name: Checkout code 13 | uses: actions/checkout@v4 14 | 15 | - name: Set up Python 16 | uses: actions/setup-python@v5 17 | with: 18 | python-version: '3.10' 19 | 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install -r requirements.txt 24 | pip install "coverage[toml]" 25 | 26 | - name: Run tests with coverage 27 | run: | 28 | make test_with_coverage 29 | coverage report 30 | coverage xml 31 | 32 | - name: Upload coverage to Codecov 33 | uses: codecov/codecov-action@v4 34 | with: 35 | token: ${{ secrets.CODECOV_TOKEN }} 36 | file: coverage.xml 37 | flags: unittests 38 | name: codecov-umbrella 39 | fail_ci_if_error: true 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # Coverage reports 163 | coverage.xml 164 | 165 | # Jupyter Notebook checkpoints 166 | .ipynb_checkpoints/ 167 | 168 | # Temporary files 169 | *.tmp 170 | *.temp 171 | 172 | # PyTorch Lightning checkpoints 173 | lightning_logs/ 174 | 175 | # PyTorch model weights 176 | *.pth 177 | *.pt 178 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | 4 | # Don't run pre-commit on files under third-party/ 5 | exclude: "^\ 6 | (third-party/.*)\ 7 | " 8 | 9 | repos: 10 | - repo: https://github.com/pre-commit/pre-commit-hooks 11 | rev: v4.1.0 12 | hooks: 13 | - id: check-added-large-files # prevents giant files from being committed. 14 | - id: check-case-conflict # checks for files that would conflict in case-insensitive filesystems. 15 | - id: check-merge-conflict # checks for files that contain merge conflict strings. 16 | - id: check-yaml # checks yaml files for parseable syntax. 17 | - id: detect-private-key # detects the presence of private keys. 18 | - id: end-of-file-fixer # ensures that a file is either empty, or ends with one newline. 19 | - id: fix-byte-order-marker # removes utf-8 byte order marker. 20 | - id: mixed-line-ending # replaces or checks mixed line ending. 21 | - id: requirements-txt-fixer # sorts entries in requirements.txt. 22 | - id: trailing-whitespace # trims trailing whitespace. 23 | 24 | - repo: https://github.com/sirosen/check-jsonschema 25 | rev: 0.23.2 26 | hooks: 27 | - id: check-github-actions 28 | - id: check-github-workflows 29 | 30 | - repo: https://github.com/astral-sh/ruff-pre-commit 31 | rev: v0.1.13 32 | hooks: 33 | - id: ruff 34 | - id: ruff-format 35 | 36 | - repo: https://github.com/psf/black 37 | rev: 24.8.0 38 | hooks: 39 | - id: black 40 | language_version: python3 41 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | Adibvafa.fallahpour@mail.utoronto.ca. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /CodonTransformer/CodonData.py: -------------------------------------------------------------------------------- 1 | """ 2 | File: CodonData.py 3 | --------------------- 4 | Includes helper functions for preprocessing NCBI or Kazusa databases and 5 | preparing the data for training and inference of the CodonTransformer model. 6 | """ 7 | 8 | import json 9 | import os 10 | import random 11 | from typing import Dict, List, Optional, Tuple, Union 12 | 13 | import pandas as pd 14 | import python_codon_tables as pct 15 | from Bio import SeqIO 16 | from Bio.Seq import Seq 17 | from sklearn.utils import shuffle as sk_shuffle 18 | from tqdm import tqdm 19 | 20 | from CodonTransformer.CodonUtils import ( 21 | AMBIGUOUS_AMINOACID_MAP, 22 | AMINO2CODON_TYPE, 23 | AMINO_ACIDS, 24 | ORGANISM2ID, 25 | START_CODONS, 26 | STOP_CODONS, 27 | STOP_SYMBOL, 28 | STOP_SYMBOLS, 29 | ProteinConfig, 30 | find_pattern_in_fasta, 31 | get_taxonomy_id, 32 | sort_amino2codon_skeleton, 33 | ) 34 | 35 | 36 | def prepare_training_data( 37 | dataset: Union[str, pd.DataFrame], output_file: str, shuffle: bool = True 38 | ) -> None: 39 | """ 40 | Prepare a JSON dataset for training the CodonTransformer model. 41 | 42 | Input dataset should have columns below: 43 | - dna: str (DNA sequence) 44 | - protein: str (Protein sequence) 45 | - organism: Union[int, str] (ID or Name of the organism) 46 | 47 | The output JSON dataset will have the following format: 48 | {"idx": 0, "codons": "M_ATG R_AGG L_TTG L_CTA R_CGA __TAG", "organism": 51} 49 | {"idx": 1, "codons": "M_ATG K_AAG C_TGC F_TTT F_TTC __TAA", "organism": 59} 50 | 51 | Args: 52 | dataset (Union[str, pd.DataFrame]): Input dataset in CSV or DataFrame format. 53 | output_file (str): Path to save the output JSON dataset. 54 | shuffle (bool, optional): Whether to shuffle the dataset before saving. 55 | Defaults to True. 56 | 57 | Returns: 58 | None 59 | """ 60 | if isinstance(dataset, str): 61 | dataset = pd.read_csv(dataset) 62 | 63 | required_columns = {"dna", "protein", "organism"} 64 | if not required_columns.issubset(dataset.columns): 65 | raise ValueError(f"Input dataset must have columns: {required_columns}") 66 | 67 | # Prepare the dataset for finetuning 68 | dataset["codons"] = dataset.apply( 69 | lambda row: get_merged_seq(row["protein"], row["dna"], separator="_"), axis=1 70 | ) 71 | 72 | # Replace organism str with organism id using ORGANISM2ID 73 | dataset["organism"] = dataset["organism"].apply( 74 | lambda org: process_organism(org, ORGANISM2ID) 75 | ) 76 | 77 | # Save the dataset to a JSON file 78 | dataframe_to_json(dataset[["codons", "organism"]], output_file, shuffle=shuffle) 79 | 80 | 81 | def dataframe_to_json(df: pd.DataFrame, output_file: str, shuffle: bool = True) -> None: 82 | """ 83 | Convert pandas DataFrame to JSON file format suitable for training CodonTransformer. 84 | 85 | This function takes a preprocessed DataFrame and writes it to a JSON file 86 | where each line is a JSON object representing a single record. 87 | 88 | Args: 89 | df (pd.DataFrame): The input DataFrame with 'codons' and 'organism' columns. 90 | output_file (str): Path to the output JSON file. 91 | shuffle (bool, optional): Whether to shuffle the dataset before saving. 92 | Defaults to True. 93 | 94 | Returns: 95 | None 96 | 97 | Raises: 98 | ValueError: If the required columns are not present in the DataFrame. 99 | """ 100 | required_columns = {"codons", "organism"} 101 | if not required_columns.issubset(df.columns): 102 | raise ValueError(f"DataFrame must contain columns: {required_columns}") 103 | 104 | print(f"\nStarted writing to {output_file}...") 105 | 106 | # Shuffle the DataFrame if requested 107 | if shuffle: 108 | df = sk_shuffle(df) 109 | 110 | # Write the DataFrame to a JSON file 111 | with open(output_file, "w") as f: 112 | for idx, row in tqdm( 113 | df.iterrows(), total=len(df), desc="Writing JSON...", unit=" records" 114 | ): 115 | doc = {"idx": idx, "codons": row["codons"], "organism": row["organism"]} 116 | f.write(json.dumps(doc) + "\n") 117 | 118 | print(f"\nTotal Entries Saved: {len(df)}, JSON data saved to {output_file}") 119 | 120 | 121 | def process_organism(organism: Union[str, int], organism_to_id: Dict[str, int]) -> int: 122 | """ 123 | Process and validate the organism input, converting it to a valid organism ID. 124 | 125 | This function handles both string (organism name) and integer (organism ID) inputs. 126 | It validates the input against a provided mapping of organism names to IDs. 127 | 128 | Args: 129 | organism (Union[str, int]): Input organism, either as a name (str) or ID (int). 130 | organism_to_id (Dict[str, int]): Dictionary mapping organism names to their 131 | corresponding IDs. 132 | 133 | Returns: 134 | int: The validated organism ID. 135 | 136 | Raises: 137 | ValueError: If the input is an invalid organism name or ID. 138 | TypeError: If the input is neither a string nor an integer. 139 | """ 140 | if isinstance(organism, str): 141 | if organism not in organism_to_id: 142 | raise ValueError(f"Invalid organism name: {organism}") 143 | return organism_to_id[organism] 144 | 145 | elif isinstance(organism, int): 146 | if organism not in organism_to_id.values(): 147 | raise ValueError(f"Invalid organism ID: {organism}") 148 | return organism 149 | 150 | raise TypeError( 151 | f"Organism must be a string or integer, not {type(organism).__name__}" 152 | ) 153 | 154 | 155 | def preprocess_protein_sequence(protein: str) -> str: 156 | """ 157 | Preprocess a protein sequence by cleaning, standardizing, and handling 158 | ambiguous amino acids. 159 | 160 | Args: 161 | protein (str): The input protein sequence. 162 | 163 | Returns: 164 | str: The preprocessed protein sequence. 165 | 166 | Raises: 167 | ValueError: If the protein sequence is invalid or if the configuration is invalid. 168 | """ 169 | if not protein: 170 | raise ValueError("Protein sequence is empty.") 171 | 172 | # Clean and standardize the protein sequence 173 | protein = ( 174 | protein.upper().strip().replace("\n", "").replace(" ", "").replace("\t", "") 175 | ) 176 | 177 | # Handle ambiguous amino acids based on the specified behavior 178 | config = ProteinConfig() 179 | ambiguous_aminoacid_map_override = config.get("ambiguous_aminoacid_map_override") 180 | ambiguous_aminoacid_behavior = config.get("ambiguous_aminoacid_behavior") 181 | ambiguous_aminoacid_map = AMBIGUOUS_AMINOACID_MAP.copy() 182 | 183 | for aminoacid, standard_aminoacids in ambiguous_aminoacid_map_override.items(): 184 | ambiguous_aminoacid_map[aminoacid] = standard_aminoacids 185 | 186 | if ambiguous_aminoacid_behavior == "raise_error": 187 | if any(aminoacid in ambiguous_aminoacid_map for aminoacid in protein): 188 | raise ValueError("Ambiguous amino acids found in protein sequence.") 189 | elif ambiguous_aminoacid_behavior == "standardize_deterministic": 190 | protein = "".join( 191 | ambiguous_aminoacid_map.get(aminoacid, [aminoacid])[0] 192 | for aminoacid in protein 193 | ) 194 | elif ambiguous_aminoacid_behavior == "standardize_random": 195 | protein = "".join( 196 | random.choice(ambiguous_aminoacid_map.get(aminoacid, [aminoacid])) 197 | for aminoacid in protein 198 | ) 199 | else: 200 | raise ValueError( 201 | f"Invalid ambiguous_aminoacid_behavior: {ambiguous_aminoacid_behavior}." 202 | ) 203 | 204 | # Check for sequence validity 205 | if any(aminoacid not in AMINO_ACIDS + STOP_SYMBOLS for aminoacid in protein): 206 | raise ValueError("Invalid characters in protein sequence.") 207 | 208 | if protein[-1] not in AMINO_ACIDS + STOP_SYMBOLS: 209 | raise ValueError( 210 | "Protein sequence must end with `*`, or `_`, or an amino acid." 211 | ) 212 | 213 | # Replace '*' at the end of protein with STOP_SYMBOL if present 214 | if protein[-1] == "*": 215 | protein = protein[:-1] + STOP_SYMBOL 216 | 217 | # Add stop symbol to end of protein 218 | if protein[-1] != STOP_SYMBOL: 219 | protein += STOP_SYMBOL 220 | 221 | return protein 222 | 223 | 224 | def replace_ambiguous_codons(dna: str) -> str: 225 | """ 226 | Replaces ambiguous codons in a DNA sequence with "UNK". 227 | 228 | Args: 229 | dna (str): The DNA sequence to process. 230 | 231 | Returns: 232 | str: The processed DNA sequence with ambiguous codons replaced by "UNK". 233 | """ 234 | result = [] 235 | dna = dna.upper() 236 | 237 | # Check codons in DNA sequence 238 | for i in range(0, len(dna), 3): 239 | codon = dna[i : i + 3] 240 | 241 | if len(codon) == 3 and all(nucleotide in "ATCG" for nucleotide in codon): 242 | result.append(codon) 243 | else: 244 | result.append("UNK") 245 | 246 | return "".join(result) 247 | 248 | 249 | def preprocess_dna_sequence(dna: str) -> str: 250 | """ 251 | Cleans and preprocesses a DNA sequence by standardizing it and replacing 252 | ambiguous codons. 253 | 254 | Args: 255 | dna (str): The DNA sequence to preprocess. 256 | 257 | Returns: 258 | str: The cleaned and preprocessed DNA sequence. 259 | """ 260 | if not dna: 261 | return "" 262 | 263 | # Clean and standardize the DNA sequence 264 | dna = dna.upper().strip().replace("\n", "").replace(" ", "").replace("\t", "") 265 | 266 | # Replace codons with ambigous nucleotides with "UNK" 267 | dna = replace_ambiguous_codons(dna) 268 | 269 | # Add unkown stop codon to end of DNA sequence if not present 270 | if dna[-3:] not in STOP_CODONS: 271 | dna += "UNK" 272 | 273 | return dna 274 | 275 | 276 | def get_merged_seq(protein: str, dna: str = "", separator: str = "_") -> str: 277 | """ 278 | Return the merged sequence of protein amino acids and DNA codons in the form 279 | of tokens separated by space, where each token is composed of an amino acid + 280 | separator + codon. 281 | 282 | Args: 283 | protein (str): Protein sequence. 284 | dna (str): DNA sequence. 285 | separator (str): Separator between amino acid and codon. 286 | 287 | Returns: 288 | str: Merged sequence. 289 | 290 | Example: 291 | >>> get_merged_seq(protein="MAV_", dna="ATGGCTGTGTAA", separator="_") 292 | 'M_ATG A_GCT V_GTG __TAA' 293 | 294 | >>> get_merged_seq(protein="QHH_", dna="", separator="_") 295 | 'Q_UNK H_UNK H_UNK __UNK' 296 | """ 297 | merged_seq = "" 298 | 299 | # Prepare protein and dna sequences 300 | dna = preprocess_dna_sequence(dna) 301 | protein = preprocess_protein_sequence(protein) 302 | 303 | # Check if the length of protein and dna sequences are equal 304 | if len(dna) > 0 and len(protein) != len(dna) / 3: 305 | raise ValueError( 306 | 'Length of protein (including stop symbol such as "_") and ' 307 | "the number of codons in DNA sequence (including stop codon) " 308 | "must be equal." 309 | ) 310 | 311 | # Merge protein and DNA sequences into tokens 312 | for i, aminoacid in enumerate(protein): 313 | merged_seq += f'{aminoacid}{separator}{dna[i * 3:i * 3 + 3] if dna else "UNK"} ' 314 | 315 | return merged_seq.strip() 316 | 317 | 318 | def is_correct_seq(dna: str, protein: str, stop_symbol: str = STOP_SYMBOL) -> bool: 319 | """ 320 | Check if the given DNA and protein pair is correct, that is: 321 | 1. The length of dna is divisible by 3 322 | 2. There is an initiator codon in the beginning of dna 323 | 3. There is only one stop codon in the sequence 324 | 4. The only stop codon is the last codon 325 | 326 | Note since in Codon Table 3, 'TGA' is interpreted as Triptophan (W), 327 | there is a separate check to make sure those sequences are considered correct. 328 | 329 | Args: 330 | dna (str): DNA sequence. 331 | protein (str): Protein sequence. 332 | stop_symbol (str): Stop symbol. 333 | 334 | Returns: 335 | bool: True if the sequence is correct, False otherwise. 336 | """ 337 | return ( 338 | len(dna) % 3 == 0 # Check if DNA length is divisible by 3 339 | and dna[:3].upper() in START_CODONS # Check for initiator codon 340 | and protein[-1] 341 | == stop_symbol # Check if the last protein symbol is the stop symbol 342 | and protein.count(stop_symbol) == 1 # Check if there is only one stop symbol 343 | and len(set(dna)) 344 | == 4 # Check if DNA consists of 4 unique nucleotides (A, T, C, G) 345 | ) 346 | 347 | 348 | def get_amino_acid_sequence( 349 | dna: str, 350 | stop_symbol: str = "_", 351 | codon_table: int = 1, 352 | return_correct_seq: bool = False, 353 | ) -> Union[str, Tuple[str, bool]]: 354 | """ 355 | Return the translated protein sequence given a DNA sequence and codon table. 356 | 357 | Args: 358 | dna (str): DNA sequence. 359 | stop_symbol (str): Stop symbol. 360 | codon_table (int): Codon table number. 361 | return_correct_seq (bool): Whether to return if the sequence is correct. 362 | 363 | Returns: 364 | Union[str, Tuple[str, bool]]: Protein sequence and correctness flag if 365 | return_correct_seq is True, otherwise just the protein sequence. 366 | """ 367 | dna_seq = Seq(dna).strip() 368 | 369 | # Translate the DNA sequence to a protein sequence 370 | protein_seq = str( 371 | dna_seq.translate( 372 | stop_symbol=stop_symbol, # Symbol to use for stop codons 373 | to_stop=False, # Translate the entire sequence, including any stop codons 374 | cds=False, # Do not assume the input is a coding sequence 375 | table=codon_table, # Codon table to use for translation 376 | ) 377 | ).strip() 378 | 379 | return ( 380 | protein_seq 381 | if not return_correct_seq 382 | else (protein_seq, is_correct_seq(dna_seq, protein_seq, stop_symbol)) 383 | ) 384 | 385 | 386 | def read_fasta_file( 387 | input_file: str, 388 | save_to_file: Optional[str] = None, 389 | organism: str = "", 390 | buffer_size: int = 50000, 391 | ) -> pd.DataFrame: 392 | """ 393 | Read a FASTA file of DNA sequences and convert it to a Pandas DataFrame. 394 | Optionally, save the DataFrame to a CSV file. 395 | 396 | Args: 397 | input_file (str): Path to the input FASTA file. 398 | save_to_file (Optional[str]): Path to save the output DataFrame. If None, 399 | data is only returned. 400 | organism (str): Name of the organism. If empty, it will be extracted from 401 | the FASTA description. 402 | buffer_size (int): Number of records to process before writing to file. 403 | 404 | Returns: 405 | pd.DataFrame: DataFrame containing the DNA sequences if return_dataframe 406 | is True, else None. 407 | 408 | Raises: 409 | FileNotFoundError: If the input file does not exist. 410 | """ 411 | if not os.path.exists(input_file): 412 | raise FileNotFoundError(f"Input file not found: {input_file}") 413 | 414 | buffer = [] 415 | columns = [ 416 | "dna", 417 | "protein", 418 | "correct_seq", 419 | "organism", 420 | "GeneID", 421 | "description", 422 | "tokenized", 423 | ] 424 | 425 | # Initialize DataFrame to store all data if return_dataframe is True 426 | all_data = pd.DataFrame(columns=columns) 427 | 428 | with open(input_file, "r") as fasta_file: 429 | for record in tqdm( 430 | SeqIO.parse(fasta_file, "fasta"), 431 | desc=f"Processing {organism}", 432 | unit=" Records", 433 | ): 434 | dna = str(record.seq).strip().upper() # Ensure uppercase DNA sequence 435 | 436 | # Determine the organism from the record if not provided 437 | current_organism = organism or find_pattern_in_fasta( 438 | "organism", record.description 439 | ) 440 | gene_id = find_pattern_in_fasta("GeneID", record.description) 441 | 442 | # Get the appropriate codon table for the organism 443 | codon_table = get_codon_table(current_organism) 444 | 445 | # Translate DNA to protein sequence 446 | protein, correct_seq = get_amino_acid_sequence( 447 | dna, 448 | stop_symbol=STOP_SYMBOL, 449 | codon_table=codon_table, 450 | return_correct_seq=True, 451 | ) 452 | description = record.description.split("[", 1)[0].strip() 453 | tokenized = get_merged_seq(protein, dna, separator=STOP_SYMBOL) 454 | 455 | # Create a data row for the current sequence 456 | data_row = { 457 | "dna": dna, 458 | "protein": protein, 459 | "correct_seq": correct_seq, 460 | "organism": current_organism, 461 | "GeneID": gene_id, 462 | "description": description, 463 | "tokenized": tokenized, 464 | } 465 | buffer.append(data_row) 466 | 467 | # Write buffer to CSV file when buffer size is reached 468 | if save_to_file and len(buffer) >= buffer_size: 469 | write_buffer_to_csv(buffer, save_to_file, columns) 470 | buffer = [] 471 | 472 | all_data = pd.concat( 473 | [all_data, pd.DataFrame([data_row])], ignore_index=True 474 | ) 475 | 476 | # Write remaining buffer to CSV file 477 | if save_to_file and buffer: 478 | write_buffer_to_csv(buffer, save_to_file, columns) 479 | 480 | return all_data 481 | 482 | 483 | def write_buffer_to_csv(buffer: List[Dict], output_path: str, columns: List[str]): 484 | """Helper function to write buffer to CSV file.""" 485 | buffer_df = pd.DataFrame(buffer, columns=columns) 486 | buffer_df.to_csv( 487 | output_path, 488 | mode="a", 489 | header=(not os.path.exists(output_path)), 490 | index=True, 491 | ) 492 | 493 | 494 | def download_codon_frequencies_from_kazusa( 495 | taxonomy_id: Optional[int] = None, 496 | organism: Optional[str] = None, 497 | taxonomy_reference: Optional[str] = None, 498 | return_original_format: bool = False, 499 | ) -> AMINO2CODON_TYPE: 500 | """ 501 | Return the codon table of the given taxonomy ID from the Kazusa Database. 502 | 503 | Args: 504 | taxonomy_id (Optional[int]): Taxonomy ID. 505 | organism (Optional[str]): Name of the organism. 506 | taxonomy_reference (Optional[str]): Taxonomy reference. 507 | return_original_format (bool): Whether to return in the original format. 508 | 509 | Returns: 510 | AMINO2CODON_TYPE: Codon table. 511 | """ 512 | if taxonomy_reference: 513 | taxonomy_id = get_taxonomy_id(taxonomy_reference, organism=organism) 514 | 515 | kazusa_amino2codon = pct.get_codons_table(table_name=taxonomy_id) 516 | 517 | if return_original_format: 518 | return kazusa_amino2codon 519 | 520 | # Replace "*" with STOP_SYMBOL in the codon table 521 | kazusa_amino2codon[STOP_SYMBOL] = kazusa_amino2codon.pop("*") 522 | 523 | # Create amino2codon dictionary 524 | amino2codon = { 525 | aminoacid: (list(codon2freq.keys()), list(codon2freq.values())) 526 | for aminoacid, codon2freq in kazusa_amino2codon.items() 527 | } 528 | 529 | return sort_amino2codon_skeleton(amino2codon) 530 | 531 | 532 | def build_amino2codon_skeleton(organism: str) -> AMINO2CODON_TYPE: 533 | """ 534 | Return the empty skeleton of the amino2codon dictionary, needed for 535 | get_codon_frequencies. 536 | 537 | Args: 538 | organism (str): Name of the organism. 539 | 540 | Returns: 541 | AMINO2CODON_TYPE: Empty amino2codon dictionary. 542 | """ 543 | amino2codon = {} 544 | possible_codons = [f"{i}{j}{k}" for i in "ACGT" for j in "ACGT" for k in "ACGT"] 545 | possible_aminoacids = get_amino_acid_sequence( 546 | dna="".join(possible_codons), 547 | codon_table=get_codon_table(organism), 548 | return_correct_seq=False, 549 | ) 550 | 551 | # Initialize the amino2codon skeleton with all possible codons and set their 552 | # frequencies to 0 553 | for i, (codon, amino) in enumerate(zip(possible_codons, possible_aminoacids)): 554 | if amino not in amino2codon: 555 | amino2codon[amino] = ([], []) 556 | 557 | amino2codon[amino][0].append(codon) 558 | amino2codon[amino][1].append(0) 559 | 560 | # Sort the dictionary and each list of codon frequency alphabetically 561 | amino2codon = sort_amino2codon_skeleton(amino2codon) 562 | 563 | return amino2codon 564 | 565 | 566 | def get_codon_frequencies( 567 | dna_sequences: List[str], 568 | protein_sequences: Optional[List[str]] = None, 569 | organism: Optional[str] = None, 570 | ) -> AMINO2CODON_TYPE: 571 | """ 572 | Return a dictionary mapping each codon to its respective frequency based on 573 | the collection of DNA sequences and protein sequences. 574 | 575 | Args: 576 | dna_sequences (List[str]): List of DNA sequences. 577 | protein_sequences (Optional[List[str]]): List of protein sequences. 578 | organism (Optional[str]): Name of the organism. 579 | 580 | Returns: 581 | AMINO2CODON_TYPE: Dictionary mapping each amino acid to a tuple of codons 582 | and frequencies. 583 | """ 584 | if organism: 585 | codon_table = get_codon_table(organism) 586 | protein_sequences = [ 587 | get_amino_acid_sequence( 588 | dna, codon_table=codon_table, return_correct_seq=False 589 | ) 590 | for dna in dna_sequences 591 | ] 592 | 593 | amino2codon = build_amino2codon_skeleton(organism) 594 | 595 | # Count the frequencies of each codon for each amino acid 596 | for dna, protein in zip(dna_sequences, protein_sequences): 597 | for i, amino in enumerate(protein): 598 | codon = dna[i * 3 : (i + 1) * 3] 599 | codon_loc = amino2codon[amino][0].index(codon) 600 | amino2codon[amino][1][codon_loc] += 1 601 | 602 | # Normalize codon frequencies per amino acid so they sum to 1 603 | amino2codon = { 604 | amino: (codons, [freq / (sum(frequencies) + 1e-100) for freq in frequencies]) 605 | for amino, (codons, frequencies) in amino2codon.items() 606 | } 607 | 608 | return amino2codon 609 | 610 | 611 | def get_organism_to_codon_frequencies( 612 | dataset: pd.DataFrame, organisms: List[str] 613 | ) -> Dict[str, AMINO2CODON_TYPE]: 614 | """ 615 | Return a dictionary mapping each organism to their codon frequency distribution. 616 | 617 | Args: 618 | dataset (pd.DataFrame): DataFrame containing DNA sequences. 619 | organisms (List[str]): List of organisms. 620 | 621 | Returns: 622 | Dict[str, AMINO2CODON_TYPE]: Dictionary mapping each organism to its codon 623 | frequency distribution. 624 | """ 625 | organism2frequencies = {} 626 | 627 | # Calculate codon frequencies for each organism in the dataset 628 | for organism in tqdm( 629 | organisms, desc="Calculating Codon Frequencies: ", unit="Organism" 630 | ): 631 | organism_data = dataset.loc[dataset["organism"] == organism] 632 | 633 | dna_sequences = organism_data["dna"].to_list() 634 | protein_sequences = organism_data["protein"].to_list() 635 | 636 | codon_frequencies = get_codon_frequencies(dna_sequences, protein_sequences) 637 | organism2frequencies[organism] = codon_frequencies 638 | 639 | return organism2frequencies 640 | 641 | 642 | def get_codon_table(organism: str) -> int: 643 | """ 644 | Return the appropriate NCBI codon table for a given organism. 645 | 646 | Args: 647 | organism (str): Name of the organism. 648 | 649 | Returns: 650 | int: Codon table number. 651 | """ 652 | # Common codon table (Table 1) for many model organisms 653 | if organism in [ 654 | "Arabidopsis thaliana", 655 | "Caenorhabditis elegans", 656 | "Chlamydomonas reinhardtii", 657 | "Saccharomyces cerevisiae", 658 | "Danio rerio", 659 | "Drosophila melanogaster", 660 | "Homo sapiens", 661 | "Mus musculus", 662 | "Nicotiana tabacum", 663 | "Solanum tuberosum", 664 | "Solanum lycopersicum", 665 | "Oryza sativa", 666 | "Glycine max", 667 | "Zea mays", 668 | ]: 669 | codon_table = 1 670 | 671 | # Chloroplast codon table (Table 11) 672 | elif organism in [ 673 | "Chlamydomonas reinhardtii chloroplast", 674 | "Nicotiana tabacum chloroplast", 675 | ]: 676 | codon_table = 11 677 | 678 | # Default to Table 11 for other bacteria and archaea 679 | else: 680 | codon_table = 11 681 | 682 | return codon_table 683 | -------------------------------------------------------------------------------- /CodonTransformer/CodonEvaluation.py: -------------------------------------------------------------------------------- 1 | """ 2 | File: CodonEvaluation.py 3 | --------------------------- 4 | Includes functions to calculate various evaluation metrics along with helper 5 | functions. 6 | """ 7 | 8 | from typing import Dict, List, Tuple 9 | 10 | import pandas as pd 11 | from CAI import CAI, relative_adaptiveness 12 | from tqdm import tqdm 13 | 14 | 15 | def get_CSI_weights(sequences: List[str]) -> Dict[str, float]: 16 | """ 17 | Calculate the Codon Similarity Index (CSI) weights for a list of DNA sequences. 18 | 19 | Args: 20 | sequences (List[str]): List of DNA sequences. 21 | 22 | Returns: 23 | dict: The CSI weights. 24 | """ 25 | return relative_adaptiveness(sequences=sequences) 26 | 27 | 28 | def get_CSI_value(dna: str, weights: Dict[str, float]) -> float: 29 | """ 30 | Calculate the Codon Similarity Index (CSI) for a DNA sequence. 31 | 32 | Args: 33 | dna (str): The DNA sequence. 34 | weights (dict): The CSI weights from get_CSI_weights. 35 | 36 | Returns: 37 | float: The CSI value. 38 | """ 39 | return CAI(dna, weights) 40 | 41 | 42 | def get_organism_to_CSI_weights( 43 | dataset: pd.DataFrame, organisms: List[str] 44 | ) -> Dict[str, dict]: 45 | """ 46 | Calculate the Codon Similarity Index (CSI) weights for a list of organisms. 47 | 48 | Args: 49 | dataset (pd.DataFrame): Dataset containing organism and DNA sequence info. 50 | organisms (List[str]): List of organism names. 51 | 52 | Returns: 53 | Dict[str, dict]: A dictionary mapping each organism to its CSI weights. 54 | """ 55 | organism2weights = {} 56 | 57 | # Iterate through each organism to calculate its CSI weights 58 | for organism in tqdm(organisms, desc="Calculating CSI Weights: ", unit="Organism"): 59 | organism_data = dataset.loc[dataset["organism"] == organism] 60 | sequences = organism_data["dna"].to_list() 61 | weights = get_CSI_weights(sequences) 62 | organism2weights[organism] = weights 63 | 64 | return organism2weights 65 | 66 | 67 | def get_GC_content(dna: str, lower: bool = False) -> float: 68 | """ 69 | Calculate the GC content of a DNA sequence. 70 | 71 | Args: 72 | dna (str): The DNA sequence. 73 | lower (bool): If True, converts DNA sequence to lowercase before calculation. 74 | 75 | Returns: 76 | float: The GC content as a percentage. 77 | """ 78 | if lower: 79 | dna = dna.lower() 80 | return (dna.count("G") + dna.count("C")) / len(dna) * 100 81 | 82 | 83 | def get_cfd( 84 | dna: str, 85 | codon_frequencies: Dict[str, Tuple[List[str], List[float]]], 86 | threshold: float = 0.3, 87 | ) -> float: 88 | """ 89 | Calculate the codon frequency distribution (CFD) metric for a DNA sequence. 90 | 91 | Args: 92 | dna (str): The DNA sequence. 93 | codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon 94 | frequency distribution per amino acid. 95 | threshold (float): Frequency threshold for counting rare codons. 96 | 97 | Returns: 98 | float: The CFD metric as a percentage. 99 | """ 100 | # Get a dictionary mapping each codon to its normalized frequency 101 | codon2frequency = { 102 | codon: freq / max(frequencies) 103 | for amino, (codons, frequencies) in codon_frequencies.items() 104 | for codon, freq in zip(codons, frequencies) 105 | } 106 | 107 | cfd = 0 108 | 109 | # Iterate through the DNA sequence in steps of 3 to process each codon 110 | for i in range(0, len(dna), 3): 111 | codon = dna[i : i + 3] 112 | codon_frequency = codon2frequency[codon] 113 | 114 | if codon_frequency < threshold: 115 | cfd += 1 116 | 117 | return cfd / (len(dna) / 3) * 100 118 | 119 | 120 | def get_min_max_percentage( 121 | dna: str, 122 | codon_frequencies: Dict[str, Tuple[List[str], List[float]]], 123 | window_size: int = 18, 124 | ) -> List[float]: 125 | """ 126 | Calculate the %MinMax metric for a DNA sequence. 127 | 128 | Args: 129 | dna (str): The DNA sequence. 130 | codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon 131 | frequency distribution per amino acid. 132 | window_size (int): Size of the window to calculate %MinMax. 133 | 134 | Returns: 135 | List[float]: List of %MinMax values for the sequence. 136 | 137 | Credit: https://github.com/chowington/minmax 138 | """ 139 | # Get a dictionary mapping each codon to its respective amino acid 140 | codon2amino = { 141 | codon: amino 142 | for amino, (codons, frequencies) in codon_frequencies.items() 143 | for codon in codons 144 | } 145 | 146 | min_max_values = [] 147 | codons = [dna[i : i + 3] for i in range(0, len(dna), 3)] # Split DNA into codons 148 | 149 | # Iterate through the DNA sequence using the specified window size 150 | for i in range(len(codons) - window_size + 1): 151 | codon_window = codons[i : i + window_size] # Codons in the current window 152 | 153 | Actual = 0.0 # Average of the actual codon frequencies 154 | Max = 0.0 # Average of the min codon frequencies 155 | Min = 0.0 # Average of the max codon frequencies 156 | Avg = 0.0 # Average of the averages of all frequencies for each amino acid 157 | 158 | # Sum the frequencies for codons in the current window 159 | for codon in codon_window: 160 | aminoacid = codon2amino[codon] 161 | frequencies = codon_frequencies[aminoacid][1] 162 | codon_index = codon_frequencies[aminoacid][0].index(codon) 163 | codon_frequency = codon_frequencies[aminoacid][1][codon_index] 164 | 165 | Actual += codon_frequency 166 | Max += max(frequencies) 167 | Min += min(frequencies) 168 | Avg += sum(frequencies) / len(frequencies) 169 | 170 | # Divide by the window size to get the averages 171 | Actual = Actual / window_size 172 | Max = Max / window_size 173 | Min = Min / window_size 174 | Avg = Avg / window_size 175 | 176 | # Calculate %MinMax 177 | percentMax = ((Actual - Avg) / (Max - Avg)) * 100 178 | percentMin = ((Avg - Actual) / (Avg - Min)) * 100 179 | 180 | # Append the appropriate %MinMax value 181 | if percentMax >= 0: 182 | min_max_values.append(percentMax) 183 | else: 184 | min_max_values.append(-percentMin) 185 | 186 | # Populate the last floor(window_size / 2) entries of min_max_values with None 187 | for i in range(int(window_size / 2)): 188 | min_max_values.append(None) 189 | 190 | return min_max_values 191 | 192 | 193 | def get_sequence_complexity(dna: str) -> float: 194 | """ 195 | Calculate the sequence complexity score of a DNA sequence. 196 | 197 | Args: 198 | dna (str): The DNA sequence. 199 | 200 | Returns: 201 | float: The sequence complexity score. 202 | """ 203 | 204 | def sum_up_to(x): 205 | """Recursive function to calculate the sum of integers from 1 to x.""" 206 | if x <= 1: 207 | return 1 208 | else: 209 | return x + sum_up_to(x - 1) 210 | 211 | def f(x): 212 | """Returns 4 if x is greater than or equal to 4, else returns x.""" 213 | if x >= 4: 214 | return 4 215 | elif x < 4: 216 | return x 217 | 218 | unique_subseq_length = [] 219 | 220 | # Calculate unique subsequences lengths 221 | for i in range(1, len(dna) + 1): 222 | unique_subseq = set() 223 | for j in range(len(dna) - (i - 1)): 224 | unique_subseq.add(dna[j : (j + i)]) 225 | unique_subseq_length.append(len(unique_subseq)) 226 | 227 | # Calculate complexity score 228 | complexity_score = ( 229 | sum(unique_subseq_length) / (sum_up_to(len(dna) - 1) + f(len(dna))) 230 | ) * 100 231 | 232 | return complexity_score 233 | 234 | 235 | def get_sequence_similarity( 236 | original: str, predicted: str, truncate: bool = True, window_length: int = 1 237 | ) -> float: 238 | """ 239 | Calculate the sequence similarity between two sequences. 240 | 241 | Args: 242 | original (str): The original sequence. 243 | predicted (str): The predicted sequence. 244 | truncate (bool): If True, truncate the original sequence to match the length 245 | of the predicted sequence. 246 | window_length (int): Length of the window for comparison (1 for amino acids, 247 | 3 for codons). 248 | 249 | Returns: 250 | float: The sequence similarity as a percentage. 251 | 252 | Preconditions: 253 | len(predicted) <= len(original). 254 | """ 255 | if not truncate and len(original) != len(predicted): 256 | raise ValueError( 257 | "Set truncate to True if the length of sequences do not match." 258 | ) 259 | 260 | identity = 0.0 261 | original = original.strip() 262 | predicted = predicted.strip() 263 | 264 | if truncate: 265 | original = original[: len(predicted)] 266 | 267 | if window_length == 1: 268 | # Simple comparison for amino acid 269 | for i in range(len(predicted)): 270 | if original[i] == predicted[i]: 271 | identity += 1 272 | else: 273 | # Comparison for substrings based on window_length 274 | for i in range(0, len(original) - window_length + 1, window_length): 275 | if original[i : i + window_length] == predicted[i : i + window_length]: 276 | identity += 1 277 | 278 | return (identity / (len(predicted) / window_length)) * 100 279 | -------------------------------------------------------------------------------- /CodonTransformer/CodonJupyter.py: -------------------------------------------------------------------------------- 1 | """ 2 | File: CodonJupyter.py 3 | --------------------- 4 | Includes Jupyter-specific functions for displaying interactive widgets. 5 | """ 6 | 7 | from typing import Dict, List, Tuple 8 | 9 | import ipywidgets as widgets 10 | from IPython.display import HTML, display 11 | 12 | from CodonTransformer.CodonUtils import ( 13 | COMMON_ORGANISMS, 14 | ID2ORGANISM, 15 | ORGANISM2ID, 16 | DNASequencePrediction, 17 | ) 18 | 19 | 20 | class UserContainer: 21 | """ 22 | A container class to store user inputs for organism and protein sequence. 23 | Attributes: 24 | organism (int): The selected organism id. 25 | protein (str): The input protein sequence. 26 | """ 27 | 28 | def __init__(self) -> None: 29 | self.organism: int = -1 30 | self.protein: str = "" 31 | 32 | 33 | def create_styled_options( 34 | organisms: list, organism2id: Dict[str, int], is_fine_tuned: bool = False 35 | ) -> list: 36 | """ 37 | Create styled options for the dropdown widget. 38 | 39 | Args: 40 | organisms (list): List of organism names. 41 | organism2id (Dict[str, int]): Dictionary mapping organism names to their IDs. 42 | is_fine_tuned (bool): Whether these are fine-tuned organisms. 43 | 44 | Returns: 45 | list: Styled options for the dropdown widget. 46 | """ 47 | styled_options = [] 48 | for organism in organisms: 49 | organism_id = organism2id[organism] 50 | if is_fine_tuned: 51 | if organism_id < 10: 52 | styled_options.append(f"\u200b{organism_id:>6}. {organism}") 53 | elif organism_id < 100: 54 | styled_options.append(f"\u200b{organism_id:>5}. {organism}") 55 | else: 56 | styled_options.append(f"\u200b{organism_id:>4}. {organism}") 57 | else: 58 | if organism_id < 10: 59 | styled_options.append(f"{organism_id:>6}. {organism}") 60 | elif organism_id < 100: 61 | styled_options.append(f"{organism_id:>5}. {organism}") 62 | else: 63 | styled_options.append(f"{organism_id:>4}. {organism}") 64 | return styled_options 65 | 66 | 67 | def create_dropdown_options(organism2id: Dict[str, int]) -> list: 68 | """ 69 | Create the full list of dropdown options, including section headers. 70 | 71 | Args: 72 | organism2id (Dict[str, int]): Dictionary mapping organism names to their IDs. 73 | 74 | Returns: 75 | list: Full list of dropdown options. 76 | """ 77 | fine_tuned_organisms = sorted( 78 | [org for org in organism2id.keys() if org in COMMON_ORGANISMS] 79 | ) 80 | all_organisms = sorted(organism2id.keys()) 81 | 82 | fine_tuned_options = create_styled_options( 83 | fine_tuned_organisms, organism2id, is_fine_tuned=True 84 | ) 85 | all_organisms_options = create_styled_options( 86 | all_organisms, organism2id, is_fine_tuned=False 87 | ) 88 | 89 | return ( 90 | [""] 91 | + ["Selected Organisms"] 92 | + fine_tuned_options 93 | + [""] 94 | + ["All Organisms"] 95 | + all_organisms_options 96 | ) 97 | 98 | 99 | def create_organism_dropdown(container: UserContainer) -> widgets.Dropdown: 100 | """ 101 | Create and configure the organism dropdown widget. 102 | 103 | Args: 104 | container (UserContainer): Container to store the selected organism. 105 | 106 | Returns: 107 | widgets.Dropdown: Configured dropdown widget. 108 | """ 109 | dropdown = widgets.Dropdown( 110 | options=create_dropdown_options(ORGANISM2ID), 111 | description="", 112 | layout=widgets.Layout(width="40%", margin="0 0 10px 0"), 113 | style={"description_width": "initial"}, 114 | ) 115 | 116 | def show_organism(change: Dict[str, str]) -> None: 117 | """ 118 | Update the container with the selected organism and print to terminal. 119 | 120 | Args: 121 | change (Dict[str, str]): Information about the change in dropdown value. 122 | """ 123 | dropdown_choice = change["new"] 124 | if dropdown_choice and dropdown_choice not in [ 125 | "Selected Organisms", 126 | "All Organisms", 127 | ]: 128 | organism = "".join(filter(str.isdigit, dropdown_choice)) 129 | organism_id = ID2ORGANISM[int(organism)] 130 | container.organism = organism_id 131 | else: 132 | container.organism = None 133 | 134 | dropdown.observe(show_organism, names="value") 135 | return dropdown 136 | 137 | 138 | def get_dropdown_style() -> str: 139 | """ 140 | Return the custom CSS style for the dropdown widget. 141 | 142 | Returns: 143 | str: CSS style string. 144 | """ 145 | return """ 146 | 179 | """ 180 | 181 | 182 | def display_organism_dropdown(container: UserContainer) -> None: 183 | """ 184 | Display the organism dropdown widget and apply custom styles. 185 | 186 | Args: 187 | container (UserContainer): Container to store the selected organism. 188 | """ 189 | dropdown = create_organism_dropdown(container) 190 | header = widgets.HTML( 191 | 'Select Organism:' 192 | '
' 193 | ) 194 | container_widget = widgets.VBox( 195 | [header, dropdown], 196 | layout=widgets.Layout(padding="12px 0 12px 25px"), 197 | ) 198 | display(container_widget) 199 | display(HTML(get_dropdown_style())) 200 | 201 | 202 | def display_protein_input(container: UserContainer) -> None: 203 | """ 204 | Display a widget for entering a protein sequence and save it to the container. 205 | 206 | Args: 207 | container (UserContainer): A container to store the entered protein sequence. 208 | """ 209 | protein_input = widgets.Textarea( 210 | value="", 211 | placeholder="Enter here...", 212 | description="", 213 | layout=widgets.Layout(width="100%", height="100px", margin="0 0 10px 0"), 214 | style={"description_width": "initial"}, 215 | ) 216 | 217 | # Custom CSS for the input widget 218 | input_style = """ 219 | 238 | """ 239 | 240 | # Function to save the input protein sequence to the container 241 | def save_protein(change: Dict[str, str]) -> None: 242 | """ 243 | Save the input protein sequence to the container. 244 | 245 | Args: 246 | change (Dict[str, str]): A dictionary containing information about 247 | the change in textarea value. 248 | """ 249 | container.protein = ( 250 | change["new"] 251 | .upper() 252 | .strip() 253 | .replace("\n", "") 254 | .replace(" ", "") 255 | .replace("\t", "") 256 | ) 257 | 258 | # Attach the function to the input widget 259 | protein_input.observe(save_protein, names="value") 260 | 261 | # Display the input widget 262 | header = widgets.HTML( 263 | 'Enter Protein Sequence:' 264 | '
' 265 | ) 266 | container_widget = widgets.VBox( 267 | [header, protein_input], layout=widgets.Layout(padding="12px 12px 0 25px") 268 | ) 269 | 270 | display(container_widget) 271 | display(widgets.HTML(input_style)) 272 | 273 | 274 | def format_model_output(output: DNASequencePrediction) -> str: 275 | """ 276 | Format DNA sequence prediction output in an appealing and easy-to-read manner. 277 | 278 | This function takes the prediction output and formats it into 279 | a structured string with clear section headers and separators. 280 | 281 | Args: 282 | output (DNASequencePrediction): Object containing the prediction output. 283 | Expected attributes: 284 | - organism (str): The organism name. 285 | - protein (str): The input protein sequence. 286 | - processed_input (str): The processed input sequence. 287 | - predicted_dna (str): The predicted DNA sequence. 288 | 289 | Returns: 290 | str: A formatted string containing the organized output. 291 | """ 292 | 293 | def format_section(title: str, content: str) -> str: 294 | """Helper function to format individual sections.""" 295 | separator = "-" * 29 296 | title_line = f"| {title.center(25)} |" 297 | return f"{separator}\n{title_line}\n{separator}\n{content}\n\n" 298 | 299 | sections: List[Tuple[str, str]] = [ 300 | ("Organism", output.organism), 301 | ("Input Protein", output.protein), 302 | ("Processed Input", output.processed_input), 303 | ("Predicted DNA", output.predicted_dna), 304 | ] 305 | 306 | formatted_output = "" 307 | for title, content in sections: 308 | formatted_output += format_section(title, content) 309 | 310 | # Remove the last newline to avoid extra space at the end 311 | return formatted_output.rstrip() 312 | -------------------------------------------------------------------------------- /CodonTransformer/CodonPrediction.py: -------------------------------------------------------------------------------- 1 | """ 2 | File: CodonPrediction.py 3 | --------------------------- 4 | Includes functions to tokenize input, load models, infer predicted dna sequences and 5 | helper functions related to processing data for passing to the model. 6 | """ 7 | 8 | import warnings 9 | from typing import Any, Dict, List, Optional, Tuple, Union 10 | 11 | import numpy as np 12 | import onnxruntime as rt 13 | import torch 14 | import transformers 15 | from transformers import ( 16 | AutoTokenizer, 17 | BatchEncoding, 18 | BigBirdConfig, 19 | BigBirdForMaskedLM, 20 | PreTrainedTokenizerFast, 21 | ) 22 | 23 | from CodonTransformer.CodonData import get_merged_seq 24 | from CodonTransformer.CodonUtils import ( 25 | AMINO_ACID_TO_INDEX, 26 | INDEX2TOKEN, 27 | NUM_ORGANISMS, 28 | ORGANISM2ID, 29 | TOKEN2INDEX, 30 | DNASequencePrediction, 31 | ) 32 | 33 | 34 | def predict_dna_sequence( 35 | protein: str, 36 | organism: Union[int, str], 37 | device: torch.device, 38 | tokenizer: Union[str, PreTrainedTokenizerFast] = None, 39 | model: Union[str, torch.nn.Module] = None, 40 | attention_type: str = "original_full", 41 | deterministic: bool = True, 42 | temperature: float = 0.2, 43 | top_p: float = 0.95, 44 | num_sequences: int = 1, 45 | match_protein: bool = False, 46 | ) -> Union[DNASequencePrediction, List[DNASequencePrediction]]: 47 | """ 48 | Predict the DNA sequence(s) for a given protein using the CodonTransformer model. 49 | 50 | This function takes a protein sequence and an organism (as ID or name) as input 51 | and returns the predicted DNA sequence(s) using the CodonTransformer model. It can use 52 | either provided tokenizer and model objects or load them from specified paths. 53 | 54 | Args: 55 | protein (str): The input protein sequence for which to predict the DNA sequence. 56 | organism (Union[int, str]): Either the ID of the organism or its name (e.g., 57 | "Escherichia coli general"). If a string is provided, it will be converted 58 | to the corresponding ID using ORGANISM2ID. 59 | device (torch.device): The device (CPU or GPU) to run the model on. 60 | tokenizer (Union[str, PreTrainedTokenizerFast, None], optional): Either a file 61 | path to load the tokenizer from, a pre-loaded tokenizer object, or None. If 62 | None, it will be loaded from HuggingFace. Defaults to None. 63 | model (Union[str, torch.nn.Module, None], optional): Either a file path to load 64 | the model from, a pre-loaded model object, or None. If None, it will be 65 | loaded from HuggingFace. Defaults to None. 66 | attention_type (str, optional): The type of attention mechanism to use in the 67 | model. Can be either 'block_sparse' or 'original_full'. Defaults to 68 | "original_full". 69 | deterministic (bool, optional): Whether to use deterministic decoding (most 70 | likely tokens). If False, samples tokens according to their probabilities 71 | adjusted by the temperature. Defaults to True. 72 | temperature (float, optional): A value controlling the randomness of predictions 73 | during non-deterministic decoding. Lower values (e.g., 0.2) make the model 74 | more conservative, while higher values (e.g., 0.8) increase randomness. 75 | Using high temperatures may result in prediction of DNA sequences that 76 | do not translate to the input protein. 77 | Recommended values are: 78 | - Low randomness: 0.2 79 | - Medium randomness: 0.5 80 | - High randomness: 0.8 81 | The temperature must be a positive float. Defaults to 0.2. 82 | top_p (float, optional): The cumulative probability threshold for nucleus sampling. 83 | Tokens with cumulative probability up to top_p are considered for sampling. 84 | This parameter helps balance diversity and coherence in the predicted DNA sequences. 85 | The value must be a float between 0 and 1. Defaults to 0.95. 86 | num_sequences (int, optional): The number of DNA sequences to generate. Only applicable 87 | when deterministic is False. Defaults to 1. 88 | match_protein (bool, optional): Ensures the predicted DNA sequence is translated 89 | to the input protein sequence by sampling from only the respective codons of 90 | given amino acids. Defaults to False. 91 | 92 | Returns: 93 | Union[DNASequencePrediction, List[DNASequencePrediction]]: An object or list of objects 94 | containing the prediction results: 95 | - organism (str): Name of the organism used for prediction. 96 | - protein (str): Input protein sequence for which DNA sequence is predicted. 97 | - processed_input (str): Processed input sequence (merged protein and DNA). 98 | - predicted_dna (str): Predicted DNA sequence. 99 | 100 | Raises: 101 | ValueError: If the protein sequence is empty, if the organism is invalid, 102 | if the temperature is not a positive float, if top_p is not between 0 and 1, 103 | or if num_sequences is less than 1 or used with deterministic mode. 104 | 105 | Note: 106 | This function uses ORGANISM2ID, INDEX2TOKEN, and AMINO_ACID_TO_INDEX dictionaries 107 | imported from CodonTransformer.CodonUtils. ORGANISM2ID maps organism names to their 108 | corresponding IDs. INDEX2TOKEN maps model output indices (token IDs) to 109 | respective codons. AMINO_ACID_TO_INDEX maps each amino acid and stop symbol to indices 110 | of codon tokens that translate to it. 111 | 112 | Example: 113 | >>> import torch 114 | >>> from transformers import AutoTokenizer, BigBirdForMaskedLM 115 | >>> from CodonTransformer.CodonPrediction import predict_dna_sequence 116 | >>> from CodonTransformer.CodonJupyter import format_model_output 117 | >>> 118 | >>> # Set up device 119 | >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 120 | >>> 121 | >>> # Load tokenizer and model 122 | >>> tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer") 123 | >>> model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer") 124 | >>> model = model.to(device) 125 | >>> 126 | >>> # Define protein sequence and organism 127 | >>> protein = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA" 128 | >>> organism = "Escherichia coli general" 129 | >>> 130 | >>> # Predict DNA sequence with deterministic decoding (single sequence) 131 | >>> output = predict_dna_sequence( 132 | ... protein=protein, 133 | ... organism=organism, 134 | ... device=device, 135 | ... tokenizer=tokenizer, 136 | ... model=model, 137 | ... attention_type="original_full", 138 | ... deterministic=True 139 | ... ) 140 | >>> 141 | >>> # Predict multiple DNA sequences with low randomness and top_p sampling 142 | >>> output_random = predict_dna_sequence( 143 | ... protein=protein, 144 | ... organism=organism, 145 | ... device=device, 146 | ... tokenizer=tokenizer, 147 | ... model=model, 148 | ... attention_type="original_full", 149 | ... deterministic=False, 150 | ... temperature=0.2, 151 | ... top_p=0.95, 152 | ... num_sequences=3 153 | ... ) 154 | >>> 155 | >>> print(format_model_output(output)) 156 | >>> for i, seq in enumerate(output_random, 1): 157 | ... print(f"Sequence {i}:") 158 | ... print(format_model_output(seq)) 159 | ... print() 160 | """ 161 | if not protein: 162 | raise ValueError("Protein sequence cannot be empty.") 163 | 164 | if not isinstance(temperature, (float, int)) or temperature <= 0: 165 | raise ValueError("Temperature must be a positive float.") 166 | 167 | if not isinstance(top_p, (float, int)) or not 0 < top_p <= 1.0: 168 | raise ValueError("top_p must be a float between 0 and 1.") 169 | 170 | if not isinstance(num_sequences, int) or num_sequences < 1: 171 | raise ValueError("num_sequences must be a positive integer.") 172 | 173 | if deterministic and num_sequences > 1: 174 | raise ValueError( 175 | "Multiple sequences can only be generated in non-deterministic mode." 176 | ) 177 | 178 | # Load tokenizer 179 | if not isinstance(tokenizer, PreTrainedTokenizerFast): 180 | tokenizer = load_tokenizer(tokenizer) 181 | 182 | # Load model 183 | if not isinstance(model, torch.nn.Module): 184 | model = load_model(model, device=device, attention_type=attention_type) 185 | else: 186 | model.eval() 187 | model.bert.set_attention_type(attention_type) 188 | model.to(device) 189 | 190 | # Validate organism and convert to organism_id and organism_name 191 | organism_id, organism_name = validate_and_convert_organism(organism) 192 | 193 | # Inference loop 194 | with torch.no_grad(): 195 | # Tokenize the input sequence 196 | merged_seq = get_merged_seq(protein=protein, dna="") 197 | input_dict = { 198 | "idx": 0, # sample index 199 | "codons": merged_seq, 200 | "organism": organism_id, 201 | } 202 | tokenized_input = tokenize([input_dict], tokenizer=tokenizer).to(device) 203 | 204 | # Get the model predictions 205 | output_dict = model(**tokenized_input, return_dict=True) 206 | logits = output_dict.logits.detach().cpu() 207 | logits = logits[:, 1:-1, :] # Remove [CLS] and [SEP] tokens 208 | 209 | # Mask the logits of codons that do not correspond to the input protein sequence 210 | if match_protein: 211 | possible_tokens_per_position = [ 212 | AMINO_ACID_TO_INDEX[token[0]] for token in merged_seq.split(" ") 213 | ] 214 | mask = torch.full_like(logits, float("-inf")) 215 | 216 | for pos, possible_tokens in enumerate(possible_tokens_per_position): 217 | mask[:, pos, possible_tokens] = 0 218 | 219 | logits = mask + logits 220 | 221 | predictions = [] 222 | for _ in range(num_sequences): 223 | # Decode the predicted DNA sequence from the model output 224 | if deterministic: 225 | predicted_indices = logits.argmax(dim=-1).squeeze().tolist() 226 | else: 227 | predicted_indices = sample_non_deterministic( 228 | logits=logits, temperature=temperature, top_p=top_p 229 | ) 230 | 231 | predicted_dna = list(map(INDEX2TOKEN.__getitem__, predicted_indices)) 232 | predicted_dna = ( 233 | "".join([token[-3:] for token in predicted_dna]).strip().upper() 234 | ) 235 | 236 | predictions.append( 237 | DNASequencePrediction( 238 | organism=organism_name, 239 | protein=protein, 240 | processed_input=merged_seq, 241 | predicted_dna=predicted_dna, 242 | ) 243 | ) 244 | 245 | return predictions[0] if num_sequences == 1 else predictions 246 | 247 | 248 | def sample_non_deterministic( 249 | logits: torch.Tensor, 250 | temperature: float = 0.2, 251 | top_p: float = 0.95, 252 | ) -> List[int]: 253 | """ 254 | Sample token indices from logits using temperature scaling and nucleus (top-p) sampling. 255 | 256 | This function applies temperature scaling to the logits, computes probabilities, 257 | and then performs nucleus sampling to select token indices. It is used for 258 | non-deterministic decoding in language models to introduce randomness while 259 | maintaining coherence in the generated sequences. 260 | 261 | Args: 262 | logits (torch.Tensor): The logits output from the model of shape 263 | [seq_len, vocab_size] or [batch_size, seq_len, vocab_size]. 264 | temperature (float, optional): Temperature value for scaling logits. 265 | Must be a positive float. Defaults to 1.0. 266 | top_p (float, optional): Cumulative probability threshold for nucleus sampling. 267 | Must be a float between 0 and 1. Tokens with cumulative probability up to 268 | `top_p` are considered for sampling. Defaults to 0.95. 269 | 270 | Returns: 271 | List[int]: A list of sampled token indices corresponding to the predicted tokens. 272 | 273 | Raises: 274 | ValueError: If `temperature` is not a positive float or if `top_p` is not between 0 and 1. 275 | 276 | Example: 277 | >>> logits = model_output.logits # Assume logits is a tensor of shape [seq_len, vocab_size] 278 | >>> predicted_indices = sample_non_deterministic(logits, temperature=0.7, top_p=0.9) 279 | """ 280 | if not isinstance(temperature, (float, int)) or temperature <= 0: 281 | raise ValueError("Temperature must be a positive float.") 282 | 283 | if not isinstance(top_p, (float, int)) or not 0 < top_p <= 1.0: 284 | raise ValueError("top_p must be a float between 0 and 1.") 285 | 286 | # Compute probabilities using temperature scaling 287 | probs = torch.softmax(logits / temperature, dim=-1) 288 | 289 | 290 | # Remove batch dimension if present 291 | if probs.dim() == 3: 292 | probs = probs.squeeze(0) # Shape: [seq_len, vocab_size] 293 | 294 | # Sort probabilities in descending order 295 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) 296 | probs_sum = torch.cumsum(probs_sort, dim=-1) 297 | mask = probs_sum - probs_sort > top_p 298 | 299 | # Zero out probabilities for tokens beyond the top-p threshold 300 | probs_sort[mask] = 0.0 301 | 302 | # Renormalize the probabilities 303 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) 304 | next_token = torch.multinomial(probs_sort, num_samples=1) 305 | predicted_indices = torch.gather(probs_idx, -1, next_token).squeeze(-1) 306 | 307 | return predicted_indices.tolist() 308 | 309 | 310 | def load_model( 311 | model_path: Optional[str] = None, 312 | device: torch.device = None, 313 | attention_type: str = "original_full", 314 | num_organisms: int = None, 315 | remove_prefix: bool = True, 316 | ) -> torch.nn.Module: 317 | """ 318 | Load a BigBirdForMaskedLM model from a model file, checkpoint, or HuggingFace. 319 | 320 | Args: 321 | model_path (Optional[str]): Path to the model file or checkpoint. If None, 322 | load from HuggingFace. 323 | device (torch.device, optional): The device to load the model onto. 324 | attention_type (str, optional): The type of attention, 'block_sparse' 325 | or 'original_full'. 326 | num_organisms (int, optional): Number of organisms, needed if loading from a 327 | checkpoint that requires this. 328 | remove_prefix (bool, optional): Whether to remove the "model." prefix from the 329 | keys in the state dict. 330 | 331 | Returns: 332 | torch.nn.Module: The loaded model. 333 | """ 334 | if not model_path: 335 | warnings.warn("Model path not provided. Loading from HuggingFace.", UserWarning) 336 | model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer") 337 | 338 | elif model_path.endswith(".ckpt"): 339 | checkpoint = torch.load(model_path) 340 | state_dict = checkpoint["state_dict"] 341 | 342 | # Remove the "model." prefix from the keys 343 | if remove_prefix: 344 | state_dict = { 345 | key.replace("model.", ""): value for key, value in state_dict.items() 346 | } 347 | 348 | if num_organisms is None: 349 | num_organisms = NUM_ORGANISMS 350 | 351 | # Load model configuration and instantiate the model 352 | config = load_bigbird_config(num_organisms) 353 | model = BigBirdForMaskedLM(config=config) 354 | model.load_state_dict(state_dict) 355 | 356 | elif model_path.endswith(".pt"): 357 | state_dict = torch.load(model_path) 358 | config = state_dict.pop("self.config") 359 | model = BigBirdForMaskedLM(config=config) 360 | model.load_state_dict(state_dict) 361 | 362 | else: 363 | raise ValueError( 364 | "Unsupported file type. Please provide a .ckpt or .pt file, " 365 | "or None to load from HuggingFace." 366 | ) 367 | 368 | # Prepare model for evaluation 369 | model.bert.set_attention_type(attention_type) 370 | model.eval() 371 | if device: 372 | model.to(device) 373 | 374 | return model 375 | 376 | 377 | def load_bigbird_config(num_organisms: int) -> BigBirdConfig: 378 | """ 379 | Load the config object used to train the BigBird transformer. 380 | 381 | Args: 382 | num_organisms (int): The number of organisms. 383 | 384 | Returns: 385 | BigBirdConfig: The configuration object for BigBird. 386 | """ 387 | config = transformers.BigBirdConfig( 388 | vocab_size=len(TOKEN2INDEX), # Equal to len(tokenizer) 389 | type_vocab_size=num_organisms, 390 | sep_token_id=2, 391 | ) 392 | return config 393 | 394 | 395 | def create_model_from_checkpoint( 396 | checkpoint_dir: str, output_model_dir: str, num_organisms: int 397 | ) -> None: 398 | """ 399 | Save a model to disk using a previous checkpoint. 400 | 401 | Args: 402 | checkpoint_dir (str): Directory where the checkpoint is stored. 403 | output_model_dir (str): Directory where the model will be saved. 404 | num_organisms (int): Number of organisms. 405 | """ 406 | checkpoint = load_model(model_path=checkpoint_dir, num_organisms=num_organisms) 407 | state_dict = checkpoint.state_dict() 408 | state_dict["self.config"] = load_bigbird_config(num_organisms=num_organisms) 409 | 410 | # Save the model state dict to the output directory 411 | torch.save(state_dict, output_model_dir) 412 | 413 | 414 | def load_tokenizer(tokenizer_path: Optional[str] = None) -> PreTrainedTokenizerFast: 415 | """ 416 | Create and return a tokenizer object from tokenizer path or HuggingFace. 417 | 418 | Args: 419 | tokenizer_path (Optional[str]): Path to the tokenizer file. If None, 420 | load from HuggingFace. 421 | 422 | Returns: 423 | PreTrainedTokenizerFast: The tokenizer object. 424 | """ 425 | if not tokenizer_path: 426 | warnings.warn( 427 | "Tokenizer path not provided. Loading from HuggingFace.", UserWarning 428 | ) 429 | return AutoTokenizer.from_pretrained("adibvafa/CodonTransformer") 430 | 431 | return transformers.PreTrainedTokenizerFast( 432 | tokenizer_file=tokenizer_path, 433 | bos_token="[CLS]", 434 | eos_token="[SEP]", 435 | unk_token="[UNK]", 436 | sep_token="[SEP]", 437 | pad_token="[PAD]", 438 | cls_token="[CLS]", 439 | mask_token="[MASK]", 440 | ) 441 | 442 | 443 | def tokenize( 444 | batch: List[Dict[str, Any]], 445 | tokenizer: Union[PreTrainedTokenizerFast, str] = None, 446 | max_len: int = 2048, 447 | ) -> BatchEncoding: 448 | """ 449 | Return the tokenized sequences given a batch of input data. 450 | Each data in the batch is expected to be a dictionary with "codons" and 451 | "organism" keys. 452 | 453 | Args: 454 | batch (List[Dict[str, Any]]): A list of dictionaries with "codons" and 455 | "organism" keys. 456 | tokenizer (PreTrainedTokenizerFast, str, optional): The tokenizer object or 457 | path to the tokenizer file. 458 | max_len (int, optional): Maximum length of the tokenized sequence. 459 | 460 | Returns: 461 | BatchEncoding: The tokenized batch. 462 | """ 463 | if not isinstance(tokenizer, PreTrainedTokenizerFast): 464 | tokenizer = load_tokenizer(tokenizer) 465 | 466 | tokenized = tokenizer( 467 | [data["codons"] for data in batch], 468 | return_attention_mask=True, 469 | return_token_type_ids=True, 470 | truncation=True, 471 | padding=True, 472 | max_length=max_len, 473 | return_tensors="pt", 474 | ) 475 | 476 | # Add token type IDs for species 477 | seq_len = tokenized["input_ids"].shape[-1] 478 | species_index = torch.tensor([[data["organism"]] for data in batch]) 479 | tokenized["token_type_ids"] = species_index.repeat(1, seq_len) 480 | 481 | return tokenized 482 | 483 | 484 | def validate_and_convert_organism(organism: Union[int, str]) -> Tuple[int, str]: 485 | """ 486 | Validate and convert the organism input to both ID and name. 487 | 488 | This function takes either an organism ID or name as input and returns both 489 | the ID and name. It performs validation to ensure the input corresponds to 490 | a valid organism in the ORGANISM2ID dictionary. 491 | 492 | Args: 493 | organism (Union[int, str]): Either the ID of the organism (int) or its 494 | name (str). 495 | 496 | Returns: 497 | Tuple[int, str]: A tuple containing the organism ID (int) and name (str). 498 | 499 | Raises: 500 | ValueError: If the input is neither a string nor an integer, if the 501 | organism name is not found in ORGANISM2ID, if the organism ID is not a 502 | value in ORGANISM2ID, or if no name is found for a given ID. 503 | 504 | Note: 505 | This function relies on the ORGANISM2ID dictionary imported from 506 | CodonTransformer.CodonUtils, which maps organism names to their 507 | corresponding IDs. 508 | """ 509 | if isinstance(organism, str): 510 | if organism not in ORGANISM2ID: 511 | raise ValueError( 512 | f"Invalid organism name: {organism}. " 513 | "Please use a valid organism name or ID." 514 | ) 515 | organism_id = ORGANISM2ID[organism] 516 | organism_name = organism 517 | 518 | elif isinstance(organism, int): 519 | if organism not in ORGANISM2ID.values(): 520 | raise ValueError( 521 | f"Invalid organism ID: {organism}. " 522 | "Please use a valid organism name or ID." 523 | ) 524 | 525 | organism_id = organism 526 | organism_name = next( 527 | (name for name, id in ORGANISM2ID.items() if id == organism), None 528 | ) 529 | if organism_name is None: 530 | raise ValueError(f"No organism name found for ID: {organism}") 531 | 532 | return organism_id, organism_name 533 | 534 | 535 | def get_high_frequency_choice_sequence( 536 | protein: str, codon_frequencies: Dict[str, Tuple[List[str], List[float]]] 537 | ) -> str: 538 | """ 539 | Return the DNA sequence optimized using High Frequency Choice (HFC) approach 540 | in which the most frequent codon for a given amino acid is always chosen. 541 | 542 | Args: 543 | protein (str): The protein sequence. 544 | codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon 545 | frequencies for each amino acid. 546 | 547 | Returns: 548 | str: The optimized DNA sequence. 549 | """ 550 | # Select the most frequent codon for each amino acid in the protein sequence 551 | dna_codons = [ 552 | codon_frequencies[aminoacid][0][np.argmax(codon_frequencies[aminoacid][1])] 553 | for aminoacid in protein 554 | ] 555 | return "".join(dna_codons) 556 | 557 | 558 | def precompute_most_frequent_codons( 559 | codon_frequencies: Dict[str, Tuple[List[str], List[float]]], 560 | ) -> Dict[str, str]: 561 | """ 562 | Precompute the most frequent codon for each amino acid. 563 | 564 | Args: 565 | codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon 566 | frequencies for each amino acid. 567 | 568 | Returns: 569 | Dict[str, str]: The most frequent codon for each amino acid. 570 | """ 571 | # Create a dictionary mapping each amino acid to its most frequent codon 572 | return { 573 | aminoacid: codons[np.argmax(frequencies)] 574 | for aminoacid, (codons, frequencies) in codon_frequencies.items() 575 | } 576 | 577 | 578 | def get_high_frequency_choice_sequence_optimized( 579 | protein: str, codon_frequencies: Dict[str, Tuple[List[str], List[float]]] 580 | ) -> str: 581 | """ 582 | Efficient implementation of get_high_frequency_choice_sequence that uses 583 | vectorized operations and helper functions, achieving up to x10 faster speed. 584 | 585 | Args: 586 | protein (str): The protein sequence. 587 | codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon 588 | frequencies for each amino acid. 589 | 590 | Returns: 591 | str: The optimized DNA sequence. 592 | """ 593 | # Precompute the most frequent codons for each amino acid 594 | most_frequent_codons = precompute_most_frequent_codons(codon_frequencies) 595 | 596 | return "".join(most_frequent_codons[aminoacid] for aminoacid in protein) 597 | 598 | 599 | def get_background_frequency_choice_sequence( 600 | protein: str, codon_frequencies: Dict[str, Tuple[List[str], List[float]]] 601 | ) -> str: 602 | """ 603 | Return the DNA sequence optimized using Background Frequency Choice (BFC) 604 | approach in which a random codon for a given amino acid is chosen using 605 | the codon frequencies probability distribution. 606 | 607 | Args: 608 | protein (str): The protein sequence. 609 | codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon 610 | frequencies for each amino acid. 611 | 612 | Returns: 613 | str: The optimized DNA sequence. 614 | """ 615 | # Select a random codon for each amino acid based on the codon frequencies 616 | # probability distribution 617 | dna_codons = [ 618 | np.random.choice( 619 | codon_frequencies[aminoacid][0], p=codon_frequencies[aminoacid][1] 620 | ) 621 | for aminoacid in protein 622 | ] 623 | return "".join(dna_codons) 624 | 625 | 626 | def precompute_cdf( 627 | codon_frequencies: Dict[str, Tuple[List[str], List[float]]], 628 | ) -> Dict[str, Tuple[List[str], Any]]: 629 | """ 630 | Precompute the cumulative distribution function (CDF) for each amino acid. 631 | 632 | Args: 633 | codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon 634 | frequencies for each amino acid. 635 | 636 | Returns: 637 | Dict[str, Tuple[List[str], Any]]: CDFs for each amino acid. 638 | """ 639 | cdf = {} 640 | 641 | # Calculate the cumulative distribution function for each amino acid 642 | for aminoacid, (codons, frequencies) in codon_frequencies.items(): 643 | cdf[aminoacid] = (codons, np.cumsum(frequencies)) 644 | 645 | return cdf 646 | 647 | 648 | def get_background_frequency_choice_sequence_optimized( 649 | protein: str, codon_frequencies: Dict[str, Tuple[List[str], List[float]]] 650 | ) -> str: 651 | """ 652 | Efficient implementation of get_background_frequency_choice_sequence that uses 653 | vectorized operations and helper functions, achieving up to x8 faster speed. 654 | 655 | Args: 656 | protein (str): The protein sequence. 657 | codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon 658 | frequencies for each amino acid. 659 | 660 | Returns: 661 | str: The optimized DNA sequence. 662 | """ 663 | dna_codons = [] 664 | cdf = precompute_cdf(codon_frequencies) 665 | 666 | # Select a random codon for each amino acid using the precomputed CDFs 667 | for aminoacid in protein: 668 | codons, cumulative_prob = cdf[aminoacid] 669 | selected_codon_index = np.searchsorted(cumulative_prob, np.random.rand()) 670 | dna_codons.append(codons[selected_codon_index]) 671 | 672 | return "".join(dna_codons) 673 | 674 | 675 | def get_uniform_random_choice_sequence( 676 | protein: str, codon_frequencies: Dict[str, Tuple[List[str], List[float]]] 677 | ) -> str: 678 | """ 679 | Return the DNA sequence optimized using Uniform Random Choice (URC) approach 680 | in which a random codon for a given amino acid is chosen using a uniform 681 | prior. 682 | 683 | Args: 684 | protein (str): The protein sequence. 685 | codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon 686 | frequencies for each amino acid. 687 | 688 | Returns: 689 | str: The optimized DNA sequence. 690 | """ 691 | # Select a random codon for each amino acid using a uniform prior distribution 692 | dna_codons = [] 693 | for aminoacid in protein: 694 | codons = codon_frequencies[aminoacid][0] 695 | random_index = np.random.randint(0, len(codons)) 696 | dna_codons.append(codons[random_index]) 697 | return "".join(dna_codons) 698 | 699 | 700 | def get_icor_prediction(input_seq: str, model_path: str, stop_symbol: str) -> str: 701 | """ 702 | Return the optimized codon sequence for the given protein sequence using ICOR. 703 | 704 | Credit: ICOR: improving codon optimization with recurrent neural networks 705 | Rishab Jain, Aditya Jain, Elizabeth Mauro, Kevin LeShane, Douglas 706 | Densmore 707 | 708 | Args: 709 | input_seq (str): The input protein sequence. 710 | model_path (str): The path to the ICOR model. 711 | stop_symbol (str): The symbol representing stop codons in the sequence. 712 | 713 | Returns: 714 | str: The optimized DNA sequence. 715 | """ 716 | input_seq = input_seq.strip().upper() 717 | input_seq = input_seq.replace(stop_symbol, "*") 718 | 719 | # Define categorical labels from when model was trained. 720 | labels = [ 721 | "AAA", 722 | "AAC", 723 | "AAG", 724 | "AAT", 725 | "ACA", 726 | "ACG", 727 | "ACT", 728 | "AGC", 729 | "ATA", 730 | "ATC", 731 | "ATG", 732 | "ATT", 733 | "CAA", 734 | "CAC", 735 | "CAG", 736 | "CCG", 737 | "CCT", 738 | "CTA", 739 | "CTC", 740 | "CTG", 741 | "CTT", 742 | "GAA", 743 | "GAT", 744 | "GCA", 745 | "GCC", 746 | "GCG", 747 | "GCT", 748 | "GGA", 749 | "GGC", 750 | "GTC", 751 | "GTG", 752 | "GTT", 753 | "TAA", 754 | "TAT", 755 | "TCA", 756 | "TCG", 757 | "TCT", 758 | "TGG", 759 | "TGT", 760 | "TTA", 761 | "TTC", 762 | "TTG", 763 | "TTT", 764 | "ACC", 765 | "CAT", 766 | "CCA", 767 | "CGG", 768 | "CGT", 769 | "GAC", 770 | "GAG", 771 | "GGT", 772 | "AGT", 773 | "GGG", 774 | "GTA", 775 | "TGC", 776 | "CCC", 777 | "CGA", 778 | "CGC", 779 | "TAC", 780 | "TAG", 781 | "TCC", 782 | "AGA", 783 | "AGG", 784 | "TGA", 785 | ] 786 | 787 | # Define aa to integer table 788 | def aa2int(seq: str) -> List[int]: 789 | _aa2int = { 790 | "A": 1, 791 | "R": 2, 792 | "N": 3, 793 | "D": 4, 794 | "C": 5, 795 | "Q": 6, 796 | "E": 7, 797 | "G": 8, 798 | "H": 9, 799 | "I": 10, 800 | "L": 11, 801 | "K": 12, 802 | "M": 13, 803 | "F": 14, 804 | "P": 15, 805 | "S": 16, 806 | "T": 17, 807 | "W": 18, 808 | "Y": 19, 809 | "V": 20, 810 | "B": 21, 811 | "Z": 22, 812 | "X": 23, 813 | "*": 24, 814 | "-": 25, 815 | "?": 26, 816 | } 817 | return [_aa2int[i] for i in seq] 818 | 819 | # Create empty array to fill 820 | oh_array = np.zeros(shape=(26, len(input_seq))) 821 | 822 | # Load placements from aa2int 823 | aa_placement = aa2int(input_seq) 824 | 825 | # One-hot encode the amino acid sequence: 826 | 827 | # style nit: more pythonic to write for i in range(0, len(aa_placement)): 828 | for i in range(0, len(aa_placement)): 829 | oh_array[aa_placement[i], i] = 1 830 | i += 1 831 | 832 | oh_array = [oh_array] 833 | x = np.array(np.transpose(oh_array)) 834 | 835 | y = x.astype(np.float32) 836 | 837 | y = np.reshape(y, (y.shape[0], 1, 26)) 838 | 839 | # Start ICOR session using model. 840 | sess = rt.InferenceSession(model_path) 841 | input_name = sess.get_inputs()[0].name 842 | 843 | # Get prediction: 844 | pred_onx = sess.run(None, {input_name: y}) 845 | 846 | # Get the index of the highest probability from softmax output: 847 | pred_indices = [] 848 | for pred in pred_onx[0]: 849 | pred_indices.append(np.argmax(pred)) 850 | 851 | out_str = "" 852 | for index in pred_indices: 853 | out_str += labels[index] 854 | 855 | return out_str 856 | -------------------------------------------------------------------------------- /CodonTransformer/CodonUtils.py: -------------------------------------------------------------------------------- 1 | """ 2 | File: CodonUtils.py 3 | --------------------- 4 | Includes constants and helper functions used by other Python scripts. 5 | """ 6 | 7 | import itertools 8 | import os 9 | import pickle 10 | import re 11 | from abc import ABC, abstractmethod 12 | from dataclasses import dataclass 13 | from typing import Any, Dict, Iterator, List, Optional, Tuple 14 | 15 | import pandas as pd 16 | import requests 17 | import torch 18 | 19 | # List of all amino acids 20 | AMINO_ACIDS: List[str] = [ 21 | "A", # Alanine 22 | "C", # Cysteine 23 | "D", # Aspartic acid 24 | "E", # Glutamic acid 25 | "F", # Phenylalanine 26 | "G", # Glycine 27 | "H", # Histidine 28 | "I", # Isoleucine 29 | "K", # Lysine 30 | "L", # Leucine 31 | "M", # Methionine 32 | "N", # Asparagine 33 | "P", # Proline 34 | "Q", # Glutamine 35 | "R", # Arginine 36 | "S", # Serine 37 | "T", # Threonine 38 | "V", # Valine 39 | "W", # Tryptophan 40 | "Y", # Tyrosine 41 | ] 42 | STOP_SYMBOLS = ["_", "*"] # Stop codon symbols 43 | 44 | # Dictionary ambiguous amino acids to standard amino acids 45 | AMBIGUOUS_AMINOACID_MAP: Dict[str, list[str]] = { 46 | "B": ["N", "D"], # Asparagine (N) or Aspartic acid (D) 47 | "Z": ["Q", "E"], # Glutamine (Q) or Glutamic acid (E) 48 | "X": ["A"], # Any amino acid (typically replaced with Alanine) 49 | "J": ["L", "I"], # Leucine (L) or Isoleucine (I) 50 | "U": ["C"], # Selenocysteine (typically replaced with Cysteine) 51 | "O": ["K"], # Pyrrolysine (typically replaced with Lysine) 52 | } 53 | 54 | # List of all possible start and stop codons 55 | START_CODONS: List[str] = ["ATG", "TTG", "CTG", "GTG"] 56 | STOP_CODONS: List[str] = ["TAA", "TAG", "TGA"] 57 | 58 | # Token-to-index mapping for amino acids and special tokens 59 | TOKEN2INDEX: Dict[str, int] = { 60 | "[UNK]": 0, 61 | "[CLS]": 1, 62 | "[SEP]": 2, 63 | "[PAD]": 3, 64 | "[MASK]": 4, 65 | "a_unk": 5, 66 | "c_unk": 6, 67 | "d_unk": 7, 68 | "e_unk": 8, 69 | "f_unk": 9, 70 | "g_unk": 10, 71 | "h_unk": 11, 72 | "i_unk": 12, 73 | "k_unk": 13, 74 | "l_unk": 14, 75 | "m_unk": 15, 76 | "n_unk": 16, 77 | "p_unk": 17, 78 | "q_unk": 18, 79 | "r_unk": 19, 80 | "s_unk": 20, 81 | "t_unk": 21, 82 | "v_unk": 22, 83 | "w_unk": 23, 84 | "y_unk": 24, 85 | "__unk": 25, 86 | "k_aaa": 26, 87 | "n_aac": 27, 88 | "k_aag": 28, 89 | "n_aat": 29, 90 | "t_aca": 30, 91 | "t_acc": 31, 92 | "t_acg": 32, 93 | "t_act": 33, 94 | "r_aga": 34, 95 | "s_agc": 35, 96 | "r_agg": 36, 97 | "s_agt": 37, 98 | "i_ata": 38, 99 | "i_atc": 39, 100 | "m_atg": 40, 101 | "i_att": 41, 102 | "q_caa": 42, 103 | "h_cac": 43, 104 | "q_cag": 44, 105 | "h_cat": 45, 106 | "p_cca": 46, 107 | "p_ccc": 47, 108 | "p_ccg": 48, 109 | "p_cct": 49, 110 | "r_cga": 50, 111 | "r_cgc": 51, 112 | "r_cgg": 52, 113 | "r_cgt": 53, 114 | "l_cta": 54, 115 | "l_ctc": 55, 116 | "l_ctg": 56, 117 | "l_ctt": 57, 118 | "e_gaa": 58, 119 | "d_gac": 59, 120 | "e_gag": 60, 121 | "d_gat": 61, 122 | "a_gca": 62, 123 | "a_gcc": 63, 124 | "a_gcg": 64, 125 | "a_gct": 65, 126 | "g_gga": 66, 127 | "g_ggc": 67, 128 | "g_ggg": 68, 129 | "g_ggt": 69, 130 | "v_gta": 70, 131 | "v_gtc": 71, 132 | "v_gtg": 72, 133 | "v_gtt": 73, 134 | "__taa": 74, 135 | "y_tac": 75, 136 | "__tag": 76, 137 | "y_tat": 77, 138 | "s_tca": 78, 139 | "s_tcc": 79, 140 | "s_tcg": 80, 141 | "s_tct": 81, 142 | "__tga": 82, 143 | "c_tgc": 83, 144 | "w_tgg": 84, 145 | "c_tgt": 85, 146 | "l_tta": 86, 147 | "f_ttc": 87, 148 | "l_ttg": 88, 149 | "f_ttt": 89, 150 | } 151 | 152 | # Index-to-token mapping, reverse of TOKEN2INDEX 153 | INDEX2TOKEN: Dict[int, str] = {i: c for c, i in TOKEN2INDEX.items()} 154 | 155 | # Dictionary mapping each amino acid and stop symbol to indices of codon tokens that translate to it 156 | AMINO_ACID_TO_INDEX = { 157 | aa: sorted( 158 | [i for t, i in TOKEN2INDEX.items() if t[0].upper() == aa and t[-3:] != "unk"] 159 | ) 160 | for aa in (AMINO_ACIDS + STOP_SYMBOLS) 161 | } 162 | 163 | 164 | # Mask token mapping 165 | TOKEN2MASK: Dict[int, int] = { 166 | 0: 0, 167 | 1: 1, 168 | 2: 2, 169 | 3: 3, 170 | 4: 4, 171 | 5: 5, 172 | 6: 6, 173 | 7: 7, 174 | 8: 8, 175 | 9: 9, 176 | 10: 10, 177 | 11: 11, 178 | 12: 12, 179 | 13: 13, 180 | 14: 14, 181 | 15: 15, 182 | 16: 16, 183 | 17: 17, 184 | 18: 18, 185 | 19: 19, 186 | 20: 20, 187 | 21: 21, 188 | 22: 22, 189 | 23: 23, 190 | 24: 24, 191 | 25: 25, 192 | 26: 13, 193 | 27: 16, 194 | 28: 13, 195 | 29: 16, 196 | 30: 21, 197 | 31: 21, 198 | 32: 21, 199 | 33: 21, 200 | 34: 19, 201 | 35: 20, 202 | 36: 19, 203 | 37: 20, 204 | 38: 12, 205 | 39: 12, 206 | 40: 15, 207 | 41: 12, 208 | 42: 18, 209 | 43: 11, 210 | 44: 18, 211 | 45: 11, 212 | 46: 17, 213 | 47: 17, 214 | 48: 17, 215 | 49: 17, 216 | 50: 19, 217 | 51: 19, 218 | 52: 19, 219 | 53: 19, 220 | 54: 14, 221 | 55: 14, 222 | 56: 14, 223 | 57: 14, 224 | 58: 8, 225 | 59: 7, 226 | 60: 8, 227 | 61: 7, 228 | 62: 5, 229 | 63: 5, 230 | 64: 5, 231 | 65: 5, 232 | 66: 10, 233 | 67: 10, 234 | 68: 10, 235 | 69: 10, 236 | 70: 22, 237 | 71: 22, 238 | 72: 22, 239 | 73: 22, 240 | 74: 25, 241 | 75: 24, 242 | 76: 25, 243 | 77: 24, 244 | 78: 20, 245 | 79: 20, 246 | 80: 20, 247 | 81: 20, 248 | 82: 25, 249 | 83: 6, 250 | 84: 23, 251 | 85: 6, 252 | 86: 14, 253 | 87: 9, 254 | 88: 14, 255 | 89: 9, 256 | } 257 | 258 | # List of organisms used for fine-tuning 259 | FINE_TUNE_ORGANISMS: List[str] = [ 260 | "Arabidopsis thaliana", 261 | "Bacillus subtilis", 262 | "Caenorhabditis elegans", 263 | "Chlamydomonas reinhardtii", 264 | "Chlamydomonas reinhardtii chloroplast", 265 | "Danio rerio", 266 | "Drosophila melanogaster", 267 | "Homo sapiens", 268 | "Mus musculus", 269 | "Nicotiana tabacum", 270 | "Nicotiana tabacum chloroplast", 271 | "Pseudomonas putida", 272 | "Saccharomyces cerevisiae", 273 | "Escherichia coli O157-H7 str. Sakai", 274 | "Escherichia coli general", 275 | "Escherichia coli str. K-12 substr. MG1655", 276 | "Thermococcus barophilus MPT", 277 | ] 278 | 279 | # List of organisms most commonly used for coodn optimization 280 | COMMON_ORGANISMS: List[str] = [ 281 | "Arabidopsis thaliana", 282 | "Bacillus subtilis", 283 | "Caenorhabditis elegans", 284 | "Chlamydomonas reinhardtii", 285 | "Danio rerio", 286 | "Drosophila melanogaster", 287 | "Homo sapiens", 288 | "Mus musculus", 289 | "Nicotiana tabacum", 290 | "Pseudomonas putida", 291 | "Saccharomyces cerevisiae", 292 | "Escherichia coli general", 293 | ] 294 | 295 | # Dictionary mapping each organism name to respective organism id 296 | ORGANISM2ID: Dict[str, int] = { 297 | "Arabidopsis thaliana": 0, 298 | "Atlantibacter hermannii": 1, 299 | "Bacillus subtilis": 2, 300 | "Brenneria goodwinii": 3, 301 | "Buchnera aphidicola (Schizaphis graminum)": 4, 302 | "Caenorhabditis elegans": 5, 303 | "Candidatus Erwinia haradaeae": 6, 304 | "Candidatus Hamiltonella defensa 5AT (Acyrthosiphon pisum)": 7, 305 | "Chlamydomonas reinhardtii": 8, 306 | "Chlamydomonas reinhardtii chloroplast": 9, 307 | "Citrobacter amalonaticus": 10, 308 | "Citrobacter braakii": 11, 309 | "Citrobacter cronae": 12, 310 | "Citrobacter europaeus": 13, 311 | "Citrobacter farmeri": 14, 312 | "Citrobacter freundii": 15, 313 | "Citrobacter koseri ATCC BAA-895": 16, 314 | "Citrobacter portucalensis": 17, 315 | "Citrobacter werkmanii": 18, 316 | "Citrobacter youngae": 19, 317 | "Cronobacter dublinensis subsp. dublinensis LMG 23823": 20, 318 | "Cronobacter malonaticus LMG 23826": 21, 319 | "Cronobacter sakazakii": 22, 320 | "Cronobacter turicensis": 23, 321 | "Danio rerio": 24, 322 | "Dickeya dadantii 3937": 25, 323 | "Dickeya dianthicola": 26, 324 | "Dickeya fangzhongdai": 27, 325 | "Dickeya solani": 28, 326 | "Dickeya zeae": 29, 327 | "Drosophila melanogaster": 30, 328 | "Edwardsiella anguillarum ET080813": 31, 329 | "Edwardsiella ictaluri": 32, 330 | "Edwardsiella piscicida": 33, 331 | "Edwardsiella tarda": 34, 332 | "Enterobacter asburiae": 35, 333 | "Enterobacter bugandensis": 36, 334 | "Enterobacter cancerogenus": 37, 335 | "Enterobacter chengduensis": 38, 336 | "Enterobacter cloacae": 39, 337 | "Enterobacter hormaechei": 40, 338 | "Enterobacter kobei": 41, 339 | "Enterobacter ludwigii": 42, 340 | "Enterobacter mori": 43, 341 | "Enterobacter quasiroggenkampii": 44, 342 | "Enterobacter roggenkampii": 45, 343 | "Enterobacter sichuanensis": 46, 344 | "Erwinia amylovora CFBP1430": 47, 345 | "Erwinia persicina": 48, 346 | "Escherichia albertii": 49, 347 | "Escherichia coli O157-H7 str. Sakai": 50, 348 | "Escherichia coli general": 51, 349 | "Escherichia coli str. K-12 substr. MG1655": 52, 350 | "Escherichia fergusonii": 53, 351 | "Escherichia marmotae": 54, 352 | "Escherichia ruysiae": 55, 353 | "Ewingella americana": 56, 354 | "Hafnia alvei": 57, 355 | "Hafnia paralvei": 58, 356 | "Homo sapiens": 59, 357 | "Kalamiella piersonii": 60, 358 | "Klebsiella aerogenes": 61, 359 | "Klebsiella grimontii": 62, 360 | "Klebsiella michiganensis": 63, 361 | "Klebsiella oxytoca": 64, 362 | "Klebsiella pasteurii": 65, 363 | "Klebsiella pneumoniae subsp. pneumoniae HS11286": 66, 364 | "Klebsiella quasipneumoniae": 67, 365 | "Klebsiella quasivariicola": 68, 366 | "Klebsiella variicola": 69, 367 | "Kosakonia cowanii": 70, 368 | "Kosakonia radicincitans": 71, 369 | "Leclercia adecarboxylata": 72, 370 | "Lelliottia amnigena": 73, 371 | "Lonsdalea populi": 74, 372 | "Moellerella wisconsensis": 75, 373 | "Morganella morganii": 76, 374 | "Mus musculus": 77, 375 | "Nicotiana tabacum": 78, 376 | "Nicotiana tabacum chloroplast": 79, 377 | "Obesumbacterium proteus": 80, 378 | "Pantoea agglomerans": 81, 379 | "Pantoea allii": 82, 380 | "Pantoea ananatis PA13": 83, 381 | "Pantoea dispersa": 84, 382 | "Pantoea stewartii": 85, 383 | "Pantoea vagans": 86, 384 | "Pectobacterium aroidearum": 87, 385 | "Pectobacterium atrosepticum": 88, 386 | "Pectobacterium brasiliense": 89, 387 | "Pectobacterium carotovorum": 90, 388 | "Pectobacterium odoriferum": 91, 389 | "Pectobacterium parmentieri": 92, 390 | "Pectobacterium polaris": 93, 391 | "Pectobacterium versatile": 94, 392 | "Photorhabdus laumondii subsp. laumondii TTO1": 95, 393 | "Plesiomonas shigelloides": 96, 394 | "Pluralibacter gergoviae": 97, 395 | "Proteus faecis": 98, 396 | "Proteus mirabilis HI4320": 99, 397 | "Proteus penneri": 100, 398 | "Proteus terrae subsp. cibarius": 101, 399 | "Proteus vulgaris": 102, 400 | "Providencia alcalifaciens": 103, 401 | "Providencia heimbachae": 104, 402 | "Providencia rettgeri": 105, 403 | "Providencia rustigianii": 106, 404 | "Providencia stuartii": 107, 405 | "Providencia thailandensis": 108, 406 | "Pseudomonas putida": 109, 407 | "Pyrococcus furiosus": 110, 408 | "Pyrococcus horikoshii": 111, 409 | "Pyrococcus yayanosii": 112, 410 | "Rahnella aquatilis CIP 78.65 = ATCC 33071": 113, 411 | "Raoultella ornithinolytica": 114, 412 | "Raoultella planticola": 115, 413 | "Raoultella terrigena": 116, 414 | "Rosenbergiella epipactidis": 117, 415 | "Rouxiella badensis": 118, 416 | "Saccharolobus solfataricus": 119, 417 | "Saccharomyces cerevisiae": 120, 418 | "Salmonella bongori N268-08": 121, 419 | "Salmonella enterica subsp. enterica serovar Typhimurium str. LT2": 122, 420 | "Serratia bockelmannii": 123, 421 | "Serratia entomophila": 124, 422 | "Serratia ficaria": 125, 423 | "Serratia fonticola": 126, 424 | "Serratia grimesii": 127, 425 | "Serratia liquefaciens": 128, 426 | "Serratia marcescens": 129, 427 | "Serratia nevei": 130, 428 | "Serratia plymuthica AS9": 131, 429 | "Serratia proteamaculans": 132, 430 | "Serratia quinivorans": 133, 431 | "Serratia rubidaea": 134, 432 | "Serratia ureilytica": 135, 433 | "Shigella boydii": 136, 434 | "Shigella dysenteriae": 137, 435 | "Shigella flexneri 2a str. 301": 138, 436 | "Shigella sonnei": 139, 437 | "Thermoccoccus kodakarensis": 140, 438 | "Thermococcus barophilus MPT": 141, 439 | "Thermococcus chitonophagus": 142, 440 | "Thermococcus gammatolerans": 143, 441 | "Thermococcus litoralis": 144, 442 | "Thermococcus onnurineus": 145, 443 | "Thermococcus sibiricus": 146, 444 | "Xenorhabdus bovienii str. feltiae Florida": 147, 445 | "Yersinia aldovae 670-83": 148, 446 | "Yersinia aleksiciae": 149, 447 | "Yersinia alsatica": 150, 448 | "Yersinia enterocolitica": 151, 449 | "Yersinia frederiksenii ATCC 33641": 152, 450 | "Yersinia intermedia": 153, 451 | "Yersinia kristensenii": 154, 452 | "Yersinia massiliensis CCUG 53443": 155, 453 | "Yersinia mollaretii ATCC 43969": 156, 454 | "Yersinia pestis A1122": 157, 455 | "Yersinia proxima": 158, 456 | "Yersinia pseudotuberculosis IP 32953": 159, 457 | "Yersinia rochesterensis": 160, 458 | "Yersinia rohdei": 161, 459 | "Yersinia ruckeri": 162, 460 | "Yokenella regensburgei": 163, 461 | } 462 | 463 | # Dictionary mapping each organism id to respective organism name 464 | ID2ORGANISM = {v: k for k, v in ORGANISM2ID.items()} 465 | 466 | # Type alias for amino acid to codon mapping 467 | AMINO2CODON_TYPE = Dict[str, Tuple[List[str], List[float]]] 468 | 469 | # Constants for the number of organisms and sequence lengths 470 | NUM_ORGANISMS = 164 471 | MAX_LEN = 2048 472 | MAX_AMINO_ACIDS = MAX_LEN - 2 # Without special tokens [CLS] and [SEP] 473 | STOP_SYMBOL = "_" 474 | 475 | 476 | @dataclass 477 | class DNASequencePrediction: 478 | """ 479 | A class to hold the output of the DNA sequence prediction. 480 | 481 | Attributes: 482 | organism (str): Name of the organism used for prediction. 483 | protein (str): Input protein sequence for which DNA sequence is predicted. 484 | processed_input (str): Processed input sequence (merged protein and DNA). 485 | predicted_dna (str): Predicted DNA sequence. 486 | """ 487 | 488 | organism: str 489 | protein: str 490 | processed_input: str 491 | predicted_dna: str 492 | 493 | 494 | class IterableData(torch.utils.data.IterableDataset): 495 | """ 496 | Defines the logic for iterable datasets (working over streams of 497 | data) in parallel multi-processing environments, e.g., multi-GPU. 498 | 499 | Args: 500 | dist_env (Optional[str]): The distribution environment identifier 501 | (e.g., "slurm"). 502 | 503 | Credit: Guillaume Filion 504 | """ 505 | 506 | def __init__(self, dist_env: Optional[str] = None): 507 | super().__init__() 508 | self.world_size_handle, self.rank_handle = { 509 | "slurm": ("SLURM_NTASKS", "SLURM_PROCID") 510 | }.get(dist_env, ("WORLD_SIZE", "LOCAL_RANK")) 511 | 512 | @property 513 | def iterator(self) -> Iterator: 514 | """Define the stream logic for the dataset. Implement in subclasses.""" 515 | raise NotImplementedError 516 | 517 | def __iter__(self) -> Iterator: 518 | """ 519 | Create an iterator for the dataset, handling multi-processing contexts. 520 | 521 | Returns: 522 | Iterator: The iterator for the dataset. 523 | """ 524 | worker_info = torch.utils.data.get_worker_info() 525 | if worker_info is None: 526 | return self.iterator 527 | 528 | # In multi-processing context, use 'os.environ' to 529 | # find global worker rank. Then use 'islice' to allocate 530 | # the items of the stream to the workers. 531 | world_size = int(os.environ.get(self.world_size_handle)) 532 | global_rank = int(os.environ.get(self.rank_handle)) 533 | local_rank = worker_info.id 534 | local_num_workers = worker_info.num_workers 535 | 536 | # Assume that each process has the same number of local workers. 537 | worker_rk = global_rank * local_num_workers + local_rank 538 | worker_nb = world_size * local_num_workers 539 | return itertools.islice(self.iterator, worker_rk, None, worker_nb) 540 | 541 | 542 | class IterableJSONData(IterableData): 543 | """ 544 | Iterate over the lines of a JSON file and uncompress if needed. 545 | 546 | Args: 547 | data_path (str): The path to the JSON data file. 548 | train (bool): Flag indicating if the dataset is for training. 549 | **kwargs: Additional keyword arguments for the base class. 550 | """ 551 | 552 | def __init__(self, data_path: str, train: bool = True, **kwargs): 553 | super().__init__(**kwargs) 554 | self.data_path = data_path 555 | self.train = train 556 | 557 | 558 | class ConfigManager(ABC): 559 | """ 560 | Abstract base class for managing configuration settings. 561 | """ 562 | 563 | def __enter__(self): 564 | return self 565 | 566 | def __exit__(self, exc_type, exc_value, traceback): 567 | if exc_type is not None: 568 | print(f"Exception occurred: {exc_type}, {exc_value}, {traceback}") 569 | self.reset_config() 570 | 571 | @abstractmethod 572 | def reset_config(self) -> None: 573 | """Reset the configuration to default values.""" 574 | pass 575 | 576 | def get(self, key: str) -> Any: 577 | """ 578 | Get the value of a configuration key. 579 | 580 | Args: 581 | key (str): The key to retrieve the value for. 582 | 583 | Returns: 584 | Any: The value of the configuration key. 585 | """ 586 | return self._config.get(key) 587 | 588 | def set(self, key: str, value: Any) -> None: 589 | """ 590 | Set the value of a configuration key. 591 | 592 | Args: 593 | key (str): The key to set the value for. 594 | value (Any): The value to set for the key. 595 | """ 596 | self.validate_inputs(key, value) 597 | self._config[key] = value 598 | 599 | def update(self, config_dict: dict) -> None: 600 | """ 601 | Update the configuration with a dictionary of key-value pairs after validating them. 602 | 603 | Args: 604 | config_dict (dict): A dictionary of key-value pairs to update the configuration. 605 | """ 606 | for key, value in config_dict.items(): 607 | self.validate_inputs(key, value) 608 | self._config.update(config_dict) 609 | 610 | @abstractmethod 611 | def validate_inputs(self, key: str, value: Any) -> None: 612 | """Validate the inputs for the configuration.""" 613 | pass 614 | 615 | 616 | class ProteinConfig(ConfigManager): 617 | """ 618 | A class to manage configuration settings for protein sequences. 619 | 620 | This class ensures that the configuration is a singleton. 621 | It provides methods to get, set, and update configuration values. 622 | 623 | Attributes: 624 | _instance (Optional[ConfigManager]): The singleton instance of the ConfigManager. 625 | _config (Dict[str, Any]): The configuration dictionary. 626 | """ 627 | 628 | _instance = None 629 | 630 | def __new__(cls): 631 | """ 632 | Create a new instance of the ProteinConfig class. 633 | 634 | Returns: 635 | ProteinConfig: The singleton instance of the ProteinConfig. 636 | """ 637 | if cls._instance is None: 638 | cls._instance = super(ProteinConfig, cls).__new__(cls) 639 | cls._instance.reset_config() 640 | return cls._instance 641 | 642 | def validate_inputs(self, key: str, value: Any) -> None: 643 | """ 644 | Validate the inputs for the configuration. 645 | 646 | Args: 647 | key (str): The key to validate. 648 | value (Any): The value to validate. 649 | 650 | Raises: 651 | ValueError: If the value is invalid. 652 | TypeError: If the value is of the wrong type. 653 | """ 654 | if key == "ambiguous_aminoacid_behavior": 655 | if value not in [ 656 | "raise_error", 657 | "standardize_deterministic", 658 | "standardize_random", 659 | ]: 660 | raise ValueError( 661 | f"Invalid value for ambiguous_aminoacid_behavior: {value}." 662 | ) 663 | elif key == "ambiguous_aminoacid_map_override": 664 | if not isinstance(value, dict): 665 | raise TypeError( 666 | f"Invalid type for ambiguous_aminoacid_map_override: {value}." 667 | ) 668 | for ambiguous_aminoacid, aminoacids in value.items(): 669 | if not isinstance(aminoacids, list): 670 | raise TypeError(f"Invalid type for aminoacids: {aminoacids}.") 671 | if not aminoacids: 672 | raise ValueError( 673 | f"Override for aminoacid '{ambiguous_aminoacid}' cannot be empty list." 674 | ) 675 | if ambiguous_aminoacid not in AMBIGUOUS_AMINOACID_MAP: 676 | raise ValueError( 677 | f"Invalid amino acid in ambiguous_aminoacid_map_override: {ambiguous_aminoacid}" 678 | ) 679 | else: 680 | raise ValueError(f"Invalid configuration key: {key}") 681 | 682 | def reset_config(self) -> None: 683 | """ 684 | Reset the configuration to the default values. 685 | """ 686 | self._config = { 687 | "ambiguous_aminoacid_behavior": "standardize_random", 688 | "ambiguous_aminoacid_map_override": {}, 689 | } 690 | 691 | 692 | def load_python_object_from_disk(file_path: str) -> Any: 693 | """ 694 | Load a Pickle object from disk and return it as a Python object. 695 | 696 | Args: 697 | file_path (str): The path to the Pickle file. 698 | 699 | Returns: 700 | Any: The loaded Python object. 701 | """ 702 | with open(file_path, "rb") as file: 703 | return pickle.load(file) 704 | 705 | 706 | def save_python_object_to_disk(input_object: Any, file_path: str) -> None: 707 | """ 708 | Save a Python object to disk using Pickle. 709 | 710 | Args: 711 | input_object (Any): The Python object to save. 712 | file_path (str): The path where the object will be saved. 713 | """ 714 | with open(file_path, "wb") as file: 715 | pickle.dump(input_object, file) 716 | 717 | 718 | def find_pattern_in_fasta(keyword: str, text: str) -> str: 719 | """ 720 | Find a specific keyword pattern in text. Helpful for identifying parts 721 | of a FASTA sequence. 722 | 723 | Args: 724 | keyword (str): The keyword pattern to search for. 725 | text (str): The text to search within. 726 | 727 | Returns: 728 | str: The found pattern or an empty string if not found. 729 | """ 730 | # Search for the keyword pattern in the text using regex 731 | result = re.search(keyword + r"=(.*?)]", text) 732 | return result.group(1) if result else "" 733 | 734 | 735 | def get_organism2id_dict(organism_reference: str) -> Dict[str, int]: 736 | """ 737 | Return a dictionary mapping each organism in training data to an index 738 | used for training. 739 | 740 | Args: 741 | organism_reference (str): Path to a CSV file containing a list of 742 | all organisms. The format of the CSV file should be as follows: 743 | 744 | 0,Escherichia coli 745 | 1,Homo sapiens 746 | 2,Mus musculus 747 | 748 | Returns: 749 | Dict[str, int]: Dictionary mapping organism names to their respective indices. 750 | """ 751 | # Read the CSV file and create a dictionary mapping organisms to their indices 752 | organisms = pd.read_csv(organism_reference, index_col=0, header=None) 753 | organism2id = {organisms.iloc[i].values[0]: i for i in organisms.index} 754 | 755 | return organism2id 756 | 757 | 758 | def get_taxonomy_id( 759 | taxonomy_reference: str, organism: Optional[str] = None, return_dict: bool = False 760 | ) -> Any: 761 | """ 762 | Return the taxonomy id of a given organism using a reference file. 763 | Optionally, return the whole dictionary instead if return_dict is True. 764 | 765 | Args: 766 | taxonomy_reference (str): Path to the taxonomy reference file. 767 | organism (Optional[str]): The name of the organism to look up. 768 | return_dict (bool): Whether to return the entire dictionary. 769 | 770 | Returns: 771 | Any: The taxonomy id of the organism or the entire dictionary. 772 | """ 773 | # Load the organism-to-taxonomy mapping from a Pickle file 774 | organism2taxonomy = load_python_object_from_disk(taxonomy_reference) 775 | 776 | if return_dict: 777 | return dict(sorted(organism2taxonomy.items())) 778 | 779 | return organism2taxonomy[organism] 780 | 781 | 782 | def sort_amino2codon_skeleton(amino2codon: Dict[str, Any]) -> Dict[str, Any]: 783 | """ 784 | Sort the amino2codon dictionary alphabetically by amino acid and by codon name. 785 | 786 | Args: 787 | amino2codon (Dict[str, Any]): The amino2codon dictionary to sort. 788 | 789 | Returns: 790 | Dict[str, Any]: The sorted amino2codon dictionary. 791 | """ 792 | # Sort the dictionary by amino acid and then by codon name 793 | amino2codon = dict(sorted(amino2codon.items())) 794 | amino2codon = { 795 | amino: ( 796 | [codon for codon, _ in sorted(zip(codons, frequencies))], 797 | [freq for _, freq in sorted(zip(codons, frequencies))], 798 | ) 799 | for amino, (codons, frequencies) in amino2codon.items() 800 | } 801 | 802 | return amino2codon 803 | 804 | 805 | def load_pkl_from_url(url: str) -> Any: 806 | """ 807 | Download a Pickle file from a URL and return the loaded object. 808 | 809 | Args: 810 | url (str): The URL to download the Pickle file from. 811 | 812 | Returns: 813 | Any: The loaded Python object from the Pickle file. 814 | """ 815 | response = requests.get(url) 816 | response.raise_for_status() # Ensure the request was successful 817 | 818 | # Load the Pickle object from the response content 819 | return pickle.loads(response.content) 820 | -------------------------------------------------------------------------------- /CodonTransformer/__init__.py: -------------------------------------------------------------------------------- 1 | """CodonTransformer package.""" 2 | -------------------------------------------------------------------------------- /CodonTransformerDemo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import warnings\n", 10 | "from tqdm import tqdm\n", 11 | "\n", 12 | "import pandas as pd\n", 13 | "import torch\n", 14 | "from transformers import AutoTokenizer, BigBirdForMaskedLM\n", 15 | "\n", 16 | "from CodonTransformer.CodonJupyter import (\n", 17 | " UserContainer,\n", 18 | " display_organism_dropdown,\n", 19 | " display_protein_input,\n", 20 | " format_model_output,\n", 21 | ")\n", 22 | "from CodonTransformer.CodonPrediction import predict_dna_sequence\n", 23 | "\n", 24 | "warnings.filterwarnings(\"ignore\")\n", 25 | "\n", 26 | "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "# Load model and tokenizer\n", 36 | "tokenizer = AutoTokenizer.from_pretrained(\"adibvafa/CodonTransformer\")\n", 37 | "model = BigBirdForMaskedLM.from_pretrained(\"adibvafa/CodonTransformer\").to(DEVICE)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "**Optimizing a Single Sequence**\n", 45 | "-------------------------------------\n", 46 | "1. Run the next code cell and input only your protein sequence and organism\n", 47 | "\n", 48 | "2. Run the code cell after it to optimize the sequence and display it.\n", 49 | "\n", 50 | "Protein sequences should end with \"*\" or \"_\" or an amino acid." 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "# Sample: MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG, Homo sapiens\n", 60 | "user = UserContainer()\n", 61 | "display_protein_input(user)\n", 62 | "display_organism_dropdown(user)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "output = predict_dna_sequence(\n", 72 | " protein=user.protein,\n", 73 | " organism=user.organism,\n", 74 | " device=DEVICE,\n", 75 | " tokenizer=tokenizer,\n", 76 | " model=model,\n", 77 | " attention_type=\"original_full\",\n", 78 | " deterministic=True,\n", 79 | " # Can set temperature for non deterministic prediction\n", 80 | ")\n", 81 | "\n", 82 | "print(format_model_output(output))" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "**Optimizing Multiple Sequences**\n", 90 | "-------------------------------------\n", 91 | "1. Create a CSV file that has columns 'protein_sequence' and 'organism'.\n", 92 | " You can have other columns in any order.\n", 93 | "\n", 94 | "2. Replace the _dataset_path_ below with the actual path to your CSV file.\n", 95 | "\n", 96 | "3. Run the next code cells to optimize and save the predicted DNA sequences." 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "# Update with the actual path to your dataset\n", 106 | "dataset_path = \"demo/sample_dataset.csv\"\n", 107 | "output_path = \"demo/sample_predictions.csv\"\n", 108 | "\n", 109 | "dataset = pd.read_csv(dataset_path, index_col=0)\n", 110 | "dataset[\"predicted_dna\"] = None\n", 111 | "dataset.head()" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "for index, data in tqdm(\n", 121 | " dataset.iterrows(),\n", 122 | " desc=f\"CodonTransformer Predicting\",\n", 123 | " unit=\" Sequences\",\n", 124 | " total=dataset.shape[0],\n", 125 | "):\n", 126 | "\n", 127 | " outputs = predict_dna_sequence(\n", 128 | " protein=data[\"protein_sequence\"],\n", 129 | " organism=data[\"organism\"],\n", 130 | " device=DEVICE,\n", 131 | " tokenizer_object=tokenizer,\n", 132 | " model_object=model,\n", 133 | " )\n", 134 | " dataset.loc[index, \"predicted_dna\"] = outputs.predicted_dna\n", 135 | "\n", 136 | "dataset.to_csv(output_path)\n", 137 | "dataset.head()" 138 | ] 139 | } 140 | ], 141 | "metadata": { 142 | "kernelspec": { 143 | "display_name": "light", 144 | "language": "python", 145 | "name": "python3" 146 | }, 147 | "language_info": { 148 | "codemirror_mode": { 149 | "name": "ipython", 150 | "version": 3 151 | }, 152 | "file_extension": ".py", 153 | "mimetype": "text/x-python", 154 | "name": "python", 155 | "nbconvert_exporter": "python", 156 | "pygments_lexer": "ipython3", 157 | "version": "3.12.2" 158 | } 159 | }, 160 | "nbformat": 4, 161 | "nbformat_minor": 2 162 | } 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2024 Adibvafa Fallahpour 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Makefile 2 | 3 | .PHONY: test 4 | test: 5 | python -m unittest discover -s tests 6 | 7 | .PHONY: test_with_coverage 8 | test_with_coverage: 9 | coverage run -m unittest discover -s tests 10 | -------------------------------------------------------------------------------- /demo/sample_dataset.csv: -------------------------------------------------------------------------------- 1 | ,protein_sequence,organism,predicted_dna 2 | 0,MSEKYIVTWDMLQIHARKLASRLMPSEQWKGIIAVSRGGLVPGALLARELGIRHVDTVCISSYDHDNQRELKVLKRAEGDGEGFIVIDDLVDTGGTAVAIREMYPKAHFVTIFAKPAGRPLVDDYVVDIPQDTWIEQPWDMGVVFVPPISGR_,Escherichia coli general, 3 | 1,MKNIIRTPETHPLTWRLRDDKQPVWLDEYRSKNGYEGARKALTGLSPDEIVNQVKDAGLKGRGGAGFSTGLKWSLMPKDESMNIRYLLCNADEMEPGTYKDRLLMEQLPHLLVEGMLISAFALKAYRGYIFLRGEYIEAAVNLRRAIAEATEAGLLGKNIMGTGFDFELFVHTGAGRYICGEETALINSLEGRRANPRSKPPFPATSGAWGKPTCVNNVETLCNVPAILANGVEWYQNISKSKDAGTKLMGFSGRVKNPGLWELPFGTTAREILEDYAGGMRDGLKFKAWQPGGAGTDFLTEAHLDLPMEFESIGKAGSRLGTALAMAVDHEINMVSLVRNLEEFFARESCGWCTPCRDGLPWSVKILRALERGEGQPGDIETLEQLCRFLGPGKTFCAHAPGAVEPLQSAIKYFREEFEAGIKQPFSNTHLINGIQPNLLKERW_,Escherichia coli general, 4 | 2,MDALQIAEDTLQTLVPHCPVPSGPRRIFLDANVKESYCPLVPHTMYCLPLWQGINLVLLTRSPSAPLALVLSQLMDGFSMLEKKLKEGPEPGASLRSQPLVGDLRQRMDKFVKNRGAQEIQSTWLEFKAKAFSKSEPGSSWELLQACGKLKRQLCAIYRLNFLTTAPSRGGPHLPQHLQDQVQRLMREKLTDWKDFLLVKSRRNITMVSYLEDFPGLVHFIYVDRTTGQMVAPSLNCSQKTSSELGKGPLAAFVKTKVWSLIQLARRYLQKGYTTLLFQEGDFYCSYFLWFENDMGYKLQMIEVPVLSDDSVPIGMLGGDYYRKLLRYYSKNRPTEAVRCYELLALHLSVIPTDLLVQQAGQLARRLWEASRIPLL_,Homo sapiens, 5 | 3,MAFANFRRILRLSTFEKRKSREYEHVRRDLDPNEVWEIVGELGDGAFGKVYKAKNKETGALAAAKVIETKSEEELEDYIVEIEILATCDHPYIVKLLGAYYHDGKLWIMIEFCPGGAVDAIMLELDRGLTEPQIQVVCRQMLEALNFLHSKRIIHRDLKAGNVLMTLEGDIRLADFGVSAKNLKTLQKRDSFIGTPYWMAPEVVMCETMKDTPYDYKADIWSLGITLIEMAQIEPPHHELNPMRVLLKIAKSDPPTLLTPSKWSVEFRDFLKIALDKNPETRPSAAQLLEHPFVSSITSNKALRELVAEAKAEVMEEIEDGRDEGEEEDAVDAASDPKLYKKTLKRTRKFVVDGVEVSITTSKIISEDEKKDEEMRFLRRQELRELRLLQKEEHRNQTQLSNKHELQLEQMHKRFEQEINAKKKFFDTELENLERQQKQQVEKMEQDHAVRRREEARRIRLEQDRDYTRFQEQLKLMKKEVKNEVEKLPRQQRKESMKQKMEEHTQKKQLLDRDFVAKQKEDLELAMKRLTTDNRREICDKERECLMKKQELLRDREAALWEMEEHQLQERHQLVKQQLKDQYFLQRHELLRKHEKEREQMQRYNQRMIEQLKVRQQQEKARLPKIQRSEGKTRMAMYKKSLHINGGGSAAEQREKIKQFSQQEEKRQKSERLQQQQKHENQMRDMLAQCESNMSELQQLQNEKCHLLVEHETQKLKALDESHNQNLKEWRDKLRPRKKALEEDLNQKKREQEMFFKLSEEAECPNPSTPSKAAKFFPYSSADAS_,Homo sapiens, 6 | 4,MTEKDAGGFNMSTFMNRKFQEPIQQIKTFSWMGFSWTCRKRRKHYQSYLRNGVRISVNDFVYVLAEQHKRLVAYIEDLYEDSKGKKMVVVRWFHKTEEVGSVLSDDDNDREIFFSLNRQDISIECIDYLATVLSPQHYEKFLKVPMHVQTVAFFCQKLYGDDGLKPYDITQLEGYWRQEMLRYLNVSILKSFEGAQAPGTDPGLKAPLVGCVGIRSRKRRRPSPVGTLNVSYAGDMKGDCKSSPDSVLAVTDASIFKGDEDGSSHHIKKGSLIEVLSEDSGIRGCWFKALVLKKHKDKVKVQYQDIQDADDESKKLEEWILTSRVAAGDHLGDLRIKGRKVVRPMLKPSKENDVCVIGVGMPVDVWWCDGWWEGIVVQEVSEEKFEVYLPGEKKMSAFHRNDLRQSREWLDDEWLNIRSRSDIVSSVLSLTKKKEMEVKHDEKSSDVGVCNGRMSPKTEAKRTISLPVATTKKSLPKRPIPDLLKDVLVTSDLKWKKSSRKRNRVVSCCPHDPSLNDGFSSERSLDCENCKFMEDTFGSSDGQHLTGLLMSR_,Arabidopsis thaliana, 7 | -------------------------------------------------------------------------------- /demo/sample_predictions.csv: -------------------------------------------------------------------------------- 1 | ,protein_sequence,organism,predicted_dna 2 | 0,MSEKYIVTWDMLQIHARKLASRLMPSEQWKGIIAVSRGGLVPGALLARELGIRHVDTVCISSYDHDNQRELKVLKRAEGDGEGFIVIDDLVDTGGTAVAIREMYPKAHFVTIFAKPAGRPLVDDYVVDIPQDTWIEQPWDMGVVFVPPISGR_,Escherichia coli general,ATGAGCGAAAAATATATTGTCACCTGGGACATGCTGCAGATCCATGCCCGCAAACTGGCCAGCCGCCTGATGCCGTCAGAACAGTGGAAAGGCATTATTGCCGTCAGCCGCGGCGGCCTGGTGCCGGGTGCGCTGCTGGCGCGTGAGCTGGGTATTCGCCACGTCGACACCGTGTGCATCAGCAGCTATGACCACGACAACCAGCGCGAGCTGAAAGTGCTGAAACGTGCGGAAGGCGATGGCGAAGGCTTTATCGTCATTGATGATCTGGTTGATACCGGCGGCACCGCGGTGGCGATCCGTGAAATGTACCCGAAAGCGCACTTTGTCACCATCTTTGCGAAACCGGCAGGCCGTCCGCTGGTTGATGATTATGTGGTTGATATTCCGCAGGACACCTGGATCGAACAGCCGTGGGACATGGGCGTGGTGTTTGTTCCGCCGATCAGCGGCCGCTAA 3 | 1,MKNIIRTPETHPLTWRLRDDKQPVWLDEYRSKNGYEGARKALTGLSPDEIVNQVKDAGLKGRGGAGFSTGLKWSLMPKDESMNIRYLLCNADEMEPGTYKDRLLMEQLPHLLVEGMLISAFALKAYRGYIFLRGEYIEAAVNLRRAIAEATEAGLLGKNIMGTGFDFELFVHTGAGRYICGEETALINSLEGRRANPRSKPPFPATSGAWGKPTCVNNVETLCNVPAILANGVEWYQNISKSKDAGTKLMGFSGRVKNPGLWELPFGTTAREILEDYAGGMRDGLKFKAWQPGGAGTDFLTEAHLDLPMEFESIGKAGSRLGTALAMAVDHEINMVSLVRNLEEFFARESCGWCTPCRDGLPWSVKILRALERGEGQPGDIETLEQLCRFLGPGKTFCAHAPGAVEPLQSAIKYFREEFEAGIKQPFSNTHLINGIQPNLLKERW_,Escherichia coli general,ATGAAAAATATTATTAGAACACCTGAAACCCATCCGCTGACCTGGCGTCTGCGCGATGACAAACAGCCGGTGTGGCTGGATGAGTACCGCAGCAAAAACGGCTATGAAGGTGCGCGTAAAGCGCTGACCGGTCTGTCTCCGGATGAGATTGTCAATCAGGTCAAAGATGCCGGCCTGAAAGGCCGTGGCGGTGCGGGTTTCTCCACCGGCCTGAAGTGGTCTCTGATGCCGAAAGATGAGAGCATGAACATCCGCTATCTGCTGTGCAATGCCGATGAGATGGAGCCGGGCACCTATAAAGACCGCCTGCTGATGGAGCAGCTGCCGCACCTGCTGGTTGAAGGTATGCTGATCTCTGCATTTGCGCTGAAAGCCTACCGTGGCTACATCTTCCTGCGTGGCGAGTACATCGAAGCGGCGGTGAACCTGCGCCGTGCGATTGCTGAAGCCACTGAAGCAGGTCTGCTGGGTAAAAACATCATGGGTACCGGTTTTGACTTTGAACTGTTCGTCCACACCGGTGCAGGGCGCTACATCTGCGGTGAAGAAACCGCGCTGATCAACAGCCTGGAAGGCCGTCGTGCGAACCCGCGCAGCAAACCGCCGTTCCCGGCAACCTCTGGTGCGTGGGGTAAACCGACCTGTGTTAATAACGTTGAAACCCTGTGCAACGTTCCGGCGATTCTGGCGAACGGTGTGGAATGGTATCAGAACATCTCCAAAAGCAAAGATGCTGGTACCAAGCTGATGGGTTTCTCCGGCCGTGTGAAAAACCCGGGCCTGTGGGAACTGCCGTTTGGTACCACCGCGCGTGAAATCCTGGAAGATTATGCCGGTGGCATGCGTGACGGCCTGAAGTTCAAAGCGTGGCAGCCGGGTGGTGCAGGTACCGATTTCCTGACTGAAGCGCACCTGGATCTGCCGATGGAGTTTGAGTCCATTGGTAAAGCGGGCAGCCGTCTGGGTACAGCGCTGGCGATGGCGGTTGACCATGAGATCAACATGGTGTCGCTGGTGCGTAACCTGGAAGAGTTCTTTGCCCGTGAAAGCTGCGGCTGGTGCACGCCGTGCCGCGACGGGCTGCCGTGGTCAGTGAAAATCCTGCGTGCGCTGGAGCGCGGTGAAGGCCAGCCGGGTGACATCGAAACGCTGGAACAGCTGTGCCGCTTCCTGGGTCCGGGTAAAACCTTCTGTGCACATGCACCCGGTGCGGTGGAACCGCTGCAGTCTGCGATCAAATATTTCCGTGAAGAGTTTGAAGCGGGGATCAAACAGCCGTTCTCCAACACCCATCTGATTAACGGGATTCAGCCGAACCTGCTGAAAGAGCGCTGGTAA 4 | 2,MDALQIAEDTLQTLVPHCPVPSGPRRIFLDANVKESYCPLVPHTMYCLPLWQGINLVLLTRSPSAPLALVLSQLMDGFSMLEKKLKEGPEPGASLRSQPLVGDLRQRMDKFVKNRGAQEIQSTWLEFKAKAFSKSEPGSSWELLQACGKLKRQLCAIYRLNFLTTAPSRGGPHLPQHLQDQVQRLMREKLTDWKDFLLVKSRRNITMVSYLEDFPGLVHFIYVDRTTGQMVAPSLNCSQKTSSELGKGPLAAFVKTKVWSLIQLARRYLQKGYTTLLFQEGDFYCSYFLWFENDMGYKLQMIEVPVLSDDSVPIGMLGGDYYRKLLRYYSKNRPTEAVRCYELLALHLSVIPTDLLVQQAGQLARRLWEASRIPLL_,Homo sapiens,ATGGATGCCCTGCAGATTGCTGAGGACACCCTGCAGACCCTGGTGCCCCACTGCCCTGTGCCCTCTGGGCCCAGGAGGATCTTCCTGGATGCCAATGTGAAGGAGAGCTACTGCCCCCTGGTGCCCCACACCATGTACTGCCTGCCCCTGTGGCAGGGCATCAACCTGGTCCTGCTGACCAGGTCTCCCTCTGCCCCCCTGGCCCTGGTGCTGTCCCAGCTGATGGATGGCTTTTCCATGCTGGAGAAGAAGCTGAAGGAGGGGCCCGAGCCTGGAGCCTCTCTGAGGAGCCAGCCCCTGGTGGGGGACCTGCGGCAGAGAATGGACAAATTTGTGAAGAACCGAGGGGCCCAGGAGATCCAGAGCACCTGGCTGGAATTCAAGGCCAAGGCCTTCTCCAAATCTGAGCCTGGCAGCAGCTGGGAGCTGCTGCAGGCCTGTGGGAAGCTGAAGAGACAGCTGTGTGCCATCTACAGGCTGAACTTCCTGACCACAGCCCCCTCCAGAGGAGGGCCCCACCTGCCCCAGCACCTGCAGGACCAGGTGCAGCGGCTGATGCGGGAGAAGCTGACTGACTGGAAGGACTTCCTGCTGGTGAAGAGCAGGAGGAACATCACCATGGTGTCCTACCTGGAGGACTTCCCTGGCCTGGTGCACTTCATCTATGTGGACAGGACCACTGGGCAGATGGTGGCCCCCAGCCTGAACTGCAGCCAGAAGACCAGCTCTGAGCTGGGCAAGGGGCCCCTGGCTGCCTTTGTGAAGACCAAAGTGTGGAGCCTGATCCAGCTGGCCCGGAGATACCTGCAGAAGGGCTATACCACCCTGCTGTTCCAGGAAGGAGACTTCTACTGCTCCTACTTCCTGTGGTTTGAGAATGATATGGGCTACAAGCTGCAGATGATTGAGGTGCCTGTGCTGTCTGATGACTCTGTCCCCATTGGCATGCTGGGAGGAGACTACTACCGGAAGCTGCTGCGCTATTACAGCAAGAACCGGCCCACTGAGGCTGTGCGCTGCTATGAGCTGCTGGCCCTGCACCTGTCTGTGATCCCCACTGACCTGCTGGTGCAGCAGGCTGGGCAGCTGGCCAGGAGGCTGTGGGAGGCCTCCAGGATCCCCCTGCTGTGA 5 | 3,MAFANFRRILRLSTFEKRKSREYEHVRRDLDPNEVWEIVGELGDGAFGKVYKAKNKETGALAAAKVIETKSEEELEDYIVEIEILATCDHPYIVKLLGAYYHDGKLWIMIEFCPGGAVDAIMLELDRGLTEPQIQVVCRQMLEALNFLHSKRIIHRDLKAGNVLMTLEGDIRLADFGVSAKNLKTLQKRDSFIGTPYWMAPEVVMCETMKDTPYDYKADIWSLGITLIEMAQIEPPHHELNPMRVLLKIAKSDPPTLLTPSKWSVEFRDFLKIALDKNPETRPSAAQLLEHPFVSSITSNKALRELVAEAKAEVMEEIEDGRDEGEEEDAVDAASDPKLYKKTLKRTRKFVVDGVEVSITTSKIISEDEKKDEEMRFLRRQELRELRLLQKEEHRNQTQLSNKHELQLEQMHKRFEQEINAKKKFFDTELENLERQQKQQVEKMEQDHAVRRREEARRIRLEQDRDYTRFQEQLKLMKKEVKNEVEKLPRQQRKESMKQKMEEHTQKKQLLDRDFVAKQKEDLELAMKRLTTDNRREICDKERECLMKKQELLRDREAALWEMEEHQLQERHQLVKQQLKDQYFLQRHELLRKHEKEREQMQRYNQRMIEQLKVRQQQEKARLPKIQRSEGKTRMAMYKKSLHINGGGSAAEQREKIKQFSQQEEKRQKSERLQQQQKHENQMRDMLAQCESNMSELQQLQNEKCHLLVEHETQKLKALDESHNQNLKEWRDKLRPRKKALEEDLNQKKREQEMFFKLSEEAECPNPSTPSKAAKFFPYSSADAS_,Homo sapiens,ATGGCCTTTGCCAACTTCCGGAGAATCCTGCGGCTGTCCACCTTTGAGAAGAGGAAGAGCCGGGAATATGAGCACGTGCGCAGGGACCTGGACCCCAACGAGGTGTGGGAGATTGTGGGGGAGCTGGGGGATGGTGCCTTTGGGAAGGTCTACAAGGCCAAGAACAAGGAGACTGGGGCCTTGGCTGCGGCCAAGGTAATTGAGACCAAATCTGAGGAGGAGCTGGAAGACTACATTGTGGAGATTGAGATTCTGGCCACCTGTGACCACCCCTACATTGTGAAGCTGCTGGGGGCCTACTACCATGATGGCAAGCTGTGGATCATGATCGAGTTCTGCCCTGGAGGGGCTGTGGATGCCATTATGCTGGAGCTGGACCGAGGCCTGACTGAGCCACAGATCCAGGTGGTGTGCAGGCAGATGCTGGAGGCCCTGAACTTCCTGCACAGCAAGAGAATCATTCACAGAGACCTGAAGGCTGGGAACGTGTTGATGACCCTGGAAGGAGACATCCGCTTGGCTGACTTTGGTGTCTCTGCCAAGAACCTGAAGACCTTACAGAAGAGGGACAGCTTCATTGGCACACCCTACTGGATGGCCCCTGAGGTGGTCATGTGTGAGACCATGAAGGACACACCCTATGACTACAAGGCTGACATTTGGAGCCTGGGCATCACCTTGATTGAGATGGCCCAGATTGAGCCACCACACCACGAGCTGAACCCAATGCGCGTGCTGCTGAAGATTGCCAAGTCAGACCCACCCACCTTGTTGACACCAAGCAAGTGGTCTGTGGAGTTCCGGGACTTCCTGAAGATTGCCCTGGACAAGAACCCTGAAACCAGGCCCTCTGCTGCCCAGCTGCTGGAGCACCCCTTTGTCTCCTCCATCACCTCCAACAAGGCCCTGCGGGAGCTGGTGGCTGAGGCCAAGGCTGAGGTGATGGAGGAGATTGAGGATGGCAGAGATGAAGGAGAGGAGGAGGATGCTGTGGATGCCGCCTCTGACCCCAAGCTGTACAAGAAGACCCTGAAGAGGACCCGCAAGTTTGTGGTGGACGGGGTGGAGGTGTCTATCACCACCTCCAAGATCATCTCTGAGGATGAGAAGAAGGATGAAGAGATGAGGTTCCTGCGGCGCCAGGAGCTGCGGGAGCTGCGGCTGCTGCAGAAGGAGGAGCACAGGAACCAGACCCAGCTGTCCAACAAGCATGAGCTGCAGCTGGAGCAGATGCACAAGAGGTTTGAGCAGGAGATCAATGCCAAGAAGAAGTTCTTTGACACTGAGCTGGAGAACCTGGAGAGGCAGCAGAAGCAGCAGGTGGAGAAGATGGAGCAGGACCATGCTGTGCGGAGGCGGGAGGAGGCCCGGAGGATCCGCCTGGAGCAGGACCGGGACTACACCCGCTTCCAGGAGCAGCTGAAGCTGATGAAGAAGGAGGTCAAGAACGAGGTGGAGAAGCTGCCCCGGCAGCAGAGGAAGGAAAGCATGAAGCAGAAGATGGAGGAGCACACCCAGAAGAAGCAGCTGCTGGACAGAGACTTTGTGGCCAAGCAGAAGGAGGACCTGGAGCTGGCTATGAAGAGACTGACCACTGATAACAGGAGGGAGATCTGTGACAAGGAGCGGGAGTGCCTGATGAAGAAGCAGGAGCTGCTGCGGGACCGGGAGGCTGCCCTGTGGGAGATGGAGGAGCACCAGCTGCAGGAGAGGCACCAGCTGGTCAAGCAGCAGCTGAAGGACCAGTACTTCCTGCAGAGGCACGAGCTGCTGAGGAAGCATGAGAAGGAGCGGGAGCAGATGCAGAGGTACAACCAGAGGATGATTGAGCAGCTGAAGGTGCGGCAGCAGCAGGAGAAGGCCAGACTGCCCAAGATCCAGAGATCTGAGGGGAAGACAAGGATGGCCATGTACAAGAAGTCTCTGCACATCAATGGGGGGGGCTCTGCCGCTGAGCAGCGGGAGAAGATCAAGCAGTTCTCCCAGCAGGAGGAGAAGAGGCAGAAGAGTGAGAGGCTGCAGCAGCAGCAGAAGCACGAGAACCAGATGAGGGACATGCTGGCCCAGTGTGAGTCCAACATGTCTGAGCTGCAGCAGCTGCAGAACGAGAAGTGTCACCTGCTGGTGGAGCATGAGACCCAGAAGCTGAAGGCCCTGGATGAGAGCCACAACCAGAATCTGAAGGAGTGGAGGGACAAGCTGAGGCCCAGGAAGAAGGCCCTGGAGGAGGACCTGAACCAGAAGAAGCGGGAGCAGGAGATGTTCTTCAAGCTGTCTGAGGAGGCCGAGTGCCCAAACCCAAGCACTCCATCCAAGGCTGCCAAGTTCTTCCCCTACAGCTCTGCCGACGCCAGCTAA 6 | 4,MTEKDAGGFNMSTFMNRKFQEPIQQIKTFSWMGFSWTCRKRRKHYQSYLRNGVRISVNDFVYVLAEQHKRLVAYIEDLYEDSKGKKMVVVRWFHKTEEVGSVLSDDDNDREIFFSLNRQDISIECIDYLATVLSPQHYEKFLKVPMHVQTVAFFCQKLYGDDGLKPYDITQLEGYWRQEMLRYLNVSILKSFEGAQAPGTDPGLKAPLVGCVGIRSRKRRRPSPVGTLNVSYAGDMKGDCKSSPDSVLAVTDASIFKGDEDGSSHHIKKGSLIEVLSEDSGIRGCWFKALVLKKHKDKVKVQYQDIQDADDESKKLEEWILTSRVAAGDHLGDLRIKGRKVVRPMLKPSKENDVCVIGVGMPVDVWWCDGWWEGIVVQEVSEEKFEVYLPGEKKMSAFHRNDLRQSREWLDDEWLNIRSRSDIVSSVLSLTKKKEMEVKHDEKSSDVGVCNGRMSPKTEAKRTISLPVATTKKSLPKRPIPDLLKDVLVTSDLKWKKSSRKRNRVVSCCPHDPSLNDGFSSERSLDCENCKFMEDTFGSSDGQHLTGLLMSR_,Arabidopsis thaliana,ATGACGGAGAAAGATGCTGGAGGTTTTAATATGTCAACTTTCATGAACAGGAAGTTTCAAGAACCAATTCAACAGATCAAAACTTTCTCCTGGATGGGTTTCTCATGGACTTGTAGGAAGAGGAGGAAACATTATCAATCTTACCTTAGGAATGGAGTGAGGATCTCTGTCAATGATTTTGTTTATGTTCTTGCTGAGCAACACAAGAGGCTTGTTGCTTACATTGAAGATCTTTATGAGGATAGCAAAGGGAAGAAGATGGTTGTTGTTAGGTGGTTCCACAAGACTGAAGAGGTTGGATCTGTTCTTAGCGATGATGACAACGACAGGGAGATCTTCTTCTCTCTCAACAGACAAGACATCAGCATTGAGTGCATTGATTACCTTGCCACTGTTCTCTCTCCTCAACATTACGAGAAGTTTCTCAAGGTTCCTATGCATGTTCAAACTGTTGCTTTCTTCTGCCAGAAGCTCTATGGAGATGATGGTTTGAAACCTTATGACATCACTCAGCTTGAAGGTTACTGGAGACAAGAAATGCTCAGATACCTCAATGTCTCCATTCTCAAGAGCTTTGAAGGAGCTCAAGCTCCTGGAACTGATCCTGGTTTGAAGGCTCCTTTGGTTGGTTGTGTTGGTATCAGAAGCAGGAAGAGGAGGAGACCATCACCGGTTGGAACTCTCAACGTCAGCTACGCTGGAGACATGAAAGGAGACTGCAAAAGCTCTCCTGATTCTGTTTTGGCTGTCACTGATGCTTCGATCTTCAAAGGAGATGAAGATGGATCTTCTCACCACATCAAGAAAGGAAGCTTGATTGAGGTTCTCAGCGAGGACTCTGGGATCCGTGGTTGCTGGTTCAAAGCTTTGGTGTTGAAGAAACACAAGGACAAGGTGAAGGTGCAGTACCAAGACATTCAAGATGCTGATGATGAGAGCAAGAAGCTTGAGGAGTGGATTCTCACTAGCCGTGTTGCTGCTGGAGATCATCTTGGTGATTTGAGGATCAAAGGAAGGAAAGTTGTGAGACCAATGCTCAAACCTTCCAAGGAGAACGATGTTTGTGTGATTGGTGTTGGAATGCCGGTTGATGTTTGGTGGTGTGATGGATGGTGGGAAGGGATTGTGGTTCAAGAGGTTTCTGAGGAGAAGTTTGAGGTTTATCTTCCTGGAGAGAAGAAAATGTCAGCTTTCCACAGAAATGATTTGAGACAAAGCAGAGAGTGGCTTGATGATGAGTGGCTCAACATTAGAAGCAGAAGTGACATTGTTTCTTCTGTTCTTTCTTTGACCAAGAAGAAAGAGATGGAGGTGAAGCATGATGAGAAAAGCAGCGATGTTGGTGTCTGCAATGGAAGAATGTCTCCAAAAACAGAAGCTAAGAGAACAATCTCTCTTCCTGTTGCTACAACCAAGAAATCTCTTCCTAAGAGACCAATTCCTGATCTTCTCAAGGATGTGTTGGTCACTTCTGATTTGAAGTGGAAGAAAAGCTCAAGGAAAAGAAACAGAGTTGTTTCCTGCTGTCCTCATGATCCATCTCTCAATGATGGTTTCTCCTCTGAGAGATCTCTTGATTGTGAGAACTGCAAGTTCATGGAAGACACTTTTGGTTCTTCTGATGGACAACATCTCACTGGTCTTCTTATGTCCAGATAA 7 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | """ 2 | File: finetune.py 3 | ------------------- 4 | Finetune the CodonTransformer model. 5 | 6 | The pretrained model is loaded directly from Hugging Face. 7 | The dataset is a JSON file. You can use prepare_training_data from CodonData to 8 | prepare the dataset. The repository README has a guide on how to prepare the 9 | dataset and use this script. 10 | """ 11 | 12 | import argparse 13 | import os 14 | 15 | import pytorch_lightning as pl 16 | import torch 17 | from torch.utils.data import DataLoader 18 | from transformers import AutoTokenizer, BigBirdForMaskedLM 19 | 20 | from CodonTransformer.CodonUtils import ( 21 | MAX_LEN, 22 | TOKEN2MASK, 23 | IterableJSONData, 24 | ) 25 | 26 | 27 | class MaskedTokenizerCollator: 28 | def __init__(self, tokenizer): 29 | self.tokenizer = tokenizer 30 | 31 | def __call__(self, examples): 32 | tokenized = self.tokenizer( 33 | [ex["codons"] for ex in examples], 34 | return_attention_mask=True, 35 | return_token_type_ids=True, 36 | truncation=True, 37 | padding=True, 38 | max_length=MAX_LEN, 39 | return_tensors="pt", 40 | ) 41 | 42 | seq_len = tokenized["input_ids"].shape[-1] 43 | species_index = torch.tensor([[ex["organism"]] for ex in examples]) 44 | tokenized["token_type_ids"] = species_index.repeat(1, seq_len) 45 | 46 | inputs = tokenized["input_ids"] 47 | targets = tokenized["input_ids"].clone() 48 | 49 | prob_matrix = torch.full(inputs.shape, 0.15) 50 | prob_matrix[torch.where(inputs < 5)] = 0.0 51 | selected = torch.bernoulli(prob_matrix).bool() 52 | 53 | # 80% of the time, replace masked input tokens with respective mask tokens 54 | replaced = torch.bernoulli(torch.full(selected.shape, 0.8)).bool() & selected 55 | inputs[replaced] = torch.tensor( 56 | list((map(TOKEN2MASK.__getitem__, inputs[replaced].numpy()))) 57 | ) 58 | 59 | # 10% of the time, we replace masked input tokens with random vector. 60 | randomized = ( 61 | torch.bernoulli(torch.full(selected.shape, 0.1)).bool() 62 | & selected 63 | & ~replaced 64 | ) 65 | random_idx = torch.randint(26, 90, prob_matrix.shape, dtype=torch.long) 66 | inputs[randomized] = random_idx[randomized] 67 | 68 | tokenized["input_ids"] = inputs 69 | tokenized["labels"] = torch.where(selected, targets, -100) 70 | 71 | return tokenized 72 | 73 | 74 | class plTrainHarness(pl.LightningModule): 75 | def __init__(self, model, learning_rate, warmup_fraction): 76 | super().__init__() 77 | self.model = model 78 | self.learning_rate = learning_rate 79 | self.warmup_fraction = warmup_fraction 80 | 81 | def configure_optimizers(self): 82 | optimizer = torch.optim.AdamW( 83 | self.model.parameters(), 84 | lr=self.learning_rate, 85 | ) 86 | lr_scheduler = { 87 | "scheduler": torch.optim.lr_scheduler.OneCycleLR( 88 | optimizer, 89 | max_lr=self.learning_rate, 90 | total_steps=self.trainer.estimated_stepping_batches, 91 | pct_start=self.warmup_fraction, 92 | ), 93 | "interval": "step", 94 | "frequency": 1, 95 | } 96 | return [optimizer], [lr_scheduler] 97 | 98 | def training_step(self, batch, batch_idx): 99 | self.model.bert.set_attention_type("block_sparse") 100 | outputs = self.model(**batch) 101 | self.log_dict( 102 | dictionary={ 103 | "loss": outputs.loss, 104 | "lr": self.trainer.optimizers[0].param_groups[0]["lr"], 105 | }, 106 | on_step=True, 107 | prog_bar=True, 108 | ) 109 | return outputs.loss 110 | 111 | 112 | class DumpStateDict(pl.callbacks.ModelCheckpoint): 113 | def __init__(self, checkpoint_dir, checkpoint_filename, every_n_train_steps): 114 | super().__init__( 115 | dirpath=checkpoint_dir, every_n_train_steps=every_n_train_steps 116 | ) 117 | self.checkpoint_filename = checkpoint_filename 118 | 119 | def on_save_checkpoint(self, trainer, pl_module, checkpoint): 120 | model = trainer.model.model 121 | torch.save( 122 | model.state_dict(), os.path.join(self.dirpath, self.checkpoint_filename) 123 | ) 124 | 125 | 126 | def main(args): 127 | """Finetune the CodonTransformer model.""" 128 | pl.seed_everything(args.seed) 129 | torch.set_float32_matmul_precision("medium") 130 | 131 | # Load the tokenizer and model 132 | tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer") 133 | model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer-base") 134 | harnessed_model = plTrainHarness(model, args.learning_rate, args.warmup_fraction) 135 | 136 | # Load the training data 137 | train_data = IterableJSONData(args.dataset_dir, dist_env="slurm") 138 | data_loader = DataLoader( 139 | dataset=train_data, 140 | collate_fn=MaskedTokenizerCollator(tokenizer), 141 | batch_size=args.batch_size, 142 | num_workers=0 if args.debug else args.num_workers, 143 | persistent_workers=False if args.debug else True, 144 | ) 145 | 146 | # Setup trainer and callbacks 147 | save_checkpoint = DumpStateDict( 148 | checkpoint_dir=args.checkpoint_dir, 149 | checkpoint_filename=args.checkpoint_filename, 150 | every_n_train_steps=args.save_every_n_steps, 151 | ) 152 | trainer = pl.Trainer( 153 | default_root_dir=args.checkpoint_dir, 154 | strategy="ddp_find_unused_parameters_true", 155 | accelerator="gpu", 156 | devices=1 if args.debug else args.num_gpus, 157 | precision="16-mixed", 158 | max_epochs=args.max_epochs, 159 | deterministic=False, 160 | enable_checkpointing=True, 161 | callbacks=[save_checkpoint], 162 | accumulate_grad_batches=args.accumulate_grad_batches, 163 | ) 164 | 165 | # Finetune the model 166 | trainer.fit(harnessed_model, data_loader) 167 | 168 | 169 | if __name__ == "__main__": 170 | parser = argparse.ArgumentParser(description="Finetune the CodonTransformer model.") 171 | parser.add_argument( 172 | "--dataset_dir", 173 | type=str, 174 | required=True, 175 | help="Directory containing the dataset", 176 | ) 177 | parser.add_argument( 178 | "--checkpoint_dir", 179 | type=str, 180 | required=True, 181 | help="Directory where checkpoints will be saved", 182 | ) 183 | parser.add_argument( 184 | "--checkpoint_filename", 185 | type=str, 186 | default="finetune.ckpt", 187 | help="Filename for the saved checkpoint", 188 | ) 189 | parser.add_argument( 190 | "--batch_size", type=int, default=6, help="Batch size for training" 191 | ) 192 | parser.add_argument( 193 | "--max_epochs", type=int, default=15, help="Maximum number of epochs to train" 194 | ) 195 | parser.add_argument( 196 | "--num_workers", type=int, default=5, help="Number of workers for data loading" 197 | ) 198 | parser.add_argument( 199 | "--accumulate_grad_batches", 200 | type=int, 201 | default=1, 202 | help="Number of batches to accumulate gradients", 203 | ) 204 | parser.add_argument( 205 | "--num_gpus", type=int, default=4, help="Number of GPUs to use for training" 206 | ) 207 | parser.add_argument( 208 | "--learning_rate", 209 | type=float, 210 | default=5e-5, 211 | help="Learning rate for the optimizer", 212 | ) 213 | parser.add_argument( 214 | "--warmup_fraction", 215 | type=float, 216 | default=0.1, 217 | help="Fraction of total steps to use for warmup", 218 | ) 219 | parser.add_argument( 220 | "--save_every_n_steps", 221 | type=int, 222 | default=512, 223 | help="Save checkpoint every N steps", 224 | ) 225 | parser.add_argument( 226 | "--seed", type=int, default=123, help="Random seed for reproducibility" 227 | ) 228 | parser.add_argument("--debug", action="store_true", help="Enable debug mode") 229 | args = parser.parse_args() 230 | main(args) 231 | -------------------------------------------------------------------------------- /pretrain.py: -------------------------------------------------------------------------------- 1 | """ 2 | File: pretrain.py 3 | ------------------- 4 | Pretrain the CodonTransformer model. 5 | 6 | The dataset is a JSON file. You can use prepare_training_data from CodonData to 7 | prepare the dataset. The repository README has a guide on how to prepare the 8 | dataset and use this script. 9 | """ 10 | 11 | import argparse 12 | import os 13 | 14 | import pytorch_lightning as pl 15 | import torch 16 | from torch.utils.data import DataLoader 17 | from transformers import BigBirdConfig, BigBirdForMaskedLM, PreTrainedTokenizerFast 18 | 19 | from CodonTransformer.CodonUtils import ( 20 | MAX_LEN, 21 | NUM_ORGANISMS, 22 | TOKEN2MASK, 23 | IterableJSONData, 24 | ) 25 | 26 | 27 | class MaskedTokenizerCollator: 28 | def __init__(self, tokenizer): 29 | self.tokenizer = tokenizer 30 | 31 | def __call__(self, examples): 32 | tokenized = self.tokenizer( 33 | [ex["codons"] for ex in examples], 34 | return_attention_mask=True, 35 | return_token_type_ids=True, 36 | truncation=True, 37 | padding=True, 38 | max_length=MAX_LEN, 39 | return_tensors="pt", 40 | ) 41 | 42 | seq_len = tokenized["input_ids"].shape[-1] 43 | species_index = torch.tensor([[ex["organism"]] for ex in examples]) 44 | tokenized["token_type_ids"] = species_index.repeat(1, seq_len) 45 | 46 | inputs = tokenized["input_ids"] 47 | targets = inputs.clone() 48 | 49 | prob_matrix = torch.full(inputs.shape, 0.15) 50 | prob_matrix[inputs < 5] = 0.0 51 | selected = torch.bernoulli(prob_matrix).bool() 52 | 53 | # 80% of the time, replace masked input tokens with respective mask tokens 54 | replaced = torch.bernoulli(torch.full(selected.shape, 0.8)).bool() & selected 55 | inputs[replaced] = torch.tensor( 56 | list((map(TOKEN2MASK.__getitem__, inputs[replaced].numpy()))) 57 | ) 58 | 59 | # 10% of the time, we replace masked input tokens with random vector. 60 | randomized = ( 61 | torch.bernoulli(torch.full(selected.shape, 0.1)).bool() 62 | & selected 63 | & ~replaced 64 | ) 65 | random_idx = torch.randint(26, 90, inputs.shape, dtype=torch.long) 66 | inputs[randomized] = random_idx[randomized] 67 | 68 | tokenized["input_ids"] = inputs 69 | tokenized["labels"] = torch.where(selected, targets, -100) 70 | 71 | return tokenized 72 | 73 | 74 | class plTrainHarness(pl.LightningModule): 75 | def __init__(self, model, learning_rate, warmup_fraction): 76 | super().__init__() 77 | self.model = model 78 | self.learning_rate = learning_rate 79 | self.warmup_fraction = warmup_fraction 80 | 81 | def configure_optimizers(self): 82 | optimizer = torch.optim.AdamW( 83 | self.model.parameters(), 84 | lr=self.learning_rate, 85 | ) 86 | lr_scheduler = { 87 | "scheduler": torch.optim.lr_scheduler.OneCycleLR( 88 | optimizer, 89 | max_lr=self.learning_rate, 90 | total_steps=self.trainer.estimated_stepping_batches, 91 | pct_start=self.warmup_fraction, 92 | ), 93 | "interval": "step", 94 | "frequency": 1, 95 | } 96 | return [optimizer], [lr_scheduler] 97 | 98 | def training_step(self, batch, batch_idx): 99 | self.model.bert.set_attention_type("block_sparse") 100 | outputs = self.model(**batch) 101 | self.log_dict( 102 | dictionary={ 103 | "loss": outputs.loss, 104 | "lr": self.trainer.optimizers[0].param_groups[0]["lr"], 105 | }, 106 | on_step=True, 107 | prog_bar=True, 108 | ) 109 | return outputs.loss 110 | 111 | 112 | class EpochCheckpoint(pl.Callback): 113 | def __init__(self, checkpoint_dir, save_interval): 114 | super().__init__() 115 | self.checkpoint_dir = checkpoint_dir 116 | self.save_interval = save_interval 117 | 118 | def on_train_epoch_end(self, trainer, pl_module): 119 | current_epoch = trainer.current_epoch 120 | if current_epoch % self.save_interval == 0 or current_epoch == 0: 121 | checkpoint_path = os.path.join( 122 | self.checkpoint_dir, f"epoch_{current_epoch}.ckpt" 123 | ) 124 | trainer.save_checkpoint(checkpoint_path) 125 | print(f"\nCheckpoint saved at {checkpoint_path}\n") 126 | 127 | 128 | def main(args): 129 | """Pretrain the CodonTransformer model.""" 130 | pl.seed_everything(args.seed) 131 | torch.set_float32_matmul_precision("medium") 132 | 133 | # Load the tokenizer and model 134 | tokenizer = PreTrainedTokenizerFast( 135 | tokenizer_file=args.tokenizer_path, 136 | bos_token="[CLS]", 137 | eos_token="[SEP]", 138 | unk_token="[UNK]", 139 | sep_token="[SEP]", 140 | pad_token="[PAD]", 141 | cls_token="[CLS]", 142 | mask_token="[MASK]", 143 | ) 144 | config = BigBirdConfig( 145 | vocab_size=len(tokenizer), 146 | type_vocab_size=NUM_ORGANISMS, 147 | sep_token_id=2, 148 | ) 149 | model = BigBirdForMaskedLM(config=config) 150 | harnessed_model = plTrainHarness(model, args.learning_rate, args.warmup_fraction) 151 | 152 | # Load the training data 153 | train_data = IterableJSONData(args.train_data_path, dist_env="slurm") 154 | data_loader = DataLoader( 155 | dataset=train_data, 156 | collate_fn=MaskedTokenizerCollator(tokenizer), 157 | batch_size=args.batch_size, 158 | num_workers=0 if args.debug else args.num_workers, 159 | persistent_workers=False if args.debug else True, 160 | ) 161 | 162 | # Setup trainer and callbacks 163 | save_checkpoint = EpochCheckpoint(args.checkpoint_dir, args.save_interval) 164 | trainer = pl.Trainer( 165 | default_root_dir=args.checkpoint_dir, 166 | strategy="ddp_find_unused_parameters_true", 167 | accelerator="gpu", 168 | devices=1 if args.debug else args.num_gpus, 169 | precision="16-mixed", 170 | max_epochs=args.max_epochs, 171 | deterministic=False, 172 | enable_checkpointing=True, 173 | callbacks=[save_checkpoint], 174 | accumulate_grad_batches=args.accumulate_grad_batches, 175 | ) 176 | 177 | # Pretrain the model 178 | trainer.fit(harnessed_model, data_loader) 179 | 180 | 181 | if __name__ == "__main__": 182 | parser = argparse.ArgumentParser(description="Pretrain the CodonTransformer model.") 183 | parser.add_argument( 184 | "--tokenizer_path", 185 | type=str, 186 | required=True, 187 | help="Path to the tokenizer model file", 188 | ) 189 | parser.add_argument( 190 | "--train_data_path", 191 | type=str, 192 | required=True, 193 | help="Path to the training data JSON file", 194 | ) 195 | parser.add_argument( 196 | "--checkpoint_dir", 197 | type=str, 198 | required=True, 199 | help="Directory where checkpoints will be saved", 200 | ) 201 | parser.add_argument( 202 | "--batch_size", type=int, default=6, help="Batch size for training" 203 | ) 204 | parser.add_argument( 205 | "--max_epochs", type=int, default=5, help="Maximum number of epochs to train" 206 | ) 207 | parser.add_argument( 208 | "--num_workers", type=int, default=5, help="Number of workers for data loading" 209 | ) 210 | parser.add_argument( 211 | "--accumulate_grad_batches", 212 | type=int, 213 | default=1, 214 | help="Number of batches to accumulate gradients", 215 | ) 216 | parser.add_argument( 217 | "--num_gpus", type=int, default=16, help="Number of GPUs to use for training" 218 | ) 219 | parser.add_argument( 220 | "--learning_rate", 221 | type=float, 222 | default=5e-5, 223 | help="Learning rate for the optimizer", 224 | ) 225 | parser.add_argument( 226 | "--warmup_fraction", 227 | type=float, 228 | default=0.1, 229 | help="Fraction of total steps to use for warmup", 230 | ) 231 | parser.add_argument( 232 | "--save_interval", type=int, default=5, help="Save checkpoint every N epochs" 233 | ) 234 | parser.add_argument( 235 | "--seed", type=int, default=123, help="Random seed for reproducibility" 236 | ) 237 | parser.add_argument("--debug", action="store_true", help="Enable debug mode") 238 | args = parser.parse_args() 239 | main(args) 240 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "CodonTransformer" 3 | version = "1.6.7" 4 | description = "The ultimate tool for codon optimization, transforming protein sequences into optimized DNA sequences specific for your target organisms." 5 | authors = ["Adibvafa Fallahpour "] 6 | license = "Apache-2.0" 7 | readme = "README.md" 8 | homepage = "https://github.com/adibvafa/CodonTransformer" 9 | repository = "https://github.com/adibvafa/CodonTransformer" 10 | classifiers = [ 11 | "Programming Language :: Python :: 3", 12 | "License :: OSI Approved :: Apache Software License", 13 | "Operating System :: OS Independent", 14 | ] 15 | 16 | [tool.poetry.dependencies] 17 | python = "^3.9" 18 | biopython = "^1.83" 19 | ipywidgets = "^7.0.0" 20 | numpy = "<2.0.0" 21 | onnxruntime = "^1.16.3" 22 | pandas = "^2.0.0" 23 | python_codon_tables = "^0.1.12" 24 | pytorch_lightning = "^2.2.1" 25 | scikit-learn = "^1.2.2" 26 | scipy = "^1.13.1" 27 | setuptools = "^70.0.0" 28 | torch = "^2.0.0" 29 | tqdm = "^4.66.2" 30 | transformers = "^4.40.0" 31 | CAI-PyPI = "^2.0.1" 32 | 33 | [tool.poetry.dev-dependencies] 34 | coverage = {version = "^7.0", extras = ["toml"]} 35 | 36 | [build-system] 37 | requires = ["poetry-core>=1.0.0"] 38 | build-backend = "poetry.core.masonry.api" 39 | 40 | [tool.ruff] 41 | line-length = 88 42 | indent-width = 4 43 | target-version = "py310" 44 | 45 | [tool.ruff.lint] 46 | select = ["E", "F", "I"] 47 | ignore = [] 48 | 49 | [tool.ruff.format] 50 | quote-style = "double" 51 | indent-style = "space" 52 | skip-magic-trailing-comma = false 53 | line-ending = "auto" 54 | 55 | [tool.coverage.run] 56 | omit = [ 57 | # omit pytorch-generated files in /tmp 58 | "/tmp/*", 59 | ] 60 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | biopython>=1.83,<2.0 2 | CAI-PyPI>=2.0.1,<3.0 3 | ipywidgets>=7.0.0,<10.0 4 | numpy>=1.26.4,<2.0 5 | onnxruntime>=1.16.3,<3.0 6 | pandas>=2.0.0,<3.0 7 | python_codon_tables>=0.1.12,<1.0 8 | pytorch_lightning>=2.2.1,<3.0 9 | scikit-learn>=1.2.2,<2.0 10 | scipy>=1.13.1,<3.0 11 | setuptools>=70.0.0 12 | torch>=2.0.0,<3.0 13 | tqdm>=4.66.2,<5.0 14 | transformers>=4.40.0,<5.0 15 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # setup.py 2 | import os 3 | 4 | from setuptools import find_packages, setup 5 | 6 | 7 | def read_requirements(): 8 | with open("requirements.txt") as f: 9 | return [line.strip() for line in f if line.strip() and not line.startswith("#")] 10 | 11 | 12 | def read_readme(): 13 | here = os.path.abspath(os.path.dirname(__file__)) 14 | readme_path = os.path.join(here, "README.md") 15 | 16 | with open(readme_path, "r", encoding="utf-8") as f: 17 | return f.read() 18 | 19 | 20 | setup( 21 | name="CodonTransformer", 22 | version="1.6.7", 23 | packages=find_packages(), 24 | install_requires=read_requirements(), 25 | author="Adibvafa Fallahpour", 26 | author_email="Adibvafa.fallahpour@mail.utoronto.ca", 27 | description=( 28 | "The ultimate tool for codon optimization, " 29 | "transforming protein sequences into optimized DNA sequences " 30 | "specific for your target organisms." 31 | ), 32 | long_description=read_readme(), 33 | long_description_content_type="text/markdown", 34 | url="https://github.com/adibvafa/CodonTransformer", 35 | classifiers=[ 36 | "Programming Language :: Python :: 3", 37 | "License :: OSI Approved :: Apache Software License", 38 | "Operating System :: OS Independent", 39 | ], 40 | python_requires=">=3.9", 41 | ) 42 | -------------------------------------------------------------------------------- /slurm/finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=finetune 3 | #SBATCH --output=your_output_directory/output_%j.out 4 | #SBATCH --error=your_error_directory/error_%j.err 5 | #SBATCH --nodes=1 6 | #SBATCH --ntasks-per-node=4 7 | #SBATCH --gpus-per-node=4 8 | #SBATCH --cpus-per-task=6 9 | #SBATCH --time=15:00:00 10 | #SBATCH --partition=compute_full_node 11 | 12 | # Load required modules 13 | module --ignore_cache load cuda/11.4.4 14 | module --ignore_cache load anaconda3 15 | source activate your_environment 16 | 17 | # Change to the working directory 18 | cd your_working_directory 19 | 20 | # Set environment variables 21 | export CUBLAS_WORKSPACE_CONFIG=:4096:2 22 | export NCCL_DEBUG=INFO 23 | export PYTHONFAULTHANDLER=1 24 | 25 | # Run the Python script with arguments 26 | stdbuf -oL -eL srun python finetune.py \ 27 | --dataset_dir your_dataset_directory \ 28 | --checkpoint_dir your_checkpoint_directory \ 29 | --checkpoint_filename finetune.ckpt \ 30 | --batch_size 6 \ 31 | --max_epochs 15 \ 32 | --num_workers 5 \ 33 | --accumulate_grad_batches 1 \ 34 | --num_gpus 4 \ 35 | --learning_rate 0.00005 \ 36 | --warmup_fraction 0.1 \ 37 | --save_every_n_steps 512 \ 38 | --seed 123 39 | -------------------------------------------------------------------------------- /slurm/pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=pretrain 3 | #SBATCH --output=your_output_directory/output_%j.out 4 | #SBATCH --error=your_error_directory/error_%j.err 5 | #SBATCH --nodes=4 6 | #SBATCH --gpus-per-node=4 7 | #SBATCH --ntasks-per-node=4 8 | #SBATCH --time=23:59:00 9 | #SBATCH -p compute_full_node 10 | 11 | # Load required modules 12 | module --ignore_cache load cuda/11.4.4 13 | module --ignore_cache load anaconda3 14 | source activate your_environment 15 | 16 | # Change to the working directory 17 | cd your_working_directory 18 | 19 | # Set environment variables 20 | export CUBLAS_WORKSPACE_CONFIG=:4096:2 21 | export NCCL_DEBUG=INFO 22 | export PYTHONFAULTHANDLER=1 23 | 24 | # Run the Python script with arguments 25 | stdbuf -oL -eL srun python pretrain.py \ 26 | --tokenizer_path your_tokenizer_path/CodonTransformerTokenizer.json \ 27 | --train_data_path your_data_directory/pretrain_dataset.json \ 28 | --checkpoint_dir your_checkpoint_directory \ 29 | --batch_size 6 \ 30 | --max_epochs 5 \ 31 | --num_workers 5 \ 32 | --accumulate_grad_batches 1 \ 33 | --num_gpus 16 \ 34 | --learning_rate 0.00005 \ 35 | --warmup_fraction 0.1 \ 36 | --save_interval 5 \ 37 | --seed 123 38 | -------------------------------------------------------------------------------- /src/CodonTransformerTokenizer.json: -------------------------------------------------------------------------------- 1 | {"version": "1.0", "truncation": null, "padding": null, "added_tokens": [{"id": 0, "special": true, "content": "[UNK]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false}, {"id": 1, "special": true, "content": "[CLS]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false}, {"id": 2, "special": true, "content": "[SEP]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false}, {"id": 3, "special": true, "content": "[PAD]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false}, {"id": 4, "special": true, "content": "[MASK]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false}], "normalizer": {"type": "Sequence", "normalizers": [{"type": "Lowercase"}]}, "pre_tokenizer": {"type": "Sequence", "pretokenizers": [{"type": "Split", "pattern": {"String": " "}, "behavior": "Isolated", "invert": false}, {"type": "Whitespace"}]}, "post_processor": {"type": "TemplateProcessing", "single": [{"SpecialToken": {"id": "[CLS]", "type_id": 0}}, {"Sequence": {"id": "A", "type_id": 0}}, {"SpecialToken": {"id": "[SEP]", "type_id": 0}}], "pair": [{"SpecialToken": {"id": "[CLS]", "type_id": 0}}, {"Sequence": {"id": "A", "type_id": 0}}, {"SpecialToken": {"id": "[SEP]", "type_id": 0}}, {"Sequence": {"id": "B", "type_id": 1}}, {"SpecialToken": {"id": "[SEP]", "type_id": 1}}], "special_tokens": {"[CLS]": {"id": "[CLS]", "ids": [1], "tokens": ["[CLS]"]}, "[SEP]": {"id": "[SEP]", "ids": [2], "tokens": ["[SEP]"]}}}, "decoder": null, "model": {"type": "WordPiece", "unk_token": "[UNK]", "continuing_subword_prefix": "##", "max_input_chars_per_word": 100, "vocab": {"[UNK]": 0, "[CLS]": 1, "[SEP]": 2, "[PAD]": 3, "[MASK]": 4, "a_unk": 5, "c_unk": 6, "d_unk": 7, "e_unk": 8, "f_unk": 9, "g_unk": 10, "h_unk": 11, "i_unk": 12, "k_unk": 13, "l_unk": 14, "m_unk": 15, "n_unk": 16, "p_unk": 17, "q_unk": 18, "r_unk": 19, "s_unk": 20, "t_unk": 21, "v_unk": 22, "w_unk": 23, "y_unk": 24, "__unk": 25, "k_aaa": 26, "n_aac": 27, "k_aag": 28, "n_aat": 29, "t_aca": 30, "t_acc": 31, "t_acg": 32, "t_act": 33, "r_aga": 34, "s_agc": 35, "r_agg": 36, "s_agt": 37, "i_ata": 38, "i_atc": 39, "m_atg": 40, "i_att": 41, "q_caa": 42, "h_cac": 43, "q_cag": 44, "h_cat": 45, "p_cca": 46, "p_ccc": 47, "p_ccg": 48, "p_cct": 49, "r_cga": 50, "r_cgc": 51, "r_cgg": 52, "r_cgt": 53, "l_cta": 54, "l_ctc": 55, "l_ctg": 56, "l_ctt": 57, "e_gaa": 58, "d_gac": 59, "e_gag": 60, "d_gat": 61, "a_gca": 62, "a_gcc": 63, "a_gcg": 64, "a_gct": 65, "g_gga": 66, "g_ggc": 67, "g_ggg": 68, "g_ggt": 69, "v_gta": 70, "v_gtc": 71, "v_gtg": 72, "v_gtt": 73, "__taa": 74, "y_tac": 75, "__tag": 76, "y_tat": 77, "s_tca": 78, "s_tcc": 79, "s_tcg": 80, "s_tct": 81, "__tga": 82, "c_tgc": 83, "w_tgg": 84, "c_tgt": 85, "l_tta": 86, "f_ttc": 87, "l_ttg": 88, "f_ttt": 89}}} 2 | -------------------------------------------------------------------------------- /src/CodonTransformer_inference_template.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Adibvafa/CodonTransformer/6a45669f8b9ab8d81395dc917d4bbb05343d2e12/src/CodonTransformer_inference_template.xlsx -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | """Model weights, tokenizer, and other resources.""" 2 | -------------------------------------------------------------------------------- /src/banner_final.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Adibvafa/CodonTransformer/6a45669f8b9ab8d81395dc917d4bbb05343d2e12/src/banner_final.png -------------------------------------------------------------------------------- /src/organism2id.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Adibvafa/CodonTransformer/6a45669f8b9ab8d81395dc917d4bbb05343d2e12/src/organism2id.pkl -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Adibvafa/CodonTransformer/6a45669f8b9ab8d81395dc917d4bbb05343d2e12/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_CodonData.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | import unittest 3 | 4 | import pandas as pd 5 | from Bio.Data.CodonTable import TranslationError 6 | 7 | from CodonTransformer.CodonData import ( 8 | build_amino2codon_skeleton, 9 | get_amino_acid_sequence, 10 | is_correct_seq, 11 | preprocess_protein_sequence, 12 | read_fasta_file, 13 | ) 14 | from CodonTransformer.CodonUtils import ProteinConfig 15 | 16 | 17 | class TestCodonData(unittest.TestCase): 18 | def test_preprocess_protein_sequence(self): 19 | with ProteinConfig() as config: 20 | config.set("ambiguous_aminoacid_behavior", "raise_error") 21 | protein = "Z_" 22 | try: 23 | preprocess_protein_sequence(protein) 24 | self.fail("Expected ValueError") 25 | except ValueError: 26 | pass 27 | config.set("ambiguous_aminoacid_behavior", "standardize_deterministic") 28 | for _ in range(10): 29 | preprocessed_protein = preprocess_protein_sequence(protein) 30 | self.assertEqual(preprocessed_protein, "Q_") 31 | config.set("ambiguous_aminoacid_behavior", "standardize_random") 32 | random_results = set() 33 | # The probability of getting the same result 30 times in a row is 34 | # 1 in 1.073741824*10^9 if there are only two possible results. 35 | for _ in range(30): 36 | preprocessed_protein = preprocess_protein_sequence(protein) 37 | random_results.add(preprocessed_protein) 38 | self.assertGreater(len(random_results), 1) 39 | 40 | def test_read_fasta_file(self): 41 | fasta_content = ">sequence1\n" "ATGATGATGATGATG\n" ">sequence2\n" "TGATGATGATGA" 42 | 43 | with tempfile.NamedTemporaryFile( 44 | mode="w", delete=False, suffix=".fasta" 45 | ) as temp_file: 46 | temp_file.write(fasta_content) 47 | temp_file_name = temp_file.name 48 | 49 | try: 50 | sequences = read_fasta_file(temp_file_name, save_to_file=None) 51 | self.assertIsInstance(sequences, pd.DataFrame) 52 | self.assertEqual(len(sequences), 2) 53 | self.assertEqual(sequences.iloc[0]["dna"], "ATGATGATGATGATG") 54 | self.assertEqual(sequences.iloc[1]["dna"], "TGATGATGATGA") 55 | finally: 56 | import os 57 | 58 | os.unlink(temp_file_name) 59 | 60 | def test_build_amino2codon_skeleton(self): 61 | organism = "Homo sapiens" 62 | codon_skeleton = build_amino2codon_skeleton(organism) 63 | 64 | expected_amino_acids = "ARNDCQEGHILKMFPSTWYV_" 65 | 66 | for amino_acid in expected_amino_acids: 67 | self.assertIn(amino_acid, codon_skeleton) 68 | codons, frequencies = codon_skeleton[amino_acid] 69 | self.assertIsInstance(codons, list) 70 | self.assertIsInstance(frequencies, list) 71 | self.assertEqual(len(codons), len(frequencies)) 72 | self.assertTrue(all(isinstance(codon, str) for codon in codons)) 73 | self.assertTrue(all(freq == 0 for freq in frequencies)) 74 | 75 | all_codons = set( 76 | codon for codons, _ in codon_skeleton.values() for codon in codons 77 | ) 78 | self.assertEqual(len(all_codons), 64) # There should be 64 unique codons 79 | 80 | def test_get_amino_acid_sequence(self): 81 | dna = "ATGGCCTGA" 82 | protein, is_correct = get_amino_acid_sequence(dna, return_correct_seq=True) 83 | self.assertEqual(protein, "MA_") 84 | self.assertTrue(is_correct) 85 | 86 | def test_is_correct_seq(self): 87 | dna = "ATGGCCTGA" 88 | protein = "MA_" 89 | self.assertTrue(is_correct_seq(dna, protein)) 90 | 91 | def test_read_fasta_file_raises_exception_for_non_dna(self): 92 | non_dna_content = ">sequence1\nATGATGATGXYZATG\n>sequence2\nTGATGATGATGA" 93 | 94 | with tempfile.NamedTemporaryFile( 95 | mode="w", delete=False, suffix=".fasta" 96 | ) as temp_file: 97 | temp_file.write(non_dna_content) 98 | temp_file_name = temp_file.name 99 | 100 | try: 101 | with self.assertRaises(TranslationError) as context: 102 | read_fasta_file(temp_file_name) 103 | self.assertIn("Codon 'XYZ' is invalid", str(context.exception)) 104 | finally: 105 | import os 106 | 107 | os.unlink(temp_file_name) 108 | 109 | 110 | if __name__ == "__main__": 111 | unittest.main() 112 | -------------------------------------------------------------------------------- /tests/test_CodonJupyter.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import ipywidgets 4 | 5 | from CodonTransformer.CodonJupyter import ( 6 | DNASequencePrediction, 7 | UserContainer, 8 | create_dropdown_options, 9 | create_organism_dropdown, 10 | display_organism_dropdown, 11 | display_protein_input, 12 | format_model_output, 13 | ) 14 | from CodonTransformer.CodonUtils import ORGANISM2ID 15 | 16 | 17 | class TestCodonJupyter(unittest.TestCase): 18 | def test_UserContainer(self): 19 | user_container = UserContainer() 20 | self.assertEqual(user_container.organism, -1) 21 | self.assertEqual(user_container.protein, "") 22 | 23 | def test_create_organism_dropdown(self): 24 | container = UserContainer() 25 | dropdown = create_organism_dropdown(container) 26 | 27 | self.assertIsInstance(dropdown, ipywidgets.Dropdown) 28 | self.assertGreater(len(dropdown.options), 0) 29 | self.assertEqual(dropdown.description, "") 30 | self.assertEqual(dropdown.layout.width, "40%") 31 | self.assertEqual(dropdown.layout.margin, "0 0 10px 0") 32 | self.assertEqual(dropdown.style.description_width, "initial") 33 | 34 | # Test the dropdown options 35 | options = dropdown.options 36 | self.assertIn("", options) 37 | self.assertIn("Selected Organisms", options) 38 | self.assertIn("All Organisms", options) 39 | 40 | def test_create_dropdown_options(self): 41 | options = create_dropdown_options(ORGANISM2ID) 42 | self.assertIsInstance(options, list) 43 | self.assertGreater(len(options), 0) 44 | 45 | def test_display_organism_dropdown(self): 46 | container = UserContainer() 47 | with unittest.mock.patch( 48 | "CodonTransformer.CodonJupyter.display" 49 | ) as mock_display: 50 | display_organism_dropdown(container) 51 | 52 | # Check that display was called twice (for container_widget and HTML) 53 | self.assertEqual(mock_display.call_count, 2) 54 | 55 | # Check that the first call to display was with a VBox widget 56 | self.assertIsInstance(mock_display.call_args_list[0][0][0], ipywidgets.VBox) 57 | 58 | # Check that the VBox contains a Dropdown 59 | dropdown = mock_display.call_args_list[0][0][0].children[1] 60 | self.assertIsInstance(dropdown, ipywidgets.Dropdown) 61 | self.assertGreater(len(dropdown.options), 0) 62 | 63 | def test_display_protein_input(self): 64 | container = UserContainer() 65 | with unittest.mock.patch( 66 | "CodonTransformer.CodonJupyter.display" 67 | ) as mock_display: 68 | display_protein_input(container) 69 | 70 | # Check that display was called twice (for container_widget and HTML) 71 | self.assertEqual(mock_display.call_count, 2) 72 | 73 | # Check that the first call to display was with a VBox widget 74 | self.assertIsInstance(mock_display.call_args_list[0][0][0], ipywidgets.VBox) 75 | 76 | # Check that the VBox contains a Textarea 77 | textarea = mock_display.call_args_list[0][0][0].children[1] 78 | self.assertIsInstance(textarea, ipywidgets.Textarea) 79 | 80 | # Verify the properties of the Textarea 81 | self.assertEqual(textarea.value, "") 82 | self.assertEqual(textarea.placeholder, "Enter here...") 83 | self.assertEqual(textarea.description, "") 84 | self.assertEqual(textarea.layout.width, "100%") 85 | self.assertEqual(textarea.layout.height, "100px") 86 | self.assertEqual(textarea.layout.margin, "0 0 10px 0") 87 | self.assertEqual(textarea.style.description_width, "initial") 88 | 89 | def test_format_model_output(self): 90 | output = DNASequencePrediction( 91 | organism="Escherichia coli", 92 | protein="MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG", 93 | processed_input="MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG", 94 | predicted_dna="ATGAAAACTGTTCGTCAGGAACGTCTGAAATCTATTGTTCGTATTCTGGAACGTTCTAAAGAACCGGTTTCTGGTGCTCAACTGGCTGAAGAACTGTCTGTTTCTCGTCAGGTTATTGTTCAGGACATTGCTTACCTGCGTTCTCTGGGTTATAA", 95 | ) 96 | formatted_output = format_model_output(output) 97 | self.assertIsInstance(formatted_output, str) 98 | self.assertIn("Organism", formatted_output) 99 | self.assertIn("Escherichia coli", formatted_output) 100 | self.assertIn("Input Protein", formatted_output) 101 | self.assertIn( 102 | "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG", 103 | formatted_output, 104 | ) 105 | self.assertIn("Processed Input", formatted_output) 106 | self.assertIn( 107 | "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG", 108 | formatted_output, 109 | ) 110 | self.assertIn("Predicted DNA", formatted_output) 111 | self.assertIn( 112 | "ATGAAAACTGTTCGTCAGGAACGTCTGAAATCTATTGTTCGTATTCTGGAACGTTCTAAAGAACCGGTTTCTGGTGCTCAACTGGCTGAAGAACTGTCTGTTTCTCGTCAGGTTATTGTTCAGGACATTGCTTACCTGCGTTCTCTGGGTTATAA", 113 | formatted_output, 114 | ) 115 | 116 | 117 | if __name__ == "__main__": 118 | unittest.main() 119 | -------------------------------------------------------------------------------- /tests/test_CodonPrediction.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | import warnings 4 | 5 | import torch 6 | 7 | from CodonTransformer.CodonData import get_amino_acid_sequence 8 | from CodonTransformer.CodonPrediction import ( 9 | load_model, 10 | load_tokenizer, 11 | predict_dna_sequence, 12 | ) 13 | from CodonTransformer.CodonUtils import ( 14 | AMINO_ACIDS, 15 | ORGANISM2ID, 16 | STOP_SYMBOLS, 17 | DNASequencePrediction, 18 | ) 19 | 20 | 21 | class TestCodonPrediction(unittest.TestCase): 22 | @classmethod 23 | def setUpClass(cls): 24 | cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 25 | 26 | # Suppress warnings about loading from HuggingFace 27 | for message in [ 28 | "Tokenizer path not provided. Loading from HuggingFace.", 29 | "Model path not provided. Loading from HuggingFace.", 30 | ]: 31 | warnings.filterwarnings("ignore", message=message) 32 | 33 | cls.model = load_model(device=cls.device) 34 | cls.tokenizer = load_tokenizer() 35 | 36 | def test_predict_dna_sequence_valid_input(self): 37 | protein_sequence = "MWWMW" 38 | organism = "Escherichia coli general" 39 | result = predict_dna_sequence( 40 | protein_sequence, 41 | organism, 42 | device=self.device, 43 | tokenizer=self.tokenizer, 44 | model=self.model, 45 | ) 46 | self.assertIsInstance(result.predicted_dna, str) 47 | self.assertTrue( 48 | all(nucleotide in "ATCG" for nucleotide in result.predicted_dna) 49 | ) 50 | self.assertEqual(result.predicted_dna, "ATGTGGTGGATGTGGTGA") 51 | 52 | def test_predict_dna_sequence_non_deterministic(self): 53 | protein_sequence = "MFWY" 54 | organism = "Escherichia coli general" 55 | num_iterations = 100 56 | temperatures = [0.2, 0.5, 0.8] 57 | possible_outputs = set() 58 | possible_encodings_wo_stop = { 59 | "ATGTTTTGGTAT", 60 | "ATGTTCTGGTAT", 61 | "ATGTTTTGGTAC", 62 | "ATGTTCTGGTAC", 63 | } 64 | for _ in range(num_iterations): 65 | for temperature in temperatures: 66 | result = predict_dna_sequence( 67 | protein=protein_sequence, 68 | organism=organism, 69 | device=self.device, 70 | tokenizer=self.tokenizer, 71 | model=self.model, 72 | deterministic=False, 73 | temperature=temperature, 74 | ) 75 | possible_outputs.add(result.predicted_dna[:-3]) # Remove stop codon 76 | 77 | self.assertEqual(possible_outputs, possible_encodings_wo_stop) 78 | 79 | def test_predict_dna_sequence_invalid_inputs(self): 80 | test_cases = [ 81 | ("MKTZZFVLLL?", "Escherichia coli general", "invalid protein sequence"), 82 | ("MKTFFVLLL", "Alien $%#@!", "invalid organism code"), 83 | ("", "Escherichia coli general", "empty protein sequence"), 84 | ] 85 | 86 | for protein_sequence, organism, error_type in test_cases: 87 | with self.subTest(error_type=error_type): 88 | with self.assertRaises(ValueError): 89 | predict_dna_sequence( 90 | protein_sequence, 91 | organism, 92 | device=self.device, 93 | tokenizer=self.tokenizer, 94 | model=self.model, 95 | ) 96 | 97 | def test_predict_dna_sequence_top_p_effect(self): 98 | """Test that changing top_p affects the diversity of outputs.""" 99 | protein_sequence = "MFWY" 100 | organism = "Escherichia coli general" 101 | num_iterations = 50 102 | temperature = 0.5 103 | top_p_values = [0.8, 0.95] 104 | outputs_by_top_p = {top_p: set() for top_p in top_p_values} 105 | 106 | for top_p in top_p_values: 107 | for _ in range(num_iterations): 108 | result = predict_dna_sequence( 109 | protein=protein_sequence, 110 | organism=organism, 111 | device=self.device, 112 | tokenizer=self.tokenizer, 113 | model=self.model, 114 | deterministic=False, 115 | temperature=temperature, 116 | top_p=top_p, 117 | ) 118 | outputs_by_top_p[top_p].add( 119 | result.predicted_dna[:-3] 120 | ) # Remove stop codon 121 | 122 | # Assert that higher top_p results in more diverse outputs 123 | diversity_lower_top_p = len(outputs_by_top_p[0.8]) 124 | diversity_higher_top_p = len(outputs_by_top_p[0.95]) 125 | self.assertGreaterEqual( 126 | diversity_higher_top_p, 127 | diversity_lower_top_p, 128 | "Higher top_p should result in more diverse outputs", 129 | ) 130 | 131 | def test_predict_dna_sequence_invalid_temperature_and_top_p(self): 132 | """Test that invalid temperature and top_p values raise ValueError.""" 133 | protein_sequence = "MWWMW" 134 | organism = "Escherichia coli general" 135 | invalid_params = [ 136 | {"temperature": -0.1, "top_p": 0.95}, 137 | {"temperature": 0, "top_p": 0.95}, 138 | {"temperature": 0.5, "top_p": -0.1}, 139 | {"temperature": 0.5, "top_p": 1.1}, 140 | ] 141 | 142 | for params in invalid_params: 143 | with self.subTest(params=params): 144 | with self.assertRaises(ValueError): 145 | predict_dna_sequence( 146 | protein=protein_sequence, 147 | organism=organism, 148 | device=self.device, 149 | tokenizer=self.tokenizer, 150 | model=self.model, 151 | deterministic=False, 152 | temperature=params["temperature"], 153 | top_p=params["top_p"], 154 | ) 155 | 156 | def test_predict_dna_sequence_translation_consistency(self): 157 | """Test that the predicted DNA translates back to the original protein.""" 158 | protein_sequence = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVE" 159 | organism = "Escherichia coli general" 160 | result = predict_dna_sequence( 161 | protein=protein_sequence, 162 | organism=organism, 163 | device=self.device, 164 | tokenizer=self.tokenizer, 165 | model=self.model, 166 | deterministic=True, 167 | ) 168 | 169 | # Translate predicted DNA back to protein 170 | translated_protein = get_amino_acid_sequence(result.predicted_dna[:-3]) 171 | 172 | self.assertEqual( 173 | translated_protein, 174 | protein_sequence, 175 | "Translated protein does not match the original protein sequence", 176 | ) 177 | 178 | def test_predict_dna_sequence_long_protein_sequence(self): 179 | """Test the function with a very long protein sequence to check performance and correctness.""" 180 | protein_sequence = ( 181 | "M" 182 | + "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG" 183 | * 20 184 | + STOP_SYMBOLS[0] 185 | ) 186 | organism = "Escherichia coli general" 187 | result = predict_dna_sequence( 188 | protein=protein_sequence, 189 | organism=organism, 190 | device=self.device, 191 | tokenizer=self.tokenizer, 192 | model=self.model, 193 | deterministic=True, 194 | ) 195 | 196 | # Check that the predicted DNA translates back to the original protein 197 | dna_sequence = result.predicted_dna[:-3] 198 | translated_protein = get_amino_acid_sequence(dna_sequence) 199 | self.assertEqual( 200 | translated_protein, 201 | protein_sequence[:-1], 202 | "Translated protein does not match the original long protein sequence", 203 | ) 204 | 205 | def test_predict_dna_sequence_edge_case_organisms(self): 206 | """Test the function with organism IDs at the boundaries of the mapping.""" 207 | protein_sequence = "MWWMW" 208 | # Assuming ORGANISM2ID has IDs starting from 0 to N 209 | min_organism_id = min(ORGANISM2ID.values()) 210 | max_organism_id = max(ORGANISM2ID.values()) 211 | organisms = [min_organism_id, max_organism_id] 212 | 213 | for organism_id in organisms: 214 | with self.subTest(organism_id=organism_id): 215 | result = predict_dna_sequence( 216 | protein=protein_sequence, 217 | organism=organism_id, 218 | device=self.device, 219 | tokenizer=self.tokenizer, 220 | model=self.model, 221 | deterministic=True, 222 | ) 223 | self.assertIsInstance(result.predicted_dna, str) 224 | self.assertTrue( 225 | all(nucleotide in "ATCG" for nucleotide in result.predicted_dna) 226 | ) 227 | 228 | def test_predict_dna_sequence_concurrent_calls(self): 229 | """Test the function's behavior under concurrent execution.""" 230 | import threading 231 | 232 | protein_sequence = "MWWMW" 233 | organism = "Escherichia coli general" 234 | results = [] 235 | 236 | def call_predict(): 237 | result = predict_dna_sequence( 238 | protein=protein_sequence, 239 | organism=organism, 240 | device=self.device, 241 | tokenizer=self.tokenizer, 242 | model=self.model, 243 | deterministic=True, 244 | ) 245 | results.append(result.predicted_dna) 246 | 247 | threads = [threading.Thread(target=call_predict) for _ in range(10)] 248 | for thread in threads: 249 | thread.start() 250 | for thread in threads: 251 | thread.join() 252 | 253 | self.assertEqual(len(results), 10) 254 | self.assertTrue(all(dna == results[0] for dna in results)) 255 | 256 | def test_predict_dna_sequence_random_seed_consistency(self): 257 | """Test that setting a random seed results in consistent outputs in non-deterministic mode.""" 258 | protein_sequence = "MFWY" 259 | organism = "Escherichia coli general" 260 | temperature = 0.5 261 | top_p = 0.95 262 | torch.manual_seed(42) 263 | 264 | result1 = predict_dna_sequence( 265 | protein=protein_sequence, 266 | organism=organism, 267 | device=self.device, 268 | tokenizer=self.tokenizer, 269 | model=self.model, 270 | deterministic=False, 271 | temperature=temperature, 272 | top_p=top_p, 273 | ) 274 | 275 | torch.manual_seed(42) 276 | 277 | result2 = predict_dna_sequence( 278 | protein=protein_sequence, 279 | organism=organism, 280 | device=self.device, 281 | tokenizer=self.tokenizer, 282 | model=self.model, 283 | deterministic=False, 284 | temperature=temperature, 285 | top_p=top_p, 286 | ) 287 | 288 | self.assertEqual( 289 | result1.predicted_dna, 290 | result2.predicted_dna, 291 | "Outputs should be consistent when random seed is set", 292 | ) 293 | 294 | def test_predict_dna_sequence_invalid_tokenizer_and_model(self): 295 | """Test that providing invalid tokenizer or model raises appropriate exceptions.""" 296 | protein_sequence = "MWWMW" 297 | organism = "Escherichia coli general" 298 | 299 | with self.subTest("Invalid tokenizer"): 300 | with self.assertRaises(Exception): 301 | predict_dna_sequence( 302 | protein=protein_sequence, 303 | organism=organism, 304 | device=self.device, 305 | tokenizer="invalid_tokenizer_path", 306 | model=self.model, 307 | ) 308 | 309 | with self.subTest("Invalid model"): 310 | with self.assertRaises(Exception): 311 | predict_dna_sequence( 312 | protein=protein_sequence, 313 | organism=organism, 314 | device=self.device, 315 | tokenizer=self.tokenizer, 316 | model="invalid_model_path", 317 | ) 318 | 319 | def test_predict_dna_sequence_stop_codon_handling(self): 320 | """Test the function's handling of protein sequences ending with a non '_' or '*' stop symbol.""" 321 | protein_sequence = "MWW/" 322 | organism = "Escherichia coli general" 323 | 324 | with self.assertRaises(ValueError): 325 | predict_dna_sequence( 326 | protein=protein_sequence, 327 | organism=organism, 328 | device=self.device, 329 | tokenizer=self.tokenizer, 330 | model=self.model, 331 | ) 332 | 333 | def test_predict_dna_sequence_device_compatibility(self): 334 | """Test that the function works correctly on both CPU and GPU devices.""" 335 | protein_sequence = "MWWMW" 336 | organism = "Escherichia coli general" 337 | 338 | devices = [torch.device("cpu")] 339 | if torch.cuda.is_available(): 340 | devices.append(torch.device("cuda")) 341 | 342 | for device in devices: 343 | with self.subTest(device=device): 344 | result = predict_dna_sequence( 345 | protein=protein_sequence, 346 | organism=organism, 347 | device=device, 348 | tokenizer=self.tokenizer, 349 | model=self.model, 350 | deterministic=True, 351 | ) 352 | self.assertIsInstance(result.predicted_dna, str) 353 | self.assertTrue( 354 | all(nucleotide in "ATCG" for nucleotide in result.predicted_dna) 355 | ) 356 | 357 | def test_predict_dna_sequence_random_proteins(self): 358 | """Test random proteins to ensure translated DNA matches the original protein.""" 359 | organism = "Escherichia coli general" 360 | num_tests = 200 361 | 362 | for _ in range(num_tests): 363 | # Generate a random protein sequence of random length between 10 and 50 364 | protein_length = random.randint(10, 500) 365 | protein_sequence = "M" + "".join( 366 | random.choices(AMINO_ACIDS, k=protein_length - 1) 367 | ) 368 | protein_sequence += random.choice(STOP_SYMBOLS) 369 | 370 | result = predict_dna_sequence( 371 | protein=protein_sequence, 372 | organism=organism, 373 | device=self.device, 374 | tokenizer=self.tokenizer, 375 | model=self.model, 376 | deterministic=True, 377 | ) 378 | 379 | # Remove stop codon from predicted DNA 380 | dna_sequence = result.predicted_dna[:-3] 381 | 382 | # Translate predicted DNA back to protein 383 | translated_protein = get_amino_acid_sequence(dna_sequence) 384 | self.assertEqual( 385 | translated_protein, 386 | protein_sequence[:-1], # Remove stop symbol 387 | f"Translated protein does not match the original protein sequence for protein: {protein_sequence}", 388 | ) 389 | 390 | def test_predict_dna_sequence_long_protein_over_max_length(self): 391 | """Test that the model handles protein sequences longer than 2048 amino acids.""" 392 | # Create a protein sequence longer than 2048 amino acids 393 | base_sequence = ( 394 | "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG" 395 | ) 396 | protein_sequence = base_sequence * 100 # Length > 2048 amino acids 397 | organism = "Escherichia coli general" 398 | 399 | result = predict_dna_sequence( 400 | protein=protein_sequence, 401 | organism=organism, 402 | device=self.device, 403 | tokenizer=self.tokenizer, 404 | model=self.model, 405 | deterministic=True, 406 | ) 407 | 408 | # Remove stop codon from predicted DNA 409 | dna_sequence = result.predicted_dna[:-3] 410 | translated_protein = get_amino_acid_sequence(dna_sequence) 411 | 412 | # Due to potential model limitations, compare up to the model's max supported length 413 | max_length = len(translated_protein) 414 | self.assertEqual( 415 | translated_protein[:max_length], 416 | protein_sequence[:max_length], 417 | "Translated protein does not match the original protein sequence up to the maximum length supported.", 418 | ) 419 | 420 | def test_predict_dna_sequence_multi_output(self): 421 | """Test that the function returns multiple sequences when num_sequences > 1.""" 422 | protein_sequence = "MFQLLAPWY" 423 | organism = "Escherichia coli general" 424 | num_sequences = 20 425 | 426 | result = predict_dna_sequence( 427 | protein=protein_sequence, 428 | organism=organism, 429 | device=self.device, 430 | tokenizer=self.tokenizer, 431 | model=self.model, 432 | deterministic=False, 433 | num_sequences=num_sequences, 434 | ) 435 | 436 | self.assertIsInstance(result, list) 437 | self.assertEqual(len(result), num_sequences) 438 | 439 | for prediction in result: 440 | self.assertIsInstance(prediction, DNASequencePrediction) 441 | self.assertTrue( 442 | all(nucleotide in "ATCG" for nucleotide in prediction.predicted_dna) 443 | ) 444 | 445 | # Check that all predicted DNA sequences translate back to the original protein 446 | translated_protein = get_amino_acid_sequence(prediction.predicted_dna[:-3]) 447 | self.assertEqual(translated_protein, protein_sequence) 448 | 449 | def test_predict_dna_sequence_deterministic_multi_raises_error(self): 450 | """Test that requesting multiple sequences in deterministic mode raises an error.""" 451 | protein_sequence = "MFWY" 452 | organism = "Escherichia coli general" 453 | 454 | with self.assertRaises(ValueError): 455 | predict_dna_sequence( 456 | protein=protein_sequence, 457 | organism=organism, 458 | device=self.device, 459 | tokenizer=self.tokenizer, 460 | model=self.model, 461 | deterministic=True, 462 | num_sequences=3, 463 | ) 464 | 465 | def test_predict_dna_sequence_multi_diversity(self): 466 | """Test that multiple sequences generated are diverse.""" 467 | protein_sequence = "MFWYMFWY" 468 | organism = "Escherichia coli general" 469 | num_sequences = 10 470 | 471 | result = predict_dna_sequence( 472 | protein=protein_sequence, 473 | organism=organism, 474 | device=self.device, 475 | tokenizer=self.tokenizer, 476 | model=self.model, 477 | deterministic=False, 478 | num_sequences=num_sequences, 479 | temperature=0.8, 480 | ) 481 | 482 | unique_sequences = set(prediction.predicted_dna for prediction in result) 483 | 484 | self.assertGreater( 485 | len(unique_sequences), 486 | 2, 487 | "Multiple sequence generation should produce diverse results", 488 | ) 489 | 490 | # Check that all sequences are valid translations of the input protein 491 | for prediction in result: 492 | translated_protein = get_amino_acid_sequence(prediction.predicted_dna[:-3]) 493 | self.assertEqual(translated_protein, protein_sequence) 494 | 495 | def test_predict_dna_sequence_match_protein_repetitive(self): 496 | """Test that match_protein=True correctly handles highly repetitive and unconventional sequences.""" 497 | test_sequences = ( 498 | "QQQQQQQQQQQQQQQQ_", 499 | "KRKRKRKRKRKRKRKR_", 500 | "PGPGPGPGPGPGPGPG_", 501 | "DEDEDEDEDEDEDEDEDE_", 502 | "M_M_M_M_M_", 503 | "MMMMMMMMMM_", 504 | "WWWWWWWWWW_", 505 | "CCCCCCCCCC_", 506 | "MWCHMWCHMWCH_", 507 | "Q_QQ_QQQ_QQQQ_", 508 | "MWMWMWMWMWMW_", 509 | "CCCHHHMMMWWW_", 510 | "_", 511 | "M_", 512 | "MGWC_", 513 | ) 514 | 515 | organism = "Homo sapiens" 516 | 517 | for protein_sequence in test_sequences: 518 | # Generate sequence with match_protein=True 519 | result = predict_dna_sequence( 520 | protein=protein_sequence, 521 | organism=organism, 522 | device=self.device, 523 | tokenizer=self.tokenizer, 524 | model=self.model, 525 | deterministic=False, 526 | temperature=20, # High temperature to test protein matching 527 | match_protein=True, 528 | ) 529 | 530 | dna_sequence = result.predicted_dna 531 | translated_protein = get_amino_acid_sequence(dna_sequence) 532 | 533 | self.assertEqual( 534 | translated_protein, 535 | protein_sequence, 536 | f"Translated protein must match original when match_protein=True. Failed for sequence: {protein_sequence}", 537 | ) 538 | 539 | def test_predict_dna_sequence_match_protein_rare_amino_acids(self): 540 | """Test match_protein with rare amino acids that have limited codon options.""" 541 | # Methionine (M) and Tryptophan (W) have only one codon each 542 | # While Leucine (L) has 6 codons - testing contrast 543 | protein_sequence = "MWLLLMWLLL" 544 | organism = "Escherichia coli general" 545 | 546 | # Run multiple predictions 547 | results = [] 548 | num_iterations = 10 549 | 550 | for _ in range(num_iterations): 551 | result = predict_dna_sequence( 552 | protein=protein_sequence, 553 | organism=organism, 554 | device=self.device, 555 | tokenizer=self.tokenizer, 556 | model=self.model, 557 | deterministic=False, 558 | temperature=20, # High temperature to test protein matching 559 | match_protein=True, 560 | ) 561 | results.append(result.predicted_dna) 562 | 563 | # Check all sequences 564 | for dna_sequence in results: 565 | # Verify M always uses ATG 566 | m_positions = [0, 5] # Known positions of M in sequence 567 | for pos in m_positions: 568 | self.assertEqual( 569 | dna_sequence[pos * 3 : (pos + 1) * 3], 570 | "ATG", 571 | "Methionine must use ATG codon.", 572 | ) 573 | 574 | # Verify W always uses TGG 575 | w_positions = [1, 6] # Known positions of W in sequence 576 | for pos in w_positions: 577 | self.assertEqual( 578 | dna_sequence[pos * 3 : (pos + 1) * 3], 579 | "TGG", 580 | "Tryptophan must use TGG codon.", 581 | ) 582 | 583 | # Verify all L codons are valid 584 | l_positions = [2, 3, 4, 7, 8, 9] # Known positions of L in sequence 585 | l_codons = [dna_sequence[pos * 3 : (pos + 1) * 3] for pos in l_positions] 586 | valid_l_codons = {"TTA", "TTG", "CTT", "CTC", "CTA", "CTG"} 587 | self.assertTrue( 588 | all(codon in valid_l_codons for codon in l_codons), 589 | "All Leucine codons must be valid", 590 | ) 591 | 592 | 593 | if __name__ == "__main__": 594 | unittest.main() 595 | -------------------------------------------------------------------------------- /tests/test_CodonUtils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import tempfile 4 | import unittest 5 | 6 | from CodonTransformer.CodonUtils import ( 7 | ProteinConfig, 8 | find_pattern_in_fasta, 9 | get_organism2id_dict, 10 | get_taxonomy_id, 11 | load_pkl_from_url, 12 | load_python_object_from_disk, 13 | save_python_object_to_disk, 14 | sort_amino2codon_skeleton, 15 | ) 16 | 17 | 18 | class TestCodonUtils(unittest.TestCase): 19 | def test_config_manager(self): 20 | with ProteinConfig() as config: 21 | config.set("ambiguous_aminoacid_behavior", "standardize_deterministic") 22 | self.assertEqual( 23 | config.get("ambiguous_aminoacid_behavior"), "standardize_deterministic" 24 | ) 25 | config.set("ambiguous_aminoacid_map_override", {"X": ["A", "G"]}) 26 | self.assertEqual( 27 | config.get("ambiguous_aminoacid_map_override"), {"X": ["A", "G"]} 28 | ) 29 | config.update( 30 | { 31 | "ambiguous_aminoacid_behavior": "raise_error", 32 | "ambiguous_aminoacid_map_override": {"X": ["A", "G"]}, 33 | } 34 | ) 35 | self.assertEqual(config.get("ambiguous_aminoacid_behavior"), "raise_error") 36 | self.assertEqual( 37 | config.get("ambiguous_aminoacid_map_override"), {"X": ["A", "G"]} 38 | ) 39 | try: 40 | config.set("invalid_key", "invalid_value") 41 | self.fail("Expected ValueError") 42 | except ValueError: 43 | pass 44 | with ProteinConfig() as config: 45 | self.assertEqual( 46 | config.get("ambiguous_aminoacid_behavior"), "standardize_random" 47 | ) 48 | self.assertEqual(config.get("ambiguous_aminoacid_map_override"), {}) 49 | 50 | def test_load_python_object_from_disk(self): 51 | test_obj = {"key1": "value1", "key2": 2} 52 | with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as temp_file: 53 | temp_file_name = temp_file.name 54 | save_python_object_to_disk(test_obj, temp_file_name) 55 | loaded_obj = load_python_object_from_disk(temp_file_name) 56 | self.assertEqual(test_obj, loaded_obj) 57 | os.remove(temp_file_name) 58 | 59 | def test_save_python_object_to_disk(self): 60 | test_obj = [1, 2, 3, 4, 5] 61 | with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as temp_file: 62 | temp_file_name = temp_file.name 63 | save_python_object_to_disk(test_obj, temp_file_name) 64 | self.assertTrue(os.path.exists(temp_file_name)) 65 | os.remove(temp_file_name) 66 | 67 | def test_find_pattern_in_fasta(self): 68 | text = ( 69 | ">seq1 [keyword=value1]\nATGCGTACGTAGCTAG\n" 70 | ">seq2 [keyword=value2]\nGGTACGATCGATCGAT" 71 | ) 72 | self.assertEqual(find_pattern_in_fasta("keyword", text), "value1") 73 | self.assertEqual(find_pattern_in_fasta("nonexistent", text), "") 74 | 75 | def test_get_organism2id_dict(self): 76 | with tempfile.NamedTemporaryFile( 77 | mode="w", delete=True, suffix=".csv" 78 | ) as temp_file: 79 | temp_file.write("0,Escherichia coli\n1,Homo sapiens\n2,Mus musculus") 80 | temp_file.flush() 81 | organism2id = get_organism2id_dict(temp_file.name) 82 | self.assertEqual( 83 | organism2id, 84 | {"Escherichia coli": 0, "Homo sapiens": 1, "Mus musculus": 2}, 85 | ) 86 | 87 | def test_get_taxonomy_id(self): 88 | taxonomy_dict = { 89 | "Escherichia coli": 562, 90 | "Homo sapiens": 9606, 91 | "Mus musculus": 10090, 92 | } 93 | with tempfile.NamedTemporaryFile(suffix=".pkl", delete=True) as temp_file: 94 | temp_file_name = temp_file.name 95 | save_python_object_to_disk(taxonomy_dict, temp_file_name) 96 | self.assertEqual(get_taxonomy_id(temp_file_name, "Escherichia coli"), 562) 97 | self.assertEqual( 98 | get_taxonomy_id(temp_file_name, return_dict=True), taxonomy_dict 99 | ) 100 | 101 | def test_sort_amino2codon_skeleton(self): 102 | amino2codon = { 103 | "A": (["GCT", "GCC", "GCA", "GCG"], [0.0, 0.0, 0.0, 0.0]), 104 | "C": (["TGT", "TGC"], [0.0, 0.0]), 105 | } 106 | sorted_amino2codon = sort_amino2codon_skeleton(amino2codon) 107 | self.assertEqual( 108 | sorted_amino2codon, 109 | { 110 | "A": (["GCA", "GCC", "GCG", "GCT"], [0.0, 0.0, 0.0, 0.0]), 111 | "C": (["TGC", "TGT"], [0.0, 0.0]), 112 | }, 113 | ) 114 | 115 | def test_load_pkl_from_url(self): 116 | url = "https://example.com/test.pkl" 117 | expected_obj = {"key": "value"} 118 | with unittest.mock.patch("requests.get") as mock_get: 119 | mock_get.return_value.content = pickle.dumps(expected_obj) 120 | loaded_obj = load_pkl_from_url(url) 121 | self.assertEqual(loaded_obj, expected_obj) 122 | 123 | 124 | if __name__ == "__main__": 125 | unittest.main() 126 | --------------------------------------------------------------------------------