├── __init__.py ├── src ├── __init__.py ├── models │ ├── __init__.py │ ├── baseline │ │ ├── __init__.py │ │ └── genomics_benchmark_cnn.py │ ├── sequence │ │ └── __init__.py │ └── nn │ │ ├── __init__.py │ │ ├── activation.py │ │ └── utils.py ├── dataloaders │ ├── datasets │ │ ├── __init__.py │ │ ├── nucleotide_transformer_dataset.py │ │ ├── genomic_bench_dataset.py │ │ └── hg38_char_tokenizer.py │ ├── __init__.py │ ├── utils │ │ ├── rc.py │ │ └── mlm.py │ ├── fault_tolerant_sampler.py │ └── base.py ├── utils │ ├── __init__.py │ ├── registry.py │ ├── optim │ │ └── schedulers.py │ ├── config.py │ ├── train.py │ └── optim_groups.py ├── callbacks │ ├── validation.py │ ├── params.py │ └── timer.py ├── tasks │ ├── encoders.py │ └── torchmetrics.py └── ops │ └── fftconv.py ├── janusdna └── __init__.py ├── configs ├── callbacks │ ├── lr.yaml │ ├── gpu_affinity.yaml │ ├── val_every_n_global_steps.yaml │ ├── rich.yaml │ ├── base.yaml │ ├── wandb.yaml │ └── checkpoint.yaml ├── loader │ └── default.yaml ├── task │ ├── lm.yaml │ ├── regression.yaml │ ├── longrange_benchmark.yaml │ ├── multiclass_classification.yaml │ └── multilabel_classification.yaml ├── model │ ├── hf_caduceus.yaml │ ├── genomics_benchmark_cnn.yaml │ ├── layer │ │ └── hyena.yaml │ ├── hyena.yaml │ ├── hyena_hg.yaml │ ├── mamba.yaml │ ├── caduceus.yaml │ ├── caduceus_ph_131k_hg.yaml │ ├── caduceus_nt.yaml │ ├── caduceus_ffn_moe_attn.yaml │ └── janusdna.yaml ├── optimizer │ ├── adamw.yaml │ ├── sgd.yaml │ └── adam.yaml ├── scheduler │ ├── constant.yaml │ ├── step.yaml │ ├── multistep.yaml │ ├── cosine_warmup.yaml │ ├── linear_warmup.yaml │ ├── constant_warmup.yaml │ ├── cosine.yaml │ ├── cosine_warmup_timm.yaml │ └── plateau.yaml ├── dataset │ ├── hg38.yaml │ ├── akita_benchmark.yaml │ ├── eqtl_benchmark.yaml │ ├── genomic_benchmark.yaml │ └── nucleotide_transformer.yaml ├── trainer │ ├── debug.yaml │ ├── default.yaml │ ├── lm.yaml │ └── full.yaml ├── pipeline │ ├── akita_benchmark.yaml │ ├── eqtl_benchmark.yaml │ ├── genomic_benchmark.yaml │ ├── enhancer_target_gene.yaml │ ├── nucleotide_transformer.yaml │ ├── longrange_benchmark.yaml │ └── hg38.yaml ├── experiment │ └── hg38 │ │ ├── nucleotide_transformer.yaml │ │ ├── hg38.yaml │ │ ├── genomic_benchmark_cnn.yaml │ │ ├── genomic_benchmark.yaml │ │ └── eqtl.yaml └── config.yaml ├── assets └── JanusDNA.png ├── caduceus ├── __init__.py ├── configuration_caduceus.py └── tokenization_caduceus.py ├── evals ├── auroc.py ├── evaluate_contact_map.py ├── evaluate_auroc_janus.py ├── evaluate_auroc.py └── hg38_inference.py ├── .gitignore ├── scripts ├── benchmark │ ├── dnalong │ │ ├── eqtl_evaluation_janus.sh │ │ └── eqtl_train_janus_8gpu.sh │ ├── gb │ │ └── gb_janusdna.sh │ └── nt │ │ └── nt_janusdna.sh └── pre_train │ ├── slurm_JanusDNA_w_midattn_32dim.sh │ ├── slurm_JanusDNA_w_midattn_72dim.sh │ ├── slurm_JanusDNA_wo_midattn_144dim.sh │ ├── slurm_JanusDNA_wo_midattn_32dim.sh │ └── slurm_JanusDNA_wo_midattn_72dim.sh ├── janusdna.yml └── README.md /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /janusdna/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/baseline/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/sequence/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/dataloaders/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/callbacks/lr.yaml: -------------------------------------------------------------------------------- 1 | lr_monitor: 2 | logging_interval: step -------------------------------------------------------------------------------- /src/models/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .activation import Activation 2 | -------------------------------------------------------------------------------- /configs/callbacks/gpu_affinity.yaml: -------------------------------------------------------------------------------- 1 | gpu_affinity: 2 | _name_: gpu_affinity 3 | -------------------------------------------------------------------------------- /configs/loader/default.yaml: -------------------------------------------------------------------------------- 1 | num_workers: 0 2 | pin_memory: True 3 | drop_last: True 4 | -------------------------------------------------------------------------------- /assets/JanusDNA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qihao-Duan/JanusDNA/HEAD/assets/JanusDNA.png -------------------------------------------------------------------------------- /src/dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from . import genomics 2 | from .base import SequenceDataset 3 | -------------------------------------------------------------------------------- /configs/callbacks/val_every_n_global_steps.yaml: -------------------------------------------------------------------------------- 1 | val_every_n_global_steps: 2 | every_n: 10000 3 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import is_list, is_dict, to_list, to_dict, get_class, instantiate 2 | -------------------------------------------------------------------------------- /configs/task/lm.yaml: -------------------------------------------------------------------------------- 1 | _name_: lm 2 | # loss: cross_entropy # Handled by task: cross entropy loss 3 | metrics: ppl 4 | -------------------------------------------------------------------------------- /configs/callbacks/rich.yaml: -------------------------------------------------------------------------------- 1 | rich_model_summary: 2 | max_depth: 2 3 | 4 | rich_progress_bar: 5 | refresh_rate_per_second: 1.0 6 | -------------------------------------------------------------------------------- /configs/task/regression.yaml: -------------------------------------------------------------------------------- 1 | # _target_: tasks.tasks.BaseTask 2 | _name_: base 3 | loss: mse 4 | metrics: mse 5 | torchmetrics: null 6 | -------------------------------------------------------------------------------- /configs/model/hf_caduceus.yaml: -------------------------------------------------------------------------------- 1 | # Caduceus downloaded form HF 2 | _name_: hf_caduceus 3 | pretrained_model_name_or_path: null 4 | trust_remote_code: true -------------------------------------------------------------------------------- /configs/task/longrange_benchmark.yaml: -------------------------------------------------------------------------------- 1 | # _target_: tasks.tasks.MultiClass 2 | _name_: lrb 3 | loss: cross_entropy 4 | metrics: 5 | - accuracy 6 | torchmetrics: null 7 | -------------------------------------------------------------------------------- /configs/optimizer/adamw.yaml: -------------------------------------------------------------------------------- 1 | # _target_: torch.optim.AdamW 2 | _name_: adamw 3 | lr: 0.001 # Initial learning rate 4 | weight_decay: 0.00 # Weight decay 5 | betas: [0.9, 0.999] 6 | -------------------------------------------------------------------------------- /configs/optimizer/sgd.yaml: -------------------------------------------------------------------------------- 1 | # _target_: torch.optim.SGD 2 | _name_: sgd 3 | lr: 0.001 # Initial learning rate 4 | momentum: 0.9 5 | weight_decay: 0.0 # Weight decay for adam|lamb 6 | -------------------------------------------------------------------------------- /configs/scheduler/constant.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | interval: epoch 4 | scheduler: 5 | # _target_: transformers.get_constant_schedule 6 | _name_: constant 7 | -------------------------------------------------------------------------------- /configs/task/multiclass_classification.yaml: -------------------------------------------------------------------------------- 1 | # _target_: tasks.tasks.MultiClass 2 | _name_: multiclass 3 | loss: cross_entropy 4 | metrics: 5 | - accuracy 6 | torchmetrics: null 7 | -------------------------------------------------------------------------------- /configs/scheduler/step.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | interval: step 4 | scheduler: 5 | # _target_: torch.optim.lr_scheduler.StepLR 6 | _name_: step 7 | step_size: 1 8 | gamma: 0.99 9 | -------------------------------------------------------------------------------- /configs/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | # _target_: torch.optim.Adam 2 | _name_: adam 3 | lr: 0.001 # Initial learning rate 4 | # weight_decay: 0.0 # Weight decay for adam|lamb; should use AdamW instead if desired 5 | betas: [0.9, 0.999] 6 | -------------------------------------------------------------------------------- /configs/scheduler/multistep.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | interval: epoch 4 | # _target_: torch.optim.lr_scheduler.MultiStepLR 5 | scheduler: 6 | _name_: multistep 7 | milestones: [80,140,180] 8 | gamma: 0.2 9 | -------------------------------------------------------------------------------- /configs/scheduler/cosine_warmup.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | interval: step 4 | scheduler: 5 | # _target_: transformers.get_cosine_schedule_with_warmup 6 | _name_: cosine_warmup 7 | num_warmup_steps: 1000 8 | num_training_steps: 40000 9 | -------------------------------------------------------------------------------- /configs/scheduler/linear_warmup.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | interval: step 4 | scheduler: 5 | # _target_: transformers.get_linear_schedule_with_warmup 6 | _name_: linear_warmup 7 | num_warmup_steps: 1000 8 | num_training_steps: 40000 9 | -------------------------------------------------------------------------------- /configs/scheduler/constant_warmup.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | interval: step 4 | scheduler: 5 | # _target_: transformers.get_constant_schedule_with_warmup 6 | _name_: constant_warmup 7 | num_warmup_steps: 1000 # Number of iterations for LR warmup 8 | -------------------------------------------------------------------------------- /caduceus/__init__.py: -------------------------------------------------------------------------------- 1 | """Hugging Face config, model, and tokenizer for Caduceus. 2 | 3 | """ 4 | 5 | from .configuration_caduceus import CaduceusConfig 6 | from .modeling_caduceus import Caduceus, CaduceusForMaskedLM, CaduceusForSequenceClassification 7 | from .tokenization_caduceus import CaduceusTokenizer 8 | -------------------------------------------------------------------------------- /configs/scheduler/cosine.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | interval: epoch 4 | scheduler: 5 | # _target_: torch.optim.lr_scheduler.CosineAnnealingLR 6 | _name_: cosine 7 | T_max: 100 # Max number of epochs steps for LR scheduler 8 | eta_min: 1e-6 # Min learning rate for cosine scheduler 9 | -------------------------------------------------------------------------------- /configs/task/multilabel_classification.yaml: -------------------------------------------------------------------------------- 1 | # _target_: 2 | _name_: base 3 | loss: binary_cross_entropy 4 | metrics: null 5 | torchmetrics: 6 | - MultilabelAUROC # AUROC 7 | - MultilabelAveragePrecision # Precision 8 | # - Recall # not supported in torchmetrics 9 | # - F1 # not supported in torchmetrics 10 | -------------------------------------------------------------------------------- /configs/scheduler/cosine_warmup_timm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | interval: step 4 | scheduler: 5 | # _target_: transformers.get_cosine_schedule_with_warmup 6 | _name_: cosine_warmup_timm 7 | t_in_epochs: False 8 | t_initial: 300 9 | lr_min: 1e-5 10 | warmup_lr_init: 1e-6 11 | warmup_t: 10 12 | -------------------------------------------------------------------------------- /configs/model/genomics_benchmark_cnn.yaml: -------------------------------------------------------------------------------- 1 | # Use open-source version of Mamba 2 | _name_: genomics_benchmark_cnn 3 | number_of_classes: ${dataset.d_output} 4 | vocab_size: 12 5 | embedding_dim: 100 # See: https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks/tree/main/experiments/torch_cnn_experiments 6 | input_len: ${dataset.__l_max} 7 | -------------------------------------------------------------------------------- /configs/model/layer/hyena.yaml: -------------------------------------------------------------------------------- 1 | _name_: hyena 2 | l_max: 1024 3 | order: 2 4 | filter_order: 64 5 | num_heads: 1 6 | inner_factor: 1 7 | num_blocks: 1 8 | fused_bias_fc: false 9 | outer_mixing: false 10 | dropout: 0.0 11 | filter_dropout: 0.0 12 | filter_cls: 'hyena-filter' 13 | post_order_ffn: false 14 | jit_filter: false 15 | short_filter_order: 3 16 | activation: "id" -------------------------------------------------------------------------------- /configs/dataset/hg38.yaml: -------------------------------------------------------------------------------- 1 | _name_: hg38 2 | bed_file: null 3 | fasta_file: null 4 | dataset_name: hg38 5 | tokenizer_name: null 6 | cache_dir: null 7 | max_length: 1024 8 | add_eos: True 9 | batch_size: 8 # per GPU 10 | batch_size_eval: ${eval:${.batch_size} * 2} 11 | num_workers: 4 # For preprocessing only 12 | shuffle: True 13 | __train_len: 34021 14 | __l_max: ${.max_length} 15 | -------------------------------------------------------------------------------- /configs/callbacks/base.yaml: -------------------------------------------------------------------------------- 1 | learning_rate_monitor: 2 | # _target_: pytorch_lightning.callbacks.LearningRateMonitor 3 | logging_interval: ${train.interval} 4 | 5 | timer: 6 | # _target_: callbacks.timer.Timer 7 | step: True 8 | inter_step: False 9 | epoch: True 10 | val: True 11 | 12 | params: 13 | # _target_: callbacks.params.ParamsLog 14 | total: True 15 | trainable: True 16 | fixed: True 17 | -------------------------------------------------------------------------------- /configs/scheduler/plateau.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | interval: epoch 4 | monitor: ??? # must be specified 5 | scheduler: 6 | # _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 7 | _name_: plateau 8 | mode: ${train.mode} # Which metric to monitor 9 | factor: 0.2 # Decay factor when ReduceLROnPlateau is used 10 | patience: 20 11 | min_lr: 0.0 # Minimum learning rate during annealing 12 | -------------------------------------------------------------------------------- /configs/trainer/debug.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | gpus: 1 5 | min_epochs: 1 6 | max_epochs: 10 7 | 8 | # prints 9 | progress_bar_refresh_rate: null 10 | weights_summary: full 11 | profiler: null 12 | 13 | # debugs 14 | fast_dev_run: False 15 | num_sanity_val_steps: 2 16 | overfit_batches: 0 17 | limit_train_batches: 0.1 18 | limit_val_batches: 0.1 19 | limit_test_batches: 0.1 20 | track_grad_norm: -1 21 | terminate_on_nan: False 22 | -------------------------------------------------------------------------------- /configs/pipeline/akita_benchmark.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /trainer: default 4 | - /loader: default 5 | - /dataset: akita_benchmark 6 | - /optimizer: adamw 7 | - /scheduler: plateau 8 | - /callbacks: [base, checkpoint] 9 | - /model: crab 10 | 11 | train: 12 | monitor: val/loss # Needed for plateau scheduler 13 | mode: min 14 | 15 | encoder: id 16 | 17 | # we need this for classification! 18 | decoder: 19 | _name_: contact_map -------------------------------------------------------------------------------- /configs/pipeline/eqtl_benchmark.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /trainer: default 4 | - /loader: default 5 | - /dataset: eqtl_benchmark 6 | - /optimizer: adamw 7 | - /scheduler: plateau 8 | - /callbacks: [base, checkpoint, lr] 9 | - /model: crab 10 | 11 | train: 12 | monitor: val/loss_epoch # Needed for plateau scheduler 13 | mode: min 14 | 15 | encoder: id 16 | 17 | # we need this for classification! 18 | decoder: 19 | _name_: sequence 20 | mode: pool 21 | -------------------------------------------------------------------------------- /evals/auroc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | from sklearn.metrics import roc_auc_score 4 | from sklearn.metrics import average_precision_score 5 | 6 | preds, targets = [], [] 7 | path = "TO THE LOG FILE PATH" 8 | lines = open(path).readlines() 9 | for idx, line in enumerate(lines): 10 | items = line.strip().split() 11 | preds.append(float(items[0])) 12 | targets.append(float(items[1])) 13 | 14 | print(roc_auc_score(targets, preds)) 15 | print(average_precision_score(targets, preds)) -------------------------------------------------------------------------------- /configs/pipeline/genomic_benchmark.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /trainer: default 4 | - /loader: default 5 | - /dataset: genomic_benchmark 6 | - /task: multiclass_classification 7 | - /optimizer: adamw 8 | - /scheduler: plateau 9 | - /callbacks: [base, checkpoint] 10 | 11 | train: 12 | monitor: val/accuracy # Needed for plateau scheduler 13 | mode: max 14 | 15 | encoder: id 16 | 17 | # we need this for classification! 18 | decoder: 19 | _name_: sequence 20 | mode: pool 21 | -------------------------------------------------------------------------------- /configs/dataset/akita_benchmark.yaml: -------------------------------------------------------------------------------- 1 | _name_: akita_benchmark 2 | dataset_name: akita_benchmark 3 | dest_path: "/qihao/data/data/benchmark/dnalong" 4 | max_length: 1048576 5 | d_output: ${.${.dataset_name}.classes} 6 | use_padding: True 7 | padding_side: 'left' 8 | add_eos: False 9 | batch_size: 1 10 | train_len: ${.${.dataset_name}.train_len} 11 | __l_max: ${.max_length} 12 | shuffle: true # set this as default! 13 | akita_benchmark: 14 | train_len: 450000 15 | classes: 1 16 | 17 | 18 | # it seems the train len and class have no meaning here. 19 | -------------------------------------------------------------------------------- /configs/pipeline/enhancer_target_gene.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /trainer: default # no use 4 | - /loader: default # no use 5 | - /dataset: enhancer_target_gene 6 | - /task: multiclass_classification 7 | - /optimizer: adamw 8 | - /scheduler: plateau 9 | - /callbacks: [base, checkpoint] 10 | - /model: crab 11 | 12 | train: 13 | monitor: val/loss_epoch # Needed for plateau scheduler 14 | mode: min 15 | 16 | encoder: id 17 | 18 | # we need this for classification! 19 | decoder: 20 | _name_: sequence 21 | mode: pool 22 | -------------------------------------------------------------------------------- /configs/pipeline/nucleotide_transformer.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /trainer: default 4 | - /loader: default 5 | - /dataset: nucleotide_transformer 6 | - /task: multiclass_classification 7 | - /optimizer: adamw 8 | - /scheduler: plateau 9 | - /callbacks: [base, checkpoint] 10 | 11 | task: 12 | loss: 13 | _name_: cross_entropy 14 | metrics: 15 | - ${dataset.metric} 16 | 17 | train: 18 | monitor: val/${dataset.metric} 19 | mode: max 20 | 21 | encoder: id 22 | 23 | # we need this for classification! 24 | decoder: 25 | _name_: sequence 26 | mode: pool -------------------------------------------------------------------------------- /configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | devices: 1 4 | accelerator: gpu 5 | accumulate_grad_batches: 1 # Gradient accumulation every n batches 6 | max_epochs: 200 7 | # accelerator: ddp # Automatically set if gpus > 1 8 | gradient_clip_val: 0.0 9 | log_every_n_steps: 10 10 | limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run 11 | limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run 12 | num_sanity_val_steps: 2 # default value: 2; override to 0 to skip sanity checking 13 | -------------------------------------------------------------------------------- /configs/pipeline/longrange_benchmark.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /trainer: default 4 | - /loader: default 5 | - /dataset: long_range_benchmark # what is the name of automatic generated dict key? 6 | - /task: longrange_benchmark 7 | - /optimizer: adamw 8 | - /scheduler: plateau 9 | - /callbacks: [base, checkpoint] 10 | 11 | task: 12 | loss: 13 | _name_: cross_entropy 14 | metrics: 15 | - ${dataset.metric} 16 | 17 | train: 18 | monitor: val/${dataset.metric} 19 | mode: max 20 | 21 | encoder: id 22 | 23 | # we need this for classification! 24 | decoder: 25 | _name_: sequence_snp 26 | mode: pool -------------------------------------------------------------------------------- /configs/model/hyena.yaml: -------------------------------------------------------------------------------- 1 | _name_: hyena_lm 2 | d_model: 128 3 | n_layer: 2 4 | d_inner: ${eval:4 * ${.d_model}} 5 | vocab_size: 12 6 | resid_dropout: 0.0 7 | embed_dropout: 0.1 8 | fused_mlp: False 9 | fused_dropout_add_ln: False 10 | checkpoint_mixer: False # set true for memory reduction 11 | checkpoint_mlp: False # set true for memory reduction 12 | residual_in_fp32: True 13 | pad_vocab_size_multiple: 8 14 | layer: 15 | _name_: hyena 16 | emb_dim: 5 17 | filter_order: 64 18 | local_order: 3 19 | l_max: ${eval:${dataset.max_length}+2} 20 | modulate: True 21 | w: 10 22 | lr: ${optimizer.lr} 23 | wd: 0.0 24 | lr_pos_emb: 0.0 25 | -------------------------------------------------------------------------------- /configs/trainer/lm.yaml: -------------------------------------------------------------------------------- 1 | accumulate_grad_batches: 1 2 | # accelerator: null # set to 'ddp' for distributed 3 | # amp_backend: native # 'native' | 'apex' 4 | gpus: 8 5 | max_epochs: 50 6 | gradient_clip_val: 0.0 # Gradient clipping 7 | log_every_n_steps: 10 8 | precision: 16 9 | progress_bar_refresh_rate: 1 10 | weights_summary: top # Set to 'full' to see every layer 11 | track_grad_norm: -1 # Set to 2 to track norms of gradients 12 | limit_train_batches: 1.0 13 | limit_val_batches: 1.0 14 | # We use the dataloader from Transformer-XL to ensure adjacent minibatches 15 | # are from text that are next to each other. 16 | # So that dataloader has to deal with DDP, and we don't want PL to handle 17 | # that. 18 | replace_sampler_ddp: False 19 | -------------------------------------------------------------------------------- /configs/pipeline/hg38.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /trainer: default 4 | - /loader: null 5 | - /dataset: hg38 6 | - /optimizer: adamw 7 | - /scheduler: cosine_warmup 8 | - /callbacks: [base, checkpoint] 9 | 10 | train: 11 | monitor: test/loss 12 | mode: min 13 | 14 | task: 15 | _name_: lm 16 | loss: 17 | _name_: cross_entropy 18 | ignore_index: 4 # Bake in tokenizer value for padding / EOS tokens 19 | torchmetrics: ['perplexity', 'num_tokens'] 20 | 21 | encoder: null 22 | decoder: null 23 | 24 | loader: 25 | num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"} 26 | pin_memory: True 27 | drop_last: True # There's enough data and epochs, ignore the edge case 28 | # shuffle: True 29 | -------------------------------------------------------------------------------- /configs/dataset/eqtl_benchmark.yaml: -------------------------------------------------------------------------------- 1 | _name_: eqtl_benchmark 2 | dataset_name: eqtl_benchmark 3 | dest_path: null 4 | max_length: 450000 5 | d_output: ${.${.dataset_name}.classes} 6 | use_padding: True 7 | padding_side: 'left' 8 | add_eos: False 9 | batch_size: 1 10 | train_len: ${.${.dataset_name}.train_len} 11 | __l_max: ${.max_length} 12 | shuffle: true # set this as default! 13 | cell_type: "Adipose_Subcutaneous" 14 | 15 | eqtl_benchmark: 16 | train_len: 450000 17 | classes: 2 18 | 19 | # cell_type: 20 | # ['Adipose_Subcutaneous' 21 | # 'Artery_Tibial' 22 | # 'Cells_Cultured_fibroblasts' 23 | # 'Muscle_Skeletal' 24 | # 'Nerve_Tibial' 25 | # 'Skin_Not_Sun_Exposed_Suprapubic' 26 | # 'Skin_Sun_Exposed_Lower_leg' 27 | # 'Thyroid' 28 | # 'Whole_Blood'] 29 | -------------------------------------------------------------------------------- /src/dataloaders/utils/rc.py: -------------------------------------------------------------------------------- 1 | """Utility functions for reverse complementing DNA sequences. 2 | 3 | """ 4 | 5 | from random import random 6 | 7 | STRING_COMPLEMENT_MAP = { 8 | "A": "T", "C": "G", "G": "C", "T": "A", "a": "t", "c": "g", "g": "c", "t": "a", 9 | "N": "N", "n": "n", 10 | } 11 | 12 | def coin_flip(p=0.5): 13 | """Flip a (potentially weighted) coin.""" 14 | return random() > p 15 | 16 | 17 | def string_reverse_complement(seq): 18 | """Reverse complement a DNA sequence.""" 19 | rev_comp = "" 20 | for base in seq[::-1]: 21 | if base in STRING_COMPLEMENT_MAP: 22 | rev_comp += STRING_COMPLEMENT_MAP[base] 23 | # if bp not complement map, use the same bp 24 | else: 25 | rev_comp += base 26 | return rev_comp 27 | -------------------------------------------------------------------------------- /configs/callbacks/wandb.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | watch_model: 5 | _target_: src.callbacks.wandb_callbacks.WatchModel 6 | log: "all" 7 | log_freq: 100 8 | 9 | upload_code_as_artifact: 10 | _target_: src.callbacks.wandb_callbacks.UploadCodeAsArtifact 11 | code_dir: ${work_dir}/src 12 | 13 | upload_ckpts_as_artifact: 14 | _target_: src.callbacks.wandb_callbacks.UploadCheckpointsAsArtifact 15 | ckpt_dir: "checkpoints/" 16 | upload_best_only: True 17 | 18 | log_f1_precision_recall_heatmap: 19 | _target_: src.callbacks.wandb_callbacks.LogF1PrecRecHeatmap 20 | 21 | log_confusion_matrix: 22 | _target_: src.callbacks.wandb_callbacks.LogConfusionMatrix 23 | 24 | log_image_predictions: 25 | _target_: src.callbacks.wandb_callbacks.LogImagePredictions 26 | num_samples: 8 27 | -------------------------------------------------------------------------------- /configs/model/hyena_hg.yaml: -------------------------------------------------------------------------------- 1 | _name_: hyena_lm 2 | d_model: 256 3 | n_layer: 8 4 | d_inner: ${eval:4 * ${.d_model}} 5 | vocab_size: 12 6 | resid_dropout: 0.0 7 | embed_dropout: 0.1 8 | fused_mlp: False 9 | fused_dropout_add_ln: False 10 | checkpoint_mixer: False # set true for memory reduction 11 | checkpoint_mlp: False # set true for memory reduction 12 | residual_in_fp32: True 13 | pad_vocab_size_multiple: 8 14 | layer: 15 | _name_: hyena 16 | emb_dim: 5 17 | filter_order: 64 18 | hyena_dropout: 0.0 19 | filter_dropout: 0.0 20 | order: 2 21 | num_inner_mlps: 2 22 | short_filter_order: 3 23 | use_bias: True 24 | l_max: 450002 25 | 26 | local_order: 3 27 | # l_max: ${eval:${dataset.max_length}+2} 28 | modulate: True 29 | w: 10 30 | lr: ${optimizer.lr} 31 | wd: 0.0 32 | lr_pos_emb: 0.0 33 | -------------------------------------------------------------------------------- /configs/model/mamba.yaml: -------------------------------------------------------------------------------- 1 | # Use open-source version of Mamba 2 | _name_: mamba_lm 3 | config: 4 | _target_: mamba_ssm.models.config_mamba.MambaConfig 5 | d_model: 128 # Will be overwritten by CL in the scaling exps 6 | n_layer: 2 # Will be overwritten by CL in the scaling exps 7 | vocab_size: 12 8 | pad_vocab_size_multiple: 8 9 | rms_norm: true 10 | fused_add_norm: true 11 | residual_in_fp32: false 12 | ssm_cfg: 13 | d_state: 16 14 | d_conv: 4 15 | expand: 2 16 | dt_rank: "auto" 17 | dt_min: 0.001 18 | dt_max: 0.1 19 | dt_init: "random" 20 | dt_scale: 1.0 21 | dt_init_floor: 1e-4 22 | conv_bias: true 23 | bias: false 24 | use_fast_path: true 25 | initializer_cfg: 26 | initializer_range: 0.02 27 | rescale_prenorm_residual: true 28 | # n_residuals_per_layer: 1 29 | #norm_epsilon: 1e-5 # Default arg in mamba create_block 30 | -------------------------------------------------------------------------------- /evals/evaluate_contact_map.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.stats 3 | 4 | 5 | # 'HFF': 0, 'H1hESC': 1, 'GM12878': 2, 'IMR90': 3, 'HCT116': 4 6 | lines1 = open("akita/HFF_pred.txt").readlines() 7 | lines2 = open("akita/HFF_tgt.txt").readlines() 8 | corrs = [] 9 | for line1, line2 in zip(lines1, lines2): 10 | preds, targets = [], [] 11 | items1 = line1.strip().split() 12 | items1 = np.array([float(item) for item in items1]).reshape([200, 200]) 13 | items2 = line2.strip().split() 14 | items2 = np.array([float(item) for item in items2]).reshape([200, 200]) 15 | for i in range(200): 16 | for j in range(i+2, 200): 17 | 18 | preds.append(items1[i][j]) 19 | targets.append(items2[i][j]) 20 | cor = scipy.stats.spearmanr(preds, targets)[0] 21 | 22 | corrs.append(cor) 23 | print(cor) 24 | 25 | print(np.average(corrs)) 26 | 27 | -------------------------------------------------------------------------------- /configs/model/caduceus.yaml: -------------------------------------------------------------------------------- 1 | # Use open-source version of Mamba 2 | _name_: caduceus_lm 3 | config: 4 | _target_: caduceus.configuration_caduceus.CaduceusConfig 5 | # From original MambaConfig 6 | d_model: 128 7 | n_layer: 2 8 | vocab_size: 12 9 | ssm_cfg: 10 | d_state: 16 11 | d_conv: 4 12 | expand: 2 13 | dt_rank: "auto" 14 | dt_min: 0.001 15 | dt_max: 0.1 16 | dt_init: "random" 17 | dt_scale: 1.0 18 | dt_init_floor: 1e-4 19 | conv_bias: true 20 | bias: false 21 | use_fast_path: true 22 | rms_norm: true 23 | fused_add_norm: true 24 | residual_in_fp32: false 25 | pad_vocab_size_multiple: 8 26 | # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm 27 | norm_epsilon: 1e-5 28 | 29 | # Used in init_weights 30 | initializer_cfg: 31 | initializer_range: 0.02 32 | rescale_prenorm_residual: true 33 | n_residuals_per_layer: 1 34 | 35 | # Caduceus-specific params 36 | bidirectional: true, 37 | bidirectional_strategy: "add" 38 | bidirectional_weight_tie: true 39 | rcps: false 40 | 41 | # Used for RCPSEmbedding / RCPSLMHead (will be filled in during model instantiation using info from tokenizer) 42 | complement_map: null 43 | -------------------------------------------------------------------------------- /configs/model/caduceus_ph_131k_hg.yaml: -------------------------------------------------------------------------------- 1 | # Use open-source version of Mamba 2 | _name_: caduceus_lm 3 | config: 4 | _target_: caduceus.configuration_caduceus.CaduceusConfig 5 | # From original MambaConfig 6 | d_model: 256 7 | n_layer: 16 8 | vocab_size: 12 9 | ssm_cfg: 10 | d_state: 16 11 | d_conv: 4 12 | expand: 2 13 | dt_rank: "auto" 14 | dt_min: 0.001 15 | dt_max: 0.1 16 | dt_init: "random" 17 | dt_scale: 1.0 18 | dt_init_floor: 1e-4 19 | conv_bias: true 20 | bias: false 21 | use_fast_path: true 22 | rms_norm: true 23 | fused_add_norm: true 24 | residual_in_fp32: false 25 | pad_vocab_size_multiple: 8 26 | # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm 27 | norm_epsilon: 1e-5 28 | 29 | # Used in init_weights 30 | initializer_cfg: 31 | initializer_range: 0.02 32 | rescale_prenorm_residual: true 33 | n_residuals_per_layer: 1 34 | 35 | # Caduceus-specific params 36 | bidirectional: true, 37 | bidirectional_strategy: "add" 38 | bidirectional_weight_tie: true 39 | rcps: false 40 | 41 | # Used for RCPSEmbedding / RCPSLMHead (will be filled in during model instantiation using info from tokenizer) 42 | complement_map: null 43 | -------------------------------------------------------------------------------- /configs/trainer/full.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | # default values for all trainer parameters 4 | checkpoint_callback: True 5 | default_root_dir: null 6 | gradient_clip_val: 0.0 7 | process_position: 0 8 | num_nodes: 1 9 | num_processes: 1 10 | gpus: null 11 | auto_select_gpus: False 12 | tpu_cores: null 13 | log_gpu_memory: null 14 | overfit_batches: 0.0 15 | track_grad_norm: -1 16 | check_val_every_n_epoch: 1 17 | fast_dev_run: False 18 | accumulate_grad_batches: 1 19 | max_epochs: 1 20 | min_epochs: 1 21 | max_steps: null 22 | min_steps: null 23 | limit_train_batches: 1.0 24 | limit_val_batches: 1.0 25 | limit_test_batches: 1.0 26 | val_check_interval: 1.0 27 | flush_logs_every_n_steps: 100 28 | log_every_n_steps: 50 29 | accelerator: null 30 | sync_batchnorm: False 31 | precision: 32 32 | weights_summary: "top" 33 | weights_save_path: null 34 | num_sanity_val_steps: 2 35 | truncated_bptt_steps: null 36 | resume_from_checkpoint: null 37 | profiler: null 38 | benchmark: False 39 | deterministic: False 40 | reload_dataloaders_every_epoch: False 41 | auto_lr_find: False 42 | replace_sampler_ddp: True 43 | terminate_on_nan: False 44 | auto_scale_batch_size: False 45 | prepare_data_per_node: True 46 | plugins: null 47 | amp_backend: "native" 48 | amp_level: "O2" 49 | move_metrics_to_cpu: False 50 | -------------------------------------------------------------------------------- /configs/callbacks/checkpoint.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | monitor: ${train.monitor} # name of the logged metric which determines when model is improving 3 | mode: ${train.mode} # can be "max" or "min" 4 | save_top_k: 1 # save k best models (determined by above metric) 5 | save_last: False # True = additionally always save model from last epoch 6 | dirpath: "checkpoints/" 7 | filename: ${train.monitor} 8 | auto_insert_metric_name: False 9 | verbose: True 10 | 11 | model_checkpoint_every_n_steps: 12 | monitor: train/loss # name of the logged metric which determines when model is improving 13 | mode: min # can be "max" or "min" 14 | save_top_k: 0 # Do not save any "best" models; this callback is being used to save every n train steps 15 | save_last: True # additionally always save model from last epoch 16 | dirpath: "checkpoints/" 17 | filename: train/loss 18 | auto_insert_metric_name: False 19 | verbose: True 20 | every_n_train_steps: 100 21 | 22 | #model_checkpoint_every_epoch: 23 | # monitor: trainer/epoch # name of the logged metric which determines when model is improving 24 | # mode: max # can be "max" or "min" 25 | # save_top_k: 1 # Do not save any "best" models; this callback is being used to save every n train steps 26 | # save_last: False # additionally always save model from last epoch 27 | # dirpath: "checkpoints/" 28 | # filename: null 29 | # auto_insert_metric_name: False 30 | # verbose: True 31 | # every_n_epochs: 1 32 | -------------------------------------------------------------------------------- /configs/model/caduceus_nt.yaml: -------------------------------------------------------------------------------- 1 | # Use open-source version of Mamba 2 | _name_: caduceus_lm 3 | config: 4 | _target_: caduceus.configuration_caduceus.CaduceusConfig 5 | # From original MambaConfig 6 | d_model: 256 7 | n_layer: 4 8 | vocab_size: 16 9 | ssm_cfg: 10 | d_state: 16 11 | d_conv: 4 12 | expand: 2 13 | dt_rank: "auto" 14 | dt_min: 0.001 15 | dt_max: 0.1 16 | dt_init: "random" 17 | dt_scale: 1.0 18 | dt_init_floor: 1e-4 19 | conv_bias: true 20 | bias: false 21 | use_fast_path: true 22 | rms_norm: true 23 | fused_add_norm: true 24 | residual_in_fp32: false 25 | pad_vocab_size_multiple: 8 26 | # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm 27 | norm_epsilon: 1e-5 28 | 29 | # Used in init_weights 30 | initializer_cfg: 31 | initializer_range: 0.02 32 | rescale_prenorm_residual: true 33 | n_residuals_per_layer: 1 34 | 35 | # Caduceus-specific params 36 | bidirectional: true, 37 | bidirectional_strategy: "add" 38 | bidirectional_weight_tie: true 39 | rcps: True 40 | 41 | # Used for RCPSEmbedding / RCPSLMHead (will be filled in during model instantiation using info from tokenizer) 42 | complement_map: 43 | "0": 0, 44 | "1": 1, 45 | "2": 2, 46 | "3": 3, 47 | "4": 4, 48 | "5": 5, 49 | "6": 6, 50 | "7": 10, 51 | "8": 9, 52 | "9": 8, 53 | "10": 7, 54 | "11": 11, 55 | "12": 12, 56 | "13": 13, 57 | "14": 14, 58 | "15": 15 59 | 60 | -------------------------------------------------------------------------------- /src/callbacks/validation.py: -------------------------------------------------------------------------------- 1 | """Check validation every n **global** steps. 2 | 3 | Pytorch Lightning has a `val_check_interval` parameter that checks validation every n batches, but does not support 4 | checking every n **global** steps. 5 | """ 6 | 7 | from typing import Any 8 | 9 | from pytorch_lightning.callbacks import Callback 10 | from pytorch_lightning.trainer.states import RunningStage 11 | 12 | 13 | class ValEveryNGlobalSteps(Callback): 14 | """Check validation every n **global** steps.""" 15 | def __init__(self, every_n): 16 | self.every_n = every_n 17 | self.last_run = None 18 | 19 | def on_train_batch_end(self, trainer, *_: Any): 20 | """Check if we should run validation. 21 | 22 | Adapted from: https://github.com/Lightning-AI/pytorch-lightning/issues/2534#issuecomment-1085986529 23 | """ 24 | # Prevent Running validation many times in gradient accumulation 25 | if trainer.global_step == self.last_run: 26 | return 27 | else: 28 | self.last_run = None 29 | if trainer.global_step % self.every_n == 0 and trainer.global_step != 0: 30 | trainer.training = False 31 | stage = trainer.state.stage 32 | trainer.state.stage = RunningStage.VALIDATING 33 | trainer._run_evaluate() 34 | trainer.state.stage = stage 35 | trainer.training = True 36 | trainer._logger_connector._epoch_end_reached = False 37 | self.last_run = trainer.global_step 38 | -------------------------------------------------------------------------------- /src/callbacks/params.py: -------------------------------------------------------------------------------- 1 | """Callback to log the number of parameters of the model. 2 | 3 | """ 4 | 5 | import pytorch_lightning as pl 6 | from pytorch_lightning.utilities import rank_zero_only 7 | from pytorch_lightning.utilities.parsing import AttributeDict 8 | 9 | 10 | class ParamsLog(pl.Callback): 11 | """ Log the number of parameters of the model """ 12 | def __init__( 13 | self, 14 | total: bool = True, 15 | trainable: bool = True, 16 | fixed: bool = True, 17 | ): 18 | super().__init__() 19 | self._log_stats = AttributeDict( 20 | { 21 | 'total_params_log': total, 22 | 'trainable_params_log': trainable, 23 | 'non_trainable_params_log': fixed, 24 | } 25 | ) 26 | 27 | @rank_zero_only 28 | def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: 29 | logs = {} 30 | if self._log_stats.total_params_log: 31 | logs["params/total"] = sum(p.numel() for p in pl_module.parameters()) 32 | if self._log_stats.trainable_params_log: 33 | logs["params/trainable"] = sum(p.numel() for p in pl_module.parameters() 34 | if p.requires_grad) 35 | if self._log_stats.non_trainable_params_log: 36 | logs["params/fixed"] = sum(p.numel() for p in pl_module.parameters() 37 | if not p.requires_grad) 38 | if trainer.logger: 39 | trainer.logger.log_hyperparams(logs) 40 | -------------------------------------------------------------------------------- /configs/model/caduceus_ffn_moe_attn.yaml: -------------------------------------------------------------------------------- 1 | # Use open-source version of Mamba 2 | _name_: caduceus_ffn_moe_attn_lm 3 | config: 4 | _target_: crab.configuration_caduceus_attention_moe.CaduceusConfig 5 | # From original MambaConfig 6 | d_model: 256 7 | n_layer: 8 8 | vocab_size: 12 9 | ssm_cfg: 10 | d_state: 16 11 | d_conv: 4 12 | expand: 2 13 | dt_rank: "auto" 14 | dt_min: 0.001 15 | dt_max: 0.1 16 | dt_init: "random" 17 | dt_scale: 1.0 18 | dt_init_floor: 1e-4 19 | conv_bias: true 20 | bias: false 21 | use_fast_path: true 22 | rms_norm: true 23 | fused_add_norm: true 24 | residual_in_fp32: false 25 | pad_vocab_size_multiple: 8 26 | # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm 27 | norm_epsilon: 1e-5 28 | 29 | # Used in init_weights 30 | initializer_cfg: 31 | initializer_range: 0.02 32 | rescale_prenorm_residual: true 33 | n_residuals_per_layer: 1 34 | 35 | # Caduceus-specific params 36 | bidirectional: true, 37 | bidirectional_strategy: "add" 38 | bidirectional_weight_tie: true 39 | rcps: false 40 | 41 | # Used for RCPSEmbedding / RCPSLMHead (will be filled in during model instantiation using info from tokenizer) 42 | complement_map: null 43 | 44 | # added specifically for caduceus_ffn_moe_attn 45 | output_hidden_states: false 46 | return_dict: true 47 | 48 | # expert 49 | expert_layer_period: 4 50 | expert_layer_offset: 2 51 | num_experts: 4 52 | intermediate_factor: 4 53 | # attn 54 | num_attention_heads: 4 55 | is_causal: False 56 | attention_dropout: 0.0 57 | 58 | attn_layer_period: 4 59 | attn_layer_offset: 2 60 | 61 | 62 | -------------------------------------------------------------------------------- /configs/experiment/hg38/nucleotide_transformer.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /pipeline: nucleotide_transformer 4 | - /model: ??? 5 | - override /scheduler: cosine_warmup_timm 6 | 7 | model: 8 | _name_: dna_embedding 9 | 10 | trainer: 11 | accelerator: gpu 12 | devices: 1 13 | num_nodes: 1 14 | accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}} 15 | max_epochs: 100 16 | precision: 16 # bf16 only a100 17 | gradient_clip_val: 1.0 18 | 19 | dataset: 20 | tokenizer_name: char 21 | rc_aug: false # reverse complement augmentation 22 | 23 | scheduler: 24 | t_in_epochs: False 25 | t_initial: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}} 26 | warmup_lr_init: 1e-6 27 | warmup_t: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.01} 28 | lr_min: ${eval:0.1 * ${optimizer.lr}} 29 | 30 | optimizer: 31 | lr: 1e-3 32 | weight_decay: 0.1 33 | 34 | train: 35 | gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"} 36 | seed: 2222 37 | global_batch_size: ${dataset.batch_size} 38 | cross_validation: true 39 | remove_test_loader_in_eval: true # test only at the end of training 40 | pretrained_model_strict_load: false # false allows encoder/decoder to be used if new model uses it 41 | # for loading backbone and not head, requires both of these flags below 42 | pretrained_model_path: ??? 43 | pretrained_model_state_hook: 44 | _name_: load_backbone 45 | freeze_backbone: false 46 | -------------------------------------------------------------------------------- /configs/experiment/hg38/hg38.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /pipeline: hg38 4 | - /model: ??? # Specify a model, e.g. model=mamba or model=hyena 5 | - override /scheduler: cosine_warmup_timm 6 | 7 | task: 8 | _name_: lm 9 | loss: 10 | _name_: cross_entropy 11 | ignore_index: 4 12 | 13 | trainer: 14 | accelerator: gpu 15 | devices: 1 16 | num_nodes: 1 17 | accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}} 18 | max_epochs: null 19 | max_steps: 10000 20 | precision: 16 # bf16 only a100 21 | gradient_clip_val: 1.0 22 | limit_val_batches: 0.125 23 | 24 | dataset: 25 | batch_size: ${eval:1024//${trainer.devices}} 26 | max_length: 1024 27 | # optional, default is max_length 28 | max_length_val: ${dataset.max_length} 29 | max_length_test: ${dataset.max_length} 30 | tokenizer_name: char 31 | pad_max_length: null # needed for bpe tokenizer 32 | add_eos: true 33 | rc_aug: false 34 | num_workers: 12 35 | use_fixed_len_val: false # placing a fixed length val here, but it's really the test 36 | mlm: false 37 | mlm_probability: 0.0 38 | 39 | scheduler: 40 | t_in_epochs: False 41 | t_initial: ${eval:${trainer.max_steps}-${.warmup_t}} 42 | warmup_prefix: True 43 | warmup_lr_init: 1e-6 44 | warmup_t: ${eval:0.1*${trainer.max_steps}} 45 | lr_min: 1e-4 46 | 47 | optimizer: 48 | lr: 6e-4 49 | weight_decay: 0.1 50 | betas: [0.9, 0.95] 51 | 52 | train: 53 | gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"} 54 | seed: 2222 55 | global_batch_size: 256 # effects the scheduler, need to set properly 56 | -------------------------------------------------------------------------------- /src/dataloaders/utils/mlm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def mlm_getitem(seq, mlm_probability=0.15, contains_eos=False, tokenizer=None, eligible_replacements=None): 5 | """Helper method for creating MLM input / target. 6 | 7 | Adapted from: 8 | https://github.com/huggingface/transformers/blob/14666775a296a76c88e1aa686a9547f393d322e2/src/transformers/data/data_collator.py#L751 9 | """ 10 | data = seq[:-1].clone() if contains_eos else seq.clone() # remove eos, if applicable 11 | target = data.clone() 12 | # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`) 13 | probability_matrix = torch.full(target.shape, mlm_probability) 14 | # TODO: Do we need to avoid "masking" special tokens as is done here? 15 | # https://github.com/huggingface/transformers/blob/14666775a296a76c88e1aa686a9547f393d322e2/src/transformers/data/data_collator.py#L760-L766 16 | masked_indices = torch.bernoulli(probability_matrix).bool() 17 | target[~masked_indices] = tokenizer.pad_token_id # We only compute loss on masked tokens 18 | 19 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 20 | indices_replaced = torch.bernoulli(torch.full(target.shape, 0.8)).bool() & masked_indices 21 | data[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) 22 | 23 | # 10% of the time, we replace masked input tokens with random word 24 | indices_random = torch.bernoulli(torch.full(target.shape, 0.5)).bool() & masked_indices & ~indices_replaced 25 | if eligible_replacements is not None: 26 | rand_choice = torch.randint(eligible_replacements.shape[0], size=target.shape) 27 | random_words = eligible_replacements[rand_choice] 28 | else: 29 | random_words = torch.randint(len(tokenizer), size=target.shape, dtype=torch.long) 30 | data[indices_random] = random_words[indices_random] 31 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged 32 | return data, target 33 | -------------------------------------------------------------------------------- /src/models/baseline/genomics_benchmark_cnn.py: -------------------------------------------------------------------------------- 1 | """Genomics Benchmark CNN model. 2 | 3 | Adapted from https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks/blob/main/src/genomic_benchmarks/models/torch.py 4 | """ 5 | 6 | import torch 7 | from torch import nn 8 | 9 | 10 | class GenomicsBenchmarkCNN(nn.Module): 11 | def __init__(self, number_of_classes, vocab_size, input_len, embedding_dim=100): 12 | """Genomics Benchmark CNN model. 13 | 14 | `embedding_dim` = 100 comes from: 15 | https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks/tree/main/experiments/torch_cnn_experiments 16 | """ 17 | super(GenomicsBenchmarkCNN, self).__init__() 18 | 19 | self.embeddings = nn.Embedding(vocab_size, embedding_dim) 20 | self.cnn_model = nn.Sequential( 21 | nn.Conv1d(in_channels=embedding_dim, out_channels=16, kernel_size=8, bias=True), 22 | nn.BatchNorm1d(16), 23 | nn.ReLU(), 24 | nn.MaxPool1d(2), 25 | 26 | nn.Conv1d(in_channels=16, out_channels=8, kernel_size=8, bias=True), 27 | nn.BatchNorm1d(8), 28 | nn.MaxPool1d(2), 29 | 30 | nn.Conv1d(in_channels=8, out_channels=4, kernel_size=8, bias=True), 31 | nn.BatchNorm1d(4), 32 | nn.MaxPool1d(2), 33 | 34 | nn.Flatten() 35 | ) 36 | self.dense_model = nn.Sequential( 37 | nn.Linear(self.count_flatten_size(input_len), 512), 38 | # To be consistent with SSM classifier decoders, we use num_classes (even when it's binary) 39 | nn.Linear(512, number_of_classes) 40 | ) 41 | 42 | def count_flatten_size(self, input_len): 43 | zeros = torch.zeros([1, input_len], dtype=torch.long) 44 | x = self.embeddings(zeros) 45 | x = x.transpose(1, 2) 46 | x = self.cnn_model(x) 47 | return x.size()[1] 48 | 49 | def forward(self, x, state=None): # Adding `state` to be consistent with other models 50 | x = self.embeddings(x) 51 | x = x.transpose(1, 2) 52 | x = self.cnn_model(x) 53 | x = self.dense_model(x) 54 | return x, state # Returning tuple to be consistent with other models 55 | -------------------------------------------------------------------------------- /configs/dataset/genomic_benchmark.yaml: -------------------------------------------------------------------------------- 1 | _name_: genomic_benchmark 2 | train_val_split_seed: ${train.seed} # Used for train/validation splitting 3 | dataset_name: dummy_mouse_enhancers_ensembl 4 | dest_path: null 5 | max_length: ${.${.dataset_name}.max_length} 6 | max_length_val: ${.max_length} 7 | max_length_test: ${.max_length} 8 | d_output: ${.${.dataset_name}.classes} 9 | use_padding: True 10 | padding_side: 'left' 11 | add_eos: False 12 | batch_size: 128 13 | train_len: ${.${.dataset_name}.train_len} 14 | __l_max: ${.max_length} 15 | shuffle: true # set this as default! 16 | # these are used to find the right attributes automatically for each dataset 17 | dummy_mouse_enhancers_ensembl: 18 | train_len: 1210 19 | classes: 2 20 | max_length: 1024 21 | demo_coding_vs_intergenomic_seqs: 22 | train_len: 100_000 23 | classes: 2 24 | max_length: 200 25 | demo_human_or_worm: 26 | train_len: 100_000 27 | classes: 2 28 | max_length: 200 29 | human_enhancers_cohn: 30 | train_len: 27791 31 | classes: 2 32 | max_length: 500 33 | human_enhancers_ensembl: 34 | train_len: 154842 35 | classes: 2 36 | max_length: 512 37 | human_ensembl_regulatory: 38 | train_len: 289061 39 | classes: 3 40 | max_length: 512 41 | human_nontata_promoters: 42 | train_len: 36131 43 | classes: 2 44 | max_length: 251 45 | human_ocr_ensembl: 46 | train_len: 174756 47 | classes: 2 48 | max_length: 512 49 | 50 | # there are 8 datasets in this suite, choose 1 at a time, with their corresponding settings 51 | # name num_seqs num_classes median len std 52 | # dummy_mouse_enhancers_ensembl 1210 2 2381 984.4 53 | # demo_coding_vs_intergenomic_seqs 100_000 2 200 0 54 | # demo_human_or_worm 100_000 2 200 0 55 | # human_enhancers_cohn 27791 2 500 0 56 | # human_enhancers_ensembl 154842 2 269 122.6 57 | # human_ensembl_regulatory 289061 3 401 184.3 58 | # human_nontata_promoters 36131 2 251 0 59 | # human_ocr_ensembl 174756 2 315 108.1 60 | -------------------------------------------------------------------------------- /caduceus/configuration_caduceus.py: -------------------------------------------------------------------------------- 1 | """Caduceus config for Hugging Face. 2 | 3 | """ 4 | 5 | from typing import Optional, Union 6 | 7 | from transformers import PretrainedConfig 8 | 9 | 10 | class CaduceusConfig(PretrainedConfig): 11 | """Config that extends the original MambaConfig with params relevant to bi-directionality and RC equivariance.""" 12 | model_type = "caduceus" 13 | 14 | def __init__( 15 | self, 16 | # From original MambaConfig 17 | d_model: int = 2560, 18 | n_layer: int = 64, 19 | vocab_size: int = 50277, 20 | ssm_cfg: Optional[dict] = None, 21 | rms_norm: bool = True, 22 | residual_in_fp32: bool = True, 23 | fused_add_norm: bool = True, 24 | pad_vocab_size_multiple: int = 8, 25 | 26 | # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm 27 | norm_epsilon: float = 1e-5, 28 | 29 | # Used in init_weights 30 | initializer_cfg: Optional[dict] = None, 31 | 32 | # Caduceus-specific params 33 | bidirectional: bool = True, 34 | bidirectional_strategy: Union[str, None] = "add", 35 | bidirectional_weight_tie: bool = True, 36 | rcps: bool = False, 37 | complement_map: Optional[dict] = None, # used for RCPSEmbedding / RCPSLMHead 38 | gradient_checkpointing: bool = False, 39 | **kwargs, 40 | ): 41 | super().__init__(**kwargs) 42 | self.d_model = d_model 43 | self.n_layer = n_layer 44 | self.vocab_size = vocab_size 45 | self.ssm_cfg = ssm_cfg 46 | self.rms_norm = rms_norm 47 | self.residual_in_fp32 = residual_in_fp32 48 | self.fused_add_norm = fused_add_norm 49 | self.pad_vocab_size_multiple = pad_vocab_size_multiple 50 | self.norm_epsilon = norm_epsilon 51 | self.initializer_cfg = initializer_cfg 52 | self.bidirectional = bidirectional 53 | self.bidirectional_strategy = bidirectional_strategy 54 | self.bidirectional_weight_tie = bidirectional_weight_tie 55 | self.rcps = rcps 56 | self.complement_map = complement_map 57 | 58 | self.gradient_checkpointing = gradient_checkpointing 59 | -------------------------------------------------------------------------------- /evals/evaluate_auroc_janus.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import random 4 | from sklearn.metrics import roc_auc_score 5 | from sklearn.metrics import average_precision_score 6 | 7 | # automatically go through all the output for all cell types, and 8 | 9 | def fprint(*args, **kwargs): 10 | """Print to file and stdout""" 11 | print(*args, **kwargs) 12 | conclude_output_file.write(" ".join(map(str, args)) + "\n") 13 | conclude_output_file.flush() 14 | 15 | 16 | CELL_TYPES=[ 17 | "Adipose_Subcutaneous", 18 | "Artery_Tibial", 19 | "Cells_Cultured_fibroblasts", 20 | "Muscle_Skeletal", 21 | "Nerve_Tibial", 22 | "Skin_Not_Sun_Exposed_Suprapubic", 23 | "Skin_Sun_Exposed_Lower_leg", 24 | "Thyroid", 25 | "Whole_Blood", 26 | ] 27 | 28 | WATCH_FOLDER_BASE_PATH = "TO THE LOG FOLDER PATH" 29 | MODEL_NAME = "janusdna_len-131k_d_model-128_inter_dim-512_n_layer-8_lr-8e-3_step-50K_moeloss-true_1head_onlymoe" 30 | END_SUFFIX = "lr-5e-5_ftepoch-3_cjtrain_false_cjtest_true_batch_2_withpretrainedweight_bf16mix_output.log" 31 | 32 | 33 | CONCLUDE_OUTPUT_FILE_DIR_PATH = os.path.join(WATCH_FOLDER_BASE_PATH, MODEL_NAME) 34 | conclude_output_file_path = os.path.join(CONCLUDE_OUTPUT_FILE_DIR_PATH, END_SUFFIX) 35 | print("conclude_output_file_path", conclude_output_file_path) 36 | 37 | conclude_output_file = open(conclude_output_file_path, "w", encoding="utf-8") 38 | 39 | # print("CONCLUDE_FILE_PATH", os.path.join(CONCLUDE_OUTPUT_FILE_DIR_PATH, END_SUFFIX)) 40 | fprint("original_watch_folder: ", os.path.join(WATCH_FOLDER_BASE_PATH, MODEL_NAME)) 41 | fprint("MODEL_NAME: ", MODEL_NAME) 42 | fprint("end_suffix: ", END_SUFFIX) 43 | 44 | fprint("\n") 45 | fprint("metric_value", "cell_type", "average_precision_score", "metric_type") 46 | for cell_type in CELL_TYPES: 47 | file_path = os.path.join(WATCH_FOLDER_BASE_PATH, MODEL_NAME, cell_type, cell_type + "_" + END_SUFFIX) 48 | if not os.path.exists(file_path): 49 | fprint(f"File not found: {file_path}") 50 | continue 51 | preds, targets = [], [] 52 | lines = open(file_path).readlines() 53 | for idx, line in enumerate(lines): 54 | items = line.strip().split() 55 | preds.append(float(items[0])) 56 | targets.append(float(items[1])) 57 | 58 | # fprint() 59 | 60 | fprint(roc_auc_score(targets, preds), cell_type, ": ", average_precision_score(targets, preds), "auroc") 61 | 62 | -------------------------------------------------------------------------------- /configs/experiment/hg38/genomic_benchmark_cnn.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /model: genomics_benchmark_cnn 4 | - /pipeline: genomic_benchmark 5 | - override /scheduler: cosine_warmup_timm 6 | 7 | # there are 8 datasets in this suite, choose 1 at a time, with their corresponding settings 8 | # name num_seqs num_classes median len std 9 | # dummy_mouse_enhancers_ensembl 1210 2 2381 984.4 10 | # demo_coding_vs_intergenomic_seqs 100_000 2 200 0 11 | # demo_human_or_worm 100_000 2 200 0 12 | # human_enhancers_cohn 27791 2 500 0 13 | # human_enhancers_ensembl 154842 2 269 122.6 14 | # human_ensembl_regulatory 289061 3 401 184.3 15 | # human_nontata_promoters 36131 2 251 0 16 | # human_ocr_ensembl 174756 2 315 108.1 17 | 18 | task: 19 | loss: 20 | _name_: cross_entropy 21 | 22 | trainer: 23 | accelerator: gpu 24 | devices: 1 25 | num_nodes: 1 26 | accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}} 27 | max_epochs: 100 28 | precision: 16 # bf16 only a100 29 | gradient_clip_val: 1.0 30 | 31 | encoder: id 32 | decoder: id 33 | 34 | dataset: 35 | tokenizer_name: char 36 | rc_aug: false # reverse complement augmentation 37 | 38 | scheduler: 39 | t_in_epochs: False 40 | t_initial: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}} 41 | warmup_lr_init: 1e-6 42 | warmup_t: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.01} 43 | lr_min: ${eval:0.1 * ${optimizer.lr}} 44 | 45 | 46 | optimizer: 47 | lr: 6e-4 48 | weight_decay: 0.1 49 | 50 | train: 51 | gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"} 52 | seed: 2222 53 | global_batch_size: ${dataset.batch_size} 54 | cross_validation: true 55 | remove_test_loader_in_eval: true 56 | pretrained_model_strict_load: false # false allows encoder/decoder to be used if new model uses it 57 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data.tar.gz 2 | *.tsf 3 | *.ckpt 4 | .ipynb_checkpoints 5 | */.ipynb_checkpoints/* 6 | *.lprof 7 | 8 | .DS_Store 9 | .idea/ 10 | outputs/ 11 | multirun/ 12 | 13 | # slurm log files 14 | watch_folder/ 15 | runable_scripts/ 16 | 17 | data/ 18 | 19 | # Created by https://www.gitignore.io/api/python 20 | # Edit at https://www.gitignore.io/?templates=python 21 | 22 | ### Python ### 23 | # Byte-compiled / optimized / DLL files 24 | __pycache__/ 25 | *.py[cod] 26 | *$py.class 27 | 28 | # C extensions 29 | *.so 30 | 31 | # Distribution / packaging 32 | .Python 33 | build/ 34 | develop-eggs/ 35 | dist/ 36 | downloads/ 37 | eggs/ 38 | .eggs/ 39 | lib/ 40 | lib64/ 41 | parts/ 42 | sdist/ 43 | var/ 44 | wheels/ 45 | pip-wheel-metadata/ 46 | share/python-wheels/ 47 | *.egg-info/ 48 | .installed.cfg 49 | *.egg 50 | MANIFEST 51 | 52 | # PyInstaller 53 | # Usually these files are written by a python script from a template 54 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 55 | *.manifest 56 | *.spec 57 | 58 | # Installer logs 59 | pip-log.txt 60 | pip-delete-this-directory.txt 61 | 62 | # Unit test / coverage reports 63 | htmlcov/ 64 | .tox/ 65 | .nox/ 66 | .coverage 67 | .coverage.* 68 | .cache 69 | nosetests.xml 70 | coverage.xml 71 | *.cover 72 | .hypothesis/ 73 | .pytest_cache/ 74 | 75 | # Translations 76 | *.mo 77 | *.pot 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # celery beat schedule file 99 | celerybeat-schedule 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Spyder project settings 105 | .spyderproject 106 | .spyproject 107 | 108 | # Rope project settings 109 | .ropeproject 110 | 111 | # Mr Developer 112 | .mr.developer.cfg 113 | .project 114 | .pydevproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | .dmypy.json 122 | dmypy.json 123 | 124 | # Pyre type checker 125 | .pyre/ 126 | 127 | # End of https://www.gitignore.io/api/python 128 | -------------------------------------------------------------------------------- /evals/evaluate_auroc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import random 4 | from sklearn.metrics import roc_auc_score 5 | from sklearn.metrics import average_precision_score 6 | 7 | # automatically go through all the output for all cell types, and 8 | 9 | def fprint(*args, **kwargs): 10 | """Print to file and stdout""" 11 | print(*args, **kwargs) 12 | conclude_output_file.write(" ".join(map(str, args)) + "\n") 13 | conclude_output_file.flush() 14 | 15 | 16 | CELL_TYPES=[ 17 | "Adipose_Subcutaneous", 18 | "Artery_Tibial", 19 | "Cells_Cultured_fibroblasts", 20 | "Muscle_Skeletal", 21 | "Nerve_Tibial", 22 | "Skin_Not_Sun_Exposed_Suprapubic", 23 | "Skin_Sun_Exposed_Lower_leg", 24 | "Thyroid", 25 | "Whole_Blood", 26 | ] 27 | # Replace the following with the actual path to your watch folder 28 | WATCH_FOLDER_BASE_PATH = "" # e.g., ".../janusdna/watch_folder/DNALong/eQTL/caduceus-ph_seqlen-131k_d_model-256_n_layer-16" 29 | MODEL_NAME = "janusdna_len-131k_d_model-128_inter_dim-512_n_layer-8_lr-8e-3_step-50K_moeloss-true_1head" 30 | END_SUFFIX = "lr-1e-3_ftepoch-1_cjtrain_false_cjtest_true_batch_2_withpretrainedweight_bf16mix_output.log" 31 | 32 | 33 | 34 | CONCLUDE_OUTPUT_FILE_DIR_PATH = os.path.join(WATCH_FOLDER_BASE_PATH, MODEL_NAME) 35 | conclude_output_file_path = os.path.join(CONCLUDE_OUTPUT_FILE_DIR_PATH, END_SUFFIX) 36 | print("conclude_output_file_path", conclude_output_file_path) 37 | 38 | conclude_output_file = open(conclude_output_file_path, "w", encoding="utf-8") 39 | 40 | # print("CONCLUDE_FILE_PATH", os.path.join(CONCLUDE_OUTPUT_FILE_DIR_PATH, END_SUFFIX)) 41 | fprint("original_watch_folder: ", os.path.join(WATCH_FOLDER_BASE_PATH, MODEL_NAME)) 42 | fprint("MODEL_NAME: ", MODEL_NAME) 43 | fprint("end_suffix: ", END_SUFFIX) 44 | 45 | fprint("\n") 46 | fprint("metric_value", "cell_type", "average_precision_score", "metric_type") 47 | for cell_type in CELL_TYPES: 48 | file_path = os.path.join(WATCH_FOLDER_BASE_PATH, MODEL_NAME, cell_type, cell_type + "_" + END_SUFFIX) 49 | if not os.path.exists(file_path): 50 | fprint(f"File not found: {file_path}") 51 | continue 52 | preds, targets = [], [] 53 | lines = open(file_path).readlines() 54 | for idx, line in enumerate(lines): 55 | items = line.strip().split() 56 | preds.append(float(items[0])) 57 | targets.append(float(items[1])) 58 | 59 | # fprint() 60 | 61 | fprint(roc_auc_score(targets, preds), cell_type, ": ", average_precision_score(targets, preds), "auroc") 62 | 63 | -------------------------------------------------------------------------------- /configs/experiment/hg38/genomic_benchmark.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /pipeline: genomic_benchmark 4 | - /model: ??? 5 | - override /scheduler: cosine_warmup_timm 6 | 7 | # there are 8 datasets in this suite, choose 1 at a time, with their corresponding settings 8 | # name num_seqs num_classes median len std 9 | # dummy_mouse_enhancers_ensembl 1210 2 2381 984.4 10 | # demo_coding_vs_intergenomic_seqs 100_000 2 200 0 11 | # demo_human_or_worm 100_000 2 200 0 12 | # human_enhancers_cohn 27791 2 500 0 13 | # human_enhancers_ensembl 154842 2 269 122.6 14 | # human_ensembl_regulatory 289061 3 401 184.3 15 | # human_nontata_promoters 36131 2 251 0 16 | # human_ocr_ensembl 174756 2 315 108.1 17 | 18 | task: 19 | loss: 20 | _name_: cross_entropy 21 | 22 | trainer: 23 | accelerator: gpu 24 | devices: 1 25 | num_nodes: 1 26 | accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}} 27 | max_epochs: 100 28 | precision: 16 # bf16 only a100 29 | gradient_clip_val: 1.0 30 | 31 | model: 32 | _name_: dna_embedding 33 | 34 | dataset: 35 | # optional, default is max_length 36 | tokenizer_name: char 37 | rc_aug: false # reverse complement augmentation 38 | 39 | scheduler: 40 | # COSINE TIMM 41 | t_in_epochs: False 42 | t_initial: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}} 43 | warmup_lr_init: 1e-6 44 | warmup_t: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.01} 45 | lr_min: ${eval:0.1 * ${optimizer.lr}} 46 | 47 | 48 | optimizer: 49 | lr: 6e-4 50 | weight_decay: 0.1 51 | 52 | train: 53 | gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"} 54 | seed: 2222 55 | global_batch_size: ${dataset.batch_size} 56 | cross_validation: true 57 | remove_test_loader_in_eval: true # test only at the end of training 58 | pretrained_model_strict_load: false # false allows encoder/decoder to be used if new model uses it 59 | # for loading backbone and not head, requires both of these flags below 60 | pretrained_model_path: ??? 61 | pretrained_model_state_hook: 62 | _name_: load_backbone 63 | freeze_backbone: false 64 | -------------------------------------------------------------------------------- /src/tasks/encoders.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | import src.models.nn.utils as U 4 | import src.utils as utils 5 | 6 | 7 | class Encoder(nn.Module): 8 | """Encoder abstraction 9 | 10 | Accepts a tensor and optional kwargs. Other than the main tensor, all other arguments should be kwargs. 11 | Returns a tensor and optional kwargs. 12 | Encoders are combined via U.PassthroughSequential which passes these kwargs through in a pipeline. The resulting 13 | kwargs are accumulated and passed into the model backbone. 14 | """ 15 | 16 | def forward(self, x, **kwargs): 17 | """ 18 | x: input tensor 19 | *args: additional info from the dataset (e.g. sequence lengths) 20 | 21 | Returns: 22 | y: output tensor 23 | *args: other arguments to pass into the model backbone 24 | """ 25 | return x, {} 26 | 27 | 28 | # For every type of encoder/decoder, specify: 29 | # - constructor class 30 | # - list of attributes to grab from dataset 31 | # - list of attributes to grab from model 32 | 33 | registry = { 34 | "stop": Encoder, 35 | "id": nn.Identity, 36 | "embedding": nn.Embedding, 37 | "linear": nn.Linear, 38 | } 39 | 40 | dataset_attrs = { 41 | "embedding": ["n_tokens"], 42 | "linear": ["d_input"], # TODO make this d_data? 43 | "class": ["n_classes"], 44 | "time": ["n_tokens_time"], 45 | "onehot": ["n_tokens"], 46 | "conv1d": ["d_input"], 47 | "patch2d": ["d_input"], 48 | } 49 | 50 | model_attrs = { 51 | "embedding": ["d_model"], 52 | "linear": ["d_model"], 53 | "position": ["d_model"], 54 | "class": ["d_model"], 55 | "time": ["d_model"], 56 | "onehot": ["d_model"], 57 | "conv1d": ["d_model"], 58 | "patch2d": ["d_model"], 59 | "timestamp_embedding": ["d_model"], 60 | "layer": ["d_model"], 61 | } 62 | 63 | 64 | def _instantiate(encoder, dataset=None, model=None): 65 | """Instantiate a single encoder""" 66 | if encoder is None: 67 | return None 68 | if isinstance(encoder, str): 69 | name = encoder 70 | else: 71 | name = encoder["_name_"] 72 | 73 | # Extract dataset/model arguments from attribute names 74 | dataset_args = utils.config.extract_attrs_from_obj( 75 | dataset, *dataset_attrs.get(name, []) 76 | ) 77 | model_args = utils.config.extract_attrs_from_obj(model, *model_attrs.get(name, [])) 78 | 79 | # Instantiate encoder 80 | obj = utils.instantiate(registry, encoder, *dataset_args, *model_args) 81 | return obj 82 | 83 | 84 | def instantiate(encoder, dataset=None, model=None): 85 | encoder = utils.to_list(encoder) 86 | return U.PassthroughSequential( 87 | *[_instantiate(e, dataset=dataset, model=model) for e in encoder] 88 | ) 89 | -------------------------------------------------------------------------------- /configs/model/janusdna.yaml: -------------------------------------------------------------------------------- 1 | # Use open-source version of Mamba 2 | _name_: janusdna_lm 3 | config: 4 | _target_: janusdna.configuration_janusdna.JanusDNAConfig 5 | # From original MambaConfig 6 | hidden_size: 256 7 | flex_attn_n_embd: 256 8 | num_hidden_layers: 15 9 | vocab_size: 12 10 | ssm_cfg: 11 | d_state: 16 12 | d_conv: 4 13 | expand: 2 14 | dt_rank: "auto" 15 | dt_min: 0.001 16 | dt_max: 0.1 17 | dt_init: "random" 18 | dt_scale: 1.0 19 | dt_init_floor: 1e-4 20 | conv_bias: true 21 | bias: false 22 | use_fast_path: true 23 | # rms_norm: true 24 | # fused_add_norm: true 25 | # residual_in_fp32: false 26 | # pad_vocab_size_multiple: 8 27 | # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm 28 | rms_norm_eps: 1e-6 29 | 30 | # Used in init_weights 31 | # initializer_cfg: 32 | initializer_range: 0.02 33 | # rescale_prenorm_residual: true 34 | # n_residuals_per_layer: 1 35 | 36 | # Caduceus-specific params 37 | bidirectional: true, 38 | bidirectional_strategy: "add" 39 | bidirectional_weight_tie: true 40 | # rcps: false 41 | 42 | # Used for RCPSEmbedding / RCPSLMHead (will be filled in during model instantiation using info from tokenizer) 43 | # complement_map: null 44 | 45 | # added specifically for caduceus_ffn_moe_attn 46 | # output_hidden_states: false 47 | # return_dict: true 48 | 49 | # # expert 50 | # expert_layer_period: 4 51 | # expert_layer_offset: 2 52 | # num_experts: 4 53 | # intermediate_factor: 4 54 | # # attn 55 | # num_attention_heads: 4 56 | # is_causal: False 57 | # attention_dropout: 0.0 58 | 59 | # attn_layer_period: 4 60 | # attn_layer_offset: 100 61 | return_dict: true 62 | 63 | 64 | # moe 65 | num_experts: 16 66 | num_experts_per_tok: 2 67 | expert_layer_period: 4 # 100 means no moe 68 | expert_layer_offset: 2 # layer_name % expert_layer_period == expert_layer_offset, would be a MOE layer 69 | output_hidden_states: False 70 | intermediate_factor: 4 71 | 72 | # bidirectional: True 73 | # # bidirectional_strategy: "add" # todo: # should be "add", "concat", "ew_multiply" and "final layer transformer" 74 | # bidirectional_weight_tie: False 75 | # key params for autoregressive training diagram 76 | layer_fusion: False # if layer_fusion, bi-directional output of each layer would be fused. If not, will just concat and fuse at last decode layer. 77 | 78 | # attn 79 | num_attention_heads: 4 80 | attn_implementation: "flash_attention_2" ## acutally is flex attention 81 | attn_layer_period: 4 # every 8 layers is an jamba block 82 | attn_layer_offset: 2 # every 4 layers, there is an attention layer 83 | 84 | # final 85 | final_attention: true 86 | mid_single_direction_attention: true 87 | layer_fusion_strategy: "pool" # "pool" or "None" 88 | final_attention_class: "flex_attention" 89 | bidirectional_attn_tie: false 90 | 91 | gradient_checkpointing: true -------------------------------------------------------------------------------- /src/models/nn/activation.py: -------------------------------------------------------------------------------- 1 | """Utilities for activation functions.""" 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def Activation(activation=None, size=None, dim=-1): 10 | """Returns a PyTorch activation module.""" 11 | if activation in [None, 'id', 'identity', 'linear', 'none']: 12 | return nn.Identity() 13 | elif activation == 'tanh': 14 | return nn.Tanh() 15 | elif activation == 'relu': 16 | return nn.ReLU() 17 | elif activation == 'gelu': 18 | return nn.GELU() 19 | elif activation == 'elu': 20 | return nn.ELU() 21 | elif activation in ['swish', 'silu']: 22 | return nn.SiLU() 23 | elif activation == 'glu': 24 | return nn.GLU(dim=dim) 25 | elif activation.startswith('glu-'): 26 | return GLU(dim=dim, activation=activation[4:]) 27 | elif activation == 'sigmoid': 28 | return nn.Sigmoid() 29 | elif activation == 'softplus': 30 | return nn.Softplus() 31 | elif activation == 'modrelu': 32 | return ModReLU(size) 33 | elif activation in ['sqrelu', 'relu2']: 34 | return SquaredReLU() 35 | elif activation == 'laplace': 36 | return Laplace() 37 | # Earlier experimentation with a LN in the middle of the block instead of activation 38 | # IIRC ConvNext does something like this? 39 | # elif activation == 'ln': 40 | # return TransposedLN(dim) 41 | else: 42 | raise NotImplementedError("hidden activation '{}' is not implemented".format(activation)) 43 | 44 | 45 | class GLU(nn.Module): 46 | def __init__(self, dim=-1, activation='sigmoid'): 47 | super().__init__() 48 | assert not activation.startswith('glu') 49 | self.dim = dim 50 | self.activation_fn = Activation(activation) 51 | 52 | def forward(self, x): 53 | x, g = torch.split(x, x.size(self.dim) // 2, dim=self.dim) 54 | return x * self.activation_fn(g) 55 | 56 | 57 | class ModReLU(nn.Module): 58 | # Adapted from https://github.com/Lezcano/expRNN 59 | 60 | def __init__(self, features): 61 | # For now we just support square layers 62 | super().__init__() 63 | self.features = features 64 | self.b = nn.Parameter(torch.Tensor(self.features)) 65 | self.reset_parameters() 66 | 67 | def reset_parameters(self): 68 | self.b.data.uniform_(-0.01, 0.01) 69 | 70 | def forward(self, inputs): 71 | norm = torch.abs(inputs) 72 | biased_norm = norm + self.b 73 | magnitude = F.relu(biased_norm) 74 | phase = torch.sign(inputs) 75 | 76 | return phase * magnitude 77 | 78 | 79 | class SquaredReLU(nn.Module): 80 | def forward(self, x): 81 | # return F.relu(x)**2 82 | return torch.square(F.relu(x)) # Could this be faster? 83 | 84 | 85 | def laplace(x, mu=0.707107, sigma=0.282095): 86 | x = (x - mu).div(sigma * math.sqrt(2.0)) 87 | return 0.5 * (1.0 + torch.erf(x)) 88 | 89 | 90 | class Laplace(nn.Module): 91 | def __init__(self, mu=0.707107, sigma=0.282095): 92 | super().__init__() 93 | self.mu = mu 94 | self.sigma = sigma 95 | 96 | def forward(self, x): 97 | return laplace(x, mu=self.mu, sigma=self.sigma) 98 | -------------------------------------------------------------------------------- /src/utils/registry.py: -------------------------------------------------------------------------------- 1 | """Class registry for models, layers, optimizers, and schedulers. 2 | 3 | """ 4 | 5 | optimizer = { 6 | "adam": "torch.optim.Adam", 7 | "adamw": "torch.optim.AdamW", 8 | "rmsprop": "torch.optim.RMSprop", 9 | "sgd": "torch.optim.SGD", 10 | "lamb": "src.utils.optim.lamb.JITLamb", 11 | } 12 | 13 | scheduler = { 14 | "constant": "transformers.get_constant_schedule", 15 | "plateau": "torch.optim.lr_scheduler.ReduceLROnPlateau", 16 | "step": "torch.optim.lr_scheduler.StepLR", 17 | "multistep": "torch.optim.lr_scheduler.MultiStepLR", 18 | "cosine": "torch.optim.lr_scheduler.CosineAnnealingLR", 19 | "constant_warmup": "transformers.get_constant_schedule_with_warmup", 20 | "linear_warmup": "transformers.get_linear_schedule_with_warmup", 21 | "cosine_warmup": "transformers.get_cosine_schedule_with_warmup", 22 | "cosine_warmup_timm": "src.utils.optim.schedulers.TimmCosineLRScheduler", 23 | } 24 | 25 | model = { 26 | # Pre-training LM head models 27 | "hyena_lm": "src.models.sequence.long_conv_lm.ConvLMHeadModel", 28 | "mamba_lm": "mamba_ssm.models.mixer_seq_simple.MambaLMHeadModel", 29 | "caduceus_lm": "caduceus.modeling_caduceus.CaduceusForMaskedLM", 30 | "janusdna_lm": "janusdna.modeling_janusdna.JanusDNAForCausalLM", 31 | 32 | # Downstream task embedding backbones 33 | "dna_embedding": "src.models.sequence.dna_embedding.DNAEmbeddingModel", 34 | "dna_embedding_mamba": "src.models.sequence.dna_embedding.DNAEmbeddingModelMamba", 35 | "dna_embedding_caduceus": "src.models.sequence.dna_embedding.DNAEmbeddingModelCaduceus", 36 | "dna_embedding_hf_caduceus": "src.models.sequence.dna_embedding.DNAEmbeddingModelHFCaduceus", 37 | "dna_embedding_janusdna": "src.models.sequence.dna_embedding.DNAEmbeddingModelJanusDNA", 38 | 39 | # Baseline for genomics benchmark 40 | "genomics_benchmark_cnn": "src.models.baseline.genomics_benchmark_cnn.GenomicsBenchmarkCNN", 41 | } 42 | 43 | layer = { 44 | "id": "src.models.sequence.base.SequenceIdentity", 45 | "ff": "src.models.sequence.ff.FF", 46 | "hyena": "src.models.sequence.hyena.HyenaOperator", 47 | "hyena-filter": "src.models.sequence.hyena.HyenaFilter", 48 | } 49 | 50 | callbacks = { 51 | "learning_rate_monitor": "pytorch_lightning.callbacks.LearningRateMonitor", 52 | "model_checkpoint": "pytorch_lightning.callbacks.ModelCheckpoint", 53 | "model_checkpoint_every_n_steps": "pytorch_lightning.callbacks.ModelCheckpoint", 54 | "model_checkpoint_every_epoch": "pytorch_lightning.callbacks.ModelCheckpoint", 55 | "early_stopping": "pytorch_lightning.callbacks.EarlyStopping", 56 | "swa": "pytorch_lightning.callbacks.StochasticWeightAveraging", 57 | "rich_model_summary": "pytorch_lightning.callbacks.RichModelSummary", 58 | "rich_progress_bar": "pytorch_lightning.callbacks.RichProgressBar", 59 | "params": "src.callbacks.params.ParamsLog", 60 | "timer": "src.callbacks.timer.Timer", 61 | "val_every_n_global_steps": "src.callbacks.validation.ValEveryNGlobalSteps", 62 | "lr_monitor": "pytorch_lightning.callbacks.LearningRateMonitor", 63 | } 64 | 65 | model_state_hook = { 66 | 'load_backbone': 'src.models.sequence.dna_embedding.load_backbone', 67 | 'load_backbone_caduceus_huggingface': 'src.models.sequence.dna_embedding.load_backbone_caduceus_huggingface', 68 | } 69 | -------------------------------------------------------------------------------- /configs/dataset/nucleotide_transformer.yaml: -------------------------------------------------------------------------------- 1 | _name_: nucleotide_transformer # this links to the overall SequenceDataset of all nucleotide transformer datasets 2 | train_val_split_seed: ${train.seed} # Used for train/validation splitting 3 | dataset_name: enhancers # this specifies which dataset in nuc trx 4 | dest_path: null # path to overall nuc trx datasets 5 | max_length: ${.${.dataset_name}.max_length} 6 | d_output: ${.${.dataset_name}.classes} 7 | use_padding: True 8 | padding_side: left 9 | add_eos: False 10 | batch_size: 256 11 | train_len: ${.${.dataset_name}.train_len} 12 | __l_max: ${.max_length} 13 | shuffle: true # set this as default! 14 | metric: ${.${.dataset_name}.metric} 15 | # these are used to find the right attributes automatically for each dataset 16 | enhancers: 17 | train_len: 14968 18 | classes: 2 19 | max_length: 200 20 | metric: mcc 21 | enhancers_types: 22 | train_len: 14968 23 | classes: 3 24 | max_length: 200 25 | metric: mcc 26 | H3: 27 | train_len: 13468 28 | classes: 2 29 | max_length: 500 30 | metric: mcc 31 | H3K4me1: 32 | train_len: 28509 33 | classes: 2 34 | max_length: 500 35 | metric: mcc 36 | H3K4me2: 37 | train_len: 27614 38 | classes: 2 39 | max_length: 500 40 | metric: mcc 41 | H3K4me3: 42 | train_len: 33119 43 | classes: 2 44 | max_length: 500 45 | metric: mcc 46 | H3K9ac: 47 | train_len: 25003 48 | classes: 2 49 | max_length: 500 50 | metric: mcc 51 | H3K14ac: 52 | train_len: 29743 53 | classes: 2 54 | max_length: 500 55 | metric: mcc 56 | H3K36me3: 57 | train_len: 31392 58 | classes: 2 59 | max_length: 500 60 | metric: mcc 61 | H3K79me3: 62 | train_len: 25953 63 | classes: 2 64 | max_length: 500 65 | metric: mcc 66 | H4: 67 | train_len: 13140 68 | classes: 2 69 | max_length: 500 70 | metric: mcc 71 | H4ac: 72 | train_len: 30685 73 | classes: 2 74 | max_length: 500 75 | metric: mcc 76 | promoter_all: 77 | train_len: 53276 78 | classes: 2 79 | max_length: 300 80 | metric: f1_binary 81 | promoter_no_tata: 82 | train_len: 47767 83 | classes: 2 84 | max_length: 300 85 | metric: f1_binary 86 | promoter_tata: 87 | train_len: 5517 88 | classes: 2 89 | max_length: 300 90 | metric: f1_binary 91 | splice_sites_acceptors: 92 | train_len: 19961 93 | classes: 2 94 | max_length: 600 95 | metric: f1_binary 96 | splice_sites_all: 97 | train_len: 27000 98 | classes: 3 99 | max_length: 400 100 | metric: accuracy 101 | splice_sites_donors: 102 | train_len: 19775 103 | classes: 2 104 | max_length: 600 105 | metric: f1_binary 106 | 107 | # name maxlen classes samples metric 108 | 109 | # enhancers 200 2 14968 MCC 110 | # enhancers_types 200 3 14968 MCC 111 | # H3 500 2 13468 MCC 112 | # H3K4me1 500 2 28509 MCC 113 | # H3K4me2 500 2 27614 MCC 114 | # H3K4me3 500 2 33119 MCC 115 | # H3K9ac 500 2 25003 MCC 116 | # H3K14ac 500 2 29743 MCC 117 | # H3K36me3 500 2 31392 MCC 118 | # H3K79me3 500 2 25953 MCC 119 | # H4 500 2 13140 MCC 120 | # H4ac 500 2 30685 MCC 121 | # promoter_all 300 2 53276 F1 122 | # promoter_no_tata 300 2 47759 F1 123 | # promoter_tata 300 2 5517 F1 124 | # splice_sites_acceptor 600 2 19961 F1 125 | # splice_sites_all 400 2 27000 F1 126 | # splice_sites_donor 600 2 19775 F1 127 | -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - experiment: ??? 5 | # - model: ??? # Model backbone 6 | # - pipeline: ??? # Specifies collection of configs, equivalent to next 5 lines 7 | # Pipelines should specify /loader, /dataset, /task, /encoder, /decoder (ideally in that order) 8 | # # - loader: default # Dataloader (e.g. handles batches) 9 | # # - dataset: cifar # Defines the data (x and y pairs) 10 | # # - task: multiclass_classification # Defines loss and metrics 11 | # # - encoder: null # Interface between data and model 12 | # # - decoder: null # Interface between model and targets 13 | 14 | # Additional arguments used to configure the training loop 15 | # Most of these set combinations of options in the PL trainer, add callbacks, or add features to the optimizer 16 | train: 17 | seed: 0 18 | # These three options are used by callbacks (checkpoint, monitor) and scheduler 19 | # Most of them are task dependent and are set by the pipeline 20 | interval: ??? # Should be specified by scheduler. Also used by LR monitor 21 | monitor: ??? # Should be specified by pipeline. Used by scheduler (plateau) and checkpointer 22 | mode: ??? # Should be specified by pipeline. Used by scheduler (plateau) and checkpointer 23 | ema: 0.0 # Moving average model for validation 24 | test: True # Test after training 25 | debug: False # Special settings to make debugging more convenient 26 | ignore_warnings: False # Disable python warnings 27 | 28 | optimizer_param_grouping: 29 | bias_weight_decay: False 30 | normalization_weight_decay: False 31 | 32 | # These control state passing between batches 33 | state: 34 | mode: null # [ None | 'none' | 'reset' | 'bptt' | 'tbptt' ] 35 | n_context: 0 # How many steps to use as memory context. Must be >= 0 or None (null), meaning infinite context 36 | n_context_eval: ${.n_context} # Context at evaluation time 37 | # Convenience keys to allow grouping runs 38 | 39 | ckpt: checkpoints/last.ckpt # Resume training 40 | 41 | disable_dataset: False # Disable dataset loading 42 | validate_at_start: false 43 | 44 | pretrained_model_path: null # Path to pretrained model 45 | pretrained_model_strict_load: true # Whether to load the pretrained model even if the model is not compatible 46 | pretrained_model_state_hook: # Hook called on the loaded model's state_dict 47 | _name_: null 48 | post_init_hook: # After initializing model, call method on model 49 | _name_: null 50 | 51 | layer_decay: # Used for ImageNet finetuning 52 | _name_: null 53 | decay: 0.7 54 | 55 | # We primarily use wandb so this is moved to top level in the config for convenience 56 | # Set `~wandb` or `wandb=null` or `wandb.mode=disabled` to disable logging 57 | # If other loggers are added, it would make sense to put this one level lower under train/ or logger/ 58 | wandb: 59 | project: dna 60 | group: "" 61 | job_type: training 62 | mode: online # choices=['online', 'offline', 'disabled'] 63 | name: null 64 | save_dir: "." 65 | id: ${.name} # pass correct id to resume experiment! 66 | # Below options should not need to be specified 67 | # entity: "" # set to name of your wandb team or just remove it 68 | # log_model: False 69 | # prefix: "" 70 | # job_type: "train" 71 | # tags: [] 72 | 73 | hydra: 74 | run: 75 | dir: ./outputs/${now:%Y-%m-%d}/${now:%H-%M-%S-%f} 76 | job: 77 | chdir: true 78 | -------------------------------------------------------------------------------- /scripts/benchmark/dnalong/eqtl_evaluation_janus.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=eqtl_eval 3 | #SBATCH --partition=gpu 4 | #SBATCH --nodes=1 5 | #SBATCH --gres=gpu:1 6 | #SBATCH --ntasks-per-node=1 7 | #SBATCH --cpus-per-task=16 8 | #SBATCH --nodelist=s-sc-gpu[002-028] 9 | #SBATCH --mem=128G 10 | #SBATCH --requeue 11 | #SBATCH --time=2-00:00:00 12 | #SBATCH --mail-type=ALL 13 | #SBATCH --mail-user= 14 | #SBATCH --output=/my_job_%j_%t.out 15 | #SBATCH --error=/my_job_%j_%t.err 16 | 17 | # env params 18 | export LD_LIBRARY_PATH=/janusdna/lib/python3.11/site-packages/nvidia/nvjitlink/lib:$LD_LIBRARY_PATH 19 | export HYDRA_FULL_ERROR=1 20 | 21 | source /etc/profile.d/conda.sh 22 | 23 | conda activate janusdna 24 | 25 | # work dir 26 | cd 27 | 28 | ls train.py || echo "train.py not found!" 29 | 30 | # mission params 31 | CELL_TYPE="Whole_Blood" 32 | LR="4e-4" # same as the pre-trained to locate the pre-trained dir 33 | BATCH_SIZE=8 34 | SEED=1 35 | NUM_GPUS=1 36 | FINETUNED_EPOCH=3 37 | OUTPUT_ROUTER_LOGITS="true" 38 | 39 | # model params 40 | MODEL="janusdna" 41 | MODEL_NAME="dna_embedding_janusdna" 42 | RC_AUG="false" 43 | CONJOIN_TRAIN_DECODER="false" 44 | CONJOIN_TEST="true" 45 | FREEZE_BACKBONE="false" 46 | 47 | MODEL_PRETRAINED_DIRNAME="janusdna_len-131k_d_model-144_inter_dim-576_n_layer-8_lr-8e-3_step-50K_moeloss-true_1head_onlymoe" 48 | 49 | 50 | WANDB_NAME="${CELL_TYPE}_lr-${LR}_cjtrain_${CONJOIN_TRAIN_DECODER}_batch_${BATCH_SIZE}_seed_${SEED}" 51 | ROOT_DIR="" 52 | HYDRA_RUN_DIR="${ROOT_DIR}/outputs/downstream/longrange_benchmark/eqtl/${MODEL_PRETRAINED_DIRNAME}/${CELL_TYPE}/${WANDB_NAME}_cjtest_${CONJOIN_TEST}" 53 | LOG_BASE_DIR="${ROOT_DIR}/watch_folder/DNALong/eQTL" 54 | LOG_DIR="${LOG_BASE_DIR}/${MODEL_PRETRAINED_DIRNAME}/${CELL_TYPE}" 55 | 56 | LOG_FILE="${LOG_DIR}/${WANDB_NAME}_cjtest_${CONJOIN_TEST}.log" 57 | EVAL_OUTPUT_FILE="${LOG_DIR}/${WANDB_NAME}_cjtest_${CONJOIN_TEST}_output.txt" 58 | 59 | mkdir -p "${HYDRA_RUN_DIR}" 60 | mkdir -p "${LOG_DIR}" 61 | 62 | FINETUNE_BASE_DIR="${ROOT_DIR}/outputs/downstream/longrange_benchmark/eqtl/${MODEL_PRETRAINED_DIRNAME}/${CELL_TYPE}/${WANDB_NAME}" 63 | CONFIG_PATH="${FINETUNE_BASE_DIR}/model_config.json" 64 | PRETRAINED_WEIGHT_PATH="${FINETUNE_BASE_DIR}/checkpoints/last.ckpt" 65 | 66 | srun python -m evaluation wandb=null experiment=hg38/eqtl \ 67 | dataset.batch_size=1 \ 68 | dataset.cell_type=${CELL_TYPE} \ 69 | dataset.dest_path="${ROOT_DIR}/data" \ 70 | +dataset.conjoin_test="${CONJOIN_TEST}" \ 71 | model="${MODEL}" \ 72 | model._name_="${MODEL_NAME}" \ 73 | +model.config_path=${CONFIG_PATH} \ 74 | +model.conjoin_test="${CONJOIN_TEST}" \ 75 | +model.config.output_router_logits="${OUTPUT_ROUTER_LOGITS}" \ 76 | +model.config.router_aux_loss_coef=0.02 \ 77 | decoder.mode="pool" \ 78 | train.pretrained_model_path=${PRETRAINED_WEIGHT_PATH} \ 79 | train.pretrained_model_strict_load=True \ 80 | +train.eval_log_path="${EVAL_OUTPUT_FILE}" \ 81 | train.pretrained_model_state_hook._name_=null \ 82 | train.test=True \ 83 | +train.remove_val_loader_in_eval=True \ 84 | train.remove_test_loader_in_eval=False \ 85 | trainer.precision=32 \ 86 | +decoder.conjoin_test="${CONJOIN_TEST}" \ 87 | hydra.run.dir="${HYDRA_RUN_DIR}" \ 88 | > ${LOG_FILE} 2>&1 89 | -------------------------------------------------------------------------------- /configs/experiment/hg38/eqtl.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /pipeline: eqtl_benchmark 4 | - override /scheduler: cosine_warmup_timm 5 | 6 | # there are 8 datasets in this suite, choose 1 at a time, with their corresponding settings 7 | # name num_seqs num_classes median len std 8 | # dummy_mouse_enhancers_ensembl 1210 2 2381 984.4 9 | # demo_coding_vs_intergenomic_seqs 100_000 2 200 0 10 | # demo_human_or_worm 100_000 2 200 0 11 | # human_enhancers_cohn 27791 2 500 0 12 | # human_enhancers_ensembl 154842 2 269 122.6 13 | # human_ensembl_regulatory 289061 3 401 184.3 14 | # human_nontata_promoters 36131 2 251 0 15 | # human_ocr_ensembl 174756 2 315 108.1 16 | 17 | 18 | model: 19 | _name_: dna_embedding 20 | 21 | 22 | # new task, allows you to pass a mask (or not), and will only average over those tokens 23 | task: 24 | _name_: eqtl 25 | loss: cross_entropy 26 | metrics: 27 | - cross_entropy 28 | torchmetrics: null 29 | 30 | trainer: 31 | accelerator: gpu 32 | devices: 1 33 | num_nodes: 1 34 | accumulate_grad_batches: 1 # ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}} 35 | max_epochs: 10 36 | precision: bf16 # bf16 only a100 37 | gradient_clip_val: 1.0 38 | # strategy: null 39 | 40 | 41 | dataset: 42 | # batch_size: 32 # Per GPU 43 | batch_size: 1 44 | # max_length: 256 # 262144, 524288 45 | # optional, default is max_length 46 | # max_length_val: ${dataset.max_length} 47 | # max_length_test: ${dataset.max_length} 48 | tokenizer_name: char 49 | add_eos: false 50 | rc_aug: false # reverse complement augmentation 51 | return_mask: false 52 | padding_side: left # right is ok too, depending on what you want to do 53 | 54 | # scheduler: 55 | # t_in_epochs: False 56 | # t_initial: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}} 57 | # warmup_lr_init: 1e-6 58 | # warmup_t: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.01} 59 | # lr_min: ${eval:0.1 * ${optimizer.lr}} 60 | scheduler: 61 | # _name_: cosine_annealing # 或使用 "cosine_warmup_timm" 的epoch模式 62 | t_in_epochs: true # 关键:启用epoch模式 63 | t_initial: ${trainer.max_epochs} # 总epoch数 64 | lr_min: ${eval:0.1 * ${optimizer.lr}} # 最低学习率 65 | warmup_lr_init: 1e-7 # Warmup起始学习率 66 | warmup_t: 0.3 # Warmup占30%的epoch(即0.9个epoch) 67 | 68 | 69 | optimizer: 70 | lr: 5e-6 71 | weight_decay: 0.1 72 | 73 | train: 74 | gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"} 75 | seed: 2222 76 | global_batch_size: ${dataset.batch_size} 77 | remove_test_loader_in_eval: true # no test set in this benchmark 78 | pretrained_model_strict_load: False # false allows encoder/decoder to be used if new model uses it 79 | # for loading backbone and not head, requires both of these flags below 80 | pretrained_model_path: null # pretrained_models/weights.ckpt 81 | pretrained_model_state_hook: # !for train, need to be set as load_backbone; for evaluation, need to be disabled. 82 | _name_: load_backbone 83 | freeze_backbone: false # seems to work much better if false (ie finetune entire model) 84 | -------------------------------------------------------------------------------- /src/utils/optim/schedulers.py: -------------------------------------------------------------------------------- 1 | """Custom learning rate schedulers""" 2 | 3 | import math 4 | import warnings 5 | import torch 6 | 7 | from timm.scheduler import CosineLRScheduler 8 | 9 | 10 | # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html 11 | class CosineWarmup(torch.optim.lr_scheduler.CosineAnnealingLR): 12 | 13 | def __init__(self, optimizer, T_max, eta_min=0, warmup_step=0, **kwargs): 14 | self.warmup_step = warmup_step 15 | super().__init__(optimizer, T_max - warmup_step, eta_min, *kwargs) 16 | 17 | # Copied from CosineAnnealingLR, but adding warmup and changing self.last_epoch to 18 | # self.last_epoch - self.warmup_step. 19 | def get_lr(self): 20 | if not self._get_lr_called_within_step: 21 | warnings.warn("To get the last learning rate computed by the scheduler, " 22 | "please use `get_last_lr()`.", UserWarning) 23 | 24 | if self.last_epoch == self.warmup_step: # also covers the case where both are 0 25 | return self.base_lrs 26 | elif self.last_epoch < self.warmup_step: 27 | return [base_lr * (self.last_epoch + 1) / self.warmup_step for base_lr in self.base_lrs] 28 | elif (self.last_epoch - self.warmup_step - 1 - self.T_max) % (2 * self.T_max) == 0: 29 | return [group['lr'] + (base_lr - self.eta_min) * 30 | (1 - math.cos(math.pi / self.T_max)) / 2 31 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)] 32 | return [(1 + math.cos(math.pi * (self.last_epoch - self.warmup_step) / self.T_max)) / 33 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_step - 1) / self.T_max)) * 34 | (group['lr'] - self.eta_min) + self.eta_min 35 | for group in self.optimizer.param_groups] 36 | 37 | _get_closed_form_lr = None 38 | 39 | 40 | def InvSqrt(optimizer, warmup_step): 41 | """ Originally used for Transformer (in Attention is all you need) 42 | """ 43 | 44 | def lr_lambda(step): 45 | # return a multiplier instead of a learning rate 46 | if step == warmup_step: # also covers the case where both are 0 47 | return 1. 48 | else: 49 | return 1. / (step ** 0.5) if step > warmup_step else (step + 1) / (warmup_step ** 1.5) 50 | 51 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) 52 | 53 | 54 | def Constant(optimizer, warmup_step): 55 | 56 | def lr_lambda(step): 57 | if step == warmup_step: # also covers the case where both are 0 58 | return 1. 59 | else: 60 | return 1. if step > warmup_step else (step + 1) / warmup_step 61 | 62 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) 63 | 64 | 65 | class TimmCosineLRScheduler(CosineLRScheduler, torch.optim.lr_scheduler._LRScheduler): 66 | """ Wrap timm.scheduler.CosineLRScheduler so we can call scheduler.step() without passing in epoch. 67 | It supports resuming as well. 68 | """ 69 | 70 | def __init__(self, *args, **kwargs): 71 | super().__init__(*args, **kwargs) 72 | self._last_epoch = -1 73 | self.step(epoch=0) 74 | 75 | def step(self, epoch=None): 76 | if epoch is None: 77 | self._last_epoch += 1 78 | else: 79 | self._last_epoch = epoch 80 | # We call either step or step_update, depending on whether we're using the scheduler every 81 | # epoch or every step. 82 | # Otherwise, lightning will always call step (i.e., meant for each epoch), and if we set 83 | # scheduler interval to "step", then the learning rate update will be wrong. 84 | if self.t_in_epochs: 85 | super().step(epoch=self._last_epoch) 86 | else: 87 | super().step_update(num_updates=self._last_epoch) 88 | -------------------------------------------------------------------------------- /src/callbacks/timer.py: -------------------------------------------------------------------------------- 1 | """Callback to monitor the speed of each step and each epoch. 2 | 3 | https://github.com/HazyResearch/transformers/blob/master/src/callbacks/speed_monitor.py 4 | Adapted from: 5 | https://pytorch-lightning.readthedocs.io/en/latest/_modules/pytorch_lightning/callbacks/gpu_stats_monitor.html#GPUStatsMonitor 6 | """ 7 | 8 | # We only need the speed monitoring, not the GPU monitoring 9 | import time 10 | from typing import Any 11 | 12 | from pytorch_lightning import Callback, Trainer, LightningModule 13 | from pytorch_lightning.utilities import rank_zero_only 14 | from pytorch_lightning.utilities.parsing import AttributeDict 15 | from pytorch_lightning.utilities.types import STEP_OUTPUT 16 | 17 | 18 | class Timer(Callback): 19 | """Monitor the speed of each step and each epoch. 20 | """ 21 | def __init__( 22 | self, 23 | step: bool = True, 24 | inter_step: bool = True, 25 | epoch: bool = True, 26 | val: bool = True, 27 | ): 28 | super().__init__() 29 | self._log_stats = AttributeDict( { 30 | 'step_time': step, 31 | 'inter_step_time': inter_step, 32 | 'epoch_time': epoch, 33 | 'val_time': val, 34 | }) 35 | 36 | def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: 37 | self._snap_epoch_time = None 38 | 39 | def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: 40 | self._snap_step_time = None 41 | self._snap_inter_step_time = None 42 | self._snap_epoch_time = time.time() 43 | 44 | def on_train_batch_start( 45 | self, 46 | trainer: Trainer, 47 | pl_module: LightningModule, 48 | batch: Any, 49 | batch_idx: int, 50 | ) -> None: 51 | if self._log_stats.step_time: 52 | self._snap_step_time = time.time() 53 | 54 | if not self._should_log(trainer): 55 | return 56 | 57 | logs = {} 58 | if self._log_stats.inter_step_time and self._snap_inter_step_time: 59 | # First log at beginning of second step 60 | logs["timer/inter_step"] = (time.time() - self._snap_inter_step_time) # * 1000 61 | 62 | if trainer.logger: trainer.logger.log_metrics(logs, step=trainer.global_step) 63 | 64 | @rank_zero_only 65 | def on_train_batch_end( 66 | self, 67 | trainer: Trainer, 68 | pl_module: LightningModule, 69 | outputs: STEP_OUTPUT, 70 | batch: Any, 71 | batch_idx: int, 72 | ) -> None: 73 | if self._log_stats.inter_step_time: 74 | self._snap_inter_step_time = time.time() 75 | 76 | if not self._should_log(trainer): 77 | return 78 | 79 | logs = {} 80 | if self._log_stats.step_time and self._snap_step_time: 81 | logs["timer/step"] = (time.time() - self._snap_step_time) # * 1000 82 | 83 | if trainer.logger: trainer.logger.log_metrics(logs, step=trainer.global_step) 84 | 85 | @rank_zero_only 86 | def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule,) -> None: 87 | logs = {} 88 | if self._log_stats.epoch_time and self._snap_epoch_time: 89 | logs["timer/epoch"] = time.time() - self._snap_epoch_time 90 | if trainer.logger: trainer.logger.log_metrics(logs, step=trainer.global_step) 91 | 92 | def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: 93 | self._snap_val_time = time.time() 94 | 95 | @rank_zero_only 96 | def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule,) -> None: 97 | logs = {} 98 | if self._log_stats.val_time and self._snap_val_time: 99 | logs["timer/validation"] = time.time() - self._snap_val_time 100 | if trainer.logger: trainer.logger.log_metrics(logs) # , step=trainer.global_step) 101 | 102 | @staticmethod 103 | def _should_log(trainer) -> bool: 104 | return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop 105 | -------------------------------------------------------------------------------- /src/utils/config.py: -------------------------------------------------------------------------------- 1 | """Utilities for dealing with collection objects (lists, dicts) and configs. 2 | 3 | """ 4 | 5 | import functools 6 | from typing import Sequence, Mapping, Callable 7 | 8 | import hydra 9 | from omegaconf import ListConfig, DictConfig 10 | 11 | 12 | # TODO this is usually used in a pattern where it's turned into a list, so can just do that here 13 | def is_list(x): 14 | return isinstance(x, Sequence) and not isinstance(x, str) 15 | 16 | 17 | def is_dict(x): 18 | return isinstance(x, Mapping) 19 | 20 | 21 | def to_dict(x, recursive=True): 22 | """Convert Sequence or Mapping object to dict 23 | 24 | lists get converted to {0: x[0], 1: x[1], ...} 25 | """ 26 | if is_list(x): 27 | x = {i: v for i, v in enumerate(x)} 28 | if is_dict(x): 29 | if recursive: 30 | return {k: to_dict(v, recursive=recursive) for k, v in x.items()} 31 | else: 32 | return dict(x) 33 | else: 34 | return x 35 | 36 | 37 | def to_list(x, recursive=False): 38 | """Convert an object to list. 39 | 40 | If Sequence (e.g. list, tuple, Listconfig): just return it 41 | 42 | Special case: If non-recursive and not a list, wrap in list 43 | """ 44 | if is_list(x): 45 | if recursive: 46 | return [to_list(_x) for _x in x] 47 | else: 48 | return list(x) 49 | else: 50 | if recursive: 51 | return x 52 | else: 53 | return [x] 54 | 55 | 56 | def extract_attrs_from_obj(obj, *attrs): 57 | if obj is None: 58 | assert len(attrs) == 0 59 | return [] 60 | return [getattr(obj, attr, None) for attr in attrs] 61 | 62 | 63 | def auto_assign_attrs(cls, **kwargs): 64 | for k, v in kwargs.items(): 65 | setattr(cls, k, v) 66 | 67 | 68 | def instantiate(registry, config, *args, partial=False, wrap=None, **kwargs): 69 | """ 70 | registry: Dictionary mapping names to functions or target paths (e.g. {'model': 'models.SequenceModel'}) 71 | config: Dictionary with a '_name_' key indicating which element of the registry to grab, and kwargs to be passed into the target constructor 72 | wrap: wrap the target class (e.g. ema optimizer or tasks.wrap) 73 | *args, **kwargs: additional arguments to override the config to pass into the target constructor 74 | """ 75 | 76 | # Case 1: no config 77 | if config is None: 78 | return None 79 | # Case 2a: string means _name_ was overloaded 80 | if isinstance(config, str): 81 | _name_ = None 82 | _target_ = registry[config] 83 | config = {} 84 | # Case 2b: grab the desired callable from name 85 | else: 86 | _name_ = config.pop("_name_") 87 | _target_ = registry[_name_] 88 | 89 | # Retrieve the right constructor automatically based on type 90 | if isinstance(_target_, str): 91 | fn = hydra.utils.get_method(path=_target_) 92 | elif isinstance(_target_, Callable): 93 | fn = _target_ 94 | else: 95 | raise NotImplementedError("instantiate target must be string or callable") 96 | 97 | # Instantiate object 98 | if wrap is not None: 99 | fn = wrap(fn) 100 | obj = functools.partial(fn, *args, **config, **kwargs) 101 | 102 | # Restore _name_ 103 | if _name_ is not None: 104 | config["_name_"] = _name_ 105 | 106 | if partial: 107 | return obj 108 | else: 109 | return obj() 110 | 111 | 112 | def get_class(registry, _name_): 113 | return hydra.utils.get_class(path=registry[_name_]) 114 | 115 | 116 | def omegaconf_filter_keys(d, fn=None): 117 | """Only keep keys where fn(key) is True. Support nested DictConfig. 118 | # TODO can make this inplace? 119 | """ 120 | if fn is None: 121 | fn = lambda _: True 122 | if is_list(d): 123 | return ListConfig([omegaconf_filter_keys(v, fn) for v in d]) 124 | elif is_dict(d): 125 | return DictConfig( 126 | {k: omegaconf_filter_keys(v, fn) for k, v in d.items() if fn(k)} 127 | ) 128 | else: 129 | return d 130 | -------------------------------------------------------------------------------- /scripts/pre_train/slurm_JanusDNA_w_midattn_32dim.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ####SBATCH --job-name=janus_w_midattn # Specify job name 3 | ####SBATCH --partition=pgpu # Specify partition name 4 | ####SBATCH --nodes=1 # Specify number of nodes 5 | ####SBATCH --gres=gpu:8 # Generic resources; 1 GPU 6 | ####SBATCH --ntasks-per-node=8 # each gpu is a task, a 4 gpu mission requires 4 tasks 7 | ####SBATCH --cpus-per-task=12 8 | ####SBATCH --mem=128G # Request memory 9 | ####SBATCH --exclusive 10 | ####SBATCH --requeue 11 | ####SBATCH --time=2-00:00:00 # Set a limit on the total run time 12 | ####SBATCH --nodelist=s-sc-pgpu[01-08] 13 | ####SBATCH --mail-type=ALL # Notify user by email in case of job failure 14 | ####SBATCH --mail-user= 15 | #### SBATCH --output=/my_job%j # File name for standard output 16 | #### SBATCH --error=/my_job%j # File name for standard error output 17 | 18 | 19 | 20 | 21 | 22 | export LD_LIBRARY_PATH=/janusdna/lib/python3.11/site-packages/nvidia/nvjitlink/lib:$LD_LIBRARY_PATH 23 | 24 | source /etc/profile.d/conda.sh 25 | conda activate janusdna 26 | cd 27 | 28 | full_path_to_root="" 29 | 30 | 31 | export HYDRA_FULL_ERROR=1 32 | 33 | NUM_DEVICES=8 34 | 35 | # Run script 36 | SEQLEN=1024 37 | MAX_STEPS=10000 38 | GRADIENT_UPDATE_NUM="$((MAX_STEPS / 1000))K" 39 | 40 | D_MODEL=32 41 | FLEX_ATTN_MODEL=64 # should be multiple of 2, and 64 is the minimum. 42 | INTER_FFN_MODEL=128 # 4x is the best 43 | 44 | N_LAYER=8 45 | LR="8e-3" 46 | 47 | RCPS="false" 48 | RC_AUG="false" 49 | BIDIRECTIONAL_WEIGHT_TIE="true" 50 | BIDIRECTIONAL_ATTN_TIE="false" 51 | ROUTER_AUX_LOSS_COEF=0.2 52 | OUTPUT_ROUTER_LOGITS="true" 53 | 54 | BATCH_SIZE=$(( 1048576 / SEQLEN )) 55 | 56 | SEQLEN_DIS="$((SEQLEN / 1000))k" 57 | WANDB_NAME="janusdna_len-${SEQLEN_DIS}_d_model-${D_MODEL}_inter_dim-${INTER_FFN_MODEL}_n_layer-${N_LAYER}_lr-${LR}_step-${GRADIENT_UPDATE_NUM}_moeloss-${OUTPUT_ROUTER_LOGITS}_1head_midattn" 58 | HYDRA_RUN_DIR="${full_path_to_root}/outputs/pretrain/hg38/${WANDB_NAME}" 59 | WATCH_DIR="${full_path_to_root}/watch_folder/pretrain" 60 | 61 | export WANDBID=$(python -c "import wandb; print(wandb.util.generate_id())") 62 | 63 | mkdir -p "${HYDRA_RUN_DIR}" 64 | mkdir -p "${WATCH_DIR}" 65 | srun python -m train \ 66 | experiment=hg38/hg38 \ 67 | callbacks.model_checkpoint_every_n_steps.every_n_train_steps=500 \ 68 | dataset.max_length=${SEQLEN} \ 69 | dataset.batch_size=$(( BATCH_SIZE / NUM_DEVICES )) \ 70 | dataset.batch_size_eval=$(( BATCH_SIZE / NUM_DEVICES )) \ 71 | dataset.mlm=False \ 72 | dataset.rc_aug="${RC_AUG}" \ 73 | dataset.add_eos=false \ 74 | loader.num_workers=0 \ 75 | model="janusdna" \ 76 | +model.config.output_router_logits="${OUTPUT_ROUTER_LOGITS}" \ 77 | +model.config.router_aux_loss_coef="${ROUTER_AUX_LOSS_COEF}" \ 78 | model.config.bidirectional_weight_tie="${BIDIRECTIONAL_WEIGHT_TIE}" \ 79 | model.config.bidirectional_attn_tie="${BIDIRECTIONAL_ATTN_TIE}" \ 80 | model.config.num_hidden_layers=${N_LAYER} \ 81 | model.config.hidden_size=${D_MODEL} \ 82 | model.config.flex_attn_n_embd=${FLEX_ATTN_MODEL} \ 83 | +model.config.intermediate_size=${INTER_FFN_MODEL} \ 84 | model.config.expert_layer_period=2 \ 85 | model.config.expert_layer_offset=1 \ 86 | model.config.intermediate_factor=4 \ 87 | model.config.num_attention_heads=4 \ 88 | model.config.attn_implementation="flash_attention_2" \ 89 | model.config.attn_layer_period=8 \ 90 | model.config.attn_layer_offset=4 \ 91 | optimizer.lr="${LR}" \ 92 | train.global_batch_size=${BATCH_SIZE} \ 93 | trainer.max_steps=${MAX_STEPS} \ 94 | trainer.precision=bf16-mixed \ 95 | trainer.devices=${NUM_DEVICES} \ 96 | +trainer.val_check_interval=$(( MAX_STEPS / 5 )) \ 97 | +trainer.strategy="ddp_find_unused_parameters_true" \ 98 | wandb.group=pretrain_hg38 \ 99 | wandb.name="${WANDB_NAME}" \ 100 | wandb.mode=online \ 101 | wandb.id=${WANDBID} \ 102 | hydra.run.dir="${HYDRA_RUN_DIR}" \ 103 | > ${WATCH_DIR}/${WANDB_NAME}.log 2>&1 104 | 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /scripts/pre_train/slurm_JanusDNA_w_midattn_72dim.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ####SBATCH --job-name=janus_w_midattn # Specify job name 3 | ####SBATCH --partition=pgpu # Specify partition name 4 | ####SBATCH --nodes=1 # Specify number of nodes 5 | ####SBATCH --gres=gpu:8 # Generic resources; 1 GPU 6 | ####SBATCH --ntasks-per-node=8 # each gpu is a task, a 4 gpu mission requires 4 tasks 7 | ####SBATCH --cpus-per-task=12 8 | ####SBATCH --mem=128G # Request memory 9 | ####SBATCH --exclusive 10 | ####SBATCH --requeue 11 | ####SBATCH --time=2-00:00:00 # Set a limit on the total run time 12 | ####SBATCH --nodelist=s-sc-pgpu[01-08] 13 | ####SBATCH --mail-type=ALL # Notify user by email in case of job failure 14 | ####SBATCH --mail-user= 15 | #### SBATCH --output=/my_job%j # File name for standard output 16 | #### SBATCH --error=/my_job%j # File name for standard error output 17 | 18 | 19 | 20 | 21 | 22 | export LD_LIBRARY_PATH=/janusdna/lib/python3.11/site-packages/nvidia/nvjitlink/lib:$LD_LIBRARY_PATH 23 | 24 | source /etc/profile.d/conda.sh 25 | conda activate janusdna 26 | cd 27 | 28 | full_path_to_root="" 29 | 30 | 31 | export HYDRA_FULL_ERROR=1 32 | 33 | NUM_DEVICES=8 34 | 35 | # Run script 36 | SEQLEN=1024 37 | MAX_STEPS=10000 38 | GRADIENT_UPDATE_NUM="$((MAX_STEPS / 1000))K" 39 | 40 | D_MODEL=72 41 | FLEX_ATTN_MODEL=128 # should be multiple of 2, and 64 is the minimum. 42 | INTER_FFN_MODEL=288 # 4x is the best 43 | 44 | N_LAYER=8 45 | LR="8e-3" 46 | 47 | RCPS="false" 48 | RC_AUG="false" 49 | BIDIRECTIONAL_WEIGHT_TIE="true" 50 | BIDIRECTIONAL_ATTN_TIE="false" 51 | ROUTER_AUX_LOSS_COEF=0.2 52 | OUTPUT_ROUTER_LOGITS="true" 53 | 54 | BATCH_SIZE=$(( 1048576 / SEQLEN )) 55 | 56 | SEQLEN_DIS="$((SEQLEN / 1000))k" 57 | WANDB_NAME="janusdna_len-${SEQLEN_DIS}_d_model-${D_MODEL}_inter_dim-${INTER_FFN_MODEL}_n_layer-${N_LAYER}_lr-${LR}_step-${GRADIENT_UPDATE_NUM}_moeloss-${OUTPUT_ROUTER_LOGITS}_1head_midattn" 58 | HYDRA_RUN_DIR="${full_path_to_root}/outputs/pretrain/hg38/${WANDB_NAME}" 59 | WATCH_DIR="${full_path_to_root}/watch_folder/pretrain" 60 | 61 | export WANDBID=$(python -c "import wandb; print(wandb.util.generate_id())") 62 | 63 | mkdir -p "${HYDRA_RUN_DIR}" 64 | mkdir -p "${WATCH_DIR}" 65 | srun python -m train \ 66 | experiment=hg38/hg38 \ 67 | callbacks.model_checkpoint_every_n_steps.every_n_train_steps=500 \ 68 | dataset.max_length=${SEQLEN} \ 69 | dataset.batch_size=$(( BATCH_SIZE / NUM_DEVICES )) \ 70 | dataset.batch_size_eval=$(( BATCH_SIZE / NUM_DEVICES )) \ 71 | dataset.mlm=False \ 72 | dataset.rc_aug="${RC_AUG}" \ 73 | dataset.add_eos=false \ 74 | loader.num_workers=0 \ 75 | model="janusdna" \ 76 | +model.config.output_router_logits="${OUTPUT_ROUTER_LOGITS}" \ 77 | +model.config.router_aux_loss_coef="${ROUTER_AUX_LOSS_COEF}" \ 78 | model.config.bidirectional_weight_tie="${BIDIRECTIONAL_WEIGHT_TIE}" \ 79 | model.config.bidirectional_attn_tie="${BIDIRECTIONAL_ATTN_TIE}" \ 80 | model.config.num_hidden_layers=${N_LAYER} \ 81 | model.config.hidden_size=${D_MODEL} \ 82 | model.config.flex_attn_n_embd=${FLEX_ATTN_MODEL} \ 83 | +model.config.intermediate_size=${INTER_FFN_MODEL} \ 84 | model.config.expert_layer_period=2 \ 85 | model.config.expert_layer_offset=1 \ 86 | model.config.intermediate_factor=4 \ 87 | model.config.num_attention_heads=4 \ 88 | model.config.attn_implementation="flash_attention_2" \ 89 | model.config.attn_layer_period=8 \ 90 | model.config.attn_layer_offset=4 \ 91 | optimizer.lr="${LR}" \ 92 | train.global_batch_size=${BATCH_SIZE} \ 93 | trainer.max_steps=${MAX_STEPS} \ 94 | trainer.precision=bf16-mixed \ 95 | trainer.devices=${NUM_DEVICES} \ 96 | +trainer.val_check_interval=$(( MAX_STEPS / 5 )) \ 97 | +trainer.strategy="ddp_find_unused_parameters_true" \ 98 | wandb.group=pretrain_hg38 \ 99 | wandb.name="${WANDB_NAME}" \ 100 | wandb.mode=online \ 101 | wandb.id=${WANDBID} \ 102 | hydra.run.dir="${HYDRA_RUN_DIR}" \ 103 | > ${WATCH_DIR}/${WANDB_NAME}.log 2>&1 104 | 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /scripts/pre_train/slurm_JanusDNA_wo_midattn_144dim.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ####SBATCH --job-name=janus_w/o_midattn # Specify job name 3 | ####SBATCH --partition=pgpu # Specify partition name 4 | ####SBATCH --nodes=1 # Specify number of nodes 5 | ####SBATCH --gres=gpu:8 # Generic resources; 1 GPU 6 | ####SBATCH --ntasks-per-node=8 # each gpu is a task, a 4 gpu mission requires 4 tasks 7 | ####SBATCH --cpus-per-task=12 8 | ####SBATCH --mem=128G # Request memory 9 | ####SBATCH --exclusive 10 | ####SBATCH --requeue 11 | ####SBATCH --time=2-00:00:00 # Set a limit on the total run time 12 | ####SBATCH --nodelist=s-sc-pgpu[01-08] 13 | ####SBATCH --mail-type=ALL # Notify user by email in case of job failure 14 | ####SBATCH --mail-user= 15 | #### SBATCH --output=/my_job%j # File name for standard output 16 | #### SBATCH --error=/my_job%j # File name for standard error output 17 | 18 | 19 | 20 | 21 | 22 | export LD_LIBRARY_PATH=/janusdna/lib/python3.11/site-packages/nvidia/nvjitlink/lib:$LD_LIBRARY_PATH 23 | 24 | source /etc/profile.d/conda.sh 25 | conda activate janusdna 26 | cd 27 | 28 | full_path_to_root="" 29 | 30 | 31 | export HYDRA_FULL_ERROR=1 32 | 33 | NUM_DEVICES=8 34 | 35 | # Run script 36 | SEQLEN=131072 37 | MAX_STEPS=50000 38 | GRADIENT_UPDATE_NUM="$((MAX_STEPS / 1000))K" 39 | 40 | D_MODEL=144 41 | FLEX_ATTN_MODEL=256 # should be multiple of 2, and 64 is the minimum. 42 | INTER_FFN_MODEL=576 # 4x is the best 43 | 44 | N_LAYER=8 45 | LR="8e-3" 46 | 47 | RCPS="false" 48 | RC_AUG="false" 49 | BIDIRECTIONAL_WEIGHT_TIE="true" 50 | BIDIRECTIONAL_ATTN_TIE="false" 51 | ROUTER_AUX_LOSS_COEF=0.2 52 | OUTPUT_ROUTER_LOGITS="true" 53 | 54 | BATCH_SIZE=$(( 1048576 / SEQLEN )) 55 | 56 | SEQLEN_DIS="$((SEQLEN / 1000))k" 57 | WANDB_NAME="janusdna_len-${SEQLEN_DIS}_d_model-${D_MODEL}_inter_dim-${INTER_FFN_MODEL}_n_layer-${N_LAYER}_lr-${LR}_step-${GRADIENT_UPDATE_NUM}_moeloss-${OUTPUT_ROUTER_LOGITS}_1head_onlymoe" 58 | HYDRA_RUN_DIR="${full_path_to_root}/outputs/pretrain/hg38/${WANDB_NAME}" 59 | WATCH_DIR="${full_path_to_root}/watch_folder/pretrain" 60 | 61 | export WANDBID=$(python -c "import wandb; print(wandb.util.generate_id())") 62 | 63 | mkdir -p "${HYDRA_RUN_DIR}" 64 | mkdir -p "${WATCH_DIR}" 65 | srun python -m train \ 66 | experiment=hg38/hg38 \ 67 | callbacks.model_checkpoint_every_n_steps.every_n_train_steps=500 \ 68 | dataset.max_length=${SEQLEN} \ 69 | dataset.batch_size=$(( BATCH_SIZE / NUM_DEVICES )) \ 70 | dataset.batch_size_eval=$(( BATCH_SIZE / NUM_DEVICES )) \ 71 | dataset.mlm=False \ 72 | dataset.rc_aug="${RC_AUG}" \ 73 | dataset.add_eos=false \ 74 | loader.num_workers=0 \ 75 | model="janusdna" \ 76 | +model.config.output_router_logits="${OUTPUT_ROUTER_LOGITS}" \ 77 | +model.config.router_aux_loss_coef="${ROUTER_AUX_LOSS_COEF}" \ 78 | model.config.bidirectional_weight_tie="${BIDIRECTIONAL_WEIGHT_TIE}" \ 79 | model.config.bidirectional_attn_tie="${BIDIRECTIONAL_ATTN_TIE}" \ 80 | model.config.num_hidden_layers=${N_LAYER} \ 81 | model.config.hidden_size=${D_MODEL} \ 82 | model.config.flex_attn_n_embd=${FLEX_ATTN_MODEL} \ 83 | +model.config.intermediate_size=${INTER_FFN_MODEL} \ 84 | model.config.expert_layer_period=2 \ 85 | model.config.expert_layer_offset=1 \ 86 | model.config.intermediate_factor=4 \ 87 | model.config.num_attention_heads=4 \ 88 | model.config.attn_implementation="flash_attention_2" \ 89 | model.config.attn_layer_period=8 \ 90 | model.config.attn_layer_offset=100 \ 91 | optimizer.lr="${LR}" \ 92 | train.global_batch_size=${BATCH_SIZE} \ 93 | trainer.max_steps=${MAX_STEPS} \ 94 | trainer.precision=bf16-mixed \ 95 | trainer.devices=${NUM_DEVICES} \ 96 | +trainer.val_check_interval=$(( MAX_STEPS / 5 )) \ 97 | +trainer.strategy="ddp_find_unused_parameters_true" \ 98 | wandb.group=pretrain_hg38 \ 99 | wandb.name="${WANDB_NAME}" \ 100 | wandb.mode=online \ 101 | wandb.id=${WANDBID} \ 102 | hydra.run.dir="${HYDRA_RUN_DIR}" \ 103 | > ${WATCH_DIR}/${WANDB_NAME}.log 2>&1 104 | 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /scripts/pre_train/slurm_JanusDNA_wo_midattn_32dim.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ####SBATCH --job-name=janus_w/o_midattn # Specify job name 3 | ####SBATCH --partition=pgpu # Specify partition name 4 | ####SBATCH --nodes=1 # Specify number of nodes 5 | ####SBATCH --gres=gpu:8 # Generic resources; 1 GPU 6 | ####SBATCH --ntasks-per-node=8 # each gpu is a task, a 4 gpu mission requires 4 tasks 7 | ####SBATCH --cpus-per-task=12 8 | ####SBATCH --mem=128G # Request memory 9 | ####SBATCH --exclusive 10 | ####SBATCH --requeue 11 | ####SBATCH --time=2-00:00:00 # Set a limit on the total run time 12 | ####SBATCH --nodelist=s-sc-pgpu[01-08] 13 | ####SBATCH --mail-type=ALL # Notify user by email in case of job failure 14 | ####SBATCH --mail-user= 15 | #### SBATCH --output=/my_job%j # File name for standard output 16 | #### SBATCH --error=/my_job%j # File name for standard error output 17 | 18 | 19 | 20 | 21 | 22 | export LD_LIBRARY_PATH=/janusdna/lib/python3.11/site-packages/nvidia/nvjitlink/lib:$LD_LIBRARY_PATH 23 | 24 | source /etc/profile.d/conda.sh 25 | conda activate janusdna 26 | cd 27 | 28 | full_path_to_root="" 29 | 30 | 31 | export HYDRA_FULL_ERROR=1 32 | 33 | NUM_DEVICES=8 34 | 35 | # Run script 36 | SEQLEN=1024 37 | MAX_STEPS=10000 38 | GRADIENT_UPDATE_NUM="$((MAX_STEPS / 1000))K" 39 | 40 | D_MODEL=32 41 | FLEX_ATTN_MODEL=64 # should be multiple of 2, and 64 is the minimum. 42 | INTER_FFN_MODEL=128 # 4x is the best 43 | 44 | N_LAYER=8 45 | LR="8e-3" 46 | 47 | RCPS="false" 48 | RC_AUG="false" 49 | BIDIRECTIONAL_WEIGHT_TIE="true" 50 | BIDIRECTIONAL_ATTN_TIE="false" 51 | ROUTER_AUX_LOSS_COEF=0.2 52 | OUTPUT_ROUTER_LOGITS="true" 53 | 54 | BATCH_SIZE=$(( 1048576 / SEQLEN )) 55 | 56 | SEQLEN_DIS="$((SEQLEN / 1000))k" 57 | WANDB_NAME="janusdna_len-${SEQLEN_DIS}_d_model-${D_MODEL}_inter_dim-${INTER_FFN_MODEL}_n_layer-${N_LAYER}_lr-${LR}_step-${GRADIENT_UPDATE_NUM}_moeloss-${OUTPUT_ROUTER_LOGITS}_1head_onlymoe" 58 | HYDRA_RUN_DIR="${full_path_to_root}/outputs/pretrain/hg38/${WANDB_NAME}" 59 | WATCH_DIR="${full_path_to_root}/watch_folder/pretrain" 60 | 61 | export WANDBID=$(python -c "import wandb; print(wandb.util.generate_id())") 62 | 63 | mkdir -p "${HYDRA_RUN_DIR}" 64 | mkdir -p "${WATCH_DIR}" 65 | srun python -m train \ 66 | experiment=hg38/hg38 \ 67 | callbacks.model_checkpoint_every_n_steps.every_n_train_steps=500 \ 68 | dataset.max_length=${SEQLEN} \ 69 | dataset.batch_size=$(( BATCH_SIZE / NUM_DEVICES )) \ 70 | dataset.batch_size_eval=$(( BATCH_SIZE / NUM_DEVICES )) \ 71 | dataset.mlm=False \ 72 | dataset.rc_aug="${RC_AUG}" \ 73 | dataset.add_eos=false \ 74 | loader.num_workers=0 \ 75 | model="janusdna" \ 76 | +model.config.output_router_logits="${OUTPUT_ROUTER_LOGITS}" \ 77 | +model.config.router_aux_loss_coef="${ROUTER_AUX_LOSS_COEF}" \ 78 | model.config.bidirectional_weight_tie="${BIDIRECTIONAL_WEIGHT_TIE}" \ 79 | model.config.bidirectional_attn_tie="${BIDIRECTIONAL_ATTN_TIE}" \ 80 | model.config.num_hidden_layers=${N_LAYER} \ 81 | model.config.hidden_size=${D_MODEL} \ 82 | model.config.flex_attn_n_embd=${FLEX_ATTN_MODEL} \ 83 | +model.config.intermediate_size=${INTER_FFN_MODEL} \ 84 | model.config.expert_layer_period=2 \ 85 | model.config.expert_layer_offset=1 \ 86 | model.config.intermediate_factor=4 \ 87 | model.config.num_attention_heads=4 \ 88 | model.config.attn_implementation="flash_attention_2" \ 89 | model.config.attn_layer_period=8 \ 90 | model.config.attn_layer_offset=100 \ 91 | optimizer.lr="${LR}" \ 92 | train.global_batch_size=${BATCH_SIZE} \ 93 | trainer.max_steps=${MAX_STEPS} \ 94 | trainer.precision=bf16-mixed \ 95 | trainer.devices=${NUM_DEVICES} \ 96 | +trainer.val_check_interval=$(( MAX_STEPS / 5 )) \ 97 | +trainer.strategy="ddp_find_unused_parameters_true" \ 98 | wandb.group=pretrain_hg38 \ 99 | wandb.name="${WANDB_NAME}" \ 100 | wandb.mode=online \ 101 | wandb.id=${WANDBID} \ 102 | hydra.run.dir="${HYDRA_RUN_DIR}" \ 103 | > ${WATCH_DIR}/${WANDB_NAME}.log 2>&1 104 | 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /scripts/pre_train/slurm_JanusDNA_wo_midattn_72dim.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ####SBATCH --job-name=janus_w/o_midattn # Specify job name 3 | ####SBATCH --partition=pgpu # Specify partition name 4 | ####SBATCH --nodes=1 # Specify number of nodes 5 | ####SBATCH --gres=gpu:8 # Generic resources; 1 GPU 6 | ####SBATCH --ntasks-per-node=8 # each gpu is a task, a 4 gpu mission requires 4 tasks 7 | ####SBATCH --cpus-per-task=12 8 | ####SBATCH --mem=128G # Request memory 9 | ####SBATCH --exclusive 10 | ####SBATCH --requeue 11 | ####SBATCH --time=2-00:00:00 # Set a limit on the total run time 12 | ####SBATCH --nodelist=s-sc-pgpu[01-08] 13 | ####SBATCH --mail-type=ALL # Notify user by email in case of job failure 14 | ####SBATCH --mail-user= 15 | #### SBATCH --output=/my_job%j # File name for standard output 16 | #### SBATCH --error=/my_job%j # File name for standard error output 17 | 18 | 19 | 20 | 21 | 22 | export LD_LIBRARY_PATH=/janusdna/lib/python3.11/site-packages/nvidia/nvjitlink/lib:$LD_LIBRARY_PATH 23 | 24 | source /etc/profile.d/conda.sh 25 | conda activate janusdna 26 | cd 27 | 28 | full_path_to_root="" 29 | 30 | 31 | export HYDRA_FULL_ERROR=1 32 | 33 | NUM_DEVICES=8 34 | 35 | # Run script 36 | SEQLEN=1024 37 | MAX_STEPS=10000 38 | GRADIENT_UPDATE_NUM="$((MAX_STEPS / 1000))K" 39 | 40 | D_MODEL=72 41 | FLEX_ATTN_MODEL=128 # should be multiple of 2, and 64 is the minimum. 42 | INTER_FFN_MODEL=288 # 4x is the best 43 | 44 | N_LAYER=8 45 | LR="8e-3" 46 | 47 | RCPS="false" 48 | RC_AUG="false" 49 | BIDIRECTIONAL_WEIGHT_TIE="true" 50 | BIDIRECTIONAL_ATTN_TIE="false" 51 | ROUTER_AUX_LOSS_COEF=0.2 52 | OUTPUT_ROUTER_LOGITS="true" 53 | 54 | BATCH_SIZE=$(( 1048576 / SEQLEN )) 55 | 56 | SEQLEN_DIS="$((SEQLEN / 1000))k" 57 | WANDB_NAME="janusdna_len-${SEQLEN_DIS}_d_model-${D_MODEL}_inter_dim-${INTER_FFN_MODEL}_n_layer-${N_LAYER}_lr-${LR}_step-${GRADIENT_UPDATE_NUM}_moeloss-${OUTPUT_ROUTER_LOGITS}_1head_onlymoe" 58 | HYDRA_RUN_DIR="${full_path_to_root}/outputs/pretrain/hg38/${WANDB_NAME}" 59 | WATCH_DIR="${full_path_to_root}/watch_folder/pretrain" 60 | 61 | export WANDBID=$(python -c "import wandb; print(wandb.util.generate_id())") 62 | 63 | mkdir -p "${HYDRA_RUN_DIR}" 64 | mkdir -p "${WATCH_DIR}" 65 | srun python -m train \ 66 | experiment=hg38/hg38 \ 67 | callbacks.model_checkpoint_every_n_steps.every_n_train_steps=500 \ 68 | dataset.max_length=${SEQLEN} \ 69 | dataset.batch_size=$(( BATCH_SIZE / NUM_DEVICES )) \ 70 | dataset.batch_size_eval=$(( BATCH_SIZE / NUM_DEVICES )) \ 71 | dataset.mlm=False \ 72 | dataset.rc_aug="${RC_AUG}" \ 73 | dataset.add_eos=false \ 74 | loader.num_workers=0 \ 75 | model="janusdna" \ 76 | +model.config.output_router_logits="${OUTPUT_ROUTER_LOGITS}" \ 77 | +model.config.router_aux_loss_coef="${ROUTER_AUX_LOSS_COEF}" \ 78 | model.config.bidirectional_weight_tie="${BIDIRECTIONAL_WEIGHT_TIE}" \ 79 | model.config.bidirectional_attn_tie="${BIDIRECTIONAL_ATTN_TIE}" \ 80 | model.config.num_hidden_layers=${N_LAYER} \ 81 | model.config.hidden_size=${D_MODEL} \ 82 | model.config.flex_attn_n_embd=${FLEX_ATTN_MODEL} \ 83 | +model.config.intermediate_size=${INTER_FFN_MODEL} \ 84 | model.config.expert_layer_period=2 \ 85 | model.config.expert_layer_offset=1 \ 86 | model.config.intermediate_factor=4 \ 87 | model.config.num_attention_heads=4 \ 88 | model.config.attn_implementation="flash_attention_2" \ 89 | model.config.attn_layer_period=8 \ 90 | model.config.attn_layer_offset=100 \ 91 | optimizer.lr="${LR}" \ 92 | train.global_batch_size=${BATCH_SIZE} \ 93 | trainer.max_steps=${MAX_STEPS} \ 94 | trainer.precision=bf16-mixed \ 95 | trainer.devices=${NUM_DEVICES} \ 96 | +trainer.val_check_interval=$(( MAX_STEPS / 5 )) \ 97 | +trainer.strategy="ddp_find_unused_parameters_true" \ 98 | wandb.group=pretrain_hg38 \ 99 | wandb.name="${WANDB_NAME}" \ 100 | wandb.mode=online \ 101 | wandb.id=${WANDBID} \ 102 | hydra.run.dir="${HYDRA_RUN_DIR}" \ 103 | > ${WATCH_DIR}/${WANDB_NAME}.log 2>&1 104 | 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /src/models/nn/utils.py: -------------------------------------------------------------------------------- 1 | """ Utility wrappers around modules to let them handle Args and extra arguments """ 2 | 3 | import inspect 4 | from functools import wraps 5 | import torch 6 | from torch import nn 7 | 8 | def wrap_kwargs(f): 9 | """ 10 | Given a callable f that can consume some named arguments, 11 | wrap it with a kwargs that passes back any unused args 12 | 13 | EXAMPLES 14 | -------- 15 | 16 | Basic usage: 17 | def foo(x, y=None): 18 | return x 19 | 20 | wrap_kwargs(foo)(0, y=1, z=2) == (0, {'z': 2}) 21 | 22 | -------- 23 | 24 | The wrapped function can return its own argument dictionary, 25 | which gets merged with the new kwargs. 26 | def foo(x, y=None): 27 | return x, {} 28 | wrap_kwargs(foo)(0, y=1, z=2) == (0, {'z': 2}) 29 | 30 | def foo(x, y=None): 31 | return x, {"y": y, "z": None} 32 | wrap_kwargs(foo)(0, y=1, z=2) == (0, {'y': 1, 'z': 2}) 33 | 34 | -------- 35 | 36 | The wrapped function can have its own kwargs parameter: 37 | def foo(x, y=None, **kw_args): 38 | return x, {} 39 | wrap_kwargs(foo)(0, y=1, z=2) == (0, {}) 40 | 41 | -------- 42 | 43 | Partial functions and modules work automatically: 44 | class Module: 45 | def forward(self, x, y=0): 46 | return x, {"y": y+1} 47 | 48 | m = Module() 49 | 50 | wrap_kwargs(m.forward)(0, y=1, z=2) == (0, {'y': 2, 'z': 2}) 51 | 52 | """ 53 | sig = inspect.signature(f) 54 | # Check if f already has kwargs 55 | has_kwargs = any([ 56 | param.kind == inspect.Parameter.VAR_KEYWORD 57 | for param in sig.parameters.values() 58 | ]) 59 | if has_kwargs: 60 | @wraps(f) 61 | def f_kwargs(*args, **kwargs): 62 | y = f(*args, **kwargs) 63 | if isinstance(y, tuple) and isinstance(y[-1], dict): 64 | return y 65 | else: 66 | return y, {} 67 | else: 68 | param_kwargs = inspect.Parameter("kwargs", kind=inspect.Parameter.VAR_KEYWORD) 69 | sig_kwargs = inspect.Signature(parameters=list(sig.parameters.values())+[param_kwargs]) 70 | @wraps(f) 71 | def f_kwargs(*args, **kwargs): 72 | bound = sig_kwargs.bind(*args, **kwargs) 73 | if "kwargs" in bound.arguments: 74 | kwargs = bound.arguments.pop("kwargs") 75 | else: 76 | kwargs = {} 77 | y = f(**bound.arguments) 78 | if isinstance(y, tuple) and isinstance(y[-1], dict): 79 | return *y[:-1], {**y[-1], **kwargs} 80 | else: 81 | return y, kwargs 82 | return f_kwargs 83 | 84 | def discard_kwargs(f): 85 | if f is None: return None 86 | f_kwargs = wrap_kwargs(f) 87 | @wraps(f) 88 | def f_(*args, **kwargs): 89 | return f_kwargs(*args, **kwargs)[0] 90 | return f_ 91 | 92 | def PassthroughSequential(*modules): 93 | """Special Sequential module that chains kwargs. 94 | 95 | Semantics are the same as nn.Sequential, with extra convenience features: 96 | - Discard None modules 97 | - Flatten inner Sequential modules 98 | - In case with 0 or 1 Module, rename the class for ease of inspection 99 | """ 100 | def flatten(module): 101 | if isinstance(module, nn.Sequential): 102 | return sum([flatten(m) for m in module], []) 103 | else: 104 | return [module] 105 | 106 | modules = flatten(nn.Sequential(*modules)) 107 | modules = [module for module in modules if module if not None] 108 | 109 | class Sequential(nn.Sequential): 110 | def forward(self, x, **kwargs): 111 | for layer in self: 112 | x, kwargs = wrap_kwargs(layer.forward)(x, **kwargs) 113 | return x, kwargs 114 | 115 | def step(self, x, **kwargs): 116 | for layer in self: 117 | fn = getattr(layer, "step", layer.forward) 118 | x, kwargs = wrap_kwargs(fn)(x, **kwargs) 119 | return x, kwargs 120 | 121 | if len(modules) == 0: 122 | Sequential.__name__ = "Identity" 123 | elif len(modules) == 1: 124 | Sequential.__name__ = type(modules[0]).__name__ 125 | return Sequential(*modules) 126 | -------------------------------------------------------------------------------- /src/dataloaders/datasets/nucleotide_transformer_dataset.py: -------------------------------------------------------------------------------- 1 | """Nucleotide Transformer Benchmarks Dataset. 2 | 3 | From: https://huggingface.co/datasets/InstaDeepAI/nucleotide_transformer_downstream_tasks 4 | """ 5 | 6 | import torch 7 | from datasets import load_dataset 8 | 9 | from src.dataloaders.utils.rc import coin_flip, string_reverse_complement 10 | 11 | 12 | class NucleotideTransformerDataset(torch.utils.data.Dataset): 13 | 14 | """ 15 | Loop through fasta file for sequence. 16 | Returns a generator that retrieves the sequence. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | split, 22 | max_length, 23 | dataset_name=None, 24 | d_output=2, # default binary classification 25 | tokenizer=None, 26 | tokenizer_name=None, 27 | use_padding=None, 28 | add_eos=False, 29 | rc_aug=False, 30 | conjoin_train=False, 31 | conjoin_test=False, 32 | return_augs=False 33 | ): 34 | 35 | self.max_length = max_length 36 | self.use_padding = use_padding 37 | self.tokenizer_name = tokenizer_name 38 | self.tokenizer = tokenizer 39 | self.return_augs = return_augs 40 | self.add_eos = add_eos 41 | self.d_output = d_output # needed for decoder to grab 42 | assert not (conjoin_train and conjoin_test), "conjoin_train and conjoin_test cannot both be True" 43 | if (conjoin_train or conjoin_test) and rc_aug: 44 | print("When using conjoin, we turn off rc_aug.") 45 | rc_aug = False 46 | self.rc_aug = rc_aug 47 | self.conjoin_train = conjoin_train 48 | self.conjoin_test = conjoin_test 49 | 50 | self.split = split 51 | 52 | # For NT tasks, we use data from InstaDeepAI/nucleotide_transformer_downstream_tasks 53 | # self.seqs = load_dataset( 54 | # "InstaDeepAI/nucleotide_transformer_downstream_tasks", 55 | # name=dataset_name, 56 | # split=split 57 | # ) 58 | 59 | self.seqs = load_dataset( 60 | "InstaDeepAI/nucleotide_transformer_downstream_tasks", 61 | name=dataset_name, 62 | split=split, 63 | trust_remote_code=True 64 | ) 65 | 66 | def __len__(self): 67 | return len(self.seqs) 68 | 69 | def __getitem__(self, idx): 70 | x = self.seqs[idx]["sequence"] # only one sequence 71 | y = self.seqs[idx]["label"] 72 | 73 | if (self.rc_aug or (self.conjoin_test and self.split == "train")) and coin_flip(): 74 | x = string_reverse_complement(x) 75 | 76 | seq = self.tokenizer( 77 | x, 78 | add_special_tokens=False, 79 | padding="max_length" if self.use_padding else None, 80 | max_length=self.max_length, 81 | truncation=True, 82 | ) 83 | seq_ids = seq["input_ids"] # get input_ids 84 | 85 | # need to handle eos here 86 | if self.add_eos: 87 | # append list seems to be faster than append tensor 88 | seq_ids.append(self.tokenizer.sep_token_id) 89 | 90 | if self.conjoin_train or (self.conjoin_test and self.split != "train"): 91 | x_rc = string_reverse_complement(x) 92 | seq_rc = self.tokenizer( 93 | x_rc, 94 | add_special_tokens=False, 95 | padding="max_length" if self.use_padding else None, 96 | max_length=self.max_length, 97 | truncation=True, 98 | ) 99 | seq_rc_ids = seq_rc["input_ids"] # get input_ids 100 | # need to handle eos here 101 | if self.add_eos: 102 | # append list seems to be faster than append tensor 103 | seq_rc_ids.append(self.tokenizer.sep_token_id) 104 | seq_ids = torch.stack((torch.LongTensor(seq_ids), torch.LongTensor(seq_rc_ids)), dim=1) 105 | 106 | else: 107 | # convert to tensor 108 | seq_ids = torch.LongTensor(seq_ids) 109 | 110 | # need to wrap in list 111 | target = torch.LongTensor([y]) 112 | 113 | # `seq` has shape: 114 | # - (seq_len,) if not conjoining 115 | # - (seq_len, 2) for conjoining 116 | return seq_ids, target 117 | -------------------------------------------------------------------------------- /evals/hg38_inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import argparse 4 | import os 5 | import sys 6 | import yaml 7 | from tqdm import tqdm 8 | import json 9 | 10 | sys.path.append(os.environ.get("SAFARI_PATH", ".")) 11 | 12 | from src.models.sequence.long_conv_lm import ConvLMHeadModel 13 | 14 | # from transformers import AutoTokenizer, GPT2LMHeadModel 15 | # from spacy.lang.en.stop_words import STOP_WORDS 16 | from src.dataloaders.datasets.hg38_char_tokenizer import CharacterTokenizer 17 | 18 | try: 19 | from tokenizers import Tokenizer 20 | except: 21 | pass 22 | 23 | # https://github.com/openai/gpt-2/issues/131#issuecomment-492786058 24 | # def preprocess(text): 25 | # text = text.replace("“", '"') 26 | # text = text.replace("”", '"') 27 | # return '\n'+text.strip() 28 | 29 | 30 | class HG38Encoder: 31 | "Encoder inference for HG38 sequences" 32 | def __init__(self, model_cfg, ckpt_path, max_seq_len): 33 | self.max_seq_len = max_seq_len 34 | self.model, self.tokenizer = self.load_model(model_cfg, ckpt_path) 35 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 36 | self.model = self.model.to(self.device) 37 | 38 | def encode(self, seqs): 39 | 40 | results = [] 41 | 42 | # sample code to loop thru each sample and tokenize first (char level) 43 | for seq in tqdm(seqs): 44 | 45 | if isinstance(self.tokenizer, Tokenizer): 46 | tokenized_seq = self.tokenizer.encode(seq).ids 47 | else: 48 | tokenized_seq = self.tokenizer.encode(seq) 49 | 50 | # can accept a batch, shape [B, seq_len, hidden_dim] 51 | logits, __ = self.model(torch.tensor([tokenized_seq]).to(device=self.device)) 52 | 53 | # Using head, so just have logits 54 | results.append(logits) 55 | 56 | return results 57 | 58 | 59 | def load_model(self, model_cfg, ckpt_path): 60 | config = yaml.load(open(model_cfg, 'r'), Loader=yaml.FullLoader) 61 | model = ConvLMHeadModel(**config['model_config']) 62 | 63 | state_dict = torch.load(ckpt_path, map_location='cpu') 64 | 65 | # loads model from ddp by removing prexix to single if necessary 66 | torch.nn.modules.utils.consume_prefix_in_state_dict_if_present( 67 | state_dict["state_dict"], "model." 68 | ) 69 | 70 | model_state_dict = state_dict["state_dict"] 71 | 72 | # need to remove torchmetrics. to remove keys, need to convert to list first 73 | for key in list(model_state_dict.keys()): 74 | if "torchmetrics" in key: 75 | model_state_dict.pop(key) 76 | 77 | model.load_state_dict(state_dict["state_dict"]) 78 | 79 | # setup tokenizer 80 | if config['tokenizer_name'] == 'char': 81 | print("**Using Char-level tokenizer**") 82 | 83 | # add to vocab 84 | tokenizer = CharacterTokenizer( 85 | characters=['A', 'C', 'G', 'T', 'N'], 86 | model_max_length=self.max_seq_len + 2, # add 2 since default adds eos/eos tokens, crop later 87 | add_special_tokens=False, 88 | ) 89 | print(tokenizer._vocab_str_to_int) 90 | else: 91 | raise NotImplementedError("You need to provide a custom tokenizer!") 92 | 93 | return model, tokenizer 94 | 95 | 96 | if __name__ == "__main__": 97 | 98 | SAFARI_PATH = os.getenv('SAFARI_PATH', '.') 99 | 100 | parser = argparse.ArgumentParser() 101 | 102 | parser.add_argument( 103 | "--model_cfg", 104 | default=f"{SAFARI_PATH}/configs/evals/hyena_small_150b.yaml", 105 | ) 106 | 107 | parser.add_argument( 108 | "--ckpt_path", 109 | default=f"", 110 | help="Path to model state dict checkpoint" 111 | ) 112 | 113 | args = parser.parse_args() 114 | 115 | task = HG38Encoder(args.model_cfg, args.ckpt_path, max_seq_len=1024) 116 | 117 | # sample sequence, can pass a list of seqs (themselves a list of chars) 118 | seqs = ["ACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGT"] 119 | 120 | logits = task.encode(seqs) 121 | print(logits) 122 | print(logits[0].logits.shape) 123 | 124 | breakpoint() 125 | 126 | -------------------------------------------------------------------------------- /scripts/benchmark/dnalong/eqtl_train_janus_8gpu.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | # SBATCH --job-name=AS_w/o_midattn 3 | # SBATCH --partition=pgpu 4 | # SBATCH --nodes=1 5 | # SBATCH --ntasks-per-node=8 6 | # SBATCH --gres=gpu:8 7 | # SBATCH --cpus-per-task=12 8 | # SBATCH --mem=128G 9 | # SBATCH --exclusive 10 | # SBATCH --requeue 11 | # SBATCH --time=2-00:00:00 12 | # SBATCH --nodelist=s-sc-pgpu[08],s-sc-dgx[01-02] 13 | # SBATCH --mail-type=ALL 14 | # SBATCH --mail-user= 15 | # SBATCH --output=/my_job%j.out 16 | # SBATCH --error=/my_job%j.err 17 | 18 | 19 | export LD_LIBRARY_PATH=/janusdna/lib/python3.11/site-packages/nvidia/nvjitlink/lib:$LD_LIBRARY_PATH 20 | export HYDRA_FULL_ERROR=1 21 | 22 | source /etc/profile.d/conda.sh 23 | 24 | conda activate janusdna 25 | 26 | # cd working directory 27 | cd 28 | 29 | 30 | 31 | # make sure train.py exists 32 | ls train.py || echo "train.py not found!" 33 | 34 | # CELL_TYPES=( 35 | # # "Adipose_Subcutaneous" 36 | # "Artery_Tibial" 37 | # "Cells_Cultured_fibroblasts" 38 | # "Muscle_Skeletal" 39 | # "Nerve_Tibial" 40 | # "Skin_Not_Sun_Exposed_Suprapubic" 41 | # "Skin_Sun_Exposed_Lower_leg" 42 | # "Thyroid" 43 | # "Whole_Blood" 44 | # ) 45 | 46 | 47 | # task params 48 | CELL_TYPE="Whole_Blood" # pich one from above list, e.g., Cells_Cultured_fibroblasts, Adipose_Subcutaneous 49 | LR="4e-4" 50 | BATCH_SIZE=8 51 | SEED=1 52 | NUM_GPUS=8 53 | FINETUNED_EPOCH=3 54 | OUTPUT_ROUTER_LOGITS="true" # define missing variable 55 | 56 | # model config 57 | MODEL="janusdna" 58 | MODEL_NAME="dna_embedding_janusdna" 59 | RC_AUG="false" 60 | CONJOIN_TRAIN_DECODER="false" 61 | CONJOIN_TEST="false" 62 | FREEZE_BACKBONE="false" 63 | 64 | # pretrained model path 65 | ROOT_DIR="" 66 | PRETRAINED_DIR="${ROOT_DIR}/outputs/pretrain/hg38" 67 | MODEL_PRETRAINED_DIRNAME="janusdna_len-131k_d_model-144_inter_dim-576_n_layer-8_lr-8e-3_step-50K_moeloss-true_1head_onlymoe" # copy the pre-trained model directory name 68 | PRETRAINED_CONFIG_PATH="${PRETRAINED_DIR}/${MODEL_PRETRAINED_DIRNAME}/model_config.json" 69 | PRETRAINED_WEIGHT_PATH="${PRETRAINED_DIR}/${MODEL_PRETRAINED_DIRNAME}/checkpoints/last.ckpt" 70 | 71 | # finetuned model path 72 | FINETUNED_BASE_DIR="${ROOT_DIR}/outputs/downstream/longrange_benchmark/eqtl" 73 | MODEL_FINETUNED_DIRNAME=${MODEL_PRETRAINED_DIRNAME} 74 | 75 | # log path 76 | LOG_BASE_DIR="${ROOT_DIR}/watch_folder/DNALong/eQTL" 77 | 78 | # name and log setting 79 | WANDB_NAME="${CELL_TYPE}_lr-${LR}_cjtrain_${CONJOIN_TRAIN_DECODER}_batch_${BATCH_SIZE}_seed_${SEED}" 80 | HYDRA_RUN_DIR="${FINETUNED_BASE_DIR}/${MODEL_FINETUNED_DIRNAME}/${CELL_TYPE}/${WANDB_NAME}" 81 | LOG_DIR="${LOG_BASE_DIR}/${MODEL_FINETUNED_DIRNAME}/${CELL_TYPE}" 82 | LOG_FILE="${LOG_DIR}/${WANDB_NAME}.log" 83 | 84 | mkdir -p "${HYDRA_RUN_DIR}" 85 | mkdir -p "${LOG_DIR}" 86 | 87 | echo "Running cell_type: ${CELL_TYPE}, Pretrained_model: ${MODEL_PRETRAINED_DIRNAME}, LR: ${LR}, BATCH_SIZE: ${BATCH_SIZE}, SEED: ${SEED}" 88 | echo "Logging to: ${LOG_FILE}" 89 | 90 | # wandb id 91 | export WANDBID=$(python -c "import wandb; print(wandb.util.generate_id())" 2>/dev/null || echo "wandb_id_error") 92 | 93 | srun python -m train \ 94 | experiment=hg38/eqtl \ 95 | dataset.batch_size=$((BATCH_SIZE / NUM_GPUS)) \ 96 | dataset.dest_path="${ROOT_DIR}/data" \ 97 | dataset.cell_type=${CELL_TYPE} \ 98 | loader.num_workers=0 \ 99 | model="${MODEL}" \ 100 | model._name_="${MODEL_NAME}" \ 101 | +model.config_path=${PRETRAINED_CONFIG_PATH} \ 102 | +model.config.output_router_logits="${OUTPUT_ROUTER_LOGITS}" \ 103 | +model.config.router_aux_loss_coef=0.02 \ 104 | decoder.mode="pool" \ 105 | optimizer.lr=${LR} \ 106 | train.pretrained_model_path="${PRETRAINED_WEIGHT_PATH}" \ 107 | train.pretrained_model_state_hook.freeze_backbone=${FREEZE_BACKBONE} \ 108 | train.monitor=val/main_loss_epoch \ 109 | train.global_batch_size=${BATCH_SIZE} \ 110 | trainer.num_sanity_val_steps=1 \ 111 | trainer.max_epochs=${FINETUNED_EPOCH} \ 112 | trainer.precision=32 \ 113 | trainer.devices=${NUM_GPUS} \ 114 | +trainer.strategy="ddp_find_unused_parameters_true" \ 115 | wandb.mode=online \ 116 | wandb.name="${WANDB_NAME}" \ 117 | wandb.id=${WANDBID} \ 118 | wandb.group=DNALong/eQTL/janus-onlymoe/131k_50k \ 119 | hydra.run.dir=${HYDRA_RUN_DIR} \ 120 | > "${LOG_FILE}" 2>&1 121 | 122 | 123 | 124 | 125 | -------------------------------------------------------------------------------- /src/ops/fftconv.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from einops import rearrange 7 | 8 | from fftconv import fftconv_fwd, fftconv_bwd 9 | 10 | @torch.jit.script 11 | def _mul_sum(y, q): 12 | return (y * q).sum(dim=1) 13 | 14 | # reference convolution with residual connection 15 | def fftconv_ref(u, k, D, dropout_mask, gelu=True, k_rev=None): 16 | seqlen = u.shape[-1] 17 | fft_size = 2 * seqlen 18 | k_f = torch.fft.rfft(k, n=fft_size) / fft_size 19 | if k_rev is not None: 20 | k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size 21 | k_f = k_f + k_rev_f.conj() 22 | u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) 23 | 24 | if len(u.shape) > 3: k_f = k_f.unsqueeze(1) 25 | 26 | y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen] 27 | 28 | out = y + u * D.unsqueeze(-1) 29 | if gelu: 30 | out = F.gelu(out) 31 | if dropout_mask is not None: 32 | return (out * rearrange(dropout_mask, 'b H -> b H 1')).to(dtype=u.dtype) 33 | else: 34 | return out.to(dtype=u.dtype) 35 | 36 | 37 | # reference H3 forward pass 38 | def fftconv_h3_ref(k, ssm_kernel, D, q, v, head_dim=1, ssm_kernel_rev=None): 39 | seqlen = k.shape[-1] 40 | fft_size = 2 * seqlen 41 | kv = (rearrange(k, 'b (h d1) l -> b d1 1 h l', d1=head_dim) 42 | * rearrange(v, 'b (h d2) l -> b 1 d2 h l', d2=head_dim)) # b d1 d2 h l 43 | kv_f = torch.fft.rfft(kv.to(dtype=ssm_kernel.dtype), n=fft_size) / fft_size 44 | ssm_kernel_f = torch.fft.rfft(ssm_kernel, n=fft_size) # h L+1 45 | if ssm_kernel_rev is not None: 46 | ssm_kernel_rev_f = torch.fft.rfft(ssm_kernel_rev, n=fft_size) # h L+1 47 | ssm_kernel_f = ssm_kernel_f + ssm_kernel_rev_f.conj() 48 | y = torch.fft.irfft(kv_f * ssm_kernel_f, n=fft_size, norm='forward')[..., :seqlen] # b d1 d2 h l 49 | out = y + kv * D.unsqueeze(-1) # b d1 d2 h l 50 | q = rearrange(q, 'b (h d1) l -> b d1 1 h l', d1=head_dim) 51 | if head_dim > 1: 52 | out = _mul_sum(out, q) 53 | return rearrange(out, 'b d2 h l -> b (h d2) l').to(dtype=k.dtype) 54 | else: 55 | return rearrange(out * q, 'b 1 1 h l -> b h l').to(dtype=k.dtype) 56 | 57 | 58 | class FFTConvFunc(torch.autograd.Function): 59 | 60 | @staticmethod 61 | def forward(ctx, u, k, D, dropout_mask=None, gelu=True, force_fp16_output=False, 62 | output_hbl_layout=False, v=None, head_dim=1, q=None, fftfp16=False, k_rev=None): 63 | seqlen = u.shape[-1] 64 | fft_size = max(2 * 2 ** int(math.ceil(math.log2(seqlen))), 16) 65 | k_f = torch.fft.rfft(k, n=fft_size) 66 | if k_rev is not None: 67 | k_f = k_f + torch.fft.rfft(k_rev, n=fft_size).conj() 68 | if u.stride(-1) != 1: 69 | u = u.contiguous() 70 | k_f = k_f.contiguous() 71 | D = D.contiguous() 72 | if v is not None and v.stride(-1) != 1: 73 | v = v.contiguous() 74 | if q is not None and q.stride(-1) != 1: 75 | q = q.contiguous() 76 | if dropout_mask is not None: 77 | dropout_mask = dropout_mask.contiguous() 78 | ctx.save_for_backward(u, k_f, D, dropout_mask, v, q) 79 | ctx.output_hbl_layout = output_hbl_layout 80 | ctx.head_dim = head_dim 81 | ctx.gelu = gelu 82 | ctx.fftfp16 = fftfp16 83 | ctx.has_k_rev = k_rev is not None 84 | out = fftconv_fwd(u, k_f, D, v, head_dim, q, dropout_mask, gelu, False, False, fft_size, force_fp16_output, output_hbl_layout, fftfp16) 85 | return out 86 | 87 | @staticmethod 88 | def backward(ctx, dout): 89 | if ctx.output_hbl_layout: 90 | dout = rearrange(rearrange(dout, 'b h l -> h b l').contiguous(), 'h b l -> b h l') 91 | else: 92 | dout = dout.contiguous() 93 | u, k_f, D, dropout_mask, v, q = ctx.saved_tensors 94 | seqlen = u.shape[-1] 95 | fft_size = max(2 * 2 ** int(math.ceil(math.log2(seqlen))), 16) 96 | du, dk_f, dD, dv, dq = fftconv_bwd(dout, u, k_f, D, v, ctx.head_dim, q, dropout_mask, ctx.gelu, False, False, fft_size, 97 | ctx.output_hbl_layout, ctx.fftfp16) 98 | dk = torch.fft.irfft(dk_f, n=fft_size, norm='forward')[..., :seqlen] 99 | dk_rev = (None if not ctx.has_k_rev 100 | else torch.fft.irfft(dk_f.conj(), n=fft_size, norm='forward')[..., :seqlen]) 101 | if v is not None: 102 | dv = dv.to(dtype=v.dtype) # We do atomicAdd in fp32 so might need to convert to fp16 103 | return du, dk, dD, None, None, None, None, dv if v is not None else None, None, dq if q is not None else None, None, dk_rev 104 | 105 | def fftconv_func(u, k, D, dropout_mask=None, gelu=True, force_fp16_output=False, 106 | output_hbl_layout=False, v=None, head_dim=1, q=None, fftfp16=False, k_rev=None): 107 | return FFTConvFunc.apply(u, k, D, dropout_mask, gelu, force_fp16_output, 108 | output_hbl_layout, v, head_dim, q, fftfp16, k_rev) 109 | -------------------------------------------------------------------------------- /src/tasks/torchmetrics.py: -------------------------------------------------------------------------------- 1 | # Inspired by https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/metrics/perplexity.py 2 | # But we compute the perplexity correctly: exp(average(nll)), not average(exp(nll)) 3 | # Also adapted from https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/text/perplexity.py 4 | # But we pass in the loss to avoid recomputation 5 | 6 | from typing import Any, Dict, Optional 7 | 8 | import torch 9 | from torch import Tensor 10 | 11 | from torchmetrics import Metric 12 | 13 | try: 14 | from flash_attn.losses.cross_entropy import CrossEntropyLoss 15 | except ImportError: 16 | CrossEntropyLoss = torch.nn.CrossEntropyLoss 17 | 18 | try: 19 | from apex.transformer import parallel_state 20 | except ImportError: 21 | parallel_state = None 22 | 23 | 24 | class Perplexity(Metric): 25 | r""" 26 | Perplexity measures how well a language model predicts a text sample. It's calculated as the average number of bits 27 | per word a model needs to represent the sample. 28 | Args: 29 | kwargs: 30 | Additional keyword arguments, see :ref:`Metric kwargs` for more info. 31 | Examples: 32 | >>> import torch 33 | >>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22)) 34 | >>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22)) 35 | >>> target[0, 6:] = -100 36 | >>> metric = Perplexity(ignore_index=-100) 37 | >>> metric(preds, target) 38 | tensor(5.2545) 39 | """ 40 | is_differentiable = True 41 | higher_is_better = False 42 | full_state_update = False 43 | total_log_probs: Tensor 44 | count: Tensor 45 | 46 | def __init__(self, **kwargs: Dict[str, Any]): 47 | super().__init__(**kwargs) 48 | self.add_state("total_log_probs", default=torch.tensor(0.0, dtype=torch.float64), 49 | dist_reduce_fx="sum") 50 | self.add_state("count", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum") 51 | 52 | self.loss_fn = CrossEntropyLoss() 53 | 54 | def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor] = None) -> None: # type: ignore 55 | """Compute and store intermediate statistics for Perplexity. 56 | Args: 57 | preds: 58 | Probabilities assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size]. 59 | target: 60 | Ground truth values with a shape [batch_size, seq_len]. 61 | """ 62 | count = target.numel() 63 | if loss is None: 64 | loss = self.loss_fn(preds, target) 65 | self.total_log_probs += loss.double() * count # ! here is the multiple of count and single-batch average loss 66 | self.count += count 67 | 68 | def compute(self) -> Tensor: 69 | """Compute the Perplexity. 70 | Returns: 71 | Perplexity 72 | """ 73 | return torch.exp(self.total_log_probs / self.count) 74 | 75 | class NumTokens(Metric): 76 | """Keep track of how many tokens we've seen. 77 | """ 78 | # TODO: how do we prevent the reset between the epochs? The reset happens on the 1st batch 79 | # of the next epoch. 80 | # Right now the hack is that we override reset(), which would mess up the forward method. 81 | # We then override forward to do the right thing. 82 | 83 | is_differentiable = False 84 | higher_is_better = False 85 | full_state_update = False 86 | count: Tensor 87 | 88 | def __init__(self, **kwargs: Dict[str, Any]): 89 | super().__init__(**kwargs) 90 | self.add_state("count", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum", 91 | persistent=True) # We want the count to be saved to state-dict 92 | if parallel_state is not None and not parallel_state.is_unitialized(): 93 | self.tensor_parallel_world_size = parallel_state.get_tensor_model_parallel_world_size() 94 | else: 95 | self.tensor_parallel_world_size = 1 96 | 97 | def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor] = None) -> None: # type: ignore 98 | self.count += target.numel() // self.tensor_parallel_world_size 99 | 100 | def compute(self) -> Tensor: 101 | return self.count 102 | 103 | def reset(self): 104 | count = self.count 105 | super().reset() 106 | self.count = count 107 | 108 | # Adapted from https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/metric.py 109 | def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any: 110 | """forward computation using single call to `update` to calculate the metric value on the current batch and 111 | accumulate global state. 112 | This can be done when the global metric state is a sinple reduction of batch states. 113 | """ 114 | self.update(*args, **kwargs) 115 | return self.compute() 116 | 117 | torchmetric_fns = { 118 | "perplexity": Perplexity, 119 | "num_tokens": NumTokens, 120 | } 121 | -------------------------------------------------------------------------------- /src/dataloaders/fault_tolerant_sampler.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/Lightning-AI/lightning/blob/2845e7565dbe6b765ae32870e7d2bc456529c30a/tests/tests_pytorch/utilities/test_auto_restart.py#L1397 2 | from typing import Iterator 3 | import math 4 | 5 | import torch 6 | from torch.utils.data import RandomSampler, DistributedSampler 7 | 8 | 9 | class RandomFaultTolerantSampler(RandomSampler): 10 | 11 | def __init__(self, *args, generator=None, **kwargs): 12 | # generator = torch.Generator().manual_seed(seed) 13 | # super().__init__(*args, generator=generator, **kwargs) 14 | # TD [2022-07-17]: We don't force the seed to be zero. We generate random seed, 15 | # which should be reproducible if pl.seed_everything was called before hand. 16 | # This means that changing the seed of the experiment will also change the 17 | # sampling order. 18 | if generator is None: 19 | seed = int(torch.empty((), dtype=torch.int64).random_().item()) 20 | generator = torch.Generator().manual_seed(seed) 21 | super().__init__(*args, generator=generator, **kwargs) 22 | self.counter = 0 23 | # self.start_counter = 0 24 | self.restarting = False 25 | 26 | def state_dict(self): 27 | return {"random_state": self.state, "counter": self.counter} 28 | 29 | def load_state_dict(self, state_dict): 30 | self.generator.set_state(state_dict.get("random_state")) 31 | self.counter = state_dict["counter"] 32 | # self.start_counter = self.counter 33 | self.restarting = True 34 | 35 | # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per 36 | # epoch, and subsequent epoch will have very few batches. 37 | # def __len__(self): 38 | # # We need a separate self.start_counter because PL seems to call len repeatedly. 39 | # # If we use len(self.data_source) - self.counter then PL will think the epoch ends 40 | # # when we're only half way through. 41 | # return len(self.data_source) - self.start_counter 42 | 43 | def __iter__(self) -> Iterator[int]: 44 | n = len(self.data_source) 45 | 46 | self.state = self.generator.get_state() 47 | indices = torch.randperm(n, generator=self.generator).tolist() 48 | 49 | if not self.restarting: 50 | self.counter = 0 51 | else: 52 | indices = indices[self.counter:] 53 | self.restarting = False 54 | # self.start_counter = self.counter 55 | 56 | for index in indices: 57 | self.counter += 1 58 | yield index 59 | 60 | self.counter = 0 61 | # self.start_counter = self.counter 62 | 63 | 64 | class FaultTolerantDistributedSampler(DistributedSampler): 65 | 66 | def __init__(self, *args, **kwargs): 67 | super().__init__(*args, **kwargs) 68 | self.counter = 0 69 | # self.start_counter = 0 70 | self.restarting = False 71 | 72 | def state_dict(self): 73 | return {"epoch": self.epoch, "counter": self.counter} 74 | 75 | def load_state_dict(self, state_dict): 76 | self.epoch = state_dict["epoch"] 77 | self.counter = state_dict["counter"] 78 | # self.start_counter = self.counter 79 | self.restarting = True 80 | 81 | # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per 82 | # epoch, and subsequent epoch will have very few batches. 83 | # def __len__(self) -> int: 84 | # return self.num_samples - self.start_counter 85 | 86 | def __iter__(self): 87 | if self.shuffle: 88 | # deterministically shuffle based on epoch and seed 89 | g = torch.Generator() 90 | g.manual_seed(self.seed + self.epoch) 91 | indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] 92 | else: 93 | indices = list(range(len(self.dataset))) # type: ignore[arg-type] 94 | 95 | if not self.drop_last: 96 | # add extra samples to make it evenly divisible 97 | padding_size = self.total_size - len(indices) 98 | if padding_size <= len(indices): 99 | indices += indices[:padding_size] 100 | else: 101 | indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] 102 | else: 103 | # remove tail of data to make it evenly divisible. 104 | indices = indices[:self.total_size] 105 | assert len(indices) == self.total_size 106 | 107 | # subsample 108 | indices = indices[self.rank:self.total_size:self.num_replicas] 109 | assert len(indices) == self.num_samples 110 | 111 | if not self.restarting: 112 | self.counter = 0 113 | else: 114 | indices = indices[self.counter:] 115 | self.restarting = False 116 | # self.start_counter = self.counter 117 | 118 | for index in indices: 119 | self.counter += 1 120 | yield index 121 | 122 | self.counter = 0 123 | # self.start_counter = self.counter -------------------------------------------------------------------------------- /src/dataloaders/datasets/genomic_bench_dataset.py: -------------------------------------------------------------------------------- 1 | """Genomic Benchmarks Dataset. 2 | 3 | From: https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks 4 | """ 5 | 6 | from pathlib import Path 7 | 8 | import torch 9 | from genomic_benchmarks.data_check import is_downloaded 10 | from genomic_benchmarks.loc2seq import download_dataset 11 | 12 | from src.dataloaders.utils.rc import coin_flip, string_reverse_complement 13 | 14 | 15 | class GenomicBenchmarkDataset(torch.utils.data.Dataset): 16 | """ 17 | Loop through bed file, retrieve (chr, start, end), query fasta file for sequence. 18 | Returns a generator that retrieves the sequence. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | split, 24 | max_length, 25 | dataset_name="human_nontata_promoters", 26 | d_output=2, # default binary classification 27 | dest_path=None, 28 | tokenizer=None, 29 | tokenizer_name=None, 30 | use_padding=None, 31 | add_eos=False, 32 | rc_aug=False, 33 | conjoin_train=False, 34 | conjoin_test=False, 35 | return_augs=False, 36 | return_mask=False, 37 | ): 38 | 39 | self.max_length = max_length 40 | self.use_padding = use_padding 41 | self.tokenizer_name = tokenizer_name 42 | self.tokenizer = tokenizer 43 | self.return_augs = return_augs 44 | self.add_eos = add_eos 45 | self.d_output = d_output # needed for decoder to grab 46 | assert not (conjoin_train and conjoin_test), "conjoin_train and conjoin_test cannot both be True" 47 | if (conjoin_train or conjoin_test) and rc_aug: 48 | print("When using conjoin, we turn off rc_aug.") 49 | rc_aug = False 50 | self.rc_aug = rc_aug 51 | self.conjoin_train = conjoin_train 52 | self.conjoin_test = conjoin_test 53 | self.return_mask = return_mask 54 | 55 | if not is_downloaded(dataset_name, cache_path=dest_path): 56 | print("downloading {} to {}".format(dataset_name, dest_path)) 57 | download_dataset(dataset_name, version=0, dest_path=dest_path) 58 | else: 59 | print("already downloaded {}-{}".format(split, dataset_name)) 60 | 61 | self.split = split 62 | 63 | # use Path object 64 | base_path = Path(dest_path) / dataset_name / split 65 | 66 | self.all_seqs = [] 67 | self.all_labels = [] 68 | label_mapper = {} 69 | 70 | for i, x in enumerate(sorted(base_path.iterdir())): 71 | label_mapper[x.stem] = i 72 | 73 | for label_type in label_mapper.keys(): 74 | for path in (base_path / label_type).iterdir(): 75 | with open(path, "r") as f: 76 | content = f.read() 77 | self.all_seqs.append(content) 78 | self.all_labels.append(label_mapper[label_type]) 79 | 80 | def __len__(self): 81 | return len(self.all_labels) 82 | 83 | def __getitem__(self, idx): 84 | x = self.all_seqs[idx] 85 | y = self.all_labels[idx] 86 | 87 | if (self.rc_aug or (self.conjoin_test and self.split == "train")) and coin_flip(): 88 | x = string_reverse_complement(x) # attach reverse complement to the head of original sequence 89 | 90 | seq = self.tokenizer( 91 | x, 92 | add_special_tokens=False, 93 | padding="max_length" if self.use_padding else None, 94 | max_length=self.max_length, 95 | truncation=True, 96 | ) 97 | seq_ids = seq["input_ids"] # get input_ids 98 | 99 | # need to handle eos here 100 | if self.add_eos: 101 | # append list seems to be faster than append tensor 102 | seq_ids.append(self.tokenizer.sep_token_id) 103 | 104 | if self.conjoin_train or (self.conjoin_test and self.split != "train"): 105 | x_rc = string_reverse_complement(x) 106 | seq_rc = self.tokenizer( 107 | x_rc, 108 | add_special_tokens=False, 109 | padding="max_length" if self.use_padding else None, 110 | max_length=self.max_length, 111 | truncation=True, 112 | ) 113 | seq_rc_ids = seq_rc["input_ids"] # get input_ids 114 | # need to handle eos here 115 | if self.add_eos: 116 | # append list seems to be faster than append tensor 117 | seq_rc_ids.append(self.tokenizer.sep_token_id) 118 | seq_ids = torch.stack((torch.LongTensor(seq_ids), torch.LongTensor(seq_rc_ids)), dim=1) 119 | 120 | else: 121 | # convert to tensor 122 | seq_ids = torch.LongTensor(seq_ids) 123 | 124 | # need to wrap in list 125 | target = torch.LongTensor([y]) 126 | 127 | # `seq` has shape: 128 | # - (seq_len,) if not conjoining 129 | # - (seq_len, 2) for conjoining 130 | if self.return_mask: 131 | return seq_ids, target, {"mask": torch.BoolTensor(seq["attention_mask"])} 132 | else: 133 | return seq_ids, target 134 | -------------------------------------------------------------------------------- /scripts/benchmark/gb/gb_janusdna.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # environment variables 4 | export HYDRA_FULL_ERROR=1 5 | export LD_LIBRARY_PATH=/janusdna/lib/python3.11/site-packages/nvidia/nvjitlink/lib:$LD_LIBRARY_PATH 6 | 7 | source /etc/profile.d/conda.sh 8 | conda activate janusdna 9 | 10 | cd 11 | 12 | PROJECT_ROOT_DIR="" 13 | 14 | # datasets 15 | DATASETS=( 16 | "dummy_mouse_enhancers_ensembl" 17 | "demo_coding_vs_intergenomic_seqs" 18 | "demo_human_or_worm" 19 | "human_enhancers_cohn" 20 | "human_enhancers_ensembl" 21 | "human_ensembl_regulatory" 22 | "human_nontata_promoters" 23 | "human_ocr_ensembl" 24 | ) 25 | 26 | LRS=("1e-3" "2e-3") 27 | BATCH_SIZES=(128 256 512) 28 | SEEDS=(1 2 3 4 5) 29 | NUM_GPUS=8 30 | 31 | # pretrained model path and config 32 | CONFIG_PATH=$(realpath "${PROJECT_ROOT_DIR}/outputs/pretrain/hg38/janusdna_len-1k_d_model-32_inter_dim-128_n_layer-8_lr-8e-3_step-10K_moeloss-true_1head_midattn/model_config.json") 33 | PRETRAINED_PATH=$(realpath "${PROJECT_ROOT_DIR}/outputs/pretrain/hg38/janusdna_len-1k_d_model-32_inter_dim-128_n_layer-8_lr-8e-3_step-10K_moeloss-true_1head_midattn/checkpoints/last.ckpt") 34 | 35 | # model parameters 36 | # name should be same as the pre-trained one 37 | DISPLAY_NAME="janusdna_len-1k_d_model-32_inter_dim-128_n_layer-8_lr-8e-3_step-10K_moeloss-true_1head_midattn" 38 | MODEL="janusdna" 39 | MODEL_NAME="dna_embedding_janusdna" 40 | RC_AUG="false" 41 | CONJOIN_TRAIN_DECODER="false" 42 | CONJOIN_TEST="true" 43 | 44 | HYDRA_RUN_DIR="${PROJECT_ROOT_DIR}/outputs/gb" 45 | 46 | LOG_DIR="${PROJECT_ROOT_DIR}/watch_folder/gb/${DISPLAY_NAME}" 47 | mkdir -p "${LOG_DIR}" 48 | 49 | # task queue 50 | declare -A GPU_TASKS 51 | 52 | 53 | run_task() { 54 | local task=$1 55 | local gpu_id=$2 56 | local lr=$3 57 | local batch_size=$4 58 | local seed=$5 59 | 60 | local WANDB_NAME="${DISPLAY_NAME}_LR-${lr}_BATCH_SIZE-${batch_size}" 61 | local hydra_run_dir="${HYDRA_RUN_DIR}/${WANDB_NAME}/${task}/seed-${seed}" 62 | mkdir -p "${hydra_run_dir}" 63 | 64 | local LOG_FILE="${LOG_DIR}/${task}_gpu-${gpu_id}_lr-${lr}_batch-${batch_size}_seed-${seed}.log" 65 | 66 | local WANDBID=$(python -c "import wandb; print(wandb.util.generate_id())") 67 | 68 | echo "Running Task: ${task}, GPU: ${gpu_id}, LR: ${lr}, BATCH_SIZE: ${batch_size}, SEED: ${seed}" 69 | echo "Logging to: ${LOG_FILE}" 70 | 71 | CUDA_VISIBLE_DEVICES=${gpu_id} nohup python -m train \ 72 | experiment=hg38/genomic_benchmark \ 73 | dataset.dataset_name="${task}" \ 74 | dataset.train_val_split_seed="${seed}" \ 75 | dataset.batch_size=${batch_size} \ 76 | dataset.rc_aug="${RC_AUG}" \ 77 | +dataset.conjoin_test="${CONJOIN_TEST}" \ 78 | optimizer.lr="${lr}" \ 79 | model="${MODEL}" \ 80 | model._name_="${MODEL_NAME}" \ 81 | +model.config_path="${CONFIG_PATH}" \ 82 | +model.conjoin_test="${CONJOIN_TEST}" \ 83 | +decoder.conjoin_train="${CONJOIN_TRAIN_DECODER}" \ 84 | +decoder.conjoin_test="${CONJOIN_TEST}" \ 85 | train.pretrained_model_path="${PRETRAINED_PATH}" \ 86 | trainer.max_epochs=10 \ 87 | trainer.devices=1 \ 88 | trainer.precision=bf16 \ 89 | wandb.mode="offline" \ 90 | wandb.group="downstream/gb/${task}" \ 91 | wandb.job_type="${task}" \ 92 | wandb.name="${WANDB_NAME}" \ 93 | wandb.id="${WANDBID}" \ 94 | +wandb.tags=\["seed-${seed}"\] \ 95 | hydra.run.dir="${hydra_run_dir}" \ 96 | > "${LOG_FILE}" 2>&1 & 97 | 98 | GPU_TASKS[${gpu_id}]=$! 99 | echo "Started task on GPU ${gpu_id}, PID: ${GPU_TASKS[${gpu_id}]}" 100 | sleep 5 # Reduce contention 101 | } 102 | 103 | 104 | schedule_tasks() { 105 | local tasks=() 106 | 107 | for dataset in "${DATASETS[@]}"; do 108 | for lr in "${LRS[@]}"; do 109 | for batch_size in "${BATCH_SIZES[@]}"; do 110 | for seed in "${SEEDS[@]}"; do 111 | tasks+=("${dataset} ${lr} ${batch_size} ${seed}") 112 | done 113 | done 114 | done 115 | done 116 | 117 | local total_tasks=${#tasks[@]} 118 | local task_index=0 119 | 120 | # initially launch tasks on available GPUs 121 | for ((gpu_id=0; gpu_id/dev/null; then 131 | echo "GPU ${gpu_id} finished, launching new task..." 132 | IFS=' ' read -r dataset lr batch_size seed <<< "${tasks[$task_index]}" 133 | run_task "${dataset}" "${gpu_id}" "${lr}" "${batch_size}" "${seed}" 134 | ((task_index++)) 135 | fi 136 | done 137 | sleep 5 # polling interval 138 | done 139 | 140 | wait 141 | echo "All tasks completed." 142 | } 143 | 144 | # Start task scheduling 145 | schedule_tasks 146 | 147 | 148 | -------------------------------------------------------------------------------- /caduceus/tokenization_caduceus.py: -------------------------------------------------------------------------------- 1 | """Character tokenizer for Hugging Face. 2 | 3 | """ 4 | 5 | from typing import List, Optional, Dict, Sequence, Tuple 6 | 7 | from transformers import PreTrainedTokenizer 8 | 9 | 10 | class CaduceusTokenizer(PreTrainedTokenizer): 11 | model_input_names = ["input_ids"] 12 | 13 | def __init__(self, 14 | model_max_length: int, 15 | characters: Sequence[str] = ("A", "C", "G", "T", "N"), 16 | complement_map=None, 17 | bos_token="[BOS]", 18 | eos_token="[SEP]", 19 | sep_token="[SEP]", 20 | cls_token="[CLS]", 21 | pad_token="[PAD]", 22 | mask_token="[MASK]", 23 | unk_token="[UNK]", 24 | **kwargs): 25 | """Character tokenizer for Hugging Face transformers. 26 | 27 | Adapted from https://huggingface.co/LongSafari/hyenadna-tiny-1k-seqlen-hf/blob/main/tokenization_hyena.py 28 | Args: 29 | model_max_length (int): Model maximum sequence length. 30 | characters (Sequence[str]): List of desired characters. Any character which 31 | is not included in this list will be replaced by a special token called 32 | [UNK] with id=6. Following is a list of the special tokens with 33 | their corresponding ids: 34 | "[CLS]": 0 35 | "[SEP]": 1 36 | "[BOS]": 2 37 | "[MASK]": 3 38 | "[PAD]": 4 39 | "[RESERVED]": 5 40 | "[UNK]": 6 41 | an id (starting at 7) will be assigned to each character. 42 | complement_map (Optional[Dict[str, str]]): Dictionary with string complements for each character. 43 | """ 44 | if complement_map is None: 45 | complement_map = {"A": "T", "C": "G", "G": "C", "T": "A", "N": "N"} 46 | self.characters = characters 47 | self.model_max_length = model_max_length 48 | 49 | self._vocab_str_to_int = { 50 | "[CLS]": 0, 51 | "[SEP]": 1, 52 | "[BOS]": 2, 53 | "[MASK]": 3, 54 | "[PAD]": 4, 55 | "[RESERVED]": 5, 56 | "[UNK]": 6, 57 | **{ch: i + 7 for i, ch in enumerate(self.characters)}, 58 | } 59 | self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()} 60 | add_prefix_space = kwargs.pop("add_prefix_space", False) 61 | padding_side = kwargs.pop("padding_side", "left") 62 | 63 | self._complement_map = {} 64 | for k, v in self._vocab_str_to_int.items(): 65 | complement_id = self._vocab_str_to_int[complement_map[k]] if k in complement_map.keys() else v 66 | self._complement_map[self._vocab_str_to_int[k]] = complement_id 67 | 68 | super().__init__( 69 | bos_token=bos_token, 70 | eos_token=eos_token, 71 | sep_token=sep_token, 72 | cls_token=cls_token, 73 | pad_token=pad_token, 74 | mask_token=mask_token, 75 | unk_token=unk_token, 76 | add_prefix_space=add_prefix_space, 77 | model_max_length=model_max_length, 78 | padding_side=padding_side, 79 | **kwargs, 80 | ) 81 | 82 | @property 83 | def vocab_size(self) -> int: 84 | return len(self._vocab_str_to_int) 85 | 86 | @property 87 | def complement_map(self) -> Dict[int, int]: 88 | return self._complement_map 89 | 90 | def _tokenize(self, text: str, **kwargs) -> List[str]: 91 | return list(text.upper()) # Convert all base pairs to uppercase 92 | 93 | def _convert_token_to_id(self, token: str) -> int: 94 | return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"]) 95 | 96 | def _convert_id_to_token(self, index: int) -> str: 97 | return self._vocab_int_to_str[index] 98 | 99 | def convert_tokens_to_string(self, tokens): 100 | return "".join(tokens) # Note: this operation has lost info about which base pairs were originally lowercase 101 | 102 | def get_special_tokens_mask( 103 | self, 104 | token_ids_0: List[int], 105 | token_ids_1: Optional[List[int]] = None, 106 | already_has_special_tokens: bool = False, 107 | ) -> List[int]: 108 | if already_has_special_tokens: 109 | return super().get_special_tokens_mask( 110 | token_ids_0=token_ids_0, 111 | token_ids_1=token_ids_1, 112 | already_has_special_tokens=True, 113 | ) 114 | 115 | result = ([0] * len(token_ids_0)) + [1] 116 | if token_ids_1 is not None: 117 | result += ([0] * len(token_ids_1)) + [1] 118 | return result 119 | 120 | def build_inputs_with_special_tokens( 121 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 122 | ) -> List[int]: 123 | sep = [self.sep_token_id] 124 | # cls = [self.cls_token_id] 125 | result = token_ids_0 + sep 126 | if token_ids_1 is not None: 127 | result += token_ids_1 + sep 128 | return result 129 | 130 | def get_vocab(self) -> Dict[str, int]: 131 | return self._vocab_str_to_int 132 | 133 | # Fixed vocabulary with no vocab file 134 | def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple: 135 | return () 136 | -------------------------------------------------------------------------------- /scripts/benchmark/nt/nt_janusdna.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export HYDRA_FULL_ERROR=1 4 | export LD_LIBRARY_PATH=/janusdna/lib/python3.11/site-packages/nvidia/nvjitlink/lib:$LD_LIBRARY_PATH 5 | 6 | source /etc/profile.d/conda.sh 7 | conda activate janusdna 8 | 9 | cd 10 | 11 | PROJECT_ROOT_DIR="" 12 | 13 | 14 | # datasets 15 | DATASETS=( 16 | "H3K79me3" 17 | "enhancers" 18 | "enhancers_types" 19 | "H3" 20 | "H3K4me1" 21 | "H3K4me2" 22 | "H3K4me3" 23 | "H3K9ac" 24 | "H3K14ac" 25 | "H3K36me3" 26 | "H4" 27 | "H4ac" 28 | "promoter_all" 29 | "promoter_no_tata" 30 | "promoter_tata" 31 | "splice_sites_acceptors" 32 | "splice_sites_all" 33 | "splice_sites_donors" 34 | ) 35 | 36 | LRS=("1e-3" "2e-3") 37 | BATCH_SIZES=(128 256 512) 38 | SEEDS=(1 2 3 4 5 6 7 8 9 10) 39 | NUM_GPUS=8 40 | 41 | # pretrained model path and config 42 | CONFIG_PATH=$(realpath "${PROJECT_ROOT_DIR}/outputs/pretrain/hg38/janusdna_len-1k_d_model-32_inter_dim-128_n_layer-8_lr-8e-3_step-10K_moeloss-true_1head_midattn/model_config.json") 43 | PRETRAINED_PATH=$(realpath "${PROJECT_ROOT_DIR}/outputs/pretrain/hg38/janusdna_len-1k_d_model-32_inter_dim-128_n_layer-8_lr-8e-3_step-10K_moeloss-true_1head_midattn/checkpoints/last.ckpt") 44 | 45 | # model parameters 46 | # name should be same as the pre-trained one 47 | DISPLAY_NAME="janusdna_len-1k_d_model-32_inter_dim-128_n_layer-8_lr-8e-3_step-10K_moeloss-true_1head_midattn" 48 | MODEL="janusdna" 49 | MODEL_NAME="dna_embedding_janusdna" 50 | RC_AUG="false" 51 | CONJOIN_TRAIN_DECODER="false" 52 | CONJOIN_TEST="true" 53 | 54 | HYDRA_RUN_DIR="${PROJECT_ROOT_DIR}/outputs/nt" 55 | 56 | LOG_DIR="${PROJECT_ROOT_DIR}/watch_folder/nt/${DISPLAY_NAME}" 57 | mkdir -p "${LOG_DIR}" 58 | 59 | # task queue 60 | declare -A GPU_TASKS 61 | 62 | 63 | run_task() { 64 | local task=$1 65 | local gpu_id=$2 66 | local lr=$3 67 | local batch_size=$4 68 | local seed=$5 69 | 70 | local WANDB_NAME="${DISPLAY_NAME}_LR-${lr}_BATCH_SIZE-${batch_size}" 71 | local hydra_run_dir="${HYDRA_RUN_DIR}/${WANDB_NAME}/${task}/seed-${seed}" 72 | mkdir -p "${hydra_run_dir}" 73 | 74 | local LOG_FILE="${LOG_DIR}/${task}_gpu-${gpu_id}_lr-${lr}_batch-${batch_size}_seed-${seed}.log" 75 | 76 | local WANDBID=$(python -c "import wandb; print(wandb.util.generate_id())") 77 | 78 | echo "Running Task: ${task}, GPU: ${gpu_id}, LR: ${lr}, BATCH_SIZE: ${batch_size}, SEED: ${seed}" 79 | echo "Logging to: ${LOG_FILE}" 80 | 81 | CUDA_VISIBLE_DEVICES=${gpu_id} nohup python -m train \ 82 | experiment=hg38/nucleotide_transformer \ 83 | dataset.dataset_name="${task}" \ 84 | dataset.train_val_split_seed="${seed}" \ 85 | dataset.batch_size=${batch_size} \ 86 | dataset.rc_aug="${RC_AUG}" \ 87 | +dataset.conjoin_test="${CONJOIN_TEST}" \ 88 | optimizer.lr="${lr}" \ 89 | model="${MODEL}" \ 90 | model._name_="${MODEL_NAME}" \ 91 | +model.config_path="${CONFIG_PATH}" \ 92 | +model.conjoin_test="${CONJOIN_TEST}" \ 93 | +decoder.conjoin_train="${CONJOIN_TRAIN_DECODER}" \ 94 | +decoder.conjoin_test="${CONJOIN_TEST}" \ 95 | train.pretrained_model_path="${PRETRAINED_PATH}" \ 96 | trainer.max_epochs=20 \ 97 | trainer.devices=1 \ 98 | trainer.precision=bf16 \ 99 | wandb.mode="offline" \ 100 | wandb.group="downstream/nt/${task}" \ 101 | wandb.job_type="${task}" \ 102 | wandb.name="${WANDB_NAME}" \ 103 | wandb.id="${WANDBID}" \ 104 | +wandb.tags=\["seed-${seed}"\] \ 105 | +wandb.entity="cardiors" \ 106 | hydra.run.dir="${hydra_run_dir}" \ 107 | > "${LOG_FILE}" 2>&1 & 108 | 109 | GPU_TASKS[${gpu_id}]=$! 110 | echo "Started task on GPU ${gpu_id}, PID: ${GPU_TASKS[${gpu_id}]}" 111 | sleep 5 # Reduce contention 112 | } 113 | 114 | 115 | schedule_tasks() { 116 | local tasks=() 117 | 118 | for dataset in "${DATASETS[@]}"; do 119 | for lr in "${LRS[@]}"; do 120 | for batch_size in "${BATCH_SIZES[@]}"; do 121 | for seed in "${SEEDS[@]}"; do 122 | tasks+=("${dataset} ${lr} ${batch_size} ${seed}") 123 | done 124 | done 125 | done 126 | done 127 | 128 | local total_tasks=${#tasks[@]} 129 | local task_index=0 130 | 131 | # initially launch tasks on available GPUs 132 | for ((gpu_id=0; gpu_id/dev/null; then 142 | echo "GPU ${gpu_id} finished, launching new task..." 143 | IFS=' ' read -r dataset lr batch_size seed <<< "${tasks[$task_index]}" 144 | run_task "${dataset}" "${gpu_id}" "${lr}" "${batch_size}" "${seed}" 145 | ((task_index++)) 146 | fi 147 | done 148 | sleep 5 149 | done 150 | 151 | wait 152 | echo "All tasks completed." 153 | } 154 | 155 | # start the task scheduling 156 | schedule_tasks 157 | -------------------------------------------------------------------------------- /src/dataloaders/datasets/hg38_char_tokenizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | From: https://github.com/dariush-bahrami/character-tokenizer/blob/master/charactertokenizer/core.py 3 | 4 | CharacterTokenizer for Hugging Face Transformers. 5 | This is heavily inspired from CanineTokenizer in transformers package. 6 | """ 7 | import json 8 | import os 9 | from pathlib import Path 10 | from typing import Dict, List, Optional, Sequence, Union 11 | 12 | from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer 13 | 14 | 15 | class CharacterTokenizer(PreTrainedTokenizer): 16 | def __init__(self, characters: Sequence[str], model_max_length: int, padding_side: str = 'left', **kwargs): 17 | """Character tokenizer for Hugging Face transformers. 18 | Args: 19 | characters (Sequence[str]): List of desired characters. Any character which 20 | is not included in this list will be replaced by a special token called 21 | [UNK] with id=6. Following is the list of all the special tokens with 22 | their corresponding ids: 23 | "[CLS]": 0 24 | "[SEP]": 1 25 | "[BOS]": 2 26 | "[MASK]": 3 27 | "[PAD]": 4 28 | "[RESERVED]": 5 29 | "[UNK]": 6 30 | an id (starting at 7) will be assigned to each character. 31 | model_max_length (int): Model maximum sequence length. 32 | """ 33 | self.characters = characters 34 | self.model_max_length = model_max_length 35 | bos_token = AddedToken("[BOS]", lstrip=False, rstrip=False) 36 | eos_token = AddedToken("[EOS]", lstrip=False, rstrip=False) 37 | sep_token = AddedToken("[SEP]", lstrip=False, rstrip=False) 38 | cls_token = AddedToken("[CLS]", lstrip=False, rstrip=False) 39 | pad_token = AddedToken("[PAD]", lstrip=False, rstrip=False) 40 | unk_token = AddedToken("[UNK]", lstrip=False, rstrip=False) 41 | 42 | mask_token = AddedToken("[MASK]", lstrip=True, rstrip=False) 43 | 44 | self._vocab_str_to_int = { 45 | "[CLS]": 0, 46 | "[SEP]": 1, 47 | "[BOS]": 2, 48 | "[MASK]": 3, 49 | "[PAD]": 4, 50 | "[RESERVED]": 5, 51 | "[UNK]": 6, 52 | **{ch: i + 7 for i, ch in enumerate(characters)}, 53 | } 54 | self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()} 55 | 56 | # TODO: This should be a parameter passed to __init__ 57 | complement_map = {"A": "T", "C": "G", "G": "C", "T": "A"} 58 | self.complement_map = {} 59 | for k, v in self._vocab_str_to_int.items(): 60 | complement_id = self._vocab_str_to_int[complement_map[k]] if k in complement_map.keys() else v 61 | self.complement_map[self._vocab_str_to_int[k]] = complement_id 62 | super().__init__( 63 | bos_token=bos_token, 64 | eos_token=pad_token, 65 | sep_token=sep_token, 66 | cls_token=cls_token, 67 | pad_token=pad_token, 68 | mask_token=mask_token, 69 | unk_token=unk_token, 70 | add_prefix_space=False, 71 | model_max_length=model_max_length, 72 | padding_side=padding_side, 73 | **kwargs, 74 | ) 75 | 76 | @property 77 | def vocab_size(self) -> int: 78 | return len(self._vocab_str_to_int) 79 | 80 | def _tokenize(self, text: str) -> List[str]: 81 | return list(text) 82 | 83 | def _convert_token_to_id(self, token: str) -> int: 84 | return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"]) 85 | 86 | def _convert_id_to_token(self, index: int) -> str: 87 | return self._vocab_int_to_str[index] 88 | 89 | def convert_tokens_to_string(self, tokens): 90 | return "".join(tokens) 91 | 92 | def build_inputs_with_special_tokens( 93 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 94 | ) -> List[int]: 95 | sep = [self.sep_token_id] 96 | cls = [self.cls_token_id] 97 | result = cls + token_ids_0 + sep 98 | if token_ids_1 is not None: 99 | result += token_ids_1 + sep 100 | return result 101 | 102 | def get_special_tokens_mask( 103 | self, 104 | token_ids_0: List[int], 105 | token_ids_1: Optional[List[int]] = None, 106 | already_has_special_tokens: bool = False, 107 | ) -> List[int]: 108 | if already_has_special_tokens: 109 | return super().get_special_tokens_mask( 110 | token_ids_0=token_ids_0, 111 | token_ids_1=token_ids_1, 112 | already_has_special_tokens=True, 113 | ) 114 | 115 | result = [1] + ([0] * len(token_ids_0)) + [1] 116 | if token_ids_1 is not None: 117 | result += ([0] * len(token_ids_1)) + [1] 118 | return result 119 | 120 | def get_vocab(self) -> Dict[str, int]: 121 | return self._vocab_str_to_int 122 | 123 | def create_token_type_ids_from_sequences( 124 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 125 | ) -> List[int]: 126 | sep = [self.sep_token_id] 127 | cls = [self.cls_token_id] 128 | 129 | result = len(cls + token_ids_0 + sep) * [0] 130 | if token_ids_1 is not None: 131 | result += len(token_ids_1 + sep) * [1] 132 | return result 133 | 134 | def get_config(self) -> Dict: 135 | return { 136 | "char_ords": [ord(ch) for ch in self.characters], 137 | "model_max_length": self.model_max_length, 138 | } 139 | 140 | @classmethod 141 | def from_config(cls, config: Dict) -> "CharacterTokenizer": 142 | cfg = {} 143 | cfg["characters"] = [chr(i) for i in config["char_ords"]] 144 | cfg["model_max_length"] = config["model_max_length"] 145 | return cls(**cfg) 146 | 147 | def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs): 148 | cfg_file = Path(save_directory) / "tokenizer_config.json" 149 | cfg = self.get_config() 150 | with open(cfg_file, "w") as f: 151 | json.dump(cfg, f, indent=4) 152 | 153 | @classmethod 154 | def from_pretrained(cls, save_directory: Union[str, os.PathLike], **kwargs): 155 | cfg_file = Path(save_directory) / "tokenizer_config.json" 156 | with open(cfg_file) as f: 157 | cfg = json.load(f) 158 | return cls.from_config(cfg) 159 | -------------------------------------------------------------------------------- /src/utils/train.py: -------------------------------------------------------------------------------- 1 | """ Utils for the training loop. 2 | 3 | Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py 4 | """ 5 | 6 | import json 7 | import logging 8 | import warnings 9 | 10 | import rich.syntax 11 | import rich.tree 12 | import torch.nn as nn 13 | from omegaconf import DictConfig, OmegaConf 14 | from pytorch_lightning.utilities import rank_zero_only 15 | 16 | from src.utils.config import omegaconf_filter_keys 17 | 18 | 19 | # Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging 20 | class LoggingContext: 21 | def __init__(self, logger, level=None, handler=None, close=True): 22 | self.logger = logger 23 | self.level = level 24 | self.handler = handler 25 | self.close = close 26 | 27 | def __enter__(self): 28 | if self.level is not None: 29 | self.old_level = self.logger.level 30 | self.logger.setLevel(self.level) 31 | if self.handler: 32 | self.logger.addHandler(self.handler) 33 | 34 | def __exit__(self, et, ev, tb): 35 | if self.level is not None: 36 | self.logger.setLevel(self.old_level) 37 | if self.handler: 38 | self.logger.removeHandler(self.handler) 39 | if self.handler and self.close: 40 | self.handler.close() 41 | # implicit return of None => don't swallow exceptions 42 | 43 | 44 | def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: 45 | """Initializes multi-GPU-friendly python logger.x""" 46 | 47 | logger = logging.getLogger(name) 48 | logger.setLevel(level) 49 | 50 | # this ensures all logging levels get marked with the rank zero decorator 51 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 52 | for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"): 53 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 54 | 55 | return logger 56 | 57 | 58 | def process_config(config: DictConfig) -> DictConfig: # TODO because of filter_keys, this is no longer in place 59 | """A couple of optional utilities, controlled by main config file: 60 | - disabling warnings 61 | - easier access to debug mode 62 | - forcing debug friendly configuration 63 | Modifies DictConfig in place. 64 | Args: 65 | config (DictConfig): Configuration composed by Hydra. 66 | """ 67 | log = get_logger() 68 | 69 | # Filter out keys that were used just for interpolation 70 | config = omegaconf_filter_keys(config, lambda k: not k.startswith('__')) 71 | 72 | # enable adding new keys to config 73 | OmegaConf.set_struct(config, False) 74 | 75 | # disable python warnings if 76 | if config.get("ignore_warnings"): 77 | log.info("Disabling python warnings! ") 78 | warnings.filterwarnings("ignore") 79 | 80 | if config.get("debug"): 81 | log.info("Running in debug mode! ") 82 | config.trainer.fast_dev_run = True 83 | 84 | # force debugger friendly configuration 85 | log.info("Forcing debugger friendly configuration! ") 86 | # Debuggers don't like GPUs or multiprocessing 87 | if config.trainer.get("gpus"): 88 | config.trainer.gpus = 0 89 | if config.loader.get("pin_memory"): 90 | config.loader.pin_memory = False 91 | if config.loader.get("num_workers"): 92 | config.loader.num_workers = 0 93 | 94 | # disable adding new keys to config 95 | # OmegaConf.set_struct(config, True) # [21-09-17 AG] I need this for .pop(_name_) pattern among other things 96 | 97 | return config 98 | 99 | 100 | @rank_zero_only 101 | def print_config( 102 | config: DictConfig, 103 | resolve: bool = True, 104 | save_cfg=True, 105 | ) -> None: 106 | """Prints content of DictConfig using Rich library and its tree structure. 107 | Args: 108 | config (DictConfig): Configuration composed by Hydra. 109 | resolve (bool, optional): Whether to resolve reference fields of DictConfig. 110 | save_cfg (bool, optional): Whether to save the config to a file. 111 | """ 112 | 113 | style = "dim" 114 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 115 | 116 | fields = config.keys() 117 | for field in fields: 118 | branch = tree.add(field, style=style, guide_style=style) 119 | 120 | config_section = config.get(field) 121 | branch_content = str(config_section) 122 | if isinstance(config_section, DictConfig): 123 | branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) 124 | 125 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 126 | 127 | rich.print(tree) 128 | 129 | if save_cfg: 130 | with open("config_tree.txt", "w") as fp: 131 | rich.print(tree, file=fp) 132 | with open("model_config.json", "w") as fp: # Save config / model config for use in fine-tuning or testing 133 | model_config = { 134 | k: v 135 | for k, v in OmegaConf.to_container(config.model, resolve=True).items() 136 | if not k.startswith("_") or k == "config_path" 137 | } 138 | json.dump(model_config, fp, indent=4) 139 | with open("config.json", "w") as fp: 140 | json.dump(OmegaConf.to_container(config, resolve=True), fp, indent=4) 141 | 142 | 143 | def log_optimizer(logger, optimizer, keys): 144 | """ Log values of particular keys from the optimizers param groups """ 145 | keys = sorted(keys) 146 | for i, g in enumerate(optimizer.param_groups): 147 | group_hps = {k: g.get(k, None) for k in keys} 148 | logger.info(' | '.join([ 149 | f"Optimizer group {i}", 150 | f"{len(g['params'])} tensors", 151 | ] + [f"{k} {v}" for k, v in group_hps.items()])) 152 | 153 | 154 | class OptimModule(nn.Module): 155 | """ Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters """ 156 | 157 | def register(self, name, tensor, lr=None, wd=0.0): 158 | """Register a tensor with a configurable learning rate and 0 weight decay""" 159 | 160 | if lr == 0.0: 161 | self.register_buffer(name, tensor) 162 | else: 163 | self.register_parameter(name, nn.Parameter(tensor)) 164 | 165 | optim = {} 166 | if lr is not None: 167 | optim["lr"] = lr 168 | if wd is not None: 169 | optim["weight_decay"] = wd 170 | setattr(getattr(self, name), "_optim", optim) 171 | -------------------------------------------------------------------------------- /janusdna.yml: -------------------------------------------------------------------------------- 1 | name: janusdna 2 | channels: 3 | - pytorch 4 | - nvidia 5 | # - defaults due to limited license of conda 6 | - conda-forge 7 | dependencies: 8 | # system 9 | - pip=25.0 10 | - cuda-nvcc=11.7.99=0 11 | - python=3.11.9 12 | 13 | # torch 14 | - pytorch=2.5.0 15 | - torchvision=0.20.0 16 | - torchaudio=2.5.0 17 | - pytorch-cuda=12.1 18 | - ruamel.yaml 19 | 20 | - pip: 21 | - numpy==1.26.0 22 | - absl-py==2.1.0 23 | - aiohappyeyeballs==2.4.6 24 | - aiohttp==3.11.12 25 | - aiosignal==1.3.2 26 | - annotated-types==0.7.0 27 | - antlr4-python3-runtime==4.9.3 28 | - anyio==4.6.2.post1 29 | - argcomplete==3.5.1 30 | - argh==0.31.3 31 | - argon2-cffi==23.1.0 32 | - argon2-cffi-bindings==21.2.0 33 | - arrow==1.3.0 34 | - asttokens==3.0.0 35 | - astunparse==1.6.3 36 | - async-lru==2.0.4 37 | - attrs 38 | - autocommand==2.2.2 39 | - babel==2.16.0 40 | - backcall==0.2.0 41 | - backports-tarfile==1.2.0 42 | - beautifulsoup4==4.13.3 43 | - binaryornot==0.4.4 44 | - biopython==1.85 45 | - bleach==6.2.0 46 | - cached-property==2.0.1 47 | - cachetools==5.5.0 48 | - causal-conv1d==1.5.0.post8 49 | - certifi==2025.1.31 50 | - cffi==1.17.1 51 | - chardet==5.2.0 52 | - charset-normalizer==3.4.0 53 | - clang==5.0 54 | - click==8.1.8 55 | - colorlog==6.9.0 56 | - comm==0.2.2 57 | - contourpy==1.3.1 58 | - cookiecutter==2.6.0 59 | - cycler==0.12.1 60 | - cython==3.0.11 61 | - datasets==3.4.1 62 | - debugpy==1.8.8 63 | - decorator==5.1.1 64 | - defusedxml==0.7.1 65 | - deprecation==2.1.0 66 | - dill==0.3.8 67 | - docker-pycreds==0.4.0 68 | - docopt==0.6.2 69 | - einops==0.8.0 70 | - exceptiongroup==1.2.2 71 | - executing==2.1.0 72 | - filelock==3.17.0 73 | - flash-attn==2.7.4.post1 74 | - flatbuffers==24.3.25 75 | - fonttools==4.56.0 76 | - fqdn==1.5.1 77 | - frozenlist==1.5.0 78 | - fsspec==2024.12.0 79 | - future==1.0.0 80 | - gast==0.6.0 81 | - gdown==5.2.0 82 | - genomic-benchmarks==0.0.9 83 | - gffutils==0.13 84 | - gitdb==4.0.12 85 | - gitpython==3.1.44 86 | - gmpy2==2.2.1 87 | - google-auth==2.36.0 88 | - google-auth-oauthlib==1.0.0 89 | - google-pasta==0.2.0 90 | - grpcio==1.68.0 91 | - h11==0.14.0 92 | - h5py==3.13.0 93 | - httpcore==1.0.7 94 | - httpx==0.27.2 95 | - huggingface-hub==0.28.1 96 | - hydra-core==1.3.2 97 | - idna==3.10 98 | - importlib-metadata==8.6.1 99 | - importlib-resources==6.5.2 100 | - inflect==7.3.1 101 | - ipykernel==6.29.5 102 | - ipython==8.12.3 103 | - ipywidgets==8.1.5 104 | - isoduration==20.11.0 105 | - jaraco-collections==5.1.0 106 | - jaraco-context==5.3.0 107 | - jaraco-functools==4.0.1 108 | - jaraco-text==3.12.1 109 | - jax==0.4.30 110 | - jaxlib==0.4.30 111 | - jedi==0.19.2 112 | - jinja2==3.1.6 113 | - joblib==1.4.2 114 | - json5==0.9.28 115 | - jsonpointer==3.0.0 116 | - keras==3.9.0 117 | - keras-nightly==2.5.0.dev2021032900 118 | - keras-preprocessing==1.1.2 119 | - kipoi 120 | - kipoi-conda 121 | - kipoi-utils 122 | - kipoiseq 123 | - kiwisolver==1.4.8 124 | - libclang==18.1.1 125 | - liftover==1.3.2 126 | - lightning-utilities==0.11.8 127 | - markdown==3.7 128 | - markdown-it-py==3.0.0 129 | - markupsafe==3.0.2 130 | - matplotlib==3.10.1 131 | - matplotlib-inline==0.1.7 132 | - mdurl==0.1.2 133 | - mistune==3.1.2 134 | - ml-dtypes==0.5.1 135 | - more-itertools==10.3.0 136 | - mpmath==1.3.0 137 | - multidict==6.1.0 138 | - multiprocess==0.70.16 139 | - namex==0.0.8 140 | - natsort==8.4.0 141 | - nbclient==0.10.0 142 | - nbconvert==7.16.6 143 | - ncls==0.0.68 144 | - nest-asyncio==1.6.0 145 | - networkx==3.4.2 146 | - ninja==1.11.1.3 147 | - notebook==7.2.2 148 | - notebook-shim==0.2.4 149 | - oauthlib==3.2.2 150 | - omegaconf==2.3.0 151 | - opt-einsum==3.4.0 152 | - optree==0.14.1 153 | - overrides==7.7.0 154 | - packaging==24.2 155 | - pandas==2.1.4 156 | - pandocfilters==1.5.1 157 | - parso==0.8.4 158 | - patsy==1.0.1 159 | - pexpect==4.9.0 160 | - pickleshare==0.7.5 161 | - pillow==11.1.0 162 | - pipreqs==0.5.0 163 | - platformdirs==4.2.2 164 | - plotly==6.0.1 165 | - polars==1.25.2 166 | - prometheus-client==0.21.0 167 | - prompt-toolkit==3.0.50 168 | - propcache==0.2.1 169 | - protobuf==3.20.3 170 | - psutil==6.1.0 171 | - ptyprocess==0.7.0 172 | - pure-eval==0.2.3 173 | - pyarrow==19.0.0 174 | - pyasn1==0.6.1 175 | - pyasn1-modules==0.4.1 176 | - pybigwig==0.3.23 177 | - pycparser==2.22 178 | - pydantic==2.10.6 179 | - pydantic-core==2.27.2 180 | - pyfaidx==0.8.1.3 181 | - pygments==2.19.1 182 | - pyparsing==3.2.1 183 | - pyranges==0.1.4 184 | - pysocks==1.7.1 185 | - pytabix==0.1 186 | - python-dateutil==2.9.0.post0 187 | - python-json-logger==2.0.7 188 | - python-slugify==8.0.4 189 | - pytorch-lightning==2.5.0.post0 190 | - pytz==2025.1 191 | - pyvcf3==1.0.3 192 | - pyyaml==6.0.2 193 | - pyzmq==26.2.0 194 | - referencing 195 | - regex==2024.11.6 196 | - related==0.7.3 197 | - requests==2.32.3 198 | - requests-oauthlib==2.0.0 199 | - rfc3339-validator==0.1.4 200 | - rfc3986-validator==0.1.1 201 | - rich==13.9.4 202 | - rpds-py==0.23.1 203 | - rsa==4.9 204 | - safetensors==0.5.2 205 | - scikit-learn==1.6.1 206 | - scipy==1.15.2 207 | - seaborn==0.13.2 208 | - send2trash==1.8.3 209 | - sentry-sdk==2.20.0 210 | - setproctitle==1.3.4 211 | - simplejson==3.20.1 212 | - six==1.17.0 213 | - smmap==5.0.2 214 | - sniffio==1.3.1 215 | - sorted-nearest==0.0.39 216 | - soupsieve==2.6 217 | - stack-data==0.6.3 218 | - statsmodels==0.14.4 219 | - sympy==1.13.1 220 | - tabulate==0.9.0 221 | - tenacity==9.0.0 222 | - tensorboard==2.19.0 223 | - tensorboard-data-server==0.7.2 224 | - tensorflow==2.19.0 225 | - tensorflow-io-gcs-filesystem==0.37.1 226 | - termcolor==1.1.0 227 | - terminado==0.18.1 228 | - text-unidecode==1.3 229 | - threadpoolctl==3.5.0 230 | - timm==1.0.14 231 | - tinycss2==1.4.0 232 | - tinydb==4.8.2 233 | - tokenizers==0.21.0 234 | - tomli==2.0.1 235 | - torch==2.5.0 236 | - torchaudio==2.5.0 237 | - torchmetrics==1.6.1 238 | - torchvision==0.20.0 239 | - tornado==6.4.1 240 | - tqdm==4.67.1 241 | - traitlets==5.14.3 242 | - transformers==4.49.0 243 | - triton==3.1.0 244 | - typeguard==4.3.0 245 | - types-python-dateutil==2.9.0.20241003 246 | - typing-extensions==4.12.2 247 | - tzdata==2025.1 248 | - unicodedata2==16.0.0 249 | - uri-template==1.3.0 250 | - urllib3==2.3.0 251 | - wandb==0.19.6 252 | - wcwidth==0.2.13 253 | - webcolors==24.11.1 254 | - webencodings==0.5.1 255 | - websocket-client==1.8.0 256 | - werkzeug==3.1.3 257 | - widgetsnbextension==4.0.13 258 | - wrapt==1.14.0 259 | - xxhash==3.5.0 260 | - yarg==0.1.9 261 | - yarl==1.18.3 262 | - zipp==3.21.0 -------------------------------------------------------------------------------- /src/utils/optim_groups.py: -------------------------------------------------------------------------------- 1 | """Utilities for special optimizer hyperparameters. 2 | 3 | group_parameters_for_optimizer is a modification of timm's optimizer logic, which is currently unused 4 | add_optimizer_hooks is an improved version that uses this codebase's _optim dictionary 5 | """ 6 | 7 | import inspect 8 | 9 | import torch.nn as nn 10 | 11 | import hydra 12 | 13 | 14 | def add_optimizer_hooks( 15 | model, 16 | bias_weight_decay=False, 17 | normalization_weight_decay=False, 18 | ): 19 | """Set weight_decay=0.0 for parameters in model.no_weight_decay, for parameters with 20 | attribute _no_weight_decay==True, for bias parameters if bias_weight_decay==False, for 21 | normalization parameters if normalization_weight_decay==False 22 | """ 23 | 24 | # Separate out all parameters to those that will and won't experience regularizing weight decay 25 | blacklist_weight_modules = (nn.Embedding, ) 26 | if not normalization_weight_decay: 27 | blacklist_weight_modules += (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, 28 | # Not compatible with Pytorch 1.8.1 29 | # nn.LazyBatchNorm1d, nn.LazyBatchNorm2d, nn.LazyBatchNorm3d, 30 | nn.GroupNorm, nn.SyncBatchNorm, 31 | nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, 32 | nn.LayerNorm, nn.LocalResponseNorm) 33 | for mn, m in model.named_modules(): 34 | for pn, p in m.named_parameters(): 35 | if (not bias_weight_decay and pn.endswith('bias')) \ 36 | or getattr(p, '_no_weight_decay', False) \ 37 | or isinstance(m, blacklist_weight_modules): 38 | setattr(p, "_optim", {"weight_decay": 0.0}) 39 | 40 | 41 | def group_parameters_for_optimizer( 42 | model, 43 | optimizer_cfg, 44 | bias_weight_decay=False, 45 | normalization_weight_decay=False, 46 | ): 47 | """Set weight_decay=0.0 for parameters in model.no_weight_decay, for parameters with 48 | attribute _no_weight_decay==True, for bias parameters if bias_weight_decay==False, for 49 | normalization parameters if normalization_weight_decay==False 50 | """ 51 | # Get the weight decay from the config, or from the default value of the optimizer constructor 52 | # if it's not specified in the config. 53 | if 'weight_decay' in optimizer_cfg: 54 | weight_decay = optimizer_cfg.weight_decay 55 | else: 56 | # https://stackoverflow.com/questions/12627118/get-a-function-arguments-default-value 57 | signature = inspect.signature(hydra.utils.get_class(optimizer_cfg._target_)) 58 | if 'weight_decay' in signature.parameters: 59 | weight_decay = signature.parameters['weight_decay'].default 60 | if weight_decay is inspect.Parameter.empty: 61 | weight_decay = 0.0 62 | else: 63 | weight_decay = 0.0 64 | 65 | # If none of the parameters have weight decay anyway, and there are no parameters with special 66 | # optimization params 67 | if weight_decay == 0.0 and not any(hasattr(p, '_optim') for p in model.parameters()): 68 | return model.parameters() 69 | 70 | skip = model.no_weight_decay() if hasattr(model, 'no_weight_decay') else set() 71 | skip_keywords = (model.no_weight_decay_keywords() if hasattr(model, 'no_weight_decay_keywords') 72 | else set()) 73 | 74 | # Adapted from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py#L134 75 | """ 76 | This long function is unfortunately doing something very simple and is being very defensive: 77 | We are separating out all parameters of the model into two buckets: those that will experience 78 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 79 | We are then returning the PyTorch optimizer object. 80 | """ 81 | 82 | # separate out all parameters to those that will and won't experience regularizing weight decay 83 | decay = set() 84 | no_decay = set() 85 | special = set() 86 | whitelist_weight_modules = (nn.Linear, ) 87 | blacklist_weight_modules = (nn.Embedding, ) 88 | if not normalization_weight_decay: 89 | blacklist_weight_modules += (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, 90 | # Not compatible with Pytorch 1.8.1 91 | # nn.LazyBatchNorm1d, nn.LazyBatchNorm2d, nn.LazyBatchNorm3d, 92 | nn.GroupNorm, nn.SyncBatchNorm, 93 | nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, 94 | nn.LayerNorm, nn.LocalResponseNorm) 95 | for mn, m in model.named_modules(): 96 | for pn, p in m.named_parameters(): 97 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name 98 | if not p.requires_grad: 99 | continue # frozen weights 100 | if hasattr(p, '_optim'): 101 | special.add(fpn) 102 | elif fpn in skip or any(skip_keyword in fpn for skip_keyword in skip_keywords): 103 | no_decay.add(fpn) 104 | elif getattr(p, '_no_weight_decay', False): 105 | no_decay.add(fpn) 106 | elif not bias_weight_decay and pn.endswith('bias'): 107 | no_decay.add(fpn) 108 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 109 | # weights of whitelist modules will be weight decayed 110 | decay.add(fpn) 111 | elif isinstance(m, blacklist_weight_modules): 112 | # weights of blacklist modules will NOT be weight decayed 113 | no_decay.add(fpn) 114 | 115 | param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad} 116 | # special case the position embedding parameter in the root GPT module as not decayed 117 | if 'pos_emb' in param_dict: 118 | no_decay.add('pos_emb') 119 | 120 | # In case of parameter sharing, some parameters show up in decay but are not in param_dict.keys() 121 | decay &= param_dict.keys() 122 | decay |= (param_dict.keys() - no_decay - special) 123 | # validate that we considered every parameter 124 | inter_params = decay & no_decay 125 | union_params = decay | no_decay 126 | assert len(inter_params) == 0, f"Parameters {str(inter_params)} made it into both decay/no_decay sets!" 127 | assert len(param_dict.keys() - special - union_params) == 0, f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!" 128 | 129 | if weight_decay == 0.0 or not no_decay: 130 | param_groups = [{"params": [param_dict[pn] for pn in sorted(list(no_decay | decay))], 131 | "weight_decay": weight_decay}] 132 | else: 133 | param_groups = [ 134 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, 135 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, 136 | ] 137 | # Add parameters with special hyperparameters 138 | # Unique dicts 139 | hps = [dict(s) for s in set(frozenset(param_dict[pn]._optim.items()) for pn in special)] 140 | for hp in hps: 141 | params = [param_dict[pn] for pn in sorted(list(special)) if param_dict[pn]._optim == hp] 142 | param_groups.append({"params": params, **hp}) 143 | 144 | return param_groups 145 | -------------------------------------------------------------------------------- /src/dataloaders/base.py: -------------------------------------------------------------------------------- 1 | """ Datasets for core experimental results. 2 | 3 | """ 4 | 5 | import os 6 | from functools import partial 7 | from pathlib import Path 8 | 9 | import torch 10 | 11 | 12 | # Default data path is environment variable or /data 13 | if (default_data_path := os.getenv("DATA_PATH")) is None: 14 | default_data_path = Path(__file__).parent.parent.parent.absolute() 15 | default_data_path = default_data_path / "data" 16 | else: 17 | default_data_path = Path(default_data_path).absolute() 18 | 19 | 20 | class DefaultCollateMixin: 21 | """Controls collating in the DataLoader 22 | 23 | The CollateMixin classes instantiate a dataloader by separating collate arguments with the rest of the dataloader 24 | arguments. Instantiations of this class should modify the callback functions as desired, and modify the collate_args 25 | list. The class then defines a _dataloader() method which takes in a DataLoader constructor and arguments, 26 | constructs a collate_fn based on the collate_args, and passes the rest of the arguments into the constructor. 27 | """ 28 | 29 | @classmethod 30 | def _collate_callback(cls, x, *args, **kwargs): 31 | """ 32 | Modify the behavior of the default _collate method. 33 | """ 34 | return x 35 | 36 | _collate_arg_names = [] 37 | 38 | @classmethod 39 | def _return_callback(cls, return_value, *args, **kwargs): 40 | """ 41 | Modify the return value of the collate_fn. 42 | Assign a name to each element of the returned tuple beyond the (x, y) pairs 43 | See InformerSequenceDataset for an example of this being used 44 | """ 45 | x, y, *z = return_value 46 | assert len(z) == len(cls._collate_arg_names), "Specify a name for each auxiliary data item returned by dataset" 47 | return x, y, {k: v for k, v in zip(cls._collate_arg_names, z)} 48 | 49 | @classmethod 50 | def _collate(cls, batch, *args, **kwargs): 51 | # From https://github.com/pyforch/pytorch/blob/master/torch/utils/data/_utils/collate.py 52 | elem = batch[0] 53 | if isinstance(elem, torch.Tensor): 54 | out = None 55 | if torch.utils.data.get_worker_info() is not None: 56 | # If we're in a background process, concatenate directly into a 57 | # shared memory tensor to avoid an extra copy 58 | numel = sum(x.numel() for x in batch) 59 | storage = elem.storage()._new_shared(numel) 60 | out = elem.new(storage) 61 | x = torch.stack(batch, dim=0, out=out) 62 | 63 | # Insert custom functionality into the collate_fn 64 | x = cls._collate_callback(x, *args, **kwargs) 65 | 66 | return x 67 | else: 68 | return torch.tensor(batch) 69 | 70 | @classmethod 71 | def _collate_fn(cls, batch, *args, **kwargs): 72 | """ 73 | Default collate function. 74 | Generally accessed by the dataloader() methods to pass into torch DataLoader 75 | 76 | Arguments: 77 | batch: list of (x, y) pairs 78 | args, kwargs: extra arguments that get passed into the _collate_callback and _return_callback 79 | """ 80 | x, y, *z = zip(*batch) 81 | 82 | x = cls._collate(x, *args, **kwargs) 83 | y = cls._collate(y) 84 | z = [cls._collate(z_) for z_ in z] 85 | 86 | return_value = (x, y, *z) 87 | return cls._return_callback(return_value, *args, **kwargs) 88 | 89 | # List of loader arguments to pass into collate_fn 90 | collate_args = [] 91 | 92 | def _dataloader(self, dataset, **loader_args): 93 | collate_args = {k: loader_args[k] for k in loader_args if k in self.collate_args} 94 | loader_args = {k: loader_args[k] for k in loader_args if k not in self.collate_args} 95 | loader_cls = loader_registry[loader_args.pop("_name_", None)] 96 | return loader_cls( 97 | dataset=dataset, 98 | collate_fn=partial(self._collate_fn, **collate_args), 99 | **loader_args, 100 | ) 101 | 102 | 103 | # class SequenceDataset(LightningDataModule): 104 | # [21-09-10 AG] Subclassing LightningDataModule fails due to trying to access _has_setup_fit. No idea why. So we just 105 | # provide our own class with the same core methods as LightningDataModule (e.g. setup) 106 | class SequenceDataset(DefaultCollateMixin): 107 | registry = {} 108 | _name_ = NotImplementedError("Dataset must have shorthand name") 109 | 110 | # Since subclasses do not specify __init__ which is instead handled by this class 111 | # Subclasses can provide a list of default arguments which are automatically registered as attributes 112 | # TODO it might be possible to write this as a @dataclass, but it seems tricky to separate from the other features 113 | # of this class such as the _name_ and d_input/d_output 114 | @property 115 | def init_defaults(self): 116 | return {} 117 | 118 | # https://www.python.org/dev/peps/pep-0487/#subclass-registration 119 | def __init_subclass__(cls, **kwargs): 120 | super().__init_subclass__(**kwargs) 121 | cls.registry[cls._name_] = cls 122 | 123 | def __init__(self, _name_, data_dir=None, **dataset_cfg): 124 | assert _name_ == self._name_ 125 | self.data_dir = Path(data_dir).absolute() if data_dir is not None else None 126 | 127 | # Add all arguments to self 128 | init_args = self.init_defaults.copy() 129 | init_args.update(dataset_cfg) 130 | for k, v in init_args.items(): 131 | setattr(self, k, v) 132 | 133 | # The train, val, test datasets must be set by `setup()` 134 | self.dataset_train = self.dataset_val = self.dataset_test = None 135 | 136 | self.init() 137 | 138 | def init(self): 139 | """Hook called at end of __init__, override this instead of __init__""" 140 | pass 141 | 142 | def setup(self): 143 | """This method should set self.dataset_train, self.dataset_val, and self.dataset_test.""" 144 | raise NotImplementedError 145 | 146 | def split_train_val(self, val_split): 147 | """ 148 | Randomly split self.dataset_train into a new (self.dataset_train, self.dataset_val) pair. 149 | """ 150 | train_len = int(len(self.dataset_train) * (1.0 - val_split)) 151 | self.dataset_train, self.dataset_val = torch.utils.data.random_split( 152 | self.dataset_train, 153 | (train_len, len(self.dataset_train) - train_len), 154 | generator=torch.Generator().manual_seed( 155 | getattr(self, "seed", 42) 156 | ), # PL is supposed to have a way to handle seeds properly, but doesn't seem to work for us 157 | ) 158 | 159 | def train_dataloader(self, **kwargs): 160 | """Return a DataLoader for the training dataset.""" 161 | return self._train_dataloader(self.dataset_train, **kwargs) 162 | 163 | def _train_dataloader(self, dataset, **kwargs): 164 | if dataset is None: 165 | return 166 | kwargs['shuffle'] = 'sampler' not in kwargs # shuffle cant be True if we have custom sampler 167 | return self._dataloader(dataset, **kwargs) 168 | 169 | def val_dataloader(self, **kwargs): 170 | """Return a DataLoader for the validation dataset.""" 171 | return self._eval_dataloader(self.dataset_val, **kwargs) 172 | 173 | def test_dataloader(self, **kwargs): 174 | """Return a DataLoader for the test dataset.""" 175 | return self._eval_dataloader(self.dataset_test, **kwargs) 176 | 177 | def _eval_dataloader(self, dataset, **kwargs): 178 | if dataset is None: 179 | return 180 | # Note that shuffle=False by default 181 | return self._dataloader(dataset, **kwargs) 182 | 183 | def __str__(self): 184 | return self._name_ 185 | 186 | 187 | # Registry for dataloader class 188 | loader_registry = { 189 | None: torch.utils.data.DataLoader, # default case 190 | } 191 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | Caduceus 3 |

