├── nucleotide_transformer ├── __init__.py ├── borzoi │ ├── __init__.py │ ├── pretrained.py │ └── layers.py ├── chatNT │ ├── __init__.py │ ├── types.py │ ├── params.py │ ├── pretrained.py │ └── gpt_rotary.py ├── enformer │ ├── __init__.py │ ├── features.py │ ├── params.py │ ├── pretrained.py │ ├── heads.py │ └── layers.py ├── mojo │ ├── __init__.py │ ├── pretrained.py │ ├── config.py │ ├── layers.py │ └── model.py ├── bulk_rna_bert │ ├── __init__.py │ ├── layers.py │ ├── config.py │ ├── pretrained.py │ ├── model.py │ └── tokenizer.py ├── sCellTransformer │ ├── __init__.py │ ├── __pycache__ │ │ ├── model.cpython-310.pyc │ │ └── params.cpython-310.pyc │ └── params.py ├── constants.py ├── types.py ├── utils.py └── heads.py ├── imgs ├── agront.webp ├── codon_nt.png ├── isoformer.png ├── chatNT_figures.png ├── instadeep_logo.png ├── nt_rebuttal_results.png ├── nt_results_rebuttal_2.png ├── Agro_NT_Gene_Expression.png ├── finetuning_results_transp.png └── segment_nt_panel1_screen.png ├── .pre-commit-config.yaml ├── setup.py ├── docs ├── isoformer.md ├── mojo.md ├── bulk_rna_bert.md ├── sct.md ├── codon_nt.md ├── agro_nucleotide_transformer.md ├── chat_nt.md ├── nucleotide_transformer.md └── segment_nt.md ├── notebooks ├── mojo │ ├── inference_mojo_pytorch_example.ipynb │ └── inference_mojo_jax_example.ipynb ├── bulk_rna_bert │ ├── inference_bulkrnabert_pytorch_example.ipynb │ └── inference_bulkrnaert_jax_example.ipynb ├── isoformer │ └── inference.ipynb └── chat_nt │ └── inference.ipynb └── README.md /nucleotide_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nucleotide_transformer/borzoi/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nucleotide_transformer/chatNT/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nucleotide_transformer/enformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nucleotide_transformer/mojo/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nucleotide_transformer/bulk_rna_bert/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nucleotide_transformer/sCellTransformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imgs/agront.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/nucleotide-transformer/HEAD/imgs/agront.webp -------------------------------------------------------------------------------- /imgs/codon_nt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/nucleotide-transformer/HEAD/imgs/codon_nt.png -------------------------------------------------------------------------------- /imgs/isoformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/nucleotide-transformer/HEAD/imgs/isoformer.png -------------------------------------------------------------------------------- /imgs/chatNT_figures.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/nucleotide-transformer/HEAD/imgs/chatNT_figures.png -------------------------------------------------------------------------------- /imgs/instadeep_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/nucleotide-transformer/HEAD/imgs/instadeep_logo.png -------------------------------------------------------------------------------- /imgs/nt_rebuttal_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/nucleotide-transformer/HEAD/imgs/nt_rebuttal_results.png -------------------------------------------------------------------------------- /imgs/nt_results_rebuttal_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/nucleotide-transformer/HEAD/imgs/nt_results_rebuttal_2.png -------------------------------------------------------------------------------- /imgs/Agro_NT_Gene_Expression.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/nucleotide-transformer/HEAD/imgs/Agro_NT_Gene_Expression.png -------------------------------------------------------------------------------- /imgs/finetuning_results_transp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/nucleotide-transformer/HEAD/imgs/finetuning_results_transp.png -------------------------------------------------------------------------------- /imgs/segment_nt_panel1_screen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/nucleotide-transformer/HEAD/imgs/segment_nt_panel1_screen.png -------------------------------------------------------------------------------- /nucleotide_transformer/sCellTransformer/__pycache__/model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/nucleotide-transformer/HEAD/nucleotide_transformer/sCellTransformer/__pycache__/model.cpython-310.pyc -------------------------------------------------------------------------------- /nucleotide_transformer/sCellTransformer/__pycache__/params.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/nucleotide-transformer/HEAD/nucleotide_transformer/sCellTransformer/__pycache__/params.cpython-310.pyc -------------------------------------------------------------------------------- /nucleotide_transformer/enformer/features.py: -------------------------------------------------------------------------------- 1 | FEATURES = [ 2 | "protein_coding_gene", 3 | "lncRNA", 4 | "exon", 5 | "intron", 6 | "splice_donor", 7 | "splice_acceptor", 8 | "5UTR", 9 | "3UTR", 10 | "CTCF-bound", 11 | "polyA_signal", 12 | "enhancer_Tissue_specific", 13 | "enhancer_Tissue_invariant", 14 | "promoter_Tissue_specific", 15 | "promoter_Tissue_invariant", 16 | ] 17 | -------------------------------------------------------------------------------- /nucleotide_transformer/chatNT/types.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | 3 | import jax.numpy as jnp 4 | from typing_extensions import TypeAlias 5 | 6 | RNGKey: TypeAlias = jnp.ndarray 7 | Embedding: TypeAlias = jnp.ndarray 8 | Tokens: TypeAlias = jnp.ndarray 9 | Labels: TypeAlias = jnp.ndarray 10 | Images: TypeAlias = jnp.ndarray 11 | AttentionMask: TypeAlias = jnp.ndarray 12 | SequenceMask: TypeAlias = jnp.ndarray 13 | AttentionWeights: TypeAlias = jnp.ndarray 14 | TransformerOutput: TypeAlias = Dict[str, jnp.ndarray] 15 | MultiOmicsTokens: TypeAlias = Tuple[jnp.ndarray, jnp.ndarray] 16 | -------------------------------------------------------------------------------- /nucleotide_transformer/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 InstaDeep Ltd 2 | # 3 | # Licensed under the Creative Commons BY-NC-SA 4.0 License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://creativecommons.org/licenses/by-nc-sa/4.0/ 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | NUCLEOTIDES = ["A", "T", "C", "G"] 16 | VALID_EXTRA_NUCLEOTIDES = ["N", "M", "Y", "B", "S", "W", "K", "H", "D", "V", "R"] 17 | EXTRA_NUCLEOTIDES = ["N"] 18 | -------------------------------------------------------------------------------- /nucleotide_transformer/types.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 InstaDeep Ltd 2 | # 3 | # Licensed under the Creative Commons BY-NC-SA 4.0 License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://creativecommons.org/licenses/by-nc-sa/4.0/ 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Dict 16 | 17 | import jax.numpy as jnp 18 | from typing_extensions import TypeAlias 19 | 20 | Embedding: TypeAlias = jnp.ndarray 21 | Tokens: TypeAlias = jnp.ndarray 22 | AttentionMask: TypeAlias = jnp.ndarray 23 | SequenceMask: TypeAlias = jnp.ndarray 24 | TransformerOutput: TypeAlias = Dict[str, jnp.ndarray] # type: ignore 25 | -------------------------------------------------------------------------------- /nucleotide_transformer/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | SUPPORTED_FFN_ACTIVATIONS = ["gelu", "gelu-no-approx", "relu", "swish", "silu", "sin"] 7 | 8 | 9 | def get_activation_fn(activation_name: str) -> Callable: 10 | """ 11 | Return activation fn given its name. 12 | Args: 13 | activation_name: Activation name. 14 | 15 | Returns: 16 | activation function. 17 | """ 18 | if activation_name not in SUPPORTED_FFN_ACTIVATIONS: 19 | raise NotImplementedError( 20 | f"Activation {activation_name} not supported yet. " 21 | f"Supported activations for feed forward " 22 | f"block are {SUPPORTED_FFN_ACTIVATIONS}" 23 | ) 24 | if activation_name == "gelu-no-approx": 25 | activation_fn = lambda x: jax.nn.gelu(x, approximate=False) # noqa: E731 26 | elif activation_name == "sin": 27 | activation_fn = lambda x: jnp.sin(x) # noqa: E731 28 | else: 29 | activation_fn = getattr(jax.nn, activation_name) 30 | return activation_fn 31 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pycqa/isort 3 | rev: 5.12.0 4 | hooks: 5 | - id: isort 6 | args: ["--profile", "black"] 7 | - repo: https://github.com/psf/black 8 | rev: 22.10.0 9 | hooks: 10 | - id: black 11 | name: "Format code (black)" 12 | - repo: https://github.com/pycqa/flake8 13 | rev: 6.0.0 14 | hooks: 15 | - id: flake8 16 | args: 17 | [ 18 | "--max-line-length=88", 19 | "--extend-ignore=E203", 20 | "--kwargs-max-positional-arguments=6", 21 | ] 22 | additional_dependencies: 23 | [ 24 | "flake8-bugbear==23.2.13", 25 | "flake8-builtins==2.1.0", 26 | "flake8-comprehensions==3.10.1", 27 | "flake8-class-attributes-order==0.1.3", 28 | "pep8-naming==0.13.3", 29 | "flake8-force-keyword-arguments==1.0.4", 30 | ] 31 | - repo: https://github.com/pre-commit/mirrors-mypy 32 | rev: v0.991 33 | hooks: 34 | - id: mypy 35 | - repo: https://github.com/pre-commit/pre-commit-hooks 36 | rev: v4.2.0 37 | hooks: 38 | - id: check-merge-conflict 39 | - id: end-of-file-fixer 40 | - id: requirements-txt-fixer 41 | - id: trailing-whitespace 42 | - repo: https://github.com/asottile/pyupgrade 43 | rev: v3.3.1 44 | hooks: 45 | - id: pyupgrade 46 | args: [--py38-plus] 47 | -------------------------------------------------------------------------------- /nucleotide_transformer/bulk_rna_bert/layers.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import haiku as hk 4 | import jax.numpy as jnp 5 | from haiku import initializers 6 | 7 | 8 | class SimpleLMHead(hk.Module): 9 | """ 10 | Basic Language Model head. Transforms final attention block output 11 | into a distribution over tokens at each sequence position. 12 | """ 13 | 14 | def __init__( 15 | self, 16 | embed_dim: int, 17 | alphabet_size: int, 18 | add_bias_lm_head: bool = True, 19 | name: Optional[str] = None, 20 | ): 21 | """ 22 | Args: 23 | embed_dim: Embedding dimension. 24 | alphabet_size: Number of tokens in the alphabet. 25 | name: Name of the layer. Defaults to None. 26 | """ 27 | super().__init__(name=name) 28 | self.embed_dim = embed_dim 29 | self.alphabet_size = alphabet_size 30 | 31 | # Define layers 32 | w_init = initializers.VarianceScaling(2.0, "fan_in", "uniform") 33 | b_init = initializers.VarianceScaling(2.0, "fan_in", "uniform") 34 | self._final_fc = hk.Linear( 35 | self.alphabet_size, 36 | w_init=w_init, 37 | b_init=b_init, 38 | with_bias=add_bias_lm_head, 39 | name="lm_final_fc", 40 | ) 41 | 42 | def __call__(self, x: jnp.ndarray) -> dict[str, jnp.ndarray]: 43 | # Compute logits 44 | logits = self._final_fc(x) 45 | return {"logits": logits} 46 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from setuptools import find_packages, setup 4 | 5 | CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) 6 | 7 | with open(os.path.join(CURRENT_DIR, "README.md"), encoding="utf-8") as f: 8 | long_description = f.read() 9 | 10 | setup( 11 | name="nucleotide_transformer", 12 | version="0.0.1", 13 | packages=find_packages(), 14 | url="https://github.com/instadeepai/nucleotide-transformer", 15 | license="CC BY-NC-SA 4.0", 16 | author="InstaDeep Ltd", 17 | python_requires=">=3.9", 18 | description="The Nucleotide Transformer: Building and Evaluating " 19 | "Robust Foundation Models for Human Genomics ", 20 | long_description=long_description, 21 | long_description_content_type="text/markdown", 22 | install_requires=[ 23 | "absl-py>=1.0.0", 24 | "jax>=0.3.25", 25 | "jaxlib>=0.3.25", 26 | "dm-haiku>=0.0.9", 27 | "numpy>=1.23.5,<2.0.0", 28 | "typing_extensions>=3.10.0", 29 | "joblib>=1.2.0", 30 | "tqdm>=4.56.0", 31 | "regex>=2022.1.18", 32 | "huggingface-hub>=0.23.0", 33 | "pydantic==1.10.13", 34 | ], 35 | dependency_links=[ 36 | "https://storage.googleapis.com/jax-releases/jax_releases.html", 37 | ], 38 | keywords=["Genomics", "Language Model", "Deep Learning", "JAX"], 39 | classifiers=[ 40 | "Development Status :: 4 - Beta", 41 | "Environment :: Console", 42 | "Intended Audience :: Science/Research", 43 | "License :: OSI Approved :: MIT License", 44 | "Operating System :: POSIX :: Linux", 45 | "Programming Language :: Python :: 3", 46 | "Programming Language :: Python :: 3.8", 47 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 48 | ], 49 | ) 50 | -------------------------------------------------------------------------------- /nucleotide_transformer/chatNT/params.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 InstaDeep Ltd 2 | # 3 | # Licensed under the Creative Commons BY-NC-SA 4.0 License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://creativecommons.org/licenses/by-nc-sa/4.0/ 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | import os 17 | from typing import Any, Dict, Tuple 18 | 19 | import haiku as hk 20 | import joblib 21 | from huggingface_hub import hf_hub_download 22 | 23 | ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME" 24 | DEFAULT_CACHE_DIR = "~/.cache" 25 | 26 | 27 | def _get_dir(model_name: str) -> str: 28 | """ 29 | Get directory to save files on user machine. 30 | """ 31 | return os.path.expanduser( 32 | os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), model_name) 33 | ) 34 | 35 | 36 | def download_ckpt() -> Tuple[hk.Params, Any]: 37 | """ 38 | Download checkpoint on kao datacenter. 39 | 40 | Args: 41 | model_name: Name of the model. 42 | 43 | Returns: 44 | Model parameters. 45 | Model state 46 | 47 | 48 | """ 49 | 50 | save_dir = os.path.join(_get_dir("ChatNT"), "ChatNT") 51 | 52 | repo_id = f"InstaDeepAI/ChatNT" 53 | 54 | # Download parameters 55 | print("Downloading model's weights...") 56 | params = joblib.load( 57 | hf_hub_download( 58 | repo_id=repo_id, 59 | filename="jax_params/params.joblib", 60 | cache_dir=save_dir, 61 | ) 62 | ) 63 | 64 | return params 65 | -------------------------------------------------------------------------------- /nucleotide_transformer/enformer/params.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 InstaDeep Ltd 2 | # 3 | # Licensed under the Creative Commons BY-NC-SA 4.0 License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://creativecommons.org/licenses/by-nc-sa/4.0/ 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from typing import Any, Dict, Tuple 17 | 18 | import haiku as hk 19 | import joblib 20 | from huggingface_hub import hf_hub_download 21 | 22 | ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME" 23 | DEFAULT_CACHE_DIR = "~/.cache" 24 | 25 | 26 | def _get_dir(model_name: str) -> str: 27 | """ 28 | Get directory to save files on user machine. 29 | """ 30 | assert model_name in ["segment_enformer", "segment_borzoi"] 31 | return os.path.expanduser( 32 | os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), model_name) 33 | ) 34 | 35 | 36 | def download_ckpt(model_name: str) -> Tuple[hk.Params, Dict[str, Any]]: 37 | """ 38 | Download checkpoint on kao datacenter. 39 | 40 | Args: 41 | model_name: Name of the model. 42 | 43 | Returns: 44 | Model parameters. 45 | Model state 46 | 47 | 48 | """ 49 | assert model_name in ["segment_enformer", "segment_borzoi"] 50 | save_dir = os.path.join(_get_dir(model_name), model_name) 51 | 52 | repo_id = f"InstaDeepAI/{model_name}" 53 | 54 | # Download parameters 55 | print("Downloading model's weights...") 56 | params = joblib.load( 57 | hf_hub_download( 58 | repo_id=repo_id, filename="jax_model/params.joblib", cache_dir=save_dir 59 | ) 60 | ) 61 | state = joblib.load( 62 | hf_hub_download( 63 | repo_id=repo_id, filename="jax_model/state.joblib", cache_dir=save_dir 64 | ) 65 | ) 66 | 67 | return params, state 68 | -------------------------------------------------------------------------------- /docs/isoformer.md: -------------------------------------------------------------------------------- 1 | # Isoformer 2 | 3 | Isoformer is a model able to accurately predict differential transcript expression, outperforming existing methods and leveraging the use of multiple modalities. 4 | Our framework achieves efficient transfer knowledge from three pre-trained encoders: Enformer for the DNA modality, Nucleotide Transformer v2 for 5 | the RNA modality and ESM2 for the protein modality. 6 | 7 | * 📜 **[Read the Paper (NeurIPS 2024)](https://papers.nips.cc/paper_files/paper/2024/file/8f6b3692297e49e5d5c91ba00281379c-Paper-Conference.pdf)** 8 | * 🤗 **[Hugging Face Link](https://huggingface.co/InstaDeepAI/isoformer)** 9 | * 🚀 **[Isoformer Inference Notebook (HF)](../notebooks/isoformer/inference.ipynb)** 10 | 11 | Isoformer 12 | 13 | # Training data 14 | 15 | Isoformer is trained on RNA transcript expression data obtained from the GTex portal, 16 | namely Transcript TPMs measurements across 30 tissues which come from more than 5000 individuals. 17 | In total, the dataset is made of ∼170k unique transcripts, of which 90k are protein-coding and correspond to ∼20k unique genes. 18 | 19 | ## How to use 🚀 20 | 21 | We make Isoformer available on HuggingFace and provide an example on how to use it at `./notebooks/isoformer/inference.ipynb`. 22 | 23 | ## Citing our work 📚 24 | 25 | You can cite our model at: 26 | 27 | ```bibtex 28 | @inproceedings{NEURIPS2024_8f6b3692, 29 | author = {Garau-Luis, Juan Jose and Bordes, Patrick and Gonzalez, Liam and Roller, Masa and de Almeida, Bernardo P. and Hexemer, Lorenz and Blum, Christopher and Laurent, Stefan and Grzegorzewski, Jan and Lang, Maren and Pierrot, Thomas and Richard, Guillaume}, 30 | booktitle = {Advances in Neural Information Processing Systems}, 31 | editor = {A. Globerson and L. Mackey and D. Belgrave and A. Fan and U. Paquet and J. Tomczak and C. Zhang}, 32 | pages = {78431--78450}, 33 | publisher = {Curran Associates, Inc.}, 34 | title = {Multi-modal Transfer Learning between Biological Foundation Models}, 35 | url = {https://proceedings.neurips.cc/paper_files/paper/2024/file/8f6b3692297e49e5d5c91ba00281379c-Paper-Conference.pdf}, 36 | volume = {37}, 37 | year = {2024} 38 | } 39 | ``` 40 | -------------------------------------------------------------------------------- /nucleotide_transformer/sCellTransformer/params.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 InstaDeep Ltd 2 | # 3 | # Licensed under the Creative Commons BY-NC-SA 4.0 License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://creativecommons.org/licenses/by-nc-sa/4.0/ 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | import os 17 | from typing import Any, Dict, Tuple 18 | 19 | import haiku as hk 20 | import joblib 21 | from huggingface_hub import hf_hub_download 22 | 23 | from nucleotide_transformer.sCellTransformer.model import sCTConfig 24 | 25 | ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME" 26 | DEFAULT_CACHE_DIR = "~/.cache" 27 | 28 | 29 | def _get_dir(model_name: str) -> str: 30 | """ 31 | Get directory to save files on user machine. 32 | """ 33 | return os.path.expanduser( 34 | os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), model_name) 35 | ) 36 | 37 | 38 | def download_ckpt() -> Tuple[hk.Params, Any]: 39 | """ 40 | Download checkpoint on kao datacenter. 41 | 42 | Args: 43 | model_name: Name of the model. 44 | 45 | Returns: 46 | Model parameters. 47 | Model state 48 | 49 | 50 | """ 51 | 52 | save_dir = os.path.join(_get_dir("sCellTransformer"), "sCellTransformer") 53 | 54 | repo_id = f"InstaDeepAI/sCellTransformer" 55 | 56 | # Download parameters 57 | print("Downloading model's weights...") 58 | params = joblib.load( 59 | hf_hub_download( 60 | repo_id=repo_id, 61 | filename="jax_params/params.joblib", 62 | cache_dir=save_dir, 63 | ) 64 | ) 65 | 66 | config_path = hf_hub_download( 67 | repo_id=repo_id, 68 | filename="jax_params/config.json", 69 | ) 70 | with open(config_path, "r") as f: 71 | config_dict = json.load(f) 72 | config = sCTConfig(**config_dict) 73 | 74 | return params, config 75 | -------------------------------------------------------------------------------- /docs/mojo.md: -------------------------------------------------------------------------------- 1 | # MOJO 2 | 3 | MOJO (MultiOmics JOint representation learning) is a multimodal model designed learn embeddings of multi-omics data. It integrates bulk RNA-seq and DNA methylation data to generate powerful joint representations. These representations are tailored for improving predictive performance on downstream tasks like cancer-type classification and survival analysis. 4 | 5 | * 📜 **[Read the Paper (ICML Workshop on Generative AI and Biology 2025)](https://www.biorxiv.org/content/10.1101/2025.06.25.661237v1)** 6 | * 🤗 **[Hugging Face Link](https://huggingface.co/InstaDeepAI/MOJO)** 7 | 8 | ## Training data 9 | 10 | MOJO is trained on multi-modal data from The Cancer Genome Atlas (TCGA) dataset. The training data consists of paired bulk RNA-seq and DNA methylation profiles, enabling the model to learn the complex interplay between transcription and epigenetic regulation. 11 | The model takes as input sequences of 17116 gene expression and DNA methylation. The gene ids that must be used and in which order they should appear in the sequence as provided in `../notebooks/data/mojo_gene_names.txt` 12 | 13 | ## Training procedure 14 | 15 | MOJO uses a bimodal masked language modeling objective. It is trained to simultaneously predict masked values in both the RNA-seq and DNA methylation data modalities. This process forces the model to learn the intricate cross-modal relationships between gene expression and epigenetic regulation, leading to robust, integrated representations. 16 | 17 | ## How to use 🚀 18 | 19 | We make MOJO available in Jax in this repository and in PyTorch on HuggingFace. Examples on how to use it at: 20 | - Jax: `../notebooks/mojo/inference_mojo_jax_example.ipynb`. 21 | - PyTorch: `../notebooks/mojo/inference_mojo_pytorch_example.ipynb`. 22 | 23 | ## Citing our work 📚 24 | 25 | You can cite our model at: 26 | 27 | ```bibtex 28 | @article {G{\'e}lard2025.06.25.661237, 29 | author = {G{\'e}lard, Maxence and Benkirane, Hakim and Pierrot, Thomas and Richard, Guillaume and Courn{\`e}de, Paul-Henry}, 30 | title = {Bimodal masked language modeling for bulk RNA-seq and DNA methylation representation learning}, 31 | elocation-id = {2025.06.25.661237}, 32 | year = {2025}, 33 | doi = {10.1101/2025.06.25.661237}, 34 | publisher = {Cold Spring Harbor Laboratory}, 35 | URL = {https://www.biorxiv.org/content/early/2025/06/27/2025.06.25.661237}, 36 | journal = {bioRxiv} 37 | } 38 | ``` 39 | -------------------------------------------------------------------------------- /nucleotide_transformer/borzoi/pretrained.py: -------------------------------------------------------------------------------- 1 | """Implementation of utilities to load a pretrained Borzoi model in Trix.""" 2 | 3 | from typing import Callable, Tuple 4 | 5 | import haiku as hk 6 | import jax.numpy as jnp 7 | 8 | from nucleotide_transformer.borzoi.model import ( 9 | BorzoiConfig, 10 | build_borzoi_fn_with_head_fn, 11 | ) 12 | from nucleotide_transformer.enformer.features import FEATURES 13 | from nucleotide_transformer.enformer.heads import UNetHead 14 | from nucleotide_transformer.enformer.params import download_ckpt 15 | from nucleotide_transformer.enformer.tokenizer import NucleotidesKmersTokenizer 16 | 17 | 18 | def get_pretrained_segment_borzoi_model() -> ( 19 | Tuple[hk.Params, hk.State, Callable, NucleotidesKmersTokenizer, BorzoiConfig] 20 | ): 21 | """ 22 | Loads the pretrained SegmentBorzoi model. 23 | 24 | Returns: 25 | hk.Params: Model parameters 26 | hk.State: Model state 27 | Callable: Haiku forward function 28 | NucleotidesKmersTokenizer: Tokenizer 29 | BorzoiConfig: Configuration of the Borzoi model 30 | 31 | Example: 32 | >>> import jax 33 | >>> import haiku as hk 34 | >>> parameters, state, forward_fn, tokenizer, config = get_pretrained_segment_borzoi_model() 35 | >>> apply_fn = hk.transform_with_state(forward_fn).apply 36 | >>> random_key = jax.random.PRNGKey(seed=0) 37 | >>> sequences = ["A" * 524_288] 38 | >>> tokens = jax.numpy.asarray([b[1] for b in tokenizer.batch_tokenize(sequences)]) 39 | >>> outs, _ = apply_fn(parameters, state, random_key, tokens) 40 | """ 41 | config = BorzoiConfig() 42 | tokenizer = NucleotidesKmersTokenizer( 43 | k_mers=1, 44 | prepend_bos_token=False, 45 | prepend_cls_token=False, 46 | append_eos_token=False, 47 | tokens_to_ids=None, 48 | ) 49 | 50 | def head_fn() -> hk.Module: 51 | return UNetHead( 52 | features=FEATURES, 53 | embed_dimension=config.embed_dim, 54 | nucl_per_token=config.dim_divisible_by, 55 | remove_cls_token=False, 56 | ) 57 | 58 | forward_fn = build_borzoi_fn_with_head_fn( 59 | config=config, 60 | head_fn=head_fn, 61 | embedding_name="embedding", 62 | name="Borzoi", 63 | compute_dtype=jnp.float32, 64 | ) 65 | 66 | parameters, state = download_ckpt("segment_borzoi") 67 | 68 | return parameters, state, forward_fn, tokenizer, config 69 | -------------------------------------------------------------------------------- /nucleotide_transformer/enformer/pretrained.py: -------------------------------------------------------------------------------- 1 | """Implementation of utilities to load a pretrained Enformer model in Trix.""" 2 | 3 | from typing import Callable, Tuple 4 | 5 | import haiku as hk 6 | import jax.numpy as jnp 7 | 8 | from nucleotide_transformer.enformer.features import FEATURES 9 | from nucleotide_transformer.enformer.heads import UNetHead 10 | from nucleotide_transformer.enformer.model import ( 11 | EnformerConfig, 12 | build_enformer_with_head_fn, 13 | ) 14 | from nucleotide_transformer.enformer.params import download_ckpt 15 | from nucleotide_transformer.enformer.tokenizer import NucleotidesKmersTokenizer 16 | 17 | 18 | def get_pretrained_segment_enformer_model() -> ( 19 | Tuple[hk.Params, hk.State, Callable, NucleotidesKmersTokenizer, EnformerConfig] 20 | ): 21 | """ 22 | Loads the pretrained SegmentEnformer model. 23 | 24 | Returns: 25 | hk.Params: Model parameters 26 | hk.State: Model state 27 | Callable: Haiku forward function 28 | NucleotidesKmersTokenizer: Tokenizer 29 | EnformerConfig: Configuration of the Enformer model 30 | 31 | Example: 32 | >>> import jax 33 | >>> import haiku as hk 34 | >>> parameters, state, forward_fn, tokenizer, config = get_pretrained_segment_enformer_model() 35 | >>> apply_fn = hk.transform_with_state(forward_fn).apply 36 | >>> random_key = jax.random.PRNGKey(seed=0) 37 | >>> sequences = ["A" * 196608] 38 | >>> tokens = jax.numpy.asarray([b[1] for b in tokenizer.batch_tokenize(sequences)]) 39 | >>> outs, _ = apply_fn(parameters, state, random_key, tokens) 40 | """ 41 | config = EnformerConfig() 42 | tokenizer = NucleotidesKmersTokenizer( 43 | k_mers=1, 44 | prepend_bos_token=False, 45 | prepend_cls_token=False, 46 | append_eos_token=False, 47 | tokens_to_ids=None, 48 | ) 49 | 50 | def head_fn() -> hk.Module: 51 | return UNetHead( 52 | features=FEATURES, 53 | embed_dimension=config.embed_dim, 54 | nucl_per_token=config.dim_divisible_by, 55 | remove_cls_token=False, 56 | ) 57 | 58 | forward_fn = build_enformer_with_head_fn( 59 | config=config, 60 | head_fn=head_fn, 61 | embedding_name="embedding_transformer_tower", 62 | name="Enformer", 63 | compute_dtype=jnp.float32, 64 | ) 65 | 66 | parameters, state = download_ckpt("segment_enformer") 67 | 68 | return parameters, state, forward_fn, tokenizer, config 69 | -------------------------------------------------------------------------------- /docs/bulk_rna_bert.md: -------------------------------------------------------------------------------- 1 | # BulkRNABert 2 | 3 | BulkRNABert is a transformer-based, encoder-only foundation model designed for bulk RNA-seq data. It learns biologically meaningful representations from large-scale transcriptomic profiles. Once pre-trained, BulkRNABert can be fine-tuned for various cancer-related downstream tasks, such as cancer type classification or survival analysis, by using its learned embeddings. 4 | 5 | * 📜 **[Read the Paper (Machine Learning for Health 2024)](https://proceedings.mlr.press/v259/gelard25a.html)** 6 | * 🤗 **[Hugging Face Link](https://huggingface.co/InstaDeepAI/BulkRNABert)** 7 | 8 | ## Training data 9 | 10 | The model was pre-trained on a large dataset of bulk RNA-seq profiles from The Cancer Genome Atlas (TCGA) dataset. 11 | The model takes as input sequences of 19062 gene expression. The gene ids that must be used and in which order they should appear in the sequence as provided in `../notebooks/data/bulkrnabert_gene_ids.txt` 12 | 13 | 14 | 15 | ## Training procedure 16 | 17 | Following the original BERT framework, BulkRNABert uses a self-supervised, masked language modeling objective. During pre-training, gene expression values are randomly masked, and the model is tasked with reconstructing these values from their surrounding genomic context. This process enables the model to learn rich, contextual representations of transcriptomic profiles. 18 | 19 | ## How to use 🚀 20 | 21 | We make BulkRNABert available in Jax in this repository and in PyTorch on HuggingFace. Examples on how to use it at: 22 | - Jax: `../notebooks/bulk_rna_bert/inference_bulkrnabert_jax_example.ipynb`. 23 | - PyTorch: `../notebooks/bulk_rna_bert/inference_bulkrnabert_pytorch_example.ipynb`. 24 | 25 | ## Citing our work 📚 26 | 27 | You can cite our model at: 28 | 29 | ```bibtex 30 | @InProceedings{pmlr-v259-gelard25a, 31 | title = {BulkRNABert: Cancer prognosis from bulk RNA-seq based language models}, 32 | author = {G{\'{e}}lard, Maxence and Richard, Guillaume and Pierrot, Thomas and Courn{\`{e}}de, Paul-Henry}, 33 | booktitle = {Proceedings of the 4th Machine Learning for Health Symposium}, 34 | pages = {384--400}, 35 | year = {2025}, 36 | editor = {Hegselmann, Stefan and Zhou, Helen and Healey, Elizabeth and Chang, Trenton and Ellington, Caleb and Mhasawade, Vishwali and Tonekaboni, Sana and Argaw, Peniel and Zhang, Haoran}, 37 | volume = {259}, 38 | series = {Proceedings of Machine Learning Research}, 39 | month = {15--16 Dec}, 40 | publisher = {PMLR}, 41 | url = {https://proceedings.mlr.press/v259/gelard25a.html}, 42 | } 43 | ``` 44 | -------------------------------------------------------------------------------- /docs/sct.md: -------------------------------------------------------------------------------- 1 | # sCT 2 | 3 | sCT (single-Cell Transformer) is our foundational transformer model for single-cell and spatial transcriptomics data. sCT aims to learn rich representations from complex, high-dimensional single-cell datasets to improve various downstream analytical tasks. 4 | 5 | sCT processes raw gene expression profiles across multiple cells to predict discretized gene expression levels for unseen cells without retraining. 6 | The model can handle up to 20,000 protein-coding genes and a bag of 50 cells in the same sample. 7 | This ability (around a million-gene expressions tokens) allows it to learn cross-cell relationships and 8 | capture long-range dependencies in gene expression data, and to mitigate the sparsity typical in single-cell datasets. 9 | 10 | sCT is trained on a large dataset of single-cell RNA-seq and finetuned on spatial transcriptomics data. Evaluation tasks include zero-shot imputation of masked gene expression, and zero-shot prediction of cell types. 11 | 12 | * 📜 **[Read the Paper (OpenReview preprint)](https://openreview.net/forum?id=VdX9tL3VXH)** 13 | * 🤗 **[Hugging Face Link](https://huggingface.co/InstaDeepAI/sCellTransformer)** 14 | 15 | ## Training data 16 | 17 | The model was trained following a two-step procedure: pre-training on single-cell data, then finetuning on spatial transcriptomics data. The single-cell data used for pre-training, comes from the Cellxgene Census collection datasets used to train the scGPT models. It consists of around 50 millions cells and approximately 60,000 genes. The spatial data comes from both the human breast cell atlas and the human heart atlas. 18 | 19 | ## Training procedure 20 | 21 | As detailed in the paper, the gene expressions are first binned into a pre-defined number of bins. This allows the model to better learn the distribution of the gene expressions through sparsity mitigation, noise reduction, and extreme-values handling. Then, the training objective is to predict the masked gene expressions in a cell, following a BERT-like style training. 22 | 23 | ## How to use 🚀 24 | 25 | We make sCT available in Jax in this repository and in PyTorch on HuggingFace. Examples on how to use it at: 26 | - Jax: `../notebooks/sct/inference_sCT_jax_example.ipynb` 27 | - PyTorch: `../notebooks/sct/inference_sCT_pytorch_example.ipynb`. 28 | 29 | ## Citing our work 📚 30 | 31 | You can cite our model at: 32 | 33 | ```bibtex 34 | @misc{joshi2025a, 35 | title={A long range foundation model for zero-shot predictions in single-cell and 36 | spatial transcriptomics data}, 37 | author={Ameya Joshi and Raphael Boige and Lee Zamparo and Ugo Tanielian and Juan Jose 38 | Garau-Luis and Michail Chatzianastasis and Priyanka Pandey and Janik Sielemann and 39 | Alexander Seifert and Martin Brand and Maren Lang and Karim Beguir and Thomas PIERROT}, 40 | year={2025}, 41 | url={https://openreview.net/forum?id=VdX9tL3VXH} 42 | } 43 | ``` -------------------------------------------------------------------------------- /nucleotide_transformer/bulk_rna_bert/config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any 3 | 4 | from pydantic import root_validator 5 | from pydantic.main import BaseModel 6 | 7 | 8 | class BulkRNABertConfig(BaseModel): 9 | n_genes: int 10 | n_expressions_bins: int 11 | embed_dim: int 12 | init_gene_embed_dim: int 13 | project_gene_embedding: bool = False 14 | use_gene_embedding: bool = True 15 | 16 | # architecture 17 | num_attention_heads: int 18 | key_size: int | None = None 19 | ffn_embed_dim: int 20 | num_layers: int 21 | use_memory_efficient_attention: bool = False 22 | 23 | use_gradient_checkpointing: bool = False 24 | 25 | gene2vec_weights_path: str 26 | 27 | embeddings_layers_to_save: tuple[int, ...] = () 28 | attention_layers_to_save: tuple[int, ...] = () 29 | 30 | # RNASeq data processing 31 | rnaseq_tokenizer_bins: list[float] | None = None 32 | use_log_normalization: bool = True 33 | use_max_normalization: bool = True 34 | normalization_factor: float | None = None 35 | 36 | @root_validator 37 | @classmethod 38 | def validate_key_size(cls, values: dict[str, Any]) -> dict[str, Any]: 39 | """ 40 | Checks that the given values are compatible. 41 | """ 42 | key_size = values.get("key_size") 43 | if key_size is None: 44 | embed_dim = values["embed_dim"] 45 | num_attention_heads = values["num_attention_heads"] 46 | if not embed_dim % num_attention_heads == 0: 47 | raise ValueError( 48 | f"When no key size is provided, the embedding dimension should be " 49 | f"divisible by the number of heads, however provided embedding " 50 | f"dimension is {embed_dim} and the number of heads is " 51 | f"{num_attention_heads}." 52 | ) 53 | values["key_size"] = embed_dim // num_attention_heads 54 | return values 55 | 56 | @root_validator 57 | @classmethod 58 | def validate_gene_embedding(cls, values: dict[str, Any]) -> dict[str, Any]: 59 | """ 60 | Checks that the given values are compatible. 61 | """ 62 | use_gene_embedding = values.get("use_gene_embedding") 63 | if use_gene_embedding: 64 | init_gene_embed_dim = values["init_gene_embed_dim"] 65 | embed_dim = values["embed_dim"] 66 | if init_gene_embed_dim != embed_dim: 67 | project_gene_embedding = values["project_gene_embedding"] 68 | if not project_gene_embedding: 69 | logging.warning( 70 | f"Init gene embedding dimension ({init_gene_embed_dim})" 71 | f"different than embedding dimension ({embed_dim})." 72 | f"Setting `project_gene_embedding` to True" 73 | ) 74 | values["project_gene_embedding"] = True 75 | return values 76 | -------------------------------------------------------------------------------- /nucleotide_transformer/enformer/heads.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, List, Optional 3 | 4 | import haiku as hk 5 | import jax 6 | import jax.numpy as jnp 7 | from haiku import initializers 8 | from typing_extensions import TypeAlias 9 | 10 | from nucleotide_transformer.heads import UNET1DSegmentationHead 11 | 12 | SequenceMask: TypeAlias = jnp.ndarray 13 | 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class UNetHead(hk.Module): 19 | """ 20 | Returns a probability between 0 and 1 over a target feature presence 21 | for each nucleotide in the input sequence. Assumes the sequence has been tokenized 22 | with non-overlapping k-mers with k being nucl_per_token. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | features: List[str], 28 | num_classes: int = 2, 29 | embed_dimension: int = 1024, 30 | nucl_per_token: int = 6, 31 | num_layers: int = 2, 32 | remove_cls_token: bool = True, 33 | name: Optional[str] = None, 34 | ): 35 | """ 36 | Args: 37 | features: List of features names. 38 | num_classes: Number of classes. 39 | embed_dimension: Embedding dimension. 40 | nucl_per_token: Number of nucleotides per token. 41 | num_layers: Number of layers. 42 | remove_cls_token: Whether to remove the CLS token. 43 | name: Name of the layer. Defaults to None. 44 | """ 45 | super().__init__(name=name) 46 | self._num_features = len(features) 47 | self._num_classes = num_classes 48 | self.nucl_per_token = nucl_per_token 49 | self.remove_cls_token = remove_cls_token 50 | 51 | w_init = initializers.VarianceScaling(2.0, "fan_in", "uniform") 52 | b_init = initializers.VarianceScaling(2.0, "fan_in", "uniform") 53 | unet = UNET1DSegmentationHead( 54 | num_classes=embed_dimension // 2, 55 | output_channels_list=tuple( 56 | embed_dimension * (2**i) for i in range(num_layers) 57 | ), 58 | ) 59 | fc = hk.Linear( 60 | self.nucl_per_token * self._num_classes * self._num_features, 61 | w_init=w_init, 62 | b_init=b_init, 63 | name="fc", 64 | ) 65 | self._fc = hk.Sequential([unet, jax.nn.swish, fc]) 66 | 67 | def __call__( 68 | self, x: jnp.ndarray, sequence_mask: SequenceMask 69 | ) -> Dict[str, jnp.ndarray]: 70 | """ 71 | Input shape: (batch_size, sequence_length, embed_dim) 72 | Output_shape: (batch_size, sequence_length * nucl_per_token, 73 | num_features, num_classes) 74 | """ 75 | if self.remove_cls_token: 76 | x = x[:, 1:] 77 | logits = self._fc(x) 78 | batch_size, seq_len = x.shape[0], x.shape[1] 79 | logits = jnp.reshape( 80 | logits, 81 | ( 82 | batch_size, 83 | seq_len * self.nucl_per_token, 84 | self._num_features, 85 | self._num_classes, 86 | ), 87 | ) 88 | return {"logits": logits} 89 | -------------------------------------------------------------------------------- /nucleotide_transformer/chatNT/pretrained.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | 3 | from nucleotide_transformer.chatNT.gpt_decoder import GptConfig, RotaryEmbeddingConfig 4 | from nucleotide_transformer.chatNT.model import build_chat_nt_fn 5 | from nucleotide_transformer.chatNT.multi_modal_perceiver_projection import ( 6 | PerceiverResamplerConfig, 7 | ) 8 | from nucleotide_transformer.chatNT.params import download_ckpt 9 | from nucleotide_transformer.model import NucleotideTransformerConfig 10 | 11 | 12 | def get_chatNT(): 13 | gpt_config = GptConfig( 14 | vocab_size=32000, 15 | eos_token_id=2, 16 | embed_dim=4096, 17 | ffn_embed_dim=11008, 18 | num_heads=32, 19 | num_kv_heads=32, 20 | num_layers=32, 21 | rope_config=RotaryEmbeddingConfig(max_seq_len=2048, dim=128, theta=10000.0), 22 | add_bias_ffn=False, 23 | ffn_activation_name="silu", 24 | use_glu_in_ffn=True, 25 | add_bias_lm_head=False, 26 | norm_type="RMS_norm", 27 | rms_norm_eps=1e-06, 28 | parallel_attention_ff=False, 29 | use_gradient_checkpointing=False, 30 | add_bias_attn=False, 31 | ) 32 | nt_config = NucleotideTransformerConfig( 33 | alphabet_size=4107, 34 | pad_token_id=1, 35 | mask_token_id=2, 36 | max_positions=2048, 37 | embed_scale=1.0, 38 | emb_layer_norm_before=False, 39 | attention_heads=16, 40 | key_size=64, 41 | embed_dim=1024, 42 | ffn_embed_dim=4096, 43 | num_layers=29, 44 | positional_embedding=None, 45 | lm_head="roberta", 46 | add_bias_kv=False, 47 | add_bias_ffn=False, 48 | use_rotary_embedding=True, 49 | rescaling_factor=None, 50 | ffn_activation_name="swish", 51 | use_glu_in_ffn=True, 52 | mask_before_attention=False, 53 | layer_norm_eps=1e-05, 54 | pre_layer_norm=True, 55 | bias_word_embedding=False, 56 | token_dropout=False, 57 | masking_ratio=0.0, 58 | masking_prob=0.0, 59 | use_gradient_checkpointing=False, 60 | embeddings_layers_to_save=(21,), 61 | attention_maps_to_save=[], 62 | ) 63 | perceiver_resampler_config = PerceiverResamplerConfig( 64 | emb_layer_norm_before=False, 65 | attention_heads=32, 66 | key_size=128, 67 | embed_dim=4096, 68 | ffn_embed_dim=11008, 69 | num_layers=3, 70 | add_bias_kv=False, 71 | add_bias_ffn=True, 72 | ffn_activation_name="gelu-no-approx", 73 | use_glu_in_ffn=False, 74 | resampled_length=64, 75 | use_gradient_checkpointing=False, 76 | ) 77 | 78 | bio_tokenizer = AutoTokenizer.from_pretrained( 79 | "InstaDeepAI/ChatNT", 80 | subfolder="bio_tokenizer", 81 | ) 82 | english_tokenizer = AutoTokenizer.from_pretrained( 83 | "InstaDeepAI/ChatNT", 84 | subfolder="english_tokenizer", 85 | ) 86 | seq_token_id = 32000 87 | 88 | forward_fn = build_chat_nt_fn( 89 | nt_config=nt_config, 90 | gpt_config=gpt_config, 91 | seq_token_id=seq_token_id, 92 | bio_pad_token_id=bio_tokenizer.pad_token_id, 93 | english_pad_token_id=english_tokenizer.pad_token_id, 94 | perceiver_resampler_config=perceiver_resampler_config, 95 | nt_name="dcnuc_v2_500M_multi_species", 96 | gpt_name="llama_decoder", 97 | ) 98 | 99 | params = download_ckpt() 100 | params = { 101 | key.replace("bio_brain_decoder", "chat_nt_decoder").replace( 102 | "bio_brain_encoder", "chat_nt_encoder" 103 | ): value 104 | for key, value in params.items() 105 | } 106 | 107 | return forward_fn, params, english_tokenizer, bio_tokenizer 108 | -------------------------------------------------------------------------------- /nucleotide_transformer/bulk_rna_bert/pretrained.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import Callable 4 | 5 | import haiku as hk 6 | import jax.numpy as jnp 7 | import joblib 8 | from huggingface_hub import hf_hub_download 9 | 10 | from nucleotide_transformer.bulk_rna_bert.config import BulkRNABertConfig 11 | from nucleotide_transformer.bulk_rna_bert.model import build_bulk_rna_bert_forward_fn 12 | from nucleotide_transformer.bulk_rna_bert.tokenizer import BinnedOmicTokenizer 13 | 14 | ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME" 15 | DEFAULT_CACHE_DIR = "~/.cache" 16 | 17 | 18 | def _get_dir(model_name: str) -> str: 19 | """ 20 | Get directory to save files on user machine. 21 | """ 22 | return os.path.expanduser( 23 | os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), model_name) 24 | ) 25 | 26 | 27 | def download_bulkrnabert_ckpt() -> tuple[hk.Params, BulkRNABertConfig]: 28 | """ 29 | Download BulkRNABert checkpoint from Hugging Face. 30 | 31 | 32 | Returns: 33 | Model parameters. 34 | Model configuration 35 | """ 36 | 37 | save_dir = os.path.join(_get_dir("bulkrnabert"), "bulkrnabert") 38 | 39 | repo_id = "InstaDeepAI/BulkRNABert" 40 | 41 | # Download parameters 42 | print("Downloading model's weights...") 43 | params = joblib.load( 44 | hf_hub_download( 45 | repo_id=repo_id, 46 | filename="jax_params/params.joblib", 47 | cache_dir=save_dir, 48 | ) 49 | ) 50 | 51 | config_path = hf_hub_download( 52 | repo_id=repo_id, 53 | filename="jax_params/config.json", 54 | ) 55 | with open(config_path) as f: 56 | config_dict = json.load(f) 57 | config = BulkRNABertConfig(**config_dict) 58 | 59 | return params, config 60 | 61 | 62 | def get_pretrained_bulkrnabert_model( 63 | compute_dtype: jnp.dtype = jnp.float32, 64 | param_dtype: jnp.dtype = jnp.float32, 65 | output_dtype: jnp.dtype = jnp.float32, 66 | embeddings_layers_to_save: tuple[int, ...] = (), 67 | ) -> tuple[hk.Params, Callable, BinnedOmicTokenizer, BulkRNABertConfig]: 68 | """ 69 | Create a Haiku BulkRNABert model. 70 | model by downloading pre-trained weights and hyperparameters. 71 | 72 | Args: 73 | compute_dtype: the type of the activations. fp16 runs faster and is lighter in 74 | memory. bf16 handles better large int, and is hence more stable ( it avoids 75 | float overflows ). 76 | param_dtype: if compute_dtype is fp16, the model weights will be cast to fp16 77 | during the forward pass anyway. So in inference mode ( not training mode ), 78 | it is better to use params in fp16 if compute_dtype is fp16 too. During 79 | training, it is preferable to keep parameters in float32 for better 80 | numerical stability. 81 | output_dtype: the output type of the model. it determines the float precision 82 | of the gradient when training the model. 83 | embeddings_layers_to_save: Intermediate embeddings to return in the output. 84 | 85 | Returns: 86 | Model parameters. 87 | Haiku function to call the model. 88 | Tokenizer. 89 | Model config (hyperparameters). 90 | 91 | """ 92 | parameters, config = download_bulkrnabert_ckpt() 93 | tokenizer = BinnedOmicTokenizer( 94 | n_expressions_bins=config.n_expressions_bins, 95 | use_max_normalization=config.use_max_normalization, 96 | normalization_factor=config.normalization_factor, # type: ignore 97 | prepend_cls_token=False, 98 | ) 99 | 100 | config.embeddings_layers_to_save = embeddings_layers_to_save 101 | 102 | forward_fn = build_bulk_rna_bert_forward_fn( 103 | model_config=config, 104 | compute_dtype=compute_dtype, 105 | param_dtype=param_dtype, 106 | output_dtype=output_dtype, 107 | model_name="bulk_bert", 108 | ) 109 | 110 | return parameters, forward_fn, tokenizer, config 111 | -------------------------------------------------------------------------------- /nucleotide_transformer/mojo/pretrained.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import Callable 4 | 5 | import haiku as hk 6 | import jax.numpy as jnp 7 | import joblib 8 | from huggingface_hub import hf_hub_download 9 | 10 | from nucleotide_transformer.bulk_rna_bert.tokenizer import BinnedOmicTokenizer 11 | from nucleotide_transformer.mojo.config import MOJOConfig 12 | from nucleotide_transformer.mojo.model import build_mojo_fn 13 | 14 | ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME" 15 | DEFAULT_CACHE_DIR = "~/.cache" 16 | 17 | 18 | def _get_dir(model_name: str) -> str: 19 | """ 20 | Get directory to save files on user machine. 21 | """ 22 | return os.path.expanduser( 23 | os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), model_name) 24 | ) 25 | 26 | 27 | def download_mojo_ckpt() -> tuple[hk.Params, MOJOConfig]: 28 | """ 29 | Download MOJO checkpoint from Hugging Face. 30 | 31 | 32 | Returns: 33 | Model parameters. 34 | Model configuration 35 | """ 36 | 37 | save_dir = os.path.join(_get_dir("mojo"), "mojo") 38 | 39 | repo_id = "InstaDeepAI/MOJO" 40 | 41 | # Download parameters 42 | print("Downloading model's weights...") 43 | params = joblib.load( 44 | hf_hub_download( 45 | repo_id=repo_id, 46 | filename="jax_params/params.joblib", 47 | cache_dir=save_dir, 48 | ) 49 | ) 50 | 51 | config_path = hf_hub_download( 52 | repo_id=repo_id, 53 | filename="jax_params/config.json", 54 | ) 55 | with open(config_path) as f: 56 | config_dict = json.load(f) 57 | config = MOJOConfig(**config_dict) 58 | 59 | return params, config 60 | 61 | 62 | def get_mojo_pretrained_model( 63 | compute_dtype: jnp.dtype = jnp.float32, 64 | param_dtype: jnp.dtype = jnp.float32, 65 | output_dtype: jnp.dtype = jnp.float32, 66 | ) -> tuple[hk.Params, Callable, dict[str, BinnedOmicTokenizer], MOJOConfig]: 67 | """ 68 | Create a Haiku MOJO model. 69 | model by downloading pre-trained weights and hyperparameters. 70 | 71 | Args: 72 | compute_dtype: the type of the activations. fp16 runs faster and is lighter in 73 | memory. bf16 handles better large int, and is hence more stable ( it avoids 74 | float overflows ). 75 | param_dtype: if compute_dtype is fp16, the model weights will be cast to fp16 76 | during the forward pass anyway. So in inference mode ( not training mode ), 77 | it is better to use params in fp16 if compute_dtype is fp16 too. During 78 | training, it is preferable to keep parameters in float32 for better 79 | numerical stability. 80 | output_dtype: the output type of the model. it determines the float precision 81 | of the gradient when training the model. 82 | 83 | Returns: 84 | Model parameters. 85 | Haiku function to call the model. 86 | Tokenizer. 87 | Model config (hyperparameters). 88 | 89 | """ 90 | parameters, config = download_mojo_ckpt() 91 | tokenizers = { 92 | omic: BinnedOmicTokenizer( 93 | n_expressions_bins=config.n_expressions_bins[omic], 94 | min_omic_value=config.min_omic_value[omic], 95 | max_omic_value=config.max_omic_value[omic], 96 | use_max_normalization=config.use_max_normalization[omic], 97 | normalization_factor=config.normalization_factor[omic], 98 | prepend_cls_token=False, 99 | fixed_sequence_length=config.fixed_sequence_length, 100 | unpadded_length=config.sequence_length, 101 | ) 102 | for omic in config.n_expressions_bins.keys() 103 | } 104 | 105 | forward_fn = build_mojo_fn( 106 | model_config=config, 107 | compute_dtype=compute_dtype, 108 | param_dtype=param_dtype, 109 | output_dtype=output_dtype, 110 | model_name="multi_omics_lm", 111 | ) 112 | 113 | return parameters, forward_fn, tokenizers, config 114 | -------------------------------------------------------------------------------- /docs/codon_nt.md: -------------------------------------------------------------------------------- 1 | # Codon-NT 2 | 3 | Condon-NT is a Nucleotide Transformer model variant trained on 3-mers (codons). 4 | It is a 50M parameters transformer pre-trained on a collection of 850 genomes from a wide range of species, including model and non-model organisms. 5 | This work investigates alternative tokenization strategies for genomic language models and their impact on downstream performance and interpretability. 6 | 7 | * 📜 **[Read the Paper (Bioinformatics 2024)](https://academic.oup.com/bioinformatics/article/40/9/btae529/7745814)** 8 | * 🤗 **[Hugging Face Link](https://huggingface.co/InstaDeepAI/nucleotide-transformer-v2-50m-3mer-multi-species)** 9 | 10 | Performance of Codon-NT 11 | 12 | *3mer tokenization achieves better performance than 6-mer on specific protein tasks* 13 | 14 | ## Training data 15 | 16 | The nucleotide-transformer-v2-50m-3mer-multi-species model was pretrained on a total of 850 genomes downloaded from NCBI. Plants and viruses are not included in these genomes, as their regulatory elements differ from those of interest in the paper's tasks. Some heavily studied model organisms were picked to be included in the collection of genomes, which represents a total of 174B nucleotides, i.e roughly 29B tokens. The data has been released as a HuggingFace dataset here. 17 | 18 | ## Training detals 19 | 20 | ### Data pre-rocessing 21 | The DNA sequences are tokenized using the Nucleotide Transformer Tokenizer, which tokenizes sequences as 6-mers tokenizer when possible, otherwise tokenizing each nucleotide separately as described in the Tokenization section of the associated repository. This tokenizer has a vocabulary size of 4105. The inputs of the model are then of the form: 22 | 23 | 24 | 25 | The tokenized sequence have a maximum length of 1,000. 26 | 27 | The masking procedure used is the standard one for Bert-style training: 28 | 29 | 15% of the tokens are masked. 30 | In 80% of the cases, the masked tokens are replaced by [MASK]. 31 | In 10% of the cases, the masked tokens are replaced by a random token (different) from the one they replace. 32 | In the 10% remaining cases, the masked tokens are left as is. 33 | 34 | ### Pre-training 35 | The model was trained with 64 TPUv4s on 300B tokens, with an effective batch size of 1M tokens. The sequence length used was 1000 tokens. The Adam optimizer [38] was used with a learning rate schedule, and standard values for exponential decay rates and epsilon constants, β1 = 0.9, β2 = 0.999 and ε=1e-8. During a first warmup period, the learning rate was increased linearly between 5e-5 and 1e-4 over 16k steps before decreasing following a square root decay until the end of training. 36 | 37 | ### Model architecture 38 | The model belongs to the second generation of nucleotide transformers, with the changes in architecture consisting the use of rotary positional embeddings instead of learned ones, as well as the introduction of Gated Linear Units. 39 | 40 | ## Available Resources 41 | 42 | Benchmark datasets https://huggingface.co/datasets/InstaDeepAI/true-cds-protein-tasks. 43 | 44 | ## How to use 🚀 45 | 46 | To use the code and pre-trained models in jax: 47 | 48 | ```python 49 | import haiku as hk 50 | import jax 51 | import jax.numpy as jnp 52 | from nucleotide_transformer.pretrained import get_pretrained_model 53 | 54 | # Get pretrained model 55 | parameters, forward_fn, tokenizer, config = get_pretrained_model( 56 | model_name="codon_nt", 57 | embeddings_layers_to_save=(12,), 58 | max_positions=32, 59 | ) 60 | forward_fn = hk.transform(forward_fn) 61 | 62 | # Get data and tokenize it 63 | sequences = ["ATTCCGATTCCGATTCCG", "ATTTCTCTCTCTCTCTGAGATCGATCGATCGAT"] 64 | tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)] 65 | tokens = jnp.asarray(tokens_ids, dtype=jnp.int32) 66 | 67 | # Initialize random key 68 | random_key = jax.random.PRNGKey(0) 69 | 70 | # Infer 71 | outs = forward_fn.apply(parameters, random_key, tokens) 72 | 73 | # Get embeddings at layer 20 74 | print(outs["embeddings_12"].shape) 75 | ``` 76 | You can also run our models and find more example code in `../notebooks/codon_nt/inference.ipynb`. 77 | 78 | The code runs both on GPU and TPU thanks to Jax! 79 | 80 | ## Citing our work 📚 81 | 82 | You can cite our model at: 83 | 84 | ```bibtex 85 | @article{10.1093/bioinformatics/btae529, 86 | author = {Boshar, Sam and Trop, Evan and de Almeida, Bernardo P and Copoiu, Liviu and Pierrot, Thomas}, 87 | title = {Are genomic language models all you need? Exploring genomic language models on protein downstream tasks}, 88 | journal = {Bioinformatics}, 89 | volume = {40}, 90 | number = {9}, 91 | pages = {btae529}, 92 | year = {2024}, 93 | month = {08}, 94 | } 95 | ``` 96 | -------------------------------------------------------------------------------- /nucleotide_transformer/mojo/config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | from typing import Any, Literal, Optional 4 | 5 | import numpy as np 6 | from pydantic import BaseModel, Field, root_validator 7 | 8 | 9 | class MOJOConfig(BaseModel): 10 | name: Literal["MOJO"] 11 | alphabet_size: dict[str, int] = Field(default_factory=dict) 12 | token_embed_dim: int 13 | conv_init_embed_dim: int 14 | embed_dim: int 15 | num_downsamples: int 16 | filter_list: list = Field(default_factory=list) 17 | init_gene_embed_dim: int = 200 18 | project_gene_embedding: bool = False 19 | use_gene_embedding: bool = True 20 | 21 | sequence_length: int # n_genes 22 | fixed_sequence_length: int | None = None 23 | 24 | use_remat_in_convs: bool = False 25 | use_remat_in_transformer: bool = False 26 | use_skip_connection: bool = True 27 | 28 | embeddings_layers_to_save: tuple[int, ...] = () 29 | attention_layers_to_save: tuple[int, ...] = () 30 | 31 | gene2vec_weights_path: str 32 | 33 | num_attention_heads: int = 16 34 | key_size: Optional[int] = None 35 | ffn_embed_dim: int = 512 36 | num_layers: int = 4 37 | layer_norm_eps: float = 1e-5 38 | stem_kernel_shape: int = 15 39 | num_hidden_layers_head: int = 0 40 | 41 | n_expressions_bins: dict[str, int] 42 | min_omic_value: dict[str, float] 43 | max_omic_value: dict[str, float] 44 | use_log_normalization: dict[str, bool] 45 | use_max_normalization: dict[str, bool] 46 | normalization_factor: dict[str, float] 47 | 48 | @root_validator 49 | @classmethod 50 | def validate_key_size(cls, values: dict[str, Any]) -> dict[str, Any]: 51 | """ 52 | Checks that the given values are compatible. 53 | """ 54 | key_size = values.get("key_size") 55 | if key_size is None: 56 | embed_dim = values["embed_dim"] 57 | num_attention_heads = values["num_attention_heads"] 58 | if not embed_dim % num_attention_heads == 0: 59 | raise ValueError( 60 | f"When no key size is provided, the embedding dimension should be " 61 | f"divisible by the number of heads, however provided embedding " 62 | f"dimension is {embed_dim} and the number of heads is " 63 | f"{num_attention_heads}." 64 | ) 65 | values["key_size"] = embed_dim // num_attention_heads 66 | return values 67 | 68 | @root_validator 69 | @classmethod 70 | def create_filter_list(cls, values: dict[str, Any]) -> dict[str, Any]: 71 | """ 72 | Checks that the given values are compatible. 73 | """ 74 | num_downsamples: int = values["num_downsamples"] 75 | filter_list = ( 76 | np.linspace( 77 | values.get("conv_init_embed_dim"), 78 | values.get("embed_dim"), 79 | num_downsamples + 1, 80 | ) 81 | .astype(int) 82 | .tolist() 83 | ) 84 | 85 | values["filter_list"] = filter_list 86 | return values 87 | 88 | @root_validator 89 | @classmethod 90 | def validate_gene_embedding(cls, values: dict[str, Any]) -> dict[str, Any]: 91 | """ 92 | Checks that the given values are compatible. 93 | """ 94 | use_gene_embedding = values.get("use_gene_embedding") 95 | if use_gene_embedding: 96 | init_gene_embed_dim = values["init_gene_embed_dim"] 97 | token_embed_dim = values["token_embed_dim"] 98 | if init_gene_embed_dim != token_embed_dim: 99 | project_gene_embedding = values["project_gene_embedding"] 100 | if not project_gene_embedding: 101 | logging.warning( 102 | f"Init gene embedding dimension ({init_gene_embed_dim})" 103 | f"different than token embedding dimension ({token_embed_dim})." 104 | f"Setting `project_gene_embedding` to True" 105 | ) 106 | values["project_gene_embedding"] = True 107 | return values 108 | 109 | @root_validator 110 | @classmethod 111 | def compute_fixed_sequence_length(cls, values: dict[str, Any]) -> dict[str, Any]: 112 | num_downsamples: int = values["num_downsamples"] 113 | sequence_length: int = values["sequence_length"] 114 | downsample_factor = 2**num_downsamples 115 | fixed_sequence_length = ( 116 | math.ceil(sequence_length / downsample_factor) * downsample_factor 117 | ) 118 | values["fixed_sequence_length"] = fixed_sequence_length 119 | return values 120 | -------------------------------------------------------------------------------- /nucleotide_transformer/mojo/layers.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import haiku as hk 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | 8 | class ConvBlock(hk.Module): 9 | """ 10 | Conv Block. 11 | """ 12 | 13 | def __init__( 14 | self, 15 | dim: int, 16 | dim_out: Optional[int] = None, 17 | kernel_size: int = 1, 18 | layer_norm_axis: int = -1, 19 | name: Optional[str] = None, 20 | ): 21 | """ 22 | Args: 23 | dim: input dimension. 24 | dim_out: output dimension. 25 | kernel_size: kernel's size. 26 | name: model's name. 27 | """ 28 | super().__init__(name=name) 29 | self._dim = dim 30 | self._dim_out = dim_out 31 | self._kernel_size = kernel_size 32 | self._layer_norm_axis = layer_norm_axis 33 | 34 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 35 | conv = hk.Conv1D( 36 | output_channels=self._dim if self._dim_out is None else self._dim_out, 37 | kernel_shape=self._kernel_size, 38 | padding="SAME", 39 | data_format="NWC", 40 | ) 41 | 42 | layer_norm = hk.LayerNorm( 43 | axis=self._layer_norm_axis, 44 | create_scale=True, 45 | create_offset=True, 46 | eps=1e-5, 47 | param_axis=self._layer_norm_axis, 48 | ) 49 | 50 | x = layer_norm(x) 51 | x = conv(x) 52 | x = jax.nn.gelu(x) 53 | return x 54 | 55 | 56 | class ResidualConvBlock(hk.Module): 57 | """ 58 | Conv Block with Residual connection. 59 | """ 60 | 61 | def __init__( 62 | self, 63 | dim: int, 64 | dim_out: Optional[int] = None, 65 | kernel_size: int = 1, 66 | name: Optional[str] = None, 67 | ): 68 | """ 69 | Args: 70 | dim: input dimension. 71 | dim_out: output dimension. 72 | kernel_size: kernel's size. 73 | name: model's name. 74 | """ 75 | super().__init__(name=name) 76 | self._dim = dim 77 | self._dim_out = dim_out 78 | self._kernel_size = kernel_size 79 | 80 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 81 | conv_block = ConvBlock( 82 | dim=self._dim, 83 | dim_out=self._dim_out, 84 | kernel_size=self._kernel_size, 85 | ) 86 | return x + conv_block(x) 87 | 88 | 89 | class ResidualDeConvBlock(hk.Module): 90 | """ 91 | Conv Block with Residual connection. 92 | """ 93 | 94 | def __init__( 95 | self, 96 | dim: int, 97 | dim_out: Optional[int] = None, 98 | kernel_size: int = 1, 99 | stride: int = 1, 100 | name: Optional[str] = None, 101 | ): 102 | """ 103 | Args: 104 | dim: input dimension. 105 | dim_out: output dimension. 106 | kernel_size: kernel's size. 107 | stride: kernel's stride. 108 | name: model's name. 109 | """ 110 | super().__init__(name=name) 111 | self._dim = dim 112 | self._dim_out = dim_out 113 | self._kernel_size = kernel_size 114 | self._stride = stride 115 | 116 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 117 | conv_block = DeConvBlock( 118 | dim=self._dim, 119 | dim_out=self._dim_out, 120 | kernel_size=self._kernel_size, 121 | stride=self._stride, 122 | ) 123 | return x + conv_block(x) 124 | 125 | 126 | class DeConvBlock(hk.Module): 127 | """ 128 | Conv Block. 129 | """ 130 | 131 | def __init__( 132 | self, 133 | dim: int, 134 | dim_out: Optional[int] = None, 135 | kernel_size: int = 1, 136 | stride: int = 1, 137 | layer_norm_axis: int = -1, 138 | name: Optional[str] = None, 139 | ): 140 | """ 141 | Args: 142 | dim: input dimension. 143 | dim_out: output dimension. 144 | kernel_size: kernel's size. 145 | stride: kernel's stride. 146 | name: model's name. 147 | """ 148 | super().__init__(name=name) 149 | self._dim = dim 150 | self._dim_out = dim_out 151 | self._kernel_size = kernel_size 152 | self._stride = stride 153 | self._layer_norm_axis = layer_norm_axis 154 | 155 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 156 | conv = hk.Conv1DTranspose( 157 | output_channels=self._dim if self._dim_out is None else self._dim_out, 158 | kernel_shape=self._kernel_size, 159 | padding="SAME", 160 | data_format="NWC", 161 | stride=self._stride, 162 | ) 163 | 164 | layer_norm = hk.LayerNorm( 165 | axis=self._layer_norm_axis, 166 | create_scale=True, 167 | create_offset=True, 168 | eps=1e-5, 169 | param_axis=self._layer_norm_axis, 170 | ) 171 | 172 | x = layer_norm(x) 173 | x = conv(x) 174 | x = jax.nn.gelu(x) 175 | return x 176 | -------------------------------------------------------------------------------- /notebooks/mojo/inference_mojo_pytorch_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Inference with MOJO - PyTorch version from HuggingFace" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "\n", 15 | "[![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/nucleotide-transformer/blob/main/notebooks/mojo/inference_mojo_pytorch_example.ipynb)" 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": {}, 21 | "source": [ 22 | "## Installation and imports" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "outputs": [], 29 | "source": [ 30 | "!pip install pandas\n", 31 | "!pip install transformers\n", 32 | "!pip install torch" 33 | ], 34 | "metadata": { 35 | "collapsed": false 36 | } 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": { 42 | "ExecuteTime": { 43 | "end_time": "2025-06-06T07:50:47.544612Z", 44 | "start_time": "2025-06-06T07:50:47.538478Z" 45 | } 46 | }, 47 | "outputs": [], 48 | "source": [ 49 | "try:\n", 50 | " import nucleotide_transformer\n", 51 | "except:\n", 52 | " !pip install git+https://github.com/instadeepai/nucleotide-transformer@main | tail -n 1\n", 53 | " import nucleotide_transformer" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": { 60 | "ExecuteTime": { 61 | "end_time": "2025-06-06T07:50:52.872346Z", 62 | "start_time": "2025-06-06T07:50:48.163017Z" 63 | } 64 | }, 65 | "outputs": [], 66 | "source": [ 67 | "import numpy as np\n", 68 | "import pandas as pd\n", 69 | "from transformers import AutoModel, AutoTokenizer\n", 70 | "from huggingface_hub import hf_hub_download" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "# Load model" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "metadata": { 84 | "ExecuteTime": { 85 | "end_time": "2025-06-06T07:50:57.729403Z", 86 | "start_time": "2025-06-06T07:50:54.022075Z" 87 | } 88 | }, 89 | "outputs": [], 90 | "source": [ 91 | "tokenizer = AutoTokenizer.from_pretrained(\"InstaDeepAI/MOJO\", trust_remote_code=True)\n", 92 | "model = AutoModel.from_pretrained(\n", 93 | " \"InstaDeepAI/MOJO\",\n", 94 | " trust_remote_code=True,\n", 95 | ")" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": {}, 101 | "source": [ 102 | "## Download, load and preprocess the data" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": { 109 | "ExecuteTime": { 110 | "end_time": "2025-06-06T07:56:21.091995Z", 111 | "start_time": "2025-06-06T07:55:07.965186Z" 112 | } 113 | }, 114 | "outputs": [], 115 | "source": [ 116 | "n_examples = 4\n", 117 | "omic_dict = {}\n", 118 | "\n", 119 | "for omic in [\"rnaseq\", \"methylation\"]:\n", 120 | " csv_path = hf_hub_download(\n", 121 | " repo_id=\"InstaDeepAI/MOJO\",\n", 122 | " filename=f\"data/tcga_{omic}_sample.csv\",\n", 123 | " repo_type=\"model\",\n", 124 | " )\n", 125 | " omic_array = pd.read_csv(csv_path).drop([\"identifier\", \"cohort\"], axis=1).to_numpy()[:n_examples, :]\n", 126 | " if omic == \"rnaseq\":\n", 127 | " omic_array = np.log10(1 + omic_array)\n", 128 | " assert omic_array.shape[1] == model.config.sequence_length\n", 129 | " omic_dict[omic] = omic_array" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "outputs": [], 136 | "source": [ 137 | "omic_ids = {\n", 138 | " omic: tokens[\"input_ids\"]\n", 139 | " for omic, tokens in tokenizer.batch_encode_plus(omic_dict, pad_to_fixed_length=True, return_tensors=\"pt\").items()\n", 140 | "}" 141 | ], 142 | "metadata": { 143 | "collapsed": false 144 | } 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "source": [ 149 | "# Inference" 150 | ], 151 | "metadata": { 152 | "collapsed": false 153 | } 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "outputs": [], 159 | "source": [ 160 | "omic_mean_embeddings = model(omic_ids)[\"after_transformer_embedding\"].mean(axis=1) # embeddings can be used for downstream tasks." 161 | ], 162 | "metadata": { 163 | "collapsed": false, 164 | "pycharm": { 165 | "is_executing": true 166 | } 167 | } 168 | } 169 | ], 170 | "metadata": { 171 | "kernelspec": { 172 | "display_name": "genomics-research-env", 173 | "language": "python", 174 | "name": "python3" 175 | }, 176 | "language_info": { 177 | "codemirror_mode": { 178 | "name": "ipython", 179 | "version": 3 180 | }, 181 | "file_extension": ".py", 182 | "mimetype": "text/x-python", 183 | "name": "python", 184 | "nbconvert_exporter": "python", 185 | "pygments_lexer": "ipython3", 186 | "version": "3.11.10" 187 | } 188 | }, 189 | "nbformat": 4, 190 | "nbformat_minor": 2 191 | } 192 | -------------------------------------------------------------------------------- /docs/agro_nucleotide_transformer.md: -------------------------------------------------------------------------------- 1 | # AgroNT 2 | 3 | Agronomic Nucleotide Transformer (AgroNT) is a novel foundational large language model trained on reference genomes from 48 plant species, with a predominant focus on crop species. AgroNT demonstrates state-of-the-art performance across several prediction tasks ranging from regulatory features, RNA processing, and gene expression in plants. 4 | 5 | * 📜 **[Read the Paper (Communications Biology 2024)](https://www.nature.com/articles/s42003-024-06465-2)** 6 | * 🤗 **[Hugging Face Collection](https://huggingface.co/collections/InstaDeepAI/agro-nucleotide-transformer-65b25c077cd0069ad6f6d344)** 7 | 8 | AgroNT 9 | 10 | *Overview of agronomic nucleotide transformer.* 11 | 12 | AgroNT gene expression performance 13 | 14 | *AgroNT provides gene expression prediction across different plant species. 15 | Gene expression prediction on holdout genes across all tissues are correlated with observed gene expression levels. 16 | 17 | ## Model architecture 18 | 19 | AgroNT uses the transformer architecture with self-attention and a masked language modeling objective to leverage highly 20 | available genotype data from 48 different plant speices to learn general representations of nucleotide sequences. 21 | AgroNT contains 1 billion parameters and has a context window of 1024 tokens. AgroNt uses a non-overlapping 6-mer 22 | tokenizer to convert genomic nucletoide sequences to tokens. As a result the 1024 tokens correspond to approximately 6144 base pairs. 23 | 24 | ## Pre-training 25 | 26 | ### Data 27 | Our pre-training dataset was built from (mostly) edible plants reference genomes contained in the Ensembl Plants database. The dataset consists of approximately 10.5 million genomic sequences across 48 different species. 28 | 29 | ### Processing 30 | All reference genomes for each specie were assembled into a single fasta file. In this fasta file, all nucleotides other than A, T, C, G were replaced by N. A tokenizer was used to convert strings of letters into sequences of tokens. The tokenizer's alphabet consisted of the 46 = 4096 possible 6-mer combinations obtained by combining A, T, C, G, as well as five additional tokens representing standalone A, T, C, G, and N. It also included three special tokens: the pad [PAD], mask [MASK], and class [CLS] tokens. This resulted in a vocabulary of 4104 tokens. To tokenize an input sequence, the tokenizer started with a class token and then converted the sequence from left to right, matching 6-mer tokens when possible, or using the standalone tokens when necessary (for instance, when the letter N was present or if the sequence length was not a multiple of 6). 31 | 32 | ### Tokenization example 33 | 34 | ```python 35 | nucleotide sequence: ATCCCGGNNTCGACACN 36 | tokens: 37 | ``` 38 | 39 | ### Training 40 | The MLM objective was used to pre-train AgroNT in a self-supervised manner. In a self-supervised learning setting annotations (supervision) for each sequence are not needed as we can mask some proportion of the sequence and use the information contained in the unmasked portion of the sequence to predict the masked locations. This allows us to leverage the vast amount of unlabeled genomic sequencing data available. Specifically, 15% of the tokens in the input sequence are selected to be augmented with 80% being replaced with a mask token, 10% randomly replaced by another token from the vocabulary, and the final 10% maintaining the same token. The tokenized sequence is passed through the model and a cross entropy loss is computed for the masked tokens. Pre-training was carried out with a sequence length of 1024 tokens and an effective batch size of 1.5M tokens for 315k update steps, resulting in the model training on a total of 472.5B tokens. 41 | 42 | ### Hardware 43 | Model pre-training was carried out using Google TPU-V4 accelerators, specifically a TPU v4-1024 containing 512 devices. We trained for a total of approx. four days. 44 | 45 | ## How to use 🚀 46 | 47 | To use the code and pre-trained models in jax: 48 | 49 | ```python 50 | import haiku as hk 51 | import jax 52 | import jax.numpy as jnp 53 | from nucleotide_transformer.pretrained import get_pretrained_model 54 | 55 | # Get pretrained model 56 | parameters, forward_fn, tokenizer, config = get_pretrained_model( 57 | model_name="1B_agro_nt", 58 | embeddings_layers_to_save=(12,), 59 | max_positions=32, 60 | ) 61 | forward_fn = hk.transform(forward_fn) 62 | 63 | # Get data and tokenize it 64 | sequences = ["ATTCCGATTCCGATTCCG", "ATTTCTCTCTCTCTCTGAGATCGATCGATCGAT"] 65 | tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)] 66 | tokens = jnp.asarray(tokens_ids, dtype=jnp.int32) 67 | 68 | # Initialize random key 69 | random_key = jax.random.PRNGKey(0) 70 | 71 | # Infer 72 | outs = forward_fn.apply(parameters, random_key, tokens) 73 | 74 | # Get embeddings at layer 20 75 | print(outs["embeddings_12"].shape) 76 | ``` 77 | You can also run our models and find more example code in `../notebooks/agro_nucleotide_transformer/inference.ipynb`. 78 | 79 | The code runs both on GPU and TPU thanks to Jax! 80 | 81 | ## Citing our work 📚 82 | 83 | You can cite our model at: 84 | 85 | ```bibtex 86 | @article{mendoza2024foundational, 87 | title={A foundational large language model for edible plant genomes}, 88 | author={Mendoza-Revilla, Javier and Trop, Evan and Gonzalez, Liam and Roller, Ma{\v{s}}a and Dalla-Torre, Hugo and de Almeida, Bernardo P and Richard, Guillaume and Caton, Jonathan and Lopez Carranza, Nicolas and Skwark, Marcin and others}, 89 | journal={Communications Biology}, 90 | volume={7}, 91 | number={1}, 92 | pages={835}, 93 | year={2024}, 94 | publisher={Nature Publishing Group UK London} 95 | } 96 | ``` -------------------------------------------------------------------------------- /notebooks/bulk_rna_bert/inference_bulkrnabert_pytorch_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Inference with BulkRNABert - PyTorch version from HuggingFace" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "[![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/nucleotide-transformer/blob/main/notebooks/bulk_rna_bert/inference_bulkrnabert_pytorch_example.ipynb)" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "## Installation and imports" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "outputs": [], 28 | "source": [ 29 | "!pip install pandas\n", 30 | "!pip install transformers\n", 31 | "!pip install torch" 32 | ], 33 | "metadata": { 34 | "collapsed": false 35 | } 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": { 41 | "ExecuteTime": { 42 | "end_time": "2025-06-06T07:50:47.544612Z", 43 | "start_time": "2025-06-06T07:50:47.538478Z" 44 | } 45 | }, 46 | "outputs": [], 47 | "source": [ 48 | "try:\n", 49 | " import nucleotide_transformer\n", 50 | "except:\n", 51 | " !pip install git+https://github.com/instadeepai/nucleotide-transformer@main | tail -n 1\n", 52 | " import nucleotide_transformer" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": { 59 | "ExecuteTime": { 60 | "end_time": "2025-06-06T07:50:52.872346Z", 61 | "start_time": "2025-06-06T07:50:48.163017Z" 62 | } 63 | }, 64 | "outputs": [], 65 | "source": [ 66 | "from huggingface_hub import hf_hub_download\n", 67 | "import numpy as np\n", 68 | "import pandas as pd\n", 69 | "from transformers import AutoConfig, AutoModel, AutoTokenizer" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": {}, 75 | "source": [ 76 | "# Load model" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": { 83 | "ExecuteTime": { 84 | "end_time": "2025-06-06T07:50:57.729403Z", 85 | "start_time": "2025-06-06T07:50:54.022075Z" 86 | } 87 | }, 88 | "outputs": [], 89 | "source": [ 90 | "# Load model and tokenizer from Hugging Face\n", 91 | "config = AutoConfig.from_pretrained(\n", 92 | " \"InstaDeepAI/BulkRNABert\",\n", 93 | " trust_remote_code=True,\n", 94 | ")\n", 95 | "config.embeddings_layers_to_save = (4,) # last transformer layer\n", 96 | "\n", 97 | "tokenizer = AutoTokenizer.from_pretrained(\"InstaDeepAI/BulkRNABert\", trust_remote_code=True)\n", 98 | "model = AutoModel.from_pretrained(\n", 99 | " \"InstaDeepAI/BulkRNABert\",\n", 100 | " config=config,\n", 101 | " trust_remote_code=True,\n", 102 | ")" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "metadata": {}, 108 | "source": [ 109 | "## Download the data" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": { 116 | "ExecuteTime": { 117 | "end_time": "2025-06-06T07:56:21.091995Z", 118 | "start_time": "2025-06-06T07:55:07.965186Z" 119 | } 120 | }, 121 | "outputs": [], 122 | "source": [ 123 | "# Downloading the bulk RNA-seq file from HuggingFace\n", 124 | "csv_path = hf_hub_download(\n", 125 | " repo_id=\"InstaDeepAI/BulkRNABert\",\n", 126 | " filename=\"data/tcga_sample.csv\",\n", 127 | " repo_type=\"model\",\n", 128 | ")" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "metadata": {}, 134 | "source": [ 135 | "# Load dataset and preprocess" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": { 142 | "ExecuteTime": { 143 | "end_time": "2025-06-06T07:56:41.705799Z", 144 | "start_time": "2025-06-06T07:56:41.675371Z" 145 | } 146 | }, 147 | "outputs": [], 148 | "source": [ 149 | "gene_expression_array = pd.read_csv(csv_path).drop([\"identifier\"], axis=1).to_numpy()[:1, :]\n", 150 | "gene_expression_array = np.log10(1 + gene_expression_array)\n", 151 | "assert gene_expression_array.shape[1] == config.n_genes\n", 152 | "\n", 153 | "# Tokenize\n", 154 | "gene_expression_ids = tokenizer.batch_encode_plus(gene_expression_array, return_tensors=\"pt\")[\"input_ids\"]" 155 | ] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "source": [ 160 | "# Inference" 161 | ], 162 | "metadata": { 163 | "collapsed": false 164 | } 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "outputs": [], 170 | "source": [ 171 | "# Compute BulkRNABert's embeddings\n", 172 | "gene_expression_mean_embeddings = model(gene_expression_ids)[\"embeddings_4\"].mean(axis=1) # embeddings can be used for downstream tasks." 173 | ], 174 | "metadata": { 175 | "collapsed": false, 176 | "pycharm": { 177 | "is_executing": true 178 | } 179 | } 180 | } 181 | ], 182 | "metadata": { 183 | "kernelspec": { 184 | "display_name": "genomics-research-env", 185 | "language": "python", 186 | "name": "python3" 187 | }, 188 | "language_info": { 189 | "codemirror_mode": { 190 | "name": "ipython", 191 | "version": 3 192 | }, 193 | "file_extension": ".py", 194 | "mimetype": "text/x-python", 195 | "name": "python", 196 | "nbconvert_exporter": "python", 197 | "pygments_lexer": "ipython3", 198 | "version": "3.11.10" 199 | } 200 | }, 201 | "nbformat": 4, 202 | "nbformat_minor": 2 203 | } 204 | -------------------------------------------------------------------------------- /docs/chat_nt.md: -------------------------------------------------------------------------------- 1 | ## ChatNT 2 | 3 | ChatNT is the first multimodal conversational agent designed with a deep understanding of biological sequences (DNA, RNA, proteins). 4 | It enables users — even those with no coding background — to interact with biological data through natural language and it generalizes 5 | across multiple biological tasks and modalities. 6 | 7 | * 📜 **[Read the Paper (Nature Machine Intelligence 2025)](https://www.biorxiv.org/content/10.1101/2024.04.30.591835v1)** 8 | * 🤗 **[ChatNT on Hugging Face](https://huggingface.co/InstaDeepAI/ChatNT)** 9 | * 🚀 **[ChatNT Inference Notebook](../notebooks/chat_nt/inference.ipynb)** 10 | 11 | ChatNT and performance on downstream tasks 12 | 13 | ## Architecture and Parameters 14 | ChatNT is built on a three‑module design: a 500M‑parameter [Nucleotide Transformer v2](https://www.nature.com/articles/s41592-024-02523-z) DNA encoder pre‑trained on genomes from 850 species 15 | (handling up to 12 kb per sequence, Dalla‑Torre et al., 2024), an English‑aware Perceiver Resampler that linearly projects and gated‑attention compresses 16 | 2048 DNA‑token embeddings into 64 task‑conditioned vectors (REF), and a frozen 7B‑parameter [Vicuna‑7B](https://lmsys.org/blog/2023-03-30-vicuna/) decoder. 17 | 18 | Users provide a natural‑language prompt containing one or more `` placeholders and the corresponding DNA sequences (tokenized as 6‑mers). 19 | The projection layer inserts 64 resampled DNA embeddings at each placeholder, and the Vicuna decoder generates free‑form English responses in 20 | an autoregressive fashion, using low‑temperature sampling to produce classification labels, multi‑label statements, or numeric values. 21 | 22 | ## Training Data 23 | ChatNT was instruction‑tuned on a unified corpus covering 27 diverse tasks from DNA, RNA and proteins, spanning multiple species, tissues and biological processes. 24 | This amounted to 605 million DNA tokens (≈ 3.6 billion bases) and 273 million English tokens, sampled uniformly over tasks for 2 billion instruction tokens. 25 | Examples of questions and sequences for each task, as well as additional task information, can be found in [Datasets_overview.csv](https://huggingface.co/InstaDeepAI/ChatNT/blob/main/Datasets_overview.csv). 26 | 27 | ## Tokenization 28 | DNA inputs are broken into overlapping 6‑mer tokens and padded or truncated to 2048 tokens (~ 12 kb). English prompts and 29 | outputs use the LLaMA tokenizer, augmented with `` as a special token to mark sequence insertion points. 30 | 31 | ## Limitations and Disclaimer 32 | ChatNT can only handle questions related to the 27 tasks it has been trained on, including the same format of DNA sequences. ChatNT is **not** a clinical or diagnostic tool. 33 | It can produce incorrect or “hallucinated” answers, particularly on out‑of‑distribution inputs, and its numeric predictions may suffer digit‑level errors. Confidence 34 | estimates require post‑hoc calibration. Users should always validate critical outputs against experiments or specialized bioinformatics 35 | pipelines. 36 | 37 | ## How to use 🚀 38 | 39 | 🔍 The notebook `../notebooks/chat_nt/inference.ipynb` showcases how to generate text from an english input and a DNA sequence. 40 | 41 | ```python 42 | import haiku as hk 43 | import jax 44 | import jax.numpy as jnp 45 | import numpy as np 46 | 47 | from nucleotide_transformer.chatNT.pretrained import get_chatNT 48 | 49 | # Initialize CPU as default JAX device. This makes the code robust to memory leakage on 50 | # the devices. 51 | jax.config.update("jax_platform_name", "cpu") 52 | 53 | backend = "cpu" 54 | devices = jax.devices(backend) 55 | num_devices = len(devices) 56 | 57 | # Load model 58 | forward_fn, parameters, english_tokenizer, bio_tokenizer = get_chatNT() 59 | forward_fn = hk.transform(forward_fn) 60 | 61 | # Replicate over devices 62 | apply_fn = jax.pmap(forward_fn.apply, devices=devices, donate_argnums=(0,)) 63 | random_key = jax.random.PRNGKey(seed=0) 64 | keys = jax.device_put_replicated(random_key, devices=devices) 65 | parameters = jax.device_put_replicated(parameters, devices=devices) 66 | 67 | # Define prompt 68 | english_sequence = "A chat between a curious user and an artificial intelligence assistant that can handle bio sequences. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: Is there any evidence of an acceptor splice site in this sequence ? ASSISTANT:" 69 | dna_sequences = ["A"*600] 70 | 71 | # Tokenize 72 | english_tokens = english_tokenizer( 73 | [english_sequence], 74 | return_tensors="np", 75 | max_length=english_max_length, 76 | padding="max_length", 77 | truncation=True, 78 | ).input_ids 79 | bio_tokens = bio_tokenizer( 80 | dna_sequences, 81 | return_tensors="np", 82 | padding="max_length", 83 | max_length=bio_tokenized_sequence_length, 84 | truncation=True, 85 | ).input_ids 86 | bio_tokens = np.expand_dims(bio_tokens, axis=0) # add batch size dimension 87 | 88 | # Replicate over devices 89 | english_tokens = jnp.stack([jnp.asarray(english_tokens, dtype=jnp.int32)]*num_devices, axis=0) 90 | bio_tokens = jnp.stack([jnp.asarray(bio_tokens, dtype=jnp.int32)]*num_devices, axis=0) 91 | 92 | # Infer 93 | outs = apply_fn( 94 | parameters, 95 | keys, 96 | multi_omics_tokens_ids=(english_tokens, bio_tokens), 97 | projection_english_tokens_ids=english_tokens, 98 | ) 99 | 100 | # Obtain the logits 101 | logits = outs["logits"] 102 | ``` 103 | 104 | ## Citing our work 📚 105 | 106 | You can cite our model at: 107 | 108 | [ChatNT paper](https://www.nature.com/articles/s42256-025-01047-1) 109 | ```bibtex 110 | @article{dealmeida2024chatnt, 111 | title={A multimodal conversational agent for dna, rna and protein tasks}, 112 | author={de Almeida, Bernardo P and Richard, Guillaume and Dalla-Torre, Hugo and Blum, Christopher and Hexemer, Lorenz and Pandey, Priyanka and Laurent, Stefan and others}, 113 | journal={Nature Machine Intelligence}, 114 | year={2025}, 115 | } 116 | ``` -------------------------------------------------------------------------------- /notebooks/mojo/inference_mojo_jax_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "# Inference with MOJO - Jax version" 7 | ], 8 | "metadata": { 9 | "collapsed": false 10 | } 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "[![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/nucleotide-transformer/blob/main/notebooks/mojo/inference_mojo_jax_example.ipynb)" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "SWffCMcBfn37" 23 | }, 24 | "source": [ 25 | "## Installation and imports" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "outputs": [], 32 | "source": [ 33 | "!pip install pandas" 34 | ], 35 | "metadata": { 36 | "collapsed": false 37 | } 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": { 43 | "ExecuteTime": { 44 | "end_time": "2025-06-06T08:05:38.619015Z", 45 | "start_time": "2025-06-06T08:05:38.610404Z" 46 | }, 47 | "id": "alzkIxk9fn38" 48 | }, 49 | "outputs": [], 50 | "source": [ 51 | "import os\n", 52 | "\n", 53 | "try:\n", 54 | " import nucleotide_transformer\n", 55 | "except:\n", 56 | " !pip install git+https://github.com/instadeepai/nucleotide-transformer@main |tail -n 1\n", 57 | " import nucleotide_transformer\n", 58 | "\n", 59 | "if \"COLAB_TPU_ADDR\" in os.environ:\n", 60 | " from jax.tools import colab_tpu\n", 61 | "\n", 62 | " colab_tpu.setup_tpu()" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": { 69 | "ExecuteTime": { 70 | "end_time": "2025-06-06T08:05:42.565213Z", 71 | "start_time": "2025-06-06T08:05:39.457648Z" 72 | }, 73 | "colab": { 74 | "base_uri": "https://localhost:8080/" 75 | }, 76 | "id": "zkTU4k4_fn39", 77 | "outputId": "a04ca440-be95-49e1-b683-bf5b70d00777" 78 | }, 79 | "outputs": [], 80 | "source": [ 81 | "import haiku as hk\n", 82 | "from huggingface_hub import hf_hub_download\n", 83 | "import jax\n", 84 | "import jax.numpy as jnp\n", 85 | "import numpy as np\n", 86 | "import pandas as pd\n", 87 | "\n", 88 | "from nucleotide_transformer.mojo.pretrained import get_mojo_pretrained_model" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "source": [ 94 | "# Load model\n" 95 | ], 96 | "metadata": { 97 | "collapsed": false 98 | } 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "outputs": [], 104 | "source": [ 105 | "# Get pretrained MOJO model\n", 106 | "parameters, forward_fn, tokenizers, config = get_mojo_pretrained_model()\n", 107 | "forward_fn = hk.transform(forward_fn)" 108 | ], 109 | "metadata": { 110 | "collapsed": false, 111 | "pycharm": { 112 | "is_executing": true 113 | } 114 | } 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "source": [ 119 | "## Download, load and preprocess the data" 120 | ], 121 | "metadata": { 122 | "collapsed": false 123 | } 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "outputs": [], 129 | "source": [ 130 | "n_examples = 4\n", 131 | "omic_dict = {}\n", 132 | "\n", 133 | "for omic in [\"rnaseq\", \"methylation\"]:\n", 134 | " csv_path = hf_hub_download(\n", 135 | " repo_id=\"InstaDeepAI/MOJO\",\n", 136 | " filename=f\"data/tcga_{omic}_sample.csv\",\n", 137 | " repo_type=\"model\",\n", 138 | " )\n", 139 | " omic_array = pd.read_csv(csv_path).drop([\"identifier\", \"cohort\"], axis=1).to_numpy()[:n_examples, :]\n", 140 | " if omic == \"rnaseq\":\n", 141 | " omic_array = np.log10(1 + omic_array)\n", 142 | " assert omic_array.shape[1] == config.sequence_length\n", 143 | " omic_dict[omic] = omic_array" 144 | ], 145 | "metadata": { 146 | "collapsed": false, 147 | "pycharm": { 148 | "is_executing": true 149 | } 150 | } 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "outputs": [], 156 | "source": [ 157 | "tokens_ids = {\n", 158 | " omic: jnp.asarray(tokenizers[omic].batch_tokenize(omic_array, pad_to_fixed_length=True), dtype=jnp.int32)\n", 159 | " for omic, omic_array in omic_dict.items()\n", 160 | "}" 161 | ], 162 | "metadata": { 163 | "collapsed": false 164 | } 165 | }, 166 | { 167 | "cell_type": "markdown", 168 | "source": [ 169 | "# Inference" 170 | ], 171 | "metadata": { 172 | "collapsed": false 173 | } 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "outputs": [], 179 | "source": [ 180 | "# Inference\n", 181 | "random_key = jax.random.PRNGKey(0)\n", 182 | "outs = forward_fn.apply(parameters, random_key, tokens_ids)\n", 183 | "\n", 184 | "# Get embedding from last transformer layer\n", 185 | "mean_embedding = outs[\"after_transformer_embedding\"].mean(axis=1)" 186 | ], 187 | "metadata": { 188 | "collapsed": false 189 | } 190 | } 191 | ], 192 | "metadata": { 193 | "accelerator": "GPU", 194 | "colab": { 195 | "gpuType": "T4", 196 | "provenance": [] 197 | }, 198 | "kernelspec": { 199 | "display_name": "debug_segment_enformer", 200 | "language": "python", 201 | "name": "python3" 202 | }, 203 | "language_info": { 204 | "codemirror_mode": { 205 | "name": "ipython", 206 | "version": 3 207 | }, 208 | "file_extension": ".py", 209 | "mimetype": "text/x-python", 210 | "name": "python", 211 | "nbconvert_exporter": "python", 212 | "pygments_lexer": "ipython3", 213 | "version": "3.10.16" 214 | } 215 | }, 216 | "nbformat": 4, 217 | "nbformat_minor": 0 218 | } 219 | -------------------------------------------------------------------------------- /nucleotide_transformer/bulk_rna_bert/model.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional 2 | 3 | import haiku as hk 4 | import jax.numpy as jnp 5 | import jmp 6 | 7 | from nucleotide_transformer.bulk_rna_bert.config import BulkRNABertConfig 8 | from nucleotide_transformer.bulk_rna_bert.layers import SimpleLMHead 9 | from nucleotide_transformer.layers import SelfAttentionBlock 10 | from nucleotide_transformer.types import ( 11 | AttentionMask, 12 | Embedding, 13 | Tokens, 14 | TransformerOutput, 15 | ) 16 | 17 | 18 | class BulkRNABert(hk.Module): 19 | """ 20 | Jax implementation of BulkRNABert model. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | config: BulkRNABertConfig, 26 | name: Optional[str] = None, 27 | ): 28 | super().__init__(name=name) 29 | self._config = config 30 | 31 | self._gene_embedding_layer = hk.Embed( 32 | self._config.n_genes, 33 | self._config.init_gene_embed_dim, 34 | name="gene_embedding", 35 | ) 36 | self._fc_gene_embedding = hk.Linear(self._config.embed_dim) 37 | self._expression_embedding_layer = hk.Embed( 38 | self._config.n_expressions_bins, 39 | self._config.embed_dim, 40 | name="expression_embedding", 41 | ) 42 | self._lm_head = SimpleLMHead( 43 | embed_dim=self._config.embed_dim, 44 | alphabet_size=self._config.n_expressions_bins, 45 | ) 46 | 47 | @hk.transparent 48 | def apply_attention_blocks( 49 | self, 50 | x: Embedding, 51 | outs: dict[str, Embedding], 52 | attention_mask: Optional[AttentionMask] = None, 53 | ) -> tuple[Embedding, dict[str, Embedding]]: 54 | """ 55 | Creates the blocks of attention layers and applies them. 56 | 57 | Args: 58 | x: Sequence embedding of shape (batch,seq_len,embed_dim). 59 | outs: A dictionary to carry through the attention layers which stores the 60 | intermediate sequence embedding and attention maps. 61 | attention_mask: attention mask of shape (batch_size, 1, seq_len, seq_len). 62 | 63 | Returns: 64 | Output sequence embedding. 65 | Dictionary of optional intermediate results (embeddings of the layer and 66 | attention weights). 67 | """ 68 | 69 | layers: list[Callable] = [ 70 | self._self_attention(layer_idx) 71 | for layer_idx in range(self._config.num_layers) 72 | ] 73 | 74 | if self._config.use_gradient_checkpointing: 75 | # the remat-ed function cannot take control flow arguments 76 | layers = [hk.remat(layer) for layer in layers] 77 | 78 | for layer_idx, layer in enumerate(layers): 79 | output = layer( 80 | x=x, 81 | attention_mask=attention_mask, 82 | ) 83 | x = output["embeddings"] 84 | 85 | if (layer_idx + 1) in self._config.embeddings_layers_to_save: 86 | outs[f"embeddings_{(layer_idx + 1)}"] = output["embeddings"] 87 | if (layer_idx + 1) in self._config.attention_layers_to_save: 88 | outs[f"attention_map_layer_{layer_idx + 1}"] = output[ 89 | "attention_weights" 90 | ] 91 | 92 | return x, outs 93 | 94 | @hk.transparent 95 | def _self_attention(self, layer_idx: int) -> SelfAttentionBlock: 96 | return SelfAttentionBlock( # type: ignore 97 | num_heads=self._config.num_attention_heads, 98 | embed_dim=self._config.embed_dim, 99 | key_size=self._config.key_size, 100 | ffn_embed_dim=self._config.ffn_embed_dim, 101 | name=f"self_attention_block_{layer_idx}", 102 | ) 103 | 104 | def __call__( 105 | self, 106 | tokens: Tokens, 107 | attention_mask: AttentionMask = None, 108 | ) -> TransformerOutput: 109 | batch_size, seq_len = tokens.shape 110 | outs: dict[str, jnp.ndarray] = {} 111 | 112 | x = self._expression_embedding_layer(tokens) 113 | if self._config.use_gene_embedding: 114 | gene_embedding = self._gene_embedding_layer( 115 | jnp.arange(self._config.n_genes) 116 | ) 117 | if self._config.project_gene_embedding: 118 | gene_embedding = self._fc_gene_embedding(gene_embedding) 119 | x = x + gene_embedding 120 | 121 | if attention_mask is None: 122 | attention_mask = jnp.ones((batch_size, 1, seq_len, seq_len)) 123 | 124 | x, outs = self.apply_attention_blocks( 125 | x=x, 126 | outs=outs, 127 | attention_mask=attention_mask, 128 | ) 129 | lm_head_outs = self._lm_head(x) 130 | outs["logits"] = lm_head_outs["logits"] 131 | return outs 132 | 133 | 134 | def build_bulk_rna_bert_forward_fn( 135 | model_config: BulkRNABertConfig, 136 | compute_dtype: jnp.dtype = jnp.float32, 137 | param_dtype: jnp.dtype = jnp.float32, 138 | output_dtype: jnp.dtype = jnp.float32, 139 | model_name: Optional[str] = None, 140 | ) -> Callable: 141 | assert {compute_dtype, param_dtype, output_dtype}.issubset( 142 | { 143 | jnp.bfloat16, 144 | jnp.float32, 145 | jnp.float16, 146 | } 147 | ), f"Please provide a dtype in {jnp.bfloat16, jnp.float32, jnp.float16}" 148 | 149 | policy = jmp.Policy( 150 | compute_dtype=compute_dtype, 151 | param_dtype=param_dtype, 152 | output_dtype=output_dtype, 153 | ) 154 | hk.mixed_precision.set_policy(BulkRNABert, policy) 155 | 156 | # Remove it in batch norm to avoid instabilities 157 | norm_policy = jmp.Policy( 158 | compute_dtype=jnp.float32, 159 | param_dtype=param_dtype, 160 | output_dtype=compute_dtype, 161 | ) 162 | hk.mixed_precision.set_policy(hk.BatchNorm, norm_policy) 163 | hk.mixed_precision.set_policy(hk.LayerNorm, norm_policy) 164 | hk.mixed_precision.set_policy(hk.RMSNorm, norm_policy) 165 | 166 | def forward_fn( 167 | tokens: jnp.ndarray, 168 | attention_mask: Optional[jnp.ndarray] = None, 169 | ): 170 | """Forward pass""" 171 | model = BulkRNABert( 172 | config=model_config, 173 | name=model_name, 174 | ) 175 | outs = model( 176 | tokens=tokens, 177 | attention_mask=attention_mask, 178 | ) 179 | return outs 180 | 181 | return forward_fn 182 | -------------------------------------------------------------------------------- /nucleotide_transformer/bulk_rna_bert/tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | 5 | 6 | class BinnedOmicTokenizer: 7 | """ 8 | Tokenizer that bins gene expressions or methylation to convert them to tokens. 9 | """ 10 | 11 | def __init__( 12 | self, 13 | n_expressions_bins: int, 14 | min_omic_value: float = 0.0, 15 | max_omic_value: float = 1.0, 16 | use_max_normalization: bool = True, 17 | normalization_factor: float = 1.0, 18 | prepend_cls_token: bool = False, 19 | fixed_sequence_length: int | None = None, 20 | unpadded_length: int | None = None, 21 | ): 22 | self._n_expressions_bins = n_expressions_bins 23 | self._use_max_normalization = use_max_normalization 24 | self._normalization_factor = normalization_factor 25 | self._prepend_cls_token = prepend_cls_token 26 | 27 | self._gene_expression_bins = np.linspace( 28 | min_omic_value, max_omic_value, self._n_expressions_bins 29 | ) 30 | 31 | self._fixed_sequence_length = fixed_sequence_length 32 | self._unpadded_length = unpadded_length 33 | 34 | standard_tokens = list(map(str, range(len(self._gene_expression_bins)))) 35 | self._pad_token = "" 36 | self._mask_token = "" 37 | self._class_token = "" 38 | self._unk_token = "" 39 | self._eos_token = "" 40 | self._bos_token = "" 41 | self._missing_modality_token = "" 42 | 43 | if prepend_cls_token: 44 | special_tokens = [ 45 | self._class_token, 46 | self._pad_token, 47 | self._mask_token, 48 | ] 49 | else: 50 | special_tokens = [ 51 | self._pad_token, 52 | self._mask_token, 53 | ] 54 | 55 | self._all_tokens = standard_tokens + special_tokens 56 | self._standard_tokens = standard_tokens 57 | self._special_tokens = special_tokens 58 | 59 | self._tokens_to_ids = {tok: i for i, tok in enumerate(self._all_tokens)} 60 | self._ids_to_tokens = {i: tok for tok, i in self._tokens_to_ids.items()} 61 | 62 | @property 63 | def gene_expression_bins(self) -> np.ndarray: 64 | return self._gene_expression_bins 65 | 66 | @property 67 | def pad_token(self) -> str: 68 | return self._pad_token 69 | 70 | @property 71 | def mask_token(self) -> str: 72 | return self._mask_token 73 | 74 | @property 75 | def mask_token_id(self) -> int: 76 | """ 77 | Property that returns id (int representation) of the mask token. 78 | 79 | Returns: 80 | Id (int representation) of the mask token. 81 | """ 82 | return self.token_to_id(self.mask_token) 83 | 84 | @property 85 | def class_token(self) -> str: 86 | return self._class_token 87 | 88 | @property 89 | def unk_token(self) -> str: 90 | return self.unk_token 91 | 92 | @property 93 | def eos_token(self) -> str: 94 | return self._eos_token 95 | 96 | @property 97 | def bos_token(self) -> str: 98 | return self._bos_token 99 | 100 | @property 101 | def pad_id(self) -> int: 102 | return self.token_to_id(self.pad_token) 103 | 104 | @property 105 | def mask_id(self) -> int: 106 | return self.token_to_id(self.mask_token) 107 | 108 | @property 109 | def class_id(self) -> int: 110 | return self.token_to_id(self.class_token) 111 | 112 | @property 113 | def vocabulary(self) -> List[str]: 114 | return self._all_tokens 115 | 116 | @property 117 | def standard_tokens(self) -> List[str]: 118 | return self._standard_tokens 119 | 120 | @property 121 | def special_tokens(self) -> List[str]: 122 | return self._special_tokens 123 | 124 | def id_to_token(self, token_id: int) -> str: 125 | try: 126 | return self._ids_to_tokens[token_id] 127 | except KeyError: 128 | raise KeyError(f"Token id {token_id} not found in vocabulary") 129 | 130 | def token_to_id(self, token: str) -> int: 131 | try: 132 | return self._tokens_to_ids[token] 133 | except KeyError: 134 | raise KeyError(f"Token {token} not found in vocabulary") 135 | 136 | def tokenize( 137 | self, 138 | gene_expressions: np.ndarray | None, 139 | pad_to_fixed_length: bool = False, 140 | max_length: int | None = None, 141 | ) -> np.ndarray: 142 | """ 143 | Tokenize a gene expression array and return an array of bin ids. 144 | 145 | Args: 146 | gene_expressions: Gene expressions sequence to be tokenized. 147 | pad_to_fixed_length: if True and fixed length is provided as attributed 148 | to the tokenizer, the sequence will be padded. 149 | max_length: allows to pass another max length than the one specified 150 | by self._fixed_sequence_length 151 | Returns: 152 | List of tokens ids. 153 | """ 154 | if gene_expressions is None: 155 | assert self._unpadded_length is not None 156 | tokens_ids = np.array([self.mask_token_id] * self._unpadded_length) 157 | else: 158 | if self._use_max_normalization: 159 | gene_expressions /= self._normalization_factor 160 | tokens_ids = np.digitize(gene_expressions, self._gene_expression_bins) 161 | tokens_ids[gene_expressions == 0.0] = 0 162 | if self._prepend_cls_token: 163 | tokens_ids = np.concatenate([[self.class_id], tokens_ids]) 164 | if pad_to_fixed_length: 165 | if self._fixed_sequence_length is not None: 166 | current_max_length = self._fixed_sequence_length 167 | else: 168 | assert max_length is not None 169 | current_max_length = max_length 170 | padded_tokens_ids = np.ones( 171 | current_max_length, dtype=tokens_ids.dtype 172 | ) * self.token_to_id(self._pad_token) 173 | padded_tokens_ids[: len(tokens_ids)] = tokens_ids 174 | return padded_tokens_ids 175 | return tokens_ids 176 | 177 | def batch_tokenize( 178 | self, 179 | gene_expressions: np.ndarray, 180 | pad_to_fixed_length: bool = False, 181 | max_length: int | None = None, 182 | ) -> np.ndarray: 183 | """ 184 | Tokenizes a batch of gene expressions. 185 | 186 | Args: 187 | gene_expressions: gene expressions sequence to be tokenized. 188 | pad_to_fixed_length: if True and fixed length is provided as attributed 189 | to the tokenizer, the sequence will be padded. 190 | max_length: max length in the batch 191 | 192 | Returns: 193 | Tokenized gene expressions. 194 | """ 195 | return np.vstack( 196 | [ 197 | self.tokenize( 198 | g, pad_to_fixed_length=pad_to_fixed_length, max_length=max_length 199 | ) 200 | for g in gene_expressions 201 | ] 202 | ) 203 | -------------------------------------------------------------------------------- /notebooks/bulk_rna_bert/inference_bulkrnaert_jax_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "# Inference with BulkRNABert - Jax version" 7 | ], 8 | "metadata": { 9 | "collapsed": false 10 | } 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "[![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/nucleotide-transformer/blob/main/notebooks/bulk_rna_bert/inference_bulkrnabert_jax_example.ipynb)" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "SWffCMcBfn37" 23 | }, 24 | "source": [ 25 | "## Installation and imports" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 1, 31 | "outputs": [ 32 | { 33 | "name": "stdout", 34 | "output_type": "stream", 35 | "text": [ 36 | "Requirement already satisfied: pandas in /Users/maxencegelard/miniconda3/envs/nucleotide-transformer-private/lib/python3.10/site-packages (2.3.0)\r\n", 37 | "Requirement already satisfied: numpy>=1.22.4 in /Users/maxencegelard/miniconda3/envs/nucleotide-transformer-private/lib/python3.10/site-packages (from pandas) (2.2.6)\r\n", 38 | "Requirement already satisfied: python-dateutil>=2.8.2 in /Users/maxencegelard/miniconda3/envs/nucleotide-transformer-private/lib/python3.10/site-packages (from pandas) (2.9.0.post0)\r\n", 39 | "Requirement already satisfied: pytz>=2020.1 in /Users/maxencegelard/miniconda3/envs/nucleotide-transformer-private/lib/python3.10/site-packages (from pandas) (2025.2)\r\n", 40 | "Requirement already satisfied: tzdata>=2022.7 in /Users/maxencegelard/miniconda3/envs/nucleotide-transformer-private/lib/python3.10/site-packages (from pandas) (2025.2)\r\n", 41 | "Requirement already satisfied: six>=1.5 in /Users/maxencegelard/miniconda3/envs/nucleotide-transformer-private/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas) (1.17.0)\r\n" 42 | ] 43 | } 44 | ], 45 | "source": [ 46 | "!pip install pandas" 47 | ], 48 | "metadata": { 49 | "collapsed": false 50 | } 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": { 56 | "ExecuteTime": { 57 | "end_time": "2025-06-06T08:05:38.619015Z", 58 | "start_time": "2025-06-06T08:05:38.610404Z" 59 | }, 60 | "id": "alzkIxk9fn38" 61 | }, 62 | "outputs": [], 63 | "source": [ 64 | "import os\n", 65 | "\n", 66 | "try:\n", 67 | " import nucleotide_transformer\n", 68 | "except:\n", 69 | " !pip install git+https://github.com/instadeepai/nucleotide-transformer@main |tail -n 1\n", 70 | " import nucleotide_transformer\n", 71 | "\n", 72 | "if \"COLAB_TPU_ADDR\" in os.environ:\n", 73 | " from jax.tools import colab_tpu\n", 74 | "\n", 75 | " colab_tpu.setup_tpu()" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": { 82 | "ExecuteTime": { 83 | "end_time": "2025-06-06T08:05:42.565213Z", 84 | "start_time": "2025-06-06T08:05:39.457648Z" 85 | }, 86 | "colab": { 87 | "base_uri": "https://localhost:8080/" 88 | }, 89 | "id": "zkTU4k4_fn39", 90 | "outputId": "a04ca440-be95-49e1-b683-bf5b70d00777" 91 | }, 92 | "outputs": [], 93 | "source": [ 94 | "import haiku as hk\n", 95 | "from huggingface_hub import hf_hub_download\n", 96 | "import jax\n", 97 | "import jax.numpy as jnp\n", 98 | "import numpy as np\n", 99 | "import pandas as pd\n", 100 | "\n", 101 | "from nucleotide_transformer.bulk_rna_bert.pretrained import get_pretrained_bulkrnabert_model" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "source": [ 107 | "# Load model" 108 | ], 109 | "metadata": { 110 | "collapsed": false 111 | } 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "outputs": [], 117 | "source": [ 118 | "parameters, forward_fn, tokenizer, config = get_pretrained_bulkrnabert_model(\n", 119 | " embeddings_layers_to_save=(4,),\n", 120 | ")\n", 121 | "forward_fn = hk.transform(forward_fn)" 122 | ], 123 | "metadata": { 124 | "collapsed": false, 125 | "pycharm": { 126 | "is_executing": true 127 | } 128 | } 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "source": [ 133 | "## Download the data" 134 | ], 135 | "metadata": { 136 | "collapsed": false 137 | } 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "outputs": [], 143 | "source": [ 144 | "# Downloading the bulk RNA-seq file from HuggingFace\n", 145 | "csv_path = hf_hub_download(\n", 146 | " repo_id=\"InstaDeepAI/BulkRNABert\",\n", 147 | " filename=\"data/tcga_sample.csv\",\n", 148 | " repo_type=\"model\",\n", 149 | ")" 150 | ], 151 | "metadata": { 152 | "collapsed": false, 153 | "pycharm": { 154 | "is_executing": true 155 | } 156 | } 157 | }, 158 | { 159 | "cell_type": "markdown", 160 | "source": [ 161 | "# Load dataset and preprocess" 162 | ], 163 | "metadata": { 164 | "collapsed": false 165 | } 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "outputs": [], 171 | "source": [ 172 | "gene_expression_array = pd.read_csv(csv_path).drop([\"identifier\"], axis=1).to_numpy()[:1, :]\n", 173 | "gene_expression_array = np.log10(1 + gene_expression_array)\n", 174 | "assert gene_expression_array.shape[1] == config.n_genes\n", 175 | "\n", 176 | "# Tokenize\n", 177 | "gene_expression_ids = tokenizer.batch_tokenize(gene_expression_array)\n", 178 | "gene_expression_ids = jnp.asarray(gene_expression_ids, dtype=jnp.int32)" 179 | ], 180 | "metadata": { 181 | "collapsed": false 182 | } 183 | }, 184 | { 185 | "cell_type": "markdown", 186 | "source": [ 187 | "# Inference" 188 | ], 189 | "metadata": { 190 | "collapsed": false 191 | } 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "outputs": [], 197 | "source": [ 198 | "# Inference\n", 199 | "random_key = jax.random.PRNGKey(0)\n", 200 | "outs = forward_fn.apply(parameters, random_key, gene_expression_ids)\n", 201 | "\n", 202 | "# Get mean embeddings from layer 4\n", 203 | "gene_expression_mean_embeddings = outs[\"embeddings_4\"].mean(axis=1)" 204 | ], 205 | "metadata": { 206 | "collapsed": false 207 | } 208 | } 209 | ], 210 | "metadata": { 211 | "accelerator": "GPU", 212 | "colab": { 213 | "gpuType": "T4", 214 | "provenance": [] 215 | }, 216 | "kernelspec": { 217 | "display_name": "debug_segment_enformer", 218 | "language": "python", 219 | "name": "python3" 220 | }, 221 | "language_info": { 222 | "codemirror_mode": { 223 | "name": "ipython", 224 | "version": 3 225 | }, 226 | "file_extension": ".py", 227 | "mimetype": "text/x-python", 228 | "name": "python", 229 | "nbconvert_exporter": "python", 230 | "pygments_lexer": "ipython3", 231 | "version": "3.10.16" 232 | } 233 | }, 234 | "nbformat": 4, 235 | "nbformat_minor": 0 236 | } 237 | -------------------------------------------------------------------------------- /notebooks/isoformer/inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "8e193673", 6 | "metadata": {}, 7 | "source": [ 8 | "# Inference with Isoformer\n", 9 | "\n", 10 | "This notebook demonstrates how to use the Isoformer model for multi-omics data analysis and gene expression prediction. It shows how to load the model, process DNA, RNA, and protein sequences, and perform inference to predict gene expression levels.\n", 11 | "\n", 12 | "For Google Colab: as model inference requires a significant amount of RAM, please use a v2-8 TPU for testing." 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "id": "f3155bac", 18 | "metadata": {}, 19 | "source": [ 20 | "[![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/nucleotide-transformer/blob/main/notebooks/isoformer/inference.ipynb)" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "id": "76982781", 26 | "metadata": {}, 27 | "source": [ 28 | "## Installation and Dependencies" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "d748898b-5598-4eb9-9f76-c83decc3d463", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "! pip install -U huggingface_hub\n", 39 | "! pip install -U datasets\n", 40 | "! pip install transformers \n", 41 | "! pip install torch\n", 42 | "! pip install enformer_pytorch\n", 43 | "! pip install tqdm\n", 44 | "! pip install pyfaidx\n", 45 | "! pip install pandas\n", 46 | "! pip install pathlib\n", 47 | "! pip install ssl" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "id": "2509db60", 53 | "metadata": {}, 54 | "source": [ 55 | "## Import Required Libraries" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "id": "955fa25a-1444-4798-80e4-83c7c0b77054", 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "from datasets import load_dataset\n", 66 | "from transformers import AutoTokenizer, AutoModelForMaskedLM\n", 67 | "import numpy as np\n", 68 | "import torch" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "id": "9ffa6769", 74 | "metadata": {}, 75 | "source": [ 76 | "## Load Dataset\n", 77 | "\n", 78 | "Load the multi-omics transcript expression dataset. We will use the dataset light version for testing purposes, and the test split with a sequence length of 196,608 nucleotides." 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "id": "460edff7-f4d7-4444-907f-19c26c4911d8", 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "# Load the dataset\n", 89 | "transcript_expression_dataset = load_dataset(\n", 90 | " \"InstaDeepAI/multi_omics_transcript_expression\",\n", 91 | " task_name=\"transcript_expression_expression\",\n", 92 | " sequence_length=196608,\n", 93 | " filter_out_sequence_length=196608,\n", 94 | " split=\"test\",\n", 95 | " streaming=False,\n", 96 | " light_version=True, # Set to False to use the full dataset\n", 97 | ")\n", 98 | "dataset = iter(transcript_expression_dataset)" 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "id": "84aebeea", 104 | "metadata": {}, 105 | "source": [ 106 | "## Load Model and Tokenizer\n", 107 | "\n", 108 | "Load the pre-trained Isoformer model and its tokenizer from Hugging Face." 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "id": "6ea96412-aa6d-41c0-8fa8-cb27fd1c3694", 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "# Import the tokenizer and the model\n", 119 | "tokenizer = AutoTokenizer.from_pretrained(\"InstaDeepAI/isoformer\", trust_remote_code=True)\n", 120 | "model = AutoModelForMaskedLM.from_pretrained(\"InstaDeepAI/isoformer\",trust_remote_code=True)" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "id": "ebb78edc", 126 | "metadata": {}, 127 | "source": [ 128 | "## Prepare Input Data\n", 129 | "\n", 130 | "Prepare the input sequences for DNA, RNA, and protein data." 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "id": "a50ba7cf-1278-4ac4-9e52-53c30654a298", 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "# Sample data\n", 141 | "sample_data = next(dataset)\n", 142 | "protein_sequences = [sample_data[\"Protein\"]]\n", 143 | "rna_sequences = [sample_data[\"RNA\"]]\n", 144 | "dna_sequences = [sample_data[\"DNA\"]]\n", 145 | "sequence_length = 196_608" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "id": "7c12c787", 151 | "metadata": {}, 152 | "source": [ 153 | "## Tokenize Input Sequences\n", 154 | "\n", 155 | "Tokenize the input sequences for the model." 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "id": "59a5de5f-34ac-48b1-87e4-d4b3e51c0d2f", 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "# Tokenize\n", 166 | "torch_tokens = tokenizer(\n", 167 | " dna_input=dna_sequences, rna_input=rna_sequences, protein_input=protein_sequences\n", 168 | ")\n", 169 | "dna_torch_tokens = torch.tensor(torch_tokens[0][\"input_ids\"])\n", 170 | "rna_torch_tokens = torch.tensor(torch_tokens[1][\"input_ids\"])\n", 171 | "protein_torch_tokens = torch.tensor(torch_tokens[2][\"input_ids\"])" 172 | ] 173 | }, 174 | { 175 | "cell_type": "markdown", 176 | "id": "63f5fc90", 177 | "metadata": {}, 178 | "source": [ 179 | "## Run Model Inference\n", 180 | "\n", 181 | "Perform inference using the Isoformer model to predict gene expression levels and obtain DNA embeddings." 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "id": "03bf18a6-b356-4bb0-9f78-7f649a29a974", 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "# Inference\n", 192 | "torch_output = model.forward(\n", 193 | " tensor_dna=dna_torch_tokens,\n", 194 | " tensor_rna=rna_torch_tokens,\n", 195 | " tensor_protein=protein_torch_tokens,\n", 196 | " attention_mask_rna=rna_torch_tokens != 1,\n", 197 | " attention_mask_protein=protein_torch_tokens != 1,\n", 198 | ")\n", 199 | "\n", 200 | "print(f\"Gene expression predictions: {torch_output['gene_expression_predictions']}\")\n", 201 | "print(f\"Final DNA embedding: {torch_output['final_dna_embeddings']}\")" 202 | ] 203 | } 204 | ], 205 | "metadata": { 206 | "kernelspec": { 207 | "display_name": "Python 3 (ipykernel)", 208 | "language": "python", 209 | "name": "python3" 210 | }, 211 | "language_info": { 212 | "codemirror_mode": { 213 | "name": "ipython", 214 | "version": 3 215 | }, 216 | "file_extension": ".py", 217 | "mimetype": "text/x-python", 218 | "name": "python", 219 | "nbconvert_exporter": "python", 220 | "pygments_lexer": "ipython3", 221 | "version": "3.12.0" 222 | } 223 | }, 224 | "nbformat": 4, 225 | "nbformat_minor": 5 226 | } 227 | -------------------------------------------------------------------------------- /nucleotide_transformer/chatNT/gpt_rotary.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Tuple 3 | 4 | import haiku as hk 5 | import jax.numpy as jnp 6 | import numpy as np 7 | 8 | 9 | def create_sinusoidal_positions(num_pos: int, dim: int, theta: float) -> np.ndarray: 10 | """ 11 | Create the sinus and cosines for the RoPE 12 | 13 | Args: 14 | num_pos: the number of position to encode 15 | dim: the dimension of the RoPE 16 | theta: rotation angle 17 | 18 | Returns: 19 | Array of size (num_pos, 2*dim) containing the sinus and cosinus for RoPE 20 | """ 21 | 22 | inv_freq = 1.0 / (theta ** (np.arange(0, dim, 2) / dim)) 23 | sinusoid_inp = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq) 24 | sin, cos = np.sin(sinusoid_inp), np.cos(sinusoid_inp) 25 | 26 | sentinel = dim // 2 + dim % 2 27 | jmp_policy = hk.mixed_precision.current_policy() 28 | if jmp_policy is None: 29 | # default float32 30 | compute_dtype = np.float32 31 | else: 32 | # cast to jmp policy if specified 33 | compute_dtype = jmp_policy.compute_dtype 34 | 35 | sincos = np.zeros((num_pos, dim), dtype=compute_dtype) 36 | sincos[:, 0:sentinel] = sin 37 | sincos[:, sentinel:] = cos 38 | 39 | return np.array(sincos) 40 | 41 | 42 | def rotate_every_two(attention_tensor: jnp.ndarray) -> jnp.ndarray: 43 | """ 44 | Prepare a tensor to apply the RoPE mechanism 45 | 46 | Args: 47 | attention_tensor: Tensor of shape (batch_size, seq_len, num_heads, key_dim) 48 | It is in fact a key of query tensor 49 | 50 | Returns: 51 | The even indices in the last dimension have their sign flipped 52 | tensor size : (batch_size, seq_len, num_heads, key_dim) 53 | """ 54 | rotate_half_tensor = jnp.stack( 55 | (-attention_tensor[:, :, :, 1::2], attention_tensor[:, :, :, ::2]), axis=-1 56 | ) 57 | rotate_half_tensor = rotate_half_tensor.reshape( 58 | rotate_half_tensor.shape[:-2] + (-1,) 59 | ) 60 | return rotate_half_tensor 61 | 62 | 63 | def apply_rotary_pos_emb( 64 | attention_tensor: jnp.ndarray, sincos: jnp.ndarray 65 | ) -> jnp.ndarray: 66 | """ 67 | Apply the RoPE to attention_tensor 68 | Args: 69 | attention_tensor: Tensor of shape (batch_size, seq_len, num_heads, key_dim) 70 | It is in fact a key of query tensor 71 | sincos: the sincos generated by the function 'create_sinusoidal_positions' 72 | shape : 73 | 74 | Returns: 75 | the corresponding RoPE-encoded tensor 76 | """ 77 | sin_pos, cos_pos = sincos 78 | sin_pos = sin_pos[:, :, None, :].repeat(2, 3) 79 | cos_pos = cos_pos[:, :, None, :].repeat(2, 3) 80 | return (attention_tensor * cos_pos) + (rotate_every_two(attention_tensor) * sin_pos) 81 | 82 | 83 | @dataclass 84 | class RotaryEmbeddingConfig: 85 | """ 86 | Rotary Positional Embedding configuration 87 | max_seq_len: The number of positions to encode and cache. 88 | dim: Dimension of RoPE. 89 | theta: Rotation angle. 90 | """ 91 | 92 | max_seq_len: int 93 | dim: int 94 | theta: float 95 | 96 | 97 | class RotaryEmbedding(hk.Module): 98 | """ 99 | Rotary Positional Embedding inspired by GPT-like models (LLAMA and GPTJ). 100 | """ 101 | 102 | def __init__( 103 | self, 104 | config: RotaryEmbeddingConfig, 105 | name: Optional[str] = None, 106 | ): 107 | """ 108 | Args: 109 | config: Rotary Positional Embedding configuration. 110 | name: Name of the layer. Defaults to None. 111 | """ 112 | super().__init__(name=name) 113 | self.max_seq_len = config.max_seq_len 114 | self.dim = config.dim 115 | self.theta = config.theta 116 | self.sincos_cache = self._create_sinusoidal_positions() 117 | 118 | def _create_sinusoidal_positions(self) -> np.ndarray: 119 | """ 120 | Create the sines and cosines for the RoPE. 121 | 122 | Returns: 123 | Sinusoidal positions of shape (self.max_seq_len, self.dim). 124 | """ 125 | inv_freq = 1.0 / (self.theta ** (np.arange(0, self.dim, 2) / self.dim)) 126 | sinusoid_inp = np.einsum("i , j -> i j", np.arange(self.max_seq_len), inv_freq) 127 | sin, cos = np.sin(sinusoid_inp), np.cos(sinusoid_inp) 128 | sentinel = self.dim // 2 + self.dim % 2 129 | jmp_policy = hk.mixed_precision.current_policy() 130 | if jmp_policy is None: 131 | compute_dtype = np.float32 132 | else: 133 | compute_dtype = jmp_policy.compute_dtype 134 | sincos = np.zeros((self.max_seq_len, self.dim), dtype=compute_dtype) 135 | sincos[:, 0:sentinel] = sin 136 | sincos[:, sentinel:] = cos 137 | return np.array(sincos) 138 | 139 | def _rotate_every_two(self, x: jnp.ndarray) -> jnp.ndarray: 140 | """ 141 | Prepare a tensor to apply the RoPE mechanism. 142 | 143 | Args: 144 | x: Tensor of shape (batch_size, seq_len, num_heads, head_dim), 145 | typically this is the key or query tensor. 146 | 147 | Returns: 148 | The even indices in the last dimension have their sign flipped. 149 | Tensor of shape (batch_size, seq_len, num_heads, head_dim). 150 | """ 151 | rotate_half = jnp.stack((-x[:, :, :, 1::2], x[:, :, :, ::2]), axis=-1) 152 | rotate_half = rotate_half.reshape(rotate_half.shape[:-2] + (-1,)) 153 | return rotate_half 154 | 155 | def _apply_rotary_pos_emb(self, x: jnp.ndarray, sincos: jnp.ndarray) -> jnp.ndarray: 156 | """ 157 | Applies rotary embeddings to x. 158 | 159 | Args: 160 | x: Tensor of shape (batch_size, seq_len, num_heads, head_dim), 161 | typically this is the key or query tensor. 162 | Returns: 163 | Rope embeddings tensor. 164 | """ 165 | sin_pos, cos_pos = sincos 166 | sin_pos = sin_pos[:, :, None, :].repeat(2, 3) 167 | cos_pos = cos_pos[:, :, None, :].repeat(2, 3) 168 | return (x * cos_pos) + (self._rotate_every_two(x) * sin_pos) 169 | 170 | def __call__( 171 | self, k: jnp.ndarray, q: jnp.ndarray, positions: Optional[jnp.ndarray] = None 172 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 173 | """ 174 | Applies rotary embeddings to k and q. 175 | 176 | Args: 177 | k: key tensor of shape (batch_size, seq_len, num_heads, head_dim), 178 | q: value tensor of shape (batch_size, seq_len, num_heads, head_dim), 179 | positions: optional positions offset useful when caching, 180 | 181 | Returns: 182 | RoPE embeddings for the keys and values. 183 | """ 184 | position_ids = jnp.arange(0, k.shape[1], 1, dtype=jnp.int32) 185 | position_ids = jnp.expand_dims(position_ids, 0).repeat(k.shape[0], 0) 186 | if positions is not None: 187 | position_ids += positions 188 | sincos = jnp.take(self.sincos_cache, position_ids, axis=0) 189 | sincos = jnp.split(sincos, 2, axis=-1) 190 | k_rot = self._apply_rotary_pos_emb(k[:, :, :, : self.dim], sincos) 191 | k_pass = k[:, :, :, self.dim :] 192 | q_rot = self._apply_rotary_pos_emb(q[:, :, :, : self.dim], sincos) 193 | q_pass = q[:, :, :, self.dim :] 194 | keys = jnp.concatenate([k_rot, k_pass], axis=-1) 195 | values = jnp.concatenate([q_rot, q_pass], axis=-1) 196 | return keys, values 197 | -------------------------------------------------------------------------------- /docs/nucleotide_transformer.md: -------------------------------------------------------------------------------- 1 | # Nucleotide Transformer 2 | 3 | The Nucleotide Transformer (NT) project addresses the challenge of predicting molecular phenotypes from DNA by developing 4 | large-scale foundation models pre-trained on extensive DNA sequences. 5 | These models aim to learn general-purpose representations of DNA, overcoming limitations of 6 | task-specific models and data scarcity, similar to how models like BERT and GPT revolutionized natural language processing.   7 | 8 | Compared to other approaches, our models do not only integrate information from single reference genomes, 9 | but leverage DNA sequences from over 3,200 diverse human genomes, as well as 850 genomes from a wide range of species, 10 | including model and non-model organisms. Through robust and extensive evaluation, 11 | we show that these large models provide extremely accurate molecular phenotype prediction compared to existing methods. 12 | 13 | * 📜 **[Read the Paper (Nature Methods 2025)](https://www.nature.com/articles/s41592-024-02523-z)** 14 | * 🤗 **[Hugging Face Collection](https://huggingface.co/collections/InstaDeepAI/nucleotide-transformer-65099cdde13ff96230f2e592)** 15 | * 🚀 **Fine-tuning Notebooks (HF): ([LoRA](https://github.com/huggingface/notebooks/blob/main/examples/nucleotide_transformer_dna_sequence_modelling_with_peft.ipynb) and [regular](https://github.com/huggingface/notebooks/blob/main/examples/nucleotide_transformer_dna_sequence_modelling.ipynb))** 16 | 17 | 18 | Performance on downstream tasks 19 | 20 | *The Nucleotide Transformer model accurately predicts diverse genomics tasks 21 | after fine-tuning. We show the performance results across downstream tasks for fine-tuned transformer models. Error bars represent 2 SDs derived from 10-fold cross-validation.* 22 | 23 | ## Model Variants and Sizes 24 | 25 | NT models are transformer-based, varying in size (50 million to 2.5 billion parameters) and pre-training data. 26 | Two main sets of NT models were developed: 27 | 28 | 1 - Initial NT Models (NT-v1): This first generation of models explored a range of parameter sizes and pre-training data sources: 29 | - Human ref 500M: A 500-million-parameter model trained on the human reference genome (GRCh38/hg38). 30 | - 1000G 500M: A 500-million-parameter model trained on 3,202 diverse human genomes from the 1000 Genomes Project. 31 | - 1000G 2.5B: A 2.5-billion-parameter model, also trained on the 3,202 genomes from the 1000 Genomes Project. 32 | - Multispecies 2.5B: A 2.5-billion-parameter model trained on a diverse collection of 850 genomes from various species, including 11 model organisms. 33 | 34 | 2 - Optimized NT Models (NT-v2): Following insights from the initial models, a second set of four models was developed, focusing on parameter efficiency and architectural advancements. These models range from 50 million to 500 million parameters. 35 | - Our second version Nucleotide Transformer v2 models include a series of architectural changes that proved more efficient: instead of using learned positional embeddings, we use Rotary Embeddings that are used at each attention layer and Gated Linear Units with swish activations without bias. These improved models also accept sequences of up to 2,048 tokens leading to a longer context window of 12kbp. 36 | - Inspired by Chinchilla scaling laws, we also trained our NT-v2 models on our multi-species dataset for longer duration (300B tokens for the 50M and 100M models; 1T tokens for the 250M and 500M model) compared to the v1 models (300B tokens for all four models). 37 | - NT-v2 models were all pre-trained on the multispecies dataset 38 | 39 | Both type of models are encoder-only transformers tokenizing DNA into 6-mers (4,104 total tokens including special ones). 40 | NT-v1 used learnable positional encodings for a 6kb context. 41 | NT-v2 models incorporated rotary positional embeddings (RoPE), SwiGLU activations, removed biases and dropout, and extended context length to 12kb. 42 | 43 | ## Available Resources 44 | 45 | Pre-training datasets, models and downstream tasks can be found on our [HuggingFace space](https://huggingface.co/collections/InstaDeepAI/nucleotide-transformer-65099cdde13ff96230f2e592). 46 | Interactive Leaderboard on downstream tasks is available at https://huggingface.co/spaces/InstaDeepAI/nucleotide_transformer_benchmark. 47 | 48 | ## Tokenization :abc: 49 | 50 | The models are trained on sequences of length up to 1000 (NT-v1) and 2000 (NT-v1) tokens, including the 51 | \ token prepended automatically to the beginning of the sequence. The tokenizer 52 | starts tokenizing from left to right by grouping the letters "A", "C", "G" and "T" in 53 | 6-mers. The "N" letter is chosen not to be grouped inside the k-mers, therefore 54 | whenever the tokenizer encounters a "N", or if the number of nucleotides in the sequence 55 | is not a multiple of 6, it will tokenize the nucleotides without grouping them. Examples 56 | are given below: 57 | 58 | ```python 59 | dna_sequence_1 = "ACGTGTACGTGCACGGACGACTAGTCAGCA" 60 | tokenized_dna_sequence_1 = [,,,,,] 61 | 62 | dna_sequence_2 = "ACGTGTACNTGCACGGANCGACTAGTCTGA" 63 | tokenized_dna_sequence_2 = [,,,,,,,,,,] 64 | ``` 65 | 66 | All the v1 and v2 transformers can therefore take sequences of up to 5994 and 12282 nucleotides respectively if there are 67 | no "N" inside. 68 | 69 | ## How to use 🚀 70 | 71 | To use the code and pre-trained models in jax: 72 | 73 | ```python 74 | import haiku as hk 75 | import jax 76 | import jax.numpy as jnp 77 | from nucleotide_transformer.pretrained import get_pretrained_model 78 | 79 | # Get pretrained model 80 | parameters, forward_fn, tokenizer, config = get_pretrained_model( 81 | model_name="250M_multi_species_v2", 82 | embeddings_layers_to_save=(20,), 83 | max_positions=32, 84 | ) 85 | forward_fn = hk.transform(forward_fn) 86 | 87 | # Get data and tokenize it 88 | sequences = ["ATTCCGATTCCGATTCCG", "ATTTCTCTCTCTCTCTGAGATCGATCGATCGAT"] 89 | tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)] 90 | tokens = jnp.asarray(tokens_ids, dtype=jnp.int32) 91 | 92 | # Initialize random key 93 | random_key = jax.random.PRNGKey(0) 94 | 95 | # Infer 96 | outs = forward_fn.apply(parameters, random_key, tokens) 97 | 98 | # Get embeddings at layer 20 99 | print(outs["embeddings_20"].shape) 100 | ``` 101 | Supported model names are: 102 | - **500M_human_ref** 103 | - **500M_1000G** 104 | - **2B5_1000G** 105 | - **2B5_multi_species** 106 | - **50M_multi_species_v2** 107 | - **100M_multi_species_v2** 108 | - **250M_multi_species_v2** 109 | - **500M_multi_species_v2** 110 | 111 | You can also run our models and find more example code in `../notebooks/nucleotide_transformer/inference.ipynb`. 112 | 113 | The code runs both on GPU and TPU thanks to Jax! 114 | 115 | ### Embeddings retrieval 116 | The transformer layers are 1-indexed, which means that calling `get_pretrained_model` with the arguments `model_name="500M_human_ref"` and `embeddings_layers_to_save=(1, 20,)` will result in extracting embeddings after the first and 20-th transformer layer. For transformers using the Roberta LM head, it is common practice to extract the final embeddings after the first layer norm of the LM head rather than after the last transformer block. Therefore, if `get_pretrained_model` is called with the following arguments `embeddings_layers_to_save=(24,)`, the embeddings will not be extracted after the final transformer layer but rather after the first layer norm of the LM head. 117 | 118 | 119 | 120 | ## Citing our work 📚 121 | 122 | You can cite our model at: 123 | 124 | ```bibtex 125 | @article{dalla2025nucleotide, 126 | title={Nucleotide Transformer: building and evaluating robust foundation models for human genomics}, 127 | author={Dalla-Torre, Hugo and Gonzalez, Liam and Mendoza Revilla, Javier and Lopez Carranza, Nicolas and Henryk Grywaczewski, Adam and Oteri, Francesco and Dallago, Christian and Trop, Evan and de Almeida, Bernardo P and Sirelkhatim, Hassan and Richard, Guillaume and others}, 128 | journal={Nature Methods}, 129 | volume={22}, 130 | pages={287-297}, 131 | year={2025}, 132 | publisher={Nature Publishing Group UK London} 133 | } 134 | ``` 135 | -------------------------------------------------------------------------------- /nucleotide_transformer/borzoi/layers.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | from typing import Optional, Tuple, Union 3 | 4 | import haiku as hk 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | 9 | 10 | def _prepend_dims(x: np.ndarray, num_dims: int) -> np.ndarray: 11 | return jnp.reshape(x, tuple([1] * num_dims) + x.shape) 12 | 13 | 14 | def get_positional_features_central_mask_borzoi( 15 | positions: jnp.ndarray, feature_size: int, seq_length: int 16 | ) -> np.ndarray: 17 | """Positional features using a central mask (allow only central features).""" 18 | pow_rate = jnp.exp(jnp.log(seq_length + 1) / feature_size).astype("float32") 19 | center_widths = jnp.power(pow_rate, jnp.arange(1, feature_size + 1, 1)) 20 | center_widths = center_widths - 1 21 | center_widths = _prepend_dims(center_widths, positions.ndim) 22 | outputs = jnp.asarray( 23 | center_widths > jnp.abs(positions)[..., jnp.newaxis], jnp.float32 24 | ) 25 | 26 | return outputs 27 | 28 | 29 | def get_positional_embed_borzoi(seq_len: int, feature_size: int) -> jnp.ndarray: 30 | """ 31 | Compute positional embedding for Borzoi. Note that it is different than the one 32 | used in Enformer 33 | """ 34 | distances = jnp.arange(-seq_len + 1, seq_len) 35 | 36 | num_components = 2 37 | 38 | if (feature_size % num_components) != 0: 39 | raise ValueError( 40 | f"feature size is not divisible by number of components ({num_components})" 41 | ) 42 | 43 | num_basis_per_class = feature_size // num_components 44 | 45 | embeddings = [] 46 | 47 | embeddings.append( 48 | get_positional_features_central_mask_borzoi( 49 | distances, num_basis_per_class, seq_len 50 | ) 51 | ) 52 | 53 | embeddings = jnp.concatenate(embeddings, axis=-1) 54 | embeddings = jnp.concatenate( 55 | (embeddings, jnp.sign(distances)[..., None] * embeddings), axis=-1 56 | ) 57 | return embeddings 58 | 59 | 60 | class SeparableDepthwiseConv1D(hk.Module): 61 | """Separable 2-D Depthwise Convolution Module.""" 62 | 63 | def __init__( 64 | self, 65 | channel_multiplier: int, 66 | kernel_shape: int, 67 | stride: int = 1, 68 | padding: Union[str, Sequence[Tuple[int, int]]] = "SAME", 69 | with_bias: bool = True, 70 | w_init: Optional[hk.initializers.Initializer] = None, 71 | b_init: Optional[hk.initializers.Initializer] = None, 72 | data_format: str = "NWC", 73 | name: Optional[str] = None, 74 | ): 75 | """Construct a Separable 2D Depthwise Convolution module. 76 | 77 | Args: 78 | channel_multiplier: Multiplicity of output channels. To keep the number of 79 | output channels the same as the number of input channels, set 1. 80 | kernel_shape: The shape of the kernel. Either an integer or a sequence of 81 | length ``num_spatial_dims``. 82 | stride: Optional stride for the kernel. Either an integer or a sequence of 83 | length ``num_spatial_dims``. Defaults to 1. 84 | padding: Optional padding algorithm. Either ``VALID``, ``SAME`` or a 85 | sequence of ``before, after`` pairs. Defaults to ``SAME``. See: 86 | https://www.tensorflow.org/xla/operation_semantics#conv_convolution. 87 | with_bias: Whether to add a bias. By default, true. 88 | w_init: Optional weight initialization. By default, truncated normal. 89 | b_init: Optional bias initialization. By default, zeros. 90 | data_format: The data format of the input. Can be either 91 | ``channels_first``, ``channels_last``, ``N...C`` or ``NC...``. By 92 | default, ``channels_last``. 93 | name: The name of the module. 94 | """ 95 | super().__init__(name=name) 96 | self._conv1 = hk.DepthwiseConv1D( 97 | channel_multiplier=channel_multiplier, 98 | kernel_shape=kernel_shape, 99 | stride=stride, 100 | padding=padding, 101 | with_bias=False, 102 | w_init=w_init, 103 | b_init=b_init, 104 | data_format=data_format, 105 | ) 106 | 107 | self._conv2 = hk.Conv1D( 108 | output_channels=1536, 109 | kernel_shape=1, 110 | stride=1, 111 | padding=padding, 112 | with_bias=with_bias, 113 | w_init=w_init, 114 | b_init=b_init, 115 | data_format=data_format, 116 | ) 117 | 118 | def __call__(self, inputs: jax.Array) -> jax.Array: 119 | 120 | x = self._conv1(inputs) 121 | x = self._conv2(x) 122 | 123 | return x 124 | 125 | 126 | class ConvBlock(hk.Module): 127 | """ 128 | Conv Block. 129 | """ 130 | 131 | def __init__( 132 | self, 133 | dim: int, 134 | dim_out: Optional[int] = None, 135 | kernel_size: int = 1, 136 | name: Optional[str] = None, 137 | ): 138 | """ 139 | Args: 140 | dim: input dimension. 141 | dim_out: output dimension. 142 | kernel_size: kernel's size. 143 | name: model's name. 144 | """ 145 | super().__init__(name=name) 146 | self._dim = dim 147 | self._dim_out = dim_out 148 | self._kernel_size = kernel_size 149 | 150 | def __call__(self, x: jnp.ndarray, is_training: bool = False) -> jnp.ndarray: 151 | batch_norm = hk.BatchNorm( 152 | create_scale=True, 153 | create_offset=True, 154 | decay_rate=0.99, 155 | eps=0.001, 156 | data_format="NC...", 157 | ) 158 | conv = hk.Conv1D( 159 | output_channels=self._dim if self._dim_out is None else self._dim_out, 160 | kernel_shape=self._kernel_size, 161 | padding=(self._kernel_size // 2, self._kernel_size // 2), 162 | data_format="NCW", 163 | ) 164 | 165 | x = batch_norm(x, is_training=is_training) 166 | x = jax.nn.gelu(x) 167 | x = conv(x) 168 | return x 169 | 170 | 171 | class UNetConvBlock(hk.Module): 172 | """ 173 | Conv Block. 174 | """ 175 | 176 | def __init__( 177 | self, 178 | dim: int, 179 | name: Optional[str] = None, 180 | ): 181 | """ 182 | Args: 183 | dim: input dimension. 184 | dim_out: output dimension. 185 | kernel_size: kernel's size. 186 | name: model's name. 187 | """ 188 | super().__init__(name=name) 189 | self._dim = dim 190 | 191 | def __call__( 192 | self, x: jnp.ndarray, unet_repr: jnp.ndarray, is_training: bool = False 193 | ) -> jnp.ndarray: 194 | batch_norm = hk.BatchNorm( 195 | create_scale=True, 196 | create_offset=True, 197 | decay_rate=0.99, 198 | eps=0.001, 199 | data_format="NC...", 200 | name="batch_norm", 201 | ) 202 | batch_norm_unet_repr = hk.BatchNorm( 203 | create_scale=True, 204 | create_offset=True, 205 | decay_rate=0.99, 206 | eps=0.001, 207 | data_format="NC...", 208 | name="batch_norm_unet_repr", 209 | ) 210 | 211 | linear = hk.Linear(output_size=self._dim, with_bias=True, name="linear") 212 | linear_unet = hk.Linear( 213 | output_size=self._dim, with_bias=True, name="linear_unet" 214 | ) 215 | separable_conv = SeparableDepthwiseConv1D( 216 | channel_multiplier=1, kernel_shape=3, with_bias=False, data_format="NCW" 217 | ) 218 | 219 | x = batch_norm(x, is_training=is_training) 220 | unet_repr = batch_norm_unet_repr(unet_repr, is_training=is_training) 221 | 222 | x = jax.nn.gelu(x) 223 | unet_repr = jax.nn.gelu(unet_repr) 224 | 225 | x = jnp.transpose(x, axes=(0, 2, 1)) 226 | unet_repr = jnp.transpose(unet_repr, axes=(0, 2, 1)) 227 | 228 | x = linear(x) 229 | unet_repr = linear_unet(unet_repr) 230 | 231 | x = jnp.transpose(x, axes=(0, 2, 1)) 232 | unet_repr = jnp.transpose(unet_repr, axes=(0, 2, 1)) 233 | 234 | # Upsample x 235 | x = jax.image.resize( 236 | x, 237 | shape=(x.shape[0], x.shape[1], x.shape[2] * 2), 238 | method="nearest", 239 | ) 240 | 241 | x = x + unet_repr 242 | 243 | x = separable_conv(x) 244 | 245 | return x 246 | -------------------------------------------------------------------------------- /nucleotide_transformer/heads.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Tuple 2 | 3 | import haiku as hk 4 | import jax 5 | import jax.numpy as jnp 6 | from haiku import initializers 7 | 8 | from nucleotide_transformer.types import SequenceMask 9 | from nucleotide_transformer.utils import get_activation_fn 10 | 11 | 12 | class DownSample1D(hk.Module): 13 | """ 14 | 1D-UNET downsampling block. 15 | """ 16 | 17 | def __init__( 18 | self, 19 | output_channels: int, 20 | activation_fn: str = "swish", 21 | num_layers: int = 2, 22 | name: Optional[str] = None, 23 | ): 24 | """ 25 | Args: 26 | output_channels: number of output channels. 27 | activation_fn: name of the activation function to use. 28 | Should be one of "gelu", 29 | "gelu-no-approx", "relu", "swish", "silu", "sin". 30 | num_layers: number of convolution layers. 31 | name: module name. 32 | """ 33 | 34 | super().__init__(name=name) 35 | 36 | self._conv_layers = [ 37 | hk.Conv1D( 38 | output_channels=output_channels, 39 | kernel_shape=3, 40 | stride=1, 41 | rate=1, 42 | padding="SAME", 43 | data_format="NWC", 44 | name=f"conv{i}", 45 | ) 46 | for i in range(num_layers) 47 | ] 48 | 49 | self._avg_pool = hk.AvgPool( 50 | window_shape=2, 51 | strides=2, 52 | padding="SAME", 53 | channel_axis=-1, 54 | ) 55 | 56 | self._activation_fn = get_activation_fn(activation_name=activation_fn) 57 | 58 | def __call__(self, x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: 59 | for _, conv_layer in enumerate(self._conv_layers): 60 | x = self._activation_fn(conv_layer(x)) 61 | hidden = x 62 | x = self._avg_pool(hidden) 63 | return x, hidden 64 | 65 | 66 | class UpSample1D(hk.Module): 67 | """ 68 | 1D-UNET upsampling block. 69 | """ 70 | 71 | def __init__( 72 | self, 73 | output_channels: int, 74 | activation_fn: str = "swish", 75 | num_layers: int = 2, 76 | interpolation_method: str = "nearest", 77 | name: Optional[str] = None, 78 | ): 79 | """ 80 | Args: 81 | output_channels: number of output channels. 82 | activation_fn: name of the activation function to use. 83 | Should be one of "gelu", 84 | "gelu-no-approx", "relu", "swish", "silu", "sin". 85 | interpolation_method: Method to be used for upsampling interpolation. 86 | Should be one of "nearest", "linear", "cubic", "lanczos3", "lanczos5". 87 | num_layers: number of convolution layers. 88 | name: module name. 89 | """ 90 | super().__init__(name=name) 91 | 92 | self._conv_layers = [ 93 | hk.Conv1DTranspose( 94 | output_channels=output_channels, 95 | kernel_shape=3, 96 | stride=1, 97 | padding="SAME", 98 | data_format="NWC", 99 | name=f"conv_transpose{i}", 100 | ) 101 | for i in range(num_layers) 102 | ] 103 | 104 | self._interpolation_method = interpolation_method 105 | self._activation_fn = get_activation_fn(activation_name=activation_fn) 106 | 107 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 108 | for _, conv_layer in enumerate(self._conv_layers): 109 | x = self._activation_fn(conv_layer(x)) 110 | x = jax.image.resize( 111 | x, 112 | shape=(x.shape[0], 2 * x.shape[1], x.shape[2]), 113 | method=self._interpolation_method, 114 | ) 115 | return x 116 | 117 | 118 | class FinalConv1D(hk.Module): 119 | """ 120 | Final output block of the 1D-UNET. 121 | """ 122 | 123 | def __init__( 124 | self, 125 | output_channels: int, 126 | activation_fn: str = "swish", 127 | num_layers: int = 2, 128 | name: Optional[str] = None, 129 | ): 130 | """ 131 | Args: 132 | output_channels: number of output channels. 133 | activation_fn: name of the activation function to use. 134 | Should be one of "gelu", 135 | "gelu-no-approx", "relu", "swish", "silu", "sin". 136 | num_layers: number of convolution layers. 137 | name: module name. 138 | """ 139 | super().__init__(name=name) 140 | 141 | self._conv_layers = [ 142 | hk.Conv1D( 143 | output_channels=output_channels, 144 | kernel_shape=3, 145 | stride=1, 146 | rate=1, 147 | padding="SAME", 148 | data_format="NWC", 149 | name=f"conv{i}", 150 | ) 151 | for i in range(num_layers) 152 | ] 153 | 154 | self._activation_fn = get_activation_fn(activation_name=activation_fn) 155 | 156 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 157 | for i, conv_layer in enumerate(self._conv_layers): 158 | x = conv_layer(x) 159 | if i < len(self._conv_layers) - 1: 160 | x = self._activation_fn(x) 161 | return x 162 | 163 | 164 | class UNET1DSegmentationHead(hk.Module): 165 | """ 166 | 1D-UNET based head to be plugged on top of a pretrained model to perform 167 | semantic segmentation. 168 | """ 169 | 170 | def __init__( 171 | self, 172 | num_classes: int, 173 | output_channels_list: Tuple[int, ...] = (64, 128, 256), 174 | activation_fn: str = "swish", 175 | num_conv_layers_per_block: int = 2, 176 | upsampling_interpolation_method: str = "nearest", 177 | name: Optional[str] = None, 178 | ): 179 | """ 180 | Args: 181 | num_classes: number of classes to segment 182 | output_channels_list: list of the number of output channel at each level of 183 | the UNET 184 | activation_fn: name of the activation function to use. 185 | Should be one of "gelu", 186 | "gelu-no-approx", "relu", "swish", "silu", "sin". 187 | num_conv_layers_per_block: number of convolution layers per block. 188 | upsampling_interpolation_method: Method to be used for 189 | interpolation in upsampling blocks. Should be one of "nearest", 190 | "linear", "cubic", "lanczos3", "lanczos5". 191 | name: module name. 192 | """ 193 | super().__init__(name=name) 194 | self._num_pooling_layers = len(output_channels_list) 195 | self._downsample_blocks = [ 196 | DownSample1D( 197 | output_channels=output_channels, 198 | activation_fn=activation_fn, 199 | num_layers=num_conv_layers_per_block, 200 | name=f"downsample_block_{i}", 201 | ) 202 | for i, output_channels in enumerate(output_channels_list) 203 | ] 204 | 205 | self._upsample_blocks = [ 206 | UpSample1D( 207 | output_channels=output_channels, 208 | activation_fn=activation_fn, 209 | num_layers=num_conv_layers_per_block, 210 | interpolation_method=upsampling_interpolation_method, 211 | name=f"upsample_block_{i}", 212 | ) 213 | for i, output_channels in enumerate(reversed(output_channels_list)) 214 | ] 215 | 216 | self._final_block = FinalConv1D( 217 | activation_fn=activation_fn, 218 | output_channels=num_classes * 2, 219 | num_layers=num_conv_layers_per_block, 220 | ) 221 | 222 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 223 | 224 | if x.shape[1] % 2**self._num_pooling_layers: 225 | raise ValueError( 226 | "Input length must be divisible by the 2 to the power of" 227 | " number of poolign layers." 228 | ) 229 | 230 | hiddens = [] 231 | for downsample_block in self._downsample_blocks: 232 | x, hidden = downsample_block(x) 233 | hiddens.append(hidden) 234 | 235 | for upsample_block, hidden in zip(self._upsample_blocks, reversed(hiddens)): 236 | x = upsample_block(x) + hidden 237 | 238 | x = self._final_block(x) 239 | return x 240 | 241 | 242 | class UNetHead(hk.Module): 243 | """ 244 | Returns a probability between 0 and 1 over a target feature presence 245 | for each nucleotide in the input sequence. Assumes the sequence has been tokenized 246 | with non-overlapping 6-mers. 247 | """ 248 | 249 | def __init__( 250 | self, 251 | num_features: int, 252 | embed_dimension: int = 1024, 253 | num_layers: int = 2, 254 | name: Optional[str] = None, 255 | ): 256 | """ 257 | Args: 258 | name: Name of the layer. Defaults to None. 259 | """ 260 | super().__init__(name=name) 261 | self._num_features = num_features 262 | 263 | w_init = initializers.VarianceScaling(2.0, "fan_in", "uniform") 264 | b_init = initializers.VarianceScaling(2.0, "fan_in", "uniform") 265 | unet = UNET1DSegmentationHead( 266 | num_classes=embed_dimension // 2, 267 | output_channels_list=tuple( 268 | embed_dimension * (2**i) for i in range(num_layers) 269 | ), 270 | ) 271 | fc = hk.Linear( 272 | 6 * 2 * self._num_features, w_init=w_init, b_init=b_init, name="fc" 273 | ) 274 | self._fc = hk.Sequential([unet, jax.nn.swish, fc]) 275 | 276 | def __call__( 277 | self, x: jnp.ndarray, sequence_mask: SequenceMask 278 | ) -> Dict[str, jnp.ndarray]: 279 | """ 280 | Input shape: (batch_size, sequence_length + 1, embed_dim) 281 | Output_shape: (batch_size, 6 * sequence_length, 2) 282 | """ 283 | batch_size, seq_len = x.shape[0], x.shape[1] - 1 284 | logits = self._fc(x[:, 1:]) # remove CLS token 285 | logits = jnp.reshape(logits, (batch_size, seq_len * 6, self._num_features, 2)) 286 | return {"logits": logits} 287 | -------------------------------------------------------------------------------- /nucleotide_transformer/mojo/model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Callable, Optional 3 | 4 | import haiku as hk 5 | import jax 6 | import jax.numpy as jnp 7 | import jmp 8 | 9 | from nucleotide_transformer.layers import RotaryEmbeddingConfig, SelfAttentionBlock 10 | from nucleotide_transformer.mojo.config import MOJOConfig 11 | from nucleotide_transformer.mojo.layers import ( 12 | ConvBlock, 13 | DeConvBlock, 14 | ResidualConvBlock, 15 | ResidualDeConvBlock, 16 | ) 17 | from nucleotide_transformer.types import AttentionMask, Embedding, Tokens 18 | 19 | 20 | class MOJO(hk.Module): 21 | def __init__( 22 | self, 23 | config: MOJOConfig, 24 | name: Optional[str] = None, 25 | ): 26 | super().__init__(name=name) 27 | self._config = config 28 | self._embedding_layers = {} 29 | self._embedding_layers = { 30 | omic: hk.Embed( 31 | self._config.alphabet_size[omic], 32 | self._config.token_embed_dim, 33 | name=f"{omic}_embedding", 34 | ) 35 | for omic in self._config.alphabet_size 36 | } 37 | 38 | self._gene_embedding_layer = hk.Embed( 39 | self._config.fixed_sequence_length, 40 | self._config.init_gene_embed_dim, 41 | name="gene_embedding", 42 | ) 43 | self._fc_gene_embedding = hk.Linear(self._config.token_embed_dim) 44 | 45 | self._rotary_embedding_config = RotaryEmbeddingConfig(rescaling_factor=None) 46 | 47 | @hk.transparent 48 | def stem(self, x: jnp.ndarray) -> jnp.ndarray: 49 | with hk.experimental.name_scope("stem"): 50 | conv = hk.Conv1D( 51 | output_channels=self._config.conv_init_embed_dim, 52 | kernel_shape=self._config.stem_kernel_shape, 53 | padding="SAME", 54 | data_format="NWC", 55 | ) 56 | if self._config.use_remat_in_convs: 57 | conv = hk.remat(conv) 58 | 59 | x = conv(x) 60 | x = jax.nn.gelu(x) 61 | 62 | return x 63 | 64 | @hk.transparent 65 | def conv_tower(self, x: jnp.ndarray) -> tuple[jnp.ndarray, list[jnp.ndarray]]: 66 | filter_list = copy.deepcopy(self._config.filter_list) 67 | residuals = [] 68 | 69 | for i, (dim_in, dim_out) in enumerate(zip(filter_list[:-1], filter_list[1:])): 70 | with hk.experimental.name_scope(f"conv_block_{i}"): 71 | conv, res_conv = self._conv_block(dim_in, dim_out) 72 | avg_pool = hk.AvgPool(window_shape=2, strides=2, padding="SAME") 73 | if self._config.use_remat_in_convs: 74 | conv = hk.remat(conv) 75 | res_conv = hk.remat(res_conv) 76 | avg_pool = hk.remat(avg_pool) 77 | 78 | residuals.append(x) 79 | x = conv(x) 80 | x = res_conv(x) 81 | x = avg_pool(x) 82 | 83 | return x, residuals 84 | 85 | @hk.transparent 86 | def deconv_tower(self, x: jnp.ndarray, residuals: list[jnp.ndarray]) -> jnp.ndarray: 87 | filter_list = copy.deepcopy(self._config.filter_list) 88 | filter_list.reverse() 89 | residuals_generator = reversed(residuals) 90 | 91 | for i, (dim_in, dim_out) in enumerate(zip(filter_list[:-1], filter_list[1:])): 92 | with hk.experimental.name_scope(f"deconv_block_{i}"): 93 | conv, res_conv = self._deconv_block(dim_in, dim_out) 94 | if self._config.use_remat_in_convs: 95 | conv = hk.remat(conv) 96 | res_conv = hk.remat(res_conv) 97 | 98 | x = conv(x) 99 | x = res_conv(x) 100 | 101 | if self._config.use_skip_connection: 102 | residuals = next(residuals_generator) 103 | x = x + residuals 104 | 105 | return x 106 | 107 | @hk.transparent 108 | def transformer_tower( 109 | self, 110 | x: Embedding, 111 | outs: dict[str, Embedding], 112 | attention_mask: Optional[AttentionMask] = None, 113 | ) -> tuple[Embedding, dict[str, Embedding]]: 114 | 115 | layers: list[Callable] = [ 116 | self._attention_block(layer_idx) 117 | for layer_idx in range(self._config.num_layers) 118 | ] 119 | 120 | if self._config.use_remat_in_transformer: 121 | layers = [hk.remat(layer) for layer in layers] 122 | 123 | for layer_idx, layer in enumerate(layers): 124 | output = layer( 125 | x=x, attention_mask=attention_mask, attention_weight_bias=None 126 | ) 127 | x = output["embeddings"] 128 | 129 | if (layer_idx + 1) in self._config.embeddings_layers_to_save: 130 | outs[f"embeddings_{(layer_idx + 1)}"] = output["embeddings"] 131 | if (layer_idx + 1) in self._config.attention_layers_to_save: 132 | outs[f"attention_map_layer_{layer_idx + 1}"] = output[ 133 | "attention_weights" 134 | ] 135 | 136 | return x, outs 137 | 138 | @hk.transparent 139 | def omic_lm_heads(self, x: jnp.ndarray, omic: str) -> jnp.ndarray: 140 | x = jax.nn.gelu(x) 141 | for i in range(self._config.num_hidden_layers_head): 142 | name = f"{omic}_head_linear_{i}" 143 | x = hk.Linear(self._config.embed_dim, name=name)(x) 144 | x = jax.nn.gelu(x) 145 | name = f"{omic}_head_linear_final" 146 | head = hk.Linear(self._config.alphabet_size[omic], name=name) 147 | return head(x) 148 | 149 | def get_embeddings( 150 | self, 151 | tokens: dict[str, Tokens], 152 | attention_masks: Optional[dict[str, Optional[AttentionMask]]], 153 | ) -> dict[str, Embedding]: 154 | omic_embeddings = {} 155 | for omic, omic_tokens in tokens.items(): 156 | omic_embeddings[omic] = self._embedding_layers[omic](omic_tokens) 157 | 158 | return omic_embeddings 159 | 160 | @hk.transparent 161 | def _conv_block(self, dim_in: int, dim_out: int) -> tuple[hk.Module, hk.Module]: 162 | conv = ConvBlock( 163 | dim=dim_in, 164 | dim_out=dim_out, 165 | kernel_size=5, 166 | ) 167 | res_conv = ResidualConvBlock( 168 | dim=dim_out, 169 | dim_out=dim_out, 170 | kernel_size=1, 171 | ) 172 | return conv, res_conv 173 | 174 | @hk.transparent 175 | def _deconv_block(self, dim_in: int, dim_out: int) -> tuple[hk.Module, hk.Module]: 176 | conv = DeConvBlock( 177 | dim=dim_in, 178 | dim_out=dim_out, 179 | kernel_size=5, 180 | stride=2, 181 | ) 182 | res_conv = ResidualDeConvBlock( 183 | dim=dim_out, 184 | dim_out=dim_out, 185 | kernel_size=1, 186 | ) 187 | return conv, res_conv 188 | 189 | @hk.transparent 190 | def _attention_block(self, layer_idx: int) -> SelfAttentionBlock: 191 | return SelfAttentionBlock( # type: ignore 192 | num_heads=self._config.num_attention_heads, 193 | embed_dim=self._config.embed_dim, 194 | key_size=self._config.key_size, 195 | ffn_embed_dim=self._config.ffn_embed_dim, 196 | add_bias_kv=False, 197 | add_bias_fnn=False, 198 | ffn_activation_name="swish", 199 | use_glu_in_ffn=True, 200 | rotary_embedding_config=self._rotary_embedding_config, 201 | layer_norm_eps=self._config.layer_norm_eps, 202 | pre_layer_norm=True, 203 | name=f"attention_layer_{layer_idx}", 204 | ) 205 | 206 | def __call__( 207 | self, 208 | tokens_dict: dict[str, Tokens], 209 | attention_masks: Optional[dict[str, Optional[AttentionMask]]] = None, 210 | ) -> dict[str, Embedding]: 211 | outs: dict[str, dict[str, jnp.ndarray] | jnp.ndarray | list[jnp.ndarray]] = {} 212 | 213 | embeddings = self.get_embeddings(tokens_dict, attention_masks) 214 | outs["omic_embeddings"] = embeddings 215 | 216 | x = jnp.sum(jnp.array(list(embeddings.values())), axis=0) 217 | outs["embeddings"] = x 218 | 219 | if self._config.use_gene_embedding: 220 | gene_embedding = self._gene_embedding_layer( 221 | jnp.arange(self._config.fixed_sequence_length) 222 | ) 223 | if self._config.project_gene_embedding: 224 | gene_embedding = self._fc_gene_embedding(gene_embedding) 225 | x = x + gene_embedding 226 | outs["embeddings_with_gene_embedding"] = x 227 | 228 | x = self.stem(x) 229 | outs["stem"] = x 230 | 231 | x, residuals = self.conv_tower(x) 232 | outs["conv_tower"] = x 233 | outs["conv_tower_residuals"] = residuals 234 | 235 | x, outs = self.transformer_tower(x, outs=outs, attention_mask=None) 236 | outs["after_transformer_embedding"] = x 237 | 238 | x = self.deconv_tower(x, residuals) 239 | outs["deconv_tower"] = x 240 | 241 | outs["logits"] = { 242 | omic: self.omic_lm_heads(x, omic) for omic in self._config.alphabet_size 243 | } 244 | 245 | return outs 246 | 247 | 248 | def build_mojo_fn( 249 | model_config: MOJOConfig, 250 | compute_dtype: jnp.dtype = jnp.float32, 251 | param_dtype: jnp.dtype = jnp.float32, 252 | output_dtype: jnp.dtype = jnp.float32, 253 | model_name: Optional[str] = None, 254 | ) -> Callable: 255 | assert {compute_dtype, param_dtype, output_dtype}.issubset( 256 | { 257 | jnp.bfloat16, 258 | jnp.float32, 259 | jnp.float16, 260 | } 261 | ), f"provide a dtype in {jnp.bfloat16, jnp.float32, jnp.float16}" 262 | 263 | policy = jmp.Policy( 264 | compute_dtype=compute_dtype, param_dtype=param_dtype, output_dtype=output_dtype 265 | ) 266 | hk.mixed_precision.set_policy(MOJO, policy) 267 | 268 | # Remove it in batch norm to avoid instabilities 269 | norm_policy = jmp.Policy( 270 | compute_dtype=jnp.float32, param_dtype=param_dtype, output_dtype=compute_dtype 271 | ) 272 | hk.mixed_precision.set_policy(hk.LayerNorm, norm_policy) 273 | hk.mixed_precision.set_policy(hk.BatchNorm, norm_policy) 274 | 275 | def multiomics_lm_fn( 276 | tokens_dict: dict[str, Tokens], 277 | attention_masks: Optional[dict[str, Optional[AttentionMask]]] = None, 278 | ) -> dict[str, jnp.ndarray]: 279 | model = MOJO(model_config, name=model_name) 280 | return model(tokens_dict, attention_masks) 281 | 282 | return multiomics_lm_fn 283 | -------------------------------------------------------------------------------- /docs/segment_nt.md: -------------------------------------------------------------------------------- 1 | ## Segmentation models 2 | 3 | Segmentation models using transformer backbones (Nucleotide Transformers, Enformer, Borzoi) for predicting genomic elements at single-nucleotide resolution. SegmentNT, for instance, predicts 14 different classes of human genomic elements in sequences up to 30kb (generalizing to 50kbp) and demonstrates superior performance. 4 | 5 | All models are used with a 1-dimensional U-Net segmentation head to predict the location of several types of genomics elements in a sequence at a single nucleotide resolution. These include gene (protein-coding genes, lncRNAs, 5’UTR, 3’UTR, exon, intron, splice acceptor and donor sites) and regulatory (polyA signal, tissue-invariant and tissue-specific promoters and enhancers, and CTCF-bound sites) elements. 6 | 7 | * 📜 **[Read the Paper (Nature Methods 2025)](https://www.nature.com/articles/s41592-025-02881-2)** 8 | * 🤗 **[SegmentNT Hugging Face Collection](https://huggingface.co/collections/InstaDeepAI/segmentnt-65eb4941c57808b4a3fe1319)** 9 | * 🚀 **[SegmentNT Inference Notebook (HF)](https://colab.research.google.com/#fileId=https%3A//huggingface.co/InstaDeepAI/segment_nt/blob/main/inference_segment_nt.ipynb)** 10 | 11 | Performance on downstream tasks 12 | 13 | *Fig. 1: SegmentNT localizes genomics elements at nucleotide resolution.* 14 | 15 | ## How to use 🚀 16 | 17 | ### SegmentNT 18 | 19 | ⚠️ The SegmentNT models leverage the [Nucleotide Transformer (NT)](https://www.nature.com/articles/s41592-024-02523-z) backbone and have been trained on a sequences of 30,000 nucleotides, or 5001 tokens (accounting for the CLS token). However, SegmentNT has been shown to generalize up to sequences of 50,000 bp. For training on 30,000 bps, which is a length 20 | superior than the maximum length of 2048 6-mers tokens that the nucleotide transformer can handle, Yarn rescaling is employed. 21 | By default, the `rescaling factor` is set to the one used during the training. In case you need to infer on sequences between 30kbp and 50kbp, make sure to pass the `rescaling_factor` argument in the `get_pretrained_segment_nt_model` function with 22 | the value `rescaling_factor = max_num_nucleotides / max_num_tokens_nt` where `num_dna_tokens_inference` is the number of tokens at inference (i.e 6669 for a sequence of 40008 base pairs) and `max_num_tokens_nt` is the max number of tokens on which the backbone nucleotide-transformer was trained on, i.e `2048`. 23 | 24 | 🔍 The notebook `../notebooks/segment_nt/inference_segment_nt.ipynb` showcases how to infer on a 50kb sequence and plot the probabilities to reproduce the Fig.3 of the paper. 25 | 26 | 🚧 The SegmentNT models do not handle any "N" in the input sequence because each nucleotides need to be tokenized as 6-mers, which can not be the case when using sequences containing one or multiple "N" base pairs. 27 | 28 | ```python 29 | import haiku as hk 30 | import jax 31 | import jax.numpy as jnp 32 | from nucleotide_transformer.pretrained import get_pretrained_segment_nt_model 33 | 34 | # Initialize CPU as default JAX device. This makes the code robust to memory leakage on 35 | # the devices. 36 | jax.config.update("jax_platform_name", "cpu") 37 | 38 | backend = "cpu" 39 | devices = jax.devices(backend) 40 | num_devices = len(devices) 41 | print(f"Devices found: {devices}") 42 | 43 | # The number of DNA tokens (excluding the CLS token prepended) needs to be dividible by 44 | # 2 to the power of the number of downsampling block, i.e 4. 45 | max_num_nucleotides = 8 46 | 47 | assert max_num_nucleotides % 4 == 0, ( 48 | "The number of DNA tokens (excluding the CLS token prepended) needs to be dividible by" 49 | "2 to the power of the number of downsampling block, i.e 4.") 50 | 51 | parameters, forward_fn, tokenizer, config = get_pretrained_segment_nt_model( 52 | model_name="segment_nt", 53 | embeddings_layers_to_save=(29,), 54 | attention_maps_to_save=((1, 4), (7, 10)), 55 | max_positions=max_num_nucleotides + 1, 56 | ) 57 | forward_fn = hk.transform(forward_fn) 58 | apply_fn = jax.pmap(forward_fn.apply, devices=devices, donate_argnums=(0,)) 59 | 60 | 61 | # Get data and tokenize it 62 | sequences = ["ATTCCGATTCCGATTCCAACGGATTATTCCGATTAACCGATTCCAATT", "ATTTCTCTCTCTCTCTGAGATCGATGATTTCTCTCTCATCGAACTATG"] 63 | tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)] 64 | tokens = jnp.asarray(tokens_ids, dtype=jnp.int32) 65 | 66 | random_key = jax.random.PRNGKey(seed=0) 67 | keys = jax.device_put_replicated(random_key, devices=devices) 68 | parameters = jax.device_put_replicated(parameters, devices=devices) 69 | tokens = jax.device_put_replicated(tokens, devices=devices) 70 | 71 | # Infer on the sequence 72 | outs = apply_fn(parameters, keys, tokens) 73 | # Obtain the logits over the genomic features 74 | logits = outs["logits"] 75 | # Transform them in probabilities 76 | probabilities = jnp.asarray(jax.nn.softmax(logits, axis=-1))[...,-1] 77 | print(f"Probabilities shape: {probabilities.shape}") 78 | 79 | print(f"Features inferred: {config.features}") 80 | 81 | # Get probabilities associated with intron 82 | idx_intron = config.features.index("intron") 83 | probabilities_intron = probabilities[..., idx_intron] 84 | print(f"Intron probabilities shape: {probabilities_intron.shape}") 85 | ``` 86 | 87 | Supported model names are: 88 | - **segment_nt** 89 | - **segment_nt_multi_species** 90 | 91 | The code runs both on GPU and TPU thanks to Jax! 92 | 93 | --- 94 | ### SegmentEnformer 95 | 96 | SegmentEnformer leverages [Enformer](https://www.nature.com/articles/s41592-021-01252-x) by removing the prediction head and replacing it by a 1-dimensional U-Net segmentation head to predict the location of several types of genomics elements in a sequence at a single nucleotide resolution. 97 | 98 | 🔍 The notebook `../notebooks/segment_nt/inference_segment_enformer.ipynb` showcases how to infer on a 196,608bp sequence and plot the probabilities. 99 | 100 | ```python 101 | import haiku as hk 102 | import jax 103 | import jax.numpy as jnp 104 | import numpy as np 105 | 106 | from nucleotide_transformer.enformer.pretrained import get_pretrained_segment_enformer_model 107 | from nucleotide_transformer.enformer.features import FEATURES 108 | 109 | # Initialize CPU as default JAX device. This makes the code robust to memory leakage on 110 | # the devices. 111 | jax.config.update("jax_platform_name", "cpu") 112 | 113 | backend = "cpu" 114 | devices = jax.devices(backend) 115 | num_devices = len(devices) 116 | 117 | # Load model 118 | parameters, state, forward_fn, tokenizer, config = get_pretrained_segment_enformer_model() 119 | forward_fn = hk.transform_with_state(forward_fn) 120 | 121 | apply_fn = jax.pmap(forward_fn.apply, devices=devices, donate_argnums=(0,)) 122 | random_key = jax.random.PRNGKey(seed=0) 123 | 124 | # Replicate over devices 125 | keys = jax.device_put_replicated(random_key, devices=devices) 126 | parameters = jax.device_put_replicated(parameters, devices=devices) 127 | state = jax.device_put_replicated(state, devices=devices) 128 | 129 | # Get data and tokenize it 130 | sequences = ["A" * 196_608] 131 | tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)] 132 | tokens = jnp.stack([jnp.asarray(tokens_ids, dtype=jnp.int32)] * num_devices, axis=0) 133 | 134 | # Infer 135 | outs, state = apply_fn(parameters, state, keys, tokens) 136 | 137 | # Obtain the logits over the genomic features 138 | logits = outs["logits"] 139 | # Transform them on probabilities 140 | probabilities = np.asarray(jax.nn.softmax(logits, axis=-1))[..., -1] 141 | 142 | # Get probabilities associated with intron 143 | idx_intron = FEATURES.index("intron") 144 | probabilities_intron = probabilities[..., idx_intron] 145 | print(f"Intron probabilities shape: {probabilities_intron.shape}") 146 | ``` 147 | 148 | ### SegmentBorzoi 149 | SegmentBorzoi leverages [Borzoi](https://www.nature.com/articles/s41588-024-02053-6) by removing the prediction head and replacing it by a 1-dimensional U-Net segmentation head to predict the location of several types of genomics elements in a sequence. 150 | 151 | 🔍 The notebook `../notebooks/segment_nt/inference_segment_borzoi.ipynb` showcases how to infer on a 196608bp sequence and plot the probabilities. 152 | 153 | ```python 154 | import haiku as hk 155 | import jax 156 | import jax.numpy as jnp 157 | import numpy as np 158 | 159 | from nucleotide_transformer.borzoi.pretrained import get_pretrained_segment_borzoi_model 160 | from nucleotide_transformer.enformer.features import FEATURES 161 | 162 | # Initialize CPU as default JAX device. This makes the code robust to memory leakage on 163 | # the devices. 164 | jax.config.update("jax_platform_name", "cpu") 165 | 166 | backend = "cpu" 167 | devices = jax.devices(backend) 168 | num_devices = len(devices) 169 | 170 | # Load model 171 | parameters, state, forward_fn, tokenizer, config = get_pretrained_segment_borzoi_model() 172 | forward_fn = hk.transform_with_state(forward_fn) 173 | apply_fn = jax.pmap(forward_fn.apply, devices=devices, donate_argnums=(0,)) 174 | random_key = jax.random.PRNGKey(seed=0) 175 | 176 | # Replicate over devices 177 | keys = jax.device_put_replicated(random_key, devices=devices) 178 | parameters = jax.device_put_replicated(parameters, devices=devices) 179 | state = jax.device_put_replicated(state, devices=devices) 180 | 181 | # Get data and tokenize it 182 | sequences = ["A" * 524_288] 183 | tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)] 184 | tokens = jnp.stack([jnp.asarray(tokens_ids, dtype=jnp.int32)] * num_devices, axis=0) 185 | 186 | # Infer 187 | outs, state = apply_fn(parameters, state, keys, tokens) 188 | 189 | # Obtain the logits over the genomic features 190 | logits = outs["logits"] 191 | # Transform them on probabilities 192 | probabilities = np.asarray(jax.nn.softmax(logits, axis=-1))[..., -1] 193 | 194 | # Get probabilities associated with intron 195 | idx_intron = FEATURES.index("intron") 196 | probabilities_intron = probabilities[..., idx_intron] 197 | print(f"Intron probabilities shape: {probabilities_intron.shape}") 198 | 199 | ``` 200 | 201 | ## Citing our work 📚 202 | 203 | You can cite our models at: 204 | 205 | ```bibtex 206 | @Article{deAlmeida2025segmentNT, 207 | title={Annotating the genome at single-nucleotide resolution with DNA foundation models}, 208 | author={de Almeida, Bernardo P. and Dalla-Torre, Hugo and Richard, Guillaume and Blum, Christopher and Hexemer, Lorenz and G{\'e}lard, Maxence and Mendoza-Revilla, 209 | Javier and Tang, Ziqi and Marin, Frederikke I. and Emms, David M. and Pandey, Priyanka and Laurent, Stefan and Lopez, Marie and Laterre, Alexandre and Lang, Maren 210 | and {\c{S}}ahin, U{\u{g}}ur and Beguir, Karim and Pierrot, Thomas}, 211 | journal={Nature Methods}, 212 | year={2025}, 213 | month={Oct}, 214 | day={29}, 215 | issn={1548-7105}, 216 | doi={10.1038/s41592-025-02881-2}, 217 | url={https://doi.org/10.1038/s41592-025-02881-2} 218 | } 219 | ``` -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | InstaDeep AI for Genomics Logo 3 |

