12 |
13 | # Training data
14 |
15 | Isoformer is trained on RNA transcript expression data obtained from the GTex portal,
16 | namely Transcript TPMs measurements across 30 tissues which come from more than 5000 individuals.
17 | In total, the dataset is made of ∼170k unique transcripts, of which 90k are protein-coding and correspond to ∼20k unique genes.
18 |
19 | ## How to use 🚀
20 |
21 | We make Isoformer available on HuggingFace and provide an example on how to use it at `./notebooks/isoformer/inference.ipynb`.
22 |
23 | ## Citing our work 📚
24 |
25 | You can cite our model at:
26 |
27 | ```bibtex
28 | @inproceedings{NEURIPS2024_8f6b3692,
29 | author = {Garau-Luis, Juan Jose and Bordes, Patrick and Gonzalez, Liam and Roller, Masa and de Almeida, Bernardo P. and Hexemer, Lorenz and Blum, Christopher and Laurent, Stefan and Grzegorzewski, Jan and Lang, Maren and Pierrot, Thomas and Richard, Guillaume},
30 | booktitle = {Advances in Neural Information Processing Systems},
31 | editor = {A. Globerson and L. Mackey and D. Belgrave and A. Fan and U. Paquet and J. Tomczak and C. Zhang},
32 | pages = {78431--78450},
33 | publisher = {Curran Associates, Inc.},
34 | title = {Multi-modal Transfer Learning between Biological Foundation Models},
35 | url = {https://proceedings.neurips.cc/paper_files/paper/2024/file/8f6b3692297e49e5d5c91ba00281379c-Paper-Conference.pdf},
36 | volume = {37},
37 | year = {2024}
38 | }
39 | ```
40 |
--------------------------------------------------------------------------------
/nucleotide_transformer/sCellTransformer/params.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 InstaDeep Ltd
2 | #
3 | # Licensed under the Creative Commons BY-NC-SA 4.0 License (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://creativecommons.org/licenses/by-nc-sa/4.0/
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import json
16 | import os
17 | from typing import Any, Dict, Tuple
18 |
19 | import haiku as hk
20 | import joblib
21 | from huggingface_hub import hf_hub_download
22 |
23 | from nucleotide_transformer.sCellTransformer.model import sCTConfig
24 |
25 | ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME"
26 | DEFAULT_CACHE_DIR = "~/.cache"
27 |
28 |
29 | def _get_dir(model_name: str) -> str:
30 | """
31 | Get directory to save files on user machine.
32 | """
33 | return os.path.expanduser(
34 | os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), model_name)
35 | )
36 |
37 |
38 | def download_ckpt() -> Tuple[hk.Params, Any]:
39 | """
40 | Download checkpoint on kao datacenter.
41 |
42 | Args:
43 | model_name: Name of the model.
44 |
45 | Returns:
46 | Model parameters.
47 | Model state
48 |
49 |
50 | """
51 |
52 | save_dir = os.path.join(_get_dir("sCellTransformer"), "sCellTransformer")
53 |
54 | repo_id = f"InstaDeepAI/sCellTransformer"
55 |
56 | # Download parameters
57 | print("Downloading model's weights...")
58 | params = joblib.load(
59 | hf_hub_download(
60 | repo_id=repo_id,
61 | filename="jax_params/params.joblib",
62 | cache_dir=save_dir,
63 | )
64 | )
65 |
66 | config_path = hf_hub_download(
67 | repo_id=repo_id,
68 | filename="jax_params/config.json",
69 | )
70 | with open(config_path, "r") as f:
71 | config_dict = json.load(f)
72 | config = sCTConfig(**config_dict)
73 |
74 | return params, config
75 |
--------------------------------------------------------------------------------
/docs/mojo.md:
--------------------------------------------------------------------------------
1 | # MOJO
2 |
3 | MOJO (MultiOmics JOint representation learning) is a multimodal model designed learn embeddings of multi-omics data. It integrates bulk RNA-seq and DNA methylation data to generate powerful joint representations. These representations are tailored for improving predictive performance on downstream tasks like cancer-type classification and survival analysis.
4 |
5 | * 📜 **[Read the Paper (ICML Workshop on Generative AI and Biology 2025)](https://www.biorxiv.org/content/10.1101/2025.06.25.661237v1)**
6 | * 🤗 **[Hugging Face Link](https://huggingface.co/InstaDeepAI/MOJO)**
7 |
8 | ## Training data
9 |
10 | MOJO is trained on multi-modal data from The Cancer Genome Atlas (TCGA) dataset. The training data consists of paired bulk RNA-seq and DNA methylation profiles, enabling the model to learn the complex interplay between transcription and epigenetic regulation.
11 | The model takes as input sequences of 17116 gene expression and DNA methylation. The gene ids that must be used and in which order they should appear in the sequence as provided in `../notebooks/data/mojo_gene_names.txt`
12 |
13 | ## Training procedure
14 |
15 | MOJO uses a bimodal masked language modeling objective. It is trained to simultaneously predict masked values in both the RNA-seq and DNA methylation data modalities. This process forces the model to learn the intricate cross-modal relationships between gene expression and epigenetic regulation, leading to robust, integrated representations.
16 |
17 | ## How to use 🚀
18 |
19 | We make MOJO available in Jax in this repository and in PyTorch on HuggingFace. Examples on how to use it at:
20 | - Jax: `../notebooks/mojo/inference_mojo_jax_example.ipynb`.
21 | - PyTorch: `../notebooks/mojo/inference_mojo_pytorch_example.ipynb`.
22 |
23 | ## Citing our work 📚
24 |
25 | You can cite our model at:
26 |
27 | ```bibtex
28 | @article {G{\'e}lard2025.06.25.661237,
29 | author = {G{\'e}lard, Maxence and Benkirane, Hakim and Pierrot, Thomas and Richard, Guillaume and Courn{\`e}de, Paul-Henry},
30 | title = {Bimodal masked language modeling for bulk RNA-seq and DNA methylation representation learning},
31 | elocation-id = {2025.06.25.661237},
32 | year = {2025},
33 | doi = {10.1101/2025.06.25.661237},
34 | publisher = {Cold Spring Harbor Laboratory},
35 | URL = {https://www.biorxiv.org/content/early/2025/06/27/2025.06.25.661237},
36 | journal = {bioRxiv}
37 | }
38 | ```
39 |
--------------------------------------------------------------------------------
/nucleotide_transformer/borzoi/pretrained.py:
--------------------------------------------------------------------------------
1 | """Implementation of utilities to load a pretrained Borzoi model in Trix."""
2 |
3 | from typing import Callable, Tuple
4 |
5 | import haiku as hk
6 | import jax.numpy as jnp
7 |
8 | from nucleotide_transformer.borzoi.model import (
9 | BorzoiConfig,
10 | build_borzoi_fn_with_head_fn,
11 | )
12 | from nucleotide_transformer.enformer.features import FEATURES
13 | from nucleotide_transformer.enformer.heads import UNetHead
14 | from nucleotide_transformer.enformer.params import download_ckpt
15 | from nucleotide_transformer.enformer.tokenizer import NucleotidesKmersTokenizer
16 |
17 |
18 | def get_pretrained_segment_borzoi_model() -> (
19 | Tuple[hk.Params, hk.State, Callable, NucleotidesKmersTokenizer, BorzoiConfig]
20 | ):
21 | """
22 | Loads the pretrained SegmentBorzoi model.
23 |
24 | Returns:
25 | hk.Params: Model parameters
26 | hk.State: Model state
27 | Callable: Haiku forward function
28 | NucleotidesKmersTokenizer: Tokenizer
29 | BorzoiConfig: Configuration of the Borzoi model
30 |
31 | Example:
32 | >>> import jax
33 | >>> import haiku as hk
34 | >>> parameters, state, forward_fn, tokenizer, config = get_pretrained_segment_borzoi_model()
35 | >>> apply_fn = hk.transform_with_state(forward_fn).apply
36 | >>> random_key = jax.random.PRNGKey(seed=0)
37 | >>> sequences = ["A" * 524_288]
38 | >>> tokens = jax.numpy.asarray([b[1] for b in tokenizer.batch_tokenize(sequences)])
39 | >>> outs, _ = apply_fn(parameters, state, random_key, tokens)
40 | """
41 | config = BorzoiConfig()
42 | tokenizer = NucleotidesKmersTokenizer(
43 | k_mers=1,
44 | prepend_bos_token=False,
45 | prepend_cls_token=False,
46 | append_eos_token=False,
47 | tokens_to_ids=None,
48 | )
49 |
50 | def head_fn() -> hk.Module:
51 | return UNetHead(
52 | features=FEATURES,
53 | embed_dimension=config.embed_dim,
54 | nucl_per_token=config.dim_divisible_by,
55 | remove_cls_token=False,
56 | )
57 |
58 | forward_fn = build_borzoi_fn_with_head_fn(
59 | config=config,
60 | head_fn=head_fn,
61 | embedding_name="embedding",
62 | name="Borzoi",
63 | compute_dtype=jnp.float32,
64 | )
65 |
66 | parameters, state = download_ckpt("segment_borzoi")
67 |
68 | return parameters, state, forward_fn, tokenizer, config
69 |
--------------------------------------------------------------------------------
/nucleotide_transformer/enformer/pretrained.py:
--------------------------------------------------------------------------------
1 | """Implementation of utilities to load a pretrained Enformer model in Trix."""
2 |
3 | from typing import Callable, Tuple
4 |
5 | import haiku as hk
6 | import jax.numpy as jnp
7 |
8 | from nucleotide_transformer.enformer.features import FEATURES
9 | from nucleotide_transformer.enformer.heads import UNetHead
10 | from nucleotide_transformer.enformer.model import (
11 | EnformerConfig,
12 | build_enformer_with_head_fn,
13 | )
14 | from nucleotide_transformer.enformer.params import download_ckpt
15 | from nucleotide_transformer.enformer.tokenizer import NucleotidesKmersTokenizer
16 |
17 |
18 | def get_pretrained_segment_enformer_model() -> (
19 | Tuple[hk.Params, hk.State, Callable, NucleotidesKmersTokenizer, EnformerConfig]
20 | ):
21 | """
22 | Loads the pretrained SegmentEnformer model.
23 |
24 | Returns:
25 | hk.Params: Model parameters
26 | hk.State: Model state
27 | Callable: Haiku forward function
28 | NucleotidesKmersTokenizer: Tokenizer
29 | EnformerConfig: Configuration of the Enformer model
30 |
31 | Example:
32 | >>> import jax
33 | >>> import haiku as hk
34 | >>> parameters, state, forward_fn, tokenizer, config = get_pretrained_segment_enformer_model()
35 | >>> apply_fn = hk.transform_with_state(forward_fn).apply
36 | >>> random_key = jax.random.PRNGKey(seed=0)
37 | >>> sequences = ["A" * 196608]
38 | >>> tokens = jax.numpy.asarray([b[1] for b in tokenizer.batch_tokenize(sequences)])
39 | >>> outs, _ = apply_fn(parameters, state, random_key, tokens)
40 | """
41 | config = EnformerConfig()
42 | tokenizer = NucleotidesKmersTokenizer(
43 | k_mers=1,
44 | prepend_bos_token=False,
45 | prepend_cls_token=False,
46 | append_eos_token=False,
47 | tokens_to_ids=None,
48 | )
49 |
50 | def head_fn() -> hk.Module:
51 | return UNetHead(
52 | features=FEATURES,
53 | embed_dimension=config.embed_dim,
54 | nucl_per_token=config.dim_divisible_by,
55 | remove_cls_token=False,
56 | )
57 |
58 | forward_fn = build_enformer_with_head_fn(
59 | config=config,
60 | head_fn=head_fn,
61 | embedding_name="embedding_transformer_tower",
62 | name="Enformer",
63 | compute_dtype=jnp.float32,
64 | )
65 |
66 | parameters, state = download_ckpt("segment_enformer")
67 |
68 | return parameters, state, forward_fn, tokenizer, config
69 |
--------------------------------------------------------------------------------
/docs/bulk_rna_bert.md:
--------------------------------------------------------------------------------
1 | # BulkRNABert
2 |
3 | BulkRNABert is a transformer-based, encoder-only foundation model designed for bulk RNA-seq data. It learns biologically meaningful representations from large-scale transcriptomic profiles. Once pre-trained, BulkRNABert can be fine-tuned for various cancer-related downstream tasks, such as cancer type classification or survival analysis, by using its learned embeddings.
4 |
5 | * 📜 **[Read the Paper (Machine Learning for Health 2024)](https://proceedings.mlr.press/v259/gelard25a.html)**
6 | * 🤗 **[Hugging Face Link](https://huggingface.co/InstaDeepAI/BulkRNABert)**
7 |
8 | ## Training data
9 |
10 | The model was pre-trained on a large dataset of bulk RNA-seq profiles from The Cancer Genome Atlas (TCGA) dataset.
11 | The model takes as input sequences of 19062 gene expression. The gene ids that must be used and in which order they should appear in the sequence as provided in `../notebooks/data/bulkrnabert_gene_ids.txt`
12 |
13 |
14 |
15 | ## Training procedure
16 |
17 | Following the original BERT framework, BulkRNABert uses a self-supervised, masked language modeling objective. During pre-training, gene expression values are randomly masked, and the model is tasked with reconstructing these values from their surrounding genomic context. This process enables the model to learn rich, contextual representations of transcriptomic profiles.
18 |
19 | ## How to use 🚀
20 |
21 | We make BulkRNABert available in Jax in this repository and in PyTorch on HuggingFace. Examples on how to use it at:
22 | - Jax: `../notebooks/bulk_rna_bert/inference_bulkrnabert_jax_example.ipynb`.
23 | - PyTorch: `../notebooks/bulk_rna_bert/inference_bulkrnabert_pytorch_example.ipynb`.
24 |
25 | ## Citing our work 📚
26 |
27 | You can cite our model at:
28 |
29 | ```bibtex
30 | @InProceedings{pmlr-v259-gelard25a,
31 | title = {BulkRNABert: Cancer prognosis from bulk RNA-seq based language models},
32 | author = {G{\'{e}}lard, Maxence and Richard, Guillaume and Pierrot, Thomas and Courn{\`{e}}de, Paul-Henry},
33 | booktitle = {Proceedings of the 4th Machine Learning for Health Symposium},
34 | pages = {384--400},
35 | year = {2025},
36 | editor = {Hegselmann, Stefan and Zhou, Helen and Healey, Elizabeth and Chang, Trenton and Ellington, Caleb and Mhasawade, Vishwali and Tonekaboni, Sana and Argaw, Peniel and Zhang, Haoran},
37 | volume = {259},
38 | series = {Proceedings of Machine Learning Research},
39 | month = {15--16 Dec},
40 | publisher = {PMLR},
41 | url = {https://proceedings.mlr.press/v259/gelard25a.html},
42 | }
43 | ```
44 |
--------------------------------------------------------------------------------
/docs/sct.md:
--------------------------------------------------------------------------------
1 | # sCT
2 |
3 | sCT (single-Cell Transformer) is our foundational transformer model for single-cell and spatial transcriptomics data. sCT aims to learn rich representations from complex, high-dimensional single-cell datasets to improve various downstream analytical tasks.
4 |
5 | sCT processes raw gene expression profiles across multiple cells to predict discretized gene expression levels for unseen cells without retraining.
6 | The model can handle up to 20,000 protein-coding genes and a bag of 50 cells in the same sample.
7 | This ability (around a million-gene expressions tokens) allows it to learn cross-cell relationships and
8 | capture long-range dependencies in gene expression data, and to mitigate the sparsity typical in single-cell datasets.
9 |
10 | sCT is trained on a large dataset of single-cell RNA-seq and finetuned on spatial transcriptomics data. Evaluation tasks include zero-shot imputation of masked gene expression, and zero-shot prediction of cell types.
11 |
12 | * 📜 **[Read the Paper (OpenReview preprint)](https://openreview.net/forum?id=VdX9tL3VXH)**
13 | * 🤗 **[Hugging Face Link](https://huggingface.co/InstaDeepAI/sCellTransformer)**
14 |
15 | ## Training data
16 |
17 | The model was trained following a two-step procedure: pre-training on single-cell data, then finetuning on spatial transcriptomics data. The single-cell data used for pre-training, comes from the Cellxgene Census collection datasets used to train the scGPT models. It consists of around 50 millions cells and approximately 60,000 genes. The spatial data comes from both the human breast cell atlas and the human heart atlas.
18 |
19 | ## Training procedure
20 |
21 | As detailed in the paper, the gene expressions are first binned into a pre-defined number of bins. This allows the model to better learn the distribution of the gene expressions through sparsity mitigation, noise reduction, and extreme-values handling. Then, the training objective is to predict the masked gene expressions in a cell, following a BERT-like style training.
22 |
23 | ## How to use 🚀
24 |
25 | We make sCT available in Jax in this repository and in PyTorch on HuggingFace. Examples on how to use it at:
26 | - Jax: `../notebooks/sct/inference_sCT_jax_example.ipynb`
27 | - PyTorch: `../notebooks/sct/inference_sCT_pytorch_example.ipynb`.
28 |
29 | ## Citing our work 📚
30 |
31 | You can cite our model at:
32 |
33 | ```bibtex
34 | @misc{joshi2025a,
35 | title={A long range foundation model for zero-shot predictions in single-cell and
36 | spatial transcriptomics data},
37 | author={Ameya Joshi and Raphael Boige and Lee Zamparo and Ugo Tanielian and Juan Jose
38 | Garau-Luis and Michail Chatzianastasis and Priyanka Pandey and Janik Sielemann and
39 | Alexander Seifert and Martin Brand and Maren Lang and Karim Beguir and Thomas PIERROT},
40 | year={2025},
41 | url={https://openreview.net/forum?id=VdX9tL3VXH}
42 | }
43 | ```
--------------------------------------------------------------------------------
/nucleotide_transformer/bulk_rna_bert/config.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Any
3 |
4 | from pydantic import root_validator
5 | from pydantic.main import BaseModel
6 |
7 |
8 | class BulkRNABertConfig(BaseModel):
9 | n_genes: int
10 | n_expressions_bins: int
11 | embed_dim: int
12 | init_gene_embed_dim: int
13 | project_gene_embedding: bool = False
14 | use_gene_embedding: bool = True
15 |
16 | # architecture
17 | num_attention_heads: int
18 | key_size: int | None = None
19 | ffn_embed_dim: int
20 | num_layers: int
21 | use_memory_efficient_attention: bool = False
22 |
23 | use_gradient_checkpointing: bool = False
24 |
25 | gene2vec_weights_path: str
26 |
27 | embeddings_layers_to_save: tuple[int, ...] = ()
28 | attention_layers_to_save: tuple[int, ...] = ()
29 |
30 | # RNASeq data processing
31 | rnaseq_tokenizer_bins: list[float] | None = None
32 | use_log_normalization: bool = True
33 | use_max_normalization: bool = True
34 | normalization_factor: float | None = None
35 |
36 | @root_validator
37 | @classmethod
38 | def validate_key_size(cls, values: dict[str, Any]) -> dict[str, Any]:
39 | """
40 | Checks that the given values are compatible.
41 | """
42 | key_size = values.get("key_size")
43 | if key_size is None:
44 | embed_dim = values["embed_dim"]
45 | num_attention_heads = values["num_attention_heads"]
46 | if not embed_dim % num_attention_heads == 0:
47 | raise ValueError(
48 | f"When no key size is provided, the embedding dimension should be "
49 | f"divisible by the number of heads, however provided embedding "
50 | f"dimension is {embed_dim} and the number of heads is "
51 | f"{num_attention_heads}."
52 | )
53 | values["key_size"] = embed_dim // num_attention_heads
54 | return values
55 |
56 | @root_validator
57 | @classmethod
58 | def validate_gene_embedding(cls, values: dict[str, Any]) -> dict[str, Any]:
59 | """
60 | Checks that the given values are compatible.
61 | """
62 | use_gene_embedding = values.get("use_gene_embedding")
63 | if use_gene_embedding:
64 | init_gene_embed_dim = values["init_gene_embed_dim"]
65 | embed_dim = values["embed_dim"]
66 | if init_gene_embed_dim != embed_dim:
67 | project_gene_embedding = values["project_gene_embedding"]
68 | if not project_gene_embedding:
69 | logging.warning(
70 | f"Init gene embedding dimension ({init_gene_embed_dim})"
71 | f"different than embedding dimension ({embed_dim})."
72 | f"Setting `project_gene_embedding` to True"
73 | )
74 | values["project_gene_embedding"] = True
75 | return values
76 |
--------------------------------------------------------------------------------
/nucleotide_transformer/enformer/heads.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Dict, List, Optional
3 |
4 | import haiku as hk
5 | import jax
6 | import jax.numpy as jnp
7 | from haiku import initializers
8 | from typing_extensions import TypeAlias
9 |
10 | from nucleotide_transformer.heads import UNET1DSegmentationHead
11 |
12 | SequenceMask: TypeAlias = jnp.ndarray
13 |
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | class UNetHead(hk.Module):
19 | """
20 | Returns a probability between 0 and 1 over a target feature presence
21 | for each nucleotide in the input sequence. Assumes the sequence has been tokenized
22 | with non-overlapping k-mers with k being nucl_per_token.
23 | """
24 |
25 | def __init__(
26 | self,
27 | features: List[str],
28 | num_classes: int = 2,
29 | embed_dimension: int = 1024,
30 | nucl_per_token: int = 6,
31 | num_layers: int = 2,
32 | remove_cls_token: bool = True,
33 | name: Optional[str] = None,
34 | ):
35 | """
36 | Args:
37 | features: List of features names.
38 | num_classes: Number of classes.
39 | embed_dimension: Embedding dimension.
40 | nucl_per_token: Number of nucleotides per token.
41 | num_layers: Number of layers.
42 | remove_cls_token: Whether to remove the CLS token.
43 | name: Name of the layer. Defaults to None.
44 | """
45 | super().__init__(name=name)
46 | self._num_features = len(features)
47 | self._num_classes = num_classes
48 | self.nucl_per_token = nucl_per_token
49 | self.remove_cls_token = remove_cls_token
50 |
51 | w_init = initializers.VarianceScaling(2.0, "fan_in", "uniform")
52 | b_init = initializers.VarianceScaling(2.0, "fan_in", "uniform")
53 | unet = UNET1DSegmentationHead(
54 | num_classes=embed_dimension // 2,
55 | output_channels_list=tuple(
56 | embed_dimension * (2**i) for i in range(num_layers)
57 | ),
58 | )
59 | fc = hk.Linear(
60 | self.nucl_per_token * self._num_classes * self._num_features,
61 | w_init=w_init,
62 | b_init=b_init,
63 | name="fc",
64 | )
65 | self._fc = hk.Sequential([unet, jax.nn.swish, fc])
66 |
67 | def __call__(
68 | self, x: jnp.ndarray, sequence_mask: SequenceMask
69 | ) -> Dict[str, jnp.ndarray]:
70 | """
71 | Input shape: (batch_size, sequence_length, embed_dim)
72 | Output_shape: (batch_size, sequence_length * nucl_per_token,
73 | num_features, num_classes)
74 | """
75 | if self.remove_cls_token:
76 | x = x[:, 1:]
77 | logits = self._fc(x)
78 | batch_size, seq_len = x.shape[0], x.shape[1]
79 | logits = jnp.reshape(
80 | logits,
81 | (
82 | batch_size,
83 | seq_len * self.nucl_per_token,
84 | self._num_features,
85 | self._num_classes,
86 | ),
87 | )
88 | return {"logits": logits}
89 |
--------------------------------------------------------------------------------
/nucleotide_transformer/chatNT/pretrained.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer
2 |
3 | from nucleotide_transformer.chatNT.gpt_decoder import GptConfig, RotaryEmbeddingConfig
4 | from nucleotide_transformer.chatNT.model import build_chat_nt_fn
5 | from nucleotide_transformer.chatNT.multi_modal_perceiver_projection import (
6 | PerceiverResamplerConfig,
7 | )
8 | from nucleotide_transformer.chatNT.params import download_ckpt
9 | from nucleotide_transformer.model import NucleotideTransformerConfig
10 |
11 |
12 | def get_chatNT():
13 | gpt_config = GptConfig(
14 | vocab_size=32000,
15 | eos_token_id=2,
16 | embed_dim=4096,
17 | ffn_embed_dim=11008,
18 | num_heads=32,
19 | num_kv_heads=32,
20 | num_layers=32,
21 | rope_config=RotaryEmbeddingConfig(max_seq_len=2048, dim=128, theta=10000.0),
22 | add_bias_ffn=False,
23 | ffn_activation_name="silu",
24 | use_glu_in_ffn=True,
25 | add_bias_lm_head=False,
26 | norm_type="RMS_norm",
27 | rms_norm_eps=1e-06,
28 | parallel_attention_ff=False,
29 | use_gradient_checkpointing=False,
30 | add_bias_attn=False,
31 | )
32 | nt_config = NucleotideTransformerConfig(
33 | alphabet_size=4107,
34 | pad_token_id=1,
35 | mask_token_id=2,
36 | max_positions=2048,
37 | embed_scale=1.0,
38 | emb_layer_norm_before=False,
39 | attention_heads=16,
40 | key_size=64,
41 | embed_dim=1024,
42 | ffn_embed_dim=4096,
43 | num_layers=29,
44 | positional_embedding=None,
45 | lm_head="roberta",
46 | add_bias_kv=False,
47 | add_bias_ffn=False,
48 | use_rotary_embedding=True,
49 | rescaling_factor=None,
50 | ffn_activation_name="swish",
51 | use_glu_in_ffn=True,
52 | mask_before_attention=False,
53 | layer_norm_eps=1e-05,
54 | pre_layer_norm=True,
55 | bias_word_embedding=False,
56 | token_dropout=False,
57 | masking_ratio=0.0,
58 | masking_prob=0.0,
59 | use_gradient_checkpointing=False,
60 | embeddings_layers_to_save=(21,),
61 | attention_maps_to_save=[],
62 | )
63 | perceiver_resampler_config = PerceiverResamplerConfig(
64 | emb_layer_norm_before=False,
65 | attention_heads=32,
66 | key_size=128,
67 | embed_dim=4096,
68 | ffn_embed_dim=11008,
69 | num_layers=3,
70 | add_bias_kv=False,
71 | add_bias_ffn=True,
72 | ffn_activation_name="gelu-no-approx",
73 | use_glu_in_ffn=False,
74 | resampled_length=64,
75 | use_gradient_checkpointing=False,
76 | )
77 |
78 | bio_tokenizer = AutoTokenizer.from_pretrained(
79 | "InstaDeepAI/ChatNT",
80 | subfolder="bio_tokenizer",
81 | )
82 | english_tokenizer = AutoTokenizer.from_pretrained(
83 | "InstaDeepAI/ChatNT",
84 | subfolder="english_tokenizer",
85 | )
86 | seq_token_id = 32000
87 |
88 | forward_fn = build_chat_nt_fn(
89 | nt_config=nt_config,
90 | gpt_config=gpt_config,
91 | seq_token_id=seq_token_id,
92 | bio_pad_token_id=bio_tokenizer.pad_token_id,
93 | english_pad_token_id=english_tokenizer.pad_token_id,
94 | perceiver_resampler_config=perceiver_resampler_config,
95 | nt_name="dcnuc_v2_500M_multi_species",
96 | gpt_name="llama_decoder",
97 | )
98 |
99 | params = download_ckpt()
100 | params = {
101 | key.replace("bio_brain_decoder", "chat_nt_decoder").replace(
102 | "bio_brain_encoder", "chat_nt_encoder"
103 | ): value
104 | for key, value in params.items()
105 | }
106 |
107 | return forward_fn, params, english_tokenizer, bio_tokenizer
108 |
--------------------------------------------------------------------------------
/nucleotide_transformer/bulk_rna_bert/pretrained.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from typing import Callable
4 |
5 | import haiku as hk
6 | import jax.numpy as jnp
7 | import joblib
8 | from huggingface_hub import hf_hub_download
9 |
10 | from nucleotide_transformer.bulk_rna_bert.config import BulkRNABertConfig
11 | from nucleotide_transformer.bulk_rna_bert.model import build_bulk_rna_bert_forward_fn
12 | from nucleotide_transformer.bulk_rna_bert.tokenizer import BinnedOmicTokenizer
13 |
14 | ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME"
15 | DEFAULT_CACHE_DIR = "~/.cache"
16 |
17 |
18 | def _get_dir(model_name: str) -> str:
19 | """
20 | Get directory to save files on user machine.
21 | """
22 | return os.path.expanduser(
23 | os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), model_name)
24 | )
25 |
26 |
27 | def download_bulkrnabert_ckpt() -> tuple[hk.Params, BulkRNABertConfig]:
28 | """
29 | Download BulkRNABert checkpoint from Hugging Face.
30 |
31 |
32 | Returns:
33 | Model parameters.
34 | Model configuration
35 | """
36 |
37 | save_dir = os.path.join(_get_dir("bulkrnabert"), "bulkrnabert")
38 |
39 | repo_id = "InstaDeepAI/BulkRNABert"
40 |
41 | # Download parameters
42 | print("Downloading model's weights...")
43 | params = joblib.load(
44 | hf_hub_download(
45 | repo_id=repo_id,
46 | filename="jax_params/params.joblib",
47 | cache_dir=save_dir,
48 | )
49 | )
50 |
51 | config_path = hf_hub_download(
52 | repo_id=repo_id,
53 | filename="jax_params/config.json",
54 | )
55 | with open(config_path) as f:
56 | config_dict = json.load(f)
57 | config = BulkRNABertConfig(**config_dict)
58 |
59 | return params, config
60 |
61 |
62 | def get_pretrained_bulkrnabert_model(
63 | compute_dtype: jnp.dtype = jnp.float32,
64 | param_dtype: jnp.dtype = jnp.float32,
65 | output_dtype: jnp.dtype = jnp.float32,
66 | embeddings_layers_to_save: tuple[int, ...] = (),
67 | ) -> tuple[hk.Params, Callable, BinnedOmicTokenizer, BulkRNABertConfig]:
68 | """
69 | Create a Haiku BulkRNABert model.
70 | model by downloading pre-trained weights and hyperparameters.
71 |
72 | Args:
73 | compute_dtype: the type of the activations. fp16 runs faster and is lighter in
74 | memory. bf16 handles better large int, and is hence more stable ( it avoids
75 | float overflows ).
76 | param_dtype: if compute_dtype is fp16, the model weights will be cast to fp16
77 | during the forward pass anyway. So in inference mode ( not training mode ),
78 | it is better to use params in fp16 if compute_dtype is fp16 too. During
79 | training, it is preferable to keep parameters in float32 for better
80 | numerical stability.
81 | output_dtype: the output type of the model. it determines the float precision
82 | of the gradient when training the model.
83 | embeddings_layers_to_save: Intermediate embeddings to return in the output.
84 |
85 | Returns:
86 | Model parameters.
87 | Haiku function to call the model.
88 | Tokenizer.
89 | Model config (hyperparameters).
90 |
91 | """
92 | parameters, config = download_bulkrnabert_ckpt()
93 | tokenizer = BinnedOmicTokenizer(
94 | n_expressions_bins=config.n_expressions_bins,
95 | use_max_normalization=config.use_max_normalization,
96 | normalization_factor=config.normalization_factor, # type: ignore
97 | prepend_cls_token=False,
98 | )
99 |
100 | config.embeddings_layers_to_save = embeddings_layers_to_save
101 |
102 | forward_fn = build_bulk_rna_bert_forward_fn(
103 | model_config=config,
104 | compute_dtype=compute_dtype,
105 | param_dtype=param_dtype,
106 | output_dtype=output_dtype,
107 | model_name="bulk_bert",
108 | )
109 |
110 | return parameters, forward_fn, tokenizer, config
111 |
--------------------------------------------------------------------------------
/nucleotide_transformer/mojo/pretrained.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from typing import Callable
4 |
5 | import haiku as hk
6 | import jax.numpy as jnp
7 | import joblib
8 | from huggingface_hub import hf_hub_download
9 |
10 | from nucleotide_transformer.bulk_rna_bert.tokenizer import BinnedOmicTokenizer
11 | from nucleotide_transformer.mojo.config import MOJOConfig
12 | from nucleotide_transformer.mojo.model import build_mojo_fn
13 |
14 | ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME"
15 | DEFAULT_CACHE_DIR = "~/.cache"
16 |
17 |
18 | def _get_dir(model_name: str) -> str:
19 | """
20 | Get directory to save files on user machine.
21 | """
22 | return os.path.expanduser(
23 | os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), model_name)
24 | )
25 |
26 |
27 | def download_mojo_ckpt() -> tuple[hk.Params, MOJOConfig]:
28 | """
29 | Download MOJO checkpoint from Hugging Face.
30 |
31 |
32 | Returns:
33 | Model parameters.
34 | Model configuration
35 | """
36 |
37 | save_dir = os.path.join(_get_dir("mojo"), "mojo")
38 |
39 | repo_id = "InstaDeepAI/MOJO"
40 |
41 | # Download parameters
42 | print("Downloading model's weights...")
43 | params = joblib.load(
44 | hf_hub_download(
45 | repo_id=repo_id,
46 | filename="jax_params/params.joblib",
47 | cache_dir=save_dir,
48 | )
49 | )
50 |
51 | config_path = hf_hub_download(
52 | repo_id=repo_id,
53 | filename="jax_params/config.json",
54 | )
55 | with open(config_path) as f:
56 | config_dict = json.load(f)
57 | config = MOJOConfig(**config_dict)
58 |
59 | return params, config
60 |
61 |
62 | def get_mojo_pretrained_model(
63 | compute_dtype: jnp.dtype = jnp.float32,
64 | param_dtype: jnp.dtype = jnp.float32,
65 | output_dtype: jnp.dtype = jnp.float32,
66 | ) -> tuple[hk.Params, Callable, dict[str, BinnedOmicTokenizer], MOJOConfig]:
67 | """
68 | Create a Haiku MOJO model.
69 | model by downloading pre-trained weights and hyperparameters.
70 |
71 | Args:
72 | compute_dtype: the type of the activations. fp16 runs faster and is lighter in
73 | memory. bf16 handles better large int, and is hence more stable ( it avoids
74 | float overflows ).
75 | param_dtype: if compute_dtype is fp16, the model weights will be cast to fp16
76 | during the forward pass anyway. So in inference mode ( not training mode ),
77 | it is better to use params in fp16 if compute_dtype is fp16 too. During
78 | training, it is preferable to keep parameters in float32 for better
79 | numerical stability.
80 | output_dtype: the output type of the model. it determines the float precision
81 | of the gradient when training the model.
82 |
83 | Returns:
84 | Model parameters.
85 | Haiku function to call the model.
86 | Tokenizer.
87 | Model config (hyperparameters).
88 |
89 | """
90 | parameters, config = download_mojo_ckpt()
91 | tokenizers = {
92 | omic: BinnedOmicTokenizer(
93 | n_expressions_bins=config.n_expressions_bins[omic],
94 | min_omic_value=config.min_omic_value[omic],
95 | max_omic_value=config.max_omic_value[omic],
96 | use_max_normalization=config.use_max_normalization[omic],
97 | normalization_factor=config.normalization_factor[omic],
98 | prepend_cls_token=False,
99 | fixed_sequence_length=config.fixed_sequence_length,
100 | unpadded_length=config.sequence_length,
101 | )
102 | for omic in config.n_expressions_bins.keys()
103 | }
104 |
105 | forward_fn = build_mojo_fn(
106 | model_config=config,
107 | compute_dtype=compute_dtype,
108 | param_dtype=param_dtype,
109 | output_dtype=output_dtype,
110 | model_name="multi_omics_lm",
111 | )
112 |
113 | return parameters, forward_fn, tokenizers, config
114 |
--------------------------------------------------------------------------------
/docs/codon_nt.md:
--------------------------------------------------------------------------------
1 | # Codon-NT
2 |
3 | Condon-NT is a Nucleotide Transformer model variant trained on 3-mers (codons).
4 | It is a 50M parameters transformer pre-trained on a collection of 850 genomes from a wide range of species, including model and non-model organisms.
5 | This work investigates alternative tokenization strategies for genomic language models and their impact on downstream performance and interpretability.
6 |
7 | * 📜 **[Read the Paper (Bioinformatics 2024)](https://academic.oup.com/bioinformatics/article/40/9/btae529/7745814)**
8 | * 🤗 **[Hugging Face Link](https://huggingface.co/InstaDeepAI/nucleotide-transformer-v2-50m-3mer-multi-species)**
9 |
10 |
11 |
12 | *3mer tokenization achieves better performance than 6-mer on specific protein tasks*
13 |
14 | ## Training data
15 |
16 | The nucleotide-transformer-v2-50m-3mer-multi-species model was pretrained on a total of 850 genomes downloaded from NCBI. Plants and viruses are not included in these genomes, as their regulatory elements differ from those of interest in the paper's tasks. Some heavily studied model organisms were picked to be included in the collection of genomes, which represents a total of 174B nucleotides, i.e roughly 29B tokens. The data has been released as a HuggingFace dataset here.
17 |
18 | ## Training detals
19 |
20 | ### Data pre-rocessing
21 | The DNA sequences are tokenized using the Nucleotide Transformer Tokenizer, which tokenizes sequences as 6-mers tokenizer when possible, otherwise tokenizing each nucleotide separately as described in the Tokenization section of the associated repository. This tokenizer has a vocabulary size of 4105. The inputs of the model are then of the form:
22 |
23 |
9 |
10 | *Overview of agronomic nucleotide transformer.*
11 |
12 |
13 |
14 | *AgroNT provides gene expression prediction across different plant species.
15 | Gene expression prediction on holdout genes across all tissues are correlated with observed gene expression levels.
16 |
17 | ## Model architecture
18 |
19 | AgroNT uses the transformer architecture with self-attention and a masked language modeling objective to leverage highly
20 | available genotype data from 48 different plant speices to learn general representations of nucleotide sequences.
21 | AgroNT contains 1 billion parameters and has a context window of 1024 tokens. AgroNt uses a non-overlapping 6-mer
22 | tokenizer to convert genomic nucletoide sequences to tokens. As a result the 1024 tokens correspond to approximately 6144 base pairs.
23 |
24 | ## Pre-training
25 |
26 | ### Data
27 | Our pre-training dataset was built from (mostly) edible plants reference genomes contained in the Ensembl Plants database. The dataset consists of approximately 10.5 million genomic sequences across 48 different species.
28 |
29 | ### Processing
30 | All reference genomes for each specie were assembled into a single fasta file. In this fasta file, all nucleotides other than A, T, C, G were replaced by N. A tokenizer was used to convert strings of letters into sequences of tokens. The tokenizer's alphabet consisted of the 46 = 4096 possible 6-mer combinations obtained by combining A, T, C, G, as well as five additional tokens representing standalone A, T, C, G, and N. It also included three special tokens: the pad [PAD], mask [MASK], and class [CLS] tokens. This resulted in a vocabulary of 4104 tokens. To tokenize an input sequence, the tokenizer started with a class token and then converted the sequence from left to right, matching 6-mer tokens when possible, or using the standalone tokens when necessary (for instance, when the letter N was present or if the sequence length was not a multiple of 6).
31 |
32 | ### Tokenization example
33 |
34 | ```python
35 | nucleotide sequence: ATCCCGGNNTCGACACN
36 | tokens:
12 |
13 | ## Architecture and Parameters
14 | ChatNT is built on a three‑module design: a 500M‑parameter [Nucleotide Transformer v2](https://www.nature.com/articles/s41592-024-02523-z) DNA encoder pre‑trained on genomes from 850 species
15 | (handling up to 12 kb per sequence, Dalla‑Torre et al., 2024), an English‑aware Perceiver Resampler that linearly projects and gated‑attention compresses
16 | 2048 DNA‑token embeddings into 64 task‑conditioned vectors (REF), and a frozen 7B‑parameter [Vicuna‑7B](https://lmsys.org/blog/2023-03-30-vicuna/) decoder.
17 |
18 | Users provide a natural‑language prompt containing one or more `
19 |
20 | *The Nucleotide Transformer model accurately predicts diverse genomics tasks
21 | after fine-tuning. We show the performance results across downstream tasks for fine-tuned transformer models. Error bars represent 2 SDs derived from 10-fold cross-validation.*
22 |
23 | ## Model Variants and Sizes
24 |
25 | NT models are transformer-based, varying in size (50 million to 2.5 billion parameters) and pre-training data.
26 | Two main sets of NT models were developed:
27 |
28 | 1 - Initial NT Models (NT-v1): This first generation of models explored a range of parameter sizes and pre-training data sources:
29 | - Human ref 500M: A 500-million-parameter model trained on the human reference genome (GRCh38/hg38).
30 | - 1000G 500M: A 500-million-parameter model trained on 3,202 diverse human genomes from the 1000 Genomes Project.
31 | - 1000G 2.5B: A 2.5-billion-parameter model, also trained on the 3,202 genomes from the 1000 Genomes Project.
32 | - Multispecies 2.5B: A 2.5-billion-parameter model trained on a diverse collection of 850 genomes from various species, including 11 model organisms.
33 |
34 | 2 - Optimized NT Models (NT-v2): Following insights from the initial models, a second set of four models was developed, focusing on parameter efficiency and architectural advancements. These models range from 50 million to 500 million parameters.
35 | - Our second version Nucleotide Transformer v2 models include a series of architectural changes that proved more efficient: instead of using learned positional embeddings, we use Rotary Embeddings that are used at each attention layer and Gated Linear Units with swish activations without bias. These improved models also accept sequences of up to 2,048 tokens leading to a longer context window of 12kbp.
36 | - Inspired by Chinchilla scaling laws, we also trained our NT-v2 models on our multi-species dataset for longer duration (300B tokens for the 50M and 100M models; 1T tokens for the 250M and 500M model) compared to the v1 models (300B tokens for all four models).
37 | - NT-v2 models were all pre-trained on the multispecies dataset
38 |
39 | Both type of models are encoder-only transformers tokenizing DNA into 6-mers (4,104 total tokens including special ones).
40 | NT-v1 used learnable positional encodings for a 6kb context.
41 | NT-v2 models incorporated rotary positional embeddings (RoPE), SwiGLU activations, removed biases and dropout, and extended context length to 12kb.
42 |
43 | ## Available Resources
44 |
45 | Pre-training datasets, models and downstream tasks can be found on our [HuggingFace space](https://huggingface.co/collections/InstaDeepAI/nucleotide-transformer-65099cdde13ff96230f2e592).
46 | Interactive Leaderboard on downstream tasks is available at https://huggingface.co/spaces/InstaDeepAI/nucleotide_transformer_benchmark.
47 |
48 | ## Tokenization :abc:
49 |
50 | The models are trained on sequences of length up to 1000 (NT-v1) and 2000 (NT-v1) tokens, including the
51 | \
12 |
13 | *Fig. 1: SegmentNT localizes genomics elements at nucleotide resolution.*
14 |
15 | ## How to use 🚀
16 |
17 | ### SegmentNT
18 |
19 | ⚠️ The SegmentNT models leverage the [Nucleotide Transformer (NT)](https://www.nature.com/articles/s41592-024-02523-z) backbone and have been trained on a sequences of 30,000 nucleotides, or 5001 tokens (accounting for the CLS token). However, SegmentNT has been shown to generalize up to sequences of 50,000 bp. For training on 30,000 bps, which is a length
20 | superior than the maximum length of 2048 6-mers tokens that the nucleotide transformer can handle, Yarn rescaling is employed.
21 | By default, the `rescaling factor` is set to the one used during the training. In case you need to infer on sequences between 30kbp and 50kbp, make sure to pass the `rescaling_factor` argument in the `get_pretrained_segment_nt_model` function with
22 | the value `rescaling_factor = max_num_nucleotides / max_num_tokens_nt` where `num_dna_tokens_inference` is the number of tokens at inference (i.e 6669 for a sequence of 40008 base pairs) and `max_num_tokens_nt` is the max number of tokens on which the backbone nucleotide-transformer was trained on, i.e `2048`.
23 |
24 | 🔍 The notebook `../notebooks/segment_nt/inference_segment_nt.ipynb` showcases how to infer on a 50kb sequence and plot the probabilities to reproduce the Fig.3 of the paper.
25 |
26 | 🚧 The SegmentNT models do not handle any "N" in the input sequence because each nucleotides need to be tokenized as 6-mers, which can not be the case when using sequences containing one or multiple "N" base pairs.
27 |
28 | ```python
29 | import haiku as hk
30 | import jax
31 | import jax.numpy as jnp
32 | from nucleotide_transformer.pretrained import get_pretrained_segment_nt_model
33 |
34 | # Initialize CPU as default JAX device. This makes the code robust to memory leakage on
35 | # the devices.
36 | jax.config.update("jax_platform_name", "cpu")
37 |
38 | backend = "cpu"
39 | devices = jax.devices(backend)
40 | num_devices = len(devices)
41 | print(f"Devices found: {devices}")
42 |
43 | # The number of DNA tokens (excluding the CLS token prepended) needs to be dividible by
44 | # 2 to the power of the number of downsampling block, i.e 4.
45 | max_num_nucleotides = 8
46 |
47 | assert max_num_nucleotides % 4 == 0, (
48 | "The number of DNA tokens (excluding the CLS token prepended) needs to be dividible by"
49 | "2 to the power of the number of downsampling block, i.e 4.")
50 |
51 | parameters, forward_fn, tokenizer, config = get_pretrained_segment_nt_model(
52 | model_name="segment_nt",
53 | embeddings_layers_to_save=(29,),
54 | attention_maps_to_save=((1, 4), (7, 10)),
55 | max_positions=max_num_nucleotides + 1,
56 | )
57 | forward_fn = hk.transform(forward_fn)
58 | apply_fn = jax.pmap(forward_fn.apply, devices=devices, donate_argnums=(0,))
59 |
60 |
61 | # Get data and tokenize it
62 | sequences = ["ATTCCGATTCCGATTCCAACGGATTATTCCGATTAACCGATTCCAATT", "ATTTCTCTCTCTCTCTGAGATCGATGATTTCTCTCTCATCGAACTATG"]
63 | tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
64 | tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)
65 |
66 | random_key = jax.random.PRNGKey(seed=0)
67 | keys = jax.device_put_replicated(random_key, devices=devices)
68 | parameters = jax.device_put_replicated(parameters, devices=devices)
69 | tokens = jax.device_put_replicated(tokens, devices=devices)
70 |
71 | # Infer on the sequence
72 | outs = apply_fn(parameters, keys, tokens)
73 | # Obtain the logits over the genomic features
74 | logits = outs["logits"]
75 | # Transform them in probabilities
76 | probabilities = jnp.asarray(jax.nn.softmax(logits, axis=-1))[...,-1]
77 | print(f"Probabilities shape: {probabilities.shape}")
78 |
79 | print(f"Features inferred: {config.features}")
80 |
81 | # Get probabilities associated with intron
82 | idx_intron = config.features.index("intron")
83 | probabilities_intron = probabilities[..., idx_intron]
84 | print(f"Intron probabilities shape: {probabilities_intron.shape}")
85 | ```
86 |
87 | Supported model names are:
88 | - **segment_nt**
89 | - **segment_nt_multi_species**
90 |
91 | The code runs both on GPU and TPU thanks to Jax!
92 |
93 | ---
94 | ### SegmentEnformer
95 |
96 | SegmentEnformer leverages [Enformer](https://www.nature.com/articles/s41592-021-01252-x) by removing the prediction head and replacing it by a 1-dimensional U-Net segmentation head to predict the location of several types of genomics elements in a sequence at a single nucleotide resolution.
97 |
98 | 🔍 The notebook `../notebooks/segment_nt/inference_segment_enformer.ipynb` showcases how to infer on a 196,608bp sequence and plot the probabilities.
99 |
100 | ```python
101 | import haiku as hk
102 | import jax
103 | import jax.numpy as jnp
104 | import numpy as np
105 |
106 | from nucleotide_transformer.enformer.pretrained import get_pretrained_segment_enformer_model
107 | from nucleotide_transformer.enformer.features import FEATURES
108 |
109 | # Initialize CPU as default JAX device. This makes the code robust to memory leakage on
110 | # the devices.
111 | jax.config.update("jax_platform_name", "cpu")
112 |
113 | backend = "cpu"
114 | devices = jax.devices(backend)
115 | num_devices = len(devices)
116 |
117 | # Load model
118 | parameters, state, forward_fn, tokenizer, config = get_pretrained_segment_enformer_model()
119 | forward_fn = hk.transform_with_state(forward_fn)
120 |
121 | apply_fn = jax.pmap(forward_fn.apply, devices=devices, donate_argnums=(0,))
122 | random_key = jax.random.PRNGKey(seed=0)
123 |
124 | # Replicate over devices
125 | keys = jax.device_put_replicated(random_key, devices=devices)
126 | parameters = jax.device_put_replicated(parameters, devices=devices)
127 | state = jax.device_put_replicated(state, devices=devices)
128 |
129 | # Get data and tokenize it
130 | sequences = ["A" * 196_608]
131 | tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
132 | tokens = jnp.stack([jnp.asarray(tokens_ids, dtype=jnp.int32)] * num_devices, axis=0)
133 |
134 | # Infer
135 | outs, state = apply_fn(parameters, state, keys, tokens)
136 |
137 | # Obtain the logits over the genomic features
138 | logits = outs["logits"]
139 | # Transform them on probabilities
140 | probabilities = np.asarray(jax.nn.softmax(logits, axis=-1))[..., -1]
141 |
142 | # Get probabilities associated with intron
143 | idx_intron = FEATURES.index("intron")
144 | probabilities_intron = probabilities[..., idx_intron]
145 | print(f"Intron probabilities shape: {probabilities_intron.shape}")
146 | ```
147 |
148 | ### SegmentBorzoi
149 | SegmentBorzoi leverages [Borzoi](https://www.nature.com/articles/s41588-024-02053-6) by removing the prediction head and replacing it by a 1-dimensional U-Net segmentation head to predict the location of several types of genomics elements in a sequence.
150 |
151 | 🔍 The notebook `../notebooks/segment_nt/inference_segment_borzoi.ipynb` showcases how to infer on a 196608bp sequence and plot the probabilities.
152 |
153 | ```python
154 | import haiku as hk
155 | import jax
156 | import jax.numpy as jnp
157 | import numpy as np
158 |
159 | from nucleotide_transformer.borzoi.pretrained import get_pretrained_segment_borzoi_model
160 | from nucleotide_transformer.enformer.features import FEATURES
161 |
162 | # Initialize CPU as default JAX device. This makes the code robust to memory leakage on
163 | # the devices.
164 | jax.config.update("jax_platform_name", "cpu")
165 |
166 | backend = "cpu"
167 | devices = jax.devices(backend)
168 | num_devices = len(devices)
169 |
170 | # Load model
171 | parameters, state, forward_fn, tokenizer, config = get_pretrained_segment_borzoi_model()
172 | forward_fn = hk.transform_with_state(forward_fn)
173 | apply_fn = jax.pmap(forward_fn.apply, devices=devices, donate_argnums=(0,))
174 | random_key = jax.random.PRNGKey(seed=0)
175 |
176 | # Replicate over devices
177 | keys = jax.device_put_replicated(random_key, devices=devices)
178 | parameters = jax.device_put_replicated(parameters, devices=devices)
179 | state = jax.device_put_replicated(state, devices=devices)
180 |
181 | # Get data and tokenize it
182 | sequences = ["A" * 524_288]
183 | tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
184 | tokens = jnp.stack([jnp.asarray(tokens_ids, dtype=jnp.int32)] * num_devices, axis=0)
185 |
186 | # Infer
187 | outs, state = apply_fn(parameters, state, keys, tokens)
188 |
189 | # Obtain the logits over the genomic features
190 | logits = outs["logits"]
191 | # Transform them on probabilities
192 | probabilities = np.asarray(jax.nn.softmax(logits, axis=-1))[..., -1]
193 |
194 | # Get probabilities associated with intron
195 | idx_intron = FEATURES.index("intron")
196 | probabilities_intron = probabilities[..., idx_intron]
197 | print(f"Intron probabilities shape: {probabilities_intron.shape}")
198 |
199 | ```
200 |
201 | ## Citing our work 📚
202 |
203 | You can cite our models at:
204 |
205 | ```bibtex
206 | @Article{deAlmeida2025segmentNT,
207 | title={Annotating the genome at single-nucleotide resolution with DNA foundation models},
208 | author={de Almeida, Bernardo P. and Dalla-Torre, Hugo and Richard, Guillaume and Blum, Christopher and Hexemer, Lorenz and G{\'e}lard, Maxence and Mendoza-Revilla,
209 | Javier and Tang, Ziqi and Marin, Frederikke I. and Emms, David M. and Pandey, Priyanka and Laurent, Stefan and Lopez, Marie and Laterre, Alexandre and Lang, Maren
210 | and {\c{S}}ahin, U{\u{g}}ur and Beguir, Karim and Pierrot, Thomas},
211 | journal={Nature Methods},
212 | year={2025},
213 | month={Oct},
214 | day={29},
215 | issn={1548-7105},
216 | doi={10.1038/s41592-025-02881-2},
217 | url={https://doi.org/10.1038/s41592-025-02881-2}
218 | }
219 | ```
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
8 | A hub for InstaDeep's cutting-edge deep learning models and research for genomics, originating from the Nucleotide Transformer and its evolutions. 9 |
10 | 11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |