├── .gitattributes
├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── MANIFEST.in
├── README.md
├── data
└── test-sample.pt
├── enformer.png
├── enformer_pytorch
├── __init__.py
├── config_enformer.py
├── data.py
├── finetune.py
├── metrics.py
├── modeling_enformer.py
└── precomputed
│ └── tf_gammas.pt
├── evaluate_enformer_pytorch_correlation.ipynb
├── scripts
└── tf_to_torch.py
├── setup.py
└── test_pretrained.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | * text=auto
2 | *.ipynb filter=nbstripout
3 | *.ipynb diff=ipynb
4 | *.ipynb linguist-language=Python
5 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | jobs:
16 | deploy:
17 |
18 | runs-on: ubuntu-latest
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: '3.x'
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install build
30 | - name: Build package
31 | run: python -m build
32 | - name: Publish package
33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34 | with:
35 | user: __token__
36 | password: ${{ secrets.PYPI_API_TOKEN }}
37 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include enformer_pytorch/precomputed/tf_gammas.pt
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Enformer - Pytorch
4 |
5 | Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch. This repository also contains the means to fine tune pretrained models for your downstream tasks. The original tensorflow sonnet code can be found here.
6 |
7 | Update: finetuned for predicting pseudobulk chromatin accessibility here
8 |
9 | ## Install
10 |
11 | ```bash
12 | $ pip install enformer-pytorch
13 | ```
14 |
15 | ## Usage
16 |
17 | ```python
18 | import torch
19 | from enformer_pytorch import Enformer
20 |
21 | model = Enformer.from_hparams(
22 | dim = 1536,
23 | depth = 11,
24 | heads = 8,
25 | output_heads = dict(human = 5313, mouse = 1643),
26 | target_length = 896,
27 | )
28 |
29 | seq = torch.randint(0, 5, (1, 196_608)) # for ACGTN, in that order (-1 for padding)
30 | output = model(seq)
31 |
32 | output['human'] # (1, 896, 5313)
33 | output['mouse'] # (1, 896, 1643)
34 | ```
35 |
36 | You can also directly pass in the sequence as one-hot encodings, which must be float values
37 |
38 | ```python
39 | import torch
40 | from enformer_pytorch import Enformer, seq_indices_to_one_hot
41 |
42 | model = Enformer.from_hparams(
43 | dim = 1536,
44 | depth = 11,
45 | heads = 8,
46 | output_heads = dict(human = 5313, mouse = 1643),
47 | target_length = 896,
48 | )
49 |
50 | seq = torch.randint(0, 5, (1, 196_608))
51 | one_hot = seq_indices_to_one_hot(seq)
52 |
53 | output = model(one_hot)
54 |
55 | output['human'] # (1, 896, 5313)
56 | output['mouse'] # (1, 896, 1643)
57 | ```
58 |
59 | Finally, one can fetch the embeddings, for fine-tuning and otherwise, by setting the `return_embeddings` flag to be `True` on forward
60 |
61 | ```python
62 | import torch
63 | from enformer_pytorch import Enformer, seq_indices_to_one_hot
64 |
65 | model = Enformer.from_hparams(
66 | dim = 1536,
67 | depth = 11,
68 | heads = 8,
69 | output_heads = dict(human = 5313, mouse = 1643),
70 | target_length = 896,
71 | )
72 |
73 | seq = torch.randint(0, 5, (1, 196_608))
74 | one_hot = seq_indices_to_one_hot(seq)
75 |
76 | output, embeddings = model(one_hot, return_embeddings = True)
77 |
78 | embeddings # (1, 896, 3072)
79 | ```
80 |
81 | For training, you can directly pass the head and target in to get the poisson loss
82 |
83 | ```python
84 | import torch
85 | from enformer_pytorch import Enformer, seq_indices_to_one_hot
86 |
87 | model = Enformer.from_hparams(
88 | dim = 1536,
89 | depth = 11,
90 | heads = 8,
91 | output_heads = dict(human = 5313, mouse = 1643),
92 | target_length = 200,
93 | ).cuda()
94 |
95 | seq = torch.randint(0, 5, (196_608 // 2,)).cuda()
96 | target = torch.randn(200, 5313).cuda()
97 |
98 | loss = model(
99 | seq,
100 | head = 'human',
101 | target = target
102 | )
103 |
104 | loss.backward()
105 |
106 | # after much training
107 |
108 | corr_coef = model(
109 | seq,
110 | head = 'human',
111 | target = target,
112 | return_corr_coef = True
113 | )
114 |
115 | corr_coef # pearson R, used as a metric in the paper
116 | ```
117 |
118 | ## Pretrained Model
119 |
120 | Deepmind has released the weights for their tensorflow sonnet Enformer model! I have ported it over to Pytorch and uploaded it to 🤗 Huggingface (~1GB). There are still some rounding errors that seem to be accruing across the layers, resulting in an absolute error as high as `0.5`. However, correlation coefficient look good so I am releasing the 'rough'ly working version. Will keep working on figuring out where the numerical errors are happening (it may be the attention pooling module, as I noticed the attention logits are pretty high).
121 |
122 | Update: John St. John did some work and found that the `enformer-official-rough` model hits the reported marks in the paper - human pearson R of `0.625` for validation, and `0.65` for test.
123 |
124 | Update: As of version 0.8.0, if one were to use the `from_pretrained` function to load the pretrained model, it should automatically use precomputed gamma positions to address a difference between tensorflow and pytorch `xlogy`. This should resolve the numerical discrepancy above. If you were to further finetune and not be using the `from_pretrained` function, please make sure to set `use_tf_gamma = True` when using `.from_hparams` to instantiate the `Enformer`
125 |
126 | ```bash
127 | $ pip install enformer-pytorch>=0.5
128 | ````
129 |
130 | Loading the model
131 |
132 | ```python
133 | from enformer_pytorch import from_pretrained
134 |
135 | enformer = from_pretrained('EleutherAI/enformer-official-rough')
136 | ```
137 |
138 | Quick sanity check on a single human validation point
139 |
140 | ```python
141 | $ python test_pretrained.py
142 | # 0.5963 correlation coefficient on a validation sample
143 | ```
144 |
145 | This is all made possible thanks to HuggingFace's [custom model](https://huggingface.co/docs/transformers/master/en/custom_models) feature.
146 |
147 | You can also load, with overriding of the `target_length` parameter, if you are working with shorter sequence lengths
148 |
149 | ```python
150 | from enformer_pytorch import from_pretrained
151 |
152 | model = from_pretrained('EleutherAI/enformer-official-rough', target_length = 128, dropout_rate = 0.1)
153 |
154 | # do your fine-tuning
155 | ```
156 |
157 | To save on memory during fine-tuning a large Enformer model
158 |
159 | ```python
160 | from enformer_pytorch import from_pretrained
161 |
162 | enformer = from_pretrained('EleutherAI/enformer-official-rough', use_checkpointing = True)
163 |
164 | # finetune enformer on a limited budget
165 | ```
166 |
167 | ## Fine-tuning
168 |
169 | This repository will also allow for easy fine-tuning of Enformer.
170 |
171 | Fine-tuning on new tracks
172 |
173 | ```python
174 | import torch
175 | from enformer_pytorch import from_pretrained
176 | from enformer_pytorch.finetune import HeadAdapterWrapper
177 |
178 | enformer = from_pretrained('EleutherAI/enformer-official-rough')
179 |
180 | model = HeadAdapterWrapper(
181 | enformer = enformer,
182 | num_tracks = 128,
183 | post_transformer_embed = False # by default, embeddings are taken from after the final pointwise block w/ conv -> gelu - but if you'd like the embeddings right after the transformer block with a learned layernorm, set this to True
184 | ).cuda()
185 |
186 | seq = torch.randint(0, 5, (1, 196_608 // 2,)).cuda()
187 | target = torch.randn(1, 200, 128).cuda() # 128 tracks
188 |
189 | loss = model(seq, target = target)
190 | loss.backward()
191 | ```
192 |
193 | Finetuning on contextual data (cell type, transcription factor, etc)
194 |
195 | ```python
196 | import torch
197 | from enformer_pytorch import from_pretrained
198 | from enformer_pytorch.finetune import ContextAdapterWrapper
199 |
200 | enformer = from_pretrained('EleutherAI/enformer-official-rough')
201 |
202 | model = ContextAdapterWrapper(
203 | enformer = enformer,
204 | context_dim = 1024
205 | ).cuda()
206 |
207 | seq = torch.randint(0, 5, (1, 196_608 // 2,)).cuda()
208 |
209 | target = torch.randn(1, 200, 4).cuda() # 4 tracks
210 | context = torch.randn(4, 1024).cuda() # 4 contexts for the different 'tracks'
211 |
212 | loss = model(
213 | seq,
214 | context = context,
215 | target = target
216 | )
217 |
218 | loss.backward()
219 | ```
220 |
221 | Finally, there is also a way to use attention aggregation from a set of context embeddings (or a single context embedding). Simply use the `ContextAttentionAdapterWrapper`
222 |
223 | ```python
224 | import torch
225 | from enformer_pytorch import from_pretrained
226 | from enformer_pytorch.finetune import ContextAttentionAdapterWrapper
227 |
228 | enformer = from_pretrained('EleutherAI/enformer-official-rough')
229 |
230 | model = ContextAttentionAdapterWrapper(
231 | enformer = enformer,
232 | context_dim = 1024,
233 | heads = 8, # number of heads in the cross attention
234 | dim_head = 64 # dimension per head
235 | ).cuda()
236 |
237 | seq = torch.randint(0, 5, (1, 196_608 // 2,)).cuda()
238 |
239 | target = torch.randn(1, 200, 4).cuda() # 4 tracks
240 | context = torch.randn(4, 16, 1024).cuda() # 4 contexts for the different 'tracks', each with 16 tokens
241 |
242 | context_mask = torch.ones(4, 16).bool().cuda() # optional context mask, in example, include all context tokens
243 |
244 | loss = model(
245 | seq,
246 | context = context,
247 | context_mask = context_mask,
248 | target = target
249 | )
250 |
251 | loss.backward()
252 | ```
253 |
254 | ## Data
255 |
256 | You can use the `GenomicIntervalDataset` to easily fetch sequences of any length from a `.bed` file, with greater context length dynamically computed if specified
257 |
258 | ```python
259 | import torch
260 | import polars as pl
261 | from enformer_pytorch import Enformer, GenomeIntervalDataset
262 |
263 | filter_train = lambda df: df.filter(pl.col('column_4') == 'train')
264 |
265 | ds = GenomeIntervalDataset(
266 | bed_file = './sequences.bed', # bed file - columns 0, 1, 2 must be , ,
267 | fasta_file = './hg38.ml.fa', # path to fasta file
268 | filter_df_fn = filter_train, # filter dataframe function
269 | return_seq_indices = True, # return nucleotide indices (ACGTN) or one hot encodings
270 | shift_augs = (-2, 2), # random shift augmentations from -2 to +2 basepairs
271 | context_length = 196_608,
272 | # this can be longer than the interval designated in the .bed file,
273 | # in which case it will take care of lengthening the interval on either sides
274 | # as well as proper padding if at the end of the chromosomes
275 | chr_bed_to_fasta_map = {
276 | 'chr1': 'chromosome1', # if the chromosome name in the .bed file is different than the key name in the fasta file, you can rename them on the fly
277 | 'chr2': 'chromosome2',
278 | 'chr3': 'chromosome3',
279 | # etc etc
280 | }
281 | )
282 |
283 | model = Enformer.from_hparams(
284 | dim = 1536,
285 | depth = 11,
286 | heads = 8,
287 | output_heads = dict(human = 5313, mouse = 1643),
288 | target_length = 896,
289 | )
290 |
291 | seq = ds[0] # (196608,)
292 | pred = model(seq, head = 'human') # (896, 5313)
293 | ```
294 |
295 | To return the random shift value, as well as whether reverse complement was activated (in the case you need to reverse the corresponding chip-seq target data), just set `return_augs = True` when initializing the `GenomicIntervalDataset`
296 |
297 | ```python
298 | import torch
299 | import polars as pl
300 | from enformer_pytorch import Enformer, GenomeIntervalDataset
301 |
302 | filter_train = lambda df: df.filter(pl.col('column_4') == 'train')
303 |
304 | ds = GenomeIntervalDataset(
305 | bed_file = './sequences.bed', # bed file - columns 0, 1, 2 must be , ,
306 | fasta_file = './hg38.ml.fa', # path to fasta file
307 | filter_df_fn = filter_train, # filter dataframe function
308 | return_seq_indices = True, # return nucleotide indices (ACGTN) or one hot encodings
309 | shift_augs = (-2, 2), # random shift augmentations from -2 to +2 basepairs
310 | rc_aug = True, # use reverse complement augmentation with 50% probability
311 | context_length = 196_608,
312 | return_augs = True # return the augmentation meta data
313 | )
314 |
315 | seq, rand_shift_val, rc_bool = ds[0] # (196608,), (1,), (1,)
316 | ```
317 |
318 | ## Appreciation
319 |
320 | Special thanks goes out to EleutherAI for providing the resources to retrain the model, during a time when the official model from Deepmind had not been released yet.
321 |
322 | Thanks also goes out to @johahi for finding out that there are numerical differences between the torch and tensorflow implementations of `xlogy`. He provided a fix for this difference, which is adopted in this repository in `v0.8.0`
323 |
324 | ## Todo
325 |
326 | - [x] script to load weights from trained tensorflow enformer model to pytorch model
327 | - [x] add loss wrapper with poisson loss
328 | - [x] move the metrics code over to pytorch as well
329 | - [x] train enformer model
330 | - [x] build context manager for fine-tuning with unfrozen enformer but with frozen batchnorm
331 | - [x] allow for plain fine-tune with fixed static context
332 | - [x] allow for fine tuning with only unfrozen layernorms (technique from fine tuning transformers)
333 | - [x] fix handling of 'N' in sequence, figure out representation of N in basenji barnyard
334 | - [x] take care of shift augmentation in `GenomicIntervalDataset`
335 | - [x] speed up `str_to_seq_indices`
336 | - [x] add to EleutherAI huggingface (done thanks to Niels)
337 | - [ ] offer some basic training utils, as gradient accumulation will be needed for fine tuning
338 |
339 | ## Citations
340 |
341 | ```bibtex
342 | @article {Avsec2021.04.07.438649,
343 | author = {Avsec, {\v Z}iga and Agarwal, Vikram and Visentin, Daniel and Ledsam, Joseph R. and Grabska-Barwinska, Agnieszka and Taylor, Kyle R. and Assael, Yannis and Jumper, John and Kohli, Pushmeet and Kelley, David R.},
344 | title = {Effective gene expression prediction from sequence by integrating long-range interactions},
345 | elocation-id = {2021.04.07.438649},
346 | year = {2021},
347 | doi = {10.1101/2021.04.07.438649},
348 | publisher = {Cold Spring Harbor Laboratory},
349 | URL = {https://www.biorxiv.org/content/early/2021/04/08/2021.04.07.438649},
350 | eprint = {https://www.biorxiv.org/content/early/2021/04/08/2021.04.07.438649.full.pdf},
351 | journal = {bioRxiv}
352 | }
353 | ```
354 |
355 | ```bibtex
356 | @misc{liu2022convnet,
357 | title = {A ConvNet for the 2020s},
358 | author = {Zhuang Liu and Hanzi Mao and Chao-Yuan Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
359 | year = {2022},
360 | eprint = {2201.03545},
361 | archivePrefix = {arXiv},
362 | primaryClass = {cs.CV}
363 | }
364 | ```
365 |
--------------------------------------------------------------------------------
/data/test-sample.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/enformer-pytorch/5a5974d2821c728f93294731c50b55f1f55fd86d/data/test-sample.pt
--------------------------------------------------------------------------------
/enformer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/enformer-pytorch/5a5974d2821c728f93294731c50b55f1f55fd86d/enformer.png
--------------------------------------------------------------------------------
/enformer_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from enformer_pytorch.config_enformer import EnformerConfig
2 | from enformer_pytorch.modeling_enformer import Enformer, from_pretrained, SEQUENCE_LENGTH, AttentionPool
3 | from enformer_pytorch.data import seq_indices_to_one_hot, str_to_one_hot, GenomeIntervalDataset, FastaInterval
--------------------------------------------------------------------------------
/enformer_pytorch/config_enformer.py:
--------------------------------------------------------------------------------
1 | from transformers import PretrainedConfig
2 |
3 | class EnformerConfig(PretrainedConfig):
4 | model_type = "enformer"
5 |
6 | def __init__(
7 | self,
8 | dim = 1536,
9 | depth = 11,
10 | heads = 8,
11 | output_heads = dict(human = 5313, mouse= 1643),
12 | target_length = 896,
13 | attn_dim_key = 64,
14 | dropout_rate = 0.4,
15 | attn_dropout = 0.05,
16 | pos_dropout = 0.01,
17 | use_checkpointing = False,
18 | use_convnext = False,
19 | num_downsamples = 7, # genetic sequence is downsampled 2 ** 7 == 128x in default Enformer - can be changed for higher resolution
20 | dim_divisible_by = 128,
21 | use_tf_gamma = False,
22 | **kwargs,
23 | ):
24 | self.dim = dim
25 | self.depth = depth
26 | self.heads = heads
27 | self.output_heads = output_heads
28 | self.target_length = target_length
29 | self.attn_dim_key = attn_dim_key
30 | self.dropout_rate = dropout_rate
31 | self.attn_dropout = attn_dropout
32 | self.pos_dropout = pos_dropout
33 | self.use_checkpointing = use_checkpointing
34 | self.num_downsamples = num_downsamples
35 | self.dim_divisible_by = dim_divisible_by
36 | self.use_tf_gamma = use_tf_gamma
37 |
38 | super().__init__(**kwargs)
--------------------------------------------------------------------------------
/enformer_pytorch/data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.utils.data import Dataset
4 |
5 | import polars as pl
6 | import numpy as np
7 | from random import randrange, random
8 | from pathlib import Path
9 | from pyfaidx import Fasta
10 |
11 | # helper functions
12 |
13 | def exists(val):
14 | return val is not None
15 |
16 | def identity(t):
17 | return t
18 |
19 | def cast_list(t):
20 | return t if isinstance(t, list) else [t]
21 |
22 | def coin_flip():
23 | return random() > 0.5
24 |
25 | # genomic function transforms
26 |
27 | seq_indices_embed = torch.zeros(256).long()
28 | seq_indices_embed[ord('a')] = 0
29 | seq_indices_embed[ord('c')] = 1
30 | seq_indices_embed[ord('g')] = 2
31 | seq_indices_embed[ord('t')] = 3
32 | seq_indices_embed[ord('n')] = 4
33 | seq_indices_embed[ord('A')] = 0
34 | seq_indices_embed[ord('C')] = 1
35 | seq_indices_embed[ord('G')] = 2
36 | seq_indices_embed[ord('T')] = 3
37 | seq_indices_embed[ord('N')] = 4
38 | seq_indices_embed[ord('.')] = -1
39 |
40 | one_hot_embed = torch.zeros(256, 4)
41 | one_hot_embed[ord('a')] = torch.Tensor([1., 0., 0., 0.])
42 | one_hot_embed[ord('c')] = torch.Tensor([0., 1., 0., 0.])
43 | one_hot_embed[ord('g')] = torch.Tensor([0., 0., 1., 0.])
44 | one_hot_embed[ord('t')] = torch.Tensor([0., 0., 0., 1.])
45 | one_hot_embed[ord('n')] = torch.Tensor([0., 0., 0., 0.])
46 | one_hot_embed[ord('A')] = torch.Tensor([1., 0., 0., 0.])
47 | one_hot_embed[ord('C')] = torch.Tensor([0., 1., 0., 0.])
48 | one_hot_embed[ord('G')] = torch.Tensor([0., 0., 1., 0.])
49 | one_hot_embed[ord('T')] = torch.Tensor([0., 0., 0., 1.])
50 | one_hot_embed[ord('N')] = torch.Tensor([0., 0., 0., 0.])
51 | one_hot_embed[ord('.')] = torch.Tensor([0.25, 0.25, 0.25, 0.25])
52 |
53 | reverse_complement_map = torch.Tensor([3, 2, 1, 0, 4]).long()
54 |
55 | def torch_fromstring(seq_strs):
56 | batched = not isinstance(seq_strs, str)
57 | seq_strs = cast_list(seq_strs)
58 | np_seq_chrs = list(map(lambda t: np.fromstring(t, dtype = np.uint8), seq_strs))
59 | seq_chrs = list(map(torch.from_numpy, np_seq_chrs))
60 | return torch.stack(seq_chrs) if batched else seq_chrs[0]
61 |
62 | def str_to_seq_indices(seq_strs):
63 | seq_chrs = torch_fromstring(seq_strs)
64 | return seq_indices_embed[seq_chrs.long()]
65 |
66 | def str_to_one_hot(seq_strs):
67 | seq_chrs = torch_fromstring(seq_strs)
68 | return one_hot_embed[seq_chrs.long()]
69 |
70 | def seq_indices_to_one_hot(t, padding = -1):
71 | is_padding = t == padding
72 | t = t.clamp(min = 0)
73 | one_hot = F.one_hot(t, num_classes = 5)
74 | out = one_hot[..., :4].float()
75 | out = out.masked_fill(is_padding[..., None], 0.25)
76 | return out
77 |
78 | # augmentations
79 |
80 | def seq_indices_reverse_complement(seq_indices):
81 | complement = reverse_complement_map[seq_indices.long()]
82 | return torch.flip(complement, dims = (-1,))
83 |
84 | def one_hot_reverse_complement(one_hot):
85 | *_, n, d = one_hot.shape
86 | assert d == 4, 'must be one hot encoding with last dimension equal to 4'
87 | return torch.flip(one_hot, (-1, -2))
88 |
89 | # processing bed files
90 |
91 | class FastaInterval():
92 | def __init__(
93 | self,
94 | *,
95 | fasta_file,
96 | context_length = None,
97 | return_seq_indices = False,
98 | shift_augs = None,
99 | rc_aug = False
100 | ):
101 | fasta_file = Path(fasta_file)
102 | assert fasta_file.exists(), 'path to fasta file must exist'
103 |
104 | self.seqs = Fasta(str(fasta_file))
105 | self.return_seq_indices = return_seq_indices
106 | self.context_length = context_length
107 | self.shift_augs = shift_augs
108 | self.rc_aug = rc_aug
109 |
110 | def __call__(self, chr_name, start, end, return_augs = False):
111 | interval_length = end - start
112 | chromosome = self.seqs[chr_name]
113 | chromosome_length = len(chromosome)
114 |
115 | if exists(self.shift_augs):
116 | min_shift, max_shift = self.shift_augs
117 | max_shift += 1
118 |
119 | min_shift = min(max(start + min_shift, 0) - start, 0)
120 | max_shift = max(min(end + max_shift, chromosome_length) - end, 1)
121 |
122 | rand_shift = randrange(min_shift, max_shift)
123 | start += rand_shift
124 | end += rand_shift
125 |
126 | left_padding = right_padding = 0
127 |
128 | if exists(self.context_length) and interval_length < self.context_length:
129 | extra_seq = self.context_length - interval_length
130 |
131 | extra_left_seq = extra_seq // 2
132 | extra_right_seq = extra_seq - extra_left_seq
133 |
134 | start -= extra_left_seq
135 | end += extra_right_seq
136 |
137 | if start < 0:
138 | left_padding = -start
139 | start = 0
140 |
141 | if end > chromosome_length:
142 | right_padding = end - chromosome_length
143 | end = chromosome_length
144 |
145 | seq = ('.' * left_padding) + str(chromosome[start:end]) + ('.' * right_padding)
146 |
147 | should_rc_aug = self.rc_aug and coin_flip()
148 |
149 | if self.return_seq_indices:
150 | seq = str_to_seq_indices(seq)
151 |
152 | if should_rc_aug:
153 | seq = seq_indices_reverse_complement(seq)
154 |
155 | return seq
156 |
157 | one_hot = str_to_one_hot(seq)
158 |
159 | if should_rc_aug:
160 | one_hot = one_hot_reverse_complement(one_hot)
161 |
162 | if not return_augs:
163 | return one_hot
164 |
165 | # returns the shift integer as well as the bool (for whether reverse complement was activated)
166 | # for this particular genomic sequence
167 |
168 | rand_shift_tensor = torch.tensor([rand_shift])
169 | rand_aug_bool_tensor = torch.tensor([should_rc_aug])
170 |
171 | return one_hot, rand_shift_tensor, rand_aug_bool_tensor
172 |
173 |
174 | class GenomeIntervalDataset(Dataset):
175 | def __init__(
176 | self,
177 | bed_file,
178 | fasta_file,
179 | filter_df_fn = identity,
180 | chr_bed_to_fasta_map = dict(),
181 | context_length = None,
182 | return_seq_indices = False,
183 | shift_augs = None,
184 | rc_aug = False,
185 | return_augs = False
186 | ):
187 | super().__init__()
188 | bed_path = Path(bed_file)
189 | assert bed_path.exists(), 'path to .bed file must exist'
190 |
191 | df = pl.read_csv(str(bed_path), separator = '\t', has_header = False)
192 | df = filter_df_fn(df)
193 | self.df = df
194 |
195 | # if the chromosome name in the bed file is different than the keyname in the fasta
196 | # can remap on the fly
197 | self.chr_bed_to_fasta_map = chr_bed_to_fasta_map
198 |
199 | self.fasta = FastaInterval(
200 | fasta_file = fasta_file,
201 | context_length = context_length,
202 | return_seq_indices = return_seq_indices,
203 | shift_augs = shift_augs,
204 | rc_aug = rc_aug
205 | )
206 |
207 | self.return_augs = return_augs
208 |
209 | def __len__(self):
210 | return len(self.df)
211 |
212 | def __getitem__(self, ind):
213 | interval = self.df.row(ind)
214 | chr_name, start, end = (interval[0], interval[1], interval[2])
215 | chr_name = self.chr_bed_to_fasta_map.get(chr_name, chr_name)
216 | return self.fasta(chr_name, start, end, return_augs = self.return_augs)
217 |
--------------------------------------------------------------------------------
/enformer_pytorch/finetune.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Optional
3 |
4 | from copy import deepcopy
5 | from contextlib import contextmanager
6 | import torch.nn.functional as F
7 | from torch import nn, einsum
8 |
9 | from einops import rearrange, repeat
10 | from einops.layers.torch import Rearrange
11 | from enformer_pytorch.modeling_enformer import Enformer, poisson_loss
12 |
13 | from discrete_key_value_bottleneck_pytorch import DiscreteKeyValueBottleneck
14 |
15 | def exists(val):
16 | return val is not None
17 |
18 | def default(val, d):
19 | return val if exists(val) else d
20 |
21 | @contextmanager
22 | def null_context():
23 | yield
24 |
25 | # better sequential
26 |
27 | def Sequential(*modules):
28 | return nn.Sequential(*filter(exists, modules))
29 |
30 | # controlling freezing of layers
31 |
32 | def set_module_requires_grad_(module, requires_grad):
33 | for param in module.parameters():
34 | param.requires_grad = requires_grad
35 |
36 | def freeze_all_layers_(module):
37 | set_module_requires_grad_(module, False)
38 |
39 | def unfreeze_all_layers_(module):
40 | set_module_requires_grad_(module, True)
41 |
42 | def freeze_batchnorms_(model):
43 | bns = [m for m in model.modules() if isinstance(m, nn.BatchNorm1d)]
44 |
45 | for bn in bns:
46 | bn.eval()
47 | bn.track_running_stats = False
48 | set_module_requires_grad_(bn, False)
49 |
50 | def freeze_all_but_layernorms_(model):
51 | for m in model.modules():
52 | set_module_requires_grad_(m, isinstance(m, nn.LayerNorm))
53 |
54 | def freeze_all_but_last_n_layers_(enformer, n):
55 | assert isinstance(enformer, Enformer)
56 | freeze_all_layers_(enformer)
57 |
58 | transformer_blocks = enformer.transformer
59 |
60 | for module in transformer_blocks[-n:]:
61 | set_module_requires_grad_(module, True)
62 |
63 | # get enformer embeddings
64 |
65 | def get_enformer_embeddings(
66 | model,
67 | seq,
68 | freeze = False,
69 | train_layernorms_only = False,
70 | train_last_n_layers_only = None,
71 | enformer_kwargs: dict = {}
72 | ):
73 | freeze_batchnorms_(model)
74 |
75 | if train_layernorms_only:
76 | assert not freeze, 'you set the intent to train the layernorms of the enformer, yet also indicated you wanted to freeze the entire model'
77 | freeze_all_but_layernorms_(model)
78 |
79 | if exists(train_last_n_layers_only):
80 | assert not freeze, 'you set the intent to train last N layers of enformer, but also indicated you wanted to freeze the entire network'
81 | freeze_all_but_last_n_layers_(model, train_last_n_layers_only)
82 |
83 | enformer_context = null_context() if not freeze else torch.no_grad()
84 |
85 | with enformer_context:
86 | embeddings = model(seq, return_only_embeddings = True, **enformer_kwargs)
87 |
88 | if freeze:
89 | embeddings.detach_()
90 |
91 | return embeddings
92 |
93 | # fine-tune wrapper classes
94 |
95 | # extra head projection, akin to how human and mouse tracks were trained
96 |
97 | class HeadAdapterWrapper(nn.Module):
98 | def __init__(
99 | self,
100 | *,
101 | enformer,
102 | num_tracks,
103 | post_transformer_embed = False, # whether to take the embeddings from right after the transformer, instead of after the final pointwise convolutional - this would add another layernorm
104 | discrete_key_value_bottleneck = False,
105 | bottleneck_num_memories = 256,
106 | bottleneck_num_codebooks = 4,
107 | bottleneck_decay = 0.9,
108 | transformer_embed_fn: nn.Module = nn.Identity(),
109 | output_activation: Optional[nn.Module] = nn.Softplus(),
110 | auto_set_target_length = True
111 | ):
112 | super().__init__()
113 | assert isinstance(enformer, Enformer)
114 | enformer_hidden_dim = enformer.dim * (2 if not post_transformer_embed else 1)
115 |
116 | self.discrete_key_value_bottleneck = discrete_key_value_bottleneck
117 |
118 | if discrete_key_value_bottleneck:
119 | enformer = DiscreteKeyValueBottleneck(
120 | encoder = enformer,
121 | dim = enformer_hidden_dim,
122 | num_memory_codebooks = bottleneck_num_codebooks,
123 | num_memories = bottleneck_num_memories,
124 | dim_memory = enformer_hidden_dim // bottleneck_num_codebooks,
125 | decay = bottleneck_decay,
126 | )
127 |
128 | self.post_transformer_embed = post_transformer_embed
129 |
130 | self.enformer = enformer
131 |
132 | self.auto_set_target_length = auto_set_target_length
133 |
134 | if post_transformer_embed:
135 | self.enformer = deepcopy(enformer)
136 | self.enformer._trunk[-1] = nn.Identity()
137 | self.enformer.final_pointwise = nn.Identity()
138 |
139 | self.post_embed_transform = Sequential(
140 | transformer_embed_fn,
141 | nn.LayerNorm(enformer_hidden_dim) if post_transformer_embed else None
142 | )
143 |
144 | self.to_tracks = Sequential(
145 | nn.Linear(enformer_hidden_dim, num_tracks),
146 | output_activation
147 | )
148 |
149 | def forward(
150 | self,
151 | seq,
152 | *,
153 | target = None,
154 | freeze_enformer = False,
155 | finetune_enformer_ln_only = False,
156 | finetune_last_n_layers_only = None
157 | ):
158 | enformer_kwargs = dict()
159 |
160 | if exists(target) and self.auto_set_target_length:
161 | enformer_kwargs = dict(target_length = target.shape[-2])
162 |
163 | if self.discrete_key_value_bottleneck:
164 | embeddings = self.enformer(seq, return_only_embeddings = True, **enformer_kwargs)
165 | else:
166 | embeddings = get_enformer_embeddings(self.enformer, seq, freeze = freeze_enformer, train_layernorms_only = finetune_enformer_ln_only, train_last_n_layers_only = finetune_last_n_layers_only, enformer_kwargs = enformer_kwargs)
167 |
168 | preds = self.to_tracks(embeddings)
169 |
170 | if not exists(target):
171 | return preds
172 |
173 | return poisson_loss(preds, target)
174 |
175 | # wrapper that allows one to supply each track with a context dimension
176 | # the context embedding will be projected into the weights and biases of the head linear projection (hypernetwork)
177 |
178 | class ContextAdapterWrapper(nn.Module):
179 | def __init__(
180 | self,
181 | *,
182 | enformer,
183 | context_dim,
184 | discrete_key_value_bottleneck = False,
185 | bottleneck_num_memories = 256,
186 | bottleneck_num_codebooks = 4,
187 | bottleneck_decay = 0.9,
188 | auto_set_target_length = True,
189 | output_activation: Optional[nn.Module] = nn.Softplus()
190 | ):
191 | super().__init__()
192 | assert isinstance(enformer, Enformer)
193 | enformer_hidden_dim = enformer.dim * 2
194 |
195 | self.discrete_key_value_bottleneck = discrete_key_value_bottleneck
196 |
197 | if discrete_key_value_bottleneck:
198 | enformer = DiscreteKeyValueBottleneck(
199 | encoder = enformer,
200 | dim = enformer_hidden_dim,
201 | num_memory_codebooks = bottleneck_num_codebooks,
202 | num_memories = bottleneck_num_memories,
203 | dim_memory = enformer_hidden_dim // bottleneck_num_codebooks,
204 | decay = bottleneck_decay,
205 | )
206 |
207 | self.enformer = enformer
208 |
209 | self.auto_set_target_length = auto_set_target_length
210 |
211 | self.to_context_weights = nn.Parameter(torch.randn(context_dim, enformer_hidden_dim))
212 | self.to_context_bias = nn.Parameter(torch.randn(context_dim))
213 |
214 | self.activation = default(output_activation, nn.Identity())
215 |
216 | def forward(
217 | self,
218 | seq,
219 | *,
220 | context,
221 | target = None,
222 | freeze_enformer = False,
223 | finetune_enformer_ln_only = False,
224 | finetune_last_n_layers_only = None
225 | ):
226 | enformer_kwargs = dict()
227 |
228 | if exists(target) and self.auto_set_target_length:
229 | enformer_kwargs = dict(target_length = target.shape[-2])
230 |
231 | if self.discrete_key_value_bottleneck:
232 | embeddings = self.enformer(seq, return_only_embeddings = True, **enformer_kwargs)
233 | else:
234 | embeddings = get_enformer_embeddings(self.enformer, seq, freeze = freeze_enformer, train_layernorms_only = finetune_enformer_ln_only, train_last_n_layers_only = finetune_last_n_layers_only, enformer_kwargs = enformer_kwargs)
235 |
236 | weights = einsum('t d, d e -> t e', context, self.to_context_weights)
237 | bias = einsum('t d, d -> t', context, self.to_context_bias)
238 |
239 | pred = einsum('b n d, t d -> b n t', embeddings, weights) + bias
240 |
241 | pred = self.activation(pred)
242 |
243 | if not exists(target):
244 | return pred
245 |
246 | return poisson_loss(pred, target)
247 |
248 | # wrapper that does attention aggregation of the context, which can be a list of tokens (batch x seq x dim)
249 |
250 | class ContextAttentionAdapterWrapper(nn.Module):
251 | def __init__(
252 | self,
253 | *,
254 | enformer,
255 | context_dim,
256 | heads = 8,
257 | dim_head = 64,
258 | discrete_key_value_bottleneck = False,
259 | bottleneck_num_memories = 256,
260 | bottleneck_num_codebooks = 4,
261 | bottleneck_decay = 0.9,
262 | auto_set_target_length = True,
263 | output_activation: Optional[nn.Module] = nn.Softplus()
264 | ):
265 | super().__init__()
266 | assert isinstance(enformer, Enformer)
267 | enformer_hidden_dim = enformer.dim * 2
268 |
269 | self.discrete_key_value_bottleneck = discrete_key_value_bottleneck
270 |
271 | if discrete_key_value_bottleneck:
272 | enformer = DiscreteKeyValueBottleneck(
273 | encoder = enformer,
274 | dim = enformer_hidden_dim,
275 | num_memory_codebooks = bottleneck_num_codebooks,
276 | num_memories = bottleneck_num_memories,
277 | dim_memory = enformer_hidden_dim // bottleneck_num_codebooks,
278 | decay = bottleneck_decay,
279 | )
280 |
281 | self.enformer = enformer
282 |
283 | self.auto_set_target_length = auto_set_target_length
284 |
285 | self.query_norm = nn.LayerNorm(enformer_hidden_dim)
286 | self.key_values_norm = nn.LayerNorm(context_dim)
287 |
288 | self.scale = dim_head ** -0.5
289 | self.heads = heads
290 | inner_dim = heads * dim_head
291 | self.to_queries = nn.Linear(enformer_hidden_dim, inner_dim, bias = False)
292 |
293 | self.null_key = nn.Parameter(torch.randn(inner_dim))
294 | self.null_value = nn.Parameter(torch.randn(inner_dim))
295 |
296 | self.to_key_values = nn.Linear(context_dim, inner_dim * 2, bias = False)
297 | self.to_out = nn.Linear(inner_dim, enformer_hidden_dim)
298 |
299 | self.to_pred = Sequential(
300 | nn.Linear(enformer_hidden_dim, 1),
301 | Rearrange('b c ... 1 -> b ... c'),
302 | output_activation
303 | )
304 |
305 | def forward(
306 | self,
307 | seq,
308 | *,
309 | context,
310 | context_mask = None,
311 | target = None,
312 | freeze_enformer = False,
313 | finetune_enformer_ln_only = False,
314 | finetune_last_n_layers_only = None
315 | ):
316 | """
317 | b - batch
318 | n - sequence length
319 | c - number of contexts (tracks)
320 | d - dimension
321 | i - sequence length (query embeddings)
322 | j - sequence length (keys / values contexts)
323 | h - attention heads
324 | """
325 |
326 | h = self.heads
327 |
328 | enformer_kwargs = dict()
329 |
330 | if exists(target) and self.auto_set_target_length:
331 | enformer_kwargs = dict(target_length = target.shape[-2])
332 |
333 | if self.discrete_key_value_bottleneck:
334 | embeddings = self.enformer(seq, return_only_embeddings = True, **enformer_kwargs)
335 | else:
336 | embeddings = get_enformer_embeddings(self.enformer, seq, freeze = freeze_enformer, train_layernorms_only = finetune_enformer_ln_only, train_last_n_layers_only = finetune_last_n_layers_only, enformer_kwargs = enformer_kwargs)
337 |
338 | # perform cross attention from genetic -> context
339 |
340 | if context.ndim == 2:
341 | context = rearrange(context, 'b d -> b 1 d')
342 |
343 | q = self.to_queries(self.query_norm(embeddings))
344 | k, v = self.to_key_values(self.key_values_norm(context)).chunk(2, dim = -1)
345 |
346 | null_k, null_v = map(lambda t: repeat(t, 'd -> b 1 d', b = context.shape[0]), (self.null_key, self.null_value))
347 |
348 | k = torch.cat((null_k, k), dim = 1)
349 | v = torch.cat((null_v, v), dim = 1)
350 |
351 | # split out head
352 |
353 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
354 | sim = einsum('b h i d, c h j d -> b c h i j', q, k) * self.scale
355 |
356 | # masking
357 |
358 | if exists(context_mask):
359 | context_mask = F.pad(context_mask, (1, 0), value = True)
360 | context_mask =rearrange(context_mask, 'b j -> b 1 1 1 j')
361 | sim = sim.masked_fill(~context_mask, -torch.finfo(sim.dtype).max)
362 |
363 | # attention
364 |
365 | attn = sim.softmax(dim = -1)
366 |
367 | # aggregate
368 |
369 | out = einsum('b c h i j, c h j d -> b c h i d', attn, v)
370 |
371 | out = rearrange(out, 'b c h n d -> b c n (h d)', h = h)
372 |
373 | # combine heads
374 |
375 | branch_out = self.to_out(out)
376 |
377 | # residual
378 |
379 | embeddings = embeddings + branch_out
380 |
381 | # to prediction
382 |
383 | pred = self.to_pred(embeddings)
384 |
385 | if not exists(target):
386 | return pred
387 |
388 | return poisson_loss(pred, target)
389 |
--------------------------------------------------------------------------------
/enformer_pytorch/metrics.py:
--------------------------------------------------------------------------------
1 | from torchmetrics import Metric
2 | from typing import Optional
3 | import torch
4 |
5 |
6 | class MeanPearsonCorrCoefPerChannel(Metric):
7 | is_differentiable: Optional[bool] = False
8 | higher_is_better: Optional[bool] = True
9 | def __init__(self, n_channels:int, dist_sync_on_step=False):
10 | """Calculates the mean pearson correlation across channels aggregated over regions"""
11 | super().__init__(dist_sync_on_step=dist_sync_on_step)
12 | self.reduce_dims=(0, 1)
13 | self.add_state("product", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum", )
14 | self.add_state("true", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum", )
15 | self.add_state("true_squared", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum", )
16 | self.add_state("pred", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum", )
17 | self.add_state("pred_squared", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum", )
18 | self.add_state("count", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum")
19 |
20 | def update(self, preds: torch.Tensor, target: torch.Tensor):
21 | assert preds.shape == target.shape
22 |
23 | self.product += torch.sum(preds * target, dim=self.reduce_dims)
24 | self.true += torch.sum(target, dim=self.reduce_dims)
25 | self.true_squared += torch.sum(torch.square(target), dim=self.reduce_dims)
26 | self.pred += torch.sum(preds, dim=self.reduce_dims)
27 | self.pred_squared += torch.sum(torch.square(preds), dim=self.reduce_dims)
28 | self.count += torch.sum(torch.ones_like(target), dim=self.reduce_dims)
29 |
30 | def compute(self):
31 | true_mean = self.true / self.count
32 | pred_mean = self.pred / self.count
33 |
34 | covariance = (self.product
35 | - true_mean * self.pred
36 | - pred_mean * self.true
37 | + self.count * true_mean * pred_mean)
38 |
39 | true_var = self.true_squared - self.count * torch.square(true_mean)
40 | pred_var = self.pred_squared - self.count * torch.square(pred_mean)
41 | tp_var = torch.sqrt(true_var) * torch.sqrt(pred_var)
42 | correlation = covariance / tp_var
43 | return correlation
44 |
--------------------------------------------------------------------------------
/enformer_pytorch/modeling_enformer.py:
--------------------------------------------------------------------------------
1 | import math
2 | from pathlib import Path
3 |
4 | import torch
5 | from torch import nn, einsum
6 | import torch.nn.functional as F
7 | import torch.distributed as dist
8 | from torch.utils.checkpoint import checkpoint_sequential
9 |
10 | from einops import rearrange, reduce
11 | from einops.layers.torch import Rearrange
12 |
13 | from enformer_pytorch.data import str_to_one_hot, seq_indices_to_one_hot
14 |
15 | from enformer_pytorch.config_enformer import EnformerConfig
16 |
17 | from transformers import PreTrainedModel
18 |
19 | # constants
20 |
21 | SEQUENCE_LENGTH = 196_608
22 | TARGET_LENGTH = 896
23 |
24 | # gamma positions from tensorflow
25 | # addressing a difference between xlogy results from tensorflow and pytorch
26 | # solution came from @johahi
27 |
28 | DIR = Path(__file__).parents[0]
29 | TF_GAMMAS = torch.load(str(DIR / "precomputed"/ "tf_gammas.pt"), weights_only=True)
30 |
31 | # helpers
32 |
33 | def exists(val):
34 | return val is not None
35 |
36 | def default(val, d):
37 | return val if exists(val) else d
38 |
39 | def always(val):
40 | def inner(*args, **kwargs):
41 | return val
42 | return inner
43 |
44 | def map_values(fn, d):
45 | return {key: fn(values) for key, values in d.items()}
46 |
47 | def exponential_linspace_int(start, end, num, divisible_by = 1):
48 | def _round(x):
49 | return int(round(x / divisible_by) * divisible_by)
50 |
51 | base = math.exp(math.log(end / start) / (num - 1))
52 | return [_round(start * base**i) for i in range(num)]
53 |
54 | def log(t, eps = 1e-20):
55 | return torch.log(t.clamp(min = eps))
56 |
57 | # maybe sync batchnorm, for distributed training
58 |
59 | def MaybeSyncBatchnorm(is_distributed = None):
60 | is_distributed = default(is_distributed, dist.is_initialized() and dist.get_world_size() > 1)
61 | return nn.SyncBatchNorm if is_distributed else nn.BatchNorm1d
62 |
63 | # losses and metrics
64 |
65 | def poisson_loss(pred, target):
66 | return (pred - target * log(pred)).mean()
67 |
68 | def pearson_corr_coef(x, y, dim = 1, reduce_dims = (-1,)):
69 | x_centered = x - x.mean(dim = dim, keepdim = True)
70 | y_centered = y - y.mean(dim = dim, keepdim = True)
71 | return F.cosine_similarity(x_centered, y_centered, dim = dim).mean(dim = reduce_dims)
72 |
73 | # relative positional encoding functions
74 |
75 | def get_positional_features_exponential(positions, features, seq_len, min_half_life = 3., dtype = torch.float):
76 | max_range = math.log(seq_len) / math.log(2.)
77 | half_life = 2 ** torch.linspace(min_half_life, max_range, features, device = positions.device)
78 | half_life = half_life[None, ...]
79 | positions = positions.abs()[..., None]
80 | return torch.exp(-math.log(2.) / half_life * positions)
81 |
82 | def get_positional_features_central_mask(positions, features, seq_len, dtype = torch.float):
83 | center_widths = 2 ** torch.arange(1, features + 1, device = positions.device).to(dtype)
84 | center_widths = center_widths - 1
85 | return (center_widths[None, ...] > positions.abs()[..., None]).to(dtype)
86 |
87 | def gamma_pdf(x, concentration, rate):
88 | log_unnormalized_prob = torch.xlogy(concentration - 1., x) - rate * x
89 | log_normalization = (torch.lgamma(concentration) - concentration * torch.log(rate))
90 | return torch.exp(log_unnormalized_prob - log_normalization)
91 |
92 | def get_positional_features_gamma(positions, features, seq_len, stddev = None, start_mean = None, eps = 1e-8, dtype = torch.float):
93 | if not exists(stddev):
94 | stddev = seq_len / (2 * features)
95 |
96 | if not exists(start_mean):
97 | start_mean = seq_len / features
98 |
99 | mean = torch.linspace(start_mean, seq_len, features, device = positions.device)
100 |
101 | mean = mean[None, ...]
102 | concentration = (mean / stddev) ** 2
103 | rate = mean / stddev ** 2
104 |
105 | probabilities = gamma_pdf(positions.to(dtype).abs()[..., None], concentration, rate)
106 | probabilities = probabilities + eps
107 | outputs = probabilities / torch.amax(probabilities, dim = -1, keepdim = True)
108 | return outputs
109 |
110 | def get_positional_embed(seq_len, feature_size, device, use_tf_gamma, dtype = torch.float):
111 | distances = torch.arange(-seq_len + 1, seq_len, device = device)
112 |
113 | assert not use_tf_gamma or seq_len == 1536, 'if using tf gamma, only sequence length of 1536 allowed for now'
114 |
115 | feature_functions = [
116 | get_positional_features_exponential,
117 | get_positional_features_central_mask,
118 | get_positional_features_gamma if not use_tf_gamma else always(TF_GAMMAS.to(device))
119 | ]
120 |
121 | num_components = len(feature_functions) * 2
122 |
123 | if (feature_size % num_components) != 0:
124 | raise ValueError(f'feature size is not divisible by number of components ({num_components})')
125 |
126 | num_basis_per_class = feature_size // num_components
127 |
128 | embeddings = []
129 | for fn in feature_functions:
130 | embeddings.append(fn(distances, num_basis_per_class, seq_len, dtype = dtype))
131 |
132 | embeddings = torch.cat(embeddings, dim = -1)
133 | embeddings = torch.cat((embeddings, torch.sign(distances)[..., None] * embeddings), dim = -1)
134 | return embeddings.to(dtype)
135 |
136 | def relative_shift(x):
137 | to_pad = torch.zeros_like(x[..., :1])
138 | x = torch.cat((to_pad, x), dim = -1)
139 | _, h, t1, t2 = x.shape
140 | x = x.reshape(-1, h, t2, t1)
141 | x = x[:, :, 1:, :]
142 | x = x.reshape(-1, h, t1, t2 - 1)
143 | return x[..., :((t2 + 1) // 2)]
144 |
145 | # classes
146 |
147 | class Residual(nn.Module):
148 | def __init__(self, fn):
149 | super().__init__()
150 | self.fn = fn
151 |
152 | def forward(self, x, **kwargs):
153 | return self.fn(x, **kwargs) + x
154 |
155 | class GELU(nn.Module):
156 | def forward(self, x):
157 | return torch.sigmoid(1.702 * x) * x
158 |
159 | class AttentionPool(nn.Module):
160 | def __init__(self, dim, pool_size = 2):
161 | super().__init__()
162 | self.pool_size = pool_size
163 | self.pool_fn = Rearrange('b d (n p) -> b d n p', p = pool_size)
164 |
165 | self.to_attn_logits = nn.Conv2d(dim, dim, 1, bias = False)
166 |
167 | nn.init.dirac_(self.to_attn_logits.weight)
168 |
169 | with torch.no_grad():
170 | self.to_attn_logits.weight.mul_(2)
171 |
172 | def forward(self, x):
173 | b, _, n = x.shape
174 | remainder = n % self.pool_size
175 | needs_padding = remainder > 0
176 |
177 | if needs_padding:
178 | x = F.pad(x, (0, remainder), value = 0)
179 | mask = torch.zeros((b, 1, n), dtype = torch.bool, device = x.device)
180 | mask = F.pad(mask, (0, remainder), value = True)
181 |
182 | x = self.pool_fn(x)
183 | logits = self.to_attn_logits(x)
184 |
185 | if needs_padding:
186 | mask_value = -torch.finfo(logits.dtype).max
187 | logits = logits.masked_fill(self.pool_fn(mask), mask_value)
188 |
189 | attn = logits.softmax(dim = -1)
190 |
191 | return (x * attn).sum(dim = -1)
192 |
193 | class TargetLengthCrop(nn.Module):
194 | def __init__(self, target_length):
195 | super().__init__()
196 | self.target_length = target_length
197 |
198 | def forward(self, x):
199 | seq_len, target_len = x.shape[-2], self.target_length
200 |
201 | if target_len == -1:
202 | return x
203 |
204 | if seq_len < target_len:
205 | raise ValueError(f'sequence length {seq_len} is less than target length {target_len}')
206 |
207 | trim = (target_len - seq_len) // 2
208 |
209 | if trim == 0:
210 | return x
211 |
212 | return x[:, -trim:trim]
213 |
214 | def ConvBlock(dim, dim_out = None, kernel_size = 1, is_distributed = None):
215 | batchnorm_klass = MaybeSyncBatchnorm(is_distributed = is_distributed)
216 |
217 | return nn.Sequential(
218 | batchnorm_klass(dim),
219 | GELU(),
220 | nn.Conv1d(dim, default(dim_out, dim), kernel_size, padding = kernel_size // 2)
221 | )
222 |
223 | # attention classes
224 |
225 | class Attention(nn.Module):
226 | def __init__(
227 | self,
228 | dim,
229 | *,
230 | num_rel_pos_features,
231 | heads = 8,
232 | dim_key = 64,
233 | dim_value = 64,
234 | dropout = 0.,
235 | pos_dropout = 0.,
236 | use_tf_gamma = False
237 | ):
238 | super().__init__()
239 | self.scale = dim_key ** -0.5
240 | self.heads = heads
241 |
242 | self.to_q = nn.Linear(dim, dim_key * heads, bias = False)
243 | self.to_k = nn.Linear(dim, dim_key * heads, bias = False)
244 | self.to_v = nn.Linear(dim, dim_value * heads, bias = False)
245 |
246 | self.to_out = nn.Linear(dim_value * heads, dim)
247 | nn.init.zeros_(self.to_out.weight)
248 | nn.init.zeros_(self.to_out.bias)
249 |
250 | # relative positional encoding
251 |
252 | self.num_rel_pos_features = num_rel_pos_features
253 |
254 | self.to_rel_k = nn.Linear(num_rel_pos_features, dim_key * heads, bias = False)
255 | self.rel_content_bias = nn.Parameter(torch.randn(1, heads, 1, dim_key))
256 | self.rel_pos_bias = nn.Parameter(torch.randn(1, heads, 1, dim_key))
257 |
258 | # dropouts
259 |
260 | self.pos_dropout = nn.Dropout(pos_dropout)
261 | self.attn_dropout = nn.Dropout(dropout)
262 |
263 | # whether to use tf gamma
264 |
265 | self.use_tf_gamma = use_tf_gamma
266 |
267 | def forward(self, x):
268 | n, h, device = x.shape[-2], self.heads, x.device
269 |
270 | q = self.to_q(x)
271 | k = self.to_k(x)
272 | v = self.to_v(x)
273 |
274 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
275 |
276 | q = q * self.scale
277 |
278 | content_logits = einsum('b h i d, b h j d -> b h i j', q + self.rel_content_bias, k)
279 |
280 | positions = get_positional_embed(n, self.num_rel_pos_features, device, use_tf_gamma = self.use_tf_gamma, dtype = self.to_rel_k.weight.dtype)
281 | positions = self.pos_dropout(positions)
282 | rel_k = self.to_rel_k(positions)
283 |
284 | rel_k = rearrange(rel_k, 'n (h d) -> h n d', h = h)
285 | rel_logits = einsum('b h i d, h j d -> b h i j', q + self.rel_pos_bias, rel_k)
286 | rel_logits = relative_shift(rel_logits)
287 |
288 | logits = content_logits + rel_logits
289 | attn = logits.softmax(dim = -1)
290 | attn = self.attn_dropout(attn)
291 |
292 | out = einsum('b h i j, b h j d -> b h i d', attn, v)
293 | out = rearrange(out, 'b h n d -> b n (h d)')
294 | return self.to_out(out)
295 |
296 | # main class
297 |
298 | class Enformer(PreTrainedModel):
299 | config_class = EnformerConfig
300 | base_model_prefix = "enformer"
301 |
302 | @staticmethod
303 | def from_hparams(**kwargs):
304 | return Enformer(EnformerConfig(**kwargs))
305 |
306 | def __init__(self, config):
307 | super().__init__(config)
308 | self.dim = config.dim
309 | half_dim = config.dim // 2
310 | twice_dim = config.dim * 2
311 |
312 | # create stem
313 |
314 | self.stem = nn.Sequential(
315 | nn.Conv1d(4, half_dim, 15, padding = 7),
316 | Residual(ConvBlock(half_dim)),
317 | AttentionPool(half_dim, pool_size = 2)
318 | )
319 |
320 | # create conv tower
321 |
322 | filter_list = exponential_linspace_int(half_dim, config.dim, num = (config.num_downsamples - 1), divisible_by = config.dim_divisible_by)
323 | filter_list = [half_dim, *filter_list]
324 |
325 | conv_layers = []
326 | for dim_in, dim_out in zip(filter_list[:-1], filter_list[1:]):
327 | conv_layers.append(nn.Sequential(
328 | ConvBlock(dim_in, dim_out, kernel_size = 5),
329 | Residual(ConvBlock(dim_out, dim_out, 1)),
330 | AttentionPool(dim_out, pool_size = 2)
331 | ))
332 |
333 | self.conv_tower = nn.Sequential(*conv_layers)
334 |
335 | # whether to use tensorflow gamma positions
336 |
337 | use_tf_gamma = config.use_tf_gamma
338 | self.use_tf_gamma = use_tf_gamma
339 |
340 | # transformer
341 |
342 | transformer = []
343 | for _ in range(config.depth):
344 | transformer.append(nn.Sequential(
345 | Residual(nn.Sequential(
346 | nn.LayerNorm(config.dim),
347 | Attention(
348 | config.dim,
349 | heads = config.heads,
350 | dim_key = config.attn_dim_key,
351 | dim_value = config.dim // config.heads,
352 | dropout = config.attn_dropout,
353 | pos_dropout = config.pos_dropout,
354 | num_rel_pos_features = config.dim // config.heads,
355 | use_tf_gamma = use_tf_gamma
356 | ),
357 | nn.Dropout(config.dropout_rate)
358 | )),
359 | Residual(nn.Sequential(
360 | nn.LayerNorm(config.dim),
361 | nn.Linear(config.dim, config.dim * 2),
362 | nn.Dropout(config.dropout_rate),
363 | nn.ReLU(),
364 | nn.Linear(config.dim * 2, config.dim),
365 | nn.Dropout(config.dropout_rate)
366 | ))
367 | ))
368 |
369 | self.transformer = nn.Sequential(*transformer)
370 |
371 | # target cropping
372 |
373 | self.target_length = config.target_length
374 | self.crop_final = TargetLengthCrop(config.target_length)
375 |
376 | # final pointwise
377 |
378 | self.final_pointwise = nn.Sequential(
379 | Rearrange('b n d -> b d n'),
380 | ConvBlock(filter_list[-1], twice_dim, 1),
381 | Rearrange('b d n -> b n d'),
382 | nn.Dropout(config.dropout_rate / 8),
383 | GELU()
384 | )
385 |
386 | # create trunk sequential module
387 |
388 | self._trunk = nn.Sequential(
389 | Rearrange('b n d -> b d n'),
390 | self.stem,
391 | self.conv_tower,
392 | Rearrange('b d n -> b n d'),
393 | self.transformer,
394 | self.crop_final,
395 | self.final_pointwise
396 | )
397 |
398 | # create final heads for human and mouse
399 |
400 | self.add_heads(**config.output_heads)
401 |
402 | # use checkpointing on transformer trunk
403 |
404 | self.use_checkpointing = config.use_checkpointing
405 |
406 | def add_heads(self, **kwargs):
407 | self.output_heads = kwargs
408 |
409 | self._heads = nn.ModuleDict(map_values(lambda features: nn.Sequential(
410 | nn.Linear(self.dim * 2, features),
411 | nn.Softplus()
412 | ), kwargs))
413 |
414 | def set_target_length(self, target_length):
415 | crop_module = self._trunk[-2]
416 | crop_module.target_length = target_length
417 |
418 | @property
419 | def trunk(self):
420 | return self._trunk
421 |
422 | @property
423 | def heads(self):
424 | return self._heads
425 |
426 | def trunk_checkpointed(self, x):
427 | x = rearrange(x, 'b n d -> b d n')
428 | x = self.stem(x)
429 | x = self.conv_tower(x)
430 | x = rearrange(x, 'b d n -> b n d')
431 | x = checkpoint_sequential(self.transformer, len(self.transformer), x)
432 | x = self.crop_final(x)
433 | x = self.final_pointwise(x)
434 | return x
435 |
436 | def forward(
437 | self,
438 | x,
439 | target = None,
440 | return_corr_coef = False,
441 | return_embeddings = False,
442 | return_only_embeddings = False,
443 | head = None,
444 | target_length = None
445 | ):
446 | if isinstance(x, list):
447 | x = str_to_one_hot(x)
448 |
449 | elif type(x) == torch.Tensor and x.dtype == torch.long:
450 | x = seq_indices_to_one_hot(x)
451 | x.to(self.device)
452 |
453 | no_batch = x.ndim == 2
454 |
455 | if no_batch:
456 | x = rearrange(x, '... -> () ...')
457 |
458 | if exists(target_length):
459 | self.set_target_length(target_length)
460 |
461 | trunk_fn = self.trunk_checkpointed if self.use_checkpointing else self._trunk
462 | x = trunk_fn(x)
463 |
464 | if no_batch:
465 | x = rearrange(x, '() ... -> ...')
466 |
467 | if return_only_embeddings:
468 | return x
469 |
470 | out = map_values(lambda fn: fn(x), self._heads)
471 |
472 | if exists(head):
473 | assert head in self._heads, f'head {head} not found'
474 | out = out[head]
475 |
476 | if exists(target):
477 | assert exists(head), 'head must be passed in if one were to calculate loss directly with targets'
478 |
479 | if return_corr_coef:
480 | return pearson_corr_coef(out, target)
481 |
482 | return poisson_loss(out, target)
483 |
484 | if return_embeddings:
485 | return out, x
486 |
487 | return out
488 |
489 | # from pretrained function
490 |
491 | def from_pretrained(name, use_tf_gamma = None, **kwargs):
492 | enformer = Enformer.from_pretrained(name, **kwargs)
493 |
494 | if name == 'EleutherAI/enformer-official-rough':
495 | use_tf_gamma = default(use_tf_gamma, True)
496 |
497 | for module in enformer.modules():
498 | if isinstance(module, Attention):
499 | module.use_tf_gamma = use_tf_gamma
500 |
501 | return enformer
502 |
--------------------------------------------------------------------------------
/enformer_pytorch/precomputed/tf_gammas.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/enformer-pytorch/5a5974d2821c728f93294731c50b55f1f55fd86d/enformer_pytorch/precomputed/tf_gammas.pt
--------------------------------------------------------------------------------
/evaluate_enformer_pytorch_correlation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "view-in-github",
7 | "colab_type": "text"
8 | },
9 | "source": [
10 | "
"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 2,
16 | "metadata": {
17 | "colab": {
18 | "base_uri": "https://localhost:8080/"
19 | },
20 | "id": "beViQ8vrM6LY",
21 | "outputId": "1a821c07-6701-457d-ddd3-a3f52a03d4a6"
22 | },
23 | "outputs": [
24 | {
25 | "output_type": "stream",
26 | "name": "stdout",
27 | "text": [
28 | "Cloning into 'enformer-pytorch'...\n",
29 | "remote: Enumerating objects: 643, done.\u001b[K\n",
30 | "remote: Counting objects: 100% (132/132), done.\u001b[K\n",
31 | "remote: Compressing objects: 100% (117/117), done.\u001b[K\n",
32 | "remote: Total 643 (delta 28), reused 28 (delta 13), pack-reused 511\u001b[K\n",
33 | "Receiving objects: 100% (643/643), 8.88 MiB | 3.09 MiB/s, done.\n",
34 | "Resolving deltas: 100% (439/439), done.\n"
35 | ]
36 | }
37 | ],
38 | "source": [
39 | "!git clone https://github.com/lucidrains/enformer-pytorch.git"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 3,
45 | "metadata": {
46 | "colab": {
47 | "base_uri": "https://localhost:8080/"
48 | },
49 | "id": "PMXWgyEfNT_s",
50 | "outputId": "804379ef-37af-481c-c870-b3e9c7b8e2bb"
51 | },
52 | "outputs": [
53 | {
54 | "output_type": "stream",
55 | "name": "stdout",
56 | "text": [
57 | "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
58 | "Processing /content/enformer-pytorch\n",
59 | "\u001b[33m DEPRECATION: A future pip version will change local packages to be built in-place without first copying to a temporary directory. We recommend you use --use-feature=in-tree-build to test your packages with this new behavior before it becomes the default.\n",
60 | " pip 21.3 will remove support for this functionality. You can find discussion regarding this at https://github.com/pypa/pip/issues/7555.\u001b[0m\n",
61 | "Collecting einops>=0.3\n",
62 | " Downloading einops-0.4.1-py3-none-any.whl (28 kB)\n",
63 | "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from enformer-pytorch==0.5.1) (1.21.6)\n",
64 | "Requirement already satisfied: torch>=1.6 in /usr/local/lib/python3.7/dist-packages (from enformer-pytorch==0.5.1) (1.11.0+cu113)\n",
65 | "Collecting polars\n",
66 | " Downloading polars-0.13.40-cp37-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (11.9 MB)\n",
67 | "\u001b[K |████████████████████████████████| 11.9 MB 19.1 MB/s \n",
68 | "\u001b[?25hCollecting pyfaidx\n",
69 | " Downloading pyfaidx-0.7.0.tar.gz (102 kB)\n",
70 | "\u001b[K |████████████████████████████████| 102 kB 70.3 MB/s \n",
71 | "\u001b[?25hRequirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from enformer-pytorch==0.5.1) (3.13)\n",
72 | "Collecting transformers\n",
73 | " Downloading transformers-4.19.2-py3-none-any.whl (4.2 MB)\n",
74 | "\u001b[K |████████████████████████████████| 4.2 MB 56.0 MB/s \n",
75 | "\u001b[?25hRequirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.6->enformer-pytorch==0.5.1) (4.2.0)\n",
76 | "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from pyfaidx->enformer-pytorch==0.5.1) (1.15.0)\n",
77 | "Requirement already satisfied: setuptools>=0.7 in /usr/local/lib/python3.7/dist-packages (from pyfaidx->enformer-pytorch==0.5.1) (57.4.0)\n",
78 | "Collecting tokenizers!=0.11.3,<0.13,>=0.11.1\n",
79 | " Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)\n",
80 | "\u001b[K |████████████████████████████████| 6.6 MB 44.0 MB/s \n",
81 | "\u001b[?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers->enformer-pytorch==0.5.1) (4.64.0)\n",
82 | "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers->enformer-pytorch==0.5.1) (3.7.0)\n",
83 | "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers->enformer-pytorch==0.5.1) (2019.12.20)\n",
84 | "Collecting huggingface-hub<1.0,>=0.1.0\n",
85 | " Downloading huggingface_hub-0.7.0-py3-none-any.whl (86 kB)\n",
86 | "\u001b[K |████████████████████████████████| 86 kB 6.6 MB/s \n",
87 | "\u001b[?25hRequirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers->enformer-pytorch==0.5.1) (4.11.4)\n",
88 | "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers->enformer-pytorch==0.5.1) (21.3)\n",
89 | "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers->enformer-pytorch==0.5.1) (2.23.0)\n",
90 | "Collecting pyyaml\n",
91 | " Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)\n",
92 | "\u001b[K |████████████████████████████████| 596 kB 64.8 MB/s \n",
93 | "\u001b[?25hRequirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers->enformer-pytorch==0.5.1) (3.0.9)\n",
94 | "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers->enformer-pytorch==0.5.1) (3.8.0)\n",
95 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers->enformer-pytorch==0.5.1) (2022.5.18.1)\n",
96 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers->enformer-pytorch==0.5.1) (3.0.4)\n",
97 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers->enformer-pytorch==0.5.1) (1.24.3)\n",
98 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers->enformer-pytorch==0.5.1) (2.10)\n",
99 | "Building wheels for collected packages: enformer-pytorch, pyfaidx\n",
100 | " Building wheel for enformer-pytorch (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
101 | " Created wheel for enformer-pytorch: filename=enformer_pytorch-0.5.1-py3-none-any.whl size=11587 sha256=3888ddb9602b37eedeb4b73f60c5d56609dd500e54536df6fb1bb50f481fe6d5\n",
102 | " Stored in directory: /root/.cache/pip/wheels/e9/30/2d/17bcf281153d214e2a57d1ed68544840825283066370c73eda\n",
103 | " Building wheel for pyfaidx (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
104 | " Created wheel for pyfaidx: filename=pyfaidx-0.7.0-py3-none-any.whl size=27697 sha256=464766257aa70062f3ac14de5005e39af8e5375eb2be08e6f32e0a05dd994a12\n",
105 | " Stored in directory: /root/.cache/pip/wheels/df/6b/ce/46374a70af569061fa10a6c16525b0d8efe2d9a4069f8a144a\n",
106 | "Successfully built enformer-pytorch pyfaidx\n",
107 | "Installing collected packages: pyyaml, tokenizers, huggingface-hub, transformers, pyfaidx, polars, einops, enformer-pytorch\n",
108 | " Attempting uninstall: pyyaml\n",
109 | " Found existing installation: PyYAML 3.13\n",
110 | " Uninstalling PyYAML-3.13:\n",
111 | " Successfully uninstalled PyYAML-3.13\n",
112 | "Successfully installed einops-0.4.1 enformer-pytorch-0.5.1 huggingface-hub-0.7.0 polars-0.13.40 pyfaidx-0.7.0 pyyaml-6.0 tokenizers-0.12.1 transformers-4.19.2\n"
113 | ]
114 | }
115 | ],
116 | "source": [
117 | "!cd enformer-pytorch && pip install ."
118 | ]
119 | },
120 | {
121 | "cell_type": "code",
122 | "execution_count": 4,
123 | "metadata": {
124 | "id": "0XbZipT7O0yD"
125 | },
126 | "outputs": [],
127 | "source": [
128 | "!pip install torchmetrics kipoiseq==0.5.2 BioPython --quiet > /dev/null"
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": 5,
134 | "metadata": {
135 | "colab": {
136 | "base_uri": "https://localhost:8080/"
137 | },
138 | "id": "H8miDmi7OMK1",
139 | "outputId": "4249c35d-2c6f-4c0f-cf70-16d74da61f24"
140 | },
141 | "outputs": [
142 | {
143 | "output_type": "stream",
144 | "name": "stdout",
145 | "text": [
146 | "Copying gs://basenji_barnyard/hg38.ml.fa.gz...\n",
147 | "/ [1/1 files][839.8 MiB/839.8 MiB] 100% Done 58.5 MiB/s ETA 00:00:00 \n",
148 | "Operation completed over 1 objects/839.8 MiB. \n",
149 | "Copying gs://basenji_barnyard/mm10.ml.fa.gz...\n",
150 | "/ [1/1 files][800.8 MiB/800.8 MiB] 100% Done 65.9 MiB/s ETA 00:00:00 \n",
151 | "Operation completed over 1 objects/800.8 MiB. \n"
152 | ]
153 | }
154 | ],
155 | "source": [
156 | "import torch\n",
157 | "import numpy as np\n",
158 | "import tensorflow as tf\n",
159 | "import os \n",
160 | "import json\n",
161 | "import pandas as pd\n",
162 | "import pyfaidx\n",
163 | "import kipoiseq\n",
164 | "import functools\n",
165 | "from kipoiseq import Interval\n",
166 | "\n",
167 | "SEQUENCE_LENGTH = 196_608\n",
168 | "BIN_SIZE = 128\n",
169 | "TARGET_LENGTH = 896\n",
170 | "import os\n",
171 | "fasta_dir = \"/root/data/\"\n",
172 | "!mkdir -p {fasta_dir}\n",
173 | "human_fasta_f = 'hg38.ml.fa.gz'\n",
174 | "mouse_fasta_f = 'mm10.ml.fa.gz'\n",
175 | "human_fasta_gz_path = f\"{fasta_dir}/{human_fasta_f}\"\n",
176 | "mouse_fasta_gz_path = f\"{fasta_dir}/{mouse_fasta_f}\"\n",
177 | "human_fasta_path = human_fasta_gz_path.rstrip(\".gz\")\n",
178 | "mouse_fasta_path = mouse_fasta_gz_path.rstrip(\".gz\")\n",
179 | "\n",
180 | "if not os.path.isfile(human_fasta_path):\n",
181 | " !gsutil -m cp -n gs://basenji_barnyard/{human_fasta_f} {human_fasta_gz_path}\n",
182 | " !gunzip {human_fasta_gz_path}\n",
183 | "if not os.path.isfile(mouse_fasta_path):\n",
184 | " !gsutil -m cp -n gs://basenji_barnyard/{mouse_fasta_f} {mouse_fasta_gz_path}\n",
185 | " !gunzip {mouse_fasta_gz_path}\n",
186 | "\n",
187 | "class FastaStringExtractor:\n",
188 | " \n",
189 | " def __init__(self, fasta_file):\n",
190 | " self.fasta = pyfaidx.Fasta(fasta_file)\n",
191 | " self._chromosome_sizes = {k: len(v) for k, v in self.fasta.items()}\n",
192 | "\n",
193 | " def extract(self, interval: Interval, **kwargs) -> str:\n",
194 | " # Truncate interval if it extends beyond the chromosome lengths.\n",
195 | " chromosome_length = self._chromosome_sizes[interval.chrom]\n",
196 | " trimmed_interval = Interval(interval.chrom,\n",
197 | " max(interval.start, 0),\n",
198 | " min(interval.end, chromosome_length),\n",
199 | " )\n",
200 | " # pyfaidx wants a 1-based interval\n",
201 | " sequence = str(self.fasta.get_seq(trimmed_interval.chrom,\n",
202 | " trimmed_interval.start + 1,\n",
203 | " trimmed_interval.stop).seq).upper()\n",
204 | " # Fill truncated values with N's.\n",
205 | " pad_upstream = 'N' * max(-interval.start, 0)\n",
206 | " pad_downstream = 'N' * max(interval.end - chromosome_length, 0)\n",
207 | " return pad_upstream + sequence + pad_downstream\n",
208 | "\n",
209 | " def close(self):\n",
210 | " return self.fasta.close()\n",
211 | "\n",
212 | "\n",
213 | "class BasenjiDataSet(torch.utils.data.IterableDataset):\n",
214 | " @staticmethod\n",
215 | " def get_organism_path(organism):\n",
216 | " return os.path.join('gs://basenji_barnyard/data', organism)\n",
217 | " @classmethod\n",
218 | " def get_metadata(cls, organism):\n",
219 | " # Keys:\n",
220 | " # num_targets, train_seqs, valid_seqs, test_seqs, seq_length,\n",
221 | " # pool_width, crop_bp, target_length\n",
222 | " path = os.path.join(cls.get_organism_path(organism), 'statistics.json')\n",
223 | " with tf.io.gfile.GFile(path, 'r') as f:\n",
224 | " return json.load(f)\n",
225 | " @staticmethod\n",
226 | " def one_hot_encode(sequence):\n",
227 | " return kipoiseq.transforms.functional.one_hot_dna(sequence).astype(np.float32)\n",
228 | "\n",
229 | " @classmethod\n",
230 | " def get_tfrecord_files(cls, organism, subset):\n",
231 | " # Sort the values by int(*).\n",
232 | " return sorted(tf.io.gfile.glob(os.path.join(\n",
233 | " cls.get_organism_path(organism), 'tfrecords', f'{subset}-*.tfr'\n",
234 | " )), key=lambda x: int(x.split('-')[-1].split('.')[0]))\n",
235 | " \n",
236 | " @property\n",
237 | " def num_channels(self):\n",
238 | " metadata = self.get_metadata(self.organism)\n",
239 | " return metadata['num_targets']\n",
240 | "\n",
241 | " @staticmethod\n",
242 | " def deserialize(serialized_example, metadata):\n",
243 | " \"\"\"Deserialize bytes stored in TFRecordFile.\"\"\"\n",
244 | " # Deserialization\n",
245 | " feature_map = {\n",
246 | " 'sequence': tf.io.FixedLenFeature([], tf.string), # Ignore this, resize our own bigger one\n",
247 | " 'target': tf.io.FixedLenFeature([], tf.string),\n",
248 | " }\n",
249 | " example = tf.io.parse_example(serialized_example, feature_map)\n",
250 | " sequence = tf.io.decode_raw(example['sequence'], tf.bool)\n",
251 | " sequence = tf.reshape(sequence, (metadata['seq_length'], 4))\n",
252 | " sequence = tf.cast(sequence, tf.float32)\n",
253 | "\n",
254 | " target = tf.io.decode_raw(example['target'], tf.float16)\n",
255 | " target = tf.reshape(target,\n",
256 | " (metadata['target_length'], metadata['num_targets']))\n",
257 | " target = tf.cast(target, tf.float32)\n",
258 | "\n",
259 | " return {'sequence_old': sequence,\n",
260 | " 'target': target}\n",
261 | "\n",
262 | " @classmethod\n",
263 | " def get_dataset(cls, organism, subset, num_threads=8):\n",
264 | " metadata = cls.get_metadata(organism)\n",
265 | " dataset = tf.data.TFRecordDataset(cls.get_tfrecord_files(organism, subset),\n",
266 | " compression_type='ZLIB',\n",
267 | " num_parallel_reads=num_threads).map(\n",
268 | " functools.partial(cls.deserialize, metadata=metadata)\n",
269 | " )\n",
270 | " return dataset\n",
271 | "\n",
272 | " def __init__(self, organism:str, subset:str, seq_len:int, fasta_path:str, n_to_test:int = -1):\n",
273 | " assert subset in {\"train\", \"valid\", \"test\"}\n",
274 | " assert organism in {\"human\", \"mouse\"}\n",
275 | " self.organism = organism\n",
276 | " self.subset = subset\n",
277 | " self.base_dir = self.get_organism_path(organism)\n",
278 | " self.seq_len = seq_len\n",
279 | " self.fasta_reader = FastaStringExtractor(fasta_path)\n",
280 | " self.n_to_test = n_to_test\n",
281 | " with tf.io.gfile.GFile(f\"{self.base_dir}/sequences.bed\", 'r') as f:\n",
282 | " region_df = pd.read_csv(f, sep=\"\\t\", header=None)\n",
283 | " region_df.columns = ['chrom', 'start', 'end', 'subset']\n",
284 | " self.region_df = region_df.query('subset==@subset').reset_index(drop=True)\n",
285 | " \n",
286 | " def __iter__(self):\n",
287 | " worker_info = torch.utils.data.get_worker_info()\n",
288 | " assert worker_info is None, \"Only support single process loading\"\n",
289 | " # If num_threads > 1, the following will actually shuffle the inputs! luckily we catch this with the sequence comparison\n",
290 | " basenji_iterator = self.get_dataset(self.organism, self.subset, num_threads=1).as_numpy_iterator()\n",
291 | " for i, records in enumerate(basenji_iterator):\n",
292 | " loc_row = self.region_df.iloc[i]\n",
293 | " target_interval = Interval(loc_row['chrom'], loc_row['start'], loc_row['end'])\n",
294 | " sequence_one_hot = self.one_hot_encode(self.fasta_reader.extract(target_interval.resize(self.seq_len)))\n",
295 | " if self.n_to_test >= 0 and i < self.n_to_test:\n",
296 | " old_sequence_onehot = records[\"sequence_old\"]\n",
297 | " if old_sequence_onehot.shape[0] > sequence_one_hot.shape[0]:\n",
298 | " diff = old_sequence_onehot.shape[0] - sequence_one_hot.shape[0]\n",
299 | " trim = diff//2\n",
300 | " np.testing.assert_equal(old_sequence_onehot[trim:(-trim)], sequence_one_hot)\n",
301 | " elif sequence_one_hot.shape[0] > old_sequence_onehot.shape[0]:\n",
302 | " diff = sequence_one_hot.shape[0] - old_sequence_onehot.shape[0]\n",
303 | " trim = diff//2\n",
304 | " np.testing.assert_equal(old_sequence_onehot, sequence_one_hot[trim:(-trim)])\n",
305 | " else:\n",
306 | " np.testing.assert_equal(old_sequence_onehot, sequence_one_hot)\n",
307 | " yield {\n",
308 | " \"sequence\": sequence_one_hot,\n",
309 | " \"target\": records[\"target\"],\n",
310 | " }"
311 | ]
312 | },
313 | {
314 | "cell_type": "code",
315 | "execution_count": 6,
316 | "metadata": {
317 | "colab": {
318 | "base_uri": "https://localhost:8080/",
319 | "height": 81,
320 | "referenced_widgets": [
321 | "a76d636b937c4b5691d36cb6b7ec7b17",
322 | "e06f0950617847e7a894afc8cc8ec69f",
323 | "fd6815290f9e45f69d9ac29826de71ef",
324 | "109781990b25433daa411cafa44c662a",
325 | "69789023fe8544ff8a0d7540de793041",
326 | "ef2d9fda87da4dbc825f4776a982856d",
327 | "20f0fc109b5249a3bd0e5738f6b8c06f",
328 | "7ab66e50e29a48b38a1d5c1820422a6a",
329 | "b0b4b1ebc3ac4aaf91b0112969fc64f9",
330 | "38e1b0815aaf4c01b1988f73fc3c38d2",
331 | "c177e2c35da948918a29fdcea6840d89",
332 | "beeeb50adabb48978b1c502d799713be",
333 | "a34e3298896c454abb1b0dda7c225c2b",
334 | "579ec7740cbc455a9c6d3e1bc4442612",
335 | "e486e3955a574b818227964293fa76e1",
336 | "44a8b215df314ba9beb66b44d2e13d29",
337 | "36a42d5cd33541b5aa24db50b03cef9a",
338 | "e47a62c2ab1a4fcea674274f3435f031",
339 | "02df913b81f3473cbbe0189880e20586",
340 | "9d3909432e2d42a5849427f6d1ad1371",
341 | "ec0bf3f9df7c40b8b64aef5ac62f0728",
342 | "7ba9a6bb917746d2bdb7143843480038"
343 | ]
344 | },
345 | "id": "W_QtUyQXNfgJ",
346 | "outputId": "2b58967c-2b79-439f-a901-687a164282d3"
347 | },
348 | "outputs": [
349 | {
350 | "output_type": "display_data",
351 | "data": {
352 | "text/plain": [
353 | "Downloading: 0%| | 0.00/464 [00:00, ?B/s]"
354 | ],
355 | "application/vnd.jupyter.widget-view+json": {
356 | "version_major": 2,
357 | "version_minor": 0,
358 | "model_id": "a76d636b937c4b5691d36cb6b7ec7b17"
359 | }
360 | },
361 | "metadata": {}
362 | },
363 | {
364 | "output_type": "display_data",
365 | "data": {
366 | "text/plain": [
367 | "Downloading: 0%| | 0.00/959M [00:00, ?B/s]"
368 | ],
369 | "application/vnd.jupyter.widget-view+json": {
370 | "version_major": 2,
371 | "version_minor": 0,
372 | "model_id": "beeeb50adabb48978b1c502d799713be"
373 | }
374 | },
375 | "metadata": {}
376 | }
377 | ],
378 | "source": [
379 | "import torch\n",
380 | "from enformer_pytorch import Enformer\n",
381 | "\n",
382 | "model = Enformer.from_pretrained(\"EleutherAI/enformer-official-rough\")"
383 | ]
384 | },
385 | {
386 | "cell_type": "code",
387 | "execution_count": 7,
388 | "metadata": {
389 | "id": "3wgrVVQIRB-w"
390 | },
391 | "outputs": [],
392 | "source": [
393 | "model = model.eval().cuda()"
394 | ]
395 | },
396 | {
397 | "cell_type": "code",
398 | "execution_count": 14,
399 | "metadata": {
400 | "id": "2MCoTxFDVPuV"
401 | },
402 | "outputs": [],
403 | "source": [
404 | "from enformer_pytorch.metrics import MeanPearsonCorrCoefPerChannel"
405 | ]
406 | },
407 | {
408 | "cell_type": "code",
409 | "execution_count": 15,
410 | "metadata": {
411 | "colab": {
412 | "base_uri": "https://localhost:8080/"
413 | },
414 | "id": "0GFJBmbLQ1-R",
415 | "outputId": "d10a6a75-d141-4d45-bf76-7f5e2d23c6ec"
416 | },
417 | "outputs": [
418 | {
419 | "output_type": "stream",
420 | "name": "stderr",
421 | "text": [
422 | "100%|██████████| 100/100 [01:22<00:00, 1.21it/s]\n"
423 | ]
424 | },
425 | {
426 | "output_type": "execute_result",
427 | "data": {
428 | "text/plain": [
429 | "tensor(0.6270)"
430 | ]
431 | },
432 | "metadata": {},
433 | "execution_count": 15
434 | }
435 | ],
436 | "source": [
437 | "from tqdm import tqdm\n",
438 | "from torchmetrics.regression.pearson import PearsonCorrCoef\n",
439 | "def compute_correlation(model, organism:str=\"human\", subset:str=\"valid\", max_steps=-1):\n",
440 | " fasta_path = human_fasta_path if organism == \"human\" else mouse_fasta_path\n",
441 | " ds = BasenjiDataSet(organism, subset, SEQUENCE_LENGTH, fasta_path)\n",
442 | " total = len(ds.region_df) # number of records\n",
443 | " dl = torch.utils.data.DataLoader(ds, num_workers=0, batch_size=1)\n",
444 | " corr_coef = MeanPearsonCorrCoefPerChannel(n_channels=ds.num_channels)\n",
445 | " n_steps = total if max_steps <= 0 else max_steps\n",
446 | " for i,batch in enumerate(tqdm(dl, total=n_steps)):\n",
447 | " if max_steps > 0 and i >= max_steps:\n",
448 | " break\n",
449 | " batch_gpu = {k:v.to(model.device) for k,v in batch.items()}\n",
450 | " sequence = batch_gpu['sequence']\n",
451 | " target = batch_gpu['target']\n",
452 | " with torch.no_grad():\n",
453 | " pred = model(sequence)[organism]\n",
454 | " corr_coef(preds=pred.cpu(), target=target.cpu())\n",
455 | " return corr_coef.compute().mean()\n",
456 | "compute_correlation(model, organism=\"human\", subset=\"valid\", max_steps=100)"
457 | ]
458 | },
459 | {
460 | "cell_type": "code",
461 | "execution_count": null,
462 | "metadata": {
463 | "colab": {
464 | "background_save": true,
465 | "base_uri": "https://localhost:8080/"
466 | },
467 | "id": "raZi83xyg8SV",
468 | "outputId": "c00dae32-0838-4b4d-a9a1-be6040133011"
469 | },
470 | "outputs": [
471 | {
472 | "name": "stderr",
473 | "output_type": "stream",
474 | "text": [
475 | "/usr/local/lib/python3.7/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n",
476 | " not been set for this class (MeanPearsonCorrCoefPerChannel). The property determines if `update` by\n",
477 | " default needs access to the full metric state. If this is not the case, significant speedups can be\n",
478 | " achieved and we recommend setting this to `False`.\n",
479 | " We provide an checking function\n",
480 | " `from torchmetrics.utilities import check_forward_no_full_state`\n",
481 | " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n",
482 | " default for now) or if `full_state_update=False` can be used safely.\n",
483 | " \n",
484 | " warnings.warn(*args, **kwargs)\n",
485 | "100%|██████████| 2213/2213 [30:51<00:00, 1.20it/s]\n"
486 | ]
487 | },
488 | {
489 | "data": {
490 | "text/plain": [
491 | "tensor(0.6252)"
492 | ]
493 | },
494 | "execution_count": null,
495 | "metadata": {},
496 | "output_type": "execute_result"
497 | }
498 | ],
499 | "source": [
500 | "compute_correlation(model, organism=\"human\", subset=\"valid\", max_steps=-1)"
501 | ]
502 | },
503 | {
504 | "cell_type": "code",
505 | "execution_count": null,
506 | "metadata": {
507 | "colab": {
508 | "background_save": true
509 | },
510 | "id": "PtWriiqConiR",
511 | "outputId": "20acb7f9-3ccc-425c-83d6-99ad88100cba"
512 | },
513 | "outputs": [
514 | {
515 | "name": "stderr",
516 | "output_type": "stream",
517 | "text": [
518 | "/usr/local/lib/python3.7/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n",
519 | " not been set for this class (MeanPearsonCorrCoefPerChannel). The property determines if `update` by\n",
520 | " default needs access to the full metric state. If this is not the case, significant speedups can be\n",
521 | " achieved and we recommend setting this to `False`.\n",
522 | " We provide an checking function\n",
523 | " `from torchmetrics.utilities import check_forward_no_full_state`\n",
524 | " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n",
525 | " default for now) or if `full_state_update=False` can be used safely.\n",
526 | " \n",
527 | " warnings.warn(*args, **kwargs)\n",
528 | "100%|██████████| 1937/1937 [20:05<00:00, 1.61it/s]\n"
529 | ]
530 | },
531 | {
532 | "data": {
533 | "text/plain": [
534 | "tensor(0.6503)"
535 | ]
536 | },
537 | "execution_count": null,
538 | "metadata": {},
539 | "output_type": "execute_result"
540 | }
541 | ],
542 | "source": [
543 | "compute_correlation(model, organism=\"human\", subset=\"test\", max_steps=-1)"
544 | ]
545 | },
546 | {
547 | "cell_type": "code",
548 | "execution_count": null,
549 | "metadata": {
550 | "colab": {
551 | "background_save": true
552 | },
553 | "id": "oRMRZUMHo5XX",
554 | "outputId": "68b054a6-e6b7-4949-9866-ad4b9615fc9e"
555 | },
556 | "outputs": [
557 | {
558 | "name": "stderr",
559 | "output_type": "stream",
560 | "text": [
561 | "/usr/local/lib/python3.7/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n",
562 | " not been set for this class (MeanPearsonCorrCoefPerChannel). The property determines if `update` by\n",
563 | " default needs access to the full metric state. If this is not the case, significant speedups can be\n",
564 | " achieved and we recommend setting this to `False`.\n",
565 | " We provide an checking function\n",
566 | " `from torchmetrics.utilities import check_forward_no_full_state`\n",
567 | " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n",
568 | " default for now) or if `full_state_update=False` can be used safely.\n",
569 | " \n",
570 | " warnings.warn(*args, **kwargs)\n",
571 | "100%|██████████| 34021/34021 [4:58:22<00:00, 1.90it/s]\n"
572 | ]
573 | },
574 | {
575 | "data": {
576 | "text/plain": [
577 | "tensor(0.7415)"
578 | ]
579 | },
580 | "execution_count": null,
581 | "metadata": {},
582 | "output_type": "execute_result"
583 | }
584 | ],
585 | "source": [
586 | "compute_correlation(model, organism=\"human\", subset=\"train\", max_steps=-1)"
587 | ]
588 | }
589 | ],
590 | "metadata": {
591 | "accelerator": "GPU",
592 | "colab": {
593 | "background_execution": "on",
594 | "collapsed_sections": [],
595 | "name": "evaluate_enformer_pytorch_correlation.ipynb",
596 | "provenance": [],
597 | "authorship_tag": "ABX9TyP5rds3GvFsVbB6KySflndx",
598 | "include_colab_link": true
599 | },
600 | "kernelspec": {
601 | "display_name": "Python 3",
602 | "name": "python3"
603 | },
604 | "language_info": {
605 | "name": "python"
606 | },
607 | "widgets": {
608 | "application/vnd.jupyter.widget-state+json": {
609 | "a76d636b937c4b5691d36cb6b7ec7b17": {
610 | "model_module": "@jupyter-widgets/controls",
611 | "model_name": "HBoxModel",
612 | "model_module_version": "1.5.0",
613 | "state": {
614 | "_dom_classes": [],
615 | "_model_module": "@jupyter-widgets/controls",
616 | "_model_module_version": "1.5.0",
617 | "_model_name": "HBoxModel",
618 | "_view_count": null,
619 | "_view_module": "@jupyter-widgets/controls",
620 | "_view_module_version": "1.5.0",
621 | "_view_name": "HBoxView",
622 | "box_style": "",
623 | "children": [
624 | "IPY_MODEL_e06f0950617847e7a894afc8cc8ec69f",
625 | "IPY_MODEL_fd6815290f9e45f69d9ac29826de71ef",
626 | "IPY_MODEL_109781990b25433daa411cafa44c662a"
627 | ],
628 | "layout": "IPY_MODEL_69789023fe8544ff8a0d7540de793041"
629 | }
630 | },
631 | "e06f0950617847e7a894afc8cc8ec69f": {
632 | "model_module": "@jupyter-widgets/controls",
633 | "model_name": "HTMLModel",
634 | "model_module_version": "1.5.0",
635 | "state": {
636 | "_dom_classes": [],
637 | "_model_module": "@jupyter-widgets/controls",
638 | "_model_module_version": "1.5.0",
639 | "_model_name": "HTMLModel",
640 | "_view_count": null,
641 | "_view_module": "@jupyter-widgets/controls",
642 | "_view_module_version": "1.5.0",
643 | "_view_name": "HTMLView",
644 | "description": "",
645 | "description_tooltip": null,
646 | "layout": "IPY_MODEL_ef2d9fda87da4dbc825f4776a982856d",
647 | "placeholder": "",
648 | "style": "IPY_MODEL_20f0fc109b5249a3bd0e5738f6b8c06f",
649 | "value": "Downloading: 100%"
650 | }
651 | },
652 | "fd6815290f9e45f69d9ac29826de71ef": {
653 | "model_module": "@jupyter-widgets/controls",
654 | "model_name": "FloatProgressModel",
655 | "model_module_version": "1.5.0",
656 | "state": {
657 | "_dom_classes": [],
658 | "_model_module": "@jupyter-widgets/controls",
659 | "_model_module_version": "1.5.0",
660 | "_model_name": "FloatProgressModel",
661 | "_view_count": null,
662 | "_view_module": "@jupyter-widgets/controls",
663 | "_view_module_version": "1.5.0",
664 | "_view_name": "ProgressView",
665 | "bar_style": "success",
666 | "description": "",
667 | "description_tooltip": null,
668 | "layout": "IPY_MODEL_7ab66e50e29a48b38a1d5c1820422a6a",
669 | "max": 464,
670 | "min": 0,
671 | "orientation": "horizontal",
672 | "style": "IPY_MODEL_b0b4b1ebc3ac4aaf91b0112969fc64f9",
673 | "value": 464
674 | }
675 | },
676 | "109781990b25433daa411cafa44c662a": {
677 | "model_module": "@jupyter-widgets/controls",
678 | "model_name": "HTMLModel",
679 | "model_module_version": "1.5.0",
680 | "state": {
681 | "_dom_classes": [],
682 | "_model_module": "@jupyter-widgets/controls",
683 | "_model_module_version": "1.5.0",
684 | "_model_name": "HTMLModel",
685 | "_view_count": null,
686 | "_view_module": "@jupyter-widgets/controls",
687 | "_view_module_version": "1.5.0",
688 | "_view_name": "HTMLView",
689 | "description": "",
690 | "description_tooltip": null,
691 | "layout": "IPY_MODEL_38e1b0815aaf4c01b1988f73fc3c38d2",
692 | "placeholder": "",
693 | "style": "IPY_MODEL_c177e2c35da948918a29fdcea6840d89",
694 | "value": " 464/464 [00:00<00:00, 12.3kB/s]"
695 | }
696 | },
697 | "69789023fe8544ff8a0d7540de793041": {
698 | "model_module": "@jupyter-widgets/base",
699 | "model_name": "LayoutModel",
700 | "model_module_version": "1.2.0",
701 | "state": {
702 | "_model_module": "@jupyter-widgets/base",
703 | "_model_module_version": "1.2.0",
704 | "_model_name": "LayoutModel",
705 | "_view_count": null,
706 | "_view_module": "@jupyter-widgets/base",
707 | "_view_module_version": "1.2.0",
708 | "_view_name": "LayoutView",
709 | "align_content": null,
710 | "align_items": null,
711 | "align_self": null,
712 | "border": null,
713 | "bottom": null,
714 | "display": null,
715 | "flex": null,
716 | "flex_flow": null,
717 | "grid_area": null,
718 | "grid_auto_columns": null,
719 | "grid_auto_flow": null,
720 | "grid_auto_rows": null,
721 | "grid_column": null,
722 | "grid_gap": null,
723 | "grid_row": null,
724 | "grid_template_areas": null,
725 | "grid_template_columns": null,
726 | "grid_template_rows": null,
727 | "height": null,
728 | "justify_content": null,
729 | "justify_items": null,
730 | "left": null,
731 | "margin": null,
732 | "max_height": null,
733 | "max_width": null,
734 | "min_height": null,
735 | "min_width": null,
736 | "object_fit": null,
737 | "object_position": null,
738 | "order": null,
739 | "overflow": null,
740 | "overflow_x": null,
741 | "overflow_y": null,
742 | "padding": null,
743 | "right": null,
744 | "top": null,
745 | "visibility": null,
746 | "width": null
747 | }
748 | },
749 | "ef2d9fda87da4dbc825f4776a982856d": {
750 | "model_module": "@jupyter-widgets/base",
751 | "model_name": "LayoutModel",
752 | "model_module_version": "1.2.0",
753 | "state": {
754 | "_model_module": "@jupyter-widgets/base",
755 | "_model_module_version": "1.2.0",
756 | "_model_name": "LayoutModel",
757 | "_view_count": null,
758 | "_view_module": "@jupyter-widgets/base",
759 | "_view_module_version": "1.2.0",
760 | "_view_name": "LayoutView",
761 | "align_content": null,
762 | "align_items": null,
763 | "align_self": null,
764 | "border": null,
765 | "bottom": null,
766 | "display": null,
767 | "flex": null,
768 | "flex_flow": null,
769 | "grid_area": null,
770 | "grid_auto_columns": null,
771 | "grid_auto_flow": null,
772 | "grid_auto_rows": null,
773 | "grid_column": null,
774 | "grid_gap": null,
775 | "grid_row": null,
776 | "grid_template_areas": null,
777 | "grid_template_columns": null,
778 | "grid_template_rows": null,
779 | "height": null,
780 | "justify_content": null,
781 | "justify_items": null,
782 | "left": null,
783 | "margin": null,
784 | "max_height": null,
785 | "max_width": null,
786 | "min_height": null,
787 | "min_width": null,
788 | "object_fit": null,
789 | "object_position": null,
790 | "order": null,
791 | "overflow": null,
792 | "overflow_x": null,
793 | "overflow_y": null,
794 | "padding": null,
795 | "right": null,
796 | "top": null,
797 | "visibility": null,
798 | "width": null
799 | }
800 | },
801 | "20f0fc109b5249a3bd0e5738f6b8c06f": {
802 | "model_module": "@jupyter-widgets/controls",
803 | "model_name": "DescriptionStyleModel",
804 | "model_module_version": "1.5.0",
805 | "state": {
806 | "_model_module": "@jupyter-widgets/controls",
807 | "_model_module_version": "1.5.0",
808 | "_model_name": "DescriptionStyleModel",
809 | "_view_count": null,
810 | "_view_module": "@jupyter-widgets/base",
811 | "_view_module_version": "1.2.0",
812 | "_view_name": "StyleView",
813 | "description_width": ""
814 | }
815 | },
816 | "7ab66e50e29a48b38a1d5c1820422a6a": {
817 | "model_module": "@jupyter-widgets/base",
818 | "model_name": "LayoutModel",
819 | "model_module_version": "1.2.0",
820 | "state": {
821 | "_model_module": "@jupyter-widgets/base",
822 | "_model_module_version": "1.2.0",
823 | "_model_name": "LayoutModel",
824 | "_view_count": null,
825 | "_view_module": "@jupyter-widgets/base",
826 | "_view_module_version": "1.2.0",
827 | "_view_name": "LayoutView",
828 | "align_content": null,
829 | "align_items": null,
830 | "align_self": null,
831 | "border": null,
832 | "bottom": null,
833 | "display": null,
834 | "flex": null,
835 | "flex_flow": null,
836 | "grid_area": null,
837 | "grid_auto_columns": null,
838 | "grid_auto_flow": null,
839 | "grid_auto_rows": null,
840 | "grid_column": null,
841 | "grid_gap": null,
842 | "grid_row": null,
843 | "grid_template_areas": null,
844 | "grid_template_columns": null,
845 | "grid_template_rows": null,
846 | "height": null,
847 | "justify_content": null,
848 | "justify_items": null,
849 | "left": null,
850 | "margin": null,
851 | "max_height": null,
852 | "max_width": null,
853 | "min_height": null,
854 | "min_width": null,
855 | "object_fit": null,
856 | "object_position": null,
857 | "order": null,
858 | "overflow": null,
859 | "overflow_x": null,
860 | "overflow_y": null,
861 | "padding": null,
862 | "right": null,
863 | "top": null,
864 | "visibility": null,
865 | "width": null
866 | }
867 | },
868 | "b0b4b1ebc3ac4aaf91b0112969fc64f9": {
869 | "model_module": "@jupyter-widgets/controls",
870 | "model_name": "ProgressStyleModel",
871 | "model_module_version": "1.5.0",
872 | "state": {
873 | "_model_module": "@jupyter-widgets/controls",
874 | "_model_module_version": "1.5.0",
875 | "_model_name": "ProgressStyleModel",
876 | "_view_count": null,
877 | "_view_module": "@jupyter-widgets/base",
878 | "_view_module_version": "1.2.0",
879 | "_view_name": "StyleView",
880 | "bar_color": null,
881 | "description_width": ""
882 | }
883 | },
884 | "38e1b0815aaf4c01b1988f73fc3c38d2": {
885 | "model_module": "@jupyter-widgets/base",
886 | "model_name": "LayoutModel",
887 | "model_module_version": "1.2.0",
888 | "state": {
889 | "_model_module": "@jupyter-widgets/base",
890 | "_model_module_version": "1.2.0",
891 | "_model_name": "LayoutModel",
892 | "_view_count": null,
893 | "_view_module": "@jupyter-widgets/base",
894 | "_view_module_version": "1.2.0",
895 | "_view_name": "LayoutView",
896 | "align_content": null,
897 | "align_items": null,
898 | "align_self": null,
899 | "border": null,
900 | "bottom": null,
901 | "display": null,
902 | "flex": null,
903 | "flex_flow": null,
904 | "grid_area": null,
905 | "grid_auto_columns": null,
906 | "grid_auto_flow": null,
907 | "grid_auto_rows": null,
908 | "grid_column": null,
909 | "grid_gap": null,
910 | "grid_row": null,
911 | "grid_template_areas": null,
912 | "grid_template_columns": null,
913 | "grid_template_rows": null,
914 | "height": null,
915 | "justify_content": null,
916 | "justify_items": null,
917 | "left": null,
918 | "margin": null,
919 | "max_height": null,
920 | "max_width": null,
921 | "min_height": null,
922 | "min_width": null,
923 | "object_fit": null,
924 | "object_position": null,
925 | "order": null,
926 | "overflow": null,
927 | "overflow_x": null,
928 | "overflow_y": null,
929 | "padding": null,
930 | "right": null,
931 | "top": null,
932 | "visibility": null,
933 | "width": null
934 | }
935 | },
936 | "c177e2c35da948918a29fdcea6840d89": {
937 | "model_module": "@jupyter-widgets/controls",
938 | "model_name": "DescriptionStyleModel",
939 | "model_module_version": "1.5.0",
940 | "state": {
941 | "_model_module": "@jupyter-widgets/controls",
942 | "_model_module_version": "1.5.0",
943 | "_model_name": "DescriptionStyleModel",
944 | "_view_count": null,
945 | "_view_module": "@jupyter-widgets/base",
946 | "_view_module_version": "1.2.0",
947 | "_view_name": "StyleView",
948 | "description_width": ""
949 | }
950 | },
951 | "beeeb50adabb48978b1c502d799713be": {
952 | "model_module": "@jupyter-widgets/controls",
953 | "model_name": "HBoxModel",
954 | "model_module_version": "1.5.0",
955 | "state": {
956 | "_dom_classes": [],
957 | "_model_module": "@jupyter-widgets/controls",
958 | "_model_module_version": "1.5.0",
959 | "_model_name": "HBoxModel",
960 | "_view_count": null,
961 | "_view_module": "@jupyter-widgets/controls",
962 | "_view_module_version": "1.5.0",
963 | "_view_name": "HBoxView",
964 | "box_style": "",
965 | "children": [
966 | "IPY_MODEL_a34e3298896c454abb1b0dda7c225c2b",
967 | "IPY_MODEL_579ec7740cbc455a9c6d3e1bc4442612",
968 | "IPY_MODEL_e486e3955a574b818227964293fa76e1"
969 | ],
970 | "layout": "IPY_MODEL_44a8b215df314ba9beb66b44d2e13d29"
971 | }
972 | },
973 | "a34e3298896c454abb1b0dda7c225c2b": {
974 | "model_module": "@jupyter-widgets/controls",
975 | "model_name": "HTMLModel",
976 | "model_module_version": "1.5.0",
977 | "state": {
978 | "_dom_classes": [],
979 | "_model_module": "@jupyter-widgets/controls",
980 | "_model_module_version": "1.5.0",
981 | "_model_name": "HTMLModel",
982 | "_view_count": null,
983 | "_view_module": "@jupyter-widgets/controls",
984 | "_view_module_version": "1.5.0",
985 | "_view_name": "HTMLView",
986 | "description": "",
987 | "description_tooltip": null,
988 | "layout": "IPY_MODEL_36a42d5cd33541b5aa24db50b03cef9a",
989 | "placeholder": "",
990 | "style": "IPY_MODEL_e47a62c2ab1a4fcea674274f3435f031",
991 | "value": "Downloading: 100%"
992 | }
993 | },
994 | "579ec7740cbc455a9c6d3e1bc4442612": {
995 | "model_module": "@jupyter-widgets/controls",
996 | "model_name": "FloatProgressModel",
997 | "model_module_version": "1.5.0",
998 | "state": {
999 | "_dom_classes": [],
1000 | "_model_module": "@jupyter-widgets/controls",
1001 | "_model_module_version": "1.5.0",
1002 | "_model_name": "FloatProgressModel",
1003 | "_view_count": null,
1004 | "_view_module": "@jupyter-widgets/controls",
1005 | "_view_module_version": "1.5.0",
1006 | "_view_name": "ProgressView",
1007 | "bar_style": "success",
1008 | "description": "",
1009 | "description_tooltip": null,
1010 | "layout": "IPY_MODEL_02df913b81f3473cbbe0189880e20586",
1011 | "max": 1005149571,
1012 | "min": 0,
1013 | "orientation": "horizontal",
1014 | "style": "IPY_MODEL_9d3909432e2d42a5849427f6d1ad1371",
1015 | "value": 1005149571
1016 | }
1017 | },
1018 | "e486e3955a574b818227964293fa76e1": {
1019 | "model_module": "@jupyter-widgets/controls",
1020 | "model_name": "HTMLModel",
1021 | "model_module_version": "1.5.0",
1022 | "state": {
1023 | "_dom_classes": [],
1024 | "_model_module": "@jupyter-widgets/controls",
1025 | "_model_module_version": "1.5.0",
1026 | "_model_name": "HTMLModel",
1027 | "_view_count": null,
1028 | "_view_module": "@jupyter-widgets/controls",
1029 | "_view_module_version": "1.5.0",
1030 | "_view_name": "HTMLView",
1031 | "description": "",
1032 | "description_tooltip": null,
1033 | "layout": "IPY_MODEL_ec0bf3f9df7c40b8b64aef5ac62f0728",
1034 | "placeholder": "",
1035 | "style": "IPY_MODEL_7ba9a6bb917746d2bdb7143843480038",
1036 | "value": " 959M/959M [00:36<00:00, 58.7MB/s]"
1037 | }
1038 | },
1039 | "44a8b215df314ba9beb66b44d2e13d29": {
1040 | "model_module": "@jupyter-widgets/base",
1041 | "model_name": "LayoutModel",
1042 | "model_module_version": "1.2.0",
1043 | "state": {
1044 | "_model_module": "@jupyter-widgets/base",
1045 | "_model_module_version": "1.2.0",
1046 | "_model_name": "LayoutModel",
1047 | "_view_count": null,
1048 | "_view_module": "@jupyter-widgets/base",
1049 | "_view_module_version": "1.2.0",
1050 | "_view_name": "LayoutView",
1051 | "align_content": null,
1052 | "align_items": null,
1053 | "align_self": null,
1054 | "border": null,
1055 | "bottom": null,
1056 | "display": null,
1057 | "flex": null,
1058 | "flex_flow": null,
1059 | "grid_area": null,
1060 | "grid_auto_columns": null,
1061 | "grid_auto_flow": null,
1062 | "grid_auto_rows": null,
1063 | "grid_column": null,
1064 | "grid_gap": null,
1065 | "grid_row": null,
1066 | "grid_template_areas": null,
1067 | "grid_template_columns": null,
1068 | "grid_template_rows": null,
1069 | "height": null,
1070 | "justify_content": null,
1071 | "justify_items": null,
1072 | "left": null,
1073 | "margin": null,
1074 | "max_height": null,
1075 | "max_width": null,
1076 | "min_height": null,
1077 | "min_width": null,
1078 | "object_fit": null,
1079 | "object_position": null,
1080 | "order": null,
1081 | "overflow": null,
1082 | "overflow_x": null,
1083 | "overflow_y": null,
1084 | "padding": null,
1085 | "right": null,
1086 | "top": null,
1087 | "visibility": null,
1088 | "width": null
1089 | }
1090 | },
1091 | "36a42d5cd33541b5aa24db50b03cef9a": {
1092 | "model_module": "@jupyter-widgets/base",
1093 | "model_name": "LayoutModel",
1094 | "model_module_version": "1.2.0",
1095 | "state": {
1096 | "_model_module": "@jupyter-widgets/base",
1097 | "_model_module_version": "1.2.0",
1098 | "_model_name": "LayoutModel",
1099 | "_view_count": null,
1100 | "_view_module": "@jupyter-widgets/base",
1101 | "_view_module_version": "1.2.0",
1102 | "_view_name": "LayoutView",
1103 | "align_content": null,
1104 | "align_items": null,
1105 | "align_self": null,
1106 | "border": null,
1107 | "bottom": null,
1108 | "display": null,
1109 | "flex": null,
1110 | "flex_flow": null,
1111 | "grid_area": null,
1112 | "grid_auto_columns": null,
1113 | "grid_auto_flow": null,
1114 | "grid_auto_rows": null,
1115 | "grid_column": null,
1116 | "grid_gap": null,
1117 | "grid_row": null,
1118 | "grid_template_areas": null,
1119 | "grid_template_columns": null,
1120 | "grid_template_rows": null,
1121 | "height": null,
1122 | "justify_content": null,
1123 | "justify_items": null,
1124 | "left": null,
1125 | "margin": null,
1126 | "max_height": null,
1127 | "max_width": null,
1128 | "min_height": null,
1129 | "min_width": null,
1130 | "object_fit": null,
1131 | "object_position": null,
1132 | "order": null,
1133 | "overflow": null,
1134 | "overflow_x": null,
1135 | "overflow_y": null,
1136 | "padding": null,
1137 | "right": null,
1138 | "top": null,
1139 | "visibility": null,
1140 | "width": null
1141 | }
1142 | },
1143 | "e47a62c2ab1a4fcea674274f3435f031": {
1144 | "model_module": "@jupyter-widgets/controls",
1145 | "model_name": "DescriptionStyleModel",
1146 | "model_module_version": "1.5.0",
1147 | "state": {
1148 | "_model_module": "@jupyter-widgets/controls",
1149 | "_model_module_version": "1.5.0",
1150 | "_model_name": "DescriptionStyleModel",
1151 | "_view_count": null,
1152 | "_view_module": "@jupyter-widgets/base",
1153 | "_view_module_version": "1.2.0",
1154 | "_view_name": "StyleView",
1155 | "description_width": ""
1156 | }
1157 | },
1158 | "02df913b81f3473cbbe0189880e20586": {
1159 | "model_module": "@jupyter-widgets/base",
1160 | "model_name": "LayoutModel",
1161 | "model_module_version": "1.2.0",
1162 | "state": {
1163 | "_model_module": "@jupyter-widgets/base",
1164 | "_model_module_version": "1.2.0",
1165 | "_model_name": "LayoutModel",
1166 | "_view_count": null,
1167 | "_view_module": "@jupyter-widgets/base",
1168 | "_view_module_version": "1.2.0",
1169 | "_view_name": "LayoutView",
1170 | "align_content": null,
1171 | "align_items": null,
1172 | "align_self": null,
1173 | "border": null,
1174 | "bottom": null,
1175 | "display": null,
1176 | "flex": null,
1177 | "flex_flow": null,
1178 | "grid_area": null,
1179 | "grid_auto_columns": null,
1180 | "grid_auto_flow": null,
1181 | "grid_auto_rows": null,
1182 | "grid_column": null,
1183 | "grid_gap": null,
1184 | "grid_row": null,
1185 | "grid_template_areas": null,
1186 | "grid_template_columns": null,
1187 | "grid_template_rows": null,
1188 | "height": null,
1189 | "justify_content": null,
1190 | "justify_items": null,
1191 | "left": null,
1192 | "margin": null,
1193 | "max_height": null,
1194 | "max_width": null,
1195 | "min_height": null,
1196 | "min_width": null,
1197 | "object_fit": null,
1198 | "object_position": null,
1199 | "order": null,
1200 | "overflow": null,
1201 | "overflow_x": null,
1202 | "overflow_y": null,
1203 | "padding": null,
1204 | "right": null,
1205 | "top": null,
1206 | "visibility": null,
1207 | "width": null
1208 | }
1209 | },
1210 | "9d3909432e2d42a5849427f6d1ad1371": {
1211 | "model_module": "@jupyter-widgets/controls",
1212 | "model_name": "ProgressStyleModel",
1213 | "model_module_version": "1.5.0",
1214 | "state": {
1215 | "_model_module": "@jupyter-widgets/controls",
1216 | "_model_module_version": "1.5.0",
1217 | "_model_name": "ProgressStyleModel",
1218 | "_view_count": null,
1219 | "_view_module": "@jupyter-widgets/base",
1220 | "_view_module_version": "1.2.0",
1221 | "_view_name": "StyleView",
1222 | "bar_color": null,
1223 | "description_width": ""
1224 | }
1225 | },
1226 | "ec0bf3f9df7c40b8b64aef5ac62f0728": {
1227 | "model_module": "@jupyter-widgets/base",
1228 | "model_name": "LayoutModel",
1229 | "model_module_version": "1.2.0",
1230 | "state": {
1231 | "_model_module": "@jupyter-widgets/base",
1232 | "_model_module_version": "1.2.0",
1233 | "_model_name": "LayoutModel",
1234 | "_view_count": null,
1235 | "_view_module": "@jupyter-widgets/base",
1236 | "_view_module_version": "1.2.0",
1237 | "_view_name": "LayoutView",
1238 | "align_content": null,
1239 | "align_items": null,
1240 | "align_self": null,
1241 | "border": null,
1242 | "bottom": null,
1243 | "display": null,
1244 | "flex": null,
1245 | "flex_flow": null,
1246 | "grid_area": null,
1247 | "grid_auto_columns": null,
1248 | "grid_auto_flow": null,
1249 | "grid_auto_rows": null,
1250 | "grid_column": null,
1251 | "grid_gap": null,
1252 | "grid_row": null,
1253 | "grid_template_areas": null,
1254 | "grid_template_columns": null,
1255 | "grid_template_rows": null,
1256 | "height": null,
1257 | "justify_content": null,
1258 | "justify_items": null,
1259 | "left": null,
1260 | "margin": null,
1261 | "max_height": null,
1262 | "max_width": null,
1263 | "min_height": null,
1264 | "min_width": null,
1265 | "object_fit": null,
1266 | "object_position": null,
1267 | "order": null,
1268 | "overflow": null,
1269 | "overflow_x": null,
1270 | "overflow_y": null,
1271 | "padding": null,
1272 | "right": null,
1273 | "top": null,
1274 | "visibility": null,
1275 | "width": null
1276 | }
1277 | },
1278 | "7ba9a6bb917746d2bdb7143843480038": {
1279 | "model_module": "@jupyter-widgets/controls",
1280 | "model_name": "DescriptionStyleModel",
1281 | "model_module_version": "1.5.0",
1282 | "state": {
1283 | "_model_module": "@jupyter-widgets/controls",
1284 | "_model_module_version": "1.5.0",
1285 | "_model_name": "DescriptionStyleModel",
1286 | "_view_count": null,
1287 | "_view_module": "@jupyter-widgets/base",
1288 | "_view_module_version": "1.2.0",
1289 | "_view_name": "StyleView",
1290 | "description_width": ""
1291 | }
1292 | }
1293 | }
1294 | }
1295 | },
1296 | "nbformat": 4,
1297 | "nbformat_minor": 0
1298 | }
--------------------------------------------------------------------------------
/scripts/tf_to_torch.py:
--------------------------------------------------------------------------------
1 | from einops import rearrange
2 |
3 | def copy_bn(mod, vars, path):
4 | bn_offset = vars[f'{path}offset:0']
5 | bn_scale = vars[f'{path}scale:0']
6 |
7 | ema_path = '/'.join(path.split('/')[:-1]) + '/'
8 | bn_running_mean = vars[f'{ema_path}moving_mean/average:0']
9 | bn_running_var = vars[f'{ema_path}moving_variance/average:0']
10 |
11 | mod.weight.data.copy_(bn_scale)
12 | mod.bias.data.copy_(bn_offset)
13 |
14 | mod.running_var.data.copy_(rearrange(bn_running_var, '1 1 d -> d'))
15 | mod.running_mean.data.copy_(rearrange(bn_running_mean, '1 1 d -> d'))
16 |
17 | def copy_conv(mod, vars, path):
18 | bias = vars[f'{path}b:0']
19 | weight = vars[f'{path}w:0']
20 | mod.weight.data.copy_(rearrange(weight, 'k i o -> o i k'))
21 | mod.bias.data.copy_(bias)
22 |
23 | def copy_attn_pool(mod, vars, path):
24 | attn_pool_proj = vars[path]
25 | mod.to_attn_logits.weight.data.copy_(rearrange(attn_pool_proj, 'i o -> o i 1 1'))
26 |
27 | def copy_linear(mod, vars, path, has_bias = True):
28 | weight = vars[f'{path}w:0']
29 | mod.weight.data.copy_(rearrange(weight, 'i o -> o i'))
30 |
31 | if not has_bias:
32 | return
33 |
34 | bias = vars[f'{path}b:0']
35 | mod.bias.data.copy_(bias)
36 |
37 | def copy_ln(mod, vars, path):
38 | weight = vars[f'{path}scale:0']
39 | bias = vars[f'{path}offset:0']
40 | mod.weight.data.copy_(weight)
41 | mod.bias.data.copy_(bias)
42 |
43 | def get_tf_vars(tf_model):
44 | return {v.name: (torch.from_numpy(v.numpy()) if isinstance(v.numpy(), np.ndarray) else None) for v in tf_model.variables}
45 |
46 | def copy_tf_to_pytorch(tf_model, pytorch_model):
47 | tf_vars = get_tf_vars(tf_model)
48 | stem_conv = pytorch_model.stem[0]
49 | stem_point_bn = pytorch_model.stem[1].fn[0]
50 | stem_point_conv = pytorch_model.stem[1].fn[2]
51 | stem_attn_pool = pytorch_model.stem[2]
52 |
53 | copy_conv(stem_conv, tf_vars, 'enformer/trunk/stem/conv1_d/')
54 | copy_bn(stem_point_bn, tf_vars, 'enformer/trunk/stem/pointwise_conv_block/cross_replica_batch_norm/')
55 | copy_conv(stem_point_conv, tf_vars, 'enformer/trunk/stem/pointwise_conv_block/conv1_d/')
56 | copy_attn_pool(stem_attn_pool, tf_vars, 'enformer/trunk/stem/softmax_pooling/linear/w:0')
57 |
58 | for ind, tower_block in enumerate(pytorch_model.conv_tower):
59 | tower_bn = tower_block[0][0]
60 | tower_conv = tower_block[0][2]
61 | tower_point_bn = tower_block[1].fn[0]
62 | tower_point_conv = tower_block[1].fn[2]
63 | tower_attn_pool = tower_block[2]
64 |
65 | conv_path = f'enformer/trunk/conv_tower/conv_tower_block_{ind}/conv_block/conv1_d/'
66 | bn_path = f'enformer/trunk/conv_tower/conv_tower_block_{ind}/conv_block/cross_replica_batch_norm/'
67 | point_conv_path = f'enformer/trunk/conv_tower/conv_tower_block_{ind}/pointwise_conv_block/conv1_d/'
68 | point_bn_path = f'enformer/trunk/conv_tower/conv_tower_block_{ind}/pointwise_conv_block/cross_replica_batch_norm/'
69 | attn_pool_path = f'enformer/trunk/conv_tower/conv_tower_block_{ind}/softmax_pooling/linear/w:0'
70 |
71 | copy_bn(tower_bn, tf_vars, bn_path)
72 | copy_conv(tower_conv, tf_vars, conv_path)
73 | copy_bn(tower_point_bn, tf_vars, point_bn_path)
74 | copy_conv(tower_point_conv, tf_vars, point_conv_path)
75 | copy_attn_pool(tower_attn_pool, tf_vars, attn_pool_path)
76 |
77 | for ind, transformer_block in enumerate(pytorch_model.transformer):
78 | attn_ln_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/layer_norm/'
79 | attn_q_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/attention_{ind}/q_layer/'
80 | attn_k_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/attention_{ind}/k_layer/'
81 | attn_r_k_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/attention_{ind}/r_k_layer/'
82 | attn_v_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/attention_{ind}/v_layer/'
83 | attn_out_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/attention_{ind}/embedding_layer/'
84 |
85 | attn_content_bias_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/attention_{ind}/r_w_bias:0'
86 | attn_rel_bias_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/attention_{ind}/r_r_bias:0'
87 |
88 | ff_ln_path = f'enformer/trunk/transformer/transformer_block_{ind}/mlp/layer_norm/'
89 |
90 | # https://github.com/deepmind/deepmind-research/blob/master/enformer/enformer.py#L119
91 | # needs to be edited to snt.Linear(channels * 2, name = 'project_in') and snt.Linear(channels, name = 'project_out') or variables are not accessible
92 | ff_linear1_path = f'enformer/trunk/transformer/transformer_block_{ind}/mlp/project_in/'
93 | ff_linear2_path = f'enformer/trunk/transformer/transformer_block_{ind}/mlp/project_out/'
94 |
95 | attn = transformer_block[0]
96 | attn_ln = attn.fn[0]
97 | mha = attn.fn[1]
98 |
99 | copy_linear(mha.to_q, tf_vars, attn_q_path, has_bias = False)
100 | copy_linear(mha.to_k, tf_vars, attn_k_path, has_bias = False)
101 | copy_linear(mha.to_rel_k, tf_vars, attn_r_k_path, has_bias = False)
102 | copy_linear(mha.to_v, tf_vars, attn_v_path, has_bias = False)
103 | copy_linear(mha.to_out, tf_vars, attn_out_path)
104 |
105 | mha.rel_content_bias.data.copy_(tf_vars[attn_content_bias_path])
106 | mha.rel_pos_bias.data.copy_(tf_vars[attn_rel_bias_path])
107 |
108 | ff = transformer_block[-1]
109 | ff_ln = ff.fn[0]
110 | ff_linear1 = ff.fn[1]
111 | ff_linear2 = ff.fn[4]
112 |
113 | copy_ln(attn_ln, tf_vars, attn_ln_path)
114 |
115 | copy_ln(ff_ln, tf_vars, ff_ln_path)
116 | copy_linear(ff_linear1, tf_vars, ff_linear1_path)
117 | copy_linear(ff_linear2, tf_vars, ff_linear2_path)
118 |
119 | final_bn = pytorch_model.final_pointwise[1][0]
120 | final_conv = pytorch_model.final_pointwise[1][2]
121 |
122 | copy_bn(final_bn, tf_vars, 'enformer/trunk/final_pointwise/conv_block/cross_replica_batch_norm/')
123 | copy_conv(final_conv, tf_vars, 'enformer/trunk/final_pointwise/conv_block/conv1_d/')
124 |
125 | human_linear = pytorch_model._heads['human'][0]
126 | mouse_linear = pytorch_model._heads['mouse'][0]
127 |
128 | copy_linear(human_linear, tf_vars, 'enformer/heads/head_human/linear/')
129 | copy_linear(mouse_linear, tf_vars, 'enformer/heads/head_mouse/linear/')
130 |
131 | print('success')
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'enformer-pytorch',
5 | packages = find_packages(exclude=[]),
6 | include_package_data = True,
7 | version = '0.8.10',
8 | license='MIT',
9 | description = 'Enformer - Pytorch',
10 | author = 'Phil Wang',
11 | author_email = 'lucidrains@gmail.com',
12 | long_description_content_type = 'text/markdown',
13 | url = 'https://github.com/lucidrains/enformer-pytorch',
14 | keywords = [
15 | 'artificial intelligence',
16 | 'transformer',
17 | 'gene-expression'
18 | ],
19 | install_requires=[
20 | 'discrete-key-value-bottleneck-pytorch>=0.0.8',
21 | 'einops>=0.3',
22 | 'numpy',
23 | 'torch>=1.6',
24 | 'torchmetrics',
25 | 'polars',
26 | 'pyfaidx',
27 | 'pyyaml',
28 | 'transformers[torch]',
29 | ],
30 | classifiers=[
31 | 'Development Status :: 4 - Beta',
32 | 'Intended Audience :: Developers',
33 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
34 | 'License :: OSI Approved :: MIT License',
35 | 'Programming Language :: Python :: 3.6',
36 | ],
37 | )
38 |
--------------------------------------------------------------------------------
/test_pretrained.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from enformer_pytorch import from_pretrained
3 |
4 | enformer = from_pretrained('EleutherAI/enformer-official-rough', use_tf_gamma = False).cuda()
5 | enformer.eval()
6 |
7 | data = torch.load('./data/test-sample.pt')
8 | seq, target = data['sequence'].cuda(), data['target'].cuda()
9 |
10 | with torch.no_grad():
11 | corr_coef = enformer(
12 | seq,
13 | target = target,
14 | return_corr_coef = True,
15 | head = 'human'
16 | )
17 |
18 | print(corr_coef)
19 | assert corr_coef > 0.1
20 |
--------------------------------------------------------------------------------