├── .gitignore ├── .project-root ├── Makefile ├── README.md ├── configs ├── __init__.py ├── callbacks │ ├── default.yaml │ ├── early_stopping.yaml │ ├── model_checkpoint.yaml │ ├── model_summary.yaml │ ├── none.yaml │ ├── retrobridge.yaml │ └── rich_progress_bar.yaml ├── data │ ├── enhancer.yaml │ ├── promoter_design.yaml │ ├── qm9.yaml │ └── toy_dfm.yaml ├── debug │ ├── default.yaml │ ├── fdr.yaml │ ├── limit.yaml │ ├── overfit.yaml │ └── profiler.yaml ├── eval.yaml ├── experiment │ ├── enhancer_fly_sfm_cnn.yaml │ ├── enhancer_mel_sfm_cnn.yaml │ ├── enhancer_sfm_bmlp.yaml │ ├── promoter_dfm.yaml │ ├── promoter_rfm_tmlp.yaml │ ├── promoter_sfm_cnn.yaml │ ├── promoter_sfm_promdfm.yaml │ ├── promoter_sfm_tmlp.yaml │ ├── promoter_sfm_unet1d.yaml │ ├── qm_clean_sfm.yaml │ ├── qm_euclid.yaml │ ├── qm_simplex.yaml │ ├── qm_simplex_boost.yaml │ ├── qm_simplex_pushing.yaml │ ├── qm_sphere.yaml │ ├── qm_sphere_pushing.yaml │ ├── qm_vecfield_sfm.yaml │ ├── toy_dfm_bmlp.yaml │ ├── toy_dfm_cnn.yaml │ ├── toy_dfm_sfm_cnn.yaml │ ├── toy_dfm_sfm_tmlp.yaml │ ├── toy_dfm_sfm_unet1d.yaml │ ├── toy_dfm_temb.yaml │ ├── toy_rfm_cnn.yaml │ └── toy_rfm_tmlp.yaml ├── extras │ └── default.yaml ├── hparams_search │ └── mnist_optuna.yaml ├── hydra │ └── default.yaml ├── local │ └── .gitkeep ├── logger │ ├── aim.yaml │ ├── comet.yaml │ ├── csv.yaml │ ├── many_loggers.yaml │ ├── mlflow.yaml │ ├── neptune.yaml │ ├── tensorboard.yaml │ └── wandb.yaml ├── misc │ └── qm9_sfm.yaml ├── model │ ├── benhancer_mlp.yaml │ ├── bsignal_mlp.yaml │ ├── mel_cnn.yaml │ ├── molecule_module.yaml │ ├── promoter_dfm.yaml │ ├── promoter_model.yaml │ ├── promoter_sfm_unet1d.yaml │ ├── qm_vecfield_sfm.yaml │ ├── sfm_tmlp.yaml │ ├── sfm_tmlp_signal.yaml │ ├── toy_bmlp.yaml │ ├── toy_cnn_model.yaml │ ├── toy_cnn_sfm.yaml │ ├── toy_temb.yaml │ └── toy_unet1d.yaml ├── paths │ └── default.yaml ├── train.yaml └── trainer │ ├── cpu.yaml │ ├── ddp.yaml │ ├── ddp_sim.yaml │ ├── default.yaml │ └── gpu.yaml ├── environment.yaml ├── process_qm9.py ├── pyproject.toml ├── requirements.txt ├── script ├── submit.sh ├── submit_bmlp_sfm.sh ├── submit_bmlp_sfm_lin.sh ├── submit_bmlp_sfm_noot.sh ├── submit_bmlp_sfm_simplex_noot.sh ├── submit_bmlp_sfm_simplex_ot.sh ├── submit_cnn_dfm_toy.sh ├── submit_cnn_sfm_toy.sh ├── submit_cnn_sfm_toy_lsmooth.sh ├── submit_enhancer_bmlp.sh ├── submit_enhancer_cnn_sfm.sh ├── submit_enhancer_cnn_sfm_dna.sh ├── submit_eval_fly.sh ├── submit_eval_mel.sh ├── submit_eval_promoter.sh ├── submit_promoter_sfm_bmlp.sh ├── submit_promoter_sfm_promdfm.sh ├── submit_qm_clean.sh ├── submit_qm_sfm.sh ├── submit_qm_sfm_single.sh ├── submit_rfm_cnn_toy.sh ├── submit_tmlp_sfm_promoter.sh ├── submit_unet1d_toy.sh ├── submit_unet_sfm_promoter.sh ├── toy_all_bmlp_sfm.sh ├── toy_all_bmlp_sfm_noot.sh ├── toy_all_bmlp_sfm_simplex_noot.sh ├── toy_all_bmlp_sfm_simplex_ot.sh ├── toy_all_cnn_dfm.sh ├── toy_all_cnn_sfm.sh ├── toy_bmlp_lin.sh └── toy_cnn_sfm_lsmooth.sh ├── setup.py ├── src ├── __init__.py ├── data │ ├── __init__.py │ ├── components │ │ ├── __init__.py │ │ ├── fbd.py │ │ ├── ff_energy.py │ │ ├── molecule.py │ │ ├── molecule_builder.py │ │ ├── molecule_prior.py │ │ ├── promoter_back.py │ │ ├── promoter_eval.py │ │ ├── qm_utils.py │ │ ├── sample_analyzer.py │ │ └── sei.py │ ├── dna_enhancer_datamodule.py │ ├── promoter_datamodule.py │ ├── qm9_datamodule.py │ └── toy_dfm_datamodule.py ├── dfm │ ├── __init__.py │ └── flow_utils.py ├── eval.py ├── models │ ├── __init__.py │ ├── dfm_module.py │ ├── molecule_module.py │ ├── net │ │ ├── __init__.py │ │ ├── gvp.py │ │ ├── interpolant_scheduler.py │ │ ├── model.py │ │ ├── promoter_model.py │ │ └── vector_field.py │ └── sfm_module.py ├── sfm │ ├── __init__.py │ ├── distribution.py │ ├── manifold.py │ ├── maths.py │ ├── plot.py │ ├── sampler.py │ └── train.py ├── train.py └── utils │ ├── __init__.py │ ├── instantiators.py │ ├── logging_utils.py │ ├── pylogger.py │ ├── rich_utils.py │ └── utils.py └── tests ├── __init__.py ├── conftest.py ├── helpers ├── __init__.py ├── package_available.py ├── run_if.py └── run_sh_command.py ├── test_configs.py ├── test_datamodules.py ├── test_eval.py ├── test_manifold.py ├── test_maths.py ├── test_sweeps.py └── test_train.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | out/ 3 | *.out 4 | wandb/ 5 | slurm/ 6 | ./data/ 7 | .pytest_cache/ 8 | logs/ 9 | *.csv 10 | *.pt 11 | pytabix_fix/ 12 | copy2server.sh 13 | cache/ 14 | -------------------------------------------------------------------------------- /.project-root: -------------------------------------------------------------------------------- 1 | # Empty 2 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | help: ## Show help 3 | @grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 4 | 5 | clean: ## Clean autogenerated files 6 | rm -rf dist 7 | find . -type f -name "*.DS_Store" -ls -delete 8 | find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf 9 | find . | grep -E ".pytest_cache" | xargs rm -rf 10 | find . | grep -E ".ipynb_checkpoints" | xargs rm -rf 11 | rm -f .coverage 12 | 13 | clean-logs: ## Clean logs 14 | rm -rf logs/** 15 | 16 | format: ## Run pre-commit hooks 17 | pre-commit run -a 18 | 19 | sync: ## Merge changes from main branch to your current branch 20 | git pull 21 | git pull origin main 22 | 23 | test: ## Run not slow tests 24 | pytest -k "not slow" 25 | 26 | test-full: ## Run all tests 27 | pytest 28 | 29 | train: ## Train the model 30 | python src/train.py 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fisher Flow Matching 2 | All our dependencies are listed in `environment.yaml`, for Conda, and `requirements.txt`, for `pip`. Please also separately install `DGL`: 3 | ```bash 4 | pip install -r requirements.txt 5 | pip install dgl -f https://data.dgl.ai/wheels/torch-2.1/cu121/repo.html 6 | ``` 7 | Our code contains parts of [FlowMol](https://github.com/Dunni3/FlowMol/tree/main) by Dunn and Koes [1] (most of QM9 experiments), [Riemannian-FM](https://github.com/facebookresearch/riemannian-fm) by Chen, et al. [2], and, for the baselines, [DFM](https://github.com/HannesStark/dirichlet-flow-matching/tree/main) by Stark, et al [3]. 8 | 9 | ## Toy Experiment 10 | For the DFM toy experiment, the following command allows us to run our code: 11 | ```bash 12 | python -m src.train experiment=toy_dfm_bmlp data.dim=100 trainer=gpu trainer.max_epochs=500 13 | ``` 14 | Of course, the dimension argument is varied, and the configuration files allow for changing manifolds (`"simplex"`, or `"sphere"`) and turn OT on/off (`"exact"` or `"None"`). 15 | 16 | ## Promoter and Enhancer DNA Experiment 17 | To download the datasets, it suffices to follow the steps of [Stark, et al](https://github.com/HannesStark/dirichlet-flow-matching/). For evaluating the FBD, it also needed to download their weights from their `workdir.zip`. To run the promoter dataset experiment, the following command can be used: 18 | 19 | ```bash 20 | python -m src.train experiment=promoter_sfm_promdfm trainer.max_epochs=200 trainer=gpu data.batch_size=128 21 | ``` 22 | 23 | As for the enhancer MEL2 experiment, the following command is available: 24 | 25 | ```bash 26 | python -m src.train experiment=enhancer_mel_sfm_cnn trainer.max_epochs=800 trainer=gpu 27 | ``` 28 | 29 | and for the FlyBrain DNA one: 30 | ```bash 31 | python -m src.train experiment=enhancer_fly_sfm_cnn trainer.max_epochs=800 trainer=gpu 32 | ``` 33 | 34 | ## QM9 experiment 35 | To install the QM9 dataset, we have included `process_qm9.py` from FlowMol, so it suffices to follow the steps indicated in their [README](https://github.com/Dunni3/FlowMol/tree/main). 36 | 37 | ```bash 38 | python -m src.train experiment=qm_clean_sfm trainer=gpu 39 | ``` 40 | 41 | ## References 42 | - [1]: [Dunn and Koes: Mixed Continuous and Categorical Flow Matching for 3D De Novo Molecule Generation](https://arxiv.org/abs/2404.19739). 43 | - [2]: [Chen, et al.: Flow Matching on General Geometries](https://arxiv.org/pdf/2302.03660). 44 | - [3]: [Stark, et al.: Dirichlet Flow Matching with Applications to DNA Sequence Design](https://arxiv.org/abs/2402.05841). 45 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | # this file is needed here to include configs when building project as a package 2 | -------------------------------------------------------------------------------- /configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model_checkpoint 3 | - early_stopping 4 | - model_summary 5 | - rich_progress_bar 6 | - _self_ 7 | 8 | model_checkpoint: 9 | dirpath: ${paths.output_dir}/checkpoints 10 | filename: "epoch_{epoch:03d}" 11 | monitor: "val/loss" 12 | mode: "min" 13 | save_last: True 14 | auto_insert_metric_name: False 15 | 16 | early_stopping: 17 | monitor: "val/loss" 18 | patience: 500 19 | mode: "max" 20 | 21 | model_summary: 22 | max_depth: -1 23 | -------------------------------------------------------------------------------- /configs/callbacks/early_stopping.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html 2 | 3 | early_stopping: 4 | _target_: lightning.pytorch.callbacks.EarlyStopping 5 | monitor: ??? # quantity to be monitored, must be specified !!! 6 | min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement 7 | patience: 3 # number of checks with no improvement after which training will be stopped 8 | verbose: False # verbosity mode 9 | mode: "min" # "max" means higher metric value is better, can be also "min" 10 | strict: True # whether to crash the training if monitor is not found in the validation metrics 11 | check_finite: True # when set True, stops training when the monitor becomes NaN or infinite 12 | stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold 13 | divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold 14 | check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch 15 | # log_rank_zero_only: False # this keyword argument isn't available in stable version 16 | -------------------------------------------------------------------------------- /configs/callbacks/model_checkpoint.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html 2 | 3 | model_checkpoint: 4 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 5 | dirpath: null # directory to save the model file 6 | filename: null # checkpoint filename 7 | monitor: "val/loss" # name of the logged metric which determines when model is improving 8 | verbose: False # verbosity mode 9 | save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt 10 | save_top_k: 1 # save k best models (determined by above metric) 11 | mode: "min" # "max" means higher metric value is better, can be also "min" 12 | auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name 13 | save_weights_only: False # if True, then only the model’s weights will be saved 14 | every_n_train_steps: null # number of training steps between checkpoints 15 | train_time_interval: null # checkpoints are monitored at the specified time interval 16 | every_n_epochs: null # number of epochs between checkpoints 17 | save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation 18 | -------------------------------------------------------------------------------- /configs/callbacks/model_summary.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html 2 | 3 | model_summary: 4 | _target_: lightning.pytorch.callbacks.RichModelSummary 5 | max_depth: 1 # the maximum depth of layer nesting that the summary will include 6 | -------------------------------------------------------------------------------- /configs/callbacks/none.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olsdavis/fisher-flow/8102f853e3b4c0f29f3a700959ec86e426fa86e8/configs/callbacks/none.yaml -------------------------------------------------------------------------------- /configs/callbacks/retrobridge.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model_checkpoint 3 | - early_stopping 4 | - model_summary 5 | - rich_progress_bar 6 | - _self_ 7 | 8 | model_checkpoint: 9 | dirpath: ${paths.output_dir}/checkpoints 10 | filename: "epoch_{epoch:03d}" 11 | monitor: "val_loss/batch_CE" 12 | mode: "max" 13 | save_last: True 14 | auto_insert_metric_name: False 15 | 16 | early_stopping: 17 | monitor: "val_loss/batch_CE" 18 | patience: 500 19 | mode: "max" 20 | 21 | model_summary: 22 | max_depth: -1 23 | -------------------------------------------------------------------------------- /configs/callbacks/rich_progress_bar.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html 2 | 3 | rich_progress_bar: 4 | _target_: lightning.pytorch.callbacks.RichProgressBar 5 | -------------------------------------------------------------------------------- /configs/data/enhancer.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.data.DNAEnhancerDataModule 2 | batch_size: 64 # Needs to be divisible by the number of devices (e.g., if in a distributed setup) 3 | num_workers: 32 4 | pin_memory: True 5 | -------------------------------------------------------------------------------- /configs/data/promoter_design.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.data.PromoterDesignDataModule 2 | data_dir: ${paths.data_dir}/promoter_design 3 | batch_size: 128 # Needs to be divisible by the number of devices (e.g., if in a distributed setup) 4 | num_workers: 4 5 | pin_memory: False 6 | sep_x_y: True 7 | -------------------------------------------------------------------------------- /configs/data/qm9.yaml: -------------------------------------------------------------------------------- 1 | # see trained_models.qm9_gaussian/config.yaml for an example in the original simplex flow repo: https://github.com/Dunni3/FlowMol/tree/main 2 | _target_: src.data.MoleculeDataModule 3 | dataset_config: xxx 4 | dm_prior_config: xxx 5 | batch_size: 384 # Needs to be divisible by the number of devices (e.g., if in a distributed setup) 6 | num_workers: 32 7 | distributed: false 8 | max_num_edges: 4000 9 | -------------------------------------------------------------------------------- /configs/data/toy_dfm.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.data.ToyDFMDataModule 2 | data_dir: ${paths.data_dir} 3 | batch_size: 512 # Needs to be divisible by the number of devices (e.g., if in a distributed setup) 4 | train_val_test_split: [100_000, 5_000, 10_000] 5 | num_workers: 4 6 | pin_memory: False 7 | k: 4 8 | dim: 100 9 | -------------------------------------------------------------------------------- /configs/debug/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default debugging setup, runs 1 full epoch 4 | # other debugging configs can inherit from this one 5 | 6 | # overwrite task name so debugging logs are stored in separate folder 7 | task_name: "debug" 8 | 9 | # disable callbacks and loggers during debugging 10 | callbacks: null 11 | logger: null 12 | 13 | extras: 14 | ignore_warnings: False 15 | enforce_tags: False 16 | 17 | # sets level of all command line loggers to 'DEBUG' 18 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ 19 | hydra: 20 | job_logging: 21 | root: 22 | level: DEBUG 23 | 24 | # use this to also set hydra loggers to 'DEBUG' 25 | # verbose: True 26 | 27 | trainer: 28 | max_epochs: 1 29 | accelerator: cpu # debuggers don't like gpus 30 | devices: 1 # debuggers don't like multiprocessing 31 | detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor 32 | 33 | data: 34 | num_workers: 0 # debuggers don't like multiprocessing 35 | pin_memory: False # disable gpu memory pin 36 | -------------------------------------------------------------------------------- /configs/debug/fdr.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs 1 train, 1 validation and 1 test step 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | fast_dev_run: true 10 | -------------------------------------------------------------------------------- /configs/debug/limit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # uses only 1% of the training data and 5% of validation/test data 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | max_epochs: 3 10 | limit_train_batches: 0.01 11 | limit_val_batches: 0.05 12 | limit_test_batches: 0.05 13 | -------------------------------------------------------------------------------- /configs/debug/overfit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # overfits to 3 batches 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | max_epochs: 20 10 | overfit_batches: 3 11 | 12 | # model ckpt and early stopping need to be disabled during overfitting 13 | callbacks: null 14 | -------------------------------------------------------------------------------- /configs/debug/profiler.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs with execution time profiling 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | max_epochs: 1 10 | profiler: "simple" 11 | # profiler: "advanced" 12 | # profiler: "pytorch" 13 | -------------------------------------------------------------------------------- /configs/eval.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - data: mnist # choose datamodule with `test_dataloader()` for evaluation 6 | - model: mnist 7 | - logger: null 8 | - trainer: default 9 | - paths: default 10 | - extras: default 11 | - hydra: default 12 | 13 | task_name: "eval" 14 | 15 | tags: ["dev"] 16 | 17 | # passing checkpoint path is necessary for evaluation 18 | ckpt_path: ??? 19 | -------------------------------------------------------------------------------- /configs/experiment/enhancer_fly_sfm_cnn.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: enhancer 8 | - override /model: toy_cnn_sfm 9 | - override /callbacks: default 10 | - override /trainer: default 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["enhancer", "cnn"] 16 | 17 | seed: 12345 18 | 19 | trainer: 20 | min_epochs: 1 21 | max_epochs: 800 22 | gradient_clip_val: 1.0 23 | 24 | model: 25 | optimizer: 26 | lr: 0.001 27 | net: 28 | k: 500 29 | dim: 4 30 | activation: lrelu 31 | depth: 4 32 | hidden: 128 33 | dropout: 0.0 34 | compile: false 35 | manifold: sphere 36 | ot_method: exact 37 | fbd_every: 20 38 | eval_fbd: true 39 | eval_ppl: false 40 | mel_or_dna: false # DNA 41 | fbd_classifier_path: "workdir/clsDNAclean_cnn_1stack_2023-12-30_15-01-30/epoch=15-step=10480.ckpt" 42 | inference_steps: 100 43 | 44 | data: 45 | batch_size: 512 46 | dataset: MEL2 47 | 48 | logger: 49 | wandb: 50 | tags: ${tags} 51 | name: enhancer_fly_s${seed} 52 | group: "enhancer" 53 | project: enhancer 54 | aim: 55 | experiment: "enhancer_mel" 56 | -------------------------------------------------------------------------------- /configs/experiment/enhancer_mel_sfm_cnn.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: enhancer 8 | - override /model: toy_cnn_sfm 9 | - override /callbacks: default 10 | - override /trainer: default 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["enhancer", "cnn"] 16 | 17 | seed: 12345 18 | 19 | trainer: 20 | min_epochs: 1 21 | max_epochs: 800 22 | gradient_clip_val: 1.0 23 | 24 | model: 25 | optimizer: 26 | lr: 0.001 27 | net: 28 | k: 500 29 | dim: 4 30 | activation: lrelu 31 | depth: 4 32 | hidden: 128 33 | dropout: 0.0 34 | compile: false 35 | manifold: sphere 36 | ot_method: None 37 | fbd_every: 10 38 | eval_ppl: false 39 | normalize_loglikelihood: true 40 | eval_fbd: true 41 | mel_or_dna: true # MEL 42 | fbd_classifier_path: "workdir/clsMELclean_cnn_dropout02_2023-12-31_12-26-28/epoch=9-step=5540.ckpt" 43 | inference_steps: 100 44 | 45 | data: 46 | batch_size: 512 47 | dataset: MEL2 48 | 49 | logger: 50 | wandb: 51 | tags: ${tags} 52 | name: enhancer_mel_s${seed}_lr${model.optimizer.lr} 53 | group: "enhancer" 54 | project: enhancer-mel 55 | aim: 56 | experiment: "enhancer_mel" 57 | -------------------------------------------------------------------------------- /configs/experiment/enhancer_sfm_bmlp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: enhancer 8 | - override /model: benhancer_mlp 9 | - override /callbacks: default 10 | - override /trainer: default 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["enhancer", "cnn"] 16 | 17 | seed: 12345 18 | 19 | trainer: 20 | min_epochs: 1 21 | max_epochs: 1000 22 | gradient_clip_val: 1.0 23 | 24 | model: 25 | optimizer: 26 | lr: 0.001 27 | net: 28 | k: 500 29 | dim: 4 30 | activation: lrelu 31 | compile: false 32 | manifold: sphere 33 | ema: true 34 | eval_fbd: true 35 | mel_or_dna: true # MEL 36 | ot_method: exact 37 | fbd_classifier_path: "workdir/clsMELclean_cnn_dropout02_2023-12-31_12-26-28/epoch=9-step=5540.ckpt" 38 | 39 | data: 40 | batch_size: 128 41 | dataset: MEL2 42 | 43 | logger: 44 | wandb: 45 | tags: ${tags} 46 | name: enhancer_mel_s${seed} 47 | group: "enhancer" 48 | project: enhancer 49 | aim: 50 | experiment: "enhancer" 51 | -------------------------------------------------------------------------------- /configs/experiment/promoter_dfm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: promoter_design 8 | - override /model: promoter_dfm 9 | - override /callbacks: default 10 | - override /trainer: default 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["promoter_design", "promoter_model"] 16 | 17 | seed: 12345 18 | 19 | trainer: 20 | max_steps: 450000 21 | limit_train_batches: null 22 | gradient_clip_val: 1.0 23 | accelerator: 'gpu' 24 | devices: [0] 25 | 26 | model: 27 | compile: false 28 | net: # TODO: not used remove 29 | model: 30 | mode: dirichlet 31 | embed_dim: 256 32 | time_dependent_weights: null 33 | time_step: 0.01 34 | 35 | logger: 36 | wandb: 37 | project: sfm 38 | tags: ${tags} 39 | group: promoter_design 40 | name: promoter_design_promdfm_B${data.batch_size}_s${seed} 41 | aim: 42 | experiment: "promoter_design" 43 | -------------------------------------------------------------------------------- /configs/experiment/promoter_rfm_tmlp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: promoter_design 8 | - override /model: rfm_tmlp_signal 9 | - override /callbacks: default 10 | - override /trainer: default 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["promoter_design", "promoter_model"] 16 | 17 | seed: 12345 18 | 19 | trainer: 20 | min_epochs: 1 21 | max_epochs: 1000 22 | gradient_clip_val: 1.0 23 | num_sanity_val_steps: 0 24 | 25 | model: 26 | compile: false 27 | promoter_eval: true 28 | net: 29 | k: 1024 30 | dim: 4 31 | atol: 1e-7 32 | rtol: 1e-7 33 | 34 | logger: 35 | wandb: 36 | tags: ${tags} 37 | group: "promoter_design" 38 | name: promoter_design_promdfm_B${data.batch_size} 39 | aim: 40 | experiment: "promoter_design" 41 | -------------------------------------------------------------------------------- /configs/experiment/promoter_sfm_cnn.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: promoter_design 8 | - override /model: bsignal_mlp 9 | - override /callbacks: default 10 | - override /trainer: default 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["promoter_design", "sfm_cnn"] 16 | 17 | seed: 12345 18 | 19 | trainer: 20 | min_epochs: 1 21 | max_epochs: 1000 22 | gradient_clip_val: 1.0 23 | 24 | model: 25 | optimizer: 26 | lr: 0.001 27 | net: 28 | k: 1024 29 | dim: 4 30 | manifold: sphere 31 | compile: false 32 | promoter_eval: true 33 | 34 | logger: 35 | wandb: 36 | tags: ${tags} 37 | group: "promoter_design" 38 | name: promoter_design_sfm_bmlp_B${data.batch_size} 39 | aim: 40 | experiment: "promoter_design" 41 | -------------------------------------------------------------------------------- /configs/experiment/promoter_sfm_promdfm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: promoter_design 8 | - override /model: promoter_model 9 | - override /callbacks: default 10 | - override /trainer: default 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["promoter_design", "promoter_model"] 16 | 17 | seed: 12345 18 | 19 | callbacks: 20 | model_checkpoint: 21 | monitor: "val/sp-mse" 22 | 23 | trainer: 24 | min_epochs: 1 25 | max_epochs: 1000 26 | gradient_clip_val: 1.0 27 | 28 | model: 29 | compile: false 30 | promoter_eval: true 31 | net: 32 | mode: sfm # just ignores the custom stuff 33 | ot_method: None 34 | inference_steps: 100 35 | manifold: sphere 36 | eval_ppl: true 37 | 38 | logger: 39 | wandb: 40 | tags: ${tags} 41 | group: "promoter_design" 42 | name: promoter_design_promdfm_B${data.batch_size} 43 | aim: 44 | experiment: "promoter_design" 45 | -------------------------------------------------------------------------------- /configs/experiment/promoter_sfm_tmlp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: promoter_design 8 | - override /model: sfm_tmlp_signal 9 | - override /callbacks: default 10 | - override /trainer: default 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["promoter_design", "sfm_cnn"] 16 | 17 | seed: 12345 18 | 19 | trainer: 20 | min_epochs: 1 21 | max_epochs: 1000 22 | gradient_clip_val: 1.0 23 | 24 | model: 25 | optimizer: 26 | lr: 0.001 27 | net: 28 | k: 1024 29 | dim: 4 30 | manifold: sphere 31 | compile: false 32 | promoter_eval: true 33 | 34 | logger: 35 | wandb: 36 | tags: ${tags} 37 | group: "promoter_design" 38 | name: promoter_design_sfm_tmlp 39 | aim: 40 | experiment: "promoter_design" 41 | -------------------------------------------------------------------------------- /configs/experiment/promoter_sfm_unet1d.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: promoter_design 8 | - override /model: promoter_sfm_unet1d 9 | - override /callbacks: default 10 | - override /trainer: default 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["promoter_design", "sfm_unet1d"] 16 | 17 | seed: 12345 18 | 19 | trainer: 20 | min_epochs: 1 21 | max_epochs: 200 22 | gradient_clip_val: 1.0 23 | 24 | model: 25 | optimizer: 26 | lr: 0.001 27 | net: 28 | k: 1024 29 | dim: 4 30 | sig_emb: 128 31 | batch_norm: false 32 | time_emb_size: 32 33 | depth: 5 34 | filters: 128 35 | manifold: sphere 36 | # label_smoothing: 0.81 37 | compile: false 38 | promoter_eval: true 39 | inference_steps: 100 40 | ot_method: None 41 | 42 | logger: 43 | wandb: 44 | tags: ${tags} 45 | group: "promoter_design" 46 | name: promoter_design_sfm_unet1d_B${data.batch_size} 47 | aim: 48 | experiment: "promoter_design" 49 | -------------------------------------------------------------------------------- /configs/experiment/qm_clean_sfm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: qm9 8 | - override /model: molecule_module 9 | - override /callbacks: default 10 | - override /trainer: default 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["qm9", "vecfield"] 16 | 17 | seed: 12345 18 | 19 | data: 20 | batch_size: 512 21 | 22 | trainer: 23 | min_epochs: 500 24 | max_epochs: 500 25 | gradient_clip_val: 1.0 26 | 27 | model: 28 | compile: false 29 | conditional: false 30 | inference_steps: 100 31 | features_manifolds: 32 | x: euclidean 33 | a: simplex 34 | c: simplex 35 | e: simplex 36 | features_priors: 37 | x: centered-gaussian 38 | a: uniform 39 | e: uniform 40 | c: uniform 41 | loss_weights: 42 | x: 3.0 43 | a: 0.4 44 | c: 1.0 45 | e: 2.0 46 | atom_type_map: ['C', 'H', 'N', 'O', 'F'] 47 | eval_mols_every: 5 48 | n_eval_mols: 128 49 | time_weighted_loss: true 50 | net: 51 | n_atom_types: 5 52 | features_manifolds: ${model.features_manifolds} 53 | interpolant_scheduler: linear 54 | n_recycles: 3 55 | 56 | logger: 57 | wandb: 58 | tags: ${tags} 59 | group: "qm9" 60 | name: s${seed}_deep_tw_${model.features_manifolds.a} 61 | project: sfm-qm9 62 | aim: 63 | experiment: "qm9" 64 | -------------------------------------------------------------------------------- /configs/experiment/qm_euclid.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: qm9 8 | - override /model: molecule_module 9 | - override /callbacks: default 10 | - override /trainer: default 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["qm9", "vecfield"] 16 | 17 | seed: 12345 18 | 19 | trainer: 20 | min_epochs: 500 21 | max_epochs: 500 22 | gradient_clip_val: 1.0 23 | 24 | model: 25 | compile: false 26 | conditional: false 27 | inference_steps: 100 28 | features_manifolds: 29 | x: euclidean 30 | a: euclidean 31 | c: euclidean 32 | e: euclidean 33 | features_priors: 34 | x: centered-gaussian 35 | a: uniform 36 | e: uniform 37 | c: uniform 38 | loss_weights: 39 | x: 3.0 40 | a: 0.4 41 | c: 1.0 42 | e: 2.0 43 | atom_type_map: ['C', 'H', 'N', 'O', 'F'] 44 | eval_mols_every: 1 45 | n_eval_mols: 128 46 | net: 47 | n_atom_types: 5 48 | features_manifolds: ${model.features_manifolds} 49 | interpolant_scheduler: linear 50 | 51 | logger: 52 | wandb: 53 | tags: ${tags} 54 | group: "qm9" 55 | name: qm9_s${seed}_linear_euclid 56 | project: sfm-qm9 57 | aim: 58 | experiment: "qm9" 59 | -------------------------------------------------------------------------------- /configs/experiment/qm_simplex.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: qm9 8 | - override /model: molecule_module 9 | - override /callbacks: default 10 | - override /trainer: default 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["qm9", "vecfield"] 16 | 17 | seed: 12345 18 | 19 | trainer: 20 | min_epochs: 400 21 | max_epochs: 400 22 | gradient_clip_val: 1.0 23 | 24 | model: 25 | compile: false 26 | conditional: false 27 | inference_steps: 200 28 | features_manifolds: 29 | x: euclidean 30 | a: simplex 31 | c: simplex 32 | e: simplex 33 | features_priors: 34 | x: centered-gaussian 35 | a: uniform 36 | e: uniform 37 | c: uniform 38 | loss_weights: 39 | x: 1.0 40 | a: 1.0 41 | c: 1.0 42 | e: 1.0 43 | atom_type_map: ['C', 'H', 'N', 'O', 'F'] 44 | eval_mols_every: 10 45 | n_eval_mols: 128 46 | time_weighted_loss: true 47 | net: 48 | n_atom_types: 5 49 | features_manifolds: ${model.features_manifolds} 50 | interpolant_scheduler: linear 51 | 52 | logger: 53 | wandb: 54 | tags: ${tags} 55 | group: "qm9" 56 | name: qm9_s${seed}_linear_simplex 57 | project: sfm-qm9 58 | aim: 59 | experiment: "qm9" 60 | -------------------------------------------------------------------------------- /configs/experiment/qm_simplex_boost.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: qm9 8 | - override /model: molecule_module 9 | - override /callbacks: default 10 | - override /trainer: default 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["qm9", "vecfield"] 16 | 17 | seed: 12345 18 | 19 | trainer: 20 | min_epochs: 500 21 | max_epochs: 500 22 | gradient_clip_val: 1.0 23 | 24 | model: 25 | compile: false 26 | conditional: false 27 | inference_steps: 100 28 | features_manifolds: 29 | x: euclidean 30 | a: simplex 31 | c: simplex 32 | e: simplex 33 | features_priors: 34 | x: centered-gaussian 35 | a: uniform 36 | e: uniform 37 | c: uniform 38 | loss_weights: 39 | x: 1.0 40 | a: 1.0 41 | c: 1.0 42 | e: 1.0 43 | atom_type_map: ['C', 'H', 'N', 'O', 'F'] 44 | eval_mols_every: 1 45 | n_eval_mols: 128 46 | inference_scaling: 10.0 47 | net: 48 | n_atom_types: 5 49 | features_manifolds: ${model.features_manifolds} 50 | interpolant_scheduler: linear 51 | 52 | logger: 53 | wandb: 54 | tags: ${tags} 55 | group: "qm9" 56 | name: qm9_s${seed}_linear_simplex_unweighted_boosted 57 | project: sfm-qm9 58 | aim: 59 | experiment: "qm9" 60 | -------------------------------------------------------------------------------- /configs/experiment/qm_simplex_pushing.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: qm9 8 | - override /model: molecule_module 9 | - override /callbacks: default 10 | - override /trainer: default 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["qm9", "vecfield"] 16 | 17 | seed: 12345 18 | 19 | trainer: 20 | min_epochs: 500 21 | max_epochs: 500 22 | gradient_clip_val: 1.0 23 | 24 | model: 25 | compile: false 26 | conditional: false 27 | inference_steps: 100 28 | features_manifolds: 29 | x: euclidean 30 | a: simplex 31 | c: simplex 32 | e: simplex 33 | features_priors: 34 | x: centered-gaussian 35 | a: pushing-normal 36 | e: pushing-normal 37 | c: pushing-normal 38 | loss_weights: 39 | x: 3.0 40 | a: 0.4 41 | c: 1.0 42 | e: 2.0 43 | atom_type_map: ['C', 'H', 'N', 'O', 'F'] 44 | eval_mols_every: 1 45 | n_eval_mols: 128 46 | net: 47 | n_atom_types: 5 48 | features_manifolds: ${model.features_manifolds} 49 | interpolant_scheduler: linear 50 | 51 | logger: 52 | wandb: 53 | tags: ${tags} 54 | group: "qm9" 55 | name: qm9_s${seed}_linear_sphere_pushing 56 | project: sfm-qm9 57 | aim: 58 | experiment: "qm9" 59 | -------------------------------------------------------------------------------- /configs/experiment/qm_sphere.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: qm9 8 | - override /model: molecule_module 9 | - override /callbacks: default 10 | - override /trainer: default 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["qm9", "vecfield"] 16 | 17 | seed: 12345 18 | 19 | trainer: 20 | min_epochs: 500 21 | max_epochs: 500 22 | gradient_clip_val: 1.0 23 | 24 | model: 25 | compile: false 26 | conditional: false 27 | inference_steps: 200 28 | features_manifolds: 29 | x: euclidean 30 | a: sphere 31 | c: sphere 32 | e: sphere 33 | features_priors: 34 | x: centered-gaussian 35 | a: uniform 36 | e: uniform 37 | c: uniform 38 | loss_weights: 39 | x: 1.0 40 | a: 1.0 41 | c: 1.0 42 | e: 1.0 43 | atom_type_map: ['C', 'H', 'N', 'O', 'F'] 44 | eval_mols_every: 10 45 | n_eval_mols: 128 46 | time_weighted_loss: true 47 | net: 48 | n_atom_types: 5 49 | features_manifolds: ${model.features_manifolds} 50 | interpolant_scheduler: linear 51 | 52 | logger: 53 | wandb: 54 | tags: ${tags} 55 | group: "qm9" 56 | name: qm9_s${seed}_${model.features_manifolds.a} 57 | project: sfm-qm9 58 | aim: 59 | experiment: "qm9" 60 | -------------------------------------------------------------------------------- /configs/experiment/qm_sphere_pushing.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: qm9 8 | - override /model: molecule_module 9 | - override /callbacks: default 10 | - override /trainer: default 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["qm9", "vecfield"] 16 | 17 | seed: 12345 18 | 19 | trainer: 20 | min_epochs: 500 21 | max_epochs: 500 22 | gradient_clip_val: 1.0 23 | 24 | model: 25 | compile: false 26 | conditional: false 27 | inference_steps: 100 28 | features_manifolds: 29 | x: euclidean 30 | a: sphere 31 | c: sphere 32 | e: sphere 33 | features_priors: 34 | x: centered-gaussian 35 | a: pushing-normal 36 | e: pushing-normal 37 | c: pushing-normal 38 | loss_weights: 39 | x: 3.0 40 | a: 0.4 41 | c: 1.0 42 | e: 2.0 43 | atom_type_map: ['C', 'H', 'N', 'O', 'F'] 44 | eval_mols_every: 1 45 | n_eval_mols: 128 46 | net: 47 | n_atom_types: 5 48 | features_manifolds: ${model.features_manifolds} 49 | interpolant_scheduler: linear 50 | 51 | logger: 52 | wandb: 53 | tags: ${tags} 54 | group: "qm9" 55 | name: qm9_s${seed}_linear_sphere_pushing 56 | project: sfm-qm9 57 | aim: 58 | experiment: "qm9" 59 | -------------------------------------------------------------------------------- /configs/experiment/qm_vecfield_sfm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: qm9 8 | - override /model: qm_vecfield_sfm 9 | - override /callbacks: default 10 | - override /trainer: default 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["qm9", "vecfield"] 16 | 17 | seed: 12345 18 | 19 | trainer: 20 | min_epochs: 1 21 | max_epochs: 500 22 | gradient_clip_val: 1.0 23 | 24 | model: 25 | compile: false 26 | ot_method: None 27 | net: 28 | n_atom_types: 5 29 | inference_steps: 100 30 | manifold: sphere 31 | closed_form_drv: false 32 | tangent_wrapper: false 33 | eval_unconditional_mols: true 34 | eval_n_mols: 64 35 | eval_unconditional_mols_every: 5 36 | fast_matmul: true 37 | predict_mol: true 38 | 39 | logger: 40 | wandb: 41 | tags: ${tags} 42 | group: "qm9" 43 | name: qm9_s${seed} 44 | project: sfm-qm9 45 | aim: 46 | experiment: "qm9" 47 | -------------------------------------------------------------------------------- /configs/experiment/toy_dfm_bmlp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: toy_dfm 8 | - override /model: toy_bmlp 9 | - override /callbacks: default 10 | - override /trainer: default 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["toy_dfm", "bmlp"] 16 | 17 | seed: 12345 18 | 19 | trainer: 20 | min_epochs: 1 21 | max_epochs: 1000 22 | gradient_clip_val: 1.0 23 | 24 | model: 25 | optimizer: 26 | lr: 0.001 27 | net: 28 | k: ${data.k} 29 | dim: ${data.dim} 30 | missing_coordinate: false 31 | compile: false 32 | manifold: "sphere" 33 | kl_eval: true 34 | ot_method: exact 35 | closed_form_drv: false 36 | 37 | data: 38 | batch_size: 512 39 | k: 4 40 | dim: 100 41 | 42 | logger: 43 | wandb: 44 | tags: ${tags} 45 | group: "toy_dfm" 46 | name: toy_${data.k}_${data.dim}_${model.manifold}_${model.ot_method} 47 | aim: 48 | experiment: "toy_dfm" 49 | -------------------------------------------------------------------------------- /configs/experiment/toy_dfm_cnn.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: toy_dfm 8 | - override /model: toy_cnn_model 9 | - override /callbacks: default 10 | - override /trainer: default 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["toy_dfm", "cnn"] 16 | 17 | seed: 12345 18 | 19 | trainer: 20 | min_epochs: 1 21 | max_epochs: 1000 22 | gradient_clip_val: 1.0 23 | num_sanity_val_steps: 0 24 | 25 | model: 26 | optimizer: 27 | lr: 0.0001 28 | net: 29 | k: ${data.k} 30 | dim: ${data.dim} 31 | mode: "dirichlet" 32 | compile: false 33 | mode: "dirichlet" 34 | 35 | data: 36 | batch_size: 512 37 | k: 4 38 | dim: 100 39 | 40 | logger: 41 | wandb: 42 | tags: ${tags} 43 | group: "toy_dfm" 44 | aim: 45 | experiment: "toy_dfm" 46 | -------------------------------------------------------------------------------- /configs/experiment/toy_dfm_sfm_cnn.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: toy_dfm 8 | - override /model: toy_cnn_sfm 9 | - override /callbacks: default 10 | - override /trainer: default 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["toy_dfm", "cnn"] 16 | 17 | seed: 12345 18 | 19 | trainer: 20 | min_epochs: 1 21 | max_epochs: 1000 22 | gradient_clip_val: 1.0 23 | accelerator: 'gpu' 24 | devices: [0] 25 | 26 | model: 27 | optimizer: 28 | lr: 0.001 29 | net: 30 | k: ${data.k} 31 | dim: ${data.dim} 32 | activation: lrelu 33 | compile: false 34 | manifold: sphere 35 | kl_eval: true 36 | ema: true 37 | # label_smoothing: 0.81 38 | 39 | data: 40 | batch_size: 512 41 | k: 4 42 | dim: 100 43 | 44 | logger: 45 | wandb: 46 | project: sfm 47 | tags: ${tags} 48 | group: "toy_dfm" 49 | name: "toy_cnn_sfm" 50 | aim: 51 | experiment: "toy_dfm" 52 | -------------------------------------------------------------------------------- /configs/experiment/toy_dfm_sfm_tmlp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: toy_dfm 8 | - override /model: sfm_tmlp 9 | - override /callbacks: default 10 | - override /trainer: default 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["toy_dfm", "cnn"] 16 | 17 | seed: 12345 18 | 19 | trainer: 20 | min_epochs: 1 21 | max_epochs: 1000 22 | gradient_clip_val: 1.0 23 | num_sanity_val_steps: 0 24 | 25 | model: 26 | optimizer: 27 | lr: 0.001 28 | net: 29 | k: ${data.k} 30 | dim: ${data.dim} 31 | compile: false 32 | kl_eval: true 33 | # ema: true 34 | # label_smoothing: 0.81 35 | 36 | data: 37 | batch_size: 512 38 | k: 4 39 | dim: 100 40 | 41 | logger: 42 | wandb: 43 | tags: ${tags} 44 | group: "toy_dfm" 45 | name: rfm_toy_dfm_${data.k}_${data.dim} 46 | aim: 47 | experiment: "toy_dfm" 48 | -------------------------------------------------------------------------------- /configs/experiment/toy_dfm_sfm_unet1d.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: toy_dfm 8 | - override /model: toy_unet1d 9 | - override /callbacks: default 10 | - override /trainer: default 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["toy_dfm", "cnn"] 16 | 17 | seed: 12345 18 | 19 | trainer: 20 | min_epochs: 1 21 | max_epochs: 1000 22 | gradient_clip_val: 1.0 23 | num_sanity_val_steps: 0 24 | 25 | model: 26 | optimizer: 27 | lr: 0.001 28 | net: 29 | k: ${data.k} 30 | dim: ${data.dim} 31 | compile: false 32 | manifold: sphere 33 | kl_eval: true 34 | 35 | data: 36 | batch_size: 512 37 | k: 4 38 | dim: 100 39 | 40 | logger: 41 | wandb: 42 | tags: ${tags} 43 | group: "toy_dfm" 44 | aim: 45 | experiment: "toy_dfm" 46 | -------------------------------------------------------------------------------- /configs/experiment/toy_dfm_temb.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: toy_dfm 8 | - override /model: toy_temb 9 | - override /callbacks: default 10 | - override /trainer: default 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["toy_dfm", "cnn"] 16 | 17 | seed: 12345 18 | 19 | trainer: 20 | min_epochs: 1 21 | max_epochs: 1000 22 | gradient_clip_val: 1.0 23 | 24 | model: 25 | optimizer: 26 | lr: 0.001 27 | net: 28 | k: ${data.k} 29 | dim: ${data.dim} 30 | compile: false 31 | manifold: sphere 32 | kl_eval: true 33 | 34 | data: 35 | batch_size: 512 36 | k: 4 37 | dim: 100 38 | 39 | logger: 40 | wandb: 41 | tags: ${tags} 42 | group: "toy_dfm" 43 | aim: 44 | experiment: "toy_dfm" 45 | -------------------------------------------------------------------------------- /configs/experiment/toy_rfm_cnn.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: toy_dfm 8 | - override /model: rfm_cnn_model 9 | - override /callbacks: default 10 | - override /trainer: default 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["toy_dfm", "cnn"] 16 | 17 | seed: 12345 18 | 19 | trainer: 20 | min_epochs: 1 21 | max_epochs: 1000 22 | gradient_clip_val: 1.0 23 | num_sanity_val_steps: 0 24 | 25 | model: 26 | optimizer: 27 | lr: 0.001 28 | net: 29 | k: ${data.k} 30 | dim: ${data.dim} 31 | activation: lrelu 32 | compile: false 33 | kl_eval: true 34 | # ema: true 35 | # label_smoothing: 0.81 36 | 37 | data: 38 | batch_size: 512 39 | k: 4 40 | dim: 100 41 | 42 | logger: 43 | wandb: 44 | tags: ${tags} 45 | group: "toy_dfm" 46 | name: rfm_toy_dfm_${data.k}_${data.dim} 47 | aim: 48 | experiment: "toy_dfm" 49 | -------------------------------------------------------------------------------- /configs/experiment/toy_rfm_tmlp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: toy_dfm 8 | - override /model: rfm_tmlp 9 | - override /callbacks: default 10 | - override /trainer: default 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["toy_dfm", "cnn"] 16 | 17 | seed: 12345 18 | 19 | trainer: 20 | min_epochs: 1 21 | max_epochs: 1000 22 | gradient_clip_val: 1.0 23 | num_sanity_val_steps: 0 24 | 25 | model: 26 | optimizer: 27 | lr: 0.001 28 | net: 29 | k: ${data.k} 30 | dim: ${data.dim} 31 | compile: false 32 | kl_eval: true 33 | atol: 1e-7 34 | rtol: 1e-7 35 | # ema: true 36 | # label_smoothing: 0.81 37 | 38 | data: 39 | batch_size: 512 40 | k: 4 41 | dim: 100 42 | 43 | logger: 44 | wandb: 45 | tags: ${tags} 46 | group: "toy_dfm" 47 | name: rfm_toy_dfm_${data.k}_${data.dim} 48 | aim: 49 | experiment: "toy_dfm" 50 | -------------------------------------------------------------------------------- /configs/extras/default.yaml: -------------------------------------------------------------------------------- 1 | # disable python warnings if they annoy you 2 | ignore_warnings: False 3 | 4 | # ask user for tags if none are provided in the config 5 | enforce_tags: True 6 | 7 | # pretty print config tree at the start of the run using Rich library 8 | print_config: True 9 | -------------------------------------------------------------------------------- /configs/hparams_search/mnist_optuna.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # example hyperparameter optimization of some experiment with Optuna: 4 | # python train.py -m hparams_search=mnist_optuna experiment=example 5 | 6 | defaults: 7 | - override /hydra/sweeper: optuna 8 | 9 | # choose metric which will be optimized by Optuna 10 | # make sure this is the correct name of some metric logged in lightning module! 11 | optimized_metric: "val/loss" 12 | 13 | # here we define Optuna hyperparameter search 14 | # it optimizes for value returned from function with @hydra.main decorator 15 | # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper 16 | hydra: 17 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached 18 | 19 | sweeper: 20 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 21 | 22 | # storage URL to persist optimization results 23 | # for example, you can use SQLite if you set 'sqlite:///example.db' 24 | storage: null 25 | 26 | # name of the study to persist optimization results 27 | study_name: null 28 | 29 | # number of parallel workers 30 | n_jobs: 1 31 | 32 | # 'minimize' or 'maximize' the objective 33 | direction: maximize 34 | 35 | # total number of runs that will be executed 36 | n_trials: 20 37 | 38 | # choose Optuna hyperparameter sampler 39 | # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others 40 | # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html 41 | sampler: 42 | _target_: optuna.samplers.TPESampler 43 | seed: 1234 44 | n_startup_trials: 10 # number of random sampling runs before optimization starts 45 | 46 | # define hyperparameter search space 47 | params: 48 | model.optimizer.lr: interval(0.0001, 0.1) 49 | data.batch_size: choice(32, 64, 128, 256) 50 | model.net.lin1_size: choice(64, 128, 256) 51 | model.net.lin2_size: choice(64, 128, 256) 52 | model.net.lin3_size: choice(32, 64, 128, 256) 53 | -------------------------------------------------------------------------------- /configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # https://hydra.cc/docs/configure_hydra/intro/ 2 | 3 | # enable color logging 4 | defaults: 5 | - override hydra_logging: colorlog 6 | - override job_logging: colorlog 7 | 8 | # output directory, generated dynamically on each run 9 | run: 10 | dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} 11 | sweep: 12 | dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} 13 | subdir: ${hydra.job.num} 14 | 15 | job_logging: 16 | handlers: 17 | file: 18 | # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242 19 | filename: ${hydra.runtime.output_dir}/${task_name}.log 20 | -------------------------------------------------------------------------------- /configs/local/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olsdavis/fisher-flow/8102f853e3b4c0f29f3a700959ec86e426fa86e8/configs/local/.gitkeep -------------------------------------------------------------------------------- /configs/logger/aim.yaml: -------------------------------------------------------------------------------- 1 | # https://aimstack.io/ 2 | 3 | # example usage in lightning module: 4 | # https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py 5 | 6 | # open the Aim UI with the following command (run in the folder containing the `.aim` folder): 7 | # `aim up` 8 | 9 | aim: 10 | _target_: aim.pytorch_lightning.AimLogger 11 | repo: ${paths.root_dir} # .aim folder will be created here 12 | # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html# 13 | 14 | # aim allows to group runs under experiment name 15 | experiment: null # any string, set to "default" if not specified 16 | 17 | train_metric_prefix: "train/" 18 | val_metric_prefix: "val/" 19 | test_metric_prefix: "test/" 20 | 21 | # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.) 22 | system_tracking_interval: 10 # set to null to disable system metrics tracking 23 | 24 | # enable/disable logging of system params such as installed packages, git info, env vars, etc. 25 | log_system_params: true 26 | 27 | # enable/disable tracking console logs (default value is true) 28 | capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550 29 | -------------------------------------------------------------------------------- /configs/logger/comet.yaml: -------------------------------------------------------------------------------- 1 | # https://www.comet.ml 2 | 3 | comet: 4 | _target_: lightning.pytorch.loggers.comet.CometLogger 5 | api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable 6 | save_dir: "${paths.output_dir}" 7 | project_name: "lightning-hydra-template" 8 | rest_api_key: null 9 | # experiment_name: "" 10 | experiment_key: null # set to resume experiment 11 | offline: False 12 | prefix: "" 13 | -------------------------------------------------------------------------------- /configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # csv logger built in lightning 2 | 3 | csv: 4 | _target_: lightning.pytorch.loggers.csv_logs.CSVLogger 5 | save_dir: "${paths.output_dir}" 6 | name: "csv/" 7 | prefix: "" 8 | -------------------------------------------------------------------------------- /configs/logger/many_loggers.yaml: -------------------------------------------------------------------------------- 1 | # train with many loggers at once 2 | 3 | defaults: 4 | # - comet 5 | - csv 6 | # - mlflow 7 | # - neptune 8 | - tensorboard 9 | - wandb 10 | -------------------------------------------------------------------------------- /configs/logger/mlflow.yaml: -------------------------------------------------------------------------------- 1 | # https://mlflow.org 2 | 3 | mlflow: 4 | _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger 5 | # experiment_name: "" 6 | # run_name: "" 7 | tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI 8 | tags: null 9 | # save_dir: "./mlruns" 10 | prefix: "" 11 | artifact_location: null 12 | # run_id: "" 13 | -------------------------------------------------------------------------------- /configs/logger/neptune.yaml: -------------------------------------------------------------------------------- 1 | # https://neptune.ai 2 | 3 | neptune: 4 | _target_: lightning.pytorch.loggers.neptune.NeptuneLogger 5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable 6 | project: username/lightning-hydra-template 7 | # name: "" 8 | log_model_checkpoints: True 9 | prefix: "" 10 | -------------------------------------------------------------------------------- /configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/tensorboard/ 2 | 3 | tensorboard: 4 | _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger 5 | save_dir: "${paths.output_dir}/tensorboard/" 6 | name: null 7 | log_graph: False 8 | default_hp_metric: True 9 | prefix: "" 10 | # version: "" 11 | -------------------------------------------------------------------------------- /configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: lightning.pytorch.loggers.wandb.WandbLogger 5 | name: "${data.k}_${data.dim}_${model._target_}" # name of the run (normally generated by wandb) 6 | save_dir: "${paths.output_dir}" 7 | offline: False 8 | id: null # pass correct id to resume experiment! 9 | anonymous: null # enable anonymous logging 10 | project: "simplex-flow" 11 | log_model: False # upload lightning ckpts 12 | prefix: "" # a string to put at the beginning of metric keys 13 | # entity: "" # set to name of your wandb team 14 | group: "" 15 | tags: [] 16 | job_type: "" 17 | -------------------------------------------------------------------------------- /configs/misc/qm9_sfm.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | output_dir: runs_paper_ed/ 3 | batch_size: 64 4 | num_workers: 8 5 | max_num_edges: 4.0e+5 6 | trainer_args: 7 | max_epochs: 1000 8 | accelerator: gpu 9 | devices: 1 10 | num_nodes: 1 11 | strategy: auto 12 | accumulate_grad_batches: 1 13 | limit_val_batches: 0.1 14 | gradient_clip_val: 1.0 15 | gradient_clip_algorithm: 'value' 16 | 17 | evaluation: 18 | mols_to_sample: 128 # how many molecules to sample during evaluation 19 | sample_interval: 0.2 # how often to sample molecules during training, measured in epochs 20 | val_loss_interval: 0.2 # how often to compute validation set loss during training, measured in epochs 21 | 22 | wandb: 23 | project: mol-fm 24 | group: 25 | name: qm-dp-ep 26 | mode: online # can be disabled, online, offline 27 | 28 | lr_scheduler: 29 | # to turn off warmup and restarts, set both warmup_length and restart_interval to 0 30 | base_lr: 1.0e-4 31 | warmup_length: 1.0 32 | restart_interval: 0 # 0 means no restart 33 | restart_type: 'linear' 34 | weight_decay: 1.0e-12 35 | 36 | dataset: 37 | raw_data_dir: data/qm9_raw 38 | processed_data_dir: data/qm9 39 | atom_map: ['C', 'H', 'N', 'O', 'F',] 40 | dataset_name: qm9 # must be qm9 or geom 41 | dataset_size: 42 | 43 | checkpointing: 44 | save_last: True 45 | save_top_k: 3 46 | monitor: 'val_total_loss' 47 | every_n_epochs: 1 48 | 49 | mol_fm: 50 | parameterization: endpoint # can be "endpoint" or "vector-field" 51 | weight_ae: False 52 | target_blur: 0.0 53 | time_scaled_loss: True 54 | total_loss_weights: 55 | x: 3.0 56 | a: 0.4 57 | c: 1.0 58 | e: 2.0 59 | prior_config: 60 | x: 61 | align: True 62 | type: 'centered-normal' 63 | kwargs: {std: 1.0} 64 | a: 65 | align: False 66 | type: 'marginal' 67 | kwargs: {blur: 0.15} 68 | c: 69 | align: False 70 | type: 'c-given-a' 71 | kwargs: {blur: 0.15} 72 | e: 73 | align: False 74 | type: 'marginal' 75 | kwargs: {blur: 0.15} 76 | 77 | vector_field: 78 | n_vec_channels: 16 79 | update_edge_w_distance: True 80 | n_hidden_scalars: 256 81 | n_hidden_edge_feats: 128 82 | n_recycles: 1 83 | separate_mol_updaters: True 84 | n_molecule_updates: 8 85 | convs_per_update: 1 86 | n_cp_feats: 4 87 | n_message_gvps: 3 88 | n_update_gvps: 3 89 | message_norm: 100 90 | rbf_dmax: 14 91 | rbf_dim: 16 92 | 93 | interpolant_scheduler: 94 | schedule_type: 95 | x: 'cosine' 96 | a: 'cosine' 97 | c: 'cosine' 98 | e: 'cosine' 99 | cosine_params: 100 | x: 1 101 | a: 2 102 | c: 2 103 | e: 1.5 104 | -------------------------------------------------------------------------------- /configs/model/benhancer_mlp.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.SFMModule 2 | 3 | optimizer: 4 | _target_: torch.optim.AdamW 5 | _partial_: true 6 | lr: 0.001 7 | weight_decay: 0.0 8 | 9 | scheduler: 10 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 11 | _partial_: true 12 | mode: min 13 | factor: 0.1 14 | patience: 10 15 | 16 | net: 17 | _target_: src.models.net.BestEnhancerMLP 18 | hidden: 512 19 | depth: 4 20 | activation: lrelu 21 | emb_size: 64 22 | 23 | 24 | compile: false 25 | -------------------------------------------------------------------------------- /configs/model/bsignal_mlp.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.SFMModule 2 | 3 | optimizer: 4 | _target_: torch.optim.AdamW 5 | _partial_: true 6 | lr: 0.001 7 | weight_decay: 0.0 8 | 9 | scheduler: 10 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 11 | _partial_: true 12 | mode: min 13 | factor: 0.1 14 | patience: 10 15 | 16 | net: 17 | _target_: src.models.net.BestSignalMLP 18 | hidden: 512 19 | depth: 4 20 | activation: lrelu 21 | emb_size: 64 22 | 23 | 24 | compile: false 25 | -------------------------------------------------------------------------------- /configs/model/mel_cnn.yaml: -------------------------------------------------------------------------------- 1 | net: 2 | k: 500 3 | dim: 4 4 | activation: lrelu 5 | depth: 4 6 | hidden: 128 7 | dropout: 0.0 8 | -------------------------------------------------------------------------------- /configs/model/molecule_module.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.MoleculeModule 2 | 3 | optimizer: 4 | _target_: torch.optim.Adam 5 | _partial_: true 6 | lr: 0.0001 7 | weight_decay: 1.0e-12 8 | 9 | scheduler: 10 | _target_: torch.optim.lr_scheduler.CosineAnnealingLR 11 | _partial_: true 12 | # mode: min 13 | # factor: 0.1 14 | # patience: 10 15 | T_max: 1000 16 | eta_min: 0.00008 17 | 18 | net: 19 | _target_: src.models.net.EndpointVectorField 20 | n_vec_channels: 16 21 | update_edge_w_distance: True 22 | n_hidden_scalars: 256 23 | n_hidden_edge_feats: 128 24 | n_recycles: 1 25 | separate_mol_updaters: True 26 | n_molecule_updates: 8 27 | convs_per_update: 1 28 | n_cp_feats: 4 29 | n_message_gvps: 3 30 | n_update_gvps: 3 31 | message_norm: 100 32 | rbf_dmax: 14 33 | rbf_dim: 16 34 | 35 | compile: false 36 | -------------------------------------------------------------------------------- /configs/model/promoter_dfm.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.PromoterModule 2 | 3 | optimizer: 4 | _target_: torch.optim.Adam 5 | _partial_: true 6 | lr: 0.0005 7 | weight_decay: 0.0 8 | 9 | scheduler: null 10 | 11 | model: 12 | _target_: src.models.net.PromoterModel 13 | mode: dirichlet 14 | embed_dim: 256 15 | time_dependent_weights: null 16 | time_step: 0.01 17 | 18 | compile: false 19 | 20 | mode: dirichlet 21 | 22 | # validate 23 | validate: false 24 | 25 | # run settings 26 | distill_ckpt: null # cls model for evaluation purposes 27 | distill_ckpt_hparams: null 28 | 29 | # model 30 | fix_alpha: null 31 | alpha_scale: 2 32 | alpha_max: 8 33 | prior_pseudocount: 2 34 | flow_temp: 1.0 35 | num_integration_steps: 100 36 | 37 | # logging 38 | print_freq: 100 39 | 40 | # misc. 41 | ckpt_iterations: null -------------------------------------------------------------------------------- /configs/model/promoter_model.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.SFMModule 2 | 3 | optimizer: 4 | _target_: torch.optim.AdamW 5 | _partial_: true 6 | lr: 0.001 7 | weight_decay: 0.0 8 | 9 | scheduler: 10 | _target_: torch.optim.lr_scheduler.CosineAnnealingLR 11 | _partial_: true 12 | # mode: min 13 | # factor: 0.1 14 | # patience: 10 15 | T_max: 1000 16 | eta_min: 0.00008 17 | 18 | net: 19 | _target_: src.models.net.PromoterModel 20 | mode: dirichlet 21 | embed_dim: 256 22 | time_step: 0.01 23 | 24 | compile: false 25 | -------------------------------------------------------------------------------- /configs/model/promoter_sfm_unet1d.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.SFMModule 2 | 3 | optimizer: 4 | _target_: torch.optim.AdamW 5 | _partial_: true 6 | lr: 0.001 7 | weight_decay: 0.0 8 | 9 | scheduler: 10 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 11 | _partial_: true 12 | mode: min 13 | factor: 0.1 14 | patience: 10 15 | 16 | net: 17 | _target_: src.models.net.UNet1DSignal 18 | activation: swish 19 | depth: 3 20 | filters: 128 21 | 22 | 23 | compile: false 24 | -------------------------------------------------------------------------------- /configs/model/qm_vecfield_sfm.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.SFMModule 2 | 3 | optimizer: 4 | _target_: torch.optim.AdamW 5 | _partial_: true 6 | lr: 0.001 7 | weight_decay: 0.0 8 | 9 | scheduler: 10 | _target_: torch.optim.lr_scheduler.CosineAnnealingLR 11 | _partial_: true 12 | # mode: min 13 | # factor: 0.1 14 | # patience: 10 15 | T_max: 1000 16 | eta_min: 0.00008 17 | 18 | net: 19 | _target_: src.models.net.EndpointVectorField 20 | n_vec_channels: 16 21 | update_edge_w_distance: True 22 | n_hidden_scalars: 256 23 | n_hidden_edge_feats: 128 24 | n_recycles: 1 25 | separate_mol_updaters: True 26 | n_molecule_updates: 8 27 | convs_per_update: 1 28 | n_cp_feats: 4 29 | n_message_gvps: 3 30 | n_update_gvps: 3 31 | message_norm: 100 32 | rbf_dmax: 14 33 | rbf_dim: 16 34 | 35 | compile: false 36 | -------------------------------------------------------------------------------- /configs/model/sfm_tmlp.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.SFMModule 2 | 3 | optimizer: 4 | _target_: torch.optim.AdamW 5 | _partial_: true 6 | lr: 0.001 7 | weight_decay: 0.0 8 | 9 | scheduler: 10 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 11 | _partial_: true 12 | mode: min 13 | factor: 0.1 14 | patience: 10 15 | # _target_: torch.optim.lr_scheduler.CosineAnnealingLR 16 | # T_max: 1000 17 | # eta_min: 0.00008 18 | 19 | net: 20 | _target_: src.models.net.TMLP 21 | activation: swish 22 | hidden: 512 23 | depth: 8 24 | fourier: null 25 | 26 | 27 | compile: false 28 | -------------------------------------------------------------------------------- /configs/model/sfm_tmlp_signal.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.SFMModule 2 | 3 | optimizer: 4 | _target_: torch.optim.AdamW 5 | _partial_: true 6 | lr: 0.001 7 | weight_decay: 0.0 8 | 9 | scheduler: 10 | # _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 11 | _partial_: true 12 | # mode: min 13 | # factor: 0.1 14 | # patience: 10 15 | _target_: torch.optim.lr_scheduler.CosineAnnealingLR 16 | T_max: 1000 17 | eta_min: 0.00008 18 | 19 | net: 20 | _target_: src.models.net.TMLPSignal 21 | activation: swish 22 | hidden: 512 23 | depth: 8 24 | fourier: null 25 | 26 | 27 | compile: false 28 | -------------------------------------------------------------------------------- /configs/model/toy_bmlp.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.SFMModule 2 | 3 | optimizer: 4 | _target_: torch.optim.AdamW 5 | _partial_: true 6 | lr: 0.001 7 | weight_decay: 0.0 8 | 9 | scheduler: 10 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 11 | _partial_: true 12 | mode: min 13 | factor: 0.1 14 | patience: 10 15 | 16 | net: 17 | _target_: src.models.net.BestMLP 18 | hidden: 512 19 | depth: 4 20 | activation: lrelu 21 | emb_size: 64 22 | batch_norm: false 23 | 24 | 25 | compile: false 26 | -------------------------------------------------------------------------------- /configs/model/toy_cnn_model.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.DNAModule 2 | 3 | optimizer: 4 | _target_: torch.optim.AdamW 5 | _partial_: true 6 | lr: 0.001 7 | weight_decay: 0.0 8 | 9 | scheduler: 10 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 11 | _partial_: true 12 | mode: min 13 | factor: 0.1 14 | patience: 10 15 | 16 | net: 17 | _target_: src.models.net.CNNModel 18 | hidden: 128 19 | mode: simplex 20 | num_cls: 3 21 | depth: 1 22 | dropout: 0.0 23 | prior_pseudocount: 2.0 24 | cls_expanded_simplex: False 25 | clean_data: False 26 | classifier: False 27 | classifier_free_guidance: False 28 | activation: relu 29 | 30 | 31 | compile: false 32 | -------------------------------------------------------------------------------- /configs/model/toy_cnn_sfm.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.SFMModule 2 | 3 | optimizer: 4 | _target_: torch.optim.AdamW 5 | _partial_: true 6 | lr: 0.001 7 | weight_decay: 0.0 8 | 9 | scheduler: 10 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 11 | _partial_: true 12 | mode: min 13 | factor: 0.1 14 | patience: 10 15 | # _target_: torch.optim.lr_scheduler.CosineAnnealingLR 16 | # T_max: 1000 17 | # eta_min: 0.00008 18 | 19 | net: 20 | _target_: src.models.net.CNNModel 21 | hidden: 128 22 | mode: simplex 23 | num_cls: 3 24 | depth: 1 25 | dropout: 0.0 26 | prior_pseudocount: 2.0 27 | cls_expanded_simplex: False 28 | clean_data: False 29 | classifier: False 30 | classifier_free_guidance: False 31 | activation: relu 32 | 33 | 34 | compile: false 35 | -------------------------------------------------------------------------------- /configs/model/toy_temb.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.SFMModule 2 | 3 | optimizer: 4 | _target_: torch.optim.AdamW 5 | _partial_: true 6 | lr: 0.001 7 | weight_decay: 0.0 8 | 9 | scheduler: 10 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 11 | _partial_: true 12 | mode: min 13 | factor: 0.1 14 | patience: 10 15 | 16 | net: 17 | _target_: src.models.net.TembMLP 18 | hidden: 512 19 | depth: 3 20 | emb_size: 128 21 | time_emb: sinusoidal 22 | input_emb: sinusoidal 23 | add_t_emb: false 24 | concat_t_emb: false 25 | activation: lrelu 26 | 27 | 28 | compile: false 29 | -------------------------------------------------------------------------------- /configs/model/toy_unet1d.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.SFMModule 2 | 3 | optimizer: 4 | _target_: torch.optim.AdamW 5 | _partial_: true 6 | lr: 0.001 7 | weight_decay: 0.0 8 | 9 | scheduler: 10 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 11 | _partial_: true 12 | mode: min 13 | factor: 0.1 14 | patience: 10 15 | 16 | net: 17 | _target_: src.models.net.UNet1DModel 18 | activation: gelu 19 | 20 | 21 | compile: false 22 | -------------------------------------------------------------------------------- /configs/paths/default.yaml: -------------------------------------------------------------------------------- 1 | # path to root directory 2 | # this requires PROJECT_ROOT environment variable to exist 3 | # you can replace it with "." if you want the root to be the current working directory 4 | root_dir: ${oc.env:PROJECT_ROOT} 5 | 6 | # path to data directory 7 | data_dir: ${paths.root_dir}/data/ 8 | 9 | # path to logging directory 10 | log_dir: ${paths.root_dir}/logs/ 11 | 12 | # path to output directory, created dynamically by hydra 13 | # path generation pattern is specified in `configs/hydra/default.yaml` 14 | # use it to store all files generated during the run, like ckpts and metrics 15 | output_dir: ${hydra:runtime.output_dir} 16 | 17 | # path to working directory 18 | work_dir: ${hydra:runtime.cwd} 19 | -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default configuration 4 | # order of defaults determines the order in which configs override each other 5 | defaults: 6 | - _self_ 7 | - data: toy_dfm 8 | - model: toy_bmlp 9 | - callbacks: default 10 | - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`) 11 | - trainer: default 12 | - paths: default 13 | - extras: default 14 | - hydra: default 15 | 16 | # experiment configs allow for version control of specific hyperparameters 17 | # e.g. best hyperparameters for given model and datamodule 18 | - experiment: null 19 | 20 | # config for hyperparameter optimization 21 | - hparams_search: null 22 | 23 | # optional local config for machine/user specific settings 24 | # it's optional since it doesn't need to exist and is excluded from version control 25 | - optional local: default 26 | 27 | # debugging config (enable through command line, e.g. `python train.py debug=default) 28 | - debug: null 29 | 30 | # task name, determines output directory path 31 | task_name: "train" 32 | 33 | # tags to help you identify your experiments 34 | # you can overwrite this in experiment configs 35 | # overwrite from command line with `python train.py tags="[first_tag, second_tag]"` 36 | tags: ["dev"] 37 | 38 | # set False to skip model training 39 | train: True 40 | 41 | # evaluate on test set, using best model weights achieved during training 42 | # lightning chooses best weights based on the metric specified in checkpoint callback 43 | test: True 44 | 45 | # simply provide checkpoint path to resume training 46 | ckpt_path: null 47 | 48 | # seed for random number generators in pytorch, numpy and python.random 49 | seed: null 50 | -------------------------------------------------------------------------------- /configs/trainer/cpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: cpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /configs/trainer/ddp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | strategy: ddp 5 | 6 | accelerator: gpu 7 | devices: 4 8 | num_nodes: 1 9 | sync_batchnorm: True 10 | -------------------------------------------------------------------------------- /configs/trainer/ddp_sim.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | # simulate DDP on CPU, useful for debugging 5 | accelerator: cpu 6 | devices: 2 7 | strategy: ddp_spawn 8 | -------------------------------------------------------------------------------- /configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: lightning.pytorch.trainer.Trainer 2 | 3 | default_root_dir: ${paths.output_dir} 4 | 5 | min_epochs: 1 # prevents early stopping 6 | max_epochs: 10 7 | 8 | accelerator: cpu 9 | devices: 1 10 | 11 | # mixed precision for extra speed-up 12 | # precision: 16 13 | 14 | # perform a validation loop every N training epochs 15 | check_val_every_n_epoch: 1 16 | 17 | # set True to to ensure deterministic results 18 | # makes training slower but gives more reproducibility than just setting seeds 19 | deterministic: False 20 | 21 | num_sanity_val_steps: 0 22 | -------------------------------------------------------------------------------- /configs/trainer/gpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: gpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | # reasons you might want to use `environment.yaml` instead of `requirements.txt`: 2 | # - pip installs packages in a loop, without ensuring dependencies across all packages 3 | # are fulfilled simultaneously, but conda achieves proper dependency control across 4 | # all packages 5 | # - conda allows for installing packages without requiring certain compilers or 6 | # libraries to be available in the system, since it installs precompiled binaries 7 | 8 | name: myenv 9 | 10 | channels: 11 | - pytorch 12 | - conda-forge 13 | - nvidia 14 | - defaults 15 | 16 | # it is strongly recommended to specify versions of packages installed through conda 17 | # to avoid situation when version-unspecified packages install their latest major 18 | # versions which can sometimes break things 19 | 20 | # current approach below keeps the dependencies in the same major versions across all 21 | # users, but allows for different minor and patch versions of packages where backwards 22 | # compatibility is usually guaranteed 23 | 24 | dependencies: 25 | - python=3.10 26 | - pytorch=2.* 27 | - torchvision=0.* 28 | - lightning=2.* 29 | - torchmetrics=0.* 30 | - hydra-core=1.* 31 | - rich=13.* 32 | - pre-commit=3.* 33 | - pytest=7.* 34 | # ours 35 | - pytorch-cuda=12.1 36 | - matplotlib=3.7.1 37 | - seaborn=0.12.2 38 | - pip>=23 39 | - pip: 40 | - torch==2.1.0 41 | - POT==0.9.3 42 | - geoopt==0.5.0 43 | - torch-ema==0.3 44 | - dirichlet==0.9 45 | - einops==0.7.0 46 | - wandb==0.16.5 47 | - geomstats==2.7.0 48 | - ipdb==0.13.13 49 | - rdkit==2023.9.5 50 | - torch_geometric==2.5.3 51 | - selene-sdk==0.4.4 52 | - pyBigWig==0.3.22 53 | - pyranges==0.0.129 54 | - cooler==0.9.3 55 | - cooltools==0.7.0 56 | - torchdiffeq==0.2.3 57 | - transformers==4.40.2 58 | - schedulefree==1.2.5 59 | - dgl==https://data.dgl.ai/wheels/torch-2.1/cu121/repo.html 60 | - pydantic==2.7.1 61 | - hydra-optuna-sweeper 62 | - hydra-colorlog 63 | - rootutils 64 | - torch_ema==0.3 65 | 66 | # --------- loggers --------- # 67 | # - wandb 68 | # - neptune-client 69 | # - mlflow 70 | # - comet-ml 71 | # - aim>=3.16.2 # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550 72 | 73 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.pytest.ini_options] 2 | addopts = [ 3 | "--color=yes", 4 | "--durations=0", 5 | "--strict-markers", 6 | "--doctest-modules", 7 | ] 8 | filterwarnings = [ 9 | "ignore::DeprecationWarning", 10 | "ignore::UserWarning", 11 | ] 12 | log_cli = "True" 13 | markers = [ 14 | "slow: slow tests", 15 | ] 16 | minversion = "6.0" 17 | testpaths = "tests/" 18 | 19 | [tool.coverage.report] 20 | exclude_lines = [ 21 | "pragma: nocover", 22 | "raise NotImplementedError", 23 | "raise NotImplementedError()", 24 | "if __name__ == .__main__.:", 25 | ] 26 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # --------- pytorch --------- # 2 | torch==2.1.0 3 | torchvision>=0.15.0 4 | lightning>=2.0.0 5 | torchmetrics>=0.11.4 6 | 7 | # --------- hydra --------- # 8 | hydra-core==1.3.2 9 | hydra-colorlog==1.2.0 10 | hydra-optuna-sweeper==1.2.0 11 | 12 | # --------- loggers --------- # 13 | wandb==0.16.5 14 | # neptune-client 15 | # mlflow 16 | # comet-ml 17 | # aim>=3.16.2 # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550 18 | 19 | # --------- others --------- # 20 | rootutils # standardizing the project root setup 21 | pre-commit # hooks for applying linters on commit 22 | rich # beautiful text formatting in terminal 23 | pytest # tests 24 | # sh # for running bash commands in some tests (linux/macos only) 25 | 26 | 27 | # Our dependencies: 28 | POT==0.9.3 29 | dirichlet==0.9 30 | einops==0.7.0 31 | selene-sdk==0.4.4 32 | biopython==1.83 33 | pyBigWig==0.3.22 34 | pyranges==0.0.129 35 | cooler==0.9.3 36 | cooltools==0.7.0 37 | rdkit==2023.9.5 38 | torch_geometric==2.5.3 39 | torchdiffeq==0.2.3 40 | schedulefree==1.2.5 41 | pydantic==2.7.1 42 | torch_ema==0.3.0 43 | -------------------------------------------------------------------------------- /script/submit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -N 1 # nodes requested 3 | #SBATCH --job-name=sfm 4 | #SBATCH --output=slurm/slurm_%j.out 5 | #SBATCH --error=slurm/slurm_%j.err 6 | #SBATCH --time=23:00:00 7 | #SBATCH --mem=36000 8 | #SBATCH --qos=long 9 | #SBATCH --gres=gpu:1 10 | conda activate sfm 11 | # srun -u python main.py -e dfm_toy -c config/toy_dfm/bmlp.yml -m sphere --wandb 12 | srun -u python -m src.train experiment=toy_dfm_bmlp trainer=gpu data.dim=$1 data.batch_size=1024 logger=wandb 13 | -------------------------------------------------------------------------------- /script/submit_bmlp_sfm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -N 1 # nodes requested 3 | #SBATCH --job-name=sfm 4 | #SBATCH --output=slurm/slurm_%j.out 5 | #SBATCH --error=slurm/slurm_%j.err 6 | #SBATCH --time=23:00:00 7 | #SBATCH --qos=medium 8 | #SBATCH --gres=gpu:1 9 | echo "Here: dim = $1; seed = $2" 10 | conda activate sfm 11 | srun -u python -m src.train experiment=toy_dfm_bmlp data.dim=$1 seed=$2 trainer=gpu trainer.max_epochs=500 logger=wandb 12 | -------------------------------------------------------------------------------- /script/submit_bmlp_sfm_lin.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -N 1 # nodes requested 3 | #SBATCH --job-name=sfm 4 | #SBATCH --output=slurm/slurm_%j.out 5 | #SBATCH --error=slurm/slurm_%j.err 6 | #SBATCH --time=3:00:00 7 | #SBATCH --qos=short 8 | #SBATCH --gres=gpu:1 9 | echo "Here: dim = $1; seed = $2" 10 | conda activate sfm 11 | srun -u python -m src.train experiment=toy_dfm_lin data.dim=$1 seed=$2 model.closed_form_drv=true logger.wandb.project=sfm-synth-lin trainer=gpu trainer.max_epochs=500 logger=wandb 12 | -------------------------------------------------------------------------------- /script/submit_bmlp_sfm_noot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -N 1 # nodes requested 3 | #SBATCH --job-name=sfm 4 | #SBATCH --output=slurm/slurm_%j.out 5 | #SBATCH --error=slurm/slurm_%j.err 6 | #SBATCH --time=3:00:00 7 | #SBATCH --qos=short 8 | #SBATCH --gres=gpu:1 9 | echo "Here: dim = $1; seed = $2" 10 | conda activate sfm 11 | srun -u python -m src.train experiment=toy_dfm_bmlp data.dim=$1 seed=$2 model.ot_method=None logger.wandb.project=sfm-synth-ablation trainer=gpu trainer.max_epochs=500 logger=wandb 12 | -------------------------------------------------------------------------------- /script/submit_bmlp_sfm_simplex_noot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -N 1 # nodes requested 3 | #SBATCH --job-name=sfm 4 | #SBATCH --output=slurm/slurm_%j.out 5 | #SBATCH --error=slurm/slurm_%j.err 6 | #SBATCH --time=3:00:00 7 | #SBATCH --qos=short 8 | #SBATCH --gres=gpu:1 9 | echo "Here: dim = $1; seed = $2" 10 | conda activate sfm 11 | srun -u python -m src.train experiment=toy_dfm_bmlp data.dim=$1 seed=$2 model.closed_form_drv=true model.manifold=simplex model.ot_method=None logger.wandb.project=sfm-synth-ablation trainer=gpu trainer.max_epochs=500 logger=wandb 12 | -------------------------------------------------------------------------------- /script/submit_bmlp_sfm_simplex_ot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -N 1 # nodes requested 3 | #SBATCH --job-name=sfm 4 | #SBATCH --output=slurm/slurm_%j.out 5 | #SBATCH --error=slurm/slurm_%j.err 6 | #SBATCH --time=3:00:00 7 | #SBATCH --qos=short 8 | #SBATCH --gres=gpu:1 9 | echo "Here: dim = $1; seed = $2" 10 | conda activate sfm 11 | srun -u python -m src.train experiment=toy_dfm_bmlp data.dim=$1 seed=$2 model.closed_form_drv=true model.manifold=simplex model.ot_method=exact logger.wandb.project=sfm-synth-ablation trainer=gpu trainer.max_epochs=500 logger=wandb 12 | -------------------------------------------------------------------------------- /script/submit_cnn_dfm_toy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -N 1 # nodes requested 3 | #SBATCH --job-name=sfm 4 | #SBATCH --output=slurm/slurm_%j.out 5 | #SBATCH --error=slurm/slurm_%j.err 6 | #SBATCH --time=23:00:00 7 | #SBATCH --qos=medium 8 | #SBATCH --gres=gpu:1 9 | echo "Here: dim = $1" 10 | conda activate sfm 11 | srun -u python -m src.train experiment=toy_dfm_cnn data.dim=$1 trainer=gpu trainer.max_epochs=1000 logger=wandb 12 | -------------------------------------------------------------------------------- /script/submit_cnn_sfm_toy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -N 1 # nodes requested 3 | #SBATCH --job-name=sfm 4 | #SBATCH --output=slurm/slurm_%j.out 5 | #SBATCH --error=slurm/slurm_%j.err 6 | #SBATCH --time=23:00:00 7 | #SBATCH --qos=medium 8 | #SBATCH --gres=gpu:1 9 | echo "Here: dim = $1" 10 | conda activate sfm 11 | srun -u python -m src.train experiment=toy_dfm_sfm_cnn data.dim=$1 trainer=gpu trainer.max_epochs=500 logger=wandb data.batch_size=512 12 | -------------------------------------------------------------------------------- /script/submit_cnn_sfm_toy_lsmooth.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -N 1 # nodes requested 3 | #SBATCH --job-name=sfm 4 | #SBATCH --output=slurm/slurm_%j.out 5 | #SBATCH --error=slurm/slurm_%j.err 6 | #SBATCH --time=10:00:00 7 | #SBATCH --qos=medium 8 | #SBATCH --gres=gpu:1 9 | echo "Here: dim = $1" 10 | conda activate sfm 11 | srun -u python -m src.train experiment=toy_dfm_sfm_cnn data.dim=$1 model.label_smoothing=$2 trainer=gpu trainer.max_epochs=250 logger=wandb data.batch_size=1024 12 | -------------------------------------------------------------------------------- /script/submit_enhancer_bmlp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -N 1 # nodes requested 3 | #SBATCH --job-name=sfm 4 | #SBATCH --output=slurm/slurm_%j.out 5 | #SBATCH --error=slurm/slurm_%j.err 6 | #SBATCH --time=23:00:00 7 | #SBATCH --qos=medium 8 | #SBATCH --gres=gpu:1 9 | conda activate sfm 10 | srun -u python -m src.train experiment=enhancer_sfm_bmlp trainer.max_epochs=800 trainer=gpu seed=$1 logger=wandb -------------------------------------------------------------------------------- /script/submit_enhancer_cnn_sfm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -N 1 # nodes requested 3 | #SBATCH --job-name=sfm 4 | #SBATCH --output=slurm/slurm_%j.out 5 | #SBATCH --error=slurm/slurm_%j.err 6 | #SBATCH --time=23:00:00 7 | #SBATCH --qos=medium 8 | #SBATCH --gres=gpu:1 9 | conda activate sfm 10 | srun -u python -m src.train experiment=enhancer_mel_sfm_cnn trainer.max_epochs=800 trainer=gpu seed=$1 logger=wandb -------------------------------------------------------------------------------- /script/submit_enhancer_cnn_sfm_dna.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -N 1 # nodes requested 3 | #SBATCH --job-name=sfm 4 | #SBATCH --output=slurm/slurm_%j.out 5 | #SBATCH --error=slurm/slurm_%j.err 6 | #SBATCH --time=23:00:00 7 | #SBATCH --qos=medium 8 | #SBATCH --gres=gpu:1 9 | conda activate sfm 10 | srun -u python -m src.train experiment=enhancer_fly_sfm_cnn trainer.max_epochs=800 trainer=gpu seed=$1 logger=wandb -------------------------------------------------------------------------------- /script/submit_eval_fly.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -N 1 # nodes requested 3 | #SBATCH --job-name=sfm_eval 4 | #SBATCH --output=slurm/slurm_%j.out 5 | #SBATCH --error=slurm/slurm_%j.err 6 | #SBATCH --time=00:45:00 7 | #SBATCH --qos=short 8 | #SBATCH --gres=gpu:1 9 | conda activate sfm 10 | srun -u python -m src.train experiment=enhancer_fly_sfm_cnn model.eval_fbd=true model.eval_ppl=true trainer=gpu logger=wandb data.batch_size=512 trainer.max_epochs=$2 ckpt_path=$1 -------------------------------------------------------------------------------- /script/submit_eval_mel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -N 1 # nodes requested 3 | #SBATCH --job-name=sfm_eval 4 | #SBATCH --output=slurm/slurm_%j.out 5 | #SBATCH --error=slurm/slurm_%j.err 6 | #SBATCH --time=00:45:00 7 | #SBATCH --qos=short 8 | #SBATCH --gres=gpu:1 9 | conda activate sfm 10 | srun -u python -m src.train experiment=enhancer_mel_sfm_cnn model.eval_fbd=true model.eval_ppl=true trainer=gpu logger=wandb data.batch_size=512 trainer.max_epochs=$2 ckpt_path=$1 -------------------------------------------------------------------------------- /script/submit_eval_promoter.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -N 1 # nodes requested 3 | #SBATCH --job-name=sfm 4 | #SBATCH --output=slurm/slurm_%j.out 5 | #SBATCH --error=slurm/slurm_%j.err 6 | #SBATCH --time=01:00:00 7 | #SBATCH --qos=short 8 | #SBATCH --gres=gpu:1 9 | conda activate sfm 10 | srun -u python -m src.train experiment=promoter_sfm_promdfm ckpt_path=$1 trainer=gpu trainer.max_epochs=$2 logger=wandb -------------------------------------------------------------------------------- /script/submit_promoter_sfm_bmlp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -N 1 # nodes requested 3 | #SBATCH --job-name=sfm 4 | #SBATCH --output=slurm/slurm_%j.out 5 | #SBATCH --error=slurm/slurm_%j.err 6 | #SBATCH --time=23:00:00 7 | #SBATCH --qos=medium 8 | #SBATCH --gres=gpu:1 9 | conda activate sfm 10 | srun -u python -m src.train experiment=promoter_sfm_cnn trainer=gpu logger=wandb data.batch_size=512 trainer.max_epochs=200 -------------------------------------------------------------------------------- /script/submit_promoter_sfm_promdfm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -N 1 # nodes requested 3 | #SBATCH --job-name=sfm 4 | #SBATCH --output=slurm/slurm_%j.out 5 | #SBATCH --error=slurm/slurm_%j.err 6 | #SBATCH --time=23:00:00 7 | #SBATCH --mem=36000 8 | #SBATCH --qos=medium 9 | #SBATCH --gres=gpu:4 10 | conda activate sfm 11 | # srun -u python main.py -e dfm_toy -c config/toy_dfm/bmlp.yml -m sphere --wandb 12 | python -m src.train experiment=promoter_sfm_promdfm seed=$1 trainer.max_epochs=200 trainer=ddp trainer.devices=4 logger=wandb data.batch_size=512 13 | -------------------------------------------------------------------------------- /script/submit_qm_clean.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -N 1 # nodes requested 3 | #SBATCH --job-name=sfm 4 | #SBATCH --output=slurm/slurm_%j.out 5 | #SBATCH --error=slurm/slurm_%j.err 6 | #SBATCH --time=24:00:00 7 | #SBATCH --mem=36000 8 | #SBATCH --qos=medium 9 | #SBATCH --gres=gpu:1 10 | conda activate sfm 11 | srun -u python -m src.train experiment=$2 trainer=gpu logger=wandb seed=$1 12 | -------------------------------------------------------------------------------- /script/submit_qm_sfm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -N 1 # nodes requested 3 | #SBATCH --job-name=sfm 4 | #SBATCH --output=slurm/slurm_%j.out 5 | #SBATCH --error=slurm/slurm_%j.err 6 | #SBATCH --time=01:00:00 7 | #SBATCH --mem=36000 8 | #SBATCH --qos=medium 9 | #SBATCH --gres=gpu:2 10 | conda activate sfm 11 | python -m src.train experiment=qm_vecfield_sfm seed=$1 data.batch_size=384 trainer.max_epochs=500 trainer=ddp trainer.devices=2 logger=wandb 12 | -------------------------------------------------------------------------------- /script/submit_qm_sfm_single.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -N 1 # nodes requested 3 | #SBATCH --job-name=sfm 4 | #SBATCH --output=slurm/slurm_%j.out 5 | #SBATCH --error=slurm/slurm_%j.err 6 | #SBATCH --time=01:00:00 7 | #SBATCH --mem=36000 8 | #SBATCH --qos=medium 9 | #SBATCH --gres=gpu:1 10 | conda activate sfm 11 | python -m src.train experiment=qm_vecfield_sfm seed=$1 data.batch_size=256 trainer.max_epochs=500 trainer=gpu logger=wandb 12 | -------------------------------------------------------------------------------- /script/submit_rfm_cnn_toy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -N 1 # nodes requested 3 | #SBATCH --job-name=sfm 4 | #SBATCH --output=slurm/slurm_%j.out 5 | #SBATCH --error=slurm/slurm_%j.err 6 | #SBATCH --time=23:00:00 7 | #SBATCH --qos=medium 8 | #SBATCH --gres=gpu:1 9 | echo "Here: dim = $1" 10 | conda activate sfm 11 | srun -u python -m src.train experiment=toy_rfm_cnn data.dim=$1 trainer=gpu logger=wandb -------------------------------------------------------------------------------- /script/submit_tmlp_sfm_promoter.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -N 1 # nodes requested 3 | #SBATCH --job-name=sfm 4 | #SBATCH --output=slurm/slurm_%j.out 5 | #SBATCH --error=slurm/slurm_%j.err 6 | #SBATCH --time=23:00:00 7 | #SBATCH --qos=medium 8 | #SBATCH --gres=gpu:1 9 | conda activate sfm 10 | python -m src.train experiment=promoter_sfm_tmlp trainer.max_epochs=200 trainer=gpu logger=wandb data.batch_size=64 -------------------------------------------------------------------------------- /script/submit_unet1d_toy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -N 1 # nodes requested 3 | #SBATCH --job-name=sfm 4 | #SBATCH --output=slurm/slurm_%j.out 5 | #SBATCH --error=slurm/slurm_%j.err 6 | #SBATCH --time=23:00:00 7 | #SBATCH --qos=medium 8 | #SBATCH --gres=gpu:1 9 | echo "Here: dim = $1" 10 | conda activate sfm 11 | srun -u python -m src.train experiment=toy_dfm_sfm_unet1d data.dim=$1 trainer=gpu trainer.max_epochs=1000 logger=wandb data.batch_size=1024 12 | -------------------------------------------------------------------------------- /script/submit_unet_sfm_promoter.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -N 1 # nodes requested 3 | #SBATCH --job-name=sfm 4 | #SBATCH --output=slurm/slurm_%j.out 5 | #SBATCH --error=slurm/slurm_%j.err 6 | #SBATCH --time=23:00:00 7 | #SBATCH --qos=medium 8 | #SBATCH --gres=gpu:4 9 | conda activate sfm 10 | python -m src.train experiment=promoter_sfm_unet1d model.net.depth=$1 model.net.filters=$2 trainer=ddp trainer.devices=4 data.batch_size=2048 logger=wandb -------------------------------------------------------------------------------- /script/toy_all_bmlp_sfm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | dims=(5 10 20 40 60 80 100 120 140 160) 3 | for dim in ${dims[@]}; do 4 | sbatch ./script/submit_bmlp_sfm.sh $dim $1 5 | done 6 | -------------------------------------------------------------------------------- /script/toy_all_bmlp_sfm_noot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | dims=(5 10 20 40 60 80 100 120 140 160) 3 | for dim in ${dims[@]}; do 4 | sbatch ./script/submit_bmlp_sfm_noot.sh $dim $1 5 | done 6 | -------------------------------------------------------------------------------- /script/toy_all_bmlp_sfm_simplex_noot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | dims=(5 10 20 40 60 80 100 120 140 160) 3 | for dim in ${dims[@]}; do 4 | sbatch ./script/submit_bmlp_sfm_simplex_noot.sh $dim $1 5 | done 6 | -------------------------------------------------------------------------------- /script/toy_all_bmlp_sfm_simplex_ot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | dims=(5 10 20 40 60 80 100 120 140 160) 3 | for dim in ${dims[@]}; do 4 | sbatch ./script/submit_bmlp_sfm_simplex_ot.sh $dim $1 5 | done 6 | -------------------------------------------------------------------------------- /script/toy_all_cnn_dfm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | dims=(5 10 20 40 60 80 100 120 140 160) 3 | for dim in ${dims[@]}; do 4 | sbatch ./script/submit_cnn_dfm_toy.sh $dim 5 | done 6 | -------------------------------------------------------------------------------- /script/toy_all_cnn_sfm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # dims=(5 10 20 40 60 80 100 120 140 160) 3 | dims=(80 100 120 140 160) 4 | for dim in ${dims[@]}; do 5 | sbatch ./script/submit_cnn_sfm_toy.sh $dim 6 | done 7 | -------------------------------------------------------------------------------- /script/toy_bmlp_lin.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | dims=(5 10 20 40 60 80 100 120 140 160) 3 | for dim in ${dims[@]}; do 4 | sbatch ./script/submit_bmlp_sfm_lin.sh $dim $1 5 | done 6 | -------------------------------------------------------------------------------- /script/toy_cnn_sfm_lsmooth.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | dims=(0 0.7 0.8 0.9 0.95 0.999 0.9999) 3 | for dim in ${dims[@]}; do 4 | sbatch ./script/submit_cnn_sfm_toy_lsmooth.sh 100 $dim 5 | done 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import find_packages, setup 4 | 5 | setup( 6 | name="fisher-flow", 7 | version="1.0.0", 8 | description="Fisher Flow Matching official implementation", 9 | author="Oscar Davis, Samuel Kessler, Joey Bose", 10 | author_email="oscar.davis@cs.ox.ac.uk", 11 | install_requires=["lightning", "hydra-core"], 12 | packages=find_packages(), 13 | # use this to customize global commands available in the terminal after installing the package 14 | entry_points={ 15 | "console_scripts": [ 16 | "train_command = src.train:main", 17 | "eval_command = src.eval:main", 18 | ] 19 | }, 20 | ) 21 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olsdavis/fisher-flow/8102f853e3b4c0f29f3a700959ec86e426fa86e8/src/__init__.py -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .toy_dfm_datamodule import * 2 | from .dna_enhancer_datamodule import * 3 | from .promoter_datamodule import * 4 | from .qm9_datamodule import * 5 | -------------------------------------------------------------------------------- /src/data/components/__init__.py: -------------------------------------------------------------------------------- 1 | from .promoter_back import * 2 | from .promoter_eval import * 3 | from .sei import * 4 | from .molecule import * 5 | from .molecule_prior import * 6 | from .qm_utils import * 7 | from .molecule_builder import * 8 | -------------------------------------------------------------------------------- /src/data/components/fbd.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | import torch 5 | from torch import Tensor 6 | import numpy as np 7 | 8 | 9 | from src.sfm import get_wasserstein_dist 10 | from src.models.net import CNNModel 11 | 12 | 13 | def _upgrade_state_dict(state_dict, prefixes=["encoder.sentence_encoder.", "encoder."]): 14 | """Removes prefixes 'model.encoder.sentence_encoder.' and 'model.encoder.'.""" 15 | pattern = re.compile("^" + "|".join(prefixes)) 16 | state_dict = {pattern.sub("", name): param for name, param in state_dict.items()} 17 | return state_dict 18 | 19 | 20 | class FBD(torch.nn.Module): 21 | """ 22 | Class to evaluate the Fréchet Biological Distance (FBD). 23 | """ 24 | 25 | def __init__( 26 | self, 27 | dim: int, 28 | k: int, 29 | hidden: int, 30 | num_cls: int, 31 | depth: int, 32 | ckpt_path: str, 33 | ): 34 | super().__init__() 35 | self.cls_model = CNNModel( 36 | dim=dim, 37 | k=k, 38 | num_cls=num_cls, 39 | hidden=hidden, 40 | mode="", # ignored 41 | depth=depth, 42 | dropout=0.0, 43 | prior_pseudocount=2.0, # unused 44 | cls_expanded_simplex=False, 45 | clean_data=True, 46 | classifier=True, 47 | classifier_free_guidance=False, 48 | ) 49 | self.cls_model.load_state_dict( 50 | _upgrade_state_dict( 51 | torch.load(ckpt_path, map_location="cpu")['state_dict'], prefixes=['model.'], 52 | ) 53 | ) 54 | self.cache = {} 55 | 56 | @torch.inference_mode() 57 | def forward(self, x: Tensor, gt: Tensor, batch_index: int | None) -> np.ndarray | float: 58 | self.cls_model.eval() 59 | if batch_index is not None and batch_index in self.cache: 60 | gt_embeddings = self.cache[batch_index] 61 | else: 62 | gt_embeddings = self.cls_model(gt, t=None, return_embedding=True)[1].cpu().numpy() 63 | if batch_index is not None: 64 | self.cache[batch_index] = gt_embeddings 65 | embeddings = self.cls_model(x, t=None, return_embedding=True)[1].cpu().numpy() 66 | return get_wasserstein_dist(embeddings, gt_embeddings) 67 | -------------------------------------------------------------------------------- /src/data/components/ff_energy.py: -------------------------------------------------------------------------------- 1 | from rdkit.Chem import AllChem as Chem 2 | from typing import List 3 | 4 | def compute_uff_energy(mol): 5 | ff = Chem.UFFGetMoleculeForceField(mol, ignoreInterfragInteractions=False) 6 | return ff.CalcEnergy() 7 | 8 | def compute_mmff_energy(mol): 9 | try: 10 | ff = Chem.MMFFGetMoleculeForceField(mol, Chem.MMFFGetMoleculeProperties(mol), ignoreInterfragInteractions=False) 11 | except Exception as e: 12 | print(e) 13 | print('Failed to get force-field object') 14 | return None 15 | if ff: 16 | return ff.CalcEnergy() 17 | return None 18 | -------------------------------------------------------------------------------- /src/data/components/molecule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pathlib import Path 3 | import dgl 4 | from torch.nn.functional import one_hot 5 | from .molecule_prior import align_prior 6 | 7 | # create a function named collate that takes a list of samples from the dataset and combines them into a batch 8 | # this might not be necessary. I think we can pass the argument collate_fn=dgl.batch to the DataLoader 9 | def collate(graphs): 10 | return dgl.batch(graphs) 11 | 12 | class MoleculeDataset(torch.utils.data.Dataset): 13 | 14 | def __init__(self, split: str, dataset_config: dict, prior_config: dict): 15 | super(MoleculeDataset, self).__init__() 16 | 17 | # unpack some configs regarding the prior 18 | self.prior_config = prior_config 19 | self.dataset_config = dataset_config 20 | 21 | # get the processed data directory 22 | processed_data_dir: Path = Path(dataset_config['processed_data_dir']) 23 | 24 | # load the marginal distributions of atom types and the conditional distribution of charges given atom type 25 | marginal_dists_file = processed_data_dir / 'train_data_marginal_dists.pt' 26 | p_a, p_c, p_e, p_c_given_a = torch.load(marginal_dists_file) 27 | 28 | # add the marginal distributions as arguments to the prior sampling functions 29 | if self.prior_config['a']['type'] == 'marginal': 30 | self.prior_config['a']['kwargs']['p'] = p_a 31 | 32 | if self.prior_config['e']['type'] == 'marginal': 33 | self.prior_config['e']['kwargs']['p'] = p_e 34 | 35 | if self.prior_config['c']['type'] == 'marginal': 36 | self.prior_config['c']['kwargs']['p'] = p_c 37 | 38 | if self.prior_config['c']['type'] == 'c-given-a': 39 | self.prior_config['c']['kwargs']['p_c_given_a'] = p_c_given_a 40 | 41 | if dataset_config['dataset_name'] in ['geom', 'qm9', 'geom_5conf']: 42 | data_file = processed_data_dir / f'{split}_data_processed.pt' 43 | else: 44 | raise NotImplementedError('unsupported dataset_name') 45 | 46 | # load data from processed data directory 47 | data_dict = torch.load(data_file) 48 | 49 | self.positions = data_dict['positions'] 50 | self.atom_types = data_dict['atom_types'] 51 | self.atom_charges = data_dict['atom_charges'] 52 | self.bond_types = data_dict['bond_types'] 53 | self.bond_idxs = data_dict['bond_idxs'] 54 | self.node_idx_array = data_dict['node_idx_array'] 55 | self.edge_idx_array = data_dict['edge_idx_array'] 56 | 57 | def __len__(self): 58 | return self.node_idx_array.shape[0] 59 | 60 | def __getitem__(self, idx): 61 | node_start_idx = self.node_idx_array[idx, 0] 62 | node_end_idx = self.node_idx_array[idx, 1] 63 | edge_start_idx = self.edge_idx_array[idx, 0] 64 | edge_end_idx = self.edge_idx_array[idx, 1] 65 | 66 | # get data pertaining to nodes for this molecule 67 | positions = self.positions[node_start_idx:node_end_idx] 68 | atom_types = self.atom_types[node_start_idx:node_end_idx].float() 69 | atom_charges = self.atom_charges[node_start_idx:node_end_idx].long() 70 | 71 | # remove COM from positions 72 | positions = positions - positions.mean(dim=0, keepdim=True) 73 | 74 | # get data pertaining to edges for this molecule 75 | bond_types = self.bond_types[edge_start_idx:edge_end_idx].int() 76 | bond_idxs = self.bond_idxs[edge_start_idx:edge_end_idx].long() 77 | 78 | # reconstruct adjacency matrix 79 | n_atoms = positions.shape[0] 80 | adj = torch.zeros((n_atoms, n_atoms), dtype=torch.int32) 81 | 82 | # fill in the values of the adjacency matrix specified by bond_idxs 83 | adj[bond_idxs[:,0], bond_idxs[:,1]] = bond_types 84 | 85 | # get upper triangle of adjacency matrix 86 | upper_edge_idxs = torch.triu_indices(n_atoms, n_atoms, offset=1) # has shape (2, n_upper_edges) 87 | upper_edge_labels = adj[upper_edge_idxs[0], upper_edge_idxs[1]] 88 | 89 | # get lower triangle edges by swapping source and destination of upper_edge_idxs 90 | lower_edge_idxs = torch.stack((upper_edge_idxs[1], upper_edge_idxs[0])) 91 | 92 | edges = torch.cat((upper_edge_idxs, lower_edge_idxs), dim=1) 93 | edge_labels = torch.cat((upper_edge_labels, upper_edge_labels)) 94 | 95 | # one-hot encode edge labels and atom charges 96 | edge_labels = one_hot(edge_labels.to(torch.int64), num_classes=5).float() # hard-coded assumption of 5 bond types 97 | try: 98 | atom_charges = one_hot(atom_charges + 2, num_classes=6).float() # hard-coded assumption that charges are in range [-2, 3] 99 | except Exception as e: 100 | print('an atom charge outside of the expected range was encountered') 101 | print(f'max atom charge: {atom_charges.max()}, min atom charge: {atom_charges.min()}') 102 | raise e 103 | 104 | # create a dgl graph 105 | g = dgl.graph((edges[0], edges[1]), num_nodes=n_atoms) 106 | 107 | # add edge features 108 | g.edata['e_1_true'] = edge_labels 109 | 110 | # add node features 111 | g.ndata['x_1_true'] = positions 112 | g.ndata['a_1_true'] = atom_types 113 | g.ndata['c_1_true'] = atom_charges 114 | 115 | # sample prior for node features, coupled to the destination features 116 | """dst_dict = { 117 | 'x': positions, 118 | 'a': atom_types, 119 | 'c': atom_charges 120 | }""" 121 | # prior_node_feats = coupled_node_prior(dst_dict=dst_dict, prior_config=self.prior_config) 122 | # for feat in prior_node_feats: 123 | #  g.ndata[f'{feat}_0'] = prior_node_feats[feat] 124 | g.ndata["x_0"] = align_prior( 125 | torch.randn_like(positions), 126 | positions, 127 | permutation=True, 128 | rigid_body=True, 129 | ) 130 | 131 | return g 132 | -------------------------------------------------------------------------------- /src/data/components/promoter_eval.py: -------------------------------------------------------------------------------- 1 | # Adaptation from: https://github.com/HannesStark/dirichlet-flow-matching/blob/main/lightning_modules/promoter_module.py 2 | import re 3 | import pandas as pd 4 | import torch 5 | from torch import Tensor 6 | from selene_sdk.utils import NonStrandSpecific 7 | from .sei import Sei 8 | 9 | 10 | def upgrade_state_dict(state_dict, prefixes=["encoder.sentence_encoder.", "encoder."]): 11 | """Removes prefixes 'model.encoder.sentence_encoder.' and 'model.encoder.'.""" 12 | pattern = re.compile("^" + "|".join(prefixes)) 13 | state_dict = {pattern.sub("", name): param for name, param in state_dict.items()} 14 | return state_dict 15 | 16 | 17 | class SeiEval: 18 | """Singleton class for SEI evaluation.""" 19 | _instance = None 20 | 21 | def __new__(cls): 22 | if cls._instance is None: 23 | cls._instance = super().__new__(cls) 24 | cls._instance._sei = NonStrandSpecific(Sei(4096, 21907)) 25 | cls._instance._sei_features = ( 26 | pd.read_csv('data/promoter_design/target.sei.names', sep='|', header=None) 27 | ) 28 | cls._instance._sei_loaded = False 29 | cls._instance._sei_cache = {} 30 | return cls._instance 31 | 32 | @torch.no_grad() 33 | def get_sei_profile(self, seq_one_hot: Tensor) -> Tensor: 34 | """ 35 | Get the SEI profile from the one-hot encoded sequence. 36 | 37 | Parameters: 38 | - `seq_one_hot`: The one-hot encoded sequence tensor. 39 | 40 | Returns: 41 | The SEI profile tensor. 42 | """ 43 | if not self._sei_loaded: 44 | self._sei.load_state_dict( 45 | upgrade_state_dict( 46 | torch.load('data/promoter_design/best.sei.model.pth.tar', map_location="cpu")['state_dict'], 47 | prefixes=["module."], 48 | ) 49 | ) 50 | self._sei.to(seq_one_hot.device) 51 | self._sei.eval() 52 | self._sei_loaded = True 53 | B, _, _ = seq_one_hot.shape 54 | sei_inp = torch.cat([torch.ones((B, 4, 1536), device=seq_one_hot.device) * 0.25, 55 | seq_one_hot.transpose(1, 2), 56 | torch.ones((B, 4, 1536), device=seq_one_hot.device) * 0.25], 2) # batchsize x 4 x 4,096 57 | sei_out = self._sei(sei_inp).cpu().detach().numpy() # batchsize x 21,907 58 | sei_out = sei_out[:, self._sei_features[1].str.strip().values == 'H3K4me3'] # batchsize x 2,350 59 | predh3k4me3 = sei_out.mean(axis=1) # batchsize 60 | return predh3k4me3 61 | 62 | def eval_sp_mse(self, seq_one_hot: Tensor, target: Tensor, b_index: int | None = None) -> Tensor: 63 | """ 64 | Evaluate the mean squared error of the SEI profile prediction. 65 | 66 | Parameters: 67 | - `seq_one_hot`: The one-hot encoded sequence tensor. 68 | - `target`: The target tensor; 69 | - `b_index`: The batch index of the target Tensor; avoids recalculating 70 | the profile all the time; if `None` always calculates profile (useful 71 | for testing). 72 | 73 | Returns: 74 | The mean squared error tensor. 75 | """ 76 | if b_index is not None and b_index in self._sei_cache: 77 | target_prof = self._sei_cache[b_index] 78 | else: 79 | target_prof = self.get_sei_profile(target) 80 | self._sei_cache[b_index] = target_prof 81 | pred_prof = self.get_sei_profile(seq_one_hot) 82 | return (pred_prof - target_prof) ** 2 83 | -------------------------------------------------------------------------------- /src/data/components/qm_utils.py: -------------------------------------------------------------------------------- 1 | """All utilities taken from https://github.com/Dunni3/FlowMol.""" 2 | import torch 3 | import dgl 4 | 5 | def build_edge_idxs(n_atoms: int): 6 | """Builds an array of edge indices for a molecule with n_atoms. 7 | 8 | The edge indicies are constructed such that the upper-triangle of the adjacency matrix is traversed before the lower triangle. 9 | Much of our infrastructure relies on this particular ordering of edge indicies within our graph objects. 10 | """ 11 | # get upper triangle of adjacency matrix 12 | upper_edge_idxs = torch.triu_indices(n_atoms, n_atoms, offset=1) 13 | 14 | # get lower triangle edges by swapping source and destination of upper_edge_idxs 15 | lower_edge_idxs = torch.stack((upper_edge_idxs[1], upper_edge_idxs[0])) 16 | 17 | edges = torch.cat((upper_edge_idxs, lower_edge_idxs), dim=1) 18 | return edges 19 | 20 | def get_upper_edge_mask(g: dgl.DGLGraph): 21 | """Returns a boolean mask for the edges that lie in the upper triangle of the adjacency matrix for each molecule in the batch.""" 22 | # this algorithm assumes that the edges are ordered such that the upper triangle edges come first, followed by the lower triangle edges for each graph in the batch 23 | # and then those graph-wise edges are concatenated together 24 | # you can see that this is indeed how the edges are constructed by inspecting data_processing.dataset.MoleculeDataset.__getitem__ 25 | edges_per_mol = g.batch_num_edges() 26 | ul_pattern = torch.tensor([1,0]).repeat(g.batch_size).to(g.device) 27 | n_edges_pattern = (edges_per_mol/2).int().repeat_interleave(2) 28 | upper_edge_mask = ul_pattern.repeat_interleave(n_edges_pattern).bool() 29 | return upper_edge_mask 30 | 31 | def get_node_batch_idxs(g: dgl.DGLGraph): 32 | """Returns a tensor of integers indicating which molecule each node belongs to.""" 33 | node_batch_idx = torch.arange(g.batch_size, device=g.device) 34 | node_batch_idx = node_batch_idx.repeat_interleave(g.batch_num_nodes()) 35 | return node_batch_idx 36 | 37 | def get_edge_batch_idxs(g: dgl.DGLGraph): 38 | """Returns a tensor of integers indicating which molecule each edge belongs to.""" 39 | edge_batch_idx = torch.arange(g.batch_size, device=g.device) 40 | edge_batch_idx = edge_batch_idx.repeat_interleave(g.batch_num_edges()) 41 | return edge_batch_idx 42 | 43 | def get_batch_idxs(g: dgl.DGLGraph): 44 | """Returns two tensors of integers indicating which molecule each node and edge belongs to.""" 45 | node_batch_idx = get_node_batch_idxs(g) 46 | edge_batch_idx = get_edge_batch_idxs(g) 47 | return node_batch_idx, edge_batch_idx 48 | -------------------------------------------------------------------------------- /src/data/components/sample_analyzer.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from typing import List 3 | from .molecule_builder import SampledMolecule 4 | from .ff_energy import compute_mmff_energy 5 | from rdkit import Chem 6 | 7 | 8 | allowed_bonds = {'H': {0: 1, 1: 0, -1: 0}, 9 | 'C': {0: [3, 4], 1: 3, -1: 3}, 10 | 'N': {0: [2, 3], 1: [2, 3, 4], -1: 2}, # In QM9, N+ seems to be present in the form NH+ and NH2+ 11 | 'O': {0: 2, 1: 3, -1: 1}, 12 | 'F': {0: 1, -1: 0}, 13 | 'B': 3, 'Al': 3, 'Si': 4, 14 | 'P': {0: [3, 5], 1: 4}, 15 | 'S': {0: [2, 6], 1: [2, 3], 2: 4, 3: 5, -1: 3}, 16 | 'Cl': 1, 'As': 3, 17 | 'Br': {0: 1, 1: 2}, 'I': 1, 'Hg': [1, 2], 'Bi': [3, 5], 'Se': [2, 4, 6]} 18 | bond_dict = [None, Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, 19 | Chem.rdchem.BondType.AROMATIC] 20 | 21 | 22 | class SampleAnalyzer(): 23 | 24 | def __init__(self, processed_data_dir: str = None): 25 | 26 | self.processed_data_dir = processed_data_dir 27 | if self.processed_data_dir is not None: 28 | energy_dist_file = self.processed_data_dir / 'energy_dist.npz' 29 | self.energy_div_calculator = DivergenceCalculator(energy_dist_file) 30 | 31 | 32 | def analyze(self, sampled_molecules: List[SampledMolecule], return_counts: bool = False): 33 | print("Molecules: ", "\n".join([str(m) for m in sampled_molecules])) 34 | 35 | # compute the atom-level stabiltiy of a molecule. this is the number of atoms that have valid valencies. 36 | # note that since is computed at the atom level, even if the entire molecule is unstable, we can still get an idea 37 | # of how close the molecule is to being stable. 38 | n_atoms = 0 39 | n_stable_atoms = 0 40 | n_stable_molecules = 0 41 | n_molecules = len(sampled_molecules) 42 | for molecule in sampled_molecules: 43 | n_atoms += molecule.num_atoms 44 | n_stable_atoms_this_mol, mol_stable = check_stability(molecule) 45 | n_stable_atoms += n_stable_atoms_this_mol 46 | n_stable_molecules += int(mol_stable) 47 | 48 | frac_atoms_stable = n_stable_atoms / n_atoms # the fraction of generated atoms that have valid valencies 49 | frac_mols_stable_valence = n_stable_molecules / n_molecules # the fraction of generated molecules whose atoms all have valid valencies 50 | 51 | # compute validity as determined by rdkit, and the average size of the largest fragment, and the average number of fragments 52 | validity_result = self.compute_validity(sampled_molecules, return_counts=return_counts) 53 | if return_counts: 54 | frac_valid_mols, avg_frag_frac, avg_num_components, n_valid, sum_frag_fracs, n_frag_fracs, sum_num_components, n_num_components = validity_result 55 | else: 56 | frac_valid_mols, avg_frag_frac, avg_num_components = validity_result 57 | 58 | metrics_dict = { 59 | 'frac_atoms_stable': frac_atoms_stable, 60 | 'frac_mols_stable_valence': frac_mols_stable_valence, 61 | 'frac_valid_mols': frac_valid_mols, 62 | 'avg_frag_frac': avg_frag_frac, 63 | 'avg_num_components': avg_num_components 64 | } 65 | 66 | if return_counts: 67 | counts_dict = {} 68 | counts_dict['n_stable_atoms'] = n_stable_atoms 69 | counts_dict['n_atoms'] = n_atoms 70 | counts_dict['n_stable_molecules'] = n_stable_molecules 71 | counts_dict['n_molecules'] = n_molecules 72 | counts_dict['n_valid'] = n_valid 73 | counts_dict['sum_frag_fracs'] = sum_frag_fracs 74 | counts_dict['n_frag_fracs'] = n_frag_fracs 75 | counts_dict['sum_num_components'] = sum_num_components 76 | counts_dict['n_num_components'] = n_num_components 77 | return counts_dict 78 | 79 | 80 | return metrics_dict 81 | 82 | # this function taken from MiDi molecular_metrics.py script 83 | def compute_validity(self, sampled_molecules: List[SampledMolecule], return_counts: bool = False): 84 | """ generated: list of couples (positions, atom_types)""" 85 | n_valid = 0 86 | num_components = [] 87 | frag_fracs = [] 88 | error_message = Counter() 89 | for mol in sampled_molecules: 90 | rdmol = mol.rdkit_mol 91 | if rdmol is not None: 92 | try: 93 | mol_frags = Chem.rdmolops.GetMolFrags(rdmol, asMols=True, sanitizeFrags=False) 94 | num_components.append(len(mol_frags)) 95 | if len(mol_frags) > 1: 96 | error_message[4] += 1 97 | largest_mol = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms()) 98 | largest_mol_n_atoms = largest_mol.GetNumAtoms() 99 | largest_frag_frac = largest_mol_n_atoms / mol.num_atoms 100 | frag_fracs.append(largest_frag_frac) 101 | Chem.SanitizeMol(largest_mol) 102 | smiles = Chem.MolToSmiles(largest_mol) 103 | n_valid += 1 104 | error_message[-1] += 1 105 | except Chem.rdchem.AtomValenceException: 106 | error_message[1] += 1 107 | # print("Valence error in GetmolFrags") 108 | except Chem.rdchem.KekulizeException: 109 | error_message[2] += 1 110 | # print("Can't kekulize molecule") 111 | except Chem.rdchem.AtomKekulizeException or ValueError: 112 | error_message[3] += 1 113 | print(f"Error messages: AtomValence {error_message[1]}, Kekulize {error_message[2]}, other {error_message[3]}, " 114 | f" -- No error {error_message[-1]}") 115 | 116 | 117 | frac_valid_mols = n_valid / len(sampled_molecules) 118 | avg_frag_frac = sum(frag_fracs) / len(frag_fracs) 119 | avg_num_components = sum(num_components) / len(num_components) 120 | 121 | if return_counts: 122 | return frac_valid_mols, avg_frag_frac, avg_num_components, n_valid, sum(frag_fracs), len(frag_fracs), sum(num_components), len(num_components) 123 | 124 | return frac_valid_mols, avg_frag_frac, avg_num_components 125 | 126 | def compute_sample_energy(self, samples: List[SampledMolecule]): 127 | """ samples: list of SampledMolecule objects. """ 128 | energies = [] 129 | for sample in samples: 130 | rdmol = sample.rdkit_mol 131 | if rdmol is not None: 132 | try: 133 | Chem.SanitizeMol(rdmol) 134 | except: 135 | continue 136 | energy = compute_mmff_energy(rdmol) 137 | if energy is not None: 138 | energies.append(energy) 139 | 140 | return energies 141 | 142 | def compute_energy_divergence(self, samples: List[SampledMolecule]): 143 | 144 | if self.processed_data_dir is None: 145 | raise ValueError('You must specify processed_data_dir upon initialization to compute energy divergences') 146 | 147 | # compute the FF energy of each molecule 148 | energies = self.compute_sample_energy(samples) 149 | 150 | # compute the Jensen-Shannon divergence between the energy distribution of the samples and the training set 151 | js_div = self.energy_div_calculator.js_divergence(energies) 152 | 153 | return js_div 154 | 155 | 156 | def check_stability(molecule: SampledMolecule): 157 | """ molecule: Molecule object. """ 158 | atom_types = molecule.atom_types 159 | # edge_types = molecule.bond_types 160 | 161 | valencies = molecule.valencies 162 | 163 | n_stable_atoms = 0 164 | mol_stable = True 165 | for i, (atom_type, valency, charge) in enumerate(zip(atom_types, valencies, molecule.atom_charges)): 166 | valency = int(valency) 167 | charge = int(charge) 168 | possible_bonds = allowed_bonds[atom_type] 169 | if type(possible_bonds) == int: 170 | is_stable = possible_bonds == valency 171 | elif type(possible_bonds) == dict: 172 | expected_bonds = possible_bonds[charge] if charge in possible_bonds.keys() else possible_bonds[0] 173 | is_stable = expected_bonds == valency if type(expected_bonds) == int else valency in expected_bonds 174 | else: 175 | is_stable = valency in possible_bonds 176 | if not is_stable: 177 | mol_stable = False 178 | n_stable_atoms += int(is_stable) 179 | 180 | return n_stable_atoms, mol_stable 181 | -------------------------------------------------------------------------------- /src/data/dna_enhancer_datamodule.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import pickle 3 | import os 4 | import torch 5 | from torch.utils.data import DataLoader, Dataset, TensorDataset 6 | from lightning import LightningDataModule 7 | 8 | 9 | class DNAEnhancerDataModule(LightningDataModule): 10 | """ 11 | DNA Enhancer data module. 12 | """ 13 | 14 | def __init__( 15 | self, 16 | dataset: str = "MEL2", 17 | data_dir: str = "data/enhancer/the_code/General/data/Deep", 18 | train_val_test_split: tuple[int, int, int] = (55_000, 5_000, 10_000), 19 | batch_size: int = 64, 20 | num_workers: int = 0, 21 | pin_memory: bool = False, 22 | ): 23 | """Initialize a `DNAEnhancerDataModule`. 24 | 25 | :param dataset: The dataset to use, choices: "MEL2", "FlyBrain". Defaults to `"MEL2"`. 26 | :param data_dir: The data directory. 27 | :param train_val_test_split: Not used. The train, validation and test split. Defaults to `(55_000, 5_000, 10_000)`. 28 | :param batch_size: The batch size. Defaults to `64`. 29 | :param num_workers: The number of workers. Defaults to `0`. 30 | :param pin_memory: Whether to pin memory. Defaults to `False`. 31 | """ 32 | super().__init__() 33 | assert dataset in ["MEL2", "FlyBrain"], f"Invalid dataset '{dataset}'. Choose from 'MEL2', 'FlyBrain'." 34 | self.dataset = dataset 35 | 36 | # this line allows to access init params with 'self.hparams' attribute 37 | # also ensures init params will be stored in ckpt 38 | self.save_hyperparameters(logger=False) 39 | 40 | self.data_train: Dataset | None = None 41 | self.data_val: Dataset | None = None 42 | self.data_test: Dataset | None = None 43 | 44 | self.batch_size_per_device = batch_size 45 | 46 | def prepare_data(self): 47 | """Prepare data.""" 48 | 49 | def setup(self, stage: str | None = None) -> None: 50 | """ 51 | Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. 52 | """ 53 | all_data = pickle.load( 54 | open(f"{self.hparams.data_dir}{self.dataset}_data.pkl", "rb") 55 | ) 56 | for split in ["train", "valid", "test"]: 57 | data = torch.nn.functional.one_hot( 58 | torch.from_numpy(all_data[f"{split}_data"]).argmax(dim=-1), 59 | num_classes=4, 60 | ).float() 61 | clss = torch.from_numpy(all_data[f"y_{split}"]).argmax(dim=-1).float() 62 | print(data.shape, clss.shape) 63 | dataset = TensorDataset(data, clss) 64 | setattr( 65 | self, 66 | f"data_{split}", 67 | dataset 68 | ) 69 | # Divide batch size by the number of devices. 70 | if self.trainer is not None: 71 | if self.hparams.batch_size % self.trainer.world_size != 0: 72 | raise RuntimeError( 73 | f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})." 74 | ) 75 | self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size 76 | 77 | def train_dataloader(self) -> DataLoader[Any]: 78 | """Create and return the train dataloader. 79 | 80 | :return: The train dataloader. 81 | """ 82 | assert self.data_train 83 | return DataLoader( 84 | dataset=self.data_train, 85 | batch_size=self.batch_size_per_device, 86 | num_workers=self.hparams.num_workers, 87 | pin_memory=self.hparams.pin_memory, 88 | ) 89 | 90 | def val_dataloader(self) -> DataLoader[Any]: 91 | """Create and return the validation dataloader. 92 | 93 | :return: The validation dataloader. 94 | """ 95 | assert self.data_valid 96 | return DataLoader( 97 | dataset=self.data_valid, 98 | batch_size=self.batch_size_per_device, 99 | num_workers=self.hparams.num_workers, 100 | pin_memory=self.hparams.pin_memory, 101 | shuffle=False, 102 | ) 103 | 104 | def test_dataloader(self) -> DataLoader[Any]: 105 | """Create and return the test dataloader. 106 | 107 | :return: The test dataloader. 108 | """ 109 | assert self.data_test 110 | return DataLoader( 111 | dataset=self.data_test, 112 | batch_size=self.batch_size_per_device, 113 | num_workers=self.hparams.num_workers, 114 | pin_memory=self.hparams.pin_memory, 115 | shuffle=False, 116 | ) 117 | 118 | def teardown(self, stage: str | None = None): 119 | """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`, 120 | `trainer.test()`, and `trainer.predict()`. 121 | 122 | :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. 123 | Defaults to ``None``. 124 | """ 125 | 126 | def state_dict(self) -> dict[Any, Any]: 127 | """Called when saving a checkpoint. Implement to generate and save the datamodule state. 128 | 129 | :return: A dictionary containing the datamodule state that you want to save. 130 | """ 131 | return {} 132 | 133 | def load_state_dict(self, state_dict: dict[str, Any]) -> None: 134 | """Called when loading a checkpoint. Implement to reload datamodule state given datamodule 135 | `state_dict()`. 136 | 137 | :param state_dict: The datamodule state returned by `self.state_dict()`. 138 | """ 139 | 140 | 141 | if __name__ == "__main__": 142 | _ = DNAEnhancerDataModule().prepare_data() 143 | -------------------------------------------------------------------------------- /src/data/promoter_datamodule.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from torch.utils.data import DataLoader, Dataset 3 | from lightning import LightningDataModule 4 | from .components.promoter_back import PromoterDataset 5 | 6 | """ 7 | test module loading: 8 | 9 | python -m src.data.promoter_datamodule 10 | """ 11 | 12 | class PromoterDesignDataModule(LightningDataModule): 13 | """ 14 | Promoter Design data module. 15 | """ 16 | 17 | def __init__( 18 | self, 19 | data_dir: str = "data/promoter/", 20 | train_val_test_split: tuple[int, int, int] = (55_000, 5_000, 10_000), 21 | batch_size: int = 64, 22 | num_workers: int = 0, 23 | pin_memory: bool = False, 24 | sep_x_y: bool = False, 25 | ): 26 | """Initialize a `PromoterDesignDataModule`. 27 | 28 | :param data_dir: The data directory. 29 | :param train_val_test_split: Not used. The train, validation and test split. Defaults to `(55_000, 5_000, 10_000)`. 30 | :param batch_size: The batch size. Defaults to `64`. 31 | :param num_workers: The number of workers. Defaults to `0`. 32 | :param pin_memory: Whether to pin memory. Defaults to `False`. 33 | """ 34 | super().__init__() 35 | 36 | # this line allows to access init params with 'self.hparams' attribute 37 | # also ensures init params will be stored in ckpt 38 | self.save_hyperparameters(logger=False) 39 | 40 | self.data_train: Dataset | None = None 41 | self.data_val: Dataset | None = None 42 | self.data_test: Dataset | None = None 43 | 44 | self.batch_size_per_device = batch_size 45 | 46 | self.sep_x_y = sep_x_y # whether to separate x and y in the dataset 47 | 48 | def prepare_data(self): 49 | """Nothing to download.""" 50 | 51 | def setup(self, stage: str | None = None) -> None: 52 | """ 53 | Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. 54 | """ 55 | self.data_train = PromoterDataset(n_tsses=100000, rand_offset=10, split="train", sep_x_y=self.sep_x_y) 56 | self.data_val = PromoterDataset(n_tsses=100000, rand_offset=0, split="valid", sep_x_y=self.sep_x_y) 57 | self.data_test = PromoterDataset(n_tsses=100000, rand_offset=0, split="test", sep_x_y=self.sep_x_y) 58 | # Divide batch size by the number of devices. 59 | if self.trainer is not None: 60 | if self.hparams.batch_size % self.trainer.world_size != 0: 61 | raise RuntimeError( 62 | f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})." 63 | ) 64 | self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size 65 | 66 | def train_dataloader(self) -> DataLoader[Any]: 67 | """Create and return the train dataloader. 68 | 69 | :return: The train dataloader. 70 | """ 71 | assert self.data_train 72 | return DataLoader( 73 | dataset=self.data_train, 74 | batch_size=self.batch_size_per_device, 75 | shuffle=True, 76 | num_workers=self.hparams.num_workers, 77 | pin_memory=self.hparams.pin_memory, 78 | ) 79 | 80 | def val_dataloader(self) -> DataLoader[Any]: 81 | """Create and return the validation dataloader. 82 | 83 | :return: The validation dataloader. 84 | """ 85 | assert self.data_val 86 | return DataLoader( 87 | dataset=self.data_val, 88 | batch_size=self.batch_size_per_device, 89 | num_workers=self.hparams.num_workers, 90 | pin_memory=self.hparams.pin_memory, 91 | shuffle=False, 92 | ) 93 | 94 | def test_dataloader(self) -> DataLoader[Any]: 95 | """Create and return the test dataloader. 96 | 97 | :return: The test dataloader. 98 | """ 99 | assert self.data_test 100 | return DataLoader( 101 | dataset=self.data_test, 102 | batch_size=self.batch_size_per_device, 103 | num_workers=self.hparams.num_workers, 104 | pin_memory=self.hparams.pin_memory, 105 | shuffle=False, 106 | ) 107 | 108 | def teardown(self, stage: str | None = None): 109 | """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`, 110 | `trainer.test()`, and `trainer.predict()`. 111 | 112 | :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. 113 | Defaults to ``None``. 114 | """ 115 | 116 | def state_dict(self) -> dict[Any, Any]: 117 | """Called when saving a checkpoint. Implement to generate and save the datamodule state. 118 | 119 | :return: A dictionary containing the datamodule state that you want to save. 120 | """ 121 | return {} 122 | 123 | def load_state_dict(self, state_dict: dict[str, Any]) -> None: 124 | """Called when loading a checkpoint. Implement to reload datamodule state given datamodule 125 | `state_dict()`. 126 | 127 | :param state_dict: The datamodule state returned by `self.state_dict()`. 128 | """ 129 | 130 | 131 | if __name__ == "__main__": 132 | mod = PromoterDesignDataModule() 133 | mod.prepare_data() 134 | mod.setup() 135 | data_loader = mod.train_dataloader() 136 | x = next(iter(data_loader)) 137 | print(type(x)) 138 | print([item.shape for item in x]) 139 | -------------------------------------------------------------------------------- /src/data/qm9_datamodule.py: -------------------------------------------------------------------------------- 1 | from lightning import LightningDataModule 2 | from torch.utils.data import DataLoader 3 | import torch 4 | import dgl 5 | 6 | from src.data.components import MoleculeDataset 7 | 8 | """ 9 | test module loading: 10 | 11 | python -m src.data.qm9_datamodule 12 | """ 13 | 14 | class MoleculeDataModule(LightningDataModule): 15 | 16 | def __init__( 17 | self, 18 | dataset_config: dict, 19 | dm_prior_config: dict, 20 | batch_size: int, 21 | num_workers: int = 0, 22 | distributed: bool = False, 23 | max_num_edges: int = 40000, 24 | dataset: str = "qm9", 25 | ): 26 | super().__init__() 27 | self.distributed = distributed 28 | self.dataset_config = { 29 | 'processed_data_dir': f'data/{dataset}', 30 | 'raw_data_dir': f'data/{dataset}_raw', 31 | 'dataset_name': dataset, 32 | } 33 | self.batch_size = batch_size 34 | self.num_workers = num_workers 35 | self.prior_config = { 36 | 'a': { 37 | 'align': False, 38 | 'kwargs': {}, 39 | 'type': 'gaussian', 40 | }, 41 | 'c': { 42 | 'align': False, 43 | 'kwargs': {}, 44 | 'type': 'gaussian', 45 | }, 46 | 'e': { 47 | 'align': False, 48 | 'kwargs': {}, 49 | 'type': 'gaussian', 50 | }, 51 | 'x': { 52 | 'align': True, 53 | 'kwargs': {}, 54 | 'std': 1.0, 55 | 'type': 'centered-normal', 56 | }, 57 | } 58 | self.max_num_edges = max_num_edges 59 | self.save_hyperparameters() 60 | self.train_dataset = None 61 | self.val_dataset = None 62 | self.test_dataset = None 63 | 64 | def prepare_data(self) -> None: 65 | """Nothing to do""" 66 | 67 | def setup(self, stage: str | None = None) -> None: 68 | self.train_dataset = MoleculeDataset( 69 | 'train', 70 | self.dataset_config, 71 | prior_config=self.prior_config, 72 | ) 73 | 74 | self.val_dataset = MoleculeDataset( 75 | 'val', 76 | self.dataset_config, 77 | prior_config=self.prior_config, 78 | ) 79 | 80 | self.test_dataset = MoleculeDataset( 81 | 'test', 82 | self.dataset_config, 83 | prior_config=self.prior_config, 84 | ) 85 | 86 | def train_dataloader(self): 87 | assert self.train_dataset 88 | dataloader = DataLoader( 89 | self.train_dataset, 90 | batch_size=self.batch_size, 91 | shuffle=True, 92 | collate_fn=dgl.batch, 93 | num_workers=self.num_workers, 94 | ) 95 | return dataloader 96 | 97 | # # i wrote the following code under the assumption that we had to align predictions to the target, but i don't think this is true 98 | # if self.x_subspace == 'se3-quotient': 99 | # # if we are using the se3-quotient subspace, then we need to do same-size sampling so that we can efficiently compute rigid aligments during training 100 | # if self.distributed: 101 | # batch_sampler = SameSizeDistributedMoleculeSampler(self.train_dataset, batch_size=self.batch_size, max_num_edges=self.max_num_edges) 102 | # else: 103 | # batch_sampler = SameSizeMoleculeSampler(self.train_dataset, batch_size=self.batch_size, max_num_edges=self.max_num_edges) 104 | 105 | # dataloader = DataLoader(dataset=self.train_dataset, batch_sampler=batch_sampler, collate_fn=dgl.batch, num_workers=self.num_workers) 106 | 107 | # elif self.x_subspace == 'com-free': 108 | # # if we are using the com-free subspace, then we don't need to do same-size sampling - and life is easier! 109 | # dataloader = DataLoader(self.train_dataset, 110 | # batch_size=self.batch_size, 111 | # shuffle=True, 112 | # collate_fn=dgl.batch, 113 | # num_workers=self.num_workers) 114 | 115 | 116 | # return dataloader 117 | 118 | def test_dataloader(self): 119 | assert self.test_dataset 120 | dataloader = DataLoader( 121 | self.test_dataset, 122 | batch_size=self.batch_size*2, 123 | shuffle=True, 124 | collate_fn=dgl.batch, 125 | num_workers=self.num_workers, 126 | ) 127 | return dataloader 128 | 129 | def val_dataloader(self): 130 | assert self.val_dataset 131 | dataloader = DataLoader( 132 | self.val_dataset, 133 | batch_size=self.batch_size*2, 134 | shuffle=True, 135 | collate_fn=dgl.batch, 136 | num_workers=self.num_workers, 137 | ) 138 | return dataloader 139 | 140 | # if self.x_subspace == 'se3-quotient': 141 | # # if we are using the se3-quotient subspace, then we need to do same-size sampling so that we can efficiently compute rigid aligments during training 142 | # if self.distributed: 143 | # batch_sampler = SameSizeDistributedMoleculeSampler(self.train_dataset, batch_size=self.batch_size*2) 144 | # else: 145 | # batch_sampler = SameSizeMoleculeSampler(self.train_dataset, batch_size=self.batch_size*2) 146 | 147 | # dataloader = DataLoader(dataset=self.train_dataset, batch_sampler=batch_sampler, collate_fn=dgl.batch, num_workers=self.num_workers) 148 | 149 | # elif self.x_subspace == 'com-free': 150 | # # if we are using the com-free subspace, then we don't need to do same-size sampling - and life is easier! 151 | # dataloader = DataLoader(self.train_dataset, 152 | # batch_size=self.batch_size*2, 153 | # shuffle=True, 154 | # collate_fn=dgl.batch, 155 | # num_workers=self.num_workers) 156 | # return dataloader 157 | 158 | if __name__ == "__main__": 159 | dataset_config = { 160 | 'processed_data_dir': 'data/qm9', 161 | 'raw_data_dir': 'data/qm9_raw', 162 | 'dataset_name': 'qm9', 163 | } 164 | prior_config = { 165 | 'a': { 166 | 'align': False, 167 | 'kwargs': {}, 168 | 'type': 'gaussian', 169 | }, 170 | 'c': { 171 | 'align': False, 172 | 'kwargs': {}, 173 | 'type': 'gaussian', 174 | }, 175 | 'e': { 176 | 'align': False, 177 | 'kwargs': {}, 178 | 'type': 'gaussian', 179 | }, 180 | 'x': { 181 | 'align': True, 182 | 'kwargs': {}, 183 | 'std': 1.0, 184 | 'type': 'centered-normal', 185 | }, 186 | } 187 | mod = MoleculeDataModule(dataset_config, prior_config, 32, 0, False, 40000) 188 | mod.prepare_data() 189 | mod.setup() 190 | data_loader = mod.train_dataloader() 191 | x = next(iter(data_loader)) 192 | print(type(x)) 193 | print(x) 194 | import ipdb; ipdb.set_trace() 195 | -------------------------------------------------------------------------------- /src/data/toy_dfm_datamodule.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | from torch import Tensor, nn 5 | from torch.utils.data import DataLoader, Dataset 6 | from lightning import LightningDataModule 7 | 8 | 9 | from src.sfm import manifold_from_name 10 | 11 | 12 | class ToyDataset(torch.utils.data.IterableDataset): 13 | """ 14 | Adapted from `https://github.com/HannesStark/dirichlet-flow-matching/blob/main/utils/dataset.py`. 15 | """ 16 | def __init__(self, manifold, probs: Tensor, toy_seq_len: int, toy_simplex_dim: int, sz: int = 100_000): 17 | super().__init__() 18 | self.m = manifold 19 | self.sz = sz 20 | self.seq_len = toy_seq_len 21 | self.alphabet_size = toy_simplex_dim 22 | self.probs = probs 23 | 24 | def __len__(self) -> int: 25 | return self.sz 26 | 27 | def __iter__(self): 28 | while True: 29 | sample = torch.multinomial(replacement=True, num_samples=1, input=self.probs).squeeze() 30 | one_hot = nn.functional.one_hot(sample, self.alphabet_size).float() 31 | # if there is a need to smooth labels, it is done in the model's training step 32 | yield one_hot.reshape((self.seq_len, self.alphabet_size)) 33 | 34 | 35 | class ToyDFMDataModule(LightningDataModule): 36 | """ 37 | Toy DFM data module. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | k: int = 4, 43 | dim: int = 100, 44 | data_dir: str = "data/", 45 | train_val_test_split: tuple[int, int, int] = (55_000, 5_000, 10_000), 46 | batch_size: int = 64, 47 | num_workers: int = 0, 48 | pin_memory: bool = False, 49 | ): 50 | """Initialize a `MNISTDataModule`. 51 | 52 | :param data_dir: The data directory. Defaults to `"data/"`. 53 | :param train_val_test_split: The train, validation and test split. Defaults to `(55_000, 5_000, 10_000)`. 54 | :param batch_size: The batch size. Defaults to `64`. 55 | :param num_workers: The number of workers. Defaults to `0`. 56 | :param pin_memory: Whether to pin memory. Defaults to `False`. 57 | """ 58 | super().__init__() 59 | 60 | # this line allows to access init params with 'self.hparams' attribute 61 | # also ensures init params will be stored in ckpt 62 | self.save_hyperparameters(logger=False) 63 | self.k = k 64 | self.dim = dim 65 | self.probs = torch.softmax(torch.rand(k, dim), dim=-1) 66 | 67 | self.data_train: Dataset | None = None 68 | self.data_val: Dataset | None = None 69 | self.data_test: Dataset | None = None 70 | self.batch_size_per_device = batch_size 71 | 72 | def prepare_data(self): 73 | """Nothing to download.""" 74 | 75 | def setup(self, stage: str | None = None) -> None: 76 | """ 77 | Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. 78 | """ 79 | # Divide batch size by the number of devices. 80 | if self.trainer is not None: 81 | if self.hparams.batch_size % self.trainer.world_size != 0: 82 | raise RuntimeError( 83 | f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})." 84 | ) 85 | self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size 86 | 87 | # load and split datasets only if not loaded already 88 | if not self.data_train and not self.data_val and not self.data_test: 89 | manifold = manifold_from_name(self.hparams.get("manifold", "sphere")) 90 | self.data_train, self.data_val, self.data_test = ( 91 | ToyDataset( 92 | manifold, 93 | self.probs, 94 | self.k, 95 | self.dim, 96 | sz, 97 | ) for sz in self.hparams.train_val_test_split 98 | ) 99 | 100 | def train_dataloader(self) -> DataLoader[Any]: 101 | """Create and return the train dataloader. 102 | 103 | :return: The train dataloader. 104 | """ 105 | assert self.data_train 106 | return DataLoader( 107 | dataset=self.data_train, 108 | batch_size=self.batch_size_per_device, 109 | num_workers=self.hparams.num_workers, 110 | pin_memory=self.hparams.pin_memory, 111 | ) 112 | 113 | def val_dataloader(self) -> DataLoader[Any]: 114 | """Create and return the validation dataloader. 115 | 116 | :return: The validation dataloader. 117 | """ 118 | assert self.data_val 119 | return DataLoader( 120 | dataset=self.data_val, 121 | batch_size=self.batch_size_per_device, 122 | num_workers=self.hparams.num_workers, 123 | pin_memory=self.hparams.pin_memory, 124 | shuffle=False, 125 | ) 126 | 127 | def test_dataloader(self) -> DataLoader[Any]: 128 | """Create and return the test dataloader. 129 | 130 | :return: The test dataloader. 131 | """ 132 | assert self.data_test 133 | return DataLoader( 134 | dataset=self.data_test, 135 | batch_size=self.batch_size_per_device, 136 | num_workers=self.hparams.num_workers, 137 | pin_memory=self.hparams.pin_memory, 138 | shuffle=False, 139 | ) 140 | 141 | def teardown(self, stage: str | None = None): 142 | """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`, 143 | `trainer.test()`, and `trainer.predict()`. 144 | 145 | :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. 146 | Defaults to ``None``. 147 | """ 148 | 149 | def state_dict(self) -> dict[Any, Any]: 150 | """Called when saving a checkpoint. Implement to generate and save the datamodule state. 151 | 152 | :return: A dictionary containing the datamodule state that you want to save. 153 | """ 154 | return {} 155 | 156 | def load_state_dict(self, state_dict: dict[str, Any]) -> None: 157 | """Called when loading a checkpoint. Implement to reload datamodule state given datamodule 158 | `state_dict()`. 159 | 160 | :param state_dict: The datamodule state returned by `self.state_dict()`. 161 | """ 162 | 163 | 164 | if __name__ == "__main__": 165 | _ = ToyDFMDataModule() 166 | -------------------------------------------------------------------------------- /src/dfm/__init__.py: -------------------------------------------------------------------------------- 1 | from .flow_utils import * 2 | -------------------------------------------------------------------------------- /src/eval.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Tuple 2 | 3 | import hydra 4 | import rootutils 5 | from lightning import LightningDataModule, LightningModule, Trainer 6 | from lightning.pytorch.loggers import Logger 7 | from omegaconf import DictConfig 8 | 9 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 10 | # ------------------------------------------------------------------------------------ # 11 | # the setup_root above is equivalent to: 12 | # - adding project root dir to PYTHONPATH 13 | # (so you don't need to force user to install project as a package) 14 | # (necessary before importing any local modules e.g. `from src import utils`) 15 | # - setting up PROJECT_ROOT environment variable 16 | # (which is used as a base for paths in "configs/paths/default.yaml") 17 | # (this way all filepaths are the same no matter where you run the code) 18 | # - loading environment variables from ".env" in root dir 19 | # 20 | # you can remove it if you: 21 | # 1. either install project as a package or move entry files to project root dir 22 | # 2. set `root_dir` to "." in "configs/paths/default.yaml" 23 | # 24 | # more info: https://github.com/ashleve/rootutils 25 | # ------------------------------------------------------------------------------------ # 26 | 27 | from src.utils import ( 28 | RankedLogger, 29 | extras, 30 | instantiate_loggers, 31 | log_hyperparameters, 32 | task_wrapper, 33 | ) 34 | 35 | log = RankedLogger(__name__, rank_zero_only=True) 36 | 37 | 38 | @task_wrapper 39 | def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 40 | """Evaluates given checkpoint on a datamodule testset. 41 | 42 | This method is wrapped in optional @task_wrapper decorator, that controls the behavior during 43 | failure. Useful for multiruns, saving info about the crash, etc. 44 | 45 | :param cfg: DictConfig configuration composed by Hydra. 46 | :return: Tuple[dict, dict] with metrics and dict with all instantiated objects. 47 | """ 48 | assert cfg.ckpt_path 49 | 50 | log.info(f"Instantiating datamodule <{cfg.data._target_}>") 51 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) 52 | 53 | log.info(f"Instantiating model <{cfg.model._target_}>") 54 | model: LightningModule = hydra.utils.instantiate(cfg.model) 55 | 56 | log.info("Instantiating loggers...") 57 | logger: List[Logger] = instantiate_loggers(cfg.get("logger")) 58 | 59 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>") 60 | trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger) 61 | 62 | object_dict = { 63 | "cfg": cfg, 64 | "datamodule": datamodule, 65 | "model": model, 66 | "logger": logger, 67 | "trainer": trainer, 68 | } 69 | 70 | if logger: 71 | log.info("Logging hyperparameters!") 72 | log_hyperparameters(object_dict) 73 | 74 | log.info("Starting testing!") 75 | trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path) 76 | 77 | # for predictions use trainer.predict(...) 78 | # predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path) 79 | 80 | metric_dict = trainer.callback_metrics 81 | 82 | return metric_dict, object_dict 83 | 84 | 85 | @hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml") 86 | def main(cfg: DictConfig) -> None: 87 | """Main entry point for evaluation. 88 | 89 | :param cfg: DictConfig configuration composed by Hydra. 90 | """ 91 | # apply extra utilities 92 | # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) 93 | extras(cfg) 94 | 95 | evaluate(cfg) 96 | 97 | 98 | if __name__ == "__main__": 99 | main() 100 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .sfm_module import * 2 | from .dfm_module import DNAModule, PromoterModule 3 | from .molecule_module import MoleculeModule 4 | -------------------------------------------------------------------------------- /src/models/net/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import * 2 | from .promoter_model import * 3 | from .gvp import * 4 | from .interpolant_scheduler import InterpolantScheduler 5 | from .vector_field import * 6 | -------------------------------------------------------------------------------- /src/models/net/interpolant_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Dict, Tuple, Union 4 | 5 | class InterpolantScheduler(nn.Module): 6 | 7 | supported_schedule_types = ['cosine', 'linear'] 8 | 9 | def __init__(self, canonical_feat_order: str, schedule_type: Union[str, Dict[str, str]] = 'cosine', cosine_params: dict = {}): 10 | super().__init__() 11 | 12 | self.feats = canonical_feat_order 13 | self.n_feats = len(self.feats) 14 | 15 | # check that schedule_type is a string or a dictionary 16 | if not isinstance(schedule_type, (str, dict)): 17 | raise ValueError('schedule_type must be a string or a dictionary') 18 | 19 | # if it is a string, assign the same schedule_type to all features 20 | if isinstance(schedule_type, str): 21 | if schedule_type not in self.supported_schedule_types: 22 | raise ValueError(f'unsupported schedule_type: {schedule_type}') 23 | self.schedule_dict = { 24 | feat: schedule_type for feat in self.feats 25 | } 26 | else: 27 | # schedule_type is a dictionary specifying the schedule_type for each feature 28 | for feat in self.feats: 29 | if feat not in schedule_type: 30 | raise ValueError(f'must specify schedule_type for feature {feat}') 31 | 32 | self.schedule_dict = schedule_type 33 | 34 | # if schedule_type == 'cosine': 35 | # self.alpha_t = self.cosine_alpha_t 36 | # self.alpha_t_prime = self.cosine_alpha_t_prime 37 | # elif schedule_type == 'linear': 38 | # self.alpha_t = self.linear_alpha_t 39 | # self.alpha_t_prime = self.linear_alpha_t_prime 40 | # else: 41 | # raise NotImplementedError(f'unsupported schedule_type: {schedule_type}') 42 | 43 | 44 | # for features which have a cosine schedule, check that the parameter "nu" is provided 45 | for feat, schedule_type in self.schedule_dict.items(): 46 | if schedule_type == 'cosine' and feat not in cosine_params: 47 | raise ValueError(f'must specify cosine_params for feature {feat}') 48 | 49 | # get a list of unique schedule types which are used 50 | self.schedule_types = list(set( self.schedule_dict.values() )) 51 | 52 | # if we are using a cosine schedule, convert all of the cosine_params to torch tensors 53 | if 'cosine' in self.schedule_types: 54 | for feat in cosine_params: 55 | cosine_params[feat] = torch.tensor(cosine_params[feat]).unsqueeze(0) 56 | 57 | # save the cosine_params as an attribute 58 | self.cosine_params = cosine_params 59 | 60 | self.device = None 61 | 62 | self.clamp_t = True 63 | 64 | 65 | 66 | def update_device(self, t): 67 | if 'cosine' in self.schedule_types and t.device != self.device: 68 | for key in self.cosine_params: 69 | self.cosine_params[key] = self.cosine_params[key].to(t.device) 70 | self.device = t.device 71 | 72 | def interpolant_weights(self, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 73 | """ 74 | Returns the weights for x_0 and x_1 in the interpolation between x_0 and x_1. 75 | """ 76 | # t has shape (n_timepoints,) 77 | # returns a tuple of 2 tensors of shape (n_timepoints, n_feats) 78 | # the tensor at index 0 is the weight for x_0 79 | # the tensor at index 1 is the weight for x_1 80 | 81 | self.update_device(t) 82 | 83 | alpha_t = self.alpha_t(t) 84 | weights = (1 - alpha_t, alpha_t) 85 | return weights 86 | 87 | def loss_weights(self, t: torch.Tensor): 88 | alpha_t = self.alpha_t(t) 89 | # alpha_t_prime = self.alpha_t_prime(t) 90 | # weights = alpha_t_prime/(1 - alpha_t + 1e-5) 91 | weights = alpha_t/(1 - alpha_t) 92 | 93 | # clamp the weights with a minimum of 0.05 and a maximum of 1.5 94 | weights = torch.clamp(weights, min=0.05, max=1.5) 95 | return weights 96 | 97 | def alpha_t(self, t: torch.Tensor) -> torch.Tensor: 98 | 99 | self.update_device(t) 100 | 101 | per_feat_alpha = [] 102 | for feat in self.feats: 103 | schedule_type = self.schedule_dict[feat] 104 | if schedule_type == 'cosine': 105 | alpha_t = self.cosine_alpha_t(t, nu=self.cosine_params[feat]) 106 | elif schedule_type == 'linear': 107 | alpha_t = self.linear_alpha_t(t) 108 | 109 | per_feat_alpha.append(alpha_t) 110 | 111 | alpha_t = torch.cat(per_feat_alpha, dim=1) 112 | return alpha_t 113 | 114 | def alpha_t_prime(self, t: torch.Tensor) -> torch.Tensor: 115 | self.update_device(t) 116 | 117 | per_feat_alpha_prime = [] 118 | for feat in self.feats: 119 | schedule_type = self.schedule_dict[feat] 120 | if schedule_type == 'cosine': 121 | alpha_t_prime = self.cosine_alpha_t_prime(t, nu=self.cosine_params[feat]) 122 | elif schedule_type == 'linear': 123 | alpha_t_prime = self.linear_alpha_t_prime(t) 124 | 125 | per_feat_alpha_prime.append(alpha_t_prime) 126 | 127 | alpha_t_prime = torch.cat(per_feat_alpha_prime, dim=1) 128 | return alpha_t_prime 129 | 130 | 131 | def cosine_alpha_t(self, t: torch.Tensor, nu: torch.Tensor) -> Dict[str, torch.Tensor]: 132 | # t has shape (n_timepoints,) 133 | # alpha_t has shape (n_timepoints, n_feats) containing the alpha_t for each feature 134 | t = t.unsqueeze(-1) 135 | alpha_t = 1 - torch.cos(torch.pi*0.5*torch.pow(t, nu)).square() 136 | return alpha_t 137 | 138 | def cosine_alpha_t_prime(self, t: torch.Tensor, nu: torch.Tensor) -> torch.Tensor: 139 | 140 | if self.clamp_t: 141 | t = torch.clamp_(t, min=1e-9) 142 | 143 | t = t.unsqueeze(-1) 144 | sin_input = torch.pi*torch.pow(t, nu) 145 | alpha_t_prime = torch.pi*0.5*torch.sin(sin_input)*nu*torch.pow(t, nu-1) 146 | return alpha_t_prime 147 | 148 | def linear_alpha_t(self, t: torch.Tensor) -> Dict[str, torch.Tensor]: 149 | alpha_t = t.unsqueeze(-1) 150 | return alpha_t 151 | 152 | def linear_alpha_t_prime(self, t: torch.Tensor) -> Dict[str, torch.Tensor]: 153 | alpha_t_prime = torch.ones_like(t).unsqueeze(-1) 154 | return alpha_t_prime 155 | -------------------------------------------------------------------------------- /src/models/net/promoter_model.py: -------------------------------------------------------------------------------- 1 | # Taken from: https://github.com/HannesStark/dirichlet-flow-matching/blob/main/model/promoter_model.py 2 | from torch import nn 3 | import torch 4 | import numpy as np 5 | from .model import expand_simplex 6 | 7 | 8 | class GaussianFourierProjection(nn.Module): 9 | """ 10 | Gaussian random features for encoding time steps. 11 | """ 12 | 13 | def __init__(self, embed_dim, scale=30.): 14 | super().__init__() 15 | # Randomly sample weights during initialization. These weights are fixed 16 | # during optimization and are not trainable. 17 | self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False) 18 | 19 | def forward(self, x): 20 | x_proj = x[:, None] * self.W[None, :] * 2 * np.pi 21 | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 22 | 23 | 24 | class Dense(nn.Module): 25 | """ 26 | A fully connected layer that reshapes outputs to feature maps. 27 | """ 28 | 29 | def __init__(self, input_dim, output_dim): 30 | super().__init__() 31 | self.dense = nn.Linear(input_dim, output_dim) 32 | 33 | def forward(self, x): 34 | return self.dense(x)[...] 35 | 36 | 37 | class PromoterModel(nn.Module): 38 | """A time-dependent score-based model built upon U-Net architecture.""" 39 | 40 | def __init__(self, mode, embed_dim=256, time_dependent_weights=None, time_step=0.01): 41 | """Initialize a time-dependent score-based network. 42 | 43 | Args: 44 | marginal_prob_std: A function that takes time t and gives the standard 45 | deviation of the perturbation kernel p_{0t}(x(t) | x(0)). 46 | channels: The number of channels for feature maps of each resolution. 47 | embed_dim: The dimensionality of Gaussian random feature embeddings. 48 | """ 49 | super().__init__() 50 | # Gaussian random feature embedding layer for time 51 | self.alphabet_size = 4 52 | self.embed = nn.Sequential(GaussianFourierProjection(embed_dim=embed_dim), 53 | nn.Linear(embed_dim, embed_dim)) 54 | n = 256 55 | expanded_simplex_input = (mode == 'dirichlet' or mode == 'riemannian') 56 | # NOTE: change +1 to +2 here, at the end of the line 57 | inp_size = self.alphabet_size * (2 if expanded_simplex_input else 1) + 1 # plus one for signal input 58 | if (mode == 'ardm' or mode == 'lrar'): 59 | inp_size += 1 # plus one for the mask token of these models 60 | self.linear = nn.Conv1d(inp_size, n, kernel_size=9, padding=4) 61 | self.blocks = nn.ModuleList([nn.Conv1d(n, n, kernel_size=9, padding=4), 62 | nn.Conv1d(n, n, kernel_size=9, padding=4), 63 | nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), 64 | nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), 65 | nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256), 66 | nn.Conv1d(n, n, kernel_size=9, padding=4), 67 | nn.Conv1d(n, n, kernel_size=9, padding=4), 68 | nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), 69 | nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), 70 | nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256), 71 | nn.Conv1d(n, n, kernel_size=9, padding=4), 72 | nn.Conv1d(n, n, kernel_size=9, padding=4), 73 | nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), 74 | nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), 75 | nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256), 76 | nn.Conv1d(n, n, kernel_size=9, padding=4), 77 | nn.Conv1d(n, n, kernel_size=9, padding=4), 78 | nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), 79 | nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), 80 | nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256)]) 81 | 82 | self.denses = nn.ModuleList([Dense(embed_dim, n) for _ in range(20)]) 83 | self.norms = nn.ModuleList([nn.GroupNorm(1, n) for _ in range(20)]) 84 | 85 | # The swish activation function 86 | self.act = lambda x: x * torch.sigmoid(x) 87 | self.relu = nn.ReLU() 88 | self.softplus = nn.Softplus() 89 | # was unused 90 | # self.scale = nn.Parameter(torch.ones(1)) 91 | self.final = nn.Sequential(nn.Conv1d(n, n, kernel_size=1), 92 | nn.GELU(), 93 | nn.Conv1d(n, 4, kernel_size=1)) 94 | self.register_buffer("time_dependent_weights", time_dependent_weights) 95 | self.time_step = time_step 96 | self.mode = mode 97 | 98 | def forward(self, x, signal, t): 99 | # Obtain the Gaussian random feature embedding for t 100 | # embed: [N, embed_dim] 101 | t = t.squeeze() 102 | if t is not None and len(t.shape) == 0: 103 | # odeint is on 104 | t = t[None].expand(x.size(0)) 105 | embed = self.act(self.embed(t / 2)) 106 | 107 | x = torch.cat([x,signal], dim=-1) 108 | 109 | # Encoding path 110 | # x: NLC -> NCL 111 | out = x.permute(0, 2, 1) 112 | out = self.act(self.linear(out)) 113 | 114 | # pos encoding 115 | for block, dense, norm in zip(self.blocks, self.denses, self.norms): 116 | h = self.act(block(norm(out + dense(embed)[:, :, None]))) 117 | if h.shape == out.shape: 118 | out = h + out 119 | else: 120 | out = h 121 | 122 | out = self.final(out) 123 | 124 | out = out.permute(0, 2, 1) 125 | 126 | if self.time_dependent_weights is not None: 127 | t_step = (t / self.time_step) - 1 128 | w0 = self.time_dependent_weights[t_step.long()] 129 | w1 = self.time_dependent_weights[torch.clip(t_step + 1, max=len(self.time_dependent_weights) - 1).long()] 130 | out = out * (w0 + (t_step - t_step.floor()) * (w1 - w0))[:, None, None] 131 | 132 | # NOTE: the following is proj onto simplex tangent space, which we do not want 133 | # we want sphere tangent space (done later in our training) 134 | if self.mode != "sfm": 135 | out = out - out.mean(axis=-1, keepdims=True) 136 | return out 137 | -------------------------------------------------------------------------------- /src/sfm/__init__.py: -------------------------------------------------------------------------------- 1 | """Utils for the entire project.""" 2 | from .maths import ( 3 | fast_dot, 4 | safe_arccos, 5 | usinc, 6 | ) 7 | from .manifold import ( 8 | GeooptSphere, 9 | Manifold, 10 | NSimplex, 11 | NSphere, 12 | manifold_from_name, 13 | str_to_ot_method, 14 | ) 15 | from .distribution import ( 16 | compute_exact_loglikelihood, 17 | estimate_categorical_kl, 18 | get_wasserstein_dist, 19 | set_seeds, 20 | ) 21 | from .plot import ( 22 | define_style, 23 | save_plot, 24 | ) 25 | from .sampler import ( 26 | OTSampler, 27 | ) 28 | from .train import ( 29 | cft_loss_function, 30 | ot_train_step, 31 | ) 32 | -------------------------------------------------------------------------------- /src/sfm/maths.py: -------------------------------------------------------------------------------- 1 | """Some maths utils.""" 2 | import torch 3 | from torch import Tensor 4 | 5 | 6 | @torch.jit.script 7 | def usinc(theta: Tensor) -> Tensor: 8 | """Unnormalized sinc.""" 9 | return torch.sinc(theta / torch.pi) 10 | 11 | 12 | def safe_arccos(x: Tensor) -> Tensor: 13 | """A safe version of `x.arccos()`.""" 14 | return x.clamp(-1.0, 1.0).acos() 15 | 16 | 17 | __f_dot = torch.vmap(torch.vmap(torch.dot)) 18 | 19 | 20 | def fast_dot(u: Tensor, v: Tensor, keepdim: bool = True) -> Tensor: 21 | """A faster and unified version of dot products.""" 22 | # ret = __f_dot(p, q) 23 | # if keepdim: 24 | # ret = ret.unsqueeze(-1) 25 | # return ret 26 | ret = torch.einsum("bnd,bnd->bn", u, v) 27 | if keepdim: 28 | ret = ret.unsqueeze(-1) 29 | return ret 30 | -------------------------------------------------------------------------------- /src/sfm/plot.py: -------------------------------------------------------------------------------- 1 | """Useful functions for plotting (namely, style).""" 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | 5 | 6 | def define_style(): 7 | """ 8 | Sets the style up for matplotlib. 9 | """ 10 | # plt.rcParams['text.usetex'] = True 11 | # plt.rcParams['text.latex.preamble'] = r"""\usepackage[T1]{fontenc}""" 12 | plt.rc("font", family="serif", weight="normal", size=16) 13 | sns.set_theme() 14 | sns.set_style(style="whitegrid") 15 | sns.set_palette("Paired") 16 | 17 | 18 | def save_plot(loc: str): 19 | """ 20 | Saves the current matplotlib plot at location `loc`. This is useful 21 | to keep the same export parameters for all the plots. 22 | """ 23 | plt.savefig( 24 | loc, bbox_inches="tight", dpi=300, 25 | ) 26 | -------------------------------------------------------------------------------- /src/sfm/sampler.py: -------------------------------------------------------------------------------- 1 | """Defines sampling methods; useful for OT-sampling.""" 2 | import torch 3 | from torch import Tensor 4 | 5 | 6 | from src.sfm import Manifold, str_to_ot_method 7 | 8 | 9 | class OTSampler: 10 | """ 11 | Based on: 12 | `https://github.com/DreamFold/FoldFlow/blob/main/FoldFlow/utils/optimal_transport.py`. 13 | """ 14 | 15 | def __init__( 16 | self, 17 | manifold: Manifold, 18 | method: str, 19 | reg: float = 0.05, 20 | reg_m: float = 1.0, 21 | normalize_cost: bool = False, 22 | ): 23 | """ 24 | Parameters: 25 | - `manifold`: the underlying manifold; useful for the geodesic 26 | distance; 27 | - `method`: the OT method; 28 | - `reg`: parameter for the OT method; 29 | - `reg_m`: parameter for the OT method, can be ignored depending on method. 30 | """ 31 | self.manifold = manifold 32 | self.ot_fn = str_to_ot_method(method, reg, reg_m) 33 | self.normalize_cost = normalize_cost 34 | 35 | @torch.no_grad() 36 | def get_map(self, x0: Tensor, x1: Tensor) -> Tensor: 37 | """ 38 | Compute the OT plan between a source and a target minibatch. 39 | """ 40 | a, b = ( 41 | torch.full((x0.shape[0],), 1.0 / x0.shape[0], device=x1.device), 42 | torch.full((x1.shape[0],), 1.0 / x1.shape[0], device=x1.device), 43 | ) 44 | m = self.manifold.pairwise_geodesic_distance(x0, x1) 45 | if self.normalize_cost: 46 | m = m / m.max() # should not be normalized when using minibatches 47 | p = self.ot_fn(a, b, m) 48 | # if not torch.all(torch.isfinite(p)): 49 | #  print("ERROR: p is not finite") 50 | #  print(p) 51 | #  print("Cost mean, max", m.mean(), m.max()) 52 | #  print(x0, x1) 53 | #  raise ValueError("p is not finite") 54 | return p 55 | 56 | def sample_map(self, pi: Tensor, batch_size: int): 57 | """ 58 | Draw source and target samples from `pi`, $(x,z) \sim \pi$. 59 | """ 60 | p = pi.flatten() 61 | p = p / p.sum() 62 | # choices = np.random.choice(pi.shape[0] * pi.shape[1], p=p, size=batch_size) 63 | # return np.divmod(choices, pi.shape[1]) 64 | choices = torch.multinomial( 65 | p, num_samples=batch_size, replacement=True, 66 | ).long() 67 | return torch.floor_divide(choices, pi.shape[1]), torch.remainder(choices, pi.shape[1]) 68 | 69 | def sample_plan(self, x0: Tensor, x1: Tensor) -> tuple[Tensor, Tensor]: 70 | """ 71 | Compute the OT plan $\pi$ (wrt squared Euclidean cost) between a source and a target 72 | minibatch and draw source and target samples from `pi` $(x,z) \sim \pi$ 73 | """ 74 | pi = self.get_map(x0, x1) 75 | i, j = self.sample_map(pi, x0.shape[0]) 76 | return x0[i], x1[j] 77 | -------------------------------------------------------------------------------- /src/sfm/train.py: -------------------------------------------------------------------------------- 1 | """Loss utils.""" 2 | import torch 3 | from torch import Tensor, nn, vmap 4 | from torch.func import jvp 5 | from src.sfm import Manifold, OTSampler 6 | 7 | 8 | def geodesic(manifold, start_point, end_point): 9 | # https://github.com/facebookresearch/riemannian-fm/blob/main/manifm/manifolds/utils.py#L6 10 | shooting_tangent_vec = manifold.logmap(start_point, end_point) 11 | 12 | def path(t): 13 | """Generate parameterized function for geodesic curve. 14 | Parameters 15 | ---------- 16 | t : array-like, shape=[n_points,] 17 | Times at which to compute points of the geodesics. 18 | """ 19 | tangent_vecs = torch.einsum("i,...k->...ik", t, shooting_tangent_vec) 20 | points_at_time_t = manifold.expmap(start_point.unsqueeze(-2), tangent_vecs) 21 | return points_at_time_t 22 | 23 | return path 24 | 25 | 26 | def ot_train_step( 27 | x_1: Tensor, 28 | m: Manifold, 29 | model: nn.Module, 30 | sampler: OTSampler | None, 31 | extra_args: dict[str, Tensor] | None = None, 32 | closed_form_drv: bool = False, 33 | ) -> tuple[Tensor, Tensor, Tensor]: 34 | """ 35 | Returns the loss for a single (OT-)CFT training step along with the 36 | model's output and the target vector. 37 | 38 | Parameters: 39 | - `x_1`: batch of data points; 40 | - `m`: manifold; 41 | - `model`: the model to apply; 42 | - `sampler` (optional): the sampler for the OT plan; 43 | - `time_eps`: "guard" for sampling the time; 44 | - `signal` (optional): extra signal for some datasets; 45 | - `closed_form_drv`: whether to use the closed-form derivative; 46 | if `False`, uses autograd; 47 | - `stochastic`: whether to train for an SDE. 48 | """ 49 | b = x_1.size(0) 50 | k = x_1.size(1) 51 | d = x_1.size(-1) 52 | t = torch.rand((b, 1), device=x_1.device) 53 | x_0 = m.uniform_prior(b, k, d).to(x_1.device) 54 | return cft_loss_function( 55 | x_0, x_1, t, m, model, sampler, extra_args=extra_args, closed_form_drv=closed_form_drv, 56 | ) 57 | 58 | 59 | def cft_loss_function( 60 | x_0: Tensor, 61 | x_1: Tensor, 62 | t: Tensor, 63 | m: Manifold, 64 | model: nn.Module, 65 | sampler: OTSampler | None, 66 | extra_args: dict[str, Tensor] | None = None, 67 | closed_form_drv: bool = False, 68 | ) -> tuple[Tensor, Tensor, Tensor]: 69 | """ 70 | Our CFT loss function. If `sampler` is provided, OT-CFT loss is calculated. 71 | 72 | Parameters: 73 | - `x_0`: starting point (drawn from prior); 74 | - `x_1`: end point (drawn from data); 75 | - `t`: the times; 76 | - `m`: the manifold; 77 | - `model`: the model to apply; 78 | - `sampler` (optional): the sampler for the OT plan; 79 | - `signal` (optional): extra signal for some datasets; 80 | - `closed_form_drv`: whether to use the closed-form derivative; 81 | if `False`, uses autograd; 82 | - `stochastic`: whether to train for an SDE. 83 | 84 | Returns: 85 | The loss tensor, the model output, and the target vector. 86 | """ 87 | if sampler: 88 | x_0, x_1 = sampler.sample_plan(x_0, x_1) 89 | if closed_form_drv: 90 | x_t = m.geodesic_interpolant(x_0, x_1, t) 91 | target = m.log_map(x_0, x_1) 92 | target = m.parallel_transport(x_0, x_t, target) 93 | # target = m.log_map(x_t, x_1) / (1.0 - t.unsqueeze(-1) + 1e-7) 94 | else: 95 | with torch.inference_mode(False): 96 | # https://github.com/facebookresearch/riemannian-fm/blob/main/manifm/model_pl.py 97 | def cond_u(x0, x1, t): 98 | path = geodesic(m.sphere, x0, x1) 99 | x_t, u_t = jvp(path, (t,), (torch.ones_like(t).to(t),)) 100 | return x_t, u_t 101 | x_t, target = vmap(cond_u)(x_0, x_1, t) 102 | x_t = x_t.squeeze() 103 | target = target.squeeze() 104 | assert m.all_belong_tangent(x_t, target) 105 | 106 | # now calculate diffs 107 | out = model(x=x_t, t=t, **(extra_args or {})) 108 | 109 | diff = out - target 110 | return diff.square().sum(dim=(-1, -2)).mean(), out, target 111 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple 2 | 3 | import hydra 4 | import lightning as L 5 | import rootutils 6 | import torch 7 | from lightning import Callback, LightningDataModule, LightningModule, Trainer 8 | from lightning.pytorch.loggers import Logger 9 | from omegaconf import DictConfig 10 | 11 | 12 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 13 | # ------------------------------------------------------------------------------------ # 14 | # the setup_root above is equivalent to: 15 | # - adding project root dir to PYTHONPATH 16 | # (so you don't need to force user to install project as a package) 17 | # (necessary before importing any local modules e.g. `from src import utils`) 18 | # - setting up PROJECT_ROOT environment variable 19 | # (which is used as a base for paths in "configs/paths/default.yaml") 20 | # (this way all filepaths are the same no matter where you run the code) 21 | # - loading environment variables from ".env" in root dir 22 | # 23 | # you can remove it if you: 24 | # 1. either install project as a package or move entry files to project root dir 25 | # 2. set `root_dir` to "." in "configs/paths/default.yaml" 26 | # 27 | # more info: https://github.com/ashleve/rootutils 28 | # ------------------------------------------------------------------------------------ # 29 | 30 | from src.utils import ( 31 | RankedLogger, 32 | extras, 33 | get_metric_value, 34 | instantiate_callbacks, 35 | instantiate_loggers, 36 | log_hyperparameters, 37 | task_wrapper, 38 | ) 39 | 40 | log = RankedLogger(__name__, rank_zero_only=True) 41 | 42 | 43 | @task_wrapper 44 | def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 45 | """Trains the model. Can additionally evaluate on a testset, using best weights obtained during 46 | training. 47 | 48 | This method is wrapped in optional @task_wrapper decorator, that controls the behavior during 49 | failure. Useful for multiruns, saving info about the crash, etc. 50 | 51 | :param cfg: A DictConfig configuration composed by Hydra. 52 | :return: A tuple with metrics and dict with all instantiated objects. 53 | """ 54 | # set seed for random number generators in pytorch, numpy and python.random 55 | if cfg.get("seed"): 56 | L.seed_everything(cfg.seed, workers=True) 57 | 58 | log.info(f"Instantiating datamodule <{cfg.data._target_}>") 59 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) 60 | 61 | log.info(f"Instantiating model <{cfg.model._target_}>") 62 | model: LightningModule = hydra.utils.instantiate(cfg.model) 63 | 64 | log.info("Instantiating callbacks...") 65 | callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) 66 | 67 | log.info("Instantiating loggers...") 68 | logger: List[Logger] = instantiate_loggers(cfg.get("logger")) 69 | 70 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>") 71 | trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) 72 | 73 | object_dict = { 74 | "cfg": cfg, 75 | "datamodule": datamodule, 76 | "model": model, 77 | "callbacks": callbacks, 78 | "logger": logger, 79 | "trainer": trainer, 80 | } 81 | 82 | if logger: 83 | log.info("Logging hyperparameters!") 84 | log_hyperparameters(object_dict) 85 | 86 | if cfg.get("train"): 87 | log.info("Starting training!") 88 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) 89 | 90 | train_metrics = trainer.callback_metrics 91 | 92 | if cfg.get("test"): 93 | log.info("Starting testing!") 94 | ckpt_path = trainer.checkpoint_callback.best_model_path 95 | if ckpt_path == "": 96 | log.warning("Best ckpt not found! Using current weights for testing...") 97 | ckpt_path = None 98 | trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) 99 | log.info(f"Best ckpt path: {ckpt_path}") 100 | 101 | test_metrics = trainer.callback_metrics 102 | 103 | # merge train and test metrics 104 | metric_dict = {**train_metrics, **test_metrics} 105 | 106 | return metric_dict, object_dict 107 | 108 | 109 | @hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") 110 | def main(cfg: DictConfig) -> Optional[float]: 111 | """Main entry point for training. 112 | 113 | :param cfg: DictConfig configuration composed by Hydra. 114 | :return: Optional[float] with optimized metric value. 115 | """ 116 | # apply extra utilities 117 | # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) 118 | extras(cfg) 119 | 120 | # train the model 121 | metric_dict, _ = train(cfg) 122 | 123 | # safely retrieve metric value for hydra-based hyperparameter optimization 124 | metric_value = get_metric_value( 125 | metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") 126 | ) 127 | 128 | # return optimized metric 129 | return metric_value 130 | 131 | 132 | if __name__ == "__main__": 133 | main() 134 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from src.utils.instantiators import instantiate_callbacks, instantiate_loggers 2 | from src.utils.logging_utils import log_hyperparameters 3 | from src.utils.pylogger import RankedLogger 4 | from src.utils.rich_utils import enforce_tags, print_config_tree 5 | from src.utils.utils import extras, get_metric_value, task_wrapper 6 | -------------------------------------------------------------------------------- /src/utils/instantiators.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import hydra 4 | from lightning import Callback 5 | from lightning.pytorch.loggers import Logger 6 | from omegaconf import DictConfig 7 | 8 | from src.utils import pylogger 9 | 10 | log = pylogger.RankedLogger(__name__, rank_zero_only=True) 11 | 12 | 13 | def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: 14 | """Instantiates callbacks from config. 15 | 16 | :param callbacks_cfg: A DictConfig object containing callback configurations. 17 | :return: A list of instantiated callbacks. 18 | """ 19 | callbacks: List[Callback] = [] 20 | 21 | if not callbacks_cfg: 22 | log.warning("No callback configs found! Skipping..") 23 | return callbacks 24 | 25 | if not isinstance(callbacks_cfg, DictConfig): 26 | raise TypeError("Callbacks config must be a DictConfig!") 27 | 28 | for _, cb_conf in callbacks_cfg.items(): 29 | if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: 30 | log.info(f"Instantiating callback <{cb_conf._target_}>") 31 | callbacks.append(hydra.utils.instantiate(cb_conf)) 32 | 33 | return callbacks 34 | 35 | 36 | def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: 37 | """Instantiates loggers from config. 38 | 39 | :param logger_cfg: A DictConfig object containing logger configurations. 40 | :return: A list of instantiated loggers. 41 | """ 42 | logger: List[Logger] = [] 43 | 44 | if not logger_cfg: 45 | log.warning("No logger configs found! Skipping...") 46 | return logger 47 | 48 | if not isinstance(logger_cfg, DictConfig): 49 | raise TypeError("Logger config must be a DictConfig!") 50 | 51 | for _, lg_conf in logger_cfg.items(): 52 | if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: 53 | log.info(f"Instantiating logger <{lg_conf._target_}>") 54 | logger.append(hydra.utils.instantiate(lg_conf)) 55 | 56 | return logger 57 | -------------------------------------------------------------------------------- /src/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | from lightning_utilities.core.rank_zero import rank_zero_only 4 | from omegaconf import OmegaConf 5 | 6 | from src.utils import pylogger 7 | 8 | log = pylogger.RankedLogger(__name__, rank_zero_only=True) 9 | 10 | 11 | @rank_zero_only 12 | def log_hyperparameters(object_dict: Dict[str, Any]) -> None: 13 | """Controls which config parts are saved by Lightning loggers. 14 | 15 | Additionally saves: 16 | - Number of model parameters 17 | 18 | :param object_dict: A dictionary containing the following objects: 19 | - `"cfg"`: A DictConfig object containing the main config. 20 | - `"model"`: The Lightning model. 21 | - `"trainer"`: The Lightning trainer. 22 | """ 23 | hparams = {} 24 | 25 | cfg = OmegaConf.to_container(object_dict["cfg"]) 26 | model = object_dict["model"] 27 | trainer = object_dict["trainer"] 28 | 29 | if not trainer.logger: 30 | log.warning("Logger not found! Skipping hyperparameter logging...") 31 | return 32 | 33 | hparams["model"] = cfg["model"] 34 | 35 | # save number of model parameters 36 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) 37 | hparams["model/params/trainable"] = sum( 38 | p.numel() for p in model.parameters() if p.requires_grad 39 | ) 40 | hparams["model/params/non_trainable"] = sum( 41 | p.numel() for p in model.parameters() if not p.requires_grad 42 | ) 43 | 44 | hparams["data"] = cfg["data"] 45 | hparams["trainer"] = cfg["trainer"] 46 | 47 | hparams["callbacks"] = cfg.get("callbacks") 48 | hparams["extras"] = cfg.get("extras") 49 | 50 | hparams["task_name"] = cfg.get("task_name") 51 | hparams["tags"] = cfg.get("tags") 52 | hparams["ckpt_path"] = cfg.get("ckpt_path") 53 | hparams["seed"] = cfg.get("seed") 54 | 55 | # send hparams to all loggers 56 | for logger in trainer.loggers: 57 | logger.log_hyperparams(hparams) 58 | -------------------------------------------------------------------------------- /src/utils/pylogger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Mapping, Optional 3 | 4 | from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only 5 | 6 | 7 | class RankedLogger(logging.LoggerAdapter): 8 | """A multi-GPU-friendly python command line logger.""" 9 | 10 | def __init__( 11 | self, 12 | name: str = __name__, 13 | rank_zero_only: bool = False, 14 | extra: Optional[Mapping[str, object]] = None, 15 | ) -> None: 16 | """Initializes a multi-GPU-friendly python command line logger that logs on all processes 17 | with their rank prefixed in the log message. 18 | 19 | :param name: The name of the logger. Default is ``__name__``. 20 | :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. 21 | :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. 22 | """ 23 | logger = logging.getLogger(name) 24 | super().__init__(logger=logger, extra=extra) 25 | self.rank_zero_only = rank_zero_only 26 | 27 | def log(self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None: 28 | """Delegate a log call to the underlying logger, after prefixing its message with the rank 29 | of the process it's being logged from. If `'rank'` is provided, then the log will only 30 | occur on that rank/process. 31 | 32 | :param level: The level to log at. Look at `logging.__init__.py` for more information. 33 | :param msg: The message to log. 34 | :param rank: The rank to log at. 35 | :param args: Additional args to pass to the underlying logging function. 36 | :param kwargs: Any additional keyword args to pass to the underlying logging function. 37 | """ 38 | if self.isEnabledFor(level): 39 | msg, kwargs = self.process(msg, kwargs) 40 | current_rank = getattr(rank_zero_only, "rank", None) 41 | if current_rank is None: 42 | raise RuntimeError("The `rank_zero_only.rank` needs to be set before use") 43 | msg = rank_prefixed_message(msg, current_rank) 44 | if self.rank_zero_only: 45 | if current_rank == 0: 46 | self.logger.log(level, msg, *args, **kwargs) 47 | else: 48 | if rank is None: 49 | self.logger.log(level, msg, *args, **kwargs) 50 | elif current_rank == rank: 51 | self.logger.log(level, msg, *args, **kwargs) 52 | -------------------------------------------------------------------------------- /src/utils/rich_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Sequence 3 | 4 | import rich 5 | import rich.syntax 6 | import rich.tree 7 | from hydra.core.hydra_config import HydraConfig 8 | from lightning_utilities.core.rank_zero import rank_zero_only 9 | from omegaconf import DictConfig, OmegaConf, open_dict 10 | from rich.prompt import Prompt 11 | 12 | from src.utils import pylogger 13 | 14 | log = pylogger.RankedLogger(__name__, rank_zero_only=True) 15 | 16 | 17 | @rank_zero_only 18 | def print_config_tree( 19 | cfg: DictConfig, 20 | print_order: Sequence[str] = ( 21 | "data", 22 | "model", 23 | "callbacks", 24 | "logger", 25 | "trainer", 26 | "paths", 27 | "extras", 28 | ), 29 | resolve: bool = False, 30 | save_to_file: bool = False, 31 | ) -> None: 32 | """Prints the contents of a DictConfig as a tree structure using the Rich library. 33 | 34 | :param cfg: A DictConfig composed by Hydra. 35 | :param print_order: Determines in what order config components are printed. Default is ``("data", "model", 36 | "callbacks", "logger", "trainer", "paths", "extras")``. 37 | :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. 38 | :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. 39 | """ 40 | style = "dim" 41 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 42 | 43 | queue = [] 44 | 45 | # add fields from `print_order` to queue 46 | for field in print_order: 47 | queue.append(field) if field in cfg else log.warning( 48 | f"Field '{field}' not found in config. Skipping '{field}' config printing..." 49 | ) 50 | 51 | # add all the other fields to queue (not specified in `print_order`) 52 | for field in cfg: 53 | if field not in queue: 54 | queue.append(field) 55 | 56 | # generate config tree from queue 57 | for field in queue: 58 | branch = tree.add(field, style=style, guide_style=style) 59 | 60 | config_group = cfg[field] 61 | if isinstance(config_group, DictConfig): 62 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) 63 | else: 64 | branch_content = str(config_group) 65 | 66 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 67 | 68 | # print config tree 69 | rich.print(tree) 70 | 71 | # save config tree to file 72 | if save_to_file: 73 | with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: 74 | rich.print(tree, file=file) 75 | 76 | 77 | @rank_zero_only 78 | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: 79 | """Prompts user to input tags from command line if no tags are provided in config. 80 | 81 | :param cfg: A DictConfig composed by Hydra. 82 | :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. 83 | """ 84 | if not cfg.get("tags"): 85 | if "id" in HydraConfig().cfg.hydra.job: 86 | raise ValueError("Specify tags before launching a multirun!") 87 | 88 | log.warning("No tags provided in config. Prompting user to input tags...") 89 | tags = Prompt.ask("Enter a list of comma separated tags", default="dev") 90 | tags = [t.strip() for t in tags.split(",") if t != ""] 91 | 92 | with open_dict(cfg): 93 | cfg.tags = tags 94 | 95 | log.info(f"Tags: {cfg.tags}") 96 | 97 | if save_to_file: 98 | with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: 99 | rich.print(cfg.tags, file=file) 100 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from importlib.util import find_spec 3 | from typing import Any, Callable, Dict, Optional, Tuple 4 | 5 | from omegaconf import DictConfig 6 | 7 | from src.utils import pylogger, rich_utils 8 | 9 | log = pylogger.RankedLogger(__name__, rank_zero_only=True) 10 | 11 | 12 | def extras(cfg: DictConfig) -> None: 13 | """Applies optional utilities before the task is started. 14 | 15 | Utilities: 16 | - Ignoring python warnings 17 | - Setting tags from command line 18 | - Rich config printing 19 | 20 | :param cfg: A DictConfig object containing the config tree. 21 | """ 22 | # return if no `extras` config 23 | if not cfg.get("extras"): 24 | log.warning("Extras config not found! ") 25 | return 26 | 27 | # disable python warnings 28 | if cfg.extras.get("ignore_warnings"): 29 | log.info("Disabling python warnings! ") 30 | warnings.filterwarnings("ignore") 31 | 32 | # prompt user to input tags from command line if none are provided in the config 33 | if cfg.extras.get("enforce_tags"): 34 | log.info("Enforcing tags! ") 35 | rich_utils.enforce_tags(cfg, save_to_file=True) 36 | 37 | # pretty print config tree using Rich library 38 | if cfg.extras.get("print_config"): 39 | log.info("Printing config tree with Rich! ") 40 | rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) 41 | 42 | 43 | def task_wrapper(task_func: Callable) -> Callable: 44 | """Optional decorator that controls the failure behavior when executing the task function. 45 | 46 | This wrapper can be used to: 47 | - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) 48 | - save the exception to a `.log` file 49 | - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) 50 | - etc. (adjust depending on your needs) 51 | 52 | Example: 53 | ``` 54 | @utils.task_wrapper 55 | def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 56 | ... 57 | return metric_dict, object_dict 58 | ``` 59 | 60 | :param task_func: The task function to be wrapped. 61 | 62 | :return: The wrapped task function. 63 | """ 64 | 65 | def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 66 | # execute the task 67 | try: 68 | metric_dict, object_dict = task_func(cfg=cfg) 69 | 70 | # things to do if exception occurs 71 | except Exception as ex: 72 | # save exception to `.log` file 73 | log.exception("") 74 | 75 | # some hyperparameter combinations might be invalid or cause out-of-memory errors 76 | # so when using hparam search plugins like Optuna, you might want to disable 77 | # raising the below exception to avoid multirun failure 78 | raise ex 79 | 80 | # things to always do after either success or exception 81 | finally: 82 | # display output dir path in terminal 83 | log.info(f"Output dir: {cfg.paths.output_dir}") 84 | 85 | # always close wandb run (even if exception occurs so multirun won't fail) 86 | if find_spec("wandb"): # check if wandb is installed 87 | import wandb 88 | 89 | if wandb.run: 90 | log.info("Closing wandb!") 91 | wandb.finish() 92 | 93 | return metric_dict, object_dict 94 | 95 | return wrap 96 | 97 | 98 | def get_metric_value(metric_dict: Dict[str, Any], metric_name: Optional[str]) -> Optional[float]: 99 | """Safely retrieves value of the metric logged in LightningModule. 100 | 101 | :param metric_dict: A dict containing metric values. 102 | :param metric_name: If provided, the name of the metric to retrieve. 103 | :return: If a metric name was provided, the value of the metric. 104 | """ 105 | if not metric_name: 106 | log.info("Metric name is None! Skipping metric value retrieval...") 107 | return None 108 | 109 | if metric_name not in metric_dict: 110 | raise Exception( 111 | f"Metric value not found! \n" 112 | "Make sure metric name logged in LightningModule is correct!\n" 113 | "Make sure `optimized_metric` name in `hparams_search` config is correct!" 114 | ) 115 | 116 | metric_value = metric_dict[metric_name].item() 117 | log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") 118 | 119 | return metric_value 120 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olsdavis/fisher-flow/8102f853e3b4c0f29f3a700959ec86e426fa86e8/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """This file prepares config fixtures for other tests.""" 2 | 3 | from pathlib import Path 4 | 5 | import pytest 6 | import rootutils 7 | from hydra import compose, initialize 8 | from hydra.core.global_hydra import GlobalHydra 9 | from omegaconf import DictConfig, open_dict 10 | 11 | 12 | @pytest.fixture(scope="package") 13 | def cfg_train_global() -> DictConfig: 14 | """A pytest fixture for setting up a default Hydra DictConfig for training. 15 | 16 | :return: A DictConfig object containing a default Hydra configuration for training. 17 | """ 18 | with initialize(version_base="1.3", config_path="../configs"): 19 | cfg = compose(config_name="train.yaml", return_hydra_config=True, overrides=[]) 20 | 21 | # set defaults for all tests 22 | with open_dict(cfg): 23 | cfg.paths.root_dir = str(rootutils.find_root(indicator=".project-root")) 24 | cfg.trainer.max_epochs = 1 25 | cfg.trainer.limit_train_batches = 0.01 26 | cfg.trainer.limit_val_batches = 0.1 27 | cfg.trainer.limit_test_batches = 0.1 28 | cfg.trainer.accelerator = "cpu" 29 | cfg.trainer.devices = 1 30 | cfg.data.num_workers = 0 31 | cfg.data.pin_memory = False 32 | cfg.extras.print_config = False 33 | cfg.extras.enforce_tags = False 34 | cfg.logger = None 35 | 36 | return cfg 37 | 38 | 39 | @pytest.fixture(scope="package") 40 | def cfg_eval_global() -> DictConfig: 41 | """A pytest fixture for setting up a default Hydra DictConfig for evaluation. 42 | 43 | :return: A DictConfig containing a default Hydra configuration for evaluation. 44 | """ 45 | with initialize(version_base="1.3", config_path="../configs"): 46 | cfg = compose(config_name="eval.yaml", return_hydra_config=True, overrides=["ckpt_path=."]) 47 | 48 | # set defaults for all tests 49 | with open_dict(cfg): 50 | cfg.paths.root_dir = str(rootutils.find_root(indicator=".project-root")) 51 | cfg.trainer.max_epochs = 1 52 | cfg.trainer.limit_test_batches = 0.1 53 | cfg.trainer.accelerator = "cpu" 54 | cfg.trainer.devices = 1 55 | cfg.data.num_workers = 0 56 | cfg.data.pin_memory = False 57 | cfg.extras.print_config = False 58 | cfg.extras.enforce_tags = False 59 | cfg.logger = None 60 | 61 | return cfg 62 | 63 | 64 | @pytest.fixture(scope="function") 65 | def cfg_train(cfg_train_global: DictConfig, tmp_path: Path) -> DictConfig: 66 | """A pytest fixture built on top of the `cfg_train_global()` fixture, which accepts a temporary 67 | logging path `tmp_path` for generating a temporary logging path. 68 | 69 | This is called by each test which uses the `cfg_train` arg. Each test generates its own temporary logging path. 70 | 71 | :param cfg_train_global: The input DictConfig object to be modified. 72 | :param tmp_path: The temporary logging path. 73 | 74 | :return: A DictConfig with updated output and log directories corresponding to `tmp_path`. 75 | """ 76 | cfg = cfg_train_global.copy() 77 | 78 | with open_dict(cfg): 79 | cfg.paths.output_dir = str(tmp_path) 80 | cfg.paths.log_dir = str(tmp_path) 81 | 82 | yield cfg 83 | 84 | GlobalHydra.instance().clear() 85 | 86 | 87 | @pytest.fixture(scope="function") 88 | def cfg_eval(cfg_eval_global: DictConfig, tmp_path: Path) -> DictConfig: 89 | """A pytest fixture built on top of the `cfg_eval_global()` fixture, which accepts a temporary 90 | logging path `tmp_path` for generating a temporary logging path. 91 | 92 | This is called by each test which uses the `cfg_eval` arg. Each test generates its own temporary logging path. 93 | 94 | :param cfg_train_global: The input DictConfig object to be modified. 95 | :param tmp_path: The temporary logging path. 96 | 97 | :return: A DictConfig with updated output and log directories corresponding to `tmp_path`. 98 | """ 99 | cfg = cfg_eval_global.copy() 100 | 101 | with open_dict(cfg): 102 | cfg.paths.output_dir = str(tmp_path) 103 | cfg.paths.log_dir = str(tmp_path) 104 | 105 | yield cfg 106 | 107 | GlobalHydra.instance().clear() 108 | -------------------------------------------------------------------------------- /tests/helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olsdavis/fisher-flow/8102f853e3b4c0f29f3a700959ec86e426fa86e8/tests/helpers/__init__.py -------------------------------------------------------------------------------- /tests/helpers/package_available.py: -------------------------------------------------------------------------------- 1 | import platform 2 | 3 | import pkg_resources 4 | from lightning.fabric.accelerators import TPUAccelerator 5 | 6 | 7 | def _package_available(package_name: str) -> bool: 8 | """Check if a package is available in your environment. 9 | 10 | :param package_name: The name of the package to be checked. 11 | 12 | :return: `True` if the package is available. `False` otherwise. 13 | """ 14 | try: 15 | return pkg_resources.require(package_name) is not None 16 | except pkg_resources.DistributionNotFound: 17 | return False 18 | 19 | 20 | _TPU_AVAILABLE = TPUAccelerator.is_available() 21 | 22 | _IS_WINDOWS = platform.system() == "Windows" 23 | 24 | _SH_AVAILABLE = not _IS_WINDOWS and _package_available("sh") 25 | 26 | _DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _package_available("deepspeed") 27 | _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _package_available("fairscale") 28 | 29 | _WANDB_AVAILABLE = _package_available("wandb") 30 | _NEPTUNE_AVAILABLE = _package_available("neptune") 31 | _COMET_AVAILABLE = _package_available("comet_ml") 32 | _MLFLOW_AVAILABLE = _package_available("mlflow") 33 | -------------------------------------------------------------------------------- /tests/helpers/run_if.py: -------------------------------------------------------------------------------- 1 | """Adapted from: 2 | 3 | https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/helpers/runif.py 4 | """ 5 | 6 | import sys 7 | from typing import Any, Dict, Optional 8 | 9 | import pytest 10 | import torch 11 | from packaging.version import Version 12 | from pkg_resources import get_distribution 13 | from pytest import MarkDecorator 14 | 15 | from tests.helpers.package_available import ( 16 | _COMET_AVAILABLE, 17 | _DEEPSPEED_AVAILABLE, 18 | _FAIRSCALE_AVAILABLE, 19 | _IS_WINDOWS, 20 | _MLFLOW_AVAILABLE, 21 | _NEPTUNE_AVAILABLE, 22 | _SH_AVAILABLE, 23 | _TPU_AVAILABLE, 24 | _WANDB_AVAILABLE, 25 | ) 26 | 27 | 28 | class RunIf: 29 | """RunIf wrapper for conditional skipping of tests. 30 | 31 | Fully compatible with `@pytest.mark`. 32 | 33 | Example: 34 | 35 | ```python 36 | @RunIf(min_torch="1.8") 37 | @pytest.mark.parametrize("arg1", [1.0, 2.0]) 38 | def test_wrapper(arg1): 39 | assert arg1 > 0 40 | ``` 41 | """ 42 | 43 | def __new__( 44 | cls, 45 | min_gpus: int = 0, 46 | min_torch: Optional[str] = None, 47 | max_torch: Optional[str] = None, 48 | min_python: Optional[str] = None, 49 | skip_windows: bool = False, 50 | sh: bool = False, 51 | tpu: bool = False, 52 | fairscale: bool = False, 53 | deepspeed: bool = False, 54 | wandb: bool = False, 55 | neptune: bool = False, 56 | comet: bool = False, 57 | mlflow: bool = False, 58 | **kwargs: Dict[Any, Any], 59 | ) -> MarkDecorator: 60 | """Creates a new `@RunIf` `MarkDecorator` decorator. 61 | 62 | :param min_gpus: Min number of GPUs required to run test. 63 | :param min_torch: Minimum pytorch version to run test. 64 | :param max_torch: Maximum pytorch version to run test. 65 | :param min_python: Minimum python version required to run test. 66 | :param skip_windows: Skip test for Windows platform. 67 | :param tpu: If TPU is available. 68 | :param sh: If `sh` module is required to run the test. 69 | :param fairscale: If `fairscale` module is required to run the test. 70 | :param deepspeed: If `deepspeed` module is required to run the test. 71 | :param wandb: If `wandb` module is required to run the test. 72 | :param neptune: If `neptune` module is required to run the test. 73 | :param comet: If `comet` module is required to run the test. 74 | :param mlflow: If `mlflow` module is required to run the test. 75 | :param kwargs: Native `pytest.mark.skipif` keyword arguments. 76 | """ 77 | conditions = [] 78 | reasons = [] 79 | 80 | if min_gpus: 81 | conditions.append(torch.cuda.device_count() < min_gpus) 82 | reasons.append(f"GPUs>={min_gpus}") 83 | 84 | if min_torch: 85 | torch_version = get_distribution("torch").version 86 | conditions.append(Version(torch_version) < Version(min_torch)) 87 | reasons.append(f"torch>={min_torch}") 88 | 89 | if max_torch: 90 | torch_version = get_distribution("torch").version 91 | conditions.append(Version(torch_version) >= Version(max_torch)) 92 | reasons.append(f"torch<{max_torch}") 93 | 94 | if min_python: 95 | py_version = ( 96 | f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" 97 | ) 98 | conditions.append(Version(py_version) < Version(min_python)) 99 | reasons.append(f"python>={min_python}") 100 | 101 | if skip_windows: 102 | conditions.append(_IS_WINDOWS) 103 | reasons.append("does not run on Windows") 104 | 105 | if tpu: 106 | conditions.append(not _TPU_AVAILABLE) 107 | reasons.append("TPU") 108 | 109 | if sh: 110 | conditions.append(not _SH_AVAILABLE) 111 | reasons.append("sh") 112 | 113 | if fairscale: 114 | conditions.append(not _FAIRSCALE_AVAILABLE) 115 | reasons.append("fairscale") 116 | 117 | if deepspeed: 118 | conditions.append(not _DEEPSPEED_AVAILABLE) 119 | reasons.append("deepspeed") 120 | 121 | if wandb: 122 | conditions.append(not _WANDB_AVAILABLE) 123 | reasons.append("wandb") 124 | 125 | if neptune: 126 | conditions.append(not _NEPTUNE_AVAILABLE) 127 | reasons.append("neptune") 128 | 129 | if comet: 130 | conditions.append(not _COMET_AVAILABLE) 131 | reasons.append("comet") 132 | 133 | if mlflow: 134 | conditions.append(not _MLFLOW_AVAILABLE) 135 | reasons.append("mlflow") 136 | 137 | reasons = [rs for cond, rs in zip(conditions, reasons) if cond] 138 | return pytest.mark.skipif( 139 | condition=any(conditions), 140 | reason=f"Requires: [{' + '.join(reasons)}]", 141 | **kwargs, 142 | ) 143 | -------------------------------------------------------------------------------- /tests/helpers/run_sh_command.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pytest 4 | 5 | from tests.helpers.package_available import _SH_AVAILABLE 6 | 7 | if _SH_AVAILABLE: 8 | import sh 9 | 10 | 11 | def run_sh_command(command: List[str]) -> None: 12 | """Default method for executing shell commands with `pytest` and `sh` package. 13 | 14 | :param command: A list of shell commands as strings. 15 | """ 16 | msg = None 17 | try: 18 | sh.python(command) 19 | except sh.ErrorReturnCode as e: 20 | msg = e.stderr.decode() 21 | if msg: 22 | pytest.fail(msg=msg) 23 | -------------------------------------------------------------------------------- /tests/test_configs.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from hydra.core.hydra_config import HydraConfig 3 | from omegaconf import DictConfig 4 | 5 | 6 | def test_train_config(cfg_train: DictConfig) -> None: 7 | """Tests the training configuration provided by the `cfg_train` pytest fixture. 8 | 9 | :param cfg_train: A DictConfig containing a valid training configuration. 10 | """ 11 | assert cfg_train 12 | assert cfg_train.data 13 | assert cfg_train.model 14 | assert cfg_train.trainer 15 | 16 | HydraConfig().set_config(cfg_train) 17 | 18 | hydra.utils.instantiate(cfg_train.data) 19 | hydra.utils.instantiate(cfg_train.model) 20 | hydra.utils.instantiate(cfg_train.trainer) 21 | 22 | 23 | def test_eval_config(cfg_eval: DictConfig) -> None: 24 | """Tests the evaluation configuration provided by the `cfg_eval` pytest fixture. 25 | 26 | :param cfg_train: A DictConfig containing a valid evaluation configuration. 27 | """ 28 | assert cfg_eval 29 | assert cfg_eval.data 30 | assert cfg_eval.model 31 | assert cfg_eval.trainer 32 | 33 | HydraConfig().set_config(cfg_eval) 34 | 35 | hydra.utils.instantiate(cfg_eval.data) 36 | hydra.utils.instantiate(cfg_eval.model) 37 | hydra.utils.instantiate(cfg_eval.trainer) 38 | -------------------------------------------------------------------------------- /tests/test_datamodules.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | import torch 5 | 6 | from src.data.mnist_datamodule import MNISTDataModule 7 | 8 | 9 | @pytest.mark.parametrize("batch_size", [32, 128]) 10 | def test_mnist_datamodule(batch_size: int) -> None: 11 | """Tests `MNISTDataModule` to verify that it can be downloaded correctly, that the necessary 12 | attributes were created (e.g., the dataloader objects), and that dtypes and batch sizes 13 | correctly match. 14 | 15 | :param batch_size: Batch size of the data to be loaded by the dataloader. 16 | """ 17 | data_dir = "data/" 18 | 19 | dm = MNISTDataModule(data_dir=data_dir, batch_size=batch_size) 20 | dm.prepare_data() 21 | 22 | assert not dm.data_train and not dm.data_val and not dm.data_test 23 | assert Path(data_dir, "MNIST").exists() 24 | assert Path(data_dir, "MNIST", "raw").exists() 25 | 26 | dm.setup() 27 | assert dm.data_train and dm.data_val and dm.data_test 28 | assert dm.train_dataloader() and dm.val_dataloader() and dm.test_dataloader() 29 | 30 | num_datapoints = len(dm.data_train) + len(dm.data_val) + len(dm.data_test) 31 | assert num_datapoints == 70_000 32 | 33 | batch = next(iter(dm.train_dataloader())) 34 | x, y = batch 35 | assert len(x) == batch_size 36 | assert len(y) == batch_size 37 | assert x.dtype == torch.float32 38 | assert y.dtype == torch.int64 39 | -------------------------------------------------------------------------------- /tests/test_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import pytest 5 | from hydra.core.hydra_config import HydraConfig 6 | from omegaconf import DictConfig, open_dict 7 | 8 | from src.eval import evaluate 9 | from src.train import train 10 | 11 | 12 | @pytest.mark.slow 13 | def test_train_eval(tmp_path: Path, cfg_train: DictConfig, cfg_eval: DictConfig) -> None: 14 | """Tests training and evaluation by training for 1 epoch with `train.py` then evaluating with 15 | `eval.py`. 16 | 17 | :param tmp_path: The temporary logging path. 18 | :param cfg_train: A DictConfig containing a valid training configuration. 19 | :param cfg_eval: A DictConfig containing a valid evaluation configuration. 20 | """ 21 | assert str(tmp_path) == cfg_train.paths.output_dir == cfg_eval.paths.output_dir 22 | 23 | with open_dict(cfg_train): 24 | cfg_train.trainer.max_epochs = 1 25 | cfg_train.test = True 26 | 27 | HydraConfig().set_config(cfg_train) 28 | train_metric_dict, _ = train(cfg_train) 29 | 30 | assert "last.ckpt" in os.listdir(tmp_path / "checkpoints") 31 | 32 | with open_dict(cfg_eval): 33 | cfg_eval.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt") 34 | 35 | HydraConfig().set_config(cfg_eval) 36 | test_metric_dict, _ = evaluate(cfg_eval) 37 | 38 | assert test_metric_dict["test/acc"] > 0.0 39 | assert abs(train_metric_dict["test/acc"].item() - test_metric_dict["test/acc"].item()) < 0.001 40 | -------------------------------------------------------------------------------- /tests/test_maths.py: -------------------------------------------------------------------------------- 1 | """Testing maths functions.""" 2 | import unittest 3 | import torch 4 | from src.sfm import safe_arccos, usinc 5 | 6 | 7 | class TestUSinc(unittest.TestCase): 8 | """Tests usinc.""" 9 | 10 | @torch.no_grad() 11 | def test_usinc(self): 12 | """Tests whether the usinc function is not ill-defined.""" 13 | self.assertTrue( 14 | torch.allclose(usinc(torch.Tensor([0.0])), torch.Tensor([1.0]), atol=1e-7), 15 | "sinc(0) = 1", 16 | ) 17 | self.assertTrue( 18 | torch.allclose(usinc(torch.Tensor([torch.pi])), torch.Tensor([0.0]), atol=1e-7), 19 | "sinc(pi) = 0", 20 | ) 21 | self.assertTrue( 22 | torch.allclose(usinc(torch.Tensor([torch.pi / 2.0])), torch.Tensor([2.0 / torch.pi]), atol=1e-7), 23 | "sinc(1) = 0", 24 | ) 25 | 26 | 27 | class TestSafeArccos(unittest.TestCase): 28 | """Tests arccos.""" 29 | 30 | @torch.no_grad() 31 | def test_safe_arccos(self): 32 | """Tests whether the safe arccos function is not ill-defined.""" 33 | self.assertTrue( 34 | torch.allclose(safe_arccos(torch.tensor(1.0)), torch.tensor(0.0)), 35 | "arccos(1) = 0", 36 | ) 37 | self.assertTrue( 38 | torch.allclose(safe_arccos(torch.tensor(-1.0)), torch.tensor(torch.pi)), 39 | "arccos(-1) = pi", 40 | ) 41 | self.assertTrue( 42 | torch.allclose(safe_arccos(torch.tensor(0.0)), torch.tensor(torch.pi / 2.0)), 43 | "arccos(0) = pi / 2", 44 | ) 45 | 46 | 47 | if __name__ == "__main__": 48 | unittest.main() 49 | -------------------------------------------------------------------------------- /tests/test_sweeps.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | 5 | from tests.helpers.run_if import RunIf 6 | from tests.helpers.run_sh_command import run_sh_command 7 | 8 | startfile = "src/train.py" 9 | overrides = ["logger=[]"] 10 | 11 | 12 | @RunIf(sh=True) 13 | @pytest.mark.slow 14 | def test_experiments(tmp_path: Path) -> None: 15 | """Test running all available experiment configs with `fast_dev_run=True.` 16 | 17 | :param tmp_path: The temporary logging path. 18 | """ 19 | command = [ 20 | startfile, 21 | "-m", 22 | "experiment=glob(*)", 23 | "hydra.sweep.dir=" + str(tmp_path), 24 | "++trainer.fast_dev_run=true", 25 | ] + overrides 26 | run_sh_command(command) 27 | 28 | 29 | @RunIf(sh=True) 30 | @pytest.mark.slow 31 | def test_hydra_sweep(tmp_path: Path) -> None: 32 | """Test default hydra sweep. 33 | 34 | :param tmp_path: The temporary logging path. 35 | """ 36 | command = [ 37 | startfile, 38 | "-m", 39 | "hydra.sweep.dir=" + str(tmp_path), 40 | "model.optimizer.lr=0.005,0.01", 41 | "++trainer.fast_dev_run=true", 42 | ] + overrides 43 | 44 | run_sh_command(command) 45 | 46 | 47 | @RunIf(sh=True) 48 | @pytest.mark.slow 49 | def test_hydra_sweep_ddp_sim(tmp_path: Path) -> None: 50 | """Test default hydra sweep with ddp sim. 51 | 52 | :param tmp_path: The temporary logging path. 53 | """ 54 | command = [ 55 | startfile, 56 | "-m", 57 | "hydra.sweep.dir=" + str(tmp_path), 58 | "trainer=ddp_sim", 59 | "trainer.max_epochs=3", 60 | "+trainer.limit_train_batches=0.01", 61 | "+trainer.limit_val_batches=0.1", 62 | "+trainer.limit_test_batches=0.1", 63 | "model.optimizer.lr=0.005,0.01,0.02", 64 | ] + overrides 65 | run_sh_command(command) 66 | 67 | 68 | @RunIf(sh=True) 69 | @pytest.mark.slow 70 | def test_optuna_sweep(tmp_path: Path) -> None: 71 | """Test Optuna hyperparam sweeping. 72 | 73 | :param tmp_path: The temporary logging path. 74 | """ 75 | command = [ 76 | startfile, 77 | "-m", 78 | "hparams_search=mnist_optuna", 79 | "hydra.sweep.dir=" + str(tmp_path), 80 | "hydra.sweeper.n_trials=10", 81 | "hydra.sweeper.sampler.n_startup_trials=5", 82 | "++trainer.fast_dev_run=true", 83 | ] + overrides 84 | run_sh_command(command) 85 | 86 | 87 | @RunIf(wandb=True, sh=True) 88 | @pytest.mark.slow 89 | def test_optuna_sweep_ddp_sim_wandb(tmp_path: Path) -> None: 90 | """Test Optuna sweep with wandb logging and ddp sim. 91 | 92 | :param tmp_path: The temporary logging path. 93 | """ 94 | command = [ 95 | startfile, 96 | "-m", 97 | "hparams_search=mnist_optuna", 98 | "hydra.sweep.dir=" + str(tmp_path), 99 | "hydra.sweeper.n_trials=5", 100 | "trainer=ddp_sim", 101 | "trainer.max_epochs=3", 102 | "+trainer.limit_train_batches=0.01", 103 | "+trainer.limit_val_batches=0.1", 104 | "+trainer.limit_test_batches=0.1", 105 | "logger=wandb", 106 | ] 107 | run_sh_command(command) 108 | -------------------------------------------------------------------------------- /tests/test_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import pytest 5 | from hydra.core.hydra_config import HydraConfig 6 | from omegaconf import DictConfig, open_dict 7 | 8 | from src.train import train 9 | from tests.helpers.run_if import RunIf 10 | 11 | 12 | def test_train_fast_dev_run(cfg_train: DictConfig) -> None: 13 | """Run for 1 train, val and test step. 14 | 15 | :param cfg_train: A DictConfig containing a valid training configuration. 16 | """ 17 | HydraConfig().set_config(cfg_train) 18 | with open_dict(cfg_train): 19 | cfg_train.trainer.fast_dev_run = True 20 | cfg_train.trainer.accelerator = "cpu" 21 | train(cfg_train) 22 | 23 | 24 | @RunIf(min_gpus=1) 25 | def test_train_fast_dev_run_gpu(cfg_train: DictConfig) -> None: 26 | """Run for 1 train, val and test step on GPU. 27 | 28 | :param cfg_train: A DictConfig containing a valid training configuration. 29 | """ 30 | HydraConfig().set_config(cfg_train) 31 | with open_dict(cfg_train): 32 | cfg_train.trainer.fast_dev_run = True 33 | cfg_train.trainer.accelerator = "gpu" 34 | train(cfg_train) 35 | 36 | 37 | @RunIf(min_gpus=1) 38 | @pytest.mark.slow 39 | def test_train_epoch_gpu_amp(cfg_train: DictConfig) -> None: 40 | """Train 1 epoch on GPU with mixed-precision. 41 | 42 | :param cfg_train: A DictConfig containing a valid training configuration. 43 | """ 44 | HydraConfig().set_config(cfg_train) 45 | with open_dict(cfg_train): 46 | cfg_train.trainer.max_epochs = 1 47 | cfg_train.trainer.accelerator = "gpu" 48 | cfg_train.trainer.precision = 16 49 | train(cfg_train) 50 | 51 | 52 | @pytest.mark.slow 53 | def test_train_epoch_double_val_loop(cfg_train: DictConfig) -> None: 54 | """Train 1 epoch with validation loop twice per epoch. 55 | 56 | :param cfg_train: A DictConfig containing a valid training configuration. 57 | """ 58 | HydraConfig().set_config(cfg_train) 59 | with open_dict(cfg_train): 60 | cfg_train.trainer.max_epochs = 1 61 | cfg_train.trainer.val_check_interval = 0.5 62 | train(cfg_train) 63 | 64 | 65 | @pytest.mark.slow 66 | def test_train_ddp_sim(cfg_train: DictConfig) -> None: 67 | """Simulate DDP (Distributed Data Parallel) on 2 CPU processes. 68 | 69 | :param cfg_train: A DictConfig containing a valid training configuration. 70 | """ 71 | HydraConfig().set_config(cfg_train) 72 | with open_dict(cfg_train): 73 | cfg_train.trainer.max_epochs = 2 74 | cfg_train.trainer.accelerator = "cpu" 75 | cfg_train.trainer.devices = 2 76 | cfg_train.trainer.strategy = "ddp_spawn" 77 | train(cfg_train) 78 | 79 | 80 | @pytest.mark.slow 81 | def test_train_resume(tmp_path: Path, cfg_train: DictConfig) -> None: 82 | """Run 1 epoch, finish, and resume for another epoch. 83 | 84 | :param tmp_path: The temporary logging path. 85 | :param cfg_train: A DictConfig containing a valid training configuration. 86 | """ 87 | with open_dict(cfg_train): 88 | cfg_train.trainer.max_epochs = 1 89 | 90 | HydraConfig().set_config(cfg_train) 91 | metric_dict_1, _ = train(cfg_train) 92 | 93 | files = os.listdir(tmp_path / "checkpoints") 94 | assert "last.ckpt" in files 95 | assert "epoch_000.ckpt" in files 96 | 97 | with open_dict(cfg_train): 98 | cfg_train.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt") 99 | cfg_train.trainer.max_epochs = 2 100 | 101 | metric_dict_2, _ = train(cfg_train) 102 | 103 | files = os.listdir(tmp_path / "checkpoints") 104 | assert "epoch_001.ckpt" in files 105 | assert "epoch_002.ckpt" not in files 106 | 107 | assert metric_dict_1["train/acc"] < metric_dict_2["train/acc"] 108 | assert metric_dict_1["val/acc"] < metric_dict_2["val/acc"] 109 | --------------------------------------------------------------------------------