4 | 5 |

AI Foundation Models for Genomics

6 | 7 |

8 | A hub for InstaDeep's cutting-edge deep learning models and research for genomics, originating from the Nucleotide Transformer and its evolutions. 9 |

10 | 11 |

12 | License: CC BY-NC-SA 4.0 13 | 14 | 15 | Python 3.8 16 | 17 | Jax 0.3.25+ 18 | 19 | Hugging Face Models 20 | 21 |

22 | 23 | --- 24 | 25 | ## 🎯 Our Focus: Advancing Genomics with AI 26 | 27 | Welcome to the InstaDeep AI for Genomics repository! This is where we feature our collection of transformer-based genomic language models and innovative downstream applications. Our work in the genomics space began with **The Nucleotide Transformer**, developed in collaboration with Nvidia and TUM and trained on Cambridge-1, and has expanded to include projects like the **Agro Nucleotide Transformer** (in collaboration with Google, trained on TPU-v4 accelerators), **SegmentNT**, and **ChatNT**. 28 | 29 | Our mission is to provide the scientific community with powerful, reproducible, and accessible tools to unlock new insights from biological sequences. This repository serves as the central place for sharing our models, inference code, pre-trained weights, and research contributions in the genomics domain, with explorations into future areas like single-cell transcriptomics. 30 | 31 | We are thrilled to open-source these works and provide the community with access to the code and pre-trained weights for our diverse set of genomics language models and segmentation models. 32 | 33 | ## ✨ Featured Models & Research Evolutions 34 | 35 | This section highlights the key models and research directions from our team. Each entry provides a brief overview and links to detailed documentation, publications, and resources. *(Detailed code examples, setup for specific models, and in-depth figures are now located in their respective documentation pages within the `./docs` folder.)* 36 | 37 | --- 38 | 39 | ### 🧬 The Nucleotide Transformer (NT) 40 | 41 | Our foundational language models leverage DNA sequences from over 3,200 diverse human genomes and 850 genomes from a wide range of species. These models provide extremely accurate molecular phenotype prediction compared to existing methods. *This family includes multiple variants (e.g., 500M_human_ref, 2B5_1000G, NT-v2 series) which are detailed further in the specific documentation.* 42 | 43 | * **Keywords:** Foundational Model, Genomics, DNA/RNA, Pre-trained, Sequence Embeddings, Phenotype Prediction 44 | * ➡️ **[Model Details, Variants & Usage](./docs/nucleotide_transformer.md)** 45 | * 📜 **[Read the Paper (Nature Methods 2025)](https://www.nature.com/articles/s41592-024-02523-z)** 46 | * 🤗 **[Hugging Face Collection](https://huggingface.co/collections/InstaDeepAI/nucleotide-transformer-65099cdde13ff96230f2e592)** 47 | * 🚀 **Fine-tuning Notebooks (HF): ([LoRA](https://github.com/huggingface/notebooks/blob/main/examples/nucleotide_transformer_dna_sequence_modelling_with_peft.ipynb) and [regular](https://github.com/huggingface/notebooks/blob/main/examples/nucleotide_transformer_dna_sequence_modelling.ipynb))** 48 | 49 | --- 50 | 51 | ### 🌾 Agro Nucleotide Transformer (AgroNT) 52 | 53 | A novel foundational large language model trained on reference genomes from 48 plant species, with a predominant focus on crop species. AgroNT demonstrates state-of-the-art performance across several prediction tasks ranging from regulatory features, RNA processing, and gene expression in plants. 54 | 55 | * **Keywords:** Plant Genomics, Foundational Model, Crop Science, Gene Expression, Agriculture AI 56 | * ➡️ **[Model Details & Usage](./docs/agro_nucleotide_transformer.md)** 57 | * 📜 **[Read the Paper (Communications Biology 2024)](https://www.nature.com/articles/s42003-024-06465-2)** 58 | * 🤗 **[Hugging Face Collection](https://huggingface.co/collections/InstaDeepAI/agro-nucleotide-transformer-65b25c077cd0069ad6f6d344)** 59 | 60 | --- 61 | 62 | ### 🧩 SegmentNT (& family: SegmentEnformer, SegmentBorzoi) 63 | 64 | Segmentation models using transformer backbones (Nucleotide Transformers, Enformer, Borzoi) for predicting genomic elements at single-nucleotide resolution. SegmentNT, for instance, predicts 14 different classes of human genomic elements in sequences up to 30kb (generalizing to 50kbp) and demonstrates superior performance. 65 | 66 | * **Keywords:** Genome Segmentation, Single-Nucleotide Resolution, Genomic Elements, U-Net, Enformer, Borzoi 67 | * ➡️ **[Model Details & Usage](./docs/segment_nt.md)** (Covers SegmentNT, SegmentEnformer, SegmentBorzoi) 68 | * 📜 **[Read the Paper (Nature Methods 2025)](https://www.nature.com/articles/s41592-025-02881-2)** 69 | * 🤗 **[Hugging Face Collection](https://huggingface.co/collections/InstaDeepAI/segmentnt-65eb4941c57808b4a3fe1319)** 70 | * 🚀 **[SegmentNT Inference Notebook (HF)](https://colab.research.google.com/#fileId=https%3A//huggingface.co/InstaDeepAI/segment_nt/blob/main/inference_segment_nt.ipynb)** 71 | 72 | --- 73 | 74 | ### 💬 ChatNT 75 | 76 | A multimodal conversational agent designed with a deep understanding of DNA biological sequences, enabling interactive exploration and analysis of genomic data through natural language. 77 | 78 | * **Keywords:** Conversational AI, Multimodal, DNA Analysis, Genomics Chatbot, Interactive Biology 79 | * ➡️ **[Model Details & Usage](./docs/chat_nt.md)** 80 | * 📜 **[Read the Paper (Nature Machine Intelligence 2025)](https://www.nature.com/articles/s42256-025-01047-1)** 81 | * 🤗 **[ChatNT on Hugging Face](https://huggingface.co/InstaDeepAI/ChatNT)** 82 | * 🚀 **[ChatNT Inference Notebook (Jax)](./notebooks/chat_nt/inference.ipynb)** 83 | 84 | --- 85 | 86 | ### 3️⃣ Codon-NT (Exploring 3-mer Tokenization) 87 | 88 | A Nucleotide Transformer model variant trained on 3-mers (codons). This work investigates alternative tokenization strategies for genomic language models and their impact on downstream performance and interpretability. 89 | 90 | * **Keywords:** Genomics, Language Model, Codon, Tokenization, 3-mers, Nucleotide Transformer Variant 91 | * ➡️ **[Model Details & Usage](./docs/codon_nt.md)** 92 | * 📜 **[Read the Paper (Bioinformatics 2024)](https://academic.oup.com/bioinformatics/article/40/9/btae529/7745814)** 93 | * 🤗 **[Hugging Face Link](https://huggingface.co/InstaDeepAI/nucleotide-transformer-v2-50m-3mer-multi-species)** 94 | 95 | --- 96 | 97 | ### 🧬 Isoformer 98 | 99 | A model designed for learning isoform-aware embeddings directly from RNA-seq data, enabling a deeper understanding of transcript-specific expression and regulation. 100 | 101 | * **Keywords:** RNA-seq, Transcriptomics, Isoforms, Gene Expression, Embeddings 102 | * ➡️ **[Model Details & Usage](./docs/isoformer.md)** 103 | * 📜 **[Read the Paper (NeurIPS 2024)](https://papers.nips.cc/paper_files/paper/2024/file/8f6b3692297e49e5d5c91ba00281379c-Paper-Conference.pdf)** 104 | * 🤗 **[Hugging Face Link](https://huggingface.co/InstaDeepAI/isoformer)** 105 | * 🚀 **[Isoformer Inference Notebook (HF)](./notebooks/isoformer/inference.ipynb)** 106 | 107 | --- 108 | 109 | ### 🔬 sCT (single-Cell Transformer) 110 | 111 | Our foundational transformer model for single-cell and spatial transcriptomics data. sCT aims to learn rich representations from complex, high-dimensional single-cell datasets to improve various downstream analytical tasks. 112 | 113 | * **Keywords:** Single-cell RNA-seq, Spatial Transcriptomics, Foundational Model, Transformer, Gene Expression 114 | * ➡️ **[Model Details & Usage](./docs/sct.md)** 115 | * 📜 **[Read the Paper (OpenReview preprint)](https://openreview.net/forum?id=VdX9tL3VXH)** 116 | * 🤗 **[Hugging Face Link](https://huggingface.co/InstaDeepAI/sCellTransformer)** 117 | * 🚀 **[sCT Inference Notebook (HF)](./notebooks/sct/inference_sCT_pytorch_example.ipynb)** 118 | 119 | --- 120 | 121 | ### 🧪 BulkRNABert 122 | 123 | BulkRNABert is a transformer-based, encoder-only foundation model designed for bulk RNA-seq data. It learns biologically meaningful representations from large-scale transcriptomic profiles. 124 | 125 | * **Keywords:** Bulk RNA-seq, Foundational Model, Transformer, Cancer prognosis 126 | * ➡️ **[Model Details & Usage](./docs/bulk_rna_bert.md)** 127 | * 📜 **[Read the Paper (Machine Learning for Health 2024)](https://proceedings.mlr.press/v259/gelard25a.html)** 128 | * 🤗 **[Hugging Face Link](https://huggingface.co/InstaDeepAI/BulkRNABert)** 129 | * 🚀 **[BulkRNABert Inference Notebook (HF)](notebooks/bulk_rna_bert/inference_bulkrnabert_pytorch_example.ipynb)** 130 | 131 | --- 132 | 133 | ### 🔗 MOJO (Multi-Omics JOint representation) 134 | 135 | MOJO is a multimodal model designed learn embeddings of multi-omics data. It integrates bulk RNA-seq and DNA methylation data to generate powerful joint representations tailored for cancer-type classification and survival analysis. 136 | 137 | * **Keywords:** Bulk RNA-seq, DNA Methylation, Foundational Model, Transformer, Multimodal, Cancer prognosis 138 | * ➡️ **[Model Details & Usage](./docs/mojo.md)** 139 | * 📜 **[Read the Paper (ICML Workshop on Generative AI and Biology 2025)](https://www.biorxiv.org/content/10.1101/2025.06.25.661237v1)** 140 | * 🤗 **[Hugging Face Link](https://huggingface.co/InstaDeepAI/MOJO)** 141 | * 🚀 **[MOJO Inference Notebook (HF)](./notebooks/mojo/inference_mojo_pytorch_example.ipynb)** 142 | 143 | --- 144 | 145 | ## 💡 Why Choose InstaDeep's Genomic Models? 146 | 147 | * **Built on Strong Foundations:** Leveraging large-scale pre-training and diverse genomic datasets. 148 | * **Cutting-Edge Research:** Incorporating the latest advancements in deep learning for biological sequence analysis. 149 | * **High Performance:** Designed and validated to achieve state-of-the-art results on challenging genomic tasks. 150 | * **Open and Accessible:** We provide pre-trained weights, usage examples, and aim for easy integration into research workflows. 151 | * **Collaborative Spirit:** Developed with leading academic and industry partners. 152 | * **Focused Expertise:** Created by a dedicated team specializing in AI for genomics at InstaDeep. 153 | 154 | ## 🚀 Getting Started 155 | 156 | To begin using models from this repository: 157 | 158 | 1. **Clone the repository:** 159 | ```bash 160 | git clone https://github.com/instadeepai/nucleotide-transformer.git 161 | cd nucleotide-transformer 162 | ``` 163 | 2. **Set up your environment (virtual environment recommended):** 164 | ```bash 165 | python -m venv .venv 166 | source .venv/bin/activate # On Windows use `source .venv\Scripts\activate` 167 | ``` 168 | 3. **Install the package and dependencies:** 169 | ```bash 170 | pip install . # Installs the local package 171 | # Or, for a general requirements file if you have one: 172 | # pip install -r requirements.txt 173 | ``` 174 | 175 | For detailed instructions on individual models, including specific dependencies, downloading pre-trained weights, and Python usage examples, please refer to their dedicated documentation pages linked in the "Featured Models & Research Evolutions" section above (e.g., `./docs/nucleotide_transformer.md`). 176 | 177 | ## 🤝 Community & Support 178 | 179 | * **Questions & Bug Reports:** Please use the [GitHub Issues](https://github.com/instadeepai/nucleotide-transformer/issues) page. 180 | * **Discussions:** For broader discussions or questions, please use the [GitHub Discussions](https://github.com/instadeepai/nucleotide-transformer/discussions) tab (if enabled). 181 | * **Stay Updated:** Follow InstaDeep's official channels for announcements on new model releases and research updates. 182 | -------------------------------------------------------------------------------- /notebooks/chat_nt/inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "edrfY09jfn32" 7 | }, 8 | "source": [ 9 | "# Inference with ChatNT" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": { 15 | "id": "SOC2A0oIfn36" 16 | }, 17 | "source": [ 18 | "[![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/nucleotide-transformer/blob/main/examples/inference_chatNT.ipynb)" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "SWffCMcBfn37" 25 | }, 26 | "source": [ 27 | "## Installation and imports" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": { 34 | "colab": { 35 | "base_uri": "https://localhost:8080/" 36 | }, 37 | "id": "BtaCigg-fn37", 38 | "outputId": "555b9ef3-f72c-4957-a535-b9e4f53307c2" 39 | }, 40 | "outputs": [], 41 | "source": [ 42 | "!pip install boto3\n", 43 | "!pip install matplotlib\n", 44 | "!pip install biopython\n", 45 | "!pip install dm-haiku" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 2, 51 | "metadata": { 52 | "id": "alzkIxk9fn38" 53 | }, 54 | "outputs": [], 55 | "source": [ 56 | "import os\n", 57 | "\n", 58 | "try:\n", 59 | " import nucleotide_transformer\n", 60 | "except:\n", 61 | " !pip install git+https://github.com/instadeepai/nucleotide-transformer@main |tail -n 1\n", 62 | " import nucleotide_transformer\n", 63 | "\n", 64 | "if \"COLAB_TPU_ADDR\" in os.environ:\n", 65 | " from jax.tools import colab_tpu\n", 66 | "\n", 67 | " colab_tpu.setup_tpu()" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": { 74 | "colab": { 75 | "base_uri": "https://localhost:8080/" 76 | }, 77 | "id": "zkTU4k4_fn39", 78 | "outputId": "a04ca440-be95-49e1-b683-bf5b70d00777" 79 | }, 80 | "outputs": [ 81 | { 82 | "name": "stdout", 83 | "output_type": "stream", 84 | "text": [ 85 | "Devices found: [CpuDevice(id=0)]\n" 86 | ] 87 | } 88 | ], 89 | "source": [ 90 | "from Bio import SeqIO\n", 91 | "import gzip\n", 92 | "import haiku as hk\n", 93 | "import jax\n", 94 | "import jax.numpy as jnp\n", 95 | "import numpy as np\n", 96 | "import seaborn as sns\n", 97 | "from typing import List\n", 98 | "import matplotlib.pyplot as plt\n", 99 | "from tqdm import tqdm\n", 100 | "from nucleotide_transformer.chatNT.pretrained import get_chatNT\n", 101 | "\n", 102 | "jax.config.update(\"jax_platform_name\", \"cpu\")" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "metadata": {}, 108 | "source": [ 109 | "# Specify your backend device" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "# Use either \"cpu\", \"gpu\" or \"tpu\"\n", 119 | "backend = \"cpu\"" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "devices = jax.devices(backend)\n", 129 | "num_devices = len(devices)\n", 130 | "print(f\"Devices found: {devices}\")" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "metadata": {}, 136 | "source": [ 137 | "# Define function to generate answer" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 4, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "def generate_answer(apply_fn, parameters, random_keys, english_tokenizer, english_tokens, bio_tokens, max_num_tokens_to_decode):\n", 147 | " \"\"\"\n", 148 | " Note: the function expects that pmap is already applied to the forward function, the inputs and the parameters\n", 149 | " \"\"\"\n", 150 | " english_tokens = english_tokens.copy()\n", 151 | "\n", 152 | " idx_begin_generation = np.where(\n", 153 | " english_tokens[0, 0] == english_tokenizer.pad_token_id\n", 154 | " )[0][0]\n", 155 | " projected_bio_embeddings = jax.device_put_replicated(None, devices=devices)\n", 156 | " actual_nb_steps = 0\n", 157 | "\n", 158 | " for _ in tqdm(range(max_num_tokens_to_decode)):\n", 159 | " outs = apply_fn(\n", 160 | " parameters,\n", 161 | " random_keys,\n", 162 | " multi_omics_tokens_ids=(english_tokens, bio_tokens),\n", 163 | " projection_english_tokens_ids=english_tokens,\n", 164 | " projected_bio_embeddings=projected_bio_embeddings,\n", 165 | " )\n", 166 | " projected_bio_embeddings = outs[\"projected_bio_embeddings\"]\n", 167 | " logits = outs[\"logits\"]\n", 168 | "\n", 169 | " first_idx_pad_token = np.where(\n", 170 | " english_tokens[0, 0] == english_tokenizer.pad_token_id\n", 171 | " )[0][0]\n", 172 | " predicted_token = np.argmax(logits[0, 0, first_idx_pad_token - 1])\n", 173 | "\n", 174 | " if predicted_token == english_tokenizer.eos_token_id:\n", 175 | " break\n", 176 | " else:\n", 177 | " english_tokens = english_tokens.at[0, 0, first_idx_pad_token].set(\n", 178 | " predicted_token\n", 179 | " )\n", 180 | " actual_nb_steps += 1\n", 181 | "\n", 182 | " decoded_generated_sentence = english_tokenizer.decode(\n", 183 | " english_tokens[0, 0, idx_begin_generation : idx_begin_generation + actual_nb_steps]\n", 184 | " )\n", 185 | "\n", 186 | " return decoded_generated_sentence" 187 | ] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "metadata": {}, 192 | "source": [ 193 | "# Load model" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 5, 199 | "metadata": {}, 200 | "outputs": [ 201 | { 202 | "name": "stdout", 203 | "output_type": "stream", 204 | "text": [ 205 | "Downloading model's weights...\n" 206 | ] 207 | } 208 | ], 209 | "source": [ 210 | "forward_fn, parameters, english_tokenizer, bio_tokenizer = get_chatNT()" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 8, 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "forward_fn = hk.transform(forward_fn)\n", 220 | "apply_fn = jax.pmap(forward_fn.apply, devices=devices, donate_argnums=(0,))\n", 221 | "\n", 222 | "# Put required quantities for the inference on the devices. This step is not\n", 223 | "# reproduced in the second inference since the quantities will already be loaded\n", 224 | "# on the devices !\n", 225 | "random_key = jax.random.PRNGKey(seed=0)\n", 226 | "random_keys = jax.numpy.stack([random_key for _ in range(len(devices))])\n", 227 | "keys = jax.device_put_replicated(random_key, devices=devices)\n", 228 | "parameters = jax.device_put_replicated(parameters, devices=devices)" 229 | ] 230 | }, 231 | { 232 | "cell_type": "markdown", 233 | "metadata": {}, 234 | "source": [ 235 | "# Define prompt" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": 9, 241 | "metadata": {}, 242 | "outputs": [], 243 | "source": [ 244 | "# Define custom inputs (note that the number of token in the english sequence must be equal to len(dna_sequences))\n", 245 | "english_sequence = \"A chat between a curious user and an artificial intelligence assistant that can handle bio sequences. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: Is there any evidence of an acceptor splice site in this sequence ? ASSISTANT:\"\n", 246 | "dna_sequences = [\"ATCGGAAAAAGATCCAGAAAGTTATACCAGGCCAATGGGAATCACCTATTACGTGGATAATAGCGATAGTATGTTACCTATAAATTTAACTACGTGGATATCAGGCAGTTACGTTACCAGTCAAGGAGCACCCAAAACTGTCCAGCAACAAGTTAATTTACCCATGAAGATGTACTGCAAGCCTTGCCAACCAGTTAAAGTAGCTACTCATAAGGTAATAAACAGTAATATCGACTTTTTATCCATTTTGATAATTGATTTATAACAGTCTATAACTGATCGCTCTACATAATCTCTATCAGATTACTATTGACACAAACAGAAACCCCGTTAATTTGTATGATATATTTCCCGGTAAGCTTCGATTTTTAATCCTATCGTGACAATTTGGAATGTAACTTATTTCGTATAGGATAAACTAATTTACACGTTTGAATTCCTAGAATATGGAGAATCTAAAGGTCCTGGCAATGCCATCGGCTTTCAATATTATAATGGACCAAAAGTTACTCTATTAGCTTCCAAAACTTCGCGTGAGTACATTAGAACAGAAGAATAACCTTCAATATCGAGAGAGTTACTATCACTAACTATCCTATG\"]" 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": {}, 252 | "source": [ 253 | "# Tokenize" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 10, 259 | "metadata": {}, 260 | "outputs": [], 261 | "source": [ 262 | "english_max_length = 512 # length of the tokenized english sequence\n", 263 | "bio_tokenized_sequence_length = 512 # length of the tokenized DNA sequences\n", 264 | "\n", 265 | "english_tokens = english_tokenizer(\n", 266 | " [english_sequence],\n", 267 | " return_tensors=\"np\",\n", 268 | " max_length=english_max_length,\n", 269 | " padding=\"max_length\",\n", 270 | " truncation=True,\n", 271 | ").input_ids\n", 272 | "\n", 273 | "bio_tokens = bio_tokenizer(\n", 274 | " dna_sequences,\n", 275 | " return_tensors=\"np\",\n", 276 | " padding=\"max_length\",\n", 277 | " max_length=bio_tokenized_sequence_length,\n", 278 | " truncation=True,\n", 279 | ").input_ids\n", 280 | "bio_tokens = np.expand_dims(bio_tokens, axis=0) # Add batch dimension -> result: (1, num_dna_sequences, bio_tokenized_sequence_length)\n", 281 | "\n", 282 | "\n", 283 | "# Replicate over devices\n", 284 | "english_tokens = jnp.stack([jnp.asarray(english_tokens, dtype=jnp.int32)]*num_devices, axis=0)\n", 285 | "bio_tokens = jnp.stack([jnp.asarray(bio_tokens, dtype=jnp.int32)]*num_devices, axis=0)" 286 | ] 287 | }, 288 | { 289 | "cell_type": "markdown", 290 | "metadata": { 291 | "id": "yLOOWXYluU7p" 292 | }, 293 | "source": [ 294 | "## Inference" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": null, 300 | "metadata": {}, 301 | "outputs": [], 302 | "source": [ 303 | "generated_answer = generate_answer(\n", 304 | " apply_fn=apply_fn,\n", 305 | " parameters=parameters,\n", 306 | " random_keys=random_keys,\n", 307 | " english_tokenizer=english_tokenizer,\n", 308 | " english_tokens=english_tokens,\n", 309 | " bio_tokens=bio_tokens,\n", 310 | " max_num_tokens_to_decode=20\n", 311 | ")" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": 12, 317 | "metadata": {}, 318 | "outputs": [ 319 | { 320 | "name": "stdout", 321 | "output_type": "stream", 322 | "text": [ 323 | "Yes, an acceptor splice site is present in this nucleotide sequence.\n" 324 | ] 325 | } 326 | ], 327 | "source": [ 328 | "print(generated_answer)" 329 | ] 330 | } 331 | ], 332 | "metadata": { 333 | "accelerator": "GPU", 334 | "colab": { 335 | "gpuType": "T4", 336 | "provenance": [] 337 | }, 338 | "kernelspec": { 339 | "display_name": "genomics-research-env", 340 | "language": "python", 341 | "name": "python3" 342 | }, 343 | "language_info": { 344 | "codemirror_mode": { 345 | "name": "ipython", 346 | "version": 3 347 | }, 348 | "file_extension": ".py", 349 | "mimetype": "text/x-python", 350 | "name": "python", 351 | "nbconvert_exporter": "python", 352 | "pygments_lexer": "ipython3", 353 | "version": "3.11.10" 354 | } 355 | }, 356 | "nbformat": 4, 357 | "nbformat_minor": 0 358 | } 359 | -------------------------------------------------------------------------------- /nucleotide_transformer/enformer/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Optional 3 | 4 | import haiku as hk 5 | import jax 6 | from einops import rearrange 7 | from jax import numpy as jnp 8 | 9 | from nucleotide_transformer.borzoi.layers import get_positional_embed_borzoi 10 | 11 | 12 | def get_positional_features_exponential( 13 | positions: jnp.ndarray, features: int, seq_len: int, min_half_life: float = 3.0 14 | ) -> jnp.ndarray: 15 | """ 16 | Exponential features positional embeddings. 17 | 18 | Args: 19 | positions: positions. 20 | features: number of features. 21 | seq_len: sequence length. 22 | min_half_life: minimum half life. 23 | 24 | Returns: 25 | Positional embeddings. 26 | """ 27 | max_range = math.log(seq_len) / math.log(2.0) 28 | half_life = 2 ** jnp.linspace(min_half_life, max_range, features) 29 | half_life = half_life[None, ...] 30 | positions = jnp.abs(positions)[..., None] 31 | return jnp.exp(-math.log(2.0) / half_life * positions) 32 | 33 | 34 | def get_positional_features_central_mask( 35 | positions: jnp.ndarray, features: int 36 | ) -> jnp.ndarray: 37 | """ 38 | Exponential features central mask. 39 | 40 | Args: 41 | positions: positions. 42 | features: number of features. 43 | 44 | Returns: 45 | Positional embeddings. 46 | """ 47 | center_widths = 2 ** jnp.arange(1, features + 1).astype(jnp.float32) 48 | center_widths = center_widths - 1 49 | return (center_widths[None, ...] > jnp.abs(positions)[..., None]).astype( 50 | jnp.float32 51 | ) 52 | 53 | 54 | def gamma_pdf( 55 | x: jnp.ndarray, concentration: jnp.ndarray, rate: jnp.ndarray 56 | ) -> jnp.ndarray: 57 | """ 58 | Gamma law PDF function. 59 | 60 | Args: 61 | x: input tensor. 62 | concentration: gamma concentration. 63 | rate: gamma rate. 64 | 65 | Returns: 66 | gamma pdf value. 67 | """ 68 | log_unnormalized_prob = jax.scipy.special.xlogy(concentration - 1.0, x) - rate * x 69 | log_normalization = jax.lax.lgamma(concentration) - concentration * jnp.log(rate) 70 | return jnp.exp(log_unnormalized_prob - log_normalization) 71 | 72 | 73 | def get_positional_features_gamma( 74 | positions: jnp.ndarray, 75 | features: int, 76 | seq_len: int, 77 | stddev: Optional[float] = None, 78 | start_mean: Optional[float] = None, 79 | eps: float = 1e-8, 80 | ) -> jnp.ndarray: 81 | """ 82 | Get Gamma positional features. 83 | """ 84 | if stddev is None: 85 | stddev = seq_len / (2 * features) 86 | 87 | if start_mean is None: 88 | start_mean = seq_len / features 89 | 90 | mean = jnp.linspace(start_mean, seq_len, features) 91 | mean = mean[None, ...] 92 | concentration = (mean / stddev) ** 2 93 | rate = mean / stddev**2 94 | probabilities = gamma_pdf( 95 | jnp.abs(positions.astype(jnp.float32))[..., None], concentration, rate 96 | ) 97 | probabilities = probabilities + eps 98 | outputs = probabilities / jnp.amax(probabilities, axis=-1, keepdims=True) 99 | return outputs 100 | 101 | 102 | def get_positional_embed_enformer(seq_len: int, feature_size: int) -> jnp.ndarray: 103 | """ 104 | Compute positional embedding. 105 | """ 106 | distances = jnp.arange(-seq_len + 1, seq_len) 107 | 108 | feature_functions = [ 109 | get_positional_features_exponential, 110 | get_positional_features_central_mask, 111 | get_positional_features_gamma, 112 | ] 113 | 114 | num_components = len(feature_functions) * 2 115 | 116 | if (feature_size % num_components) != 0: 117 | raise ValueError( 118 | f"feature size is not divisible by number of components ({num_components})" 119 | ) 120 | 121 | num_basis_per_class = feature_size // num_components 122 | 123 | embeddings = [] 124 | embeddings.append( 125 | get_positional_features_exponential(distances, num_basis_per_class, seq_len) 126 | ) 127 | embeddings.append( 128 | get_positional_features_central_mask(distances, num_basis_per_class) 129 | ) 130 | embeddings.append( 131 | get_positional_features_gamma(distances, num_basis_per_class, seq_len) 132 | ) 133 | 134 | embeddings = jnp.concatenate(embeddings, axis=-1) 135 | embeddings = jnp.concatenate( 136 | (embeddings, jnp.sign(distances)[..., None] * embeddings), axis=-1 137 | ) 138 | return embeddings 139 | 140 | 141 | def relative_shift(x: jnp.ndarray) -> jnp.ndarray: 142 | """ 143 | Apply relative shift. 144 | """ 145 | to_pad = jnp.zeros_like(x[..., :1]) 146 | x = jnp.concatenate((to_pad, x), axis=-1) 147 | _, h, t1, t2 = x.shape 148 | x = x.reshape(-1, h, t2, t1) 149 | x = x[:, :, 1:, :] 150 | x = x.reshape(-1, h, t1, t2 - 1) 151 | return x[..., : ((t2 + 1) // 2)] 152 | 153 | 154 | def exponential_linspace_int( 155 | start: int, end: int, num: int, divisible_by: int = 1 156 | ) -> List[int]: 157 | """ 158 | Create list of dimensions to construct the Enformer model. 159 | """ 160 | 161 | def _round(x: float) -> int: 162 | return int(round(x / divisible_by) * divisible_by) 163 | 164 | base = math.exp(math.log(end / start) / (num - 1)) 165 | return [_round(start * base**i) for i in range(num)] 166 | 167 | 168 | def gelu_fn(x: jnp.ndarray) -> jnp.ndarray: 169 | """ 170 | Custom GELU activation function. 171 | """ 172 | return jax.nn.sigmoid(1.702 * x) * x 173 | 174 | 175 | class Attention(hk.Module): 176 | """ 177 | Enformer Attention Layer. 178 | """ 179 | 180 | def __init__( 181 | self, 182 | dim: int, 183 | *, 184 | num_rel_pos_features: int, 185 | heads: int = 8, 186 | dim_key: int = 64, 187 | dim_value: int = 64, 188 | positional_encoding_type: str = "enformer", 189 | name: Optional[str] = None, 190 | ): 191 | super().__init__(name=name) 192 | self.scale = dim_key**-0.5 193 | self.heads = heads 194 | 195 | self.to_q = hk.Linear( 196 | output_size=dim_key * heads, with_bias=False, name="attn_q" 197 | ) 198 | self.to_k = hk.Linear( 199 | output_size=dim_key * heads, with_bias=False, name="attn_k" 200 | ) 201 | self.to_v = hk.Linear( 202 | output_size=dim_value * heads, with_bias=False, name="attn_v" 203 | ) 204 | 205 | self.to_out = hk.Linear(output_size=dim, name="attn_o") 206 | 207 | # relative positional encoding 208 | self.num_rel_pos_features = num_rel_pos_features 209 | 210 | self.to_rel_k = hk.Linear( 211 | output_size=dim_key * heads, with_bias=False, name="attn_to_rel_k" 212 | ) 213 | 214 | w_init = hk.initializers.RandomNormal() 215 | self.rel_content_bias = hk.get_parameter( 216 | "rel_content_bias", 217 | shape=(1, heads, 1, dim_key), 218 | init=w_init, 219 | ) 220 | w_init = hk.initializers.RandomNormal() 221 | self.rel_pos_bias = hk.get_parameter( 222 | "rel_pos_bias", 223 | shape=(1, heads, 1, dim_key), 224 | init=w_init, 225 | ) 226 | 227 | # Set the way the position encodings are computed 228 | if positional_encoding_type == "enformer": 229 | self.get_positional_embed = get_positional_embed_enformer 230 | elif positional_encoding_type == "borzoi": 231 | self.get_positional_embed = get_positional_embed_borzoi 232 | 233 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 234 | n = x.shape[-2] 235 | 236 | q = self.to_q(x) 237 | k = self.to_k(x) 238 | v = self.to_v(x) 239 | 240 | def _rearrange(x: jnp.ndarray) -> jnp.ndarray: 241 | return rearrange(x, "b n (h d) -> b h n d", h=self.heads) 242 | 243 | q = _rearrange(q) 244 | k = _rearrange(k) 245 | v = _rearrange(v) 246 | 247 | q = q * self.scale 248 | 249 | content_logits = jnp.einsum( 250 | "b h i d, b h j d -> b h i j", q + self.rel_content_bias, k 251 | ) 252 | 253 | positions = self.get_positional_embed(n, self.num_rel_pos_features) 254 | rel_k = self.to_rel_k(positions) 255 | 256 | def _rearrange_1(x: jnp.ndarray) -> jnp.ndarray: 257 | return rearrange(x, "n (h d) -> h n d", h=self.heads) 258 | 259 | rel_k = _rearrange_1(rel_k) 260 | rel_logits = jnp.einsum( 261 | "b h i d, h j d -> b h i j", q + self.rel_pos_bias, rel_k 262 | ) 263 | rel_logits = relative_shift(rel_logits) 264 | 265 | logits = content_logits + rel_logits 266 | attn = jax.nn.softmax(logits, axis=-1) 267 | 268 | out = jnp.einsum("b h i j, b h j d -> b h i d", attn, v) 269 | 270 | def _rearrange_2(x: jnp.ndarray) -> jnp.ndarray: 271 | return rearrange(x, "b h n d -> b n (h d)") 272 | 273 | out = _rearrange_2(out) 274 | 275 | return self.to_out(out) 276 | 277 | 278 | class AttentionPool(hk.Module): 279 | """Enformer Attention Pooling layer.""" 280 | 281 | def __init__( 282 | self, 283 | dim: int, 284 | pool_size: int = 2, 285 | name: Optional[str] = None, 286 | ): 287 | """ 288 | Args: 289 | dim: input dimension. 290 | pool_size: pooling size. 291 | name: model's name. 292 | """ 293 | super().__init__(name=name) 294 | self._dim = dim 295 | self._pool_size = pool_size 296 | 297 | self._to_attn_logits = hk.Conv2D( 298 | output_channels=dim, kernel_shape=1, with_bias=False, data_format="NCHW" 299 | ) 300 | 301 | def _pool_fn(self, x: jnp.ndarray) -> jnp.ndarray: 302 | b, d, n = x.shape 303 | x = jnp.reshape(x, (b, d, n // self._pool_size, self._pool_size)) 304 | return x 305 | 306 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 307 | 308 | b, d, n = x.shape 309 | remainder = n % self._pool_size 310 | needs_padding = remainder > 0 311 | 312 | if needs_padding: 313 | x = jnp.concatenate( 314 | [x, jnp.zeros(shape=(b, d, remainder), dtype=x.dtype)], axis=-1 315 | ) 316 | mask = jnp.zeros((b, 1, n), dtype=jnp.bool_) 317 | mask = jnp.concatenate( 318 | [mask, jnp.ones(shape=(b, 1, remainder), dtype=jnp.bool_)], axis=-1 319 | ) 320 | 321 | x = self._pool_fn(x) 322 | logits = self._to_attn_logits(x) 323 | 324 | if needs_padding: 325 | mask_value = -jnp.inf 326 | logits = jnp.where( 327 | self._pool_fn(mask), 328 | mask_value * jnp.ones_like(logits, dtype=logits.dtype), 329 | logits, 330 | ) 331 | 332 | attn = jax.nn.softmax(logits, axis=-1) 333 | return jnp.sum((x * attn), axis=-1) 334 | 335 | 336 | class ResidualConvBlock(hk.Module): 337 | """ 338 | Conv Block with Residual connection. 339 | """ 340 | 341 | def __init__( 342 | self, 343 | dim: int, 344 | dim_out: Optional[int] = None, 345 | kernel_size: int = 1, 346 | name: Optional[str] = None, 347 | ): 348 | """ 349 | Args: 350 | dim: input dimension. 351 | dim_out: output dimension. 352 | kernel_size: kernel's size. 353 | name: model's name. 354 | """ 355 | super().__init__(name=name) 356 | self._dim = dim 357 | self._dim_out = dim_out 358 | self._kernel_size = kernel_size 359 | 360 | def __call__(self, x: jnp.ndarray, is_training: bool = False) -> jnp.ndarray: 361 | conv_block = ConvBlock( 362 | dim=self._dim, dim_out=self._dim_out, kernel_size=self._kernel_size 363 | ) 364 | return x + conv_block(x, is_training) 365 | 366 | 367 | class ConvBlock(hk.Module): 368 | """ 369 | Conv Block. 370 | """ 371 | 372 | def __init__( 373 | self, 374 | dim: int, 375 | dim_out: Optional[int] = None, 376 | kernel_size: int = 1, 377 | name: Optional[str] = None, 378 | ): 379 | """ 380 | Args: 381 | dim: input dimension. 382 | dim_out: output dimension. 383 | kernel_size: kernel's size. 384 | name: model's name. 385 | """ 386 | super().__init__(name=name) 387 | self._dim = dim 388 | self._dim_out = dim_out 389 | self._kernel_size = kernel_size 390 | 391 | def __call__(self, x: jnp.ndarray, is_training: bool = False) -> jnp.ndarray: 392 | batch_norm = hk.BatchNorm( 393 | create_scale=True, 394 | create_offset=True, 395 | decay_rate=0.9, 396 | data_format="NC...", 397 | ) 398 | conv = hk.Conv1D( 399 | output_channels=self._dim if self._dim_out is None else self._dim_out, 400 | kernel_shape=self._kernel_size, 401 | padding=(self._kernel_size // 2, self._kernel_size // 2), 402 | data_format="NCW", 403 | ) 404 | 405 | x = batch_norm(x, is_training=is_training) 406 | x = gelu_fn(x) 407 | x = conv(x) 408 | return x 409 | --------------------------------------------------------------------------------