4 | 5 | # JanusDNA: A Powerful Bi-directional Hybrid DNA Foundation Model 6 | 7 | ## Update 8 | 9 | 10 | - **2025-08-29**: Pre-training weights for models with final MLPs are available [here](https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi%3A10.7910%2FDVN%2FHDT0RN&version=DRAFT). 11 | - **2025-08-01**: Attaching two MLP layers after Janus fusion layer significantly improves model performance with minimal parameter scaling. 12 | 13 | `JanusDNA_mlp` are the latest ones. 14 | 15 | - Nucleotide Transformer (NT) 16 | 17 | 18 | 19 | | | JanusDNA | JanusDNA | JanusDNA_mlp | JanusDNA_mlp | 20 | | --- | --- | --- | --- | --- | 21 | | | w/ midattn | w/o midattn | w/ midattn; w/ rc | w/o midattn; w/ rc | 22 | | size(M) | 1.980 | 1.988 | 2.001 | 2.009 | 23 | | H3 | 0.821±0.021 | 0.824±0.012 | **0.835±0.009** | 0.831±0.023 | 24 | | H3K14ac | 0.665 ± 0.034 | 0.685±0.016 | **0.729±0.022** | 0.718±0.026 | 25 | | H3K36me3 | 0.658 ± 0.024 | 0.670±0.012 | **0.702±0.015** | 0.699±0.025 | 26 | | H3K4me1 | 0.563 ± 0.041 | 0.571±0.018 | 0.615±0.035 | **0.616±0.018** | 27 | | H3K4me2 | 0.509 ± 0.056 | 0.548±0.022 | **0.589±0.023** | 0.586±0.019 | 28 | | H3K4me3 | 0.605 ± 0.030 | 0.629±0.022 | **0.688±0.026** | 0.675±0.014 | 29 | | H3K79me3 | 0.716 ± 0.017 | 0.727±0.023  | **0.747±0.013** | 0.743±0.009 | 30 | | H3K9ac | 0.641 ± 0.024 | 0.639±0.019 | **0.673±0.014** | 0.661±0.027 | 31 | | H4 | 0.809 ± 0.021 | **0.816±0.008** | 0.812±0.011 | 0.813±0.013 | 32 | | H4ac | 0.637±0.060 | 0.653±0.034 | 0.698±0.013 | **0.705±0.023** | 33 | | enhancers | 0.564 ± 0.022 | 0.535±0.036 | **0.559±0.042** | 0.542±0.044 | 34 | | EnhancersTypes | 0.462±0.049 | 0.470±0.025 | **0.503±0.038** | 0.492±0.096 | 35 | | PromoterAll | 0.969±0.002 | **0.971±0.002** | 0.970±0.002 | 0.970±0.003 | 36 | | PromoterNoTata | 0.971±0.003 | **0.971±0.002** | 0.971±0.004 | 0.971±0.003 | 37 | | PromoterTata | 0.956±0.010 | 0.958±0.008 | 0.958±0.007 | **0.960±0.008** | 38 | | SpliceSitesAll | 0.963±0.022 | 0.960±0.009 | **0.967±0.005** | 0.943±0.020 | 39 | | SpliceSitesAcceptors | 0.949±0.020 | 0.939±0.022 | 0.957±0.012 | **0.961±0.009** | 40 | | SpliceSitesDonors | 0.947±0.015 | 0.936±0.014 | **0.948±0.008** | 0.935±0.016 | 41 | 42 | 43 | - DNALONGBENCH 44 | 45 | 46 | | | Caduceus-PH | JanusDNA; w/o midattn | JanusDNA_mlp; w/o midattn| 47 | | --- | --- | --- | --- | 48 | | size | 7.7M | 7.662 M | 7.745 M | 49 | | AT | 0.690 | 0.802 | **0.851** | 50 | | AS | 0.759 | 0.740 | **0.768** | 51 | | CCF | 0.689 | 0.770 | **0.801** | 52 | | MS | 0.789 | 0.803 | **0.864** | 53 | | NT | 0.841 | 0.877 | **0.913** | 54 | | SNSES | 0.812 | 0.874 | **0.903** | 55 | | SSELL | 0.691 | 0.706 | **0.845** | 56 | | Thyroid | 0.703 | 0.752 | **0.792** | 57 | | WB | 0.768 | 0.794 | **0.821** | 58 | 59 | 60 | 61 | - **2025-06-28**: Pretrained JanusDNA weights are now available for download [here](https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi%3A10.7910%2FDVN%2FHDT0RN&version=DRAFT). 62 | 63 | 64 | ## Getting Started 65 | 66 | 67 | To begin, create a conda environment with the required dependencies: 68 | 69 | ```bash 70 | conda env create -f janusdna.yml 71 | ``` 72 | 73 | Activate the environment: 74 | 75 | ```bash 76 | conda activate janusdna 77 | ``` 78 | 79 | Install Mamba: 80 | 81 | ```bash 82 | wget https://github.com/state-spaces/mamba/releases/download/v2.2.4/mamba_ssm-2.2.4+cu12torch2.5cxx11abiFALSE-cp311-cp311-linux_x86_64.whl 83 | pip install mamba_ssm-2.2.4+cu12torch2.5cxx11abiFALSE-cp311-cp311-linux_x86_64.whl 84 | 85 | pip install selene-sdk --no-deps 86 | ``` 87 | 88 | ## Reproducing Experiments 89 | As described in the paper, there are four main experimental components: 90 | 1. Pretraining JanusDNA on the Human Reference Genome 91 | 2. GenomicBenchmarks 92 | 3. Nucleotide Transformer Datasets 93 | 4. DNALongBench 94 | 95 | ### Pretraining on the Human Reference Genome 96 | (Data downloading instructions adapted from the [HyenaDNA](https://github.com/HazyResearch/hyena-dna?tab=readme-ov-file#pretraining-on-human-reference-genome)) 97 | 98 | First, download the Human Reference Genome data, which consists of two files: a `.fasta` file containing all sequences, and a `.bed` file specifying the intervals used. 99 | 100 | The directory structure should be: 101 | 102 | ``` 103 | data 104 | |-- hg38/ 105 | |-- hg38.ml.fa 106 | |-- human-sequences.bed 107 | ``` 108 | 109 | Download the fasta (.fa) file for the entire human genome into `./data/hg38`. The genome contains approximately 24 chromosomes, merged into a single file. Then, download the .bed file with sequence intervals (chromosome name, start, end, split), which allows retrieval from the fasta file. 110 | 111 | ```bash 112 | mkdir -p data/hg38/ 113 | curl https://storage.googleapis.com/basenji_barnyard2/hg38.ml.fa.gz > data/hg38/hg38.ml.fa.gz 114 | gunzip data/hg38/hg38.ml.fa.gz # unzip the fasta file 115 | curl https://storage.googleapis.com/basenji_barnyard2/sequences_human.bed > data/hg38/human-sequences.bed 116 | ``` 117 | 118 | Run a pre-training script from [`scripts/pre_train/`](./scripts/pre_train): 119 | 120 | ``` 121 | |-- scripts 122 | |--pre_train 123 | |-- slurm_JanusDNA_w_midattn_32dim.sh 124 | |-- slurm_JanusDNA_w_midattn_72dim.sh 125 | |-- slurm_JanusDNA_wo_midattn_32dim.sh 126 | |-- slurm_JanusDNA_wo_midattn_72dim.sh 127 | |-- slurm_JanusDNA_wo_midattn_144dim.sh 128 | ``` 129 | 130 | For example: 131 | 132 | ```bash 133 | cd scripts/pre_train/ 134 | sbatch slurm_JanusDNA_w_midattn_32dim.sh 135 | ``` 136 | 137 | ### GenomicBenchmarks 138 | 139 | 140 | The [GenomicBenchmarks](https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks) suite, as presented in [Grešová et al. (2023)](https://bmcgenomdata.biomedcentral.com/articles/10.1186/s12863-023-01123-8), comprises eight classification tasks. 141 | 142 | You can launch fine-tuning with 5-fold cross-validation using [`gb_janusdna.sh`](./scripts/benchmark/gb/gb_janusdna.sh): 143 | 144 | ```bash 145 | bash scripts/benchmark/gb/gb_janusdna.sh 146 | ``` 147 | 148 | ### Nucleotide Transformer Datasets 149 | 150 | 151 | The Nucleotide Transformer suite of tasks was introduced in [Dalla-Torre et al. (2023)](https://www.biorxiv.org/content/10.1101/2023.01.11.523679v1). The data is available on HuggingFace: [InstaDeepAI/nucleotide_transformer_downstream_tasks](https://huggingface.co/datasets/InstaDeepAI/nucleotide_transformer_downstream_tasks). 152 | 153 | You can launch fine-tuning with 10-fold cross-validation using [`nt_janusdna.sh`](./scripts/benchmark/nt/nt_janusdna.sh): 154 | 155 | ```bash 156 | bash scripts/benchmark/nt/nt_janusdna.sh 157 | ``` 158 | 159 | ### DNALongBench 160 | Download the dataset from the [dataset website](https://dataverse.harvard.edu/privateurl.xhtml?token=93d446a5-9c75-44bf-be1c-7622563c48d0) following the instructions in the [DNALongBench repository](https://github.com/wenduocheng/DNALongBench?tab=readme-ov-file). 161 | 162 | Place the eQTL dataset zip file in the [`data`](./data) directory and unzip it: 163 | 164 | ```bash 165 | mkdir -p /data/dnalongbench 166 | mv /data/dnalongbench 167 | unzip /data/ 168 | 169 | cd data/dnalongbench/eQTL/seqs 170 | gunzip hg38.fa.gz 171 | ``` 172 | 173 | You can fine-tune on a specific cell-type dataset using [`eqtl_train_janus_8gpu.sh`](./scripts/benchmark/dnalong/eqtl_train_janus_8gpu.sh): 174 | 175 | ```bash 176 | sbatch scripts/benchmark/dnalong/eqtl_train_janus_8gpu.sh 177 | ``` 178 | 179 | After fine-tuning, evaluate the results on the corresponding test dataset using [`eqtl_evaluation_janus.sh`](./scripts/benchmark/dnalong/eqtl_evaluation_janus.sh): 180 | 181 | ```bash 182 | sbatch scripts/benchmark/dnalong/eqtl_evaluation_janus.sh 183 | ``` 184 | 185 | Evaluation output and all log files for fine-tuning and evaluation will be stored in the `watch_folder` directory. For example: 186 | 187 | ``` 188 | watch_folder 189 | |-- eQTL 190 | |-- janusdna_len-131k_d_model-144_inter_dim-576_n_layer-8_lr-8e-3_step-50K_moeloss-true_1head_onlymoe 191 | |-- Whole_Blood_lr-4e-4_cjtrain_false_batch_8_seed_1.log 192 | (fine-tuning log) 193 | |-- Whole_Blood_lr-4e-4_cjtrain_false_batch_8_seed_1_cjtest_true.log 194 | (evaluation log) 195 | |-- Whole_Blood_lr-4e-4_cjtrain_false_batch_4_seed_1_cjtest_true_output.txt 196 | (evaluation output) 197 | ``` 198 | 199 | To calculate AUROC based on the evaluation output, use the script [`auroc.py`](./evals/auroc.py). 200 | 201 | A script is also provided to calculate AUROC for all cell-type datasets at once, [`evaluate_auroc_janus.py`](./evals/evaluate_auroc_janus.py). 202 | 203 | 204 | # Acknowledgements 205 | This repository is adopted from the [Caduceus](https://github.com/kuleshov-group/caduceus) and leverages much of the training, data loading, and logging infrastructure defined there. Caduceus was originally derived from the [HyenaDNA](https://github.com/HazyResearch/hyena-dna). 206 | We also acknowledge the contributions of [`Jamba-v0.1`](https://huggingface.co/ai21labs/Jamba-v0.1/tree/main), which provided the initial codebase for hybrid architectures. 207 | --------------------------------------------------------------------------------