├── .gitignore ├── README.md ├── asset ├── predictor.png ├── retriever.png └── training.png ├── config ├── esm_fold.yaml ├── predictor │ ├── esm_aav.yaml │ ├── esm_beta.yaml │ ├── esm_binloc.yaml │ ├── esm_ec.yaml │ ├── esm_fluorescence.yaml │ ├── esm_gb1.yaml │ ├── esm_go.yaml │ ├── esm_stability.yaml │ ├── esm_subloc.yaml │ └── esm_thermo.yaml └── retriever │ ├── esm_ec.yaml │ └── esm_go.yaml ├── esm_s ├── dataset.py └── task.py ├── script ├── retrieve.py └── run.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Structure-Informed Protein Language Model 2 | 3 | This is the official codebase of the paper 4 | 5 | **Structure-Informed Protein Language Model** 6 | [[ArXiv](https://arxiv.org/abs/2402.05856)] 7 | 8 | [Zuobai Zhang](https://oxer11.github.io/), [Jiarui Lu](https://lujiarui.github.io/), [Vijil Chenthamarakshan](https://researcher.watson.ibm.com/researcher/view.php?person=us-ecvijil), [Aurelie Lozano](https://researcher.watson.ibm.com/researcher/view.php?person=us-aclozano), [Payel Das](https://researcher.watson.ibm.com/researcher/view.php?person=us-daspa), [Jian Tang](https://jian-tang.com/) 9 | 10 | 11 | ## Overview 12 | 13 | Protein language models are a powerful tool for learning protein representations. However, traditional protein language models lack explicit structural supervision. Recent studies have developed models that combine large-scale pre-training on protein sequences with the integration of structural information as input, *e.g.*, [ESM-GearNet](https://arxiv.org/abs/2303.06275). However, their reliance on protein structures as input limits their application to proteins without structures. 14 | 15 | To address this issue, in this work, we introduce the integration of remote homology detection to **distill structural information into protein language models 16 | without requiring explicit protein structures as input**. 17 | 18 | ![Training](./asset/training.png) 19 | 20 | We take the [ESM](https://github.com/facebookresearch/esm) models as example and train them on remote homology detection tasks, *a.k.a.*, fold classification. 21 | The model weights for structure-informed ESM, *i.e.*, ESM-S, can be found [here](https://huggingface.co/Oxer11/ESM-S). 22 | 23 | ## Installation 24 | 25 | You may install the dependencies via either conda or pip. Generally, ESM-S works 26 | with Python 3.7/3.8 and PyTorch version >= 1.12.0. 27 | Please make sure the latest version of torchdrug is installed. 28 | 29 | ### From Conda 30 | 31 | ```bash 32 | conda install torchdrug pytorch=1.12.1 cudatoolkit=11.6 -c milagraph -c pytorch-lts -c pyg -c conda-forge 33 | conda install easydict pyyaml -c conda-forge 34 | ``` 35 | 36 | ### From Pip 37 | 38 | ```bash 39 | pip install torch==1.12.1+cu116 -f https://download.pytorch.org/whl/lts/1.12/torch_lts.html 40 | pip install torchdrug 41 | pip install easydict pyyaml 42 | ``` 43 | 44 | ## Reproduction 45 | 46 | ### Download Datasets and Model Weights 47 | 48 | Define the environment variable `DATADIR` and `MODELDIR` and then download datasets and model weights into the corresponding directories. 49 | The datasets and model weights can be downloaded from [Oxer11/ESM-S](https://huggingface.co/Oxer11/ESM-S) and [Oxer11/Protein-Function-Annotation](https://huggingface.co/datasets/Oxer11/Protein-Function-Annotation). 50 | For all other datasets besides EC, GO and Fold, they will be downloaded automatically by TorchDrug during first loading. 51 | 52 | ```bash 53 | DATADIR=./data 54 | MODELDIR=./model 55 | 56 | mkdir $DATADIR 57 | cd $DATADIR 58 | # Download remote homology detection dataset 59 | wget https://huggingface.co/datasets/Oxer11/Protein-Function-Annotation/resolve/main/fold.tar.gz 60 | tar -xvf fold.tar.gz 61 | # Download Enyzme Commission dataset 62 | wget https://huggingface.co/datasets/Oxer11/Protein-Function-Annotation/resolve/main/ec.tar.gz 63 | tar -xvf ec.tar.gz 64 | # Download Gene Ontology dataset 65 | wget https://huggingface.co/datasets/Oxer11/Protein-Function-Annotation/resolve/main/ec.tar.gz 66 | tar -xvf ec.tar.gz 67 | 68 | cd .. 69 | mkdir $MODELDIR 70 | cd $MODELDIR 71 | # Download ESM-2-650M model weight 72 | wget https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50D.pt 73 | # Download ESM-2-650M-S model weight 74 | wget https://huggingface.co/Oxer11/ESM-S/resolve/main/esm_650m_s.pth 75 | ``` 76 | 77 | ### Load Trained Model Weight 78 | Here we show how to load the structure-informed PLM weights `esm_650m_s.pth` into the `torchdrug.models.EvolutionaryScaleModeling` module. 79 | By default, the model weights are saved as state dict. 80 | ```python 81 | import torch 82 | from torchdrug import models 83 | 84 | model_dir = "./model" # Set the path to your model dir 85 | esm = models.EvolutionaryScaleModeling(model_dir, model="ESM-2-650M", readout="mean") 86 | 87 | # Load ESM-2-650M-S 88 | model_dict = torch.load(os.path.join(model_dir, "esm_650m_s.pth"), map_location=torch.device("cpu")) 89 | esm.load_state_dict(model_dict) 90 | ``` 91 | 92 | ### Structure-Informed Training 93 | 94 | To reproduce the training of structure-informed protein language models, we need to train a base protein language model on the remote homology detection task, *i.e.*, fold classification. 95 | You may choose to run on 4 gpus by reseting the `gpus` parameter in configure files. 96 | 97 | ```bash 98 | # Train ESM-2-650M on the fold classification dataset 99 | python script/run.py -c config/esm_fold.yaml --datadir $DATADIR/fold --modeldir $MODELDIR --model ESM-2-650M 100 | 101 | # Train ESM-2-650M with 4 gpus 102 | # Remember to change the gpu in the config file to [0, 1, 2, 3] 103 | python -m torch.distributed.launch --nproc_per_node=4 script/run.py -c config/esm_fold.yaml --datadir $DATADIR/fold --modeldir $MODELDIR --model ESM-2-650M 104 | ``` 105 | 106 | ### Predictor-Based Methods 107 | 108 | To test the effect of structure-informed training, we compare the results by feeding ESM and ESM-S's representations into a 2-layer MLP predictor. 109 | The 2-layer MLP is fine-tuned on downtream function prediction datasets. 110 | 111 | ```bash 112 | # Tune a 2-layer MLP on ESM's representations on EC 113 | python script/run.py -c config/predictor/esm_ec.yaml --datadir $DATADIR/ec --modeldir $MODELDIR --model ESM-2-650M --ckpt null 114 | 115 | # Tune a 2-layer MLP on ESM-S's representations on EC 116 | python script/run.py -c config/predictor/esm_ec.yaml --datadir $DATADIR/ec --modeldir $MODELDIR --model ESM-2-650M --ckpt $MODELDIR/esm_650m_s.pth 117 | 118 | # Tune a 2-layer MLP on ESM-S's representations on GO-BP 119 | python script/run.py -c config/predictor/esm_go.yaml --datadir $DATADIR/go --level bp --modeldir $MODELDIR --model ESM-2-650M --ckpt $MODELDIR/esm_650m_s.pth 120 | 121 | # Tune a 2-layer MLP on ESM-S's representations on Beta Lacatamase 122 | # The dataset will be downloaded automatcially. 123 | python script/run.py -c config/predictor/esm_beta.yaml --datadir $DATADIR/ --modeldir $MODELDIR --model ESM-2-650M --ckpt $MODELDIR/esm_650m_s.pth 124 | ``` 125 | 126 | You can also change the model `ESM-2-650M` to other sizes of ESM models. 127 | ```bash 128 | # Tune a 2-layer MLP on ESM-2-150M-S's representations on EC 129 | # Remember to download the esm_150_s.pth from the link above 130 | python script/run.py -c config/predictor/esm_ec.yaml --datadir $DATADIR/ec --modeldir $MODELDIR --model ESM-2-150M --ckpt $MODELDIR/esm_150m_s.pth 131 | ``` 132 | 133 | After fine-tuning, you are expected to obtain the following results. 134 | ![Predictor](./asset/predictor.png) 135 | 136 | 137 | ### Retriever-Based Methods 138 | Besides predictor-based methods, we also use ESM and ESM-2's representations as a measure for measuring protein similarity. 139 | Based on these similarities, we can annotate function labels for proteins in the test set. 140 | 141 | ```bash 142 | # Run retriever with ESM's representations on EC 143 | python script/retrieve.py -c config/retriever/esm_ec.yaml --datadir $DATADIR/ec --modeldir $MODELDIR --model ESM-2-650M --ckpt $MODELDIR/esm_650m_s.pth 144 | 145 | # Run retriever with ESM-S's representations on GO-BP 146 | python script/retrieve.py -c config/retriever/esm_go.yaml --datadir $DATADIR/go --level bp --modeldir $MODELDIR --model ESM-2-650M --ckpt $MODELDIR/esm_650m_s.pth 147 | 148 | # Run retriever with ESM-S's representations on GO-MF 149 | python script/retrieve.py -c config/retriever/esm_go.yaml --datadir $DATADIR/go --level mf --modeldir $MODELDIR --model ESM-2-650M --ckpt $MODELDIR/esm_650m_s.pth 150 | 151 | # Run retriever with ESM-S's representations on GO-CC 152 | python script/retrieve.py -c config/retriever/esm_go.yaml --datadir $DATADIR/go --level cc --modeldir $MODELDIR --model ESM-2-650M --ckpt $MODELDIR/esm_650m_s.pth 153 | ``` 154 | 155 | You are expected to obtain the following results. 156 | ![Retriever](./asset/retriever.png) 157 | 158 | 159 | ## Citation 160 | If you find this codebase useful in your research, please cite the following paper. 161 | 162 | ```bibtex 163 | @article{zhang2024structureplm, 164 | title={Structure-Informed Protein Language Model}, 165 | author={Zhang, Zuobai and Lu, Jiarui and Chenthamarakshan, Vijil and Lozano, Aurelie and Das, Payel and Tang, Jian}, 166 | journal={arXiv preprint arXiv:2402.05856}, 167 | year={2024} 168 | } 169 | ``` 170 | -------------------------------------------------------------------------------- /asset/predictor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepGraphLearning/esm-s/9fc3b9bd24876b8b1d4b5bb8859f74f9e0954e75/asset/predictor.png -------------------------------------------------------------------------------- /asset/retriever.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepGraphLearning/esm-s/9fc3b9bd24876b8b1d4b5bb8859f74f9e0954e75/asset/retriever.png -------------------------------------------------------------------------------- /asset/training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepGraphLearning/esm-s/9fc3b9bd24876b8b1d4b5bb8859f74f9e0954e75/asset/training.png -------------------------------------------------------------------------------- /config/esm_fold.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ~/scratch/protein_output 2 | 3 | dataset: 4 | class: MyFold 5 | path: {{ datadir }} 6 | transform: 7 | class: Compose 8 | transforms: 9 | - class: ProteinView 10 | view: residue 11 | - class: TruncateProtein 12 | max_length: 550 13 | 14 | task: 15 | class: FoldClassification 16 | model: 17 | class: ESM 18 | path: {{ modeldir }} 19 | model: {{ model }} 20 | mlp_batch_norm: True 21 | mlp_dropout: 0.2 22 | 23 | optimizer: 24 | class: Adam 25 | lr: 1.0e-4 26 | 27 | scheduler: 28 | class: ReduceLROnPlateau 29 | factor: 0.6 30 | patience: 5 31 | 32 | engine: 33 | gpus: [0] #, 1, 2, 3] 34 | batch_size: 8 35 | log_interval: 1000 36 | 37 | lr_ratio: 0.1 38 | 39 | eval_metric: accuracy 40 | 41 | train: 42 | num_epoch: 50 -------------------------------------------------------------------------------- /config/predictor/esm_aav.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ~/scratch/protein_output 2 | 3 | dataset: 4 | class: AAV 5 | path: {{ datadir }} 6 | atom_feature: null 7 | bond_feature: null 8 | keep_mutation_region: True 9 | transform: 10 | class: Compose 11 | transforms: 12 | - class: ProteinView 13 | view: residue 14 | 15 | task: 16 | class: PropertyPrediction 17 | model: 18 | class: ESM 19 | path: {{ modeldir }} 20 | model: {{ model }} 21 | readout: mean 22 | criterion: mse 23 | metric: ["mae", "rmse", "spearmanr"] 24 | normalization: False 25 | num_mlp_layer: 2 26 | 27 | optimizer: 28 | class: Adam 29 | lr: 5.0e-5 30 | 31 | engine: 32 | gpus: [0] #, 1, 2, 3] 33 | batch_size: 32 34 | gradient_interval: 4 35 | 36 | eval_metric: spearmanr 37 | lr_ratio: 0 38 | 39 | model_checkpoint: {{ ckpt }} 40 | 41 | train: 42 | num_epoch: 100 43 | -------------------------------------------------------------------------------- /config/predictor/esm_beta.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ~/scratch/protein_output 2 | 3 | dataset: 4 | class: BetaLactamase 5 | path: {{ datadir }} 6 | atom_feature: null 7 | bond_feature: null 8 | transform: 9 | class: Compose 10 | transforms: 11 | - class: ProteinView 12 | view: residue 13 | 14 | task: 15 | class: PropertyPrediction 16 | model: 17 | class: ESM 18 | path: {{ modeldir }} 19 | model: {{ model }} 20 | readout: mean 21 | criterion: mse 22 | metric: ["mae", "rmse", "spearmanr"] 23 | normalization: False 24 | num_mlp_layer: 2 25 | 26 | optimizer: 27 | class: Adam 28 | lr: 5.0e-5 29 | 30 | engine: 31 | gpus: [0] #, 1, 2, 3] 32 | batch_size: 32 33 | gradient_interval: 4 34 | 35 | eval_metric: spearmanr 36 | lr_ratio: 0 37 | 38 | model_checkpoint: {{ ckpt }} 39 | 40 | train: 41 | num_epoch: 100 -------------------------------------------------------------------------------- /config/predictor/esm_binloc.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ~/scratch/protein_output 2 | 3 | dataset: 4 | class: BinaryLocalization 5 | path: {{ datadir }} 6 | atom_feature: null 7 | bond_feature: null 8 | transform: 9 | class: Compose 10 | transforms: 11 | - class: ProteinView 12 | view: residue 13 | 14 | task: 15 | class: PropertyPrediction 16 | model: 17 | class: ESM 18 | path: {{ modeldir }} 19 | model: {{ model }} 20 | readout: mean 21 | criterion: ce 22 | metric: ["acc", "mcc"] 23 | num_mlp_layer: 2 24 | num_class: 2 25 | 26 | optimizer: 27 | class: Adam 28 | lr: 5.0e-5 29 | 30 | engine: 31 | gpus: [0] #, 1, 2, 3] 32 | batch_size: 32 33 | gradient_interval: 4 34 | 35 | eval_metric: accuracy 36 | lr_ratio: 0 37 | 38 | model_checkpoint: {{ ckpt }} 39 | 40 | train: 41 | num_epoch: 100 -------------------------------------------------------------------------------- /config/predictor/esm_ec.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ~/scratch/protein_output 2 | 3 | dataset: 4 | class: EC 5 | path: {{ datadir }} 6 | percent: 95 7 | transform: 8 | class: Compose 9 | transforms: 10 | - class: ProteinView 11 | view: residue 12 | - class: TruncateProtein 13 | max_length: 550 14 | 15 | task: 16 | class: FunctionAnnotation 17 | model: 18 | class: ESM 19 | path: {{ modeldir }} 20 | model: {{ model }} 21 | mlp_batch_norm: True 22 | mlp_dropout: 0.2 23 | metric: ['auprc@micro', 'f1_max'] 24 | 25 | optimizer: 26 | class: Adam 27 | lr: 1.0e-4 28 | 29 | scheduler: 30 | class: ReduceLROnPlateau 31 | factor: 0.6 32 | patience: 5 33 | 34 | engine: 35 | gpus: [0] #, 1, 2, 3] 36 | batch_size: 8 37 | log_interval: 1000 38 | 39 | eval_metric: f1_max 40 | lr_ratio: 0.0 41 | 42 | model_checkpoint: {{ ckpt }} 43 | 44 | train: 45 | num_epoch: 50 -------------------------------------------------------------------------------- /config/predictor/esm_fluorescence.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ~/scratch/protein_output 2 | 3 | dataset: 4 | class: Fluorescence 5 | path: {{ datadir }} 6 | atom_feature: null 7 | bond_feature: null 8 | transform: 9 | class: Compose 10 | transforms: 11 | - class: ProteinView 12 | view: residue 13 | 14 | task: 15 | class: PropertyPrediction 16 | model: 17 | class: ESM 18 | path: {{ modeldir }} 19 | model: {{ model }} 20 | readout: mean 21 | criterion: mse 22 | metric: ["mae", "rmse", "spearmanr"] 23 | normalization: False 24 | num_mlp_layer: 2 25 | 26 | optimizer: 27 | class: Adam 28 | lr: 5.0e-5 29 | 30 | engine: 31 | gpus: [0] #, 1, 2, 3] 32 | batch_size: 32 33 | gradient_interval: 4 34 | 35 | eval_metric: spearmanr 36 | lr_ratio: 0 37 | 38 | model_checkpoint: {{ ckpt }} 39 | 40 | train: 41 | num_epoch: 100 -------------------------------------------------------------------------------- /config/predictor/esm_gb1.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ~/scratch/protein_output 2 | 3 | dataset: 4 | class: GB1 5 | path: {{ datadir }} 6 | atom_feature: null 7 | bond_feature: null 8 | transform: 9 | class: Compose 10 | transforms: 11 | - class: ProteinView 12 | view: residue 13 | 14 | task: 15 | class: PropertyPrediction 16 | model: 17 | class: ESM 18 | path: {{ modeldir }} 19 | model: {{ model }} 20 | readout: mean 21 | criterion: mse 22 | metric: ["mae", "rmse", "spearmanr"] 23 | normalization: False 24 | num_mlp_layer: 2 25 | 26 | optimizer: 27 | class: Adam 28 | lr: 5.0e-5 29 | 30 | engine: 31 | gpus: [0] #, 1, 2, 3] 32 | batch_size: 32 33 | gradient_interval: 4 34 | 35 | eval_metric: spearmanr 36 | lr_ratio: 0 37 | 38 | model_checkpoint: {{ ckpt }} 39 | 40 | train: 41 | num_epoch: 100 42 | -------------------------------------------------------------------------------- /config/predictor/esm_go.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ~/scratch/protein_output 2 | 3 | dataset: 4 | class: GO 5 | path: {{ datadir }} 6 | level: {{ level }} 7 | percent: 95 8 | transform: 9 | class: Compose 10 | transforms: 11 | - class: ProteinView 12 | view: residue 13 | - class: TruncateProtein 14 | max_length: 550 15 | 16 | task: 17 | class: FunctionAnnotation 18 | model: 19 | class: ESM 20 | path: {{ modeldir }} 21 | model: {{ model }} 22 | mlp_batch_norm: True 23 | mlp_dropout: 0.2 24 | metric: ['auprc@micro', 'f1_max'] 25 | 26 | optimizer: 27 | class: Adam 28 | lr: 1.0e-4 29 | 30 | scheduler: 31 | class: ReduceLROnPlateau 32 | factor: 0.6 33 | patience: 5 34 | 35 | engine: 36 | gpus: [0] #, 1, 2, 3] 37 | batch_size: 8 38 | log_interval: 1000 39 | 40 | eval_metric: f1_max 41 | lr_ratio: 0.0 42 | 43 | model_checkpoint: {{ ckpt }} 44 | 45 | train: 46 | num_epoch: 50 -------------------------------------------------------------------------------- /config/predictor/esm_stability.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ~/scratch/protein_output 2 | 3 | dataset: 4 | class: Stability 5 | path: {{ datadir }} 6 | atom_feature: null 7 | bond_feature: null 8 | transform: 9 | class: Compose 10 | transforms: 11 | - class: ProteinView 12 | view: residue 13 | 14 | task: 15 | class: PropertyPrediction 16 | model: 17 | class: ESM 18 | path: {{ datadir }} 19 | model: {{ model }} 20 | readout: mean 21 | criterion: mse 22 | metric: ["mae", "rmse", "spearmanr"] 23 | normalization: False 24 | num_mlp_layer: 2 25 | 26 | optimizer: 27 | class: Adam 28 | lr: 5.0e-5 29 | 30 | engine: 31 | gpus: [0] #, 1, 2, 3] 32 | batch_size: 32 33 | gradient_interval: 4 34 | 35 | eval_metric: spearmanr 36 | lr_ratio: True 37 | 38 | model_checkpoint: {{ ckpt }} 39 | 40 | train: 41 | num_epoch: 100 -------------------------------------------------------------------------------- /config/predictor/esm_subloc.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ~/scratch/protein_output 2 | 3 | dataset: 4 | class: SubcellularLocalization 5 | path: {{ datadir }} 6 | atom_feature: null 7 | bond_feature: null 8 | transform: 9 | class: Compose 10 | transforms: 11 | - class: ProteinView 12 | view: residue 13 | 14 | task: 15 | class: PropertyPrediction 16 | model: 17 | class: ESM 18 | path: {{ datadir }} 19 | model: {{ model }} 20 | readout: mean 21 | criterion: ce 22 | metric: ["acc", "mcc"] 23 | num_mlp_layer: 2 24 | num_class: 10 25 | 26 | optimizer: 27 | class: Adam 28 | lr: 5.0e-5 29 | 30 | engine: 31 | gpus: [0] #, 1, 2, 3] 32 | batch_size: 32 33 | gradient_interval: 4 34 | 35 | eval_metric: accuracy 36 | lr_ratio: 0 37 | 38 | model_checkpoint: {{ ckpt }} 39 | 40 | train: 41 | num_epoch: 100 -------------------------------------------------------------------------------- /config/predictor/esm_thermo.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ~/scratch/protein_output 2 | 3 | dataset: 4 | class: Thermostability 5 | path: {{ datadir }} 6 | atom_feature: null 7 | bond_feature: null 8 | transform: 9 | class: Compose 10 | transforms: 11 | - class: ProteinView 12 | view: residue 13 | 14 | task: 15 | class: PropertyPrediction 16 | model: 17 | class: ESM 18 | path: {{ datadir }} 19 | model: {{ model }} 20 | readout: mean 21 | criterion: mse 22 | metric: ["mae", "rmse", "spearmanr"] 23 | normalization: False 24 | num_mlp_layer: 2 25 | 26 | optimizer: 27 | class: Adam 28 | lr: 5.0e-5 29 | 30 | engine: 31 | gpus: [0] #, 1, 2, 3] 32 | batch_size: 8 33 | log_interval: 1000 34 | 35 | eval_metric: spearmanr 36 | lr_ratio: 0 37 | 38 | model_checkpoint: {{ ckpt }} 39 | 40 | train: 41 | num_epoch: 100 42 | -------------------------------------------------------------------------------- /config/retriever/esm_ec.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ~/scratch/protein_output 2 | 3 | dataset: 4 | class: EC 5 | path: {{ datadir }} 6 | percent: 95 7 | transform: 8 | class: Compose 9 | transforms: 10 | - class: ProteinView 11 | view: residue 12 | - class: TruncateProtein 13 | max_length: 550 14 | 15 | task: 16 | class: FunctionAnnotation 17 | model: 18 | class: ESM 19 | path: {{ modeldir }} 20 | model: {{ model }} 21 | mlp_batch_norm: True 22 | mlp_dropout: 0.2 23 | metric: ['auprc@micro', 'f1_max'] 24 | 25 | gpus: [0] 26 | batch_size: 8 27 | knn: 5 28 | weighted: exp 29 | temp: 0.03 30 | 31 | model_checkpoint: {{ ckpt }} 32 | -------------------------------------------------------------------------------- /config/retriever/esm_go.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ~/scratch/protein_output 2 | 3 | dataset: 4 | class: GO 5 | path: {{ datadir }} 6 | level: {{ level }} 7 | percent: 95 8 | transform: 9 | class: Compose 10 | transforms: 11 | - class: ProteinView 12 | view: residue 13 | - class: TruncateProtein 14 | max_length: 550 15 | 16 | task: 17 | class: FunctionAnnotation 18 | model: 19 | class: ESM 20 | path: {{ modeldir }} 21 | model: {{ model }} 22 | mlp_batch_norm: True 23 | mlp_dropout: 0.2 24 | metric: ['auprc@micro', 'f1_max'] 25 | 26 | gpus: [0] 27 | batch_size: 8 28 | knn: 5 29 | weighted: exp 30 | temp: 0.03 31 | 32 | model_checkpoint: {{ ckpt }} 33 | -------------------------------------------------------------------------------- /esm_s/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import copy 4 | import random 5 | import pickle 6 | import glob 7 | import lmdb 8 | import zipfile 9 | import warnings 10 | import joblib 11 | 12 | from tqdm import tqdm 13 | import numpy as np 14 | from collections import defaultdict 15 | 16 | import torch 17 | from torch.utils import data as torch_data 18 | from torch.nn import functional as F 19 | 20 | from torchdrug import datasets, data, utils, core 21 | from torchdrug.layers import functional 22 | from torchdrug.core import Registry as R 23 | 24 | 25 | @R.register("transforms.ProteinViewList") 26 | class ProteinViewList(core.Configurable): 27 | 28 | def __init__(self, view, keys="graph"): 29 | self.view = view 30 | if isinstance(keys, str): 31 | keys = [keys] 32 | self.keys = keys 33 | 34 | def __call__(self, item): 35 | item = item.copy() 36 | for key in self.keys: 37 | graphs = copy.copy(item[key]) 38 | if isinstance(graphs, list): 39 | for graph in graphs: 40 | graph.view = self.view 41 | else: 42 | graphs.view = self.view 43 | item[key] = graphs 44 | return item 45 | 46 | 47 | def load_protein(seq, pos): 48 | residue_type = torch.as_tensor(seq) 49 | num_residue = len(seq) 50 | residue_feature = torch.zeros((num_residue, 1), dtype=torch.float) 51 | residue_number = torch.arange(num_residue) 52 | num_atom = num_residue 53 | atom2residue = torch.arange(num_residue) 54 | node_position = torch.as_tensor(pos) 55 | atom_type = torch.as_tensor([6 for _ in range(num_atom)]) 56 | atom_name = torch.as_tensor([data.Protein.atom_name2id["CA"] for _ in range(num_atom)]) 57 | 58 | edge_list = torch.as_tensor([[0, 0, 0]]) 59 | bond_type = torch.as_tensor([0]) 60 | 61 | protein = data.Protein(edge_list, atom_type, bond_type, num_node=num_atom, num_residue=num_residue, 62 | node_position=node_position, atom_name=atom_name, 63 | atom2residue=atom2residue, residue_feature=residue_feature, 64 | residue_type=residue_type, residue_number=residue_number) 65 | return protein 66 | 67 | 68 | @R.register("datasets.MyFold") 69 | class MyFold(data.ProteinDataset): 70 | 71 | def __init__(self, path, split="training", transform=None): 72 | path = os.path.expanduser(path) 73 | self.path = path 74 | self.split = split 75 | self.transform = transform 76 | npy_dir = os.path.join(path, 'coordinates', split) 77 | fasta_file = os.path.join(path, split+'.fasta') 78 | 79 | protein_seqs = [] 80 | with open(fasta_file, 'r') as f: 81 | protein_name = '' 82 | for line in f: 83 | if line.startswith('>'): 84 | protein_name = line.rstrip()[1:] 85 | else: 86 | amino_chain = line.rstrip() 87 | amino_ids = [] 88 | for amino in amino_chain: 89 | amino_ids.append(data.Protein.residue_symbol2id.get(amino, 0)) 90 | protein_seqs.append((protein_name, np.array(amino_ids))) 91 | 92 | fold_classes = {} 93 | with open(os.path.join(path, 'class_map.txt'), 'r') as f: 94 | for line in f: 95 | arr = line.rstrip().split('\t') 96 | fold_classes[arr[0]] = int(arr[1]) 97 | 98 | protein_folds = {} 99 | with open(os.path.join(path, split+'.txt'), 'r') as f: 100 | for line in f: 101 | arr = line.rstrip().split('\t') 102 | protein_folds[arr[0]] = fold_classes[arr[-1]] 103 | 104 | self.data = [] 105 | self.labels = [] 106 | for protein_name, amino_ids in protein_seqs: 107 | pos = np.load(os.path.join(npy_dir, protein_name+".npy")) 108 | center = np.sum(a=pos, axis=0, keepdims=True)/pos.shape[0] 109 | pos = pos - center 110 | protein = load_protein(amino_ids.astype(int), pos) 111 | self.data.append((protein_name, protein)) 112 | self.labels.append(protein_folds[protein_name]) 113 | 114 | self.num_classes = max(self.labels) + 1 115 | 116 | @property 117 | def tasks(self,): 118 | return ["targets"] 119 | 120 | def get_item(self, idx): 121 | protein_name, protein = self.data[idx] 122 | label = torch.as_tensor(self.labels[idx]) 123 | 124 | item = {"graph": protein, "targets": label} 125 | if self.transform: 126 | item = self.transform(item) 127 | 128 | return item 129 | 130 | 131 | @R.register("datasets.EC") 132 | class EC(data.ProteinDataset): 133 | 134 | def __init__(self, path, percent=30, split="train", transform=None, **kwargs): 135 | path = os.path.expanduser(path) 136 | self.path = path 137 | self.percent = percent 138 | self.split = split 139 | self.transform = transform 140 | npy_dir = os.path.join(path, 'coordinates') 141 | fasta_file = os.path.join(path, split+'.fasta') 142 | 143 | test_set = set() 144 | if split == "test": 145 | with open(os.path.join(path, "nrPDB-EC_test.csv"), 'r') as f: 146 | head = True 147 | for line in f: 148 | if head: 149 | head = False 150 | continue 151 | arr = line.rstrip().split(',') 152 | if percent == 30 and arr[1] == '1': 153 | test_set.add(arr[0]) 154 | elif percent == 40 and arr[2] == '1': 155 | test_set.add(arr[0]) 156 | elif percent == 50 and arr[3] == '1': 157 | test_set.add(arr[0]) 158 | elif percent == 70 and arr[4] == '1': 159 | test_set.add(arr[0]) 160 | elif percent == 95 and arr[5] == '1': 161 | test_set.add(arr[0]) 162 | else: 163 | pass 164 | 165 | protein_seqs = [] 166 | with open(fasta_file, 'r') as f: 167 | protein_name = '' 168 | for line in f: 169 | if line.startswith('>'): 170 | protein_name = line.rstrip()[1:] 171 | else: 172 | if split == "test" and (protein_name not in test_set): 173 | continue 174 | amino_chain = line.rstrip() 175 | amino_ids = [] 176 | for amino in amino_chain: 177 | amino_ids.append(data.Protein.residue_symbol2id.get(amino, 0)) 178 | protein_seqs.append((protein_name, np.array(amino_ids))) 179 | 180 | self.data = [] 181 | for protein_name, amino_ids in protein_seqs: 182 | pos = np.load(os.path.join(npy_dir, protein_name+".npy")) 183 | center = np.sum(a=pos, axis=0, keepdims=True)/pos.shape[0] 184 | pos = pos - center 185 | protein = load_protein(amino_ids.astype(int), pos, **kwargs) 186 | self.data.append((protein_name, protein)) 187 | 188 | level_idx = 1 189 | ec_cnt = 0 190 | ec_num = {} 191 | ec_annotations = {} 192 | self.labels = {} 193 | 194 | with open(os.path.join(path, 'nrPDB-EC_annot.tsv'), 'r') as f: 195 | for idx, line in enumerate(f): 196 | if idx == 1: 197 | arr = line.rstrip().split('\t') 198 | for ec in arr: 199 | ec_annotations[ec] = ec_cnt 200 | ec_num[ec] = 0 201 | ec_cnt += 1 202 | 203 | elif idx > 2: 204 | arr = line.rstrip().split('\t') 205 | protein_labels = [] 206 | if len(arr) > level_idx: 207 | protein_ec_list = arr[level_idx] 208 | protein_ec_list = protein_ec_list.split(',') 209 | for ec in protein_ec_list: 210 | if len(ec) > 0: 211 | protein_labels.append(ec_annotations[ec]) 212 | ec_num[ec] += 1 213 | self.labels[arr[0]] = np.array(protein_labels) 214 | 215 | self.num_classes = len(ec_annotations) 216 | self.weights = np.zeros((ec_cnt,), dtype=np.float32) 217 | for ec, idx in ec_annotations.items(): 218 | self.weights[idx] = len(self.labels)/ec_num[ec] 219 | 220 | @property 221 | def tasks(self,): 222 | return ["targets"] 223 | 224 | def get_item(self, idx): 225 | protein_name, protein = self.data[idx] 226 | label = np.zeros((self.num_classes,)).astype(np.float32) 227 | if len(self.labels[protein_name]) > 0: 228 | label[self.labels[protein_name]] = 1.0 229 | label = torch.as_tensor(label) 230 | 231 | item = {"graph": protein, "targets": label} 232 | 233 | if self.transform: 234 | item = self.transform(item) 235 | 236 | return item 237 | 238 | 239 | @R.register("datasets.GO") 240 | class GO(EC): 241 | 242 | def __init__(self, path, level="mf", percent=30, split="train", transform=None): 243 | path = os.path.expanduser(path) 244 | self.path = path 245 | self.percent = percent 246 | self.split = split 247 | self.transform = transform 248 | npy_dir = os.path.join(path, 'coordinates') 249 | fasta_file = os.path.join(path, split+'.fasta') 250 | 251 | test_set = set() 252 | if split == "test": 253 | with open(os.path.join(path, "nrPDB-GO_2019.06.18_test.csv"), 'r') as f: 254 | head = True 255 | for line in f: 256 | if head: 257 | head = False 258 | continue 259 | arr = line.rstrip().split(',') 260 | if percent == 30 and arr[1] == '1': 261 | test_set.add(arr[0]) 262 | elif percent == 40 and arr[2] == '1': 263 | test_set.add(arr[0]) 264 | elif percent == 50 and arr[3] == '1': 265 | test_set.add(arr[0]) 266 | elif percent == 70 and arr[4] == '1': 267 | test_set.add(arr[0]) 268 | elif percent == 95 and arr[5] == '1': 269 | test_set.add(arr[0]) 270 | else: 271 | pass 272 | 273 | protein_seqs = [] 274 | with open(fasta_file, 'r') as f: 275 | protein_name = '' 276 | for line in f: 277 | if line.startswith('>'): 278 | protein_name = line.rstrip()[1:] 279 | else: 280 | if split == "test" and (protein_name not in test_set): 281 | continue 282 | amino_chain = line.rstrip() 283 | amino_ids = [] 284 | for amino in amino_chain: 285 | amino_ids.append(data.Protein.residue_symbol2id.get(amino, 0)) 286 | protein_seqs.append((protein_name, np.array(amino_ids))) 287 | 288 | self.data = [] 289 | for protein_name, amino_ids in protein_seqs: 290 | pos = np.load(os.path.join(npy_dir, protein_name+".npy")) 291 | center = np.sum(a=pos, axis=0, keepdims=True)/pos.shape[0] 292 | pos = pos - center 293 | protein = load_protein(amino_ids.astype(int), pos) 294 | self.data.append((protein_name, protein)) 295 | 296 | level_idx = 0 297 | go_cnt = 0 298 | go_num = {} 299 | go_annotations = {} 300 | self.labels = {} 301 | with open(os.path.join(path, 'nrPDB-GO_2019.06.18_annot.tsv'), 'r') as f: 302 | for idx, line in enumerate(f): 303 | if idx == 1 and level == "mf": 304 | level_idx = 1 305 | arr = line.rstrip().split('\t') 306 | for go in arr: 307 | go_annotations[go] = go_cnt 308 | go_num[go] = 0 309 | go_cnt += 1 310 | elif idx == 5 and level == "bp": 311 | level_idx = 2 312 | arr = line.rstrip().split('\t') 313 | for go in arr: 314 | go_annotations[go] = go_cnt 315 | go_num[go] = 0 316 | go_cnt += 1 317 | elif idx == 9 and level == "cc": 318 | level_idx = 3 319 | arr = line.rstrip().split('\t') 320 | for go in arr: 321 | go_annotations[go] = go_cnt 322 | go_num[go] = 0 323 | go_cnt += 1 324 | elif idx > 12: 325 | arr = line.rstrip().split('\t') 326 | protein_labels = [] 327 | if len(arr) > level_idx: 328 | protein_go_list = arr[level_idx] 329 | protein_go_list = protein_go_list.split(',') 330 | for go in protein_go_list: 331 | if len(go) > 0: 332 | protein_labels.append(go_annotations[go]) 333 | go_num[go] += 1 334 | self.labels[arr[0]] = np.array(protein_labels) 335 | 336 | self.num_classes = len(go_annotations) 337 | 338 | self.weights = np.zeros((go_cnt,), dtype=np.float32) 339 | for go, idx in go_annotations.items(): 340 | self.weights[idx] = len(self.labels)/go_num[go] 341 | 342 | 343 | class FLIPDataset(data.ProteinDataset): 344 | 345 | def load_csv(self, csv_file, sequence_field="sequence", target_fields=None, verbose=0, **kwargs): 346 | if target_fields is not None: 347 | target_fields = set(target_fields) 348 | 349 | with open(csv_file, "r") as fin: 350 | reader = csv.reader(fin) 351 | if verbose: 352 | reader = iter(tqdm(reader, "Loading %s" % csv_file, utils.get_line_count(csv_file))) 353 | fields = next(reader) 354 | train, valid, test = [], [], [] 355 | _sequences = [] 356 | _targets = defaultdict(list) 357 | for i, values in enumerate(reader): 358 | for field, value in zip(fields, values): 359 | if field == sequence_field: 360 | _sequences.append(value) 361 | elif target_fields is None or field in target_fields: 362 | value = utils.literal_eval(value) 363 | if value == "": 364 | value = math.nan 365 | _targets[field].append(value) 366 | elif field == "set": 367 | if value == "train": 368 | train.append(i) 369 | elif value == "test": 370 | test.append(i) 371 | elif field == "validation": 372 | if value == "True": 373 | valid.append(i) 374 | 375 | valid_set = set(valid) 376 | sequences = [_sequences[i] for i in train if i not in valid_set] \ 377 | + [_sequences[i] for i in valid] \ 378 | + [_sequences[i] for i in test] 379 | targets = defaultdict(list) 380 | for key, value in _targets.items(): 381 | targets[key] = [value[i] for i in train if i not in valid_set] \ 382 | + [value[i] for i in valid] \ 383 | + [value[i] for i in test] 384 | self.load_sequence(sequences, targets, verbose=verbose, **kwargs) 385 | self.num_samples = [len(train) - len(valid), len(valid), len(test)] 386 | 387 | 388 | @R.register("datasets.AAV") 389 | class AAV(FLIPDataset): 390 | 391 | url = "https://github.com/J-SNACKKB/FLIP/raw/d5c35cc716ca93c3c74a0b43eef5b60cbf88521f/splits/aav/splits.zip" 392 | md5 = "cabdd41f3386f4949b32ca220db55c58" 393 | splits = ["train", "valid", "test"] 394 | target_fields = ["target"] 395 | region = slice(474, 674) 396 | 397 | def __init__(self, path, split="two_vs_many", keep_mutation_region=False, verbose=1, **kwargs): 398 | path = os.path.expanduser(path) 399 | path = os.path.join(path, 'aav') 400 | if not os.path.exists(path): 401 | os.makedirs(path) 402 | self.path = path 403 | assert split in ['des_mut', 'low_vs_high', 'mut_des', 'one_vs_many', 'sampled', 'seven_vs_many', 'two_vs_many'] 404 | 405 | zip_file = utils.download(self.url, path, md5=self.md5) 406 | data_path = utils.extract(zip_file) 407 | csv_file = os.path.join(data_path, "splits/%s.csv" % split) 408 | 409 | self.load_csv(csv_file, target_fields=self.target_fields, verbose=verbose, **kwargs) 410 | if keep_mutation_region: 411 | for i in range(len(self.data)): 412 | self.data[i] = self.data[i][self.region] 413 | self.sequences[i] = self.sequences[i][self.region] 414 | 415 | def split(self): 416 | offset = 0 417 | splits = [] 418 | for num_sample in self.num_samples: 419 | split = torch_data.Subset(self, range(offset, offset + num_sample)) 420 | splits.append(split) 421 | offset += num_sample 422 | return splits 423 | 424 | 425 | @R.register("datasets.GB1") 426 | class GB1(FLIPDataset): 427 | 428 | url = "https://github.com/J-SNACKKB/FLIP/raw/d5c35cc716ca93c3c74a0b43eef5b60cbf88521f/splits/gb1/splits.zip" 429 | md5 = "14216947834e6db551967c2537332a12" 430 | splits = ["train", "valid", "test"] 431 | target_fields = ["target"] 432 | 433 | def __init__(self, path, split="two_vs_rest", verbose=1, **kwargs): 434 | path = os.path.expanduser(path) 435 | path = os.path.join(path, 'gb1') 436 | if not os.path.exists(path): 437 | os.makedirs(path) 438 | self.path = path 439 | assert split in ['one_vs_rest', 'two_vs_rest', 'three_vs_rest', 'low_vs_high', 'sampled'] 440 | 441 | zip_file = utils.download(self.url, path, md5=self.md5) 442 | data_path = utils.extract(zip_file) 443 | csv_file = os.path.join(data_path, "splits/%s.csv" % split) 444 | 445 | self.load_csv(csv_file, target_fields=self.target_fields, verbose=verbose, **kwargs) 446 | 447 | def split(self): 448 | offset = 0 449 | splits = [] 450 | for num_sample in self.num_samples: 451 | split = torch_data.Subset(self, range(offset, offset + num_sample)) 452 | splits.append(split) 453 | offset += num_sample 454 | return splits 455 | 456 | 457 | @R.register("datasets.Thermostability") 458 | class Thermostability(FLIPDataset): 459 | 460 | url = "https://github.com/J-SNACKKB/FLIP/raw/d5c35cc716ca93c3c74a0b43eef5b60cbf88521f/splits/meltome/splits.zip" 461 | md5 = "0f8b1e848568f7566713d53594c0ca90" 462 | splits = ["train", "valid", "test"] 463 | target_fields = ["target"] 464 | 465 | def __init__(self, path, split="human_cell", verbose=1, **kwargs): 466 | path = os.path.expanduser(path) 467 | path = os.path.join(path, 'thermostability') 468 | if not os.path.exists(path): 469 | os.makedirs(path) 470 | self.path = path 471 | assert split in ['human', 'human_cell', 'mixed_split'] 472 | 473 | zip_file = utils.download(self.url, path, md5=self.md5) 474 | data_path = utils.extract(zip_file) 475 | csv_file = os.path.join(data_path, "splits/%s.csv" % split) 476 | 477 | self.load_csv(csv_file, target_fields=self.target_fields, verbose=verbose, **kwargs) 478 | 479 | def split(self): 480 | offset = 0 481 | splits = [] 482 | for num_sample in self.num_samples: 483 | split = torch_data.Subset(self, range(offset, offset + num_sample)) 484 | splits.append(split) 485 | offset += num_sample 486 | return splits -------------------------------------------------------------------------------- /esm_s/task.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | 3 | import time 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from torch_scatter import scatter_add 10 | 11 | from torchdrug import core, layers, tasks, metrics, data 12 | from torchdrug.core import Registry as R 13 | 14 | 15 | class MLP(nn.Module): 16 | def __init__(self, 17 | in_channels: int, 18 | mid_channels: int, 19 | out_channels: int, 20 | batch_norm: bool, 21 | dropout: float = 0.0, 22 | bias: bool = True, 23 | leakyrelu_negative_slope: float = 0.2, 24 | momentum: float = 0.2) -> nn.Module: 25 | super(MLP, self).__init__() 26 | 27 | module = [] 28 | if batch_norm: 29 | module.append(nn.BatchNorm1d(in_channels, momentum=momentum)) 30 | module.append(nn.LeakyReLU(leakyrelu_negative_slope)) 31 | module.append(nn.Dropout(dropout)) 32 | if mid_channels is None: 33 | module.append(nn.Linear(in_channels, out_channels, bias = bias)) 34 | else: 35 | module.append(nn.Linear(in_channels, mid_channels, bias = bias)) 36 | if batch_norm: 37 | if mid_channels is None: 38 | module.append(nn.BatchNorm1d(out_channels, momentum=momentum)) 39 | else: 40 | module.append(nn.BatchNorm1d(mid_channels, momentum=momentum)) 41 | module.append(nn.LeakyReLU(leakyrelu_negative_slope)) 42 | if mid_channels is None: 43 | module.append(nn.Dropout(dropout)) 44 | else: 45 | module.append(nn.Linear(mid_channels, out_channels, bias = bias)) 46 | 47 | self.module = nn.Sequential(*module) 48 | 49 | def forward(self, input): 50 | return self.module(input) 51 | 52 | 53 | @R.register("tasks.FunctionAnnotation") 54 | class FunctionAnnotation(tasks.Task, core.Configurable): 55 | 56 | eps = 1e-10 57 | _option_members = {"metric"} 58 | 59 | def __init__(self, model, num_class=1, metric=('auprc@micro', 'f1_max'), weight=None, graph_construction_model=None, 60 | mlp_batch_norm=False, mlp_dropout=0, verbose=0): 61 | super(FunctionAnnotation, self).__init__() 62 | self.model = model 63 | if weight is None: 64 | weight = torch.ones((num_class,), dtype=torch.float) 65 | self.register_buffer("weight", torch.as_tensor(weight, dtype=torch.float)) 66 | self.metric = metric 67 | self.graph_construction_model = graph_construction_model 68 | self.verbose = verbose 69 | 70 | self.mlp = MLP(in_channels=self.model.output_dim, 71 | mid_channels=self.model.output_dim, 72 | out_channels=num_class, 73 | batch_norm=mlp_batch_norm, 74 | dropout=mlp_dropout) 75 | 76 | def forward(self, batch): 77 | all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) 78 | metric = {} 79 | 80 | pred = self.predict(batch, all_loss, metric) 81 | target = self.target(batch) 82 | 83 | loss_fn = torch.nn.BCELoss(weight=torch.as_tensor(self.weight)) 84 | loss = loss_fn(pred.sigmoid(), target) 85 | 86 | name = tasks._get_criterion_name("bce") 87 | metric[name] = loss 88 | all_loss += loss 89 | 90 | return all_loss, metric 91 | 92 | def predict(self, batch, all_loss=None, metric=None): 93 | graph = batch["graph"] 94 | if self.graph_construction_model: 95 | graph = self.graph_construction_model(graph) 96 | output = self.model(graph, graph.node_feature.float(), all_loss=all_loss, metric=metric) 97 | graph_feature = output["graph_feature"] 98 | pred = self.mlp(graph_feature) 99 | return pred 100 | 101 | def target(self, batch): 102 | return batch["targets"] 103 | 104 | def evaluate(self, pred, target): 105 | metric = {} 106 | for _metric in self.metric: 107 | if _metric == "auroc@micro": 108 | score = metrics.area_under_roc(pred.flatten(), target.long().flatten()) 109 | elif _metric == "auprc@micro": 110 | score = metrics.area_under_prc(pred.flatten(), target.long().flatten()) 111 | elif _metric == "f1_max": 112 | score = metrics.f1_max(pred, target) 113 | else: 114 | raise ValueError("Unknown criterion `%s`" % _metric) 115 | 116 | name = tasks._get_metric_name(_metric) 117 | metric[name] = score 118 | 119 | return metric 120 | 121 | 122 | @R.register("tasks.FoldClassification") 123 | class FoldClassification(tasks.Task, core.Configurable): 124 | 125 | eps = 1e-10 126 | _option_members = {"metric"} 127 | 128 | def __init__(self, model, num_class=1, metric=('acc'), graph_construction_model=None, 129 | mlp_batch_norm=False, mlp_dropout=0, verbose=0): 130 | super(FoldClassification, self).__init__() 131 | self.model = model 132 | self.metric = metric 133 | self.graph_construction_model = graph_construction_model 134 | self.verbose = verbose 135 | 136 | self.mlp = MLP(in_channels=self.model.output_dim, 137 | mid_channels=self.model.output_dim, 138 | out_channels=num_class, 139 | batch_norm=mlp_batch_norm, 140 | dropout=mlp_dropout) 141 | 142 | def forward(self, batch): 143 | all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) 144 | metric = {} 145 | 146 | pred = self.predict(batch, all_loss, metric) 147 | target = self.target(batch) 148 | 149 | loss_fn = torch.nn.CrossEntropyLoss() 150 | loss = loss_fn(pred, target) 151 | 152 | name = tasks._get_criterion_name("ce") 153 | metric[name] = loss 154 | all_loss += loss 155 | 156 | return all_loss, metric 157 | 158 | def predict(self, batch, all_loss=None, metric=None): 159 | graph = batch["graph"] 160 | if self.graph_construction_model: 161 | graph = self.graph_construction_model(graph) 162 | output = self.model(graph, graph.node_feature.float(), all_loss=all_loss, metric=metric) 163 | graph_feature = output["graph_feature"] 164 | pred = self.mlp(graph_feature) 165 | return pred 166 | 167 | def target(self, batch): 168 | return batch["targets"] 169 | 170 | def evaluate(self, pred, target): 171 | metric = {} 172 | for _metric in self.metric: 173 | if _metric == "acc": 174 | score = metrics.accuracy(pred.squeeze(-1), target.long()) 175 | else: 176 | raise ValueError("Unknown criterion `%s`" % _metric) 177 | 178 | name = tasks._get_metric_name(_metric) 179 | metric[name] = score 180 | 181 | return metric -------------------------------------------------------------------------------- /script/retrieve.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import pprint 5 | import pickle 6 | import random 7 | 8 | from tqdm import tqdm 9 | 10 | import numpy as np 11 | 12 | import torch 13 | from torch import nn 14 | 15 | from torchdrug import core, tasks, datasets, utils, metrics, data 16 | from torchdrug.utils import comm 17 | 18 | sys.path.append(os.path.dirname(os.path.dirname(__file__))) 19 | import util 20 | from esm_s import dataset, task 21 | 22 | 23 | @torch.no_grad() 24 | def dump(cfg, dataset, task): 25 | dataloader = data.DataLoader(dataset, cfg.batch_size, shuffle=False, num_workers=0) 26 | device = torch.device(cfg.gpus[0]) 27 | task = task.cuda(device) 28 | task.eval() 29 | preds = [] 30 | target = [] 31 | for batch in tqdm(dataloader): 32 | batch = utils.cuda(batch, device=device) 33 | graph = batch["graph"] 34 | if task.graph_construction_model: 35 | graph = task.graph_construction_model(graph) 36 | output = task.model(graph, graph.node_feature.float()) 37 | preds.append(output["graph_feature"].detach()) 38 | target.append(batch["targets"].detach()) 39 | pred = torch.cat(preds, dim=0) 40 | target = torch.cat(target, dim=0) 41 | return pred, target 42 | 43 | 44 | def auprc(pred, target): 45 | """ 46 | Area under precision-recall curve (PRC). 47 | 48 | Parameters: 49 | pred (Tensor): predictions of shape :math:`(n,)` 50 | target (Tensor): binary targets of shape :math:`(n,)` 51 | """ 52 | pred, order = torch.sort(pred, descending=True, stable=True) 53 | target = target[order] 54 | is_not_equal = torch.ones_like(pred) 55 | is_not_equal[:-1] = (pred[1:] != pred[:-1]).long() 56 | boundary = is_not_equal.nonzero()[:, 0] 57 | real_precision_index = torch.bucketize(torch.arange(len(target), device=target.device), boundary) 58 | real_precision_index = boundary[real_precision_index] 59 | precision = target.cumsum(0) / torch.arange(1, len(target) + 1, device=target.device) 60 | precision = precision[real_precision_index] 61 | auprc = precision[target == 1].sum() / ((target == 1).sum() + 1e-10) 62 | return auprc 63 | 64 | 65 | def f1_max(pred, target): 66 | """ 67 | F1 score with the optimal threshold. 68 | 69 | This function first enumerates all possible thresholds for deciding positive and negative 70 | samples, and then pick the threshold with the maximal F1 score. 71 | 72 | Parameters: 73 | pred (Tensor): predictions of shape :math:`(B, N)` 74 | target (Tensor): binary targets of shape :math:`(B, N)` 75 | """ 76 | order = torch.sort(pred, descending=True, dim=1, stable=True)[1] 77 | target = target.gather(1, order) 78 | precision = target.cumsum(1) / torch.ones_like(target).cumsum(1) 79 | recall = target.cumsum(1) / (target.sum(1, keepdim=True) + 1e-10) 80 | is_start = torch.zeros_like(target).bool() 81 | is_start[:, 0] = 1 82 | is_start = torch.scatter(is_start, 1, order, is_start) 83 | 84 | _pred, all_order = torch.sort(pred.flatten(), descending=True, stable=True) 85 | order = order + torch.arange(order.shape[0], device=order.device).unsqueeze(1) * order.shape[1] 86 | order = order.flatten() 87 | inv_order = torch.zeros_like(order) 88 | inv_order[order] = torch.arange(order.shape[0], device=order.device) 89 | is_start = is_start.flatten()[all_order] 90 | all_order = inv_order[all_order] 91 | precision = precision.flatten() 92 | recall = recall.flatten() 93 | all_precision = precision[all_order] - \ 94 | torch.where(is_start, torch.zeros_like(precision), precision[all_order - 1]) 95 | all_precision = all_precision.cumsum(0) / is_start.cumsum(0) 96 | all_recall = recall[all_order] - \ 97 | torch.where(is_start, torch.zeros_like(recall), recall[all_order - 1]) 98 | all_recall = all_recall.cumsum(0) / pred.shape[0] 99 | 100 | # Consider equal thresholds 101 | is_not_equal = torch.ones_like(_pred.flatten()) 102 | is_not_equal[:-1] = (_pred[1:] != _pred[:-1]).long() 103 | boundary = is_not_equal.nonzero()[:, 0] 104 | real_index = torch.bucketize(torch.arange(len(_pred), device=target.device), boundary) 105 | real_index = boundary[real_index] 106 | all_precision = all_precision[real_index] 107 | all_recall = all_recall[real_index] 108 | 109 | all_f1 = 2 * all_precision * all_recall / (all_precision + all_recall + 1e-10) 110 | return all_f1.max() 111 | 112 | 113 | @torch.no_grad() 114 | def evaluate(pred, target): 115 | return { 116 | "auprc@micro": auprc(pred.flatten(), target.long().flatten()), 117 | "f1_max": f1_max(pred, target), 118 | } 119 | 120 | 121 | @torch.no_grad() 122 | def retrieve(cfg, train_keys, train_targets, test_keys, test_targets): 123 | cos_sim = nn.CosineSimilarity(dim=1) 124 | preds = [] 125 | for i in tqdm(range(0, len(test_keys), cfg.batch_size)): 126 | test_key = test_keys[i:i+cfg.batch_size] 127 | sim = -cos_sim(test_key.unsqueeze(-1), train_keys.transpose(0, 1).unsqueeze(0)) # (batch_size, num_train) 128 | retrieval_items = sim.argsort(dim=1)[:, :cfg.knn] # (num_test, k) 129 | pred = train_targets[retrieval_items] # (num_test, k, num_tasks) 130 | if cfg.weighted == "linear": 131 | _sim = -torch.gather(sim, 1, retrieval_items).unsqueeze(-1) 132 | elif cfg.weighted == "exp": 133 | _sim = ((-torch.gather(sim, 1, retrieval_items).unsqueeze(-1)) / cfg.temp).exp() 134 | else: 135 | _sim = torch.ones((pred.shape[0], pred.shape[1], 1), device=pred.device, dtype=torch.float) 136 | pred = (pred * _sim).sum(dim=1) / _sim.sum(dim=1) 137 | preds.append(pred) 138 | pred = torch.cat(preds, dim=0) 139 | return evaluate(pred, test_targets) 140 | 141 | 142 | if __name__ == "__main__": 143 | args, vars = util.parse_args() 144 | cfg = util.load_config(args.config, context=vars) 145 | 146 | seed = args.seed 147 | torch.manual_seed(seed + comm.get_rank()) 148 | os.environ['PYTHONHASHSEED'] = str(seed) 149 | random.seed(seed) 150 | np.random.seed(seed) 151 | torch.manual_seed(seed) 152 | if torch.cuda.is_available(): 153 | torch.cuda.manual_seed(seed) 154 | torch.cuda.manual_seed_all(seed) 155 | torch.backends.cudnn.deterministic = True 156 | torch.backends.cudnn.benchmark = False 157 | 158 | logger = util.get_root_logger(file=False) 159 | if comm.get_rank() == 0: 160 | logger.warning("Config file: %s" % args.config) 161 | logger.warning(pprint.pformat(cfg)) 162 | 163 | assert cfg.dataset["class"] in ["EC", "GO"] 164 | cfg.dataset.split = "train" 165 | train_set = core.Configurable.load_config_dict(cfg.dataset) 166 | cfg.dataset.split = "valid" 167 | valid_set = core.Configurable.load_config_dict(cfg.dataset) 168 | cfg.dataset.split = "test" 169 | cfg.dataset.percent = 95 170 | test_set95 = core.Configurable.load_config_dict(cfg.dataset) 171 | print(test_set95) 172 | cfg.dataset.percent = 50 173 | test_set50 = core.Configurable.load_config_dict(cfg.dataset) 174 | print(test_set50) 175 | cfg.dataset.percent = 30 176 | test_set30 = core.Configurable.load_config_dict(cfg.dataset) 177 | print(test_set30) 178 | dataset = (train_set, valid_set, test_set95) 179 | 180 | cfg.task.num_class = valid_set.num_classes 181 | task = core.Configurable.load_config_dict(cfg.task) 182 | 183 | if cfg.get("model_checkpoint") is not None: 184 | if comm.get_rank() == 0: 185 | logger.warning("Load checkpoint from %s" % cfg.model_checkpoint) 186 | cfg.model_checkpoint = os.path.expanduser(cfg.model_checkpoint) 187 | model_dict = torch.load(cfg.model_checkpoint, map_location=torch.device('cpu')) 188 | task.model.load_state_dict(model_dict) 189 | 190 | train_keys, train_targets = dump(cfg, train_set, task) 191 | valid_keys, valid_targets = dump(cfg, valid_set, task) 192 | test95_keys, test95_targets = dump(cfg, test_set95, task) 193 | test50_keys, test50_targets = dump(cfg, test_set50, task) 194 | test30_keys, test30_targets = dump(cfg, test_set30, task) 195 | 196 | valid_metric = retrieve(cfg, train_keys, train_targets, valid_keys, valid_targets) 197 | print("Metrics on valid set:", valid_metric) 198 | test95_metric = retrieve(cfg, train_keys, train_targets, test95_keys, test95_targets) 199 | print("Metrics on test set with 0.95 cutoff:", test95_metric) 200 | test50_metric = retrieve(cfg, train_keys, train_targets, test50_keys, test50_targets) 201 | print("Metrics on test set with 0.5 cutoff:", test50_metric) 202 | test30_metric = retrieve(cfg, train_keys, train_targets, test30_keys, test30_targets) 203 | print("Metrics on test set with 0.3 cutoff:", test30_metric) -------------------------------------------------------------------------------- /script/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import pprint 5 | import random 6 | 7 | import numpy as np 8 | 9 | import torch 10 | from torch.optim import lr_scheduler 11 | 12 | from torchdrug import core, models, tasks, datasets, utils 13 | from torchdrug.utils import comm 14 | 15 | sys.path.append(os.path.dirname(os.path.dirname(__file__))) 16 | import util 17 | from esm_s import dataset, task 18 | 19 | 20 | def train_and_validate(cfg, solver, scheduler): 21 | if cfg.train.num_epoch == 0: 22 | return 23 | 24 | step = math.ceil(cfg.train.num_epoch / 500) 25 | best_result = float("-inf") 26 | best_epoch = -1 27 | 28 | for i in range(0, cfg.train.num_epoch, step): 29 | kwargs = cfg.train.copy() 30 | kwargs["num_epoch"] = min(step, cfg.train.num_epoch - i) 31 | solver.train(**kwargs) 32 | metric = solver.evaluate("valid") 33 | for k, v in metric.items(): 34 | if k.startswith(cfg.eval_metric): 35 | result = v 36 | if result > best_result: 37 | best_result = result 38 | best_epoch = solver.epoch 39 | solver.save("model_epoch_%d.pth" % solver.epoch) 40 | if isinstance(scheduler, lr_scheduler.ReduceLROnPlateau): 41 | scheduler.step(result) 42 | 43 | solver.load("model_epoch_%d.pth" % best_epoch) 44 | return solver 45 | 46 | 47 | if __name__ == "__main__": 48 | args, vars = util.parse_args() 49 | cfg = util.load_config(args.config, context=vars) 50 | working_dir = util.create_working_directory(cfg) 51 | 52 | seed = args.seed 53 | torch.manual_seed(seed + comm.get_rank()) 54 | os.environ['PYTHONHASHSEED'] = str(seed) 55 | random.seed(seed) 56 | np.random.seed(seed) 57 | torch.manual_seed(seed) 58 | if torch.cuda.is_available(): 59 | torch.cuda.manual_seed(seed) 60 | torch.cuda.manual_seed_all(seed) 61 | torch.backends.cudnn.deterministic = True 62 | torch.backends.cudnn.benchmark = False 63 | 64 | logger = util.get_root_logger() 65 | if comm.get_rank() == 0: 66 | logger.warning("Config file: %s" % args.config) 67 | logger.warning(pprint.pformat(cfg)) 68 | 69 | if cfg.dataset["class"] in ["EC", "GO", "MyFold"]: 70 | cfg.dataset.split = "training" if cfg.dataset["class"] == "MyFold" else "train" 71 | train_set = core.Configurable.load_config_dict(cfg.dataset) 72 | cfg.dataset.split = "validation" if cfg.dataset["class"] == "MyFold" else "valid" 73 | valid_set = core.Configurable.load_config_dict(cfg.dataset) 74 | cfg.dataset.split = "test_fold" if cfg.dataset["class"] == "MyFold" else "test" 75 | test_set = core.Configurable.load_config_dict(cfg.dataset) 76 | dataset = (train_set, valid_set, test_set) 77 | else: 78 | dataset = core.Configurable.load_config_dict(cfg.dataset) 79 | solver, scheduler = util.build_downstream_solver(cfg, dataset) 80 | 81 | train_and_validate(cfg, solver, scheduler) 82 | torch.save_dict(task.model.state_dict(), "esm_s.pth") 83 | 84 | logger.warning("Testing on the test set with sequence identity 95%") 85 | solver.evaluate("test") 86 | 87 | cfg.dataset.split = "test" 88 | cfg.dataset.percent = 50 89 | test_set50 = core.Configurable.load_config_dict(cfg.dataset) 90 | solver.test_set = test_set50 91 | logger.warning("Testing on the test set with sequence identity 50%") 92 | solver.evaluate("test") 93 | 94 | cfg.dataset.percent = 30 95 | test_set30 = core.Configurable.load_config_dict(cfg.dataset) 96 | solver.test_set = test_set30 97 | logger.warning("Testing on the test set with sequence identity 30%") 98 | solver.evaluate("test") -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | import argparse 5 | 6 | import yaml 7 | import jinja2 8 | from jinja2 import meta 9 | import easydict 10 | 11 | import torch 12 | from torch import distributed as dist 13 | from torch.optim import lr_scheduler 14 | 15 | from torchdrug import core, utils, datasets, models, tasks 16 | from torchdrug.utils import comm 17 | 18 | 19 | logger = logging.getLogger(__file__) 20 | 21 | 22 | def get_root_logger(file=True): 23 | logger = logging.getLogger("") 24 | logger.setLevel(logging.INFO) 25 | format = logging.Formatter("%(asctime)-10s %(message)s", "%H:%M:%S") 26 | 27 | if file: 28 | handler = logging.FileHandler("log.txt") 29 | handler.setFormatter(format) 30 | logger.addHandler(handler) 31 | 32 | return logger 33 | 34 | 35 | def create_working_directory(cfg): 36 | file_name = "working_dir.tmp" 37 | world_size = comm.get_world_size() 38 | if world_size > 1 and not dist.is_initialized(): 39 | comm.init_process_group("nccl", init_method="env://") 40 | 41 | working_dir = os.path.join(os.path.expanduser(cfg.output_dir), 42 | cfg.task["class"], cfg.dataset["class"] + cfg.dataset.get("level", ""), cfg.task.get("model", cfg.get("p_model"))["class"], 43 | time.strftime("%Y-%m-%d-%H-%M-%S")) 44 | 45 | # synchronize working directory 46 | if comm.get_rank() == 0: 47 | with open(file_name, "w") as fout: 48 | fout.write(working_dir) 49 | os.makedirs(working_dir) 50 | comm.synchronize() 51 | if comm.get_rank() != 0: 52 | with open(file_name, "r") as fin: 53 | working_dir = fin.read() 54 | comm.synchronize() 55 | if comm.get_rank() == 0: 56 | os.remove(file_name) 57 | 58 | os.chdir(working_dir) 59 | return working_dir 60 | 61 | 62 | def detect_variables(cfg_file): 63 | with open(cfg_file, "r") as fin: 64 | raw = fin.read() 65 | env = jinja2.Environment() 66 | ast = env.parse(raw) 67 | vars = meta.find_undeclared_variables(ast) 68 | return vars 69 | 70 | 71 | def load_config(cfg_file, context=None): 72 | with open(cfg_file, "r") as fin: 73 | raw = fin.read() 74 | template = jinja2.Template(raw) 75 | instance = template.render(context) 76 | cfg = yaml.safe_load(instance) 77 | cfg = easydict.EasyDict(cfg) 78 | return cfg 79 | 80 | 81 | def build_downstream_solver(cfg, dataset): 82 | if isinstance(dataset, tuple): 83 | train_set, valid_set, test_set = dataset 84 | num_classes = train_set.num_classes 85 | weights = getattr(train_set, "weights", torch.ones((train_set.num_classes,), dtype=torch.float)) 86 | else: 87 | train_set, valid_set, test_set = dataset.split() 88 | num_classes = len(dataset.targets) 89 | weights = getattr(dataset, "weights", torch.ones((len(dataset.targets),), dtype=torch.float)) 90 | if comm.get_rank() == 0: 91 | logger.warning(dataset) 92 | logger.warning("#train: %d, #valid: %d, #test: %d" % (len(train_set), len(valid_set), len(test_set))) 93 | 94 | if cfg.task['class'] == "FunctionAnnotation": 95 | cfg.task.num_class = num_classes 96 | cfg.task.weight = weights 97 | elif cfg.task['class'] == "FoldClassification": 98 | cfg.task.num_class = num_classes 99 | else: 100 | cfg.task.task = dataset.tasks 101 | task = core.Configurable.load_config_dict(cfg.task) 102 | 103 | cfg.optimizer.params = task.parameters() 104 | optimizer = core.Configurable.load_config_dict(cfg.optimizer) 105 | 106 | if "scheduler" not in cfg: 107 | scheduler = None 108 | elif cfg.scheduler["class"] == "ReduceLROnPlateau": 109 | cfg.scheduler.pop("class") 110 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, **cfg.scheduler) 111 | else: 112 | cfg.scheduler.optimizer = optimizer 113 | scheduler = core.Configurable.load_config_dict(cfg.scheduler) 114 | cfg.engine.scheduler = scheduler 115 | 116 | solver = core.Engine(task, train_set, valid_set, test_set, optimizer, **cfg.engine) 117 | 118 | if "lr_ratio" in cfg: 119 | if cfg.lr_ratio > 0: 120 | cfg.optimizer.params = [ 121 | {'params': solver.model.model.parameters(), 'lr': cfg.optimizer.lr * cfg.lr_ratio}, 122 | {'params': solver.model.mlp.parameters(), 'lr': cfg.optimizer.lr} 123 | ] 124 | else: 125 | for p in solver.model.model.parameters(): 126 | p.requires_grad = False 127 | cfg.optimizer.params = [{'params': solver.model.mlp.parameters(), 'lr': cfg.optimizer.lr}] 128 | optimizer = core.Configurable.load_config_dict(cfg.optimizer) 129 | solver.optimizer = optimizer 130 | 131 | if isinstance(scheduler, lr_scheduler.ReduceLROnPlateau): 132 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, **cfg.scheduler) 133 | elif scheduler is not None: 134 | cfg.scheduler.optimizer = optimizer 135 | scheduler = core.Configurable.load_config_dict(cfg.scheduler) 136 | solver.scheduler = scheduler 137 | 138 | if cfg.get("checkpoint") is not None: 139 | solver.load(cfg.checkpoint) 140 | 141 | if cfg.get("model_checkpoint") is not None: 142 | if comm.get_rank() == 0: 143 | logger.warning("Load checkpoint from %s" % cfg.model_checkpoint) 144 | cfg.model_checkpoint = os.path.expanduser(cfg.model_checkpoint) 145 | model_dict = torch.load(cfg.model_checkpoint, map_location=torch.device('cpu')) 146 | task.model.load_state_dict(model_dict) 147 | 148 | return solver, scheduler 149 | 150 | 151 | def parse_args(): 152 | parser = argparse.ArgumentParser() 153 | parser.add_argument("-c", "--config", help="yaml configuration file", required=True) 154 | parser.add_argument("-s", "--seed", help="random seed for PyTorch", type=int, default=1024) 155 | 156 | args, unparsed = parser.parse_known_args() 157 | # get dynamic arguments defined in the config file 158 | vars = detect_variables(args.config) 159 | parser = argparse.ArgumentParser() 160 | for var in vars: 161 | parser.add_argument("--%s" % var, default="null") 162 | vars = parser.parse_known_args(unparsed)[0] 163 | vars = {k: utils.literal_eval(v) for k, v in vars._get_kwargs()} 164 | 165 | return args, vars --------------------------------------------------------------------------------