├── configs ├── local │ └── .gitkeep ├── callbacks │ ├── none.yaml │ ├── rich_progress_bar.yaml │ ├── model_summary.yaml │ ├── default.yaml │ ├── early_stopping.yaml │ └── model_checkpoint.yaml ├── trainer │ ├── cpu.yaml │ ├── gpu.yaml │ ├── mps.yaml │ ├── ddp.yaml │ ├── ddp_sim.yaml │ └── default.yaml ├── __init__.py ├── debug │ ├── fdr.yaml │ ├── profiler.yaml │ ├── overfit.yaml │ ├── limit.yaml │ └── default.yaml ├── logger │ ├── many_loggers.yaml │ ├── csv.yaml │ ├── tensorboard.yaml │ ├── neptune.yaml │ ├── mlflow.yaml │ ├── comet.yaml │ ├── wandb.yaml │ └── aim.yaml ├── data │ ├── batched_GFP.yaml │ ├── batched_protein_engineering.yaml │ └── batched_proteingym.yaml ├── extras │ └── default.yaml ├── model │ └── DePLM.yaml ├── hydra │ └── default.yaml ├── paths │ └── default.yaml ├── experiment │ └── example.yaml └── DePLM.yaml ├── .project-root ├── scripts ├── run.sh └── schedule.sh ├── data ├── GFP │ ├── ppluGFP.txt │ ├── cgreGFP.txt │ ├── amacGFP.txt │ ├── avGFP.txt │ ├── avGFP_mt.txt │ └── process.py └── fluorescence │ └── wildtype.txt ├── src ├── utils │ ├── __init__.py │ ├── logging_utils.py │ ├── instantiators.py │ ├── pylogger.py │ ├── rich_utils.py │ └── utils.py ├── models │ ├── DePLM_components │ │ ├── sort.py │ │ └── modules.py │ └── DePLM_module.py ├── train.py └── data │ ├── batched_protein_engineering_datamodule.py │ ├── batched_GFP_datamodule.py │ └── batched_proteingym_substitution_datamodule.py ├── pyproject.toml ├── setup.py ├── makefile ├── README.md ├── .gitignore ├── .pre-commit-config.yaml └── environment.yml /configs/local/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/callbacks/none.yaml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/trainer/cpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: cpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /configs/trainer/gpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: gpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /configs/trainer/mps.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: mps 5 | devices: 1 6 | -------------------------------------------------------------------------------- /.project-root: -------------------------------------------------------------------------------- 1 | # this file is required for inferring the project root directory 2 | # do not delete 3 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | # this file is needed here to include configs when building project as a package 2 | -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python src/train.py data=batched_GFP model.model=esm2 task_name='avGFP' -------------------------------------------------------------------------------- /configs/trainer/ddp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | strategy: ddp 5 | 6 | accelerator: gpu 7 | devices: 4 8 | num_nodes: 1 9 | sync_batchnorm: True 10 | -------------------------------------------------------------------------------- /configs/trainer/ddp_sim.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | # simulate DDP on CPU, useful for debugging 5 | accelerator: cpu 6 | devices: 2 7 | strategy: ddp_spawn 8 | -------------------------------------------------------------------------------- /configs/debug/fdr.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs 1 train, 1 validation and 1 test step 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | fast_dev_run: true 10 | -------------------------------------------------------------------------------- /configs/logger/many_loggers.yaml: -------------------------------------------------------------------------------- 1 | # train with many loggers at once 2 | 3 | defaults: 4 | # - comet 5 | - csv 6 | # - mlflow 7 | # - neptune 8 | - tensorboard 9 | - wandb 10 | -------------------------------------------------------------------------------- /configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # csv logger built in lightning 2 | 3 | csv: 4 | _target_: lightning.pytorch.loggers.csv_logs.CSVLogger 5 | save_dir: "${paths.output_dir}" 6 | name: "csv/" 7 | prefix: "" 8 | -------------------------------------------------------------------------------- /configs/data/batched_GFP.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.data.batched_GFP_datamodule.BatchedGFPDataModule 2 | data_dir: ${paths.data_dir} 3 | num_workers: 0 4 | pin_memory: False 5 | 6 | task_name: avGFP 7 | support_name: [] -------------------------------------------------------------------------------- /configs/callbacks/rich_progress_bar.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html 2 | 3 | rich_progress_bar: 4 | _target_: lightning.pytorch.callbacks.RichProgressBar 5 | -------------------------------------------------------------------------------- /data/GFP/ppluGFP.txt: -------------------------------------------------------------------------------- 1 | MPAMKIECRITGTLNGVEFELVGGGEGTPEQGRMTNKMKSTKGALTFSPYLLSHVMGYGFYHFGTYPSGYENPFLHAINNGGYTNTRIEKYEDGGVLHVSFSYRYEAGRVIGDFKVVGTGFPEDSVIFTDKIIRSNATVEHLHPMGDNVLVGSFARTFSLRDGGYYSFVVDSHMHFKSAIHPSILQNGGPMFAFRRVEELHSNTELGIVEYQHAFKTPIAFA -------------------------------------------------------------------------------- /data/GFP/cgreGFP.txt: -------------------------------------------------------------------------------- 1 | MTALTEGAKLFEKEIPYITELEGDVEGMKFIIKGEGTGDATTGTIKAKYICTTGDLPVPWATILSSLSYGVFCFAKYPRHIADFFKSTQPDGYSQDRIISFDNDGQYDVKAKVTYENGTLYNRVTVKGTGFKSNGNILGMRVLYHSPPHAVYILPDRKNGGMKIEYNKAFDVMGGGHQMARHAQFNKPLGAWEEDYPLYHHLTVWTSFGKDPDDDETDHLTIVEVIKAVDLETYR -------------------------------------------------------------------------------- /data/GFP/amacGFP.txt: -------------------------------------------------------------------------------- 1 | MSKGEELFTGIVPVLIELDGDVHGHKFSVRGEGEGDADYGKLEIKFICTTGKLPVPWPTLVTTLSYGILCFARYPEHMKMNDFFKSAMPEGYIQERTIFFQDDGKYKTRGEVKFEGDTLVNRIELKGMDFKEDGNILGHKLEYNFNSHNVYIMPDKANNGLKVNFKIRHNIEGGGVQLADHYQTNVPLGDGPVLIPINHYLSCQTAISKDRNETRDHMVFLEFFSACGHTHGMDELYK 2 | -------------------------------------------------------------------------------- /data/GFP/avGFP.txt: -------------------------------------------------------------------------------- 1 | MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTLSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK 2 | -------------------------------------------------------------------------------- /configs/debug/profiler.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs with execution time profiling 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | max_epochs: 1 10 | profiler: "simple" 11 | # profiler: "advanced" 12 | # profiler: "pytorch" 13 | -------------------------------------------------------------------------------- /data/GFP/avGFP_mt.txt: -------------------------------------------------------------------------------- 1 | MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDASYGRLTLKFICTTGKLPVPWPTLVTTLSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGSYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK 2 | -------------------------------------------------------------------------------- /configs/extras/default.yaml: -------------------------------------------------------------------------------- 1 | # disable python warnings if they annoy you 2 | ignore_warnings: False 3 | 4 | # ask user for tags if none are provided in the config 5 | enforce_tags: True 6 | 7 | # pretty print config tree at the start of the run using Rich library 8 | print_config: True 9 | -------------------------------------------------------------------------------- /configs/debug/overfit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # overfits to 3 batches 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | max_epochs: 20 10 | overfit_batches: 3 11 | 12 | # model ckpt and early stopping need to be disabled during overfitting 13 | callbacks: null 14 | -------------------------------------------------------------------------------- /configs/data/batched_protein_engineering.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.data.batched_protein_engineering_datamodule.BatchedProteinEngineeringDataModule 2 | data_dir: ${paths.data_dir} 3 | batch_size: 1 4 | num_workers: 0 5 | pin_memory: False 6 | 7 | task_name: fluorescence # beta_lactamase | gb1 | fluorescence -------------------------------------------------------------------------------- /configs/debug/limit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # uses only 1% of the training data and 5% of validation/test data 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | max_epochs: 3 10 | limit_train_batches: 0.01 11 | limit_val_batches: 0.05 12 | limit_test_batches: 0.05 13 | -------------------------------------------------------------------------------- /configs/callbacks/model_summary.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html 2 | 3 | model_summary: 4 | _target_: lightning.pytorch.callbacks.RichModelSummary 5 | max_depth: 1 # the maximum depth of layer nesting that the summary will include 6 | -------------------------------------------------------------------------------- /configs/data/batched_proteingym.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.data.batched_proteingym_substitution_datamodule.BatchedProteinGymSubstitutionDataModule 2 | data_dir: ${paths.data_dir} 3 | batch_size: 1 4 | num_workers: 0 5 | pin_memory: False 6 | 7 | assay_index: 46 8 | split_type: random 9 | split_index: 0 10 | support_assay_num: 40 -------------------------------------------------------------------------------- /scripts/schedule.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Schedule execution of many runs 3 | # Run from root folder with: bash scripts/schedule.sh 4 | 5 | CUDA_VISIBLE_DEVICES=0 python src/train.py data=batched_proteingym data.assay_index=196 data.split_index=0 data.split_type=random task_name=test model.model=esm2 model.diff_total_step=3 debug=default -------------------------------------------------------------------------------- /configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/tensorboard/ 2 | 3 | tensorboard: 4 | _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger 5 | save_dir: "${paths.output_dir}/tensorboard/" 6 | name: null 7 | log_graph: False 8 | default_hp_metric: True 9 | prefix: "" 10 | # version: "" 11 | -------------------------------------------------------------------------------- /configs/logger/neptune.yaml: -------------------------------------------------------------------------------- 1 | # https://neptune.ai 2 | 3 | neptune: 4 | _target_: lightning.pytorch.loggers.neptune.NeptuneLogger 5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable 6 | project: username/lightning-hydra-template 7 | # name: "" 8 | log_model_checkpoints: True 9 | prefix: "" 10 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from src.utils.instantiators import instantiate_callbacks, instantiate_loggers 2 | from src.utils.logging_utils import log_hyperparameters 3 | from src.utils.pylogger import RankedLogger 4 | from src.utils.rich_utils import enforce_tags, print_config_tree 5 | from src.utils.utils import extras, get_metric_value, task_wrapper 6 | -------------------------------------------------------------------------------- /configs/logger/mlflow.yaml: -------------------------------------------------------------------------------- 1 | # https://mlflow.org 2 | 3 | mlflow: 4 | _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger 5 | # experiment_name: "" 6 | # run_name: "" 7 | tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI 8 | tags: null 9 | # save_dir: "./mlruns" 10 | prefix: "" 11 | artifact_location: null 12 | # run_id: "" 13 | -------------------------------------------------------------------------------- /configs/logger/comet.yaml: -------------------------------------------------------------------------------- 1 | # https://www.comet.ml 2 | 3 | comet: 4 | _target_: lightning.pytorch.loggers.comet.CometLogger 5 | api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable 6 | save_dir: "${paths.output_dir}" 7 | project_name: "lightning-hydra-template" 8 | rest_api_key: null 9 | # experiment_name: "" 10 | experiment_key: null # set to resume experiment 11 | offline: False 12 | prefix: "" 13 | -------------------------------------------------------------------------------- /configs/model/DePLM.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.DePLM_module.DePLM4ProteinEngineeringModule 2 | 3 | optimizer: 4 | _target_: torch.optim.AdamW 5 | _partial_: True 6 | lr: 1e-4 7 | weight_decay: 0.005 8 | 9 | scheduler: 10 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 11 | _partial_: True 12 | mode: min 13 | factor: 0.9 14 | patience: 10 15 | 16 | diff_total_step: 3 17 | model: esm2 18 | 19 | # compile model for faster training with pytorch 2.0 20 | compile: false 21 | -------------------------------------------------------------------------------- /data/fluorescence/wildtype.txt: -------------------------------------------------------------------------------- 1 | SKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTLSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK 2 | 3 | SKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFIRTTSKLPVP-PTLVTTLSYGVQCFSRYP-YHMKQHDFFKPAMPEGYVQERTIFFKDDGNCKTRAEVKFEGDTLVSRIELKGIDFKEDGNILGHLEYSYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPELLPNNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK -------------------------------------------------------------------------------- /configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model_checkpoint 3 | - early_stopping 4 | - model_summary 5 | - rich_progress_bar 6 | - _self_ 7 | 8 | model_checkpoint: 9 | dirpath: ${paths.output_dir}/checkpoints 10 | filename: "epoch_{epoch:03d}" 11 | monitor: "val/spearman" 12 | mode: "max" 13 | auto_insert_metric_name: False 14 | every_n_epochs: 50 15 | save_top_k: 0 16 | 17 | early_stopping: 18 | monitor: "val/spearman" 19 | patience: 10000 20 | mode: "max" 21 | 22 | model_summary: 23 | max_depth: -1 24 | -------------------------------------------------------------------------------- /configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: lightning.pytorch.trainer.Trainer 2 | 3 | default_root_dir: ${paths.output_dir} 4 | 5 | min_epochs: 10 # prevents early stopping 6 | max_epochs: 100 7 | 8 | accelerator: gpu 9 | devices: 1 10 | 11 | # mixed precision for extra speed-up [32|16-mixed] 12 | precision: 32 13 | 14 | # perform a validation loop every N training epochs 15 | check_val_every_n_epoch: 1 16 | 17 | # set True to to ensure deterministic results 18 | # makes training slower but gives more reproducibility than just setting seeds 19 | deterministic: False 20 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.pytest.ini_options] 2 | addopts = [ 3 | "--color=yes", 4 | "--durations=0", 5 | "--strict-markers", 6 | "--doctest-modules", 7 | ] 8 | filterwarnings = [ 9 | "ignore::DeprecationWarning", 10 | "ignore::UserWarning", 11 | ] 12 | log_cli = "True" 13 | markers = [ 14 | "slow: slow tests", 15 | ] 16 | minversion = "6.0" 17 | testpaths = "tests/" 18 | 19 | [tool.coverage.report] 20 | exclude_lines = [ 21 | "pragma: nocover", 22 | "raise NotImplementedError", 23 | "raise NotImplementedError()", 24 | "if __name__ == .__main__.:", 25 | ] 26 | -------------------------------------------------------------------------------- /configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: lightning.pytorch.loggers.wandb.WandbLogger 5 | # name: "" # name of the run (normally generated by wandb) 6 | save_dir: "${paths.output_dir}" 7 | offline: False 8 | id: null # pass correct id to resume experiment! 9 | anonymous: null # enable anonymous logging 10 | project: "lightning-hydra-template" 11 | log_model: False # upload lightning ckpts 12 | prefix: "" # a string to put at the beginning of metric keys 13 | # entity: "" # set to name of your wandb team 14 | group: "" 15 | tags: [] 16 | job_type: "" 17 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import find_packages, setup 4 | 5 | setup( 6 | name="src", 7 | version="0.0.1", 8 | description="Describe Your Cool Project", 9 | author="", 10 | author_email="", 11 | url="https://github.com/user/project", 12 | install_requires=["lightning", "hydra-core"], 13 | packages=find_packages(), 14 | # use this to customize global commands available in the terminal after installing the package 15 | entry_points={ 16 | "console_scripts": [ 17 | "train_command = src.train:main", 18 | "eval_command = src.eval:main", 19 | ] 20 | }, 21 | ) -------------------------------------------------------------------------------- /configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # https://hydra.cc/docs/configure_hydra/intro/ 2 | 3 | # enable color logging 4 | defaults: 5 | - override hydra_logging: colorlog 6 | - override job_logging: colorlog 7 | 8 | # output directory, generated dynamically on each run 9 | run: 10 | dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} 11 | sweep: 12 | dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} 13 | subdir: ${hydra.job.num} 14 | 15 | job_logging: 16 | handlers: 17 | file: 18 | # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242 19 | filename: ${hydra.runtime.output_dir}/${task_name}.log 20 | -------------------------------------------------------------------------------- /configs/paths/default.yaml: -------------------------------------------------------------------------------- 1 | # path to root directory 2 | # this requires PROJECT_ROOT environment variable to exist 3 | # you can replace it with "." if you want the root to be the current working directory 4 | root_dir: ${oc.env:PROJECT_ROOT} 5 | 6 | # path to data directory 7 | data_dir: ${paths.root_dir}/data/ 8 | 9 | # path to logging directory 10 | log_dir: ${paths.root_dir}/logs/ 11 | 12 | # path to output directory, created dynamically by hydra 13 | # path generation pattern is specified in `configs/hydra/default.yaml` 14 | # use it to store all files generated during the run, like ckpts and metrics 15 | output_dir: ${hydra:runtime.output_dir} 16 | 17 | # path to working directory 18 | work_dir: ${hydra:runtime.cwd} 19 | -------------------------------------------------------------------------------- /makefile: -------------------------------------------------------------------------------- 1 | 2 | help: ## Show help 3 | @grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 4 | 5 | clean: ## Clean autogenerated files 6 | rm -rf dist 7 | find . -type f -name "*.DS_Store" -ls -delete 8 | find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf 9 | find . | grep -E ".pytest_cache" | xargs rm -rf 10 | find . | grep -E ".ipynb_checkpoints" | xargs rm -rf 11 | rm -f .coverage 12 | 13 | clean-logs: ## Clean logs 14 | rm -rf logs/** 15 | 16 | format: ## Run pre-commit hooks 17 | pre-commit run -a 18 | 19 | sync: ## Merge changes from main branch to your current branch 20 | git pull 21 | git pull origin main 22 | 23 | test: ## Run not slow tests 24 | pytest -k "not slow" 25 | 26 | test-full: ## Run all tests 27 | pytest 28 | 29 | train: ## Train the model 30 | python src/train.py -------------------------------------------------------------------------------- /configs/experiment/example.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: mnist 8 | - override /model: mnist 9 | - override /callbacks: default 10 | - override /trainer: default 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["mnist", "simple_dense_net"] 16 | 17 | seed: 12345 18 | 19 | trainer: 20 | min_epochs: 10 21 | max_epochs: 10 22 | gradient_clip_val: 0.5 23 | 24 | model: 25 | optimizer: 26 | lr: 0.002 27 | net: 28 | lin1_size: 128 29 | lin2_size: 256 30 | lin3_size: 64 31 | compile: false 32 | 33 | data: 34 | batch_size: 64 35 | 36 | logger: 37 | wandb: 38 | tags: ${tags} 39 | group: "mnist" 40 | aim: 41 | experiment: "mnist" 42 | -------------------------------------------------------------------------------- /configs/debug/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default debugging setup, runs 1 full epoch 4 | # other debugging configs can inherit from this one 5 | 6 | # overwrite task name so debugging logs are stored in separate folder 7 | task_name: "debug" 8 | 9 | # disable callbacks and loggers during debugging 10 | callbacks: null 11 | logger: null 12 | 13 | extras: 14 | ignore_warnings: False 15 | enforce_tags: False 16 | 17 | # sets level of all command line loggers to 'DEBUG' 18 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ 19 | hydra: 20 | job_logging: 21 | root: 22 | level: DEBUG 23 | 24 | # use this to also set hydra loggers to 'DEBUG' 25 | # verbose: True 26 | 27 | trainer: 28 | max_epochs: 1 29 | accelerator: gpu # debuggers don't like gpus 30 | devices: 1 # debuggers don't like multiprocessing 31 | detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor 32 | 33 | data: 34 | num_workers: 0 # debuggers don't like multiprocessing 35 | pin_memory: False # disable gpu memory pin 36 | -------------------------------------------------------------------------------- /configs/callbacks/early_stopping.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html 2 | 3 | early_stopping: 4 | _target_: lightning.pytorch.callbacks.EarlyStopping 5 | monitor: ??? # quantity to be monitored, must be specified !!! 6 | min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement 7 | patience: 3 # number of checks with no improvement after which training will be stopped 8 | verbose: False # verbosity mode 9 | mode: "min" # "max" means higher metric value is better, can be also "min" 10 | strict: True # whether to crash the training if monitor is not found in the validation metrics 11 | check_finite: True # when set True, stops training when the monitor becomes NaN or infinite 12 | stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold 13 | divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold 14 | check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch 15 | # log_rank_zero_only: False # this keyword argument isn't available in stable version 16 | -------------------------------------------------------------------------------- /configs/callbacks/model_checkpoint.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html 2 | 3 | model_checkpoint: 4 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 5 | dirpath: null # directory to save the model file 6 | filename: null # checkpoint filename 7 | monitor: null # name of the logged metric which determines when model is improving 8 | verbose: False # verbosity mode 9 | save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt 10 | save_top_k: 1 # save k best models (determined by above metric) 11 | mode: "min" # "max" means higher metric value is better, can be also "min" 12 | auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name 13 | save_weights_only: False # if True, then only the model’s weights will be saved 14 | every_n_train_steps: null # number of training steps between checkpoints 15 | train_time_interval: null # checkpoints are monitored at the specified time interval 16 | every_n_epochs: null # number of epochs between checkpoints 17 | save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation 18 | -------------------------------------------------------------------------------- /configs/logger/aim.yaml: -------------------------------------------------------------------------------- 1 | # https://aimstack.io/ 2 | 3 | # example usage in lightning module: 4 | # https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py 5 | 6 | # open the Aim UI with the following command (run in the folder containing the `.aim` folder): 7 | # `aim up` 8 | 9 | aim: 10 | _target_: aim.pytorch_lightning.AimLogger 11 | repo: ${paths.root_dir} # .aim folder will be created here 12 | # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html# 13 | 14 | # aim allows to group runs under experiment name 15 | experiment: null # any string, set to "default" if not specified 16 | 17 | train_metric_prefix: "train/" 18 | val_metric_prefix: "val/" 19 | test_metric_prefix: "test/" 20 | 21 | # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.) 22 | system_tracking_interval: 10 # set to null to disable system metrics tracking 23 | 24 | # enable/disable logging of system params such as installed packages, git info, env vars, etc. 25 | log_system_params: true 26 | 27 | # enable/disable tracking console logs (default value is true) 28 | capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550 29 | -------------------------------------------------------------------------------- /data/GFP/process.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | protein_name = 'avGFP' 4 | 5 | with open(f'{protein_name}.txt', 'r') as f_in: 6 | sequence = f_in.readline().strip() 7 | 8 | scores = [] 9 | with open(f'{protein_name}_score.txt', 'r') as f_in: 10 | for line in f_in: 11 | scores.append(float(line.strip())) 12 | 13 | mutants = [] 14 | with open(f'{protein_name}_mutant.txt', 'r') as f_in: 15 | for line in f_in: 16 | mutants.append(line.strip()) 17 | 18 | with open(f'{protein_name}.csv', 'w') as f_out: 19 | f_out.write('mutant,mutated_sequence,score,split\n') 20 | for idx, (score, mutant) in enumerate(zip(scores, mutants)): 21 | mutant_sequence = copy.deepcopy(sequence) 22 | output_mutant = [] 23 | if '*' in mutant or '.' in mutant or 'WT' in mutant: 24 | continue 25 | for m in mutant.split(':'): 26 | wt_aa, mt_aa, pos = m[0], m[-1], int(m[1:-1]) 27 | if mutant_sequence[pos] == wt_aa: 28 | mutant_sequence = mutant_sequence[:pos] + mt_aa + mutant_sequence[pos+1:] 29 | else: 30 | import ipdb; ipdb.set_trace() 31 | output_mutant.append(wt_aa + str(pos+1) + mt_aa) 32 | if idx % 10 == 0: 33 | f_out.write(f'{":".join(output_mutant)},{mutant_sequence},{score},{2}\n') 34 | else: 35 | f_out.write(f'{":".join(output_mutant)},{mutant_sequence},{score},{0}\n') 36 | 37 | -------------------------------------------------------------------------------- /configs/DePLM.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default configuration 4 | # order of defaults determines the order in which configs override each other 5 | defaults: 6 | - _self_ 7 | - data: batched_protein_engineering 8 | - model: DePLM 9 | - callbacks: default 10 | - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`) 11 | - trainer: default 12 | - paths: default 13 | - extras: default 14 | - hydra: default 15 | 16 | # experiment configs allow for version control of specific hyperparameters 17 | # e.g. best hyperparameters for given model and datamodule 18 | - experiment: null 19 | 20 | # config for hyperparameter optimization 21 | - hparams_search: null 22 | 23 | # optional local config for machine/user specific settings 24 | # it's optional since it doesn't need to exist and is excluded from version control 25 | - optional local: default 26 | 27 | # debugging config (enable through command line, e.g. `python train.py debug=default) 28 | - debug: null 29 | 30 | # task name, determines output directory path 31 | task_name: "wildtype" 32 | 33 | # tags to help you identify your experiments 34 | # you can overwrite this in experiment configs 35 | # overwrite from command line with `python train.py tags="[first_tag, second_tag]"` 36 | tags: ["dev"] 37 | 38 | # set False to skip model training 39 | train: True 40 | 41 | # evaluate on test set, using best model weights achieved during training 42 | # lightning chooses best weights based on the metric specified in checkpoint callback 43 | test: False 44 | 45 | # simply provide checkpoint path to resume training 46 | ckpt_path: null 47 | 48 | # seed for random number generators in pytorch, numpy and python.random 49 | seed: 2024 50 | -------------------------------------------------------------------------------- /src/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | from lightning_utilities.core.rank_zero import rank_zero_only 4 | from omegaconf import OmegaConf 5 | 6 | from src.utils import pylogger 7 | 8 | log = pylogger.RankedLogger(__name__, rank_zero_only=True) 9 | 10 | 11 | @rank_zero_only 12 | def log_hyperparameters(object_dict: Dict[str, Any]) -> None: 13 | """Controls which config parts are saved by Lightning loggers. 14 | 15 | Additionally saves: 16 | - Number of model parameters 17 | 18 | :param object_dict: A dictionary containing the following objects: 19 | - `"cfg"`: A DictConfig object containing the main config. 20 | - `"model"`: The Lightning model. 21 | - `"trainer"`: The Lightning trainer. 22 | """ 23 | hparams = {} 24 | 25 | cfg = OmegaConf.to_container(object_dict["cfg"]) 26 | model = object_dict["model"] 27 | trainer = object_dict["trainer"] 28 | 29 | if not trainer.logger: 30 | log.warning("Logger not found! Skipping hyperparameter logging...") 31 | return 32 | 33 | hparams["model"] = cfg["model"] 34 | 35 | # save number of model parameters 36 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) 37 | hparams["model/params/trainable"] = sum( 38 | p.numel() for p in model.parameters() if p.requires_grad 39 | ) 40 | hparams["model/params/non_trainable"] = sum( 41 | p.numel() for p in model.parameters() if not p.requires_grad 42 | ) 43 | 44 | hparams["data"] = cfg["data"] 45 | hparams["trainer"] = cfg["trainer"] 46 | 47 | hparams["callbacks"] = cfg.get("callbacks") 48 | hparams["extras"] = cfg.get("extras") 49 | 50 | hparams["task_name"] = cfg.get("task_name") 51 | hparams["tags"] = cfg.get("tags") 52 | hparams["ckpt_path"] = cfg.get("ckpt_path") 53 | hparams["seed"] = cfg.get("seed") 54 | 55 | # send hparams to all loggers 56 | for logger in trainer.loggers: 57 | logger.log_hyperparams(hparams) 58 | -------------------------------------------------------------------------------- /src/utils/instantiators.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import hydra 4 | from lightning import Callback 5 | from lightning.pytorch.loggers import Logger 6 | from omegaconf import DictConfig 7 | 8 | from src.utils import pylogger 9 | 10 | log = pylogger.RankedLogger(__name__, rank_zero_only=True) 11 | 12 | 13 | def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: 14 | """Instantiates callbacks from config. 15 | 16 | :param callbacks_cfg: A DictConfig object containing callback configurations. 17 | :return: A list of instantiated callbacks. 18 | """ 19 | callbacks: List[Callback] = [] 20 | 21 | if not callbacks_cfg: 22 | log.warning("No callback configs found! Skipping..") 23 | return callbacks 24 | 25 | if not isinstance(callbacks_cfg, DictConfig): 26 | raise TypeError("Callbacks config must be a DictConfig!") 27 | 28 | for _, cb_conf in callbacks_cfg.items(): 29 | if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: 30 | log.info(f"Instantiating callback <{cb_conf._target_}>") 31 | callbacks.append(hydra.utils.instantiate(cb_conf)) 32 | 33 | return callbacks 34 | 35 | 36 | def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: 37 | """Instantiates loggers from config. 38 | 39 | :param logger_cfg: A DictConfig object containing logger configurations. 40 | :return: A list of instantiated loggers. 41 | """ 42 | logger: List[Logger] = [] 43 | 44 | if not logger_cfg: 45 | log.warning("No logger configs found! Skipping...") 46 | return logger 47 | 48 | if not isinstance(logger_cfg, DictConfig): 49 | raise TypeError("Logger config must be a DictConfig!") 50 | 51 | for _, lg_conf in logger_cfg.items(): 52 | if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: 53 | log.info(f"Instantiating logger <{lg_conf._target_}>") 54 | logger.append(hydra.utils.instantiate(lg_conf)) 55 | 56 | return logger 57 | -------------------------------------------------------------------------------- /src/models/DePLM_components/sort.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import numpy as np 4 | 5 | def _find_repeats(data: Tensor) -> Tensor: 6 | """Find and return values which have repeats i.e. the same value are more than once in the tensor.""" 7 | temp = data.detach().clone() 8 | temp = temp.sort()[0] 9 | 10 | change = torch.cat([torch.tensor([True], device=temp.device), temp[1:] != temp[:-1]]) 11 | unique = temp[change] 12 | change_idx = torch.cat([torch.nonzero(change), torch.tensor([[temp.numel()]], device=temp.device)]).flatten() 13 | freq = change_idx[1:] - change_idx[:-1] 14 | atleast2 = freq > 1 15 | return unique[atleast2] 16 | 17 | def _rank_data(data: Tensor) -> Tensor: 18 | """Calculate the rank for each element of a tensor. 19 | 20 | The rank refers to the indices of an element in the corresponding sorted tensor (starting from 1). Duplicates of the 21 | same value will be assigned the mean of their rank. 22 | 23 | Adopted from `Rank of element tensor`_ 24 | 25 | """ 26 | n = data.numel() 27 | rank = torch.empty_like(data) 28 | idx = data.argsort() 29 | rank[idx[:n]] = torch.arange(1, n + 1, dtype=data.dtype, device=data.device) 30 | 31 | repeats = _find_repeats(data) 32 | for r in repeats: 33 | condition = data == r 34 | rank[condition] = rank[condition].mean() 35 | return rank 36 | 37 | def quick_sort(arr): 38 | partial_sorted_arr = [] 39 | if len(arr) < 2: return arr 40 | stack = [] 41 | stack.append([0, len(arr)-1]) 42 | 43 | while stack: 44 | tmp_stack = [] 45 | while stack: 46 | l, r = stack.pop() 47 | index = partition(arr, l, r) 48 | if l < index - 1: 49 | tmp_stack.append([l, index-1]) 50 | if r > index + 1: 51 | tmp_stack.append([index+1, r]) 52 | stack = tmp_stack 53 | partial_sorted_arr.append(arr.copy()) 54 | return partial_sorted_arr 55 | 56 | 57 | def partition(arr, s, t): 58 | pivot = np.random.randint(s, t+1) 59 | arr[s], arr[pivot] = arr[pivot], arr[s] 60 | tmp = arr[s] 61 | while s < t: 62 | while s < t and arr[t] >= tmp: t -= 1 63 | arr[s] = arr[t] 64 | while s < t and arr[s] <= tmp: s += 1 65 | arr[t] = arr[s] 66 | arr[s] = tmp 67 | return s -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DePLM: Denoising Protein Language Models for Property Optimization 2 | 3 | This repository is the official model and benchmark proposed in a paper: [DePLM: Denoising Protein Language Models for Property Optimization](https://neurips.cc/virtual/2024/poster/95517). 4 | 5 | [![Openbayes Demo](https://img.shields.io/static/v1?label=Demo&message=OpenBayes%E8%B4%9D%E5%BC%8F%E8%AE%A1%E7%AE%97&color=green)](https://openbayes.com/console/public/tutorials/tAf7dtY7k9n) 6 | [![license](https://img.shields.io/badge/License-MIT-blue.svg?labelColor=grey)](https://github.com/ashleve/lightning-hydra-template#license) 7 | ![](https://img.shields.io/github/last-commit/HICAI-ZJU/DePLM?color=blue) 8 | 9 | 10 | 11 | ## Description 12 | 13 | The central concept of DePLM revolves around perceiving the EI captured by PLMs as a blend of property-relevant and irrelevant information, with the latter akin to “noise” for the targeted property, necessitating its elimination. To achieve this, drawing inspiration from denoising diffusion models that refine noisy inputs to generate desired outputs, we devise a rank-based forward process to extend the diffusion model for denoising EI. 14 | 15 | ## Installation 16 | 17 | ``` 18 | >> git clone https://github.com/HICAI-ZJU/DePLM 19 | >> cd DePLM 20 | >> conda env create --file environment.yml 21 | ``` 22 | 23 | ## Quick Start 24 | 25 | We can train and test DePLM as follows. 26 | 27 | ``` 28 | >> bash ./scripts/schedule.sh 29 | ``` 30 | 31 | Here we use a deep mutational scanning (DMS) dataset - TAT_HV1BR_Fernandes_2016 - as an example. The program will run a training process with the default parameters. 32 | 33 | To train on your own dataset, you need to provide DMS and structure data, place them in `./data`, and modify the data configuration file `./configs/data`. 34 | 35 | ## Citation 36 | 37 | Please consider citing our paper if you find the code useful for your project. 38 | 39 | ``` 40 | @inproceedings{ 41 | wang2024deplm, 42 | title={De{PLM}: Denoising Protein Language Models for Property Optimization}, 43 | author={Zeyuan Wang and Keyan Ding and Ming Qin and Xiaotong Li and Xiang Zhuang and Yu Zhao and Jianhua Yao and Qiang Zhang and Huajun Chen}, 44 | booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems}, 45 | year={2024}, 46 | url={https://openreview.net/forum?id=MU27zjHBcW} 47 | } 48 | ``` 49 | -------------------------------------------------------------------------------- /src/utils/pylogger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Mapping, Optional 3 | 4 | from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only 5 | 6 | 7 | class RankedLogger(logging.LoggerAdapter): 8 | """A multi-GPU-friendly python command line logger.""" 9 | 10 | def __init__( 11 | self, 12 | name: str = __name__, 13 | rank_zero_only: bool = False, 14 | extra: Optional[Mapping[str, object]] = None, 15 | ) -> None: 16 | """Initializes a multi-GPU-friendly python command line logger that logs on all processes 17 | with their rank prefixed in the log message. 18 | 19 | :param name: The name of the logger. Default is ``__name__``. 20 | :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. 21 | :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. 22 | """ 23 | logger = logging.getLogger(name) 24 | super().__init__(logger=logger, extra=extra) 25 | self.rank_zero_only = rank_zero_only 26 | 27 | def log(self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None: 28 | """Delegate a log call to the underlying logger, after prefixing its message with the rank 29 | of the process it's being logged from. If `'rank'` is provided, then the log will only 30 | occur on that rank/process. 31 | 32 | :param level: The level to log at. Look at `logging.__init__.py` for more information. 33 | :param msg: The message to log. 34 | :param rank: The rank to log at. 35 | :param args: Additional args to pass to the underlying logging function. 36 | :param kwargs: Any additional keyword args to pass to the underlying logging function. 37 | """ 38 | if self.isEnabledFor(level): 39 | msg, kwargs = self.process(msg, kwargs) 40 | current_rank = getattr(rank_zero_only, "rank", None) 41 | if current_rank is None: 42 | raise RuntimeError("The `rank_zero_only.rank` needs to be set before use") 43 | msg = rank_prefixed_message(msg, current_rank) 44 | if self.rank_zero_only: 45 | if current_rank == 0: 46 | self.logger.log(level, msg, *args, **kwargs) 47 | else: 48 | if rank is None: 49 | self.logger.log(level, msg, *args, **kwargs) 50 | elif current_rank == rank: 51 | self.logger.log(level, msg, *args, **kwargs) 52 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | ### VisualStudioCode 131 | .vscode/* 132 | !.vscode/settings.json 133 | !.vscode/tasks.json 134 | !.vscode/launch.json 135 | !.vscode/extensions.json 136 | *.code-workspace 137 | **/.vscode 138 | 139 | # JetBrains 140 | .idea/ 141 | 142 | # Data & Models 143 | *.h5 144 | *.tar 145 | *.tar.gz 146 | 147 | # Lightning-Hydra-Template 148 | configs/local/default.yaml 149 | /logs/ 150 | .env 151 | 152 | # Aim logging 153 | .aim -------------------------------------------------------------------------------- /src/utils/rich_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Sequence 3 | 4 | import rich 5 | import rich.syntax 6 | import rich.tree 7 | from hydra.core.hydra_config import HydraConfig 8 | from lightning_utilities.core.rank_zero import rank_zero_only 9 | from omegaconf import DictConfig, OmegaConf, open_dict 10 | from rich.prompt import Prompt 11 | 12 | from src.utils import pylogger 13 | 14 | log = pylogger.RankedLogger(__name__, rank_zero_only=True) 15 | 16 | 17 | @rank_zero_only 18 | def print_config_tree( 19 | cfg: DictConfig, 20 | print_order: Sequence[str] = ( 21 | "data", 22 | "model", 23 | "callbacks", 24 | "logger", 25 | "trainer", 26 | "paths", 27 | "extras", 28 | ), 29 | resolve: bool = False, 30 | save_to_file: bool = False, 31 | ) -> None: 32 | """Prints the contents of a DictConfig as a tree structure using the Rich library. 33 | 34 | :param cfg: A DictConfig composed by Hydra. 35 | :param print_order: Determines in what order config components are printed. Default is ``("data", "model", 36 | "callbacks", "logger", "trainer", "paths", "extras")``. 37 | :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. 38 | :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. 39 | """ 40 | style = "dim" 41 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 42 | 43 | queue = [] 44 | 45 | # add fields from `print_order` to queue 46 | for field in print_order: 47 | queue.append(field) if field in cfg else log.warning( 48 | f"Field '{field}' not found in config. Skipping '{field}' config printing..." 49 | ) 50 | 51 | # add all the other fields to queue (not specified in `print_order`) 52 | for field in cfg: 53 | if field not in queue: 54 | queue.append(field) 55 | 56 | # generate config tree from queue 57 | for field in queue: 58 | branch = tree.add(field, style=style, guide_style=style) 59 | 60 | config_group = cfg[field] 61 | if isinstance(config_group, DictConfig): 62 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) 63 | else: 64 | branch_content = str(config_group) 65 | 66 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 67 | 68 | # print config tree 69 | rich.print(tree) 70 | 71 | # save config tree to file 72 | if save_to_file: 73 | with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: 74 | rich.print(tree, file=file) 75 | 76 | 77 | @rank_zero_only 78 | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: 79 | """Prompts user to input tags from command line if no tags are provided in config. 80 | 81 | :param cfg: A DictConfig composed by Hydra. 82 | :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. 83 | """ 84 | if not cfg.get("tags"): 85 | if "id" in HydraConfig().cfg.hydra.job: 86 | raise ValueError("Specify tags before launching a multirun!") 87 | 88 | log.warning("No tags provided in config. Prompting user to input tags...") 89 | tags = Prompt.ask("Enter a list of comma separated tags", default="dev") 90 | tags = [t.strip() for t in tags.split(",") if t != ""] 91 | 92 | with open_dict(cfg): 93 | cfg.tags = tags 94 | 95 | log.info(f"Tags: {cfg.tags}") 96 | 97 | if save_to_file: 98 | with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: 99 | rich.print(cfg.tags, file=file) 100 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.4.0 7 | hooks: 8 | # list of supported hooks: https://pre-commit.com/hooks.html 9 | - id: trailing-whitespace 10 | - id: end-of-file-fixer 11 | - id: check-docstring-first 12 | - id: check-yaml 13 | - id: debug-statements 14 | - id: detect-private-key 15 | - id: check-executables-have-shebangs 16 | - id: check-toml 17 | - id: check-case-conflict 18 | - id: check-added-large-files 19 | 20 | # python code formatting 21 | - repo: https://github.com/psf/black 22 | rev: 23.1.0 23 | hooks: 24 | - id: black 25 | args: [--line-length, "99"] 26 | 27 | # python import sorting 28 | - repo: https://github.com/PyCQA/isort 29 | rev: 5.12.0 30 | hooks: 31 | - id: isort 32 | args: ["--profile", "black", "--filter-files"] 33 | 34 | # python upgrading syntax to newer version 35 | - repo: https://github.com/asottile/pyupgrade 36 | rev: v3.3.1 37 | hooks: 38 | - id: pyupgrade 39 | args: [--py38-plus] 40 | 41 | # python docstring formatting 42 | - repo: https://github.com/myint/docformatter 43 | rev: v1.7.4 44 | hooks: 45 | - id: docformatter 46 | args: 47 | [ 48 | --in-place, 49 | --wrap-summaries=99, 50 | --wrap-descriptions=99, 51 | --style=sphinx, 52 | --black, 53 | ] 54 | 55 | # python docstring coverage checking 56 | - repo: https://github.com/econchick/interrogate 57 | rev: 1.5.0 # or master if you're bold 58 | hooks: 59 | - id: interrogate 60 | args: 61 | [ 62 | --verbose, 63 | --fail-under=80, 64 | --ignore-init-module, 65 | --ignore-init-method, 66 | --ignore-module, 67 | --ignore-nested-functions, 68 | -vv, 69 | ] 70 | 71 | # python check (PEP8), programming errors and code complexity 72 | - repo: https://github.com/PyCQA/flake8 73 | rev: 6.0.0 74 | hooks: 75 | - id: flake8 76 | args: 77 | [ 78 | "--extend-ignore", 79 | "E203,E402,E501,F401,F841,RST2,RST301", 80 | "--exclude", 81 | "logs/*,data/*", 82 | ] 83 | additional_dependencies: [flake8-rst-docstrings==0.3.0] 84 | 85 | # python security linter 86 | - repo: https://github.com/PyCQA/bandit 87 | rev: "1.7.5" 88 | hooks: 89 | - id: bandit 90 | args: ["-s", "B101"] 91 | 92 | # yaml formatting 93 | - repo: https://github.com/pre-commit/mirrors-prettier 94 | rev: v3.0.0-alpha.6 95 | hooks: 96 | - id: prettier 97 | types: [yaml] 98 | exclude: "environment.yaml" 99 | 100 | # shell scripts linter 101 | - repo: https://github.com/shellcheck-py/shellcheck-py 102 | rev: v0.9.0.2 103 | hooks: 104 | - id: shellcheck 105 | 106 | # md formatting 107 | - repo: https://github.com/executablebooks/mdformat 108 | rev: 0.7.16 109 | hooks: 110 | - id: mdformat 111 | args: ["--number"] 112 | additional_dependencies: 113 | - mdformat-gfm 114 | - mdformat-tables 115 | - mdformat_frontmatter 116 | # - mdformat-toc 117 | # - mdformat-black 118 | 119 | # word spelling linter 120 | - repo: https://github.com/codespell-project/codespell 121 | rev: v2.2.4 122 | hooks: 123 | - id: codespell 124 | args: 125 | - --skip=logs/**,data/**,*.ipynb 126 | # - --ignore-words-list=abc,def 127 | 128 | # jupyter notebook cell output clearing 129 | - repo: https://github.com/kynan/nbstripout 130 | rev: 0.6.1 131 | hooks: 132 | - id: nbstripout 133 | 134 | # jupyter notebook linting 135 | - repo: https://github.com/nbQA-dev/nbQA 136 | rev: 1.6.3 137 | hooks: 138 | - id: nbqa-black 139 | args: ["--line-length=99"] 140 | - id: nbqa-isort 141 | args: ["--profile=black"] 142 | - id: nbqa-flake8 143 | args: 144 | [ 145 | "--extend-ignore=E203,E402,E501,F401,F841", 146 | "--exclude=logs/*,data/*", 147 | ] -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from importlib.util import find_spec 3 | from typing import Any, Callable, Dict, Optional, Tuple 4 | 5 | from omegaconf import DictConfig 6 | 7 | from src.utils import pylogger, rich_utils 8 | 9 | log = pylogger.RankedLogger(__name__, rank_zero_only=True) 10 | 11 | 12 | def extras(cfg: DictConfig) -> None: 13 | """Applies optional utilities before the task is started. 14 | 15 | Utilities: 16 | - Ignoring python warnings 17 | - Setting tags from command line 18 | - Rich config printing 19 | 20 | :param cfg: A DictConfig object containing the config tree. 21 | """ 22 | # return if no `extras` config 23 | if not cfg.get("extras"): 24 | log.warning("Extras config not found! ") 25 | return 26 | 27 | # disable python warnings 28 | if cfg.extras.get("ignore_warnings"): 29 | log.info("Disabling python warnings! ") 30 | warnings.filterwarnings("ignore") 31 | 32 | # prompt user to input tags from command line if none are provided in the config 33 | if cfg.extras.get("enforce_tags"): 34 | log.info("Enforcing tags! ") 35 | rich_utils.enforce_tags(cfg, save_to_file=True) 36 | 37 | # pretty print config tree using Rich library 38 | if cfg.extras.get("print_config"): 39 | log.info("Printing config tree with Rich! ") 40 | rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) 41 | 42 | 43 | def task_wrapper(task_func: Callable) -> Callable: 44 | """Optional decorator that controls the failure behavior when executing the task function. 45 | 46 | This wrapper can be used to: 47 | - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) 48 | - save the exception to a `.log` file 49 | - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) 50 | - etc. (adjust depending on your needs) 51 | 52 | Example: 53 | ``` 54 | @utils.task_wrapper 55 | def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 56 | ... 57 | return metric_dict, object_dict 58 | ``` 59 | 60 | :param task_func: The task function to be wrapped. 61 | 62 | :return: The wrapped task function. 63 | """ 64 | 65 | def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 66 | # execute the task 67 | try: 68 | metric_dict, object_dict = task_func(cfg=cfg) 69 | 70 | # things to do if exception occurs 71 | except Exception as ex: 72 | # save exception to `.log` file 73 | log.exception("") 74 | 75 | # some hyperparameter combinations might be invalid or cause out-of-memory errors 76 | # so when using hparam search plugins like Optuna, you might want to disable 77 | # raising the below exception to avoid multirun failure 78 | raise ex 79 | 80 | # things to always do after either success or exception 81 | finally: 82 | # display output dir path in terminal 83 | log.info(f"Output dir: {cfg.paths.output_dir}") 84 | 85 | # always close wandb run (even if exception occurs so multirun won't fail) 86 | if find_spec("wandb"): # check if wandb is installed 87 | import wandb 88 | 89 | if wandb.run: 90 | log.info("Closing wandb!") 91 | wandb.finish() 92 | 93 | return metric_dict, object_dict 94 | 95 | return wrap 96 | 97 | 98 | def get_metric_value(metric_dict: Dict[str, Any], metric_name: Optional[str]) -> Optional[float]: 99 | """Safely retrieves value of the metric logged in LightningModule. 100 | 101 | :param metric_dict: A dict containing metric values. 102 | :param metric_name: If provided, the name of the metric to retrieve. 103 | :return: If a metric name was provided, the value of the metric. 104 | """ 105 | if not metric_name: 106 | log.info("Metric name is None! Skipping metric value retrieval...") 107 | return None 108 | 109 | if metric_name not in metric_dict: 110 | raise Exception( 111 | f"Metric value not found! \n" 112 | "Make sure metric name logged in LightningModule is correct!\n" 113 | "Make sure `optimized_metric` name in `hparams_search` config is correct!" 114 | ) 115 | 116 | metric_value = metric_dict[metric_name].item() 117 | log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") 118 | 119 | return metric_value 120 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple 2 | 3 | import hydra 4 | import rootutils 5 | from omegaconf import DictConfig 6 | 7 | import torch 8 | import lightning as L 9 | from lightning import Callback, LightningDataModule, LightningModule, Trainer 10 | from lightning.pytorch.loggers import Logger 11 | 12 | 13 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 14 | # ------------------------------------------------------------------------------------ # 15 | # the setup_root above is equivalent to: 16 | # - adding project root dir to PYTHONPATH 17 | # (so you don't need to force user to install project as a package) 18 | # (necessary before importing any local modules e.g. `from src import utils`) 19 | # - setting up PROJECT_ROOT environment variable 20 | # (which is used as a base for paths in "configs/paths/default.yaml") 21 | # (this way all filepaths are the same no matter where you run the code) 22 | # - loading environment variables from ".env" in root dir 23 | # 24 | # you can remove it if you: 25 | # 1. either install project as a package or move entry files to project root dir 26 | # 2. set `root_dir` to "." in "configs/paths/default.yaml" 27 | # 28 | # more info: https://github.com/ashleve/rootutils 29 | # ------------------------------------------------------------------------------------ # 30 | 31 | from src.utils import ( 32 | RankedLogger, 33 | extras, 34 | get_metric_value, 35 | instantiate_callbacks, 36 | instantiate_loggers, 37 | log_hyperparameters, 38 | task_wrapper, 39 | ) 40 | 41 | log = RankedLogger(__name__, rank_zero_only=True) 42 | 43 | 44 | 45 | @task_wrapper 46 | def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 47 | """Trains the model. Can additionally evaluate on a testset, using best weights obtained during 48 | training. 49 | 50 | This method is wrapped in optional @task_wrapper decorator, that controls the behavior during 51 | failure. Useful for multiruns, saving info about the crash, etc. 52 | 53 | :param cfg: A DictConfig configuration composed by Hydra. 54 | :return: A tuple with metrics and dict with all instantiated objects. 55 | """ 56 | # set seed for random number generators in pytorch, numpy and python.random 57 | if cfg.get("seed"): 58 | L.seed_everything(cfg.seed, workers=True) 59 | 60 | log.info(f"Instantiating datamodule <{cfg.data._target_}>") 61 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) 62 | 63 | log.info(f"Instantiating model <{cfg.model._target_}>") 64 | model: LightningModule = hydra.utils.instantiate(cfg.model) 65 | 66 | log.info("Instantiating callbacks...") 67 | callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) 68 | 69 | log.info("Instantiating loggers...") 70 | logger: List[Logger] = instantiate_loggers(cfg.get("logger")) 71 | 72 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>") 73 | trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) 74 | 75 | object_dict = { 76 | "cfg": cfg, 77 | "datamodule": datamodule, 78 | "model": model, 79 | "callbacks": callbacks, 80 | "logger": logger, 81 | "trainer": trainer, 82 | } 83 | 84 | if logger: 85 | log.info("Logging hyperparameters!") 86 | log_hyperparameters(object_dict) 87 | 88 | if cfg.get("train"): 89 | log.info("Starting training!") 90 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) 91 | 92 | train_metrics = trainer.callback_metrics 93 | 94 | if cfg.get("test"): 95 | log.info("Starting testing!") 96 | ckpt_path = trainer.checkpoint_callback.best_model_path 97 | if ckpt_path == "": 98 | log.warning("Best ckpt not found! Using current weights for testing...") 99 | ckpt_path = None 100 | trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) 101 | log.info(f"Best ckpt path: {ckpt_path}") 102 | 103 | test_metrics = trainer.callback_metrics 104 | 105 | # merge train and test metrics 106 | metric_dict = {**train_metrics, **test_metrics} 107 | 108 | return metric_dict, object_dict 109 | 110 | 111 | 112 | @hydra.main(version_base="1.3", config_path="../configs", config_name="DePLM.yaml") 113 | def main(cfg: DictConfig) -> Optional[float]: 114 | """Main entry point for training. 115 | 116 | :param cfg: DictConfig configuration composed by Hydra. 117 | :return: Optional[float] with optimized metric value. 118 | """ 119 | # apply extra utilities 120 | # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) 121 | extras(cfg) 122 | 123 | # train the model 124 | metric_dict, _ = train(cfg) 125 | 126 | # safely retrieve metric value for hydra-based hyperparameter optimization 127 | metric_value = get_metric_value( 128 | metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") 129 | ) 130 | 131 | return metric_value 132 | 133 | 134 | 135 | if __name__ == '__main__': 136 | main() 137 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: deplm 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - nvidia 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=5.1=1_gnu 10 | - blas=1.0=mkl 11 | - brotli-python=1.0.9=py39h6a678d5_8 12 | - bzip2=1.0.8=h5eee18b_6 13 | - ca-certificates=2024.11.26=h06a4308_0 14 | - certifi=2024.8.30=py39h06a4308_0 15 | - charset-normalizer=3.3.2=pyhd3eb1b0_0 16 | - cuda-cudart=11.8.89=0 17 | - cuda-cupti=11.8.87=0 18 | - cuda-libraries=11.8.0=0 19 | - cuda-nvrtc=11.8.89=0 20 | - cuda-nvtx=11.8.86=0 21 | - cuda-runtime=11.8.0=0 22 | - cuda-version=11.8=hcce14f8_3 23 | - cudatoolkit=11.8.0=h6a678d5_0 24 | - ffmpeg=4.3=hf484d3e_0 25 | - filelock=3.13.1=py39h06a4308_0 26 | - freetype=2.12.1=h4a9f257_0 27 | - gmp=6.2.1=h295c915_3 28 | - gmpy2=2.1.2=py39heeb90bb_0 29 | - gnutls=3.6.15=he1e5248_0 30 | - idna=3.7=py39h06a4308_0 31 | - intel-openmp=2023.1.0=hdb19cb5_46306 32 | - jinja2=3.1.4=py39h06a4308_1 33 | - jpeg=9e=h5eee18b_3 34 | - lame=3.100=h7b6447c_0 35 | - lcms2=2.12=h3be6417_0 36 | - ld_impl_linux-64=2.40=h12ee557_0 37 | - lerc=3.0=h295c915_0 38 | - libcublas=11.11.3.6=0 39 | - libcufft=10.9.0.58=0 40 | - libcufile=1.9.1.3=0 41 | - libcurand=10.3.5.147=0 42 | - libcusolver=11.4.1.48=0 43 | - libcusparse=11.7.5.86=0 44 | - libdeflate=1.17=h5eee18b_1 45 | - libffi=3.4.4=h6a678d5_1 46 | - libgcc-ng=11.2.0=h1234567_1 47 | - libgomp=11.2.0=h1234567_1 48 | - libiconv=1.16=h5eee18b_3 49 | - libidn2=2.3.4=h5eee18b_0 50 | - libjpeg-turbo=2.0.0=h9bf148f_0 51 | - libnpp=11.8.0.86=0 52 | - libnvjpeg=11.9.0.86=0 53 | - libpng=1.6.39=h5eee18b_0 54 | - libstdcxx-ng=11.2.0=h1234567_1 55 | - libtasn1=4.19.0=h5eee18b_0 56 | - libtiff=4.5.1=h6a678d5_0 57 | - libunistring=0.9.10=h27cfd23_0 58 | - libwebp-base=1.3.2=h5eee18b_1 59 | - llvm-openmp=14.0.6=h9e868ea_0 60 | - lz4-c=1.9.4=h6a678d5_1 61 | - markupsafe=2.1.3=py39h5eee18b_0 62 | - mkl=2023.1.0=h213fc3f_46344 63 | - mkl-service=2.4.0=py39h5eee18b_1 64 | - mkl_fft=1.3.11=py39h5eee18b_0 65 | - mkl_random=1.2.8=py39h1128e8f_0 66 | - mpc=1.1.0=h10f8cd9_1 67 | - mpfr=4.0.2=hb69a4c5_1 68 | - mpmath=1.3.0=py39h06a4308_0 69 | - ncurses=6.4=h6a678d5_0 70 | - nettle=3.7.3=hbbd107a_1 71 | - networkx=3.2.1=py39h06a4308_0 72 | - openh264=2.1.1=h4ff587b_0 73 | - openjpeg=2.5.2=he7f1fd0_0 74 | - openssl=3.0.15=h5eee18b_0 75 | - pillow=11.0.0=py39hfdbf927_0 76 | - pip=24.2=py39h06a4308_0 77 | - pysocks=1.7.1=py39h06a4308_0 78 | - python=3.9.20=he870216_1 79 | - pytorch=2.4.1=py3.9_cuda11.8_cudnn9.1.0_0 80 | - pytorch-cuda=11.8=h7e8668a_6 81 | - pytorch-mutex=1.0=cuda 82 | - pyyaml=6.0.2=py39h5eee18b_0 83 | - readline=8.2=h5eee18b_0 84 | - requests=2.32.3=py39h06a4308_1 85 | - setuptools=75.1.0=py39h06a4308_0 86 | - sqlite=3.45.3=h5eee18b_0 87 | - sympy=1.13.2=py39h06a4308_0 88 | - tbb=2021.8.0=hdb19cb5_0 89 | - tk=8.6.14=h39e8969_0 90 | - torchaudio=2.4.1=py39_cu118 91 | - torchtriton=3.0.0=py39 92 | - torchvision=0.19.1=py39_cu118 93 | - typing_extensions=4.11.0=py39h06a4308_0 94 | - urllib3=2.2.3=py39h06a4308_0 95 | - wheel=0.44.0=py39h06a4308_0 96 | - xz=5.4.6=h5eee18b_1 97 | - yaml=0.2.5=h7b6447c_0 98 | - zlib=1.2.13=h5eee18b_1 99 | - zstd=1.5.6=hc292b87_0 100 | - pip: 101 | - aiohappyeyeballs==2.4.4 102 | - aiohttp==3.11.9 103 | - aiosignal==1.3.1 104 | - alembic==1.14.0 105 | - antlr4-python3-runtime==4.9.3 106 | - asttokens==3.0.0 107 | - async-timeout==5.0.1 108 | - attrs==24.2.0 109 | - autopage==0.5.2 110 | - biotite==0.40.0 111 | - cfgv==3.4.0 112 | - cliff==4.8.0 113 | - cmaes==0.11.1 114 | - cmd2==2.5.7 115 | - colorlog==6.9.0 116 | - decorator==5.1.1 117 | - distlib==0.3.9 118 | - einops==0.8.0 119 | - exceptiongroup==1.2.2 120 | - executing==2.1.0 121 | - fair-esm==2.0.0 122 | - frozenlist==1.5.0 123 | - fsspec==2024.10.0 124 | - greenlet==3.1.1 125 | - hydra-colorlog==1.2.0 126 | - hydra-core==1.3.2 127 | - hydra-optuna-sweeper==1.2.0 128 | - identify==2.6.3 129 | - importlib-metadata==8.5.0 130 | - iniconfig==2.0.0 131 | - ipdb==0.13.13 132 | - ipython==8.18.1 133 | - jedi==0.19.2 134 | - lightning==2.4.0 135 | - lightning-utilities==0.11.9 136 | - mako==1.3.6 137 | - markdown-it-py==3.0.0 138 | - matplotlib-inline==0.1.7 139 | - mdurl==0.1.2 140 | - msgpack==1.1.0 141 | - multidict==6.1.0 142 | - nodeenv==1.9.1 143 | - numpy==1.26.1 144 | - omegaconf==2.3.0 145 | - openfold==2.0.0 146 | - optuna==2.10.1 147 | - packaging==24.2 148 | - pandas==2.2.3 149 | - parso==0.8.4 150 | - pbr==6.1.0 151 | - pexpect==4.9.0 152 | - platformdirs==4.3.6 153 | - pluggy==1.5.0 154 | - pre-commit==4.0.1 155 | - prettytable==3.12.0 156 | - prompt-toolkit==3.0.48 157 | - propcache==0.2.1 158 | - psutil==6.1.0 159 | - ptyprocess==0.7.0 160 | - pure-eval==0.2.3 161 | - pygments==2.18.0 162 | - pyparsing==3.2.0 163 | - pyperclip==1.9.0 164 | - pytest==8.3.4 165 | - python-dateutil==2.9.0.post0 166 | - python-dotenv==1.0.1 167 | - pytorch-lightning==2.4.0 168 | - pytz==2024.2 169 | - rich==13.9.4 170 | - rootutils==1.0.7 171 | - scipy==1.13.1 172 | - six==1.16.0 173 | - sqlalchemy==2.0.36 174 | - stack-data==0.6.3 175 | - stevedore==5.4.0 176 | - tomli==2.2.1 177 | - torch-cluster==1.6.3+pt24cu118 178 | - torch-geometric==2.6.1 179 | - torch-scatter==2.1.2+pt24cu118 180 | - torch-sparse==0.6.18+pt24cu118 181 | - torchmetrics==1.6.0 182 | - torchsort==0.1.9 183 | - tqdm==4.67.1 184 | - traitlets==5.14.3 185 | - tzdata==2024.2 186 | - virtualenv==20.28.0 187 | - wcwidth==0.2.13 188 | - yarl==1.18.3 189 | - zipp==3.21.0 190 | -------------------------------------------------------------------------------- /src/data/batched_protein_engineering_datamodule.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, Dict, Optional 3 | import pandas as pd 4 | import logging 5 | import torch 6 | import esm 7 | from lightning import LightningDataModule 8 | from torch.utils.data import Dataset, DataLoader 9 | 10 | LOG = logging.getLogger(__name__) 11 | 12 | 13 | class BatchedProteinEngineeringDataset(Dataset): 14 | def __init__(self, name, wt_sequence, coords, train_data, valid_data): 15 | _, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D() 16 | self.batch_converter = self.alphabet.get_batch_converter() 17 | _, self.structure_alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50() 18 | self.coord_converter = esm.inverse_folding.util.CoordBatchConverter(self.structure_alphabet) 19 | self.name = name 20 | self.coords = [] 21 | for coord in coords: 22 | batch = [(coord, None, None)] 23 | coord, confidence, _, _, padding_mask = self.coord_converter(batch) 24 | self.coords.append((coord, padding_mask, confidence)) 25 | 26 | self.batch_tokens = [] 27 | for sequence in wt_sequence: 28 | batch_labels, batch_strs, batch_tokens = self.batch_converter([('protein', sequence)]) 29 | self.batch_tokens.append(batch_tokens) 30 | 31 | self.train_labels = [] 32 | for assay_idx, dms_data in enumerate(train_data): 33 | train_label = [] 34 | for index, data in dms_data.iterrows(): 35 | mutants = data['mutant'].split(':') 36 | mutant_list = [] 37 | for mutant in mutants: 38 | location = int(mutant[1:-1]) 39 | mutant_list.append((location, torch.tensor(self.alphabet.tok_to_idx[mutant[-1:]], dtype=torch.long))) 40 | train_label.append((torch.tensor(data['score'], dtype=torch.float32), mutant_list)) 41 | self.train_labels.append(train_label) 42 | 43 | self.valid_labels = [] 44 | for assay_idx, dms_data in enumerate(valid_data): 45 | valid_label = [] 46 | for index, data in dms_data.iterrows(): 47 | mutants = data['mutant'].split(':') 48 | mutant_list = [] 49 | for mutant in mutants: 50 | location = int(mutant[1:-1]) 51 | mutant_list.append((location, torch.tensor(self.alphabet.tok_to_idx[mutant[-1:]], dtype=torch.long))) 52 | valid_label.append((torch.tensor(data['score'], dtype=torch.float32), mutant_list)) 53 | self.valid_labels.append(valid_label) 54 | 55 | def __getitem__(self, index): 56 | return self.name, self.batch_tokens, self.coords, self.train_labels, self.valid_labels 57 | 58 | def __len__(self): 59 | return 1 60 | 61 | 62 | 63 | class BatchedProteinEngineeringDataModule(LightningDataModule): 64 | """`LightningDataModule` for the ProteinGym dataset. 65 | Read the docs: 66 | https://lightning.ai/docs/pytorch/latest/data/datamodule.html 67 | """ 68 | 69 | def __init__( 70 | self, 71 | task_name: str, 72 | data_dir: str = "data/", 73 | batch_size: int = 64, 74 | num_workers: int = 0, 75 | pin_memory: bool = False, 76 | ) -> None: 77 | super().__init__() 78 | self.save_hyperparameters(logger=False) 79 | 80 | self.data_train: Optional[Dataset] = None 81 | self.data_val: Optional[Dataset] = None 82 | _, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D() 83 | self.batch_size_per_device = batch_size 84 | 85 | @property 86 | def num_classes(self) -> int: 87 | return 1 88 | 89 | def prepare_data(self) -> None: 90 | pass 91 | 92 | def setup(self, stage: Optional[str] = None) -> None: 93 | """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. 94 | """ 95 | 96 | if self.trainer is not None: 97 | if self.hparams.batch_size % self.trainer.world_size != 0: 98 | raise RuntimeError( 99 | f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})." 100 | ) 101 | self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size 102 | 103 | # load and split datasets only if not loaded already 104 | if not self.data_train and not self.data_val: 105 | self.hparams.data_dir = Path(self.hparams.data_dir) / self.hparams.task_name 106 | with open(self.hparams.data_dir/'wildtype.txt', 'r') as f: 107 | wt_sequence = f.readline().strip() 108 | assay_data = pd.read_csv(self.hparams.data_dir/f'{self.hparams.task_name}.csv') 109 | structure = esm.inverse_folding.util.load_structure(str(self.hparams.data_dir/f'{self.hparams.task_name}.pdb'), 'A') 110 | coord, _ = esm.inverse_folding.util.extract_coords_from_structure(structure) 111 | 112 | self.data_train = BatchedProteinEngineeringDataset([self.hparams.task_name], [wt_sequence], [coord], 113 | [assay_data[assay_data["split"] == 0].reset_index()], 114 | [assay_data[assay_data["split"] == 2].reset_index()]) 115 | self.data_val = BatchedProteinEngineeringDataset([self.hparams.task_name], [wt_sequence], [coord], 116 | [assay_data[assay_data["split"] == 0].reset_index()], 117 | [assay_data[assay_data["split"] == 2].reset_index()]) 118 | 119 | def collator(self, raw_batch): 120 | return raw_batch[0] 121 | 122 | def train_dataloader(self) -> DataLoader[Any]: 123 | return DataLoader( 124 | dataset=self.data_train, 125 | batch_size=self.batch_size_per_device, 126 | num_workers=self.hparams.num_workers, 127 | pin_memory=self.hparams.pin_memory, 128 | collate_fn=self.collator, 129 | shuffle=True 130 | ) 131 | 132 | def val_dataloader(self) -> DataLoader[Any]: 133 | return DataLoader( 134 | dataset=self.data_val, 135 | batch_size=self.batch_size_per_device, 136 | num_workers=self.hparams.num_workers, 137 | pin_memory=self.hparams.pin_memory, 138 | collate_fn=self.collator, 139 | shuffle=False 140 | ) 141 | 142 | def teardown(self, stage: Optional[str] = None) -> None: 143 | """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`, 144 | `trainer.test()`, and `trainer.predict()`. 145 | 146 | :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. 147 | Defaults to ``None``. 148 | """ 149 | pass 150 | 151 | def state_dict(self) -> Dict[Any, Any]: 152 | """Called when saving a checkpoint. Implement to generate and save the datamodule state. 153 | 154 | :return: A dictionary containing the datamodule state that you want to save. 155 | """ 156 | return {} 157 | 158 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 159 | """Called when loading a checkpoint. Implement to reload datamodule state given datamodule 160 | `state_dict()`. 161 | 162 | :param state_dict: The datamodule state returned by `self.state_dict()`. 163 | """ 164 | pass 165 | -------------------------------------------------------------------------------- /src/data/batched_GFP_datamodule.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, Dict, Optional, List 3 | import esm.inverse_folding 4 | import pandas as pd 5 | import logging 6 | import torch 7 | import esm 8 | from lightning import LightningDataModule 9 | from torch.utils.data import Dataset, DataLoader 10 | 11 | 12 | LOG = logging.getLogger(__name__) 13 | 14 | 15 | def mutant_process(mutants, start_mutants=None): 16 | if start_mutants is None: 17 | return mutants 18 | return_mutants = [] 19 | start_mutant_dict = {int(m[1:-1]): m for m in start_mutants} 20 | for mutant in mutants: 21 | location = int(mutant[1:-1]) 22 | if location not in start_mutant_dict.keys(): 23 | return_mutants.append(mutant) 24 | continue 25 | wt = start_mutant_dict[location][0] 26 | mt = mutant[-1] 27 | if wt == mt: 28 | continue 29 | return_mutants.append(f'{wt}{location}{mt}') 30 | del start_mutant_dict[location] 31 | return_mutants = return_mutants + list(start_mutant_dict.values()) 32 | return return_mutants 33 | 34 | class BatchedGFPDataset(Dataset): 35 | def __init__(self, name, wt_sequence, coords, train_data, val_data): 36 | _, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D() 37 | self.batch_converter = self.alphabet.get_batch_converter() 38 | _, self.structure_alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50() 39 | self.coord_converter = esm.inverse_folding.util.CoordBatchConverter(self.structure_alphabet) 40 | 41 | self.name = name 42 | self.coords = [] 43 | for coord in coords: 44 | batch = [(coord, None, None)] 45 | coord, confidence, _, _, padding_mask = self.coord_converter(batch) 46 | self.coords.append((coord, padding_mask, confidence)) 47 | 48 | self.batch_tokens = [] 49 | for sequence in wt_sequence: 50 | batch_labels, batch_strs, batch_tokens = self.batch_converter([('protein', sequence)]) 51 | self.batch_tokens.append(batch_tokens) 52 | 53 | self.train_labels = [] 54 | for assay_idx, dms_data in enumerate(train_data): 55 | train_label = [] 56 | for index, data in dms_data.iterrows(): 57 | mutants = data['mutant'].split(':') 58 | mutants = mutant_process(mutants, ['S38T', 'R41K', 'S105N']) 59 | mutant_list = [] 60 | for mutant in mutants: 61 | location = int(mutant[1:-1]) 62 | mutant_list.append((location, torch.tensor(self.alphabet.tok_to_idx[mutant[:1]], dtype=torch.long), torch.tensor(self.alphabet.tok_to_idx[mutant[-1:]], dtype=torch.long))) 63 | train_label.append((torch.tensor(data['score'], dtype=torch.float32), mutant_list)) 64 | self.train_labels.append(train_label) 65 | 66 | self.val_labels = [] 67 | for assay_idx, dms_data in enumerate(val_data): 68 | val_label = [] 69 | for index, data in dms_data.iterrows(): 70 | mutants = data['mutant'].split(':') 71 | if len(mutants) > 3: 72 | continue 73 | mutant_list = [] 74 | for mutant in mutants: 75 | location = int(mutant[1:-1]) 76 | mutant_list.append((location, torch.tensor(self.alphabet.tok_to_idx[mutant[:1]], dtype=torch.long), torch.tensor(self.alphabet.tok_to_idx[mutant[-1:]], dtype=torch.long))) 77 | val_label.append((torch.tensor(data['score'], dtype=torch.float32), mutant_list)) 78 | self.val_labels.append(val_label) 79 | LOG.info(f'train data: {len(self.train_labels)}; val data: {len(self.val_labels)}') 80 | 81 | def __getitem__(self, index): 82 | return self.name, self.batch_tokens, self.coords, self.train_labels, self.val_labels 83 | 84 | def __len__(self): 85 | return 1 86 | 87 | 88 | class BatchedGFPDataModule(LightningDataModule): 89 | """`LightningDataModule` for the GFP dataset. 90 | Read the docs: 91 | https://lightning.ai/docs/pytorch/latest/data/datamodule.html 92 | """ 93 | 94 | def __init__( 95 | self, 96 | task_name: str, 97 | support_name: List, 98 | data_dir: str = "data/", 99 | num_workers: int = 0, 100 | pin_memory: bool = False, 101 | ) -> None: 102 | super().__init__() 103 | self.save_hyperparameters(logger=False) 104 | 105 | self.data_train: Optional[Dataset] = None 106 | self.val: Optional[Dataset] = None 107 | _, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D() 108 | 109 | def setup(self, stage: Optional[str] = None) -> None: 110 | if not self.data_train: 111 | self.hparams.data_dir = Path(self.hparams.data_dir) / 'GFP' 112 | with open(self.hparams.data_dir / f'{self.hparams.task_name}.txt', 'r') as f_in: 113 | wt_sequence = f_in.readline().strip() 114 | wt_assay_data = pd.read_csv(self.hparams.data_dir/f'{self.hparams.task_name}.csv') 115 | wt_structure = esm.inverse_folding.util.load_structure(str(self.hparams.data_dir/f'{self.hparams.task_name}.pdb'), 'A') 116 | wt_coord, _ = esm.inverse_folding.util.extract_coords_from_structure(wt_structure) 117 | 118 | support_name = [] 119 | support_wt_sequence = [] 120 | support_assay_data = [] 121 | support_coord = [] 122 | for name in self.hparams.support_name: 123 | support_name.append(name) 124 | with open(self.hparams.data_dir / f'{name}.txt', 'r') as f_in: 125 | support_wt_sequence.append(f_in.readline().strip()) 126 | support_assay_data.append(pd.read_csv(self.hparams.data_dir/f'{name}_mt.csv')) 127 | structure = esm.inverse_folding.util.load_structure(str(self.hparams.data_dir/f'{name}_mt.pdb'), 'A') 128 | coord, _ = esm.inverse_folding.util.extract_coords_from_structure(structure) 129 | support_coord.append(coord) 130 | 131 | self.data_train = BatchedGFPDataset([self.hparams.task_name] + support_name, [wt_sequence] + support_wt_sequence, [wt_coord] + support_coord, 132 | [wt_assay_data[wt_assay_data["split"] == 0].reset_index()] + support_assay_data, 133 | [wt_assay_data[wt_assay_data["split"] == 2].reset_index()]) 134 | 135 | self.data_val = BatchedGFPDataset([self.hparams.task_name], [wt_sequence], [wt_coord], 136 | [wt_assay_data[wt_assay_data["split"] == 0].reset_index()], 137 | [wt_assay_data[wt_assay_data["split"] == 2].reset_index()]) 138 | 139 | def collator(self, raw_batch): 140 | return raw_batch[0] 141 | 142 | def train_dataloader(self) -> DataLoader[Any]: 143 | return DataLoader( 144 | dataset=self.data_train, 145 | batch_size=1, 146 | num_workers=self.hparams.num_workers, 147 | pin_memory=self.hparams.pin_memory, 148 | collate_fn=self.collator, 149 | shuffle=True 150 | ) 151 | 152 | def val_dataloader(self) -> DataLoader[Any]: 153 | return DataLoader( 154 | dataset=self.data_val, 155 | batch_size=1, 156 | num_workers=self.hparams.num_workers, 157 | pin_memory=self.hparams.pin_memory, 158 | collate_fn=self.collator, 159 | shuffle=False 160 | ) 161 | 162 | def teardown(self, stage: Optional[str] = None) -> None: 163 | """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`, 164 | `trainer.test()`, and `trainer.predict()`. 165 | 166 | :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. 167 | Defaults to ``None``. 168 | """ 169 | pass 170 | 171 | def state_dict(self) -> Dict[Any, Any]: 172 | """Called when saving a checkpoint. Implement to generate and save the datamodule state. 173 | 174 | :return: A dictionary containing the datamodule state that you want to save. 175 | """ 176 | return {} 177 | 178 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 179 | """Called when loading a checkpoint. Implement to reload datamodule state given datamodule 180 | `state_dict()`. 181 | 182 | :param state_dict: The datamodule state returned by `self.state_dict()`. 183 | """ 184 | pass -------------------------------------------------------------------------------- /src/models/DePLM_components/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import typing as T 5 | import math 6 | from torch.nn import LayerNorm as ESM1bLayerNorm 7 | 8 | import numpy as np 9 | from einops import rearrange, repeat 10 | 11 | from openfold.model.triangular_attention import ( 12 | TriangleAttentionEndingNode, 13 | TriangleAttentionStartingNode, 14 | ) 15 | from openfold.model.triangular_multiplicative_update import ( 16 | TriangleMultiplicationIncoming, 17 | TriangleMultiplicationOutgoing, 18 | ) 19 | 20 | def gelu(x): 21 | """Implementation of the gelu activation function. 22 | 23 | For information: OpenAI GPT's gelu is slightly different 24 | (and gives slightly different results): 25 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 26 | """ 27 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 28 | 29 | class SequenceToPair(nn.Module): 30 | def __init__(self, sequence_state_dim, inner_dim, pairwise_state_dim): 31 | super().__init__() 32 | 33 | self.layernorm = nn.LayerNorm(sequence_state_dim) 34 | self.proj = nn.Linear(sequence_state_dim, inner_dim * 2, bias=True) 35 | self.o_proj = nn.Linear(2 * inner_dim, pairwise_state_dim, bias=True) 36 | 37 | torch.nn.init.zeros_(self.proj.bias) 38 | torch.nn.init.zeros_(self.o_proj.bias) 39 | 40 | def forward(self, sequence_state): 41 | """ 42 | Inputs: 43 | sequence_state: B x L x sequence_state_dim 44 | 45 | Output: 46 | pairwise_state: B x L x L x pairwise_state_dim 47 | 48 | Intermediate state: 49 | B x L x L x 2*inner_dim 50 | """ 51 | 52 | assert len(sequence_state.shape) == 3 53 | 54 | s = self.layernorm(sequence_state) 55 | s = self.proj(s) 56 | q, k = s.chunk(2, dim=-1) 57 | 58 | prod = q[:, None, :, :] * k[:, :, None, :] 59 | diff = q[:, None, :, :] - k[:, :, None, :] 60 | 61 | x = torch.cat([prod, diff], dim=-1) 62 | x = self.o_proj(x) 63 | 64 | return x 65 | 66 | 67 | class PairToSequence(nn.Module): 68 | def __init__(self, pairwise_state_dim, num_heads): 69 | super().__init__() 70 | 71 | self.layernorm = nn.LayerNorm(pairwise_state_dim) 72 | self.linear = nn.Linear(pairwise_state_dim, num_heads, bias=False) 73 | 74 | def forward(self, pairwise_state): 75 | """ 76 | Inputs: 77 | pairwise_state: B x L x L x pairwise_state_dim 78 | 79 | Output: 80 | pairwise_bias: B x L x L x num_heads 81 | """ 82 | assert len(pairwise_state.shape) == 4 83 | z = self.layernorm(pairwise_state) 84 | pairwise_bias = self.linear(z) 85 | return pairwise_bias 86 | 87 | 88 | class ResidueMLP(nn.Module): 89 | def __init__(self, embed_dim, inner_dim, norm=nn.LayerNorm, dropout=0): 90 | super().__init__() 91 | 92 | self.mlp = nn.Sequential( 93 | norm(embed_dim), 94 | nn.Linear(embed_dim, inner_dim), 95 | nn.ReLU(), 96 | nn.Linear(inner_dim, embed_dim), 97 | nn.Dropout(dropout), 98 | ) 99 | 100 | def forward(self, x): 101 | return x + self.mlp(x) 102 | 103 | 104 | class Dropout(nn.Module): 105 | """ 106 | Implementation of dropout with the ability to share the dropout mask 107 | along a particular dimension. 108 | """ 109 | 110 | def __init__(self, r: float, batch_dim: T.Union[int, T.List[int]]): 111 | super(Dropout, self).__init__() 112 | 113 | self.r = r 114 | if type(batch_dim) == int: 115 | batch_dim = [batch_dim] 116 | self.batch_dim = batch_dim 117 | self.dropout = nn.Dropout(self.r) 118 | 119 | def forward(self, x: torch.Tensor) -> torch.Tensor: 120 | shape = list(x.shape) 121 | if self.batch_dim is not None: 122 | for bd in self.batch_dim: 123 | shape[bd] = 1 124 | return x * self.dropout(x.new_ones(shape)) 125 | 126 | 127 | class Attention(nn.Module): 128 | def __init__(self, embed_dim, num_heads, head_width, gated=False): 129 | super().__init__() 130 | assert embed_dim == num_heads * head_width 131 | 132 | self.embed_dim = embed_dim 133 | self.num_heads = num_heads 134 | self.head_width = head_width 135 | 136 | self.proj = nn.Linear(embed_dim, embed_dim * 3, bias=False) 137 | self.o_proj = nn.Linear(embed_dim, embed_dim, bias=True) 138 | self.gated = gated 139 | if gated: 140 | self.g_proj = nn.Linear(embed_dim, embed_dim) 141 | torch.nn.init.zeros_(self.g_proj.weight) 142 | torch.nn.init.ones_(self.g_proj.bias) 143 | 144 | self.rescale_factor = self.head_width**-0.5 145 | 146 | torch.nn.init.zeros_(self.o_proj.bias) 147 | 148 | def forward(self, x, mask=None, bias=None, indices=None): 149 | """ 150 | Basic self attention with optional mask and external pairwise bias. 151 | To handle sequences of different lengths, use mask. 152 | 153 | Inputs: 154 | x: batch of input sequneces (.. x L x C) 155 | mask: batch of boolean masks where 1=valid, 0=padding position (.. x L_k). optional. 156 | bias: batch of scalar pairwise attention biases (.. x Lq x Lk x num_heads). optional. 157 | 158 | Outputs: 159 | sequence projection (B x L x embed_dim), attention maps (B x L x L x num_heads) 160 | """ 161 | 162 | t = rearrange(self.proj(x), "... l (h c) -> ... h l c", h=self.num_heads) 163 | q, k, v = t.chunk(3, dim=-1) 164 | 165 | q = self.rescale_factor * q 166 | a = torch.einsum("...qc,...kc->...qk", q, k) 167 | 168 | # Add external attention bias. 169 | if bias is not None: 170 | a = a + rearrange(bias, "... lq lk h -> ... h lq lk") 171 | 172 | # Do not attend to padding tokens. 173 | if mask is not None: 174 | mask = repeat( 175 | mask, "... lk -> ... h lq lk", h=self.num_heads, lq=q.shape[-2] 176 | ) 177 | a = a.masked_fill(mask == False, -np.inf) 178 | 179 | a = F.softmax(a, dim=-1) 180 | 181 | y = torch.einsum("...hqk,...hkc->...qhc", a, v) 182 | y = rearrange(y, "... h c -> ... (h c)", h=self.num_heads) 183 | 184 | if self.gated: 185 | y = self.g_proj(x).sigmoid() * y 186 | y = self.o_proj(y) 187 | 188 | return y, rearrange(a, "... lq lk h -> ... h lq lk") 189 | 190 | 191 | class TriangularSelfAttentionBlock(nn.Module): 192 | def __init__( 193 | self, 194 | sequence_state_dim, 195 | pairwise_state_dim, 196 | sequence_head_width, 197 | pairwise_head_width, 198 | dropout=0, 199 | **__kwargs, 200 | ): 201 | super().__init__() 202 | 203 | assert sequence_state_dim % sequence_head_width == 0 204 | assert pairwise_state_dim % pairwise_head_width == 0 205 | sequence_num_heads = sequence_state_dim // sequence_head_width 206 | pairwise_num_heads = pairwise_state_dim // pairwise_head_width 207 | assert sequence_state_dim == sequence_num_heads * sequence_head_width 208 | assert pairwise_state_dim == pairwise_num_heads * pairwise_head_width 209 | assert pairwise_state_dim % 2 == 0 210 | 211 | self.sequence_state_dim = sequence_state_dim 212 | self.pairwise_state_dim = pairwise_state_dim 213 | 214 | self.layernorm_1 = nn.LayerNorm(sequence_state_dim) 215 | 216 | self.sequence_to_pair = SequenceToPair( 217 | sequence_state_dim, pairwise_state_dim // 2, pairwise_state_dim 218 | ) 219 | self.pair_to_sequence = PairToSequence(pairwise_state_dim, sequence_num_heads) 220 | 221 | self.seq_attention = Attention( 222 | sequence_state_dim, sequence_num_heads, sequence_head_width, gated=True 223 | ) 224 | 225 | self.mlp_seq = ResidueMLP(sequence_state_dim, 4 * sequence_state_dim, dropout=dropout) 226 | self.mlp_pair = ResidueMLP(pairwise_state_dim, 4 * pairwise_state_dim, dropout=dropout) 227 | 228 | assert dropout < 0.4 229 | self.drop = nn.Dropout(dropout) 230 | 231 | torch.nn.init.zeros_(self.sequence_to_pair.o_proj.weight) 232 | torch.nn.init.zeros_(self.sequence_to_pair.o_proj.bias) 233 | torch.nn.init.zeros_(self.pair_to_sequence.linear.weight) 234 | torch.nn.init.zeros_(self.seq_attention.o_proj.weight) 235 | torch.nn.init.zeros_(self.seq_attention.o_proj.bias) 236 | torch.nn.init.zeros_(self.mlp_seq.mlp[-2].weight) 237 | torch.nn.init.zeros_(self.mlp_seq.mlp[-2].bias) 238 | torch.nn.init.zeros_(self.mlp_pair.mlp[-2].weight) 239 | torch.nn.init.zeros_(self.mlp_pair.mlp[-2].bias) 240 | 241 | def forward(self, sequence_state, pairwise_state, mask=None, chunk_size=None, **__kwargs): 242 | """ 243 | Inputs: 244 | sequence_state: B x L x sequence_state_dim 245 | pairwise_state: B x L x L x pairwise_state_dim 246 | mask: B x L boolean tensor of valid positions 247 | 248 | Output: 249 | sequence_state: B x L x sequence_state_dim 250 | pairwise_state: B x L x L x pairwise_state_dim 251 | """ 252 | # Update sequence state 253 | bias = self.pair_to_sequence(pairwise_state) 254 | 255 | # Self attention with bias + mlp. 256 | y = self.layernorm_1(sequence_state) 257 | y, _ = self.seq_attention(y, mask=mask, bias=bias) 258 | sequence_state = sequence_state + self.drop(y) 259 | sequence_state = self.mlp_seq(sequence_state) 260 | 261 | # Update pairwise state 262 | pairwise_state = pairwise_state + self.sequence_to_pair(sequence_state) 263 | pairwise_state = self.mlp_pair(pairwise_state) 264 | 265 | return sequence_state, pairwise_state 266 | 267 | 268 | class RobertaLMHead(nn.Module): 269 | """Head for masked language modeling.""" 270 | 271 | def __init__(self, embed_dim, hidden_dim, output_dim, weight): 272 | super().__init__() 273 | self.dense = nn.Linear(embed_dim, hidden_dim) 274 | self.layer_norm = ESM1bLayerNorm(hidden_dim) 275 | self.weight = weight 276 | self.bias = nn.Parameter(torch.zeros(output_dim)) 277 | 278 | def forward(self, features): 279 | x = self.dense(features) 280 | x = gelu(x) 281 | x = self.layer_norm(x) 282 | # project back to size of vocabulary with bias 283 | x = F.linear(x, self.weight) + self.bias 284 | return x -------------------------------------------------------------------------------- /src/data/batched_proteingym_substitution_datamodule.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, Dict, Optional 3 | import pandas as pd 4 | import logging 5 | import torch 6 | import esm 7 | import esm.inverse_folding 8 | import ast 9 | from lightning import LightningDataModule 10 | from torch.utils.data import Dataset, DataLoader 11 | 12 | 13 | LOG = logging.getLogger(__name__) 14 | 15 | 16 | class BatchedProteinGymSubstitutionDataset(Dataset): 17 | def __init__(self, assay_name, wt_sequence, coords, train_data, valid_data): 18 | super().__init__() 19 | self.assay_name = assay_name 20 | _, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D() 21 | self.batch_converter = self.alphabet.get_batch_converter() 22 | _, self.structure_alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50() 23 | self.coord_converter = esm.inverse_folding.util.CoordBatchConverter(self.structure_alphabet) 24 | self.coords = [] 25 | for coord in coords: 26 | batch = [(coord, None, None)] 27 | coord, confidence, _, _, padding_mask = self.coord_converter(batch) 28 | self.coords.append((coord, padding_mask, confidence)) 29 | 30 | self.batch_tokens = [] 31 | for sequence in wt_sequence: 32 | batch_labels, batch_strs, batch_tokens = self.batch_converter([('protein', sequence)]) 33 | self.batch_tokens.append(batch_tokens) 34 | 35 | self.train_labels = [] 36 | for assay_idx, dms_data in enumerate(train_data): 37 | train_label = [] 38 | for index, data in dms_data.iterrows(): 39 | mutants = data['mutant'].split(':') 40 | mutant_list = [] 41 | for mutant in mutants: 42 | location = int(mutant[1:-1]) 43 | mutant_list.append((location, torch.tensor(self.alphabet.tok_to_idx[mutant[:1]], dtype=torch.long), torch.tensor(self.alphabet.tok_to_idx[mutant[-1:]], dtype=torch.long))) 44 | train_label.append((torch.tensor(data['DMS_score'], dtype=torch.float32), mutant_list)) 45 | self.train_labels.append(train_label) 46 | 47 | self.valid_labels = [] 48 | for assay_idx, dms_data in enumerate(valid_data): 49 | valid_label = [] 50 | for index, data in dms_data.iterrows(): 51 | mutants = data['mutant'].split(':') 52 | mutant_list = [] 53 | for mutant in mutants: 54 | location = int(mutant[1:-1]) 55 | mutant_list.append((location, torch.tensor(self.alphabet.tok_to_idx[mutant[:1]], dtype=torch.long), torch.tensor(self.alphabet.tok_to_idx[mutant[-1:]], dtype=torch.long))) 56 | valid_label.append((torch.tensor(data['DMS_score'], dtype=torch.float32), mutant_list)) 57 | self.valid_labels.append(valid_label) 58 | 59 | def __getitem__(self, index): 60 | return self.assay_name, self.batch_tokens, self.coords, self.train_labels, self.valid_labels 61 | 62 | 63 | def __len__(self): 64 | return 1 65 | 66 | 67 | class BatchedProteinGymSubstitutionDataModule(LightningDataModule): 68 | """`LightningDataModule` for the ProteinGym dataset. 69 | Read the docs: 70 | https://lightning.ai/docs/pytorch/latest/data/datamodule.html 71 | """ 72 | 73 | def __init__( 74 | self, 75 | data_dir: str = "data/", 76 | batch_size: int = 64, 77 | num_workers: int = 0, 78 | pin_memory: bool = False, 79 | assay_index: int = 0, # 0 - 100 80 | split_type: str = "random", # random, modulo, contiguous 81 | split_index: int = 0, # 0 - 4 82 | support_assay_num: int = 40, 83 | ) -> None: 84 | super().__init__() 85 | self.save_hyperparameters(logger=False) 86 | 87 | self.data_train: Optional[Dataset] = None 88 | self.data_val: Optional[Dataset] = None 89 | _, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D() 90 | self.batch_size_per_device = batch_size 91 | 92 | @property 93 | def num_classes(self) -> int: 94 | return 1 95 | 96 | def prepare_data(self) -> None: 97 | pass 98 | 99 | def setup(self, stage: Optional[str] = None) -> None: 100 | """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. 101 | """ 102 | 103 | if self.trainer is not None: 104 | if self.hparams.batch_size % self.trainer.world_size != 0: 105 | raise RuntimeError( 106 | f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})." 107 | ) 108 | self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size 109 | 110 | # load and split datasets only if not loaded already 111 | if not self.data_train and not self.data_val: 112 | self.hparams.data_dir = Path(self.hparams.data_dir) / 'ProteinGym' 113 | assay_reference_file = pd.read_csv(self.hparams.data_dir/'reference'/'DMS_substitutions.csv') 114 | assay_id = assay_reference_file["DMS_id"][self.hparams.assay_index] 115 | assay_file_name = assay_reference_file["DMS_filename"][assay_reference_file["DMS_id"]==assay_id].values[0] 116 | pdb_file_name = assay_reference_file["pdb_file"][assay_reference_file["DMS_id"]==assay_id].values[0] 117 | assay_data = pd.read_csv(self.hparams.data_dir/'substitutions'/assay_file_name) 118 | wt_sequence = assay_reference_file["target_seq"][assay_reference_file["DMS_id"]==assay_id].values[0] 119 | structure = esm.inverse_folding.util.load_structure(str(self.hparams.data_dir/'structure'/pdb_file_name), 'A') 120 | coord, _ = esm.inverse_folding.util.extract_coords_from_structure(structure) 121 | 122 | ### For generalization test 123 | # skip_nums=[0, 24, 25, 29, 30, 31, 58, 86, 103, 104, 128, 130, 175, 184, 185, 207] 124 | # with open(Path(self.hparams.data_dir) / 'ProteinGym' / 'cluster/id_0.5_cov_0.8.txt', 'r') as f_in: 125 | # cluster_set = ast.literal_eval(f_in.readline()) 126 | # test_assay_nums = None 127 | # for cluster in cluster_set: 128 | # if self.hparams.assay_index in cluster: 129 | # test_assay_nums = cluster 130 | # assay_selection_type = assay_reference_file["coarse_selection_type"][assay_reference_file["DMS_id"]==assay_id].values[0] 131 | # support_assay_ids = assay_reference_file["DMS_id"][assay_reference_file['coarse_selection_type']==assay_selection_type] 132 | # support_assay_ids = support_assay_ids[support_assay_ids.index != self.hparams.assay_index] 133 | # for skip_num in skip_nums: 134 | # support_assay_ids = support_assay_ids[support_assay_ids.index != skip_num] 135 | # for test_assay_id in test_assay_nums: 136 | # support_assay_ids = support_assay_ids[support_assay_ids.index != test_assay_id] 137 | # support_assay_file_names = [assay_reference_file["DMS_filename"][assay_reference_file["DMS_id"]==support_assay_id].values[0] for support_assay_id in support_assay_ids] 138 | # support_pdb_file_names = [assay_reference_file["pdb_file"][assay_reference_file["DMS_id"]==support_assay_id].values[0] for support_assay_id in support_assay_ids] 139 | # support_assay_data = [pd.read_csv(self.hparams.data_dir/'substitutions'/support_assay_file_name) for support_assay_file_name in support_assay_file_names] 140 | # support_wt_sequences = [assay_reference_file["target_seq"][assay_reference_file["DMS_id"]==support_assay_id].values[0] for support_assay_id in support_assay_ids] 141 | # support_structures = [esm.inverse_folding.util.load_structure(str(self.hparams.data_dir/'structure'/support_pdb_file_name), 'A') for support_pdb_file_name in support_pdb_file_names] 142 | # support_coords = [esm.inverse_folding.util.extract_coords_from_structure(support_structure)[0] for support_structure in support_structures] 143 | # support_assay_indices = list(support_assay_ids.index) 144 | # if len(support_assay_indices) > self.hparams.support_assay_num: 145 | # support_assay_indices, support_wt_sequences, support_coords, support_assay_data = support_assay_indices[:self.hparams.support_assay_num], support_wt_sequences[:self.hparams.support_assay_num], support_coords[:self.hparams.support_assay_num], support_assay_data[:self.hparams.support_assay_num] 146 | 147 | # self.data_train = BatchedProteinGymSubstitutionDataset(support_assay_indices, support_wt_sequences, support_coords, 148 | # support_assay_data, 149 | # []) 150 | # self.data_val = BatchedProteinGymSubstitutionDataset([assay_id], [wt_sequence], [coord], 151 | # [], 152 | # [assay_data]) 153 | 154 | self.data_train = BatchedProteinGymSubstitutionDataset([assay_id], [wt_sequence], [coord], 155 | [assay_data[assay_data[f"fold_{self.hparams.split_type}_5"] != self.hparams.split_index].reset_index()], 156 | [assay_data[assay_data[f"fold_{self.hparams.split_type}_5"] == self.hparams.split_index].reset_index()]) 157 | self.data_val = BatchedProteinGymSubstitutionDataset([assay_id], [wt_sequence], [coord], 158 | [assay_data[assay_data[f"fold_{self.hparams.split_type}_5"] != self.hparams.split_index].reset_index()], 159 | [assay_data[assay_data[f"fold_{self.hparams.split_type}_5"] == self.hparams.split_index].reset_index()]) 160 | LOG.info(f'Target assay {assay_id}; Length: {len(wt_sequence)}') 161 | 162 | def collator(self, raw_batch): 163 | return raw_batch[0] 164 | 165 | def train_dataloader(self) -> DataLoader[Any]: 166 | return DataLoader( 167 | dataset=self.data_train, 168 | batch_size=self.batch_size_per_device, 169 | num_workers=self.hparams.num_workers, 170 | pin_memory=self.hparams.pin_memory, 171 | collate_fn=self.collator, 172 | shuffle=True 173 | ) 174 | 175 | def val_dataloader(self) -> DataLoader[Any]: 176 | return DataLoader( 177 | dataset=self.data_val, 178 | batch_size=self.batch_size_per_device, 179 | num_workers=self.hparams.num_workers, 180 | pin_memory=self.hparams.pin_memory, 181 | collate_fn=self.collator, 182 | shuffle=False 183 | ) 184 | 185 | def teardown(self, stage: Optional[str] = None) -> None: 186 | """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`, 187 | `trainer.test()`, and `trainer.predict()`. 188 | 189 | :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. 190 | Defaults to ``None``. 191 | """ 192 | pass 193 | 194 | def state_dict(self) -> Dict[Any, Any]: 195 | """Called when saving a checkpoint. Implement to generate and save the datamodule state. 196 | 197 | :return: A dictionary containing the datamodule state that you want to save. 198 | """ 199 | return {} 200 | 201 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 202 | """Called when loading a checkpoint. Implement to reload datamodule state given datamodule 203 | `state_dict()`. 204 | 205 | :param state_dict: The datamodule state returned by `self.state_dict()`. 206 | """ 207 | pass 208 | -------------------------------------------------------------------------------- /src/models/DePLM_module.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import numpy as np 4 | from scipy import stats 5 | from typing import Any, Dict, Tuple 6 | import random 7 | import torch 8 | import esm 9 | from lightning import LightningModule 10 | from torchmetrics import MaxMetric, MeanMetric, PearsonCorrCoef 11 | import torchsort 12 | 13 | def spearmanr(pred, target, **kw): 14 | pred = torchsort.soft_rank(pred, **kw) 15 | target = torchsort.soft_rank(target, **kw) 16 | pred = pred - pred.mean() 17 | pred = pred / pred.norm() 18 | target = target - target.mean() 19 | target = target / target.norm() 20 | return (pred * target).sum() 21 | 22 | LOG = logging.getLogger(__name__) 23 | 24 | from src.models.DePLM_components.modules import TriangularSelfAttentionBlock 25 | from src.models.DePLM_components.sort import quick_sort, _rank_data 26 | 27 | from torch import nn 28 | from torch.nn import LayerNorm 29 | 30 | 31 | class DePLM4ProteinEngineeringModule(LightningModule): 32 | def __init__( 33 | self, 34 | diff_total_step, 35 | model, 36 | optimizer: torch.optim.Optimizer, 37 | scheduler: torch.optim.lr_scheduler, 38 | compile: bool, 39 | ) -> None: 40 | super().__init__() 41 | 42 | # this line allows to access init params with 'self.hparams' attribute 43 | # also ensures init params will be stored in ckpt 44 | self.save_hyperparameters() 45 | if model == 'esm2': 46 | self.model, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D() 47 | elif model == 'esm1v': 48 | self.model, self.alphabet = esm.pretrained.esm1v_t33_650M_UR90S_1() 49 | self.model.embed_dim = self.model.args.embed_dim 50 | self.model.attention_heads = self.model.args.attention_heads 51 | 52 | self.structure_model, self.structure_alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50() 53 | self.model, self.structure_model = self.model.eval(), self.structure_model.eval() 54 | for name, param in self.model.named_parameters(): 55 | param.requires_grad = False 56 | for name, param in self.structure_model.named_parameters(): 57 | param.requires_grad = False 58 | 59 | self.state = {} 60 | 61 | self.val_spearman = MeanMetric() 62 | self.val_spearman_best = MaxMetric() 63 | self.criterion = PearsonCorrCoef() 64 | 65 | self.diff_total_step = diff_total_step 66 | self.automatic_optimization = False 67 | 68 | self.repr_dim = self.model.embed_dim 69 | self.repr_combine = nn.Parameter(torch.zeros(self.model.num_layers + 1)) 70 | self.repr_mlp = nn.Sequential(LayerNorm(self.repr_dim), nn.Linear(self.repr_dim, self.repr_dim), nn.GELU(), nn.Linear(self.repr_dim, self.repr_dim)) 71 | self.structure_repr_mlp = nn.Sequential(LayerNorm(self.structure_model.encoder.args.encoder_embed_dim), nn.Linear(self.structure_model.encoder.args.encoder_embed_dim, self.repr_dim), nn.GELU(), nn.Linear(self.repr_dim, self.repr_dim)) 72 | self.attn_dim = 32 73 | self.attn_num = self.model.num_layers * self.model.attention_heads 74 | self.attn_mlp = nn.Sequential(LayerNorm(self.attn_num), nn.Linear(self.attn_num, self.attn_num), nn.GELU(), nn.Linear(self.attn_num, self.attn_dim)) 75 | 76 | self.num_blocks = 1 77 | self.blocks = nn.ModuleList( 78 | [ 79 | TriangularSelfAttentionBlock( 80 | sequence_state_dim=self.repr_dim, 81 | pairwise_state_dim=self.attn_dim, 82 | sequence_head_width=32, 83 | pairwise_head_width=32, 84 | dropout=0.2 85 | ) for _ in range(self.num_blocks) 86 | ] 87 | ) 88 | self.step_embedding = nn.Embedding(self.diff_total_step, self.repr_dim,) 89 | self.logits_mlp = nn.Sequential(nn.Linear(self.model.alphabet_size, self.repr_dim), nn.GELU(), LayerNorm(self.repr_dim)) 90 | self.conv = nn.Conv1d(self.repr_dim, self.repr_dim, kernel_size=7, stride=1, padding=3) 91 | self.logits_representation_mlp = nn.Sequential(nn.Linear(2 * self.repr_dim, self.repr_dim)) 92 | 93 | def state_setup(self, x_wt, y_label): 94 | y = [] 95 | locations = set() 96 | for score, mutants in y_label: 97 | y.append({ 98 | 'score': score, 99 | 'mutants': [(mutant[0], mutant[1], mutant[2]) for mutant in mutants] 100 | }) 101 | for mutant in mutants: 102 | locations.add(mutant[0]) 103 | 104 | masked_batch_tokens = x_wt.clone() 105 | with torch.no_grad(): 106 | result = self.model(masked_batch_tokens, repr_layers=range(self.model.num_layers+1), need_head_weights=True) 107 | x = {'input': x_wt, 'logits': result['logits'][0], 'representation': torch.stack([v for _, v in sorted(result['representations'].items())], dim=2), 'attention': result['attentions'].permute(0, 4, 3, 1, 2).flatten(3, 4)} 108 | return (x, y) 109 | 110 | def forward(self, x, structure_repr): 111 | return_logits = [] 112 | 113 | logits, representation, attention = x['logits'], x['representation'], x['attention'] 114 | residx = torch.arange(x['input'].shape[1], device=self.device).expand_as(x['input']) 115 | mask = torch.ones_like(x['input']) 116 | 117 | representation = self.repr_mlp((self.repr_combine.softmax(0).unsqueeze(0) @ representation).squeeze(2)) + self.structure_repr_mlp(structure_repr).repeat(representation.shape[0], 1, 1) 118 | attention = self.attn_mlp(attention) 119 | 120 | def trunk_iter(s, z, residx, mask): 121 | for block in self.blocks: 122 | s, z = block(s, z, mask=mask, residue_index=residx, chunk_size=None) 123 | return s, z 124 | 125 | representation, attention = trunk_iter(representation, attention, residx, mask) 126 | logits_recycle = logits 127 | return_logits.append(logits_recycle) 128 | step_embed = self.step_embedding(torch.tensor([i for i in range(self.diff_total_step)], dtype=torch.long).to(representation.device)) 129 | for diff_step in range(self.diff_total_step): 130 | logits_representation = self.logits_mlp(logits_recycle.detach()) + step_embed[diff_step] 131 | logits_representation = self.conv(logits_representation.transpose(0,1)).transpose(0,1) 132 | logits_recycle = logits_recycle - self.model.lm_head(self.logits_representation_mlp(torch.cat([logits_representation, representation[0]], dim=-1))) 133 | return_logits.append(logits_recycle) 134 | 135 | return return_logits, representation 136 | 137 | def loss_compute_and_backward(self, x, y, structure_repr): 138 | opt = self.optimizers() 139 | x_logits, _ = self.forward(x, structure_repr=structure_repr) 140 | combined_x_logits = [torch.stack([sum([x_logits[step][mutant[0], mutant[2]] - x_logits[step][mutant[0], mutant[1]] for mutant in y[index]['mutants']]) / len(y[index]['mutants']) for index in range(len(y))], dim=-1) for step in range(self.diff_total_step + 1)] 141 | combined_y_scores = torch.stack([y[index]['score'] for index in range(len(y))]) 142 | loss = 0. 143 | for index in range(1, self.diff_total_step+1): 144 | combined_intermediate_ranked_y_scores = self.intermediate_score_compute(combined_x_logits[index-1], combined_y_scores) 145 | if index == self.diff_total_step: 146 | combined_intermediate_ranked_y_scores = combined_intermediate_ranked_y_scores[-1] 147 | else: 148 | combined_intermediate_ranked_y_scores = combined_intermediate_ranked_y_scores[index-1] 149 | 150 | loss += (1 - spearmanr(combined_x_logits[index].unsqueeze(0), combined_intermediate_ranked_y_scores.unsqueeze(0))) 151 | self.manual_backward(loss) 152 | 153 | opt.step() 154 | opt.zero_grad() 155 | spearman = stats.spearmanr(combined_x_logits[-1].detach().cpu(), combined_y_scores.detach().cpu()).statistic 156 | return loss, spearman 157 | 158 | def output_process(self, x_logits, y): 159 | x_logits = torch.stack([sum([x_logits[mutant[0], mutant[2]] - x_logits[mutant[0], mutant[1]] for mutant in y[index]['mutants']]) / len(y[index]['mutants']) for index in range(len(y))], dim=-1) 160 | y_scores = torch.stack([y[index]['score'] for index in range(len(y))]) 161 | return x_logits, y_scores 162 | 163 | def on_train_start(self) -> None: 164 | self.val_spearman.reset() 165 | self.val_spearman_best.reset() 166 | 167 | def training_step( 168 | self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int 169 | ) -> torch.Tensor: 170 | """Perform a single training step on a batch of data from the training set. 171 | 172 | :param batch: A batch of data (a tuple) containing the input tensor of images and target 173 | labels. 174 | :param batch_idx: The index of the current batch. 175 | :return: A tensor of losses between model predictions and targets. 176 | """ 177 | assay_names, batch_tokens, coords, train_labels, _ = batch 178 | 179 | for name, x_wt, coord, train_label in zip(assay_names, batch_tokens, coords, train_labels): 180 | if f'{name}-train' not in self.state: 181 | self.state[f'{name}-train'] = self.state_setup(x_wt, train_label) 182 | if f'{name}-structure' not in self.state: 183 | self.state[f'{name}-structure'] = self.structure_model.encoder.forward(*coord)['encoder_out'][0].transpose(0, 1) 184 | x, y = self.state[f'{name}-train'] 185 | loss, spearman = self.loss_compute_and_backward(x, y, structure_repr=self.state[f'{name}-structure']) 186 | LOG.info(f'Training assay {name}: loss {loss}; spearman {spearman}') 187 | 188 | def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: 189 | """Perform a single validation step on a batch of data from the validation set. 190 | 191 | :param batch: A batch of data (a tuple) containing the input tensor of images and target 192 | labels. 193 | :param batch_idx: The index of the current batch. 194 | """ 195 | self.val_spearman.reset() 196 | assay_names, batch_tokens, coords, _, valid_labels = batch 197 | 198 | for name, x_wt, coord, valid_label in zip(assay_names, batch_tokens, coords, valid_labels): 199 | if f'{name}-valid' not in self.state: 200 | self.state[f'{name}-valid'] = self.state_setup(x_wt, valid_label) 201 | if f'{name}-structure' not in self.state: 202 | self.state[f'{name}-structure'] = self.structure_model.encoder.forward(*coord)['encoder_out'][0].transpose(0, 1) 203 | x, y = self.state[f'{name}-valid'] 204 | x_logits, _ = self.forward(x, structure_repr=self.state[f'{name}-structure']) 205 | x_logits = x_logits[-1] 206 | x_logits, y_scores = self.output_process(x_logits, y) 207 | spearman = stats.spearmanr(x_logits.detach().cpu(), y_scores.detach().cpu()).statistic 208 | LOG.info(f'Testing assay {name}: spearman {spearman}') 209 | self.val_spearman(spearman) 210 | 211 | def on_validation_epoch_end(self) -> None: 212 | "Lightning hook that is called when a validation epoch ends." 213 | 214 | self.log("val/spearman", self.val_spearman.compute(), sync_dist=True, prog_bar=True) 215 | 216 | spearman = self.val_spearman.compute() # get current val acc 217 | self.val_spearman_best(spearman) # update best so far val acc 218 | # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object 219 | # otherwise metric would be reset by lightning after each epoch 220 | self.log("val/spearman_best", self.val_spearman_best.compute(), sync_dist=True, prog_bar=True) 221 | 222 | def intermediate_score_compute(self, x_logits, y_scores): 223 | ranked_x_logits = (_rank_data(x_logits) - 1) 224 | ranked_y_scores = (_rank_data(y_scores) - 1) 225 | sorted_index = torch.argsort(ranked_y_scores) 226 | begin_state = ranked_x_logits[sorted_index] 227 | end_state = ranked_y_scores[sorted_index] 228 | intermediate_states = torch.tensor(quick_sort(begin_state.tolist()) + [end_state.tolist()], dtype=x_logits.dtype, device=self.device) 229 | intermediate_ranked_y_scores = [] 230 | for state in intermediate_states: 231 | ranked_y_score = torch.empty([state.shape[0]], dtype=x_logits.dtype, device=self.device) 232 | ranked_y_score[sorted_index] = state 233 | intermediate_ranked_y_scores.append(ranked_y_score) 234 | return intermediate_ranked_y_scores 235 | 236 | def setup(self, stage: str) -> None: 237 | """Lightning hook that is called at the beginning of fit (train + validate), validate, 238 | test, or predict. 239 | 240 | This is a good hook when you need to build models dynamically or adjust something about 241 | them. This hook is called on every process when using DDP. 242 | 243 | :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. 244 | """ 245 | if self.hparams.compile and stage == "fit": 246 | self.net = torch.compile(self.net) 247 | 248 | def configure_optimizers(self) -> Dict[str, Any]: 249 | """Choose what optimizers and learning-rate schedulers to use in your optimization. 250 | Normally you'd need one. But in the case of GANs or similar you might have multiple. 251 | 252 | Examples: 253 | https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers 254 | 255 | :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training. 256 | """ 257 | optimizer = self.hparams.optimizer(params=self.parameters()) 258 | 259 | if self.hparams.scheduler is not None: 260 | scheduler = self.hparams.scheduler(optimizer=optimizer) 261 | return { 262 | "optimizer": optimizer, 263 | "lr_scheduler": { 264 | "scheduler": scheduler, 265 | "monitor": "val/spearman_best", 266 | "interval": "epoch", 267 | "frequency": 1, 268 | }, 269 | } 270 | return {"optimizer": optimizer} 271 | --------------------------------------------------------------------------------