├── .gitignore ├── .readthedocs.yaml ├── Dockerfile ├── README.md ├── chrombert ├── __init__.py ├── base │ ├── __init__.py │ ├── model.py │ ├── model_config.py │ └── utils │ │ ├── __init__.py │ │ ├── emb_manager.py │ │ ├── embedding.py │ │ ├── feed_forward.py │ │ ├── gelu.py │ │ ├── layer_norm.py │ │ ├── sublayer.py │ │ └── transformer.py ├── finetune │ ├── README.md │ ├── __init__.py │ ├── dataset │ │ ├── __init__.py │ │ ├── basic_dataset.py │ │ ├── data_module.py │ │ ├── dataset_config.py │ │ ├── general_dataset.py │ │ ├── multi_flankwindow_dataset.py │ │ ├── presets │ │ │ ├── general.json │ │ │ ├── general_mm10.json │ │ │ ├── multi_flank_window.json │ │ │ ├── multi_flank_window_mm10.json │ │ │ ├── prompt_cistrome.json │ │ │ ├── prompt_dna.json │ │ │ ├── prompt_dnase.json │ │ │ ├── prompt_exp.json │ │ │ └── prompt_exp_pbmc.json │ │ └── prompt_dataset │ │ │ ├── __init__.py │ │ │ ├── interface.py │ │ │ ├── interface_manager.py │ │ │ ├── prompt_dataset.py │ │ │ ├── prompt_dataset_single.py │ │ │ └── prompt_dataset_two.py │ ├── model │ │ ├── __init__.py │ │ ├── basic_model.py │ │ ├── general_ft_model.py │ │ ├── gep_ft_model.py │ │ ├── model_config.py │ │ ├── presets │ │ │ ├── general.json │ │ │ ├── general_mm10.json │ │ │ ├── gep.json │ │ │ ├── gep_mm10.json │ │ │ ├── prompt_cistrome.json │ │ │ ├── prompt_cistrome_mm10.json │ │ │ ├── prompt_dna.json │ │ │ ├── prompt_dnase.json │ │ │ ├── prompt_dnase_mm10.json │ │ │ └── prompt_exp.json │ │ ├── prompt_dna_model.py │ │ ├── prompt_ft_model.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── dnabert2.py │ │ │ ├── emb_manager.py │ │ │ ├── general_header.py │ │ │ ├── gep_header.py │ │ │ ├── layer_norm.py │ │ │ ├── pool_flank_window.py │ │ │ ├── prompt_header.py │ │ │ └── residual_block.py │ └── train │ │ ├── __init__.py │ │ ├── basic_pl_module.py │ │ ├── pl_module.py │ │ ├── train_config.py │ │ └── utils │ │ ├── __init__.py │ │ ├── logger.py │ │ └── loss.py └── scripts │ ├── chrombert_get_cistrome_emb.py │ ├── chrombert_get_region_emb.py │ ├── chrombert_get_regulator_emb.py │ ├── chrombert_imputation.py │ ├── chrombert_imputation_sc.py │ ├── chrombert_make_dataset.py │ ├── chrombert_prepare_env.py │ ├── demo.py │ └── utils │ ├── __init__.py │ └── h5_manager.py ├── docs ├── Makefile ├── _static │ ├── 1_ChromBERT_framework.png │ └── ChromBERT_framework.png ├── make.bat ├── requirements.txt └── source │ ├── cli.rst │ ├── conf.py │ ├── finetune.rst │ ├── index.rst │ ├── installation.rst │ ├── quick_tour.rst │ ├── scripts │ ├── chrombert_get_cistrome_emb.rst │ ├── chrombert_get_region_emb.rst │ ├── chrombert_get_regulator_emb.rst │ ├── chrombert_imputation_cistrome.rst │ ├── chrombert_imputation_cistrome_sc.rst │ ├── chrombert_make_dataset.rst │ ├── chrombert_prepare_env.rst │ ├── ft_general.rst │ ├── ft_gep.rst │ └── ft_prompt_enhanced.rst │ ├── tutorial_embedding_extraction.nblink │ ├── tutorial_finetuning_ChromBERT.nblink │ ├── tutorial_locus_specific_TRN_eqtl.nblink │ ├── tutorial_locus_specific_TRN_ezh2.nblink │ ├── tutorial_locus_specific_TRN_starr.nblink │ ├── tutorial_prompt_cistrome_imputation.nblink │ ├── tutorial_transdifferentiation.rst │ ├── tutorial_transdifferentiation_chromatin_accessibility.nblink │ └── tutorial_transdifferentiation_transcriptome.nblink ├── examples ├── readme.md ├── train │ ├── ft_general.py │ ├── ft_gep.py │ └── ft_prompt_enhanced.py └── tutorials │ ├── tutorial_embedding_extraction.ipynb │ ├── tutorial_finetuning_ChromBERT.ipynb │ ├── tutorial_locus_specific_TRN_ezh2.ipynb │ ├── tutorial_locus_specific_TRN_starr.ipynb │ ├── tutorial_prompt_cistrome_imputation.ipynb │ ├── tutorial_prompt_eqtl.ipynb │ ├── tutorial_transdifferentiation_chromatin_accessibility.ipynb │ └── tutorial_transdifferentiation_transcriptome.ipynb ├── lumache.py └── pyproject.toml /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,jupyternotebooks 4 | 5 | ### JupyterNotebooks ### 6 | # gitignore template for Jupyter Notebooks 7 | # website: http://jupyter.org/ 8 | 9 | .ipynb_checkpoints 10 | */.ipynb_checkpoints/* 11 | 12 | # IPython 13 | profile_default/ 14 | ipython_config.py 15 | 16 | # Remove previous ipynb_checkpoints 17 | # git rm -r .ipynb_checkpoints/ 18 | 19 | ### Python ### 20 | # Byte-compiled / optimized / DLL files 21 | __pycache__/ 22 | *.py[cod] 23 | *$py.class 24 | 25 | # C extensions 26 | *.so 27 | 28 | # Distribution / packaging 29 | .Python 30 | build/ 31 | develop-eggs/ 32 | dist/ 33 | downloads/ 34 | eggs/ 35 | .eggs/ 36 | lib/ 37 | lib64/ 38 | parts/ 39 | sdist/ 40 | var/ 41 | wheels/ 42 | share/python-wheels/ 43 | *.egg-info/ 44 | .installed.cfg 45 | *.egg 46 | MANIFEST 47 | 48 | # PyInstaller 49 | # Usually these files are written by a python script from a template 50 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 51 | *.manifest 52 | *.spec 53 | 54 | # Installer logs 55 | pip-log.txt 56 | pip-delete-this-directory.txt 57 | 58 | # Unit test / coverage reports 59 | htmlcov/ 60 | .tox/ 61 | .nox/ 62 | .coverage 63 | .coverage.* 64 | .cache 65 | nosetests.xml 66 | coverage.xml 67 | *.cover 68 | *.py,cover 69 | .hypothesis/ 70 | .pytest_cache/ 71 | cover/ 72 | 73 | # Translations 74 | *.mo 75 | *.pot 76 | 77 | # Django stuff: 78 | *.log 79 | local_settings.py 80 | db.sqlite3 81 | db.sqlite3-journal 82 | 83 | # Flask stuff: 84 | instance/ 85 | .webassets-cache 86 | 87 | # Scrapy stuff: 88 | .scrapy 89 | 90 | # Sphinx documentation 91 | docs/_build/ 92 | 93 | # PyBuilder 94 | .pybuilder/ 95 | target/ 96 | 97 | # Jupyter Notebook 98 | 99 | # IPython 100 | 101 | # pyenv 102 | # For a library or package, you might want to ignore these files since the code is 103 | # intended to run in multiple environments; otherwise, check them in: 104 | # .python-version 105 | 106 | # pipenv 107 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 108 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 109 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 110 | # install all needed dependencies. 111 | #Pipfile.lock 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # End of https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks 157 | 158 | .DS_Store 159 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: "2" 2 | sphinx: 3 | # Path to your Sphinx configuration file. 4 | configuration: docs/source/conf.py 5 | 6 | build: 7 | os: "ubuntu-22.04" 8 | tools: 9 | python: "3.10" 10 | 11 | python: 12 | install: 13 | - requirements: docs/requirements.txt 14 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.3.2-devel-ubuntu22.04 2 | 3 | # Install essential packages 4 | RUN apt-get update && apt-get install -y wget bzip2 tar gzip && \ 5 | rm -rf /var/lib/apt/lists/* 6 | 7 | # Install miniconda 8 | RUN wget -q https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 9 | bash Miniconda3-latest-Linux-x86_64.sh -b -p /opt/miniconda && \ 10 | rm Miniconda3-latest-Linux-x86_64.sh 11 | 12 | # Add conda to PATH 13 | ENV PATH="/opt/miniconda/bin:${PATH}" 14 | 15 | # Install python 3.9 using conda 16 | RUN conda install -y python=3.9 && \ 17 | conda clean -a -y 18 | 19 | # Install PyTorch for CUDA 12.1 using pip 20 | RUN pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu121 && \ 21 | pip cache purge 22 | 23 | # Install packaging and ninja (required by flash_attn) 24 | RUN pip install packaging ninja && pip cache purge 25 | 26 | # Copy and install flash_attn 27 | COPY ./flash_attn-2.4.3.post1+cu122torch2.2cxx11abiFALSE-cp39-cp39-linux_x86_64.whl /tmp/ 28 | RUN pip install /tmp/flash_attn-2.4.3.post1+cu122torch2.2cxx11abiFALSE-cp39-cp39-linux_x86_64.whl && \ 29 | rm /tmp/flash_attn-*.whl && \ 30 | pip cache purge 31 | 32 | # Copy, extract, and install ChromBERT 33 | COPY ./ChromBERT.tar.gz /tmp/ 34 | RUN tar -xzf /tmp/ChromBERT.tar.gz -C /tmp && \ 35 | pip install /tmp/ChromBERT && \ 36 | rm -rf /tmp/ChromBERT* && \ 37 | pip cache purge 38 | 39 | # Install other dependencies 40 | RUN pip install scipy jupyterlab && pip cache purge 41 | 42 | # Set the working directory 43 | WORKDIR /workspace 44 | 45 | # Set Python as entrypoint 46 | ENTRYPOINT ["python"] 47 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ChromBERT: A pre-trained foundation model for context-specific transcription regulatory network 2 | [![Documentation](https://img.shields.io/badge/docs-available-brightgreen)](https://chrombert.readthedocs.io/en/) 3 | [![License: GPL v3](https://img.shields.io/badge/License-GPLv3-blue.svg)](https://www.gnu.org/licenses/gpl-3.0) 4 | [![Version: 1.1.0](https://img.shields.io/badge/Version-1.1.0-brightgreen.svg)](https://chrombert.readthedocs.io/en/) 5 | 6 | **ChromBERT** is a pre-trained deep learning model designed to capture the genome-wide co-association patterns of approximately one thousand transcription regulators, thereby enabling accurate representations of context-specific transcriptional regulatory networks (TRNs). As a foundational model, ChromBERT can be fine-tuned to adapt to various biological contexts through transfer learning. This significantly enhances our understanding of transcription regulation and offers a powerful tool for a broad range of research and clinical applications in different biological settings. 7 | 8 | ![ChromBERT Framework](docs/_static/ChromBERT_framework.png "Framework") 9 | 10 | ## Installation 11 | For direct usage, it is recommended to utilize the [Singularity image](#installation-using-singularity). For development purposes, installing from [source](#installation-from-source) is advised. 12 | 13 | ### Installation From Source 14 | 15 | ChromBERT is compatible with Python versions 3.8 or higher and requires PyTorch 2.0 or above, along with FlashAttention-2. These dependencies must be installed prior to ChromBERT. 16 | 17 | 18 | #### Installing PyTorch 19 | Follow the detailed instructions on [PyTorch’s official site](https://pytorch.org/get-started/locally/) to install PyTorch according to your device and CUDA version specifications. 20 | 21 | **Note: ChromBERT has been tested with Python 3.9+ and Torch 2.0 to 2.4 (inclusive). Compatibility with other environments is not guaranteed.** 22 | 23 | #### Installing FlashAttention-2 24 | Execute the following commands to install the requried packages and [FlashAttention-2](https://github.com/Dao-AILab/flash-attention). 25 | ```shell 26 | # install the required packages for FlashAttention-2 27 | pip install packaging 28 | pip install ninja 29 | 30 | pip install flash-attn==2.4.* --no-build-isolation # FlashAttention-3 is not supported yet, please install FlashAttention-2 31 | ``` 32 | 33 | #### Installing ChromBERT 34 | Clone the repository and install ChromBERT using the commands below: 35 | ```shell 36 | git clone https://github.com/TongjiZhanglab/ChromBERT.git 37 | cd ChromBERT 38 | pip install . 39 | ``` 40 | 41 | Installation typically takes less than five minutes. 42 | 43 | 44 | Then download required pre-trained model and annotation data files from Hugging Face to ~/.cache/chrombert/data. 45 | ```shell 46 | chrombert_prepare_env 47 | ``` 48 | 49 | Alternatively, if you're experiencing significant connectivity issues with Hugging Face, you can try to use the `--hf-endpoint` option to connect to an available mirror of Hugging Face for you. 50 | ```shell 51 | chrombert_prepare_env --hf-endpoint 52 | ``` 53 | 54 | #### Verifying Installation 55 | 56 | To verify installation, execute the following command: 57 | ```python 58 | import chrombert 59 | ``` 60 | 61 | ### Installation Using Singularity 62 | 63 | We provide a pre-built Singularity image available [here](https://drive.google.com/file/d/1ePmDK6DANSq-zkRgVBTxSBnKBZk-cEzM/view?usp=sharing). 64 | 65 | After installing `Singularity` (or `Apptainer`) and downloading the image (`chrombert.sif`), you can use the built-in `python` environment with: 66 | 67 | ```bash 68 | singularity exec --nv chrombert.sif python -c "import chrombert; print('hello chrombert')" 69 | ``` 70 | 71 | You can execute other built-in commands through the image as well. For example, to download the required pre-trained models and annotation files from Hugging Face to `~/.cache/chrombert/data`, run: 72 | 73 | > **Note:** You must execute this command to prepare the environment, as the image does not include checkpoints and additional data by default to minimize size. 74 | 75 | ```bash 76 | singularity exec --nv chrombert.sif chrombert_prepare_env 77 | ``` 78 | 79 | To run your own Python scripts, use: 80 | 81 | ```bash 82 | singularity exec --nv chrombert.sif python 83 | ``` 84 | 85 | The image also includes a built-in Jupyter kernel for interactive script development via `jupyter notebook` or editors like `VSCode`: 86 | 87 | ```bash 88 | singularity exec --nv chrombert.sif jupyter notebook [other parameters] 89 | ``` 90 | 91 | By default, Singularity mounts your home directory inside the container. If you need to mount additional directories, use the `--bind` parameter. Refer to the [Singularity documentation](https://docs.sylabs.io/guides/3.0/user-guide/bind_paths_and_mounts.html) for more details. 92 | 93 | 94 | ## Usage 95 | 96 | For detailed information on usage, please checkout the documentations and tutorials at [chrombert.readthedocs.io](https://chrombert.readthedocs.io/en/latest/). 97 | 98 | 99 | ## Pre-trained Model Zoo 100 | 101 | ChromBERT has been initially trained on the human Cistrome-Human-6K dataset at 1-kb resolution. Currently available pre-trained models include: 102 | | Model Name | Description | Download Link | 103 | | :------------------------ | :------------------------------------------------------- | :------------------------------------------------------------------------------------------------ | 104 | | Human-6K-1kb | Pre-trained on Cistrome-Human-6K dataset at 1-kb resolution | [Download](https://huggingface.co/datasets/TongjiZhanglab/chrombert) | 105 | | Mouse-5K-1kb | Pre-trained on Cistrome-Mouse-5K dataset at 1-kb resolution | [Download](https://huggingface.co/datasets/TongjiZhanglab/chrombert) | 106 | 107 | Note: Models can also be downloaded via the `chrombert_prepare_env` command, as outlined in the installation section. 108 | 109 | ## Fine-tuning ChromBERT for downstream tasks 110 | 111 | Explore detailed examples of how to fine-tune ChromBERT for downstream tasks such as prompt-enhanced fine-tuning for generative prediction, and analyses focused on locus specificities and cellular dynamics of TRNs, by visiting our examples page at [chrombert.readthedocs.io](https://chrombert.readthedocs.io/en/latest/). 112 | 113 | ## Citing ChromBERT 114 | 115 | ```bibtex 116 | @article {Yu2025.03.29.646077, 117 | author = {Yu, Zhaowei and Yang, Dongxu and Chen, Qianqian and Zhang, Yuxuan and Li, Zhanhao and Wang, Yucheng and Wang, Chenfei and Zhang, Yong}, 118 | title = {Learning interpretable representation for context-specific transcription regulatory networks using a foundation model}, 119 | year = {2025}, 120 | doi = {10.1101/2025.03.29.646077}, 121 | publisher = {Cold Spring Harbor Laboratory}, 122 | journal = {bioRxiv} 123 | } 124 | ``` -------------------------------------------------------------------------------- /chrombert/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ChromBERT, ChromBERTConfig 2 | from .finetune import ChromBERTFTConfig, get_preset_model_config,ChromBERTEmbedding,LitChromBERTFTDataModule 3 | from .finetune import DatasetConfig, get_preset_dataset_config 4 | 5 | VERSION = "1.0.0" -------------------------------------------------------------------------------- /chrombert/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import ChromBERT 2 | from .model_config import ChromBERTConfig 3 | -------------------------------------------------------------------------------- /chrombert/base/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import lightning.pytorch as pl 4 | from .utils import BERTEmbedding 5 | from .utils import EncoderTransformerBlock 6 | from .utils import ChromBERTEmbedding 7 | 8 | class ChromBERT(nn.Module): 9 | def __init__(self, config): 10 | """ 11 | ChromBERT: pre-trained foundation model for context-specific transcription regulatory network. 12 | Args: 13 | config (:obj:`ChromBERTConfig`): configuration of the model. 14 | """ 15 | super().__init__() 16 | self.config = config 17 | 18 | self.hidden = config.hidden_dim 19 | self.n_layers = config.num_layers 20 | self.attn_heads = config.num_attention_heads 21 | 22 | self.feed_forward_hidden = config.hidden_dim * 4 23 | 24 | # BERT-like embedding, sum of position and token embeddings 25 | self.embedding = BERTEmbedding(config) 26 | 27 | # multi-layers transformer blocks 28 | self.transformer_blocks = nn.ModuleList( 29 | [EncoderTransformerBlock(config) for _ in range(self.n_layers)]) 30 | 31 | 32 | def forward(self, x, position_ids, key_padding_mask = None, attn_weight = False, attn_layer = None): 33 | # attention masking for padded token 34 | x = self.embedding(x, position_ids) 35 | 36 | if attn_layer == -1: 37 | attn = [] 38 | # running over multiple transformer blocks 39 | for i,transformer in enumerate(self.transformer_blocks): 40 | if attn_weight: 41 | if attn_layer == -1: 42 | x, attn_score = transformer.forward(x, key_padding_mask, attn_weight = True, ) 43 | attn.append(attn_score) 44 | elif i == attn_layer: 45 | x, attn = transformer.forward(x, key_padding_mask, attn_weight = attn_weight, ) 46 | # attn.append(attn) 47 | else: 48 | x = transformer.forward(x, key_padding_mask, attn_weight = False,) 49 | else: 50 | x = transformer.forward(x, key_padding_mask) 51 | 52 | # return outs 53 | return (x, attn) if attn_weight else x 54 | 55 | def load_ckpt(self, ckpt_path): 56 | ck = torch.load(ckpt_path, map_location=torch.device('cpu')) 57 | self.load_state_dict(ck) 58 | return None 59 | 60 | def freeze(self, trainable = 2): 61 | ''' 62 | Freeze the model's parameters, allowing fine-tuning of specific transformer blocks. 63 | For trainable = N layers: 64 | - If `N = 0`, all transformer blocks are frozen. 65 | - If `N > 0`, only the last N transformer blocks are trainable and all other blocks are frozen. 66 | ''' 67 | assert isinstance(trainable, int), 'trainable should be an integer' 68 | assert trainable >= 0 69 | if trainable >= 0: 70 | for name, parameter in self.named_parameters(): 71 | parameter.requires_grad = False 72 | 73 | total_layers = len(self.transformer_blocks) 74 | assert trainable <= total_layers, 'trainable should not be greater than total transformer blocks' 75 | for i in range(total_layers - trainable, total_layers): 76 | for name, parameter in self.transformer_blocks[i].named_parameters(): 77 | parameter.requires_grad = True 78 | 79 | # if trainable < 0: 80 | # for name, parameter in self.named_parameters(): 81 | # parameter.requires_grad = True 82 | 83 | return None 84 | 85 | def display_trainable_parameters(self, verbose = True): 86 | ''' 87 | display the number of trainable parameters in the model 88 | ''' 89 | total_params = sum(p.numel() for p in self.parameters()) 90 | trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) 91 | o = {"total_params": total_params, "trainable_params": trainable_params} 92 | print(o) 93 | if verbose: 94 | for name, parameter in self.named_parameters(): 95 | if parameter.requires_grad: 96 | print(name, ": trainable") 97 | else: 98 | print(name, ": frozen") 99 | return o 100 | 101 | def get_embedding_manager(self, mtx_mask, ignore = False, ignore_index= None): 102 | ''' 103 | get an embedding manager for the pretrain model. 104 | params: 105 | mtx_mask: a matrix that mask the embedding, 1 for available, 0 for unavailable. 106 | ignore: if True, ignore the embedding of the specified index. 107 | ignore_index: the index to be ignored. 108 | ''' 109 | model_emb = ChromBERTEmbedding(self, mtx_mask = mtx_mask, ignore = ignore, ignore_index = ignore_index) 110 | return model_emb -------------------------------------------------------------------------------- /chrombert/base/model_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | 5 | from dataclasses import dataclass, field, asdict, fields 6 | from typing import Dict, Any, Union, Optional, List 7 | from .model import ChromBERT 8 | 9 | 10 | @dataclass 11 | class ChromBERTConfig: 12 | genome: str = 'hg38' 13 | dropout: float = 0.1 14 | dtype_str: str = field(default='bfloat16', repr=False) 15 | ckpt: str = None 16 | 17 | 18 | 19 | def __post_init__(self): 20 | """ 21 | ChromBERTConfig, the configuration of ChromBERT model. It is able to instantiate a ChromBERT model through from the pretrained model. 22 | """ 23 | assert self.genome in ['hg38', 'mm10'], f"genome should be hg38 for human, or mm10 for mouse, but got {self.genome}" 24 | print(f"use organisim {self.genome}; max sequence length is {self.n_datasets - 1}") 25 | 26 | @property 27 | def n_datasets(self): 28 | if self.genome == 'hg38': 29 | return 6392 30 | elif self.genome == 'mm10': 31 | return 5616 32 | 33 | @property 34 | def dtype(self): 35 | return getattr(torch, self.dtype_str) 36 | 37 | @property 38 | def vocab_size(self): 39 | return 10 40 | 41 | @property 42 | def vocab_size_shift(self): 43 | return 5 44 | 45 | @property 46 | def hidden_dim(self): 47 | return 768 48 | 49 | @property 50 | def num_layers(self): 51 | return 8 52 | 53 | @property 54 | def feed_forward_dim(self): 55 | return 3072 56 | 57 | @property 58 | def num_attention_heads(self): 59 | return 8 60 | 61 | @property 62 | def token_id_pad(self): 63 | return 0 64 | 65 | @property 66 | def pe_mode(self): 67 | return 'train' 68 | 69 | 70 | @property 71 | def flash_bias(self): 72 | return True 73 | 74 | @property 75 | def flash_batch_first(self): 76 | return True 77 | 78 | @property 79 | def flash_causal(self): 80 | return False 81 | 82 | @property 83 | def flash_device(self): 84 | return None 85 | 86 | def save(self, config_file: str): 87 | values = asdict(self) 88 | with open(config_file, 'w') as f: 89 | json.dump(values, f, indent=4) 90 | 91 | def __repr__(self): 92 | values = asdict(self) 93 | return f"ChromBERTConfig({values})" 94 | 95 | def __str__(self): 96 | values = asdict(self) 97 | return json.dumps(values, indent=4) 98 | 99 | 100 | @classmethod 101 | def load(cls, config: Union[str, Dict[str, Any], "ChromBERTConfig",None]=None, **kwargs: Any): 102 | if config == None: 103 | config_dict = {} 104 | elif isinstance(config, str): 105 | with open(config, 'r') as f: 106 | config_dict = json.load(f) 107 | elif isinstance(config, Dict): 108 | config_dict = config 109 | elif isinstance(config, ChromBERTConfig): 110 | config_dict = asdict(config) 111 | else: 112 | raise TypeError(f"config must be a str, Dict, or ChromBERTConfig, but got {type(config)}") 113 | 114 | config_dict.update(kwargs) 115 | 116 | return cls(**config_dict) 117 | 118 | def init_model(self, ckpt=None): 119 | ''' 120 | Instantiate the model using the configuration. 121 | ''' 122 | model = ChromBERT(self) 123 | if ckpt is None: 124 | ckpt = self.ckpt 125 | if ckpt is None: 126 | print(f"Warning: no ckpt provided, use random initialization!") 127 | elif os.path.exists(ckpt): 128 | model.load_ckpt(ckpt) 129 | else: 130 | print(f"Warning: ckpt {ckpt} not exists, use random initialization!") 131 | return model 132 | 133 | @classmethod 134 | def get_ckpt_type(cls, ckpt): 135 | assert isinstance(ckpt, str) 136 | ckpt = torch.load(ckpt, map_location='cpu') 137 | if "state" in ckpt: 138 | return "finetune" 139 | else: 140 | return "pretrain" 141 | 142 | 143 | 144 | 145 | 146 | -------------------------------------------------------------------------------- /chrombert/base/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .feed_forward import PositionwiseFeedForward 2 | from .sublayer import SublayerConnection 3 | from .gelu import GELU 4 | from .layer_norm import LayerNorm 5 | from .embedding import BERTEmbedding 6 | from .transformer import EncoderTransformerBlock 7 | from .emb_manager import ChromBERTEmbedding -------------------------------------------------------------------------------- /chrombert/base/utils/emb_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pandas as pd 4 | from torch import nn 5 | class CistromeEmbeddingManager(nn.Module): 6 | def __init__(self, mtx_mask, ignore = False,ignore_index = None) -> None: 7 | super().__init__() 8 | assert mtx_mask is not None, "mtx_mask must be specified" 9 | assert isinstance(mtx_mask, str), "mtx_mask must be a path to a mtx_mask" 10 | assert os.path.exists(mtx_mask), f"{mtx_mask} does not exist" 11 | self.mtx_mask_df = pd.read_csv(mtx_mask, sep='\t', index_col=0) 12 | self.mtx_mask = torch.tensor(self.mtx_mask_df.values) # (datasets, factors) 13 | self.gsmid_names = self.mtx_mask_df.index.tolist() 14 | self.regulator_names = self.mtx_mask_df.columns.tolist() 15 | 16 | if ignore: 17 | ignore_gsmid_index = ignore_index[0] 18 | ignore_regulator_index = ignore_index[1] 19 | print(f"Ignoring {len(ignore_gsmid_index)} cistromes and {len(ignore_regulator_index)} regulators") 20 | rows_to_keep = torch.tensor([i not in ignore_gsmid_index for i in range(self.mtx_mask.shape[0])]) 21 | cols_to_keep = torch.tensor([j not in ignore_regulator_index for j in range(self.mtx_mask.shape[1])]) 22 | self.mtx_mask=self.mtx_mask[rows_to_keep][:, cols_to_keep] 23 | self.mtx_mask_df = self.mtx_mask_df.iloc[rows_to_keep.numpy(), cols_to_keep.numpy()] 24 | self.gsmid_names = self.mtx_mask_df.index.tolist() 25 | self.regulator_names = self.mtx_mask_df.columns.tolist() 26 | 27 | factor_num = (self.mtx_mask != 0).sum(dim=0) 28 | self.normalization_factors = factor_num.clamp(min=1) 29 | self.normalized_mtx_mask = (self.mtx_mask / self.normalization_factors) 30 | return None 31 | 32 | def forward(self, x): 33 | # x: [batch_size, datasets, hidden] 34 | self.normalized_mtx_mask = self.normalized_mtx_mask.to(x.device) 35 | self.normalized_mtx_mask = self.normalized_mtx_mask.to(x.dtype) 36 | x = x.transpose(1, 2) 37 | x = torch.matmul(x, self.normalized_mtx_mask) 38 | x = x.transpose(1, 2) # [batch_size, factors, hidden] 39 | 40 | return x 41 | 42 | 43 | def get_cistrome_embedding(self, x, gsmid): 44 | assert gsmid in self.gsmid_names, f"{gsmid} not found in GSMID names" 45 | index = self.gsmid_names.index(gsmid) 46 | return x[:, index, :] 47 | 48 | def get_regulator_embedding(self, x, regulator): 49 | # x: [batch_size, datasets, hidden] 50 | assert regulator in self.regulator_names, f"{regulator} not found in regulator names" 51 | index = self.regulator_names.index(regulator) 52 | x = x.transpose(1, 2) 53 | self.normalized_mtx_mask = self.normalized_mtx_mask.to(x.device) 54 | extracted_x = torch.matmul(x, self.normalized_mtx_mask[:, index:index+1]) 55 | return extracted_x.transpose(1, 2) 56 | 57 | def get_region_embedding(self, x): 58 | return x.mean(dim=1) 59 | 60 | 61 | class ChromBERTEmbedding(nn.Module): 62 | def __init__(self, pretrain_model, mtx_mask, ignore = False,ignore_index = None) -> None: 63 | super().__init__() 64 | self.pretrain_model = pretrain_model 65 | self.CistromeEmbeddingManager = CistromeEmbeddingManager(mtx_mask, ignore = ignore,ignore_index = ignore_index) 66 | self.__hidden_cistrome = None 67 | self.__hidden_regulator = None 68 | self.__training = pretrain_model.training 69 | self.list_regulator = self.CistromeEmbeddingManager.regulator_names 70 | self.list_cistrome = self.CistromeEmbeddingManager.gsmid_names 71 | 72 | 73 | def forward(self, batch): 74 | with torch.no_grad(): 75 | self.pretrain_model.eval() 76 | x = self.pretrain_model(batch["input_ids"], batch["position_ids"]) 77 | self.__hidden_cistrome = x 78 | emb = self.CistromeEmbeddingManager(x) 79 | self.__hidden_regulator = emb 80 | if self.__training: 81 | self.pretrain_model.train() 82 | return emb 83 | 84 | def get_hidden_state(self): 85 | return self.__hidden_state 86 | 87 | def get_cistrome_embedding(self, gsmid): 88 | gsmid = gsmid.lower() 89 | return self.CistromeEmbeddingManager.get_cistrome_embedding(self.__hidden_cistrome, gsmid) 90 | 91 | def get_regulator_embedding(self, regulator): 92 | regulator = regulator.lower() 93 | # return self.CistromeEmbeddingManager.get_regulator_embedding(self.__hidden_state, regulator) 94 | assert regulator in self.list_regulator, f"{regulator} not found in regulator names" 95 | index = self.list_regulator.index(regulator) 96 | return self.__hidden_regulator[:, index, :] 97 | 98 | def get_region_embedding(self): 99 | return self.CistromeEmbeddingManager.get_region_embedding(self.__hidden_cistrome) 100 | 101 | -------------------------------------------------------------------------------- /chrombert/base/utils/embedding.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class TokenEmbedding(nn.Embedding): 6 | def __init__(self, config): 7 | super().__init__(config.vocab_size, config.hidden_dim, config.token_id_pad) 8 | 9 | 10 | class PositionalEmbeddingTrainable(nn.Module): 11 | def __init__(self, config): 12 | super().__init__() 13 | self.pe = nn.Embedding(config.n_datasets, config.hidden_dim) 14 | self.d_model = config.hidden_dim 15 | self.n_datasets = config.n_datasets 16 | 17 | def forward(self, x): 18 | return self.pe(x) 19 | 20 | class PositionalEmbedding(nn.Module): 21 | def __init__(self, config): 22 | super().__init__() 23 | if config.pe_mode == "train": 24 | self.pe = PositionalEmbeddingTrainable(config) 25 | else: 26 | raise ValueError(f"only support train mode for positional embedding! {config.pe_mode} is not supported!") 27 | 28 | def forward(self, x): 29 | return self.pe(x) 30 | 31 | 32 | class BERTEmbedding(nn.Module): 33 | """ 34 | BERT Embedding which is consisted with under features 35 | 1. TokenEmbedding : normal embedding matrix 36 | 2. PositionalEmbedding : adding positional information using sin, cos 37 | 2. SegmentEmbedding : adding sentence segment info, (sent_A:1, sent_B:2) 38 | 39 | sum of all these features are output of BERTEmbedding 40 | """ 41 | 42 | def __init__(self, config): 43 | """ 44 | :param config.vocab_size: total vocab size 45 | :param config.hidden_dim: embedding size of token embedding 46 | :param config.dropout: dropout rate 47 | :param config.pe_mode: train or word2vec, for positional embedding choice 48 | :param config.dtype 49 | """ 50 | super().__init__() 51 | self.token = TokenEmbedding(config) 52 | self.position = PositionalEmbedding(config) 53 | 54 | self.dropout = nn.Dropout(p=config.dropout) 55 | self.embed_size = config.hidden_dim 56 | self.dtype = config.dtype 57 | self.config = config 58 | 59 | def forward(self, sequence, position_ids): 60 | sequence = sequence.long() 61 | x = self.token(sequence) + self.position(position_ids) 62 | return self.dropout(x).to(self.dtype) 63 | 64 | -------------------------------------------------------------------------------- /chrombert/base/utils/feed_forward.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .gelu import GELU 3 | 4 | 5 | class PositionwiseFeedForward(nn.Module): 6 | "Implements FFN equation." 7 | 8 | def __init__(self, d_model, d_ff, dropout=0.1): 9 | super(PositionwiseFeedForward, self).__init__() 10 | self.w_1 = nn.Linear(d_model, d_ff) 11 | self.w_2 = nn.Linear(d_ff, d_model) 12 | self.dropout = nn.Dropout(dropout) 13 | self.activation = GELU() 14 | 15 | def forward(self, x): 16 | return self.w_2(self.dropout(self.activation(self.w_1(x)))) 17 | -------------------------------------------------------------------------------- /chrombert/base/utils/gelu.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | 5 | 6 | class GELU(nn.Module): 7 | """ 8 | Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU 9 | """ 10 | 11 | def forward(self, x): 12 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 13 | -------------------------------------------------------------------------------- /chrombert/base/utils/layer_norm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class LayerNorm(nn.Module): 6 | "Construct a layernorm module (See citation for details)." 7 | 8 | def __init__(self, features, eps=1e-6): 9 | super(LayerNorm, self).__init__() 10 | self.a_2 = nn.Parameter(torch.ones(features)) 11 | self.b_2 = nn.Parameter(torch.zeros(features)) 12 | self.eps = eps 13 | 14 | def forward(self, x): 15 | mean = x.mean(-1, keepdim=True) 16 | std = x.std(-1, keepdim=True) 17 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 18 | -------------------------------------------------------------------------------- /chrombert/base/utils/sublayer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .layer_norm import LayerNorm 3 | 4 | 5 | class SublayerConnection(nn.Module): 6 | """ 7 | A residual connection followed by a layer norm. 8 | Note for code simplicity the norm is first as opposed to last. 9 | """ 10 | 11 | def __init__(self, size, dropout): 12 | super(SublayerConnection, self).__init__() 13 | self.norm = LayerNorm(size) 14 | self.dropout = nn.Dropout(dropout) 15 | 16 | def forward(self, x, sublayer, index=None): 17 | "Apply residual connection to any sublayer with the same size." 18 | y = sublayer(self.norm(x)) 19 | if index is not None: 20 | y1 = y[index] 21 | o = x + self.dropout(y1), y[-1] # pick attention_weights 22 | else: 23 | o = x + self.dropout(y) 24 | return o 25 | -------------------------------------------------------------------------------- /chrombert/base/utils/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import flash_attn 6 | 7 | if flash_attn.__version__.split(".")[0] == "1": 8 | from flash_attn.flash_attention import FlashAttention 9 | print("flash attention version 2 is not installed, using version 1 instead") 10 | flash_attention_version = 1 11 | else: 12 | from flash_attn import flash_attn_qkvpacked_func, flash_attn_func 13 | flash_attention_version = 2 14 | 15 | 16 | from functools import partial 17 | from einops import rearrange 18 | from .feed_forward import PositionwiseFeedForward 19 | from .sublayer import SublayerConnection 20 | 21 | 22 | class SelfAttentionFlashMHA(nn.Module): 23 | 24 | def __init__(self, config) -> None: 25 | assert config.flash_batch_first 26 | factory_kwargs = {'device': config.flash_device, 'dtype': config.dtype} 27 | super().__init__() 28 | self.embed_dim = config.hidden_dim 29 | self.causal = config.flash_causal 30 | self.num_heads = config.num_attention_heads 31 | assert self.embed_dim % self.num_heads == 0, "self.kdim must be divisible by num_heads" 32 | self.head_dim = self.embed_dim // self.num_heads 33 | assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8" 34 | 35 | self.Wqkv = nn.Linear(self.embed_dim, 3 *self.embed_dim, bias=config.flash_bias, **factory_kwargs) 36 | if flash_attention_version == 1: 37 | f = FlashAttention(attention_dropout=config.dropout) 38 | self.inner_attn = f 39 | else: 40 | self.inner_attn = partial(flash_attn_qkvpacked_func, dropout_p = config.dropout, causal=self.causal) 41 | 42 | self.dtype = config.dtype 43 | 44 | def forward(self, x, key_padding_mask=None, need_weights=False, attn_weight = False): 45 | """x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) 46 | key_padding_mask: bool tensor of shape (batch, seqlen) 47 | """ 48 | x = x.to(self.dtype) 49 | qkv = self.Wqkv(x) 50 | qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads) 51 | 52 | if flash_attention_version == 2: 53 | context = self.inner_attn(qkv) 54 | else: 55 | context,_ = self.inner_attn(qkv, key_padding_mask=key_padding_mask, 56 | need_weights=need_weights, causal=self.causal) 57 | if attn_weight : 58 | with torch.no_grad(): 59 | qkvhp = qkv.permute(0,2,3,1,4) 60 | q,k,v = qkvhp[:,0,:,:,:], qkvhp[:,1,:,:,:], qkvhp[:,2,:,:,:] 61 | attn = torch.softmax(torch.matmul(q, k.transpose(-2, -1))/ math.sqrt(q.shape[-1]),dim = -1).detach().cpu() 62 | # attn = attn.sum(axis = -2).squeeze(0) # key dim sum 63 | else: 64 | attn = None 65 | 66 | return (rearrange(context, 'b s h d -> b s (h d)'), attn) if attn_weight else rearrange(context, 'b s h d -> b s (h d)') 67 | 68 | 69 | class EncoderTransformerBlock(nn.Module): 70 | """ 71 | Bidirectional Encoder = Transformer (self-attention) 72 | Transformer = MultiHead_Attention + Feed_Forward with sublayer connection 73 | """ 74 | def __init__(self, config): 75 | """ 76 | :param hidden: hidden size of transformer 77 | :param attn_heads: head sizes of multi-head attention 78 | :param feed_forward_hidden: feed_forward_hidden, usually 4*hidden_size 79 | :param dropout: dropout rate 80 | """ 81 | 82 | super().__init__() 83 | self.attention = SelfAttentionFlashMHA(config) 84 | self.feed_forward = PositionwiseFeedForward(d_model=config.hidden_dim, d_ff=config.feed_forward_dim, dropout=config.dropout) 85 | self.input_sublayer = SublayerConnection(size=config.hidden_dim, dropout=config.dropout) 86 | self.output_sublayer = SublayerConnection(size=config.hidden_dim, dropout=config.dropout) 87 | self.dropout = nn.Dropout(p=config.dropout) 88 | 89 | self.dtype = config.dtype 90 | 91 | def forward(self, x, mask, attn_weight = False): 92 | x = x.to(self.dtype) 93 | if attn_weight : 94 | x, out_attn = self.input_sublayer(x, lambda x: self.attention.forward(x, mask, need_weights=False, attn_weight=attn_weight), index = 0) # get context and attention_score 95 | else: 96 | x = self.input_sublayer(x, lambda x: self.attention.forward(x, mask)) 97 | 98 | x = self.output_sublayer(x, self.feed_forward) 99 | 100 | if attn_weight: 101 | out = x, out_attn 102 | else: 103 | out = x 104 | # return out 105 | return out 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /chrombert/finetune/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaoweiyu-github/ChromBERT/a5c39e9960c235038e710c6afffda821b926bfa4/chrombert/finetune/README.md -------------------------------------------------------------------------------- /chrombert/finetune/__init__.py: -------------------------------------------------------------------------------- 1 | from ..base import ChromBERT, ChromBERTConfig 2 | 3 | from .dataset import DatasetConfig, get_preset_dataset_config,LitChromBERTFTDataModule 4 | from .model import ChromBERTFTConfig, get_preset_model_config, ChromBERTEmbedding 5 | # from chrombert.base import ChromBERTConfig 6 | from .train import TrainConfig, ClassificationPLModule, RegressionPLModule, ZeroInflationPLModule 7 | -------------------------------------------------------------------------------- /chrombert/finetune/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset_config import DatasetConfig, get_preset_dataset_config 2 | from .general_dataset import GeneralDataset 3 | from .multi_flankwindow_dataset import MultiFlankwindowDataset 4 | from .prompt_dataset import PromptDataset 5 | from .data_module import LitChromBERTFTDataModule -------------------------------------------------------------------------------- /chrombert/finetune/dataset/data_module.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from lightning.pytorch import LightningDataModule 3 | from torch.utils.data import DataLoader, Subset 4 | from typing import Optional 5 | from .dataset_config import DatasetConfig 6 | from .multi_flankwindow_dataset import MultiFlankwindowDataset 7 | from .general_dataset import GeneralDataset 8 | from .prompt_dataset import PromptDataset 9 | 10 | class LitChromBERTFTDataModule(LightningDataModule): 11 | ''' 12 | For training with pytorch lightning. 13 | ''' 14 | def __init__(self, config=None, train_params={}, val_params={}, test_params={}, **params): 15 | ''' 16 | LitENBERTDataModule is a class that defines the configuration of the dataset. 17 | Args: 18 | config: DatasetConfig. 19 | {train|val|test}_params: specific params to modify config for train|val|test respetively. 20 | 21 | ''' 22 | if isinstance(config, str): 23 | config = DatasetConfig(config, **params) 24 | 25 | assert isinstance(config, DatasetConfig), f"config must be a DatasetConfig object, but got {type(config)}" 26 | self.basic_config = type(config)(config, **params) 27 | self.train_config = type(config)(config=self.basic_config, **train_params) 28 | self.val_config = type(config)(config=self.basic_config, **val_params) 29 | self.test_config = type(config)(config=self.basic_config, **test_params) 30 | 31 | assert self.train_config.kind == self.val_config.kind == self.test_config.kind 32 | 33 | super().__init__() 34 | self.num_train_epochs = None 35 | self.num_val_epochs = None 36 | self.num_test_epochs = None 37 | self.train_dataset = None 38 | self.val_dataset = None 39 | self.test_dataset = None 40 | 41 | self.has_train = train_params != {} 42 | self.has_val = val_params != {} 43 | self.has_test = test_params != {} 44 | 45 | def setup(self, stage: Optional[str] = None): 46 | 47 | if stage == "fit" or stage is None: 48 | if self.has_train: 49 | self.train_dataset = self.train_config.init_dataset() 50 | self.num_train_epochs = len(self.train_dataset) // self.train_config.batch_size 51 | 52 | if self.has_val: 53 | self.val_dataset = self.val_config.init_dataset() 54 | self.num_val_epochs = len(self.val_dataset) // self.val_config.batch_size 55 | indices = list(range(len(self.val_dataset))) 56 | np.random.shuffle(indices) 57 | self.shuffled_val_dataset = Subset(self.val_dataset, indices[:len(self.val_dataset)]) 58 | 59 | if self.has_test: 60 | self.test_dataset = self.test_config.init_dataset() 61 | self.num_test_epochs = len(self.test_dataset) // self.test_config.batch_size 62 | 63 | elif stage == "val": 64 | if self.has_val: 65 | indices = list(range(len(self.val_dataset))) 66 | np.random.shuffle(indices) 67 | self.shuffled_val_dataset = Subset(self.val_dataset, indices[:len(self.val_dataset)]) 68 | 69 | def train_dataloader(self): 70 | if self.train_dataset: 71 | dl = DataLoader(self.train_dataset, batch_size=self.train_config.batch_size, shuffle=True, num_workers=self.train_config.num_workers) 72 | self.num_train_epochs = len(dl) # force shuffling 73 | return dl 74 | return None 75 | 76 | def val_dataloader(self): 77 | if self.val_dataset: 78 | dl = DataLoader(self.shuffled_val_dataset, batch_size=self.val_config.batch_size, shuffle=True, num_workers=self.val_config.num_workers) # specifically for validation, because the foreced shuffing 79 | self.num_val_epochs = len(dl) 80 | return dl 81 | return None 82 | 83 | def test_dataloader(self): 84 | if self.test_dataset: 85 | dl = DataLoader(self.test_dataset, batch_size=self.test_config.batch_size, shuffle=False, num_workers=self.test_config.num_workers) # forced no shuffling 86 | self.num_test_epochs = len(dl) 87 | return dl 88 | return None 89 | -------------------------------------------------------------------------------- /chrombert/finetune/dataset/general_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import pandas as pd 3 | from typing import Any 4 | from .basic_dataset import IgnoreDataset 5 | import numpy as np 6 | class GeneralDataset(IgnoreDataset): 7 | ''' 8 | Dataset class for general purposes. 9 | ''' 10 | 11 | def __init__(self,config = None, **params: Any): 12 | ''' 13 | It's recommend to instantiate the class using DatasetConfig.init(). 14 | params: 15 | config: DatasetConfig. supervised_file must be provided. 16 | 17 | ''' 18 | super().__init__(config, **params) 19 | self.config = config 20 | self.supervised(config.supervised_file) 21 | self.__getitem__(0) # make sure initiation 22 | 23 | def supervised(self, supervised_file = None): 24 | ''' 25 | process supervised file to obtain necessary information 26 | ''' 27 | assert isinstance(supervised_file, str) 28 | if supervised_file.endswith('.csv'): 29 | df_supervised = pd.read_csv(supervised_file, header = 0) # csv format, [chrom, start, end, build_region_index, label, other meta datas] 30 | elif supervised_file.endswith('.tsv'): 31 | df_supervised = pd.read_csv(supervised_file, header = 0,sep='\t') # tsv format, [chrom, start, end, build_region_index, label, other meta datas] 32 | elif supervised_file.endswith('.feather'): 33 | df_supervised = pd.read_feather(supervised_file) 34 | else: 35 | raise(ValueError(f"supervised_file must be csv, tsv or feather file!")) 36 | 37 | self.supervised_indices = df_supervised["build_region_index"] 38 | self.supervised_indices_len = len(self.supervised_indices) 39 | 40 | neccessary_columns = ["chrom","start","end","build_region_index"] 41 | for column in neccessary_columns: 42 | if column not in df_supervised.columns: 43 | raise(ValueError(f"{column} not in supervised_file! it must contain headers: {neccessary_columns}")) 44 | 45 | self.optional_columns(df_supervised) 46 | 47 | def optional_columns(self,df): 48 | if 'label' not in df.columns: ### only "chrom","start","end","build_region_index" columns and to predict 49 | self.supervised_labels = None 50 | print(f"Your supervised_file does not contain the 'label' column. Please verify whether ground truth column ('label') is required. If it is not needed, you may disregard this message.") 51 | else: 52 | self.supervised_labels = df['label'].values 53 | 54 | if self.config.perturbation: 55 | if self.config.perturbation_object is not None: 56 | self.perturbation_object = [self.config.perturbation_object] * (self.supervised_indices_len) 57 | print("use perturbation_object in dataset config which high priority than supervised_file") 58 | elif "perturbation_object" in df.columns: 59 | self.perturbation_object = df['perturbation_object'].fillna("none").values 60 | print("use perturbation_object in supervised_file") 61 | else: 62 | raise AttributeError("When perturbation is set, perturbation_object should be set correctly. you can provided 'perturbation_object' column in your supervised_file or you can set perturbation_object in dataset config") 63 | 64 | if "ignore_object" in df.columns: 65 | self.ignore_object = df['ignore_object'].unique().tolist() 66 | assert(len(self.ignore_object)==1) 67 | if self.config.ignore_object is None: 68 | self.config.ignore_object = self.ignore_object[0] 69 | 70 | else: 71 | self.ignore_object = None 72 | 73 | 74 | def __len__(self): 75 | return self.supervised_indices_len 76 | 77 | def __getitem__(self, index): 78 | basic_index = self.supervised_indices[index] 79 | 80 | if self.config.perturbation: 81 | self.config.perturbation_object = self.perturbation_object[index] 82 | 83 | if self.config.ignore and self.config.ignore_object is None: 84 | raise AttributeError("When ignore is set, ignore_object should be set correctly. you can provided 'ignore_object' column in your supervised_file or you can set ignore_object in dataset config") 85 | 86 | item = super().__getitem__(basic_index) 87 | 88 | if self.supervised_labels is not None: 89 | item['label'] = self.supervised_labels[index] 90 | 91 | return item 92 | 93 | -------------------------------------------------------------------------------- /chrombert/finetune/dataset/multi_flankwindow_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import torch 4 | from typing import Any 5 | from .basic_dataset import IgnoreDataset 6 | 7 | 8 | class MultiFlankwindowDataset(IgnoreDataset): 9 | ''' 10 | Dataset class for process multi-flank-window dataset. Supervised file is required. 11 | ''' 12 | def __init__(self, 13 | config = None, 14 | **params: Any): 15 | ''' 16 | It's recommend to instantiate the class using DatasetConfig.init(). 17 | params: 18 | config: DatasetConfig. supervised_file must be provided. 19 | ''' 20 | super().__init__(config) 21 | self.flank_window = config.flank_window 22 | self.max_region_idx = self.len - 1 # self.len from BasicDataset, mean the maximum region idx 23 | self.supervised(config.supervised_file) 24 | 25 | def supervised(self, supervised_file = None): 26 | assert isinstance(supervised_file, str) 27 | if supervised_file.endswith('.csv'): 28 | df_supervised = pd.read_csv(supervised_file, header = 0) # csv format, [chrom, start, end, build_region_index, label,other_meta] 29 | elif supervised_file.endswith('.tsv'): 30 | df_supervised = pd.read_csv(supervised_file, header = 0,sep='\t') # tsv format, [chrom, start, end, build_region_index, label,other_meta] 31 | elif supervised_file.endswith('.feather'): 32 | df_supervised = pd.read_feather(supervised_file) 33 | else: 34 | raise(ValueError(f"supervised_file must be csv, tsv, feather file!")) 35 | neccessary_columns = ["chrom","start","end","build_region_index"] 36 | for column in neccessary_columns: 37 | if column not in df_supervised.columns: 38 | raise(ValueError(f"{column} not in supervised_file! it must contain headers: {neccessary_columns}")) 39 | self.supervised_indices = df_supervised["build_region_index"] 40 | self.supervised_indices_len = len(self.supervised_indices) 41 | 42 | ### labels 43 | if 'label' not in df_supervised.columns: ### only "chrom","start","end","build_region_index" columns and to predict 44 | self.supervised_labels = None 45 | print(f"Your file '{supervised_file}' does not contain the 'label' column. Please verify whether the true ground truth ('label') is required. If it is not needed, you may disregard this message.") 46 | else: 47 | self.supervised_labels = df_supervised['label'].values 48 | 49 | 50 | 51 | def __len__(self): 52 | return self.supervised_indices_len 53 | 54 | def _get_item_from_super(self, idx): 55 | return super().__getitem__(idx) 56 | 57 | def __getitem__(self, index): 58 | 59 | label = torch.tensor(self.supervised_labels[index]) 60 | 61 | region_id = int(self.supervised_indices[index]) 62 | flank_region_id = np.arange(region_id - self.flank_window, region_id + self.flank_window + 1) 63 | flank_region_id[flank_region_id < 0] = 0 64 | flank_region_id[flank_region_id > self.max_region_idx] = self.max_region_idx 65 | input_ids = torch.stack([self._get_item_from_super(id)['input_ids'] for id in flank_region_id]) 66 | position_ids = torch.stack([self._get_item_from_super(id)['position_ids'] for id in flank_region_id]) 67 | center_region = self._get_item_from_super(region_id)['region'] 68 | center_build_region_index = self._get_item_from_super(region_id)['build_region_index'] 69 | return { 70 | 'label': label, 71 | 'flank_region_id': flank_region_id, 72 | 'input_ids': input_ids, 73 | 'position_ids': position_ids, 74 | 'center_region': center_region, 75 | 'center_build_region_index': center_build_region_index 76 | } -------------------------------------------------------------------------------- /chrombert/finetune/dataset/presets/general.json: -------------------------------------------------------------------------------- 1 | { 2 | "kind":"GeneralDataset", 3 | "hdf5_file":"hg38_6k_1kb.hdf5", 4 | "meta_file":"config/hg38_6k_meta.json", 5 | "ignore": false, 6 | "perturbation": false 7 | } -------------------------------------------------------------------------------- /chrombert/finetune/dataset/presets/general_mm10.json: -------------------------------------------------------------------------------- 1 | { 2 | "kind":"GeneralDataset", 3 | "hdf5_file":"mm10_5k_1kb.hdf5", 4 | "meta_file":"config/mm10_5k_meta.json", 5 | "ignore": false, 6 | "perturbation": false 7 | } -------------------------------------------------------------------------------- /chrombert/finetune/dataset/presets/multi_flank_window.json: -------------------------------------------------------------------------------- 1 | { 2 | "kind":"MultiFlankwindowDataset", 3 | "hdf5_file":"hg38_6k_1kb.hdf5", 4 | "meta_file":"config/hg38_6k_meta.json", 5 | "ignore": false, 6 | "perturbation": false, 7 | "flank_window":4 8 | } -------------------------------------------------------------------------------- /chrombert/finetune/dataset/presets/multi_flank_window_mm10.json: -------------------------------------------------------------------------------- 1 | { 2 | "kind":"MultiFlankwindowDataset", 3 | "hdf5_file":"mm10_5k_1kb.hdf5", 4 | "meta_file":"config/mm10_5k_meta.json", 5 | "ignore": false, 6 | "perturbation": false, 7 | "flank_window":4 8 | } -------------------------------------------------------------------------------- /chrombert/finetune/dataset/presets/prompt_cistrome.json: -------------------------------------------------------------------------------- 1 | { 2 | "kind":"PromptDataset", 3 | "hdf5_file":"hg38_6k_1kb.hdf5", 4 | "meta_file":"config/hg38_6k_meta.json", 5 | "ignore": false, 6 | "perturbation": false, 7 | "prompt_kind": "cistrome" 8 | } 9 | -------------------------------------------------------------------------------- /chrombert/finetune/dataset/presets/prompt_dna.json: -------------------------------------------------------------------------------- 1 | { 2 | "kind":"PromptDataset", 3 | "prompt_kind": "dna", 4 | "hdf5_file":"hg38_6k_1kb.hdf5", 5 | "meta_file":"config/hg38_6k_meta.json", 6 | "ignore": false, 7 | "perturbation": false, 8 | "fasta_file":"other/hg38.fa" 9 | } -------------------------------------------------------------------------------- /chrombert/finetune/dataset/presets/prompt_dnase.json: -------------------------------------------------------------------------------- 1 | { 2 | "kind":"PromptDataset", 3 | "hdf5_file":"hg38_6k_1kb.hdf5", 4 | "meta_file":"config/hg38_6k_meta.json", 5 | "ignore": false, 6 | "perturbation": false, 7 | "prompt_kind": "cistrome", 8 | "prompt_regulator_cache_file": "cache/hg38_6k_1kb_regulator_prompt_chr1_cache.h5", 9 | "prompt_celltype_cache_file": "cache/hg38_6k_1kb_cistrome_cell_prompt_chr1_cache.h5" 10 | } 11 | -------------------------------------------------------------------------------- /chrombert/finetune/dataset/presets/prompt_exp.json: -------------------------------------------------------------------------------- 1 | { 2 | "kind":"PromptDataset", 3 | "hdf5_file":"hg38_6k_1kb.hdf5", 4 | "meta_file":"config/hg38_6k_meta.json", 5 | "ignore": false, 6 | "perturbation": false, 7 | "prompt_kind": "expression", 8 | "prompt_regulator_cache_file": "cache/hg38_6k_1kb_regulator_prompt_chr1_cache.h5", 9 | "prompt_celltype_cache_file": "cache/hg38_6k_1kb_expression_cell_prompt_cache.pkl" 10 | } 11 | -------------------------------------------------------------------------------- /chrombert/finetune/dataset/presets/prompt_exp_pbmc.json: -------------------------------------------------------------------------------- 1 | { 2 | "kind":"PromptDataset", 3 | "hdf5_file":"hg38_6k_1kb.hdf5", 4 | "meta_file":"config/hg38_6k_meta.json", 5 | "ignore": false, 6 | "perturbation": false, 7 | "prompt_kind": "expression", 8 | "prompt_regulator_cache_file": "cache/hg38_6k_1kb_regulator_prompt_chr1_cache.h5", 9 | "prompt_celltype_cache_file": "cache/pbmc10k_scgpt_cell_prompt_cache.pkl" 10 | } 11 | -------------------------------------------------------------------------------- /chrombert/finetune/dataset/prompt_dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .prompt_dataset import PromptDataset -------------------------------------------------------------------------------- /chrombert/finetune/dataset/prompt_dataset/interface_manager.py: -------------------------------------------------------------------------------- 1 | from .interface import RegulatorEmbInterface, CistromeCellEmbInterface, PromptsCistromInterface, ExpCellEmbInterface 2 | 3 | class RegulatorInterfaceManager(): 4 | def __init__(self,config,prompt_map): 5 | super().__init__() 6 | self.prompt_regulator_cache_file = config.prompt_regulator_cache_file 7 | self.meta_file = config.meta_file 8 | self.prompt_map = prompt_map 9 | self.config = config 10 | self.interface = self._create_interface() 11 | 12 | def _create_interface(self): 13 | if self.prompt_regulator_cache_file is not None: 14 | return RegulatorEmbInterface(self.prompt_regulator_cache_file, cache=self.config.prompt_regulator_cache_pin_memory, cache_limit=self.config.prompt_regulator_cache_limit) 15 | else: 16 | return PromptsCistromInterface(self.meta_file, self.prompt_map) 17 | 18 | def get_prompt_item(self, build_region_index, regulator, seq_len): 19 | if isinstance(self.interface, RegulatorEmbInterface): 20 | return self.interface.get_emb(build_region_index, regulator) 21 | elif isinstance(self.interface, PromptsCistromInterface): 22 | return self.interface.regulator_parse_prompts(regulator, seq_len) 23 | else: 24 | raise ValueError("Invalid interface type.") 25 | 26 | class CelltypeInterfaceManager(): 27 | def __init__(self,config,prompt_map): 28 | super().__init__() 29 | self.prompt_celltype_cache_file = config.prompt_celltype_cache_file 30 | self.prompt_kind = config.prompt_kind 31 | self.meta_file = config.meta_file 32 | self.prompt_map = prompt_map 33 | self.interface = self._create_interface() 34 | 35 | def _create_interface(self): 36 | if self.prompt_kind in ["cistrome", "cctp_sequence"]: 37 | if self.prompt_celltype_cache_file is not None: 38 | return CistromeCellEmbInterface(self.prompt_celltype_cache_file) 39 | else: 40 | return PromptsCistromInterface(self.meta_file,self.prompt_map) 41 | elif self.prompt_kind == "expression": 42 | return ExpCellEmbInterface(self.prompt_celltype_cache_file) 43 | 44 | def get_prompt_item(self, build_region_index, cell, seq_len): 45 | if isinstance(self.interface, CistromeCellEmbInterface): 46 | return self.interface.get_emb(build_region_index, cell) 47 | elif isinstance(self.interface, PromptsCistromInterface): 48 | return self.interface.cistrome_celltype_parse_prompts(cell, seq_len) 49 | elif isinstance(self.interface, ExpCellEmbInterface): 50 | return self.interface.get_emb(cell) 51 | else: 52 | raise ValueError("Invalid interface type.") 53 | -------------------------------------------------------------------------------- /chrombert/finetune/dataset/prompt_dataset/prompt_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from torch.utils.data import Dataset 4 | from .prompt_dataset_two import PromptDatasetForCCTP 5 | from .prompt_dataset_single import PromptDatasetForDNA 6 | 7 | class PromptDataset(Dataset): 8 | def __init__(self, config): 9 | ''' 10 | It's recommend to instantiate the class using DatasetConfig.init(). 11 | params: 12 | config: DatasetConfig. supervised_file must be provided. 13 | ''' 14 | super().__init__() 15 | self.config = config 16 | if isinstance(self.config.supervised_file, str): 17 | assert os.path.exists(self.config.supervised_file) 18 | else: 19 | assert isinstance(self.config.supervised_file, pd.DataFrame) and self.config.prompt_kind == "dna", "only dna prompt support DataFrame as supervised_file" 20 | 21 | list_available_prompt_kind = ["dna", "cistrome", "expression"] 22 | assert self.config.prompt_kind in list_available_prompt_kind, f"prompt_kind must be one of {list_available_prompt_kind}" 23 | if self.config.prompt_kind == "dna": 24 | self.dataset = PromptDatasetForDNA(config) 25 | elif self.config.prompt_kind in ["cistrome", "expression"]: 26 | self.dataset = PromptDatasetForCCTP(config) 27 | else: 28 | raise AttributeError(f"Warning: '{self.config.prompt_kind}' is not a valid prompt cell") 29 | 30 | def __len__(self): 31 | return len(self.dataset) 32 | 33 | def __getitem__(self, index): 34 | return self.dataset[index] 35 | def __getattr__(self, name): 36 | """ 37 | Delegate attribute and method access to the dataset object if it's not an attribute of this Proxy class. 38 | """ 39 | return getattr(self.dataset, name) 40 | 41 | 42 | -------------------------------------------------------------------------------- /chrombert/finetune/dataset/prompt_dataset/prompt_dataset_single.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | 4 | from .interface import FastaInterface 5 | from .interface_manager import RegulatorInterfaceManager 6 | from ..basic_dataset import BasicDataset 7 | 8 | ''' 9 | This file implements classes for the prompt-enhanced dataset used for DNA variation. 10 | Direct usage is not recommended; please use through PromptDataset or DatasetConfig instead. 11 | ''' 12 | 13 | class PromptDatasetForDNA(BasicDataset): 14 | def __init__(self,config): 15 | super().__init__(config) 16 | self.config = config 17 | assert isinstance(config.fasta_file, str) 18 | assert os.path.exists(config.fasta_file), f"fasta file {config.fasta_file=} does not exist" 19 | self.fasta_interface = FastaInterface(config.fasta_file) 20 | self.supervised_file = config.supervised_file 21 | self.supervised(self.supervised_file) 22 | 23 | def supervised(self, supervised_file = None): 24 | assert isinstance(supervised_file, str) or isinstance(supervised_file, pd.DataFrame) 25 | 26 | if isinstance(supervised_file, pd.DataFrame): 27 | df_supervised = supervised_file.copy().reset_index(drop=True) 28 | elif supervised_file.endswith(".csv"): 29 | df_supervised = pd.read_csv(supervised_file, header = 0) # csv format, [chrom, start, end, build_region_index, label, pos_alt, base_ref, base_ref, metadata] 30 | elif supervised_file.endswith(".tsv"): 31 | df_supervised = pd.read_csv(supervised_file, sep="\t", header = 0) 32 | else: 33 | raise ValueError(f"suffix of supervised_file {supervised_file} should be csv or tsv") 34 | 35 | self.df_supervised = df_supervised 36 | self.regions = df_supervised[['chrom', 'start', 'end']].values 37 | self.supervised_indices = df_supervised["build_region_index"] 38 | self.supervised_indices_len = len(self.supervised_indices) 39 | self.pos_alt = df_supervised["pos"].values - df_supervised["start"].values -1 40 | self.base_ref = df_supervised["base_ref"].values 41 | self.base_alt = df_supervised["base_alt"].values 42 | self.variant_id = df_supervised["variant_id"].values 43 | 44 | if "sv_label" in df_supervised.columns: 45 | self.supervised_labels = df_supervised['sv_label'].values 46 | elif "label" in df_supervised.columns: 47 | self.supervised_labels = df_supervised['label'].values 48 | else: 49 | self.supervised_labels = [None] * len(df_supervised) 50 | 51 | 52 | def get_mutant(self, seq, loci, alt): 53 | seq = list(seq) 54 | seq[loci] = alt 55 | return "".join(seq) 56 | 57 | def __len__(self): 58 | return self.supervised_indices_len 59 | 60 | def __getitem__(self, index): 61 | basic_index = self.supervised_indices[index] 62 | fw = 500 63 | 64 | item = super().__getitem__(basic_index) 65 | item['label'] = self.supervised_labels[index] 66 | 67 | pos_alt = self.pos_alt[index] 68 | region = self.regions[index, :] 69 | coord = [region[0], region[1] + pos_alt - fw, region[1] + pos_alt + fw] 70 | seq_raw = self.fasta_interface[coord] 71 | item['seq_raw'] = seq_raw 72 | 73 | variant_id = self.variant_id[index] 74 | if self.base_ref[index] == "N" : 75 | assert item["label"] == 0 76 | else: 77 | assert seq_raw[fw] == self.base_ref[index], f"{seq_raw[fw]=} != {self.base_ref[index]=} at {seq_raw[fw-3:fw+3]} of {variant_id=}, {item['region']}, {fw=}" 78 | seq_alt = self.get_mutant(seq_raw, fw, self.base_alt[index]) 79 | item['seq_alt'] = seq_alt 80 | 81 | # if self.metadata is not None: 82 | # item['metadata'] = self.metadata[index,:] 83 | 84 | return item 85 | 86 | -------------------------------------------------------------------------------- /chrombert/finetune/dataset/prompt_dataset/prompt_dataset_two.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import torch 4 | import numpy as np 5 | import pandas as pd 6 | from torch.utils.data import Dataset 7 | 8 | from ..basic_dataset import BasicDataset 9 | from .interface_manager import RegulatorInterfaceManager, CelltypeInterfaceManager 10 | 11 | ''' 12 | This file implements classes for the prompt-enhanced dataset used for TFBS prediction. 13 | Direct usage is not recommended; please use through PromptDataset or DatasetConfig instead. 14 | ''' 15 | 16 | 17 | class SupervisedForH5(): 18 | ''' 19 | For hdf5 format supervised file processing. 20 | input: h5 format supervised_file, which contains cell, regulator, label, build_region_index. 21 | regulator is optional if provived in config, but must match the length of cell if provided. 22 | label is optional, with shape of (num_regions, num of cell-regulator). 23 | See tutorials for detail format. 24 | return: cell, regulator, label, build_region_index 25 | ''' 26 | def __init__(self, config): 27 | super().__init__() 28 | self.config = config 29 | self.load_data(config.supervised_file) 30 | 31 | def load_data(self, supervised_file=None): 32 | with h5py.File(supervised_file, 'r') as hdf: 33 | assert 'regions' in hdf.keys(), "regions key is missing in h5 file" 34 | self.h5_regions = hdf['regions'][:] 35 | 36 | if self.config.prompt_celltype: 37 | self.prompt_celltype = [self.config.prompt_celltype] 38 | elif 'cell' in hdf.keys(): 39 | self.prompt_celltype = [item.decode('utf-8') for item in hdf['cell'][:]] 40 | else: 41 | raise ValueError('prompt of cell type needs to be set') 42 | 43 | if self.config.prompt_regulator: 44 | self.prompt_regulator = [self.config.prompt_regulator] * len(self.prompt_celltype) 45 | elif 'regulator' in hdf.keys(): 46 | assert len(hdf['regulator'][:]) == len(self.prompt_celltype), "Celltype and regulator lengths do not match" 47 | self.prompt_regulator = [item.decode('utf-8') for item in hdf['regulator'][:]] 48 | else: 49 | raise ValueError('prompt regulator needs to be set') 50 | 51 | assert len(self.prompt_celltype) == len(self.prompt_regulator), "Celltype and regulator lengths do not match" 52 | self.supervised_indices = self.h5_regions[:, 3] 53 | self.supervised_indices_len = self.h5_regions.shape[0] * len(self.prompt_celltype) 54 | 55 | self.supervised_labels = hdf['label'][:] > 0 if 'label' in hdf.keys() else None 56 | 57 | def __len__(self): 58 | return self.supervised_indices_len 59 | 60 | def __getitem__(self, index): 61 | index_row = index % len(self.h5_regions) 62 | index_col = index // len(self.h5_regions) 63 | return { 64 | 'build_region_index': self.supervised_indices[index_row], 65 | 'cell': self.prompt_celltype[index_col], 66 | 'regulator': self.prompt_regulator[index_col], 67 | 'label': self.supervised_labels[index_row, index_col] if self.supervised_labels is not None else None 68 | } 69 | 70 | class SupervisedForTable(): 71 | ''' 72 | For table format supervised file processing. 73 | 74 | input: supervised_file 75 | return: cell, regulator, label, build_region_index 76 | ''' 77 | def __init__(self, config): 78 | super().__init__() 79 | self.config = config 80 | self.load_data(config.supervised_file) 81 | def load_data(self, supervised_file=None): 82 | if supervised_file.endswith(".csv"): 83 | df_supervised = pd.read_csv(supervised_file) 84 | elif supervised_file.endswith(".tsv"): 85 | df_supervised = pd.read_csv(supervised_file, sep = "\t") 86 | elif supervised_file.endswith(".feather"): 87 | df_supervised = pd.read_feather(supervised_file) 88 | else: 89 | raise(ValueError(f"supervised_file must be h5, csv, tsv or feather file!")) 90 | neccessary_columns = ["chrom","start","end","build_region_index"] 91 | for column in neccessary_columns: 92 | if column not in df_supervised.columns: 93 | raise(ValueError(f"{column} not in supervised_file! it must contain headers: {neccessary_columns}")) 94 | 95 | if self.config.prompt_celltype is not None: 96 | self.prompt_celltype =[self.config.prompt_celltype]*len(df_supervised) 97 | 98 | elif "cell" in df_supervised.columns: 99 | self.prompt_celltype = df_supervised["cell"].values 100 | else: 101 | raise(ValueError(f'prompt cell need to set')) 102 | 103 | if self.config.prompt_regulator is not None: 104 | self.prompt_regulator = [self.config.prompt_regulator]*len(df_supervised) 105 | elif "regulator" in df_supervised.columns: 106 | self.prompt_regulator = df_supervised["regulator"].values 107 | else: 108 | raise(ValueError(f'prompt regulator need to set')) 109 | 110 | assert len(self.prompt_celltype) == len(self.prompt_regulator), "Celltype and regulator lengths do not match" 111 | 112 | self.supervised_indices = df_supervised["build_region_index"].values 113 | self.supervised_indices_len = len(self.supervised_indices) 114 | self.supervised_labels = df_supervised['label'].values if 'label' in df_supervised.columns else None 115 | 116 | def __len__(self): 117 | return self.supervised_indices_len 118 | 119 | def __getitem__(self, index): 120 | return { 121 | 'build_region_index': self.supervised_indices[index], 122 | 'cell': self.prompt_celltype[index], 123 | 'regulator': self.prompt_regulator[index], 124 | 'label': self.supervised_labels[index] if self.supervised_labels is not None else None 125 | } 126 | 127 | 128 | class PromptDatasetForCCTP(BasicDataset): 129 | def __init__(self,config): 130 | super().__init__(config) 131 | self.config = config 132 | self.prompt_map = {i:j for i,j in self.gsmid_to_did.items()} #self.gsmid_to_did from BasicDataset 133 | self.seq_len = self.gsm_num 134 | self.cell_interface = CelltypeInterfaceManager(config,self.prompt_map) 135 | self.regulator_interface = RegulatorInterfaceManager(config,self.prompt_map) 136 | 137 | self.supervised_file = config.supervised_file 138 | if self.supervised_file.endswith("h5"): 139 | self.sv_dataset = SupervisedForH5(config) 140 | else: 141 | self.sv_dataset = SupervisedForTable(config) 142 | 143 | def __len__(self): 144 | return len(self.sv_dataset) 145 | 146 | def __getitem__(self,index): 147 | sv_item = self.sv_dataset[index] 148 | cell = sv_item['cell'] 149 | regulator = sv_item['regulator'] 150 | label = sv_item['label'] 151 | build_region_index = sv_item['build_region_index'] 152 | 153 | celltype_item = self.cell_interface.get_prompt_item(build_region_index,cell,self.seq_len) 154 | regulator_item = self.regulator_interface.get_prompt_item(build_region_index, regulator, self.seq_len) 155 | if self.config.prompt_regulator_cache_file is None or self.config.prompt_celltype_cache_file is None: 156 | item = super().__getitem__(build_region_index) 157 | else: 158 | item = {"build_region_index": build_region_index} 159 | if label is not None: 160 | item['label'] = label 161 | else: 162 | del sv_item['label'] 163 | item.update(celltype_item) 164 | item.update(regulator_item) 165 | item.update(sv_item) 166 | return item 167 | -------------------------------------------------------------------------------- /chrombert/finetune/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .general_ft_model import ChromBERTGeneral 2 | from .gep_ft_model import ChromBERTGEP 3 | from .prompt_ft_model import ChromBERTPrompt 4 | from .model_config import ChromBERTFTConfig, get_preset_model_config 5 | from .utils import ChromBERTEmbedding -------------------------------------------------------------------------------- /chrombert/finetune/model/basic_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import nn 4 | from abc import abstractmethod, ABC 5 | from chrombert import ChromBERT 6 | from .utils import ChromBERTEmbedding 7 | from .utils import PoolFlankWindow 8 | 9 | class BasicModel(nn.Module, ABC): 10 | ''' 11 | An abstract class for fine-tuning ChromBERT, which should not be instantiated directly. 12 | ''' 13 | def __init__(self, pretrain_config, finetune_config): 14 | ''' 15 | pretrain_config: ChromBERTConfig object 16 | finetune_config: FinetuneConfig 17 | 18 | The model will be initialized using the following steps: 19 | self.pretrain_config = pretrain_config 20 | self.finetune_config = finetune_config 21 | self.create_layers() 22 | ''' 23 | super().__init__() 24 | self.pretrain_config = pretrain_config 25 | self.finetune_config = finetune_config 26 | self.create_layers() 27 | return None 28 | 29 | @abstractmethod 30 | def create_layers(self): 31 | ''' 32 | add a supervised header to the model 33 | ''' 34 | raise NotImplementedError 35 | 36 | def load_ckpt(self, ckpt = None): 37 | if ckpt is not None: 38 | assert os.path.exists(ckpt), f"Checkpoint file does not exist: {ckpt}" 39 | else: 40 | print("No checkpoint file specified, load from finetune_config.finetune_ckpt") 41 | if self.finetune_config.finetune_ckpt is not None: 42 | ckpt = self.finetune_config.finetune_ckpt 43 | assert os.path.exists(ckpt), f"Checkpoint file does not exist: {ckpt}" 44 | else: 45 | raise ValueError(f"{ckpt} is not specified!") 46 | print(f"Loading checkpoint from {ckpt}") 47 | 48 | old_state = self.state_dict() 49 | new_state = torch.load(ckpt) 50 | 51 | if "state_dict" in new_state: 52 | new_state = new_state["state_dict"] 53 | 54 | # check whether ckpt from pl module, which has prefix "model." 55 | num = len([key for key in new_state.keys() if key.startswith("model.")]) 56 | if num/len(new_state) > 0.9: 57 | new_state = {k[6:]: v for k, v in new_state.items() if k.startswith("model.")} 58 | print("Loading from pl module, remove prefix 'model.'") 59 | 60 | num = len(new_state) 61 | new_state = {k: v for k, v in new_state.items() if k in old_state} # only load the keys that are in the model 62 | print(f"Loaded {len(new_state)}/{num} parameters") 63 | old_state.update(new_state) 64 | self.load_state_dict(old_state) 65 | return None 66 | 67 | def display_trainable_parameters(self, verbose = True): 68 | ''' 69 | display the number of trainable parameters in the model 70 | ''' 71 | total_params = sum(p.numel() for p in self.parameters()) 72 | trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) 73 | o = {"total_params": total_params, "trainable_params": trainable_params} 74 | print(o) 75 | if verbose: 76 | for name, parameter in self.named_parameters(): 77 | if parameter.requires_grad: 78 | print(name, ": trainable") 79 | else: 80 | print(name, ": frozen") 81 | return o 82 | 83 | def get_pretrain(self): 84 | ''' 85 | get the pretrain part of the model 86 | ''' 87 | if hasattr(self, "pretrain_model"): 88 | assert isinstance(self.pretrain_model, ChromBERT) 89 | pretrain_model = self.pretrain_model 90 | else: 91 | if self.finetune_config.task == "gep": 92 | pretrain_model = self.pool_flank_window.pretrain_model 93 | assert isinstance(pretrain_model, ChromBERT) 94 | else: 95 | raise ValueError("pretrain_model is not specified! Please specify the pretrain_model attribute in the model, or overwrite this method.") 96 | return pretrain_model 97 | 98 | 99 | def freeze_pretrain(self, trainable = 2): 100 | ''' 101 | Freeze the model's parameters, allowing fine-tuning of specific transformer blocks. 102 | For trainable = N layers: 103 | - If `N = 0`, all transformer blocks are frozen. 104 | - If `N > 0`, only the last N transformer blocks are trainable and all other blocks are frozen. 105 | ''' 106 | pretrain_model = self.get_pretrain() 107 | pretrain_model.freeze(trainable) 108 | return self 109 | 110 | def save_pretrain(self, save_path): 111 | ''' 112 | save the pretrained part of the model to enable loading it later. 113 | ''' 114 | pretrain_model = self.get_pretrain() 115 | state_dict = pretrain_model.state_dict() 116 | torch.save(state_dict, save_path) 117 | return state_dict 118 | 119 | def get_embedding_manager(self, **kwargs): 120 | ''' 121 | get a embedding manager for the pretrain model. 122 | params: 123 | kwargs: additional parameters for EmbManager 124 | ''' 125 | pretrain_model = self.get_pretrain() 126 | finetune_config = self.finetune_config.clone() 127 | finetune_config.update(**kwargs) 128 | model_emb = ChromBERTEmbedding(pretrain_model, finetune_config.mtx_mask, finetune_config.ignore, finetune_config.ignore_index) 129 | return model_emb 130 | 131 | def save_ckpt(self, save_path): 132 | ''' 133 | save the model checkpoint 134 | ''' 135 | state_dict = self.state_dict() 136 | torch.save(state_dict, save_path) 137 | return None 138 | -------------------------------------------------------------------------------- /chrombert/finetune/model/general_ft_model.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | import torch 3 | import torch.nn as nn 4 | import lightning.pytorch as pl 5 | from chrombert import ChromBERTConfig 6 | from .utils import GeneralHeader 7 | from .basic_model import BasicModel 8 | 9 | class ChromBERTGeneral(BasicModel): 10 | ''' 11 | Fine-tuning a pre-trained ChromBERT model for general purposes. 12 | 13 | pretrain_config: ChromBERTConfig object 14 | finetune_config: FinetuneConfig 15 | 16 | The model will be initialized using the following steps: 17 | self.pretrain_config = pretrain_config 18 | self.finetune_config = finetune_config 19 | self.create_layers() 20 | ''' 21 | 22 | def create_layers(self): 23 | """ 24 | add a supervised header to fine-tune model. 25 | """ 26 | self.pretrain_model = self.pretrain_config.init_model() 27 | 28 | self.ft_header = GeneralHeader( 29 | self.pretrain_config.hidden_dim, 30 | self.finetune_config.dim_output, 31 | self.finetune_config.mtx_mask, 32 | self.finetune_config.ignore, 33 | self.finetune_config.ignore_index, 34 | self.finetune_config.dropout 35 | ) 36 | return None 37 | 38 | def forward(self, batch): 39 | input_ids = batch["input_ids"] 40 | position_ids = batch["position_ids"] 41 | chrombert_out= self.pretrain_model.forward( 42 | input_ids.long(), position_ids) 43 | header_out = self.ft_header(chrombert_out) 44 | return header_out 45 | -------------------------------------------------------------------------------- /chrombert/finetune/model/gep_ft_model.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | import torch 3 | import torch.nn as nn 4 | from chrombert import ChromBERTConfig 5 | from .utils import PoolFlankWindow,GepHeader,GeneralHeader 6 | from .basic_model import BasicModel 7 | from .utils import ChromBERTEmbedding 8 | class ChromBERTGEP(BasicModel): 9 | """ 10 | Fine-tuning a pre-trained ChromBERT model for multi window based gene expression prediction. 11 | 12 | pretrain_config: ChromBERTConfig object 13 | finetune_config: FinetuneConfig 14 | 15 | The model will be initialized using the following steps: 16 | self.pretrain_config = pretrain_config 17 | self.finetune_config = finetune_config 18 | self.create_layers() 19 | """ 20 | 21 | def create_layers(self): 22 | """add a supervised header to fine-tune model""" 23 | 24 | pretrain_model = self.pretrain_config.init_model() 25 | self.flank_region_num = int(self.finetune_config.gep_flank_window) * 2 + 1 26 | self.pool_flank_window = PoolFlankWindow( 27 | flank_region_num = self.flank_region_num, 28 | pretrain_model = pretrain_model, 29 | parallel_embedding = self.finetune_config.gep_parallel_embedding, 30 | gradient_checkpoint=self.finetune_config.gep_gradient_checkpoint 31 | ) 32 | # use zero inflation 33 | if self.finetune_config.gep_zero_inflation: 34 | self.ft_header = GepHeader( 35 | self.pretrain_config.hidden_dim, 36 | self.finetune_config.dim_output, 37 | self.finetune_config.mtx_mask, 38 | self.finetune_config.ignore, 39 | self.finetune_config.ignore_index, 40 | self.finetune_config.dropout 41 | ) 42 | else: 43 | self.ft_header = GeneralHeader( 44 | self.pretrain_config.hidden_dim, 45 | self.finetune_config.dim_output, 46 | self.finetune_config.mtx_mask, 47 | self.finetune_config.ignore, 48 | self.finetune_config.ignore_index, 49 | self.finetune_config.dropout 50 | ) 51 | return None 52 | 53 | def forward(self,batch): 54 | input_ids = batch["input_ids"] 55 | position_ids = batch["position_ids"] 56 | pool_out= self.pool_flank_window.forward( 57 | input_ids, position_ids) 58 | header_out = self.ft_header(pool_out) 59 | return header_out 60 | 61 | # While it differs slightly from other models, its usage remains the same. 62 | def get_embedding_manager(self, **kwargs): 63 | ''' 64 | get a embedding manager for the pretrain model. 65 | params: 66 | kwargs: additional parameters for EmbManager 67 | ''' 68 | pretrain_model = self.get_pretrain() 69 | finetune_config = self.finetune_config.clone() 70 | finetune_config.update(**kwargs) 71 | PoolFlankWindow_model = PoolFlankWindow( 72 | flank_region_num = int(finetune_config.gep_flank_window) * 2 + 1, 73 | pretrain_model = pretrain_model, 74 | parallel_embedding = finetune_config.gep_parallel_embedding, 75 | gradient_checkpoint = finetune_config.gep_gradient_checkpoint 76 | ) 77 | model_emb = ChromBERTEmbedding(PoolFlankWindow_model, finetune_config.mtx_mask, finetune_config.ignore, finetune_config.ignore_index) 78 | return model_emb 79 | -------------------------------------------------------------------------------- /chrombert/finetune/model/presets/general.json: -------------------------------------------------------------------------------- 1 | { 2 | "task":"general", 3 | "genome":"hg38", 4 | "mtx_mask":"config/hg38_6k_mask_matrix.tsv", 5 | "ignore": false, 6 | "pretrain_ckpt":"checkpoint/hg38_6k_1kb_pretrain.ckpt" 7 | } -------------------------------------------------------------------------------- /chrombert/finetune/model/presets/general_mm10.json: -------------------------------------------------------------------------------- 1 | { 2 | "task":"general", 3 | "genome":"mm10", 4 | "mtx_mask":"config/mm10_5k_mask_matrix.tsv", 5 | "ignore": false, 6 | "pretrain_ckpt":"checkpoint/mm10_5k_1kb_pretrain.ckpt" 7 | } -------------------------------------------------------------------------------- /chrombert/finetune/model/presets/gep.json: -------------------------------------------------------------------------------- 1 | { 2 | "task":"gep", 3 | "genome":"hg38", 4 | "mtx_mask":"config/hg38_6k_mask_matrix.tsv", 5 | "ignore": false, 6 | "pretrain_ckpt":"checkpoint/hg38_6k_1kb_pretrain.ckpt", 7 | "gep_flank_window":4, 8 | "gep_parallel_embedding":false, 9 | "gep_gradient_checkpoint":false, 10 | "gep_zero_inflation":false 11 | } -------------------------------------------------------------------------------- /chrombert/finetune/model/presets/gep_mm10.json: -------------------------------------------------------------------------------- 1 | { 2 | "task":"gep", 3 | "genome":"mm10", 4 | "mtx_mask":"config/mm10_5k_mask_matrix.tsv", 5 | "ignore": false, 6 | "pretrain_ckpt":"checkpoint/mm10_5k_1kb_pretrain.ckpt", 7 | "gep_flank_window":4, 8 | "gep_parallel_embedding":false, 9 | "gep_gradient_checkpoint":false, 10 | "gep_zero_inflation":false 11 | } -------------------------------------------------------------------------------- /chrombert/finetune/model/presets/prompt_cistrome.json: -------------------------------------------------------------------------------- 1 | { 2 | "task":"prompt", 3 | "genome":"hg38", 4 | "prompt_kind":"cistrome", 5 | "prompt_dim_external":768, 6 | "ignore": false, 7 | "pretrain_ckpt":"checkpoint/hg38_6k_1kb_pretrain.ckpt", 8 | "finetune_ckpt":"checkpoint/hg38_6k_1kb_prompt_cistrome.ckpt" 9 | 10 | } 11 | -------------------------------------------------------------------------------- /chrombert/finetune/model/presets/prompt_cistrome_mm10.json: -------------------------------------------------------------------------------- 1 | { 2 | "task":"prompt", 3 | "genome":"mm10", 4 | "prompt_kind":"cistrome", 5 | "prompt_dim_external":768, 6 | "ignore": false, 7 | "pretrain_ckpt":"checkpoint/mm10_5k_1kb_pretrain.ckpt", 8 | "finetune_ckpt":"checkpoint/mm10_5k_1kb_prompt_cistrome.ckpt" 9 | 10 | } 11 | -------------------------------------------------------------------------------- /chrombert/finetune/model/presets/prompt_dna.json: -------------------------------------------------------------------------------- 1 | { 2 | "task":"prompt", 3 | "genome":"hg38", 4 | "prompt_kind":"dna", 5 | "prompt_dim_external":768, 6 | "mtx_mask":"config/hg38_6k_mask_matrix.tsv", 7 | "ignore": false, 8 | "pretrain_ckpt":"checkpoint/hg38_6k_1kb_pretrain.ckpt" 9 | } 10 | -------------------------------------------------------------------------------- /chrombert/finetune/model/presets/prompt_dnase.json: -------------------------------------------------------------------------------- 1 | { 2 | "task":"prompt", 3 | "genome":"hg38", 4 | "prompt_kind":"cistrome", 5 | "prompt_dim_external":768, 6 | "ignore": false, 7 | "pretrain_ckpt":"checkpoint/hg38_6k_1kb_pretrain.ckpt", 8 | "finetune_ckpt":"checkpoint/hg38_6k_1kb_prompt_cistrome.ckpt" 9 | 10 | } 11 | -------------------------------------------------------------------------------- /chrombert/finetune/model/presets/prompt_dnase_mm10.json: -------------------------------------------------------------------------------- 1 | { 2 | "task":"prompt", 3 | "genome":"mm10", 4 | "prompt_kind":"cistrome", 5 | "prompt_dim_external":768, 6 | "ignore": false, 7 | "pretrain_ckpt":"checkpoint/mm10_5k_1kb_pretrain.ckpt", 8 | "finetune_ckpt":"checkpoint/mm10_5k_1kb_prompt_cistrome.ckpt" 9 | 10 | } 11 | -------------------------------------------------------------------------------- /chrombert/finetune/model/presets/prompt_exp.json: -------------------------------------------------------------------------------- 1 | { 2 | "task":"prompt", 3 | "genome":"hg38", 4 | "prompt_kind":"expression", 5 | "prompt_dim_external":512, 6 | "ignore": false, 7 | "pretrain_ckpt":"checkpoint/hg38_6k_1kb_pretrain.ckpt", 8 | "finetune_ckpt":"checkpoint/hg38_6k_1kb_prompt_expression.ckpt" 9 | } 10 | -------------------------------------------------------------------------------- /chrombert/finetune/model/prompt_dna_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import nn 4 | 5 | from chrombert.base import ChromBERTConfig 6 | from .utils.dnabert2 import DNABERT2Interface 7 | from .utils.prompt_header import AdapterExternalEmb, PromptHeader, PromptsEmb 8 | from .utils.general_header import GeneralHeader 9 | from .basic_model import BasicModel 10 | 11 | class ChromBERTPromptDNA(BasicModel): 12 | ''' 13 | Fine-tuning a pre-trained ChromBERT model using DNA-enhanced prompt. 14 | 15 | pretrain_config: ChromBERTConfig object 16 | finetune_config: FinetuneConfig 17 | 18 | The model will be initialized using the following steps: 19 | self.pretrain_config = pretrain_config 20 | self.finetune_config = finetune_config 21 | self.create_layers() 22 | ''' 23 | 24 | NECESSARY_KEYS = ["input_ids", "position_ids","seq_raw", "seq_alt"] 25 | 26 | def create_layers(self): 27 | """add a supervised header to fine-tune model""" 28 | assert self.finetune_config.prompt_kind in ["dna"], "prompt_kind must be dna!" 29 | assert self.finetune_config.prompt_dim_external == 768, "prompt_dim_external must be 768 here, only DNABERT2 supported now!" 30 | 31 | self.pretrain_model = self.pretrain_config.init_model() 32 | self.dnabert2 = DNABERT2Interface(self.finetune_config.dnabert2_ckpt, pooling="mean") 33 | self.adapter_dna_emb = AdapterExternalEmb(self.finetune_config.prompt_dim_external, 34 | dropout = self.finetune_config.dropout) 35 | 36 | self.adapter_chrombert = GeneralHeader( 37 | self.pretrain_config.hidden_dim, 38 | self.finetune_config.dim_output, 39 | self.finetune_config.mtx_mask, 40 | self.finetune_config.ignore, 41 | self.finetune_config.ignore_index, 42 | self.finetune_config.dropout 43 | ) 44 | self.head_output = PromptHeader(n_parts = 2,dropout=self.finetune_config.dropout) 45 | return None 46 | 47 | def valid_batch(self, batch): 48 | for key in self.NECESSARY_KEYS: 49 | assert key in batch, f"{key} not in batch" 50 | return None 51 | 52 | def forward(self, batch): 53 | self.valid_batch(batch) 54 | 55 | dna_embed_alt = self.dnabert2(batch["seq_alt"])["embedding_dna"] 56 | dna_emb = self.adapter_dna_emb(dna_embed_alt) 57 | 58 | chrom_embedding = self.pretrain_model( 59 | batch["input_ids"], batch["position_ids"] 60 | ) 61 | chrom_embedding = self.adapter_chrombert(chrom_embedding, return_emb = True) # (batch_size, 768) 62 | 63 | logit = self.head_output(dna_emb, chrom_embedding) 64 | return logit 65 | 66 | @DeprecationWarning 67 | def get_factor_emb(self, batch): 68 | self.valid_batch(batch) 69 | chrom_embedding = self.pretrain_model( 70 | batch["input_ids"], batch["position_ids"] 71 | ) 72 | emb_factor = self.adapter_chrombert.interface(chrom_embedding) 73 | return emb_factor 74 | -------------------------------------------------------------------------------- /chrombert/finetune/model/prompt_ft_model.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | import torch 3 | import torch.nn as nn 4 | import lightning.pytorch as pl 5 | 6 | from chrombert import ChromBERTConfig 7 | from .utils import PromptHeader,PromptsEmb,AdapterExternalEmb 8 | from .basic_model import BasicModel 9 | 10 | class ChromBERTPrompt(BasicModel): 11 | ''' 12 | Fine-tuning a pre-trained ChromBERT model using enhanced prompt, for TFBS prediction. 13 | 14 | pretrain_config: ChromBERTConfig object 15 | finetune_config: FinetuneConfig 16 | 17 | The model will be initialized using the following steps: 18 | self.pretrain_config = pretrain_config 19 | self.finetune_config = finetune_config 20 | self.create_layers() 21 | ''' 22 | 23 | def create_layers(self): 24 | """add a supervised header to fine-tune model""" 25 | self.pretrain_model = self.pretrain_config.init_model() 26 | 27 | if self.finetune_config.prompt_kind == 'expression': 28 | self.adapter_cell_emb = AdapterExternalEmb( 29 | prompt_dim_external = self.finetune_config.prompt_dim_external, 30 | dropout=self.finetune_config.dropout 31 | ) 32 | 33 | self.gather_emb = PromptsEmb() # for gather regulator and cell prompt 34 | 35 | self.ft_header = PromptHeader(n_parts = self.finetune_config.n_prompt_parts + 1, 36 | dropout = self.finetune_config.dropout) 37 | return None 38 | 39 | def forward(self,batch): 40 | 41 | emb_cell,emb_regulator,emb_all = self.get_emb_parts(batch, dtype = self.ft_header.fcs[0].fc1.weight.dtype) 42 | header_out = self.ft_header(emb_cell,emb_regulator,emb_all) 43 | 44 | return header_out 45 | 46 | def get_emb_parts(self,batch, dtype =torch.bfloat16): 47 | ''' 48 | Gather the necessary inputs for forwarding, handling cached embedding or forwarding directly. 49 | ''' 50 | 51 | if 'emb_cell' not in batch.keys() or 'emb_regulator' not in batch.keys(): 52 | input_ids = batch["input_ids"] 53 | position_ids = batch["position_ids"] 54 | chrombert_out= self.pretrain_model.forward( 55 | input_ids.long(), position_ids 56 | ) 57 | 58 | if 'emb_cell' in batch.keys(): 59 | emb_cell = batch["emb_cell"] 60 | else: 61 | prompts_cell = batch["prompts_cell"] 62 | emb_cell = self.gather_emb(chrombert_out,prompts_cell) 63 | 64 | if 'emb_regulator' in batch.keys(): 65 | emb_regulator = batch["emb_regulator"] 66 | emb_all = batch["emb_all"] 67 | else: 68 | prompts_all = batch["prompts_all"] 69 | prompts_regulator = batch["prompts_regulator"] 70 | emb_regulator = self.gather_emb(chrombert_out,prompts_regulator) 71 | emb_all = self.gather_emb(chrombert_out,prompts_all) 72 | 73 | if self.finetune_config.prompt_kind == 'expression': 74 | emb_cell = self.adapter_cell_emb(emb_cell) 75 | 76 | return emb_cell.to(dtype), emb_regulator.to(dtype), emb_all.to(dtype) 77 | -------------------------------------------------------------------------------- /chrombert/finetune/model/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .general_header import GeneralHeader 2 | from .gep_header import GepHeader 3 | from .prompt_header import PromptHeader,PromptsEmb,AdapterExternalEmb 4 | from .pool_flank_window import PoolFlankWindow 5 | 6 | from .dnabert2 import DNABERT2Interface 7 | from .residual_block import ResidualBlock 8 | from .emb_manager import CistromeEmbeddingManager, ChromBERTEmbedding 9 | -------------------------------------------------------------------------------- /chrombert/finetune/model/utils/dnabert2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from transformers import AutoTokenizer, AutoModel 6 | 7 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 8 | class DNABERT2Interface(nn.Module): 9 | """DNA-BERT2 model from hugging face""" 10 | 11 | def __init__(self, dnabert2_checkpoint, pooling = 'mean'): 12 | super().__init__() 13 | self.tokenizer = AutoTokenizer.from_pretrained(dnabert2_checkpoint, trust_remote_code=True) 14 | self.model = AutoModel.from_pretrained(dnabert2_checkpoint, trust_remote_code=True) 15 | self.pooling = pooling 16 | self.embedding_dim = 768 # the output embedding dimension of DNA-BERT2 17 | 18 | def freeze(self): 19 | for param in self.model.parameters(): 20 | param.requires_grad = False 21 | return None 22 | 23 | def forward(self, dna): 24 | # dna = "ACGTAGCATCGGATCTATCTATCGACACTTGGTTATCGATCTACGAGCATCTCGTTAGC" 25 | inputs = self.tokenizer(dna, padding = True, truncation = True, return_tensors = 'pt')["input_ids"] 26 | hidden_states = self.model(inputs.to(self.model.device))[0] # [batch_size, dna_length, embedding_dim] 27 | 28 | if self.pooling == 'mean': 29 | # embedding with mean pooling 30 | embedding_dna = torch.mean(hidden_states, dim=1) # [batch_size, embedding_dim] 31 | elif self.pooling == 'max': 32 | embedding_dna = torch.max(hidden_states, dim=1)[0] # [batch_size, embedding_dim] 33 | elif self.pooling == 'cls': 34 | embedding_dna = hidden_states[:, 0, :] 35 | else: 36 | raise(ValueError("Pooling method not supported")) 37 | 38 | return {"embedding_dna": embedding_dna, "dna_states": hidden_states} 39 | -------------------------------------------------------------------------------- /chrombert/finetune/model/utils/emb_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pandas as pd 4 | from torch import nn 5 | class CistromeEmbeddingManager(nn.Module): 6 | def __init__(self, mtx_mask, ignore = False,ignore_index = None) -> None: 7 | super().__init__() 8 | assert mtx_mask is not None, "mtx_mask must be specified" 9 | assert isinstance(mtx_mask, str), "mtx_mask must be a path to a mtx_mask" 10 | assert os.path.exists(mtx_mask), f"{mtx_mask} does not exist" 11 | self.mtx_mask_df = pd.read_csv(mtx_mask, sep='\t', index_col=0) 12 | self.mtx_mask_df = self.mtx_mask_df[sorted(self.mtx_mask_df.columns)] 13 | self.mtx_mask = torch.tensor(self.mtx_mask_df.values) # (datasets, factors) 14 | self.gsmid_names = self.mtx_mask_df.index.tolist() 15 | self.regulator_names = self.mtx_mask_df.columns.tolist() 16 | 17 | if ignore: 18 | ignore_gsmid_index = ignore_index[0] 19 | ignore_regulator_index = ignore_index[1] 20 | print(f"Ignoring {len(ignore_gsmid_index)} cistromes and {len(ignore_regulator_index)} regulators") 21 | rows_to_keep = torch.tensor([i not in ignore_gsmid_index for i in range(self.mtx_mask.shape[0])]) 22 | cols_to_keep = torch.tensor([j not in ignore_regulator_index for j in range(self.mtx_mask.shape[1])]) 23 | self.mtx_mask=self.mtx_mask[rows_to_keep][:, cols_to_keep] 24 | self.mtx_mask_df = self.mtx_mask_df.iloc[rows_to_keep.numpy(), cols_to_keep.numpy()] 25 | self.gsmid_names = self.mtx_mask_df.index.tolist() 26 | self.regulator_names = self.mtx_mask_df.columns.tolist() 27 | 28 | factor_num = (self.mtx_mask != 0).sum(dim=0) 29 | self.normalization_factors = factor_num.clamp(min=1) 30 | self.normalized_mtx_mask = (self.mtx_mask / self.normalization_factors) 31 | return None 32 | 33 | def forward(self, x): 34 | # x: [batch_size, datasets, hidden] 35 | self.normalized_mtx_mask = self.normalized_mtx_mask.to(x.device) 36 | self.normalized_mtx_mask = self.normalized_mtx_mask.to(x.dtype) 37 | x = x.transpose(1, 2) 38 | x = torch.matmul(x, self.normalized_mtx_mask) 39 | x = x.transpose(1, 2) # [batch_size, factors, hidden] 40 | 41 | return x 42 | 43 | 44 | def get_cistrome_embedding(self, x, gsmid): 45 | assert gsmid in self.gsmid_names, f"{gsmid} not found in GSMID names" 46 | index = self.gsmid_names.index(gsmid) 47 | return x[:, index, :] 48 | 49 | def get_regulator_embedding(self, x, regulator): 50 | # x: [batch_size, datasets, hidden] 51 | assert regulator in self.regulator_names, f"{regulator} not found in regulator names" 52 | index = self.regulator_names.index(regulator) 53 | x = x.transpose(1, 2) 54 | self.normalized_mtx_mask = self.normalized_mtx_mask.to(x.device) 55 | extracted_x = torch.matmul(x, self.normalized_mtx_mask[:, index:index+1]) 56 | return extracted_x.transpose(1, 2) 57 | 58 | def get_region_embedding(self, x): 59 | return x.mean(dim=1) 60 | 61 | 62 | class ChromBERTEmbedding(nn.Module): 63 | def __init__(self, pretrain_model, mtx_mask, ignore = False,ignore_index = None) -> None: 64 | super().__init__() 65 | self.pretrain_model = pretrain_model 66 | self.CistromeEmbeddingManager = CistromeEmbeddingManager(mtx_mask, ignore = ignore,ignore_index = ignore_index) 67 | self.__hidden_cistrome = None 68 | self.__hidden_regulator = None 69 | self.__training = pretrain_model.training 70 | self.list_regulator = self.CistromeEmbeddingManager.regulator_names 71 | self.list_cistrome = self.CistromeEmbeddingManager.gsmid_names 72 | 73 | 74 | def forward(self, batch): 75 | with torch.no_grad(): 76 | self.pretrain_model.eval() 77 | x = self.pretrain_model(batch["input_ids"], batch["position_ids"]) 78 | self.__hidden_cistrome = x 79 | emb = self.CistromeEmbeddingManager(x) 80 | self.__hidden_regulator = emb 81 | if self.__training: 82 | self.pretrain_model.train() 83 | return emb 84 | 85 | def get_hidden_state(self): 86 | return self.__hidden_state 87 | 88 | def get_cistrome_embedding(self, gsmid): 89 | gsmid = gsmid.lower() 90 | return self.CistromeEmbeddingManager.get_cistrome_embedding(self.__hidden_cistrome, gsmid) 91 | 92 | def get_regulator_embedding(self, regulator): 93 | regulator = regulator.lower() 94 | # return self.CistromeEmbeddingManager.get_regulator_embedding(self.__hidden_state, regulator) 95 | assert regulator in self.list_regulator, f"{regulator} not found in regulator names" 96 | index = self.list_regulator.index(regulator) 97 | return self.__hidden_regulator[:, index, :] 98 | 99 | def get_region_embedding(self): 100 | return self.CistromeEmbeddingManager.get_region_embedding(self.__hidden_cistrome) 101 | 102 | -------------------------------------------------------------------------------- /chrombert/finetune/model/utils/general_header.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from .residual_block import ResidualBlock 3 | from .emb_manager import CistromeEmbeddingManager 4 | 5 | 6 | class GeneralHeader(nn.Module): 7 | """ 8 | general 9 | """ 10 | 11 | def __init__(self, hidden_dim, dim_output, mtx_mask, ignore=False,ignore_index=None,dropout=0.1,medium_dim = 256): 12 | """ 13 | :param hidden: output size of BERT model 14 | :param dim_output: number of class 15 | """ 16 | super().__init__() 17 | self.interface = CistromeEmbeddingManager(mtx_mask = mtx_mask, 18 | ignore = ignore, 19 | ignore_index = ignore_index) 20 | self.conv = nn.Conv2d(1, 1, (1, hidden_dim)) 21 | self.activation = nn.ReLU() 22 | self.res1 = ResidualBlock(in_features=self.interface.normalized_mtx_mask.shape[1], out_features=1024,dropout=dropout) 23 | self.res2 = ResidualBlock(in_features=1024, out_features=hidden_dim,dropout=dropout) 24 | self.res3 = ResidualBlock(in_features=hidden_dim, out_features=medium_dim,dropout=dropout) 25 | self.fc = nn.Linear(in_features=medium_dim, out_features=dim_output, bias=True) 26 | 27 | 28 | def forward(self, x, return_emb = False): 29 | x = self.interface(x) # [batch_size, factors, hidden] 30 | x = x.permute(0, 2, 1) # [batch_size, hidden, factors] 31 | x = self.res1(x) # [batch_size, hidden, 1024] 32 | x = self.res2(x) # [batch_size, hidden, 768] 33 | x = x[:,None,:,:] # [batch_size, 1, hidden, 768] 34 | x = self.conv(x) # [batch_size, 1, hidden, 1] 35 | x = self.activation(x) 36 | x = x.view(x.shape[0],-1) # [batch_size, hidden] 37 | if return_emb: 38 | return x 39 | 40 | x = self.res3(x) 41 | x = self.fc(x) 42 | 43 | return x -------------------------------------------------------------------------------- /chrombert/finetune/model/utils/gep_header.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from .residual_block import ResidualBlock 3 | from .emb_manager import CistromeEmbeddingManager 4 | 5 | class GepHeader(nn.Module): 6 | """ 7 | predicting gene expression changes 8 | """ 9 | 10 | def __init__(self, hidden_dim, dim_output, mtx_mask,ignore=False,ignore_index=None,dropout=0.1,medium_dim = 256): 11 | """ 12 | :param hidden: output size of BERT model 13 | :param dim_output: number of class 14 | """ 15 | super().__init__() 16 | self.interface = CistromeEmbeddingManager(mtx_mask = mtx_mask, 17 | ignore=ignore, 18 | ignore_index=ignore_index) 19 | self.conv = nn.Conv2d(1, 1, (1, hidden_dim)) 20 | self.activation = nn.ReLU() 21 | self.res1 = ResidualBlock(in_features=self.interface.normalized_mtx_mask.shape[1], out_features=1024,dropout=dropout) 22 | self.res2 = ResidualBlock(in_features=1024, out_features=hidden_dim,dropout=dropout) 23 | self.res3 = ResidualBlock(in_features=hidden_dim, out_features=medium_dim,dropout=dropout) 24 | self.zero_inflation = nn.Sequential( 25 | nn.Linear(in_features=medium_dim, out_features=1), 26 | ) 27 | self.regression = nn.Linear(in_features=medium_dim, out_features=1, bias=True) 28 | 29 | def forward(self, x, **kwargs): 30 | x = self.interface(x) 31 | x = x.permute(0, 2, 1) 32 | x = self.res1(x) 33 | x = self.res2(x) 34 | 35 | x = x[:,None,:,:] 36 | x = self.conv(x) 37 | x = self.activation(x) 38 | x = x.view(x.shape[0],-1) 39 | x = self.res3(x) 40 | zero_prob_logit = self.zero_inflation(x) 41 | reg_value = self.regression(x) 42 | return zero_prob_logit, reg_value -------------------------------------------------------------------------------- /chrombert/finetune/model/utils/layer_norm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | class LayerNorm(nn.Module): 5 | "Construct a layernorm module (See citation for details)." 6 | 7 | def __init__(self, features, eps=1e-6): 8 | super(LayerNorm, self).__init__() 9 | self.a_2 = nn.Parameter(torch.ones(features)) 10 | self.b_2 = nn.Parameter(torch.zeros(features)) 11 | self.eps = eps 12 | 13 | def forward(self, x): 14 | mean = x.mean(-1, keepdim=True) 15 | std = x.std(-1, keepdim=True) 16 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 17 | -------------------------------------------------------------------------------- /chrombert/finetune/model/utils/pool_flank_window.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from einops import rearrange 5 | from torch.utils.checkpoint import checkpoint 6 | 7 | class PoolFlankWindow(nn.Module): 8 | """ 9 | When using multiple flanking windows for pooling 10 | """ 11 | def __init__(self, flank_region_num=9,pretrain_model=None,parallel_embedding=False,gradient_checkpoint=False): 12 | super().__init__() 13 | self.flank_region_num = flank_region_num 14 | self.pretrain_model = pretrain_model 15 | self.parallel_embedding = parallel_embedding 16 | self.gradient_checkpoint = gradient_checkpoint 17 | 18 | def forward(self,x,position_ids): 19 | batch_size = x.shape[0] 20 | seq_len = x.shape[-1] 21 | x = x.float() 22 | x.requires_grad = True 23 | if not self.parallel_embedding: 24 | embeddings = [] 25 | if not self.gradient_checkpoint: 26 | for i in range(self.flank_region_num): 27 | x_i = x[:,i,:].clone() 28 | position_ids_i = position_ids[:,i,:].clone() 29 | x_i = self.pretrain_model(x_i, position_ids_i) 30 | embeddings.append(x_i) 31 | x = torch.stack(embeddings, dim = 1) 32 | else: 33 | all_embeddings = torch.zeros((batch_size, self.flank_region_num, seq_len, 768), device=x.device) 34 | for i in range(self.flank_region_num): 35 | all_embeddings[:, i, :, :] = checkpoint( 36 | self.pretrain_model, x[:, i, :], position_ids[:, i, :] 37 | ) 38 | x = all_embeddings 39 | else: 40 | x = rearrange(x, 'b n l -> (b n) l') 41 | position_ids = rearrange(position_ids, 'b n l -> (b n) l') 42 | x = self.pretrain_model(x, position_ids) 43 | x = rearrange(x, '(b n) l h -> b n l h', b = batch_size) 44 | 45 | x = rearrange(x, 'b n l h -> b l n h', b = batch_size) 46 | x = torch.max(x, dim=-2).values #[b,l,h] 47 | return x 48 | 49 | -------------------------------------------------------------------------------- /chrombert/finetune/model/utils/prompt_header.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .residual_block import ResidualBlock 4 | 5 | class PromptHeader(nn.Module): 6 | def __init__(self, n_parts = 3,dropout = 0.1): 7 | super().__init__() 8 | self.fcs = nn.Sequential( 9 | ResidualBlock(n_parts * 768, n_parts * 768, dropout = dropout), 10 | ResidualBlock(n_parts * 768, 768, dropout = dropout), 11 | ResidualBlock(768, 768, dropout = dropout), 12 | ResidualBlock(768, 64, dropout = dropout), 13 | nn.Linear(64, 1), 14 | ) 15 | def forward(self, *args): 16 | for arg in args: 17 | assert isinstance(arg, torch.Tensor) 18 | 19 | full_emb = torch.cat(args, dim = -1) 20 | logit = self.fcs(full_emb).squeeze(-1) 21 | assert len(logit.shape) == 1 22 | return logit 23 | 24 | 25 | class AdapterExternalEmb(nn.Module): 26 | def __init__(self, prompt_dim_external, dropout = 0.1): 27 | super().__init__() 28 | dim1 = prompt_dim_external 29 | dim2 = 768 30 | dropout = dropout 31 | self.fc1 = ResidualBlock(dim1, dim2, dropout = dropout) 32 | self.fc2 = ResidualBlock(dim2, dim2, dropout = dropout) 33 | 34 | def forward(self, x): 35 | # x = x.bfloat16() 36 | x = x.to(self.fc1.fc1.weight.dtype) 37 | x = self.fc1(x) 38 | x = self.fc2(x) 39 | return x 40 | 41 | 42 | class Pooling(nn.Module): 43 | def __init__(self, operation): 44 | super().__init__() 45 | 46 | if operation in ["mean", "max"]: 47 | self.operation = operation 48 | else: 49 | raise ValueError(f"operation must be one of ['mean', 'max'], but got {operation}") 50 | 51 | def forward(self, x): 52 | if self.operation == "mean": 53 | return torch.mean(x, dim=1) 54 | elif self.operation == "max": 55 | # torch.max returns both values and indices, we only need the values 56 | return torch.max(x, dim=1).values 57 | 58 | class PromptsEmb(nn.Module): 59 | def __init__(self): 60 | super().__init__() 61 | self.pooling = Pooling('mean') 62 | def forward(self,x,prompts): 63 | prompts = prompts.unsqueeze(2) 64 | emb_sum = x.mul(prompts).sum(dim=1) 65 | emb_count = prompts.sum(dim=1) 66 | emb = emb_sum/emb_count 67 | return emb -------------------------------------------------------------------------------- /chrombert/finetune/model/utils/residual_block.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | from .layer_norm import LayerNorm 4 | 5 | class ResidualBlock(nn.Module): 6 | def __init__(self, in_features, out_features, dropout = 0.1): 7 | super(ResidualBlock, self).__init__() 8 | 9 | self.fc1 = nn.Linear(in_features, out_features) 10 | self.fc2 = nn.Linear(out_features, out_features) 11 | self.norm = LayerNorm(out_features) 12 | 13 | if in_features != out_features: 14 | self.shortcut = nn.Linear(in_features, out_features) 15 | else: 16 | self.shortcut = nn.Sequential() 17 | 18 | self.dropout = nn.Dropout(p = dropout) 19 | 20 | def forward(self, x): 21 | out = F.relu(self.fc1(x)) 22 | out = self.norm(self.fc2(out)) 23 | out = self.dropout(out) 24 | out += self.shortcut(x) 25 | out = F.relu(out) 26 | return out -------------------------------------------------------------------------------- /chrombert/finetune/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .train_config import TrainConfig 2 | from .pl_module import ClassificationPLModule, RegressionPLModule, ZeroInflationPLModule -------------------------------------------------------------------------------- /chrombert/finetune/train/basic_pl_module.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | import torch 3 | import torch.nn as nn 4 | import torchmetrics as tm 5 | from torch.optim import AdamW 6 | import lightning.pytorch as pl 7 | from transformers import get_linear_schedule_with_warmup 8 | from abc import ABC, abstractmethod 9 | 10 | from .utils.loss import FocalLoss,RMSELoss,ZeroInflationLoss 11 | from .utils.logger import LogTensorValues 12 | 13 | 14 | class BasicPLModule(pl.LightningModule, ABC): 15 | def __init__(self, model, config): 16 | super().__init__() 17 | self.model = model 18 | self.config = config 19 | self.log_values = LogTensorValues() 20 | self.configure_loss_and_metrics() 21 | 22 | def forward(self, x): 23 | return self.model(x) 24 | 25 | @abstractmethod 26 | def configure_loss_and_metrics(self): 27 | pass 28 | 29 | @abstractmethod 30 | def logging_temp_states(self, logit, batch): 31 | pass 32 | 33 | @abstractmethod 34 | def process_metrics_validation_end(self): 35 | pass 36 | 37 | def calculate_metrics(self, logits, labels, mode = 'train',pbar=False): 38 | logits, labels = self.clean_inputs(logits, labels) 39 | # only calculate loss in training mode 40 | metrics = {name: func.to(labels.device)(logits, labels) for name, func in self.loss_funcs.items()} 41 | loss = metrics[self.config.loss] 42 | 43 | state = {f"{self.config.tag}_{mode}/{name}": value for name, value in metrics.items()} 44 | for name, value in state.items(): 45 | self.log(name, value, sync_dist = True, on_step = True, on_epoch = True, prog_bar = pbar) 46 | return loss, metrics 47 | 48 | def clean_inputs(self, logits, labels, *args, **kwargs): 49 | return logits.view(-1), labels.view(-1).float() 50 | 51 | def training_step(self, batch, batch_idx): 52 | logits = self.model(batch) 53 | loss, mertrics = self.calculate_metrics(logits, batch["label"], mode='train') 54 | return loss 55 | 56 | def on_validation_epoch_start(self): 57 | self.logger_values = LogTensorValues() 58 | 59 | def validation_step(self, batch, batch_idx): 60 | logits = self.model(batch) 61 | loss, mertrics = self.calculate_metrics(logits, batch["label"], mode='validation') 62 | self.logging_temp_states(logits, batch) 63 | return loss 64 | 65 | def test_step(self, batch, batch_idx): 66 | return self.forward(batch) 67 | 68 | 69 | def on_validation_epoch_end(self): 70 | self.process_metrics_validation_end() 71 | self.logger_values = None 72 | self.trainer.datamodule.setup('val') 73 | 74 | return None 75 | 76 | def configure_optimizers(self): 77 | optimizer = AdamW(self.model.parameters(), 78 | lr = self.config.lr, 79 | betas = (self.config.adam_beta1, self.config.adam_beta2), 80 | weight_decay = self.config.weight_decay) 81 | total_steps = self.trainer.estimated_stepping_batches 82 | warmup_steps = int(self.config.warmup_ratio * total_steps) 83 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = warmup_steps, num_training_steps = total_steps) 84 | lr_scheduler_config = { 85 | "optimizer": optimizer, 86 | "lr_scheduler": { 87 | "scheduler": scheduler, 88 | "frequency": 1, 89 | "interval": "step", 90 | "monitor": "train_loss" 91 | }, 92 | } 93 | return lr_scheduler_config 94 | 95 | 96 | def freeze(self, trainable = 2): 97 | self.model.freeze_pretrain(trainable) 98 | 99 | def save_ckpt(self, ckpt_path): 100 | torch.save({"state_dict":self.model.state_dict()}, ckpt_path) 101 | return None -------------------------------------------------------------------------------- /chrombert/finetune/train/pl_module.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | import torch 3 | import torch.nn as nn 4 | import torchmetrics as tm 5 | from torch.optim import AdamW 6 | import lightning.pytorch as pl 7 | from transformers import get_linear_schedule_with_warmup 8 | from abc import ABC, abstractmethod 9 | 10 | from .utils.loss import FocalLoss,RMSELoss,ZeroInflationLoss 11 | from .basic_pl_module import BasicPLModule 12 | class ClassificationPLModule(BasicPLModule): 13 | def configure_loss_and_metrics(self): 14 | self.loss_funcs = { 15 | 'bce': nn.BCEWithLogitsLoss(), 16 | 'focal': FocalLoss(), 17 | } 18 | 19 | self.metric_funcs = { 20 | 'precision': tm.Precision(task="binary"), 21 | 'recall': tm.Recall(task="binary"), 22 | 'mcc': tm.MatthewsCorrCoef(task="binary"), 23 | 'f1': tm.F1Score(task="binary"), 24 | 'auroc': tm.AUROC(task="binary"), 25 | 'auprc': tm.AveragePrecision(task="binary"), 26 | 'acc': tm.Accuracy(task="binary"), 27 | } 28 | return None 29 | 30 | def logging_temp_states(self, logits, batch): 31 | self.logger_values.log("logits", logits) 32 | self.logger_values.log("label", batch["label"]) 33 | return None 34 | 35 | def process_metrics_validation_end(self): 36 | logits = self.logger_values.get_values("logits") 37 | labels = self.logger_values.get_values("label") 38 | logits, labels = self.clean_inputs(logits, labels) 39 | 40 | probs = torch.sigmoid(logits) 41 | metrics_loss = {name: func.to(labels.device)(logits, labels) for name, func in self.loss_funcs.items()} 42 | metrics_metrics = {name: func.to(labels.device)(logits, labels.long()) for name, func in self.metric_funcs.items()} 43 | 44 | metrics = {} 45 | metrics.update(metrics_loss) 46 | metrics.update(metrics_metrics) 47 | 48 | metrics["mean_logit"] = torch.mean(logits) 49 | metrics["median_logit"] = torch.median(logits) 50 | metrics["mean_prob"] = torch.mean(probs) 51 | metrics["median_prob"] = torch.median(probs) 52 | 53 | state = {f"{self.config.tag}_validation/{name}": value for name, value in metrics.items()} 54 | for name, value in state.items(): 55 | self.log(name, value, sync_dist = True, on_step = False, on_epoch = True, prog_bar = True) 56 | 57 | for k, func in self.metric_funcs.items(): 58 | func.reset() 59 | return None 60 | 61 | 62 | class RegressionPLModule(BasicPLModule): 63 | def configure_loss_and_metrics(self): 64 | self.loss_funcs = { 65 | 'mae': nn.L1Loss(), 66 | 'mse': nn.MSELoss(), 67 | 'rmse': RMSELoss(), 68 | } 69 | 70 | self.metric_funcs = { 71 | 'r2': tm.R2Score(), 72 | 'pcc': tm.PearsonCorrCoef(), 73 | 'scc': tm.SpearmanCorrCoef(), 74 | } 75 | return None 76 | 77 | def logging_temp_states(self, logits, batch): 78 | self.logger_values.log("logits", logits) 79 | self.logger_values.log("label", batch["label"]) 80 | return None 81 | 82 | def process_metrics_validation_end(self): 83 | logits = self.logger_values.get_values("logits") 84 | labels = self.logger_values.get_values("label") 85 | logits, labels = self.clean_inputs(logits, labels) 86 | 87 | metrics_loss = {name: func.to(labels.device)(logits, labels) for name, func in self.loss_funcs.items()} 88 | metrics_metrics = {name: func.to(labels.device)(logits, labels) for name, func in self.metric_funcs.items()} 89 | metrics_metrics.update(metrics_loss) 90 | metrics = metrics_metrics 91 | metrics["mean"] = torch.mean(logits) 92 | metrics["median"] = torch.median(logits) 93 | 94 | state = {f"{self.config.tag}_validation/{name}": value for name, value in metrics.items()} 95 | for name, value in state.items(): 96 | self.log(name, value, sync_dist = True, on_step = False, on_epoch = True, prog_bar = True) 97 | 98 | for k, func in self.metric_funcs.items(): 99 | func.reset() 100 | return None 101 | 102 | 103 | class ZeroInflationPLModule(BasicPLModule): 104 | 105 | def clean_inputs(self,logits, labels): 106 | probs, regs = logits 107 | return (probs.view(-1), regs.view(-1)), labels.view(-1) 108 | 109 | def configure_loss_and_metrics(self): 110 | self.loss_funcs = { 111 | 'zero_inflation': ZeroInflationLoss(), 112 | } 113 | self.loss_for_reg = { 114 | 'mae': nn.L1Loss(), 115 | 'mse': nn.MSELoss(), 116 | 'rmse': RMSELoss(), 117 | } 118 | self.metrics_for_reg = { 119 | 'r2': tm.R2Score(), 120 | 'pcc': tm.PearsonCorrCoef(), 121 | 'scc': tm.SpearmanCorrCoef(), 122 | } 123 | return None 124 | 125 | def logging_temp_states(self, logits, batch): 126 | prob, reg_value = logits 127 | self.logger_values.log("zero_prob_logit", prob) 128 | self.logger_values.log("reg_value", reg_value) 129 | self.logger_values.log("label", batch["label"]) 130 | return None 131 | 132 | def process_metrics_validation_end(self): 133 | zero_prob_logit = self.logger_values.get_values("zero_prob_logit") 134 | reg_value = self.logger_values.get_values("reg_value") 135 | labels = self.logger_values.get_values("label") 136 | 137 | metrics_loss = {name: func.to(labels.device)((zero_prob_logit, reg_value), labels) for name, func in self.loss_funcs.items()} 138 | loss_for_reg = {name: func.to(labels.device)(reg_value, labels) for name, func in self.loss_for_reg.items()} 139 | metrics_for_reg = {name: func.to(labels.device)(reg_value, labels) for name, func in self.metrics_for_reg.items()} 140 | 141 | metrics = {} 142 | metrics.update(metrics_loss) 143 | metrics.update(loss_for_reg) 144 | metrics.update(metrics_for_reg) 145 | 146 | state = {f"{self.config.tag}_validation/{name}": value for name, value in metrics.items()} 147 | for name, value in state.items(): 148 | self.log(name, value, sync_dist = True, on_step = False, on_epoch = True, prog_bar = True) 149 | 150 | for k, func in self.metrics_for_reg.items(): 151 | func.reset() 152 | return None 153 | -------------------------------------------------------------------------------- /chrombert/finetune/train/train_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | from copy import deepcopy 5 | from typing import Optional, Union, Any, Dict,Tuple 6 | from dataclasses import dataclass, field, asdict 7 | import numpy as np 8 | 9 | @dataclass 10 | class TrainConfig: 11 | kind: str = field(default='classification', metadata={"help": "kind of the model"}) 12 | loss: str = field(default='bce', metadata={"help": "loss function"}) 13 | tag: str = field(default='default', metadata={"help": "tag of the trainer, used for grouping logged results"}) 14 | 15 | adam_beta1: float = field(default=0.9, metadata={"help": "Adam beta1"}) 16 | adam_beta2: float = field(default=0.999, metadata={"help": "Adam beta2"}) 17 | weight_decay: float = field(default=0.01, metadata={"help": "weight decay"}) 18 | 19 | lr: float = field(default=1e-4, metadata={"help": "learning rate"}) 20 | warmup_ratio: float = field(default=0.1, metadata={"help": "warmup ratio"}) 21 | max_epochs: int = field(default=10, metadata={"help": "number of epochs"}) 22 | 23 | accumulate_grad_batches: int = field(default=1, metadata={"help": "gradient accumulation steps"}) 24 | limit_val_batches: Union[int, float] = field(default=64, metadata={"help":'number of batches to use for each validation'}) 25 | val_check_interval: Union[int, float] = field(default=64, metadata={"help":'validation check interval'}) 26 | checkpoint_metric: str = field(default=None, metadata={"help": "checkpoint metric"}) 27 | checkpoint_mode: str = field(default='min', metadata={"help": "checkpoint mode"}) 28 | 29 | 30 | def __post_init__(self): 31 | if self.checkpoint_metric is None: 32 | self.checkpoint_metric = self.loss 33 | self.validation() 34 | 35 | 36 | def to_dict(self): 37 | state = {} 38 | for k, v in self.__dataclass_fields__.items(): 39 | state[k] = deepcopy(getattr(self, k)) 40 | return state 41 | 42 | def __str__(self): 43 | return json.dumps(self.to_dict(), indent=4) 44 | 45 | def __repr__(self): 46 | return f"{self.__class__.__name__}({self.__str__()})" 47 | 48 | def __iter__(self): 49 | for name, value in self.to_dict().items(): 50 | yield name, value 51 | 52 | @classmethod 53 | def load(cls, config: Union[str, Dict[str, Any], "TrainConfig", None] = None, **kwargs: Any): 54 | if config is None: 55 | config_dict = {} 56 | elif isinstance(config, str): 57 | with open(config, 'r') as f: 58 | config_dict = json.load(f) 59 | elif isinstance(config, Dict): 60 | config_dict = deepcopy(config) 61 | elif isinstance(config, TrainConfig): 62 | config_dict = config.to_dict() 63 | else: 64 | raise TypeError(f"config must be a str, Dict, or TrainConfig, but got {type(config)}") 65 | 66 | config_dict.update(kwargs) 67 | 68 | config = cls(**config_dict) 69 | config.validation() 70 | return config 71 | 72 | def clone(self): 73 | return TrainConfig.load(self.to_dict()) 74 | 75 | def validation(self): 76 | assert self.kind in ['classification', 'regression', 'zero_inflation'], f"{self.kind=} must be one of ['classification', 'regression', 'zero_inflation']" 77 | 78 | if self.kind == 'classification': 79 | assert self.loss in ['bce', 'focal'], f"{self.loss=} must be one of ['bce', 'focal']" 80 | elif self.kind == 'regression': 81 | assert self.loss in ['mae', 'mse', 'rmse'], f"{self.loss=} must be one of ['mae', 'mse', 'rmse']" 82 | else: 83 | assert self.loss in ['zero_inflation'], f"{self.loss=} must be one of ['zero_inflation']" 84 | 85 | return None 86 | 87 | def update(self, **kwargs): 88 | for key, value in kwargs.items(): 89 | if hasattr(self, key): 90 | setattr(self, key, value) 91 | else: 92 | raise AttributeError(f"Warning: '{key}' is not a valid field name in DatasetConfig") 93 | return None 94 | 95 | def init_pl_module(self, model, **kwargs): 96 | # raise NotImplementedError("init_model method must be implemented in the subclass") 97 | train_config = self.clone() 98 | train_config.update(**kwargs) 99 | if train_config.kind == "classification": 100 | from . import ClassificationPLModule as T 101 | elif train_config.kind == "regression": 102 | from . import RegressionPLModule as T 103 | elif train_config.kind == "zero_inflation": 104 | from . import ZeroInflationPLModule as T 105 | else: 106 | raise(ValueError("Not supported kind!")) 107 | 108 | pl_module = T(model, train_config) 109 | 110 | return pl_module 111 | 112 | def init_trainer(self, name="chrombert-ft", **kwargs): 113 | ''' 114 | a simple wrapper for PyTorch Lightning Trainer. For advanced usage, please use PyTorch Lightning Trainer directly. 115 | ''' 116 | import lightning.pytorch as pl 117 | 118 | # trainer = Trainer(**kwargs) 119 | params = { 120 | "max_epochs": self.max_epochs, 121 | "accumulate_grad_batches": self.accumulate_grad_batches, 122 | "limit_val_batches": self.limit_val_batches, 123 | "val_check_interval": self.val_check_interval, 124 | } 125 | params.update(kwargs) 126 | checkpoint_metric = kwargs.get("checkpoint_metric", self.checkpoint_metric) 127 | checkpoint_mode = kwargs.get("checkpoint_mode", self.checkpoint_mode) 128 | tag = kwargs.get("tag", self.tag) 129 | if "checkpoint_metric" in kwargs: 130 | params.pop("checkpoint_metric") 131 | if "checkpoint_mode" in kwargs: 132 | params.pop("checkpoint_mode") 133 | if "tag" in kwargs: 134 | params.pop("tag") 135 | 136 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 137 | monitor=f"{tag}_validation/{checkpoint_metric}", 138 | mode= checkpoint_mode, 139 | save_top_k=kwargs.get("save_top_k", 1), 140 | save_last=True, 141 | filename='{epoch}-{step}', 142 | verbose=True, 143 | ) 144 | params.pop("save_top_k", None) 145 | trainer = pl.Trainer( 146 | logger = pl.loggers.TensorBoardLogger(save_dir=os.path.join(os.getcwd(),"lightning_logs"),name = name), 147 | callbacks = [checkpoint_callback, pl.callbacks.LearningRateMonitor()], 148 | **params 149 | ) 150 | return trainer 151 | 152 | 153 | 154 | -------------------------------------------------------------------------------- /chrombert/finetune/train/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaoweiyu-github/ChromBERT/a5c39e9960c235038e710c6afffda821b926bfa4/chrombert/finetune/train/utils/__init__.py -------------------------------------------------------------------------------- /chrombert/finetune/train/utils/logger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class LogTensorValues: 4 | def __init__(self): 5 | self.log_data = {} 6 | 7 | def log(self, key, values): 8 | if key not in self.log_data: 9 | self.log_data[key] = [] 10 | v = values.reshape(-1) 11 | # make sure v is not a zero-dim tensor 12 | assert v.dim() > 0, f"LogTensorValues: {key}={v} has zero-dim tensor" 13 | self.log_data[key].append(v) 14 | 15 | def get_values(self, key): 16 | if key not in self.log_data: 17 | return None 18 | return torch.cat(self.log_data[key]) 19 | -------------------------------------------------------------------------------- /chrombert/finetune/train/utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision.ops import sigmoid_focal_loss 5 | 6 | 7 | class FocalLoss(nn.Module): 8 | """Multi-class Focal loss implementation. 9 | Args: 10 | gamma (float): The larger the gamma, the smaller 11 | the loss weight of easier samples. 12 | alpha (float): Weighting factor in range (0,1) to balance 13 | positive vs negative examples or -1 for ignore. Default: ``-1``. 14 | reduction (string): ``'none'`` | ``'mean'`` | ``'sum'`` 15 | ``'none'``: No reduction will be applied to the output. 16 | ``'mean'``: The output will be averaged. 17 | ``'sum'``: The output will be summed. Default: ``'none'``. 18 | """ 19 | 20 | def __init__(self, gamma = 2, alpha = -1, reduction='mean'): 21 | super().__init__() 22 | self.gamma = gamma 23 | self.alpha = alpha 24 | self.reduction = reduction 25 | 26 | def forward(self, input, target): 27 | loss = sigmoid_focal_loss(inputs=input, targets=target, gamma=self.gamma, alpha=self.alpha, reduction=self.reduction) 28 | return loss 29 | 30 | class RMSELoss(nn.Module): 31 | def __init__(self): 32 | super().__init__() 33 | self.loss = nn.MSELoss(reduction = 'none') 34 | 35 | def forward(self, y_logit, y_true): 36 | loss = self.loss(y_logit, y_true) 37 | loss = torch.sqrt(loss.mean()) 38 | return loss 39 | 40 | class ZeroInflationLoss(nn.Module): 41 | def __init__(self): 42 | super().__init__() 43 | 44 | def forward(self, logit, target): 45 | zero_prob_logit, reg_value = logit 46 | changes_target = (target != 0).float() 47 | zero_mask = (target == 0).float() 48 | reg_mask = (target != 0).float() 49 | zero_loss = F.binary_cross_entropy_with_logits(zero_prob_logit, changes_target,reduction='none') 50 | zero_loss = (zero_loss * zero_mask).sum() / (zero_mask.sum() + 1e-10) 51 | mae_loss = torch.abs(reg_value - target) * reg_mask 52 | mae_loss = mae_loss.sum() / (reg_mask.sum() + 1e-10) 53 | total_loss = zero_loss + mae_loss 54 | return total_loss -------------------------------------------------------------------------------- /chrombert/scripts/chrombert_get_cistrome_emb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import h5py 4 | import json 5 | import argparse 6 | 7 | import numpy as np 8 | import pandas as pd 9 | 10 | import torch 11 | from torch import nn 12 | from tqdm import tqdm 13 | 14 | import chrombert 15 | from chrombert import ChromBERTFTConfig, DatasetConfig 16 | from .utils import HDF5Manager 17 | 18 | DEFAULT_BASEDIR = os.path.expanduser("~/.cache/chrombert/data") 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser(description="Extract cistrome embeddings from ChromBERT") 23 | parser.add_argument("supervised_file", type=str, help="Path to the supervised file") 24 | parser.add_argument("ids", nargs="+", type=str, help="IDs to extract. can be GSMID or regulator:cellline format id. To generate cache file for prompt, use 'regulator:cellline' format. ") 25 | 26 | parser.add_argument("-o", "--oname", type=str, required=True, help="Path to the output hdf5 file") 27 | 28 | parser.add_argument("--basedir", type=str, default = DEFAULT_BASEDIR, help="Base directory for the required files") 29 | 30 | parser.add_argument("-g", "--genome", type=str, default = "hg38", help="genome version. For example, hg38 or mm10. ") 31 | parser.add_argument("-k", "--ckpt", type=str, required=False, default=None, help="Path to the pretrain or fine-tuned checkpoint. Optial if it could infered from other arguments") 32 | parser.add_argument("--meta", type=str, required=False, default=None, help="Path to the meta file. Optional if it could infered from other arguments") 33 | parser.add_argument("--mask", type=str, required=False, default=None, help="Path to the mtx mask file. Optional if it could infered from other arguments") 34 | 35 | parser.add_argument("-d","--hdf5-file", type=str, required=False, default=None, help="Path to the hdf5 file that contains the dataset. Optional if it could infered from other arguments") 36 | parser.add_argument("-hr","--high-resolution", dest = "hr", action = "store_true", help="Use 200-bp resolution instead of 1-kb resolution. Caution: 200-bp resolution is preparing for the future release of ChromBERT, which is not available yet.") 37 | 38 | parser.add_argument("--batch-size", dest="batch_size", type=int, required=False, default=8, help="batch size") 39 | parser.add_argument("--num-workers",dest="num_workers", type=int, required=False, default=8, help="number of workers for dataloader") 40 | 41 | return parser.parse_args() 42 | 43 | def validate_args(args): 44 | assert os.path.exists(args.supervised_file), f"Supervised file does not exist: {args.supervised_file}" 45 | assert args.genome in ["hg38", "mm10"], f"Genome {args.genome} is not supported. " 46 | assert args.hr == False, "200-bp resolution is not supported now. " 47 | print(f"Extracting embeddings for {len(args.ids)} ids") 48 | print(f"{args.ids}") 49 | 50 | 51 | def get_model_config(args): 52 | assert args.genome in ["hg38", "mm10"], f"Genome {args.genome} is not supported. " 53 | if args.ckpt is not None: 54 | ckpt = args.ckpt 55 | else: 56 | assert os.path.exists(args.basedir), f"Basedir does not exist: {args.basedir}. If you use default basedir, please make sure environment initialized correctly (see readme of the repo). " 57 | if args.hr: 58 | res = "200bp" 59 | else: 60 | res = "1kb" 61 | if args.genome == "hg38": 62 | ckpt = os.path.join(args.basedir, "checkpoint", f"{args.genome}_6k_{res}_pretrain.ckpt") 63 | elif args.genome == "mm10": 64 | ckpt = os.path.join(args.basedir, "checkpoint", f"{args.genome}_5k_{res}_pretrain.ckpt") 65 | else: 66 | raise ValueError(f"Genome {args.genome} is not supported. ") 67 | parameters = { 68 | "genome": args.genome, 69 | "dropout": 0, 70 | "preset": "general", 71 | } 72 | if ChromBERTFTConfig.get_ckpt_type(ckpt) == "pretrain": 73 | parameters["pretrain_ckpt"] = ckpt 74 | else: 75 | parameters["finetune_ckpt"] = ckpt 76 | 77 | if args.mask is not None: 78 | parameters["mtx_mask"] = args.mask 79 | 80 | config = chrombert.get_preset_model_config( 81 | basedir = args.basedir, 82 | **parameters 83 | ) 84 | 85 | return config 86 | 87 | def get_meta_file(meta_file,basedir, genome): 88 | 89 | if meta_file is None: 90 | if genome == "hg38": 91 | meta_file = os.path.join(basedir, "config", f"{genome}_6k_meta.json") 92 | elif genome == "mm10": 93 | meta_file = os.path.join(basedir, "config", f"{genome}_5k_meta.json") 94 | else: 95 | raise ValueError(f"Genome {genome} is not supported now") 96 | return meta_file 97 | 98 | 99 | def get_dataset_config(args): 100 | if args.hr: 101 | res = "200bp" 102 | else: 103 | res = "1kb" 104 | if args.hdf5_file is not None: 105 | hdf5_file = args.hdf5_file 106 | else: 107 | assert os.path.exists(args.basedir), f"Basedir does not exist: {args.basedir}. If you use default basedir, please make sure environment initialized correctly (see readme of the repo). " 108 | if args.genome == "hg38": 109 | hdf5_file = os.path.join(args.basedir, f"{args.genome}_6k_{res}.hdf5") 110 | elif args.genome == "mm10": 111 | hdf5_file = os.path.join(args.basedir, f"{args.genome}_5k_{res}.hdf5") 112 | else: 113 | raise ValueError(f"Genome {args.genome} is not supported. ") 114 | 115 | dataset_config = DatasetConfig( 116 | kind = "GeneralDataset", 117 | supervised_file = args.supervised_file, 118 | hdf5_file = hdf5_file, 119 | batch_size = args.batch_size, 120 | num_workers = args.num_workers, 121 | ) 122 | return dataset_config 123 | 124 | def get_cistrome_ids(ids, meta_file): 125 | 126 | ids = [i.strip() for i in ids] 127 | gsm_ids = [i for i in ids if ":" not in i ] 128 | reg_ids = [i for i in ids if ":" in i] 129 | 130 | with open(meta_file) as f: 131 | meta = json.load(f) 132 | 133 | dict_ids = {i:i for i in gsm_ids} 134 | try: 135 | dict_ids.update({k:meta[k] for k in reg_ids}) 136 | except: 137 | for k in reg_ids: 138 | if k not in meta: 139 | print(f"{k} is not in the meta file!") 140 | sys.exit(1) 141 | 142 | return dict_ids 143 | 144 | 145 | 146 | def main(): 147 | args = parse_args() 148 | validate_args(args) 149 | config = get_model_config(args) 150 | model = config.init_model().get_embedding_manager().cuda().bfloat16() 151 | dc = get_dataset_config(args) 152 | dl = dc.init_dataloader() 153 | ds = dc.init_dataset() 154 | 155 | meta_file = get_meta_file(args.meta, args.basedir, args.genome) 156 | dict_ids = get_cistrome_ids(args.ids, meta_file) 157 | 158 | shapes = {f"emb/{k}": [(len(ds),768), np.float16] for k in dict_ids} 159 | with HDF5Manager(args.oname, region=[(len(ds),4), np.int64],**shapes) as h5: 160 | with torch.no_grad(): 161 | for batch in tqdm(dl, total = len(dl)): 162 | for k,v in batch.items(): 163 | if isinstance(v, torch.Tensor): 164 | batch[k] = v.cuda() 165 | model(batch) # initialize the cache 166 | region = np.concatenate([ 167 | batch["region"].long().cpu().numpy(), 168 | batch["build_region_index"].long().cpu().unsqueeze(-1).numpy() 169 | ], axis = 1 170 | ) 171 | embs = { 172 | f"emb/{k}": model.get_cistrome_embedding(v).float().cpu().detach().numpy() 173 | for k,v in dict_ids.items() 174 | } 175 | h5.insert(region = region, **embs) 176 | return None 177 | 178 | -------------------------------------------------------------------------------- /chrombert/scripts/chrombert_get_region_emb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import h5py 4 | import argparse 5 | 6 | import numpy as np 7 | import pandas as pd 8 | 9 | import torch 10 | from torch import nn 11 | from tqdm import tqdm 12 | 13 | import chrombert 14 | from chrombert import ChromBERTFTConfig, DatasetConfig 15 | from .utils import HDF5Manager 16 | 17 | DEFAULT_BASEDIR = os.path.expanduser("~/.cache/chrombert/data") 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser(description="Extract region embeddings from ChromBERT") 22 | parser.add_argument("supervised_file", type=str, help="Path to the supervised file") 23 | parser.add_argument("-o", "--oname", type=str, required=True, help="Path to the output hdf5 file") 24 | 25 | parser.add_argument("--basedir", type=str, default = DEFAULT_BASEDIR, help="Base directory for the required files") 26 | 27 | parser.add_argument("-g", "--genome", type=str, default = "hg38", help="genome version. For example, hg38 or mm10. ") 28 | parser.add_argument("-k", "--ckpt", type=str, required=False, default=None, help="Path to the pretrain or fine-tuned checkpoint. Optial if it could infered from other arguments") 29 | parser.add_argument("--mask", type=str, required=False, default=None, help="Path to the mtx mask file. Optional if it could infered from other arguments") 30 | 31 | parser.add_argument("-d","--hdf5-file", type=str, required=False, default=None, help="Path to the hdf5 file that contains the dataset. Optional if it could infered from other arguments") 32 | parser.add_argument("-hr","--high-resolution", dest = "hr", action = "store_true", help="Use 200-bp resolution instead of 1-kb resolution. Caution: 200-bp resolution is preparing for the future release of ChromBERT, which is not available yet.") 33 | 34 | parser.add_argument("--batch-size", dest="batch_size", type=int, required=False, default=8, help="batch size") 35 | parser.add_argument("--num-workers",dest="num_workers", type=int, required=False, default=8, help="number of workers for dataloader") 36 | 37 | return parser.parse_args() 38 | 39 | def validate_args(args): 40 | assert os.path.exists(args.supervised_file), f"Supervised file does not exist: {args.supervised_file}" 41 | assert args.genome in ["hg38", "mm10"], f"Genome {args.genome} is not supported. " 42 | assert args.hr == False, "200-bp resolution is not supported now. " 43 | 44 | def get_model_config(args): 45 | assert args.genome in ["hg38", "mm10"], f"Genome {args.genome} is not supported. " 46 | if args.ckpt is not None: 47 | ckpt = args.ckpt 48 | else: 49 | assert os.path.exists(args.basedir), f"Basedir does not exist: {args.basedir}. If you use default basedir, please make sure environment initialized correctly (see readme of the repo). " 50 | 51 | if args.hr: 52 | res = "200bp" 53 | else: 54 | res = "1kb" 55 | if args.genome == "hg38": 56 | ckpt = os.path.join(args.basedir, "checkpoint", f"{args.genome}_6k_{res}_pretrain.ckpt") 57 | elif args.genome == "mm10": 58 | ckpt = os.path.join(args.basedir, "checkpoint", f"{args.genome}_5k_{res}_pretrain.ckpt") 59 | else: 60 | raise ValueError(f"Genome {args.genome} is not supported. ") 61 | parameters = { 62 | "genome": args.genome, 63 | "dropout": 0, 64 | "preset": "general", 65 | } 66 | if ChromBERTFTConfig.get_ckpt_type(ckpt) == "pretrain": 67 | parameters["pretrain_ckpt"] = ckpt 68 | else: 69 | parameters["finetune_ckpt"] = ckpt 70 | 71 | if args.mask is not None: 72 | parameters["mtx_mask"] = args.mask 73 | 74 | config = chrombert.get_preset_model_config( 75 | basedir = args.basedir, 76 | **parameters 77 | ) 78 | return config 79 | 80 | def get_dataset_config(args): 81 | if args.hr: 82 | res = "200bp" 83 | else: 84 | res = "1kb" 85 | if args.hdf5_file is not None: 86 | hdf5_file = args.hdf5_file 87 | else: 88 | assert os.path.exists(args.basedir), f"Basedir does not exist: {args.basedir}. If you use default basedir, please make sure environment initialized correctly (see readme of the repo). " 89 | 90 | if args.genome == "hg38": 91 | hdf5_file = os.path.join(args.basedir, f"{args.genome}_6k_{res}.hdf5") 92 | elif args.genome == "mm10": 93 | hdf5_file = os.path.join(args.basedir, f"{args.genome}_5k_{res}.hdf5") 94 | else: 95 | raise ValueError(f"Genome {args.genome} is not supported. ") 96 | 97 | dataset_config = DatasetConfig( 98 | kind = "GeneralDataset", 99 | supervised_file = args.supervised_file, 100 | hdf5_file = hdf5_file, 101 | batch_size = args.batch_size, 102 | num_workers = args.num_workers, 103 | ) 104 | return dataset_config 105 | 106 | def main(): 107 | args = parse_args() 108 | validate_args(args) 109 | config = get_model_config(args) 110 | model = config.init_model().get_embedding_manager().cuda().bfloat16() 111 | dc = get_dataset_config(args) 112 | dl = dc.init_dataloader() 113 | ds = dc.init_dataset() 114 | 115 | with HDF5Manager(args.oname, region=[(len(ds),4), np.int64], emb = [(len(ds), 768), np.float16]) as h5: 116 | with torch.no_grad(): 117 | for batch in tqdm(dl, total = len(dl)): 118 | for k,v in batch.items(): 119 | if isinstance(v, torch.Tensor): 120 | batch[k] = v.cuda() 121 | model(batch) # initialize the cache 122 | region = np.concatenate([ 123 | batch["region"].long().cpu().numpy(), 124 | batch["build_region_index"].long().cpu().unsqueeze(-1).numpy() 125 | ], axis = 1 126 | ) 127 | h5.insert(region = region, emb = model.get_region_embedding().float().cpu().detach().numpy()) 128 | return None 129 | 130 | 131 | if __name__ == "__main__": 132 | main() 133 | -------------------------------------------------------------------------------- /chrombert/scripts/chrombert_get_regulator_emb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import h5py 4 | import json 5 | import argparse 6 | 7 | import numpy as np 8 | import pandas as pd 9 | 10 | import torch 11 | from torch import nn 12 | from tqdm import tqdm 13 | 14 | import chrombert 15 | from chrombert import ChromBERTFTConfig, DatasetConfig 16 | from .utils import HDF5Manager 17 | 18 | DEFAULT_BASEDIR = os.path.expanduser("~/.cache/chrombert/data") 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser(description="Extract regulator embeddings from ChromBERT") 23 | parser.add_argument("supervised_file", type=str, help="Path to the supervised file") 24 | parser.add_argument("ids", nargs="+", type=str, help="regulator names to extract. In lower case. ") 25 | 26 | parser.add_argument("-o", "--oname", type=str, required=True, help="Path to the output hdf5 file") 27 | 28 | parser.add_argument("--basedir", type=str, default = DEFAULT_BASEDIR, help="Base directory for the required files") 29 | 30 | parser.add_argument("-g", "--genome", type=str, default = "hg38", help="genome version. For example, hg38 or mm10. ") 31 | parser.add_argument("-k", "--ckpt", type=str, required=False, default=None, help="Path to the pretrain checkpoint or fine-tuned. Optial if it could infered from other arguments") 32 | parser.add_argument("--meta", type=str, required=False, default=None, help="Path to the meta file. Optional if it could infered from other arguments") 33 | parser.add_argument("--mask", type=str, required=False, default=None, help="Path to the mtx mask file. Optional if it could infered from other arguments") 34 | 35 | parser.add_argument("-d","--hdf5-file", type=str, required=False, default=None, help="Path to the hdf5 file that contains the dataset. Optional if it could infered from other arguments") 36 | parser.add_argument("-hr","--high-resolution", dest = "hr", action = "store_true", help="Use 200-bp resolution instead of 1-kb resolution. Caution: 200-bp resolution is preparing for the future release of ChromBERT, which is not available yet.") 37 | 38 | parser.add_argument("--batch-size", dest="batch_size", type=int, required=False, default=8, help="batch size") 39 | parser.add_argument("--num-workers",dest="num_workers", type=int, required=False, default=8, help="number of workers for dataloader") 40 | 41 | return parser.parse_args() 42 | 43 | def validate_args(args): 44 | assert os.path.exists(args.supervised_file), f"Supervised file does not exist: {args.supervised_file}" 45 | assert args.genome in ["hg38", "mm10"], f"Genome {args.genome} is not supported. " 46 | assert args.hr == False, "200-bp resolution is not supported now. " 47 | print(f"Extracting embeddings for {len(args.ids)} ids") 48 | args.ids = [i.lower().strip() for i in args.ids] 49 | print(f"{args.ids}") 50 | 51 | 52 | def get_model_config(args): 53 | assert args.genome in ["hg38", "mm10"], f"Genome {args.genome} is not supported. " 54 | if args.ckpt is not None: 55 | ckpt = args.ckpt 56 | else: 57 | assert os.path.exists(args.basedir), f"Basedir does not exist: {args.basedir}. If you use default basedir, please make sure environment initialized correctly (see readme of the repo). " 58 | if args.hr: 59 | res = "200bp" 60 | else: 61 | res = "1kb" 62 | if args.genome == "hg38": 63 | ckpt = os.path.join(args.basedir, "checkpoint", f"{args.genome}_6k_{res}_pretrain.ckpt") 64 | elif args.genome == "mm10": 65 | ckpt = os.path.join(args.basedir, "checkpoint", f"{args.genome}_5k_{res}_pretrain.ckpt") 66 | else: 67 | raise ValueError(f"Genome {args.genome} is not supported. ") 68 | parameters = { 69 | "genome": args.genome, 70 | "dropout": 0, 71 | "preset": "general", 72 | } 73 | if ChromBERTFTConfig.get_ckpt_type(ckpt) == "pretrain": 74 | parameters["pretrain_ckpt"] = ckpt 75 | else: 76 | parameters["finetune_ckpt"] = ckpt 77 | 78 | if args.mask is not None: 79 | parameters["mtx_mask"] = args.mask 80 | 81 | config = chrombert.get_preset_model_config( 82 | basedir = args.basedir, 83 | **parameters 84 | ) 85 | 86 | return config 87 | 88 | def get_meta_file(meta_file,basedir, genome): 89 | 90 | if meta_file is None: 91 | if genome == "hg38": 92 | meta_file = os.path.join(basedir, "config", f"{genome}_6k_meta.json") 93 | elif genome == "mm10": 94 | meta_file = os.path.join(basedir, "config", f"{genome}_5k_meta.json") 95 | else: 96 | raise ValueError(f"Genome {genome} is not supported now") 97 | return meta_file 98 | 99 | 100 | def get_dataset_config(args): 101 | if args.hr: 102 | res = "200bp" 103 | else: 104 | res = "1kb" 105 | if args.hdf5_file is not None: 106 | hdf5_file = args.hdf5_file 107 | else: 108 | assert os.path.exists(args.basedir), f"Basedir does not exist: {args.basedir}. If you use default basedir, please make sure environment initialized correctly (see readme of the repo). " 109 | if args.genome == "hg38": 110 | hdf5_file = os.path.join(args.basedir, f"{args.genome}_6k_{res}.hdf5") 111 | elif args.genome == "mm10": 112 | hdf5_file = os.path.join(args.basedir, f"{args.genome}_5k_{res}.hdf5") 113 | else: 114 | raise ValueError(f"Genome {args.genome} is not supported. ") 115 | 116 | dataset_config = DatasetConfig( 117 | kind = "GeneralDataset", 118 | supervised_file = args.supervised_file, 119 | hdf5_file = hdf5_file, 120 | batch_size = args.batch_size, 121 | num_workers = args.num_workers, 122 | ) 123 | return dataset_config 124 | 125 | 126 | def get_cistrome_ids(ids, meta_file): 127 | 128 | with open(meta_file) as f: 129 | meta = json.load(f) 130 | 131 | for i in ids: 132 | assert i in meta["regulator"], f"Regulator {i} is not in the meta file" 133 | dict_ids = {i: i for i in ids} 134 | return dict_ids 135 | 136 | 137 | def main(): 138 | args = parse_args() 139 | validate_args(args) 140 | config = get_model_config(args) 141 | model = config.init_model().get_embedding_manager().cuda().bfloat16() 142 | dc = get_dataset_config(args) 143 | dl = dc.init_dataloader() 144 | ds = dc.init_dataset() 145 | 146 | meta_file = get_meta_file(args.meta, args.basedir, args.genome) 147 | dict_ids = get_cistrome_ids(args.ids, meta_file) 148 | 149 | shapes = {f"emb/{k}": [(len(ds),768), np.float16] for k in dict_ids} 150 | with HDF5Manager(args.oname, region=[(len(ds),4), np.int64], all=[(len(ds),768), np.float16],**shapes) as h5: 151 | with torch.no_grad(): 152 | for batch in tqdm(dl, total = len(dl)): 153 | for k,v in batch.items(): 154 | if isinstance(v, torch.Tensor): 155 | batch[k] = v.cuda() 156 | model(batch) # initialize the cache 157 | region = np.concatenate([ 158 | batch["region"].long().cpu().numpy(), 159 | batch["build_region_index"].long().cpu().unsqueeze(-1).numpy() 160 | ], axis = 1 161 | ) 162 | embs = { 163 | f"emb/{k}": model.get_regulator_embedding(v).float().cpu().detach().numpy() 164 | for k,v in dict_ids.items() 165 | } 166 | all_emb = model.get_region_embedding().float().cpu().detach().numpy() 167 | h5.insert(region = region, all=all_emb, **embs) 168 | return None 169 | 170 | -------------------------------------------------------------------------------- /chrombert/scripts/chrombert_make_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import sys 4 | import subprocess as sp 5 | 6 | import numpy as np 7 | import pandas as pd 8 | 9 | import argparse 10 | 11 | DEFAULT_BASEDIR = os.path.expanduser("~/.cache/chrombert/data") 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser(description="Generate general datasets for ChromBERT from bed3 files") 15 | parser.add_argument("bed", type=str, help="Path to bed file") 16 | parser.add_argument("-o","--oname", type=str, required = False, default=None, help="Path to output file. Stdout if not specified. Must end with .tsv or .txt. ") 17 | 18 | parser.add_argument("--mode", type=str, choices=["region","all"], default="region", help="Mode to generate the dataset. \nregion: only consider overlap between input regions to determine the label generated. Useful for narrowPeak like input. \nall: report all overlapping status like bedtools intersect -wao. You should determine the label column by your self. ") 19 | parser.add_argument("--center", action="store_true", help="If used, only consider the center of the input regions." ) 20 | parser.add_argument("--label", type=int, default = 4, help="if mode is not region, this column will be used as label. Default is 4th. 1-based. ") 21 | 22 | parser.add_argument("--no-filter",dest="no_filter", default=False, action = "store_true", help="Do not filter the regions that are not overlapped. ") 23 | 24 | parser.add_argument("--basedir", type=str, default = DEFAULT_BASEDIR, help="Base directory for the required files") 25 | parser.add_argument("-g", "--genome", type=str, default = "hg38", help="genome version. For example, hg38 or mm10. ") 26 | parser.add_argument("-hr","--high-resolution", dest = "hr", action = "store_true", help="Use 200-bp resolution instead of 1-kb resolution. Caution: 200-bp resolution is preparing for the future release of ChromBERT, which is not available yet.") 27 | 28 | return parser.parse_args() 29 | 30 | def validate_args(args): 31 | assert os.path.exists(args.bed), f"Bed file does not exist: {args.bed}" 32 | assert os.path.exists(args.basedir), f"Basedir does not exist: {args.basedir}. If you use default basedir, please make sure environment initialized correctly (see readme of the repo). " 33 | assert args.genome in ["hg38", "mm10"], f"Genome version {args.genome} is not supported. " 34 | assert args.hr == False, "200-bp resolution is not supported now. " 35 | if args.oname is not None: 36 | assert isinstance(args.oname, str), f"Output file name should be string. Given: {args.oname}" 37 | assert args.oname.endswith(".tsv") or args.oname.endswith(".txt"), f"Output file should be tsv or txt file. Given: {args.oname}" 38 | 39 | def run_cmd(cmd): 40 | try: 41 | run = sp.run(cmd, shell = True, stdout = sp.PIPE, stderr = sp.PIPE, check = True, text = True) 42 | except sp.CalledProcessError as e: 43 | print(e) 44 | print(e.stderr, file = sys.stderr) 45 | sys.exit(1) 46 | except Exception as e: 47 | raise(e) 48 | sys.exit(1) 49 | return run 50 | 51 | def get_regions(basedir = DEFAULT_BASEDIR, genome="hg38", high_resolution = False): 52 | if genome == "hg38": 53 | if high_resolution: 54 | oname = os.path.join(basedir, "config", f"{genome}_6k_200bp_region.bed") 55 | else: 56 | oname = os.path.join(basedir, "config", f"{genome}_6k_1kb_region.bed") 57 | elif genome == "mm10": 58 | if high_resolution: 59 | oname = os.path.join(basedir, "config", f"{genome}_5k_200bp_region.bed") 60 | else: 61 | oname = os.path.join(basedir, "config", f"{genome}_5k_1kb_region.bed") 62 | else: 63 | raise ValueError(f"Genome {genome} is not supported. ") 64 | return oname 65 | 66 | def get_overlap(supervised, regions, no_filter = False, center = False): 67 | assert os.path.exists(supervised), f"Supervised file does not exist: {supervised}" 68 | if center: 69 | cmd = f''' 70 | cut -f 1-3 {supervised} | awk 'BEGIN{{OFS="\\t"}}{{c=int(($2+$3)/2);$2=c;$3=$2+1;print $0;}}' | sort -k1,1 -k2,2n | bedtools merge | bedtools intersect -c -e -f 0.5 -F 0.5 -a {regions} -b - \ 71 | ''' 72 | else: 73 | cmd = f''' 74 | cut -f 1-3 {supervised} | sort -k1,1 -k2,2n | bedtools merge | bedtools intersect -c -e -f 0.5 -F 0.5 -a {regions} -b - \ 75 | ''' 76 | if not no_filter: 77 | cmd += ''' | awk '$5 > 0' ''' 78 | 79 | run = run_cmd(cmd) 80 | 81 | if len(run.stdout) == 0: 82 | print("No overlapping regions found. ", file = sys.stderr) 83 | sys.exit(1) 84 | 85 | df_supervised = pd.read_csv(io.StringIO(run.stdout), sep = "\t", header = None) 86 | df_supervised.columns = ["chrom", "start", "end", "build_region_index", "label"] 87 | 88 | return df_supervised 89 | 90 | 91 | def get_overlap_all(supervised, regions, no_filter = False, col_label = 4, center = False): 92 | assert os.path.exists(supervised), f"Supervised file does not exist: {supervised}" 93 | 94 | if center: 95 | cmd = f''' 96 | cat {supervised} | awk -F '\\t' 'BEGIN{{OFS="\\t"}}{{c=int(($2+$3)/2);$2=c;$3=$2+1;print $0;}}' | sort -k1,1 -k2,2n | bedtools intersect -wao -e -f 0.5 -F 0.5 -a {regions} -b - \ 97 | ''' 98 | else: 99 | cmd = f''' 100 | cat {supervised} | sort -k1,1 -k2,2n | bedtools intersect -wao -e -f 0.5 -F 0.5 -a {regions} -b - \ 101 | ''' 102 | if not no_filter: 103 | cmd += ''' | awk '$5 != "." ' ''' 104 | 105 | run = run_cmd(cmd) 106 | 107 | if len(run.stdout) == 0: 108 | print("No overlapping regions found. ", file = sys.stderr) 109 | sys.exit(1) 110 | 111 | df_supervised = pd.read_csv(io.StringIO(run.stdout), sep = "\t", header = None) 112 | n_cols = df_supervised.shape[1] 113 | col_label += 4 - 1 # 1-based to 0-based, and shift to the right 114 | assert n_cols >= 8, f"Input file should have at least 3 columns. Given: {n_cols - 5 }" 115 | assert col_label < n_cols - 1, f"Label column {col_label -3 } is out of range. Total columns of your input file: {n_cols-4}" 116 | colnames = ["chrom","start","end","build_region_index","chrom_s","start_s","end_s"] 117 | 118 | n = 1 119 | for i in range(7, n_cols-1): 120 | if i != col_label: 121 | colnames.append(f"extra_{n}") 122 | n += 1 123 | else: 124 | colnames.append("label") 125 | colnames.append("coverage") 126 | df_supervised.columns = colnames 127 | df_supervised = df_supervised[["chrom","start","end","build_region_index","label", "chrom_s","start_s","end_s", "label", "coverage"] + [f"extra_{i}" for i in range(1,n)]] 128 | 129 | return df_supervised 130 | 131 | def process(supervised, regions, mode = "region", no_filter = False, col_label = 4, center = False): 132 | if mode == "region": 133 | df = get_overlap(supervised, regions, no_filter = no_filter, center = center) 134 | elif mode == "all": 135 | df = get_overlap_all(supervised, regions, no_filter = no_filter, col_label = col_label, center = center) 136 | else: 137 | raise ValueError(f"Mode {mode} is not supported. ") 138 | 139 | return df 140 | 141 | def main(): 142 | args = parse_args() 143 | validate_args(args) 144 | regions = get_regions(args.basedir, args.genome, args.hr) 145 | # df_supervised = get_overlap(args.bed, regions, no_merge = args.no_merge) 146 | df_supervised = process(args.bed, regions, mode = args.mode, no_filter = args.no_filter, col_label = args.label, center = args.center) 147 | if args.oname is None: 148 | text = df_supervised.to_csv(sep = "\t", index = False) 149 | try: 150 | sys.stdout.write(text) 151 | except BrokenPipeError: 152 | pass 153 | else: 154 | df_supervised.to_csv(args.oname, sep = "\t", index = False) 155 | return 0 156 | 157 | if __name__ == "__main__": 158 | main() 159 | 160 | 161 | -------------------------------------------------------------------------------- /chrombert/scripts/chrombert_prepare_env.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | import subprocess 5 | import argparse 6 | 7 | class FileManager: 8 | @staticmethod 9 | def create_directories(directories): 10 | for directory in directories: 11 | os.makedirs(directory, exist_ok=True) 12 | 13 | @staticmethod 14 | def decompress_file(file_path): 15 | command = ["gzip", "-d", file_path] 16 | result = subprocess.run(command) 17 | if result.returncode != 0: 18 | print(f"Error decompressing {file_path}") 19 | 20 | @staticmethod 21 | def unpack_tar(file_path, output_dir): 22 | command = ["tar", "-xzf", file_path, "-C", output_dir] 23 | result = subprocess.run(command) 24 | if result.returncode != 0: 25 | print(f"Error unpacking {file_path}") 26 | 27 | class HuggingFaceDownloader: 28 | @staticmethod 29 | def download(ifile, odir, hf_endpoint="https://huggingface.co"): 30 | # huggingface_cli_path = os.path.join(os.path.dirname(sys.executable), "huggingface-cli") 31 | huggingface_cli_path = shutil.which("huggingface-cli") 32 | if huggingface_cli_path is None: 33 | raise FileNotFoundError("The 'huggingface-cli' command was not found in the system PATH.") 34 | 35 | cmd = [ 36 | huggingface_cli_path, 37 | "download", 38 | "--repo-type", 39 | "dataset", 40 | "--local-dir", 41 | odir, 42 | "TongjiZhanglab/chrombert", 43 | ifile 44 | ] 45 | # cmd = f"huggingface-cli download --repo-type dataset --local-dir {odir} TongjiZhanglab/chrombert {ifile}" 46 | result = subprocess.run(cmd, env={"HF_ENDPOINT": hf_endpoint}) 47 | if result.returncode != 0: 48 | print(f"Error downloading {ifile}") 49 | 50 | 51 | def download(basedir = "~/.cache/chrombert/data", hf_endpoint="https://huggingface.co"): 52 | basedir = os.path.expanduser(basedir) 53 | os.makedirs(basedir, exist_ok=True) 54 | if hf_endpoint.endswith("/"): 55 | hf_endpoint = hf_endpoint[:-1] 56 | 57 | print(f"Downloading files to {basedir}") 58 | 59 | directories = [ 60 | "config", 61 | "checkpoint", 62 | "cache", 63 | "other", 64 | "demo" 65 | ] 66 | directories = [os.path.join(basedir, directory) for directory in directories] 67 | 68 | FileManager.create_directories(directories) 69 | 70 | files_to_download = [ 71 | ("hg38_6k_1kb.hdf5.gz", "."), 72 | ("hg38_6k_1kb_cistrome_cell_prompt_chr1_cache.h5", "cache"), 73 | ("hg38_6k_1kb_expression_cell_prompt_cache.pkl", "cache"), 74 | ("hg38_6k_1kb_regulator_prompt_chr1_cache.h5", "cache"), 75 | ("pbmc10k_scgpt_cell_prompt_cache.pkl","cache"), 76 | ("hg38_6k_1kb_pretrain.ckpt", "checkpoint"), 77 | ("hg38_6k_1kb_prompt_cistrome.ckpt", "checkpoint"), 78 | ("hg38_6k_1kb_prompt_expression.ckpt", "checkpoint"), 79 | ("hg38_6k_factors_list.txt", "config"), 80 | ("hg38_6k_meta.tsv", "config"), 81 | ("hg38_6k_regulators_list.txt", "config"), 82 | ("hg38_6k_1kb_region.bed", "config"), 83 | ("hg38_6k_meta.json", "config"), 84 | ("hg38_6k_mask_matrix.tsv", "config"), 85 | ("hg38.fa", "other"), 86 | ("demo.tar.gz","."), 87 | 88 | ("mm10_5k_1kb.hdf5.gz", "."), 89 | ("mm10_5k_1kb_pretrain.ckpt", "checkpoint"), 90 | ("mm10_5k_1kb_region.bed", "config"), 91 | ("mm10_5k_meta.tsv", "config"), 92 | ("mm10_5k_regulators_list.txt", "config"), 93 | ("mm10_5k_meta.json", "config"), 94 | ("mm10_5k_mask_matrix.tsv", "config") 95 | 96 | ] 97 | 98 | files_to_decompress = [ 99 | "hg38_6k_1kb.hdf5.gz", 100 | "mm10_5k_1kb.hdf5.gz" 101 | ] 102 | 103 | files_to_unpack = [ 104 | ("demo.tar.gz", ".") 105 | ] 106 | 107 | for ifile, odir in files_to_download: 108 | if ifile in files_to_decompress and os.path.exists(os.path.join(basedir, ifile.replace(".gz", ""))): 109 | continue 110 | HuggingFaceDownloader.download(ifile, os.path.join(basedir, odir), hf_endpoint) 111 | 112 | for file in files_to_decompress: 113 | if not os.path.exists(os.path.join(basedir, file.replace(".gz", ""))): 114 | FileManager.decompress_file(os.path.join(basedir, file)) 115 | 116 | for file, output_dir in files_to_unpack: 117 | FileManager.unpack_tar(os.path.join(basedir, file), os.path.join(basedir, output_dir)) 118 | return basedir 119 | 120 | def main(): 121 | parser = argparse.ArgumentParser() 122 | parser.add_argument("--basedir", type=str, default="~/.cache/chrombert/data", help="Base directory to download files") 123 | parser.add_argument("--hf-endpoint", type=str, default="https://huggingface.co", help="Huggingface endpoint") 124 | args = parser.parse_args() 125 | download(args.basedir, args.hf_endpoint) 126 | return None 127 | 128 | if __name__ == "__main__": 129 | main() 130 | -------------------------------------------------------------------------------- /chrombert/scripts/demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | def main(): 3 | print("This is script1") 4 | 5 | if __name__ == '__main__': 6 | main() 7 | -------------------------------------------------------------------------------- /chrombert/scripts/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .h5_manager import HDF5Manager -------------------------------------------------------------------------------- /chrombert/scripts/utils/h5_manager.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | 4 | class HDF5Manager: 5 | def __init__(self, o_file, chunks = True, **kwargs): 6 | ''' 7 | Initializes an HDF5 file with specified datasets. 8 | 9 | Parameters: 10 | - o_file (str): The name of the output HDF5 file. 11 | - chunks: If True, datasets will be chunked automatically. Else, it should be a positive integer (>=2), indicating the chunk size of first dimension. Or use none to disable chunking. 12 | - kwargs (dict): Keyword arguments where each key is the dataset name and each value is the shape and dtype((*shapes), dtype) of the dataset. 13 | ''' 14 | self.o_file = o_file 15 | self.chunks = chunks 16 | self.kwargs = kwargs 17 | self.n_samples = 0 18 | self.file = None 19 | 20 | def __enter__(self): 21 | # Open the HDF5 file in write mode 22 | self.file = h5py.File(self.o_file, 'w') 23 | # Create datasets based on provided shapes 24 | for key, info in self.kwargs.items(): 25 | shape, dtype = info[0], info[1] 26 | if self.chunks is True: 27 | self.file.create_dataset(key, shape=shape, dtype = dtype, chunks = True, maxshape = (None, *shape[1:])) 28 | elif self.chunks is None: 29 | self.file.create_dataset(key, shape=shape, dtype = dtype) 30 | else: 31 | self.file.create_dataset(key, shape=shape, dtype = dtype, chunks = (self.chunks, *shape[1:])) 32 | return self 33 | 34 | def __exit__(self, exc_type, exc_val, exc_tb): 35 | # Close the HDF5 file 36 | if self.file: 37 | self.file.close() 38 | 39 | def insert(self, **data): 40 | ''' 41 | Inserts data into the HDF5 file. 42 | 43 | Parameters: 44 | - data (dict): Keyword arguments where each key is the dataset name and each value is the data to be inserted. 45 | 46 | Raises: 47 | - AssertionError: If keys in data do not match the datasets in the file. 48 | - AssertionError: If the first dimension of all values is not the same. 49 | - AssertionError: If inserting more samples than the dataset can hold. 50 | ''' 51 | assert self.file is not None, "File is not open." 52 | assert set(data.keys()) == set(self.kwargs.keys()), ( 53 | f"Please ensure all data are passed and correctly saved in the file. {set(self.kwargs.keys())-set(set(data.keys()))} not be save" 54 | ) 55 | # Ensure all keys in data match the datasets in the file 56 | assert all(key in self.file for key in data.keys()), \ 57 | "All keys in data must match the datasets in the file." 58 | 59 | # Ensure the first dimension of all values is the same 60 | try: 61 | samples = [data[key].shape[0] for key in data.keys()] 62 | except: 63 | samples = [1 for key in data.keys()] 64 | 65 | # samples = [data[key].shape[0] if hasattr(data[key], 'shape') else 1 for key in data.keys()] 66 | 67 | assert len(set(samples)) == 1, "First dimension of all values should be the same." 68 | 69 | # Check if the total number of samples will exceed dataset size 70 | new_samples = samples[0] 71 | assert self.n_samples + new_samples <= self.file[list(data.keys())[0]].shape[0], \ 72 | "Inserting more samples than the dataset can hold." 73 | 74 | # Insert data into the datasets 75 | for key in data.keys(): 76 | self.file[key][self.n_samples:self.n_samples + new_samples] = data[key] 77 | 78 | # Update sample counter 79 | self.n_samples += new_samples 80 | 81 | # Example Usage 82 | if __name__ == "__main__": 83 | shapes = { 84 | 'dataset1': (1000, 10), 85 | 'dataset2': (1000, 20) 86 | } 87 | 88 | data1 = np.random.rand(100, 10) 89 | data2 = np.random.rand(100, 20) 90 | 91 | with HDF5Manager('output.h5', **shapes) as manager: 92 | manager.insert(dataset1=data1, dataset2=data2) 93 | # You can call manager.insert multiple times as needed 94 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/1_ChromBERT_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaoweiyu-github/ChromBERT/a5c39e9960c235038e710c6afffda821b926bfa4/docs/_static/1_ChromBERT_framework.png -------------------------------------------------------------------------------- /docs/_static/ChromBERT_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaoweiyu-github/ChromBERT/a5c39e9960c235038e710c6afffda821b926bfa4/docs/_static/ChromBERT_framework.png -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==7.1.2 2 | sphinx-rtd-theme==1.3.0rc1 3 | nbsphinx 4 | nbsphinx_link 5 | recommonmark 6 | lumache 7 | -------------------------------------------------------------------------------- /docs/source/cli.rst: -------------------------------------------------------------------------------- 1 | CLI Reference 2 | ============= 3 | 4 | Overview 5 | ----------- 6 | 7 | We provide a set of command line scripts for your convenience. All scripts can be called in your terminal directly. See the following sections for more details. 8 | 9 | .. csv-table:: Scripts Instruction 10 | :header: "Script", "Description" 11 | 12 | "`chrombert_prepare_env`_", "Download required files to ~/.cache/chrombert/data, or other path your like." 13 | "`chrombert_make_dataset`_", "Make dataset for ChromBERT forward. " 14 | "`chrombert_get_region_emb`_", "Get mean pooled TRN embedding (region embedding) and store in a file." 15 | "`chrombert_get_cistrome_emb`_", "Get cistrome embedding and store in a file. " 16 | "`chrombert_get_regulator_emb`_", "Get regulator embedding and store in a file." 17 | "`chrombert_imputation_cistrome`_", "Generate cistromes using prompt-enhanced ChromBERT. " 18 | "`chrombert_imputation_cistrome_sc`_", "Generate cistromes using prompt-enhanced ChromBERT, specified for single-cell data. " 19 | 20 | ----- 21 | 22 | Details 23 | --------- 24 | 25 | .. include:: scripts/chrombert_prepare_env.rst 26 | 27 | ---- 28 | 29 | .. include:: scripts/chrombert_make_dataset.rst 30 | 31 | ---- 32 | 33 | .. include:: scripts/chrombert_get_region_emb.rst 34 | 35 | ---- 36 | 37 | .. include:: scripts/chrombert_get_cistrome_emb.rst 38 | 39 | ---- 40 | 41 | .. include:: scripts/chrombert_get_regulator_emb.rst 42 | 43 | ---- 44 | 45 | .. include:: scripts/chrombert_imputation_cistrome.rst 46 | 47 | ---- 48 | 49 | .. include:: scripts/chrombert_imputation_cistrome_sc.rst 50 | 51 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | 3 | # -- Project information 4 | import sphinx_rtd_theme 5 | 6 | project = 'ChromBERT' 7 | copyright = '2024, Zhang Lab' 8 | author = 'Zhaowei Yu, Dongxu Yang, Qianqian Chen, Yuxuan Zhang' 9 | 10 | release = '1.1.0' 11 | version = '1.1.0' 12 | 13 | # -- General configuration 14 | 15 | extensions = [ 16 | 'sphinx.ext.duration', 17 | 'sphinx.ext.doctest', 18 | 'sphinx.ext.autodoc', 19 | 'sphinx.ext.autosummary', 20 | 'sphinx.ext.intersphinx', 21 | "nbsphinx", 22 | "nbsphinx_link", 23 | "recommonmark", 24 | "sphinx.ext.viewcode" 25 | ] 26 | 27 | intersphinx_mapping = { 28 | 'python': ('https://docs.python.org/3/', None), 29 | 'sphinx': ('https://www.sphinx-doc.org/en/master/', None), 30 | } 31 | intersphinx_disabled_domains = ['std'] 32 | 33 | templates_path = ['_templates'] 34 | 35 | # -- Options for HTML output 36 | 37 | html_theme = 'sphinx_rtd_theme' 38 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 39 | 40 | html_context = { 41 | "display_github": True, # Integrate GitHub 42 | "github_user": "zhaoweiyu-github", # Username 43 | "github_repo": "ChromBERT", # Repo name 44 | "github_version": "main/", # Branch 45 | "conf_py_path": "/docs/source/", # Path in the repo to conf.py 46 | } 47 | 48 | # -- Options for EPUB output 49 | epub_show_urls = 'footnote' 50 | nbsphinx_execute = "never" -------------------------------------------------------------------------------- /docs/source/finetune.rst: -------------------------------------------------------------------------------- 1 | Compiled Scripts for fine-tuning of ChromBERT 2 | ============================================= 3 | 4 | 5 | Overview 6 | ---------- 7 | 8 | For a hands-on tutorial, see the documentation on :doc:`tutorial_finetuning_ChromBERT`. 9 | 10 | We provide three scripts for fine-tuning, designed for your convenience. All scripts can be downloaded and executed anywhere, provided that your installation is correct. 11 | 12 | For detailed usage instructions, run the following command: 13 | 14 | .. code-block:: bash 15 | 16 | python --help 17 | 18 | .. csv-table:: Fine-Tune Scripts 19 | :header: "Type", "Download", "Description" 20 | 21 | "`Cell-type-specific regulatory effects`_", "`download `_ ", "Designed for scenarios where the model fine-tuning for cell-type-specific regulatory effects." 22 | "`Prompt-enhanced`_", "`download `_", "Designed for scenarios that require incorporating additional information into the model." 23 | "`Gene expression prediction`_", "`download `_", "Intended for tasks that use multiple 1-kb bins as input, such as gene expression prediction." 24 | 25 | Details 26 | --------- 27 | 28 | .. include:: scripts/ft_general.rst 29 | 30 | ---- 31 | 32 | .. include:: scripts/ft_prompt_enhanced.rst 33 | 34 | ---- 35 | 36 | .. include:: scripts/ft_gep.rst -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to ChromBERT's documentation! 2 | ====================================== 3 | 4 | ``ChromBERT`` is a pre-trained deep learning model designed to capture the genome-wide co-association patterns of approximately one thousand transcription regulators, thereby enabling accurate representations of context-specific transcriptional regulatory networks (TRNs). As a foundational model, ``ChromBERT`` can be fine-tuned to adapt to various biological contexts through transfer learning and provide insights into the roles of transcription regulators in the specific biological contexts without the need of additional genomic data for each regulator. 5 | 6 | .. note:: 7 | 8 | This project is under active development. 9 | 10 | 11 | .. toctree:: 12 | :maxdepth: 2 13 | :caption: Getting started: 14 | 15 | installation 16 | quick_tour 17 | 18 | .. toctree:: 19 | :maxdepth: 1 20 | :caption: Tutorials: 21 | 22 | tutorial_finetuning_ChromBERT 23 | finetune 24 | tutorial_embedding_extraction 25 | cli 26 | 27 | .. toctree:: 28 | :maxdepth: 1 29 | :caption: Examples: 30 | 31 | tutorial_prompt_cistrome_imputation 32 | tutorial_locus_specific_TRN_eqtl 33 | tutorial_locus_specific_TRN_ezh2 34 | tutorial_locus_specific_TRN_starr 35 | tutorial_transdifferentiation -------------------------------------------------------------------------------- /docs/source/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | For direct usage, it is recommended to utilize the Singularity image. For development purposes, installing from source is advised as below. 5 | 6 | 7 | Installation From Source 8 | ************************* 9 | 10 | ``ChromBERT`` is compatible with Python versions 3.8 or higher and requires ``PyTorch`` version 2.0 or above, along with ``FlashAttention-2``. These dependencies must be installed prior to ``ChromBERT``. 11 | 12 | Installing PyTorch 13 | ------------------ 14 | Follow the detailed instructions on `PyTorch’s official site `__ to install ``PyTorch`` according to your device and CUDA version specifications. 15 | 16 | .. note:: 17 | ChromBERT has been tested with Python 3.9+ and Torch 2.0 to 2.4 (inclusive). Compatibility with other environments is not guaranteed. 18 | 19 | Installing FlashAttention-2 20 | --------------------------- 21 | Execute the following commands to install the required packages and `FlashAttention-2 `__. 22 | 23 | .. code-block:: shell 24 | 25 | # Install the required packages for FlashAttention-2 26 | pip install packaging 27 | pip install ninja 28 | # FlashAttention-3 is not supported yet, please install FlashAttention-2 29 | pip install flash-attn==2.4.* --no-build-isolation 30 | 31 | Installing ChromBERT 32 | -------------------- 33 | Clone the repository and install ``ChromBERT`` using the commands below: 34 | 35 | .. code-block:: shell 36 | 37 | git clone https://github.com/TongjiZhanglab/ChromBERT.git 38 | cd ChromBERT 39 | pip install . 40 | 41 | Installation typically takes less than five minutes. 42 | 43 | 44 | Then download required pre-trained model and annotation data files from Hugging Face to ~/.cache/chrombert/data. 45 | 46 | .. code-block:: shell 47 | 48 | chrombert_prepare_env 49 | 50 | Alternatively, if you're experiencing significant connectivity issues with Hugging Face, you can try to use the ``--hf-endpoint`` option to connect to an available mirror of Hugging Face for you. 51 | 52 | .. code-block:: shell 53 | 54 | chrombert_prepare_env --hf-endpoint 55 | 56 | For built-in dataset preparation, it is recommended to install `bedtools `_. 57 | 58 | Verifying Installation 59 | ---------------------- 60 | To verify installation, execute the following python code: 61 | 62 | .. code-block:: python 63 | 64 | import chrombert 65 | 66 | 67 | Installation Using Singularity 68 | ***************************** 69 | 70 | We provide a pre-built Singularity image available: `chrombert.sif `_. 71 | 72 | 73 | After installing ``Singularity`` (or ``Apptainer``) and downloading the image (`chrombert.sif`), you can use the built-in ``python`` environment with: 74 | 75 | .. code-block:: bash 76 | 77 | singularity exec --nv chrombert.sif python -c "import chrombert; print('hello chrombert')" 78 | 79 | 80 | You can execute other built-in commands through the image as well. For example, to download the required pre-trained models and annotation files from Hugging Face to `~/.cache/chrombert/data`, run: 81 | 82 | .. note:: 83 | You must execute this command to prepare the environment, as the image does not include checkpoints and additional data by default to minimize size. 84 | 85 | .. code-block:: bash 86 | 87 | singularity exec --nv chrombert.sif chrombert_prepare_env 88 | 89 | To run your own Python scripts, use: 90 | 91 | .. code-block:: bash 92 | 93 | singularity exec --nv chrombert.sif python 94 | 95 | The image also includes a built-in Jupyter kernel for interactive script development via ``jupyter notebook`` or editors like ``VSCode``: 96 | 97 | .. code-block:: bash 98 | 99 | singularity exec --nv chrombert.sif jupyter notebook [other parameters] 100 | 101 | By default, Singularity mounts your home directory inside the container. If you need to mount additional directories, use the ``--bind`` parameter. Refer to the `Singularity documentation `_ for more details. 102 | 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /docs/source/quick_tour.rst: -------------------------------------------------------------------------------- 1 | Quick Tour 2 | ========== 3 | 4 | Before starting this quick tour, ensure you are familiar with the basics 5 | of the PyTorch Lightning framework, including the 6 | `LightningDataModule `__ 7 | and 8 | `LightningModule `__. 9 | What’s more, make sure you have downloaded the necessary file by 10 | executing the ``chrombert_prepare_env`` command, see :doc:`chrombert_prepare_env ` for 11 | details. 12 | 13 | OK! Let’s get started! 14 | 15 | 1 Customize the input 16 | --------------------- 17 | 18 | You can customize the model’s input by assigning parameters to the ``chrombert.DatasetConfig`` class. Key 19 | parameters include: 20 | 21 | ``kind``: Specifies the input format, which varies across tasks. It is 22 | crucial to assign this based on your specific task. Different tasks may 23 | require additional parameters, which you can find 24 | `here `__. 25 | 26 | ``hdf5_file``: A preprocessed HDF5 file containing features for 1kb bins 27 | across the genome. This file is cached in the default directory 28 | (``~/.cache/chrombert/data/hg38_6k_1kb.hdf5``) upon installation, unless 29 | customized. 30 | 31 | ``supervised_file``: A input dataset containing at least four columns: 32 | ``chrom``,\ ``start``,\ ``end``, ``build_region_index``. These four 33 | columns are used to locate and retrieve features for the regions. Depending on the task, you can add additional 34 | columns like ``label``. The ``build_region_index`` for each region is cached in the default directory 35 | (``~/.cache/chrombert/config/hg38_6k_1kb_region.bed``) upon installation, unless customized. 36 | 37 | You can also configure other parameters like ``batch_size`` and 38 | ``num_workers``. 39 | 40 | .. code:: python 41 | 42 | import chrombert 43 | 44 | # Create a DatasetConfig object with your settings 45 | dc = chrombert.DatasetConfig(hdf5_file="~/.cache/chrombert/data/hg38_6k_1kb.hdf5", 46 | kind="GeneralDataset", supervised_file="") 47 | 48 | # Initialize inputs in whatever formats you want 49 | ds = dc.init_dataset() # Dataset 50 | 51 | dl = dc.init_dataloader() # Dataloader 52 | 53 | dm = chrombert.LitChromBERTFTDataModule(config=dc, 54 | train_params={"supervised_file": args.train}, 55 | val_params={"supervised_file": args.valid}, 56 | test_params={"supervised_file": args.test}) # LightningDataModule 57 | 58 | 2 Customize the model 59 | --------------------- 60 | 61 | The model structure depends on the task at hand. Use the 62 | ``ChromBERTFTConfig``\ class to specify the task and configure its 63 | parameters. Remember to assign the ``pretrain_ckpt`` parameter if you 64 | want to use the pre-trained ChromBERT model. The checkpoint file is 65 | cached in the default directory 66 | (``~/.cache/chrombert/data/checkpoint/hg38_6k_1kb_pretrain.ckpt``) upon 67 | installation, unless customized. 68 | 69 | Different tasks may require additional parameters, which you can find 70 | `here `__. 71 | 72 | .. code:: python 73 | 74 | # Configure the model for your task 75 | mc = chrombert.ChromBERTFTConfig(task='general', 76 | pretrain_ckpt="~/.cache/chrombert/data/checkpoint/hg38_6k_1kb_pretrain.ckpt") 77 | model = mc.init_model() 78 | 79 | # Optional: Manage trainable parameters 80 | # model.freeze_pretrain(trainable:int) # Freeze transformer layers 81 | # model.display_trainable_parameters() # Display the number of trainable layers 82 | 83 | 3 Customize the training process 84 | -------------------------------- 85 | 86 | Once the input and model are configured, you can customize the training 87 | process, including: 88 | 89 | • Task Type (``kind``): ``"regression"`` or ``"classification"``. 90 | • Loss Function (``loss``): Specify the type of loss (e.g., ``"bce"`` for binary cross-entropy). 91 | • Learning Rate (``lr``): Set the desired learning rate. 92 | 93 | Explore other customizable training parameters `here `__. 94 | 95 | .. code:: python 96 | 97 | config_train = chrombert.finetune.TrainConfig(kind="classification", 98 | loss="bce", lr=1e-4) 99 | pl_module = config_train.init_pl_module(model) 100 | trainer = config_train.init_trainer() 101 | 102 | 4 Start training ! 103 | ------------------ 104 | 105 | With everything in place, you’re ready to train the model: 106 | 107 | .. code:: python 108 | 109 | trainer.fit(pl_module, datamodule = dm) 110 | 111 | 5 Task templates 112 | ------------------ 113 | 114 | To make your workflow easier, we’ve prepared a collection of ready-to-use scripts for different tasks. You can find detailed instructions and examples :doc:`here `. 115 | -------------------------------------------------------------------------------- /docs/source/scripts/chrombert_get_cistrome_emb.rst: -------------------------------------------------------------------------------- 1 | chrombert_get_cistrome_emb 2 | **************************** 3 | Extract cistrome embeddings from ChromBERT. 4 | 5 | .. code-block:: shell 6 | 7 | chrombert_get_cistrome_emb [OPTIONS] SUPERVISED_FILE IDS... -o ONAME 8 | 9 | .. rubric:: Options 10 | 11 | .. option:: SUPERVISED_FILE 12 | 13 | Path to the supervised file. 14 | 15 | .. option:: IDS 16 | 17 | IDs to extract. Can be in GSMID format or the `regulator:cellline` format. To generate a cache file for prompts, use the `regulator:cellline` format. 18 | 19 | .. option:: -o, --oname 20 | 21 | Path to the output HDF5 file. This option is required. 22 | 23 | .. option:: --basedir 24 | 25 | Base directory for the required files. Default is set to the value of `DEFAULT_BASEDIR`. 26 | 27 | .. option:: -g, --genome 28 | 29 | Genome version. For example, hg38 or mm10. Only hg38 is supported now. Default is *hg38*. 30 | 31 | .. option:: -k, --ckpt 32 | 33 | Path to the pretrain or **fine-tuned** checkpoint. Optional if it can be inferred from other arguments. 34 | 35 | .. option:: --meta 36 | 37 | Path to the meta file. Optional if it can be inferred from other arguments. 38 | 39 | .. option:: --mask 40 | 41 | Path to the matrix mask file. Optional if it can be inferred from other arguments. 42 | 43 | .. option:: -d, --hdf5-file 44 | 45 | Path to the HDF5 file that contains the dataset. Optional if it can be inferred from other arguments. 46 | 47 | .. option:: -hr, --high-resolution 48 | 49 | Use 200-bp resolution instead of 1-kb resolution. Caution: 200-bp resolution is preparing for the future release of ChromBERT, which is not available yet. 50 | 51 | .. option:: --batch-size 52 | 53 | Batch size. Default is *8*. 54 | 55 | .. option:: --num-workers 56 | 57 | Number of workers for the dataloader. Default is *8*. 58 | -------------------------------------------------------------------------------- /docs/source/scripts/chrombert_get_region_emb.rst: -------------------------------------------------------------------------------- 1 | chrombert_get_region_emb 2 | ************************** 3 | 4 | Extract mean pooled TRN embeddings (region embeddings) from ChromBERT. 5 | 6 | .. code-block:: shell 7 | 8 | chrombert_get_region_emb [OPTIONS] SUPERVISED_FILE -o ONAME 9 | 10 | .. rubric:: Options 11 | 12 | .. option:: SUPERVISED_FILE 13 | 14 | Path to the supervised file. 15 | 16 | .. option:: -o, --oname 17 | 18 | Path to the output HDF5 file. This option is required. 19 | 20 | .. option:: --basedir 21 | 22 | Base directory for the required files. Default is set to the value of `DEFAULT_BASEDIR`. 23 | 24 | .. option:: -g, --genome 25 | 26 | Genome version. For example, hg38 or mm10. Only hg38 is supported now. Default is *hg38*. 27 | 28 | .. option:: -k, --ckpt 29 | 30 | Path to the pretrain or **fine-tuned** checkpoint. Optional if it can be inferred from other arguments. 31 | 32 | .. option:: --mask 33 | 34 | Path to the matrix mask file. Optional if it can be inferred from other arguments. 35 | 36 | 37 | .. option:: -d, --hdf5-file 38 | 39 | Path to the HDF5 file that contains the dataset. Optional if it can be inferred from other arguments. 40 | 41 | .. option:: -hr, --high-resolution 42 | 43 | Use 200-bp resolution instead of 1-kb resolution. Caution: 200-bp resolution is preparing for the future release of ChromBERT, which is not available yet. 44 | 45 | .. option:: --gpu 46 | 47 | GPU index. Default is *0*. 48 | 49 | .. option:: --batch-size 50 | 51 | Batch size. Default is *8*. 52 | 53 | .. option:: --num-workers 54 | 55 | Number of workers for the dataloader. Default is *8*. -------------------------------------------------------------------------------- /docs/source/scripts/chrombert_get_regulator_emb.rst: -------------------------------------------------------------------------------- 1 | chrombert_get_regulator_emb 2 | ***************************** 3 | 4 | Extract regulator embeddings from ChromBERT. 5 | 6 | .. code-block:: shell 7 | 8 | chrombert_get_regulator_emb [OPTIONS] SUPERVISED_FILE IDS... -o ONAME 9 | 10 | .. rubric:: Options 11 | 12 | .. option:: SUPERVISED_FILE 13 | 14 | Path to the supervised file. 15 | 16 | .. option:: IDS 17 | 18 | Regulator names to extract. Must be in lower case. 19 | 20 | .. option:: -o, --oname 21 | 22 | Path to the output HDF5 file. This option is required. 23 | 24 | .. option:: --basedir 25 | 26 | Base directory for the required files. Default is set to the value of `DEFAULT_BASEDIR`. 27 | 28 | .. option:: -g, --genome 29 | 30 | Genome version. For example, hg38 or mm10. Only hg38 is supported now. Default is *hg38*. 31 | 32 | .. option:: -k, --ckpt 33 | 34 | Path to the pretrain or **fine-tuned** checkpoint. Optional if it can be inferred from other arguments. 35 | 36 | .. option:: --meta 37 | 38 | Path to the meta file. Optional if it can be inferred from other arguments. 39 | 40 | .. option:: --mask 41 | 42 | Path to the matrix mask file. Optional if it can be inferred from other arguments. 43 | 44 | .. option:: -d, --hdf5-file 45 | 46 | Path to the HDF5 file that contains the dataset. Optional if it can be inferred from other arguments. 47 | 48 | .. option:: -hr, --high-resolution 49 | 50 | Use 200-bp resolution instead of 1-kb resolution. Caution: 200-bp resolution is preparing for the future release of ChromBERT, which is not available yet. 51 | 52 | .. option:: --batch-size 53 | 54 | Batch size. Default is *8*. 55 | 56 | .. option:: --num-workers 57 | 58 | Number of workers for the dataloader. Default is *8*. 59 | -------------------------------------------------------------------------------- /docs/source/scripts/chrombert_imputation_cistrome.rst: -------------------------------------------------------------------------------- 1 | chrombert_imputation_cistrome 2 | ************************************ 3 | 4 | Generate prediction result (full bigwig file or table) from ChromBERT when given cell type name, region and regulator. 5 | 6 | .. note:: 7 | 8 | Either --o-bw or --o-table must be provided, depends on which format you want to output the results. 9 | 10 | .. code-block:: shell 11 | 12 | chrombert_imputation_cistrome [OPTIONS] SUPERVISED_FILE --o-bw BW_PATH --o-table TABLE_PATH --finetune-ckpt CKPT --prompt-kind KIND 13 | 14 | .. rubric:: Options 15 | 16 | .. option:: supervised_file 17 | 18 | Path to the supervised file. 19 | 20 | .. option:: --o-bw 21 | 22 | Path of the output BigWig file. 23 | 24 | .. option:: --o-table 25 | 26 | Path to the output table if you want to output the table. 27 | 28 | .. option:: --prompt-kind 29 | 30 | Prompt data class. Choose from *cistrome* or *expression*. This option is required. 31 | 32 | .. option:: --basedir 33 | 34 | Base directory for the required files. Default is set to the value of `DEFAULT_BASEDIR`. 35 | 36 | .. option:: -g, --genome 37 | 38 | Genome version. For example, *hg38* or *mm10*. Only *hg38* is supported now. Default is *hg38*. 39 | 40 | .. option:: --pretrain-ckpt 41 | 42 | Path to the pretrain checkpoint. Optional if it could be inferred from other arguments. 43 | 44 | .. option:: -d, --hdf5-file 45 | 46 | Path to the HDF5 file that contains the dataset. Optional if it could be inferred from other arguments. 47 | 48 | .. option:: -hr, --high-resolution 49 | 50 | Use 200-bp resolution instead of 1-kb resolution. Caution: 200-bp resolution is preparing for the future release of ChromBERT, which is not available yet. 51 | 52 | .. option:: --finetune-ckpt 53 | 54 | Path to the finetune checkpoint. Optional. 55 | 56 | .. option:: --prompt-dim-external 57 | 58 | Dimension of external data. Use *512* for *scGPT* and *768* for *ChromBERT*'s embedding. Default is *512*. 59 | 60 | .. option:: --prompt-celltype-cache-file 61 | 62 | Path to the cell-type-specific prompt cache file. Optional. 63 | 64 | .. option:: --prompt-regulator-cache-file 65 | 66 | Path to the regulator prompt cache file. Optional. 67 | 68 | .. option:: --prompt-celltype 69 | 70 | The cell-type-specific prompt. For example, *dnase:k562* for cistrome prompt and *k562* for expression prompt. It can also be provided in the supervised file if the format supports. Optional. 71 | 72 | .. option:: --prompt-regulator 73 | 74 | The regulator prompt. Determine the kind of output. For example, *ctcf* or *h3k27ac*. It can also be provided in the supervised file if the format supports. Optional. 75 | 76 | .. option:: --batch-size 77 | 78 | Batch size. Default is *8*. 79 | 80 | .. option:: --num-workers 81 | 82 | Number of workers for the dataloader. Default is *8*. 83 | -------------------------------------------------------------------------------- /docs/source/scripts/chrombert_imputation_cistrome_sc.rst: -------------------------------------------------------------------------------- 1 | chrombert_imputation_cistrome_sc 2 | ************************************ 3 | 4 | Generate prediction result (hdf5 format) from ChromBERT when given single cell, region and regulator. 5 | 6 | .. code-block:: shell 7 | 8 | chrombert_imputation_cistrome_sc [OPTIONS] SUPERVISED_FILE --o-h5 H5_PATH --finetune-ckpt CKPT --prompt-kind KIND 9 | 10 | .. rubric:: Options 11 | 12 | .. option:: supervised_file 13 | 14 | Path to the supervised file. 15 | 16 | .. option:: --o-h5 17 | 18 | Path of the output HDF5 file. This option is required. 19 | 20 | .. option:: --prompt-kind 21 | 22 | Prompt data class. Choose from *cistrome* or *expression*. This option is required. 23 | 24 | .. option:: --basedir 25 | 26 | Base directory for the required files. Default is set to the value of `DEFAULT_BASEDIR`. 27 | 28 | .. option:: -g, --genome 29 | 30 | Genome version. For example, *hg38* or *mm10*. Only *hg38* is supported now. Default is *hg38*. 31 | 32 | .. option:: --pretrain-ckpt 33 | 34 | Path to the pretrain checkpoint. Optional if it could be inferred from other arguments. 35 | 36 | .. option:: -d, --hdf5-file 37 | 38 | Path to the HDF5 file that contains the dataset. Optional if it could be inferred from other arguments. 39 | 40 | .. option:: -hr, --high-resolution 41 | 42 | Use 200-bp resolution instead of 1-kb resolution. Caution: 200-bp resolution is preparing for the future release of ChromBERT, which is not available yet. 43 | 44 | .. option:: --finetune-ckpt 45 | 46 | Path to the finetune checkpoint. Optional. 47 | 48 | .. option:: --prompt-dim-external 49 | 50 | Dimension of external data. Use *512* for *scGPT* and *768* for *ChromBERT*'s embedding. Default is *512*. 51 | 52 | .. option:: --prompt-celltype-cache-file 53 | 54 | Path to the cell-type-specific prompt cache file. Optional. 55 | 56 | .. option:: --prompt-regulator-cache-file 57 | 58 | Path to the regulator prompt cache file. Optional. 59 | 60 | .. option:: --prompt-regulator-cache-pin-memory 61 | Pin memory for regulator prompt cache for further accelerating. Default is False. 62 | 63 | .. option:: --prompt-regulator-cache-limit 64 | The limit of regulator prompt cached in memory. Be mindful of your memory usage! 65 | 66 | .. option:: --prompt-celltype 67 | 68 | The cell-type-specific prompt. For example, *dnase:k562* for cistrome prompt and *k562* for expression prompt. It can also be provided in the supervised file if the format supports. Optional. 69 | 70 | .. option:: --prompt-regulator 71 | 72 | The regulator prompt. Determine the kind of output. For example, *ctcf* or *h3k27ac*. It can also be provided in the supervised file if the format supports. Optional. 73 | 74 | .. option:: --batch-size 75 | 76 | Batch size. Default is *8*. 77 | 78 | .. option:: --num-workers 79 | 80 | Number of workers for the dataloader. Default is *8*. 81 | -------------------------------------------------------------------------------- /docs/source/scripts/chrombert_make_dataset.rst: -------------------------------------------------------------------------------- 1 | chrombert_make_dataset 2 | ********************** 3 | 4 | Generate general datasets for ChromBERT from bed files. 5 | 6 | .. code-block:: shell 7 | 8 | chrombert_make_datasets [OPTIONS] BED 9 | 10 | .. rubric:: Options 11 | 12 | .. option:: BED 13 | 14 | Path to the bed file. 15 | 16 | .. option:: -o, --oname 17 | 18 | Path to the output file. Stdout if not specified. Must end with .tsv or .txt. 19 | 20 | .. option:: --mode 21 | 22 | Mode to generate the dataset. Choices are: 23 | 24 | - *region*: only consider overlap between input regions to determine the label generated. Useful for narrowPeak-like input. 25 | - *all*: report all overlapping status like bedtools intersect -wao. You should determine the label column by yourself. 26 | 27 | Default is *region*. 28 | 29 | .. option:: --center 30 | 31 | If used, only consider the center of the input regions. 32 | 33 | .. option:: --label 34 | 35 | If mode is not *region*, this column will be used as the label. Default is the 4th column (1-based). 36 | 37 | .. option:: --no-filter 38 | 39 | Do not filter the regions that are not overlapped. 40 | 41 | .. option:: --basedir 42 | 43 | Base directory for the required files. Default is set to the value of `DEFAULT_BASEDIR`. 44 | 45 | .. option:: -g, --genome 46 | 47 | Genome version. For example, hg38 or mm10. Only hg38 is supported now. Default is *hg38*. 48 | 49 | .. option:: -hr, --high-resolution 50 | 51 | Use 200-bp resolution instead of 1-kb resolution. Caution: 200-bp resolution is preparing for the future release of ChromBERT, which is not available yet. 52 | 53 | -------------------------------------------------------------------------------- /docs/source/scripts/chrombert_prepare_env.rst: -------------------------------------------------------------------------------- 1 | chrombert_prepare_env 2 | ************************* 3 | 4 | 5 | Download required files to ~/.cache/chrombert/data, or other path your like. 6 | 7 | .. code-block:: shell 8 | 9 | chrombert_prepare_env [OPTIONS] 10 | 11 | .. rubric:: Options 12 | 13 | .. option:: --help 14 | 15 | Show this message and exit. 16 | 17 | .. option:: --basedir 18 | 19 | The directory to store the data. Default is ~/.cache/chrombert/data. 20 | 21 | .. option:: --hf-endpoint 22 | 23 | The endpoint of the Hugging Face model. 24 | 25 | -------------------------------------------------------------------------------- /docs/source/scripts/ft_general.rst: -------------------------------------------------------------------------------- 1 | Cell-type-specific regulatory effects 2 | **************************************** 3 | 4 | This script enables fine-tuning ChromBERT for analyzing cell-type-specific regulatory effects. Users can selectively perturb or omit specific genomic features, making it valuable for simulating regulatory changes and testing hypotheses about the role of individual regulatory elements in cell-type-specific gene regulation. 5 | 6 | .. code-block:: shell 7 | 8 | python ft_general.py [OPTIONS] --train TRAIN_PATH --valid VALID_PATH --test TEST_PATH 9 | 10 | 11 | .. rubric:: Options 12 | 13 | .. option:: --lr 14 | 15 | Learning rate. Default is *1e-4*. 16 | 17 | .. option:: --warmup-ratio 18 | 19 | Warmup ratio. Default is *0.1*. 20 | 21 | .. option:: --grad-samples 22 | 23 | Number of gradient samples. Automatically scaled according to the batch size and GPU number. Default is *512*. 24 | 25 | .. option:: --max-epochs 26 | 27 | Number of epochs to train. Default is *10*. 28 | 29 | .. option:: --pretrain-trainable 30 | 31 | Number of pretrained layers to be trainable. Default is *2*. 32 | 33 | .. option:: --tag 34 | 35 | Tag of the trainer, used for grouping logged results. Default is *default*. 36 | 37 | .. option:: --limit-val-batches 38 | 39 | Number of batches to use for each validation. Default is *64*. 40 | 41 | .. option:: --val-check-interval 42 | 43 | Validation check interval. Default is *64*. 44 | 45 | .. option:: --name 46 | 47 | Name of the trainer. Default is *chrombert-ft-general*. 48 | 49 | .. option:: --save-top-k 50 | 51 | Save top k checkpoints. Default is *3*. 52 | 53 | .. option:: --checkpoint-metric 54 | 55 | Checkpoint metric. Default is the same as the loss function if not specified. 56 | 57 | .. option:: --checkpoint-mode 58 | 59 | Checkpoint mode. Default is *min*. 60 | 61 | .. option:: --log-every-n-steps 62 | 63 | Log every n steps. Default is *50*. 64 | 65 | .. option:: --kind 66 | 67 | Kind of the task. Choose from *classification*, *regression*, or *zero_inflation*. Default is *classification*. 68 | 69 | .. option:: --loss 70 | 71 | Loss function. Default is *focal*. 72 | 73 | .. option:: --train 74 | 75 | Path to the training data. This option is required. 76 | 77 | .. option:: --valid 78 | 79 | Path to the validation data. This option is required. 80 | 81 | .. option:: --test 82 | 83 | Path to the test data. This option is required. 84 | 85 | .. option:: --batch-size 86 | 87 | Batch size. Default is *8*. 88 | 89 | .. option:: --num-workers 90 | 91 | Number of workers. Default is *4*. 92 | 93 | .. option:: --basedir 94 | 95 | Path to the base directory. Default is set to the value of ``os.path.expanduser("~/.cache/chrombert/data")``. 96 | 97 | .. option:: -g, --genome 98 | 99 | Genome version. For example, *hg38* or *mm10*. Only *hg38* is supported now. Default is *hg38*. 100 | 101 | .. option:: -k, --ckpt 102 | 103 | Path to the pretrain checkpoint. Optional if it could be inferred from other arguments. 104 | 105 | .. option:: --mask 106 | 107 | Path to the mtx mask file. Optional if it could be inferred from other arguments. 108 | 109 | .. option:: -d, --hdf5-file 110 | 111 | Path to the HDF5 file that contains the dataset. Optional if it could be inferred from other arguments. 112 | 113 | .. option:: --dropout 114 | 115 | Dropout rate. Default is *0.1*. 116 | 117 | .. option:: -hr, --high-resolution 118 | 119 | Use 200-bp resolution instead of 1-kb resolution. Caution: 200-bp resolution is preparing for the future release of ChromBERT, which is not available yet. 120 | 121 | .. option:: --ignore 122 | 123 | Ignore given targets. 124 | 125 | .. option:: --ignore-object 126 | 127 | Ignore object. Regulator, or dataset IDs separated by *;*. 128 | 129 | .. option:: --perturbation 130 | 131 | Use perturbation model. 132 | 133 | .. option:: --perturbation-object 134 | 135 | Perturbation object. Regulator, or dataset IDs separated by *;*. 136 | 137 | .. option:: --perturbation-value 138 | 139 | Perturbation target level. *0* means knock-out perturbation, and *4* means over-expression perturbation. Default is *0*. 140 | 141 | -------------------------------------------------------------------------------- /docs/source/scripts/ft_gep.rst: -------------------------------------------------------------------------------- 1 | Gene expression prediction 2 | ************************** 3 | 4 | Gene expression is influenced by multiple regulatory regions, often extending over significant genomic distances, particularly near the transcription start site (TSS). This task uses a flank window to consider multiple nearby regions, providing a holistic view of regulatory impacts on gene expression. 5 | 6 | .. code-block:: shell 7 | 8 | python ft_gep.py [OPTIONS] --flank-window FLANK_WINDOW_SIZE \ 9 | --train TRAIN_PATH \ 10 | --valid VALID_PATH \ 11 | --test TEST_PATH 12 | 13 | .. rubric:: Options 14 | 15 | .. option:: --lr 16 | 17 | Learning rate. Default is *1e-4*. 18 | 19 | .. option:: --warmup-ratio 20 | 21 | Warmup ratio for the learning rate. Default is *0.1*. 22 | 23 | .. option:: --grad-samples 24 | 25 | Number of gradient samples, scaled by batch size and GPU count. Default is *128*. 26 | 27 | .. option:: --pretrain-trainable 28 | 29 | Number of pretrained layers to be trainable. Default is *2*. 30 | 31 | .. option:: --max-epochs 32 | 33 | Maximum number of training epochs. Default is *10*. 34 | 35 | .. option:: --tag 36 | 37 | Tag of the trainer, used for grouping logged results. Default is *default*. 38 | 39 | .. option:: --limit-val-batches 40 | 41 | Number of batches to use for each validation. Default is *64*. 42 | 43 | .. option:: --val-check-interval 44 | 45 | Interval for validation checks. Default is *64*. 46 | 47 | .. option:: --name 48 | 49 | Name of the training session. Default is *chrombert-ft-gep*. 50 | 51 | .. option:: --save-top-k 52 | 53 | Number of top-performing checkpoints to save. Default is *3*. 54 | 55 | .. option:: --checkpoint-metric 56 | 57 | Metric for checkpointing. Default is *pcc*. 58 | 59 | .. option:: --checkpoint-mode 60 | 61 | Mode for checkpointing. Default is *max*. 62 | 63 | .. option:: --log-every-n-steps 64 | 65 | Logging frequency in terms of steps. Default is *50*. 66 | 67 | .. option:: --kind 68 | 69 | Type of task, such as *regression*, *zero_inflation*. Default is *regression*. 70 | 71 | .. option:: --loss 72 | 73 | Loss function to be used. Default is *rmse*. 74 | 75 | .. option:: --train 76 | 77 | Path to the training data. This option is required. 78 | 79 | .. option:: --valid 80 | 81 | Path to the validation data. This option is required. 82 | 83 | .. option:: --test 84 | 85 | Path to the test data. This option is required. 86 | 87 | .. option:: --batch-size 88 | 89 | Batch size for training. Default is *2*. 90 | 91 | .. option:: --num-workers 92 | 93 | Number of workers for data loading. Default is *4*. 94 | 95 | .. option:: --basedir 96 | 97 | Path to the base directory for model and data files. Default is ``os.path.expanduser("~/.cache/chrombert/data")``. 98 | 99 | .. option:: -g, --genome 100 | 101 | Genome version. Only *hg38* is supported now. Default is *hg38*. 102 | 103 | .. option:: -k, --ckpt 104 | 105 | Path to the pretrained checkpoint. Optional if it could be inferred from other arguments. 106 | 107 | .. option:: --mask 108 | 109 | Path to the mtx mask file. Optional if it could be inferred from other arguments. 110 | 111 | .. option:: -d, --hdf5-file 112 | 113 | Path to the HDF5 file that contains the dataset. Optional if it could be inferred from other arguments. 114 | 115 | .. option:: --dropout 116 | 117 | Dropout rate for the model. Default is *0.1*. 118 | 119 | .. option:: -hr, --high-resolution 120 | 121 | Use 200-bp resolution instead of 1-kb. Note: 200-bp resolution is not available yet, preparing for future release. 122 | 123 | .. option:: --flank-window 124 | 125 | Flank window size for genomic data embedding. Default is *4*. 126 | 127 | .. option:: --gep-zero-inflation 128 | 129 | Specifies whether to include zero inflation in the GEP header. Default is *False*. 130 | 131 | .. option:: --gep-parallel-embedding 132 | 133 | Enable parallel embedding, which is faster but requires more GPU memory. 134 | 135 | .. option:: --gep-gradient-checkpoint 136 | 137 | Use gradient checkpointing to reduce GPU memory usage during training. 138 | -------------------------------------------------------------------------------- /docs/source/scripts/ft_prompt_enhanced.rst: -------------------------------------------------------------------------------- 1 | Prompt-enhanced 2 | ******************* 3 | 4 | This script allows you to fine-tune ChromBERT by adding extra information as prompts. You can include things like cell-type features or DNA sequence patterns to help the model make better predictions. The model uses these prompts as additional clues when analyzing genomic data. 5 | 6 | .. code-block:: shell 7 | 8 | python ft_prompt_enhanced.py [OPTIONS] --prompt-kind KIND \ 9 | --train TRAIN_PATH \ 10 | --valid VALID_PATH \ 11 | --test TEST_PATH 12 | 13 | # use cache file for acceleration 14 | python ft_prompt_enhanced.py [OPTIONS] \ 15 | --prompt-kind KIND \ 16 | --prompt-regulator-cache-file CACHE_PATH1 \ 17 | --prompt-celltype-cache-file CACHE_PATH2 \ 18 | --train TRAIN_PATH \ 19 | --valid VALID_PATH \ 20 | --test TEST_PATH 21 | 22 | 23 | .. rubric:: Options 24 | 25 | .. option:: --lr 26 | 27 | Learning rate. Default is *1e-4*. 28 | 29 | .. option:: --warmup-ratio 30 | 31 | Warmup ratio. Default is *0.1*. 32 | 33 | .. option:: --grad-samples 34 | 35 | Number of gradient samples. Automatically scaled according to the batch size and GPU number. Default is *512*. 36 | 37 | .. option:: --pretrain-trainable 38 | 39 | Number of pretrained layers to be trainable. Default is *0*. 40 | 41 | .. option:: --max-epochs 42 | 43 | Number of epochs to train. Default is *10*. 44 | 45 | .. option:: --tag 46 | 47 | Tag of the trainer, used for grouping logged results. Default is *default*. 48 | 49 | .. option:: --limit-val-batches 50 | 51 | Number of batches to use for each validation. Default is *64*. 52 | 53 | .. option:: --val-check-interval 54 | 55 | Validation check interval. Default is *64*. 56 | 57 | .. option:: --name 58 | 59 | Name of the trainer. Default is *chrombert-ft-prompt-enhanced*. 60 | 61 | .. option:: --save-top-k 62 | 63 | Save top k checkpoints. Default is *3*. 64 | 65 | .. option:: --checkpoint-metric 66 | 67 | Checkpoint metric. Default is *bce*. 68 | 69 | .. option:: --checkpoint-mode 70 | 71 | Checkpoint mode. Default is *min*. 72 | 73 | .. option:: --log-every-n-steps 74 | 75 | Log every n steps. Default is *50*. 76 | 77 | .. option:: --kind 78 | 79 | Kind of the task. Choose from *classification*, *regression*, or *zero_inflation*. Default is *classification*. 80 | 81 | .. option:: --loss 82 | 83 | Loss function. Default is *focal*. 84 | 85 | .. option:: --train 86 | 87 | Path to the training data. This option is required. 88 | 89 | .. option:: --valid 90 | 91 | Path to the validation data. This option is required. 92 | 93 | .. option:: --test 94 | 95 | Path to the test data. This option is required. 96 | 97 | .. option:: --batch-size 98 | 99 | Batch size. Default is *8*. It's suggested to set a larger number to accelerate training here. 100 | 101 | .. option:: --num-workers 102 | 103 | Number of workers. Default is *4*. 104 | 105 | .. option:: --basedir 106 | 107 | Path to the base directory. Default is set to the value of ``os.path.expanduser("~/.cache/chrombert/data")``. 108 | 109 | .. option:: -g, --genome 110 | 111 | Genome version. For example, *hg38* or *mm10*. Only *hg38* is supported now. Default is *hg38*. 112 | 113 | .. option:: -k, --ckpt 114 | 115 | Path to the checkpoints used to initialize the model. Optional. Defualt is the pretrain checkpoint provided in the base directory. 116 | 117 | .. option:: --mask 118 | Path to the mtx mask file. Optional if it could infered from other arguments. 119 | 120 | .. option:: -d, --hdf5-file 121 | 122 | Path to the HDF5 file that contains the dataset. Optional if it could be inferred from other arguments. 123 | 124 | .. option:: --dropout 125 | 126 | Dropout rate. Default is *0.1*. 127 | 128 | .. option:: -hr, --high-resolution 129 | 130 | Use 200-bp resolution instead of 1-kb resolution. Caution: 200-bp resolution is preparing for the future release of ChromBERT, which is not available yet. 131 | 132 | .. option:: --prompt-kind 133 | 134 | Prompt data class. Choose from *cistrome* or *expression*. Default is *None*. This option is required. 135 | 136 | .. option:: --prompt-dim-external 137 | 138 | Dimension of external data. Use *512* for *scGPT*, and *768* for *ChromBERT*'s embedding. Default is *512*. 139 | 140 | .. option:: --prompt-celltype-cache-file 141 | 142 | Path to the cell-type-specific prompt cache file. Provided if you want to use cache file to accelerate the training process. Optional. Default is not use it. 143 | 144 | .. option:: --prompt-regulator-cache-file 145 | 146 | Path to the regulator prompt cache file. Provided if you want to use cache file to accelerate the training process. Optional. Default is not use it. 147 | -------------------------------------------------------------------------------- /docs/source/tutorial_embedding_extraction.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path":"../../examples/tutorials/tutorial_embedding_extraction.ipynb" 3 | } -------------------------------------------------------------------------------- /docs/source/tutorial_finetuning_ChromBERT.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path":"../../examples/tutorials/tutorial_finetuning_ChromBERT.ipynb" 3 | } -------------------------------------------------------------------------------- /docs/source/tutorial_locus_specific_TRN_eqtl.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path":"../../examples/tutorials/tutorial_prompt_eqtl.ipynb" 3 | } -------------------------------------------------------------------------------- /docs/source/tutorial_locus_specific_TRN_ezh2.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path":"../../examples/tutorials/tutorial_locus_specific_TRN_ezh2.ipynb" 3 | } -------------------------------------------------------------------------------- /docs/source/tutorial_locus_specific_TRN_starr.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path":"../../examples/tutorials/tutorial_locus_specific_TRN_starr.ipynb" 3 | } -------------------------------------------------------------------------------- /docs/source/tutorial_prompt_cistrome_imputation.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path":"../../examples/tutorials/tutorial_prompt_cistrome_imputation.ipynb" 3 | } -------------------------------------------------------------------------------- /docs/source/tutorial_transdifferentiation.rst: -------------------------------------------------------------------------------- 1 | Example for key regulators inference during cell state transition 2 | ==================================================================== 3 | 4 | 5 | .. toctree:: 6 | :maxdepth: 1 7 | 8 | tutorial_transdifferentiation_chromatin_accessibility 9 | tutorial_transdifferentiation_transcriptome 10 | -------------------------------------------------------------------------------- /docs/source/tutorial_transdifferentiation_chromatin_accessibility.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path":"../../examples/tutorials/tutorial_transdifferentiation_chromatin_accessibility.ipynb" 3 | } -------------------------------------------------------------------------------- /docs/source/tutorial_transdifferentiation_transcriptome.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path":"../../examples/tutorials/tutorial_transdifferentiation_transcriptome.ipynb" 3 | } -------------------------------------------------------------------------------- /examples/readme.md: -------------------------------------------------------------------------------- 1 | 2 | - The [tutorials](tutorials) directory provides examples of how to use ChromBERT for different tasks. 3 | 4 | - The [train](train) directory provides examples of how to train ChromBERT for different tasks. 5 | 6 | **See the [documentation](https://chrombert.readthedocs.io/en/) for more information.** -------------------------------------------------------------------------------- /lumache.py: -------------------------------------------------------------------------------- 1 | """ 2 | Lumache - Python library for cooks and food lovers. 3 | """ 4 | 5 | __version__ = "0.1.0" 6 | 7 | 8 | class InvalidKindError(Exception): 9 | """Raised if the kind is invalid.""" 10 | pass 11 | 12 | 13 | def get_random_ingredients(kind=None): 14 | """ 15 | Return a list of random ingredients as strings. 16 | 17 | :param kind: Optional "kind" of ingredients. 18 | :type kind: list[str] or None 19 | :raise lumache.InvalidKindError: If the kind is invalid. 20 | :return: The ingredients list. 21 | :rtype: list[str] 22 | """ 23 | return ["shells", "gorgonzola", "parsley"] 24 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools >=61", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | 6 | [project] 7 | name = "chrombert" 8 | version = "1.1.0" 9 | description = "ChromBERT: A pre-trained foundation model for context-specific transcription regulatory network" 10 | authors = [ 11 | {name = "Zhaowei Yu", email = "zhaoweiyu@tongji.edu.cn"}, 12 | {name = "Dongxu Yang", email= "dx_yang@tongji.edu.cn"}, 13 | {name = "Qianqian Chen", email= "2211083@tongji.edu.cn"}, 14 | {name = "Yuxuan Zhang", email= "2211289@tongji.edu.cn"} 15 | ] 16 | readme = "README.md" 17 | # license = "MIT" 18 | 19 | requires-python = ">=3.8" 20 | dependencies = [ 21 | "jupyter", 22 | "pyfaidx", 23 | "ninja", 24 | "packaging", 25 | "torch", 26 | "numpy", 27 | "pandas", 28 | "matplotlib", 29 | "torchinfo", 30 | "h5py", 31 | "lightning >= 2.0.0", 32 | "transformers == 4.28.1", 33 | "huggingface_hub[cli]", 34 | "pyarrow", 35 | "torchvision", 36 | "tensorboard", 37 | "scikit-learn", 38 | "pyBigWig" 39 | ] 40 | 41 | 42 | [project.scripts] 43 | chrombert_make_dataset = "chrombert.scripts.chrombert_make_dataset:main" 44 | chrombert_get_region_emb = "chrombert.scripts.chrombert_get_region_emb:main" 45 | chrombert_get_cistrome_emb = "chrombert.scripts.chrombert_get_cistrome_emb:main" 46 | chrombert_get_regulator_emb = "chrombert.scripts.chrombert_get_regulator_emb:main" 47 | chrombert_prepare_env = "chrombert.scripts.chrombert_prepare_env:main" 48 | chrombert_imputation_cistrome = "chrombert.scripts.chrombert_imputation:main" 49 | chrombert_imputation_cistrome_sc = "chrombert.scripts.chrombert_imputation_sc:main" 50 | 51 | [tool.setuptools] 52 | include-package-data = true 53 | 54 | [tool.setuptools.packages.find] 55 | where = ["."] # list of folders that contain the packages (["."] by default) 56 | include = ["chrombert*"] # package names should match these glob patterns (["*"] by default) 57 | 58 | 59 | [tool.setuptools.package-data] 60 | "*" = ["*.json", "*.md", "*.rst"] 61 | 62 | [tool.setuptools.dynamic] 63 | version = {attr = "chrombert.VERSION"} 64 | readme = {file = ["README.md"]} 65 | --------------------------------------------------------------------------------