├── .env ├── tf_bind_transformer ├── __init__.py ├── optimizer.py ├── gene_utils.py ├── cache_utils.py ├── context_utils.py ├── attention.py ├── protein_utils.py ├── training_utils.py ├── data_bigwig.py ├── training_utils_bigwig.py ├── tf_bind_transformer.py └── data.py ├── scripts ├── remap_increase_range_add_id.sh ├── calculated_scoped_negative_peaks.sh ├── negative_peak_to_bool_npy.py ├── filter_peaks_from_nonpeak.sh ├── remap_to_separate_exp_target_cell_beds.py ├── download_experiments.py └── fetch_factor_fastas.py ├── precache_proteins.py ├── LICENSE ├── .github └── workflows │ └── python-publish.yml ├── setup.py ├── finetune_track.py ├── finetune_binary_pred.py ├── .gitignore └── README.md /.env: -------------------------------------------------------------------------------- 1 | TORCH_HOME=.cache 2 | TRANSFORMERS_CACHE=.cache 3 | TF_BIND_CACHE_PATH=.cache 4 | CACHE_PATH=.cache 5 | -------------------------------------------------------------------------------- /tf_bind_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from tf_bind_transformer.tf_bind_transformer import AdapterModel 2 | from tf_bind_transformer.training_utils import Trainer 3 | from tf_bind_transformer.training_utils_bigwig import BigWigTrainer 4 | -------------------------------------------------------------------------------- /scripts/remap_increase_range_add_id.sh: -------------------------------------------------------------------------------- 1 | #/bin/bash 2 | # usage - sh ./scripts/remap_increase_range_add_id.sh remap.bed 4096 remap_out.bed 3 | 4 | filename=$1 5 | extend_len=$2 6 | output_filename=$3 7 | pad_len=$(($extend_len / 2)) 8 | 9 | if [ ! -f $filename ]; then 10 | echo "remap file not found" 11 | exit 1 12 | fi 13 | 14 | awk -v l="$pad_len" -F '\t' '{X=l; mid=(int($2)+int($3))/2;printf("%s\t%d\t%d\t%s\t%d\n",$1,(mid-X<0?0:mid-X),mid+X,$4,NR);}' $filename > $output_filename 15 | 16 | echo 'success' 17 | -------------------------------------------------------------------------------- /scripts/calculated_scoped_negative_peaks.sh: -------------------------------------------------------------------------------- 1 | #/bin/bash 2 | 3 | remap_path=$1 4 | 5 | if [ ! -f $remap_path ]; then 6 | echo "remap path with id and parsed exp-target-cell-type not found" 7 | exit 1 8 | fi 9 | 10 | numrows=$(wc -l $remap_path | cut -d " " -f 1) 11 | 12 | n=0 13 | for i in ./*.bed; do 14 | bedtools intersect -v -a $remap_path -b $i > $i.negative 15 | echo "created $i.negative" 16 | python to_npy.py $i.negative $numrows 17 | echo "processed $i" 18 | rm $i.negative 19 | 20 | n=$((n+1)) 21 | echo $n 22 | done 23 | -------------------------------------------------------------------------------- /scripts/negative_peak_to_bool_npy.py: -------------------------------------------------------------------------------- 1 | #/usr/bin/python 2 | 3 | import polars as pl 4 | import numpy as np 5 | from pathlib import Path 6 | import sys 7 | 8 | NEGATIVE_PEAK_PATH = sys.argv[1] 9 | NUMROWS = int(sys.argv[2]) 10 | ID_COLUMN = 'column_6' 11 | 12 | df = pl.read_csv(NEGATIVE_PEAK_PATH, sep = '\t', has_headers = False) 13 | np_array = df.get_column(ID_COLUMN).to_numpy() 14 | 15 | to_save = np.full((NUMROWS,), False) 16 | to_save[np_array - 1] = True 17 | 18 | p = Path(NEGATIVE_PEAK_PATH) 19 | filename = f'{p.stem}.bool' 20 | np.save(filename, to_save) 21 | 22 | print(f'{filename} saved') 23 | -------------------------------------------------------------------------------- /scripts/filter_peaks_from_nonpeak.sh: -------------------------------------------------------------------------------- 1 | #/bin/bash 2 | 3 | filename=$1 4 | peaks_path=$2 5 | num_lines_split="${3:=1000000}" 6 | 7 | if [ ! -f $filename ]; then 8 | echo "non-peaks file not found" 9 | exit 1 10 | fi 11 | 12 | if [ ! -f $peaks_path ]; then 13 | echo "peaks file not found" 14 | exit 1 15 | fi 16 | 17 | split -l $num_lines_split $peaks_path chunked_remap 18 | 19 | cp $filename "$filename.filtered" 20 | 21 | for i in ./chunked_remap*; do 22 | echo "filtering $filename.filtered with $i to $i.filtered"; 23 | bedtools intersect -v -a "$filename.filtered" -b "$i" > "$i.filtered" 24 | rm "$filename.filtered" 25 | filename=$i 26 | done 27 | 28 | echo "success" 29 | -------------------------------------------------------------------------------- /tf_bind_transformer/optimizer.py: -------------------------------------------------------------------------------- 1 | from torch.optim import AdamW 2 | 3 | def separate_weight_decayable_params(params): 4 | no_wd_params = set([param for param in params if param.ndim < 2]) 5 | wd_params = set(params) - no_wd_params 6 | return wd_params, no_wd_params 7 | 8 | def get_optimizer(params, lr = 3e-4, wd = 1e-1, filter_by_requires_grad = False): 9 | if filter_by_requires_grad: 10 | params = list(filter(lambda t: t.requires_grad, params)) 11 | 12 | params = set(params) 13 | wd_params, no_wd_params = separate_weight_decayable_params(params) 14 | 15 | param_groups = [ 16 | {'params': list(wd_params)}, 17 | {'params': list(no_wd_params), 'weight_decay': 0}, 18 | ] 19 | 20 | return AdamW(param_groups, lr = lr, weight_decay = wd) 21 | -------------------------------------------------------------------------------- /precache_proteins.py: -------------------------------------------------------------------------------- 1 | import click 2 | from tqdm import tqdm 3 | from pathlib import Path 4 | from Bio import SeqIO 5 | from tf_bind_transformer.protein_utils import get_protein_embedder 6 | 7 | @click.command() 8 | @click.option('--model-name', default = 'protalbert', help = 'Protein model name') 9 | @click.option('--fasta-folder', help = 'Path to factor fastas', required = True) 10 | def cache_embeddings( 11 | model_name, 12 | fasta_folder 13 | ): 14 | fn = get_protein_embedder(model_name)['fn'] 15 | fastas = [*Path(fasta_folder).glob('**/*.fasta')] 16 | 17 | assert len(fastas) > 0, f'no fasta files found at {fasta_folder}' 18 | 19 | for fasta in tqdm(fastas): 20 | seq = SeqIO.read(fasta, 'fasta') 21 | seq_str = str(seq.seq) 22 | fn([seq_str], device = 'cpu') 23 | 24 | if __name__ == '__main__': 25 | cache_embeddings() 26 | -------------------------------------------------------------------------------- /scripts/remap_to_separate_exp_target_cell_beds.py: -------------------------------------------------------------------------------- 1 | import polars as pl 2 | from pathlib import Path 3 | from tf_bind_transformer.data import read_bed, save_bed 4 | 5 | def generate_separate_exp_target_cell_beds( 6 | remap_file, 7 | *, 8 | output_folder = './negative-peaks-per-target', 9 | exp_target_cell_type_col = 'column_4' 10 | ): 11 | output_folder = Path(output_folder) 12 | output_folder.mkdir(exist_ok = True, parents = True) 13 | 14 | df = read_bed(remap_file) 15 | target_experiments = df.get_column(exp_target_cell_type_col).unique().to_list() 16 | 17 | for target_experiment in target_experiments: 18 | filtered_df = df.filter(pl.col(exp_target_cell_type_col) == target_experiment) 19 | 20 | target_bed_path = str(output_folder / f'{target_experiment}.bed') 21 | save_bed(filtered_df, target_bed_path) 22 | 23 | print('success') 24 | -------------------------------------------------------------------------------- /tf_bind_transformer/gene_utils.py: -------------------------------------------------------------------------------- 1 | # for fetching transcription factor sequences 2 | 3 | GENE_IDENTIFIER_MAP = { 4 | 'RXR': 'RXRA' 5 | } 6 | 7 | NAMES_WITH_HYPHENS = { 8 | 'NKX3-1', 9 | 'NKX2-1', 10 | 'NKX2-5', 11 | 'SS18-SSX' 12 | } 13 | 14 | def parse_gene_name(name): 15 | if '-' not in name or name in NAMES_WITH_HYPHENS: 16 | name = GENE_IDENTIFIER_MAP.get(name, name) 17 | 18 | if '_' in name: 19 | # for now, if target with modification 20 | # just search for the target factor name to the left of the underscore 21 | name, *_ = name.split('_') 22 | 23 | return (name,) 24 | 25 | first, *rest = name.split('-') 26 | 27 | parsed_rest = [] 28 | 29 | for name in rest: 30 | if len(name) == 1: 31 | name = f'{first[:-1]}{name}' 32 | parsed_rest.append(name) 33 | 34 | return tuple([first, *parsed_rest]) 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'tf-bind-transformer', 5 | packages = find_packages(exclude=[]), 6 | version = '0.0.118', 7 | license='MIT', 8 | description = 'Transformer for Transcription Factor Binding', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | url = 'https://github.com/lucidrains/tf-bind-transformer', 12 | long_description_content_type = 'text/markdown', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'attention mechanism', 17 | 'transformers', 18 | 'transcription factors', 19 | 'gene expression' 20 | ], 21 | install_requires=[ 22 | 'bidirectional-cross-attention', 23 | 'biopython', 24 | 'click', 25 | 'einops>=0.3', 26 | 'enformer-pytorch>=0.5', 27 | 'fair-esm', 28 | 'logavgexp-pytorch', 29 | 'polars', 30 | 'python-dotenv', 31 | 'sentencepiece', 32 | 'torch>=1.6', 33 | 'transformers>=4.0', 34 | 'tqdm' 35 | ], 36 | classifiers=[ 37 | 'Development Status :: 4 - Beta', 38 | 'Intended Audience :: Developers', 39 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 40 | 'License :: OSI Approved :: MIT License', 41 | 'Programming Language :: Python :: 3.6', 42 | ], 43 | ) 44 | -------------------------------------------------------------------------------- /scripts/download_experiments.py: -------------------------------------------------------------------------------- 1 | import json 2 | import tqdm 3 | import requests 4 | 5 | NCBI_TAX_ID = dict( 6 | human = 9606, 7 | mouse = 10090 8 | ) 9 | 10 | SPECIES = 'human' 11 | API_URL = 'https://remap.univ-amu.fr/api/v1/' 12 | 13 | def get_json(url, params = dict()): 14 | headers = dict(Accept = 'application/json') 15 | resp = requests.get(url, params = params, headers = headers) 16 | return resp.json() 17 | 18 | def get_experiments(species): 19 | assert species in NCBI_TAX_ID 20 | taxid = NCBI_TAX_ID[species] 21 | experiments = get_json(f'{API_URL}list/experiments/taxid={taxid}') 22 | return experiments 23 | 24 | def get_experiment(experiment_id, species): 25 | assert species in NCBI_TAX_ID 26 | taxid = NCBI_TAX_ID[species] 27 | experiment = get_json(f'http://remap.univ-amu.fr/api/v1/info/byExperiment/experiment={experiment_id}&taxid={taxid}') 28 | return experiment 29 | 30 | experiments = get_experiments(SPECIES) 31 | 32 | for experiment in tqdm.tqdm(experiments['experiments']): 33 | experiment_details = get_experiment(experiment['accession'], SPECIES) 34 | experiment['details'] = experiment_details 35 | 36 | with open('data/experiments.json', 'w+') as f: 37 | contents = json.dumps(experiments, indent = 4, sort_keys = True) 38 | f.write(contents) 39 | 40 | print('success') 41 | -------------------------------------------------------------------------------- /finetune_track.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | 3 | # set path to cache in .env and unset the next comment 4 | # load_dotenv() 5 | 6 | from enformer_pytorch import Enformer 7 | from tf_bind_transformer import AdapterModel, BigWigTrainer 8 | 9 | # training constants 10 | 11 | BATCH_SIZE = 1 12 | GRAD_ACCUM_STEPS = 8 13 | LEARNING_RATE = 1e-4 # Deepmind used 1e-4 for fine-tuning of Enformer 14 | 15 | # effective batch size of BATCH_SIZE * GRAD_ACCUM_STEPS = 16 16 | 17 | VALIDATE_EVERY = 250 18 | GRAD_CLIP_MAX_NORM = 1.5 19 | 20 | TFACTOR_FOLDER = './tfactor.fastas' 21 | HUMAN_FASTA_FILE_PATH = './hg38.ml.fa' 22 | MOUSE_FASTA_FILE_PATH = './mm10.ml.fa' 23 | 24 | HUMAN_LOCI_PATH = './chip_atlas/human_sequences.bed' 25 | MOUSE_LOCI_PATH = './chip_atlas/mouse_sequences.bed' 26 | BIGWIG_PATH = './chip_atlas/bigwig' 27 | BIGWIG_TRACKS_ONLY_PATH = './chip_atlas/bigwig_tracks_only' 28 | ANNOT_FILE_PATH = './chip_atlas/annot.tab' 29 | 30 | TARGET_LENGTH = 896 31 | 32 | HELD_OUT_TARGET = ['GATA2'] 33 | 34 | # instantiate enformer or load pretrained 35 | 36 | enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough', target_length = TARGET_LENGTH) 37 | 38 | # instantiate model wrapper that takes in enformer 39 | 40 | model = AdapterModel( 41 | enformer = enformer, 42 | use_aa_embeds = True, 43 | use_free_text_context = True, 44 | free_text_embed_method = 'mean_pool', 45 | aa_embed_encoder = 'esm', 46 | finetune_output_heads = dict( 47 | human = 12, 48 | mouse = 24 49 | ) 50 | ).cuda() 51 | 52 | 53 | # trainer class for fine-tuning 54 | 55 | trainer = BigWigTrainer( 56 | model, 57 | human_loci_path = HUMAN_LOCI_PATH, 58 | mouse_loci_path = MOUSE_LOCI_PATH, 59 | human_fasta_file = HUMAN_FASTA_FILE_PATH, 60 | mouse_fasta_file = MOUSE_FASTA_FILE_PATH, 61 | bigwig_folder_path = BIGWIG_PATH, 62 | bigwig_tracks_only_folder_path = BIGWIG_TRACKS_ONLY_PATH, 63 | annot_file_path = ANNOT_FILE_PATH, 64 | target_length = TARGET_LENGTH, 65 | lr = LEARNING_RATE, 66 | batch_size = BATCH_SIZE, 67 | shuffle = True, 68 | validate_every = VALIDATE_EVERY, 69 | grad_clip_norm = GRAD_CLIP_MAX_NORM, 70 | grad_accum_every = GRAD_ACCUM_STEPS, 71 | human_factor_fasta_folder = TFACTOR_FOLDER, 72 | mouse_factor_fasta_folder = TFACTOR_FOLDER, 73 | held_out_targets = HELD_OUT_TARGET 74 | ) 75 | 76 | # do gradient steps in a while loop 77 | 78 | while True: 79 | _ = trainer() 80 | -------------------------------------------------------------------------------- /finetune_binary_pred.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | 3 | # set path to cache in .env and unset the next comment 4 | # load_dotenv() 5 | 6 | from enformer_pytorch import Enformer 7 | from tf_bind_transformer import AdapterModel, Trainer 8 | 9 | # instantiate enformer or load pretrained 10 | 11 | enformer = Enformer.from_hparams( 12 | dim = 768, 13 | depth = 4, 14 | heads = 8, 15 | target_length = -1, 16 | use_convnext = True, 17 | num_downsamples = 6 # resolution of 2 ^ 6 == 64bp 18 | ) 19 | 20 | # instantiate model wrapper that takes in enformer 21 | 22 | model = AdapterModel( 23 | enformer = enformer, 24 | use_aa_embeds = True, 25 | use_free_text_context = True, 26 | free_text_embed_method = 'mean_pool', 27 | binary_target = True, 28 | target_mse_loss = False, 29 | use_squeeze_excite = True, 30 | aa_embed_encoder = 'protalbert' 31 | ).cuda() 32 | 33 | 34 | # training constants 35 | 36 | BATCH_SIZE = 2 37 | GRAD_ACCUM_STEPS = 8 38 | 39 | # effective batch size of BATCH_SIZE * GRAD_ACCUM_STEPS = 16 40 | 41 | VALIDATE_EVERY = 250 42 | GRAD_CLIP_MAX_NORM = 1.5 43 | 44 | REMAP_FILE_PATH = './remap2022_all.bed' 45 | TFACTOR_FOLDER = './tfactor.fastas' 46 | FASTA_FILE_PATH = './hg38.ml.fa' 47 | NON_PEAK_PATH = './generated-non-peaks.bed' 48 | 49 | CONTEXT_LENGTH = 4096 50 | 51 | SCOPED_NEGS_REMAP_PATH = './neg-npy/remap2022.bed' 52 | SCOPED_NEGS_PATH = './neg-npy' 53 | 54 | TRAIN_CHROMOSOMES = [*range(1, 24, 2), 'X'] # train on odd chromosomes 55 | VALID_CHROMOSOMES = [*range(2, 24, 2)] # validate on even 56 | 57 | HELD_OUT_TARGET = ['AFF4'] 58 | 59 | # trainer class for fine-tuning 60 | 61 | trainer = Trainer( 62 | model, 63 | context_length = CONTEXT_LENGTH, 64 | batch_size = BATCH_SIZE, 65 | validate_every = VALIDATE_EVERY, 66 | grad_clip_norm = GRAD_CLIP_MAX_NORM, 67 | grad_accum_every = GRAD_ACCUM_STEPS, 68 | remap_bed_file = REMAP_FILE_PATH, 69 | negative_bed_file = NON_PEAK_PATH, 70 | factor_fasta_folder = TFACTOR_FOLDER, 71 | fasta_file = FASTA_FILE_PATH, 72 | train_chromosome_ids = TRAIN_CHROMOSOMES, 73 | valid_chromosome_ids = VALID_CHROMOSOMES, 74 | held_out_targets = HELD_OUT_TARGET, 75 | include_scoped_negs = True, 76 | scoped_negs_remap_bed_path = SCOPED_NEGS_REMAP_PATH, 77 | scoped_negs_path = SCOPED_NEGS_PATH, 78 | ) 79 | 80 | # do gradient steps in a while loop 81 | 82 | while True: 83 | _ = trainer(finetune_enformer_ln_only = False) 84 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /tf_bind_transformer/cache_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from shutil import rmtree 3 | import torch 4 | import hashlib 5 | from functools import wraps 6 | from pathlib import Path 7 | 8 | def exists(val): 9 | return val is not None 10 | 11 | # constants 12 | 13 | CACHE_PATH = Path(os.getenv('TF_BIND_CACHE_PATH', os.path.expanduser('~/.cache.tf.bind.transformer'))) 14 | CACHE_PATH.mkdir(exist_ok = True, parents = True) 15 | 16 | CLEAR_CACHE = exists(os.getenv('CLEAR_CACHE', None)) 17 | VERBOSE = exists(os.getenv('VERBOSE', None)) 18 | 19 | # helper functions 20 | 21 | 22 | def log(s): 23 | if not VERBOSE: 24 | return 25 | print(s) 26 | 27 | def md5_hash_fn(s): 28 | encoded = s.encode('utf-8') 29 | return hashlib.md5(encoded).hexdigest() 30 | 31 | # run once function 32 | 33 | GLOBAL_RUN_RECORDS = dict() 34 | 35 | def run_once(global_id = None): 36 | def outer(fn): 37 | has_ran_local = False 38 | output = None 39 | 40 | @wraps(fn) 41 | def inner(*args, **kwargs): 42 | nonlocal has_ran_local 43 | nonlocal output 44 | 45 | has_ran = GLOBAL_RUN_RECORDS.get(global_id, False) if exists(global_id) else has_ran_local 46 | 47 | if has_ran: 48 | return output 49 | 50 | output = fn(*args, **kwargs) 51 | 52 | if exists(global_id): 53 | GLOBAL_RUN_RECORDS[global_id] = True 54 | 55 | has_ran = True 56 | return output 57 | 58 | return inner 59 | return outer 60 | 61 | # caching function 62 | 63 | def cache_fn( 64 | fn, 65 | path = '', 66 | hash_fn = md5_hash_fn, 67 | clear = False or CLEAR_CACHE, 68 | should_cache = True 69 | ): 70 | if not should_cache: 71 | return fn 72 | 73 | (CACHE_PATH / path).mkdir(parents = True, exist_ok = True) 74 | 75 | @run_once(path) 76 | def clear_cache_folder_(): 77 | cache_path = rmtree(str(CACHE_PATH / path)) 78 | (CACHE_PATH / path).mkdir(parents = True, exist_ok = True) 79 | 80 | @wraps(fn) 81 | def inner(t, *args, __cache_key = None, **kwargs): 82 | if clear: 83 | clear_cache_folder_() 84 | 85 | cache_str = __cache_key if exists(__cache_key) else t 86 | key = hash_fn(cache_str) 87 | 88 | entry_path = CACHE_PATH / path / f'{key}.pt' 89 | 90 | if entry_path.exists(): 91 | log(f'cache hit: fetching {t} from {str(entry_path)}') 92 | return torch.load(str(entry_path)) 93 | 94 | out = fn(t, *args, **kwargs) 95 | 96 | log(f'saving: {t} to {str(entry_path)}') 97 | torch.save(out, str(entry_path)) 98 | return out 99 | return inner 100 | -------------------------------------------------------------------------------- /tf_bind_transformer/context_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import logging 4 | from transformers import AutoTokenizer, AutoModelForMaskedLM, logging 5 | from tf_bind_transformer.cache_utils import cache_fn, run_once 6 | 7 | logging.set_verbosity_error() 8 | 9 | def exists(val): 10 | return val is not None 11 | 12 | def map_values(fn, dictionary): 13 | return {k: fn(v) for k, v in dictionary.items()} 14 | 15 | CONTEXT_EMBED_USE_CPU = os.getenv('CONTEXT_EMBED_USE_CPU', None) is not None 16 | 17 | if CONTEXT_EMBED_USE_CPU: 18 | print('calculating context embed only on cpu') 19 | 20 | MODELS = dict( 21 | pubmed = dict( 22 | dim = 768, 23 | path = 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract', 24 | ) 25 | ) 26 | 27 | GLOBAL_VARIABLES = dict(model = None, tokenizer = None) 28 | 29 | def get_contextual_dim(model_name): 30 | assert model_name in MODELS 31 | return MODELS[model_name]['dim'] 32 | 33 | @run_once('init_transformer') 34 | def init_transformer(model_name): 35 | path = MODELS[model_name]['path'] 36 | GLOBAL_VARIABLES['tokenizer'] = AutoTokenizer.from_pretrained(path) 37 | 38 | model = AutoModelForMaskedLM.from_pretrained(path) 39 | 40 | if not CONTEXT_EMBED_USE_CPU: 41 | model = model.cuda() 42 | 43 | GLOBAL_VARIABLES['model'] = model 44 | 45 | @torch.no_grad() 46 | def tokenize_text( 47 | text, 48 | max_length = 256, 49 | model_name = 'pubmed', 50 | hidden_state_index = -1, 51 | return_cls_token = True 52 | ): 53 | init_transformer(model_name) 54 | 55 | model = GLOBAL_VARIABLES['model'] 56 | tokenizer = GLOBAL_VARIABLES['tokenizer'] 57 | 58 | encoding = tokenizer.batch_encode_plus( 59 | [text], 60 | add_special_tokens = True, 61 | padding = True, 62 | truncation = True, 63 | max_length = max_length, 64 | return_attention_mask = True, 65 | return_tensors = 'pt' 66 | ) 67 | 68 | if not CONTEXT_EMBED_USE_CPU: 69 | encoding = map_values(lambda t: t.cuda(), encoding) 70 | 71 | model.eval() 72 | with torch.no_grad(): 73 | outputs = model(**encoding, output_hidden_states = True) 74 | 75 | hidden_state = outputs.hidden_states[hidden_state_index][0] 76 | 77 | if return_cls_token: 78 | return hidden_state[0] 79 | 80 | return hidden_state.mean(dim = 0) 81 | 82 | def get_text_repr( 83 | texts, 84 | *, 85 | device, 86 | max_length = 256, 87 | model_name = 'pubmed', 88 | hidden_state_index = -1, 89 | return_cls_token = True, 90 | ): 91 | assert model_name in MODELS, f'{model_name} not found in available text transformers to use' 92 | 93 | if isinstance(texts, str): 94 | texts = [texts] 95 | 96 | get_context_repr_fn = cache_fn(tokenize_text, path = f'contexts/{model_name}') 97 | 98 | representations = [get_context_repr_fn(text, max_length = max_length, model_name = model_name, hidden_state_index = hidden_state_index, return_cls_token = return_cls_token) for text in texts] 99 | 100 | return torch.stack(representations).to(device) 101 | -------------------------------------------------------------------------------- /scripts/fetch_factor_fastas.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from pathlib import Path 3 | import click 4 | import polars as pl 5 | from tqdm import tqdm 6 | from tf_bind_transformer.gene_utils import parse_gene_name 7 | from tf_bind_transformer.data import read_bed 8 | 9 | # constants 10 | 11 | UNIPROT_URL = 'http://www.uniprot.org' 12 | 13 | DEFAULT_REMAP_PATH = dict( 14 | HUMAN = './remap2022_crm_macs2_hg38_v1_0.bed', 15 | MOUSE = './remap2022_crm_macs2_mm10_v1_0.bed', 16 | ) 17 | 18 | GENE_NAME_TO_ID_OVERRIDE = { 19 | 'SS18-SSX': ['Q8IZH1'], 20 | 'TFIIIC': ['A6ZV34'] # todo: figure out where the human entry is in Uniprot 21 | } 22 | 23 | # helper functions 24 | 25 | def uniprot_mapping(fromtype, totype, identifier): 26 | params = { 27 | 'from': fromtype, 28 | 'to': totype, 29 | 'format': 'tab', 30 | 'query': identifier, 31 | } 32 | 33 | response = requests.get(f'{UNIPROT_URL}/mapping', params = params) 34 | return response.text 35 | 36 | # main functions 37 | 38 | @click.command() 39 | @click.option('--species', help = 'Species', default = 'human', type = click.Choice(['human', 'mouse'])) 40 | @click.option('--remap-bed-path', help = 'Path to species specific remap file') 41 | @click.option('--fasta-folder', help = 'Path to factor fastas', default = './tfactor.fastas') 42 | def fetch_factors( 43 | species, 44 | remap_bed_path, 45 | fasta_folder 46 | ): 47 | species = species.upper() 48 | 49 | if remap_bed_path is None: 50 | remap_bed_path = DEFAULT_REMAP_PATH[species] 51 | 52 | remap_bed_path = Path(remap_bed_path) 53 | 54 | assert remap_bed_path.exists(), f'remap file does not exist at {str(remap_bed_path)}' 55 | 56 | # load bed file and get all unique targets from column 3 57 | 58 | df = read_bed(remap_bed_path) 59 | genes = set([target for targets in df[:, 3] for target in targets.split(',')]) 60 | 61 | print(f'{len(genes)} factors found') 62 | 63 | # load all saved fasta files, so can resume gracefully 64 | 65 | fasta_files = [str(path) for path in Path('./').glob('*.fasta')] 66 | processed_genes = set([*map(lambda t: str(t).split('.')[0], fasta_files)]) 67 | 68 | results_folder = Path(fasta_folder) 69 | results_folder.mkdir(exist_ok = True, parents = True) 70 | 71 | for unparsed_gene_name in tqdm(genes): 72 | for gene_name in parse_gene_name(unparsed_gene_name): 73 | 74 | if gene_name in processed_genes: 75 | continue 76 | 77 | # fetch uniprot id based on gene id 78 | 79 | if gene_name not in GENE_NAME_TO_ID_OVERRIDE: 80 | uniprot_resp = uniprot_mapping('GENENAME', 'ID', gene_name) 81 | 82 | # only get the human ones (todo: make species agnostic) 83 | 84 | entries = list(filter(lambda t: f'_{species}' in t, uniprot_resp.split('\n'))) 85 | entries = list(map(lambda t: t.split('\t')[1], entries)) 86 | else: 87 | entries = GENE_NAME_TO_ID_OVERRIDE[gene_name] 88 | 89 | if len(entries) == 0: 90 | print(f'no entries found for {gene_name}') 91 | continue 92 | 93 | # save all hits 94 | 95 | for entry in entries: 96 | response = requests.get(f'{UNIPROT_URL}/uniprot/{entry}.fasta') 97 | 98 | if response.status_code != 200: 99 | print(f'<{response.status_code}> error fetching fasta file from gene {gene_name} {entry}') 100 | continue 101 | 102 | fasta_path = str(results_folder / f'{gene_name}.{entry}.fasta') 103 | 104 | with open(fasta_path, 'w') as f: 105 | f.write(response.text) 106 | 107 | print(f'gene {gene_name} written') 108 | 109 | # main function 110 | 111 | if __name__ == '__main__': 112 | fetch_factors() 113 | -------------------------------------------------------------------------------- /tf_bind_transformer/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from einops import rearrange 4 | from torch import einsum 5 | from bidirectional_cross_attention import BidirectionalCrossAttention 6 | 7 | def exists(val): 8 | return val is not None 9 | 10 | def default(val, d): 11 | return val if exists(val) else d 12 | 13 | # classes 14 | 15 | def FeedForward(dim, mult = 4, dropout = 0.): 16 | return nn.Sequential( 17 | nn.LayerNorm(dim), 18 | nn.Linear(dim, dim * mult), 19 | nn.GELU(), 20 | nn.Dropout(dropout), 21 | nn.Linear(dim * mult, dim) 22 | ) 23 | 24 | # self attention 25 | 26 | class SelfAttention(nn.Module): 27 | def __init__( 28 | self, 29 | *, 30 | dim, 31 | heads = 8, 32 | dim_head = 64, 33 | dropout = 0. 34 | ): 35 | super().__init__() 36 | self.norm = nn.LayerNorm(dim) 37 | 38 | self.heads = heads 39 | self.scale = dim_head ** -0.5 40 | inner_dim = dim_head * heads 41 | 42 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 43 | self.to_out = nn.Linear(inner_dim, dim) 44 | 45 | self.dropout = nn.Dropout(dropout) 46 | 47 | def forward( 48 | self, 49 | x, 50 | mask = None, 51 | ): 52 | h = self.heads 53 | x = self.norm(x) 54 | 55 | q, k, v = self.to_qkv(x).chunk(3, dim = -1) 56 | q = q * self.scale 57 | 58 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) 59 | sim = einsum('b h i d, b h j d -> b h i j', q, k) 60 | 61 | if exists(mask): 62 | mask_value = -torch.finfo(sim.dtype).max 63 | mask = rearrange(mask, 'b j -> b 1 1 j') 64 | sim = sim.masked_fill(~mask, mask_value) 65 | 66 | attn = sim.softmax(dim = -1) 67 | attn = self.dropout(attn) 68 | 69 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 70 | out = rearrange(out, 'b h n d -> b n (h d)') 71 | return self.to_out(out) 72 | 73 | class SelfAttentionBlock(nn.Module): 74 | def __init__( 75 | self, 76 | *, 77 | dim, 78 | dropout = 0., 79 | ff_mult = 4, 80 | **kwargs 81 | ): 82 | super().__init__() 83 | self.attn = SelfAttention(dim = dim, dropout = dropout, **kwargs) 84 | self.ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout) 85 | 86 | def forward(self, x, mask = None): 87 | x = self.attn(x, mask = mask) + x 88 | x = self.ff(x) + x 89 | return x 90 | 91 | # directional cross attention 92 | 93 | class CrossAttention(nn.Module): 94 | def __init__( 95 | self, 96 | *, 97 | dim, 98 | heads = 8, 99 | dim_head = 64, 100 | context_dim = None, 101 | dropout = 0. 102 | ): 103 | super().__init__() 104 | context_dim = default(context_dim, dim) 105 | self.norm = nn.LayerNorm(dim) 106 | self.context_norm = nn.LayerNorm(context_dim) 107 | 108 | self.heads = heads 109 | self.scale = dim_head ** -0.5 110 | inner_dim = dim_head * heads 111 | 112 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 113 | self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) 114 | self.to_out = nn.Linear(inner_dim, dim) 115 | 116 | self.dropout = nn.Dropout(dropout) 117 | 118 | def forward( 119 | self, 120 | x, 121 | context, 122 | mask = None, 123 | context_mask = None 124 | ): 125 | h = self.heads 126 | 127 | x = self.norm(x) 128 | context = self.context_norm(context) 129 | 130 | q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) 131 | q = q * self.scale 132 | 133 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) 134 | sim = einsum('b h i d, b h j d -> b h i j', q, k) 135 | 136 | if exists(context_mask): 137 | mask_value = -torch.finfo(sim.dtype).max 138 | context_mask = rearrange(context_mask, 'b j -> b 1 1 j') 139 | sim = sim.masked_fill(~context_mask, mask_value) 140 | 141 | attn = sim.softmax(dim = -1) 142 | attn = self.dropout(attn) 143 | 144 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 145 | out = rearrange(out, 'b h n d -> b n (h d)') 146 | return self.to_out(out) 147 | 148 | class JointCrossAttentionBlock(nn.Module): 149 | def __init__( 150 | self, 151 | *, 152 | dim, 153 | context_dim = None, 154 | ff_mult = 4, 155 | dropout = 0., 156 | **kwargs 157 | ): 158 | super().__init__() 159 | context_dim = default(context_dim, dim) 160 | 161 | self.attn = BidirectionalCrossAttention(dim = dim, context_dim = context_dim, dropout = dropout, prenorm = True, **kwargs) 162 | self.ff = FeedForward(dim, mult = ff_mult, dropout = dropout) 163 | self.context_ff = FeedForward(context_dim, mult = ff_mult, dropout = dropout) 164 | 165 | def forward( 166 | self, 167 | x, 168 | context, 169 | mask = None, 170 | context_mask = None 171 | ): 172 | attn_out, context_attn_out = self.attn(x, context, mask = mask, context_mask = context_mask) 173 | 174 | x = x + attn_out 175 | context = context + context_attn_out 176 | 177 | x = self.ff(x) + x 178 | context = self.context_ff(context) + context 179 | 180 | return x, context 181 | -------------------------------------------------------------------------------- /tf_bind_transformer/protein_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import re 4 | from pathlib import Path 5 | from functools import partial 6 | import esm 7 | from torch.nn.utils.rnn import pad_sequence 8 | from transformers import AlbertTokenizer, AutoModelForMaskedLM, logging 9 | from tf_bind_transformer.cache_utils import cache_fn, run_once, md5_hash_fn 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | def map_values(fn, dictionary): 15 | return {k: fn(v) for k, v in dictionary.items()} 16 | 17 | def to_device(t, *, device): 18 | return t.to(device) 19 | 20 | def cast_tuple(t): 21 | return (t,) if not isinstance(t, tuple) else t 22 | 23 | PROTEIN_EMBED_USE_CPU = os.getenv('PROTEIN_EMBED_USE_CPU', None) is not None 24 | 25 | if PROTEIN_EMBED_USE_CPU: 26 | print('calculating protein embed only on cpu') 27 | 28 | # global variables 29 | 30 | GLOBAL_VARIABLES = { 31 | 'model': None, 32 | 'tokenizer': None 33 | } 34 | 35 | # general helper functions 36 | 37 | def calc_protein_representations_with_subunits(proteins, get_repr_fn, *, device): 38 | representations = [] 39 | 40 | for subunits in proteins: 41 | subunits = cast_tuple(subunits) 42 | subunits_representations = list(map(get_repr_fn, subunits)) 43 | subunits_representations = list(map(partial(to_device, device = device), subunits_representations)) 44 | subunits_representations = torch.cat(subunits_representations, dim = 0) 45 | representations.append(subunits_representations) 46 | 47 | lengths = [seq_repr.shape[0] for seq_repr in representations] 48 | masks = torch.arange(max(lengths), device = device)[None, :] < torch.tensor(lengths, device = device)[:, None] 49 | padded_representations = pad_sequence(representations, batch_first = True) 50 | 51 | return padded_representations.to(device), masks.to(device) 52 | 53 | # esm related functions 54 | 55 | ESM_MAX_LENGTH = 1024 56 | ESM_EMBED_DIM = 1280 57 | 58 | INT_TO_AA_STR_MAP = { 59 | 0: 'A', 60 | 1: 'C', 61 | 2: 'D', 62 | 3: 'E', 63 | 4: 'F', 64 | 5: 'G', 65 | 6: 'H', 66 | 7: 'I', 67 | 8: 'K', 68 | 9: 'L', 69 | 10: 'M', 70 | 11: 'N', 71 | 12: 'P', 72 | 13: 'Q', 73 | 14: 'R', 74 | 15: 'S', 75 | 16: 'T', 76 | 17: 'V', 77 | 18: 'W', 78 | 19: 'Y', 79 | 20: '_' 80 | } 81 | 82 | def tensor_to_aa_str(t): 83 | str_seqs = [] 84 | for int_seq in t.unbind(dim = 0): 85 | str_seq = list(map(lambda t: INT_TO_AA_STR_MAP[t] if t != 20 else '', int_seq.tolist())) 86 | str_seqs.append(''.join(str_seq)) 87 | return str_seqs 88 | 89 | @run_once('init_esm') 90 | def init_esm(): 91 | model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S() 92 | batch_converter = alphabet.get_batch_converter() 93 | 94 | if not PROTEIN_EMBED_USE_CPU: 95 | model = model.cuda() 96 | 97 | GLOBAL_VARIABLES['model'] = (model, batch_converter) 98 | 99 | def get_single_esm_repr(protein_str): 100 | init_esm() 101 | model, batch_converter = GLOBAL_VARIABLES['model'] 102 | 103 | data = [('protein', protein_str)] 104 | batch_labels, batch_strs, batch_tokens = batch_converter(data) 105 | 106 | if batch_tokens.shape[1] > ESM_MAX_LENGTH: 107 | print(f'warning max length protein esm: {protein_str}') 108 | 109 | batch_tokens = batch_tokens[:, :ESM_MAX_LENGTH] 110 | 111 | if not PROTEIN_EMBED_USE_CPU: 112 | batch_tokens = batch_tokens.cuda() 113 | 114 | with torch.no_grad(): 115 | results = model(batch_tokens, repr_layers=[33]) 116 | 117 | token_representations = results['representations'][33] 118 | representation = token_representations[0][1 : len(protein_str) + 1] 119 | return representation 120 | 121 | def get_esm_repr(proteins, device): 122 | if isinstance(proteins, torch.Tensor): 123 | proteins = tensor_to_aa_str(proteins) 124 | 125 | get_protein_repr_fn = cache_fn(get_single_esm_repr, path = 'esm/proteins') 126 | 127 | return calc_protein_representations_with_subunits(proteins, get_protein_repr_fn, device = device) 128 | 129 | # prot-albert 2048 context length 130 | 131 | PROT_ALBERT_PATH = 'Rostlab/prot_albert' 132 | PROT_ALBERT_DIM = 4096 133 | PROT_ALBERT_MAX_LENGTH = 2048 134 | 135 | def protein_str_with_spaces(protein_str): 136 | protein_str = re.sub(r"[UZOB]", 'X', protein_str) 137 | return ' '.join([*protein_str]) 138 | 139 | @run_once('init_prot_albert') 140 | def init_prot_albert(): 141 | GLOBAL_VARIABLES['tokenizer'] = AlbertTokenizer.from_pretrained(PROT_ALBERT_PATH, do_lower_case = False) 142 | model = AutoModelForMaskedLM.from_pretrained(PROT_ALBERT_PATH) 143 | 144 | if not PROTEIN_EMBED_USE_CPU: 145 | model = model.cuda() 146 | 147 | GLOBAL_VARIABLES['model'] = model 148 | 149 | def get_single_prot_albert_repr( 150 | protein_str, 151 | max_length = PROT_ALBERT_MAX_LENGTH, 152 | hidden_state_index = -1 153 | ): 154 | init_prot_albert() 155 | model = GLOBAL_VARIABLES['model'] 156 | tokenizer = GLOBAL_VARIABLES['tokenizer'] 157 | 158 | encoding = tokenizer.batch_encode_plus( 159 | [protein_str_with_spaces(protein_str)], 160 | add_special_tokens = True, 161 | padding = True, 162 | truncation = True, 163 | max_length = max_length, 164 | return_attention_mask = True, 165 | return_tensors = 'pt' 166 | ) 167 | 168 | if not PROTEIN_EMBED_USE_CPU: 169 | encoding = map_values(lambda t: t.cuda(), encoding) 170 | 171 | model.eval() 172 | with torch.no_grad(): 173 | outputs = model(**encoding, output_hidden_states = True) 174 | 175 | hidden_state = outputs.hidden_states[hidden_state_index][0] 176 | return hidden_state 177 | 178 | def get_prot_albert_repr( 179 | proteins, 180 | device, 181 | max_length = PROT_ALBERT_MAX_LENGTH, 182 | hidden_state_index = -1 183 | ): 184 | if isinstance(proteins, str): 185 | proteins = [proteins] 186 | 187 | if isinstance(proteins, torch.Tensor): 188 | proteins = tensor_to_aa_str(proteins) 189 | 190 | get_protein_repr_fn = cache_fn(get_single_prot_albert_repr, path = f'proteins/prot_albert') 191 | 192 | return calc_protein_representations_with_subunits(proteins, get_protein_repr_fn, device = device) 193 | 194 | # alphafold2 functions 195 | 196 | AF2_MAX_LENGTH = 2500 197 | AF2_EMBEDDING_DIM = 384 198 | 199 | AF2_DIRECTORY = os.getenv('TF_BIND_AF2_DIRECTORY', os.path.expanduser('~/.cache.tf.bind.transformer/.af2_embeddings')) 200 | AF2_DIRECTORY_PATH = Path(AF2_DIRECTORY) 201 | 202 | def get_single_alphafold2_repr( 203 | protein_str, 204 | max_length = AF2_MAX_LENGTH, 205 | ): 206 | md5 = md5_hash_fn(protein_str) 207 | embedding_path = AF2_DIRECTORY_PATH / f'{md5}.pt' 208 | assert embedding_path.exists(), f'af2 embedding not found for {protein_str}' 209 | 210 | tensor = torch.load(str(embedding_path)) 211 | return tensor[:max_length] 212 | 213 | def get_alphafold2_repr( 214 | proteins, 215 | device, 216 | max_length = AF2_MAX_LENGTH, 217 | **kwargs 218 | ): 219 | representations = [] 220 | 221 | for subunits in proteins: 222 | subunits = cast_tuple(subunits) 223 | subunits = list(map(lambda t: get_single_alphafold2_repr(t, max_length = max_length), subunits)) 224 | subunits = torch.cat(subunits, dim = 0) 225 | representations.append(subunits) 226 | 227 | lengths = [seq_repr.shape[0] for seq_repr in representations] 228 | masks = torch.arange(max(lengths), device = device)[None, :] < torch.tensor(lengths, device = device)[:, None] 229 | padded_representations = pad_sequence(representations, batch_first = True) 230 | 231 | return padded_representations.to(device), masks.to(device) 232 | 233 | # factory functions 234 | 235 | PROTEIN_REPR_CONFIG = { 236 | 'esm': { 237 | 'dim': ESM_EMBED_DIM, 238 | 'fn': get_esm_repr 239 | }, 240 | 'protalbert': { 241 | 'dim': PROT_ALBERT_DIM, 242 | 'fn': get_prot_albert_repr 243 | }, 244 | 'alphafold2': { 245 | 'dim': AF2_EMBEDDING_DIM, 246 | 'fn': get_alphafold2_repr 247 | } 248 | } 249 | 250 | def get_protein_embedder(name): 251 | allowed_protein_embedders = list(PROTEIN_REPR_CONFIG.keys()) 252 | assert name in allowed_protein_embedders, f"must be one of {', '.join(allowed_protein_embedders)}" 253 | 254 | config = PROTEIN_REPR_CONFIG[name] 255 | return config 256 | -------------------------------------------------------------------------------- /tf_bind_transformer/training_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from tf_bind_transformer.optimizer import get_optimizer 4 | from tf_bind_transformer.data import read_bed, collate_dl_outputs, get_dataloader, remap_df_add_experiment_target_cell 5 | from tf_bind_transformer.data import RemapAllPeakDataset, NegativePeakDataset, ScopedNegativePeakDataset 6 | 7 | def exists(val): 8 | return val is not None 9 | 10 | def default(val, d): 11 | return val if exists(val) else d 12 | 13 | # helpers for logging and accumulating values across gradient steps 14 | 15 | def accum_log(log, new_logs): 16 | for key, new_value in new_logs.items(): 17 | old_value = log.get(key, 0.) 18 | log[key] = old_value + new_value 19 | return log 20 | 21 | # simple Trainer class 22 | 23 | class Trainer(nn.Module): 24 | def __init__( 25 | self, 26 | model, 27 | *, 28 | remap_bed_file, 29 | negative_bed_file, 30 | factor_fasta_folder, 31 | fasta_file, 32 | train_chromosome_ids, 33 | valid_chromosome_ids, 34 | batch_size, 35 | context_length, 36 | lr = 3e-4, 37 | wd = 0.1, 38 | validate_every = 250, 39 | grad_clip_norm = None, 40 | grad_accum_every = 1, 41 | held_out_targets = [], 42 | held_out_cell_types = [], 43 | exclude_targets = [], 44 | exclude_cell_types = [], 45 | shuffle = False, 46 | train_sample_frac = 1., 47 | valid_sample_frac = 1., 48 | remap_sample_frac = 1., 49 | shift_aug_range = (-2, 2), 50 | rc_aug = False, 51 | experiments_json_path = None, 52 | read_value_aux_loss = False, 53 | checkpoint_filename = './checkpoint.pt', 54 | include_scoped_negs = False, 55 | scoped_negs_remap_bed_path = None, 56 | scoped_negs_path = None, 57 | scoped_negs_exts = '.bed.bool.npy', 58 | include_biotypes_metadata_in_context = False, 59 | biotypes_metadata_path = None, 60 | include_biotypes_metadata_columns = ['germ_layer', 'cellline_cat'], 61 | biotypes_metadata_delimiter = ' | ', 62 | balance_sampling_by_target = True, 63 | valid_balance_sampling_by_target = None, 64 | ): 65 | super().__init__() 66 | self.model = model 67 | valid_balance_sampling_by_target = default(valid_balance_sampling_by_target, balance_sampling_by_target) 68 | 69 | remap_df = read_bed(remap_bed_file) 70 | 71 | if remap_sample_frac < 1: 72 | remap_df = remap_df.sample(frac = remap_sample_frac) 73 | 74 | remap_df = remap_df_add_experiment_target_cell(remap_df) 75 | 76 | neg_df = read_bed(negative_bed_file) 77 | 78 | self.ds = RemapAllPeakDataset( 79 | remap_df = remap_df, 80 | fasta_file = fasta_file, 81 | factor_fasta_folder = factor_fasta_folder, 82 | filter_chromosome_ids = train_chromosome_ids, 83 | exclude_targets = [*held_out_targets, *exclude_targets], 84 | exclude_cell_types = [*held_out_cell_types, *exclude_cell_types], 85 | context_length = context_length, 86 | remap_df_frac = train_sample_frac, 87 | shift_augs = shift_aug_range, 88 | rc_aug = rc_aug, 89 | experiments_json_path = experiments_json_path, 90 | include_biotypes_metadata_in_context = include_biotypes_metadata_in_context, 91 | biotypes_metadata_path = biotypes_metadata_path, 92 | include_biotypes_metadata_columns = include_biotypes_metadata_columns, 93 | biotypes_metadata_delimiter = biotypes_metadata_delimiter, 94 | balance_sampling_by_target = balance_sampling_by_target 95 | ) 96 | 97 | self.neg_ds = NegativePeakDataset( 98 | remap_df = remap_df, 99 | negative_df = neg_df, 100 | fasta_file = fasta_file, 101 | factor_fasta_folder = factor_fasta_folder, 102 | filter_chromosome_ids = train_chromosome_ids, 103 | exclude_targets = [*held_out_targets, *exclude_targets], 104 | exclude_cell_types = [*held_out_cell_types, *exclude_cell_types], 105 | context_length = context_length, 106 | experiments_json_path = experiments_json_path, 107 | include_biotypes_metadata_in_context = include_biotypes_metadata_in_context, 108 | biotypes_metadata_path = biotypes_metadata_path, 109 | include_biotypes_metadata_columns = include_biotypes_metadata_columns, 110 | biotypes_metadata_delimiter = biotypes_metadata_delimiter, 111 | balance_sampling_by_target = balance_sampling_by_target 112 | ) 113 | 114 | self.valid_ds = RemapAllPeakDataset( 115 | remap_df = remap_df, 116 | fasta_file = fasta_file, 117 | factor_fasta_folder = factor_fasta_folder, 118 | include_targets = held_out_targets, 119 | include_cell_types = held_out_cell_types, 120 | exclude_targets = exclude_targets, 121 | exclude_cell_types = exclude_cell_types, 122 | filter_chromosome_ids = valid_chromosome_ids, 123 | context_length = context_length, 124 | remap_df_frac = valid_sample_frac, 125 | shift_augs = shift_aug_range, 126 | rc_aug = rc_aug, 127 | experiments_json_path = experiments_json_path, 128 | include_biotypes_metadata_in_context = include_biotypes_metadata_in_context, 129 | biotypes_metadata_path = biotypes_metadata_path, 130 | include_biotypes_metadata_columns = include_biotypes_metadata_columns, 131 | biotypes_metadata_delimiter = biotypes_metadata_delimiter, 132 | balance_sampling_by_target = valid_balance_sampling_by_target 133 | ) 134 | 135 | self.valid_neg_ds = NegativePeakDataset( 136 | remap_df = remap_df, 137 | negative_df = neg_df, 138 | fasta_file = fasta_file, 139 | factor_fasta_folder = factor_fasta_folder, 140 | filter_chromosome_ids = valid_chromosome_ids, 141 | include_targets = held_out_targets, 142 | include_cell_types = held_out_cell_types, 143 | exclude_targets = exclude_targets, 144 | exclude_cell_types = exclude_cell_types, 145 | context_length = context_length, 146 | experiments_json_path = experiments_json_path, 147 | include_biotypes_metadata_in_context = include_biotypes_metadata_in_context, 148 | biotypes_metadata_path = biotypes_metadata_path, 149 | include_biotypes_metadata_columns = include_biotypes_metadata_columns, 150 | biotypes_metadata_delimiter = biotypes_metadata_delimiter, 151 | balance_sampling_by_target = valid_balance_sampling_by_target 152 | ) 153 | 154 | self.include_scoped_negs = include_scoped_negs 155 | 156 | self.dl = get_dataloader(self.ds, cycle_iter = True, shuffle = shuffle, batch_size = batch_size) 157 | self.neg_dl = get_dataloader(self.neg_ds, cycle_iter = True, shuffle = shuffle, batch_size = batch_size) 158 | 159 | if include_scoped_negs: 160 | self.scoped_neg_ds = ScopedNegativePeakDataset( 161 | fasta_file = fasta_file, 162 | factor_fasta_folder = factor_fasta_folder, 163 | numpy_folder_with_scoped_negatives = scoped_negs_path, 164 | remap_bed_file = scoped_negs_remap_bed_path, 165 | exts = scoped_negs_exts, 166 | exclude_targets = [*held_out_targets, *exclude_targets], 167 | exclude_cell_types = [*held_out_cell_types, *exclude_cell_types], 168 | filter_chromosome_ids = train_chromosome_ids, 169 | include_biotypes_metadata_in_context = include_biotypes_metadata_in_context, 170 | biotypes_metadata_path = biotypes_metadata_path, 171 | include_biotypes_metadata_columns = include_biotypes_metadata_columns, 172 | biotypes_metadata_delimiter = biotypes_metadata_delimiter, 173 | balance_sampling_by_target = balance_sampling_by_target 174 | ) 175 | 176 | self.scoped_neg_dl = get_dataloader(self.scoped_neg_ds, cycle_iter = True, shuffle = shuffle, batch_size = batch_size) 177 | 178 | self.valid_dl = get_dataloader(self.valid_ds, cycle_iter = True, shuffle = shuffle, batch_size = batch_size) 179 | self.valid_neg_dl = get_dataloader(self.valid_neg_ds, cycle_iter = True, shuffle = shuffle, batch_size = batch_size) 180 | 181 | self.aux_read_value_loss = model.aux_read_value_loss 182 | 183 | if self.aux_read_value_loss: 184 | print(f'training with read value aux loss') 185 | 186 | self.optim = get_optimizer(model.parameters(), lr = lr, wd = wd) 187 | 188 | self.grad_accum_every = grad_accum_every 189 | self.grad_clip_norm = grad_clip_norm 190 | 191 | self.validate_every = validate_every 192 | self.register_buffer('steps', torch.Tensor([0.])) 193 | 194 | self.checkpoint_filename = checkpoint_filename 195 | 196 | def forward( 197 | self, 198 | finetune_enformer_ln_only = True, 199 | **kwargs 200 | ): 201 | grad_accum_every = self.grad_accum_every 202 | curr_step = int(self.steps.item()) 203 | self.model.train() 204 | 205 | log = {} 206 | 207 | for _ in range(self.grad_accum_every): 208 | dl_outputs = [next(self.dl), next(self.neg_dl)] 209 | 210 | if self.include_scoped_negs: 211 | dl_outputs.append(next(self.scoped_neg_dl)) 212 | 213 | seq, tf_aa, contextual_texts, peaks_nr, read_value, binary_target = collate_dl_outputs(*dl_outputs) 214 | seq, binary_target, read_value, peaks_nr = seq.cuda(), binary_target.cuda(), read_value.cuda(), peaks_nr.cuda() 215 | 216 | loss, aux_loss = self.model( 217 | seq, 218 | target = binary_target, 219 | aa = tf_aa, 220 | contextual_free_text = contextual_texts, 221 | finetune_enformer_ln_only = finetune_enformer_ln_only, 222 | read_value = read_value, 223 | peaks_nr = peaks_nr, 224 | **kwargs 225 | ) 226 | 227 | total_loss = self.model.combine_losses(loss, aux_loss) 228 | 229 | log = accum_log(log, { 230 | 'loss': loss.item() / grad_accum_every, 231 | 'aux_loss': aux_loss.item() / grad_accum_every, 232 | 'total_loss': total_loss.item() / grad_accum_every 233 | }) 234 | 235 | (total_loss / self.grad_accum_every).backward() 236 | 237 | print(f'{curr_step} loss: {log["total_loss"]}') 238 | 239 | if exists(self.grad_clip_norm): 240 | nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm) 241 | 242 | self.optim.step() 243 | self.optim.zero_grad() 244 | 245 | if (curr_step % self.validate_every) == 0: 246 | self.model.eval() 247 | 248 | for _ in range(self.grad_accum_every): 249 | seq, tf_aa, contextual_texts, peaks_nr, read_value, binary_target = collate_dl_outputs(next(self.valid_dl), next(self.valid_neg_dl)) 250 | seq, binary_target = seq.cuda(), binary_target.cuda() 251 | 252 | valid_logits = self.model( 253 | seq, 254 | aa = tf_aa, 255 | contextual_free_text = contextual_texts, 256 | ) 257 | 258 | valid_loss = self.model.loss_fn(valid_logits, binary_target.float()) 259 | valid_accuracy = ((valid_logits.sigmoid() > 0.5).int() == binary_target).sum() / (binary_target.numel()) 260 | 261 | log = accum_log(log, { 262 | 'valid_loss': valid_loss.item() / grad_accum_every, 263 | 'valid_accuracy': valid_accuracy.item() / grad_accum_every 264 | }) 265 | 266 | print(f'{curr_step} valid loss: {log["valid_loss"]}') 267 | print(f'{curr_step} valid accuracy: {log["valid_accuracy"]}') 268 | 269 | if curr_step > 0: 270 | torch.save(self.model.state_dict(), self.checkpoint_filename) 271 | 272 | self.steps += 1 273 | return log 274 | -------------------------------------------------------------------------------- /tf_bind_transformer/data_bigwig.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import polars as pl 3 | import numpy as np 4 | 5 | import torch 6 | from torch.utils.data import Dataset, DataLoader 7 | 8 | from tf_bind_transformer.data import FactorProteinDataset, ContextDataset, cast_list, filter_df_by_tfactor_fastas 9 | from tf_bind_transformer.data import pl_isin, pl_notin, fetch_experiments_index, parse_exp_target_cell, read_bed, cycle, filter_by_col_isin 10 | from tf_bind_transformer.data import CHR_IDS, CHR_NAMES, get_chr_names 11 | from enformer_pytorch import FastaInterval 12 | 13 | try: 14 | import pyBigWig 15 | except ImportError: 16 | print('pyBigWig needs to be installed - conda install pyBigWig') 17 | exit() 18 | 19 | def exists(val): 20 | return val is not None 21 | 22 | def chip_atlas_add_experiment_target_cell( 23 | df, 24 | col_target = 'column_4', 25 | col_cell_type = 'column_5' 26 | ): 27 | df = df.clone() 28 | 29 | targets = df.select(col_target) 30 | targets = targets.to_series(0).str.to_uppercase().rename('target') 31 | df.insert_at_idx(2, targets) 32 | 33 | cell_type = df.select(col_cell_type) 34 | cell_type = cell_type.rename({col_cell_type: 'cell_type'}).to_series(0) 35 | df.insert_at_idx(2, cell_type) 36 | 37 | return df 38 | 39 | # dataset for CHIP ATLAS - all peaks 40 | 41 | class BigWigDataset(Dataset): 42 | def __init__( 43 | self, 44 | *, 45 | factor_fasta_folder, 46 | bigwig_folder, 47 | enformer_loci_path, 48 | fasta_file, 49 | annot_file = None, 50 | filter_chromosome_ids = None, 51 | exclude_targets = None, 52 | include_targets = None, 53 | exclude_cell_types = None, 54 | include_cell_types = None, 55 | df_frac = 1., 56 | experiments_json_path = None, 57 | include_biotypes_metadata_in_context = False, 58 | biotypes_metadata_path = None, 59 | filter_sequences_by = None, 60 | include_biotypes_metadata_columns = [], 61 | biotypes_metadata_delimiter = ' | ', 62 | only_ref = ['mm10', 'hg38'], 63 | factor_species_priority = ['human', 'mouse'], 64 | downsample_factor = 128, 65 | target_length = 896, 66 | bigwig_reduction_type = 'sum', 67 | **kwargs 68 | ): 69 | super().__init__() 70 | assert exists(annot_file) 71 | 72 | if not exists(bigwig_folder): 73 | self.invalid = True 74 | self.ntargets = 0 75 | return 76 | 77 | bigwig_folder = Path(bigwig_folder) 78 | assert bigwig_folder.exists(), 'bigwig folder does not exist' 79 | 80 | bw_experiments = [p.stem for p in bigwig_folder.glob('*.bw')] 81 | assert len(bw_experiments) > 0, 'no bigwig files found in bigwig folder' 82 | 83 | loci = read_bed(enformer_loci_path) 84 | annot_df = pl.read_csv(annot_file, sep = "\t", has_headers = False, columns = list(map(lambda i: f'column_{i + 1}', range(17)))) 85 | 86 | annot_df = annot_df.filter(pl_isin('column_2', only_ref)) 87 | annot_df = filter_by_col_isin(annot_df, 'column_1', bw_experiments) 88 | 89 | if df_frac < 1: 90 | annot_df = annot_df.sample(frac = df_frac) 91 | 92 | dataset_chr_ids = CHR_IDS 93 | 94 | if exists(filter_chromosome_ids): 95 | dataset_chr_ids = dataset_chr_ids.intersection(set(filter_chromosome_ids)) 96 | 97 | # filtering loci by chromosomes 98 | # as well as training or validation 99 | 100 | loci = loci.filter(pl_isin('column_1', get_chr_names(dataset_chr_ids))) 101 | 102 | if exists(filter_sequences_by): 103 | col_name, col_val = filter_sequences_by 104 | loci = loci.filter(pl.col(col_name) == col_val) 105 | 106 | self.factor_ds = FactorProteinDataset(factor_fasta_folder, species_priority = factor_species_priority) 107 | 108 | exp_ids = set(annot_df.get_column('column_1').to_list()) 109 | 110 | annot_df = chip_atlas_add_experiment_target_cell(annot_df) 111 | annot_df = filter_df_by_tfactor_fastas(annot_df, factor_fasta_folder) 112 | 113 | filtered_exp_ids = set(annot_df.get_column('column_1').to_list()) 114 | 115 | filtered_out_exp_ids = exp_ids - filtered_exp_ids 116 | print(f'{", ".join(only_ref)} - {len(filtered_out_exp_ids)} experiments filtered out by lack of transcription factor fastas', filtered_out_exp_ids) 117 | 118 | # filter dataset by inclusion and exclusion list of targets 119 | # ( intersect ) subtract 120 | 121 | include_targets = cast_list(include_targets) 122 | exclude_targets = cast_list(exclude_targets) 123 | 124 | if include_targets: 125 | annot_df = annot_df.filter(pl_isin('target', include_targets)) 126 | 127 | if exclude_targets: 128 | annot_df = annot_df.filter(pl_notin('target', exclude_targets)) 129 | 130 | # filter dataset by inclusion and exclusion list of cell types 131 | # same logic as for targets 132 | 133 | include_cell_types = cast_list(include_cell_types) 134 | exclude_cell_types = cast_list(exclude_cell_types) 135 | 136 | # :TODO reformulate this 137 | # Cell_type should probably be column_6 138 | if include_cell_types: 139 | annot_df = annot_df.filter(pl_isin('cell_type', include_cell_types)) 140 | 141 | if exclude_cell_types: 142 | annot_df = annot_df.filter(pl_notin('cell_type', exclude_cell_types)) 143 | 144 | self.fasta = FastaInterval(fasta_file = fasta_file, **kwargs) 145 | 146 | self.df = loci 147 | self.annot = annot_df 148 | self.ntargets = self.annot.shape[0] 149 | 150 | # bigwigs 151 | 152 | self.bigwigs = [pyBigWig.open(str(bigwig_folder / f'{str(i)}.bw')) for i in self.annot.get_column("column_1")] 153 | 154 | self.downsample_factor = downsample_factor 155 | self.target_length = target_length 156 | 157 | self.bigwig_reduction_type = bigwig_reduction_type 158 | self.invalid = False 159 | 160 | def __len__(self): 161 | if self.invalid: 162 | return 0 163 | 164 | return len(self.df) * self.ntargets 165 | 166 | def __getitem__(self, ind): 167 | # TODO return all targets from an individual enformer loci 168 | chr_name, begin, end, _ = self.df.row(ind % self.df.shape[0]) 169 | 170 | targets = self.annot.select('target').to_series(0) 171 | cell_types = self.annot.select('cell_type').to_series(0) 172 | 173 | ix_target = ind // self.df.shape[0] 174 | 175 | #experiment, target, cell_type = parse_exp_target_cell(experiment_target_cell_type) 176 | 177 | target = targets[ix_target] 178 | context_str = cell_types[ix_target] 179 | exp_bw = self.bigwigs[ix_target] 180 | 181 | # figure out ref and fetch appropriate sequence 182 | 183 | aa_seq = self.factor_ds[target] 184 | seq = self.fasta(chr_name, begin, end) 185 | 186 | # calculate bigwig 187 | # properly downsample and then crop 188 | 189 | output = np.array(exp_bw.values(chr_name, begin, end)) 190 | output = output.reshape((-1, self.downsample_factor)) 191 | 192 | if self.bigwig_reduction_type == 'mean': 193 | om = np.nanmean(output, axis = 1) 194 | elif self.bigwig_reduction_type == 'sum': 195 | om = np.nansum(output, axis = 1) 196 | else: 197 | raise ValueError(f'unknown reduction type {self.bigwig_reduction_type}') 198 | 199 | output_length = output.shape[0] 200 | 201 | if output_length < self.target_length: 202 | assert f'target length {self.target_length} cannot be less than the {output_length}' 203 | 204 | trim = (output.shape[0] - self.target_length) // 2 205 | om = om[trim:-trim] 206 | 207 | np.nan_to_num(om, copy = False) 208 | 209 | label = torch.Tensor(om) 210 | return seq, aa_seq, context_str, label 211 | 212 | # BigWig dataset for tracks only 213 | 214 | class BigWigTracksOnlyDataset(Dataset): 215 | def __init__( 216 | self, 217 | *, 218 | bigwig_folder, 219 | enformer_loci_path, 220 | fasta_file, 221 | ref, 222 | annot_file = None, 223 | filter_chromosome_ids = None, 224 | downsample_factor = 128, 225 | target_length = 896, 226 | bigwig_reduction_type = 'sum', 227 | filter_sequences_by = None, 228 | **kwargs 229 | ): 230 | super().__init__() 231 | assert exists(annot_file) 232 | 233 | if not exists(bigwig_folder): 234 | self.invalid = True 235 | self.ntargets = 0 236 | return 237 | 238 | bigwig_folder = Path(bigwig_folder) 239 | assert bigwig_folder.exists(), 'bigwig folder does not exist' 240 | 241 | bw_experiments = [p.stem for p in bigwig_folder.glob('*.bw')] 242 | assert len(bw_experiments) > 0, 'no bigwig files found in bigwig folder' 243 | 244 | loci = read_bed(enformer_loci_path) 245 | 246 | annot_df = pl.read_csv(annot_file, sep = "\t", has_headers = False, columns = list(map(lambda i: f'column_{i + 1}', range(17)))) 247 | 248 | annot_df = annot_df.filter(pl.col('column_2') == ref) 249 | annot_df = filter_by_col_isin(annot_df, 'column_1', bw_experiments) 250 | 251 | dataset_chr_ids = CHR_IDS 252 | 253 | if exists(filter_chromosome_ids): 254 | dataset_chr_ids = dataset_chr_ids.intersection(set(filter_chromosome_ids)) 255 | 256 | # filtering loci by chromosomes 257 | # as well as training or validation 258 | 259 | loci = loci.filter(pl_isin('column_1', get_chr_names(dataset_chr_ids))) 260 | 261 | if exists(filter_sequences_by): 262 | col_name, col_val = filter_sequences_by 263 | loci = loci.filter(pl.col(col_name) == col_val) 264 | 265 | self.fasta = FastaInterval(fasta_file = fasta_file, **kwargs) 266 | 267 | self.df = loci 268 | self.annot = annot_df 269 | self.ntargets = self.annot.shape[0] 270 | 271 | # bigwigs 272 | 273 | self.bigwigs = [(str(i), pyBigWig.open(str(bigwig_folder / f'{str(i)}.bw'))) for i in self.annot.get_column("column_1")] 274 | 275 | self.downsample_factor = downsample_factor 276 | self.target_length = target_length 277 | 278 | self.bigwig_reduction_type = bigwig_reduction_type 279 | self.invalid = False 280 | 281 | def __len__(self): 282 | if self.invalid: 283 | return 0 284 | 285 | return len(self.df) * int(self.ntargets > 0) 286 | 287 | def __getitem__(self, ind): 288 | chr_name, begin, end, _ = self.df.row(ind) 289 | 290 | # figure out ref and fetch appropriate sequence 291 | 292 | seq = self.fasta(chr_name, begin, end) 293 | 294 | # calculate bigwig 295 | # properly downsample and then crop 296 | 297 | all_bw_values = [] 298 | 299 | for bw_path, bw in self.bigwigs: 300 | try: 301 | bw_values = bw.values(chr_name, begin, end) 302 | all_bw_values.append(bw_values) 303 | except: 304 | print(f'hitting invalid range for {bw_path} - ({chr_name}, {begin}, {end})') 305 | exit() 306 | 307 | output = np.stack(all_bw_values, axis = -1) 308 | output = output.reshape((-1, self.downsample_factor, self.ntargets)) 309 | 310 | if self.bigwig_reduction_type == 'mean': 311 | om = np.nanmean(output, axis = 1) 312 | elif self.bigwig_reduction_type == 'sum': 313 | om = np.nansum(output, axis = 1) 314 | else: 315 | raise ValueError(f'unknown reduction type {self.bigwig_reduction_type}') 316 | 317 | output_length = output.shape[0] 318 | 319 | if output_length < self.target_length: 320 | assert f'target length {self.target_length} cannot be less than the {output_length}' 321 | 322 | trim = (output.shape[0] - self.target_length) // 2 323 | om = om[trim:-trim] 324 | 325 | np.nan_to_num(om, copy = False) 326 | 327 | label = torch.Tensor(om) 328 | return seq, label 329 | 330 | # data loader 331 | 332 | def bigwig_collate_fn(data): 333 | seq, aa_seq, context_str, labels = list(zip(*data)) 334 | return torch.stack(seq), tuple(aa_seq), tuple(context_str), torch.stack(labels) 335 | 336 | def get_bigwig_dataloader(ds, cycle_iter = False, **kwargs): 337 | dataset_len = len(ds) 338 | batch_size = kwargs.get('batch_size') 339 | drop_last = dataset_len > batch_size 340 | 341 | dl = DataLoader(ds, collate_fn = bigwig_collate_fn, drop_last = drop_last, **kwargs) 342 | wrapper = cycle if cycle_iter else iter 343 | return wrapper(dl) 344 | 345 | def get_bigwig_tracks_dataloader(ds, cycle_iter = False, **kwargs): 346 | dataset_len = len(ds) 347 | batch_size = kwargs.get('batch_size') 348 | drop_last = dataset_len > batch_size 349 | 350 | dl = DataLoader(ds, drop_last = drop_last, **kwargs) 351 | wrapper = cycle if cycle_iter else iter 352 | return wrapper(dl) 353 | -------------------------------------------------------------------------------- /tf_bind_transformer/training_utils_bigwig.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from tf_bind_transformer.optimizer import get_optimizer 4 | from tf_bind_transformer.data_bigwig import BigWigDataset, BigWigTracksOnlyDataset, get_bigwig_dataloader, get_bigwig_tracks_dataloader 5 | from enformer_pytorch.modeling_enformer import poisson_loss, pearson_corr_coef 6 | 7 | def exists(val): 8 | return val is not None 9 | 10 | def default(val, d): 11 | return val if exists(val) else d 12 | 13 | # helpers for logging and accumulating values across gradient steps 14 | 15 | def accum_log(log, new_logs): 16 | for key, new_value in new_logs.items(): 17 | old_value = log.get(key, 0.) 18 | log[key] = old_value + new_value 19 | return log 20 | 21 | # simple Trainer class 22 | 23 | class BigWigTrainer(nn.Module): 24 | def __init__( 25 | self, 26 | model, 27 | *, 28 | human_factor_fasta_folder, 29 | annot_file_path, 30 | human_loci_path, 31 | mouse_loci_path, 32 | human_fasta_file, 33 | mouse_fasta_file, 34 | batch_size, 35 | bigwig_tracks_only_folder_path = None, 36 | bigwig_folder_path = None, 37 | train_chromosome_ids = None, 38 | valid_chromosome_ids = None, 39 | mouse_factor_fasta_folder = None, 40 | downsample_factor = 128, 41 | target_length = 896, 42 | lr = 3e-4, 43 | wd = 0.1, 44 | validate_every = 250, 45 | grad_clip_norm = None, 46 | grad_accum_every = 1, 47 | held_out_targets_human = [], 48 | held_out_targets_mouse = [], 49 | held_out_cell_types_human = [], 50 | held_out_cell_types_mouse = [], 51 | context_length = 4096, 52 | shuffle = False, 53 | shift_aug_range = (-2, 2), 54 | rc_aug = False, 55 | checkpoint_filename = './checkpoint.pt', 56 | include_biotypes_metadata_in_context = False, 57 | biotypes_metadata_path = None, 58 | include_biotypes_metadata_columns = ['germ_layer', 'cellline_cat'], 59 | biotypes_metadata_delimiter = ' | ', 60 | bigwig_reduction_type = 'sum', 61 | enformer_train_valid_split = True 62 | ): 63 | super().__init__() 64 | assert exists(bigwig_folder_path) or exists(bigwig_tracks_only_folder_path) 65 | 66 | self.model = model 67 | 68 | mouse_factor_fasta_folder = default(mouse_factor_fasta_folder, human_factor_fasta_folder) 69 | 70 | self.human_ds = BigWigDataset( 71 | filter_chromosome_ids = train_chromosome_ids, 72 | factor_fasta_folder = human_factor_fasta_folder, 73 | bigwig_folder = bigwig_folder_path, 74 | enformer_loci_path = human_loci_path, 75 | annot_file = annot_file_path, 76 | fasta_file = human_fasta_file, 77 | exclude_targets = held_out_targets_human, 78 | exclude_cell_types = held_out_cell_types_human, 79 | target_length = target_length, 80 | context_length = context_length, 81 | downsample_factor = downsample_factor, 82 | shift_augs = shift_aug_range, 83 | rc_aug = rc_aug, 84 | bigwig_reduction_type = bigwig_reduction_type, 85 | filter_sequences_by = ('column_4', 'train') if enformer_train_valid_split else None, 86 | only_ref = ['hg38'], 87 | factor_species_priority = ['human', 'mouse', 'unknown'] 88 | ) 89 | 90 | self.valid_human_ds = BigWigDataset( 91 | filter_chromosome_ids = valid_chromosome_ids, 92 | factor_fasta_folder = human_factor_fasta_folder, 93 | bigwig_folder = bigwig_folder_path, 94 | enformer_loci_path = human_loci_path, 95 | annot_file = annot_file_path, 96 | fasta_file = human_fasta_file, 97 | include_targets = held_out_targets_human, 98 | include_cell_types = held_out_cell_types_human, 99 | target_length = target_length, 100 | context_length = context_length, 101 | downsample_factor = downsample_factor, 102 | shift_augs = shift_aug_range, 103 | rc_aug = rc_aug, 104 | bigwig_reduction_type = bigwig_reduction_type, 105 | filter_sequences_by = ('column_4', 'valid') if enformer_train_valid_split else None, 106 | only_ref = ['hg38'], 107 | factor_species_priority = ['human', 'mouse', 'unknown'] 108 | ) 109 | 110 | self.mouse_ds = BigWigDataset( 111 | filter_chromosome_ids = train_chromosome_ids, 112 | factor_fasta_folder = mouse_factor_fasta_folder, 113 | bigwig_folder = bigwig_folder_path, 114 | enformer_loci_path = mouse_loci_path, 115 | annot_file = annot_file_path, 116 | fasta_file = mouse_fasta_file, 117 | exclude_targets = held_out_targets_mouse, 118 | exclude_cell_types = held_out_cell_types_mouse, 119 | target_length = target_length, 120 | context_length = context_length, 121 | downsample_factor = downsample_factor, 122 | shift_augs = shift_aug_range, 123 | rc_aug = rc_aug, 124 | bigwig_reduction_type = bigwig_reduction_type, 125 | filter_sequences_by = ('column_4', 'train') if enformer_train_valid_split else None, 126 | only_ref = ['mm10'], 127 | factor_species_priority = ['mouse', 'human', 'unknown'] 128 | ) 129 | 130 | self.valid_mouse_ds = BigWigDataset( 131 | filter_chromosome_ids = valid_chromosome_ids, 132 | factor_fasta_folder = mouse_factor_fasta_folder, 133 | bigwig_folder = bigwig_folder_path, 134 | enformer_loci_path = mouse_loci_path, 135 | annot_file = annot_file_path, 136 | fasta_file = mouse_fasta_file, 137 | include_targets = held_out_targets_mouse, 138 | include_cell_types = held_out_cell_types_mouse, 139 | target_length = target_length, 140 | context_length = context_length, 141 | downsample_factor = downsample_factor, 142 | shift_augs = shift_aug_range, 143 | rc_aug = rc_aug, 144 | bigwig_reduction_type = bigwig_reduction_type, 145 | filter_sequences_by = ('column_4', 'valid') if enformer_train_valid_split else None, 146 | only_ref = ['mm10'], 147 | factor_species_priority = ['mouse', 'human', 'unknown'] 148 | ) 149 | 150 | self.human_head_ds = BigWigTracksOnlyDataset( 151 | ref = 'hg38', 152 | bigwig_folder = bigwig_tracks_only_folder_path, 153 | enformer_loci_path = human_loci_path, 154 | fasta_file = human_fasta_file, 155 | annot_file = annot_file_path, 156 | downsample_factor = downsample_factor, 157 | target_length = target_length, 158 | filter_sequences_by = ('column_4', 'train') 159 | ) 160 | 161 | self.valid_human_head_ds = BigWigTracksOnlyDataset( 162 | ref = 'hg38', 163 | bigwig_folder = bigwig_tracks_only_folder_path, 164 | enformer_loci_path = human_loci_path, 165 | fasta_file = human_fasta_file, 166 | annot_file = annot_file_path, 167 | downsample_factor = downsample_factor, 168 | target_length = target_length, 169 | filter_sequences_by = ('column_4', 'valid') 170 | ) 171 | 172 | self.mouse_head_ds = BigWigTracksOnlyDataset( 173 | ref = 'mm10', 174 | bigwig_folder = bigwig_tracks_only_folder_path, 175 | enformer_loci_path = mouse_loci_path, 176 | fasta_file = mouse_fasta_file, 177 | annot_file = annot_file_path, 178 | downsample_factor = downsample_factor, 179 | target_length = target_length, 180 | filter_sequences_by = ('column_4', 'train') 181 | ) 182 | 183 | self.valid_mouse_head_ds = BigWigTracksOnlyDataset( 184 | ref = 'mm10', 185 | bigwig_folder = bigwig_tracks_only_folder_path, 186 | enformer_loci_path = mouse_loci_path, 187 | fasta_file = mouse_fasta_file, 188 | annot_file = annot_file_path, 189 | downsample_factor = downsample_factor, 190 | target_length = target_length, 191 | filter_sequences_by = ('column_4', 'valid') 192 | ) 193 | 194 | len_train_human = len(self.human_ds) 195 | len_train_mouse = len(self.mouse_ds) 196 | len_valid_human = len(self.valid_human_ds) 197 | len_valid_mouse = len(self.valid_mouse_ds) 198 | 199 | len_train_human_head = len(self.human_head_ds) 200 | len_valid_human_head = len(self.valid_human_head_ds) 201 | len_train_mouse_head = len(self.mouse_head_ds) 202 | len_valid_mouse_head = len(self.valid_mouse_head_ds) 203 | 204 | self.has_train = len_train_human > 0 or len_train_mouse > 0 or len_train_human_head > 0 or len_train_mouse_head > 0 205 | self.has_valid = len_valid_human > 0 or len_valid_mouse > 0 or len_valid_human_head > 0 or len_valid_mouse_head > 0 206 | 207 | if self.has_train: 208 | print(f'training with {self.human_ds.ntargets} human targets and {self.mouse_ds.ntargets} mice targets') 209 | print(f'training independent tracks with {self.human_head_ds.ntargets} human targets and {self.mouse_head_ds.ntargets} mouse targets') 210 | 211 | if self.has_valid: 212 | print(f'validating with {self.valid_human_ds.ntargets} human targets and {self.valid_mouse_ds.ntargets} mice targets') 213 | print(f'validating independent tracks with {self.valid_human_head_ds.ntargets} human targets and {self.valid_mouse_head_ds.ntargets} mouse targets') 214 | 215 | assert self.has_train and self.has_valid, 'must have training and validation samples in order to proceed' 216 | 217 | self.train_human_dl = get_bigwig_dataloader(self.human_ds, cycle_iter = True, shuffle = shuffle, batch_size = batch_size) if len_train_human > 0 else None 218 | self.train_mouse_dl = get_bigwig_dataloader(self.mouse_ds, cycle_iter = True, shuffle = shuffle, batch_size = batch_size) if len_train_mouse > 0 else None 219 | 220 | self.valid_human_dl = get_bigwig_dataloader(self.valid_human_ds, cycle_iter = True, shuffle = shuffle, batch_size = batch_size) if len_valid_human > 0 else None 221 | self.valid_mouse_dl = get_bigwig_dataloader(self.valid_mouse_ds, cycle_iter = True, shuffle = shuffle, batch_size = batch_size) if len_valid_mouse > 0 else None 222 | 223 | # dataloader for independent tracks without 224 | 225 | self.train_human_head_dl = get_bigwig_tracks_dataloader(self.human_head_ds, cycle_iter = True, shuffle = shuffle, batch_size = batch_size) if len_train_human_head > 0 else None 226 | self.train_mouse_head_dl = get_bigwig_tracks_dataloader(self.mouse_head_ds, cycle_iter = True, shuffle = shuffle, batch_size = batch_size) if len_train_mouse_head > 0 else None 227 | 228 | self.valid_human_head_dl = get_bigwig_tracks_dataloader(self.valid_human_head_ds, cycle_iter = True, shuffle = shuffle, batch_size = batch_size) if len_valid_human_head > 0 else None 229 | self.valid_mouse_head_dl = get_bigwig_tracks_dataloader(self.valid_mouse_head_ds, cycle_iter = True, shuffle = shuffle, batch_size = batch_size) if len_valid_mouse_head > 0 else None 230 | 231 | # optimizer 232 | 233 | self.optim = get_optimizer(model.parameters(), lr = lr, wd = wd) 234 | 235 | self.grad_accum_every = grad_accum_every 236 | self.grad_clip_norm = grad_clip_norm 237 | 238 | self.validate_every = validate_every 239 | self.register_buffer('steps', torch.Tensor([0.])) 240 | 241 | self.checkpoint_filename = checkpoint_filename 242 | 243 | def forward( 244 | self, 245 | finetune_enformer_ln_only = True, 246 | **kwargs 247 | ): 248 | grad_accum_every = self.grad_accum_every 249 | curr_step = int(self.steps.item()) 250 | self.model.train() 251 | 252 | log = {} 253 | loss_divisor = int(exists(self.train_human_dl)) + int(exists(self.train_mouse_dl)) + int(exists(self.train_human_head_dl)) + int(exists(self.train_mouse_head_dl)) 254 | 255 | if exists(self.train_human_dl): 256 | for _ in range(grad_accum_every): 257 | seq, tf_aa, contextual_texts, target = next(self.train_human_dl) 258 | seq, target = seq.cuda(), target.cuda() 259 | 260 | loss = self.model( 261 | seq, 262 | aa = tf_aa, 263 | contextual_free_text = contextual_texts, 264 | target = target, 265 | finetune_enformer_ln_only = finetune_enformer_ln_only, 266 | **kwargs 267 | ) 268 | 269 | log = accum_log(log, {'human_loss': loss.item() / grad_accum_every}) 270 | (loss / self.grad_accum_every / loss_divisor).backward() 271 | 272 | print(f'{curr_step} human loss: {log["human_loss"]}') 273 | 274 | if exists(self.train_mouse_dl): 275 | for _ in range(grad_accum_every): 276 | seq, tf_aa, contextual_texts, target = next(self.train_mouse_dl) 277 | seq, target = seq.cuda(), target.cuda() 278 | 279 | loss = self.model( 280 | seq, 281 | aa = tf_aa, 282 | contextual_free_text = contextual_texts, 283 | target = target, 284 | finetune_enformer_ln_only = finetune_enformer_ln_only, 285 | **kwargs 286 | ) 287 | 288 | log = accum_log(log, {'mouse_loss': loss.item() / grad_accum_every}) 289 | (loss / self.grad_accum_every / loss_divisor).backward() 290 | 291 | print(f'{curr_step} mouse loss: {log["mouse_loss"]}') 292 | 293 | if exists(self.train_human_head_dl): 294 | for _ in range(grad_accum_every): 295 | seq, target = next(self.train_human_head_dl) 296 | seq, target = seq.cuda(), target.cuda() 297 | 298 | loss = self.model( 299 | seq, 300 | target = target, 301 | head = 'human', 302 | finetune_enformer_ln_only = finetune_enformer_ln_only, 303 | **kwargs 304 | ) 305 | 306 | log = accum_log(log, {'human_head_loss': loss.item() / grad_accum_every}) 307 | (loss / self.grad_accum_every / loss_divisor).backward() 308 | 309 | print(f'{curr_step} human head loss: {log["human_head_loss"]}') 310 | 311 | if exists(self.train_mouse_head_dl): 312 | for _ in range(grad_accum_every): 313 | seq, target = next(self.train_mouse_head_dl) 314 | seq, target = seq.cuda(), target.cuda() 315 | 316 | loss = self.model( 317 | seq, 318 | target = target, 319 | head = 'mouse', 320 | finetune_enformer_ln_only = finetune_enformer_ln_only, 321 | **kwargs 322 | ) 323 | 324 | log = accum_log(log, {'mouse_head_loss': loss.item() / grad_accum_every}) 325 | (loss / self.grad_accum_every / loss_divisor).backward() 326 | 327 | print(f'{curr_step} mouse head loss: {log["mouse_head_loss"]}') 328 | 329 | # gradient clipping 330 | 331 | if exists(self.grad_clip_norm): 332 | nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm) 333 | 334 | # take a gradient step 335 | 336 | self.optim.step() 337 | self.optim.zero_grad() 338 | 339 | # validation 340 | 341 | if (curr_step % self.validate_every) == 0: 342 | self.model.eval() 343 | 344 | if exists(self.valid_human_dl): 345 | for _ in range(grad_accum_every): 346 | seq, tf_aa, contextual_texts, target = next(self.valid_human_dl) 347 | seq, target = seq.cuda(), target.cuda() 348 | 349 | pred = self.model( 350 | seq, 351 | aa = tf_aa, 352 | contextual_free_text = contextual_texts, 353 | ) 354 | 355 | valid_loss = self.model.loss_fn(pred, target) 356 | valid_corr_coef = pearson_corr_coef(pred, target) 357 | 358 | log = accum_log(log, { 359 | 'human_valid_loss': valid_loss.item() / grad_accum_every, 360 | 'human_valid_corr_coef': valid_corr_coef.item() / grad_accum_every 361 | }) 362 | 363 | print(f'{curr_step} human valid loss: {log["human_valid_loss"]}') 364 | print(f'{curr_step} human valid pearson R: {log["human_valid_corr_coef"]}') 365 | 366 | if exists(self.valid_mouse_dl): 367 | for _ in range(grad_accum_every): 368 | seq, tf_aa, contextual_texts, target = next(self.valid_mouse_dl) 369 | seq, target = seq.cuda(), target.cuda() 370 | 371 | pred = self.model( 372 | seq, 373 | aa = tf_aa, 374 | contextual_free_text = contextual_texts, 375 | ) 376 | 377 | valid_loss = self.model.loss_fn(pred, target) 378 | valid_corr_coef = pearson_corr_coef(pred, target) 379 | 380 | log = accum_log(log, { 381 | 'mouse_valid_loss': valid_loss.item() / grad_accum_every, 382 | 'mouse_valid_corr_coef': valid_corr_coef.item() / grad_accum_every 383 | }) 384 | 385 | print(f'{curr_step} mouse valid loss: {log["mouse_valid_loss"]}') 386 | print(f'{curr_step} mouse valid pearson R: {log["mouse_valid_corr_coef"]}') 387 | 388 | if exists(self.valid_human_head_dl): 389 | for _ in range(grad_accum_every): 390 | seq, target = next(self.valid_human_head_dl) 391 | seq, target = seq.cuda(), target.cuda() 392 | 393 | pred = self.model(seq, head = 'human') 394 | 395 | valid_loss = self.model.loss_fn(pred, target) 396 | valid_corr_coef = pearson_corr_coef(pred, target).mean() 397 | 398 | log = accum_log(log, { 399 | 'human_head_valid_loss': valid_loss.item() / grad_accum_every, 400 | 'human_head_valid_corr_coef': valid_corr_coef.item() / grad_accum_every 401 | }) 402 | 403 | print(f'{curr_step} human head valid loss: {log["human_head_valid_loss"]}') 404 | print(f'{curr_step} human head valid pearson R: {log["human_head_valid_corr_coef"]}') 405 | 406 | if exists(self.valid_mouse_head_dl): 407 | for _ in range(grad_accum_every): 408 | seq, target = next(self.valid_mouse_head_dl) 409 | seq, target = seq.cuda(), target.cuda() 410 | 411 | pred = self.model(seq, head = 'mouse') 412 | 413 | valid_loss = self.model.loss_fn(pred, target) 414 | valid_corr_coef = pearson_corr_coef(pred, target).mean() 415 | 416 | log = accum_log(log, { 417 | 'mouse_head_valid_loss': valid_loss.item() / grad_accum_every, 418 | 'mouse_head_valid_corr_coef': valid_corr_coef.item() / grad_accum_every 419 | }) 420 | 421 | 422 | print(f'{curr_step} mouse head valid loss: {log["mouse_head_valid_loss"]}') 423 | print(f'{curr_step} mouse head valid pearson R: {log["mouse_head_valid_corr_coef"]}') 424 | 425 | if curr_step > 0: 426 | torch.save(self.model.state_dict(), self.checkpoint_filename) 427 | 428 | self.steps += 1 429 | return log 430 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Transcription Factor binding predictions with Attention and Transformers 2 | 3 | A repository with exploration into using transformers to predict DNA ↔ transcription factor binding. 4 | 5 | ## Install 6 | 7 | Run the following at the project root to download dependencies 8 | 9 | ```bash 10 | $ python setup.py install --user 11 | ``` 12 | 13 | Then you must install `pybedtools` as well as `pyBigWig` 14 | 15 | ```bash 16 | $ conda install --channel conda-forge --channel bioconda pybedtools pyBigWig 17 | ``` 18 | 19 | ## Usage 20 | 21 | ```python 22 | import torch 23 | from tf_bind_transformer import AdapterModel 24 | 25 | # instantiate enformer or load pretrained 26 | 27 | from enformer_pytorch import Enformer 28 | enformer = Enformer.from_hparams( 29 | dim = 1536, 30 | depth = 2, 31 | target_length = 256 32 | ) 33 | 34 | # instantiate model wrapper that takes in enformer 35 | 36 | model = AdapterModel( 37 | enformer = enformer, 38 | aa_embed_dim = 512, 39 | contextual_embed_dim = 256 40 | ).cuda() 41 | 42 | # mock data 43 | 44 | seq = torch.randint(0, 4, (1, 196_608 // 2)).cuda() # for ACGT 45 | 46 | aa_embed = torch.randn(1, 1024, 512).cuda() 47 | aa_mask = torch.ones(1, 1024).bool().cuda() 48 | 49 | contextual_embed = torch.randn(1, 256).cuda() # contextual embeddings, including cell type, species, experimental parameter embeddings 50 | 51 | target = torch.randn(1, 256).cuda() 52 | 53 | # train 54 | 55 | loss = model( 56 | seq, 57 | aa_embed = aa_embed, 58 | aa_mask = aa_mask, 59 | contextual_embed = contextual_embed, 60 | target = target 61 | ) 62 | 63 | loss.backward() 64 | 65 | # after a lot of training 66 | 67 | corr_coef = model( 68 | seq, 69 | aa_embed = aa_embed, 70 | aa_mask = aa_mask, 71 | contextual_embed = contextual_embed, 72 | target = target, 73 | return_corr_coef = True 74 | ) 75 | ``` 76 | 77 | ## Using ESM or ProtAlbert for fetching of transcription factor protein embeddings 78 | 79 | ```python 80 | import torch 81 | from enformer_pytorch import Enformer 82 | from tf_bind_transformer import AdapterModel 83 | 84 | enformer = Enformer.from_hparams( 85 | dim = 1536, 86 | depth = 2, 87 | target_length = 256 88 | ) 89 | 90 | model = AdapterModel( 91 | enformer = enformer, 92 | use_aa_embeds = True, # set this to True 93 | aa_embed_encoder = 'esm', # by default, will use esm, but can be set to 'protalbert', which has a longer context length of 2048 (vs esm's 1024) 94 | contextual_embed_dim = 256 95 | ).cuda() 96 | 97 | # mock data 98 | 99 | seq = torch.randint(0, 4, (1, 196_608 // 2)).cuda() 100 | tf_aa = torch.randint(0, 21, (1, 4)).cuda() # transcription factor amino acid sequence, from 0 to 20 101 | 102 | contextual_embed = torch.randn(1, 256).cuda() 103 | target = torch.randn(1, 256).cuda() 104 | 105 | # train 106 | 107 | loss = model( 108 | seq, 109 | aa = tf_aa, 110 | contextual_embed = contextual_embed, 111 | target = target 112 | ) 113 | 114 | loss.backward() 115 | ``` 116 | 117 | - [ ] add alphafold2 118 | 119 | ## Context passed in as free text 120 | 121 | One can also pass the context (cell type, experimental parameters) directly as free text, which will be encoded by a text transformer trained on pubmed abstracts. 122 | 123 | ```python 124 | import torch 125 | from tf_bind_transformer import AdapterModel 126 | 127 | # instantiate enformer or load pretrained 128 | 129 | from enformer_pytorch import Enformer 130 | enformer = Enformer.from_hparams( 131 | dim = 1536, 132 | depth = 2, 133 | target_length = 256 134 | ) 135 | 136 | # instantiate model wrapper that takes in enformer 137 | 138 | model = AdapterModel( 139 | enformer = enformer, 140 | use_aa_embeds = True, 141 | use_free_text_context = True, # this must be set to True 142 | free_text_embed_method = 'mean_pool' # allow for mean pooling of embeddings, instead of using CLS token 143 | ).cuda() 144 | 145 | # mock data 146 | 147 | seq = torch.randint(0, 4, (2, 196_608 // 2)).cuda() # for ACGT 148 | target = torch.randn(2, 256).cuda() 149 | 150 | tf_aa = [ 151 | 'KVFGRCELAA', # single protein 152 | ('AMKRHGLDNY', 'YNDLGHRKMA') # complex, representations will be concatted together 153 | ] 154 | 155 | contextual_texts = [ 156 | 'cell type: GM12878 | dual cross-linked', 157 | 'cell type: H1-hESC' 158 | ] 159 | 160 | # train 161 | 162 | loss = model( 163 | seq, 164 | target = target, 165 | aa = tf_aa, 166 | contextual_free_text = contextual_texts, 167 | ) 168 | 169 | loss.backward() 170 | ``` 171 | 172 | ## Binary prediction 173 | 174 | For predicting binary outcome (bind or not bind), just set the `binary_targets = True` when initializing either adapters 175 | 176 | ex. 177 | 178 | ```python 179 | import torch 180 | from tf_bind_transformer import AdapterModel 181 | from enformer_pytorch import Enformer 182 | 183 | # instantiate enformer or load pretrained 184 | 185 | enformer = Enformer.from_hparams( 186 | dim = 1536, 187 | depth = 2, 188 | target_length = 256 189 | ) 190 | 191 | # instantiate model wrapper that takes in enformer 192 | 193 | model = AdapterModel( 194 | enformer = enformer, 195 | use_aa_embeds = True, 196 | use_free_text_context = True, 197 | free_text_embed_method = 'mean_pool', 198 | use_squeeze_excite = True, 199 | binary_target = True, # set this to True 200 | target_mse_loss = False # whether to use MSE loss with target value 201 | ).cuda() 202 | 203 | # mock data 204 | 205 | seq = torch.randint(0, 4, (1, 196_608 // 2)).cuda() # for ACGT 206 | binary_target = torch.randint(0, 2, (2,)).cuda() # bind or not bind 207 | 208 | tf_aa = [ 209 | 'KVFGRCELAA', 210 | ('AMKRHGLDNY', 'YNDLGHRKMA') 211 | ] 212 | 213 | contextual_texts = [ 214 | 'cell type: GM12878 | chip-seq dual cross-linked', 215 | 'cell type: H1-hESC | chip-seq single cross-linked' 216 | ] 217 | 218 | # train 219 | 220 | loss = model( 221 | seq, 222 | target = binary_target, 223 | aa = tf_aa, 224 | contextual_free_text = contextual_texts, 225 | ) 226 | 227 | loss.backward() 228 | ``` 229 | 230 | ## Predicting Tracks from BigWig 231 | 232 | ```python 233 | from pathlib import Path 234 | import torch 235 | from enformer_pytorch import Enformer 236 | 237 | from tf_bind_transformer import AdapterModel 238 | from tf_bind_transformer.data_bigwig import BigWigDataset, get_bigwig_dataloader 239 | 240 | # constants 241 | 242 | ROOT = Path('.') 243 | TFACTOR_TF = str(ROOT / 'tfactor.fastas') 244 | ENFORMER_DATA = str(ROOT / 'chip_atlas' / 'sequences.bed') 245 | FASTA_FILE_PATH = str(ROOT / 'hg38.ml.fa') 246 | BIGWIG_PATH = str(ROOT / 'chip_atlas') 247 | ANNOT_FILE_PATH = str(ROOT / 'chip_atlas' / 'annot.tab') 248 | 249 | # bigwig dataset and dataloader 250 | 251 | ds = BigWigDataset( 252 | factor_fasta_folder = TFACTOR_TF, 253 | bigwig_folder = BIGWIG_PATH, 254 | enformer_loci_path = ENFORMER_DATA, 255 | annot_file = ANNOT_FILE_PATH, 256 | fasta_file = FASTA_FILE_PATH 257 | ) 258 | 259 | dl = get_bigwig_dataloader(ds, batch_size = 2) 260 | 261 | # enformer 262 | 263 | enformer = Enformer.from_hparams( 264 | dim = 384, 265 | depth = 1, 266 | target_length = 896 267 | ) 268 | 269 | model = AdapterModel( 270 | enformer = enformer, 271 | use_aa_embeds = True, 272 | use_free_text_context = True 273 | ).cuda() 274 | 275 | # mock data 276 | 277 | seq, tf_aa, context_str, target = next(dl) 278 | seq, target = seq.cuda(), target.cuda() 279 | 280 | # train 281 | 282 | loss = model( 283 | seq, 284 | aa = tf_aa, 285 | contextual_free_text = context_str, 286 | target = target 287 | ) 288 | 289 | loss.backward() 290 | ``` 291 | ## Data 292 | 293 | The data needed for training is at this download page. 294 | 295 | ### Transcription factors for Human and Mouse 296 | 297 | To download the protein sequences for both species, you need to download the remap CRMs bed files, from which all the targets will be extracted, and fastas to be downloaded from Uniprot. 298 | 299 | Download human remap CRMS 300 | 301 | ```bash 302 | $ wget https://remap.univ-amu.fr/storage/remap2022/hg38/MACS2/remap2022_crm_macs2_hg38_v1_0.bed.gz 303 | $ gzip -d remap2022_crm_macs2_hg38_v1_0.bed.gz 304 | ``` 305 | 306 | Download mouse remap CRMs 307 | 308 | ```bash 309 | $ wget https://remap.univ-amu.fr/storage/remap2022/mm10/MACS2/remap2022_crm_macs2_mm10_v1_0.bed.gz 310 | $ gzip -d remap2022_crm_macs2_mm10_v1_0.bed.gz 311 | ``` 312 | 313 | Downloading all human transcription factors 314 | 315 | ```bash 316 | $ python script/fetch_factor_fastas.py --species human 317 | ``` 318 | 319 | For mouse transcription factors 320 | 321 | ```bash 322 | $ python script/fetch_factor_fastas.py --species mouse 323 | ```` 324 | 325 | ## Generating Negatives 326 | 327 | ### Generating Hard Negatives 328 | 329 | For starters, the `RemapAllPeakDataset` will allow you to load data easily from the full remap peaks bed file for training. 330 | 331 | Firstly you'll need to generate the non-peaks dataset by running the following function 332 | 333 | ```python 334 | from tf_bind_transformer.data import generate_random_ranges_from_fasta 335 | 336 | generate_random_ranges_from_fasta( 337 | './hg38.ml.fa', 338 | output_filename = './path/to/generated-non-peaks.bed', # path to output file 339 | context_length = 4096, 340 | num_entries_per_key = 1_000_000, # number of negative samples 341 | filter_bed_files = [ 342 | './remap_all.bed', # filter out by all peak ranges (todo, allow filtering namespaced to experiment and target) 343 | './hg38.blacklist.rep.bed' # further filtering by blacklisted regions (gs://basenji_barnyard/hg38.blacklist.rep.bed) 344 | ] 345 | ) 346 | ``` 347 | 348 | ### Generating Scoped Negatives - Negatives per Dataset (experiment + target + cell type) 349 | 350 | Todo 351 | 352 | ## Simple Trainer class for fine-tuning 353 | 354 | working fine-tuning training loop for bind / no-bind prediction 355 | 356 | ```python 357 | import torch 358 | from enformer_pytorch import Enformer 359 | 360 | from tf_bind_transformer import AdapterModel, Trainer 361 | 362 | # instantiate enformer or load pretrained 363 | 364 | enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough', target_length = -1) 365 | 366 | # instantiate model wrapper that takes in enformer 367 | 368 | model = AdapterModel( 369 | enformer = enformer, 370 | use_aa_embeds = True, 371 | use_free_text_context = True, 372 | free_text_embed_method = 'mean_pool', 373 | binary_target = True, 374 | target_mse_loss = True, 375 | use_squeeze_excite = True, 376 | aux_read_value_loss = True # use auxiliary read value loss, can be turned off 377 | ).cuda() 378 | 379 | # pass the model (adapter + enformer) to the Trainer 380 | 381 | trainer = Trainer( 382 | model, 383 | batch_size = 2, # batch size 384 | context_length = 4096, # genetic sequence length 385 | grad_accum_every = 8, # gradient accumulation steps 386 | grad_clip_norm = 2.0, # gradient clipping 387 | validate_every = 250, 388 | remap_bed_file = './remap2022_all.bed', # path to remap bed peaks 389 | negative_bed_file = './generated-non-peaks.bed', # path to generated non-peaks 390 | factor_fasta_folder = './tfactor.fastas', # path to factor fasta files 391 | fasta_file = './hg38.ml.fa', # human genome sequences 392 | train_chromosome_ids = [*range(1, 24, 2), 'X'], # chromosomes to train on 393 | valid_chromosome_ids = [*range(2, 24, 2)], # chromosomes to validate on 394 | held_out_targets = ['AFF4'], # targets to hold out for validation 395 | experiments_json_path = './data/experiments.json' # path to all experiments data, at this path relative to the project root, if repository is git cloned 396 | ) 397 | 398 | while True: 399 | _ = trainer() 400 | 401 | ``` 402 | 403 | working fine-tuning script for training on new enformer tracks, with cross-attending transcription factor protein embeddings and cell type conditioning 404 | 405 | ```python 406 | from dotenv import load_dotenv 407 | 408 | # set path to cache in .env and unset the next comment 409 | # load_dotenv() 410 | 411 | from enformer_pytorch import Enformer 412 | from tf_bind_transformer import AdapterModel, BigWigTrainer 413 | 414 | # training constants 415 | 416 | BATCH_SIZE = 1 417 | GRAD_ACCUM_STEPS = 8 418 | 419 | # effective batch size of BATCH_SIZE * GRAD_ACCUM_STEPS = 16 420 | 421 | VALIDATE_EVERY = 250 422 | GRAD_CLIP_MAX_NORM = 1.5 423 | 424 | TFACTOR_FOLDER = './tfactor.fastas' 425 | FASTA_FILE_PATH = './hg38.ml.fa' 426 | 427 | LOCI_PATH = './sequences.bed' 428 | BIGWIG_PATH = './bigwig_folder' 429 | ANNOT_FILE_PATH = './experiments.tab' 430 | TARGET_LENGTH = 896 431 | 432 | TRAIN_CHROMOSOMES = [*range(1, 24, 2), 'X'] # train on odd chromosomes 433 | VALID_CHROMOSOMES = [*range(2, 24, 2)] # validate on even 434 | 435 | HELD_OUT_TARGET = ['SOX2'] 436 | 437 | # instantiate enformer or load pretrained 438 | 439 | enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough', target_length = TARGET_LENGTH) 440 | 441 | # instantiate model wrapper that takes in enformer 442 | 443 | model = AdapterModel( 444 | enformer = enformer, 445 | use_aa_embeds = True, 446 | use_free_text_context = True, 447 | free_text_embed_method = 'mean_pool', 448 | aa_embed_encoder = 'protalbert' 449 | ).cuda() 450 | 451 | 452 | # trainer class for fine-tuning 453 | 454 | trainer = BigWigTrainer( 455 | model, 456 | loci_path = LOCI_PATH, 457 | bigwig_folder_path = BIGWIG_PATH, 458 | annot_file_path = ANNOT_FILE_PATH, 459 | target_length = TARGET_LENGTH, 460 | batch_size = BATCH_SIZE, 461 | validate_every = VALIDATE_EVERY, 462 | grad_clip_norm = GRAD_CLIP_MAX_NORM, 463 | grad_accum_every = GRAD_ACCUM_STEPS, 464 | factor_fasta_folder = TFACTOR_FOLDER, 465 | fasta_file = FASTA_FILE_PATH, 466 | train_chromosome_ids = TRAIN_CHROMOSOMES, 467 | valid_chromosome_ids = VALID_CHROMOSOMES, 468 | held_out_targets = HELD_OUT_TARGET 469 | ) 470 | 471 | # do gradient steps in a while loop 472 | 473 | while True: 474 | _ = trainer() 475 | ``` 476 | 477 | ## Resources 478 | 479 | If you are low on GPU memory, you can save by making sure the protein and contextual embeddings are executed on CPU 480 | 481 | ```bash 482 | CONTEXT_EMBED_USE_CPU=1 PROTEIN_EMBED_USE_CPU=1 python train.py 483 | ``` 484 | 485 | ## Data 486 | 487 | Transcription factor dataset 488 | 489 | ```python 490 | from tf_bind_transformer.data import FactorProteinDataset 491 | 492 | ds = FactorProteinDataset( 493 | folder = 'path/to/tfactor/fastas' 494 | ) 495 | 496 | # single factor 497 | 498 | ds['ETV1'] # 499 | 500 | # multi-complexes 501 | 502 | ds['PAX3-FOXO1'] # (, ) 503 | 504 | ``` 505 | 506 | ## Preprocessing (wip) 507 | 508 | get a copy of hg38 blacklist bed file from calico 509 | 510 | ```bash 511 | $ gsutil cp gs://basenji_barnyard/hg38.blacklist.rep.bed ./ 512 | ``` 513 | 514 | using bedtools to filter out repetitive regions of the genome 515 | 516 | ```bash 517 | $ bedtools intersect -v -a ./remap2022_all_macs2_hg38_v1_0.bed -b ./hg38.blacklist.rep.bed > remap2022_all_filtered.bed 518 | ``` 519 | 520 | ## Caching 521 | 522 | During training, protein sequences and contextual strings are cached to `~/.cache.tf.bind.transformer` directory. If you would like to make sure the caching is working, you just need to run your training script with `VERBOSE=1` 523 | 524 | ex. 525 | 526 | ```bash 527 | $ VERBOSE=1 python train.py 528 | ``` 529 | 530 | You can also force a cache clearance 531 | 532 | ```bash 533 | $ CLEAR_CACHE=1 python train.py 534 | ``` 535 | 536 | ## Todo 537 | 538 | - [x] ESM and AF2 embedding fetching integrations 539 | - [x] HF transformers integration for conditioning on free text 540 | - [x] allow for fine-tuning layernorms of Enformer easily 541 | - [x] add caching for external embeddings 542 | - [x] figure out a way for external models (ESM, transformers) to be omitted from state dictionary on saving (use singletons) 543 | - [x] take care of caching genetic sequences when enformer is frozen 544 | - [x] offer a fully transformer variant with cross-attention with shared attention matrix and FiLM conditioning with contextual embed 545 | - [x] also offer using pooled genetic / protein sequence concatted with context -> project -> squeeze excitation type conditioning 546 | - [x] use checkpointing when fine-tuning enformer 547 | - [x] take care of prepping dataframe with proper chromosome and training / validation split 548 | - [x] use basenji blacklist bed file for filtering out rows in remap 549 | - [x] filter remap dataframe based on tfactor fasta folder 550 | - [x] filter remap dataframe with hg38 blacklist 551 | - [x] handle targets with modifications from remap with all peaks (underscore in name) 552 | - [x] grad clipping 553 | - [x] add a safe initialization whereby rows of dataframe with targets not found in the tfactor fasta folder will be filtered out 554 | - [x] add accuracy metric to fine tune script 555 | - [x] master trainer class that handles both training / validation splitting, efficient instantiation of dataframe, filtering etc 556 | - [x] write a simple trainer class that takes care of the training loop 557 | - [x] create faster protein and context embedding derivation by optionally moving model to gpu and back to cpu when done 558 | - [x] use ProtTrans for longer context proteins, look into AF2 559 | - [x] make protalbert usable with one flag 560 | - [x] log auxiliary losses appropriately (read value) 561 | - [x] write fine-tuning script for finetuning on merged genomic track(s) from remap 562 | - [ ] support for custom transformers other than enformer 563 | - [ ] warmup in training loop 564 | - [ ] mixed precision 565 | - [ ] use wandb for experiment tracking 566 | - [ ] cleanup tech debt in data and protein_utils 567 | - [ ] explore protein model fine-tuning of layernorm 568 | - [ ] auto-auroc calc 569 | - [ ] k-fold cross validation 570 | - [ ] output attention intermediates (or convolution output for hypertransformer), for interpreting binding site 571 | - [ ] use prefect.io to manage downloading of tfactors fastas, remap scoped negative peaks, blacklist filtering etc 572 | 573 | ## Appreciation 574 | 575 | This work was generously sponsored by Jeff Hsu to be done completely open sourced. 576 | 577 | ## Citations 578 | 579 | ```bibtex 580 | @article {Avsec2021.04.07.438649, 581 | author = {Avsec, {\v Z}iga and Agarwal, Vikram and Visentin, Daniel and Ledsam, Joseph R. and Grabska-Barwinska, Agnieszka and Taylor, Kyle R. and Assael, Yannis and Jumper, John and Kohli, Pushmeet and Kelley, David R.}, 582 | title = {Effective gene expression prediction from sequence by integrating long-range interactions}, 583 | elocation-id = {2021.04.07.438649}, 584 | year = {2021}, 585 | doi = {10.1101/2021.04.07.438649}, 586 | publisher = {Cold Spring Harbor Laboratory}, 587 | URL = {https://www.biorxiv.org/content/early/2021/04/08/2021.04.07.438649}, 588 | eprint = {https://www.biorxiv.org/content/early/2021/04/08/2021.04.07.438649.full.pdf}, 589 | journal = {bioRxiv} 590 | } 591 | ``` 592 | 593 | ```bibtex 594 | @misc{yao2021filip, 595 | title = {FILIP: Fine-grained Interactive Language-Image Pre-Training}, 596 | author = {Lewei Yao and Runhui Huang and Lu Hou and Guansong Lu and Minzhe Niu and Hang Xu and Xiaodan Liang and Zhenguo Li and Xin Jiang and Chunjing Xu}, 597 | year = {2021}, 598 | eprint = {2111.07783}, 599 | archivePrefix = {arXiv}, 600 | primaryClass = {cs.CV} 601 | } 602 | ``` 603 | 604 | ```bibtex 605 | @misc{tay2020hypergrid, 606 | title = {HyperGrid: Efficient Multi-Task Transformers with Grid-wise Decomposable Hyper Projections}, 607 | author = {Yi Tay and Zhe Zhao and Dara Bahri and Donald Metzler and Da-Cheng Juan}, 608 | year = {2020}, 609 | eprint = {2007.05891}, 610 | archivePrefix = {arXiv}, 611 | primaryClass = {cs.CL} 612 | } 613 | ``` 614 | 615 | ```bibtex 616 | @misc{lowe2021logavgexp, 617 | title = {LogAvgExp Provides a Principled and Performant Global Pooling Operator}, 618 | author = {Scott C. Lowe and Thomas Trappenberg and Sageev Oore}, 619 | year = {2021}, 620 | eprint = {2111.01742}, 621 | archivePrefix = {arXiv}, 622 | primaryClass = {cs.LG} 623 | } 624 | ``` 625 | 626 | ```bibtex 627 | @article{10.1093/nar/gkab996, 628 | author = {Hammal, Fayrouz and de Langen, Pierre and Bergon, Aurélie and Lopez, Fabrice and Ballester, Benoit}, 629 | title = "{ReMap 2022: a database of Human, Mouse, Drosophila and Arabidopsis regulatory regions from an integrative analysis of DNA-binding sequencing experiments}", 630 | journal = {Nucleic Acids Research}, 631 | issn = {0305-1048}, 632 | doi = {10.1093/nar/gkab996}, 633 | url = {https://doi.org/10.1093/nar/gkab996}, 634 | eprint = {https://academic.oup.com/nar/article-pdf/50/D1/D316/42058627/gkab996.pdf}, 635 | } 636 | ``` 637 | -------------------------------------------------------------------------------- /tf_bind_transformer/tf_bind_transformer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from functools import wraps 7 | 8 | from einops import rearrange, reduce, repeat 9 | from einops.layers.torch import Rearrange, Reduce 10 | 11 | from contextlib import contextmanager 12 | 13 | from enformer_pytorch import Enformer 14 | from enformer_pytorch.modeling_enformer import poisson_loss, pearson_corr_coef 15 | from enformer_pytorch.finetune import freeze_batchnorms_, freeze_all_but_layernorms_, unfreeze_last_n_layers_, unfreeze_all_layers_ 16 | 17 | from logavgexp_pytorch import logavgexp 18 | 19 | from tf_bind_transformer.cache_utils import cache_fn 20 | from tf_bind_transformer.protein_utils import get_protein_embedder 21 | from tf_bind_transformer.context_utils import get_text_repr, get_contextual_dim 22 | 23 | from tf_bind_transformer.attention import FeedForward, JointCrossAttentionBlock, CrossAttention, SelfAttentionBlock 24 | 25 | # helper functions 26 | 27 | def exists(val): 28 | return val is not None 29 | 30 | def default(val, d): 31 | return val if exists(val) else d 32 | 33 | def identity(fn, *args, **kwargs): 34 | return fn 35 | 36 | @contextmanager 37 | def null_context(): 38 | yield 39 | 40 | # tensor helpers 41 | 42 | def l2norm(t): 43 | return F.normalize(t, dim = -1) 44 | 45 | def prob_mask_like(t, prob): 46 | return torch.zeros_like(t).float().uniform_(0, 1) < prob 47 | 48 | def fourier_encode(x, dims, theta = 20000): 49 | device, dtype = x.device, x.dtype 50 | emb = math.log(theta) / (dims // 2) 51 | emb = torch.exp(torch.arange(dims // 2, device = device) * -emb) 52 | emb = rearrange(x, 'n -> n 1') * rearrange(emb, 'd -> 1 d') 53 | emb = torch.cat((emb.sin(), emb.cos()), dim = -1) 54 | return emb 55 | 56 | def corr_coef_loss(pred, target): 57 | return 1 - pearson_corr_coef(pred, target).mean() 58 | 59 | # genetic sequence caching enformer forward decorator 60 | 61 | def cache_enformer_forward(fn): 62 | cached_forward = cache_fn(fn, clear = True, path = 'genetic') 63 | 64 | @wraps(fn) 65 | def inner(seqs, *args, **kwargs): 66 | if seqs.ndim == 3: 67 | seqs = seqs.argmax(dim = -1) 68 | 69 | seq_list = seqs.unbind(dim = 0) 70 | seq_cache_keys = [''.join(list(map(str, one_seq.tolist()))) for one_seq in seq_list] 71 | outputs = [cached_forward(one_seq, *args, __cache_key = seq_cache_key, **kwargs) for one_seq, seq_cache_key in zip(seq_list, seq_cache_keys)] 72 | return torch.stack(outputs) 73 | 74 | return inner 75 | 76 | # model 77 | 78 | class FiLM(nn.Module): 79 | def __init__( 80 | self, 81 | dim, 82 | conditioned_dim 83 | ): 84 | super().__init__() 85 | self.to_gamma = nn.Linear(dim, conditioned_dim) 86 | self.to_bias = nn.Linear(dim, conditioned_dim) 87 | 88 | def forward(self, x, condition, mask = None): 89 | gamma = self.to_gamma(condition) 90 | bias = self.to_bias(condition) 91 | 92 | x = x * rearrange(gamma, 'b d -> b 1 d') 93 | x = x + rearrange(bias, 'b d -> b 1 d') 94 | return x 95 | 96 | class SqueezeExcitation(nn.Module): 97 | def __init__( 98 | self, 99 | dim, 100 | conditioned_dim, 101 | eps = 1e-8 102 | ): 103 | super().__init__() 104 | self.eps = eps 105 | self.to_gate = nn.Linear(dim + conditioned_dim, conditioned_dim) 106 | 107 | def forward(self, x, condition, mask = None): 108 | if exists(mask): 109 | numer = x.masked_fill(mask[..., None], 0.).sum(dim = 1) 110 | denom = mask.sum(dim = 1)[..., None].clamp(min = self.eps) 111 | mean_x = numer / denom 112 | else: 113 | mean_x = x.mean(dim = 1) 114 | 115 | condition = torch.cat((condition, mean_x), dim = -1) 116 | gate = self.to_gate(condition) 117 | 118 | x = x * rearrange(gate, 'b d -> b 1 d').sigmoid() 119 | return x 120 | 121 | # read value MLP for calculating auxiliary loss 122 | 123 | class ReadValueMLP(nn.Module): 124 | def __init__( 125 | self, 126 | dim, 127 | *, 128 | fourier_dims = 256, 129 | norm_factor_fourier = 50, 130 | norm_factor_linear = 8000, 131 | eps = 1e-20 132 | ): 133 | super().__init__() 134 | self.eps = eps 135 | self.fourier_dims = fourier_dims 136 | self.norm_factor_fourier = norm_factor_fourier 137 | self.norm_factor_linear = norm_factor_linear 138 | 139 | self.logits_norm = nn.Sequential( 140 | Reduce('b n d -> b d', 'mean'), 141 | nn.LayerNorm(dim) 142 | ) 143 | 144 | self.mlp = nn.Sequential( 145 | nn.Linear(dim + fourier_dims + 2, dim * 2), 146 | nn.GELU(), 147 | nn.Linear(dim * 2, 1), 148 | Rearrange('... 1 -> ...') 149 | ) 150 | 151 | def forward(self, logits, peaks_nr, read_value): 152 | logits = self.logits_norm(logits) 153 | 154 | peaks_nr_log_space = torch.log(peaks_nr + self.eps) 155 | 156 | peaks_nr = rearrange(peaks_nr, '... -> (...)') 157 | peaks_nr_encoded = fourier_encode(peaks_nr / self.norm_factor_fourier, self.fourier_dims) 158 | peaks_nr_normed = rearrange(peaks_nr, '... -> ... 1') / self.norm_factor_linear 159 | 160 | peaks_nr_encoded_with_self = torch.cat((peaks_nr_normed, peaks_nr_log_space, peaks_nr_encoded), dim = -1) 161 | 162 | logits_with_peaks = torch.cat((logits, peaks_nr_encoded_with_self), dim = -1) 163 | 164 | pred = self.mlp(logits_with_peaks) 165 | read_value = rearrange(read_value, '... -> (...)') 166 | 167 | return F.smooth_l1_loss(pred, read_value) 168 | 169 | class HypergridLinear(nn.Module): 170 | def __init__( 171 | self, 172 | dim, 173 | dim_out, 174 | *, 175 | context_dim 176 | ): 177 | super().__init__() 178 | self.weights = nn.Parameter(torch.randn(dim, dim_out)) 179 | self.contextual_projection = nn.Linear(context_dim, dim * dim_out) 180 | 181 | def forward(self, x, context): 182 | # derive contextual gating, from hypergrids paper 183 | 184 | gating = self.contextual_projection(context).sigmoid() 185 | gating = rearrange(gating, 'b (i o) -> b i o', i = int(math.sqrt(gating.shape[-1]))) 186 | 187 | # gate interactions projection with context 188 | 189 | to_logits_w = rearrange(self.weights, 'i o -> 1 i o') * gating 190 | return einsum('b n d, b d e -> b n e', x, to_logits_w) 191 | 192 | # FILIP adapter model 193 | 194 | class FILIP(nn.Module): 195 | def __init__( 196 | self, 197 | dim, 198 | context_dim, 199 | heads, 200 | dim_head = 64, 201 | dropout = 0. 202 | ): 203 | super().__init__() 204 | self.heads = heads 205 | inner_latent_dim = heads * dim_head 206 | 207 | self.to_latent_w = nn.Parameter(torch.randn(dim, inner_latent_dim)) 208 | self.to_latent_b = nn.Parameter(torch.randn(inner_latent_dim)) 209 | 210 | self.pre_attn_dropout = dropout 211 | 212 | self.null_context = nn.Parameter(torch.randn(heads, dim_head)) 213 | self.context_to_latent_w = nn.Parameter(torch.randn(context_dim, inner_latent_dim)) 214 | self.context_to_latent_b = nn.Parameter(torch.randn(inner_latent_dim)) 215 | 216 | def forward( 217 | self, 218 | x, 219 | context, 220 | context_mask = None 221 | ): 222 | b, heads, device = x.shape[0], self.heads, x.device 223 | 224 | x = einsum('b n d, d e -> b n e', x, self.to_latent_w) 225 | x = x + self.to_latent_b 226 | 227 | x = rearrange(x, 'b n (h d) -> b h n d', h = heads) 228 | 229 | context = einsum('b n d, d e -> b n e', context, self.context_to_latent_w) 230 | context = context + self.context_to_latent_b 231 | 232 | context = rearrange(context, 'b n (h d) -> b h n d', h = heads) 233 | 234 | context, x = map(l2norm, (context, x)) 235 | 236 | # fine grained interaction between dna and protein sequences 237 | # FILIP https://arxiv.org/abs/2111.07783 238 | 239 | if x.shape[0] == 1: 240 | # in the case one passes in 1 genomic sequence track 241 | # but multiple factors + contexts, as in enformer training 242 | x = rearrange(x, '1 ... -> ...') 243 | einsum_eq = 'h i d, b h j d -> b h i j' 244 | else: 245 | einsum_eq = 'b h i d, b h j d -> b h i j' 246 | 247 | # create context mask if not exist 248 | 249 | if not exists(context_mask): 250 | context_mask = torch.ones((b, context.shape[-1]), device = device).bool() 251 | 252 | # dropout mask by dropout prob 253 | 254 | if self.training: 255 | keep_mask = prob_mask_like(context_mask, 1 - self.pre_attn_dropout) 256 | context_mask = context_mask & keep_mask 257 | 258 | # add null context and modify mask 259 | 260 | context_mask = F.pad(context_mask, (1, 0), value = True) 261 | context_mask = rearrange(context_mask, 'b j -> b 1 1 j') 262 | 263 | null_context = repeat(self.null_context, 'h d -> b h 1 d', b = b) 264 | context = torch.cat((null_context, context), dim = -2) 265 | 266 | # differentiable max, as in FILIP paper 267 | 268 | interactions = einsum(einsum_eq, x, context) 269 | interactions = logavgexp(interactions, mask = context_mask, dim = -1, temp = 0.05) 270 | interactions = rearrange(interactions, 'b h i -> b i h') 271 | return interactions 272 | 273 | class AdapterModel(nn.Module): 274 | def __init__( 275 | self, 276 | *, 277 | enformer, 278 | latent_dim = 64, 279 | latent_heads = 32, 280 | aa_embed_dim = None, 281 | aa_embed_encoder = 'esm', 282 | contextual_embed_dim = None, 283 | use_aa_embeds = False, 284 | use_free_text_context = False, 285 | free_text_context_encoder = 'pubmed', 286 | free_text_embed_method = 'cls', 287 | dropout = 0., 288 | binary_target = False, 289 | target_mse_loss = False, 290 | aux_read_value_loss = False, 291 | read_value_aux_loss_weight = 0.05, 292 | joint_cross_attn_depth = 1, 293 | genome_self_attn_depth = 0, 294 | fourier_dims = 256, 295 | condition_squeeze_excite = False, 296 | condition_film = False, 297 | condition_hypergrid = True, 298 | use_corr_coef_loss = False, 299 | finetune_output_heads = None, 300 | **kwargs 301 | ): 302 | super().__init__() 303 | assert isinstance(enformer, Enformer), 'enformer must be an instance of Enformer' 304 | self.enformer = enformer 305 | enformer_dim = enformer.dim * 2 306 | 307 | if exists(finetune_output_heads): 308 | self.enformer.add_heads(**finetune_output_heads) 309 | 310 | self.norm_seq_embed = nn.LayerNorm(enformer_dim) 311 | 312 | # contextual embedding related variables 313 | 314 | assert free_text_embed_method in {'cls', 'mean_pool'}, 'must be either cls or mean_pool' 315 | self.free_text_embed_method = free_text_embed_method 316 | self.use_free_text_context = use_free_text_context 317 | 318 | if use_free_text_context: 319 | contextual_embed_dim = get_contextual_dim(free_text_context_encoder) 320 | else: 321 | assert exists(contextual_embed_dim), 'contextual embedding dimension must be given if not using transformer encoder' 322 | 323 | # protein embedding related variables 324 | 325 | self.use_aa_embeds = use_aa_embeds 326 | self.aa_embed_config = get_protein_embedder(aa_embed_encoder) 327 | self.get_aa_embed = self.aa_embed_config['fn'] 328 | 329 | if use_aa_embeds: 330 | aa_embed_dim = self.aa_embed_config['dim'] 331 | else: 332 | assert exists(aa_embed_dim), 'AA embedding dimensions must be set if not using ESM' 333 | 334 | # conditioning 335 | 336 | self.cond_genetic = None 337 | self.cond_protein = None 338 | 339 | if condition_squeeze_excite or condition_film: 340 | condition_klass = SqueezeExcitation if condition_squeeze_excite else FiLM 341 | 342 | self.cond_genetic = condition_klass(contextual_embed_dim, enformer_dim) 343 | self.cond_protein = condition_klass(contextual_embed_dim, aa_embed_dim) 344 | 345 | # genome self attn 346 | 347 | self.genome_self_attns = nn.ModuleList([]) 348 | 349 | for _ in range(genome_self_attn_depth): 350 | attn = SelfAttentionBlock( 351 | dim = enformer_dim, 352 | dropout = dropout 353 | ) 354 | self.genome_self_attns.append(attn) 355 | 356 | # joint attn 357 | 358 | self.joint_cross_attns = nn.ModuleList([]) 359 | 360 | for _ in range(joint_cross_attn_depth): 361 | attn = JointCrossAttentionBlock( 362 | dim = enformer_dim, 363 | context_dim = aa_embed_dim, 364 | dropout = dropout 365 | ) 366 | 367 | self.joint_cross_attns.append(attn) 368 | 369 | # latents 370 | 371 | self.filip = FILIP( 372 | dim = enformer_dim, 373 | context_dim = aa_embed_dim, 374 | dim_head = latent_dim, 375 | heads = latent_heads, 376 | dropout = dropout 377 | ) 378 | 379 | # hypergrid conditioning 380 | 381 | if condition_hypergrid: 382 | self.linear_with_hypergrid = HypergridLinear(latent_heads, latent_heads, context_dim = contextual_embed_dim) 383 | else: 384 | self.linear_to_logits = nn.Linear(latent_heads, latent_heads) 385 | 386 | # to prediction 387 | 388 | self.binary_target = binary_target 389 | self.aux_read_value_loss = aux_read_value_loss 390 | self.read_value_aux_loss_weight = read_value_aux_loss_weight 391 | 392 | if binary_target: 393 | self.loss_fn = F.binary_cross_entropy_with_logits if not target_mse_loss else F.mse_loss 394 | 395 | self.to_pred = nn.Sequential( 396 | Reduce('... n d -> ... d', 'mean'), 397 | nn.LayerNorm(latent_heads), 398 | nn.Linear(latent_heads, 1), 399 | Rearrange('... 1 -> ...') 400 | ) 401 | 402 | self.to_read_value_aux_loss = ReadValueMLP( 403 | dim = latent_heads, 404 | fourier_dims = fourier_dims 405 | ) 406 | 407 | else: 408 | self.loss_fn = poisson_loss if not use_corr_coef_loss else corr_coef_loss 409 | 410 | self.to_pred = nn.Sequential( 411 | nn.Linear(latent_heads, 1), 412 | Rearrange('... 1 -> ...'), 413 | nn.Softplus() 414 | ) 415 | 416 | def combine_losses(self, loss, aux_loss): 417 | if not self.aux_read_value_loss: 418 | return loss 419 | 420 | return loss + self.read_value_aux_loss_weight * aux_loss 421 | 422 | def forward_enformer_head( 423 | self, 424 | seq_embed, 425 | *, 426 | head, 427 | target = None, 428 | return_corr_coef = False 429 | ): 430 | assert not self.binary_target, 'cannot finetune on tracks if binary_target training is turned on' 431 | 432 | unfreeze_all_layers_(self.enformer._heads) 433 | 434 | assert head in self.enformer._heads, f'{head} head not found in enformer' 435 | 436 | pred = self.enformer._heads[head](seq_embed) 437 | 438 | if not exists(target): 439 | return pred 440 | 441 | assert pred.shape[-1] == target.shape[-1], f'{head} head on enformer produced {pred.shape[-1]} tracks, but the supplied target only has {target.shape[-1]}' 442 | 443 | if exists(target) and return_corr_coef: 444 | return pearson_corr_coef(pred, target) 445 | 446 | return self.loss_fn(pred, target) 447 | 448 | def forward( 449 | self, 450 | seq, 451 | *, 452 | aa = None, 453 | aa_embed = None, 454 | contextual_embed = None, 455 | contextual_free_text = None, 456 | aa_mask = None, 457 | target = None, 458 | read_value = None, 459 | peaks_nr = None, 460 | return_corr_coef = False, 461 | finetune_enformer = False, 462 | finetune_enformer_ln_only = False, 463 | unfreeze_enformer_last_n_layers = 0, 464 | head = None 465 | ): 466 | device = seq.device 467 | 468 | # prepare enformer for training 469 | # - set to eval and no_grad if not fine-tuning 470 | # - always freeze the batchnorms 471 | 472 | freeze_batchnorms_(self.enformer) 473 | enformer_forward = self.enformer.forward 474 | 475 | if finetune_enformer: 476 | enformer_context = null_context() 477 | elif finetune_enformer_ln_only: 478 | enformer_context = null_context() 479 | freeze_all_but_layernorms_(self.enformer) 480 | else: 481 | self.enformer.eval() 482 | enformer_context = torch.no_grad() 483 | enformer_forward_wrapper = cache_enformer_forward if self.training else identity 484 | enformer_forward = enformer_forward_wrapper(enformer_forward) 485 | 486 | # if unfreezing last N layers of enformer 487 | 488 | if unfreeze_enformer_last_n_layers > 0: 489 | unfreeze_last_n_layers_(self.enformer, unfreeze_enformer_last_n_layers) 490 | 491 | # genetic sequence embedding 492 | 493 | with enformer_context: 494 | seq_embed = enformer_forward(seq, return_only_embeddings = True) 495 | 496 | # if training off an enformer head 497 | 498 | if exists(head): 499 | return self.forward_enformer_head(seq_embed, head = head, target = target) 500 | 501 | # norm sequence embedding 502 | 503 | seq_embed = self.norm_seq_embed(seq_embed) 504 | 505 | for self_attn_block in self.genome_self_attns: 506 | seq_embed = self_attn_block(seq_embed) 507 | 508 | # protein related embeddings 509 | 510 | if self.use_aa_embeds: 511 | assert exists(aa), 'aa must be passed in as tensor of integers from 0 - 20 (20 being padding)' 512 | aa_embed, aa_mask = self.get_aa_embed(aa, device = seq.device) 513 | else: 514 | assert exists(aa_embed), 'protein embeddings must be given as aa_embed' 515 | 516 | # free text embeddings, for cell types and experimental params 517 | 518 | if not exists(contextual_embed): 519 | assert self.use_free_text_context, 'use_free_text_context must be set to True if one is not passing in contextual_embed tensor' 520 | assert exists(contextual_free_text), 'context must be supplied as array of strings as contextual_free_text if contextual_embed is not supplied' 521 | 522 | contextual_embed = get_text_repr( 523 | contextual_free_text, 524 | return_cls_token = (self.free_text_embed_method == 'cls'), 525 | device = seq.device 526 | ) 527 | 528 | # contextual conditioning 529 | # film or squeeze-excite for both genetic and protein sequences 530 | 531 | if exists(self.cond_genetic): 532 | seq_embed = self.cond_genetic(seq_embed, contextual_embed) 533 | 534 | if exists(self.cond_protein): 535 | aa_embed = self.cond_protein(aa_embed, contextual_embed, mask = aa_mask) 536 | 537 | # joint cross attention 538 | 539 | for cross_attn in self.joint_cross_attns: 540 | seq_embed, aa_embed = cross_attn( 541 | seq_embed, 542 | context = aa_embed, 543 | context_mask = aa_mask 544 | ) 545 | 546 | # project both embeddings into shared latent space 547 | 548 | interactions = self.filip( 549 | seq_embed, 550 | aa_embed, 551 | context_mask = aa_mask 552 | ) 553 | 554 | 555 | # linear with hypergrid conditioning 556 | 557 | if exists(self.linear_with_hypergrid): 558 | logits = self.linear_with_hypergrid(interactions, context = contextual_embed) 559 | else: 560 | logits = self.linear_to_logits(interactions) 561 | 562 | # to *-seq prediction 563 | 564 | pred = self.to_pred(logits) 565 | 566 | if not exists(target): 567 | return pred 568 | 569 | if exists(target) and return_corr_coef: 570 | return pearson_corr_coef(pred, target) 571 | 572 | if exists(target) and not self.binary_target: 573 | return self.loss_fn(pred, target) 574 | 575 | # binary loss w/ optional auxiliary loss 576 | 577 | loss = self.loss_fn(pred, target.float()) 578 | 579 | if not self.aux_read_value_loss: 580 | return loss, torch.Tensor([0.]).to(device) 581 | 582 | # return prediction if not auto-calculating loss 583 | 584 | assert exists(read_value) and exists(peaks_nr), 'peaks NR must be supplied if doing auxiliary read value loss' 585 | 586 | aux_loss = self.to_read_value_aux_loss( 587 | logits, 588 | peaks_nr, 589 | read_value = read_value 590 | ) 591 | 592 | return loss, aux_loss 593 | -------------------------------------------------------------------------------- /tf_bind_transformer/data.py: -------------------------------------------------------------------------------- 1 | from Bio import SeqIO 2 | from random import choice, randrange 3 | from pathlib import Path 4 | import functools 5 | import polars as pl 6 | from collections import defaultdict 7 | 8 | import os 9 | import json 10 | import shutil 11 | import numpy as np 12 | 13 | import torch 14 | from torch.utils.data import DataLoader 15 | from torch.utils.data import Dataset 16 | 17 | from tf_bind_transformer.gene_utils import parse_gene_name 18 | from enformer_pytorch import FastaInterval 19 | 20 | from pyfaidx import Fasta 21 | import pybedtools 22 | 23 | def exists(val): 24 | return val is not None 25 | 26 | def default(val, d): 27 | return val if exists(val) else d 28 | 29 | def find_first_index(cond, arr): 30 | for ind, el in enumerate(arr): 31 | if cond(el): 32 | return ind 33 | return -1 34 | 35 | def cast_list(val = None): 36 | if not exists(val): 37 | return [] 38 | return [val] if not isinstance(val, (tuple, list)) else val 39 | 40 | def read_bed(path): 41 | return pl.read_csv(path, sep = '\t', has_headers = False) 42 | 43 | def save_bed(df, path): 44 | df.to_csv(path, sep = '\t', has_header = False) 45 | 46 | def parse_exp_target_cell(exp_target_cell): 47 | experiment, target, *cell_type = exp_target_cell.split('.') 48 | cell_type = '.'.join(cell_type) # handle edge case where cell type contains periods 49 | return experiment, target, cell_type 50 | 51 | # fetch index of datasets, for providing the sequencing reads 52 | # for auxiliary read value prediction 53 | 54 | def fetch_experiments_index(path): 55 | if not exists(path): 56 | return dict() 57 | 58 | exp_path = Path(path) 59 | assert exp_path.exists(), 'path to experiments json must exist' 60 | 61 | root_json = json.loads(exp_path.read_text()) 62 | experiments = root_json['experiments'] 63 | 64 | index = {} 65 | for experiment in experiments: 66 | exp_id = experiment['accession'] 67 | 68 | if 'details' not in experiment: 69 | continue 70 | 71 | details = experiment['details'] 72 | 73 | if 'datasets' not in details: 74 | continue 75 | 76 | datasets = details['datasets'] 77 | 78 | for dataset in datasets: 79 | dataset_name = dataset['dataset_name'] 80 | index[dataset_name] = dataset['peaks_NR'] 81 | 82 | return index 83 | 84 | # fetch protein sequences by gene name and uniprot id 85 | 86 | class FactorProteinDatasetByUniprotID(Dataset): 87 | def __init__( 88 | self, 89 | folder, 90 | species_priority = ['human', 'mouse'] 91 | ): 92 | super().__init__() 93 | fasta_paths = [*Path(folder).glob('*.fasta')] 94 | assert len(fasta_paths) > 0, f'no fasta files found at {folder}' 95 | self.paths = fasta_paths 96 | self.index_by_id = dict() 97 | 98 | for path in fasta_paths: 99 | gene, uniprotid, *_ = path.stem.split('.') 100 | self.index_by_id[uniprotid] = path 101 | 102 | def __len__(self): 103 | return len(self.paths) 104 | 105 | def __getitem__(self, uid): 106 | index = self.index_by_id 107 | 108 | if uid not in index: 109 | return None 110 | 111 | entry = index[uid] 112 | fasta = SeqIO.read(entry, 'fasta') 113 | return str(fasta.seq) 114 | 115 | # fetch 116 | 117 | class FactorProteinDataset(Dataset): 118 | def __init__( 119 | self, 120 | folder, 121 | species_priority = ['human', 'mouse', 'unknown'], 122 | return_tuple_only = False 123 | ): 124 | super().__init__() 125 | fasta_paths = [*Path(folder).glob('*.fasta')] 126 | assert len(fasta_paths) > 0, f'no fasta files found at {folder}' 127 | self.paths = fasta_paths 128 | 129 | index_by_gene = defaultdict(list) 130 | self.return_tuple_only = return_tuple_only # whether to return tuple even if there is only one subunit 131 | 132 | for path in fasta_paths: 133 | gene, uniprotid, *_ = path.stem.split('.') 134 | index_by_gene[gene].append(path) 135 | 136 | # prioritize fasta files of certain species 137 | # but allow for appropriate fallback, by order of species_priority 138 | 139 | get_species_from_path = lambda p: p.stem.split('_')[-1].lower() if '_' in p.stem else 'unknown' 140 | 141 | filtered_index_by_gene = defaultdict(list) 142 | 143 | for gene, gene_paths in index_by_gene.items(): 144 | species_count = list(map(lambda specie: len(list(filter(lambda p: get_species_from_path(p) == specie, gene_paths))), species_priority)) 145 | species_ind_non_zero = find_first_index(lambda t: t > 0, species_count) 146 | 147 | if species_ind_non_zero == -1: 148 | continue 149 | 150 | species = species_priority[species_ind_non_zero] 151 | filtered_index_by_gene[gene] = list(filter(lambda p: get_species_from_path(p) == species, gene_paths)) 152 | 153 | self.index_by_gene = filtered_index_by_gene 154 | 155 | def __len__(self): 156 | return len(self.paths) 157 | 158 | def __getitem__(self, unparsed_gene_name): 159 | index = self.index_by_gene 160 | 161 | genes = parse_gene_name(unparsed_gene_name) 162 | seqs = [] 163 | 164 | for gene in genes: 165 | entry = index[gene] 166 | 167 | if len(entry) == 0: 168 | print(f'no entries for {gene}') 169 | continue 170 | 171 | path = choice(entry) if isinstance(entry, list) else entry 172 | 173 | fasta = SeqIO.read(path, 'fasta') 174 | seqs.append(str(fasta.seq)) 175 | 176 | seqs = tuple(seqs) 177 | 178 | if len(seqs) == 1 and not self.return_tuple_only: 179 | return seqs[0] 180 | 181 | return seqs 182 | 183 | # remap dataframe functions 184 | 185 | def get_chr_names(ids): 186 | return set(map(lambda t: f'chr{t}', ids)) 187 | 188 | CHR_IDS = set([*range(1, 23), 'X']) 189 | CHR_NAMES = get_chr_names(CHR_IDS) 190 | 191 | def remap_df_add_experiment_target_cell(df, col = 'column_4'): 192 | df = df.clone() 193 | 194 | exp_id = df.select([pl.col(col).str.extract(r"^([\w\-]+)\.*")]) 195 | exp_id = exp_id.rename({col: 'experiment'}).to_series(0) 196 | df.insert_at_idx(3, exp_id) 197 | 198 | targets = df.select([pl.col(col).str.extract(r"[\w\-]+\.([\w\-]+)\.[\w\-]+")]) 199 | targets = targets.rename({col: 'target'}).to_series(0) 200 | df.insert_at_idx(3, targets) 201 | 202 | cell_type = df.select([pl.col(col).str.extract(r"^.*\.([\w\-]+)$")]) 203 | cell_type = cell_type.rename({col: 'cell_type'}).to_series(0) 204 | df.insert_at_idx(3, cell_type) 205 | 206 | return df 207 | 208 | def pl_isin(col, arr): 209 | equalities = list(map(lambda t: pl.col(col) == t, arr)) 210 | return functools.reduce(lambda a, b: a | b, equalities) 211 | 212 | def pl_notin(col, arr): 213 | equalities = list(map(lambda t: pl.col(col) != t, arr)) 214 | return functools.reduce(lambda a, b: a & b, equalities) 215 | 216 | def filter_by_col_isin(df, col, arr, chunk_size = 25): 217 | """ 218 | polars seem to have a bug 219 | where OR more than 25 conditions freezes (for pl_isin) 220 | do in chunks of 25 and then concat instead 221 | """ 222 | dataframes = [] 223 | for i in range(0, len(arr), chunk_size): 224 | sub_arr = arr[i:(i + chunk_size)] 225 | filtered_df = df.filter(pl_isin(col, sub_arr)) 226 | dataframes.append(filtered_df) 227 | return pl.concat(dataframes) 228 | 229 | def filter_bed_file_by_(bed_file_1, bed_file_2, output_file): 230 | # generated by OpenAI Codex 231 | 232 | bed_file_1_bedtool = pybedtools.BedTool(bed_file_1) 233 | bed_file_2_bedtool = pybedtools.BedTool(bed_file_2) 234 | bed_file_1_bedtool_intersect_bed_file_2_bedtool = bed_file_1_bedtool.intersect(bed_file_2_bedtool, v = True) 235 | bed_file_1_bedtool_intersect_bed_file_2_bedtool.saveas(output_file) 236 | 237 | def filter_df_by_tfactor_fastas(df, folder): 238 | files = [*Path(folder).glob('**/*.fasta')] 239 | present_target_names = set([f.stem.split('.')[0] for f in files]) 240 | all_df_targets = df.get_column('target').unique().to_list() 241 | 242 | all_df_targets_with_parsed_name = [(target, parse_gene_name(target)) for target in all_df_targets] 243 | unknown_targets = [target for target, parsed_target_name in all_df_targets_with_parsed_name for parsed_target_name_sub_el in parsed_target_name if parsed_target_name_sub_el not in present_target_names] 244 | 245 | if len(unknown_targets) > 0: 246 | df = df.filter(pl_notin('target', unknown_targets)) 247 | return df 248 | 249 | def generate_random_ranges_from_fasta( 250 | fasta_file, 251 | *, 252 | output_filename = 'random-ranges.bed', 253 | context_length, 254 | filter_bed_files = [], 255 | num_entries_per_key = 10, 256 | keys = None, 257 | ): 258 | fasta = Fasta(fasta_file) 259 | tmp_file = f'/tmp/{output_filename}' 260 | 261 | with open(tmp_file, 'w') as f: 262 | for chr_name in sorted(CHR_NAMES): 263 | print(f'generating ranges for {chr_name}') 264 | 265 | if chr_name not in fasta: 266 | print(f'{chr_name} not found in fasta file') 267 | continue 268 | 269 | chromosome = fasta[chr_name] 270 | chromosome_length = len(chromosome) 271 | 272 | start = np.random.randint(0, chromosome_length - context_length, (num_entries_per_key,)) 273 | end = start + context_length 274 | start_and_end = np.stack((start, end), axis = -1) 275 | 276 | for row in start_and_end.tolist(): 277 | start, end = row 278 | f.write('\t'.join((chr_name, str(start), str(end))) + '\n') 279 | 280 | for file in filter_bed_files: 281 | filter_bed_file_by_(tmp_file, file, tmp_file) 282 | 283 | shutil.move(tmp_file, f'./{output_filename}') 284 | 285 | print('success') 286 | 287 | # context string creator class 288 | 289 | class ContextDataset(Dataset): 290 | def __init__( 291 | self, 292 | *, 293 | biotypes_metadata_path = None, 294 | include_biotypes_metadata_in_context = False, 295 | include_biotypes_metadata_columns = [], 296 | biotypes_metadata_delimiter = ' | ', 297 | ): 298 | self.include_biotypes_metadata_in_context = include_biotypes_metadata_in_context 299 | self.include_biotypes_metadata_columns = include_biotypes_metadata_columns 300 | self.biotypes_metadata_delimiter = biotypes_metadata_delimiter 301 | 302 | if include_biotypes_metadata_in_context: 303 | assert len(self.include_biotypes_metadata_columns) > 0, 'must have more than one biotype metadata column to include' 304 | assert exists(biotypes_metadata_path), 'biotypes metadata path must be supplied if to be included in context string' 305 | 306 | p = Path(biotypes_metadata_path) 307 | 308 | if p.suffix == '.csv': 309 | sep = ',' 310 | elif p.suffix == '.tsv': 311 | sep = '\t' 312 | else: 313 | raise ValueError(f'invalid suffix {p.suffix} for biotypes') 314 | 315 | self.df = pl.read_csv(str(p), sep = sep) 316 | 317 | def __len__(): 318 | return len(self.df) if self.include_biotypes_metadata_in_context else -1 319 | 320 | def __getitem__(self, biotype): 321 | if not self.include_biotypes_metadata_in_context: 322 | return biotype 323 | 324 | col_indices = list(map(self.df.columns.index, self.include_biotypes_metadata_columns)) 325 | filtered = self.df.filter(pl.col('biotype') == biotype) 326 | 327 | if len(filtered) == 0: 328 | print(f'no rows found for {biotype} in biotype metadata file') 329 | return biotype 330 | 331 | row = filtered.row(0) 332 | columns = list(map(lambda t: row[t], col_indices)) 333 | 334 | context_string = self.biotypes_metadata_delimiter.join([biotype, *columns]) 335 | return context_string 336 | 337 | # dataset for remap data - all peaks 338 | 339 | class RemapAllPeakDataset(Dataset): 340 | def __init__( 341 | self, 342 | *, 343 | factor_fasta_folder, 344 | bed_file = None, 345 | remap_df = None, 346 | filter_chromosome_ids = None, 347 | exclude_targets = None, 348 | include_targets = None, 349 | exclude_cell_types = None, 350 | include_cell_types = None, 351 | remap_df_frac = 1., 352 | experiments_json_path = None, 353 | include_biotypes_metadata_in_context = False, 354 | biotypes_metadata_path = None, 355 | include_biotypes_metadata_columns = [], 356 | biotypes_metadata_delimiter = ' | ', 357 | balance_sampling_by_target = False, 358 | **kwargs 359 | ): 360 | super().__init__() 361 | assert exists(remap_df) ^ exists(bed_file), 'either remap bed file or remap dataframe must be passed in' 362 | 363 | if not exists(remap_df): 364 | remap_df = read_bed(bed_file) 365 | 366 | if remap_df_frac < 1: 367 | remap_df = remap_df.sample(frac = remap_df_frac) 368 | 369 | dataset_chr_ids = CHR_IDS 370 | 371 | if exists(filter_chromosome_ids): 372 | dataset_chr_ids = dataset_chr_ids.intersection(set(filter_chromosome_ids)) 373 | 374 | remap_df = remap_df.filter(pl_isin('column_1', get_chr_names(dataset_chr_ids))) 375 | remap_df = filter_df_by_tfactor_fastas(remap_df, factor_fasta_folder) 376 | 377 | self.factor_ds = FactorProteinDataset(factor_fasta_folder) 378 | 379 | # filter dataset by inclusion and exclusion list of targets 380 | # ( intersect ) subtract 381 | 382 | include_targets = cast_list(include_targets) 383 | exclude_targets = cast_list(exclude_targets) 384 | 385 | if include_targets: 386 | remap_df = remap_df.filter(pl_isin('target', include_targets)) 387 | 388 | if exclude_targets: 389 | remap_df = remap_df.filter(pl_notin('target', exclude_targets)) 390 | 391 | # filter dataset by inclusion and exclusion list of cell types 392 | # same logic as for targets 393 | 394 | include_cell_types = cast_list(include_cell_types) 395 | exclude_cell_types = cast_list(exclude_cell_types) 396 | 397 | if include_cell_types: 398 | remap_df = remap_df.filter(pl_isin('cell_type', include_cell_types)) 399 | 400 | if exclude_cell_types: 401 | remap_df = remap_df.filter(pl_notin('cell_type', exclude_cell_types)) 402 | 403 | assert len(remap_df) > 0, 'dataset is empty by filter criteria' 404 | 405 | self.df = remap_df 406 | self.fasta = FastaInterval(**kwargs) 407 | 408 | self.experiments_index = fetch_experiments_index(experiments_json_path) 409 | 410 | # balanced target sampling logic 411 | 412 | self.balance_sampling_by_target = balance_sampling_by_target 413 | 414 | if self.balance_sampling_by_target: 415 | self.df_indexed_by_target = [] 416 | 417 | for target in self.df.get_column('target').unique().to_list(): 418 | df_by_target = self.df.filter(pl.col('target') == target) 419 | self.df_indexed_by_target.append(df_by_target) 420 | 421 | # context string creator 422 | 423 | self.context_ds = ContextDataset( 424 | include_biotypes_metadata_in_context = include_biotypes_metadata_in_context, 425 | biotypes_metadata_path = biotypes_metadata_path, 426 | include_biotypes_metadata_columns = include_biotypes_metadata_columns, 427 | biotypes_metadata_delimiter = biotypes_metadata_delimiter 428 | ) 429 | 430 | def __len__(self): 431 | if self.balance_sampling_by_target: 432 | return len(self.df_indexed_by_target) 433 | else: 434 | return len(self.df) 435 | 436 | def __getitem__(self, ind): 437 | # if balancing by target, randomly draw sample from indexed dataframe 438 | 439 | if self.balance_sampling_by_target: 440 | filtered_df = self.df_indexed_by_target[ind] 441 | rand_ind = randrange(0, len(filtered_df)) 442 | sample = filtered_df.row(rand_ind) 443 | else: 444 | sample = self.df.row(ind) 445 | 446 | chr_name, begin, end, _, _, _, experiment_target_cell_type, reading, *_ = sample 447 | 448 | # now aggregate all the data 449 | 450 | experiment, target, cell_type = parse_exp_target_cell(experiment_target_cell_type) 451 | 452 | seq = self.fasta(chr_name, begin, end) 453 | aa_seq = self.factor_ds[target] 454 | context_str = self.context_ds[cell_type] 455 | 456 | read_value = torch.Tensor([reading]) 457 | 458 | peaks_nr = self.experiments_index.get(experiment_target_cell_type, 0.) 459 | peaks_nr = torch.Tensor([peaks_nr]) 460 | 461 | label = torch.Tensor([1.]) 462 | 463 | return seq, aa_seq, context_str, peaks_nr, read_value, label 464 | 465 | # filter functions for exp-target-cells based on heldouts 466 | 467 | def filter_exp_target_cell( 468 | arr, 469 | *, 470 | exclude_targets = None, 471 | include_targets = None, 472 | exclude_cell_types = None, 473 | include_cell_types = None, 474 | ): 475 | out = [] 476 | 477 | for el in arr: 478 | experiment, target, cell_type = parse_exp_target_cell(el) 479 | 480 | if exists(include_targets) and len(include_targets) > 0 and target not in include_targets: 481 | continue 482 | 483 | if exists(exclude_targets) and target in exclude_targets: 484 | continue 485 | 486 | if exists(include_cell_types) and len(include_cell_types) > 0 and cell_type not in include_cell_types: 487 | continue 488 | 489 | if exists(exclude_cell_types) and cell_type in exclude_cell_types: 490 | continue 491 | 492 | out.append(el) 493 | 494 | return out 495 | 496 | 497 | # dataset for negatives scoped to a specific exp-target-celltype 498 | 499 | class ScopedNegativePeakDataset(Dataset): 500 | def __init__( 501 | self, 502 | *, 503 | fasta_file, 504 | factor_fasta_folder, 505 | numpy_folder_with_scoped_negatives, 506 | exts = '.bed.bool.npy', 507 | remap_bed_file = None, 508 | remap_df = None, 509 | filter_chromosome_ids = None, 510 | experiments_json_path = None, 511 | exclude_targets = None, 512 | include_targets = None, 513 | exclude_cell_types = None, 514 | include_cell_types = None, 515 | include_biotypes_metadata_in_context = False, 516 | biotypes_metadata_path = None, 517 | include_biotypes_metadata_columns = [], 518 | biotypes_metadata_delimiter = ' | ', 519 | balance_sampling_by_target = False, 520 | **kwargs 521 | ): 522 | super().__init__() 523 | assert exists(remap_df) ^ exists(remap_bed_file), 'either remap bed file or remap dataframe must be passed in' 524 | 525 | if not exists(remap_df): 526 | remap_df = read_bed(remap_bed_file) 527 | 528 | dataset_chr_ids = CHR_IDS 529 | 530 | if exists(filter_chromosome_ids): 531 | dataset_chr_ids = dataset_chr_ids.intersection(set(filter_chromosome_ids)) 532 | 533 | filter_map_df = remap_df.with_column(pl.when(pl_isin('column_1', get_chr_names(dataset_chr_ids))).then(True).otherwise(False).alias('mask')) 534 | mask = filter_map_df.get_column('mask').to_numpy() 535 | 536 | num_scoped_negs = mask.sum() 537 | 538 | print(f'{num_scoped_negs} scoped negative rows found for training') 539 | 540 | assert num_scoped_negs > 0, 'all remap rows filtered out for scoped negative peak dataset' 541 | 542 | self.df = remap_df 543 | self.chromosome_mask = mask 544 | 545 | # get dictionary with exp-target-cell to boolean numpy indicating which ones are negatives 546 | 547 | npys_paths = [*Path(numpy_folder_with_scoped_negatives).glob('**/*.npy')] 548 | exp_target_cell_negatives = [(path.name.rstrip(exts), path) for path in npys_paths] 549 | 550 | exp_target_cells = [el[0] for el in exp_target_cell_negatives] 551 | 552 | exp_target_cells = filter_exp_target_cell( 553 | exp_target_cells, 554 | include_targets = include_targets, 555 | exclude_targets = exclude_targets, 556 | include_cell_types = include_cell_types, 557 | exclude_cell_types = exclude_cell_types 558 | ) 559 | 560 | filtered_exp_target_cell_negatives = list(filter(lambda el: el[0] in exp_target_cells, exp_target_cell_negatives)) 561 | 562 | self.exp_target_cell_negatives = filtered_exp_target_cell_negatives 563 | assert len(self.exp_target_cell_negatives) > 0, 'no experiment-target-cell scoped negatives to select from after filtering' 564 | 565 | # balanced target sampling 566 | 567 | self.balance_sampling_by_target = balance_sampling_by_target 568 | 569 | if balance_sampling_by_target: 570 | self.exp_target_cell_by_target = defaultdict(list) 571 | 572 | for exp_target_cell, filepath in self.exp_target_cell_negatives: 573 | _, target, *_ = parse_exp_target_cell(exp_target_cell) 574 | self.exp_target_cell_by_target[target].append((exp_target_cell, filepath)) 575 | 576 | # tfactor dataset 577 | 578 | self.factor_ds = FactorProteinDataset(factor_fasta_folder) 579 | 580 | self.fasta = FastaInterval(fasta_file = fasta_file, **kwargs) 581 | self.experiments_index = fetch_experiments_index(experiments_json_path) 582 | 583 | # context string creator 584 | 585 | self.context_ds = ContextDataset( 586 | include_biotypes_metadata_in_context = include_biotypes_metadata_in_context, 587 | biotypes_metadata_path = biotypes_metadata_path, 588 | include_biotypes_metadata_columns = include_biotypes_metadata_columns, 589 | biotypes_metadata_delimiter = biotypes_metadata_delimiter 590 | ) 591 | 592 | def __len__(self): 593 | if self.balance_sampling_by_target: 594 | return len(self.exp_target_cell_by_target) 595 | else: 596 | return len(self.exp_target_cell_negatives) 597 | 598 | def __getitem__(self, idx): 599 | if self.balance_sampling_by_target: 600 | negatives = list(self.exp_target_cell_by_target.values())[idx] 601 | sample = choice(negatives) 602 | else: 603 | sample = self.exp_target_cell_negatives[idx] 604 | 605 | exp_target_cell, bool_numpy_path = sample 606 | experiment, target, cell_type = parse_exp_target_cell(exp_target_cell) 607 | 608 | # load boolean numpy array 609 | # and select random peak that is a negative 610 | 611 | np_arr = np.load(str(bool_numpy_path)) 612 | np_arr_noised = np_arr.astype(np.float32) + np.random.uniform(low = -1e-1, high = 1e-1, size = np_arr.shape[0]) 613 | 614 | # mask with chromosomes allowed 615 | 616 | np_arr_noised *= self.chromosome_mask.astype(np.float32) 617 | 618 | # select random negative peak 619 | 620 | random_neg_peak_index = np_arr_noised.argmax() 621 | 622 | chr_name, begin, end, *_ = self.df.row(random_neg_peak_index) 623 | seq = self.fasta(chr_name, begin, end) 624 | 625 | aa_seq = self.factor_ds[target] 626 | context_str = self.context_ds[cell_type] 627 | 628 | peaks_nr = self.experiments_index.get(exp_target_cell, 0.) 629 | peaks_nr = torch.Tensor([peaks_nr]) 630 | 631 | read_value = torch.Tensor([0.]) 632 | 633 | label = torch.Tensor([0.]) 634 | 635 | return seq, aa_seq, context_str, peaks_nr, read_value, label 636 | 637 | # dataset for hard negatives (negatives to all peaks) 638 | 639 | class NegativePeakDataset(Dataset): 640 | def __init__( 641 | self, 642 | *, 643 | factor_fasta_folder, 644 | negative_bed_file = None, 645 | remap_bed_file = None, 646 | remap_df = None, 647 | negative_df = None, 648 | filter_chromosome_ids = None, 649 | exclude_targets = None, 650 | include_targets = None, 651 | exclude_cell_types = None, 652 | include_cell_types = None, 653 | exp_target_cell_column = 'column_4', 654 | experiments_json_path = None, 655 | include_biotypes_metadata_in_context = False, 656 | biotypes_metadata_path = None, 657 | include_biotypes_metadata_columns = [], 658 | biotypes_metadata_delimiter = ' | ', 659 | balance_sampling_by_target = False, 660 | **kwargs 661 | ): 662 | super().__init__() 663 | assert exists(remap_df) ^ exists(remap_bed_file), 'either remap bed file or remap dataframe must be passed in' 664 | assert exists(negative_df) ^ exists(negative_bed_file), 'either negative bed file or negative dataframe must be passed in' 665 | 666 | # instantiate dataframes if not passed in 667 | 668 | if not exists(remap_df): 669 | remap_df = read_bed(remap_bed_file) 670 | 671 | neg_df = negative_df 672 | if not exists(negative_df): 673 | neg_df = read_bed(negative_bed_file) 674 | 675 | # filter remap dataframe 676 | 677 | remap_df = filter_df_by_tfactor_fastas(remap_df, factor_fasta_folder) 678 | 679 | dataset_chr_ids = CHR_IDS 680 | 681 | if exists(filter_chromosome_ids): 682 | dataset_chr_ids = dataset_chr_ids.intersection(set(filter_chromosome_ids)) 683 | 684 | neg_df = neg_df.filter(pl_isin('column_1', get_chr_names(dataset_chr_ids))) 685 | 686 | assert len(neg_df) > 0, 'dataset is empty by filter criteria' 687 | 688 | self.neg_df = neg_df 689 | 690 | # get all exp-target-cells and filter by above 691 | 692 | exp_target_cells = remap_df.get_column(exp_target_cell_column).unique().to_list() 693 | 694 | self.filtered_exp_target_cells = filter_exp_target_cell( 695 | exp_target_cells, 696 | include_targets = include_targets, 697 | exclude_targets = exclude_targets, 698 | include_cell_types = include_cell_types, 699 | exclude_cell_types = exclude_cell_types 700 | ) 701 | 702 | assert len(self.filtered_exp_target_cells), 'no experiment-target-cell left for hard negative set' 703 | 704 | # balanced sampling of targets 705 | 706 | self.balance_sampling_by_target = balance_sampling_by_target 707 | 708 | if balance_sampling_by_target: 709 | self.exp_target_cell_by_target = defaultdict(list) 710 | 711 | for exp_target_cell in self.filtered_exp_target_cells: 712 | _, target, *_ = parse_exp_target_cell(exp_target_cell) 713 | self.exp_target_cell_by_target[target].append(exp_target_cell) 714 | 715 | # factor ds 716 | 717 | self.factor_ds = FactorProteinDataset(factor_fasta_folder) 718 | self.fasta = FastaInterval(**kwargs) 719 | 720 | self.experiments_index = fetch_experiments_index(experiments_json_path) 721 | 722 | # context string creator 723 | 724 | self.context_ds = ContextDataset( 725 | include_biotypes_metadata_in_context = include_biotypes_metadata_in_context, 726 | biotypes_metadata_path = biotypes_metadata_path, 727 | include_biotypes_metadata_columns = include_biotypes_metadata_columns, 728 | biotypes_metadata_delimiter = biotypes_metadata_delimiter 729 | ) 730 | 731 | def __len__(self): 732 | return len(self.neg_df) 733 | 734 | def __getitem__(self, ind): 735 | chr_name, begin, end = self.neg_df.row(ind) 736 | 737 | if self.balance_sampling_by_target: 738 | rand_ind = randrange(0, len(self.exp_target_cell_by_target)) 739 | exp_target_cell_by_target_list = list(self.exp_target_cell_by_target.values()) 740 | random_exp_target_cell_type = choice(exp_target_cell_by_target_list[rand_ind]) 741 | else: 742 | random_exp_target_cell_type = choice(self.filtered_exp_target_cells) 743 | 744 | experiment, target, cell_type = parse_exp_target_cell(random_exp_target_cell_type) 745 | 746 | seq = self.fasta(chr_name, begin, end) 747 | aa_seq = self.factor_ds[target] 748 | context_str = self.context_ds[cell_type] 749 | 750 | read_value = torch.Tensor([0.]) 751 | 752 | peaks_nr = self.experiments_index.get(random_exp_target_cell_type, 0.) 753 | peaks_nr = torch.Tensor([peaks_nr]) 754 | 755 | label = torch.Tensor([0.]) 756 | 757 | return seq, aa_seq, context_str, peaks_nr, read_value, label 758 | 759 | # dataloader related functions 760 | 761 | def collate_fn(data): 762 | seq, aa_seq, context_str, peaks_nr, read_values, labels = list(zip(*data)) 763 | return torch.stack(seq), tuple(aa_seq), tuple(context_str), torch.stack(peaks_nr, dim = 0), torch.stack(read_values, dim = 0), torch.cat(labels, dim = 0) 764 | 765 | def collate_dl_outputs(*dl_outputs): 766 | outputs = list(zip(*dl_outputs)) 767 | ret = [] 768 | for entry in outputs: 769 | if isinstance(entry[0], torch.Tensor): 770 | entry = torch.cat(entry, dim = 0) 771 | else: 772 | entry = (sub_el for el in entry for sub_el in el) 773 | ret.append(entry) 774 | return tuple(ret) 775 | 776 | def cycle(loader): 777 | while True: 778 | for data in loader: 779 | yield data 780 | 781 | def get_dataloader(ds, cycle_iter = False, **kwargs): 782 | dataset_len = len(ds) 783 | batch_size = kwargs.get('batch_size') 784 | drop_last = dataset_len > batch_size 785 | 786 | dl = DataLoader(ds, collate_fn = collate_fn, drop_last = drop_last, **kwargs) 787 | wrapper = cycle if cycle_iter else iter 788 | return wrapper(dl) 789 | --------------------------------------------------------------------------------