├── .gitattributes ├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── data └── test-sample.pt ├── enformer.png ├── enformer_pytorch ├── __init__.py ├── config_enformer.py ├── data.py ├── finetune.py ├── metrics.py ├── modeling_enformer.py └── precomputed │ └── tf_gammas.pt ├── evaluate_enformer_pytorch_correlation.ipynb ├── scripts └── tf_to_torch.py ├── setup.py └── test_pretrained.py /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=auto 2 | *.ipynb filter=nbstripout 3 | *.ipynb diff=ipynb 4 | *.ipynb linguist-language=Python 5 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include enformer_pytorch/precomputed/tf_gammas.pt 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Enformer - Pytorch 4 | 5 | Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch. This repository also contains the means to fine tune pretrained models for your downstream tasks. The original tensorflow sonnet code can be found here. 6 | 7 | Update: finetuned for predicting pseudobulk chromatin accessibility here 8 | 9 | ## Install 10 | 11 | ```bash 12 | $ pip install enformer-pytorch 13 | ``` 14 | 15 | ## Usage 16 | 17 | ```python 18 | import torch 19 | from enformer_pytorch import Enformer 20 | 21 | model = Enformer.from_hparams( 22 | dim = 1536, 23 | depth = 11, 24 | heads = 8, 25 | output_heads = dict(human = 5313, mouse = 1643), 26 | target_length = 896, 27 | ) 28 | 29 | seq = torch.randint(0, 5, (1, 196_608)) # for ACGTN, in that order (-1 for padding) 30 | output = model(seq) 31 | 32 | output['human'] # (1, 896, 5313) 33 | output['mouse'] # (1, 896, 1643) 34 | ``` 35 | 36 | You can also directly pass in the sequence as one-hot encodings, which must be float values 37 | 38 | ```python 39 | import torch 40 | from enformer_pytorch import Enformer, seq_indices_to_one_hot 41 | 42 | model = Enformer.from_hparams( 43 | dim = 1536, 44 | depth = 11, 45 | heads = 8, 46 | output_heads = dict(human = 5313, mouse = 1643), 47 | target_length = 896, 48 | ) 49 | 50 | seq = torch.randint(0, 5, (1, 196_608)) 51 | one_hot = seq_indices_to_one_hot(seq) 52 | 53 | output = model(one_hot) 54 | 55 | output['human'] # (1, 896, 5313) 56 | output['mouse'] # (1, 896, 1643) 57 | ``` 58 | 59 | Finally, one can fetch the embeddings, for fine-tuning and otherwise, by setting the `return_embeddings` flag to be `True` on forward 60 | 61 | ```python 62 | import torch 63 | from enformer_pytorch import Enformer, seq_indices_to_one_hot 64 | 65 | model = Enformer.from_hparams( 66 | dim = 1536, 67 | depth = 11, 68 | heads = 8, 69 | output_heads = dict(human = 5313, mouse = 1643), 70 | target_length = 896, 71 | ) 72 | 73 | seq = torch.randint(0, 5, (1, 196_608)) 74 | one_hot = seq_indices_to_one_hot(seq) 75 | 76 | output, embeddings = model(one_hot, return_embeddings = True) 77 | 78 | embeddings # (1, 896, 3072) 79 | ``` 80 | 81 | For training, you can directly pass the head and target in to get the poisson loss 82 | 83 | ```python 84 | import torch 85 | from enformer_pytorch import Enformer, seq_indices_to_one_hot 86 | 87 | model = Enformer.from_hparams( 88 | dim = 1536, 89 | depth = 11, 90 | heads = 8, 91 | output_heads = dict(human = 5313, mouse = 1643), 92 | target_length = 200, 93 | ).cuda() 94 | 95 | seq = torch.randint(0, 5, (196_608 // 2,)).cuda() 96 | target = torch.randn(200, 5313).cuda() 97 | 98 | loss = model( 99 | seq, 100 | head = 'human', 101 | target = target 102 | ) 103 | 104 | loss.backward() 105 | 106 | # after much training 107 | 108 | corr_coef = model( 109 | seq, 110 | head = 'human', 111 | target = target, 112 | return_corr_coef = True 113 | ) 114 | 115 | corr_coef # pearson R, used as a metric in the paper 116 | ``` 117 | 118 | ## Pretrained Model 119 | 120 | Deepmind has released the weights for their tensorflow sonnet Enformer model! I have ported it over to Pytorch and uploaded it to 🤗 Huggingface (~1GB). There are still some rounding errors that seem to be accruing across the layers, resulting in an absolute error as high as `0.5`. However, correlation coefficient look good so I am releasing the 'rough'ly working version. Will keep working on figuring out where the numerical errors are happening (it may be the attention pooling module, as I noticed the attention logits are pretty high). 121 | 122 | Update: John St. John did some work and found that the `enformer-official-rough` model hits the reported marks in the paper - human pearson R of `0.625` for validation, and `0.65` for test. 123 | 124 | Update: As of version 0.8.0, if one were to use the `from_pretrained` function to load the pretrained model, it should automatically use precomputed gamma positions to address a difference between tensorflow and pytorch `xlogy`. This should resolve the numerical discrepancy above. If you were to further finetune and not be using the `from_pretrained` function, please make sure to set `use_tf_gamma = True` when using `.from_hparams` to instantiate the `Enformer` 125 | 126 | ```bash 127 | $ pip install enformer-pytorch>=0.5 128 | ```` 129 | 130 | Loading the model 131 | 132 | ```python 133 | from enformer_pytorch import from_pretrained 134 | 135 | enformer = from_pretrained('EleutherAI/enformer-official-rough') 136 | ``` 137 | 138 | Quick sanity check on a single human validation point 139 | 140 | ```python 141 | $ python test_pretrained.py 142 | # 0.5963 correlation coefficient on a validation sample 143 | ``` 144 | 145 | This is all made possible thanks to HuggingFace's [custom model](https://huggingface.co/docs/transformers/master/en/custom_models) feature. 146 | 147 | You can also load, with overriding of the `target_length` parameter, if you are working with shorter sequence lengths 148 | 149 | ```python 150 | from enformer_pytorch import from_pretrained 151 | 152 | model = from_pretrained('EleutherAI/enformer-official-rough', target_length = 128, dropout_rate = 0.1) 153 | 154 | # do your fine-tuning 155 | ``` 156 | 157 | To save on memory during fine-tuning a large Enformer model 158 | 159 | ```python 160 | from enformer_pytorch import from_pretrained 161 | 162 | enformer = from_pretrained('EleutherAI/enformer-official-rough', use_checkpointing = True) 163 | 164 | # finetune enformer on a limited budget 165 | ``` 166 | 167 | ## Fine-tuning 168 | 169 | This repository will also allow for easy fine-tuning of Enformer. 170 | 171 | Fine-tuning on new tracks 172 | 173 | ```python 174 | import torch 175 | from enformer_pytorch import from_pretrained 176 | from enformer_pytorch.finetune import HeadAdapterWrapper 177 | 178 | enformer = from_pretrained('EleutherAI/enformer-official-rough') 179 | 180 | model = HeadAdapterWrapper( 181 | enformer = enformer, 182 | num_tracks = 128, 183 | post_transformer_embed = False # by default, embeddings are taken from after the final pointwise block w/ conv -> gelu - but if you'd like the embeddings right after the transformer block with a learned layernorm, set this to True 184 | ).cuda() 185 | 186 | seq = torch.randint(0, 5, (1, 196_608 // 2,)).cuda() 187 | target = torch.randn(1, 200, 128).cuda() # 128 tracks 188 | 189 | loss = model(seq, target = target) 190 | loss.backward() 191 | ``` 192 | 193 | Finetuning on contextual data (cell type, transcription factor, etc) 194 | 195 | ```python 196 | import torch 197 | from enformer_pytorch import from_pretrained 198 | from enformer_pytorch.finetune import ContextAdapterWrapper 199 | 200 | enformer = from_pretrained('EleutherAI/enformer-official-rough') 201 | 202 | model = ContextAdapterWrapper( 203 | enformer = enformer, 204 | context_dim = 1024 205 | ).cuda() 206 | 207 | seq = torch.randint(0, 5, (1, 196_608 // 2,)).cuda() 208 | 209 | target = torch.randn(1, 200, 4).cuda() # 4 tracks 210 | context = torch.randn(4, 1024).cuda() # 4 contexts for the different 'tracks' 211 | 212 | loss = model( 213 | seq, 214 | context = context, 215 | target = target 216 | ) 217 | 218 | loss.backward() 219 | ``` 220 | 221 | Finally, there is also a way to use attention aggregation from a set of context embeddings (or a single context embedding). Simply use the `ContextAttentionAdapterWrapper` 222 | 223 | ```python 224 | import torch 225 | from enformer_pytorch import from_pretrained 226 | from enformer_pytorch.finetune import ContextAttentionAdapterWrapper 227 | 228 | enformer = from_pretrained('EleutherAI/enformer-official-rough') 229 | 230 | model = ContextAttentionAdapterWrapper( 231 | enformer = enformer, 232 | context_dim = 1024, 233 | heads = 8, # number of heads in the cross attention 234 | dim_head = 64 # dimension per head 235 | ).cuda() 236 | 237 | seq = torch.randint(0, 5, (1, 196_608 // 2,)).cuda() 238 | 239 | target = torch.randn(1, 200, 4).cuda() # 4 tracks 240 | context = torch.randn(4, 16, 1024).cuda() # 4 contexts for the different 'tracks', each with 16 tokens 241 | 242 | context_mask = torch.ones(4, 16).bool().cuda() # optional context mask, in example, include all context tokens 243 | 244 | loss = model( 245 | seq, 246 | context = context, 247 | context_mask = context_mask, 248 | target = target 249 | ) 250 | 251 | loss.backward() 252 | ``` 253 | 254 | ## Data 255 | 256 | You can use the `GenomicIntervalDataset` to easily fetch sequences of any length from a `.bed` file, with greater context length dynamically computed if specified 257 | 258 | ```python 259 | import torch 260 | import polars as pl 261 | from enformer_pytorch import Enformer, GenomeIntervalDataset 262 | 263 | filter_train = lambda df: df.filter(pl.col('column_4') == 'train') 264 | 265 | ds = GenomeIntervalDataset( 266 | bed_file = './sequences.bed', # bed file - columns 0, 1, 2 must be , , 267 | fasta_file = './hg38.ml.fa', # path to fasta file 268 | filter_df_fn = filter_train, # filter dataframe function 269 | return_seq_indices = True, # return nucleotide indices (ACGTN) or one hot encodings 270 | shift_augs = (-2, 2), # random shift augmentations from -2 to +2 basepairs 271 | context_length = 196_608, 272 | # this can be longer than the interval designated in the .bed file, 273 | # in which case it will take care of lengthening the interval on either sides 274 | # as well as proper padding if at the end of the chromosomes 275 | chr_bed_to_fasta_map = { 276 | 'chr1': 'chromosome1', # if the chromosome name in the .bed file is different than the key name in the fasta file, you can rename them on the fly 277 | 'chr2': 'chromosome2', 278 | 'chr3': 'chromosome3', 279 | # etc etc 280 | } 281 | ) 282 | 283 | model = Enformer.from_hparams( 284 | dim = 1536, 285 | depth = 11, 286 | heads = 8, 287 | output_heads = dict(human = 5313, mouse = 1643), 288 | target_length = 896, 289 | ) 290 | 291 | seq = ds[0] # (196608,) 292 | pred = model(seq, head = 'human') # (896, 5313) 293 | ``` 294 | 295 | To return the random shift value, as well as whether reverse complement was activated (in the case you need to reverse the corresponding chip-seq target data), just set `return_augs = True` when initializing the `GenomicIntervalDataset` 296 | 297 | ```python 298 | import torch 299 | import polars as pl 300 | from enformer_pytorch import Enformer, GenomeIntervalDataset 301 | 302 | filter_train = lambda df: df.filter(pl.col('column_4') == 'train') 303 | 304 | ds = GenomeIntervalDataset( 305 | bed_file = './sequences.bed', # bed file - columns 0, 1, 2 must be , , 306 | fasta_file = './hg38.ml.fa', # path to fasta file 307 | filter_df_fn = filter_train, # filter dataframe function 308 | return_seq_indices = True, # return nucleotide indices (ACGTN) or one hot encodings 309 | shift_augs = (-2, 2), # random shift augmentations from -2 to +2 basepairs 310 | rc_aug = True, # use reverse complement augmentation with 50% probability 311 | context_length = 196_608, 312 | return_augs = True # return the augmentation meta data 313 | ) 314 | 315 | seq, rand_shift_val, rc_bool = ds[0] # (196608,), (1,), (1,) 316 | ``` 317 | 318 | ## Appreciation 319 | 320 | Special thanks goes out to EleutherAI for providing the resources to retrain the model, during a time when the official model from Deepmind had not been released yet. 321 | 322 | Thanks also goes out to @johahi for finding out that there are numerical differences between the torch and tensorflow implementations of `xlogy`. He provided a fix for this difference, which is adopted in this repository in `v0.8.0` 323 | 324 | ## Todo 325 | 326 | - [x] script to load weights from trained tensorflow enformer model to pytorch model 327 | - [x] add loss wrapper with poisson loss 328 | - [x] move the metrics code over to pytorch as well 329 | - [x] train enformer model 330 | - [x] build context manager for fine-tuning with unfrozen enformer but with frozen batchnorm 331 | - [x] allow for plain fine-tune with fixed static context 332 | - [x] allow for fine tuning with only unfrozen layernorms (technique from fine tuning transformers) 333 | - [x] fix handling of 'N' in sequence, figure out representation of N in basenji barnyard 334 | - [x] take care of shift augmentation in `GenomicIntervalDataset` 335 | - [x] speed up `str_to_seq_indices` 336 | - [x] add to EleutherAI huggingface (done thanks to Niels) 337 | - [ ] offer some basic training utils, as gradient accumulation will be needed for fine tuning 338 | 339 | ## Citations 340 | 341 | ```bibtex 342 | @article {Avsec2021.04.07.438649, 343 | author = {Avsec, {\v Z}iga and Agarwal, Vikram and Visentin, Daniel and Ledsam, Joseph R. and Grabska-Barwinska, Agnieszka and Taylor, Kyle R. and Assael, Yannis and Jumper, John and Kohli, Pushmeet and Kelley, David R.}, 344 | title = {Effective gene expression prediction from sequence by integrating long-range interactions}, 345 | elocation-id = {2021.04.07.438649}, 346 | year = {2021}, 347 | doi = {10.1101/2021.04.07.438649}, 348 | publisher = {Cold Spring Harbor Laboratory}, 349 | URL = {https://www.biorxiv.org/content/early/2021/04/08/2021.04.07.438649}, 350 | eprint = {https://www.biorxiv.org/content/early/2021/04/08/2021.04.07.438649.full.pdf}, 351 | journal = {bioRxiv} 352 | } 353 | ``` 354 | 355 | ```bibtex 356 | @misc{liu2022convnet, 357 | title = {A ConvNet for the 2020s}, 358 | author = {Zhuang Liu and Hanzi Mao and Chao-Yuan Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie}, 359 | year = {2022}, 360 | eprint = {2201.03545}, 361 | archivePrefix = {arXiv}, 362 | primaryClass = {cs.CV} 363 | } 364 | ``` 365 | -------------------------------------------------------------------------------- /data/test-sample.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/enformer-pytorch/5a5974d2821c728f93294731c50b55f1f55fd86d/data/test-sample.pt -------------------------------------------------------------------------------- /enformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/enformer-pytorch/5a5974d2821c728f93294731c50b55f1f55fd86d/enformer.png -------------------------------------------------------------------------------- /enformer_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from enformer_pytorch.config_enformer import EnformerConfig 2 | from enformer_pytorch.modeling_enformer import Enformer, from_pretrained, SEQUENCE_LENGTH, AttentionPool 3 | from enformer_pytorch.data import seq_indices_to_one_hot, str_to_one_hot, GenomeIntervalDataset, FastaInterval -------------------------------------------------------------------------------- /enformer_pytorch/config_enformer.py: -------------------------------------------------------------------------------- 1 | from transformers import PretrainedConfig 2 | 3 | class EnformerConfig(PretrainedConfig): 4 | model_type = "enformer" 5 | 6 | def __init__( 7 | self, 8 | dim = 1536, 9 | depth = 11, 10 | heads = 8, 11 | output_heads = dict(human = 5313, mouse= 1643), 12 | target_length = 896, 13 | attn_dim_key = 64, 14 | dropout_rate = 0.4, 15 | attn_dropout = 0.05, 16 | pos_dropout = 0.01, 17 | use_checkpointing = False, 18 | use_convnext = False, 19 | num_downsamples = 7, # genetic sequence is downsampled 2 ** 7 == 128x in default Enformer - can be changed for higher resolution 20 | dim_divisible_by = 128, 21 | use_tf_gamma = False, 22 | **kwargs, 23 | ): 24 | self.dim = dim 25 | self.depth = depth 26 | self.heads = heads 27 | self.output_heads = output_heads 28 | self.target_length = target_length 29 | self.attn_dim_key = attn_dim_key 30 | self.dropout_rate = dropout_rate 31 | self.attn_dropout = attn_dropout 32 | self.pos_dropout = pos_dropout 33 | self.use_checkpointing = use_checkpointing 34 | self.num_downsamples = num_downsamples 35 | self.dim_divisible_by = dim_divisible_by 36 | self.use_tf_gamma = use_tf_gamma 37 | 38 | super().__init__(**kwargs) -------------------------------------------------------------------------------- /enformer_pytorch/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.utils.data import Dataset 4 | 5 | import polars as pl 6 | import numpy as np 7 | from random import randrange, random 8 | from pathlib import Path 9 | from pyfaidx import Fasta 10 | 11 | # helper functions 12 | 13 | def exists(val): 14 | return val is not None 15 | 16 | def identity(t): 17 | return t 18 | 19 | def cast_list(t): 20 | return t if isinstance(t, list) else [t] 21 | 22 | def coin_flip(): 23 | return random() > 0.5 24 | 25 | # genomic function transforms 26 | 27 | seq_indices_embed = torch.zeros(256).long() 28 | seq_indices_embed[ord('a')] = 0 29 | seq_indices_embed[ord('c')] = 1 30 | seq_indices_embed[ord('g')] = 2 31 | seq_indices_embed[ord('t')] = 3 32 | seq_indices_embed[ord('n')] = 4 33 | seq_indices_embed[ord('A')] = 0 34 | seq_indices_embed[ord('C')] = 1 35 | seq_indices_embed[ord('G')] = 2 36 | seq_indices_embed[ord('T')] = 3 37 | seq_indices_embed[ord('N')] = 4 38 | seq_indices_embed[ord('.')] = -1 39 | 40 | one_hot_embed = torch.zeros(256, 4) 41 | one_hot_embed[ord('a')] = torch.Tensor([1., 0., 0., 0.]) 42 | one_hot_embed[ord('c')] = torch.Tensor([0., 1., 0., 0.]) 43 | one_hot_embed[ord('g')] = torch.Tensor([0., 0., 1., 0.]) 44 | one_hot_embed[ord('t')] = torch.Tensor([0., 0., 0., 1.]) 45 | one_hot_embed[ord('n')] = torch.Tensor([0., 0., 0., 0.]) 46 | one_hot_embed[ord('A')] = torch.Tensor([1., 0., 0., 0.]) 47 | one_hot_embed[ord('C')] = torch.Tensor([0., 1., 0., 0.]) 48 | one_hot_embed[ord('G')] = torch.Tensor([0., 0., 1., 0.]) 49 | one_hot_embed[ord('T')] = torch.Tensor([0., 0., 0., 1.]) 50 | one_hot_embed[ord('N')] = torch.Tensor([0., 0., 0., 0.]) 51 | one_hot_embed[ord('.')] = torch.Tensor([0.25, 0.25, 0.25, 0.25]) 52 | 53 | reverse_complement_map = torch.Tensor([3, 2, 1, 0, 4]).long() 54 | 55 | def torch_fromstring(seq_strs): 56 | batched = not isinstance(seq_strs, str) 57 | seq_strs = cast_list(seq_strs) 58 | np_seq_chrs = list(map(lambda t: np.fromstring(t, dtype = np.uint8), seq_strs)) 59 | seq_chrs = list(map(torch.from_numpy, np_seq_chrs)) 60 | return torch.stack(seq_chrs) if batched else seq_chrs[0] 61 | 62 | def str_to_seq_indices(seq_strs): 63 | seq_chrs = torch_fromstring(seq_strs) 64 | return seq_indices_embed[seq_chrs.long()] 65 | 66 | def str_to_one_hot(seq_strs): 67 | seq_chrs = torch_fromstring(seq_strs) 68 | return one_hot_embed[seq_chrs.long()] 69 | 70 | def seq_indices_to_one_hot(t, padding = -1): 71 | is_padding = t == padding 72 | t = t.clamp(min = 0) 73 | one_hot = F.one_hot(t, num_classes = 5) 74 | out = one_hot[..., :4].float() 75 | out = out.masked_fill(is_padding[..., None], 0.25) 76 | return out 77 | 78 | # augmentations 79 | 80 | def seq_indices_reverse_complement(seq_indices): 81 | complement = reverse_complement_map[seq_indices.long()] 82 | return torch.flip(complement, dims = (-1,)) 83 | 84 | def one_hot_reverse_complement(one_hot): 85 | *_, n, d = one_hot.shape 86 | assert d == 4, 'must be one hot encoding with last dimension equal to 4' 87 | return torch.flip(one_hot, (-1, -2)) 88 | 89 | # processing bed files 90 | 91 | class FastaInterval(): 92 | def __init__( 93 | self, 94 | *, 95 | fasta_file, 96 | context_length = None, 97 | return_seq_indices = False, 98 | shift_augs = None, 99 | rc_aug = False 100 | ): 101 | fasta_file = Path(fasta_file) 102 | assert fasta_file.exists(), 'path to fasta file must exist' 103 | 104 | self.seqs = Fasta(str(fasta_file)) 105 | self.return_seq_indices = return_seq_indices 106 | self.context_length = context_length 107 | self.shift_augs = shift_augs 108 | self.rc_aug = rc_aug 109 | 110 | def __call__(self, chr_name, start, end, return_augs = False): 111 | interval_length = end - start 112 | chromosome = self.seqs[chr_name] 113 | chromosome_length = len(chromosome) 114 | 115 | if exists(self.shift_augs): 116 | min_shift, max_shift = self.shift_augs 117 | max_shift += 1 118 | 119 | min_shift = min(max(start + min_shift, 0) - start, 0) 120 | max_shift = max(min(end + max_shift, chromosome_length) - end, 1) 121 | 122 | rand_shift = randrange(min_shift, max_shift) 123 | start += rand_shift 124 | end += rand_shift 125 | 126 | left_padding = right_padding = 0 127 | 128 | if exists(self.context_length) and interval_length < self.context_length: 129 | extra_seq = self.context_length - interval_length 130 | 131 | extra_left_seq = extra_seq // 2 132 | extra_right_seq = extra_seq - extra_left_seq 133 | 134 | start -= extra_left_seq 135 | end += extra_right_seq 136 | 137 | if start < 0: 138 | left_padding = -start 139 | start = 0 140 | 141 | if end > chromosome_length: 142 | right_padding = end - chromosome_length 143 | end = chromosome_length 144 | 145 | seq = ('.' * left_padding) + str(chromosome[start:end]) + ('.' * right_padding) 146 | 147 | should_rc_aug = self.rc_aug and coin_flip() 148 | 149 | if self.return_seq_indices: 150 | seq = str_to_seq_indices(seq) 151 | 152 | if should_rc_aug: 153 | seq = seq_indices_reverse_complement(seq) 154 | 155 | return seq 156 | 157 | one_hot = str_to_one_hot(seq) 158 | 159 | if should_rc_aug: 160 | one_hot = one_hot_reverse_complement(one_hot) 161 | 162 | if not return_augs: 163 | return one_hot 164 | 165 | # returns the shift integer as well as the bool (for whether reverse complement was activated) 166 | # for this particular genomic sequence 167 | 168 | rand_shift_tensor = torch.tensor([rand_shift]) 169 | rand_aug_bool_tensor = torch.tensor([should_rc_aug]) 170 | 171 | return one_hot, rand_shift_tensor, rand_aug_bool_tensor 172 | 173 | 174 | class GenomeIntervalDataset(Dataset): 175 | def __init__( 176 | self, 177 | bed_file, 178 | fasta_file, 179 | filter_df_fn = identity, 180 | chr_bed_to_fasta_map = dict(), 181 | context_length = None, 182 | return_seq_indices = False, 183 | shift_augs = None, 184 | rc_aug = False, 185 | return_augs = False 186 | ): 187 | super().__init__() 188 | bed_path = Path(bed_file) 189 | assert bed_path.exists(), 'path to .bed file must exist' 190 | 191 | df = pl.read_csv(str(bed_path), separator = '\t', has_header = False) 192 | df = filter_df_fn(df) 193 | self.df = df 194 | 195 | # if the chromosome name in the bed file is different than the keyname in the fasta 196 | # can remap on the fly 197 | self.chr_bed_to_fasta_map = chr_bed_to_fasta_map 198 | 199 | self.fasta = FastaInterval( 200 | fasta_file = fasta_file, 201 | context_length = context_length, 202 | return_seq_indices = return_seq_indices, 203 | shift_augs = shift_augs, 204 | rc_aug = rc_aug 205 | ) 206 | 207 | self.return_augs = return_augs 208 | 209 | def __len__(self): 210 | return len(self.df) 211 | 212 | def __getitem__(self, ind): 213 | interval = self.df.row(ind) 214 | chr_name, start, end = (interval[0], interval[1], interval[2]) 215 | chr_name = self.chr_bed_to_fasta_map.get(chr_name, chr_name) 216 | return self.fasta(chr_name, start, end, return_augs = self.return_augs) 217 | -------------------------------------------------------------------------------- /enformer_pytorch/finetune.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional 3 | 4 | from copy import deepcopy 5 | from contextlib import contextmanager 6 | import torch.nn.functional as F 7 | from torch import nn, einsum 8 | 9 | from einops import rearrange, repeat 10 | from einops.layers.torch import Rearrange 11 | from enformer_pytorch.modeling_enformer import Enformer, poisson_loss 12 | 13 | from discrete_key_value_bottleneck_pytorch import DiscreteKeyValueBottleneck 14 | 15 | def exists(val): 16 | return val is not None 17 | 18 | def default(val, d): 19 | return val if exists(val) else d 20 | 21 | @contextmanager 22 | def null_context(): 23 | yield 24 | 25 | # better sequential 26 | 27 | def Sequential(*modules): 28 | return nn.Sequential(*filter(exists, modules)) 29 | 30 | # controlling freezing of layers 31 | 32 | def set_module_requires_grad_(module, requires_grad): 33 | for param in module.parameters(): 34 | param.requires_grad = requires_grad 35 | 36 | def freeze_all_layers_(module): 37 | set_module_requires_grad_(module, False) 38 | 39 | def unfreeze_all_layers_(module): 40 | set_module_requires_grad_(module, True) 41 | 42 | def freeze_batchnorms_(model): 43 | bns = [m for m in model.modules() if isinstance(m, nn.BatchNorm1d)] 44 | 45 | for bn in bns: 46 | bn.eval() 47 | bn.track_running_stats = False 48 | set_module_requires_grad_(bn, False) 49 | 50 | def freeze_all_but_layernorms_(model): 51 | for m in model.modules(): 52 | set_module_requires_grad_(m, isinstance(m, nn.LayerNorm)) 53 | 54 | def freeze_all_but_last_n_layers_(enformer, n): 55 | assert isinstance(enformer, Enformer) 56 | freeze_all_layers_(enformer) 57 | 58 | transformer_blocks = enformer.transformer 59 | 60 | for module in transformer_blocks[-n:]: 61 | set_module_requires_grad_(module, True) 62 | 63 | # get enformer embeddings 64 | 65 | def get_enformer_embeddings( 66 | model, 67 | seq, 68 | freeze = False, 69 | train_layernorms_only = False, 70 | train_last_n_layers_only = None, 71 | enformer_kwargs: dict = {} 72 | ): 73 | freeze_batchnorms_(model) 74 | 75 | if train_layernorms_only: 76 | assert not freeze, 'you set the intent to train the layernorms of the enformer, yet also indicated you wanted to freeze the entire model' 77 | freeze_all_but_layernorms_(model) 78 | 79 | if exists(train_last_n_layers_only): 80 | assert not freeze, 'you set the intent to train last N layers of enformer, but also indicated you wanted to freeze the entire network' 81 | freeze_all_but_last_n_layers_(model, train_last_n_layers_only) 82 | 83 | enformer_context = null_context() if not freeze else torch.no_grad() 84 | 85 | with enformer_context: 86 | embeddings = model(seq, return_only_embeddings = True, **enformer_kwargs) 87 | 88 | if freeze: 89 | embeddings.detach_() 90 | 91 | return embeddings 92 | 93 | # fine-tune wrapper classes 94 | 95 | # extra head projection, akin to how human and mouse tracks were trained 96 | 97 | class HeadAdapterWrapper(nn.Module): 98 | def __init__( 99 | self, 100 | *, 101 | enformer, 102 | num_tracks, 103 | post_transformer_embed = False, # whether to take the embeddings from right after the transformer, instead of after the final pointwise convolutional - this would add another layernorm 104 | discrete_key_value_bottleneck = False, 105 | bottleneck_num_memories = 256, 106 | bottleneck_num_codebooks = 4, 107 | bottleneck_decay = 0.9, 108 | transformer_embed_fn: nn.Module = nn.Identity(), 109 | output_activation: Optional[nn.Module] = nn.Softplus(), 110 | auto_set_target_length = True 111 | ): 112 | super().__init__() 113 | assert isinstance(enformer, Enformer) 114 | enformer_hidden_dim = enformer.dim * (2 if not post_transformer_embed else 1) 115 | 116 | self.discrete_key_value_bottleneck = discrete_key_value_bottleneck 117 | 118 | if discrete_key_value_bottleneck: 119 | enformer = DiscreteKeyValueBottleneck( 120 | encoder = enformer, 121 | dim = enformer_hidden_dim, 122 | num_memory_codebooks = bottleneck_num_codebooks, 123 | num_memories = bottleneck_num_memories, 124 | dim_memory = enformer_hidden_dim // bottleneck_num_codebooks, 125 | decay = bottleneck_decay, 126 | ) 127 | 128 | self.post_transformer_embed = post_transformer_embed 129 | 130 | self.enformer = enformer 131 | 132 | self.auto_set_target_length = auto_set_target_length 133 | 134 | if post_transformer_embed: 135 | self.enformer = deepcopy(enformer) 136 | self.enformer._trunk[-1] = nn.Identity() 137 | self.enformer.final_pointwise = nn.Identity() 138 | 139 | self.post_embed_transform = Sequential( 140 | transformer_embed_fn, 141 | nn.LayerNorm(enformer_hidden_dim) if post_transformer_embed else None 142 | ) 143 | 144 | self.to_tracks = Sequential( 145 | nn.Linear(enformer_hidden_dim, num_tracks), 146 | output_activation 147 | ) 148 | 149 | def forward( 150 | self, 151 | seq, 152 | *, 153 | target = None, 154 | freeze_enformer = False, 155 | finetune_enformer_ln_only = False, 156 | finetune_last_n_layers_only = None 157 | ): 158 | enformer_kwargs = dict() 159 | 160 | if exists(target) and self.auto_set_target_length: 161 | enformer_kwargs = dict(target_length = target.shape[-2]) 162 | 163 | if self.discrete_key_value_bottleneck: 164 | embeddings = self.enformer(seq, return_only_embeddings = True, **enformer_kwargs) 165 | else: 166 | embeddings = get_enformer_embeddings(self.enformer, seq, freeze = freeze_enformer, train_layernorms_only = finetune_enformer_ln_only, train_last_n_layers_only = finetune_last_n_layers_only, enformer_kwargs = enformer_kwargs) 167 | 168 | preds = self.to_tracks(embeddings) 169 | 170 | if not exists(target): 171 | return preds 172 | 173 | return poisson_loss(preds, target) 174 | 175 | # wrapper that allows one to supply each track with a context dimension 176 | # the context embedding will be projected into the weights and biases of the head linear projection (hypernetwork) 177 | 178 | class ContextAdapterWrapper(nn.Module): 179 | def __init__( 180 | self, 181 | *, 182 | enformer, 183 | context_dim, 184 | discrete_key_value_bottleneck = False, 185 | bottleneck_num_memories = 256, 186 | bottleneck_num_codebooks = 4, 187 | bottleneck_decay = 0.9, 188 | auto_set_target_length = True, 189 | output_activation: Optional[nn.Module] = nn.Softplus() 190 | ): 191 | super().__init__() 192 | assert isinstance(enformer, Enformer) 193 | enformer_hidden_dim = enformer.dim * 2 194 | 195 | self.discrete_key_value_bottleneck = discrete_key_value_bottleneck 196 | 197 | if discrete_key_value_bottleneck: 198 | enformer = DiscreteKeyValueBottleneck( 199 | encoder = enformer, 200 | dim = enformer_hidden_dim, 201 | num_memory_codebooks = bottleneck_num_codebooks, 202 | num_memories = bottleneck_num_memories, 203 | dim_memory = enformer_hidden_dim // bottleneck_num_codebooks, 204 | decay = bottleneck_decay, 205 | ) 206 | 207 | self.enformer = enformer 208 | 209 | self.auto_set_target_length = auto_set_target_length 210 | 211 | self.to_context_weights = nn.Parameter(torch.randn(context_dim, enformer_hidden_dim)) 212 | self.to_context_bias = nn.Parameter(torch.randn(context_dim)) 213 | 214 | self.activation = default(output_activation, nn.Identity()) 215 | 216 | def forward( 217 | self, 218 | seq, 219 | *, 220 | context, 221 | target = None, 222 | freeze_enformer = False, 223 | finetune_enformer_ln_only = False, 224 | finetune_last_n_layers_only = None 225 | ): 226 | enformer_kwargs = dict() 227 | 228 | if exists(target) and self.auto_set_target_length: 229 | enformer_kwargs = dict(target_length = target.shape[-2]) 230 | 231 | if self.discrete_key_value_bottleneck: 232 | embeddings = self.enformer(seq, return_only_embeddings = True, **enformer_kwargs) 233 | else: 234 | embeddings = get_enformer_embeddings(self.enformer, seq, freeze = freeze_enformer, train_layernorms_only = finetune_enformer_ln_only, train_last_n_layers_only = finetune_last_n_layers_only, enformer_kwargs = enformer_kwargs) 235 | 236 | weights = einsum('t d, d e -> t e', context, self.to_context_weights) 237 | bias = einsum('t d, d -> t', context, self.to_context_bias) 238 | 239 | pred = einsum('b n d, t d -> b n t', embeddings, weights) + bias 240 | 241 | pred = self.activation(pred) 242 | 243 | if not exists(target): 244 | return pred 245 | 246 | return poisson_loss(pred, target) 247 | 248 | # wrapper that does attention aggregation of the context, which can be a list of tokens (batch x seq x dim) 249 | 250 | class ContextAttentionAdapterWrapper(nn.Module): 251 | def __init__( 252 | self, 253 | *, 254 | enformer, 255 | context_dim, 256 | heads = 8, 257 | dim_head = 64, 258 | discrete_key_value_bottleneck = False, 259 | bottleneck_num_memories = 256, 260 | bottleneck_num_codebooks = 4, 261 | bottleneck_decay = 0.9, 262 | auto_set_target_length = True, 263 | output_activation: Optional[nn.Module] = nn.Softplus() 264 | ): 265 | super().__init__() 266 | assert isinstance(enformer, Enformer) 267 | enformer_hidden_dim = enformer.dim * 2 268 | 269 | self.discrete_key_value_bottleneck = discrete_key_value_bottleneck 270 | 271 | if discrete_key_value_bottleneck: 272 | enformer = DiscreteKeyValueBottleneck( 273 | encoder = enformer, 274 | dim = enformer_hidden_dim, 275 | num_memory_codebooks = bottleneck_num_codebooks, 276 | num_memories = bottleneck_num_memories, 277 | dim_memory = enformer_hidden_dim // bottleneck_num_codebooks, 278 | decay = bottleneck_decay, 279 | ) 280 | 281 | self.enformer = enformer 282 | 283 | self.auto_set_target_length = auto_set_target_length 284 | 285 | self.query_norm = nn.LayerNorm(enformer_hidden_dim) 286 | self.key_values_norm = nn.LayerNorm(context_dim) 287 | 288 | self.scale = dim_head ** -0.5 289 | self.heads = heads 290 | inner_dim = heads * dim_head 291 | self.to_queries = nn.Linear(enformer_hidden_dim, inner_dim, bias = False) 292 | 293 | self.null_key = nn.Parameter(torch.randn(inner_dim)) 294 | self.null_value = nn.Parameter(torch.randn(inner_dim)) 295 | 296 | self.to_key_values = nn.Linear(context_dim, inner_dim * 2, bias = False) 297 | self.to_out = nn.Linear(inner_dim, enformer_hidden_dim) 298 | 299 | self.to_pred = Sequential( 300 | nn.Linear(enformer_hidden_dim, 1), 301 | Rearrange('b c ... 1 -> b ... c'), 302 | output_activation 303 | ) 304 | 305 | def forward( 306 | self, 307 | seq, 308 | *, 309 | context, 310 | context_mask = None, 311 | target = None, 312 | freeze_enformer = False, 313 | finetune_enformer_ln_only = False, 314 | finetune_last_n_layers_only = None 315 | ): 316 | """ 317 | b - batch 318 | n - sequence length 319 | c - number of contexts (tracks) 320 | d - dimension 321 | i - sequence length (query embeddings) 322 | j - sequence length (keys / values contexts) 323 | h - attention heads 324 | """ 325 | 326 | h = self.heads 327 | 328 | enformer_kwargs = dict() 329 | 330 | if exists(target) and self.auto_set_target_length: 331 | enformer_kwargs = dict(target_length = target.shape[-2]) 332 | 333 | if self.discrete_key_value_bottleneck: 334 | embeddings = self.enformer(seq, return_only_embeddings = True, **enformer_kwargs) 335 | else: 336 | embeddings = get_enformer_embeddings(self.enformer, seq, freeze = freeze_enformer, train_layernorms_only = finetune_enformer_ln_only, train_last_n_layers_only = finetune_last_n_layers_only, enformer_kwargs = enformer_kwargs) 337 | 338 | # perform cross attention from genetic -> context 339 | 340 | if context.ndim == 2: 341 | context = rearrange(context, 'b d -> b 1 d') 342 | 343 | q = self.to_queries(self.query_norm(embeddings)) 344 | k, v = self.to_key_values(self.key_values_norm(context)).chunk(2, dim = -1) 345 | 346 | null_k, null_v = map(lambda t: repeat(t, 'd -> b 1 d', b = context.shape[0]), (self.null_key, self.null_value)) 347 | 348 | k = torch.cat((null_k, k), dim = 1) 349 | v = torch.cat((null_v, v), dim = 1) 350 | 351 | # split out head 352 | 353 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) 354 | sim = einsum('b h i d, c h j d -> b c h i j', q, k) * self.scale 355 | 356 | # masking 357 | 358 | if exists(context_mask): 359 | context_mask = F.pad(context_mask, (1, 0), value = True) 360 | context_mask =rearrange(context_mask, 'b j -> b 1 1 1 j') 361 | sim = sim.masked_fill(~context_mask, -torch.finfo(sim.dtype).max) 362 | 363 | # attention 364 | 365 | attn = sim.softmax(dim = -1) 366 | 367 | # aggregate 368 | 369 | out = einsum('b c h i j, c h j d -> b c h i d', attn, v) 370 | 371 | out = rearrange(out, 'b c h n d -> b c n (h d)', h = h) 372 | 373 | # combine heads 374 | 375 | branch_out = self.to_out(out) 376 | 377 | # residual 378 | 379 | embeddings = embeddings + branch_out 380 | 381 | # to prediction 382 | 383 | pred = self.to_pred(embeddings) 384 | 385 | if not exists(target): 386 | return pred 387 | 388 | return poisson_loss(pred, target) 389 | -------------------------------------------------------------------------------- /enformer_pytorch/metrics.py: -------------------------------------------------------------------------------- 1 | from torchmetrics import Metric 2 | from typing import Optional 3 | import torch 4 | 5 | 6 | class MeanPearsonCorrCoefPerChannel(Metric): 7 | is_differentiable: Optional[bool] = False 8 | higher_is_better: Optional[bool] = True 9 | def __init__(self, n_channels:int, dist_sync_on_step=False): 10 | """Calculates the mean pearson correlation across channels aggregated over regions""" 11 | super().__init__(dist_sync_on_step=dist_sync_on_step) 12 | self.reduce_dims=(0, 1) 13 | self.add_state("product", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum", ) 14 | self.add_state("true", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum", ) 15 | self.add_state("true_squared", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum", ) 16 | self.add_state("pred", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum", ) 17 | self.add_state("pred_squared", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum", ) 18 | self.add_state("count", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum") 19 | 20 | def update(self, preds: torch.Tensor, target: torch.Tensor): 21 | assert preds.shape == target.shape 22 | 23 | self.product += torch.sum(preds * target, dim=self.reduce_dims) 24 | self.true += torch.sum(target, dim=self.reduce_dims) 25 | self.true_squared += torch.sum(torch.square(target), dim=self.reduce_dims) 26 | self.pred += torch.sum(preds, dim=self.reduce_dims) 27 | self.pred_squared += torch.sum(torch.square(preds), dim=self.reduce_dims) 28 | self.count += torch.sum(torch.ones_like(target), dim=self.reduce_dims) 29 | 30 | def compute(self): 31 | true_mean = self.true / self.count 32 | pred_mean = self.pred / self.count 33 | 34 | covariance = (self.product 35 | - true_mean * self.pred 36 | - pred_mean * self.true 37 | + self.count * true_mean * pred_mean) 38 | 39 | true_var = self.true_squared - self.count * torch.square(true_mean) 40 | pred_var = self.pred_squared - self.count * torch.square(pred_mean) 41 | tp_var = torch.sqrt(true_var) * torch.sqrt(pred_var) 42 | correlation = covariance / tp_var 43 | return correlation 44 | -------------------------------------------------------------------------------- /enformer_pytorch/modeling_enformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | from pathlib import Path 3 | 4 | import torch 5 | from torch import nn, einsum 6 | import torch.nn.functional as F 7 | import torch.distributed as dist 8 | from torch.utils.checkpoint import checkpoint_sequential 9 | 10 | from einops import rearrange, reduce 11 | from einops.layers.torch import Rearrange 12 | 13 | from enformer_pytorch.data import str_to_one_hot, seq_indices_to_one_hot 14 | 15 | from enformer_pytorch.config_enformer import EnformerConfig 16 | 17 | from transformers import PreTrainedModel 18 | 19 | # constants 20 | 21 | SEQUENCE_LENGTH = 196_608 22 | TARGET_LENGTH = 896 23 | 24 | # gamma positions from tensorflow 25 | # addressing a difference between xlogy results from tensorflow and pytorch 26 | # solution came from @johahi 27 | 28 | DIR = Path(__file__).parents[0] 29 | TF_GAMMAS = torch.load(str(DIR / "precomputed"/ "tf_gammas.pt"), weights_only=True) 30 | 31 | # helpers 32 | 33 | def exists(val): 34 | return val is not None 35 | 36 | def default(val, d): 37 | return val if exists(val) else d 38 | 39 | def always(val): 40 | def inner(*args, **kwargs): 41 | return val 42 | return inner 43 | 44 | def map_values(fn, d): 45 | return {key: fn(values) for key, values in d.items()} 46 | 47 | def exponential_linspace_int(start, end, num, divisible_by = 1): 48 | def _round(x): 49 | return int(round(x / divisible_by) * divisible_by) 50 | 51 | base = math.exp(math.log(end / start) / (num - 1)) 52 | return [_round(start * base**i) for i in range(num)] 53 | 54 | def log(t, eps = 1e-20): 55 | return torch.log(t.clamp(min = eps)) 56 | 57 | # maybe sync batchnorm, for distributed training 58 | 59 | def MaybeSyncBatchnorm(is_distributed = None): 60 | is_distributed = default(is_distributed, dist.is_initialized() and dist.get_world_size() > 1) 61 | return nn.SyncBatchNorm if is_distributed else nn.BatchNorm1d 62 | 63 | # losses and metrics 64 | 65 | def poisson_loss(pred, target): 66 | return (pred - target * log(pred)).mean() 67 | 68 | def pearson_corr_coef(x, y, dim = 1, reduce_dims = (-1,)): 69 | x_centered = x - x.mean(dim = dim, keepdim = True) 70 | y_centered = y - y.mean(dim = dim, keepdim = True) 71 | return F.cosine_similarity(x_centered, y_centered, dim = dim).mean(dim = reduce_dims) 72 | 73 | # relative positional encoding functions 74 | 75 | def get_positional_features_exponential(positions, features, seq_len, min_half_life = 3., dtype = torch.float): 76 | max_range = math.log(seq_len) / math.log(2.) 77 | half_life = 2 ** torch.linspace(min_half_life, max_range, features, device = positions.device) 78 | half_life = half_life[None, ...] 79 | positions = positions.abs()[..., None] 80 | return torch.exp(-math.log(2.) / half_life * positions) 81 | 82 | def get_positional_features_central_mask(positions, features, seq_len, dtype = torch.float): 83 | center_widths = 2 ** torch.arange(1, features + 1, device = positions.device).to(dtype) 84 | center_widths = center_widths - 1 85 | return (center_widths[None, ...] > positions.abs()[..., None]).to(dtype) 86 | 87 | def gamma_pdf(x, concentration, rate): 88 | log_unnormalized_prob = torch.xlogy(concentration - 1., x) - rate * x 89 | log_normalization = (torch.lgamma(concentration) - concentration * torch.log(rate)) 90 | return torch.exp(log_unnormalized_prob - log_normalization) 91 | 92 | def get_positional_features_gamma(positions, features, seq_len, stddev = None, start_mean = None, eps = 1e-8, dtype = torch.float): 93 | if not exists(stddev): 94 | stddev = seq_len / (2 * features) 95 | 96 | if not exists(start_mean): 97 | start_mean = seq_len / features 98 | 99 | mean = torch.linspace(start_mean, seq_len, features, device = positions.device) 100 | 101 | mean = mean[None, ...] 102 | concentration = (mean / stddev) ** 2 103 | rate = mean / stddev ** 2 104 | 105 | probabilities = gamma_pdf(positions.to(dtype).abs()[..., None], concentration, rate) 106 | probabilities = probabilities + eps 107 | outputs = probabilities / torch.amax(probabilities, dim = -1, keepdim = True) 108 | return outputs 109 | 110 | def get_positional_embed(seq_len, feature_size, device, use_tf_gamma, dtype = torch.float): 111 | distances = torch.arange(-seq_len + 1, seq_len, device = device) 112 | 113 | assert not use_tf_gamma or seq_len == 1536, 'if using tf gamma, only sequence length of 1536 allowed for now' 114 | 115 | feature_functions = [ 116 | get_positional_features_exponential, 117 | get_positional_features_central_mask, 118 | get_positional_features_gamma if not use_tf_gamma else always(TF_GAMMAS.to(device)) 119 | ] 120 | 121 | num_components = len(feature_functions) * 2 122 | 123 | if (feature_size % num_components) != 0: 124 | raise ValueError(f'feature size is not divisible by number of components ({num_components})') 125 | 126 | num_basis_per_class = feature_size // num_components 127 | 128 | embeddings = [] 129 | for fn in feature_functions: 130 | embeddings.append(fn(distances, num_basis_per_class, seq_len, dtype = dtype)) 131 | 132 | embeddings = torch.cat(embeddings, dim = -1) 133 | embeddings = torch.cat((embeddings, torch.sign(distances)[..., None] * embeddings), dim = -1) 134 | return embeddings.to(dtype) 135 | 136 | def relative_shift(x): 137 | to_pad = torch.zeros_like(x[..., :1]) 138 | x = torch.cat((to_pad, x), dim = -1) 139 | _, h, t1, t2 = x.shape 140 | x = x.reshape(-1, h, t2, t1) 141 | x = x[:, :, 1:, :] 142 | x = x.reshape(-1, h, t1, t2 - 1) 143 | return x[..., :((t2 + 1) // 2)] 144 | 145 | # classes 146 | 147 | class Residual(nn.Module): 148 | def __init__(self, fn): 149 | super().__init__() 150 | self.fn = fn 151 | 152 | def forward(self, x, **kwargs): 153 | return self.fn(x, **kwargs) + x 154 | 155 | class GELU(nn.Module): 156 | def forward(self, x): 157 | return torch.sigmoid(1.702 * x) * x 158 | 159 | class AttentionPool(nn.Module): 160 | def __init__(self, dim, pool_size = 2): 161 | super().__init__() 162 | self.pool_size = pool_size 163 | self.pool_fn = Rearrange('b d (n p) -> b d n p', p = pool_size) 164 | 165 | self.to_attn_logits = nn.Conv2d(dim, dim, 1, bias = False) 166 | 167 | nn.init.dirac_(self.to_attn_logits.weight) 168 | 169 | with torch.no_grad(): 170 | self.to_attn_logits.weight.mul_(2) 171 | 172 | def forward(self, x): 173 | b, _, n = x.shape 174 | remainder = n % self.pool_size 175 | needs_padding = remainder > 0 176 | 177 | if needs_padding: 178 | x = F.pad(x, (0, remainder), value = 0) 179 | mask = torch.zeros((b, 1, n), dtype = torch.bool, device = x.device) 180 | mask = F.pad(mask, (0, remainder), value = True) 181 | 182 | x = self.pool_fn(x) 183 | logits = self.to_attn_logits(x) 184 | 185 | if needs_padding: 186 | mask_value = -torch.finfo(logits.dtype).max 187 | logits = logits.masked_fill(self.pool_fn(mask), mask_value) 188 | 189 | attn = logits.softmax(dim = -1) 190 | 191 | return (x * attn).sum(dim = -1) 192 | 193 | class TargetLengthCrop(nn.Module): 194 | def __init__(self, target_length): 195 | super().__init__() 196 | self.target_length = target_length 197 | 198 | def forward(self, x): 199 | seq_len, target_len = x.shape[-2], self.target_length 200 | 201 | if target_len == -1: 202 | return x 203 | 204 | if seq_len < target_len: 205 | raise ValueError(f'sequence length {seq_len} is less than target length {target_len}') 206 | 207 | trim = (target_len - seq_len) // 2 208 | 209 | if trim == 0: 210 | return x 211 | 212 | return x[:, -trim:trim] 213 | 214 | def ConvBlock(dim, dim_out = None, kernel_size = 1, is_distributed = None): 215 | batchnorm_klass = MaybeSyncBatchnorm(is_distributed = is_distributed) 216 | 217 | return nn.Sequential( 218 | batchnorm_klass(dim), 219 | GELU(), 220 | nn.Conv1d(dim, default(dim_out, dim), kernel_size, padding = kernel_size // 2) 221 | ) 222 | 223 | # attention classes 224 | 225 | class Attention(nn.Module): 226 | def __init__( 227 | self, 228 | dim, 229 | *, 230 | num_rel_pos_features, 231 | heads = 8, 232 | dim_key = 64, 233 | dim_value = 64, 234 | dropout = 0., 235 | pos_dropout = 0., 236 | use_tf_gamma = False 237 | ): 238 | super().__init__() 239 | self.scale = dim_key ** -0.5 240 | self.heads = heads 241 | 242 | self.to_q = nn.Linear(dim, dim_key * heads, bias = False) 243 | self.to_k = nn.Linear(dim, dim_key * heads, bias = False) 244 | self.to_v = nn.Linear(dim, dim_value * heads, bias = False) 245 | 246 | self.to_out = nn.Linear(dim_value * heads, dim) 247 | nn.init.zeros_(self.to_out.weight) 248 | nn.init.zeros_(self.to_out.bias) 249 | 250 | # relative positional encoding 251 | 252 | self.num_rel_pos_features = num_rel_pos_features 253 | 254 | self.to_rel_k = nn.Linear(num_rel_pos_features, dim_key * heads, bias = False) 255 | self.rel_content_bias = nn.Parameter(torch.randn(1, heads, 1, dim_key)) 256 | self.rel_pos_bias = nn.Parameter(torch.randn(1, heads, 1, dim_key)) 257 | 258 | # dropouts 259 | 260 | self.pos_dropout = nn.Dropout(pos_dropout) 261 | self.attn_dropout = nn.Dropout(dropout) 262 | 263 | # whether to use tf gamma 264 | 265 | self.use_tf_gamma = use_tf_gamma 266 | 267 | def forward(self, x): 268 | n, h, device = x.shape[-2], self.heads, x.device 269 | 270 | q = self.to_q(x) 271 | k = self.to_k(x) 272 | v = self.to_v(x) 273 | 274 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) 275 | 276 | q = q * self.scale 277 | 278 | content_logits = einsum('b h i d, b h j d -> b h i j', q + self.rel_content_bias, k) 279 | 280 | positions = get_positional_embed(n, self.num_rel_pos_features, device, use_tf_gamma = self.use_tf_gamma, dtype = self.to_rel_k.weight.dtype) 281 | positions = self.pos_dropout(positions) 282 | rel_k = self.to_rel_k(positions) 283 | 284 | rel_k = rearrange(rel_k, 'n (h d) -> h n d', h = h) 285 | rel_logits = einsum('b h i d, h j d -> b h i j', q + self.rel_pos_bias, rel_k) 286 | rel_logits = relative_shift(rel_logits) 287 | 288 | logits = content_logits + rel_logits 289 | attn = logits.softmax(dim = -1) 290 | attn = self.attn_dropout(attn) 291 | 292 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 293 | out = rearrange(out, 'b h n d -> b n (h d)') 294 | return self.to_out(out) 295 | 296 | # main class 297 | 298 | class Enformer(PreTrainedModel): 299 | config_class = EnformerConfig 300 | base_model_prefix = "enformer" 301 | 302 | @staticmethod 303 | def from_hparams(**kwargs): 304 | return Enformer(EnformerConfig(**kwargs)) 305 | 306 | def __init__(self, config): 307 | super().__init__(config) 308 | self.dim = config.dim 309 | half_dim = config.dim // 2 310 | twice_dim = config.dim * 2 311 | 312 | # create stem 313 | 314 | self.stem = nn.Sequential( 315 | nn.Conv1d(4, half_dim, 15, padding = 7), 316 | Residual(ConvBlock(half_dim)), 317 | AttentionPool(half_dim, pool_size = 2) 318 | ) 319 | 320 | # create conv tower 321 | 322 | filter_list = exponential_linspace_int(half_dim, config.dim, num = (config.num_downsamples - 1), divisible_by = config.dim_divisible_by) 323 | filter_list = [half_dim, *filter_list] 324 | 325 | conv_layers = [] 326 | for dim_in, dim_out in zip(filter_list[:-1], filter_list[1:]): 327 | conv_layers.append(nn.Sequential( 328 | ConvBlock(dim_in, dim_out, kernel_size = 5), 329 | Residual(ConvBlock(dim_out, dim_out, 1)), 330 | AttentionPool(dim_out, pool_size = 2) 331 | )) 332 | 333 | self.conv_tower = nn.Sequential(*conv_layers) 334 | 335 | # whether to use tensorflow gamma positions 336 | 337 | use_tf_gamma = config.use_tf_gamma 338 | self.use_tf_gamma = use_tf_gamma 339 | 340 | # transformer 341 | 342 | transformer = [] 343 | for _ in range(config.depth): 344 | transformer.append(nn.Sequential( 345 | Residual(nn.Sequential( 346 | nn.LayerNorm(config.dim), 347 | Attention( 348 | config.dim, 349 | heads = config.heads, 350 | dim_key = config.attn_dim_key, 351 | dim_value = config.dim // config.heads, 352 | dropout = config.attn_dropout, 353 | pos_dropout = config.pos_dropout, 354 | num_rel_pos_features = config.dim // config.heads, 355 | use_tf_gamma = use_tf_gamma 356 | ), 357 | nn.Dropout(config.dropout_rate) 358 | )), 359 | Residual(nn.Sequential( 360 | nn.LayerNorm(config.dim), 361 | nn.Linear(config.dim, config.dim * 2), 362 | nn.Dropout(config.dropout_rate), 363 | nn.ReLU(), 364 | nn.Linear(config.dim * 2, config.dim), 365 | nn.Dropout(config.dropout_rate) 366 | )) 367 | )) 368 | 369 | self.transformer = nn.Sequential(*transformer) 370 | 371 | # target cropping 372 | 373 | self.target_length = config.target_length 374 | self.crop_final = TargetLengthCrop(config.target_length) 375 | 376 | # final pointwise 377 | 378 | self.final_pointwise = nn.Sequential( 379 | Rearrange('b n d -> b d n'), 380 | ConvBlock(filter_list[-1], twice_dim, 1), 381 | Rearrange('b d n -> b n d'), 382 | nn.Dropout(config.dropout_rate / 8), 383 | GELU() 384 | ) 385 | 386 | # create trunk sequential module 387 | 388 | self._trunk = nn.Sequential( 389 | Rearrange('b n d -> b d n'), 390 | self.stem, 391 | self.conv_tower, 392 | Rearrange('b d n -> b n d'), 393 | self.transformer, 394 | self.crop_final, 395 | self.final_pointwise 396 | ) 397 | 398 | # create final heads for human and mouse 399 | 400 | self.add_heads(**config.output_heads) 401 | 402 | # use checkpointing on transformer trunk 403 | 404 | self.use_checkpointing = config.use_checkpointing 405 | 406 | def add_heads(self, **kwargs): 407 | self.output_heads = kwargs 408 | 409 | self._heads = nn.ModuleDict(map_values(lambda features: nn.Sequential( 410 | nn.Linear(self.dim * 2, features), 411 | nn.Softplus() 412 | ), kwargs)) 413 | 414 | def set_target_length(self, target_length): 415 | crop_module = self._trunk[-2] 416 | crop_module.target_length = target_length 417 | 418 | @property 419 | def trunk(self): 420 | return self._trunk 421 | 422 | @property 423 | def heads(self): 424 | return self._heads 425 | 426 | def trunk_checkpointed(self, x): 427 | x = rearrange(x, 'b n d -> b d n') 428 | x = self.stem(x) 429 | x = self.conv_tower(x) 430 | x = rearrange(x, 'b d n -> b n d') 431 | x = checkpoint_sequential(self.transformer, len(self.transformer), x) 432 | x = self.crop_final(x) 433 | x = self.final_pointwise(x) 434 | return x 435 | 436 | def forward( 437 | self, 438 | x, 439 | target = None, 440 | return_corr_coef = False, 441 | return_embeddings = False, 442 | return_only_embeddings = False, 443 | head = None, 444 | target_length = None 445 | ): 446 | if isinstance(x, list): 447 | x = str_to_one_hot(x) 448 | 449 | elif type(x) == torch.Tensor and x.dtype == torch.long: 450 | x = seq_indices_to_one_hot(x) 451 | x.to(self.device) 452 | 453 | no_batch = x.ndim == 2 454 | 455 | if no_batch: 456 | x = rearrange(x, '... -> () ...') 457 | 458 | if exists(target_length): 459 | self.set_target_length(target_length) 460 | 461 | trunk_fn = self.trunk_checkpointed if self.use_checkpointing else self._trunk 462 | x = trunk_fn(x) 463 | 464 | if no_batch: 465 | x = rearrange(x, '() ... -> ...') 466 | 467 | if return_only_embeddings: 468 | return x 469 | 470 | out = map_values(lambda fn: fn(x), self._heads) 471 | 472 | if exists(head): 473 | assert head in self._heads, f'head {head} not found' 474 | out = out[head] 475 | 476 | if exists(target): 477 | assert exists(head), 'head must be passed in if one were to calculate loss directly with targets' 478 | 479 | if return_corr_coef: 480 | return pearson_corr_coef(out, target) 481 | 482 | return poisson_loss(out, target) 483 | 484 | if return_embeddings: 485 | return out, x 486 | 487 | return out 488 | 489 | # from pretrained function 490 | 491 | def from_pretrained(name, use_tf_gamma = None, **kwargs): 492 | enformer = Enformer.from_pretrained(name, **kwargs) 493 | 494 | if name == 'EleutherAI/enformer-official-rough': 495 | use_tf_gamma = default(use_tf_gamma, True) 496 | 497 | for module in enformer.modules(): 498 | if isinstance(module, Attention): 499 | module.use_tf_gamma = use_tf_gamma 500 | 501 | return enformer 502 | -------------------------------------------------------------------------------- /enformer_pytorch/precomputed/tf_gammas.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/enformer-pytorch/5a5974d2821c728f93294731c50b55f1f55fd86d/enformer_pytorch/precomputed/tf_gammas.pt -------------------------------------------------------------------------------- /evaluate_enformer_pytorch_correlation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "metadata": { 17 | "colab": { 18 | "base_uri": "https://localhost:8080/" 19 | }, 20 | "id": "beViQ8vrM6LY", 21 | "outputId": "1a821c07-6701-457d-ddd3-a3f52a03d4a6" 22 | }, 23 | "outputs": [ 24 | { 25 | "output_type": "stream", 26 | "name": "stdout", 27 | "text": [ 28 | "Cloning into 'enformer-pytorch'...\n", 29 | "remote: Enumerating objects: 643, done.\u001b[K\n", 30 | "remote: Counting objects: 100% (132/132), done.\u001b[K\n", 31 | "remote: Compressing objects: 100% (117/117), done.\u001b[K\n", 32 | "remote: Total 643 (delta 28), reused 28 (delta 13), pack-reused 511\u001b[K\n", 33 | "Receiving objects: 100% (643/643), 8.88 MiB | 3.09 MiB/s, done.\n", 34 | "Resolving deltas: 100% (439/439), done.\n" 35 | ] 36 | } 37 | ], 38 | "source": [ 39 | "!git clone https://github.com/lucidrains/enformer-pytorch.git" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 3, 45 | "metadata": { 46 | "colab": { 47 | "base_uri": "https://localhost:8080/" 48 | }, 49 | "id": "PMXWgyEfNT_s", 50 | "outputId": "804379ef-37af-481c-c870-b3e9c7b8e2bb" 51 | }, 52 | "outputs": [ 53 | { 54 | "output_type": "stream", 55 | "name": "stdout", 56 | "text": [ 57 | "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", 58 | "Processing /content/enformer-pytorch\n", 59 | "\u001b[33m DEPRECATION: A future pip version will change local packages to be built in-place without first copying to a temporary directory. We recommend you use --use-feature=in-tree-build to test your packages with this new behavior before it becomes the default.\n", 60 | " pip 21.3 will remove support for this functionality. You can find discussion regarding this at https://github.com/pypa/pip/issues/7555.\u001b[0m\n", 61 | "Collecting einops>=0.3\n", 62 | " Downloading einops-0.4.1-py3-none-any.whl (28 kB)\n", 63 | "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from enformer-pytorch==0.5.1) (1.21.6)\n", 64 | "Requirement already satisfied: torch>=1.6 in /usr/local/lib/python3.7/dist-packages (from enformer-pytorch==0.5.1) (1.11.0+cu113)\n", 65 | "Collecting polars\n", 66 | " Downloading polars-0.13.40-cp37-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (11.9 MB)\n", 67 | "\u001b[K |████████████████████████████████| 11.9 MB 19.1 MB/s \n", 68 | "\u001b[?25hCollecting pyfaidx\n", 69 | " Downloading pyfaidx-0.7.0.tar.gz (102 kB)\n", 70 | "\u001b[K |████████████████████████████████| 102 kB 70.3 MB/s \n", 71 | "\u001b[?25hRequirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from enformer-pytorch==0.5.1) (3.13)\n", 72 | "Collecting transformers\n", 73 | " Downloading transformers-4.19.2-py3-none-any.whl (4.2 MB)\n", 74 | "\u001b[K |████████████████████████████████| 4.2 MB 56.0 MB/s \n", 75 | "\u001b[?25hRequirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.6->enformer-pytorch==0.5.1) (4.2.0)\n", 76 | "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from pyfaidx->enformer-pytorch==0.5.1) (1.15.0)\n", 77 | "Requirement already satisfied: setuptools>=0.7 in /usr/local/lib/python3.7/dist-packages (from pyfaidx->enformer-pytorch==0.5.1) (57.4.0)\n", 78 | "Collecting tokenizers!=0.11.3,<0.13,>=0.11.1\n", 79 | " Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)\n", 80 | "\u001b[K |████████████████████████████████| 6.6 MB 44.0 MB/s \n", 81 | "\u001b[?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers->enformer-pytorch==0.5.1) (4.64.0)\n", 82 | "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers->enformer-pytorch==0.5.1) (3.7.0)\n", 83 | "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers->enformer-pytorch==0.5.1) (2019.12.20)\n", 84 | "Collecting huggingface-hub<1.0,>=0.1.0\n", 85 | " Downloading huggingface_hub-0.7.0-py3-none-any.whl (86 kB)\n", 86 | "\u001b[K |████████████████████████████████| 86 kB 6.6 MB/s \n", 87 | "\u001b[?25hRequirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers->enformer-pytorch==0.5.1) (4.11.4)\n", 88 | "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers->enformer-pytorch==0.5.1) (21.3)\n", 89 | "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers->enformer-pytorch==0.5.1) (2.23.0)\n", 90 | "Collecting pyyaml\n", 91 | " Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)\n", 92 | "\u001b[K |████████████████████████████████| 596 kB 64.8 MB/s \n", 93 | "\u001b[?25hRequirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers->enformer-pytorch==0.5.1) (3.0.9)\n", 94 | "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers->enformer-pytorch==0.5.1) (3.8.0)\n", 95 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers->enformer-pytorch==0.5.1) (2022.5.18.1)\n", 96 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers->enformer-pytorch==0.5.1) (3.0.4)\n", 97 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers->enformer-pytorch==0.5.1) (1.24.3)\n", 98 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers->enformer-pytorch==0.5.1) (2.10)\n", 99 | "Building wheels for collected packages: enformer-pytorch, pyfaidx\n", 100 | " Building wheel for enformer-pytorch (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 101 | " Created wheel for enformer-pytorch: filename=enformer_pytorch-0.5.1-py3-none-any.whl size=11587 sha256=3888ddb9602b37eedeb4b73f60c5d56609dd500e54536df6fb1bb50f481fe6d5\n", 102 | " Stored in directory: /root/.cache/pip/wheels/e9/30/2d/17bcf281153d214e2a57d1ed68544840825283066370c73eda\n", 103 | " Building wheel for pyfaidx (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 104 | " Created wheel for pyfaidx: filename=pyfaidx-0.7.0-py3-none-any.whl size=27697 sha256=464766257aa70062f3ac14de5005e39af8e5375eb2be08e6f32e0a05dd994a12\n", 105 | " Stored in directory: /root/.cache/pip/wheels/df/6b/ce/46374a70af569061fa10a6c16525b0d8efe2d9a4069f8a144a\n", 106 | "Successfully built enformer-pytorch pyfaidx\n", 107 | "Installing collected packages: pyyaml, tokenizers, huggingface-hub, transformers, pyfaidx, polars, einops, enformer-pytorch\n", 108 | " Attempting uninstall: pyyaml\n", 109 | " Found existing installation: PyYAML 3.13\n", 110 | " Uninstalling PyYAML-3.13:\n", 111 | " Successfully uninstalled PyYAML-3.13\n", 112 | "Successfully installed einops-0.4.1 enformer-pytorch-0.5.1 huggingface-hub-0.7.0 polars-0.13.40 pyfaidx-0.7.0 pyyaml-6.0 tokenizers-0.12.1 transformers-4.19.2\n" 113 | ] 114 | } 115 | ], 116 | "source": [ 117 | "!cd enformer-pytorch && pip install ." 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 4, 123 | "metadata": { 124 | "id": "0XbZipT7O0yD" 125 | }, 126 | "outputs": [], 127 | "source": [ 128 | "!pip install torchmetrics kipoiseq==0.5.2 BioPython --quiet > /dev/null" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 5, 134 | "metadata": { 135 | "colab": { 136 | "base_uri": "https://localhost:8080/" 137 | }, 138 | "id": "H8miDmi7OMK1", 139 | "outputId": "4249c35d-2c6f-4c0f-cf70-16d74da61f24" 140 | }, 141 | "outputs": [ 142 | { 143 | "output_type": "stream", 144 | "name": "stdout", 145 | "text": [ 146 | "Copying gs://basenji_barnyard/hg38.ml.fa.gz...\n", 147 | "/ [1/1 files][839.8 MiB/839.8 MiB] 100% Done 58.5 MiB/s ETA 00:00:00 \n", 148 | "Operation completed over 1 objects/839.8 MiB. \n", 149 | "Copying gs://basenji_barnyard/mm10.ml.fa.gz...\n", 150 | "/ [1/1 files][800.8 MiB/800.8 MiB] 100% Done 65.9 MiB/s ETA 00:00:00 \n", 151 | "Operation completed over 1 objects/800.8 MiB. \n" 152 | ] 153 | } 154 | ], 155 | "source": [ 156 | "import torch\n", 157 | "import numpy as np\n", 158 | "import tensorflow as tf\n", 159 | "import os \n", 160 | "import json\n", 161 | "import pandas as pd\n", 162 | "import pyfaidx\n", 163 | "import kipoiseq\n", 164 | "import functools\n", 165 | "from kipoiseq import Interval\n", 166 | "\n", 167 | "SEQUENCE_LENGTH = 196_608\n", 168 | "BIN_SIZE = 128\n", 169 | "TARGET_LENGTH = 896\n", 170 | "import os\n", 171 | "fasta_dir = \"/root/data/\"\n", 172 | "!mkdir -p {fasta_dir}\n", 173 | "human_fasta_f = 'hg38.ml.fa.gz'\n", 174 | "mouse_fasta_f = 'mm10.ml.fa.gz'\n", 175 | "human_fasta_gz_path = f\"{fasta_dir}/{human_fasta_f}\"\n", 176 | "mouse_fasta_gz_path = f\"{fasta_dir}/{mouse_fasta_f}\"\n", 177 | "human_fasta_path = human_fasta_gz_path.rstrip(\".gz\")\n", 178 | "mouse_fasta_path = mouse_fasta_gz_path.rstrip(\".gz\")\n", 179 | "\n", 180 | "if not os.path.isfile(human_fasta_path):\n", 181 | " !gsutil -m cp -n gs://basenji_barnyard/{human_fasta_f} {human_fasta_gz_path}\n", 182 | " !gunzip {human_fasta_gz_path}\n", 183 | "if not os.path.isfile(mouse_fasta_path):\n", 184 | " !gsutil -m cp -n gs://basenji_barnyard/{mouse_fasta_f} {mouse_fasta_gz_path}\n", 185 | " !gunzip {mouse_fasta_gz_path}\n", 186 | "\n", 187 | "class FastaStringExtractor:\n", 188 | " \n", 189 | " def __init__(self, fasta_file):\n", 190 | " self.fasta = pyfaidx.Fasta(fasta_file)\n", 191 | " self._chromosome_sizes = {k: len(v) for k, v in self.fasta.items()}\n", 192 | "\n", 193 | " def extract(self, interval: Interval, **kwargs) -> str:\n", 194 | " # Truncate interval if it extends beyond the chromosome lengths.\n", 195 | " chromosome_length = self._chromosome_sizes[interval.chrom]\n", 196 | " trimmed_interval = Interval(interval.chrom,\n", 197 | " max(interval.start, 0),\n", 198 | " min(interval.end, chromosome_length),\n", 199 | " )\n", 200 | " # pyfaidx wants a 1-based interval\n", 201 | " sequence = str(self.fasta.get_seq(trimmed_interval.chrom,\n", 202 | " trimmed_interval.start + 1,\n", 203 | " trimmed_interval.stop).seq).upper()\n", 204 | " # Fill truncated values with N's.\n", 205 | " pad_upstream = 'N' * max(-interval.start, 0)\n", 206 | " pad_downstream = 'N' * max(interval.end - chromosome_length, 0)\n", 207 | " return pad_upstream + sequence + pad_downstream\n", 208 | "\n", 209 | " def close(self):\n", 210 | " return self.fasta.close()\n", 211 | "\n", 212 | "\n", 213 | "class BasenjiDataSet(torch.utils.data.IterableDataset):\n", 214 | " @staticmethod\n", 215 | " def get_organism_path(organism):\n", 216 | " return os.path.join('gs://basenji_barnyard/data', organism)\n", 217 | " @classmethod\n", 218 | " def get_metadata(cls, organism):\n", 219 | " # Keys:\n", 220 | " # num_targets, train_seqs, valid_seqs, test_seqs, seq_length,\n", 221 | " # pool_width, crop_bp, target_length\n", 222 | " path = os.path.join(cls.get_organism_path(organism), 'statistics.json')\n", 223 | " with tf.io.gfile.GFile(path, 'r') as f:\n", 224 | " return json.load(f)\n", 225 | " @staticmethod\n", 226 | " def one_hot_encode(sequence):\n", 227 | " return kipoiseq.transforms.functional.one_hot_dna(sequence).astype(np.float32)\n", 228 | "\n", 229 | " @classmethod\n", 230 | " def get_tfrecord_files(cls, organism, subset):\n", 231 | " # Sort the values by int(*).\n", 232 | " return sorted(tf.io.gfile.glob(os.path.join(\n", 233 | " cls.get_organism_path(organism), 'tfrecords', f'{subset}-*.tfr'\n", 234 | " )), key=lambda x: int(x.split('-')[-1].split('.')[0]))\n", 235 | " \n", 236 | " @property\n", 237 | " def num_channels(self):\n", 238 | " metadata = self.get_metadata(self.organism)\n", 239 | " return metadata['num_targets']\n", 240 | "\n", 241 | " @staticmethod\n", 242 | " def deserialize(serialized_example, metadata):\n", 243 | " \"\"\"Deserialize bytes stored in TFRecordFile.\"\"\"\n", 244 | " # Deserialization\n", 245 | " feature_map = {\n", 246 | " 'sequence': tf.io.FixedLenFeature([], tf.string), # Ignore this, resize our own bigger one\n", 247 | " 'target': tf.io.FixedLenFeature([], tf.string),\n", 248 | " }\n", 249 | " example = tf.io.parse_example(serialized_example, feature_map)\n", 250 | " sequence = tf.io.decode_raw(example['sequence'], tf.bool)\n", 251 | " sequence = tf.reshape(sequence, (metadata['seq_length'], 4))\n", 252 | " sequence = tf.cast(sequence, tf.float32)\n", 253 | "\n", 254 | " target = tf.io.decode_raw(example['target'], tf.float16)\n", 255 | " target = tf.reshape(target,\n", 256 | " (metadata['target_length'], metadata['num_targets']))\n", 257 | " target = tf.cast(target, tf.float32)\n", 258 | "\n", 259 | " return {'sequence_old': sequence,\n", 260 | " 'target': target}\n", 261 | "\n", 262 | " @classmethod\n", 263 | " def get_dataset(cls, organism, subset, num_threads=8):\n", 264 | " metadata = cls.get_metadata(organism)\n", 265 | " dataset = tf.data.TFRecordDataset(cls.get_tfrecord_files(organism, subset),\n", 266 | " compression_type='ZLIB',\n", 267 | " num_parallel_reads=num_threads).map(\n", 268 | " functools.partial(cls.deserialize, metadata=metadata)\n", 269 | " )\n", 270 | " return dataset\n", 271 | "\n", 272 | " def __init__(self, organism:str, subset:str, seq_len:int, fasta_path:str, n_to_test:int = -1):\n", 273 | " assert subset in {\"train\", \"valid\", \"test\"}\n", 274 | " assert organism in {\"human\", \"mouse\"}\n", 275 | " self.organism = organism\n", 276 | " self.subset = subset\n", 277 | " self.base_dir = self.get_organism_path(organism)\n", 278 | " self.seq_len = seq_len\n", 279 | " self.fasta_reader = FastaStringExtractor(fasta_path)\n", 280 | " self.n_to_test = n_to_test\n", 281 | " with tf.io.gfile.GFile(f\"{self.base_dir}/sequences.bed\", 'r') as f:\n", 282 | " region_df = pd.read_csv(f, sep=\"\\t\", header=None)\n", 283 | " region_df.columns = ['chrom', 'start', 'end', 'subset']\n", 284 | " self.region_df = region_df.query('subset==@subset').reset_index(drop=True)\n", 285 | " \n", 286 | " def __iter__(self):\n", 287 | " worker_info = torch.utils.data.get_worker_info()\n", 288 | " assert worker_info is None, \"Only support single process loading\"\n", 289 | " # If num_threads > 1, the following will actually shuffle the inputs! luckily we catch this with the sequence comparison\n", 290 | " basenji_iterator = self.get_dataset(self.organism, self.subset, num_threads=1).as_numpy_iterator()\n", 291 | " for i, records in enumerate(basenji_iterator):\n", 292 | " loc_row = self.region_df.iloc[i]\n", 293 | " target_interval = Interval(loc_row['chrom'], loc_row['start'], loc_row['end'])\n", 294 | " sequence_one_hot = self.one_hot_encode(self.fasta_reader.extract(target_interval.resize(self.seq_len)))\n", 295 | " if self.n_to_test >= 0 and i < self.n_to_test:\n", 296 | " old_sequence_onehot = records[\"sequence_old\"]\n", 297 | " if old_sequence_onehot.shape[0] > sequence_one_hot.shape[0]:\n", 298 | " diff = old_sequence_onehot.shape[0] - sequence_one_hot.shape[0]\n", 299 | " trim = diff//2\n", 300 | " np.testing.assert_equal(old_sequence_onehot[trim:(-trim)], sequence_one_hot)\n", 301 | " elif sequence_one_hot.shape[0] > old_sequence_onehot.shape[0]:\n", 302 | " diff = sequence_one_hot.shape[0] - old_sequence_onehot.shape[0]\n", 303 | " trim = diff//2\n", 304 | " np.testing.assert_equal(old_sequence_onehot, sequence_one_hot[trim:(-trim)])\n", 305 | " else:\n", 306 | " np.testing.assert_equal(old_sequence_onehot, sequence_one_hot)\n", 307 | " yield {\n", 308 | " \"sequence\": sequence_one_hot,\n", 309 | " \"target\": records[\"target\"],\n", 310 | " }" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": 6, 316 | "metadata": { 317 | "colab": { 318 | "base_uri": "https://localhost:8080/", 319 | "height": 81, 320 | "referenced_widgets": [ 321 | "a76d636b937c4b5691d36cb6b7ec7b17", 322 | "e06f0950617847e7a894afc8cc8ec69f", 323 | "fd6815290f9e45f69d9ac29826de71ef", 324 | "109781990b25433daa411cafa44c662a", 325 | "69789023fe8544ff8a0d7540de793041", 326 | "ef2d9fda87da4dbc825f4776a982856d", 327 | "20f0fc109b5249a3bd0e5738f6b8c06f", 328 | "7ab66e50e29a48b38a1d5c1820422a6a", 329 | "b0b4b1ebc3ac4aaf91b0112969fc64f9", 330 | "38e1b0815aaf4c01b1988f73fc3c38d2", 331 | "c177e2c35da948918a29fdcea6840d89", 332 | "beeeb50adabb48978b1c502d799713be", 333 | "a34e3298896c454abb1b0dda7c225c2b", 334 | "579ec7740cbc455a9c6d3e1bc4442612", 335 | "e486e3955a574b818227964293fa76e1", 336 | "44a8b215df314ba9beb66b44d2e13d29", 337 | "36a42d5cd33541b5aa24db50b03cef9a", 338 | "e47a62c2ab1a4fcea674274f3435f031", 339 | "02df913b81f3473cbbe0189880e20586", 340 | "9d3909432e2d42a5849427f6d1ad1371", 341 | "ec0bf3f9df7c40b8b64aef5ac62f0728", 342 | "7ba9a6bb917746d2bdb7143843480038" 343 | ] 344 | }, 345 | "id": "W_QtUyQXNfgJ", 346 | "outputId": "2b58967c-2b79-439f-a901-687a164282d3" 347 | }, 348 | "outputs": [ 349 | { 350 | "output_type": "display_data", 351 | "data": { 352 | "text/plain": [ 353 | "Downloading: 0%| | 0.00/464 [00:00 0 and i >= max_steps:\n", 448 | " break\n", 449 | " batch_gpu = {k:v.to(model.device) for k,v in batch.items()}\n", 450 | " sequence = batch_gpu['sequence']\n", 451 | " target = batch_gpu['target']\n", 452 | " with torch.no_grad():\n", 453 | " pred = model(sequence)[organism]\n", 454 | " corr_coef(preds=pred.cpu(), target=target.cpu())\n", 455 | " return corr_coef.compute().mean()\n", 456 | "compute_correlation(model, organism=\"human\", subset=\"valid\", max_steps=100)" 457 | ] 458 | }, 459 | { 460 | "cell_type": "code", 461 | "execution_count": null, 462 | "metadata": { 463 | "colab": { 464 | "background_save": true, 465 | "base_uri": "https://localhost:8080/" 466 | }, 467 | "id": "raZi83xyg8SV", 468 | "outputId": "c00dae32-0838-4b4d-a9a1-be6040133011" 469 | }, 470 | "outputs": [ 471 | { 472 | "name": "stderr", 473 | "output_type": "stream", 474 | "text": [ 475 | "/usr/local/lib/python3.7/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", 476 | " not been set for this class (MeanPearsonCorrCoefPerChannel). The property determines if `update` by\n", 477 | " default needs access to the full metric state. If this is not the case, significant speedups can be\n", 478 | " achieved and we recommend setting this to `False`.\n", 479 | " We provide an checking function\n", 480 | " `from torchmetrics.utilities import check_forward_no_full_state`\n", 481 | " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", 482 | " default for now) or if `full_state_update=False` can be used safely.\n", 483 | " \n", 484 | " warnings.warn(*args, **kwargs)\n", 485 | "100%|██████████| 2213/2213 [30:51<00:00, 1.20it/s]\n" 486 | ] 487 | }, 488 | { 489 | "data": { 490 | "text/plain": [ 491 | "tensor(0.6252)" 492 | ] 493 | }, 494 | "execution_count": null, 495 | "metadata": {}, 496 | "output_type": "execute_result" 497 | } 498 | ], 499 | "source": [ 500 | "compute_correlation(model, organism=\"human\", subset=\"valid\", max_steps=-1)" 501 | ] 502 | }, 503 | { 504 | "cell_type": "code", 505 | "execution_count": null, 506 | "metadata": { 507 | "colab": { 508 | "background_save": true 509 | }, 510 | "id": "PtWriiqConiR", 511 | "outputId": "20acb7f9-3ccc-425c-83d6-99ad88100cba" 512 | }, 513 | "outputs": [ 514 | { 515 | "name": "stderr", 516 | "output_type": "stream", 517 | "text": [ 518 | "/usr/local/lib/python3.7/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", 519 | " not been set for this class (MeanPearsonCorrCoefPerChannel). The property determines if `update` by\n", 520 | " default needs access to the full metric state. If this is not the case, significant speedups can be\n", 521 | " achieved and we recommend setting this to `False`.\n", 522 | " We provide an checking function\n", 523 | " `from torchmetrics.utilities import check_forward_no_full_state`\n", 524 | " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", 525 | " default for now) or if `full_state_update=False` can be used safely.\n", 526 | " \n", 527 | " warnings.warn(*args, **kwargs)\n", 528 | "100%|██████████| 1937/1937 [20:05<00:00, 1.61it/s]\n" 529 | ] 530 | }, 531 | { 532 | "data": { 533 | "text/plain": [ 534 | "tensor(0.6503)" 535 | ] 536 | }, 537 | "execution_count": null, 538 | "metadata": {}, 539 | "output_type": "execute_result" 540 | } 541 | ], 542 | "source": [ 543 | "compute_correlation(model, organism=\"human\", subset=\"test\", max_steps=-1)" 544 | ] 545 | }, 546 | { 547 | "cell_type": "code", 548 | "execution_count": null, 549 | "metadata": { 550 | "colab": { 551 | "background_save": true 552 | }, 553 | "id": "oRMRZUMHo5XX", 554 | "outputId": "68b054a6-e6b7-4949-9866-ad4b9615fc9e" 555 | }, 556 | "outputs": [ 557 | { 558 | "name": "stderr", 559 | "output_type": "stream", 560 | "text": [ 561 | "/usr/local/lib/python3.7/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", 562 | " not been set for this class (MeanPearsonCorrCoefPerChannel). The property determines if `update` by\n", 563 | " default needs access to the full metric state. If this is not the case, significant speedups can be\n", 564 | " achieved and we recommend setting this to `False`.\n", 565 | " We provide an checking function\n", 566 | " `from torchmetrics.utilities import check_forward_no_full_state`\n", 567 | " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", 568 | " default for now) or if `full_state_update=False` can be used safely.\n", 569 | " \n", 570 | " warnings.warn(*args, **kwargs)\n", 571 | "100%|██████████| 34021/34021 [4:58:22<00:00, 1.90it/s]\n" 572 | ] 573 | }, 574 | { 575 | "data": { 576 | "text/plain": [ 577 | "tensor(0.7415)" 578 | ] 579 | }, 580 | "execution_count": null, 581 | "metadata": {}, 582 | "output_type": "execute_result" 583 | } 584 | ], 585 | "source": [ 586 | "compute_correlation(model, organism=\"human\", subset=\"train\", max_steps=-1)" 587 | ] 588 | } 589 | ], 590 | "metadata": { 591 | "accelerator": "GPU", 592 | "colab": { 593 | "background_execution": "on", 594 | "collapsed_sections": [], 595 | "name": "evaluate_enformer_pytorch_correlation.ipynb", 596 | "provenance": [], 597 | "authorship_tag": "ABX9TyP5rds3GvFsVbB6KySflndx", 598 | "include_colab_link": true 599 | }, 600 | "kernelspec": { 601 | "display_name": "Python 3", 602 | "name": "python3" 603 | }, 604 | "language_info": { 605 | "name": "python" 606 | }, 607 | "widgets": { 608 | "application/vnd.jupyter.widget-state+json": { 609 | "a76d636b937c4b5691d36cb6b7ec7b17": { 610 | "model_module": "@jupyter-widgets/controls", 611 | "model_name": "HBoxModel", 612 | "model_module_version": "1.5.0", 613 | "state": { 614 | "_dom_classes": [], 615 | "_model_module": "@jupyter-widgets/controls", 616 | "_model_module_version": "1.5.0", 617 | "_model_name": "HBoxModel", 618 | "_view_count": null, 619 | "_view_module": "@jupyter-widgets/controls", 620 | "_view_module_version": "1.5.0", 621 | "_view_name": "HBoxView", 622 | "box_style": "", 623 | "children": [ 624 | "IPY_MODEL_e06f0950617847e7a894afc8cc8ec69f", 625 | "IPY_MODEL_fd6815290f9e45f69d9ac29826de71ef", 626 | "IPY_MODEL_109781990b25433daa411cafa44c662a" 627 | ], 628 | "layout": "IPY_MODEL_69789023fe8544ff8a0d7540de793041" 629 | } 630 | }, 631 | "e06f0950617847e7a894afc8cc8ec69f": { 632 | "model_module": "@jupyter-widgets/controls", 633 | "model_name": "HTMLModel", 634 | "model_module_version": "1.5.0", 635 | "state": { 636 | "_dom_classes": [], 637 | "_model_module": "@jupyter-widgets/controls", 638 | "_model_module_version": "1.5.0", 639 | "_model_name": "HTMLModel", 640 | "_view_count": null, 641 | "_view_module": "@jupyter-widgets/controls", 642 | "_view_module_version": "1.5.0", 643 | "_view_name": "HTMLView", 644 | "description": "", 645 | "description_tooltip": null, 646 | "layout": "IPY_MODEL_ef2d9fda87da4dbc825f4776a982856d", 647 | "placeholder": "​", 648 | "style": "IPY_MODEL_20f0fc109b5249a3bd0e5738f6b8c06f", 649 | "value": "Downloading: 100%" 650 | } 651 | }, 652 | "fd6815290f9e45f69d9ac29826de71ef": { 653 | "model_module": "@jupyter-widgets/controls", 654 | "model_name": "FloatProgressModel", 655 | "model_module_version": "1.5.0", 656 | "state": { 657 | "_dom_classes": [], 658 | "_model_module": "@jupyter-widgets/controls", 659 | "_model_module_version": "1.5.0", 660 | "_model_name": "FloatProgressModel", 661 | "_view_count": null, 662 | "_view_module": "@jupyter-widgets/controls", 663 | "_view_module_version": "1.5.0", 664 | "_view_name": "ProgressView", 665 | "bar_style": "success", 666 | "description": "", 667 | "description_tooltip": null, 668 | "layout": "IPY_MODEL_7ab66e50e29a48b38a1d5c1820422a6a", 669 | "max": 464, 670 | "min": 0, 671 | "orientation": "horizontal", 672 | "style": "IPY_MODEL_b0b4b1ebc3ac4aaf91b0112969fc64f9", 673 | "value": 464 674 | } 675 | }, 676 | "109781990b25433daa411cafa44c662a": { 677 | "model_module": "@jupyter-widgets/controls", 678 | "model_name": "HTMLModel", 679 | "model_module_version": "1.5.0", 680 | "state": { 681 | "_dom_classes": [], 682 | "_model_module": "@jupyter-widgets/controls", 683 | "_model_module_version": "1.5.0", 684 | "_model_name": "HTMLModel", 685 | "_view_count": null, 686 | "_view_module": "@jupyter-widgets/controls", 687 | "_view_module_version": "1.5.0", 688 | "_view_name": "HTMLView", 689 | "description": "", 690 | "description_tooltip": null, 691 | "layout": "IPY_MODEL_38e1b0815aaf4c01b1988f73fc3c38d2", 692 | "placeholder": "​", 693 | "style": "IPY_MODEL_c177e2c35da948918a29fdcea6840d89", 694 | "value": " 464/464 [00:00<00:00, 12.3kB/s]" 695 | } 696 | }, 697 | "69789023fe8544ff8a0d7540de793041": { 698 | "model_module": "@jupyter-widgets/base", 699 | "model_name": "LayoutModel", 700 | "model_module_version": "1.2.0", 701 | "state": { 702 | "_model_module": "@jupyter-widgets/base", 703 | "_model_module_version": "1.2.0", 704 | "_model_name": "LayoutModel", 705 | "_view_count": null, 706 | "_view_module": "@jupyter-widgets/base", 707 | "_view_module_version": "1.2.0", 708 | "_view_name": "LayoutView", 709 | "align_content": null, 710 | "align_items": null, 711 | "align_self": null, 712 | "border": null, 713 | "bottom": null, 714 | "display": null, 715 | "flex": null, 716 | "flex_flow": null, 717 | "grid_area": null, 718 | "grid_auto_columns": null, 719 | "grid_auto_flow": null, 720 | "grid_auto_rows": null, 721 | "grid_column": null, 722 | "grid_gap": null, 723 | "grid_row": null, 724 | "grid_template_areas": null, 725 | "grid_template_columns": null, 726 | "grid_template_rows": null, 727 | "height": null, 728 | "justify_content": null, 729 | "justify_items": null, 730 | "left": null, 731 | "margin": null, 732 | "max_height": null, 733 | "max_width": null, 734 | "min_height": null, 735 | "min_width": null, 736 | "object_fit": null, 737 | "object_position": null, 738 | "order": null, 739 | "overflow": null, 740 | "overflow_x": null, 741 | "overflow_y": null, 742 | "padding": null, 743 | "right": null, 744 | "top": null, 745 | "visibility": null, 746 | "width": null 747 | } 748 | }, 749 | "ef2d9fda87da4dbc825f4776a982856d": { 750 | "model_module": "@jupyter-widgets/base", 751 | "model_name": "LayoutModel", 752 | "model_module_version": "1.2.0", 753 | "state": { 754 | "_model_module": "@jupyter-widgets/base", 755 | "_model_module_version": "1.2.0", 756 | "_model_name": "LayoutModel", 757 | "_view_count": null, 758 | "_view_module": "@jupyter-widgets/base", 759 | "_view_module_version": "1.2.0", 760 | "_view_name": "LayoutView", 761 | "align_content": null, 762 | "align_items": null, 763 | "align_self": null, 764 | "border": null, 765 | "bottom": null, 766 | "display": null, 767 | "flex": null, 768 | "flex_flow": null, 769 | "grid_area": null, 770 | "grid_auto_columns": null, 771 | "grid_auto_flow": null, 772 | "grid_auto_rows": null, 773 | "grid_column": null, 774 | "grid_gap": null, 775 | "grid_row": null, 776 | "grid_template_areas": null, 777 | "grid_template_columns": null, 778 | "grid_template_rows": null, 779 | "height": null, 780 | "justify_content": null, 781 | "justify_items": null, 782 | "left": null, 783 | "margin": null, 784 | "max_height": null, 785 | "max_width": null, 786 | "min_height": null, 787 | "min_width": null, 788 | "object_fit": null, 789 | "object_position": null, 790 | "order": null, 791 | "overflow": null, 792 | "overflow_x": null, 793 | "overflow_y": null, 794 | "padding": null, 795 | "right": null, 796 | "top": null, 797 | "visibility": null, 798 | "width": null 799 | } 800 | }, 801 | "20f0fc109b5249a3bd0e5738f6b8c06f": { 802 | "model_module": "@jupyter-widgets/controls", 803 | "model_name": "DescriptionStyleModel", 804 | "model_module_version": "1.5.0", 805 | "state": { 806 | "_model_module": "@jupyter-widgets/controls", 807 | "_model_module_version": "1.5.0", 808 | "_model_name": "DescriptionStyleModel", 809 | "_view_count": null, 810 | "_view_module": "@jupyter-widgets/base", 811 | "_view_module_version": "1.2.0", 812 | "_view_name": "StyleView", 813 | "description_width": "" 814 | } 815 | }, 816 | "7ab66e50e29a48b38a1d5c1820422a6a": { 817 | "model_module": "@jupyter-widgets/base", 818 | "model_name": "LayoutModel", 819 | "model_module_version": "1.2.0", 820 | "state": { 821 | "_model_module": "@jupyter-widgets/base", 822 | "_model_module_version": "1.2.0", 823 | "_model_name": "LayoutModel", 824 | "_view_count": null, 825 | "_view_module": "@jupyter-widgets/base", 826 | "_view_module_version": "1.2.0", 827 | "_view_name": "LayoutView", 828 | "align_content": null, 829 | "align_items": null, 830 | "align_self": null, 831 | "border": null, 832 | "bottom": null, 833 | "display": null, 834 | "flex": null, 835 | "flex_flow": null, 836 | "grid_area": null, 837 | "grid_auto_columns": null, 838 | "grid_auto_flow": null, 839 | "grid_auto_rows": null, 840 | "grid_column": null, 841 | "grid_gap": null, 842 | "grid_row": null, 843 | "grid_template_areas": null, 844 | "grid_template_columns": null, 845 | "grid_template_rows": null, 846 | "height": null, 847 | "justify_content": null, 848 | "justify_items": null, 849 | "left": null, 850 | "margin": null, 851 | "max_height": null, 852 | "max_width": null, 853 | "min_height": null, 854 | "min_width": null, 855 | "object_fit": null, 856 | "object_position": null, 857 | "order": null, 858 | "overflow": null, 859 | "overflow_x": null, 860 | "overflow_y": null, 861 | "padding": null, 862 | "right": null, 863 | "top": null, 864 | "visibility": null, 865 | "width": null 866 | } 867 | }, 868 | "b0b4b1ebc3ac4aaf91b0112969fc64f9": { 869 | "model_module": "@jupyter-widgets/controls", 870 | "model_name": "ProgressStyleModel", 871 | "model_module_version": "1.5.0", 872 | "state": { 873 | "_model_module": "@jupyter-widgets/controls", 874 | "_model_module_version": "1.5.0", 875 | "_model_name": "ProgressStyleModel", 876 | "_view_count": null, 877 | "_view_module": "@jupyter-widgets/base", 878 | "_view_module_version": "1.2.0", 879 | "_view_name": "StyleView", 880 | "bar_color": null, 881 | "description_width": "" 882 | } 883 | }, 884 | "38e1b0815aaf4c01b1988f73fc3c38d2": { 885 | "model_module": "@jupyter-widgets/base", 886 | "model_name": "LayoutModel", 887 | "model_module_version": "1.2.0", 888 | "state": { 889 | "_model_module": "@jupyter-widgets/base", 890 | "_model_module_version": "1.2.0", 891 | "_model_name": "LayoutModel", 892 | "_view_count": null, 893 | "_view_module": "@jupyter-widgets/base", 894 | "_view_module_version": "1.2.0", 895 | "_view_name": "LayoutView", 896 | "align_content": null, 897 | "align_items": null, 898 | "align_self": null, 899 | "border": null, 900 | "bottom": null, 901 | "display": null, 902 | "flex": null, 903 | "flex_flow": null, 904 | "grid_area": null, 905 | "grid_auto_columns": null, 906 | "grid_auto_flow": null, 907 | "grid_auto_rows": null, 908 | "grid_column": null, 909 | "grid_gap": null, 910 | "grid_row": null, 911 | "grid_template_areas": null, 912 | "grid_template_columns": null, 913 | "grid_template_rows": null, 914 | "height": null, 915 | "justify_content": null, 916 | "justify_items": null, 917 | "left": null, 918 | "margin": null, 919 | "max_height": null, 920 | "max_width": null, 921 | "min_height": null, 922 | "min_width": null, 923 | "object_fit": null, 924 | "object_position": null, 925 | "order": null, 926 | "overflow": null, 927 | "overflow_x": null, 928 | "overflow_y": null, 929 | "padding": null, 930 | "right": null, 931 | "top": null, 932 | "visibility": null, 933 | "width": null 934 | } 935 | }, 936 | "c177e2c35da948918a29fdcea6840d89": { 937 | "model_module": "@jupyter-widgets/controls", 938 | "model_name": "DescriptionStyleModel", 939 | "model_module_version": "1.5.0", 940 | "state": { 941 | "_model_module": "@jupyter-widgets/controls", 942 | "_model_module_version": "1.5.0", 943 | "_model_name": "DescriptionStyleModel", 944 | "_view_count": null, 945 | "_view_module": "@jupyter-widgets/base", 946 | "_view_module_version": "1.2.0", 947 | "_view_name": "StyleView", 948 | "description_width": "" 949 | } 950 | }, 951 | "beeeb50adabb48978b1c502d799713be": { 952 | "model_module": "@jupyter-widgets/controls", 953 | "model_name": "HBoxModel", 954 | "model_module_version": "1.5.0", 955 | "state": { 956 | "_dom_classes": [], 957 | "_model_module": "@jupyter-widgets/controls", 958 | "_model_module_version": "1.5.0", 959 | "_model_name": "HBoxModel", 960 | "_view_count": null, 961 | "_view_module": "@jupyter-widgets/controls", 962 | "_view_module_version": "1.5.0", 963 | "_view_name": "HBoxView", 964 | "box_style": "", 965 | "children": [ 966 | "IPY_MODEL_a34e3298896c454abb1b0dda7c225c2b", 967 | "IPY_MODEL_579ec7740cbc455a9c6d3e1bc4442612", 968 | "IPY_MODEL_e486e3955a574b818227964293fa76e1" 969 | ], 970 | "layout": "IPY_MODEL_44a8b215df314ba9beb66b44d2e13d29" 971 | } 972 | }, 973 | "a34e3298896c454abb1b0dda7c225c2b": { 974 | "model_module": "@jupyter-widgets/controls", 975 | "model_name": "HTMLModel", 976 | "model_module_version": "1.5.0", 977 | "state": { 978 | "_dom_classes": [], 979 | "_model_module": "@jupyter-widgets/controls", 980 | "_model_module_version": "1.5.0", 981 | "_model_name": "HTMLModel", 982 | "_view_count": null, 983 | "_view_module": "@jupyter-widgets/controls", 984 | "_view_module_version": "1.5.0", 985 | "_view_name": "HTMLView", 986 | "description": "", 987 | "description_tooltip": null, 988 | "layout": "IPY_MODEL_36a42d5cd33541b5aa24db50b03cef9a", 989 | "placeholder": "​", 990 | "style": "IPY_MODEL_e47a62c2ab1a4fcea674274f3435f031", 991 | "value": "Downloading: 100%" 992 | } 993 | }, 994 | "579ec7740cbc455a9c6d3e1bc4442612": { 995 | "model_module": "@jupyter-widgets/controls", 996 | "model_name": "FloatProgressModel", 997 | "model_module_version": "1.5.0", 998 | "state": { 999 | "_dom_classes": [], 1000 | "_model_module": "@jupyter-widgets/controls", 1001 | "_model_module_version": "1.5.0", 1002 | "_model_name": "FloatProgressModel", 1003 | "_view_count": null, 1004 | "_view_module": "@jupyter-widgets/controls", 1005 | "_view_module_version": "1.5.0", 1006 | "_view_name": "ProgressView", 1007 | "bar_style": "success", 1008 | "description": "", 1009 | "description_tooltip": null, 1010 | "layout": "IPY_MODEL_02df913b81f3473cbbe0189880e20586", 1011 | "max": 1005149571, 1012 | "min": 0, 1013 | "orientation": "horizontal", 1014 | "style": "IPY_MODEL_9d3909432e2d42a5849427f6d1ad1371", 1015 | "value": 1005149571 1016 | } 1017 | }, 1018 | "e486e3955a574b818227964293fa76e1": { 1019 | "model_module": "@jupyter-widgets/controls", 1020 | "model_name": "HTMLModel", 1021 | "model_module_version": "1.5.0", 1022 | "state": { 1023 | "_dom_classes": [], 1024 | "_model_module": "@jupyter-widgets/controls", 1025 | "_model_module_version": "1.5.0", 1026 | "_model_name": "HTMLModel", 1027 | "_view_count": null, 1028 | "_view_module": "@jupyter-widgets/controls", 1029 | "_view_module_version": "1.5.0", 1030 | "_view_name": "HTMLView", 1031 | "description": "", 1032 | "description_tooltip": null, 1033 | "layout": "IPY_MODEL_ec0bf3f9df7c40b8b64aef5ac62f0728", 1034 | "placeholder": "​", 1035 | "style": "IPY_MODEL_7ba9a6bb917746d2bdb7143843480038", 1036 | "value": " 959M/959M [00:36<00:00, 58.7MB/s]" 1037 | } 1038 | }, 1039 | "44a8b215df314ba9beb66b44d2e13d29": { 1040 | "model_module": "@jupyter-widgets/base", 1041 | "model_name": "LayoutModel", 1042 | "model_module_version": "1.2.0", 1043 | "state": { 1044 | "_model_module": "@jupyter-widgets/base", 1045 | "_model_module_version": "1.2.0", 1046 | "_model_name": "LayoutModel", 1047 | "_view_count": null, 1048 | "_view_module": "@jupyter-widgets/base", 1049 | "_view_module_version": "1.2.0", 1050 | "_view_name": "LayoutView", 1051 | "align_content": null, 1052 | "align_items": null, 1053 | "align_self": null, 1054 | "border": null, 1055 | "bottom": null, 1056 | "display": null, 1057 | "flex": null, 1058 | "flex_flow": null, 1059 | "grid_area": null, 1060 | "grid_auto_columns": null, 1061 | "grid_auto_flow": null, 1062 | "grid_auto_rows": null, 1063 | "grid_column": null, 1064 | "grid_gap": null, 1065 | "grid_row": null, 1066 | "grid_template_areas": null, 1067 | "grid_template_columns": null, 1068 | "grid_template_rows": null, 1069 | "height": null, 1070 | "justify_content": null, 1071 | "justify_items": null, 1072 | "left": null, 1073 | "margin": null, 1074 | "max_height": null, 1075 | "max_width": null, 1076 | "min_height": null, 1077 | "min_width": null, 1078 | "object_fit": null, 1079 | "object_position": null, 1080 | "order": null, 1081 | "overflow": null, 1082 | "overflow_x": null, 1083 | "overflow_y": null, 1084 | "padding": null, 1085 | "right": null, 1086 | "top": null, 1087 | "visibility": null, 1088 | "width": null 1089 | } 1090 | }, 1091 | "36a42d5cd33541b5aa24db50b03cef9a": { 1092 | "model_module": "@jupyter-widgets/base", 1093 | "model_name": "LayoutModel", 1094 | "model_module_version": "1.2.0", 1095 | "state": { 1096 | "_model_module": "@jupyter-widgets/base", 1097 | "_model_module_version": "1.2.0", 1098 | "_model_name": "LayoutModel", 1099 | "_view_count": null, 1100 | "_view_module": "@jupyter-widgets/base", 1101 | "_view_module_version": "1.2.0", 1102 | "_view_name": "LayoutView", 1103 | "align_content": null, 1104 | "align_items": null, 1105 | "align_self": null, 1106 | "border": null, 1107 | "bottom": null, 1108 | "display": null, 1109 | "flex": null, 1110 | "flex_flow": null, 1111 | "grid_area": null, 1112 | "grid_auto_columns": null, 1113 | "grid_auto_flow": null, 1114 | "grid_auto_rows": null, 1115 | "grid_column": null, 1116 | "grid_gap": null, 1117 | "grid_row": null, 1118 | "grid_template_areas": null, 1119 | "grid_template_columns": null, 1120 | "grid_template_rows": null, 1121 | "height": null, 1122 | "justify_content": null, 1123 | "justify_items": null, 1124 | "left": null, 1125 | "margin": null, 1126 | "max_height": null, 1127 | "max_width": null, 1128 | "min_height": null, 1129 | "min_width": null, 1130 | "object_fit": null, 1131 | "object_position": null, 1132 | "order": null, 1133 | "overflow": null, 1134 | "overflow_x": null, 1135 | "overflow_y": null, 1136 | "padding": null, 1137 | "right": null, 1138 | "top": null, 1139 | "visibility": null, 1140 | "width": null 1141 | } 1142 | }, 1143 | "e47a62c2ab1a4fcea674274f3435f031": { 1144 | "model_module": "@jupyter-widgets/controls", 1145 | "model_name": "DescriptionStyleModel", 1146 | "model_module_version": "1.5.0", 1147 | "state": { 1148 | "_model_module": "@jupyter-widgets/controls", 1149 | "_model_module_version": "1.5.0", 1150 | "_model_name": "DescriptionStyleModel", 1151 | "_view_count": null, 1152 | "_view_module": "@jupyter-widgets/base", 1153 | "_view_module_version": "1.2.0", 1154 | "_view_name": "StyleView", 1155 | "description_width": "" 1156 | } 1157 | }, 1158 | "02df913b81f3473cbbe0189880e20586": { 1159 | "model_module": "@jupyter-widgets/base", 1160 | "model_name": "LayoutModel", 1161 | "model_module_version": "1.2.0", 1162 | "state": { 1163 | "_model_module": "@jupyter-widgets/base", 1164 | "_model_module_version": "1.2.0", 1165 | "_model_name": "LayoutModel", 1166 | "_view_count": null, 1167 | "_view_module": "@jupyter-widgets/base", 1168 | "_view_module_version": "1.2.0", 1169 | "_view_name": "LayoutView", 1170 | "align_content": null, 1171 | "align_items": null, 1172 | "align_self": null, 1173 | "border": null, 1174 | "bottom": null, 1175 | "display": null, 1176 | "flex": null, 1177 | "flex_flow": null, 1178 | "grid_area": null, 1179 | "grid_auto_columns": null, 1180 | "grid_auto_flow": null, 1181 | "grid_auto_rows": null, 1182 | "grid_column": null, 1183 | "grid_gap": null, 1184 | "grid_row": null, 1185 | "grid_template_areas": null, 1186 | "grid_template_columns": null, 1187 | "grid_template_rows": null, 1188 | "height": null, 1189 | "justify_content": null, 1190 | "justify_items": null, 1191 | "left": null, 1192 | "margin": null, 1193 | "max_height": null, 1194 | "max_width": null, 1195 | "min_height": null, 1196 | "min_width": null, 1197 | "object_fit": null, 1198 | "object_position": null, 1199 | "order": null, 1200 | "overflow": null, 1201 | "overflow_x": null, 1202 | "overflow_y": null, 1203 | "padding": null, 1204 | "right": null, 1205 | "top": null, 1206 | "visibility": null, 1207 | "width": null 1208 | } 1209 | }, 1210 | "9d3909432e2d42a5849427f6d1ad1371": { 1211 | "model_module": "@jupyter-widgets/controls", 1212 | "model_name": "ProgressStyleModel", 1213 | "model_module_version": "1.5.0", 1214 | "state": { 1215 | "_model_module": "@jupyter-widgets/controls", 1216 | "_model_module_version": "1.5.0", 1217 | "_model_name": "ProgressStyleModel", 1218 | "_view_count": null, 1219 | "_view_module": "@jupyter-widgets/base", 1220 | "_view_module_version": "1.2.0", 1221 | "_view_name": "StyleView", 1222 | "bar_color": null, 1223 | "description_width": "" 1224 | } 1225 | }, 1226 | "ec0bf3f9df7c40b8b64aef5ac62f0728": { 1227 | "model_module": "@jupyter-widgets/base", 1228 | "model_name": "LayoutModel", 1229 | "model_module_version": "1.2.0", 1230 | "state": { 1231 | "_model_module": "@jupyter-widgets/base", 1232 | "_model_module_version": "1.2.0", 1233 | "_model_name": "LayoutModel", 1234 | "_view_count": null, 1235 | "_view_module": "@jupyter-widgets/base", 1236 | "_view_module_version": "1.2.0", 1237 | "_view_name": "LayoutView", 1238 | "align_content": null, 1239 | "align_items": null, 1240 | "align_self": null, 1241 | "border": null, 1242 | "bottom": null, 1243 | "display": null, 1244 | "flex": null, 1245 | "flex_flow": null, 1246 | "grid_area": null, 1247 | "grid_auto_columns": null, 1248 | "grid_auto_flow": null, 1249 | "grid_auto_rows": null, 1250 | "grid_column": null, 1251 | "grid_gap": null, 1252 | "grid_row": null, 1253 | "grid_template_areas": null, 1254 | "grid_template_columns": null, 1255 | "grid_template_rows": null, 1256 | "height": null, 1257 | "justify_content": null, 1258 | "justify_items": null, 1259 | "left": null, 1260 | "margin": null, 1261 | "max_height": null, 1262 | "max_width": null, 1263 | "min_height": null, 1264 | "min_width": null, 1265 | "object_fit": null, 1266 | "object_position": null, 1267 | "order": null, 1268 | "overflow": null, 1269 | "overflow_x": null, 1270 | "overflow_y": null, 1271 | "padding": null, 1272 | "right": null, 1273 | "top": null, 1274 | "visibility": null, 1275 | "width": null 1276 | } 1277 | }, 1278 | "7ba9a6bb917746d2bdb7143843480038": { 1279 | "model_module": "@jupyter-widgets/controls", 1280 | "model_name": "DescriptionStyleModel", 1281 | "model_module_version": "1.5.0", 1282 | "state": { 1283 | "_model_module": "@jupyter-widgets/controls", 1284 | "_model_module_version": "1.5.0", 1285 | "_model_name": "DescriptionStyleModel", 1286 | "_view_count": null, 1287 | "_view_module": "@jupyter-widgets/base", 1288 | "_view_module_version": "1.2.0", 1289 | "_view_name": "StyleView", 1290 | "description_width": "" 1291 | } 1292 | } 1293 | } 1294 | } 1295 | }, 1296 | "nbformat": 4, 1297 | "nbformat_minor": 0 1298 | } -------------------------------------------------------------------------------- /scripts/tf_to_torch.py: -------------------------------------------------------------------------------- 1 | from einops import rearrange 2 | 3 | def copy_bn(mod, vars, path): 4 | bn_offset = vars[f'{path}offset:0'] 5 | bn_scale = vars[f'{path}scale:0'] 6 | 7 | ema_path = '/'.join(path.split('/')[:-1]) + '/' 8 | bn_running_mean = vars[f'{ema_path}moving_mean/average:0'] 9 | bn_running_var = vars[f'{ema_path}moving_variance/average:0'] 10 | 11 | mod.weight.data.copy_(bn_scale) 12 | mod.bias.data.copy_(bn_offset) 13 | 14 | mod.running_var.data.copy_(rearrange(bn_running_var, '1 1 d -> d')) 15 | mod.running_mean.data.copy_(rearrange(bn_running_mean, '1 1 d -> d')) 16 | 17 | def copy_conv(mod, vars, path): 18 | bias = vars[f'{path}b:0'] 19 | weight = vars[f'{path}w:0'] 20 | mod.weight.data.copy_(rearrange(weight, 'k i o -> o i k')) 21 | mod.bias.data.copy_(bias) 22 | 23 | def copy_attn_pool(mod, vars, path): 24 | attn_pool_proj = vars[path] 25 | mod.to_attn_logits.weight.data.copy_(rearrange(attn_pool_proj, 'i o -> o i 1 1')) 26 | 27 | def copy_linear(mod, vars, path, has_bias = True): 28 | weight = vars[f'{path}w:0'] 29 | mod.weight.data.copy_(rearrange(weight, 'i o -> o i')) 30 | 31 | if not has_bias: 32 | return 33 | 34 | bias = vars[f'{path}b:0'] 35 | mod.bias.data.copy_(bias) 36 | 37 | def copy_ln(mod, vars, path): 38 | weight = vars[f'{path}scale:0'] 39 | bias = vars[f'{path}offset:0'] 40 | mod.weight.data.copy_(weight) 41 | mod.bias.data.copy_(bias) 42 | 43 | def get_tf_vars(tf_model): 44 | return {v.name: (torch.from_numpy(v.numpy()) if isinstance(v.numpy(), np.ndarray) else None) for v in tf_model.variables} 45 | 46 | def copy_tf_to_pytorch(tf_model, pytorch_model): 47 | tf_vars = get_tf_vars(tf_model) 48 | stem_conv = pytorch_model.stem[0] 49 | stem_point_bn = pytorch_model.stem[1].fn[0] 50 | stem_point_conv = pytorch_model.stem[1].fn[2] 51 | stem_attn_pool = pytorch_model.stem[2] 52 | 53 | copy_conv(stem_conv, tf_vars, 'enformer/trunk/stem/conv1_d/') 54 | copy_bn(stem_point_bn, tf_vars, 'enformer/trunk/stem/pointwise_conv_block/cross_replica_batch_norm/') 55 | copy_conv(stem_point_conv, tf_vars, 'enformer/trunk/stem/pointwise_conv_block/conv1_d/') 56 | copy_attn_pool(stem_attn_pool, tf_vars, 'enformer/trunk/stem/softmax_pooling/linear/w:0') 57 | 58 | for ind, tower_block in enumerate(pytorch_model.conv_tower): 59 | tower_bn = tower_block[0][0] 60 | tower_conv = tower_block[0][2] 61 | tower_point_bn = tower_block[1].fn[0] 62 | tower_point_conv = tower_block[1].fn[2] 63 | tower_attn_pool = tower_block[2] 64 | 65 | conv_path = f'enformer/trunk/conv_tower/conv_tower_block_{ind}/conv_block/conv1_d/' 66 | bn_path = f'enformer/trunk/conv_tower/conv_tower_block_{ind}/conv_block/cross_replica_batch_norm/' 67 | point_conv_path = f'enformer/trunk/conv_tower/conv_tower_block_{ind}/pointwise_conv_block/conv1_d/' 68 | point_bn_path = f'enformer/trunk/conv_tower/conv_tower_block_{ind}/pointwise_conv_block/cross_replica_batch_norm/' 69 | attn_pool_path = f'enformer/trunk/conv_tower/conv_tower_block_{ind}/softmax_pooling/linear/w:0' 70 | 71 | copy_bn(tower_bn, tf_vars, bn_path) 72 | copy_conv(tower_conv, tf_vars, conv_path) 73 | copy_bn(tower_point_bn, tf_vars, point_bn_path) 74 | copy_conv(tower_point_conv, tf_vars, point_conv_path) 75 | copy_attn_pool(tower_attn_pool, tf_vars, attn_pool_path) 76 | 77 | for ind, transformer_block in enumerate(pytorch_model.transformer): 78 | attn_ln_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/layer_norm/' 79 | attn_q_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/attention_{ind}/q_layer/' 80 | attn_k_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/attention_{ind}/k_layer/' 81 | attn_r_k_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/attention_{ind}/r_k_layer/' 82 | attn_v_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/attention_{ind}/v_layer/' 83 | attn_out_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/attention_{ind}/embedding_layer/' 84 | 85 | attn_content_bias_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/attention_{ind}/r_w_bias:0' 86 | attn_rel_bias_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/attention_{ind}/r_r_bias:0' 87 | 88 | ff_ln_path = f'enformer/trunk/transformer/transformer_block_{ind}/mlp/layer_norm/' 89 | 90 | # https://github.com/deepmind/deepmind-research/blob/master/enformer/enformer.py#L119 91 | # needs to be edited to snt.Linear(channels * 2, name = 'project_in') and snt.Linear(channels, name = 'project_out') or variables are not accessible 92 | ff_linear1_path = f'enformer/trunk/transformer/transformer_block_{ind}/mlp/project_in/' 93 | ff_linear2_path = f'enformer/trunk/transformer/transformer_block_{ind}/mlp/project_out/' 94 | 95 | attn = transformer_block[0] 96 | attn_ln = attn.fn[0] 97 | mha = attn.fn[1] 98 | 99 | copy_linear(mha.to_q, tf_vars, attn_q_path, has_bias = False) 100 | copy_linear(mha.to_k, tf_vars, attn_k_path, has_bias = False) 101 | copy_linear(mha.to_rel_k, tf_vars, attn_r_k_path, has_bias = False) 102 | copy_linear(mha.to_v, tf_vars, attn_v_path, has_bias = False) 103 | copy_linear(mha.to_out, tf_vars, attn_out_path) 104 | 105 | mha.rel_content_bias.data.copy_(tf_vars[attn_content_bias_path]) 106 | mha.rel_pos_bias.data.copy_(tf_vars[attn_rel_bias_path]) 107 | 108 | ff = transformer_block[-1] 109 | ff_ln = ff.fn[0] 110 | ff_linear1 = ff.fn[1] 111 | ff_linear2 = ff.fn[4] 112 | 113 | copy_ln(attn_ln, tf_vars, attn_ln_path) 114 | 115 | copy_ln(ff_ln, tf_vars, ff_ln_path) 116 | copy_linear(ff_linear1, tf_vars, ff_linear1_path) 117 | copy_linear(ff_linear2, tf_vars, ff_linear2_path) 118 | 119 | final_bn = pytorch_model.final_pointwise[1][0] 120 | final_conv = pytorch_model.final_pointwise[1][2] 121 | 122 | copy_bn(final_bn, tf_vars, 'enformer/trunk/final_pointwise/conv_block/cross_replica_batch_norm/') 123 | copy_conv(final_conv, tf_vars, 'enformer/trunk/final_pointwise/conv_block/conv1_d/') 124 | 125 | human_linear = pytorch_model._heads['human'][0] 126 | mouse_linear = pytorch_model._heads['mouse'][0] 127 | 128 | copy_linear(human_linear, tf_vars, 'enformer/heads/head_human/linear/') 129 | copy_linear(mouse_linear, tf_vars, 'enformer/heads/head_mouse/linear/') 130 | 131 | print('success') -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'enformer-pytorch', 5 | packages = find_packages(exclude=[]), 6 | include_package_data = True, 7 | version = '0.8.10', 8 | license='MIT', 9 | description = 'Enformer - Pytorch', 10 | author = 'Phil Wang', 11 | author_email = 'lucidrains@gmail.com', 12 | long_description_content_type = 'text/markdown', 13 | url = 'https://github.com/lucidrains/enformer-pytorch', 14 | keywords = [ 15 | 'artificial intelligence', 16 | 'transformer', 17 | 'gene-expression' 18 | ], 19 | install_requires=[ 20 | 'discrete-key-value-bottleneck-pytorch>=0.0.8', 21 | 'einops>=0.3', 22 | 'numpy', 23 | 'torch>=1.6', 24 | 'torchmetrics', 25 | 'polars', 26 | 'pyfaidx', 27 | 'pyyaml', 28 | 'transformers[torch]', 29 | ], 30 | classifiers=[ 31 | 'Development Status :: 4 - Beta', 32 | 'Intended Audience :: Developers', 33 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 34 | 'License :: OSI Approved :: MIT License', 35 | 'Programming Language :: Python :: 3.6', 36 | ], 37 | ) 38 | -------------------------------------------------------------------------------- /test_pretrained.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from enformer_pytorch import from_pretrained 3 | 4 | enformer = from_pretrained('EleutherAI/enformer-official-rough', use_tf_gamma = False).cuda() 5 | enformer.eval() 6 | 7 | data = torch.load('./data/test-sample.pt') 8 | seq, target = data['sequence'].cuda(), data['target'].cuda() 9 | 10 | with torch.no_grad(): 11 | corr_coef = enformer( 12 | seq, 13 | target = target, 14 | return_corr_coef = True, 15 | head = 'human' 16 | ) 17 | 18 | print(corr_coef) 19 | assert corr_coef > 0.1 20 | --------------------------------------------------------------------------------