├── convnova ├── src │ ├── models │ │ ├── LegNet │ │ │ └── __init__.py │ │ ├── NTV2 │ │ │ ├── __init__.py │ │ │ └── ntv2.py │ │ ├── nn │ │ │ ├── __init__.py │ │ │ ├── residual.py │ │ │ ├── utils.py │ │ │ └── gate.py │ │ ├── sequence │ │ │ ├── __init__.py │ │ │ ├── ff.py │ │ │ ├── long_conv_kernel.py │ │ │ ├── block.py │ │ │ ├── base.py │ │ │ └── model.py │ │ └── basenji2 │ │ │ ├── params_human.json │ │ │ └── basenji2.py │ ├── dataloaders │ │ ├── __init__.py │ │ ├── datasets │ │ │ ├── lm_dataset.py │ │ │ ├── hg38_char_tokenizer.py │ │ │ └── icl_genomics_dataset.py │ │ └── fault_tolerant_sampler.py │ ├── utils │ │ ├── __init__.py │ │ ├── registry.py │ │ ├── optim │ │ │ └── schedulers.py │ │ ├── config.py │ │ ├── profiling.py │ │ ├── distributed.py │ │ └── train.py │ ├── callbacks │ │ ├── params.py │ │ ├── norms.py │ │ ├── gpu_affinity.py │ │ ├── timer.py │ │ └── progressive_resizing.py │ ├── tasks │ │ └── torchmetrics.py │ └── ops │ │ ├── fftconv.py │ │ ├── vandermonde.py │ │ ├── toeplitz.py │ │ └── krylov.py ├── configs │ ├── model │ │ ├── layer │ │ │ ├── id.yaml │ │ │ ├── ff.yaml │ │ │ ├── h3-conv.yaml │ │ │ ├── mha_dna.yaml │ │ │ ├── mha.yaml │ │ │ ├── hyena.yaml │ │ │ ├── hyena-filter.yaml │ │ │ └── hyena_dna.yaml │ │ └── transformer.yaml │ ├── task │ │ ├── lm.yaml │ │ ├── regression.yaml │ │ ├── masked_multiclass_classification.yaml │ │ ├── multiclass_classification.yaml │ │ └── multilabel_classification.yaml │ ├── callbacks │ │ ├── gpu_affinity.yaml │ │ ├── rich.yaml │ │ ├── base.yaml │ │ ├── wandb.yaml │ │ └── checkpoint.yaml │ ├── loader │ │ └── default.yaml │ ├── scheduler │ │ ├── constant.yaml │ │ ├── step.yaml │ │ ├── multistep.yaml │ │ ├── cosine_warmup.yaml │ │ ├── linear_warmup.yaml │ │ ├── constant_warmup.yaml │ │ ├── cosine.yaml │ │ ├── cosine_warmup_timm.yaml │ │ └── plateau.yaml │ ├── optimizer │ │ ├── adamw.yaml │ │ ├── sgd.yaml │ │ └── adam.yaml │ ├── experiment │ │ ├── base.yaml │ │ ├── hg38-pretrain │ │ │ ├── convNext.yaml │ │ │ ├── convnova.yaml │ │ │ ├── transformer.yaml │ │ │ ├── bert_hg38_hyena.yaml │ │ │ └── mamba.yaml │ │ ├── nt-benchmark │ │ │ ├── legnet.yaml │ │ │ ├── convnova.yaml │ │ │ └── hyena1.6M.yaml │ │ └── genomic-benchmark │ │ │ ├── legnet.yaml │ │ │ ├── basenji.yaml │ │ │ ├── convnova.yaml │ │ │ └── hyena.yaml │ ├── dataset │ │ ├── hg38_fixed_test.yaml │ │ ├── hg38.yaml │ │ ├── bert_hg38.yaml │ │ ├── dnabert2_pretrain.yaml │ │ ├── chromatin_profile.yaml │ │ ├── icl_hg38.yaml │ │ ├── deepsea.yaml │ │ ├── genomic_benchmark.yaml │ │ └── nucleotide_transformer.yaml │ ├── pipeline │ │ ├── hg38.yaml │ │ ├── bert_hg38.yaml │ │ ├── dnabert2_pretrain.yaml │ │ ├── species.yaml │ │ ├── chromatin_profile.yaml │ │ ├── genomic_benchmark.yaml │ │ ├── gue.yaml │ │ ├── deepstarr.yaml │ │ └── nucleotide_transformer.yaml │ ├── trainer │ │ ├── debug.yaml │ │ ├── default.yaml │ │ ├── lm.yaml │ │ └── full.yaml │ ├── evals │ │ ├── hyena_dna_512ksl.yaml │ │ ├── hg38.yaml │ │ ├── icl_genomics.yaml │ │ ├── instruction_tuned_genomics.yaml │ │ ├── soft_prompting_genomics.yaml │ │ └── hg38_decoder.yaml │ └── config.yaml └── requirements.txt ├── logo.jpg └── README.md /convnova/src/models/LegNet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /convnova/src/models/NTV2/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /convnova/configs/model/layer/id.yaml: -------------------------------------------------------------------------------- 1 | _name_: id 2 | -------------------------------------------------------------------------------- /logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aim-uofa/ConvNova/HEAD/logo.jpg -------------------------------------------------------------------------------- /convnova/src/dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from . import et, genomics 2 | from .base import SequenceDataset 3 | -------------------------------------------------------------------------------- /convnova/src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import is_list, is_dict, to_list, to_dict, get_class, instantiate 2 | -------------------------------------------------------------------------------- /convnova/src/models/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .components import LinearActivation, Activation, Normalization, DropoutNd 2 | -------------------------------------------------------------------------------- /convnova/configs/task/lm.yaml: -------------------------------------------------------------------------------- 1 | _name_: lm 2 | # loss: cross_entropy # Handled by task: cross entropy loss 3 | metrics: ppl 4 | -------------------------------------------------------------------------------- /convnova/configs/callbacks/gpu_affinity.yaml: -------------------------------------------------------------------------------- 1 | gpu_affinity: 2 | _name_: gpu_affinity 3 | # _target_: src.callbacks.gpu_affinity.GpuAffinity -------------------------------------------------------------------------------- /convnova/configs/task/regression.yaml: -------------------------------------------------------------------------------- 1 | # _target_: tasks.tasks.BaseTask 2 | _name_: base 3 | loss: mse 4 | metrics: mse 5 | torchmetrics: null 6 | -------------------------------------------------------------------------------- /convnova/src/models/sequence/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import SequenceModule, TransposedModule 2 | from .model import SequenceModel 3 | from .ff import FF 4 | -------------------------------------------------------------------------------- /convnova/configs/task/masked_multiclass_classification.yaml: -------------------------------------------------------------------------------- 1 | _name_: masked_multiclass 2 | loss: cross_entropy 3 | metrics: 4 | - accuracy 5 | torchmetrics: null -------------------------------------------------------------------------------- /convnova/configs/model/layer/ff.yaml: -------------------------------------------------------------------------------- 1 | _name_: ff 2 | expand: 4 3 | dropout: null 4 | transposed: False 5 | dropout: 0.0 6 | tie_dropout: ${model.tie_dropout,null} 7 | -------------------------------------------------------------------------------- /convnova/configs/model/layer/h3-conv.yaml: -------------------------------------------------------------------------------- 1 | _name_: h3-conv 2 | head_dim: 1 3 | learning_rate: ${eval:"min(0.001, ${optimizer.lr})"} 4 | kernel_dropout: 0.2 5 | lam: 0.003 -------------------------------------------------------------------------------- /convnova/configs/loader/default.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 50 2 | num_workers: 4 3 | pin_memory: True 4 | drop_last: True # We set this to true because of the recurrent state mechanism -------------------------------------------------------------------------------- /convnova/configs/scheduler/constant.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | interval: epoch 4 | scheduler: 5 | # _target_: transformers.get_constant_schedule 6 | _name_: constant 7 | -------------------------------------------------------------------------------- /convnova/configs/model/layer/mha_dna.yaml: -------------------------------------------------------------------------------- 1 | # _name_: mha_dna 2 | num_heads: 1 3 | causal: True 4 | use_flash_attn: True 5 | fused_bias_fc: True 6 | # device: 'cuda' 7 | # dtype: tensor.float16 -------------------------------------------------------------------------------- /convnova/configs/optimizer/adamw.yaml: -------------------------------------------------------------------------------- 1 | # _target_: torch.optim.AdamW 2 | _name_: adamw 3 | lr: 0.001 # Initial learning rate 4 | weight_decay: 0.00 # Weight decay 5 | betas: [0.9, 0.999] 6 | -------------------------------------------------------------------------------- /convnova/configs/optimizer/sgd.yaml: -------------------------------------------------------------------------------- 1 | # _target_: torch.optim.SGD 2 | _name_: sgd 3 | lr: 0.001 # Initial learning rate 4 | momentum: 0.9 5 | weight_decay: 0.0 # Weight decay for adam|lamb 6 | -------------------------------------------------------------------------------- /convnova/configs/task/multiclass_classification.yaml: -------------------------------------------------------------------------------- 1 | # _target_: tasks.tasks.MultiClass 2 | _name_: multiclass 3 | loss: cross_entropy 4 | metrics: 5 | - accuracy 6 | torchmetrics: null 7 | -------------------------------------------------------------------------------- /convnova/configs/model/layer/mha.yaml: -------------------------------------------------------------------------------- 1 | _name_: mha 2 | causal: true 3 | n_heads: 8 4 | dropout: null 5 | bias: True 6 | add_bias_kv: False 7 | add_zero_attn: False 8 | kdim: null 9 | vdim: null 10 | -------------------------------------------------------------------------------- /convnova/configs/scheduler/step.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | interval: step 4 | scheduler: 5 | # _target_: torch.optim.lr_scheduler.StepLR 6 | _name_: step 7 | step_size: 1 8 | gamma: 0.99 9 | -------------------------------------------------------------------------------- /convnova/configs/task/multilabel_classification.yaml: -------------------------------------------------------------------------------- 1 | # _target_: 2 | _name_: base 3 | loss: binary_cross_entropy 4 | metrics: null 5 | torchmetrics: 6 | - AUROC 7 | - Precision 8 | - Recall 9 | - F1 10 | -------------------------------------------------------------------------------- /convnova/configs/model/transformer.yaml: -------------------------------------------------------------------------------- 1 | # Large Transformer model used as baseline for WikiText-103 2 | defaults: 3 | - base 4 | - override layer: transformer 5 | 6 | encoder: 7 | _name_: position 8 | dropout: ${..dropout} 9 | -------------------------------------------------------------------------------- /convnova/configs/scheduler/multistep.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | interval: epoch 4 | # _target_: torch.optim.lr_scheduler.MultiStepLR 5 | scheduler: 6 | _name_: multistep 7 | milestones: [80,140,180] 8 | gamma: 0.2 9 | -------------------------------------------------------------------------------- /convnova/configs/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | # _target_: torch.optim.Adam 2 | _name_: adam 3 | lr: 0.001 # Initial learning rate 4 | # weight_decay: 0.0 # Weight decay for adam|lamb; should use AdamW instead if desired 5 | betas: [0.9, 0.999] 6 | -------------------------------------------------------------------------------- /convnova/configs/experiment/base.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /pipeline: mnist 4 | - /model: long-conv 5 | 6 | # This file is a bare bones config for an experiment for illustration, consisting of a pipeline and model backbone 7 | -------------------------------------------------------------------------------- /convnova/configs/callbacks/rich.yaml: -------------------------------------------------------------------------------- 1 | rich_model_summary: 2 | # _target_: pytorch_lightning.callbacks.RichModelSummary 3 | max_depth: 2 4 | 5 | rich_progress_bar: 6 | # _target_: pytorch_lightning.callbacks.RichProgressBar 7 | refresh_rate_per_second: 1.0 8 | -------------------------------------------------------------------------------- /convnova/configs/scheduler/cosine_warmup.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | interval: step 4 | scheduler: 5 | # _target_: transformers.get_cosine_schedule_with_warmup 6 | _name_: cosine_warmup 7 | num_warmup_steps: 1000 8 | num_training_steps: 40000 9 | -------------------------------------------------------------------------------- /convnova/configs/scheduler/linear_warmup.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | interval: step 4 | scheduler: 5 | # _target_: transformers.get_linear_schedule_with_warmup 6 | _name_: linear_warmup 7 | num_warmup_steps: 1000 8 | num_training_steps: 40000 9 | -------------------------------------------------------------------------------- /convnova/configs/scheduler/constant_warmup.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | interval: step 4 | scheduler: 5 | # _target_: transformers.get_constant_schedule_with_warmup 6 | _name_: constant_warmup 7 | num_warmup_steps: 1000 # Number of iterations for LR warmup 8 | -------------------------------------------------------------------------------- /convnova/configs/scheduler/cosine.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | interval: epoch 4 | scheduler: 5 | # _target_: torch.optim.lr_scheduler.CosineAnnealingLR 6 | _name_: cosine 7 | T_max: 100 # Max number of epochs steps for LR scheduler 8 | eta_min: 1e-6 # Min learning rate for cosine scheduler 9 | -------------------------------------------------------------------------------- /convnova/configs/scheduler/cosine_warmup_timm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | interval: step 4 | scheduler: 5 | # _target_: transformers.get_cosine_schedule_with_warmup 6 | _name_: cosine_warmup_timm 7 | t_in_epochs: False 8 | t_initial: 300 9 | lr_min: 1e-5 10 | warmup_lr_init: 1e-6 11 | warmup_t: 10 12 | -------------------------------------------------------------------------------- /convnova/configs/dataset/hg38_fixed_test.yaml: -------------------------------------------------------------------------------- 1 | _name_: hg38_fixed 2 | fasta_file: null 3 | chr_ranges: null 4 | pad_max_length: ${.max_length} 5 | max_length: 1024 6 | add_eos: True 7 | batch_size: 8 # per GPU 8 | num_workers: 4 # For preprocessing only 9 | shuffle: False 10 | pin_memory: True 11 | __train_len: ${div_up:1_000_000_000, ${.max_length}} 12 | __l_max: ${.max_length} -------------------------------------------------------------------------------- /convnova/configs/model/layer/hyena.yaml: -------------------------------------------------------------------------------- 1 | _name_: hyena 2 | l_max: 1024 3 | order: 2 4 | filter_order: 64 5 | num_heads: 1 6 | inner_factor: 1 7 | num_blocks: 1 8 | fused_bias_fc: false 9 | outer_mixing: false 10 | dropout: 0.0 11 | filter_dropout: 0.0 12 | filter_cls: 'hyena-filter' 13 | post_order_ffn: false 14 | jit_filter: false 15 | short_filter_order: 3 16 | activation: "id" -------------------------------------------------------------------------------- /convnova/configs/callbacks/base.yaml: -------------------------------------------------------------------------------- 1 | learning_rate_monitor: 2 | # _target_: pytorch_lightning.callbacks.LearningRateMonitor 3 | logging_interval: ${train.interval} 4 | 5 | timer: 6 | # _target_: callbacks.timer.Timer 7 | step: True 8 | inter_step: False 9 | epoch: True 10 | val: True 11 | 12 | params: 13 | # _target_: callbacks.params.ParamsLog 14 | total: True 15 | trainable: True 16 | fixed: True 17 | -------------------------------------------------------------------------------- /convnova/configs/scheduler/plateau.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | interval: epoch 4 | monitor: ??? # must be specified 5 | scheduler: 6 | # _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 7 | _name_: plateau 8 | mode: ${train.mode} # Which metric to monitor 9 | factor: 0.2 # Decay factor when ReduceLROnPlateau is used 10 | patience: 20 11 | min_lr: 0.0 # Minimum learning rate during annealing 12 | -------------------------------------------------------------------------------- /convnova/configs/dataset/hg38.yaml: -------------------------------------------------------------------------------- 1 | _name_: hg38 2 | bed_file: null 3 | fasta_file: null 4 | dataset_name: hg38 5 | tokenizer_name: null 6 | cache_dir: null 7 | max_length: 1024 8 | add_eos: True 9 | batch_size: 8 # per GPU 10 | batch_size_eval: ${eval:${.batch_size} * 2} 11 | num_workers: 4 # For preprocessing only 12 | shuffle: True 13 | pin_memory: True 14 | __train_len: ${div_up:1_000_000_000, ${.max_length}} 15 | __l_max: ${.max_length} -------------------------------------------------------------------------------- /convnova/configs/model/layer/hyena-filter.yaml: -------------------------------------------------------------------------------- 1 | _name_: hyena-filter 2 | emb_dim: 3 # dim of input to MLP, augments with positional encoding 3 | order: 16 # width of the implicit MLP 4 | fused_fft_conv: false 5 | # seq_len: ${dataset.__l_max} 6 | lr: 1e-3 7 | lr_pos_emb: 1e-5 8 | dropout: 0.0 9 | w: 1 # frequency of periodic activations 10 | wd: 0 # weight decay of kernel parameters 11 | bias: true 12 | normalized: False 13 | num_inner_mlps: 2 -------------------------------------------------------------------------------- /convnova/configs/dataset/bert_hg38.yaml: -------------------------------------------------------------------------------- 1 | _name_: bert_hg38 2 | bed_file: null 3 | fasta_file: null 4 | dataset_name: hg38 5 | tokenizer_name: null 6 | cache_dir: null 7 | max_length: 1024 8 | add_eos: True 9 | batch_size: 8 # per GPU 10 | batch_size_eval: ${eval:${.batch_size} * 2} 11 | num_workers: 4 # For preprocessing only 12 | shuffle: True 13 | pin_memory: True 14 | __train_len: ${div_up:1_000_000_000, ${.max_length}} 15 | __l_max: ${.max_length} -------------------------------------------------------------------------------- /convnova/configs/dataset/dnabert2_pretrain.yaml: -------------------------------------------------------------------------------- 1 | _name_: dnabert2_pretrain 2 | text_file: null 3 | dataset_name: dnabert2_pretrain 4 | tokenizer_name: null 5 | cache_dir: null 6 | max_length: 1024 7 | add_eos: True 8 | batch_size: 8 # per GPU 9 | batch_size_eval: ${eval:${.batch_size} * 2} 10 | num_workers: 4 # For preprocessing only 11 | shuffle: True 12 | pin_memory: True 13 | __train_len: ${div_up:1_000_000_000, ${.max_length}} 14 | __l_max: ${.max_length} -------------------------------------------------------------------------------- /convnova/configs/pipeline/hg38.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /trainer: default 4 | - /loader: default 5 | - /dataset: hg38 6 | - /optimizer: adamw 7 | - /scheduler: cosine_warmup 8 | - /callbacks: [base, checkpoint] 9 | 10 | train: 11 | monitor: test/loss 12 | mode: min 13 | 14 | task: 15 | _name_: lm 16 | loss: cross_entropy 17 | torchmetrics: ['perplexity', 'num_tokens'] 18 | 19 | encoder: null 20 | decoder: null 21 | -------------------------------------------------------------------------------- /convnova/configs/trainer/debug.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | gpus: 1 5 | min_epochs: 1 6 | max_epochs: 10 7 | 8 | # prints 9 | progress_bar_refresh_rate: null 10 | weights_summary: full 11 | profiler: null 12 | 13 | # debugs 14 | fast_dev_run: False 15 | num_sanity_val_steps: 2 16 | overfit_batches: 0 17 | limit_train_batches: 0.1 18 | limit_val_batches: 0.1 19 | limit_test_batches: 0.1 20 | track_grad_norm: -1 21 | terminate_on_nan: False 22 | -------------------------------------------------------------------------------- /convnova/configs/model/layer/hyena_dna.yaml: -------------------------------------------------------------------------------- 1 | _name_: hyena 2 | l_max: 1024 3 | order: 3 4 | filter_order: 64 5 | num_heads: 1 6 | inner_factor: 1 7 | num_blocks: 1 8 | fused_bias_fc: false 9 | outer_mixing: false 10 | dropout: 0.0 11 | filter_dropout: 0.0 12 | filter_cls: 'hyena-filter' 13 | post_order_ffn: false 14 | jit_filter: false 15 | short_filter_order: 3 16 | activation: "id" 17 | filter_args: 18 | emb_dim: 3 19 | order: 16 20 | seq_len: ${..l_max} -------------------------------------------------------------------------------- /convnova/configs/pipeline/bert_hg38.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /trainer: default 4 | - /loader: default 5 | - /dataset: bert_hg38 6 | - /optimizer: adamw 7 | - /scheduler: cosine_warmup 8 | - /callbacks: [base, checkpoint] 9 | 10 | train: 11 | monitor: test/loss 12 | mode: min 13 | 14 | task: 15 | _name_: lm 16 | loss: cross_entropy 17 | torchmetrics: ['perplexity', 'num_tokens'] 18 | 19 | encoder: null 20 | decoder: null 21 | -------------------------------------------------------------------------------- /convnova/configs/pipeline/dnabert2_pretrain.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /trainer: default 4 | - /loader: default 5 | - /dataset: dnabert2_pretrain 6 | - /optimizer: adamw 7 | - /scheduler: cosine_warmup 8 | - /callbacks: [base, checkpoint] 9 | 10 | train: 11 | monitor: test/loss 12 | mode: min 13 | 14 | task: 15 | _name_: lm 16 | loss: cross_entropy 17 | torchmetrics: ['perplexity', 'num_tokens'] 18 | 19 | encoder: null 20 | decoder: null 21 | -------------------------------------------------------------------------------- /convnova/configs/dataset/chromatin_profile.yaml: -------------------------------------------------------------------------------- 1 | _name_: chromatin_profile 2 | dataset_name: chromatin_profile 3 | ref_genome_path: '/home/callum/private/hyena/build-deepsea-training-dataset/hg19/hg19.fa' 4 | ref_genome_version: 'hg19' 5 | data_path: '/home/callum/private/hyena/safari-internal-inf/data/chromatin_profile/' 6 | max_length: 1024 7 | d_output: 919 8 | use_padding: True 9 | padding_side: 'left' 10 | add_eos: False 11 | __l_max: ${.max_length} 12 | shuffle: true # set this as default! 13 | -------------------------------------------------------------------------------- /convnova/configs/pipeline/species.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /trainer: default 4 | - /loader: default 5 | - /dataset: species 6 | - /task: multiclass_classification 7 | - /optimizer: adamw 8 | - /scheduler: plateau 9 | - /callbacks: [base, checkpoint] 10 | 11 | train: 12 | monitor: val/accuracy # Needed for plateau scheduler 13 | mode: max 14 | 15 | encoder: id 16 | 17 | # we need this for classification! 18 | decoder: 19 | _name_: sequence 20 | mode: pool 21 | -------------------------------------------------------------------------------- /convnova/configs/pipeline/chromatin_profile.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /trainer: default 4 | - /loader: default 5 | - /dataset: chromatin_profile 6 | - /task: multilabel_classification 7 | - /optimizer: adamw 8 | - /scheduler: plateau 9 | - /callbacks: [base, checkpoint] 10 | 11 | train: 12 | monitor: val/accuracy # Needed for plateau scheduler 13 | mode: max 14 | 15 | encoder: id 16 | 17 | # we need this for classification! 18 | decoder: 19 | _name_: sequence 20 | mode: pool -------------------------------------------------------------------------------- /convnova/configs/pipeline/genomic_benchmark.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /trainer: default 4 | - /loader: default 5 | - /dataset: genomic_benchmark 6 | - /task: multiclass_classification 7 | - /optimizer: adamw 8 | - /scheduler: plateau 9 | - /callbacks: [base, checkpoint] 10 | 11 | train: 12 | monitor: val/accuracy # Needed for plateau scheduler 13 | mode: max 14 | 15 | encoder: id 16 | 17 | # we need this for classification! 18 | decoder: 19 | _name_: sequence 20 | mode: pool 21 | -------------------------------------------------------------------------------- /convnova/configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | devices: 1 4 | accelerator: gpu 5 | accumulate_grad_batches: 1 # Gradient accumulation every n batches 6 | max_epochs: 200 7 | # accelerator: ddp # Automatically set if gpus > 1 8 | gradient_clip_val: 0.0 9 | log_every_n_steps: 10 10 | limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run 11 | limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run 12 | -------------------------------------------------------------------------------- /convnova/configs/pipeline/gue.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /trainer: default 4 | - /loader: default 5 | - /dataset: gue 6 | - /task: multiclass_classification 7 | - /optimizer: adamw 8 | - /scheduler: plateau 9 | - /callbacks: [base, checkpoint] 10 | 11 | task: 12 | loss: 13 | _name_: cross_entropy 14 | metrics: 15 | - ${dataset.metric} 16 | 17 | train: 18 | monitor: val/${dataset.metric} 19 | mode: max 20 | 21 | encoder: id 22 | 23 | # we need this for classification! 24 | decoder: 25 | _name_: sequence 26 | mode: pool -------------------------------------------------------------------------------- /convnova/configs/pipeline/deepstarr.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /trainer: default 4 | - /loader: default 5 | - /dataset: deepstarr 6 | - /task: regression 7 | - /optimizer: adamw 8 | - /scheduler: plateau 9 | - /callbacks: [base, checkpoint] 10 | 11 | train: 12 | monitor: test/pearsonr_mean 13 | mode: max 14 | 15 | task: 16 | loss: customMSE 17 | # _name_: cross_entropy 18 | metrics: 19 | - ${dataset.metric} 20 | 21 | encoder: id 22 | 23 | # we need this for classification! 24 | decoder: 25 | _name_: sequence 26 | mode: pool -------------------------------------------------------------------------------- /convnova/configs/pipeline/nucleotide_transformer.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /trainer: default 4 | - /loader: default 5 | - /dataset: nucleotide_transformer 6 | - /task: multiclass_classification 7 | - /optimizer: adamw 8 | - /scheduler: plateau 9 | - /callbacks: [base, checkpoint] 10 | 11 | task: 12 | loss: 13 | _name_: cross_entropy 14 | metrics: 15 | - ${dataset.metric} 16 | 17 | train: 18 | monitor: val/${dataset.metric} 19 | mode: max 20 | 21 | encoder: id 22 | 23 | # we need this for classification! 24 | decoder: 25 | _name_: sequence 26 | mode: pool -------------------------------------------------------------------------------- /convnova/configs/evals/hyena_dna_512ksl.yaml: -------------------------------------------------------------------------------- 1 | model_name: hyena-small 2 | tokenizer_name: gpt2 3 | model_config: 4 | _name_: lm 5 | d_model: 256 6 | d_inner: 1024 7 | n_layer: 8 8 | vocab_size: 4 9 | embed_dropout: 0.0 10 | layer: 11 | _name_: hyena 12 | emb_dim: 33 13 | linear_mixer: False 14 | filter_order: 64 15 | local_order: 3 16 | l_max: 400_000 # 524288 17 | modulate: False 18 | w: 14 19 | fused_mlp: False 20 | fused_dropout_add_ln: False 21 | residual_in_fp32: True 22 | checkpoint_mixer: True 23 | checkpoint_mlp: True 24 | pad_vocab_size_multiple: 8 -------------------------------------------------------------------------------- /convnova/configs/dataset/icl_hg38.yaml: -------------------------------------------------------------------------------- 1 | _name_: icl_hg38 2 | bed_file: '/home/workspace/eric/safari-internal/data/hg38/human-sequences.bed' 3 | fasta_file: '/home/workspace/eric/safari-internal/data/hg38/hg38.ml.fa' 4 | dataset_name: icl_hg38 5 | tokenizer_name: null 6 | cache_dir: null 7 | min_length: 128 8 | max_length: 1024 9 | variable_length: True 10 | add_eos: False 11 | batch_size: 8 # per GPU 12 | batch_size_eval: ${eval:${.batch_size} * 2} 13 | num_workers: 4 # For preprocessing only 14 | shuffle: True 15 | pin_memory: True 16 | __train_len: ${div_up:1_000_000_000, ${.max_length}} 17 | __l_max: ${.max_length} -------------------------------------------------------------------------------- /convnova/configs/trainer/lm.yaml: -------------------------------------------------------------------------------- 1 | accumulate_grad_batches: 1 2 | # accelerator: null # set to 'ddp' for distributed 3 | # amp_backend: native # 'native' | 'apex' 4 | gpus: 8 5 | max_epochs: 50 6 | gradient_clip_val: 0.0 # Gradient clipping 7 | log_every_n_steps: 10 8 | precision: 16 9 | progress_bar_refresh_rate: 1 10 | weights_summary: top # Set to 'full' to see every layer 11 | track_grad_norm: -1 # Set to 2 to track norms of gradients 12 | limit_train_batches: 1.0 13 | limit_val_batches: 1.0 14 | # We use the dataloader from Transformer-XL to ensure adjacent minibatches 15 | # are from text that are next to each other. 16 | # So that dataloader has to deal with DDP, and we don't want PL to handle 17 | # that. 18 | replace_sampler_ddp: False 19 | -------------------------------------------------------------------------------- /convnova/configs/evals/hg38.yaml: -------------------------------------------------------------------------------- 1 | model_name: hyena-small 2 | tokenizer_name: char 3 | model_config: 4 | _name_: lm 5 | d_model: 256 6 | n_layer: 8 7 | d_inner: 1024 # ${eval:4 * ${.d_model}} 8 | vocab_size: 12 9 | resid_dropout: 0.0 10 | embed_dropout: 0.1 11 | fused_mlp: False # figure out how to use fused MLP, maybe only with bf16 + a100 12 | fused_dropout_add_ln: True 13 | residual_in_fp32: True 14 | pad_vocab_size_multiple: 8 15 | return_hidden_state: True # in 2nd position of output tuple (1st is logits) 16 | layer: 17 | _name_: hyena 18 | emb_dim: 5 19 | filter_order: 64 20 | local_order: 3 21 | l_max: 1026 # add 2 for ckpt 22 | modulate: True 23 | w: 10 24 | lr: 6e-4 25 | wd: 0.0 26 | lr_pos_emb: 0.0 27 | -------------------------------------------------------------------------------- /convnova/configs/callbacks/wandb.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | watch_model: 5 | _target_: src.callbacks.wandb_callbacks.WatchModel 6 | log: "all" 7 | log_freq: 100 8 | 9 | upload_code_as_artifact: 10 | _target_: src.callbacks.wandb_callbacks.UploadCodeAsArtifact 11 | code_dir: ${work_dir}/src 12 | 13 | upload_ckpts_as_artifact: 14 | _target_: src.callbacks.wandb_callbacks.UploadCheckpointsAsArtifact 15 | ckpt_dir: "checkpoints/" 16 | upload_best_only: True 17 | 18 | log_f1_precision_recall_heatmap: 19 | _target_: src.callbacks.wandb_callbacks.LogF1PrecRecHeatmap 20 | 21 | log_confusion_matrix: 22 | _target_: src.callbacks.wandb_callbacks.LogConfusionMatrix 23 | 24 | log_image_predictions: 25 | _target_: src.callbacks.wandb_callbacks.LogImagePredictions 26 | num_samples: 8 27 | -------------------------------------------------------------------------------- /convnova/configs/dataset/deepsea.yaml: -------------------------------------------------------------------------------- 1 | _name_: deepsea 2 | dataset_name: null 3 | dest_path: # project root/data/DeepSea/deepsea_train/ 4 | # datasets are: train.csv.gz test.csv.gz val.csv.gz 5 | max_length: 1024 6 | d_output: 919 # binary classification task 7 | use_padding: True 8 | padding_side: 'left' 9 | add_eos: False 10 | batch_size: 32 11 | train_len: 4400000 12 | __l_max: ${.max_length} 13 | shuffle: true # set this as default! 14 | tokenizer_name: char 15 | cache_dir: null 16 | batch_size_eval: ${eval:${.batch_size} * 2} 17 | pin_memory: True 18 | rc_aug: False 19 | metric: roc 20 | # https://www.nature.com/articles/nmeth.3547#Sec2 21 | # orignal train code at DeepSea/deepsea_train/4_train.lua 22 | #For evaluating performance on the test set, we used area under the receiver operating characteristic curve (AUC). -------------------------------------------------------------------------------- /convnova/configs/evals/icl_genomics.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | _name_: lm 3 | d_model: 256 4 | n_layer: 8 5 | d_inner: 1024 # ${eval:4 * ${.d_model}} 6 | vocab_size: 12 7 | resid_dropout: 0.0 8 | embed_dropout: 0.1 9 | fused_mlp: False # figure out how to use fused MLP, maybe only with bf16 + a100 10 | fused_dropout_add_ln: True 11 | residual_in_fp32: True 12 | pad_vocab_size_multiple: 8 13 | return_hidden_state: True # in 2nd position of output tuple (1st is logits) 14 | checkpoint_mixer: False 15 | checkpoint_mlp: False 16 | layer: 17 | _name_: hyena 18 | emb_dim: 5 19 | filter_order: 64 20 | local_order: 3 21 | l_max: 160_002 # add 2 for ckpt 22 | modulate: True 23 | w: 10 24 | lr: 6e-4 25 | wd: 0.0 26 | lr_pos_emb: 0.0 27 | 28 | train: 29 | d_output: 2 # number of classes 30 | 31 | dataset: 32 | dataset_name: human_nontata_promoters 33 | max_length: 256 34 | d_output: 2 # num classes 35 | train_len: 36131 36 | tokenizer_name: char 37 | batch_size: 128 # Per GPU 38 | rc_aug: false 39 | shots: 5 40 | 41 | # human_nontata_promoters 36131 2 251 0 -------------------------------------------------------------------------------- /convnova/configs/callbacks/checkpoint.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | # _target_: pytorch_lightning.callbacks.ModelCheckpoint 3 | monitor: ${train.monitor} # name of the logged metric which determines when model is improving 4 | mode: ${train.mode} # can be "max" or "min" 5 | save_top_k: 1 # save k best models (determined by above metric) 6 | save_last: True # additionaly always save model from last epoch 7 | dirpath: "checkpoints/" 8 | # this saves an annoying "epoch=12" which makes it annoying to pass on the command line; epochs are being logged through the verbose flag and logger anyways 9 | # seems like you can override the '.format_checkpoint_name' method of ModelCheckpoint to change this, but not worth 10 | # filename: "{epoch:02d}", 11 | filename: ${train.monitor} 12 | auto_insert_metric_name: False 13 | verbose: True 14 | 15 | # early_stopping: 16 | # _target_: pytorch_lightning.callbacks.EarlyStopping 17 | # monitor: "val/acc" # name of the logged metric which determines when model is improving 18 | # mode: "max" # can be "max" or "min" 19 | # patience: 100 # how many epochs of not improving until training stops 20 | # min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement 21 | -------------------------------------------------------------------------------- /convnova/configs/evals/instruction_tuned_genomics.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | _name_: lm 3 | d_model: 256 4 | n_layer: 8 5 | d_inner: 1024 # ${eval:4 * ${.d_model}} 6 | vocab_size: 12 7 | resid_dropout: 0.0 8 | embed_dropout: 0.1 9 | fused_mlp: False # figure out how to use fused MLP, maybe only with bf16 + a100 10 | fused_dropout_add_ln: False 11 | residual_in_fp32: True 12 | pad_vocab_size_multiple: 8 13 | return_hidden_state: True # in 2nd position of output tuple (1st is logits) 14 | checkpoint_mixer: False 15 | checkpoint_mlp: False 16 | layer: 17 | _name_: hyena 18 | emb_dim: 5 19 | filter_order: 64 20 | local_order: 3 21 | l_max: 160_002 # add 2 for ckpt 22 | modulate: True 23 | w: 10 24 | lr: 6e-4 25 | wd: 0.0 26 | lr_pos_emb: 0.0 27 | 28 | tuning: 29 | tuning_samples: 30 | - 2 31 | - 16 32 | - 64 33 | - 256 34 | batch_size: 2 # for tuning 35 | max_epochs: 1 36 | lr: 1e-4 37 | weight_decay: 0 38 | gradient_clip_val: 1.0 39 | accumulate_grad_batches: 1 40 | # ema_decay: 0.98 41 | 42 | dataset: 43 | tokenizer_name: char 44 | batch_size: 4 # used for data loading & evaluation 45 | rc_aug: False 46 | shots: 47 | - 0 48 | - 2 49 | - 16 50 | - 32 51 | num_workers: 10 52 | 53 | seed: 12345 -------------------------------------------------------------------------------- /convnova/configs/trainer/full.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | # default values for all trainer parameters 4 | checkpoint_callback: True 5 | default_root_dir: null 6 | gradient_clip_val: 0.0 7 | process_position: 0 8 | num_nodes: 1 9 | num_processes: 1 10 | gpus: null 11 | auto_select_gpus: False 12 | tpu_cores: null 13 | log_gpu_memory: null 14 | overfit_batches: 0.0 15 | track_grad_norm: -1 16 | check_val_every_n_epoch: 1 17 | fast_dev_run: False 18 | accumulate_grad_batches: 1 19 | max_epochs: 1 20 | min_epochs: 1 21 | max_steps: null 22 | min_steps: null 23 | limit_train_batches: 1.0 24 | limit_val_batches: 1.0 25 | limit_test_batches: 1.0 26 | val_check_interval: 1.0 27 | flush_logs_every_n_steps: 100 28 | log_every_n_steps: 50 29 | accelerator: null 30 | sync_batchnorm: False 31 | precision: 32 32 | weights_summary: "top" 33 | weights_save_path: null 34 | num_sanity_val_steps: 2 35 | truncated_bptt_steps: null 36 | resume_from_checkpoint: null 37 | profiler: null 38 | benchmark: False 39 | deterministic: False 40 | reload_dataloaders_every_epoch: False 41 | auto_lr_find: False 42 | replace_sampler_ddp: True 43 | terminate_on_nan: False 44 | auto_scale_batch_size: False 45 | prepare_data_per_node: True 46 | plugins: null 47 | amp_backend: "native" 48 | amp_level: "O2" 49 | move_metrics_to_cpu: False 50 | -------------------------------------------------------------------------------- /convnova/src/dataloaders/datasets/lm_dataset.py: -------------------------------------------------------------------------------- 1 | # Inspired by https://github.com/NVIDIA/Megatron-LM/blob/main/tasks/zeroshot_gpt/datasets.py 2 | # Except we don't pad the last block and don't use overlapping eval 3 | # And we return both the input and the target 4 | import math 5 | import numpy as np 6 | 7 | import torch 8 | 9 | 10 | class LMDataset(torch.utils.data.Dataset): 11 | 12 | def __init__(self, tokens, seq_len, drop_last=True): 13 | """tokens should be a numpy array 14 | """ 15 | self.seq_len = seq_len 16 | ntokens = len(tokens) 17 | if drop_last: 18 | ntokens = ((ntokens - 1) // seq_len) * seq_len + 1 19 | self.ntokens = ntokens 20 | # We're careful not to slice tokens, since it could be a memmap'ed array or H5 dataset, 21 | # and slicing would load it to memory. 22 | self.tokens = tokens 23 | self.total_sequences = math.ceil((self.ntokens - 1) / self.seq_len) 24 | 25 | def __len__(self): 26 | return self.total_sequences 27 | 28 | def __getitem__(self, idx): 29 | start_idx = idx * self.seq_len 30 | seq_len = min(self.seq_len, self.ntokens - 1 - start_idx) 31 | data = torch.as_tensor(self.tokens[start_idx:(start_idx + seq_len + 1)].astype(np.int64)) 32 | return data[:-1], data[1:].clone() -------------------------------------------------------------------------------- /convnova/configs/evals/soft_prompting_genomics.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | _name_: lm 3 | d_model: 256 4 | n_layer: 8 5 | d_inner: 1024 # ${eval:4 * ${.d_model}} 6 | vocab_size: 12 7 | resid_dropout: 0.0 8 | embed_dropout: 0.1 9 | fused_mlp: False # figure out how to use fused MLP, maybe only with bf16 + a100 10 | fused_dropout_add_ln: False 11 | residual_in_fp32: True 12 | pad_vocab_size_multiple: 8 13 | return_hidden_state: True # in 2nd position of output tuple (1st is logits) 14 | checkpoint_mixer: False 15 | checkpoint_mlp: False 16 | layer: 17 | _name_: hyena 18 | emb_dim: 5 19 | filter_order: 64 20 | local_order: 3 21 | l_max: 160_002 # add 2 for ckpt 22 | modulate: True 23 | w: 10 24 | lr: 6e-4 25 | wd: 0.0 26 | lr_pos_emb: 0.0 27 | 28 | tuning: 29 | soft_tokens: 30 | - 0 31 | - 2 32 | - 32 33 | - 128 34 | - 512 35 | - 2048 36 | - 8192 37 | - 32768 38 | soft_token_pdrop: 0.1 # dropout probability for soft tokens 39 | max_epochs: 20 40 | lr: 1e-3 41 | weight_decay: 0. 42 | gradient_clip_val: 1.0 43 | accumulate_grad_batches: 8 # number of batches to accumulate before gradient update 44 | ema_decay: 0.9 # decay rate for updates of expected moving average of trained models 45 | 46 | dataset: 47 | tokenizer_name: char 48 | batch_size: 2 49 | rc_aug: True # no augmentation 50 | shots: 51 | # - 0 52 | - 2 53 | - 32 54 | - 128 55 | num_workers: 10 56 | 57 | seed: 12345 -------------------------------------------------------------------------------- /convnova/src/callbacks/params.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import pytorch_lightning as pl 4 | from pytorch_lightning.utilities import rank_zero_only 5 | from pytorch_lightning.utilities.parsing import AttributeDict 6 | 7 | 8 | class ParamsLog(pl.Callback): 9 | """ Log the number of parameters of the model """ 10 | def __init__( 11 | self, 12 | total: bool = True, 13 | trainable: bool = True, 14 | fixed: bool = True, 15 | ): 16 | super().__init__() 17 | self._log_stats = AttributeDict( 18 | { 19 | 'total_params_log': total, 20 | 'trainable_params_log': trainable, 21 | 'non_trainable_params_log': fixed, 22 | } 23 | ) 24 | 25 | @rank_zero_only 26 | def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: 27 | logs = {} 28 | if self._log_stats.total_params_log: 29 | logs["params/total"] = sum(p.numel() for p in pl_module.parameters()) 30 | if self._log_stats.trainable_params_log: 31 | logs["params/trainable"] = sum(p.numel() for p in pl_module.parameters() 32 | if p.requires_grad) 33 | if self._log_stats.non_trainable_params_log: 34 | logs["params/fixed"] = sum(p.numel() for p in pl_module.parameters() 35 | if not p.requires_grad) 36 | if trainer.logger: 37 | trainer.logger.log_hyperparams(logs) 38 | -------------------------------------------------------------------------------- /convnova/src/callbacks/norms.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from pytorch_lightning.utilities import rank_zero_only 3 | from pytorch_lightning.utilities.parsing import AttributeDict 4 | from omegaconf import OmegaConf 5 | 6 | class TrackNorms(pl.Callback): 7 | 8 | # TODO do callbacks happen before or after the method in the main LightningModule? 9 | # @rank_zero_only # needed? 10 | def on_after_training_step(self, batch, batch_idx, trainer: pl.Trainer, pl_module: pl.LightningModule): 11 | # Log extra metrics 12 | metrics = {} 13 | 14 | if hasattr(pl_module, "_grad_norms"): 15 | metrics.update(pl_module._grad_norms) 16 | 17 | self.log_dict( 18 | metrics, 19 | on_step=True, 20 | on_epoch=False, 21 | prog_bar=False, 22 | add_dataloader_idx=False, 23 | sync_dist=True, 24 | ) 25 | 26 | 27 | def on_after_backward(self, trainer: pl.Trainer, pl_module: pl.LightningModule): 28 | # example to inspect gradient information in tensorboard 29 | if OmegaConf.select(trainer.hparams, 'trainer.track_grad_norms'): # TODO dot notation should work with omegaconf? 30 | norms = {} 31 | for name, p in pl_module.named_parameters(): 32 | if p.grad is None: 33 | continue 34 | 35 | # param_norm = float(p.grad.data.norm(norm_type)) 36 | param_norm = torch.mean(p.grad.data ** 2) 37 | norms[f"grad_norm.{name}"] = param_norm 38 | pl_module._grad_norms = norms 39 | 40 | -------------------------------------------------------------------------------- /convnova/src/models/basenji2/params_human.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "batch_size": 4, 4 | "optimizer": "sgd", 5 | "learning_rate": 0.15, 6 | "momentum": 0.99, 7 | "patience": 16, 8 | "clip_norm": 2 9 | }, 10 | "model": { 11 | "seq_length": 131072, 12 | "target_length": 1024, 13 | "activation": "gelu", 14 | "norm_type": "batch", 15 | "bn_momentum": 0.9, 16 | "trunk": [ 17 | { 18 | "name": "conv_block", 19 | "filters": 288, 20 | "kernel_size": 15, 21 | "pool_size": 2 22 | }, 23 | { 24 | "name": "conv_tower", 25 | "filters_init": 339, 26 | "filters_mult": 1.1776, 27 | "kernel_size": 5, 28 | "pool_size": 2, 29 | "repeat": 6 30 | }, 31 | { 32 | "name": "dilated_residual", 33 | "filters": 384, 34 | "rate_mult": 1.5, 35 | "repeat": 11, 36 | "dropout": 0.3, 37 | "round": true 38 | }, 39 | { 40 | "name": "Cropping1D", 41 | "cropping": 64 42 | }, 43 | { 44 | "name": "conv_block", 45 | "filters": 1536, 46 | "dropout": 0.05 47 | } 48 | ], 49 | "head_human": { 50 | "name": "final", 51 | "units": 5313, 52 | "activation": "softplus" 53 | } 54 | } 55 | } -------------------------------------------------------------------------------- /convnova/src/callbacks/gpu_affinity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pytorch_lightning import Callback, Trainer, LightningModule 4 | 5 | import logging 6 | 7 | log = logging.getLogger(__name__) # We want a logger for each process, not just the rank 0 8 | 9 | 10 | def l2_promote(): 11 | import ctypes 12 | _libcudart = ctypes.CDLL('libcudart.so') 13 | # Set device limit on the current device 14 | # cudaLimitMaxL2FetchGranularity = 0x05 15 | pValue = ctypes.cast((ctypes.c_int*1)(), ctypes.POINTER(ctypes.c_int)) 16 | _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128)) 17 | _libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05)) 18 | assert pValue.contents.value == 128 19 | 20 | 21 | def set_affinity(trainer): 22 | try: 23 | from src.utils.gpu_affinity import set_affinity 24 | nproc_per_node = torch.cuda.device_count() 25 | affinity = set_affinity(trainer.local_rank, nproc_per_node, 'socket_unique_continuous') 26 | log.info(f'{trainer.local_rank}: thread affinity: {affinity}') 27 | # TD [2022-05-07] Somehow calling this causes GPU 0 to allocate extra ~800MB of memory per 28 | # number of GPUs (e.g., 6.4GB of extra memory in a 8-GPU setup). H/t Dan. 29 | # l2_promote() 30 | except: 31 | pass 32 | 33 | 34 | class GpuAffinity(Callback): 35 | """Set GPU affinity and increase the L2 fetch granularity. 36 | Adapted from https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/LanguageModeling/Transformer-XL 37 | """ 38 | 39 | def setup(self, trainer: Trainer, pl_module: LightningModule, stage=None) -> None: 40 | set_affinity(trainer) 41 | -------------------------------------------------------------------------------- /convnova/src/models/sequence/ff.py: -------------------------------------------------------------------------------- 1 | """ Implementation of FFN block in the style of Transformers """ 2 | 3 | from functools import partial 4 | from torch import nn 5 | from src.models.sequence.base import SequenceModule 6 | from src.models.nn import LinearActivation, DropoutNd 7 | 8 | class FF(SequenceModule): 9 | def __init__(self, d_input, expand=2, d_output=None, transposed=False, activation='gelu', initializer=None, dropout=0.0, tie_dropout=False): 10 | super().__init__() 11 | self.d_output = d_input if d_output is None else d_output 12 | self.transposed = transposed 13 | d_inner = expand * d_input 14 | 15 | linear1 = LinearActivation( 16 | d_input, d_inner, 17 | transposed=transposed, 18 | activation=activation, 19 | initializer=initializer, 20 | activate=True, 21 | ) 22 | dropout_cls = partial(DropoutNd, transposed=self.transposed) if tie_dropout else nn.Dropout 23 | # dropout_cls = nn.Dropout2d if self.transposed else nn.Dropout 24 | drop = dropout_cls(dropout) if dropout > 0.0 else nn.Identity() 25 | 26 | linear2 = LinearActivation( 27 | d_inner, self.d_output, 28 | transposed=transposed, 29 | activation=None, 30 | initializer=initializer, 31 | activate=False, 32 | ) 33 | 34 | self.ff = nn.Sequential( 35 | linear1, 36 | drop, 37 | linear2, 38 | ) 39 | 40 | def forward(self, x, *args, **kwargs): 41 | return self.ff(x), None 42 | 43 | def step(self, x, state, **kwargs): 44 | # x: [batch, d_input] 45 | if self.transposed: 46 | # expects: [batch, d_input, seq_len] 47 | return self.ff(x.unsqueeze(-1)).squeeze(-1), state 48 | else: 49 | return self.ff(x), state 50 | 51 | -------------------------------------------------------------------------------- /convnova/src/models/NTV2/ntv2.py: -------------------------------------------------------------------------------- 1 | from .modeling_esm import EsmForSequenceClassification 2 | from torch import nn 3 | from .esm_config import EsmConfig 4 | from peft import get_peft_model, LoraConfig, TaskType 5 | from omegaconf import OmegaConf 6 | 7 | class NTV2(nn.Module): 8 | def __init__(self, config, **kwargs): 9 | super().__init__() 10 | 11 | # Convert DictConfig to dictionary 12 | config_dict = OmegaConf.to_container(config, resolve=True) 13 | 14 | # Create EsmConfig 15 | esm_config = EsmConfig.from_dict(config_dict) 16 | 17 | # Load the pretrained model 18 | self.esm = EsmForSequenceClassification.from_pretrained(config_dict["_name_or_path"], config=esm_config, **kwargs) 19 | 20 | # Initialize LoRA configuration 21 | lora_config = LoraConfig( 22 | task_type=TaskType.SEQ_CLS, # or whatever task type you're using 23 | r=8, 24 | lora_alpha=16, 25 | lora_dropout=0.05, 26 | target_modules=[ 27 | "query", 28 | "key", 29 | "value", 30 | "dense" 31 | ] 32 | ) 33 | 34 | # Apply LoRA to the model 35 | self.esm = get_peft_model(self.esm, lora_config) 36 | 37 | self.d_model = esm_config.hidden_size 38 | 39 | def forward(self, input_ids, position_ids=None, inference_params=None): 40 | outputs = self.esm( 41 | input_ids, 42 | position_ids=position_ids, 43 | ) 44 | return outputs[0].logits, None 45 | 46 | @property 47 | def d_output(self): 48 | """Model /embedding dimension, used for decoder mapping. 49 | 50 | """ 51 | if getattr(self, "d_model", None) is None: 52 | raise NotImplementedError("SequenceModule instantiation must set d_output") 53 | return self.d_model -------------------------------------------------------------------------------- /convnova/configs/dataset/genomic_benchmark.yaml: -------------------------------------------------------------------------------- 1 | _name_: genomic_benchmark 2 | dataset_name: dummy_mouse_enhancers_ensembl 3 | dest_path: null 4 | max_length: 1024 5 | d_output: ${.${.dataset_name}.classes} 6 | use_padding: True 7 | padding_side: 'left' 8 | add_eos: False 9 | batch_size: 32 10 | train_len: ${.${.dataset_name}.train_len} 11 | __l_max: ${.max_length} 12 | shuffle: true # set this as default! 13 | # these are used to find the right attributes automatically for each dataset 14 | dummy_mouse_enhancers_ensembl: 15 | train_len: 1210 16 | classes: 2 17 | demo_coding_vs_intergenomic_seqs: 18 | train_len: 100_000 19 | classes: 2 20 | demo_human_or_worm: 21 | train_len: 100_000 22 | classes: 2 23 | human_enhancers_cohn: 24 | train_len: 27791 25 | classes: 2 26 | human_enhancers_ensembl: 27 | train_len: 154842 28 | classes: 2 29 | human_ensembl_regulatory: 30 | train_len: 289061 31 | classes: 3 32 | human_nontata_promoters: 33 | train_len: 36131 34 | classes: 2 35 | human_ocr_ensembl: 36 | train_len: 174756 37 | classes: 2 38 | 39 | # there are 8 datasets in this suite, choose 1 at a time, with their corresponding settings 40 | # name num_seqs num_classes median len std 41 | # dummy_mouse_enhancers_ensembl 1210 2 2381 984.4 42 | # demo_coding_vs_intergenomic_seqs 100_000 2 200 0 43 | # demo_human_or_worm 100_000 2 200 0 44 | # human_enhancers_cohn 27791 2 500 0 45 | # human_enhancers_ensembl 154842 2 269 122.6 46 | # human_ensembl_regulatory 289061 3 401 184.3 47 | # human_nontata_promoters 36131 2 251 0 48 | # human_ocr_ensembl 174756 2 315 108.1 49 | -------------------------------------------------------------------------------- /convnova/configs/experiment/hg38-pretrain/convNext.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /pipeline: hg38 4 | # - default model/layer: mha 5 | - override /scheduler: cosine_warmup_timm 6 | 7 | model: 8 | _name_: convnext 9 | d_model: 128 10 | max_length: ${dataset.max_length} 11 | vocab_size: 12 12 | pad_vocab_size_multiple: 8 13 | k_size: 5 14 | 15 | task: 16 | # 2 options for soft_cross_entropy (for mixup) 17 | loss: 18 | # soft_cross_entropy for pytorch 1.10+, which takes in label_smoothing here 19 | _name_: cross_entropy 20 | 21 | trainer: 22 | accelerator: gpu 23 | devices: 4 24 | num_nodes: 1 25 | accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}} 26 | max_epochs: 2000 27 | precision: 16 # bf16 only a100 28 | gradient_clip_val: 1.0 29 | strategy: null 30 | 31 | dataset: 32 | batch_size: 16 # Per GPU 33 | # batch_size: 256 34 | max_length: 8193 # 262144, 524288 35 | # optional, default is max_length 36 | max_length_val: ${dataset.max_length} 37 | max_length_test: ${dataset.max_length} 38 | tokenizer_name: char 39 | pad_max_length: null # needed for bpe tokenizer 40 | add_eos: true 41 | rc_aug: false 42 | num_workers: 12 43 | use_fixed_len_val: false 44 | 45 | scheduler: 46 | t_in_epochs: True 47 | # t_initial: ${eval:${div_up:${dataset.__train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}} # num steps for 1 cycle 48 | t_initial: ${eval:${div_up:${dataset.__train_len}, ${train.global_batch_size}} * 1} 49 | cycle_mul: 2 50 | warmup_lr_init: 1e-6 # starting point 51 | # warmup_t: ${eval:${div_up:${dataset.__train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.005} # time for ramp up 52 | warmup_t: ${eval:${div_up:${dataset.__train_len}, ${train.global_batch_size}} * 0.1} 53 | # lr_min: ${eval:0.1 * ${optimizer.lr}} # flatlines with this 54 | lr_min: 1e-6 55 | cycle_decay: 0.6 56 | cycle_limit: 10000 57 | 58 | optimizer: 59 | lr: 2e-3 # peak 60 | weight_decay: 0.1 61 | 62 | train: 63 | gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"} 64 | seed: 2222 65 | global_batch_size: ${eval:${dataset.batch_size}*${trainer.devices}} 66 | -------------------------------------------------------------------------------- /convnova/configs/experiment/hg38-pretrain/convnova.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /pipeline: bert_hg38 4 | - override /scheduler: cosine_warmup_timm 5 | 6 | model: 7 | _name_: convnova 8 | for_representation: false 9 | alphabet_size: 5 10 | d_model: 128 11 | pretrain: true 12 | kernel_size: 9 13 | final_conv: False 14 | dilation: 4 15 | num_conv1d: 5 16 | d_inner: 2 17 | ffn: true 18 | args: 19 | hidden_dim: 128 # same as d_model 20 | num_cnn_stacks: 1 21 | dropout: 0.0 22 | 23 | task: 24 | # _name_: lm 25 | _name_: hg38 # equivalent to lm task, plus allows extra metrics to be calculated 26 | loss: bert_cross_entropy 27 | 28 | trainer: 29 | accelerator: gpu 30 | devices: 6 31 | num_nodes: 1 32 | accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}} 33 | max_epochs: 2000 34 | precision: 32 # bf16 only a100 35 | gradient_clip_val: 1.0 36 | # strategy: null 37 | 38 | dataset: 39 | batch_size: 128 # Per GPU 40 | max_length: 1024 # 262144, 524288 41 | # optional, default is max_length 42 | max_length_val: ${dataset.max_length} 43 | max_length_test: ${dataset.max_length} 44 | tokenizer_name: char 45 | use_tokenizer: False 46 | pad_max_length: null # needed for bpe tokenizer 47 | add_eos: true 48 | rc_aug: false 49 | num_workers: 12 50 | use_fixed_len_val: false # placing a fixed length val here, but it's really the test 51 | replace_N_token: false # replace N (uncertain token) with pad tokens in dataloader 52 | pad_interval: false # handle uncertain tokens within the FastaInteral class 53 | 54 | scheduler: 55 | t_in_epochs: False 56 | t_initial: ${eval:${div_up:${dataset.__train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}} 57 | warmup_lr_init: 1e-6 58 | warmup_t: ${eval:${div_up:${dataset.__train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.01} 59 | lr_min: ${eval:0.1 * ${optimizer.lr}} 60 | 61 | optimizer: 62 | lr: 1e-3 63 | weight_decay: 0.0 64 | 65 | train: 66 | gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"} 67 | seed: 2222 68 | global_batch_size: ${eval:${trainer.devices}*${dataset.batch_size}} # effects the scheduler, need to set properly 69 | -------------------------------------------------------------------------------- /convnova/configs/evals/hg38_decoder.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | _name_: lm 3 | d_model: 128 4 | n_layer: 2 5 | d_inner: 512 # ${eval:4 * ${.d_model}} 6 | vocab_size: 12 7 | resid_dropout: 0.0 8 | embed_dropout: 0.1 9 | fused_mlp: False # figure out how to use fused MLP, maybe only with bf16 + a100 10 | fused_dropout_add_ln: True 11 | residual_in_fp32: True 12 | pad_vocab_size_multiple: 8 13 | return_hidden_state: True # in 2nd position of output tuple (1st is logits) 14 | layer: 15 | _name_: hyena 16 | emb_dim: 5 17 | filter_order: 64 18 | local_order: 3 19 | l_max: 1026 # add 2 for ckpt 20 | modulate: True 21 | w: 10 22 | lr: 6e-4 23 | wd: 0.0 24 | lr_pos_emb: 0.0 25 | 26 | train: 27 | d_output: 2 # number of classes 28 | 29 | dataset: 30 | dataset_name: enhancer # human_enhancers_cohn 31 | max_length: 500 32 | d_output: 2 # num classes 33 | train_len: 14968 34 | tokenizer_name: char 35 | batch_size: 128 # Per GPU 36 | 37 | 38 | 39 | # we need to set the correct config for the dataset 40 | 41 | # Genomic Benchmark 42 | # there are 8 datasets in this suite, choose 1 at a time, with their corresponding settings 43 | # name num_seqs num_classes median len std 44 | # dummy_mouse_enhancers_ensembl 1210 2 2381 984.4 45 | # demo_coding_vs_intergenomic_seqs 100_000 2 200 0 46 | # demo_human_or_worm 100_000 2 200 0 47 | # human_enhancers_cohn 27791 2 500 0 48 | # human_enhancers_ensembl 154842 2 269 122.6 49 | # human_ensembl_regulatory 289061 3 401 184.3 50 | # human_nontata_promoters 36131 2 251 0 51 | # human_ocr_ensembl 174756 2 315 108.1 52 | 53 | # Nucleotide Transformer 54 | # name, max_len, d_output (classes), train_len 55 | # enhancer 200 2 14968 # binary 56 | # enhancer_types 200 3 14968 57 | # H3 500 2 13468 58 | # H3K4me1 500 2 28509 59 | # H3K4me2 500 2 27614 60 | # H3K4me3 500 2 33119 61 | # H3K9ac 500 2 25003 62 | # H3K14ac 500 2 29743 63 | # H3K36me3 500 2 31392 64 | # H3K79me3 500 2 25953 65 | # H4 500 2 13140 66 | # H4ac 500 2 30685 67 | # promoter_all 300 2 53276 68 | # promoter_non_tata 300 2 47759 69 | # promoter_tata 300 2 5517 70 | # splice_sites_acceptor 600 2 19961 71 | # splice_sites_donor 600 2 19775 -------------------------------------------------------------------------------- /convnova/src/models/basenji2/basenji2.py: -------------------------------------------------------------------------------- 1 | from .model import Basenji2 as bsj2 2 | import json 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class Basenji2(nn.Module): 8 | def __init__(self, params, d_output, seq_length, d_model=512, repeat_conv_tower=6, repeat_dilation=11, use_cropping=True, **kwargs): 9 | super().__init__() 10 | with open(params) as params_open: 11 | model_params = json.load(params_open)['model'] 12 | model_params["head_human"]["units"] = d_output 13 | model_params["seq_length"] = seq_length 14 | model_params["target_length"] = seq_length 15 | model_params['trunk'][1]['repeat'] = repeat_conv_tower 16 | model_params['trunk'][2]['repeat'] = repeat_dilation 17 | if not use_cropping: 18 | model_params['trunk'].pop(3) 19 | # model_params['trunk'][2]['in_channels'] = int(model_params['trunk'][1]['repeat']['in_channels']**repeat_conv_tower) 20 | # model_params['trunk'][4]['in_channels'] = int(model_params['trunk'][1]['repeat']['in_channels']**repeat_conv_tower) 21 | self.d_model = d_model 22 | model_params['trunk'][-1]['filters'] = d_model 23 | model_params['head_human']['in_features'] = d_model 24 | 25 | 26 | self.basenji2 = bsj2(model_params) 27 | 28 | def forward(self, input_ids, position_ids=None, inference_params=None, state=None): # state for the repo interface 29 | if isinstance(input_ids, list): 30 | input_ids_tensor = input_ids[0] 31 | attention_mask = input_ids[1] 32 | else: 33 | input_ids_tensor = torch.tensor(input_ids) 34 | attention_mask = None 35 | if position_ids is not None: 36 | position_ids_tensor = position_ids 37 | else: 38 | position_ids_tensor = None 39 | 40 | x = F.one_hot(input_ids, num_classes=5).float() 41 | 42 | outputs = self.basenji2( 43 | x=x.permute(0,2,1), 44 | return_only_embeddings=True, 45 | inputs_embeds=None, 46 | output_hidden_states=None, 47 | return_dict=None, 48 | ) 49 | hidden_states = outputs.permute(0,2,1) 50 | return hidden_states, None 51 | 52 | @property 53 | def d_output(self): 54 | """Model /embedding dimension, used for decoder mapping. 55 | """ 56 | if getattr(self, "d_model", None) is None: 57 | raise NotImplementedError("SequenceModule instantiation must set d_output") 58 | return self.d_model -------------------------------------------------------------------------------- /convnova/configs/experiment/hg38-pretrain/transformer.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /pipeline: hg38 4 | - override /scheduler: cosine_warmup_timm 5 | 6 | 7 | # model: 8 | # _name_: lm 9 | # d_model: 128 10 | # n_layer: 2 11 | # d_inner: ${eval:4 * ${.d_model}} 12 | # vocab_size: 12 13 | # resid_dropout: 0.0 14 | # embed_dropout: 0.1 15 | # fused_mlp: False # figure out how to use fused MLP, maybe only with bf16 + a100 16 | # fused_dropout_add_ln: False 17 | # residual_in_fp32: True 18 | # pad_vocab_size_multiple: 8 19 | # layer: 20 | # _name_: "nt" 21 | # n_heads: 8 22 | # return_state: False 23 | model: 24 | _name_: nt 25 | d_model: 128 # no use, just to test nt's compatibility with hyena 26 | alphabet_size: 12 27 | pad_token_id: 4 28 | mask_token_id: 3 29 | num_layers: 2 30 | max_positions: 513 31 | pad_vocab_size_multiple: 8 32 | 33 | task: 34 | _name_: lm 35 | 36 | trainer: 37 | accelerator: gpu 38 | devices: 8 39 | num_nodes: 1 40 | accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}} 41 | max_epochs: 100 42 | precision: 16 # bf16 only a100 43 | gradient_clip_val: 1.0 44 | # strategy: null 45 | 46 | dataset: 47 | batch_size: 32 # Per GPU 48 | # batch_size: 256 49 | max_length: ${model.max_positions} # 262144, 524288 50 | # optional, default is max_length 51 | max_length_val: ${dataset.max_length} 52 | max_length_test: ${dataset.max_length} 53 | tokenizer_name: char 54 | pad_max_length: null # needed for bpe tokenizer 55 | add_eos: true 56 | rc_aug: false 57 | num_workers: 12 58 | use_fixed_len_val: false # placing a fixed length val here, but it's really the test 59 | replace_N_token: false # replace N (uncertain token) with pad tokens in dataloader 60 | pad_interval: false # handle uncertain tokens within the FastaInteral class 61 | 62 | scheduler: 63 | t_in_epochs: False 64 | t_initial: ${eval:${div_up:${dataset.__train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}} 65 | warmup_lr_init: 1e-1 66 | warmup_t: ${eval:${div_up:${dataset.__train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.01} 67 | lr_min: ${eval:0.1 * ${optimizer.lr}} 68 | 69 | optimizer: 70 | lr: 6e-4 71 | weight_decay: 0.1 72 | 73 | train: 74 | gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"} 75 | seed: 2222 76 | global_batch_size: 256 # effects the scheduler, need to set properly 77 | -------------------------------------------------------------------------------- /convnova/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | accelerate==0.32.1 3 | aiohttp==3.9.5 4 | aiosignal==1.3.1 5 | annotated-types==0.7.0 6 | antlr4-python3-runtime==4.9.3 7 | async-timeout==4.0.3 8 | attrs==23.2.0 9 | beautifulsoup4==4.12.3 10 | biopython==1.83 11 | bleach==6.1.0 12 | cachetools==5.3.3 13 | charset-normalizer==3.3.2 14 | click==8.1.7 15 | cmake==3.29.6 16 | datasets==2.16.0 17 | deepspeed==0.14.4 18 | defusedxml==0.7.1 19 | dill==0.3.7 20 | docker-pycreds==0.4.0 21 | docopt==0.6.2 22 | einops==0.8.0 23 | fastjsonschema==2.20.0 24 | filelock==3.13.1 25 | fire==0.6.0 26 | frozenlist==1.4.1 27 | fsspec==2023.10.0 28 | gdown==5.2.0 29 | genomic_benchmarks==0.0.9 30 | gitdb==4.0.11 31 | GitPython==3.1.43 32 | google-auth==2.30.0 33 | google-auth-oauthlib==1.0.0 34 | grpcio==1.64.1 35 | hjson==3.1.0 36 | huggingface-hub==0.23.4 37 | hydra-core==1.3.2 38 | idna==3.7 39 | importlib_resources==6.4.0 40 | Jinja2==3.1.3 41 | joblib==1.4.2 42 | jsonschema==4.23.0 43 | jsonschema-specifications==2023.12.1 44 | jupyterlab_pygments==0.3.0 45 | liftover==1.1.18 46 | lit==18.1.8 47 | Markdown==3.6 48 | markdown-it-py==3.0.0 49 | MarkupSafe==2.1.5 50 | mdurl==0.1.2 51 | mistune==3.0.2 52 | mpmath==1.3.0 53 | multidict==6.0.5 54 | multiprocess==0.70.15 55 | nbclient==0.10.0 56 | nbformat==5.10.4 57 | networkx==3.0 58 | ninja==1.11.1.1 59 | numpy==1.24.1 60 | oauthlib==3.2.2 61 | omegaconf==2.3.0 62 | opt-einsum==3.3.0 63 | pandas==2.0.3 64 | pandocfilters==1.5.1 65 | peft==0.11.1 66 | pillow==10.2.0 67 | pipreqs==0.5.0 68 | pkgutil_resolve_name==1.3.10 69 | polars==0.20.13 70 | protobuf==5.27.2 71 | psutil==6.0.0 72 | py-cpuinfo==9.0.0 73 | pyarrow==16.1.0 74 | pyarrow-hotfix==0.6 75 | pyasn1==0.6.0 76 | pyasn1_modules==0.4.0 77 | pydantic==2.8.2 78 | pydantic_core==2.20.1 79 | pyfaidx==0.8.1.1 80 | pygments==2.17.1 81 | pynvml==11.5.0 82 | PySocks==1.7.1 83 | python-dateutil==2.8.2 84 | pytz==2024.1 85 | PyYAML==6.0.1 86 | referencing==0.35.1 87 | regex==2024.5.15 88 | requests==2.32.3 89 | requests-oauthlib==2.0.0 90 | rich==13.7.1 91 | rpds-py==0.20.0 92 | rsa==4.9 93 | safetensors==0.4.3 94 | scikit-learn==1.3.2 95 | scipy==1.10.1 96 | sentry-sdk==2.7.1 97 | setproctitle==1.3.3 98 | six==1.16.0 99 | smmap==5.0.1 100 | soupsieve==2.5 101 | sympy==1.12 102 | tensorboard==2.14.0 103 | tensorboard-data-server==0.7.2 104 | termcolor==2.4.0 105 | threadpoolctl==3.5.0 106 | timm==0.9.16 107 | tinycss2==1.3.0 108 | tokenizers==0.13.3 109 | tqdm==4.66.4 110 | transformers==4.28.0 111 | triton==2.0.0 112 | tzdata==2024.1 113 | urllib3==2.2.2 114 | wandb==0.17.3 115 | webencodings==0.5.1 116 | Werkzeug==3.0.3 117 | xxhash==3.4.1 118 | yarg==0.1.9 119 | yarl==1.9.4 120 | -------------------------------------------------------------------------------- /convnova/configs/experiment/nt-benchmark/legnet.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /pipeline: nucleotide_transformer 4 | - override /scheduler: cosine_warmup_timm 5 | 6 | model: 7 | _name_: legnet 8 | d_output: ${dataset.d_output} 9 | 10 | decoder: null 11 | 12 | task: 13 | # 2 options for soft_cross_entropy (for mixup) 14 | _name_: masked_multiclass 15 | loss: cross_entropy 16 | 17 | trainer: 18 | accelerator: gpu 19 | devices: 2 20 | num_nodes: 1 21 | accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}} 22 | max_epochs: 20 23 | precision: 32 # bf16 only a100 24 | gradient_clip_val: 1.0 25 | # strategy: 26 | # _target_: pytorch_lightning.strategies.DeepSpeedStrategy 27 | # stage: 1 28 | # logging_batch_size_per_gpu: 500 29 | 30 | # enhancer 200 2 14968 MCC 31 | # enhancer_types 200 3 14968 MCC 32 | # H3 500 2 13468 MCC 33 | # H3K4me1 500 2 28509 MCC 34 | # H3K4me2 500 2 27614 MCC 35 | # H3K4me3 500 2 33119 MCC 36 | # H3K9ac 500 2 25003 MCC 37 | # H3K14ac 500 2 29743 MCC 38 | # H3K36me3 500 2 31392 MCC 39 | # H3K79me3 500 2 25953 MCC 40 | # H4 500 2 13140 MCC 41 | # H4ac 500 2 30685 MCC 42 | # promoter_all 300 2 53276 F1 43 | # promoter_non_tata 300 2 47759 F1 44 | # promoter_tata 300 2 5517 F1 45 | # splice_sites_acceptor 600 2 19961 F1 46 | # splice_sites_donor 600 2 19775 F1 47 | 48 | 49 | dataset: 50 | # batch_size: 32 # Per GPU 51 | batch_size: 128 52 | # max_length: 515 # select max that you want for this dataset 53 | # dataset_name: 'human_nontata_promoters' 54 | dataset_name: 'H3K4me1' 55 | # dest_path: '/mnt/nas/share2/home/by/hyena-dna/data/genomic_benchmark/' 56 | # d_output: 3 # binary classification by default 57 | # use_padding: True 58 | # padding_side: 'left' 59 | # add_eos: False 60 | # train_len: 289061 # update this according to above table 61 | # __l_max: ${.max_length} 62 | tokenizer_name: char 63 | use_tokenizer: false 64 | add_eos: false 65 | rc_aug: false # reverse complement augmentation 66 | return_mask: false 67 | padding_side: left 68 | 69 | scheduler: 70 | t_in_epochs: False 71 | t_initial: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}} 72 | warmup_lr_init: 1e-6 73 | warmup_t: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.01} 74 | lr_min: ${eval:0.1 * ${optimizer.lr}} 75 | 76 | optimizer: 77 | lr: 6e-4 78 | weight_decay: 1e-5 79 | 80 | train: 81 | remove_test_loader_in_eval: false # no test set in this benchmark 82 | gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"} 83 | seed: 2222 84 | global_batch_size: ${eval:${trainer.devices}*${dataset.batch_size}} # effects the scheduler, need to set properly 85 | # pretrained_model_path: /gpfs/gibbs/pi/gerstein/xt86/by/hyena-dna/outputs/rope_last.ckpt 86 | # pretrained_model_strict_load: false -------------------------------------------------------------------------------- /convnova/src/models/sequence/long_conv_kernel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import repeat 5 | 6 | from src.utils.train import OptimModule 7 | 8 | class LongConvKernel(OptimModule): 9 | def __init__( 10 | self, 11 | H, 12 | L, 13 | channels=1, 14 | learning_rate=None, 15 | lam=0.1, 16 | causal=True, 17 | kernel_dropout=0, 18 | weight_init="random", 19 | use_ma_smoothing = False, 20 | ma_window_len = 7, 21 | smooth_freq = False, 22 | **kwargs 23 | ): 24 | super().__init__() 25 | 26 | self.drop = torch.nn.Dropout(p=kernel_dropout) 27 | self.H = H 28 | self.weight_init = weight_init 29 | self.causal = causal 30 | self.L = L*2 if not causal else L 31 | 32 | self.channels = channels 33 | self.lam = lam 34 | self.kernel = torch.nn.Parameter(self._parameter_initialization()) #(c,H,L) 35 | 36 | self.register("kernel", self.kernel, learning_rate) 37 | 38 | self.use_ma_smoothing=use_ma_smoothing 39 | self.smooth_freq = smooth_freq 40 | self.ma_window_len = ma_window_len 41 | if self.use_ma_smoothing: 42 | if smooth_freq: 43 | weight = torch.arange(ma_window_len, dtype = self.kernel.dtype) 44 | weight = torch.exp(-0.5 * torch.abs(weight - ma_window_len // 2) ** 2) 45 | weight = repeat(weight, 'l -> h1 h2 l', h1 = self.H, h2 = 1) 46 | weight = weight.type(torch.fft.rfft(self.kernel).dtype) 47 | self.smooth_weight = weight 48 | else: 49 | self.ma_window_len = ma_window_len 50 | assert self.ma_window_len%2!=0, "window size must be odd" 51 | padding = (self.ma_window_len//2) 52 | self.smooth = torch.nn.AvgPool1d(kernel_size=self.ma_window_len,stride=1,padding=padding) 53 | 54 | def _parameter_initialization(self): 55 | if self.weight_init=="random": 56 | return torch.randn(self.channels, self.H, self.L) * 0.002 57 | elif self.weight_init=="double_exp": 58 | K = torch.randn(self.channels, self.H, self.L,dtype=torch.float32) * 0.02 59 | double_exp = torch.zeros((self.H,self.L),dtype=torch.float32) 60 | for i in range(self.H): 61 | for j in range(self.L): 62 | double_exp[i,j] = torch.exp(-(j/self.L)*torch.pow(torch.tensor(int(self.H/2)),torch.tensor(i/self.H))) 63 | K = torch.einsum("c h l, h l -> c h l",K,double_exp) 64 | return K 65 | else: raise NotImplementedError(f"{self.weight_init} is not valid") 66 | 67 | def forward(self, **kwargs): 68 | k = self.kernel 69 | if self.use_ma_smoothing: 70 | if self.smooth_freq: 71 | k_f = torch.fft.rfft(k, dim=-1) 72 | k_f = F.conv1d(k_f, self.smooth_weight.to(k_f.device), padding='same', groups=self.H) 73 | k = torch.fft.irfft(k_f, dim=-1) 74 | else: 75 | k = self.smooth(k) 76 | k = F.relu(torch.abs(k)-self.lam)*torch.sign(k) 77 | k = self.drop(k) 78 | return k, None 79 | 80 | @property 81 | def d_output(self): 82 | return self.H -------------------------------------------------------------------------------- /convnova/configs/dataset/nucleotide_transformer.yaml: -------------------------------------------------------------------------------- 1 | _name_: nucleotide_transformer # this links to the overall SequenceDataset of all nucleotide transformer datasets 2 | dataset_name: enhancer # this specifies which dataset in nuc trx 3 | dest_path: null # path to overall nuc trx datasets 4 | max_length: ${.${.dataset_name}.max_length} 5 | d_output: ${.${.dataset_name}.classes} 6 | use_padding: True 7 | padding_side: left 8 | add_eos: False 9 | batch_size: 32 10 | train_len: ${.${.dataset_name}.train_len} 11 | __l_max: ${.max_length} 12 | shuffle: true # set this as default! 13 | metric: ${.${.dataset_name}.metric} 14 | # these are used to find the right attributes automatically for each dataset 15 | enhancer: 16 | train_len: 14968 17 | classes: 2 18 | max_length: 200 19 | metric: mcc 20 | enhancer_types: 21 | train_len: 14968 22 | classes: 3 23 | max_length: 200 24 | metric: mcc 25 | H3: 26 | train_len: 13468 27 | classes: 2 28 | max_length: 500 29 | metric: mcc 30 | H3K4me1: 31 | train_len: 28509 32 | classes: 2 33 | max_length: 500 34 | metric: mcc 35 | H3K4me2: 36 | train_len: 27614 37 | classes: 2 38 | max_length: 500 39 | metric: mcc 40 | H3K4me3: 41 | train_len: 33119 42 | classes: 2 43 | max_length: 500 44 | metric: mcc 45 | H3K9ac: 46 | train_len: 25003 47 | classes: 2 48 | max_length: 500 49 | metric: mcc 50 | H3K14ac: 51 | train_len: 29743 52 | classes: 2 53 | max_length: 500 54 | metric: mcc 55 | H3K36me3: 56 | train_len: 31392 57 | classes: 2 58 | max_length: 500 59 | metric: mcc 60 | H3K79me3: 61 | train_len: 25953 62 | classes: 2 63 | max_length: 500 64 | metric: mcc 65 | H4: 66 | train_len: 13140 67 | classes: 2 68 | max_length: 500 69 | metric: mcc 70 | H4ac: 71 | train_len: 30685 72 | classes: 2 73 | max_length: 500 74 | metric: mcc 75 | promoter_all: 76 | train_len: 53276 77 | classes: 2 78 | max_length: 300 79 | metric: f1_macro 80 | promoter_non_tata: 81 | train_len: 47759 82 | classes: 2 83 | max_length: 300 84 | metric: f1_macro 85 | promoter_tata: 86 | train_len: 5517 87 | classes: 2 88 | max_length: 300 89 | metric: f1_macro 90 | splice_sites_acceptor: 91 | train_len: 19961 92 | classes: 2 93 | max_length: 600 94 | metric: f1_macro 95 | splice_sites_donor: 96 | train_len: 19775 97 | classes: 2 98 | max_length: 600 99 | metric: f1_macro 100 | #test 101 | splice_sites_all: 102 | train_len: 27000 #data/nucleotide_transformer/splice_sites_all/splice_sites_all_train.fasta.fai 103 | classes: 3 #3分类 不是2分类 104 | max_length: 600 105 | metric: f1_macro 106 | # splice_sites_all: 107 | # train_len: 27000 108 | # classes: 3 109 | # max_length: 600 110 | # metric: f1_macro 111 | 112 | # name maxlen classes samples metric 113 | 114 | # enhancer 200 2 14968 MCC 115 | # enhancer_types 200 3 14968 MCC 116 | # H3 500 2 13468 MCC 117 | # H3K4me1 500 2 28509 MCC 118 | # H3K4me2 500 2 27614 MCC 119 | # H3K4me3 500 2 33119 MCC 120 | # H3K9ac 500 2 25003 MCC 121 | # H3K14ac 500 2 29743 MCC 122 | # H3K36me3 500 2 31392 MCC 123 | # H3K79me3 500 2 25953 MCC 124 | # H4 500 2 13140 MCC 125 | # H4ac 500 2 30685 MCC 126 | # promoter_all 300 2 53276 F1 127 | # promoter_non_tata 300 2 47759 F1 128 | # promoter_tata 300 2 5517 F1 129 | # splice_sites_acceptor 600 2 19961 F1 130 | # splice_sites_donor 600 2 19775 F1 131 | -------------------------------------------------------------------------------- /convnova/configs/experiment/hg38-pretrain/bert_hg38_hyena.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /pipeline: bert_hg38 4 | - override /scheduler: cosine_warmup_timm 5 | 6 | model: 7 | _name_: blm 8 | d_model: 256 9 | n_layer: 2 10 | d_inner: ${eval:4 * ${.d_model}} 11 | vocab_size: 12 12 | resid_dropout: 0.0 13 | embed_dropout: 0.1 14 | fused_mlp: False # figure out how to use fused MLP, maybe only with bf16 + a100 15 | fused_dropout_add_ln: False 16 | checkpoint_mixer: False # set true for memory reduction 17 | checkpoint_mlp: False # set true for memory reduction 18 | residual_in_fp32: True 19 | pad_vocab_size_multiple: 8 20 | layer: 21 | _name_: hyena 22 | emb_dim: 5 23 | filter_order: 64 24 | short_filter_order: 3 25 | l_max: ${dataset.max_length} 26 | modulate: True 27 | w: 10 28 | lr: ${optimizer.lr} 29 | wd: 0.0 30 | lr_pos_emb: 0.0 31 | bidirectional: True 32 | 33 | task: 34 | # _name_: lm 35 | _name_: hg38 # equivalent to lm task, plus allows extra metrics to be calculated 36 | loss: bert_cross_entropy 37 | 38 | trainer: 39 | accelerator: gpu 40 | devices: 2 41 | num_nodes: 1 42 | # accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}} 43 | max_epochs: 1000 44 | precision: 16 # bf16 only a100 45 | gradient_clip_val: 1.0 46 | strategy: null 47 | 48 | # callbacks: 49 | # seqlen_warmup_reload: 50 | # # epochs refers to how long to run at that stage (not cummulative!) 51 | # # this is just a sample 52 | # stage_params: 53 | # - epochs: 2 # means run this stage for 2 epochs (0, and 1) 54 | # seq_len: 1024 55 | # batch_size: 256 # grad accum = 1, since train.global_batch_size=256 56 | # - epochs: 2 # run for 2 epochs (2 and 3) 57 | # seq_len: 2048 58 | # batch_size: 128 59 | # - epochs: 2 # run for epochs 4, 5 60 | # seq_len: 4096 # 61 | # batch_size: 64 62 | # - epochs: 2 # epoch 6, 7 63 | # seq_len: 8192 64 | # batch_size: 32 65 | # - epochs: 4 # epoch 8, 9, 10, 11 66 | # seq_len: 16_384 # 67 | # batch_size: 16 68 | # - epochs: 4 # epoch 12, 13, 14, 15 69 | # seq_len: 32_768 70 | # batch_size: 8 71 | 72 | dataset: 73 | batch_size: 128 # Per GPU 74 | # batch_size: 8 # this is the test batch size (and final train batch size) 75 | max_length: 1024 # note this is the test max length (and the final train max_length) + 2 76 | # optional, default is max_length 77 | max_length_val: ${dataset.max_length} 78 | max_length_test: ${dataset.max_length} 79 | tokenizer_name: char 80 | pad_max_length: null # needed for bpe tokenizer only 81 | add_eos: true 82 | rc_aug: false 83 | num_workers: 12 84 | 85 | scheduler: 86 | t_in_epochs: False 87 | t_initial: ${eval:${div_up:${dataset.__train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}} 88 | warmup_lr_init: 1e-6 89 | warmup_t: ${eval:${div_up:${dataset.__train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.01} 90 | lr_min: ${eval:0.1 * ${optimizer.lr}} 91 | 92 | optimizer: 93 | lr: 6e-4 94 | weight_decay: 0.1 95 | 96 | train: 97 | gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"} 98 | seed: 2222 99 | global_batch_size: 256 # effects the scheduler, need to set properly -------------------------------------------------------------------------------- /convnova/configs/experiment/hg38-pretrain/mamba.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /pipeline: hg38 4 | - override /scheduler: cosine_warmup_timm 5 | 6 | model: 7 | _name_: lm 8 | d_model: 128 9 | n_layer: 2 10 | d_inner: ${eval:4 * ${.d_model}} 11 | vocab_size: 12 12 | resid_dropout: 0.0 13 | embed_dropout: 0.1 14 | fused_mlp: False # figure out how to use fused MLP, maybe only with bf16 + a100 15 | fused_dropout_add_ln: False 16 | checkpoint_mixer: False # set true for memory reduction 17 | checkpoint_mlp: False # set true for memory reduction 18 | residual_in_fp32: True 19 | pad_vocab_size_multiple: 8 20 | layer: 21 | _name_: ssm 22 | d_state: 16 23 | d_conv: 4 24 | expand: 2 25 | dt_rank: "auto" 26 | dt_min: 0.001 27 | dt_max: 0.1 28 | dt_init: "random" 29 | dt_scale: 1.0 30 | dt_init_floor: 1e-4 31 | conv_bias: True 32 | bias: False 33 | use_fast_path: True 34 | return_last_state: False 35 | # layer_idx: null 36 | # device: null 37 | # dtype: null 38 | 39 | 40 | task: 41 | _name_: lm 42 | 43 | trainer: 44 | accelerator: gpu 45 | devices: 5 46 | num_nodes: 1 47 | accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}} 48 | max_epochs: 2000 49 | precision: 16 # bf16 only a100 50 | gradient_clip_val: 1.0 51 | # strategy: null 52 | 53 | dataset: 54 | batch_size: 32 # Per GPU 55 | # batch_size: 256 56 | max_length: 513 # 262144, 524288 57 | # optional, default is max_length 58 | max_length_val: ${dataset.max_length} 59 | max_length_test: ${dataset.max_length} 60 | tokenizer_name: char 61 | pad_max_length: null # needed for bpe tokenizer 62 | add_eos: true 63 | rc_aug: false 64 | num_workers: 12 65 | use_fixed_len_val: false # placing a fixed length val here, but it's really the test 66 | replace_N_token: false # replace N (uncertain token) with pad tokens in dataloader 67 | pad_interval: false # handle uncertain tokens within the FastaInteral class 68 | 69 | # scheduler: 70 | # t_in_epochs: False 71 | # t_initial: ${eval:${div_up:${dataset.__train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}} 72 | # warmup_lr_init: 1e-6 73 | # warmup_t: ${eval:${div_up:${dataset.__train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.01} 74 | # lr_min: ${eval:0.1 * ${optimizer.lr}} 75 | 76 | scheduler: 77 | t_in_epochs: True 78 | # t_initial: ${eval:${div_up:${dataset.__train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}} # num steps for 1 cycle 79 | t_initial: ${eval:${div_up:${dataset.__train_len}, ${train.global_batch_size}} * 1} 80 | cycle_mul: 2 81 | warmup_lr_init: 1e-6 # starting point 82 | # warmup_t: ${eval:${div_up:${dataset.__train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.005} # time for ramp up 83 | warmup_t: ${eval:${div_up:${dataset.__train_len}, ${train.global_batch_size}} * 0.1} 84 | # lr_min: ${eval:0.1 * ${optimizer.lr}} # flatlines with this 85 | lr_min: 1e-6 86 | cycle_decay: 0.4 87 | cycle_limit: 10000 88 | 89 | optimizer: 90 | lr: 2e-3 91 | weight_decay: 0.1 92 | 93 | train: 94 | gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"} 95 | seed: 2222 96 | global_batch_size: ${eval:${dataset.batch_size}*${trainer.devices}} # effects the scheduler, need to set properly -------------------------------------------------------------------------------- /convnova/configs/experiment/genomic-benchmark/legnet.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /pipeline: genomic_benchmark 4 | - override /scheduler: cosine_warmup_timm 5 | 6 | model: 7 | _name_: legnet 8 | d_output: ${dataset.d_output} 9 | 10 | decoder: null 11 | 12 | task: 13 | # 2 options for soft_cross_entropy (for mixup) 14 | loss: 15 | # soft_cross_entropy for pytorch 1.10+, which takes in label_smoothing here 16 | _name_: cross_entropy 17 | # label_smoothing: 0.1 18 | # pass in list of k's 19 | # last_k_ppl: null 20 | 21 | trainer: 22 | accelerator: gpu 23 | devices: 4 24 | num_nodes: 1 25 | accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}} 26 | max_epochs: 10 27 | precision: 16 # bf16 only a100 28 | gradient_clip_val: 1.0 29 | # strategy: null 30 | 31 | # there are 8 datasets in this suite, choose 1 at a time, with their corresponding settings 32 | # name num_seqs num_classes median len std 33 | # dummy_mouse_enhancers_ensembl 1210 2 2381 984.4 34 | # demo_coding_vs_intergenomic_seqs 100_000 2 200 0 35 | # demo_human_or_worm 100_000 2 200 0 36 | # human_enhancers_cohn 27791 2 500 0 37 | # human_enhancers_ensembl 154842 2 269 122.6 38 | # human_ensembl_regulatory 289061 3 401 184.3 39 | # human_nontata_promoters 36131 2 251 0 40 | # human_ocr_ensembl 174756 2 315 108.1 41 | 42 | # decoder: null 43 | 44 | dataset: 45 | # batch_size: 32 # Per GPU 46 | batch_size: 64 47 | max_length: 400 # select max that you want for this dataset 48 | # dataset_name: 'human_nontata_promoters' 49 | dataset_name: 'human_ocr_ensembl' 50 | dest_path: # project root'data/genomic_benchmark/' 51 | # d_output: 2 # binary classification by default 52 | use_padding: True 53 | padding_side: 'left' 54 | add_eos: False 55 | # train_len: 174756 # update this according to above table 56 | # __l_max: ${.max_length} 57 | tokenizer_name: char 58 | use_tokenizer: False 59 | # rc_aug: true # reverse complement augmentation. Didn't seem to help for human_nontata_promoters, but could be wrong 60 | 61 | scheduler: 62 | t_in_epochs: False 63 | t_initial: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}} 64 | warmup_lr_init: 1e-6 65 | warmup_t: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.01} 66 | lr_min: ${eval:0.1 * ${optimizer.lr}} 67 | 68 | optimizer: 69 | lr: 6e-4 70 | weight_decay: 0.1 71 | 72 | train: 73 | gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"} 74 | seed: 43 75 | global_batch_size: ${eval:${trainer.devices}*${dataset.batch_size}} 76 | remove_test_loader_in_eval: true # no test set in this benchmark 77 | pretrained_model_strict_load: False # false allows encoder/decoder to be used if new model uses it 78 | pretrained_model_state_hook: 79 | _name_: load_backbone 80 | freeze_backbone: false # seems to work much better if false (ie finetune entire model) 81 | -------------------------------------------------------------------------------- /convnova/configs/experiment/nt-benchmark/convnova.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /pipeline: nucleotide_transformer 4 | - override /scheduler: cosine_warmup_timm 5 | 6 | model: 7 | _name_: convnova 8 | for_representation: true 9 | alphabet_size: 5 10 | d_model: 128 11 | kernel_size: 9 12 | dilation: 4 13 | pretrain: False 14 | num_conv1d: 5 15 | final_conv: False 16 | d_inner: 2 17 | ffn: false 18 | args: 19 | hidden_dim: 128 # same as d_model 20 | num_cnn_stacks: 1 21 | dropout: 0.0 22 | 23 | # decoder: null # decoder in cnn 24 | 25 | task: 26 | # 2 options for soft_cross_entropy (for mixup) 27 | _name_: masked_multiclass 28 | loss: cross_entropy 29 | # label_smoothing: 0.1 30 | # pass in list of k's 31 | # last_k_ppl: null 32 | torchmetrics: null 33 | 34 | trainer: 35 | accelerator: gpu 36 | devices: 2 37 | num_nodes: 1 38 | accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}} 39 | max_epochs: 20 40 | precision: 32 # bf16 only a100 41 | gradient_clip_val: 1.0 42 | # strategy: null 43 | 44 | # name maxlen classes samples metric 45 | 46 | # enhancer 200 2 14968 MCC 47 | # enhancer_types 200 3 14968 MCC 48 | # H3 500 2 13468 MCC 49 | # H3K4me1 500 2 28509 MCC 50 | # H3K4me2 500 2 27614 MCC 51 | # H3K4me3 500 2 33119 MCC 52 | # H3K9ac 500 2 25003 MCC 53 | # H3K14ac 500 2 29743 MCC 54 | # H3K36me3 500 2 31392 MCC 55 | # H3K79me3 500 2 25953 MCC 56 | # H4 500 2 13140 MCC 57 | # H4ac 500 2 30685 MCC 58 | # promoter_all 300 2 53276 F1 59 | # promoter_non_tata 300 2 47759 F1 60 | # promoter_tata 300 2 5517 F1 61 | # splice_sites_acceptor 600 2 19961 F1 62 | # splice_sites_donor 600 2 19775 F1 63 | 64 | 65 | dataset: 66 | batch_size: 128 67 | dataset_name: 'H3K36me3' 68 | tokenizer_name: char 69 | use_tokenizer: False 70 | add_eos: false 71 | rc_aug: false # reverse complement augmentation 72 | return_mask: false 73 | padding_side: left 74 | # num_workers: 1 75 | # rc_aug: true # reverse complement augmentation. Didn't seem to help for human_nontata_promoters, but could be wrong 76 | 77 | scheduler: 78 | t_in_epochs: False 79 | t_initial: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}} 80 | warmup_lr_init: 1e-6 81 | warmup_t: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.01} 82 | lr_min: ${eval:0.1 * ${optimizer.lr}} 83 | 84 | # constant 85 | # t_initial: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}} 86 | # warmup_t: 0 87 | # lr_min: ${optimizer.lr} 88 | 89 | 90 | 91 | optimizer: 92 | lr: 1e-3 93 | weight_decay: 0.1 94 | 95 | train: 96 | gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"} 97 | seed: 48 98 | global_batch_size: ${eval:${trainer.devices}*${dataset.batch_size}} 99 | remove_test_loader_in_eval: true # no test set in this benchmark 100 | pretrained_model_strict_load: False # false allows encoder/decoder to be used if new model uses it 101 | # for loading backbone and not head, requires both of these flags below 102 | # pretrained_model_path: 103 | pretrained_model_state_hook: 104 | _name_: load_backbone 105 | freeze_backbone: false # seems to work much better if false (ie finetune entire model) 106 | 107 | -------------------------------------------------------------------------------- /convnova/configs/config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - experiment: base # Specifies model and pipeline, equivalent to next two lines 5 | # - model: s4 # Model backbone 6 | # - pipeline: cifar # Specifies collection of configs, equivalent to next 5 lines 7 | # Pipelines should specify /loader, /dataset, /task, /encoder, /decoder (ideally in that order) 8 | # # - loader: default # Dataloader (e.g. handles batches) 9 | # # - dataset: cifar # Defines the data (x and y pairs) 10 | # # - task: multiclass_classification # Defines loss and metrics 11 | # # - encoder: null # Interface between data and model 12 | # # - decoder: null # Interface between model and targets 13 | 14 | # Additional arguments used to configure the training loop 15 | # Most of these set combinations of options in the PL trainer, add callbacks, or add features to the optimizer 16 | train: 17 | seed: 0 18 | # These three options are used by callbacks (checkpoint, monitor) and scheduler 19 | # Most of them are task dependent and are set by the pipeline 20 | interval: ??? # Should be specified by scheduler. Also used by LR monitor 21 | monitor: ??? # Should be specified by pipeline. Used by scheduler (plateau) and checkpointer 22 | mode: ??? # Should be specified by pipeline. Used by scheduler (plateau) and checkpointer 23 | ema: 0.0 # Moving average model for validation # TODO move into callback 24 | test: False # Test after training 25 | debug: False # Special settings to make debugging more convenient 26 | ignore_warnings: False # Disable python warnings 27 | 28 | # These control state passing between batches 29 | state: 30 | mode: null # [ None | 'none' | 'reset' | 'bptt' | 'tbptt' ] 31 | n_context: 0 # How many steps to use as memory context. Must be >= 0 or None (null), meaning infinite context 32 | n_context_eval: ${.n_context} # Context at evaluation time 33 | # Convenience keys to allow grouping runs 34 | 35 | ckpt: null # Resume training 36 | 37 | disable_dataset: False # Disable dataset loading 38 | validate_at_start: false 39 | 40 | # pretrained_model_path: /mnt/nas/share2/home/by/hyena-dna/outputs/2024-03-15/12-52-24-903986/checkpoints/test/loss.ckpt # Path to pretrained model 41 | # pretrained_model_path: /mnt/nas/share2/home/by/hyena-dna/outputs/2024-03-24/22-43-40-435805/checkpoints/test/loss.ckpt 42 | pretrained_model_path: null 43 | 44 | pretrained_model_strict_load: true # Whether to load the pretrained model even if the model is not compatible 45 | pretrained_model_state_hook: # Hook called on the loaded model's state_dict 46 | _name_: null 47 | post_init_hook: # After initializing model, call method on model 48 | _name_: null 49 | 50 | layer_decay: # Used for ImageNet finetuning 51 | _name_: null 52 | decay: 0.7 53 | 54 | tolerance: # fault tolerance for training on preemptible machines 55 | logdir: ./resume 56 | id: null # must be set to resume training on preemption 57 | 58 | # We primarily use wandb so this is moved to top level in the config for convenience 59 | # Set `~wandb` or `wandb=null` or `wandb.mode=disabled` to disable logging 60 | # If other loggers are added, it would make sense to put this one level lower under train/ or logger/ 61 | wandb: 62 | project: dna 63 | group: "" 64 | job_type: training 65 | mode: online # choices=['online', 'offline', 'disabled'] 66 | name: null 67 | save_dir: "." 68 | id: ${.name} # pass correct id to resume experiment! 69 | # Below options should not need to be specified 70 | # entity: "" # set to name of your wandb team or just remove it 71 | # log_model: False 72 | # prefix: "" 73 | # job_type: "train" 74 | # tags: [] 75 | 76 | hydra: 77 | run: 78 | dir: ./outputs/${now:%Y-%m-%d}/${now:%H-%M-%S-%f} 79 | -------------------------------------------------------------------------------- /convnova/configs/experiment/genomic-benchmark/basenji.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /pipeline: genomic_benchmark 4 | - override /scheduler: cosine_warmup_timm 5 | 6 | model: 7 | _name_: basenji2 8 | params: # project root/src/models/basenji2/params_human.json 9 | seq_length: ${dataset.max_length} 10 | d_output: ${dataset.d_output} 11 | repeat_conv_tower: 3 12 | repeat_dilation: 7 13 | use_cropping: false 14 | d_model: 512 15 | 16 | # decoder: null 17 | 18 | task: 19 | # 2 options for soft_cross_entropy (for mixup) 20 | loss: 21 | # soft_cross_entropy for pytorch 1.10+, which takes in label_smoothing here 22 | _name_: cross_entropy 23 | # label_smoothing: 0.1 24 | # pass in list of k's 25 | # last_k_ppl: null 26 | 27 | trainer: 28 | accelerator: gpu 29 | devices: 4 30 | num_nodes: 1 31 | accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}} 32 | max_epochs: 10 33 | precision: 16 # bf16 only a100 34 | gradient_clip_val: 1.0 35 | # strategy: null 36 | 37 | # there are 8 datasets in this suite, choose 1 at a time, with their corresponding settings 38 | # name num_seqs num_classes median len std 39 | # dummy_mouse_enhancers_ensembl 1210 2 2381 984.4 40 | # demo_coding_vs_intergenomic_seqs 100_000 2 200 0 41 | # demo_human_or_worm 100_000 2 200 0 42 | # human_enhancers_cohn 27791 2 500 0 43 | # human_enhancers_ensembl 154842 2 269 122.6 44 | # human_ensembl_regulatory 289061 3 401 184.3 45 | # human_nontata_promoters 36131 2 251 0 46 | # human_ocr_ensembl 174756 2 315 108.1 47 | 48 | # decoder: null 49 | 50 | dataset: 51 | # batch_size: 32 # Per GPU 52 | batch_size: 64 53 | max_length: 400 # select max that you want for this dataset 54 | # dataset_name: 'human_nontata_promoters' 55 | dataset_name: 'human_ocr_ensembl' 56 | dest_path: # project root/data/genomic_benchmark/' 57 | # d_output: 2 # binary classification by default 58 | use_padding: True 59 | padding_side: 'left' 60 | add_eos: False 61 | # train_len: 174756 # update this according to above table 62 | # __l_max: ${.max_length} 63 | tokenizer_name: char 64 | use_tokenizer: False 65 | # rc_aug: true # reverse complement augmentation. Didn't seem to help for human_nontata_promoters, but could be wrong 66 | 67 | scheduler: 68 | t_in_epochs: False 69 | t_initial: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}} 70 | warmup_lr_init: 1e-6 71 | warmup_t: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.01} 72 | lr_min: ${eval:0.1 * ${optimizer.lr}} 73 | 74 | optimizer: 75 | lr: 6e-4 76 | weight_decay: 0.1 77 | 78 | train: 79 | gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"} 80 | seed: 43 81 | global_batch_size: ${eval:${trainer.devices}*${dataset.batch_size}} 82 | remove_test_loader_in_eval: true # no test set in this benchmark 83 | pretrained_model_strict_load: False # false allows encoder/decoder to be used if new model uses it 84 | # for loading backbone and not head, requires both of these flags below 85 | pretrained_model_state_hook: 86 | _name_: load_backbone 87 | freeze_backbone: false # seems to work much better if false (ie finetune entire model) 88 | -------------------------------------------------------------------------------- /convnova/src/utils/registry.py: -------------------------------------------------------------------------------- 1 | optimizer = { 2 | "adam": "torch.optim.Adam", 3 | "adamw": "torch.optim.AdamW", 4 | "rmsprop": "torch.optim.RMSprop", 5 | "sgd": "torch.optim.SGD", 6 | "lamb": "src.utils.optim.lamb.JITLamb", 7 | } 8 | 9 | scheduler = { 10 | "constant": "transformers.get_constant_schedule", 11 | "plateau": "torch.optim.lr_scheduler.ReduceLROnPlateau", 12 | "step": "torch.optim.lr_scheduler.StepLR", 13 | "multistep": "torch.optim.lr_scheduler.MultiStepLR", 14 | "cosine": "torch.optim.lr_scheduler.CosineAnnealingLR", 15 | "constant_warmup": "transformers.get_constant_schedule_with_warmup", 16 | "linear_warmup": "transformers.get_linear_schedule_with_warmup", 17 | "cosine_warmup": "transformers.get_cosine_schedule_with_warmup", 18 | "cosine_warmup_timm": "src.utils.optim.schedulers.TimmCosineLRScheduler", 19 | } 20 | 21 | model = { 22 | # Backbones from this repo 23 | "model": "src.models.sequence.SequenceModel", 24 | "lm": "src.models.sequence.long_conv_lm.ConvLMHeadModel", 25 | "blm": "src.models.sequence.long_conv_lm.BertLMHeadModel", 26 | "lm_simple": "src.models.sequence.simple_lm.SimpleLMHeadModel", 27 | "vit_b_16": "src.models.baselines.vit_all.vit_base_patch16_224", 28 | "dna_embedding": "src.models.sequence.dna_embedding.DNAEmbeddingModel", 29 | "bpnet": "src.models.sequence.hyena_bpnet.HyenaBPNet", 30 | # "convnext": "src.models.sequence.convNext.ConvNeXt", 31 | "convnova": "src.models.ConvNova.convnova.CNNModel", 32 | # "nconvnext": "src.models.sequence.convNext.NConvNeXt", 33 | # "dna_bert2": "src.models.DNABERT2.DNABERT2CustomModel", 34 | # "caduceus": "src.models.Caduceus.caduceus.Caduceus", 35 | # "visualizer": "src.models.sequence.visualizer.CNNModel", 36 | "ntv2": "src.models.NTV2.ntv2.NTV2", 37 | "legnet": "src.models.LegNet.LegNet.LegNet", 38 | "basenji2": "src.models.basenji2.basenji2.Basenji2", 39 | } 40 | 41 | layer = { 42 | "id": "src.models.sequence.base.SequenceIdentity", 43 | "ff": "src.models.sequence.ff.FF", 44 | "mha": "src.models.sequence.mha.MultiheadAttention", 45 | "s4d": "src.models.sequence.ssm.s4d.S4D", 46 | "s4_simple": "src.models.sequence.ssm.s4_simple.SimpleS4Wrapper", 47 | "long-conv": "src.models.sequence.long_conv.LongConv", 48 | "h3": "src.models.sequence.h3.H3", 49 | "h3-conv": "src.models.sequence.h3_conv.H3Conv", 50 | "hyena": "src.models.sequence.hyena.HyenaOperator", 51 | "hyena-filter": "src.models.sequence.hyena.HyenaFilter", 52 | "vit": "src.models.sequence.mha.VitAttention", 53 | "ssm": "src.models.sequence.pyramid.Mamba", 54 | "pyramid": "src.models.sequence.mha.MultiheadAttention", 55 | "bert": "src.models.sequence.pyramid.BertLayer" 56 | } 57 | 58 | layer_config = { 59 | "nt": "src.models.sequence.pyramid.NucleotideTransformerConfig" 60 | } 61 | 62 | callbacks = { 63 | "timer": "src.callbacks.timer.Timer", 64 | "params": "src.callbacks.params.ParamsLog", 65 | "learning_rate_monitor": "pytorch_lightning.callbacks.LearningRateMonitor", 66 | "model_checkpoint": "pytorch_lightning.callbacks.ModelCheckpoint", 67 | "early_stopping": "pytorch_lightning.callbacks.EarlyStopping", 68 | "swa": "pytorch_lightning.callbacks.StochasticWeightAveraging", 69 | "rich_model_summary": "pytorch_lightning.callbacks.RichModelSummary", 70 | "rich_progress_bar": "pytorch_lightning.callbacks.RichProgressBar", 71 | "progressive_resizing": "src.callbacks.progressive_resizing.ProgressiveResizing", 72 | "seqlen_warmup": "src.callbacks.seqlen_warmup.SeqlenWarmup", 73 | "seqlen_warmup_reload": "src.callbacks.seqlen_warmup_reload.SeqlenWarmupReload", 74 | "gpu_affinity": "src.callbacks.gpu_affinity.GpuAffinity" 75 | } 76 | 77 | model_state_hook = { 78 | 'load_backbone': 'src.models.sequence.long_conv_lm.load_backbone', 79 | } 80 | -------------------------------------------------------------------------------- /convnova/src/utils/optim/schedulers.py: -------------------------------------------------------------------------------- 1 | """Custom learning rate schedulers""" 2 | 3 | import math 4 | import warnings 5 | import torch 6 | 7 | from timm.scheduler import CosineLRScheduler 8 | 9 | 10 | # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html 11 | class CosineWarmup(torch.optim.lr_scheduler.CosineAnnealingLR): 12 | 13 | def __init__(self, optimizer, T_max, eta_min=0, warmup_step=0, **kwargs): 14 | self.warmup_step = warmup_step 15 | super().__init__(optimizer, T_max - warmup_step, eta_min, *kwargs) 16 | 17 | # Copied from CosineAnnealingLR, but adding warmup and changing self.last_epoch to 18 | # self.last_epoch - self.warmup_step. 19 | def get_lr(self): 20 | if not self._get_lr_called_within_step: 21 | warnings.warn("To get the last learning rate computed by the scheduler, " 22 | "please use `get_last_lr()`.", UserWarning) 23 | 24 | if self.last_epoch == self.warmup_step: # also covers the case where both are 0 25 | return self.base_lrs 26 | elif self.last_epoch < self.warmup_step: 27 | return [base_lr * (self.last_epoch + 1) / self.warmup_step for base_lr in self.base_lrs] 28 | elif (self.last_epoch - self.warmup_step - 1 - self.T_max) % (2 * self.T_max) == 0: 29 | return [group['lr'] + (base_lr - self.eta_min) * 30 | (1 - math.cos(math.pi / self.T_max)) / 2 31 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)] 32 | return [(1 + math.cos(math.pi * (self.last_epoch - self.warmup_step) / self.T_max)) / 33 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_step - 1) / self.T_max)) * 34 | (group['lr'] - self.eta_min) + self.eta_min 35 | for group in self.optimizer.param_groups] 36 | 37 | _get_closed_form_lr = None 38 | 39 | 40 | def InvSqrt(optimizer, warmup_step): 41 | """ Originally used for Transformer (in Attention is all you need) 42 | """ 43 | 44 | def lr_lambda(step): 45 | # return a multiplier instead of a learning rate 46 | if step == warmup_step: # also covers the case where both are 0 47 | return 1. 48 | else: 49 | return 1. / (step ** 0.5) if step > warmup_step else (step + 1) / (warmup_step ** 1.5) 50 | 51 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) 52 | 53 | 54 | def Constant(optimizer, warmup_step): 55 | 56 | def lr_lambda(step): 57 | if step == warmup_step: # also covers the case where both are 0 58 | return 1. 59 | else: 60 | return 1. if step > warmup_step else (step + 1) / warmup_step 61 | 62 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) 63 | 64 | 65 | class TimmCosineLRScheduler(CosineLRScheduler, torch.optim.lr_scheduler._LRScheduler): 66 | """ Wrap timm.scheduler.CosineLRScheduler so we can call scheduler.step() without passing in epoch. 67 | It supports resuming as well. 68 | """ 69 | 70 | def __init__(self, *args, **kwargs): 71 | super().__init__(*args, **kwargs) 72 | self._last_epoch = -1 73 | self.step(epoch=0) 74 | 75 | def step(self, epoch=None): 76 | if epoch is None: 77 | self._last_epoch += 1 78 | else: 79 | self._last_epoch = epoch 80 | # We call either step or step_update, depending on whether we're using the scheduler every 81 | # epoch or every step. 82 | # Otherwise, lightning will always call step (i.e., meant for each epoch), and if we set 83 | # scheduler interval to "step", then the learning rate update will be wrong. 84 | if self.t_in_epochs: 85 | super().step(epoch=self._last_epoch) 86 | else: 87 | super().step_update(num_updates=self._last_epoch) 88 | -------------------------------------------------------------------------------- /convnova/configs/experiment/genomic-benchmark/convnova.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /pipeline: genomic_benchmark 4 | - override /scheduler: cosine_warmup_timm 5 | 6 | model: 7 | _name_: convnova 8 | for_representation: true 9 | alphabet_size: 5 10 | d_model: 64 11 | kernel_size: 5 12 | dilation: 4 13 | pretrain: False 14 | num_conv1d: 5 15 | d_inner: 4 16 | final_conv: False 17 | ffn: false 18 | args: 19 | hidden_dim: 64 # same as d_model 20 | num_cnn_stacks: 1 21 | dropout: 0.0 22 | 23 | # decoder: null 24 | 25 | task: 26 | # 2 options for soft_cross_entropy (for mixup) 27 | loss: 28 | # soft_cross_entropy for pytorch 1.10+, which takes in label_smoothing here 29 | _name_: cross_entropy 30 | # label_smoothing: 0.1 31 | # pass in list of k's 32 | # last_k_ppl: null 33 | 34 | trainer: 35 | accelerator: gpu 36 | devices: 4 37 | num_nodes: 1 38 | accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}} 39 | max_epochs: 10 40 | precision: 16 # bf16 only a100 41 | gradient_clip_val: 1.0 42 | # strategy: null 43 | 44 | # there are 8 datasets in this suite, choose 1 at a time, with their corresponding settings 45 | # name num_seqs num_classes median len std 46 | # dummy_mouse_enhancers_ensembl 1210 2 2381 984.4 47 | # demo_coding_vs_intergenomic_seqs 100_000 2 200 0 48 | # demo_human_or_worm 100_000 2 200 0 49 | # human_enhancers_cohn 27791 2 500 0 50 | # human_enhancers_ensembl 154842 2 269 122.6 51 | # human_ensembl_regulatory 289061 3 401 184.3 52 | # human_nontata_promoters 36131 2 251 0 53 | # human_ocr_ensembl 174756 2 315 108.1 54 | 55 | # decoder: null 56 | 57 | dataset: 58 | batch_size: 32 59 | max_length: 2500 # select max that you want for this dataset 60 | dataset_name: 'dummy_mouse_enhancers_ensembl' 61 | dest_path: /mnt/nas/share2/home/by/ConvNova/convnova/data/genomic_benchmark # 'project_root/data/genomic_benchmark/' 62 | d_output: 2 # binary classification by default 63 | use_padding: True 64 | padding_side: 'left' 65 | add_eos: False 66 | train_len: 1210 # update this according to above table 67 | __l_max: ${.max_length} 68 | tokenizer_name: char 69 | use_tokenizer: False 70 | # rc_aug: true # reverse complement augmentation. Didn't seem to help for human_nontata_promoters, but could be wrong 71 | 72 | scheduler: 73 | t_in_epochs: False 74 | t_initial: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}} 75 | warmup_lr_init: 1e-6 76 | warmup_t: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.01} 77 | lr_min: ${eval:0.1 * ${optimizer.lr}} 78 | 79 | optimizer: 80 | lr: 1e-3 81 | weight_decay: 0.1 82 | 83 | train: 84 | gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"} 85 | seed: 43 86 | global_batch_size: ${eval:${trainer.devices}*${dataset.batch_size}} 87 | remove_test_loader_in_eval: true # no test set in this benchmark 88 | pretrained_model_strict_load: False # false allows encoder/decoder to be used if new model uses it 89 | # for loading backbone and not head, requires both of these flags below 90 | # pretrained_model_path: 91 | pretrained_model_state_hook: 92 | _name_: load_backbone 93 | freeze_backbone: false # seems to work much better if false (ie finetune entire model) 94 | -------------------------------------------------------------------------------- /convnova/src/models/nn/residual.py: -------------------------------------------------------------------------------- 1 | """ Implementations of different types of residual functions. """ 2 | 3 | import torch 4 | from torch import nn 5 | 6 | class Residual(nn.Module): 7 | """ Residual connection with constant affine weights. Can simulate standard residual, no residual, and "constant gates". """ 8 | 9 | def __init__(self, i_layer, d_input, d_model, alpha=1.0, beta=1.0): 10 | # print("ConstantResidual extra kwargs", kwargs) 11 | super().__init__() 12 | assert (d_input == d_model) or alpha == 0.0 13 | self.i_layer = i_layer 14 | self.d_input = d_input 15 | self.d_model = d_model 16 | self.alpha = alpha 17 | self.beta = beta 18 | 19 | @property 20 | def d_output(self): 21 | return self.d_model 22 | 23 | def forward(self, x, y, transposed): # TODO documentation of transposed 24 | y = self.beta*y if self.beta != 1.0 else y 25 | return self.alpha * x + y if self.alpha else y 26 | 27 | class Affine(Residual): 28 | """ Residual connection with learnable scalar multipliers on the main branch 29 | scalar: Single scalar multiplier, or one per dimension 30 | scale, power: Initialize to scale * layer_num**(-power) 31 | """ 32 | 33 | def __init__(self, *args, scalar=True, gamma=0.0, **kwargs): 34 | # print("ConstantResidual extra kwargs", kwargs) 35 | super().__init__(*args, **kwargs) 36 | self.scalar = scalar 37 | self.gamma = gamma 38 | 39 | c = self.beta * self.i_layer ** (-self.gamma) 40 | d = 1 if self.scalar else self.d_input 41 | self.affine = nn.Parameter(c * torch.ones(d)) 42 | 43 | def forward(self, x, y, transposed): # TODO documentation of transposed 44 | c = self.affine 45 | if transposed: c = c.unsqueeze(-1) 46 | return self.alpha * x + c * y 47 | 48 | 49 | class Feedforward(Residual): 50 | def __init__(self, *args): 51 | # print("Feedforward extra kwargs", kwargs) 52 | super().__init__(*args, alpha=0.0, beta=1.0) 53 | 54 | 55 | class Highway(Residual): 56 | def __init__(self, *args, scaling_correction=False, elemwise=False): 57 | super().__init__(*args) 58 | self.scaling_correction = 1.732 if scaling_correction else 1.0 # TODO 59 | self.elemwise = elemwise 60 | self.Wx = nn.Linear(self.d_input, self.d_input) 61 | if self.elemwise: 62 | self.Wy = nn.Parameter(torch.randn(self.d_input)) 63 | else: 64 | self.Wy = nn.Linear(self.d_input, self.d_input) 65 | 66 | def forward(self, x, y, transposed=False): # TODO handle this case 67 | if self.elemwise: 68 | y = self.Wy * y 69 | else: 70 | y = self.Wy(y) 71 | r = torch.sigmoid(self.Wx(x) + y) 72 | z = self.scaling_correction * (1.-r) * x + r * y 73 | return z 74 | 75 | 76 | class DecayResidual(Residual): 77 | """ Residual connection that can decay the linear combination depending on depth. """ 78 | 79 | def __init__(self, *args, power=0.5, l2=True): 80 | # print("DecayResidual extra kwargs", kwargs) 81 | super().__init__(*args) 82 | self.power = power 83 | self.l2 = l2 84 | 85 | def forward(self, x, y, transposed): 86 | beta = self.i_layer ** (-self.power) 87 | if self.l2: 88 | alpha = (1. - beta**2)**0.5 89 | else: 90 | alpha = 1. - beta 91 | 92 | return alpha * x + beta * y 93 | 94 | registry = { 95 | 'F': Feedforward, 96 | 'N': Feedforward, 97 | 'R': Residual, 98 | 'H': Highway, 99 | 'D': DecayResidual, 100 | 'A': Affine, 101 | 'none': Feedforward, 102 | 'ff': Feedforward, 103 | 'feedforward': Feedforward, 104 | 'residual': Residual, 105 | 'highway': Highway, 106 | 'decay': DecayResidual, 107 | 'affine': Affine, 108 | } 109 | -------------------------------------------------------------------------------- /convnova/src/callbacks/timer.py: -------------------------------------------------------------------------------- 1 | ### https://github.com/HazyResearch/transformers/blob/master/src/callbacks/speed_monitor.py 2 | 3 | # Adapted from https://pytorch-lightning.readthedocs.io/en/latest/_modules/pytorch_lightning/callbacks/gpu_stats_monitor.html#GPUStatsMonitor 4 | # We only need the speed monitoring, not the GPU monitoring 5 | import time 6 | from typing import Any 7 | 8 | from pytorch_lightning import Callback, Trainer, LightningModule 9 | from pytorch_lightning.utilities import rank_zero_only 10 | from pytorch_lightning.utilities.parsing import AttributeDict 11 | from pytorch_lightning.utilities.types import STEP_OUTPUT 12 | 13 | 14 | class Timer(Callback): 15 | """Monitor the speed of each step and each epoch. 16 | """ 17 | def __init__( 18 | self, 19 | step: bool = True, 20 | inter_step: bool = True, 21 | epoch: bool = True, 22 | val: bool = True, 23 | ): 24 | super().__init__() 25 | self._log_stats = AttributeDict( { 26 | 'step_time': step, 27 | 'inter_step_time': inter_step, 28 | 'epoch_time': epoch, 29 | 'val_time': val, 30 | }) 31 | 32 | def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: 33 | self._snap_epoch_time = None 34 | 35 | def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: 36 | self._snap_step_time = None 37 | self._snap_inter_step_time = None 38 | self._snap_epoch_time = time.time() 39 | 40 | def on_train_batch_start( 41 | self, 42 | trainer: Trainer, 43 | pl_module: LightningModule, 44 | batch: Any, 45 | batch_idx: int, 46 | ) -> None: 47 | if self._log_stats.step_time: 48 | self._snap_step_time = time.time() 49 | 50 | if not self._should_log(trainer): 51 | return 52 | 53 | logs = {} 54 | if self._log_stats.inter_step_time and self._snap_inter_step_time: 55 | # First log at beginning of second step 56 | logs["timer/inter_step"] = (time.time() - self._snap_inter_step_time) # * 1000 57 | 58 | if trainer.logger: trainer.logger.log_metrics(logs, step=trainer.global_step) 59 | 60 | @rank_zero_only 61 | def on_train_batch_end( 62 | self, 63 | trainer: Trainer, 64 | pl_module: LightningModule, 65 | outputs: STEP_OUTPUT, 66 | batch: Any, 67 | batch_idx: int, 68 | ) -> None: 69 | if self._log_stats.inter_step_time: 70 | self._snap_inter_step_time = time.time() 71 | 72 | if not self._should_log(trainer): 73 | return 74 | 75 | logs = {} 76 | if self._log_stats.step_time and self._snap_step_time: 77 | logs["timer/step"] = (time.time() - self._snap_step_time) # * 1000 78 | 79 | if trainer.logger: trainer.logger.log_metrics(logs, step=trainer.global_step) 80 | 81 | @rank_zero_only 82 | def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule,) -> None: 83 | logs = {} 84 | if self._log_stats.epoch_time and self._snap_epoch_time: 85 | logs["timer/epoch"] = time.time() - self._snap_epoch_time 86 | if trainer.logger: trainer.logger.log_metrics(logs, step=trainer.global_step) 87 | 88 | def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: 89 | self._snap_val_time = time.time() 90 | 91 | @rank_zero_only 92 | def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule,) -> None: 93 | logs = {} 94 | if self._log_stats.val_time and self._snap_val_time: 95 | logs["timer/validation"] = time.time() - self._snap_val_time 96 | if trainer.logger: trainer.logger.log_metrics(logs) # , step=trainer.global_step) 97 | 98 | @staticmethod 99 | def _should_log(trainer) -> bool: 100 | return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop 101 | -------------------------------------------------------------------------------- /convnova/src/utils/config.py: -------------------------------------------------------------------------------- 1 | """ Utilities for dealing with collection objects (lists, dicts) and configs """ 2 | from typing import Sequence, Mapping, Optional, Callable 3 | import functools 4 | import hydra 5 | from omegaconf import ListConfig, DictConfig 6 | 7 | # TODO this is usually used in a pattern where it's turned into a list, so can just do that here 8 | def is_list(x): 9 | return isinstance(x, Sequence) and not isinstance(x, str) 10 | 11 | 12 | def is_dict(x): 13 | return isinstance(x, Mapping) 14 | 15 | 16 | def to_dict(x, recursive=True): 17 | """Convert Sequence or Mapping object to dict 18 | 19 | lists get converted to {0: x[0], 1: x[1], ...} 20 | """ 21 | if is_list(x): 22 | x = {i: v for i, v in enumerate(x)} 23 | if is_dict(x): 24 | if recursive: 25 | return {k: to_dict(v, recursive=recursive) for k, v in x.items()} 26 | else: 27 | return dict(x) 28 | else: 29 | return x 30 | 31 | 32 | def to_list(x, recursive=False): 33 | """Convert an object to list. 34 | 35 | If Sequence (e.g. list, tuple, Listconfig): just return it 36 | 37 | Special case: If non-recursive and not a list, wrap in list 38 | """ 39 | if is_list(x): 40 | if recursive: 41 | return [to_list(_x) for _x in x] 42 | else: 43 | return list(x) 44 | else: 45 | if recursive: 46 | return x 47 | else: 48 | return [x] 49 | 50 | 51 | def extract_attrs_from_obj(obj, *attrs): 52 | if obj is None: 53 | assert len(attrs) == 0 54 | return [] 55 | return [getattr(obj, attr, None) for attr in attrs] 56 | 57 | 58 | def auto_assign_attrs(cls, **kwargs): 59 | for k, v in kwargs.items(): 60 | setattr(cls, k, v) 61 | 62 | 63 | def instantiate(registry, config, *args, partial=False, wrap=None, **kwargs): 64 | """ 65 | registry: Dictionary mapping names to functions or target paths (e.g. {'model': 'models.SequenceModel'}) 66 | config: Dictionary with a '_name_' key indicating which element of the registry to grab, and kwargs to be passed into the target constructor 67 | wrap: wrap the target class (e.g. ema optimizer or tasks.wrap) 68 | *args, **kwargs: additional arguments to override the config to pass into the target constructor 69 | """ 70 | 71 | # Case 1: no config 72 | if config is None: 73 | return None 74 | # Case 2a: string means _name_ was overloaded 75 | if isinstance(config, str): 76 | _name_ = None 77 | _target_ = registry[config] 78 | config = {} 79 | # Case 2b: grab the desired callable from name 80 | else: 81 | _name_ = config.pop("_name_") 82 | _target_ = registry[_name_] 83 | 84 | # Retrieve the right constructor automatically based on type 85 | if isinstance(_target_, str): 86 | fn = hydra.utils.get_method(path=_target_) 87 | elif isinstance(_target_, Callable): 88 | fn = _target_ 89 | else: 90 | raise NotImplementedError("instantiate target must be string or callable") 91 | 92 | # Instantiate object 93 | if wrap is not None: 94 | fn = wrap(fn) 95 | obj = functools.partial(fn, *args, **config, **kwargs) 96 | 97 | # Restore _name_ 98 | if _name_ is not None: 99 | config["_name_"] = _name_ 100 | 101 | if partial: 102 | return obj 103 | else: 104 | return obj() 105 | 106 | 107 | def get_class(registry, _name_): 108 | return hydra.utils.get_class(path=registry[_name_]) 109 | 110 | 111 | def omegaconf_filter_keys(d, fn=None): 112 | """Only keep keys where fn(key) is True. Support nested DictConfig. 113 | # TODO can make this inplace? 114 | """ 115 | if fn is None: 116 | fn = lambda _: True 117 | if is_list(d): 118 | return ListConfig([omegaconf_filter_keys(v, fn) for v in d]) 119 | elif is_dict(d): 120 | return DictConfig( 121 | {k: omegaconf_filter_keys(v, fn) for k, v in d.items() if fn(k)} 122 | ) 123 | else: 124 | return d 125 | 126 | 127 | -------------------------------------------------------------------------------- /convnova/src/models/nn/utils.py: -------------------------------------------------------------------------------- 1 | """ Utility wrappers around modules to let them handle Args and extra arguments """ 2 | 3 | import inspect 4 | from functools import wraps 5 | import torch 6 | from torch import nn 7 | 8 | def wrap_kwargs(f): 9 | """ 10 | Given a callable f that can consume some named arguments, 11 | wrap it with a kwargs that passes back any unused args 12 | 13 | EXAMPLES 14 | -------- 15 | 16 | Basic usage: 17 | def foo(x, y=None): 18 | return x 19 | 20 | wrap_kwargs(foo)(0, y=1, z=2) == (0, {'z': 2}) 21 | 22 | -------- 23 | 24 | The wrapped function can return its own argument dictionary, 25 | which gets merged with the new kwargs. 26 | def foo(x, y=None): 27 | return x, {} 28 | wrap_kwargs(foo)(0, y=1, z=2) == (0, {'z': 2}) 29 | 30 | def foo(x, y=None): 31 | return x, {"y": y, "z": None} 32 | wrap_kwargs(foo)(0, y=1, z=2) == (0, {'y': 1, 'z': 2}) 33 | 34 | -------- 35 | 36 | The wrapped function can have its own kwargs parameter: 37 | def foo(x, y=None, **kw_args): 38 | return x, {} 39 | wrap_kwargs(foo)(0, y=1, z=2) == (0, {}) 40 | 41 | -------- 42 | 43 | Partial functions and modules work automatically: 44 | class Module: 45 | def forward(self, x, y=0): 46 | return x, {"y": y+1} 47 | 48 | m = Module() 49 | 50 | wrap_kwargs(m.forward)(0, y=1, z=2) == (0, {'y': 2, 'z': 2}) 51 | 52 | """ 53 | sig = inspect.signature(f) 54 | # Check if f already has kwargs 55 | has_kwargs = any([ 56 | param.kind == inspect.Parameter.VAR_KEYWORD 57 | for param in sig.parameters.values() 58 | ]) 59 | if has_kwargs: 60 | @wraps(f) 61 | def f_kwargs(*args, **kwargs): 62 | y = f(*args, **kwargs) 63 | if isinstance(y, tuple) and isinstance(y[-1], dict): 64 | return y 65 | else: 66 | return y, {} 67 | else: 68 | param_kwargs = inspect.Parameter("kwargs", kind=inspect.Parameter.VAR_KEYWORD) 69 | sig_kwargs = inspect.Signature(parameters=list(sig.parameters.values())+[param_kwargs]) 70 | @wraps(f) 71 | def f_kwargs(*args, **kwargs): 72 | bound = sig_kwargs.bind(*args, **kwargs) 73 | if "kwargs" in bound.arguments: 74 | kwargs = bound.arguments.pop("kwargs") 75 | else: 76 | kwargs = {} 77 | y = f(**bound.arguments) 78 | if isinstance(y, tuple) and isinstance(y[-1], dict): 79 | return *y[:-1], {**y[-1], **kwargs} 80 | else: 81 | return y, kwargs 82 | return f_kwargs 83 | 84 | def discard_kwargs(f): 85 | if f is None: return None 86 | f_kwargs = wrap_kwargs(f) 87 | @wraps(f) 88 | def f_(*args, **kwargs): 89 | return f_kwargs(*args, **kwargs)[0] 90 | return f_ 91 | 92 | def PassthroughSequential(*modules): 93 | """Special Sequential module that chains kwargs. 94 | 95 | Semantics are the same as nn.Sequential, with extra convenience features: 96 | - Discard None modules 97 | - Flatten inner Sequential modules 98 | - In case with 0 or 1 Module, rename the class for ease of inspection 99 | """ 100 | def flatten(module): 101 | if isinstance(module, nn.Sequential): 102 | return sum([flatten(m) for m in module], []) 103 | else: 104 | return [module] 105 | 106 | modules = flatten(nn.Sequential(*modules)) 107 | modules = [module for module in modules if module if not None] 108 | 109 | class Sequential(nn.Sequential): 110 | def forward(self, x, **kwargs): 111 | for layer in self: 112 | x, kwargs = wrap_kwargs(layer.forward)(x, **kwargs) 113 | return x, kwargs 114 | 115 | def step(self, x, **kwargs): 116 | for layer in self: 117 | fn = getattr(layer, "step", layer.forward) 118 | x, kwargs = wrap_kwargs(fn)(x, **kwargs) 119 | return x, kwargs 120 | 121 | if len(modules) == 0: 122 | Sequential.__name__ = "Identity" 123 | elif len(modules) == 1: 124 | Sequential.__name__ = type(modules[0]).__name__ 125 | return Sequential(*modules) 126 | -------------------------------------------------------------------------------- /convnova/src/utils/profiling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.benchmark as benchmark 3 | 4 | 5 | def _get_gpu_mem(synchronize=True, empty_cache=True): 6 | return torch.cuda.memory_allocated() / ( 7 | (2**20) * 1000 8 | ), torch.cuda.memory_cached() / ((2**20) * 1000) 9 | 10 | 11 | def _generate_mem_hook(handle_ref, mem, idx, hook_type, exp): 12 | def hook(self, *args): 13 | if len(mem) == 0 or mem[-1]["exp"] != exp: 14 | call_idx = 0 15 | else: 16 | call_idx = mem[-1]["call_idx"] + 1 17 | 18 | mem_all, mem_cached = _get_gpu_mem() 19 | torch.cuda.synchronize() 20 | mem.append( 21 | { 22 | "layer_idx": idx, 23 | "call_idx": call_idx, 24 | "layer_type": type(self).__name__, 25 | "exp": exp, 26 | "hook_type": hook_type, 27 | "mem_all": mem_all, 28 | "mem_cached": mem_cached, 29 | } 30 | ) 31 | 32 | return hook 33 | 34 | 35 | def _add_memory_hooks(idx, model, mem_log, exp, hr): 36 | h = model.register_forward_pre_hook( 37 | _generate_mem_hook(hr, mem_log, idx, "pre", exp) 38 | ) 39 | hr.append(h) 40 | 41 | h = model.register_forward_hook(_generate_mem_hook(hr, mem_log, idx, "fwd", exp)) 42 | hr.append(h) 43 | 44 | h = model.register_backward_hook(_generate_mem_hook(hr, mem_log, idx, "bwd", exp)) 45 | hr.append(h) 46 | 47 | 48 | def log_memory(model, inp, mem_log=None, exp=None): 49 | mem_log = mem_log or [] 50 | exp = exp or f"exp_{len(mem_log)}" 51 | hr = [] 52 | for idx, module in enumerate(model.modules()): 53 | _add_memory_hooks(idx, module, mem_log, exp, hr) 54 | 55 | out = model(inp) 56 | if type(out) == tuple: 57 | out = out[0].logits 58 | loss = out.sum() 59 | loss.backward() 60 | [h.remove() for h in hr] 61 | return mem_log 62 | 63 | 64 | def benchmark_forward( 65 | fn, *inputs, min_run_time=0.2, repeats=10, desc="", verbose=True, **kwinputs 66 | ): 67 | """Use Pytorch Benchmark on the forward pass of an arbitrary function.""" 68 | if verbose: 69 | print(desc, "- Forward pass") 70 | t = benchmark.Timer( 71 | stmt="fn(*inputs, **kwinputs)", 72 | globals={"fn": fn, "inputs": inputs, "kwinputs": kwinputs}, 73 | num_threads=torch.get_num_threads(), 74 | ) 75 | m = t.timeit(repeats) 76 | if verbose: 77 | print(m) 78 | return t, m 79 | 80 | 81 | def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs): 82 | torch.cuda.empty_cache() 83 | torch.cuda.reset_peak_memory_stats() 84 | torch.cuda.synchronize() 85 | fn(*inputs, **kwinputs) 86 | torch.cuda.synchronize() 87 | mem = torch.cuda.max_memory_allocated() / ((2**20) * 1000) 88 | if verbose: 89 | print(f"{desc} max memory: {mem}GB") 90 | torch.cuda.empty_cache() 91 | return mem 92 | 93 | 94 | def benchmark_memory_bwd(fn, *inputs, desc="", verbose=True, **kwinputs): 95 | torch.cuda.empty_cache() 96 | torch.cuda.reset_peak_memory_stats() 97 | for input in inputs: 98 | input = input.requires_grad_(True) 99 | torch.cuda.synchronize() 100 | y = fn(*inputs, **kwinputs) 101 | y.sum().backward() 102 | torch.cuda.synchronize() 103 | mem = torch.cuda.max_memory_allocated() / ((2**20) * 1000) 104 | if verbose: 105 | print(f"{desc} max memory: {mem}GB") 106 | torch.cuda.empty_cache() 107 | return mem 108 | 109 | 110 | def benchmark_backward( 111 | fn, *inputs, grad=None, repeats=10, desc="", verbose=True, **kwinputs 112 | ): 113 | """Use Pytorch Benchmark on the backward pass of an arbitrary function.""" 114 | if verbose: 115 | print(desc, "- Backward pass") 116 | y = fn(*inputs, **kwinputs) 117 | if not hasattr(y, "shape"): 118 | y = y[0] 119 | if grad is None: 120 | grad = torch.randn_like(y) 121 | else: 122 | if grad.shape != y.shape: 123 | raise RuntimeError("Grad shape does not match output shape") 124 | t = benchmark.Timer( 125 | stmt="y.backward(grad, retain_graph=True)", 126 | globals={"y": y, "grad": grad}, 127 | num_threads=torch.get_num_threads(), 128 | ) 129 | m = t.timeit(repeats) 130 | if verbose: 131 | print(m) 132 | return t, m 133 | -------------------------------------------------------------------------------- /convnova/configs/experiment/genomic-benchmark/hyena.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /pipeline: genomic_benchmark 4 | - override /scheduler: cosine_warmup_timm 5 | 6 | model: 7 | _name_: dna_embedding 8 | d_model: 128 9 | n_layer: 2 10 | d_inner: ${eval:4 * ${.d_model}} 11 | vocab_size: 12 12 | resid_dropout: 0.1 13 | embed_dropout: 0.2 14 | fused_mlp: False # figure out how to use fused MLP, maybe only with bf16 + a100 15 | fused_dropout_add_ln: True 16 | residual_in_fp32: True 17 | pad_vocab_size_multiple: 8 18 | # attn_layer_idx: [0,1,2,3,4,5,6,7,8,9,10,11] # if passing these attn flags, then MHA auto used 19 | # attn_cfg: 20 | # num_heads: 8 21 | # use_flash_attn: True # figure out how to use 22 | # fused_bias_fc: False # this doesn't work for some reason, loss not going down 23 | # dropout: 0.1 24 | # rotary_emb_dim: 16 25 | layer: 26 | _name_: hyena 27 | emb_dim: 5 28 | filter_order: 64 29 | short_filter_order: 3 30 | l_max: 1026 # required to be set the same as the pretrained model if using, don't forget the +2! ${eval:${dataset.max_length}+2} 31 | modulate: True 32 | w: 10 33 | lr: ${optimizer.lr} 34 | wd: 0.0 35 | lr_pos_emb: 0.0 36 | 37 | task: 38 | # 2 options for soft_cross_entropy (for mixup) 39 | loss: 40 | # soft_cross_entropy for pytorch 1.10+, which takes in label_smoothing here 41 | _name_: cross_entropy 42 | # label_smoothing: 0.1 43 | # pass in list of k's 44 | # last_k_ppl: null 45 | 46 | trainer: 47 | accelerator: gpu 48 | devices: 1 49 | num_nodes: 1 50 | accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}} 51 | max_epochs: 10 52 | precision: 16 # bf16 only a100 53 | gradient_clip_val: 1.0 54 | # strategy: null 55 | 56 | # there are 8 datasets in this suite, choose 1 at a time, with their corresponding settings 57 | # name num_seqs num_classes median len std 58 | # dummy_mouse_enhancers_ensembl 1210 2 2381 984.4 59 | # demo_coding_vs_intergenomic_seqs 100_000 2 200 0 60 | # demo_human_or_worm 100_000 2 200 0 61 | # human_enhancers_cohn 27791 2 500 0 62 | # human_enhancers_ensembl 154842 2 269 122.6 63 | # human_ensembl_regulatory 289061 3 401 184.3 64 | # human_nontata_promoters 36131 2 251 0 65 | # human_ocr_ensembl 174756 2 315 108.1 66 | 67 | # decoder: null 68 | 69 | dataset: 70 | # batch_size: 32 # Per GPU 71 | batch_size: 64 72 | max_length: 200 # select max that you want for this dataset 73 | # dataset_name: 'human_nontata_promoters' 74 | dataset_name: 'demo_coding_vs_intergenomic_seqs' 75 | dest_path: # 'project_root/data/genomic_benchmark/' 76 | d_output: 2 # binary classification by default 77 | use_padding: True 78 | padding_side: 'left' 79 | add_eos: False 80 | train_len: 100_000 # update this according to above table 81 | __l_max: ${.max_length} 82 | tokenizer_name: char 83 | # use_tokenizer: False 84 | # rc_aug: true # reverse complement augmentation. Didn't seem to help for human_nontata_promoters, but could be wrong 85 | 86 | scheduler: 87 | t_in_epochs: False 88 | t_initial: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}} 89 | warmup_lr_init: 1e-6 90 | warmup_t: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.01} 91 | lr_min: ${eval:0.1 * ${optimizer.lr}} 92 | 93 | optimizer: 94 | lr: 6e-4 95 | weight_decay: 0.1 96 | 97 | train: 98 | gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"} 99 | seed: 2222 100 | global_batch_size: 64 101 | remove_test_loader_in_eval: true # no test set in this benchmark 102 | pretrained_model_strict_load: False # false allows encoder/decoder to be used if new model uses it 103 | # for loading backbone and not head, requires both of these flags below 104 | # pretrained_model_path: 105 | pretrained_model_state_hook: 106 | _name_: load_backbone 107 | freeze_backbone: false # seems to work much better if false (ie finetune entire model) 108 | -------------------------------------------------------------------------------- /convnova/src/models/sequence/block.py: -------------------------------------------------------------------------------- 1 | """ Implements a full residual block around a black box layer 2 | 3 | Configurable options include: 4 | normalization position: prenorm or postnorm 5 | normalization type: batchnorm, layernorm etc. 6 | subsampling/pooling 7 | residual options: feedforward, residual, affine scalars, depth-dependent scaling, etc. 8 | """ 9 | 10 | from torch import nn 11 | 12 | from functools import partial 13 | import src.utils as utils 14 | from src.models.nn.components import Normalization, StochasticDepth, DropoutNd 15 | from src.models.sequence import SequenceModule 16 | from src.models.sequence.pool import registry as pool_registry 17 | from src.models.nn.residual import registry as residual_registry 18 | import src.utils.registry as registry 19 | 20 | 21 | class SequenceResidualBlock(SequenceModule): 22 | def __init__( 23 | self, 24 | d_input, 25 | i_layer=None, # Only needs to be passed into certain residuals like Decay 26 | prenorm=True, 27 | dropout=0.0, 28 | tie_dropout=False, 29 | transposed=False, 30 | layer=None, # Config for black box module 31 | residual=None, # Config for residual function 32 | norm=None, # Config for normalization layer 33 | pool=None, 34 | drop_path=0., 35 | ): 36 | super().__init__() 37 | 38 | self.i_layer = i_layer 39 | self.d_input = d_input 40 | self.layer = utils.instantiate(registry.layer, layer, d_input) 41 | self.prenorm = prenorm 42 | self.transposed = transposed 43 | 44 | # Residual 45 | # d_residual is the output dimension after residual 46 | if residual is None: 47 | self.residual = None 48 | self.d_residual = self.layer.d_output 49 | else: 50 | self.residual = utils.instantiate(residual_registry, residual, i_layer, d_input, self.layer.d_output) 51 | self.d_residual = self.residual.d_output 52 | 53 | # Normalization 54 | d_norm = d_input if self.prenorm else self.d_residual 55 | # We don't use config to directly instantiate since Normalization has some special cases 56 | if norm is None: 57 | self.norm = None 58 | elif isinstance(norm, str): 59 | self.norm = Normalization(d_norm, transposed=self.transposed, _name_=norm) 60 | else: 61 | self.norm = Normalization(d_norm, transposed=self.transposed, **norm) 62 | 63 | # Pool 64 | self.pool = utils.instantiate(pool_registry, pool, self.d_residual, transposed=self.transposed) 65 | 66 | # Dropout 67 | dropout_cls = partial(DropoutNd, transposed=self.transposed) if tie_dropout else nn.Dropout 68 | self.drop = dropout_cls(dropout) if dropout > 0.0 else nn.Identity() 69 | 70 | # Stochastic depth 71 | self.drop_path = StochasticDepth(drop_path, mode='row') if drop_path > 0.0 else nn.Identity() 72 | 73 | 74 | @property 75 | def d_output(self): 76 | return self.pool.d_output if self.pool is not None else self.d_residual 77 | 78 | @property 79 | def d_state(self): 80 | return self.layer.d_state 81 | 82 | @property 83 | def state_to_tensor(self): 84 | return self.layer.state_to_tensor 85 | 86 | def default_state(self, *args, **kwargs): 87 | return self.layer.default_state(*args, **kwargs) 88 | 89 | def forward(self, x, state=None, **kwargs): 90 | y = x 91 | 92 | # Pre-norm 93 | if self.norm is not None and self.prenorm: y = self.norm(y) 94 | 95 | # Black box layer 96 | y, state = self.layer(y, state=state, **kwargs) 97 | 98 | # Residual 99 | if self.residual is not None: y = self.residual(x, self.drop_path(self.drop(y)), self.transposed) 100 | 101 | # Post-norm 102 | if self.norm is not None and not self.prenorm: y = self.norm(y) 103 | 104 | # Pool 105 | if self.pool is not None: y, _ = self.pool(y) 106 | 107 | return y, state 108 | 109 | def step(self, x, state, **kwargs): 110 | y = x 111 | 112 | # Pre-norm 113 | if self.norm is not None and self.prenorm: 114 | y = self.norm.step(y) 115 | 116 | # Black box layer 117 | y, state = self.layer.step(y, state, **kwargs) 118 | 119 | # Residual 120 | if self.residual is not None: y = self.residual(x, y, transposed=False) # NOTE this would not work with concat residual function (catformer) 121 | 122 | # Post-norm 123 | if self.norm is not None and not self.prenorm: 124 | y = self.norm.step(y) 125 | 126 | # Pool 127 | if self.pool is not None: y, _ = self.pool(y) 128 | 129 | return y, state 130 | -------------------------------------------------------------------------------- /convnova/src/tasks/torchmetrics.py: -------------------------------------------------------------------------------- 1 | # Inspired by https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/metrics/perplexity.py 2 | # But we compute the perplexity correctly: exp(average(nll)), not average(exp(nll)) 3 | # Also adapted from https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/text/perplexity.py 4 | # But we pass in the loss to avoid recomputation 5 | 6 | from typing import Any, Dict, Optional 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch import Tensor 11 | from torchmetrics import Metric 12 | 13 | try: 14 | from flash_attn.losses.cross_entropy import CrossEntropyLoss 15 | except ImportError: 16 | CrossEntropyLoss = torch.nn.CrossEntropyLoss 17 | 18 | try: 19 | from apex.transformer import parallel_state 20 | except ImportError: 21 | parallel_state = None 22 | 23 | 24 | class Perplexity(Metric): 25 | r""" 26 | Perplexity measures how well a language model predicts a text sample. It's calculated as the average number of bits 27 | per word a model needs to represent the sample. 28 | Args: 29 | kwargs: 30 | Additional keyword arguments, see :ref:`Metric kwargs` for more info. 31 | Examples: 32 | >>> import torch 33 | >>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22)) 34 | >>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22)) 35 | >>> target[0, 6:] = -100 36 | >>> metric = Perplexity(ignore_index=-100) 37 | >>> metric(preds, target) 38 | tensor(5.2545) 39 | """ 40 | is_differentiable = True 41 | higher_is_better = False 42 | full_state_update = False 43 | total_log_probs: Tensor 44 | count: Tensor 45 | 46 | def __init__(self, **kwargs: Dict[str, Any]): 47 | super().__init__(**kwargs) 48 | self.add_state("total_log_probs", default=torch.tensor(0.0, dtype=torch.float64), 49 | dist_reduce_fx="sum") 50 | self.add_state("count", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum") 51 | 52 | self.loss_fn = CrossEntropyLoss() 53 | 54 | def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor] = None) -> None: # type: ignore 55 | """Compute and store intermediate statistics for Perplexity. 56 | Args: 57 | preds: 58 | Probabilities assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size]. 59 | target: 60 | Ground truth values with a shape [batch_size, seq_len]. 61 | """ 62 | count = target.numel() 63 | if loss is None: 64 | loss = self.loss_fn(preds, target) 65 | self.total_log_probs += loss.double() * count 66 | self.count += count 67 | 68 | def compute(self) -> Tensor: 69 | """Compute the Perplexity. 70 | Returns: 71 | Perplexity 72 | """ 73 | return torch.exp(self.total_log_probs / self.count) 74 | 75 | class NumTokens(Metric): 76 | """Keep track of how many tokens we've seen. 77 | """ 78 | # TODO: how do we prevent the reset between the epochs? The reset happens on the 1st batch 79 | # of the next epoch. 80 | # Right now the hack is that we override reset(), which would mess up the forward method. 81 | # We then override forward to do the right thing. 82 | 83 | is_differentiable = False 84 | higher_is_better = False 85 | full_state_update = False 86 | count: Tensor 87 | 88 | def __init__(self, **kwargs: Dict[str, Any]): 89 | super().__init__(**kwargs) 90 | self.add_state("count", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum", 91 | persistent=True) # We want the count to be saved to state-dict 92 | if parallel_state is not None and not parallel_state.is_unitialized(): 93 | self.tensor_parallel_world_size = parallel_state.get_tensor_model_parallel_world_size() 94 | else: 95 | self.tensor_parallel_world_size = 1 96 | 97 | def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor] = None) -> None: # type: ignore 98 | self.count += target.numel() // self.tensor_parallel_world_size 99 | 100 | def compute(self) -> Tensor: 101 | return self.count 102 | 103 | def reset(self): 104 | count = self.count 105 | super().reset() 106 | self.count = count 107 | 108 | # Adapted from https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/metric.py 109 | def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any: 110 | """forward computation using single call to `update` to calculate the metric value on the current batch and 111 | accumulate global state. 112 | This can be done when the global metric state is a sinple reduction of batch states. 113 | """ 114 | self.update(*args, **kwargs) 115 | return self.compute() 116 | 117 | torchmetric_fns = { 118 | "perplexity": Perplexity, 119 | "num_tokens": NumTokens, 120 | } -------------------------------------------------------------------------------- /convnova/src/ops/fftconv.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from einops import rearrange 7 | 8 | from fftconv import fftconv_fwd, fftconv_bwd 9 | 10 | @torch.jit.script 11 | def _mul_sum(y, q): 12 | return (y * q).sum(dim=1) 13 | 14 | # reference convolution with residual connection 15 | def fftconv_ref(u, k, D, dropout_mask, gelu=True, k_rev=None): 16 | seqlen = u.shape[-1] 17 | fft_size = 2 * seqlen 18 | k_f = torch.fft.rfft(k, n=fft_size) / fft_size 19 | if k_rev is not None: 20 | k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size 21 | k_f = k_f + k_rev_f.conj() 22 | u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) 23 | 24 | if len(u.shape) > 3: k_f = k_f.unsqueeze(1) 25 | 26 | y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen] 27 | 28 | out = y + u * D.unsqueeze(-1) 29 | if gelu: 30 | out = F.gelu(out) 31 | if dropout_mask is not None: 32 | return (out * rearrange(dropout_mask, 'b H -> b H 1')).to(dtype=u.dtype) 33 | else: 34 | return out.to(dtype=u.dtype) 35 | 36 | 37 | # reference H3 forward pass 38 | def fftconv_h3_ref(k, ssm_kernel, D, q, v, head_dim=1, ssm_kernel_rev=None): 39 | seqlen = k.shape[-1] 40 | fft_size = 2 * seqlen 41 | kv = (rearrange(k, 'b (h d1) l -> b d1 1 h l', d1=head_dim) 42 | * rearrange(v, 'b (h d2) l -> b 1 d2 h l', d2=head_dim)) # b d1 d2 h l 43 | kv_f = torch.fft.rfft(kv.to(dtype=ssm_kernel.dtype), n=fft_size) / fft_size 44 | ssm_kernel_f = torch.fft.rfft(ssm_kernel, n=fft_size) # h L+1 45 | if ssm_kernel_rev is not None: 46 | ssm_kernel_rev_f = torch.fft.rfft(ssm_kernel_rev, n=fft_size) # h L+1 47 | ssm_kernel_f = ssm_kernel_f + ssm_kernel_rev_f.conj() 48 | y = torch.fft.irfft(kv_f * ssm_kernel_f, n=fft_size, norm='forward')[..., :seqlen] # b d1 d2 h l 49 | out = y + kv * D.unsqueeze(-1) # b d1 d2 h l 50 | q = rearrange(q, 'b (h d1) l -> b d1 1 h l', d1=head_dim) 51 | if head_dim > 1: 52 | out = _mul_sum(out, q) 53 | return rearrange(out, 'b d2 h l -> b (h d2) l').to(dtype=k.dtype) 54 | else: 55 | return rearrange(out * q, 'b 1 1 h l -> b h l').to(dtype=k.dtype) 56 | 57 | 58 | class FFTConvFunc(torch.autograd.Function): 59 | 60 | @staticmethod 61 | def forward(ctx, u, k, D, dropout_mask=None, gelu=True, force_fp16_output=False, 62 | output_hbl_layout=False, v=None, head_dim=1, q=None, fftfp16=False, k_rev=None): 63 | seqlen = u.shape[-1] 64 | fft_size = max(2 * 2 ** int(math.ceil(math.log2(seqlen))), 16) 65 | k_f = torch.fft.rfft(k, n=fft_size) 66 | if k_rev is not None: 67 | k_f = k_f + torch.fft.rfft(k_rev, n=fft_size).conj() 68 | if u.stride(-1) != 1: 69 | u = u.contiguous() 70 | k_f = k_f.contiguous() 71 | D = D.contiguous() 72 | if v is not None and v.stride(-1) != 1: 73 | v = v.contiguous() 74 | if q is not None and q.stride(-1) != 1: 75 | q = q.contiguous() 76 | if dropout_mask is not None: 77 | dropout_mask = dropout_mask.contiguous() 78 | ctx.save_for_backward(u, k_f, D, dropout_mask, v, q) 79 | ctx.output_hbl_layout = output_hbl_layout 80 | ctx.head_dim = head_dim 81 | ctx.gelu = gelu 82 | ctx.fftfp16 = fftfp16 83 | ctx.has_k_rev = k_rev is not None 84 | out = fftconv_fwd(u, k_f, D, v, head_dim, q, dropout_mask, gelu, False, False, fft_size, force_fp16_output, output_hbl_layout, fftfp16) 85 | return out 86 | 87 | @staticmethod 88 | def backward(ctx, dout): 89 | if ctx.output_hbl_layout: 90 | dout = rearrange(rearrange(dout, 'b h l -> h b l').contiguous(), 'h b l -> b h l') 91 | else: 92 | dout = dout.contiguous() 93 | u, k_f, D, dropout_mask, v, q = ctx.saved_tensors 94 | seqlen = u.shape[-1] 95 | fft_size = max(2 * 2 ** int(math.ceil(math.log2(seqlen))), 16) 96 | du, dk_f, dD, dv, dq = fftconv_bwd(dout, u, k_f, D, v, ctx.head_dim, q, dropout_mask, ctx.gelu, False, False, fft_size, 97 | ctx.output_hbl_layout, ctx.fftfp16) 98 | dk = torch.fft.irfft(dk_f, n=fft_size, norm='forward')[..., :seqlen] 99 | dk_rev = (None if not ctx.has_k_rev 100 | else torch.fft.irfft(dk_f.conj(), n=fft_size, norm='forward')[..., :seqlen]) 101 | if v is not None: 102 | dv = dv.to(dtype=v.dtype) # We do atomicAdd in fp32 so might need to convert to fp16 103 | return du, dk, dD, None, None, None, None, dv if v is not None else None, None, dq if q is not None else None, None, dk_rev 104 | 105 | def fftconv_func(u, k, D, dropout_mask=None, gelu=True, force_fp16_output=False, 106 | output_hbl_layout=False, v=None, head_dim=1, q=None, fftfp16=False, k_rev=None): 107 | return FFTConvFunc.apply(u, k, D, dropout_mask, gelu, force_fp16_output, 108 | output_hbl_layout, v, head_dim, q, fftfp16, k_rev) 109 | -------------------------------------------------------------------------------- /convnova/src/utils/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from contextlib import contextmanager 17 | 18 | import torch 19 | 20 | 21 | def init_distributed(cuda): 22 | """ 23 | Initializes distributed backend. 24 | 25 | :param cuda: (bool) if True initializes nccl backend, if False initializes 26 | gloo backend 27 | """ 28 | world_size = int(os.environ.get('WORLD_SIZE', 1)) 29 | distributed = (world_size > 1) 30 | if distributed: 31 | backend = 'nccl' if cuda else 'gloo' 32 | torch.distributed.init_process_group(backend=backend, 33 | init_method='env://') 34 | assert torch.distributed.is_initialized() 35 | return distributed 36 | 37 | 38 | def barrier(): 39 | """ 40 | Call torch.distributed.barrier() if distritubed is in use 41 | """ 42 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 43 | torch.distributed.barrier() 44 | 45 | 46 | def get_rank(): 47 | """ 48 | Gets distributed rank or returns zero if distributed is not initialized. 49 | """ 50 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 51 | rank = torch.distributed.get_rank() 52 | else: 53 | rank = 0 54 | return rank 55 | 56 | 57 | def get_world_size(): 58 | """ 59 | Gets total number of distributed workers or returns one if distributed is 60 | not initialized. 61 | """ 62 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 63 | world_size = torch.distributed.get_world_size() 64 | else: 65 | world_size = 1 66 | return world_size 67 | 68 | 69 | def all_reduce_item(value, op='sum'): 70 | """ 71 | All-reduces single scalar value if distributed is in use 72 | """ 73 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 74 | if op == 'sum' or op == 'mean': 75 | dop = torch.distributed.ReduceOp.SUM 76 | elif op == 'min': 77 | dop = torch.distributed.ReduceOp.MIN 78 | elif op == 'max': 79 | dop = torch.distributed.ReduceOp.MAX 80 | elif op == 'product': 81 | dop = torch.distributed.ReduceOp.PRODUCT 82 | else: 83 | raise RuntimeError('Unsupported reduce op') 84 | 85 | backend = torch.distributed.get_backend() 86 | if backend == torch.distributed.Backend.NCCL: 87 | device = torch.device('cuda') 88 | elif backend == torch.distributed.Backend.GLOO: 89 | device = torch.device('cpu') 90 | else: 91 | raise RuntimeError('Unsupported distributed backend') 92 | 93 | tensor = torch.tensor(value, device=device) 94 | torch.distributed.all_reduce(tensor, dop) 95 | if op == 'mean': 96 | tensor /= get_world_size() 97 | ret = tensor.item() 98 | else: 99 | ret = value 100 | return ret 101 | 102 | 103 | def all_reduce_tensor(value, op='sum'): 104 | """ 105 | All-reduces single scalar value if distributed is in use 106 | """ 107 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 108 | if op == 'sum' or op == 'mean': 109 | dop = torch.distributed.ReduceOp.SUM 110 | elif op == 'min': 111 | dop = torch.distributed.ReduceOp.MIN 112 | elif op == 'max': 113 | dop = torch.distributed.ReduceOp.MAX 114 | elif op == 'product': 115 | dop = torch.distributed.ReduceOp.PRODUCT 116 | else: 117 | raise RuntimeError('Unsupported reduce op') 118 | 119 | backend = torch.distributed.get_backend() 120 | if backend == torch.distributed.Backend.NCCL: 121 | device = torch.device('cuda') 122 | elif backend == torch.distributed.Backend.GLOO: 123 | device = torch.device('cpu') 124 | else: 125 | raise RuntimeError('Unsupported distributed backend') 126 | 127 | tensor = value 128 | torch.distributed.all_reduce(tensor, dop) 129 | if op == 'mean': 130 | tensor /= get_world_size() 131 | ret = tensor 132 | else: 133 | ret = value 134 | return ret 135 | 136 | 137 | @contextmanager 138 | def sync_workers(): 139 | """ 140 | Yields distributed rank and synchronizes all workers on exit. 141 | """ 142 | rank = get_rank() 143 | yield rank 144 | barrier() 145 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
5 | OpenReview | 6 | arXiv | 7 | GitHub | 8 | HuggingFace 🤗(coming soon) 9 | 10 |
11 | 12 |13 | ConvNova demonstrates that, if carefully designed, a pure CNN can serve as a DNA foundation model that surpasses Transformer and SSM-inspired architectures, while retaining the classic convolutional advantages of stronger locality bias, lower memory footprint, and markedly faster training and inference. 14 |
15 | 16 | 17 | --- 18 | 19 | ## 🚩 Plan 20 | - [x] Scripts for Pretraining, NT & Genomic Benchmarks. 21 | - [x] Paper Released. 22 | - [ ] Pretrained Weights of ConvNova. 23 | - [ ] Source Code and Pretrained Weights on transformers. 24 | - [ ] Scripts for DeepSEA & Bend-gene-finding. 25 | 26 | --- 27 | 28 |32 | git clone git@github.com:aim-uofa/ConvNova.git 33 | cd ConvNova/convnova 34 |35 | 36 | Prepare conda env. 37 |
38 | conda create -n convnova python==3.10 39 | conda activate convnova 40 | pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 41 | pip install -r requirements.txt --no-deps 42 | pip install pytorch-lightning==1.8.6 --no-deps 43 | pip install packaging --no-deps 44 | 45 | pip install lightning_utilities --no-deps 46 | pip install torchmetrics 47 | pip install tensorboardX 48 |49 | 50 | Download the data.(Pretrain) 51 |
52 | mkdir data 53 | mkdir -p data/hg38/ 54 | curl https://storage.googleapis.com/basenji_barnyard2/hg38.ml.fa.gz > data/hg38/hg38.ml.fa.gz 55 | gunzip data/hg38/hg38.ml.fa.gz # unzip the fasta file 56 | curl https://storage.googleapis.com/basenji_barnyard2/sequences_human.bed > data/hg38/human-sequences.bed 57 |58 | 59 | You can check out the Nucleotide Transformer ang Genomic Benchmarks paper for how to download and process NT benchmark & Genomic Benchmark datasets. 60 | 61 | The final file structure (data directory) should look like 62 | 63 |
64 | |____bert_hg38 65 | | |____hg38.ml.fa 66 | | |____hg38.ml.fa.fai 67 | | |____human-sequences.bed 68 | |____nucleotide_transformer 69 | | |____H3K36me3 70 | | |____...... 71 | |____genomic_benchmark 72 | | |____dummy_mouse_enhancers_ensembl 73 | | |____.... 74 |75 | 76 | --- 77 | 78 |
Coming Soon
80 | 81 | --- 82 | 83 |88 | python train.py experiment='hg38-pretrain/convnova' 89 |90 | 91 | you can adjust the hyperparameters by using cmd like following, detailed hyperparameters setting can be seen in configs/experiment/xxx/xxx.yaml 92 |
93 | python train.py experiment='hg38-pretrain/convnova' wandb=null trainer.devices=4 94 |95 | 96 |
GenomicBenchmarks provides 8 binary- and multi-class tasks packaged as a Python library.
98 | 99 | Remeber to adjust the setting for different dataset like max seq length and the pretrained checkpoint(comming soon). 100 |101 | python train.py experiment='genomic-benchmark/convnova' with-some-argments 102 |103 | 104 |
Datasets are hosted on the Hub as InstaDeepAI/nucleotide_transformer_downstream_tasks.
108 | python train.py experiment='nt-benchmark/convnova' with-some-argments 109 |110 | 111 | --- 112 | 113 |
116 | @inproceedings{bo2025convnova,
117 | title = {Revisiting Convolution Architecture in the Realm of DNA Foundation Models},
118 | author = {Yu Bo and Weian Mao and Yanjun Shao and Weiqiang Bai and Peng Ye
119 | and Xinzhu Ma and Junbo Zhao and Hao Chen and Chunhua Shen},
120 | booktitle = {International Conference on Learning Representations (ICLR)},
121 | year = {2025}
122 | }
123 |
124 |
125 | ---
126 |
127 | ConvNova builds on the training, logging and data-loading scaffolds of HyenaDNA and Caduceus, and evaluates on Genomic Benchmarks, Nucleotide Transformer tasks, and the Long-Range Benchmark. We thank the maintainers of these open resources for making rigorous comparison possible.
129 | -------------------------------------------------------------------------------- /convnova/src/dataloaders/fault_tolerant_sampler.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/Lightning-AI/lightning/blob/2845e7565dbe6b765ae32870e7d2bc456529c30a/tests/tests_pytorch/utilities/test_auto_restart.py#L1397 2 | from typing import Iterator 3 | import math 4 | 5 | import torch 6 | from torch.utils.data import RandomSampler, DistributedSampler 7 | 8 | 9 | class RandomFaultTolerantSampler(RandomSampler): 10 | 11 | def __init__(self, *args, generator=None, **kwargs): 12 | # generator = torch.Generator().manual_seed(seed) 13 | # super().__init__(*args, generator=generator, **kwargs) 14 | # TD [2022-07-17]: We don't force the seed to be zero. We generate random seed, 15 | # which should be reproducible if pl.seed_everything was called before hand. 16 | # This means that changing the seed of the experiment will also change the 17 | # sampling order. 18 | if generator is None: 19 | seed = int(torch.empty((), dtype=torch.int64).random_().item()) 20 | generator = torch.Generator().manual_seed(seed) 21 | super().__init__(*args, generator=generator, **kwargs) 22 | self.counter = 0 23 | # self.start_counter = 0 24 | self.restarting = False 25 | 26 | def state_dict(self): 27 | return {"random_state": self.state, "counter": self.counter} 28 | 29 | def load_state_dict(self, state_dict): 30 | self.generator.set_state(state_dict.get("random_state")) 31 | self.counter = state_dict["counter"] 32 | # self.start_counter = self.counter 33 | self.restarting = True 34 | 35 | # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per 36 | # epoch, and subsequent epoch will have very few batches. 37 | # def __len__(self): 38 | # # We need a separate self.start_counter because PL seems to call len repeatedly. 39 | # # If we use len(self.data_source) - self.counter then PL will think the epoch ends 40 | # # when we're only half way through. 41 | # return len(self.data_source) - self.start_counter 42 | 43 | def __iter__(self) -> Iterator[int]: 44 | n = len(self.data_source) 45 | 46 | self.state = self.generator.get_state() 47 | indices = torch.randperm(n, generator=self.generator).tolist() 48 | 49 | if not self.restarting: 50 | self.counter = 0 51 | else: 52 | indices = indices[self.counter:] 53 | self.restarting = False 54 | # self.start_counter = self.counter 55 | 56 | for index in indices: 57 | self.counter += 1 58 | yield index 59 | 60 | self.counter = 0 61 | # self.start_counter = self.counter 62 | 63 | 64 | class FaultTolerantDistributedSampler(DistributedSampler): 65 | 66 | def __init__(self, *args, **kwargs): 67 | super().__init__(*args, **kwargs) 68 | self.counter = 0 69 | # self.start_counter = 0 70 | self.restarting = False 71 | 72 | def state_dict(self): 73 | return {"epoch": self.epoch, "counter": self.counter} 74 | 75 | def load_state_dict(self, state_dict): 76 | self.epoch = state_dict["epoch"] 77 | self.counter = state_dict["counter"] 78 | # self.start_counter = self.counter 79 | self.restarting = True 80 | 81 | # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per 82 | # epoch, and subsequent epoch will have very few batches. 83 | # def __len__(self) -> int: 84 | # return self.num_samples - self.start_counter 85 | 86 | def __iter__(self): 87 | if self.shuffle: 88 | # deterministically shuffle based on epoch and seed 89 | g = torch.Generator() 90 | g.manual_seed(self.seed + self.epoch) 91 | indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] 92 | else: 93 | indices = list(range(len(self.dataset))) # type: ignore[arg-type] 94 | 95 | if not self.drop_last: 96 | # add extra samples to make it evenly divisible 97 | padding_size = self.total_size - len(indices) 98 | if padding_size <= len(indices): 99 | indices += indices[:padding_size] 100 | else: 101 | indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] 102 | else: 103 | # remove tail of data to make it evenly divisible. 104 | indices = indices[:self.total_size] 105 | assert len(indices) == self.total_size 106 | 107 | # subsample 108 | indices = indices[self.rank:self.total_size:self.num_replicas] 109 | assert len(indices) == self.num_samples 110 | 111 | if not self.restarting: 112 | self.counter = 0 113 | else: 114 | indices = indices[self.counter:] 115 | self.restarting = False 116 | # self.start_counter = self.counter 117 | 118 | for index in indices: 119 | self.counter += 1 120 | yield index 121 | 122 | self.counter = 0 123 | # self.start_counter = self.counter -------------------------------------------------------------------------------- /convnova/src/models/nn/gate.py: -------------------------------------------------------------------------------- 1 | """ Defines flexible gating mechanisms based on ideas from LSSL paper and UR-LSTM paper https://arxiv.org/abs/1910.09890 """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | class Gate(nn.Module): 7 | """ Implements gating mechanisms. TODO update this with more detailed description with reference to LSSL paper when it's on arxiv 8 | 9 | Mechanisms: 10 | N - No gate 11 | G - Standard sigmoid gate 12 | UR - Uniform refine gates 13 | R - Refine gate 14 | 15 | FS - Forward discretization, Sigmoid activation [equivalent to G] 16 | BE - Backward discretization, Exp activation [equivalent to G] 17 | BR - Backward discretization, Relu activation 18 | TE - Trapezoid discretization, Exp activation 19 | TR - Trapezoid discretization, Relu activation 20 | TS - Trapezoid discretization, Sigmoid activation (0 to 2) 21 | """ 22 | def __init__(self, size, preact_ctor, preact_args, mechanism='N'): 23 | super().__init__() 24 | self.size = size 25 | self.mechanism = mechanism 26 | 27 | if self.mechanism == 'N': 28 | pass 29 | elif self.mechanism in ['G', 'FS', 'BE', 'BR', 'TE', 'TR', 'TS', 'ZE', 'ZR', 'ZS']: 30 | self.W_g = preact_ctor(*preact_args) 31 | elif self.mechanism in ['U', 'UT']: 32 | self.W_g = preact_ctor(*preact_args) 33 | b_g_unif = torch.empty(size) 34 | torch.nn.init.uniform_(b_g_unif, 1./self.size, 1.-1./self.size) 35 | self.b_g = nn.Parameter(torch.log(1./b_g_unif-1.).detach(), requires_grad=False) 36 | elif self.mechanism == 'UR': 37 | self.W_g = preact_ctor(*preact_args) 38 | self.W_r = preact_ctor(*preact_args) 39 | 40 | b_g_unif = torch.empty(size) 41 | torch.nn.init.uniform_(b_g_unif, 1./self.size, 1.-1./self.size) 42 | self.b_g = nn.Parameter(torch.log(1./b_g_unif-1.).detach(), requires_grad=False) 43 | elif self.mechanism == 'R': 44 | self.W_g = preact_ctor(*preact_args) 45 | self.W_r = preact_ctor(*preact_args) 46 | elif self.mechanism in ['GT']: 47 | self.W_g = preact_ctor(*preact_args) 48 | else: 49 | assert False, f'Gating type {self.mechanism} is not supported.' 50 | 51 | def forward(self, *inputs): 52 | if self.mechanism == 'N': 53 | return 1.0 54 | 55 | if self.mechanism == 'G': 56 | g_preact = self.W_g(*inputs) 57 | g = torch.sigmoid(g_preact) 58 | if self.mechanism == 'U': 59 | g_preact = self.W_g(*inputs) + self.b_g 60 | g = torch.sigmoid(g_preact) 61 | elif self.mechanism == 'UR': 62 | g_preact = self.W_g(*inputs) + self.b_g 63 | g = torch.sigmoid(g_preact) 64 | r = torch.sigmoid(self.W_r(*inputs)) 65 | g = (1-2*r)*g**2 + 2*r*g 66 | elif self.mechanism == 'R': 67 | g_preact = self.W_g(*inputs) 68 | g = torch.sigmoid(g_preact) 69 | r = torch.sigmoid(self.W_r(*inputs)) 70 | g = (1-2*r)*g**2 + 2*r*g 71 | elif self.mechanism == 'UT': 72 | g_preact = self.W_g(*inputs) + self.b_g 73 | g = torch.sigmoid(g_preact) 74 | r = g 75 | g = (1-2*r)*g**2 + 2*r*g 76 | elif self.mechanism == 'GT': 77 | g_preact = self.W_g(*inputs) 78 | g = torch.sigmoid(g_preact) 79 | r = g 80 | g = (1-2*r)*g**2 + 2*r*g 81 | else: 82 | g_preact = self.W_g(*inputs) 83 | # if self.mechanism[1] == 'S': 84 | # g = torch.sigmoid(g_preact) 85 | # elif self.mechanism[1] == 'E': 86 | # g = torch.exp(g_preact) 87 | # elif self.mechanism[1] == 'R': 88 | # g = torch.relu(g_preact) 89 | if self.mechanism == 'FS': 90 | g = torch.sigmoid(g_preact) 91 | g = self.forward_diff(g) 92 | elif self.mechanism == 'BE': 93 | g = torch.exp(g_preact) 94 | g = self.backward_diff(g) 95 | elif self.mechanism == 'BR': 96 | g = torch.relu(g_preact) 97 | g = self.backward_diff(g) 98 | elif self.mechanism == 'TS': 99 | g = 2 * torch.sigmoid(g_preact) 100 | g = self.trapezoid(g) 101 | elif self.mechanism == 'TE': 102 | g = torch.exp(g_preact) 103 | g = self.trapezoid(g) 104 | elif self.mechanism == 'TR': 105 | g = torch.relu(g_preact) 106 | g = self.trapezoid(g) 107 | elif self.mechanism == 'ZE': 108 | g = torch.exp(g_preact) 109 | g = self.zoh(g) 110 | elif self.mechanism == 'ZR': 111 | g = torch.relu(g_preact) 112 | g = self.zoh(g) 113 | elif self.mechanism == 'ZS': 114 | g = torch.sigmoid(g_preact) 115 | g = self.zoh(g) 116 | return g 117 | 118 | def forward_diff(self, x): 119 | return x 120 | 121 | def backward_diff(self, x): 122 | return x / (1+x) 123 | 124 | def trapezoid(self, x): 125 | return x / (1 + x/2) 126 | 127 | def zoh(self, x): 128 | return 1 - torch.exp(-x) 129 | -------------------------------------------------------------------------------- /convnova/configs/experiment/nt-benchmark/hyena1.6M.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /pipeline: nucleotide_transformer 4 | - override /scheduler: cosine_warmup_timm 5 | 6 | model: 7 | _name_: dna_embedding 8 | d_model: 256 9 | n_layer: 2 10 | d_inner: ${eval:4 * ${.d_model}} 11 | vocab_size: 12 12 | resid_dropout: 0.1 13 | embed_dropout: 0.2 14 | fused_mlp: False # figure out how to use fused MLP, maybe only with bf16 + a100 15 | fused_dropout_add_ln: True 16 | residual_in_fp32: True 17 | pad_vocab_size_multiple: 8 18 | # attn_layer_idx: [0,1,2,3,4,5,6,7,8,9,10,11] # if passing these attn flags, then MHA auto used 19 | # attn_cfg: 20 | # num_heads: 8 21 | # use_flash_attn: True # figure out how to use 22 | # fused_bias_fc: False # this doesn't work for some reason, loss not going down 23 | # dropout: 0.1 24 | # rotary_emb_dim: 16 25 | layer: 26 | _name_: hyena 27 | emb_dim: 5 28 | filter_order: 64 29 | short_filter_order: 3 30 | l_max: 1026 # required to be set the same as the pretrained model if using, don't forget the +2! ${eval:${dataset.max_length}+2} 31 | modulate: True 32 | w: 10 33 | lr: ${optimizer.lr} 34 | wd: 0.0 35 | lr_pos_emb: 0.0 36 | # model: 37 | # # _name_: convnext 38 | # # d_model: 128 39 | # # # max_length: ${dataset.max_length} 40 | # # max_length: 8193 41 | # # vocab_size: 12 42 | # # pad_vocab_size_multiple: 8 43 | # # k_size: 5 44 | # _name_: lm 45 | # d_model: 128 46 | # n_layer: 8 47 | # d_inner: 512 48 | # vocab_size: 12 49 | # resid_dropout: 0.0 50 | # embed_dropout: 0.1 51 | # fused_mlp: false 52 | # fused_dropout_add_ln: false 53 | # checkpoint_mixer: false 54 | # checkpoint_mlp: false 55 | # residual_in_fp32: true 56 | # pad_vocab_size_multiple: 8 57 | # layer: 58 | # _name_: hyena 59 | # emb_dim: 5 60 | # filter_order: 64 61 | # short_filter_order: 3 62 | # l_max: 8195 63 | # modulate: true 64 | # w: 10 65 | # lr: 0.002 66 | # wd: 0.0 67 | # lr_pos_emb: 0.0 68 | 69 | task: 70 | _name_: masked_multiclass 71 | loss: cross_entropy 72 | # metrics: 73 | # - accuracy 74 | torchmetrics: null 75 | 76 | trainer: 77 | accelerator: gpu 78 | devices: 2 79 | num_nodes: 1 80 | accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}} 81 | max_epochs: 20 82 | precision: 32 # bf16 only a100 83 | gradient_clip_val: 1.0 84 | # strategy: null 85 | 86 | # name maxlen classes samples metric 87 | 88 | # enhancer 200 2 14968 MCC 89 | # enhancer_types 200 3 14968 MCC 90 | # H3 500 2 13468 MCC 91 | # H3K4me1 500 2 28509 MCC 92 | # H3K4me2 500 2 27614 MCC 93 | # H3K4me3 500 2 33119 MCC 94 | # H3K9ac 500 2 25003 MCC 95 | # H3K14ac 500 2 29743 MCC 96 | # H3K36me3 500 2 31392 MCC 97 | # H3K79me3 500 2 25953 MCC 98 | # H4 500 2 13140 MCC 99 | # H4ac 500 2 30685 MCC 100 | # promoter_all 300 2 53276 F1 101 | # promoter_non_tata 300 2 47759 F1 102 | # promoter_tata 300 2 5517 F1 103 | # splice_sites_acceptor 600 2 19961 F1 104 | # splice_sites_donor 600 2 19775 F1 105 | 106 | 107 | dataset: 108 | # batch_size: 32 # Per GPU 109 | batch_size: 128 110 | # max_length: 515 # select max that you want for this dataset 111 | # dataset_name: 'human_nontata_promoters' 112 | dataset_name: 'H4ac' 113 | # dest_path: '/mnt/nas/share2/home/by/hyena-dna/data/genomic_benchmark/' 114 | # d_output: 3 # binary classification by default 115 | # use_padding: True 116 | # padding_side: 'left' 117 | # add_eos: False 118 | # train_len: 289061 # update this according to above table 119 | # __l_max: ${.max_length} 120 | tokenizer_name: char 121 | add_eos: false 122 | rc_aug: false # reverse complement augmentation 123 | return_mask: false 124 | padding_side: left 125 | # num_workers: 1 126 | # rc_aug: true # reverse complement augmentation. Didn't seem to help for human_nontata_promoters, but could be wrong 127 | 128 | scheduler: 129 | t_in_epochs: False 130 | t_initial: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}} 131 | warmup_lr_init: 1e-6 132 | warmup_t: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.01} 133 | lr_min: ${eval:0.1 * ${optimizer.lr}} 134 | 135 | optimizer: 136 | lr: 6e-4 137 | weight_decay: 0.1 138 | 139 | train: 140 | gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"} 141 | seed: 2222 142 | global_batch_size: ${eval:${trainer.devices}*${dataset.batch_size}} 143 | remove_test_loader_in_eval: true # no test set in this benchmark 144 | pretrained_model_strict_load: False # false allows encoder/decoder to be used if new model uses it 145 | # for loading backbone and not head, requires both of these flags below 146 | # pretrained_model_path: /mnt/nas/share2/home/by/hyena-dna/outputs/2024-04-16/15-28-14-651739/checkpoints/test/loss.ckpt 147 | pretrained_model_path: /mnt/nas/share2/home/by/hyena-dna/outputs/2-256-1k.ckpt #/mnt/nas/share2/home/by/hyena-dna/outputs/weights.ckpt 148 | pretrained_model_state_hook: 149 | _name_: load_backbone 150 | freeze_backbone: false # seems to work much better if false (ie finetune entire model) 151 | -------------------------------------------------------------------------------- /convnova/src/ops/vandermonde.py: -------------------------------------------------------------------------------- 1 | """pykeops implementations of the Vandermonde matrix multiplication kernel used in the S4D kernel.""" 2 | import math 3 | import torch 4 | 5 | from einops import rearrange, repeat 6 | from opt_einsum import contract 7 | 8 | import os 9 | 10 | try: 11 | import pykeops 12 | from pykeops.torch import LazyTensor, Genred 13 | except: 14 | pass 15 | 16 | try: 17 | from cauchy_mult import vand_log_mult_sym_fwd, vand_log_mult_sym_bwd 18 | except: 19 | vand_log_mult_sym_fwd, vand_log_mult_sym_bwd = None, None 20 | 21 | _conj = lambda x: torch.cat([x, x.conj()], dim=-1) 22 | def _broadcast_dims(*tensors): 23 | max_dim = max([len(tensor.shape) for tensor in tensors]) 24 | tensors = [tensor.view((1,)*(max_dim-len(tensor.shape))+tensor.shape) for tensor in tensors] 25 | return tensors 26 | 27 | def _c2r(x): return torch.view_as_real(x) 28 | def _r2c(x): return torch.view_as_complex(x) 29 | 30 | def vandermonde_naive(v, x, L, conj=True): 31 | """ 32 | v: (..., N) 33 | x: (..., N) 34 | returns: (..., L) \sum v x^l 35 | """ 36 | if conj: 37 | x = _conj(x) 38 | v = _conj(v) 39 | vandermonde_matrix = x.unsqueeze(-1) ** torch.arange(L).to(x) # (... N L) 40 | vandermonde_prod = torch.sum(v.unsqueeze(-1) * vandermonde_matrix, dim=-2) # (... L) 41 | return vandermonde_prod 42 | 43 | def log_vandermonde_naive(v, x, L, conj=True): 44 | """ 45 | v: (..., N) 46 | x: (..., N) 47 | returns: (..., L) \sum v x^l 48 | """ 49 | vandermonde_matrix = torch.exp(x.unsqueeze(-1) * torch.arange(L).to(x)) # (... N L) 50 | vandermonde_prod = contract('... n, ... n l -> ... l', v, vandermonde_matrix) # (... L) 51 | if conj: 52 | return 2*vandermonde_prod.real 53 | else: 54 | return vandermonde_prod 55 | 56 | def log_vandermonde_lazy(v, x, L, conj=True): 57 | if conj: 58 | v = _conj(v) 59 | x = _conj(x) 60 | l = torch.arange(L).to(x) 61 | v, x, l = _broadcast_dims(v, x, l) 62 | v_l = LazyTensor(rearrange(v, '... N -> ... N 1 1')) 63 | x_l = LazyTensor(rearrange(x, '... N -> ... N 1 1')) 64 | l_l = LazyTensor(rearrange(l, '... L -> ... 1 L 1')) 65 | # exp 66 | vand = (x_l * l_l).exp() 67 | s = (v_l*vand).sum(dim=len(v_l.shape)-2) 68 | return s.squeeze(-1) 69 | 70 | def log_vandermonde(v, x, L, conj=True): 71 | expr = 'ComplexMult(v, ComplexExp(ComplexMult(x, l)))' 72 | vandermonde_mult = Genred( 73 | expr, 74 | [ 75 | 'v = Vj(2)', 76 | 'x = Vj(2)', 77 | 'l = Vi(2)', 78 | ], 79 | reduction_op='Sum', 80 | axis=1, 81 | ) 82 | 83 | l = torch.arange(L).to(x) 84 | v, x, l = _broadcast_dims(v, x, l) 85 | v = _c2r(v) 86 | x = _c2r(x) 87 | l = _c2r(l) 88 | 89 | r = vandermonde_mult(v, x, l, backend='GPU') 90 | if conj: 91 | return 2*_r2c(r).real 92 | else: 93 | return _r2c(r) 94 | 95 | def log_vandermonde_transpose_naive(u, v, x, L): 96 | vandermonde_matrix = torch.exp(x.unsqueeze(-1) * torch.arange(L).to(x)) # (... N L) 97 | vandermonde_prod = contract('... l, ... n, ... n l -> ... n', u.to(x), v.to(x), vandermonde_matrix) # (... L) 98 | return vandermonde_prod 99 | 100 | def log_vandermonde_transpose(u, v, x, L): 101 | """ 102 | u: ... H L 103 | v: ... H N 104 | x: ... H N 105 | Returns: ... H N 106 | 107 | V = Vandermonde(a, L) : (H N L) 108 | contract_L(V * u * v) 109 | """ 110 | expr = 'ComplexMult(ComplexMult(v, u), ComplexExp(ComplexMult(x, l)))' 111 | vandermonde_mult = Genred( 112 | expr, 113 | [ 114 | 'u = Vj(2)', 115 | 'v = Vi(2)', 116 | 'x = Vi(2)', 117 | 'l = Vj(2)', 118 | ], 119 | reduction_op='Sum', 120 | axis=1, 121 | ) 122 | 123 | l = torch.arange(L).to(x) 124 | u, v, x, l = _broadcast_dims(u, v, x, l) 125 | u = _c2r(u) 126 | v = _c2r(v) 127 | x = _c2r(x) 128 | l = _c2r(l) 129 | 130 | r = vandermonde_mult(u, v, x, l, backend='GPU') 131 | return _r2c(r) 132 | 133 | def _log_vandermonde_matmul(x, L): 134 | vandermonde_matrix = torch.exp(x.unsqueeze(-1) * torch.arange(L).to(x)) # (... N L) 135 | return vandermonde_matrix 136 | 137 | def log_vandermonde_matmul(v, K): 138 | prod = contract('...n, ...nl -> ...l', v, K) 139 | return 2*prod.real 140 | 141 | class LogVandMultiplySymmetric(torch.autograd.Function): 142 | 143 | @staticmethod 144 | def forward(ctx, v, x, L): 145 | batch, N = v.shape 146 | supported_N_values = [1 << log_n for log_n in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]] 147 | if not N in supported_N_values: 148 | raise NotImplementedError(f'Only support N values in {supported_N_values}') 149 | max_L_value = 32 * 1024 * 64 * 1024 150 | if L > max_L_value: 151 | raise NotImplementedError(f'Only support L values <= {max_L_value}') 152 | if not v.is_cuda and x.is_cuda: 153 | raise NotImplementedError(f'Only support CUDA tensors') 154 | ctx.save_for_backward(v, x) 155 | return vand_log_mult_sym_fwd(v, x, L) 156 | 157 | @staticmethod 158 | def backward(ctx, dout): 159 | v, x = ctx.saved_tensors 160 | dv, dx = vand_log_mult_sym_bwd(v, x, dout) 161 | return dv, dx, None 162 | 163 | 164 | if vand_log_mult_sym_fwd and vand_log_mult_sym_bwd is not None: 165 | log_vandermonde_fast = LogVandMultiplySymmetric.apply 166 | else: 167 | log_vandermonde_fast = None -------------------------------------------------------------------------------- /convnova/src/callbacks/progressive_resizing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pytorch_lightning.callbacks import Callback 3 | 4 | import src.utils as utils 5 | from src.utils import registry 6 | 7 | 8 | class ProgressiveResizing(Callback): 9 | 10 | def __init__(self, stage_params: list): 11 | """ 12 | stage_params is a list of dicts 13 | e.g. stage_params = [ 14 | {'resolution': 4, 'epochs': 50}, # 32 x 32 15 | {'resolution': 2, 'epochs': 30}, # 64 x 64 16 | {'resolution': 1, 'epochs': 20}, # 128 x 128 17 | ] 18 | """ 19 | super().__init__() 20 | assert len(stage_params) > 0, 'No stages specified' 21 | assert all([{'resolution', 'epochs'} <= set(stage.keys()) for stage in stage_params]), \ 22 | 'stage_params must contain keys: resolution and epochs' 23 | 24 | self.stage_params = stage_params 25 | self.stage_epochs_cume = np.cumsum([stage['epochs'] for stage in stage_params]) 26 | 27 | self._current_stage = 0 28 | 29 | def _verify_stages(self, trainer, model): 30 | # Double-check that stage parameters are correct, otherwise we'll fail in the middle of training 31 | for stage in self.stage_params: 32 | if hasattr(stage, 'scheduler'): 33 | # Verify that we can actually create the scheduler when we need to update it in each stage 34 | scheduler = utils.instantiate(registry.scheduler, {**model.hparams.scheduler, **stage['scheduler']}, trainer.optimizers[0]) 35 | del scheduler 36 | 37 | def on_train_start(self, trainer, model) -> None: 38 | # Verify all the stage parameters are correct 39 | self._verify_stages(trainer, model) 40 | 41 | print(f"Training starts at {trainer.current_epoch}") 42 | if trainer.current_epoch == 0: 43 | # Update the model to the first stage 44 | self._update_to_current_stage(trainer, model) 45 | else: 46 | # Preemption or resumption of progressive resizing 47 | # Update the stage to the current one 48 | self._current_stage = int(np.searchsorted(self.stage_epochs_cume - 1, trainer.current_epoch)) 49 | self._starting_stage = np.any(trainer.current_epoch == self.stage_epochs_cume) 50 | 51 | print("Progressive Resizing: Restarting at Stage {}".format(self._current_stage)) 52 | if self._starting_stage: 53 | self._update_lr_scheduler(trainer, model) 54 | 55 | # Set the dataloader and model 56 | self._update_dataloaders(trainer, model) 57 | self._update_model(trainer, model) 58 | 59 | return super().on_train_start(trainer, model) 60 | 61 | def _update_lr_scheduler(self, trainer, model): 62 | if not hasattr(self.stage_params[self._current_stage], 'scheduler'): 63 | # No scheduler specified, so don't update the current scheduler 64 | return 65 | 66 | assert len(trainer.lr_schedulers) == 1 67 | # Reinitialize the scheduler 68 | # We don't need to carry over information from the last scheduler e.g. the last_epoch property, 69 | # because that will mess with the new scheduler when we step it 70 | hparams = {**model.hparams.scheduler, **self.stage_params[self._current_stage]['scheduler']} 71 | 72 | # Note that passing in the optimizer below is okay: the scheduler will be reinitialized and doesn't seem to inherit any current lr info from the optimizer 73 | trainer.lr_schedulers[0]['scheduler'] = utils.instantiate(registry.scheduler, hparams, trainer.optimizers[0]) 74 | 75 | print("\tChanged scheduler to {}".format(hparams)) 76 | 77 | def _update_dataloaders(self, trainer, model): 78 | # Set the train resolution and reset the dataloader 79 | model.hparams.loader.train_resolution = self.stage_params[self._current_stage]['resolution'] 80 | trainer.reset_train_dataloader(model) 81 | 82 | print('\tChanged resolution to {}'.format(self.stage_params[self._current_stage]['resolution'])) 83 | 84 | def _update_model(self, trainer, model): 85 | if not hasattr(self.stage_params[self._current_stage], 'bandlimit'): 86 | return 87 | 88 | # Update the bandlimit value for the model: this is a hack to make sure the model is updated 89 | # Iterate over all the modules 90 | for module in model.modules(): 91 | if hasattr(module, 'bandlimit'): 92 | module.bandlimit = self.stage_params[self._current_stage]['bandlimit'] 93 | 94 | print('\tChanged bandlimit to {}'.format(self.stage_params[self._current_stage]['bandlimit'])) 95 | 96 | def _update_to_current_stage(self, trainer, model): 97 | print("Progressive Resizing: Moving to Stage {}".format(self._current_stage)) 98 | # Update the train dataloader, model and scheduler 99 | self._update_dataloaders(trainer, model) 100 | self._update_model(trainer, model) 101 | self._update_lr_scheduler(trainer, model) 102 | 103 | 104 | def on_train_epoch_end(self, trainer, model): 105 | """ 106 | Check to see if new stage is reached for the next epoch, and if so, prepare the new stage by 107 | changing the dataloader. 108 | 109 | (We do next epoch so that the dataloader is prepared before the next epoch) 110 | """ 111 | next_epoch = trainer.current_epoch + 1 112 | 113 | # Check if stage should be increased 114 | if next_epoch >= self.stage_epochs_cume[self._current_stage] and self._current_stage < len(self.stage_params) - 1: 115 | self._current_stage += 1 116 | self._update_to_current_stage(trainer, model) 117 | 118 | return super().on_train_epoch_end(trainer, model) 119 | -------------------------------------------------------------------------------- /convnova/src/models/sequence/base.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import functools 3 | 4 | class SequenceModule(nn.Module): 5 | """Abstract sequence model class. All models must adhere to this interface 6 | 7 | A SequenceModule is generally a model that transforms an input of shape 8 | (n_batch, l_sequence, d_model) to (n_batch, l_sequence, d_output) 9 | 10 | REQUIRED methods and attributes 11 | forward, d_model, d_output: controls standard forward pass, a sequence-to-sequence transformation 12 | __init__ should also satisfy the following interface; see SequenceIdentity for an example 13 | def __init__(self, d_model, transposed=False, **kwargs) 14 | 15 | OPTIONAL methods 16 | default_state, step: allows stepping the model recurrently with a hidden state 17 | state_to_tensor, d_state: allows decoding from hidden state 18 | """ 19 | 20 | @property 21 | def d_model(self): 22 | """Model dimension (generally same as input dimension). 23 | 24 | This attribute is required for all SequenceModule instantiations. 25 | It is used by the rest of the pipeline (e.g. model backbone, encoder) to track the internal shapes of the full model. 26 | """ 27 | if getattr(self, "_d_model", None) is None: 28 | raise NotImplementedError("SequenceModule instantiation must set d_model") 29 | return self._d_model 30 | 31 | @d_model.setter 32 | def d_model(self, d): 33 | self._d_model = d 34 | 35 | @property 36 | def d_output(self): 37 | """Output dimension of model. 38 | 39 | This attribute is required for all SequenceModule instantiations. 40 | It is used by the rest of the pipeline (e.g. model backbone, decoder) to track the internal shapes of the full model. 41 | """ 42 | if getattr(self, "_d_output", None) is None: 43 | raise NotImplementedError("SequenceModule instantiation must specify d_output for decoder") 44 | return self._d_output 45 | 46 | @d_output.setter 47 | def d_output(self, d): 48 | self._d_output = d 49 | 50 | def forward(self, x, state=None, **kwargs): 51 | """Forward pass of sequence model, a sequence-to-sequence transformation with an optional state. 52 | 53 | Generally, this should map a tensor of shape (batch, length, self.d_model) to (batch, length, self.d_output) 54 | 55 | Additionally, it returns a "state" which can be any additional information 56 | For example, RNN and SSM layers may return their hidden state, 57 | while some types of transformer layers (e.g. Transformer-XL) may want to pass a state as well 58 | """ 59 | return x, None 60 | 61 | @property 62 | def state_to_tensor(self): 63 | """Returns a function mapping a state to a single tensor. 64 | 65 | This method should be implemented if one wants to use the hidden state instead of the output sequence for final prediction. 66 | Currently only used with the StateDecoder. 67 | """ 68 | return lambda _: None 69 | 70 | @property 71 | def d_state(self): 72 | """ Returns dimension of output of self.state_to_tensor """ 73 | return None 74 | 75 | 76 | def default_state(self, *batch_shape, device=None): 77 | """Create initial state for a batch of inputs.""" 78 | 79 | return None 80 | 81 | def step(self, x, state=None, **kwargs): 82 | """Step the model recurrently for one step of the input sequence. 83 | 84 | For example, this should correspond to unrolling an RNN for one step. 85 | If the forward pass has signature (B, L, H1) -> (B, L, H2), 86 | this method should generally have signature (B, H1) -> (B, H2) with an optional recurrent state. 87 | """ 88 | raise NotImplementedError 89 | 90 | def TransposedModule(module): 91 | """Wrap a SequenceModule class to accept transposed parameter, handle state, absorb kwargs""" 92 | # https://stackoverflow.com/a/65470430/1980685 93 | @functools.wraps(module, updated=()) 94 | class TransposedModule(module): 95 | def __init__(self, *args, transposed=False, **kwargs): 96 | super().__init__(*args, **kwargs) 97 | self.transposed = transposed 98 | 99 | def forward(self, x, state=None, **kwargs): 100 | if self.transposed: x = x.transpose(-1, -2) 101 | if self.return_state: 102 | x, next_state = super().forward(x, state) # Don't use kwarg because nn.LSTM 103 | next_state = None if state is None else next_state 104 | if self.transposed: x = x.transpose(-1,-2) 105 | return x, next_state 106 | else: 107 | x = super().forward(x, state) # Don't use kwarg because nn.LSTM 108 | if self.transposed: x = x.transpose(-1,-2) 109 | return x 110 | # https://stackoverflow.com/questions/5352781/how-to-set-class-names-dynamically 111 | # TransposedModule.__name__ = module.__name__ # functools wraps is better solution 112 | return TransposedModule 113 | 114 | @TransposedModule 115 | class SequenceIdentity(SequenceModule): 116 | """Simple SequenceModule for testing purposes""" 117 | 118 | def __init__(self, d_model, dropout=0.0, **kwargs): 119 | """Default interface for SequenceModule 120 | 121 | d_model: input dimension (sometimes denoted H for hidden dimension) 122 | transposed: if True, inputs have axis ordering (B, H, L) instead of (B, H, L) 123 | """ 124 | super().__init__() 125 | self.d_model = d_model 126 | self.d_output = d_model 127 | 128 | 129 | def forward(self, x, state=None): 130 | return x, state 131 | 132 | def default_state(self, *batch_shape, device=None): 133 | return None 134 | 135 | def step(self, x, state=None, **kwargs): 136 | return x, state 137 | -------------------------------------------------------------------------------- /convnova/src/ops/toeplitz.py: -------------------------------------------------------------------------------- 1 | """ Utilities for computing convolutions. 2 | 3 | There are 3 equivalent views: 4 | 1. causal convolution 5 | 2. multiplication of (lower) triangular Toeplitz matrices 6 | 3. polynomial multiplication (mod x^N) 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | def construct_toeplitz(v, f=0.0): 15 | """Explicit construction of Krylov matrix [v A @ v A^2 @ v ... A^{n-1} @ v] 16 | where A = Z_f. This uses vectorized indexing and cumprod so it's much 17 | faster than using the Krylov function. 18 | Parameters: 19 | v: the starting vector of size n or (rank, n). 20 | f: real number 21 | Returns: 22 | K: Krylov matrix of size (n, n) or (rank, n, n). 23 | """ 24 | n = v.shape[-1] 25 | a = torch.arange(n, device=v.device) 26 | b = -a 27 | indices = a[:, None] + b[None] 28 | K = v[..., indices] 29 | K[..., indices < 0] *= f 30 | return K 31 | 32 | def triangular_toeplitz_multiply_(u, v, sum=None): 33 | n = u.shape[-1] 34 | u_expand = F.pad(u, (0, n)) 35 | v_expand = F.pad(v, (0, n)) 36 | u_f = torch.fft.rfft(u_expand, n=2*n, dim=-1) 37 | v_f = torch.fft.rfft(v_expand, n=2*n, dim=-1) 38 | uv_f = u_f * v_f 39 | if sum is not None: 40 | uv_f = uv_f.sum(dim=sum) 41 | output = torch.fft.irfft(uv_f, n=2*n, dim=-1)[..., :n] 42 | return output 43 | 44 | def triangular_toeplitz_multiply_padded_(u, v): 45 | """ Same as triangular_toeplitz_multiply but inputs and output assume to be 0-padded already. """ 46 | n = u.shape[-1] 47 | assert n % 2 == 0 48 | u_f = torch.fft.rfft(u, n=n, dim=-1) 49 | v_f = torch.fft.rfft(v, n=n, dim=-1) 50 | uv_f = u_f * v_f 51 | output = torch.fft.irfft(uv_f, n=n, dim=-1) 52 | output[..., n:] = 0 53 | return output 54 | 55 | class TriangularToeplitzMult(torch.autograd.Function): 56 | @staticmethod 57 | def forward(ctx, u, v): 58 | ctx.save_for_backward(u, v) 59 | return triangular_toeplitz_multiply_(u, v) 60 | 61 | @staticmethod 62 | def backward(ctx, grad): 63 | u, v = ctx.saved_tensors 64 | d_u = triangular_toeplitz_multiply_(grad.flip(-1), v).flip(-1) 65 | d_v = triangular_toeplitz_multiply_(grad.flip(-1), u).flip(-1) 66 | return d_u, d_v 67 | 68 | class TriangularToeplitzMultFast(torch.autograd.Function): 69 | @staticmethod 70 | def forward(ctx, u, v): 71 | n = u.shape[-1] 72 | u_expand = F.pad(u, (0, n)) 73 | v_expand = F.pad(v, (0, n)) 74 | u_f = torch.fft.rfft(u_expand, n=2*n, dim=-1) 75 | v_f = torch.fft.rfft(v_expand, n=2*n, dim=-1) 76 | 77 | ctx.save_for_backward(u_f, v_f) 78 | 79 | uv_f = u_f * v_f 80 | output = torch.fft.irfft(uv_f, n=2*n, dim=-1)[..., :n] 81 | return output 82 | 83 | @staticmethod 84 | def backward(ctx, grad): 85 | u_f, v_f = ctx.saved_tensors 86 | n = grad.shape[-1] 87 | g_expand = F.pad(grad.flip(-1), (0, n)) 88 | g_f = torch.fft.rfft(g_expand, n=2*n, dim=-1) 89 | gu_f = g_f * u_f 90 | gv_f = g_f * v_f 91 | d_u = torch.fft.irfft(gv_f, n=2*n, dim=-1)[..., :n] 92 | d_v = torch.fft.irfft(gu_f, n=2*n, dim=-1)[..., :n] 93 | d_u = d_u.flip(-1) 94 | d_v = d_v.flip(-1) 95 | return d_u, d_v 96 | 97 | class TriangularToeplitzMultPadded(torch.autograd.Function): 98 | @staticmethod 99 | def forward(ctx, u, v): 100 | ctx.save_for_backward(u, v) 101 | output = triangular_toeplitz_multiply_(u, v) 102 | return output 103 | 104 | @staticmethod 105 | def backward(ctx, grad): 106 | u, v = ctx.saved_tensors 107 | d_u = triangular_toeplitz_multiply_padded_(grad.flip(-1), v).flip(-1) 108 | d_v = triangular_toeplitz_multiply_padded_(grad.flip(-1), u).flip(-1) 109 | return d_u, d_v 110 | 111 | class TriangularToeplitzMultPaddedFast(torch.autograd.Function): 112 | """ Trade off speed (20-25% faster) for more memory (20-25%) """ 113 | 114 | @staticmethod 115 | def forward(ctx, u, v): 116 | n = u.shape[-1] 117 | u_f = torch.fft.rfft(u, n=n, dim=-1) 118 | v_f = torch.fft.rfft(v, n=n, dim=-1) 119 | 120 | ctx.save_for_backward(u_f, v_f) 121 | 122 | uv_f = u_f * v_f 123 | output = torch.fft.irfft(uv_f, n=n, dim=-1) 124 | output[..., n//2:].zero_() 125 | return output 126 | 127 | @staticmethod 128 | def backward(ctx, grad): 129 | u_f, v_f = ctx.saved_tensors 130 | n = grad.shape[-1] 131 | g_expand = F.pad(grad[..., :n//2].flip(-1), (0, n//2)) 132 | g_f = torch.fft.rfft(g_expand, n=n, dim=-1) 133 | gu_f = g_f * u_f 134 | gv_f = g_f * v_f 135 | d_u = torch.fft.irfft(gv_f, n=n, dim=-1) 136 | d_v = torch.fft.irfft(gu_f, n=n, dim=-1) 137 | d_u[..., n//2:].zero_() 138 | d_v[..., n//2:].zero_() 139 | d_u[..., :n//2] = d_u[..., :n//2].flip(-1) # TODO 140 | d_v[..., :n//2] = d_v[..., :n//2].flip(-1) # TODO 141 | return d_u, d_v 142 | 143 | # triangular_toeplitz_multiply = triangular_toeplitz_multiply_ 144 | triangular_toeplitz_multiply = TriangularToeplitzMult.apply 145 | triangular_toeplitz_multiply_fast = TriangularToeplitzMultFast.apply 146 | triangular_toeplitz_multiply_padded = TriangularToeplitzMultPadded.apply 147 | triangular_toeplitz_multiply_padded_fast = TriangularToeplitzMultPaddedFast.apply 148 | 149 | def causal_convolution(u, v, fast=True, pad=False): 150 | if not pad and not fast: 151 | return triangular_toeplitz_multiply(u, v) 152 | if not pad and fast: 153 | return triangular_toeplitz_multiply_fast(u, v) 154 | if pad and not fast: 155 | return triangular_toeplitz_multiply_padded(u, v) 156 | if pad and fast: 157 | return triangular_toeplitz_multiply_padded_fast(u, v) 158 | -------------------------------------------------------------------------------- /convnova/src/dataloaders/datasets/hg38_char_tokenizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | From: https://github.com/dariush-bahrami/character-tokenizer/blob/master/charactertokenizer/core.py 3 | 4 | CharacterTokenzier for Hugging Face Transformers. 5 | This is heavily inspired from CanineTokenizer in transformers package. 6 | """ 7 | import json 8 | import os 9 | from pathlib import Path 10 | from typing import Dict, List, Optional, Sequence, Union 11 | 12 | from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer 13 | 14 | 15 | class CharacterTokenizer(PreTrainedTokenizer): 16 | def __init__(self, characters: Sequence[str], model_max_length: int, padding_side: str='left', **kwargs): 17 | """Character tokenizer for Hugging Face transformers. 18 | Args: 19 | characters (Sequence[str]): List of desired characters. Any character which 20 | is not included in this list will be replaced by a special token called 21 | [UNK] with id=6. Following are list of all of the special tokens with 22 | their corresponding ids: 23 | "[CLS]": 0 24 | "[SEP]": 1 25 | "[BOS]": 2 26 | "[MASK]": 3 27 | "[PAD]": 4 28 | "[RESERVED]": 5 29 | "[UNK]": 6 30 | an id (starting at 7) will be assigned to each character. 31 | model_max_length (int): Model maximum sequence length. 32 | """ 33 | self.characters = characters 34 | self.model_max_length = model_max_length 35 | bos_token = AddedToken("[BOS]", lstrip=False, rstrip=False) 36 | eos_token = AddedToken("[SEP]", lstrip=False, rstrip=False) 37 | sep_token = AddedToken("[SEP]", lstrip=False, rstrip=False) 38 | cls_token = AddedToken("[CLS]", lstrip=False, rstrip=False) 39 | pad_token = AddedToken("[PAD]", lstrip=False, rstrip=False) 40 | unk_token = AddedToken("[UNK]", lstrip=False, rstrip=False) 41 | 42 | mask_token = AddedToken("[MASK]", lstrip=True, rstrip=False) 43 | 44 | super().__init__( 45 | bos_token=bos_token, 46 | eos_token=sep_token, 47 | sep_token=sep_token, 48 | cls_token=cls_token, 49 | pad_token=pad_token, 50 | mask_token=mask_token, 51 | unk_token=unk_token, 52 | add_prefix_space=False, 53 | model_max_length=model_max_length, 54 | padding_side=padding_side, 55 | **kwargs, 56 | ) 57 | 58 | self._vocab_str_to_int = { 59 | "[CLS]": 0, 60 | "[SEP]": 1, 61 | "[BOS]": 2, 62 | "[MASK]": 3, 63 | "[PAD]": 4, 64 | "[RESERVED]": 5, 65 | "[UNK]": 6, 66 | **{ch: i + 7 for i, ch in enumerate(characters)}, 67 | } 68 | self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()} 69 | 70 | @property 71 | def vocab_size(self) -> int: 72 | return len(self._vocab_str_to_int) 73 | 74 | def _tokenize(self, text: str) -> List[str]: 75 | return list(text) 76 | 77 | def _convert_token_to_id(self, token: str) -> int: 78 | return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"]) 79 | 80 | def _convert_id_to_token(self, index: int) -> str: 81 | return self._vocab_int_to_str[index] 82 | 83 | def convert_tokens_to_string(self, tokens): 84 | return "".join(tokens) 85 | 86 | def build_inputs_with_special_tokens( 87 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 88 | ) -> List[int]: 89 | sep = [self.sep_token_id] 90 | # cls = [self.cls_token_id] 91 | result = token_ids_0 + sep 92 | if token_ids_1 is not None: 93 | result += token_ids_1 + sep 94 | return result 95 | 96 | def get_special_tokens_mask( 97 | self, 98 | token_ids_0: List[int], 99 | token_ids_1: Optional[List[int]] = None, 100 | already_has_special_tokens: bool = False, 101 | ) -> List[int]: 102 | if already_has_special_tokens: 103 | return super().get_special_tokens_mask( 104 | token_ids_0=token_ids_0, 105 | token_ids_1=token_ids_1, 106 | already_has_special_tokens=True, 107 | ) 108 | 109 | result = ([0] * len(token_ids_0)) + [1] 110 | if token_ids_1 is not None: 111 | result += ([0] * len(token_ids_1)) + [1] 112 | return result 113 | 114 | def create_token_type_ids_from_sequences( 115 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 116 | ) -> List[int]: 117 | sep = [self.sep_token_id] 118 | cls = [self.cls_token_id] 119 | 120 | result = len(cls + token_ids_0 + sep) * [0] 121 | if token_ids_1 is not None: 122 | result += len(token_ids_1 + sep) * [1] 123 | return result 124 | 125 | def get_config(self) -> Dict: 126 | return { 127 | "char_ords": [ord(ch) for ch in self.characters], 128 | "model_max_length": self.model_max_length, 129 | } 130 | 131 | @classmethod 132 | def from_config(cls, config: Dict) -> "CharacterTokenizer": 133 | cfg = {} 134 | cfg["characters"] = [chr(i) for i in config["char_ords"]] 135 | cfg["model_max_length"] = config["model_max_length"] 136 | return cls(**cfg) 137 | 138 | def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs): 139 | cfg_file = Path(save_directory) / "tokenizer_config.json" 140 | cfg = self.get_config() 141 | with open(cfg_file, "w") as f: 142 | json.dump(cfg, f, indent=4) 143 | 144 | @classmethod 145 | def from_pretrained(cls, save_directory: Union[str, os.PathLike], **kwargs): 146 | cfg_file = Path(save_directory) / "tokenizer_config.json" 147 | with open(cfg_file) as f: 148 | cfg = json.load(f) 149 | return cls.from_config(cfg) -------------------------------------------------------------------------------- /convnova/src/models/sequence/model.py: -------------------------------------------------------------------------------- 1 | """ Isotropic deep sequence model backbone, in the style of ResNets / Transformers. 2 | 3 | The SequenceModel class implements a generic (batch, length, d_input) -> (batch, length, d_output) transformation 4 | """ 5 | 6 | from functools import partial 7 | 8 | import torch 9 | import torch.nn as nn 10 | from einops import rearrange 11 | 12 | from src.utils.config import to_list, to_dict 13 | from src.models.sequence.block import SequenceResidualBlock 14 | from src.models.sequence.base import SequenceModule 15 | from src.models.nn.components import Normalization, DropoutNd 16 | 17 | 18 | class SequenceModel(SequenceModule): 19 | def __init__( 20 | self, 21 | d_model, # Resize input (useful for deep models with residuals) 22 | n_layers=1, # Number of layers 23 | transposed=False, # Transpose inputs so each layer receives (batch, dim, length) 24 | dropout=0.0, # Dropout parameter applied on every residual and every layer 25 | tie_dropout=False, # Tie dropout mask across sequence like nn.Dropout1d/nn.Dropout2d 26 | prenorm=True, # Pre-norm vs. post-norm 27 | n_repeat=1, # Each layer is repeated n times per stage before applying pooling 28 | layer=None, # Layer config, must be specified 29 | residual=None, # Residual config 30 | norm=None, # Normalization config (e.g. layer vs batch) 31 | pool=None, # Config for pooling layer per stage 32 | track_norms=True, # Log norms of each layer output 33 | dropinp=0.0, # Input dropout 34 | ): 35 | super().__init__() 36 | # Save arguments needed for forward pass 37 | self.d_model = d_model 38 | self.transposed = transposed 39 | self.track_norms = track_norms 40 | 41 | # Input dropout (not really used) 42 | dropout_fn = partial(DropoutNd, transposed=self.transposed) if tie_dropout else nn.Dropout 43 | self.drop = dropout_fn(dropinp) if dropinp > 0.0 else nn.Identity() 44 | 45 | layer = to_list(layer, recursive=False) 46 | 47 | # Some special arguments are passed into each layer 48 | for _layer in layer: 49 | # If layers don't specify dropout, add it 50 | if _layer.get('dropout', None) is None: 51 | _layer['dropout'] = dropout 52 | # Ensure all layers are shaped the same way 53 | _layer['transposed'] = transposed 54 | 55 | # Duplicate layers 56 | layers = layer * n_layers * n_repeat 57 | 58 | # Instantiate layers 59 | _layers = [] 60 | d = d_model 61 | for l, layer in enumerate(layers): 62 | # Pool at the end of every n_repeat blocks 63 | pool_cfg = pool if (l+1) % n_repeat == 0 else None 64 | block = SequenceResidualBlock(d, l+1, prenorm=prenorm, dropout=dropout, tie_dropout=tie_dropout, transposed=transposed, layer=layer, residual=residual, norm=norm, pool=pool_cfg) 65 | _layers.append(block) 66 | d = block.d_output 67 | 68 | self.d_output = d 69 | self.layers = nn.ModuleList(_layers) 70 | if prenorm: 71 | if norm is None: 72 | self.norm = None 73 | elif isinstance(norm, str): 74 | self.norm = Normalization(self.d_output, transposed=self.transposed, _name_=norm) 75 | else: 76 | self.norm = Normalization(self.d_output, transposed=self.transposed, **norm) 77 | else: 78 | self.norm = nn.Identity() 79 | 80 | def forward(self, inputs, *args, state=None, **kwargs): 81 | """ Inputs assumed to be (batch, sequence, dim) """ 82 | if self.transposed: inputs = rearrange(inputs, 'b ... d -> b d ...') 83 | inputs = self.drop(inputs) 84 | 85 | # Track norms 86 | if self.track_norms: output_norms = [torch.mean(inputs.detach() ** 2)] 87 | 88 | # Apply layers 89 | outputs = inputs 90 | prev_states = [None] * len(self.layers) if state is None else state 91 | next_states = [] 92 | for layer, prev_state in zip(self.layers, prev_states): 93 | outputs, state = layer(outputs, *args, state=prev_state, **kwargs) 94 | next_states.append(state) 95 | if self.track_norms: output_norms.append(torch.mean(outputs.detach() ** 2)) 96 | if self.norm is not None: outputs = self.norm(outputs) 97 | 98 | if self.transposed: outputs = rearrange(outputs, 'b d ... -> b ... d') 99 | 100 | if self.track_norms: 101 | metrics = to_dict(output_norms, recursive=False) 102 | self.metrics = {f'norm/{i}': v for i, v in metrics.items()} 103 | 104 | return outputs, next_states 105 | 106 | @property 107 | def d_state(self): 108 | d_states = [layer.d_state for layer in self.layers] 109 | return sum([d for d in d_states if d is not None]) 110 | 111 | @property 112 | def state_to_tensor(self): 113 | # Slightly hacky way to implement this in a curried manner (so that the function can be extracted from an instance) 114 | # Somewhat more sound may be to turn this into a @staticmethod and grab subclasses using hydra.utils.get_class 115 | def fn(state): 116 | x = [_layer.state_to_tensor(_state) for (_layer, _state) in zip(self.layers, state)] 117 | x = [_x for _x in x if _x is not None] 118 | return torch.cat( x, dim=-1) 119 | return fn 120 | 121 | def default_state(self, *batch_shape, device=None): 122 | return [layer.default_state(*batch_shape, device=device) for layer in self.layers] 123 | 124 | def step(self, x, state, **kwargs): 125 | # Apply layers 126 | prev_states = [None] * len(self.layers) if state is None else state 127 | next_states = [] 128 | for layer, prev_state in zip(self.layers, prev_states): 129 | x, state = layer.step(x, state=prev_state, **kwargs) 130 | next_states.append(state) 131 | 132 | x = self.norm(x) 133 | 134 | return x, next_states 135 | -------------------------------------------------------------------------------- /convnova/src/ops/krylov.py: -------------------------------------------------------------------------------- 1 | """ Compute a Krylov function efficiently. (S4 renames the Krylov function to a "state space kernel") 2 | 3 | A : (N, N) 4 | b : (N,) 5 | c : (N,) 6 | Return: [c^T A^i b for i in [L]] 7 | """ 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from einops import rearrange, repeat 12 | 13 | from src.ops.toeplitz import causal_convolution 14 | 15 | def krylov_sequential(L, A, b, c=None): 16 | """ Constant matrix A 17 | 18 | A : (..., N, N) 19 | b : (..., N) 20 | c : (..., N) 21 | 22 | Returns 23 | if c: 24 | x : (..., L) 25 | x[i, l] = c[i] @ A^l @ b[i] 26 | 27 | else: 28 | x : (..., N, L) 29 | x[i, l] = A^l @ b[i] 30 | """ 31 | 32 | # Check which of dim b and c is smaller to save memory 33 | if c is not None and c.numel() < b.numel(): 34 | return krylov_sequential(L, A.transpose(-1, -2), c, b) 35 | 36 | b_ = b 37 | x = [] 38 | for _ in range(L): 39 | if c is not None: 40 | x_ = torch.sum(c*b_, dim=-1) # (...) # could be faster with matmul or einsum? 41 | else: 42 | x_ = b_ 43 | x.append(x_) 44 | b_ = (A @ b_.unsqueeze(-1)).squeeze(-1) 45 | 46 | x = torch.stack(x, dim=-1) 47 | return x 48 | 49 | 50 | def krylov(L, A, b, c=None, return_power=False): 51 | """ 52 | Compute the Krylov matrix (b, Ab, A^2b, ...) using the squaring trick. 53 | 54 | If return_power=True, return A^{L-1} as well 55 | """ 56 | # TODO There is an edge case if L=1 where output doesn't get broadcasted, which might be an issue if caller is expecting broadcasting semantics... can deal with it if it arises 57 | 58 | x = b.unsqueeze(-1) # (..., N, 1) 59 | A_ = A 60 | 61 | AL = None 62 | if return_power: 63 | AL = torch.eye(A.shape[-1], dtype=A.dtype, device=A.device) 64 | _L = L-1 65 | 66 | done = L == 1 67 | # loop invariant: _L represents how many indices left to compute 68 | while not done: 69 | if return_power: 70 | if _L % 2 == 1: AL = A_ @ AL 71 | _L //= 2 72 | 73 | # Save memory on last iteration 74 | l = x.shape[-1] 75 | if L - l <= l: 76 | done = True 77 | _x = x[..., :L-l] 78 | else: _x = x 79 | 80 | _x = A_ @ _x 81 | x = torch.cat([x, _x], dim=-1) # there might be a more efficient way of ordering axes 82 | if not done: A_ = A_ @ A_ 83 | 84 | assert x.shape[-1] == L 85 | 86 | if c is not None: 87 | x = torch.einsum('...nl, ...n -> ...l', x, c) 88 | x = x.contiguous() # WOW!! 89 | if return_power: 90 | return x, AL 91 | else: 92 | return x 93 | 94 | @torch.no_grad() 95 | def power(L, A, v=None): 96 | """ Compute A^L and the scan sum_i A^i v_i 97 | 98 | A: (..., N, N) 99 | v: (..., N, L) 100 | """ 101 | 102 | I = torch.eye(A.shape[-1]).to(A) # , dtype=A.dtype, device=A.device) 103 | 104 | powers = [A] 105 | l = 1 106 | while True: 107 | if L % 2 == 1: I = powers[-1] @ I 108 | L //= 2 109 | if L == 0: break 110 | l *= 2 111 | if v is None: 112 | powers = [powers[-1] @ powers[-1]] 113 | else: 114 | powers.append(powers[-1] @ powers[-1]) 115 | 116 | if v is None: return I 117 | 118 | # Invariants: 119 | # powers[-1] := A^l 120 | # l := largest po2 at most L 121 | 122 | # Note that an alternative divide and conquer to compute the reduction is possible and can be embedded into the above loop without caching intermediate powers of A 123 | # We do this reverse divide-and-conquer for efficiency reasons: 124 | # 1) it involves fewer padding steps for non-po2 L 125 | # 2) it involves more contiguous arrays 126 | 127 | # Take care of edge case for non-po2 arrays 128 | # Note that this initial step is a no-op for the case of power of 2 (l == L) 129 | k = v.size(-1) - l 130 | v_ = powers.pop() @ v[..., l:] 131 | v = v[..., :l] 132 | v[..., :k] = v[..., :k] + v_ 133 | 134 | # Handle reduction for power of 2 135 | while v.size(-1) > 1: 136 | v = rearrange(v, '... (z l) -> ... z l', z=2) 137 | v = v[..., 0, :] + powers.pop() @ v[..., 1, :] 138 | return I, v.squeeze(-1) 139 | 140 | def krylov_toeplitz(L, A, b, c=None): 141 | """ Specializes to lower triangular Toeplitz matrix A represented by its diagonals 142 | 143 | A : (..., N) 144 | b : (..., N) 145 | c : (..., N) 146 | 147 | Returns 148 | x : (..., N, L) 149 | x[i, l] = A^l @ b[i] 150 | """ 151 | x = b.unsqueeze(0) # (1, ..., N) 152 | A_ = A 153 | while x.shape[0] < L: 154 | xx = causal_convolution(A_, x) 155 | x = torch.cat([x, xx], dim=0) # there might be a more efficient way of ordering axes 156 | A_ = causal_convolution(A_, A_) 157 | x = x[:L, ...] # (L, ..., N) 158 | if c is not None: 159 | x = torch.einsum('l...n, ...n -> ...l', x, c) 160 | else: 161 | x = rearrange(x, 'l ... n -> ... n l') 162 | x = x.contiguous() 163 | return x 164 | 165 | def krylov_toeplitz_(L, A, b, c=None): 166 | """ Padded version of krylov_toeplitz that saves some fft's 167 | 168 | TODO currently not faster than original version, not sure why 169 | """ 170 | N = A.shape[-1] 171 | 172 | x = b.unsqueeze(0) # (1, ..., N) 173 | x = F.pad(x, (0, N)) 174 | A = F.pad(A, (0, N)) 175 | done = L == 1 176 | while not done: 177 | l = x.shape[0] 178 | # Save memory on last iteration 179 | if L - l <= l: 180 | done = True 181 | _x = x[:L-l] 182 | else: _x = x 183 | Af = torch.fft.rfft(A, n=2*N, dim=-1) 184 | xf = torch.fft.rfft(_x, n=2*N, dim=-1) 185 | xf_ = Af * xf 186 | x_ = torch.fft.irfft(xf_, n=2*N, dim=-1) 187 | x_[..., N:] = 0 188 | x = torch.cat([x, x_], dim=0) # there might be a more efficient way of ordering axes 189 | if not done: 190 | A = torch.fft.irfft(Af*Af, n=2*N, dim=-1) 191 | A[..., N:] = 0 192 | x = x[:L, ..., :N] # (L, ..., N) 193 | if c is not None: 194 | x = torch.einsum('l...n, ...n -> ...l', x, c) 195 | else: 196 | x = rearrange(x, 'l ... n -> ... n l') 197 | x = x.contiguous() 198 | return x 199 | -------------------------------------------------------------------------------- /convnova/src/utils/train.py: -------------------------------------------------------------------------------- 1 | """ Utils for the training loop. Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py """ 2 | import logging 3 | import os 4 | import warnings 5 | from typing import List, Sequence 6 | 7 | import torch.nn as nn 8 | import pytorch_lightning as pl 9 | import rich.syntax 10 | import rich.tree 11 | from omegaconf import DictConfig, OmegaConf 12 | from pytorch_lightning.utilities import rank_zero_only 13 | 14 | from src.utils.config import omegaconf_filter_keys 15 | 16 | 17 | # Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging 18 | class LoggingContext: 19 | def __init__(self, logger, level=None, handler=None, close=True): 20 | self.logger = logger 21 | self.level = level 22 | self.handler = handler 23 | self.close = close 24 | 25 | def __enter__(self): 26 | if self.level is not None: 27 | self.old_level = self.logger.level 28 | self.logger.setLevel(self.level) 29 | if self.handler: 30 | self.logger.addHandler(self.handler) 31 | 32 | def __exit__(self, et, ev, tb): 33 | if self.level is not None: 34 | self.logger.setLevel(self.old_level) 35 | if self.handler: 36 | self.logger.removeHandler(self.handler) 37 | if self.handler and self.close: 38 | self.handler.close() 39 | # implicit return of None => don't swallow exceptions 40 | 41 | 42 | def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: 43 | """Initializes multi-GPU-friendly python logger.""" 44 | 45 | logger = logging.getLogger(name) 46 | logger.setLevel(level) 47 | 48 | # this ensures all logging levels get marked with the rank zero decorator 49 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 50 | for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"): 51 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 52 | 53 | return logger 54 | 55 | 56 | def process_config(config: DictConfig) -> DictConfig: # TODO because of filter_keys, this is no longer in place 57 | """A couple of optional utilities, controlled by main config file: 58 | - disabling warnings 59 | - easier access to debug mode 60 | - forcing debug friendly configuration 61 | Modifies DictConfig in place. 62 | Args: 63 | config (DictConfig): Configuration composed by Hydra. 64 | """ 65 | log = get_logger() 66 | 67 | # Filter out keys that were used just for interpolation 68 | # config = dictconfig_filter_keys(config, lambda k: not k.startswith('__')) 69 | config = omegaconf_filter_keys(config, lambda k: not k.startswith('__')) 70 | 71 | # enable adding new keys to config 72 | OmegaConf.set_struct(config, False) 73 | 74 | # disable python warnings if