├── LICENSE ├── README.md ├── RF2na-linux.yml ├── SE3Transformer ├── Dockerfile ├── LICENSE ├── NOTICE ├── README.md ├── images │ └── se3-transformer.png ├── requirements.txt ├── scripts │ ├── benchmark_inference.sh │ ├── benchmark_train.sh │ ├── benchmark_train_multi_gpu.sh │ ├── predict.sh │ ├── train.sh │ └── train_multi_gpu.sh ├── se3_transformer │ ├── __init__.py │ ├── data_loading │ │ ├── __init__.py │ │ ├── data_module.py │ │ └── qm9.py │ ├── model │ │ ├── __init__.py │ │ ├── basis.py │ │ ├── fiber.py │ │ ├── layers │ │ │ ├── __init__.py │ │ │ ├── attention.py │ │ │ ├── convolution.py │ │ │ ├── linear.py │ │ │ ├── norm.py │ │ │ └── pooling.py │ │ └── transformer.py │ └── runtime │ │ ├── __init__.py │ │ ├── arguments.py │ │ ├── callbacks.py │ │ ├── gpu_affinity.py │ │ ├── inference.py │ │ ├── loggers.py │ │ ├── metrics.py │ │ ├── training.py │ │ └── utils.py ├── setup.py └── tests │ ├── __init__.py │ ├── test_equivariance.py │ └── utils.py ├── example ├── RNA.fa ├── dna_binding_protein.fa └── rna_binding_protein.fa ├── input_prep ├── make_protein_msa.sh ├── make_rna_msa.sh ├── merge_msa_prot_rna.py └── reprocess_rnac.pl ├── network ├── Attention_module.py ├── AuxiliaryPredictor.py ├── Embeddings.py ├── RoseTTAFoldModel.py ├── SE3_network.py ├── Track_module.py ├── arguments.py ├── chemical.py ├── coords6d.py ├── data_loader.py ├── ffindex.py ├── kinematics.py ├── loss.py ├── models.json ├── parsers.py ├── predict.py ├── resnet.py ├── scheduler.py ├── scoring.py ├── util.py └── util_module.py └── run_RF2NA.sh /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Institute for Protein Design 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RF2NA 2 | GitHub repo for RoseTTAFold2 with nucleic acids 3 | 4 | **New: April 13, 2023 v0.2** 5 | * Updated weights (https://files.ipd.uw.edu/dimaio/RF2NA_apr23.tgz) for better prediction of homodimer:DNA interactions and better DNA-specific sequence recognition 6 | * Bugfixes in MSA generation pipeline 7 | * Support for paired protein/RNA MSAs 8 | 9 | ## Installation 10 | 11 | 1. Clone the package 12 | ``` 13 | git clone https://github.com/uw-ipd/RoseTTAFold2NA.git 14 | cd RoseTTAFold2NA 15 | ``` 16 | 17 | 2. Create conda environment 18 | All external dependencies are contained in `RF2na-linux.yml` 19 | ``` 20 | # create conda environment for RoseTTAFold2NA 21 | conda env create -f RF2na-linux.yml 22 | ``` 23 | You also need to install NVIDIA's SE(3)-Transformer (**please use SE3Transformer in this repo to install**). 24 | ``` 25 | conda activate RF2NA 26 | cd SE3Transformer 27 | pip install --no-cache-dir -r requirements.txt 28 | python setup.py install 29 | cd .. 30 | ``` 31 | 32 | 3. Download pre-trained weights under network directory 33 | ``` 34 | cd network 35 | wget https://files.ipd.uw.edu/dimaio/RF2NA_apr23.tgz 36 | tar xvfz RF2NA_apr23.tgz 37 | ls weights/ # it should contain a 1.1GB weights file 38 | cd .. 39 | ``` 40 | 41 | 4. Download sequence and structure databases 42 | ``` 43 | # uniref30 [46G] 44 | wget http://wwwuser.gwdg.de/~compbiol/uniclust/2020_06/UniRef30_2020_06_hhsuite.tar.gz 45 | mkdir -p UniRef30_2020_06 46 | tar xfz UniRef30_2020_06_hhsuite.tar.gz -C ./UniRef30_2020_06 47 | 48 | # BFD [272G] 49 | wget https://bfd.mmseqs.com/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz 50 | mkdir -p bfd 51 | tar xfz bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz -C ./bfd 52 | 53 | # structure templates (including *_a3m.ffdata, *_a3m.ffindex) 54 | wget https://files.ipd.uw.edu/pub/RoseTTAFold/pdb100_2021Mar03.tar.gz 55 | tar xfz pdb100_2021Mar03.tar.gz 56 | 57 | # RNA databases 58 | mkdir -p RNA 59 | cd RNA 60 | 61 | # Rfam [300M] 62 | wget ftp://ftp.ebi.ac.uk/pub/databases/Rfam/CURRENT/Rfam.full_region.gz 63 | wget ftp://ftp.ebi.ac.uk/pub/databases/Rfam/CURRENT/Rfam.cm.gz 64 | gunzip Rfam.cm.gz 65 | cmpress Rfam.cm 66 | 67 | # RNAcentral [12G] 68 | wget ftp://ftp.ebi.ac.uk/pub/databases/RNAcentral/current_release/rfam/rfam_annotations.tsv.gz 69 | wget ftp://ftp.ebi.ac.uk/pub/databases/RNAcentral/current_release/id_mapping/id_mapping.tsv.gz 70 | wget ftp://ftp.ebi.ac.uk/pub/databases/RNAcentral/current_release/sequences/rnacentral_species_specific_ids.fasta.gz 71 | ../input_prep/reprocess_rnac.pl id_mapping.tsv.gz rfam_annotations.tsv.gz # ~8 minutes 72 | gunzip -c rnacentral_species_specific_ids.fasta.gz | makeblastdb -in - -dbtype nucl -parse_seqids -out rnacentral.fasta -title "RNACentral" 73 | 74 | # nt [151G] 75 | update_blastdb.pl --decompress nt 76 | cd .. 77 | ``` 78 | 79 | ## Usage 80 | ``` 81 | conda activate RF2NA 82 | cd example 83 | # run Protein/RNA prediction 84 | ../run_RF2NA.sh rna_pred rna_binding_protein.fa R:RNA.fa 85 | # run Protein/DNA prediction 86 | ../run_RF2NA.sh dna_pred dna_binding_protein.fa D:DNA.fa 87 | ``` 88 | ### Inputs 89 | * The first argument to the script is the output folder 90 | * The remaining arguments are fasta files for individual chains in the structure. Use the tags `P:xxx.fa` `R:xxx.fa` `D:xxx.fa` `S:xxx.fa` to specify protein, RNA, double-stranded DNA, and single-stranded DNA, respectively. Use the tag `PR:xxx.fa` to specify paired protein/RNA. Each chain is a separate file; 'D' will automatically generate a complementary DNA strand to the input strand. 91 | 92 | ### Expected outputs 93 | * Outputs are written to the folder provided as the first argument (`dna_pred` and `rna_pred`). 94 | * Model outputs are placed in a subfolder, `models` (e.g., `dna_pred.models`) 95 | * You will get a predicted structre with estimated per-residue LDDT in the B-factor column (`models/model_00.pdb`) 96 | * You will get a numpy `.npz` file (`models/model_00.npz`). This can be read with `numpy.load` and contains three tables (L=complex length): 97 | - dist (L x L x 37) - the predicted distogram 98 | - lddt (L) - the per-residue predicted lddt 99 | - pae (L x L) - the per-residue pair predicted error 100 | -------------------------------------------------------------------------------- /RF2na-linux.yml: -------------------------------------------------------------------------------- 1 | name: RF2NA 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | - conda-forge 7 | dependencies: 8 | - python=3.10 9 | - pip 10 | - pytorch 11 | - requests 12 | - pytorch-cuda=11.7 13 | - dglteam/label/cu117::dgl 14 | - pyg::pyg 15 | - bioconda::mafft 16 | - bioconda::hhsuite 17 | - bioconda::blast 18 | - bioconda::hmmer>=3.3 19 | - bioconda::infernal 20 | - bioconda::cd-hit 21 | - bioconda::csblast 22 | - pip: 23 | - psutil 24 | - tqdm -------------------------------------------------------------------------------- /SE3Transformer/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a 4 | # copy of this software and associated documentation files (the "Software"), 5 | # to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | # and/or sell copies of the Software, and to permit persons to whom the 8 | # Software is furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | # DEALINGS IN THE SOFTWARE. 20 | # 21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES 22 | # SPDX-License-Identifier: MIT 23 | 24 | # run docker daemon with --default-runtime=nvidia for GPU detection during build 25 | # multistage build for DGL with CUDA and FP16 26 | 27 | ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:21.07-py3 28 | 29 | FROM ${FROM_IMAGE_NAME} AS dgl_builder 30 | 31 | ENV DEBIAN_FRONTEND=noninteractive 32 | RUN apt-get update \ 33 | && apt-get install -y git build-essential python3-dev make cmake \ 34 | && rm -rf /var/lib/apt/lists/* 35 | WORKDIR /dgl 36 | RUN git clone --branch v0.7.0 --recurse-submodules --depth 1 https://github.com/dmlc/dgl.git . 37 | RUN sed -i 's/"35 50 60 70"/"60 70 80"/g' cmake/modules/CUDA.cmake 38 | WORKDIR build 39 | RUN cmake -DUSE_CUDA=ON -DUSE_FP16=ON .. 40 | RUN make -j8 41 | 42 | 43 | FROM ${FROM_IMAGE_NAME} 44 | 45 | RUN rm -rf /workspace/* 46 | WORKDIR /workspace/se3-transformer 47 | 48 | # copy built DGL and install it 49 | COPY --from=dgl_builder /dgl ./dgl 50 | RUN cd dgl/python && python setup.py install && cd ../.. && rm -rf dgl 51 | 52 | ADD requirements.txt . 53 | RUN pip install --no-cache-dir --upgrade --pre pip 54 | RUN pip install --no-cache-dir -r requirements.txt 55 | ADD . . 56 | 57 | ENV DGLBACKEND=pytorch 58 | ENV OMP_NUM_THREADS=1 59 | -------------------------------------------------------------------------------- /SE3Transformer/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2021 NVIDIA CORPORATION & AFFILIATES 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /SE3Transformer/NOTICE: -------------------------------------------------------------------------------- 1 | SE(3)-Transformer PyTorch 2 | 3 | This repository includes software from https://github.com/FabianFuchsML/se3-transformer-public 4 | licensed under the MIT License. 5 | 6 | This repository includes software from https://github.com/lucidrains/se3-transformer-pytorch 7 | licensed under the MIT License. 8 | -------------------------------------------------------------------------------- /SE3Transformer/images/se3-transformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uw-ipd/RoseTTAFold2NA/f761af286729ea08a6ddab149023c1b73458fbe2/SE3Transformer/images/se3-transformer.png -------------------------------------------------------------------------------- /SE3Transformer/requirements.txt: -------------------------------------------------------------------------------- 1 | e3nn==0.3.3 2 | wandb==0.12.0 3 | pynvml==11.0.0 4 | git+https://github.com/NVIDIA/dllogger#egg=dllogger 5 | -------------------------------------------------------------------------------- /SE3Transformer/scripts/benchmark_inference.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Script to benchmark inference performance, without bases precomputation 3 | 4 | # CLI args with defaults 5 | BATCH_SIZE=${1:-240} 6 | AMP=${2:-true} 7 | 8 | CUDA_VISIBLE_DEVICES=0 python -m se3_transformer.runtime.inference \ 9 | --amp "$AMP" \ 10 | --batch_size "$BATCH_SIZE" \ 11 | --use_layer_norm \ 12 | --norm \ 13 | --task homo \ 14 | --seed 42 \ 15 | --benchmark 16 | -------------------------------------------------------------------------------- /SE3Transformer/scripts/benchmark_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Script to benchmark single-GPU training performance, with bases precomputation 3 | 4 | # CLI args with defaults 5 | BATCH_SIZE=${1:-240} 6 | AMP=${2:-true} 7 | 8 | CUDA_VISIBLE_DEVICES=0 python -m se3_transformer.runtime.training \ 9 | --amp "$AMP" \ 10 | --batch_size "$BATCH_SIZE" \ 11 | --epochs 6 \ 12 | --use_layer_norm \ 13 | --norm \ 14 | --save_ckpt_path model_qm9.pth \ 15 | --task homo \ 16 | --precompute_bases \ 17 | --seed 42 \ 18 | --benchmark 19 | -------------------------------------------------------------------------------- /SE3Transformer/scripts/benchmark_train_multi_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Script to benchmark multi-GPU training performance, with bases precomputation 3 | 4 | # CLI args with defaults 5 | BATCH_SIZE=${1:-240} 6 | AMP=${2:-true} 7 | 8 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --max_restarts 0 --module \ 9 | se3_transformer.runtime.training \ 10 | --amp "$AMP" \ 11 | --batch_size "$BATCH_SIZE" \ 12 | --epochs 6 \ 13 | --use_layer_norm \ 14 | --norm \ 15 | --save_ckpt_path model_qm9.pth \ 16 | --task homo \ 17 | --precompute_bases \ 18 | --seed 42 \ 19 | --benchmark 20 | -------------------------------------------------------------------------------- /SE3Transformer/scripts/predict.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # CLI args with defaults 4 | BATCH_SIZE=${1:-240} 5 | AMP=${2:-true} 6 | 7 | 8 | # choices: 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv', 9 | # 'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C' 10 | TASK=homo 11 | 12 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --max_restarts 0 --module \ 13 | se3_transformer.runtime.inference \ 14 | --amp "$AMP" \ 15 | --batch_size "$BATCH_SIZE" \ 16 | --use_layer_norm \ 17 | --norm \ 18 | --load_ckpt_path model_qm9.pth \ 19 | --task "$TASK" 20 | -------------------------------------------------------------------------------- /SE3Transformer/scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # CLI args with defaults 4 | BATCH_SIZE=${1:-240} 5 | AMP=${2:-true} 6 | NUM_EPOCHS=${3:-100} 7 | LEARNING_RATE=${4:-0.002} 8 | WEIGHT_DECAY=${5:-0.1} 9 | 10 | # choices: 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv', 11 | # 'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C' 12 | TASK=homo 13 | 14 | python -m se3_transformer.runtime.training \ 15 | --amp "$AMP" \ 16 | --batch_size "$BATCH_SIZE" \ 17 | --epochs "$NUM_EPOCHS" \ 18 | --lr "$LEARNING_RATE" \ 19 | --weight_decay "$WEIGHT_DECAY" \ 20 | --use_layer_norm \ 21 | --norm \ 22 | --save_ckpt_path model_qm9.pth \ 23 | --precompute_bases \ 24 | --seed 42 \ 25 | --task "$TASK" -------------------------------------------------------------------------------- /SE3Transformer/scripts/train_multi_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # CLI args with defaults 4 | BATCH_SIZE=${1:-240} 5 | AMP=${2:-true} 6 | NUM_EPOCHS=${3:-130} 7 | LEARNING_RATE=${4:-0.01} 8 | WEIGHT_DECAY=${5:-0.1} 9 | 10 | # choices: 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv', 11 | # 'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C' 12 | TASK=homo 13 | 14 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --max_restarts 0 --module \ 15 | se3_transformer.runtime.training \ 16 | --amp "$AMP" \ 17 | --batch_size "$BATCH_SIZE" \ 18 | --epochs "$NUM_EPOCHS" \ 19 | --lr "$LEARNING_RATE" \ 20 | --min_lr 0.00001 \ 21 | --weight_decay "$WEIGHT_DECAY" \ 22 | --use_layer_norm \ 23 | --norm \ 24 | --save_ckpt_path model_qm9.pth \ 25 | --precompute_bases \ 26 | --seed 42 \ 27 | --task "$TASK" -------------------------------------------------------------------------------- /SE3Transformer/se3_transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uw-ipd/RoseTTAFold2NA/f761af286729ea08a6ddab149023c1b73458fbe2/SE3Transformer/se3_transformer/__init__.py -------------------------------------------------------------------------------- /SE3Transformer/se3_transformer/data_loading/__init__.py: -------------------------------------------------------------------------------- 1 | from .qm9 import QM9DataModule 2 | -------------------------------------------------------------------------------- /SE3Transformer/se3_transformer/data_loading/data_module.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a 4 | # copy of this software and associated documentation files (the "Software"), 5 | # to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | # and/or sell copies of the Software, and to permit persons to whom the 8 | # Software is furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | # DEALINGS IN THE SOFTWARE. 20 | # 21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES 22 | # SPDX-License-Identifier: MIT 23 | 24 | import torch.distributed as dist 25 | from abc import ABC 26 | from torch.utils.data import DataLoader, DistributedSampler, Dataset 27 | 28 | from se3_transformer.runtime.utils import get_local_rank 29 | 30 | 31 | def _get_dataloader(dataset: Dataset, shuffle: bool, **kwargs) -> DataLoader: 32 | # Classic or distributed dataloader depending on the context 33 | sampler = DistributedSampler(dataset, shuffle=shuffle) if dist.is_initialized() else None 34 | return DataLoader(dataset, shuffle=(shuffle and sampler is None), sampler=sampler, **kwargs) 35 | 36 | 37 | class DataModule(ABC): 38 | """ Abstract DataModule. Children must define self.ds_{train | val | test}. """ 39 | 40 | def __init__(self, **dataloader_kwargs): 41 | super().__init__() 42 | if get_local_rank() == 0: 43 | self.prepare_data() 44 | 45 | # Wait until rank zero has prepared the data (download, preprocessing, ...) 46 | if dist.is_initialized(): 47 | dist.barrier(device_ids=[get_local_rank()]) 48 | 49 | self.dataloader_kwargs = {'pin_memory': True, 'persistent_workers': True, **dataloader_kwargs} 50 | self.ds_train, self.ds_val, self.ds_test = None, None, None 51 | 52 | def prepare_data(self): 53 | """ Method called only once per node. Put here any downloading or preprocessing """ 54 | pass 55 | 56 | def train_dataloader(self) -> DataLoader: 57 | return _get_dataloader(self.ds_train, shuffle=True, **self.dataloader_kwargs) 58 | 59 | def val_dataloader(self) -> DataLoader: 60 | return _get_dataloader(self.ds_val, shuffle=False, **self.dataloader_kwargs) 61 | 62 | def test_dataloader(self) -> DataLoader: 63 | return _get_dataloader(self.ds_test, shuffle=False, **self.dataloader_kwargs) 64 | -------------------------------------------------------------------------------- /SE3Transformer/se3_transformer/data_loading/qm9.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a 4 | # copy of this software and associated documentation files (the "Software"), 5 | # to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | # and/or sell copies of the Software, and to permit persons to whom the 8 | # Software is furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | # DEALINGS IN THE SOFTWARE. 20 | # 21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES 22 | # SPDX-License-Identifier: MIT 23 | from typing import Tuple 24 | 25 | import dgl 26 | import pathlib 27 | import torch 28 | from dgl.data import QM9EdgeDataset 29 | from dgl import DGLGraph 30 | from torch import Tensor 31 | from torch.utils.data import random_split, DataLoader, Dataset 32 | from tqdm import tqdm 33 | 34 | from se3_transformer.data_loading.data_module import DataModule 35 | from se3_transformer.model.basis import get_basis 36 | from se3_transformer.runtime.utils import get_local_rank, str2bool, using_tensor_cores 37 | 38 | 39 | def _get_relative_pos(qm9_graph: DGLGraph) -> Tensor: 40 | x = qm9_graph.ndata['pos'] 41 | src, dst = qm9_graph.edges() 42 | rel_pos = x[dst] - x[src] 43 | return rel_pos 44 | 45 | 46 | def _get_split_sizes(full_dataset: Dataset) -> Tuple[int, int, int]: 47 | len_full = len(full_dataset) 48 | len_train = 100_000 49 | len_test = int(0.1 * len_full) 50 | len_val = len_full - len_train - len_test 51 | return len_train, len_val, len_test 52 | 53 | 54 | class QM9DataModule(DataModule): 55 | """ 56 | Datamodule wrapping https://docs.dgl.ai/en/latest/api/python/dgl.data.html#qm9edge-dataset 57 | Training set is 100k molecules. Test set is 10% of the dataset. Validation set is the rest. 58 | This includes all the molecules from QM9 except the ones that are uncharacterized. 59 | """ 60 | 61 | NODE_FEATURE_DIM = 6 62 | EDGE_FEATURE_DIM = 4 63 | 64 | def __init__(self, 65 | data_dir: pathlib.Path, 66 | task: str = 'homo', 67 | batch_size: int = 240, 68 | num_workers: int = 8, 69 | num_degrees: int = 4, 70 | amp: bool = False, 71 | precompute_bases: bool = False, 72 | **kwargs): 73 | self.data_dir = data_dir # This needs to be before __init__ so that prepare_data has access to it 74 | super().__init__(batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate) 75 | self.amp = amp 76 | self.task = task 77 | self.batch_size = batch_size 78 | self.num_degrees = num_degrees 79 | 80 | qm9_kwargs = dict(label_keys=[self.task], verbose=False, raw_dir=str(data_dir)) 81 | if precompute_bases: 82 | bases_kwargs = dict(max_degree=num_degrees - 1, use_pad_trick=using_tensor_cores(amp), amp=amp) 83 | full_dataset = CachedBasesQM9EdgeDataset(bases_kwargs=bases_kwargs, batch_size=batch_size, 84 | num_workers=num_workers, **qm9_kwargs) 85 | else: 86 | full_dataset = QM9EdgeDataset(**qm9_kwargs) 87 | 88 | self.ds_train, self.ds_val, self.ds_test = random_split(full_dataset, _get_split_sizes(full_dataset), 89 | generator=torch.Generator().manual_seed(0)) 90 | 91 | train_targets = full_dataset.targets[self.ds_train.indices, full_dataset.label_keys[0]] 92 | self.targets_mean = train_targets.mean() 93 | self.targets_std = train_targets.std() 94 | 95 | def prepare_data(self): 96 | # Download the QM9 preprocessed data 97 | QM9EdgeDataset(verbose=True, raw_dir=str(self.data_dir)) 98 | 99 | def _collate(self, samples): 100 | graphs, y, *bases = map(list, zip(*samples)) 101 | batched_graph = dgl.batch(graphs) 102 | edge_feats = {'0': batched_graph.edata['edge_attr'][..., None]} 103 | batched_graph.edata['rel_pos'] = _get_relative_pos(batched_graph) 104 | # get node features 105 | node_feats = {'0': batched_graph.ndata['attr'][:, :6, None]} 106 | targets = (torch.cat(y) - self.targets_mean) / self.targets_std 107 | 108 | if bases: 109 | # collate bases 110 | all_bases = { 111 | key: torch.cat([b[key] for b in bases[0]], dim=0) 112 | for key in bases[0][0].keys() 113 | } 114 | 115 | return batched_graph, node_feats, edge_feats, all_bases, targets 116 | else: 117 | return batched_graph, node_feats, edge_feats, targets 118 | 119 | @staticmethod 120 | def add_argparse_args(parent_parser): 121 | parser = parent_parser.add_argument_group("QM9 dataset") 122 | parser.add_argument('--task', type=str, default='homo', const='homo', nargs='?', 123 | choices=['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv', 124 | 'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C'], 125 | help='Regression task to train on') 126 | parser.add_argument('--precompute_bases', type=str2bool, nargs='?', const=True, default=False, 127 | help='Precompute bases at the beginning of the script during dataset initialization,' 128 | ' instead of computing them at the beginning of each forward pass.') 129 | return parent_parser 130 | 131 | def __repr__(self): 132 | return f'QM9({self.task})' 133 | 134 | 135 | class CachedBasesQM9EdgeDataset(QM9EdgeDataset): 136 | """ Dataset extending the QM9 dataset from DGL with precomputed (cached in RAM) pairwise bases """ 137 | 138 | def __init__(self, bases_kwargs: dict, batch_size: int, num_workers: int, *args, **kwargs): 139 | """ 140 | :param bases_kwargs: Arguments to feed the bases computation function 141 | :param batch_size: Batch size to use when iterating over the dataset for computing bases 142 | """ 143 | self.bases_kwargs = bases_kwargs 144 | self.batch_size = batch_size 145 | self.bases = None 146 | self.num_workers = num_workers 147 | super().__init__(*args, **kwargs) 148 | 149 | def load(self): 150 | super().load() 151 | # Iterate through the dataset and compute bases (pairwise only) 152 | # Potential improvement: use multi-GPU and gather 153 | dataloader = DataLoader(self, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, 154 | collate_fn=lambda samples: dgl.batch([sample[0] for sample in samples])) 155 | bases = [] 156 | for i, graph in tqdm(enumerate(dataloader), total=len(dataloader), desc='Precomputing QM9 bases', 157 | disable=get_local_rank() != 0): 158 | rel_pos = _get_relative_pos(graph) 159 | # Compute the bases with the GPU but convert the result to CPU to store in RAM 160 | bases.append({k: v.cpu() for k, v in get_basis(rel_pos.cuda(), **self.bases_kwargs).items()}) 161 | self.bases = bases # Assign at the end so that __getitem__ isn't confused 162 | 163 | def __getitem__(self, idx: int): 164 | graph, label = super().__getitem__(idx) 165 | 166 | if self.bases: 167 | bases_idx = idx // self.batch_size 168 | bases_cumsum_idx = self.ne_cumsum[idx] - self.ne_cumsum[bases_idx * self.batch_size] 169 | bases_cumsum_next_idx = self.ne_cumsum[idx + 1] - self.ne_cumsum[bases_idx * self.batch_size] 170 | return graph, label, {key: basis[bases_cumsum_idx:bases_cumsum_next_idx] for key, basis in 171 | self.bases[bases_idx].items()} 172 | else: 173 | return graph, label 174 | -------------------------------------------------------------------------------- /SE3Transformer/se3_transformer/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import SE3Transformer, SE3TransformerPooled 2 | from .fiber import Fiber 3 | -------------------------------------------------------------------------------- /SE3Transformer/se3_transformer/model/basis.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a 4 | # copy of this software and associated documentation files (the "Software"), 5 | # to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | # and/or sell copies of the Software, and to permit persons to whom the 8 | # Software is furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | # DEALINGS IN THE SOFTWARE. 20 | # 21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES 22 | # SPDX-License-Identifier: MIT 23 | 24 | 25 | from functools import lru_cache 26 | from typing import Dict, List 27 | 28 | import e3nn.o3 as o3 29 | import torch 30 | import torch.nn.functional as F 31 | from torch import Tensor 32 | from torch.cuda.nvtx import range as nvtx_range 33 | 34 | from se3_transformer.runtime.utils import degree_to_dim 35 | 36 | 37 | @lru_cache(maxsize=None) 38 | def get_clebsch_gordon(J: int, d_in: int, d_out: int, device) -> Tensor: 39 | """ Get the (cached) Q^{d_out,d_in}_J matrices from equation (8) """ 40 | return o3.wigner_3j(J, d_in, d_out, dtype=torch.float64, device=device).permute(2, 1, 0) 41 | 42 | 43 | @lru_cache(maxsize=None) 44 | def get_all_clebsch_gordon(max_degree: int, device) -> List[List[Tensor]]: 45 | all_cb = [] 46 | for d_in in range(max_degree + 1): 47 | for d_out in range(max_degree + 1): 48 | K_Js = [] 49 | for J in range(abs(d_in - d_out), d_in + d_out + 1): 50 | K_Js.append(get_clebsch_gordon(J, d_in, d_out, device)) 51 | all_cb.append(K_Js) 52 | return all_cb 53 | 54 | 55 | def get_spherical_harmonics(relative_pos: Tensor, max_degree: int) -> List[Tensor]: 56 | all_degrees = list(range(2 * max_degree + 1)) 57 | with nvtx_range('spherical harmonics'): 58 | sh = o3.spherical_harmonics(all_degrees, relative_pos, normalize=True) 59 | return torch.split(sh, [degree_to_dim(d) for d in all_degrees], dim=1) 60 | 61 | 62 | @torch.jit.script 63 | def get_basis_script(max_degree: int, 64 | use_pad_trick: bool, 65 | spherical_harmonics: List[Tensor], 66 | clebsch_gordon: List[List[Tensor]], 67 | amp: bool) -> Dict[str, Tensor]: 68 | """ 69 | Compute pairwise bases matrices for degrees up to max_degree 70 | :param max_degree: Maximum input or output degree 71 | :param use_pad_trick: Pad some of the odd dimensions for a better use of Tensor Cores 72 | :param spherical_harmonics: List of computed spherical harmonics 73 | :param clebsch_gordon: List of computed CB-coefficients 74 | :param amp: When true, return bases in FP16 precision 75 | """ 76 | basis = {} 77 | idx = 0 78 | # Double for loop instead of product() because of JIT script 79 | for d_in in range(max_degree + 1): 80 | for d_out in range(max_degree + 1): 81 | key = f'{d_in},{d_out}' 82 | K_Js = [] 83 | for freq_idx, J in enumerate(range(abs(d_in - d_out), d_in + d_out + 1)): 84 | Q_J = clebsch_gordon[idx][freq_idx] 85 | K_Js.append(torch.einsum('n f, k l f -> n l k', spherical_harmonics[J].float(), Q_J.float())) 86 | 87 | basis[key] = torch.stack(K_Js, 2) # Stack on second dim so order is n l f k 88 | if amp: 89 | basis[key] = basis[key].half() 90 | if use_pad_trick: 91 | basis[key] = F.pad(basis[key], (0, 1)) # Pad the k dimension, that can be sliced later 92 | 93 | idx += 1 94 | 95 | return basis 96 | 97 | 98 | @torch.jit.script 99 | def update_basis_with_fused(basis: Dict[str, Tensor], 100 | max_degree: int, 101 | use_pad_trick: bool, 102 | fully_fused: bool) -> Dict[str, Tensor]: 103 | """ Update the basis dict with partially and optionally fully fused bases """ 104 | num_edges = basis['0,0'].shape[0] 105 | device = basis['0,0'].device 106 | dtype = basis['0,0'].dtype 107 | sum_dim = sum([degree_to_dim(d) for d in range(max_degree + 1)]) 108 | 109 | # Fused per output degree 110 | for d_out in range(max_degree + 1): 111 | sum_freq = sum([degree_to_dim(min(d, d_out)) for d in range(max_degree + 1)]) 112 | basis_fused = torch.zeros(num_edges, sum_dim, sum_freq, degree_to_dim(d_out) + int(use_pad_trick), 113 | device=device, dtype=dtype) 114 | acc_d, acc_f = 0, 0 115 | for d_in in range(max_degree + 1): 116 | basis_fused[:, acc_d:acc_d + degree_to_dim(d_in), acc_f:acc_f + degree_to_dim(min(d_out, d_in)), 117 | :degree_to_dim(d_out)] = basis[f'{d_in},{d_out}'][:, :, :, :degree_to_dim(d_out)] 118 | 119 | acc_d += degree_to_dim(d_in) 120 | acc_f += degree_to_dim(min(d_out, d_in)) 121 | 122 | basis[f'out{d_out}_fused'] = basis_fused 123 | 124 | # Fused per input degree 125 | for d_in in range(max_degree + 1): 126 | sum_freq = sum([degree_to_dim(min(d, d_in)) for d in range(max_degree + 1)]) 127 | basis_fused = torch.zeros(num_edges, degree_to_dim(d_in), sum_freq, sum_dim, 128 | device=device, dtype=dtype) 129 | acc_d, acc_f = 0, 0 130 | for d_out in range(max_degree + 1): 131 | basis_fused[:, :, acc_f:acc_f + degree_to_dim(min(d_out, d_in)), acc_d:acc_d + degree_to_dim(d_out)] \ 132 | = basis[f'{d_in},{d_out}'][:, :, :, :degree_to_dim(d_out)] 133 | 134 | acc_d += degree_to_dim(d_out) 135 | acc_f += degree_to_dim(min(d_out, d_in)) 136 | 137 | basis[f'in{d_in}_fused'] = basis_fused 138 | 139 | if fully_fused: 140 | # Fully fused 141 | # Double sum this way because of JIT script 142 | sum_freq = sum([ 143 | sum([degree_to_dim(min(d_in, d_out)) for d_in in range(max_degree + 1)]) for d_out in range(max_degree + 1) 144 | ]) 145 | basis_fused = torch.zeros(num_edges, sum_dim, sum_freq, sum_dim, device=device, dtype=dtype) 146 | 147 | acc_d, acc_f = 0, 0 148 | for d_out in range(max_degree + 1): 149 | b = basis[f'out{d_out}_fused'] 150 | basis_fused[:, :, acc_f:acc_f + b.shape[2], acc_d:acc_d + degree_to_dim(d_out)] = b[:, :, :, 151 | :degree_to_dim(d_out)] 152 | acc_f += b.shape[2] 153 | acc_d += degree_to_dim(d_out) 154 | 155 | basis['fully_fused'] = basis_fused 156 | 157 | del basis['0,0'] # We know that the basis for l = k = 0 is filled with a constant 158 | return basis 159 | 160 | 161 | def get_basis(relative_pos: Tensor, 162 | max_degree: int = 4, 163 | compute_gradients: bool = False, 164 | use_pad_trick: bool = False, 165 | amp: bool = False) -> Dict[str, Tensor]: 166 | with nvtx_range('spherical harmonics'): 167 | spherical_harmonics = get_spherical_harmonics(relative_pos, max_degree) 168 | with nvtx_range('CB coefficients'): 169 | clebsch_gordon = get_all_clebsch_gordon(max_degree, relative_pos.device) 170 | 171 | with torch.autograd.set_grad_enabled(compute_gradients): 172 | with nvtx_range('bases'): 173 | basis = get_basis_script(max_degree=max_degree, 174 | use_pad_trick=use_pad_trick, 175 | spherical_harmonics=spherical_harmonics, 176 | clebsch_gordon=clebsch_gordon, 177 | amp=amp) 178 | return basis 179 | -------------------------------------------------------------------------------- /SE3Transformer/se3_transformer/model/fiber.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a 4 | # copy of this software and associated documentation files (the "Software"), 5 | # to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | # and/or sell copies of the Software, and to permit persons to whom the 8 | # Software is furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | # DEALINGS IN THE SOFTWARE. 20 | # 21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES 22 | # SPDX-License-Identifier: MIT 23 | 24 | 25 | from collections import namedtuple 26 | from itertools import product 27 | from typing import Dict 28 | 29 | import torch 30 | from torch import Tensor 31 | 32 | from se3_transformer.runtime.utils import degree_to_dim 33 | 34 | FiberEl = namedtuple('FiberEl', ['degree', 'channels']) 35 | 36 | 37 | class Fiber(dict): 38 | """ 39 | Describes the structure of some set of features. 40 | Features are split into types (0, 1, 2, 3, ...). A feature of type k has a dimension of 2k+1. 41 | Type-0 features: invariant scalars 42 | Type-1 features: equivariant 3D vectors 43 | Type-2 features: equivariant symmetric traceless matrices 44 | ... 45 | 46 | As inputs to a SE3 layer, there can be many features of the same types, and many features of different types. 47 | The 'multiplicity' or 'number of channels' is the number of features of a given type. 48 | This class puts together all the degrees and their multiplicities in order to describe 49 | the inputs, outputs or hidden features of SE3 layers. 50 | """ 51 | 52 | def __init__(self, structure): 53 | if isinstance(structure, dict): 54 | structure = [FiberEl(int(d), int(m)) for d, m in sorted(structure.items(), key=lambda x: x[1])] 55 | elif not isinstance(structure[0], FiberEl): 56 | structure = list(map(lambda t: FiberEl(*t), sorted(structure, key=lambda x: x[1]))) 57 | self.structure = structure 58 | super().__init__({d: m for d, m in self.structure}) 59 | 60 | @property 61 | def degrees(self): 62 | return sorted([t.degree for t in self.structure]) 63 | 64 | @property 65 | def channels(self): 66 | return [self[d] for d in self.degrees] 67 | 68 | @property 69 | def num_features(self): 70 | """ Size of the resulting tensor if all features were concatenated together """ 71 | return sum(t.channels * degree_to_dim(t.degree) for t in self.structure) 72 | 73 | @staticmethod 74 | def create(num_degrees: int, num_channels: int): 75 | """ Create a Fiber with degrees 0..num_degrees-1, all with the same multiplicity """ 76 | return Fiber([(degree, num_channels) for degree in range(num_degrees)]) 77 | 78 | @staticmethod 79 | def from_features(feats: Dict[str, Tensor]): 80 | """ Infer the Fiber structure from a feature dict """ 81 | structure = {} 82 | for k, v in feats.items(): 83 | degree = int(k) 84 | assert len(v.shape) == 3, 'Feature shape should be (N, C, 2D+1)' 85 | assert v.shape[-1] == degree_to_dim(degree) 86 | structure[degree] = v.shape[-2] 87 | return Fiber(structure) 88 | 89 | def __getitem__(self, degree: int): 90 | """ fiber[degree] returns the multiplicity for this degree """ 91 | return dict(self.structure).get(degree, 0) 92 | 93 | def __iter__(self): 94 | """ Iterate over namedtuples (degree, channels) """ 95 | return iter(self.structure) 96 | 97 | def __mul__(self, other): 98 | """ 99 | If other in an int, multiplies all the multiplicities by other. 100 | If other is a fiber, returns the cartesian product. 101 | """ 102 | if isinstance(other, Fiber): 103 | return product(self.structure, other.structure) 104 | elif isinstance(other, int): 105 | return Fiber({t.degree: t.channels * other for t in self.structure}) 106 | 107 | def __add__(self, other): 108 | """ 109 | If other in an int, add other to all the multiplicities. 110 | If other is a fiber, add the multiplicities of the fibers together. 111 | """ 112 | if isinstance(other, Fiber): 113 | return Fiber({t.degree: t.channels + other[t.degree] for t in self.structure}) 114 | elif isinstance(other, int): 115 | return Fiber({t.degree: t.channels + other for t in self.structure}) 116 | 117 | def __repr__(self): 118 | return str(self.structure) 119 | 120 | @staticmethod 121 | def combine_max(f1, f2): 122 | """ Combine two fiber by taking the maximum multiplicity for each degree in both fibers """ 123 | new_dict = dict(f1.structure) 124 | for k, m in f2.structure: 125 | new_dict[k] = max(new_dict.get(k, 0), m) 126 | 127 | return Fiber(list(new_dict.items())) 128 | 129 | @staticmethod 130 | def combine_selectively(f1, f2): 131 | """ Combine two fiber by taking the sum of multiplicities for each degree in the first fiber """ 132 | # only use orders which occur in fiber f1 133 | new_dict = dict(f1.structure) 134 | for k in f1.degrees: 135 | if k in f2.degrees: 136 | new_dict[k] += f2[k] 137 | return Fiber(list(new_dict.items())) 138 | 139 | def to_attention_heads(self, tensors: Dict[str, Tensor], num_heads: int): 140 | # dict(N, num_channels, 2d+1) -> (N, num_heads, -1) 141 | fibers = [tensors[str(degree)].reshape(*tensors[str(degree)].shape[:-2], num_heads, -1) for degree in 142 | self.degrees] 143 | fibers = torch.cat(fibers, -1) 144 | return fibers 145 | -------------------------------------------------------------------------------- /SE3Transformer/se3_transformer/model/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .linear import LinearSE3 2 | from .norm import NormSE3 3 | from .pooling import GPooling 4 | from .convolution import ConvSE3 5 | from .attention import AttentionBlockSE3 -------------------------------------------------------------------------------- /SE3Transformer/se3_transformer/model/layers/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a 4 | # copy of this software and associated documentation files (the "Software"), 5 | # to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | # and/or sell copies of the Software, and to permit persons to whom the 8 | # Software is furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | # DEALINGS IN THE SOFTWARE. 20 | # 21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES 22 | # SPDX-License-Identifier: MIT 23 | 24 | import dgl 25 | import numpy as np 26 | import torch 27 | import torch.nn as nn 28 | from dgl import DGLGraph 29 | from dgl.ops import edge_softmax 30 | from torch import Tensor 31 | from typing import Dict, Optional, Union 32 | 33 | from se3_transformer.model.fiber import Fiber 34 | from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel 35 | from se3_transformer.model.layers.linear import LinearSE3 36 | from se3_transformer.runtime.utils import degree_to_dim, aggregate_residual, unfuse_features 37 | from torch.cuda.nvtx import range as nvtx_range 38 | 39 | 40 | class AttentionSE3(nn.Module): 41 | """ Multi-headed sparse graph self-attention (SE(3)-equivariant) """ 42 | 43 | def __init__( 44 | self, 45 | num_heads: int, 46 | key_fiber: Fiber, 47 | value_fiber: Fiber 48 | ): 49 | """ 50 | :param num_heads: Number of attention heads 51 | :param key_fiber: Fiber for the keys (and also for the queries) 52 | :param value_fiber: Fiber for the values 53 | """ 54 | super().__init__() 55 | self.num_heads = num_heads 56 | self.key_fiber = key_fiber 57 | self.value_fiber = value_fiber 58 | 59 | def forward( 60 | self, 61 | value: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused) 62 | key: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused) 63 | query: Dict[str, Tensor], # node features 64 | graph: DGLGraph 65 | ): 66 | with nvtx_range('AttentionSE3'): 67 | with nvtx_range('reshape keys and queries'): 68 | if isinstance(key, Tensor): 69 | # case where features of all types are fused 70 | key = key.reshape(key.shape[0], self.num_heads, -1) 71 | # need to reshape queries that way to keep the same layout as keys 72 | out = torch.cat([query[str(d)] for d in self.key_fiber.degrees], dim=-1) 73 | query = out.reshape(list(query.values())[0].shape[0], self.num_heads, -1) 74 | else: 75 | # features are not fused, need to fuse and reshape them 76 | key = self.key_fiber.to_attention_heads(key, self.num_heads) 77 | query = self.key_fiber.to_attention_heads(query, self.num_heads) 78 | 79 | with nvtx_range('attention dot product + softmax'): 80 | # Compute attention weights (softmax of inner product between key and query) 81 | edge_weights = dgl.ops.e_dot_v(graph, key, query).squeeze(-1) 82 | edge_weights /= np.sqrt(self.key_fiber.num_features) 83 | edge_weights = edge_softmax(graph, edge_weights) 84 | edge_weights = edge_weights[..., None, None] 85 | 86 | with nvtx_range('weighted sum'): 87 | if isinstance(value, Tensor): 88 | # features of all types are fused 89 | v = value.view(value.shape[0], self.num_heads, -1, value.shape[-1]) 90 | weights = edge_weights * v 91 | feat_out = dgl.ops.copy_e_sum(graph, weights) 92 | feat_out = feat_out.view(feat_out.shape[0], -1, feat_out.shape[-1]) # merge heads 93 | out = unfuse_features(feat_out, self.value_fiber.degrees) 94 | else: 95 | out = {} 96 | for degree, channels in self.value_fiber: 97 | v = value[str(degree)].view(-1, self.num_heads, channels // self.num_heads, 98 | degree_to_dim(degree)) 99 | weights = edge_weights * v 100 | res = dgl.ops.copy_e_sum(graph, weights) 101 | out[str(degree)] = res.view(-1, channels, degree_to_dim(degree)) # merge heads 102 | 103 | return out 104 | 105 | 106 | class AttentionBlockSE3(nn.Module): 107 | """ Multi-headed sparse graph self-attention block with skip connection, linear projection (SE(3)-equivariant) """ 108 | 109 | def __init__( 110 | self, 111 | fiber_in: Fiber, 112 | fiber_out: Fiber, 113 | fiber_edge: Optional[Fiber] = None, 114 | num_heads: int = 4, 115 | channels_div: Optional[Dict[str,int]] = None, 116 | use_layer_norm: bool = False, 117 | max_degree: bool = 4, 118 | fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL, 119 | **kwargs 120 | ): 121 | """ 122 | :param fiber_in: Fiber describing the input features 123 | :param fiber_out: Fiber describing the output features 124 | :param fiber_edge: Fiber describing the edge features (node distances excluded) 125 | :param num_heads: Number of attention heads 126 | :param channels_div: Divide the channels by this integer for computing values 127 | :param use_layer_norm: Apply layer normalization between MLP layers 128 | :param max_degree: Maximum degree used in the bases computation 129 | :param fuse_level: Maximum fuse level to use in TFN convolutions 130 | """ 131 | super().__init__() 132 | if fiber_edge is None: 133 | fiber_edge = Fiber({}) 134 | self.fiber_in = fiber_in 135 | # value_fiber has same structure as fiber_out but #channels divided by 'channels_div' 136 | if channels_div is not None: 137 | value_fiber = Fiber([(degree, channels // channels_div[str(degree)]) for degree, channels in fiber_out]) 138 | else: 139 | value_fiber = Fiber([(degree, channels) for degree, channels in fiber_out]) 140 | 141 | # key_query_fiber has the same structure as fiber_out, but only degrees which are in in_fiber 142 | # (queries are merely projected, hence degrees have to match input) 143 | key_query_fiber = Fiber([(fe.degree, fe.channels) for fe in value_fiber if fe.degree in fiber_in.degrees]) 144 | 145 | self.to_key_value = ConvSE3(fiber_in, value_fiber + key_query_fiber, pool=False, fiber_edge=fiber_edge, 146 | use_layer_norm=use_layer_norm, max_degree=max_degree, fuse_level=fuse_level, 147 | allow_fused_output=True) 148 | self.to_query = LinearSE3(fiber_in, key_query_fiber) 149 | self.attention = AttentionSE3(num_heads, key_query_fiber, value_fiber) 150 | self.project = LinearSE3(value_fiber + fiber_in, fiber_out) 151 | 152 | def forward( 153 | self, 154 | node_features: Dict[str, Tensor], 155 | edge_features: Dict[str, Tensor], 156 | graph: DGLGraph, 157 | basis: Dict[str, Tensor] 158 | ): 159 | with nvtx_range('AttentionBlockSE3'): 160 | #print ('AttentionBlockSE3 node_features',[torch.sum(torch.isnan(v)) for v in node_features.values()]) 161 | #print ('AttentionBlockSE3 edge_features',[torch.sum(torch.isnan(v)) for v in edge_features.values()]) 162 | #print ('AttentionBlockSE3 node_features',[torch.max(torch.abs(v)) for v in node_features.values()]) 163 | #print ('AttentionBlockSE3 edge_features',[torch.max(torch.abs(v)) for v in edge_features.values()]) 164 | 165 | with nvtx_range('keys / values'): 166 | fused_key_value = self.to_key_value(node_features, edge_features, graph, basis) 167 | key, value = self._get_key_value_from_fused(fused_key_value) 168 | 169 | with nvtx_range('queries'): 170 | with torch.cuda.amp.autocast(False): 171 | query = self.to_query(node_features) 172 | 173 | #if (type(value) is dict): 174 | # print ('AttentionBlockSE3 value',[torch.sum(torch.isnan(v)) for v in value.values()]) 175 | #else: 176 | # print ('AttentionBlockSE3 value',[torch.sum(torch.isnan(value))]) 177 | #if (type(key) is dict): 178 | # print ('AttentionBlockSE3 key',[torch.sum(torch.isnan(k)) for k in key.values()]) 179 | #else: 180 | # print ('AttentionBlockSE3 key',[torch.sum(torch.isnan(key))]) 181 | #print ('AttentionBlockSE3 query',[torch.sum(torch.isnan(q)) for q in query.values()]) 182 | z = self.attention(value, key, query, graph) 183 | #print ('AttentionBlockSE3 b',[torch.sum(torch.isnan(zi)) for zi in z.values()]) 184 | z_concat = aggregate_residual(node_features, z, 'cat') 185 | #print ('AttentionBlockSE3 c',[torch.sum(torch.isnan(zi)) for zi in z_concat.values()] ) 186 | output = self.project(z_concat) 187 | #print ('AttentionBlockSE3 d',[torch.sum(torch.isnan(o)) for o in output.values()] ) 188 | return output 189 | 190 | def _get_key_value_from_fused(self, fused_key_value): 191 | # Extract keys and queries features from fused features 192 | if isinstance(fused_key_value, Tensor): 193 | # Previous layer was a fully fused convolution 194 | value, key = torch.chunk(fused_key_value, chunks=2, dim=-2) 195 | else: 196 | key, value = {}, {} 197 | for degree, feat in fused_key_value.items(): 198 | if int(degree) in self.fiber_in.degrees: 199 | value[degree], key[degree] = torch.chunk(feat, chunks=2, dim=-2) 200 | else: 201 | value[degree] = feat 202 | 203 | return key, value 204 | -------------------------------------------------------------------------------- /SE3Transformer/se3_transformer/model/layers/linear.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a 4 | # copy of this software and associated documentation files (the "Software"), 5 | # to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | # and/or sell copies of the Software, and to permit persons to whom the 8 | # Software is furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | # DEALINGS IN THE SOFTWARE. 20 | # 21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES 22 | # SPDX-License-Identifier: MIT 23 | 24 | 25 | from typing import Dict 26 | 27 | import numpy as np 28 | import torch 29 | import torch.nn as nn 30 | from torch import Tensor 31 | 32 | from se3_transformer.model.fiber import Fiber 33 | 34 | 35 | class LinearSE3(nn.Module): 36 | """ 37 | Graph Linear SE(3)-equivariant layer, equivalent to a 1x1 convolution. 38 | Maps a fiber to a fiber with the same degrees (channels may be different). 39 | No interaction between degrees, but interaction between channels. 40 | 41 | type-0 features (C_0 channels) ────> Linear(bias=False) ────> type-0 features (C'_0 channels) 42 | type-1 features (C_1 channels) ────> Linear(bias=False) ────> type-1 features (C'_1 channels) 43 | : 44 | type-k features (C_k channels) ────> Linear(bias=False) ────> type-k features (C'_k channels) 45 | """ 46 | 47 | def __init__(self, fiber_in: Fiber, fiber_out: Fiber): 48 | super().__init__() 49 | self.weights = nn.ParameterDict({ 50 | str(degree_out): nn.Parameter( 51 | torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out])) 52 | for degree_out, channels_out in fiber_out 53 | }) 54 | 55 | def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]: 56 | return { 57 | degree: self.weights[degree] @ features[degree] 58 | for degree, weight in self.weights.items() 59 | } 60 | -------------------------------------------------------------------------------- /SE3Transformer/se3_transformer/model/layers/norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a 4 | # copy of this software and associated documentation files (the "Software"), 5 | # to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | # and/or sell copies of the Software, and to permit persons to whom the 8 | # Software is furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | # DEALINGS IN THE SOFTWARE. 20 | # 21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES 22 | # SPDX-License-Identifier: MIT 23 | 24 | 25 | from typing import Dict 26 | 27 | import torch 28 | import torch.nn as nn 29 | from torch import Tensor 30 | from torch.cuda.nvtx import range as nvtx_range 31 | 32 | from se3_transformer.model.fiber import Fiber 33 | 34 | 35 | class NormSE3(nn.Module): 36 | """ 37 | Norm-based SE(3)-equivariant nonlinearity. 38 | 39 | ┌──> feature_norm ──> LayerNorm() ──> ReLU() ──┐ 40 | feature_in ──┤ * ──> feature_out 41 | └──> feature_phase ────────────────────────────┘ 42 | """ 43 | 44 | NORM_CLAMP = 2 ** -24 # Minimum positive subnormal for FP16 45 | 46 | def __init__(self, fiber: Fiber, nonlinearity: nn.Module = nn.ReLU()): 47 | super().__init__() 48 | self.fiber = fiber 49 | self.nonlinearity = nonlinearity 50 | 51 | if len(set(fiber.channels)) == 1: 52 | # Fuse all the layer normalizations into a group normalization 53 | self.group_norm = nn.GroupNorm(num_groups=len(fiber.degrees), num_channels=sum(fiber.channels)) 54 | else: 55 | # Use multiple layer normalizations 56 | self.layer_norms = nn.ModuleDict({ 57 | str(degree): nn.LayerNorm(channels) 58 | for degree, channels in fiber 59 | }) 60 | 61 | def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]: 62 | with nvtx_range('NormSE3'): 63 | output = {} 64 | #print ('NormSE3 features',[torch.sum(torch.isnan(v)) for v in features.values()]) 65 | if hasattr(self, 'group_norm'): 66 | # Compute per-degree norms of features 67 | norms = [features[str(d)].norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP) 68 | for d in self.fiber.degrees] 69 | fused_norms = torch.cat(norms, dim=-2) 70 | 71 | # Transform the norms only 72 | new_norms = self.nonlinearity(self.group_norm(fused_norms.squeeze(-1))).unsqueeze(-1) 73 | new_norms = torch.chunk(new_norms, chunks=len(self.fiber.degrees), dim=-2) 74 | 75 | # Scale features to the new norms 76 | for norm, new_norm, d in zip(norms, new_norms, self.fiber.degrees): 77 | output[str(d)] = features[str(d)] / norm * new_norm 78 | else: 79 | for degree, feat in features.items(): 80 | norm = feat.norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP) 81 | new_norm = self.nonlinearity(self.layer_norms[degree](norm.squeeze(-1)).unsqueeze(-1)) 82 | output[degree] = new_norm * feat / norm 83 | #print ('NormSE3 output',[torch.sum(torch.isnan(v)) for v in output.values()]) 84 | 85 | return output 86 | -------------------------------------------------------------------------------- /SE3Transformer/se3_transformer/model/layers/pooling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a 4 | # copy of this software and associated documentation files (the "Software"), 5 | # to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | # and/or sell copies of the Software, and to permit persons to whom the 8 | # Software is furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | # DEALINGS IN THE SOFTWARE. 20 | # 21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES 22 | # SPDX-License-Identifier: MIT 23 | 24 | from typing import Dict, Literal 25 | 26 | import torch.nn as nn 27 | from dgl import DGLGraph 28 | from dgl.nn.pytorch import AvgPooling, MaxPooling 29 | from torch import Tensor 30 | 31 | 32 | class GPooling(nn.Module): 33 | """ 34 | Graph max/average pooling on a given feature type. 35 | The average can be taken for any feature type, and equivariance will be maintained. 36 | The maximum can only be taken for invariant features (type 0). 37 | If you want max-pooling for type > 0 features, look into Vector Neurons. 38 | """ 39 | 40 | def __init__(self, feat_type: int = 0, pool: Literal['max', 'avg'] = 'max'): 41 | """ 42 | :param feat_type: Feature type to pool 43 | :param pool: Type of pooling: max or avg 44 | """ 45 | super().__init__() 46 | assert pool in ['max', 'avg'], f'Unknown pooling: {pool}' 47 | assert feat_type == 0 or pool == 'avg', 'Max pooling on type > 0 features will break equivariance' 48 | self.feat_type = feat_type 49 | self.pool = MaxPooling() if pool == 'max' else AvgPooling() 50 | 51 | def forward(self, features: Dict[str, Tensor], graph: DGLGraph, **kwargs) -> Tensor: 52 | pooled = self.pool(graph, features[str(self.feat_type)]) 53 | return pooled.squeeze(dim=-1) 54 | -------------------------------------------------------------------------------- /SE3Transformer/se3_transformer/model/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a 4 | # copy of this software and associated documentation files (the "Software"), 5 | # to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | # and/or sell copies of the Software, and to permit persons to whom the 8 | # Software is furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | # DEALINGS IN THE SOFTWARE. 20 | # 21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES 22 | # SPDX-License-Identifier: MIT 23 | 24 | import logging 25 | from typing import Optional, Literal, Dict 26 | 27 | import torch 28 | import torch.nn as nn 29 | from dgl import DGLGraph 30 | from torch import Tensor 31 | 32 | from se3_transformer.model.basis import get_basis, update_basis_with_fused 33 | from se3_transformer.model.layers.attention import AttentionBlockSE3 34 | from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel 35 | from se3_transformer.model.layers.linear import LinearSE3 36 | from se3_transformer.model.layers.norm import NormSE3 37 | from se3_transformer.model.layers.pooling import GPooling 38 | from se3_transformer.runtime.utils import str2bool 39 | from se3_transformer.model.fiber import Fiber 40 | 41 | 42 | class Sequential(nn.Sequential): 43 | """ Sequential module with arbitrary forward args and kwargs. Used to pass graph, basis and edge features. """ 44 | 45 | def forward(self, input, *args, **kwargs): 46 | for module in self: 47 | input = module(input, *args, **kwargs) 48 | return input 49 | 50 | 51 | def get_populated_edge_features(relative_pos: Tensor, edge_features: Optional[Dict[str, Tensor]] = None): 52 | """ Add relative positions to existing edge features """ 53 | edge_features = edge_features.copy() if edge_features else {} 54 | r = relative_pos.norm(dim=-1, keepdim=True) 55 | if '0' in edge_features: 56 | edge_features['0'] = torch.cat([edge_features['0'], r[..., None]], dim=1) 57 | else: 58 | edge_features['0'] = r[..., None] 59 | 60 | return edge_features 61 | 62 | 63 | class SE3Transformer(nn.Module): 64 | def __init__(self, 65 | num_layers: int, 66 | fiber_in: Fiber, 67 | fiber_hidden: Fiber, 68 | fiber_out: Fiber, 69 | num_heads: int, 70 | channels_div: int, 71 | fiber_edge: Fiber = Fiber({}), 72 | return_type: Optional[int] = None, 73 | pooling: Optional[Literal['avg', 'max']] = None, 74 | final_layer: Optional[Literal['conv', 'lin', 'att']] = 'conv', 75 | norm: bool = True, 76 | use_layer_norm: bool = True, 77 | tensor_cores: bool = False, 78 | low_memory: bool = False, 79 | populate_edge: Optional[Literal['lin', 'arcsin', 'log', 'zero']] = 'lin', 80 | sum_over_edge: bool = True, 81 | **kwargs): 82 | """ 83 | :param num_layers: Number of attention layers 84 | :param fiber_in: Input fiber description 85 | :param fiber_hidden: Hidden fiber description 86 | :param fiber_out: Output fiber description 87 | :param fiber_edge: Input edge fiber description 88 | :param num_heads: Number of attention heads 89 | :param channels_div: Channels division before feeding to attention layer 90 | :param return_type: Return only features of this type 91 | :param pooling: 'avg' or 'max' graph pooling before MLP layers 92 | :param norm: Apply a normalization layer after each attention block 93 | :param use_layer_norm: Apply layer normalization between MLP layers 94 | :param tensor_cores: True if using Tensor Cores (affects the use of fully fused convs, and padded bases) 95 | :param low_memory: If True, will use slower ops that use less memory 96 | """ 97 | super().__init__() 98 | self.num_layers = num_layers 99 | self.fiber_edge = fiber_edge 100 | self.num_heads = num_heads 101 | self.channels_div = channels_div 102 | self.return_type = return_type 103 | self.pooling = pooling 104 | self.max_degree = max(*fiber_in.degrees, *fiber_hidden.degrees, *fiber_out.degrees) 105 | self.tensor_cores = tensor_cores 106 | self.low_memory = low_memory 107 | self.populate_edge = populate_edge 108 | 109 | if low_memory and not tensor_cores: 110 | logging.warning('Low memory mode will have no effect with no Tensor Cores') 111 | 112 | # Fully fused convolutions when using Tensor Cores (and not low memory mode) 113 | fuse_level = ConvSE3FuseLevel.FULL if tensor_cores and not low_memory else ConvSE3FuseLevel.PARTIAL 114 | 115 | div = dict((str(degree), channels_div) for degree in range(self.max_degree+1)) 116 | div_fin = dict((str(degree), 1) for degree in range(self.max_degree+1)) 117 | div_fin['0'] = channels_div 118 | 119 | graph_modules = [] 120 | for i in range(num_layers): 121 | graph_modules.append(AttentionBlockSE3(fiber_in=fiber_in, 122 | fiber_out=fiber_hidden, 123 | fiber_edge=fiber_edge, 124 | num_heads=num_heads, 125 | channels_div=div, 126 | use_layer_norm=use_layer_norm, 127 | max_degree=self.max_degree, 128 | fuse_level=fuse_level)) 129 | if norm: 130 | graph_modules.append(NormSE3(fiber_hidden)) 131 | fiber_in = fiber_hidden 132 | 133 | if final_layer == 'conv': 134 | graph_modules.append(ConvSE3(fiber_in=fiber_in, 135 | fiber_out=fiber_out, 136 | fiber_edge=fiber_edge, 137 | self_interaction=True, 138 | sum_over_edge=sum_over_edge, 139 | use_layer_norm=use_layer_norm, 140 | max_degree=self.max_degree)) 141 | elif final_layer == "lin": 142 | graph_modules.append(LinearSE3(fiber_in=fiber_in, 143 | fiber_out=fiber_out)) 144 | else: 145 | graph_modules.append(AttentionBlockSE3(fiber_in=fiber_in, 146 | fiber_out=fiber_out, 147 | fiber_edge=fiber_edge, 148 | num_heads=1, 149 | channels_div=div_fin, 150 | use_layer_norm=use_layer_norm, 151 | max_degree=self.max_degree, 152 | fuse_level=fuse_level)) 153 | self.graph_modules = Sequential(*graph_modules) 154 | 155 | if pooling is not None: 156 | assert return_type is not None, 'return_type must be specified when pooling' 157 | self.pooling_module = GPooling(pool=pooling, feat_type=return_type) 158 | 159 | def forward(self, graph: DGLGraph, node_feats: Dict[str, Tensor], 160 | edge_feats: Optional[Dict[str, Tensor]] = None, 161 | basis: Optional[Dict[str, Tensor]] = None): 162 | # Compute bases in case they weren't precomputed as part of the data loading 163 | basis = basis or get_basis(graph.edata['rel_pos'], max_degree=self.max_degree, compute_gradients=False, 164 | use_pad_trick=self.tensor_cores and not self.low_memory, 165 | amp=torch.is_autocast_enabled()) 166 | 167 | # Add fused bases (per output degree, per input degree, and fully fused) to the dict 168 | basis = update_basis_with_fused(basis, self.max_degree, use_pad_trick=self.tensor_cores and not self.low_memory, 169 | fully_fused=self.tensor_cores and not self.low_memory) 170 | 171 | if self.populate_edge=='lin': 172 | edge_feats = get_populated_edge_features(graph.edata['rel_pos'], edge_feats) 173 | elif self.populate_edge=='arcsin': 174 | r = graph.edata['rel_pos'].norm(dim=-1, keepdim=True) 175 | r = torch.maximum(r, torch.zeros_like(r) + 4.0) - 4.0 176 | r = torch.arcsinh(r)/3.0 177 | edge_feats['0'] = torch.cat([edge_feats['0'], r[..., None]], dim=1) 178 | elif self.populate_edge=='log': 179 | # fd - replace with log(1+x) 180 | r = torch.log( 1 + graph.edata['rel_pos'].norm(dim=-1, keepdim=True) ) 181 | edge_feats['0'] = torch.cat([edge_feats['0'], r[..., None]], dim=1) 182 | else: 183 | edge_feats['0'] = torch.cat((edge_feats['0'], torch.zeros_like(edge_feats['0'][:,:1,:])), dim=1) 184 | 185 | node_feats = self.graph_modules(node_feats, edge_feats, graph=graph, basis=basis) 186 | 187 | if self.pooling is not None: 188 | return self.pooling_module(node_feats, graph=graph) 189 | 190 | if self.return_type is not None: 191 | return node_feats[str(self.return_type)] 192 | 193 | return node_feats 194 | 195 | @staticmethod 196 | def add_argparse_args(parser): 197 | parser.add_argument('--num_layers', type=int, default=7, 198 | help='Number of stacked Transformer layers') 199 | parser.add_argument('--num_heads', type=int, default=8, 200 | help='Number of heads in self-attention') 201 | parser.add_argument('--channels_div', type=int, default=2, 202 | help='Channels division before feeding to attention layer') 203 | parser.add_argument('--pooling', type=str, default=None, const=None, nargs='?', choices=['max', 'avg'], 204 | help='Type of graph pooling') 205 | parser.add_argument('--norm', type=str2bool, nargs='?', const=True, default=False, 206 | help='Apply a normalization layer after each attention block') 207 | parser.add_argument('--use_layer_norm', type=str2bool, nargs='?', const=True, default=False, 208 | help='Apply layer normalization between MLP layers') 209 | parser.add_argument('--low_memory', type=str2bool, nargs='?', const=True, default=False, 210 | help='If true, will use fused ops that are slower but that use less memory ' 211 | '(expect 25 percent less memory). ' 212 | 'Only has an effect if AMP is enabled on Volta GPUs, or if running on Ampere GPUs') 213 | 214 | return parser 215 | 216 | 217 | class SE3TransformerPooled(nn.Module): 218 | def __init__(self, 219 | fiber_in: Fiber, 220 | fiber_out: Fiber, 221 | fiber_edge: Fiber, 222 | num_degrees: int, 223 | num_channels: int, 224 | output_dim: int, 225 | **kwargs): 226 | super().__init__() 227 | kwargs['pooling'] = kwargs['pooling'] or 'max' 228 | self.transformer = SE3Transformer( 229 | fiber_in=fiber_in, 230 | fiber_hidden=Fiber.create(num_degrees, num_channels), 231 | fiber_out=fiber_out, 232 | fiber_edge=fiber_edge, 233 | return_type=0, 234 | **kwargs 235 | ) 236 | 237 | n_out_features = fiber_out.num_features 238 | self.mlp = nn.Sequential( 239 | nn.Linear(n_out_features, n_out_features), 240 | nn.ReLU(), 241 | nn.Linear(n_out_features, output_dim) 242 | ) 243 | 244 | def forward(self, graph, node_feats, edge_feats, basis=None): 245 | feats = self.transformer(graph, node_feats, edge_feats, basis).squeeze(-1) 246 | y = self.mlp(feats).squeeze(-1) 247 | return y 248 | 249 | @staticmethod 250 | def add_argparse_args(parent_parser): 251 | parser = parent_parser.add_argument_group("Model architecture") 252 | SE3Transformer.add_argparse_args(parser) 253 | parser.add_argument('--num_degrees', 254 | help='Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1]', 255 | type=int, default=4) 256 | parser.add_argument('--num_channels', help='Number of channels for the hidden features', type=int, default=32) 257 | return parent_parser 258 | -------------------------------------------------------------------------------- /SE3Transformer/se3_transformer/runtime/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uw-ipd/RoseTTAFold2NA/f761af286729ea08a6ddab149023c1b73458fbe2/SE3Transformer/se3_transformer/runtime/__init__.py -------------------------------------------------------------------------------- /SE3Transformer/se3_transformer/runtime/arguments.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a 4 | # copy of this software and associated documentation files (the "Software"), 5 | # to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | # and/or sell copies of the Software, and to permit persons to whom the 8 | # Software is furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | # DEALINGS IN THE SOFTWARE. 20 | # 21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES 22 | # SPDX-License-Identifier: MIT 23 | 24 | import argparse 25 | import pathlib 26 | 27 | from se3_transformer.data_loading import QM9DataModule 28 | from se3_transformer.model import SE3TransformerPooled 29 | from se3_transformer.runtime.utils import str2bool 30 | 31 | PARSER = argparse.ArgumentParser(description='SE(3)-Transformer') 32 | 33 | paths = PARSER.add_argument_group('Paths') 34 | paths.add_argument('--data_dir', type=pathlib.Path, default=pathlib.Path('./data'), 35 | help='Directory where the data is located or should be downloaded') 36 | paths.add_argument('--log_dir', type=pathlib.Path, default=pathlib.Path('/results'), 37 | help='Directory where the results logs should be saved') 38 | paths.add_argument('--dllogger_name', type=str, default='dllogger_results.json', 39 | help='Name for the resulting DLLogger JSON file') 40 | paths.add_argument('--save_ckpt_path', type=pathlib.Path, default=None, 41 | help='File where the checkpoint should be saved') 42 | paths.add_argument('--load_ckpt_path', type=pathlib.Path, default=None, 43 | help='File of the checkpoint to be loaded') 44 | 45 | optimizer = PARSER.add_argument_group('Optimizer') 46 | optimizer.add_argument('--optimizer', choices=['adam', 'sgd', 'lamb'], default='adam') 47 | optimizer.add_argument('--learning_rate', '--lr', dest='learning_rate', type=float, default=0.002) 48 | optimizer.add_argument('--min_learning_rate', '--min_lr', dest='min_learning_rate', type=float, default=None) 49 | optimizer.add_argument('--momentum', type=float, default=0.9) 50 | optimizer.add_argument('--weight_decay', type=float, default=0.1) 51 | 52 | PARSER.add_argument('--epochs', type=int, default=100, help='Number of training epochs') 53 | PARSER.add_argument('--batch_size', type=int, default=240, help='Batch size') 54 | PARSER.add_argument('--seed', type=int, default=None, help='Set a seed globally') 55 | PARSER.add_argument('--num_workers', type=int, default=8, help='Number of dataloading workers') 56 | 57 | PARSER.add_argument('--amp', type=str2bool, nargs='?', const=True, default=False, help='Use Automatic Mixed Precision') 58 | PARSER.add_argument('--gradient_clip', type=float, default=None, help='Clipping of the gradient norms') 59 | PARSER.add_argument('--accumulate_grad_batches', type=int, default=1, help='Gradient accumulation') 60 | PARSER.add_argument('--ckpt_interval', type=int, default=-1, help='Save a checkpoint every N epochs') 61 | PARSER.add_argument('--eval_interval', dest='eval_interval', type=int, default=1, 62 | help='Do an evaluation round every N epochs') 63 | PARSER.add_argument('--silent', type=str2bool, nargs='?', const=True, default=False, 64 | help='Minimize stdout output') 65 | 66 | PARSER.add_argument('--benchmark', type=str2bool, nargs='?', const=True, default=False, 67 | help='Benchmark mode') 68 | 69 | QM9DataModule.add_argparse_args(PARSER) 70 | SE3TransformerPooled.add_argparse_args(PARSER) 71 | -------------------------------------------------------------------------------- /SE3Transformer/se3_transformer/runtime/callbacks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a 4 | # copy of this software and associated documentation files (the "Software"), 5 | # to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | # and/or sell copies of the Software, and to permit persons to whom the 8 | # Software is furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | # DEALINGS IN THE SOFTWARE. 20 | # 21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES 22 | # SPDX-License-Identifier: MIT 23 | 24 | import logging 25 | import time 26 | from abc import ABC, abstractmethod 27 | from typing import Optional 28 | 29 | import numpy as np 30 | import torch 31 | 32 | from se3_transformer.runtime.loggers import Logger 33 | from se3_transformer.runtime.metrics import MeanAbsoluteError 34 | 35 | 36 | class BaseCallback(ABC): 37 | def on_fit_start(self, optimizer, args): 38 | pass 39 | 40 | def on_fit_end(self): 41 | pass 42 | 43 | def on_epoch_end(self): 44 | pass 45 | 46 | def on_batch_start(self): 47 | pass 48 | 49 | def on_validation_step(self, input, target, pred): 50 | pass 51 | 52 | def on_validation_end(self, epoch=None): 53 | pass 54 | 55 | def on_checkpoint_load(self, checkpoint): 56 | pass 57 | 58 | def on_checkpoint_save(self, checkpoint): 59 | pass 60 | 61 | 62 | class LRSchedulerCallback(BaseCallback): 63 | def __init__(self, logger: Optional[Logger] = None): 64 | self.logger = logger 65 | self.scheduler = None 66 | 67 | @abstractmethod 68 | def get_scheduler(self, optimizer, args): 69 | pass 70 | 71 | def on_fit_start(self, optimizer, args): 72 | self.scheduler = self.get_scheduler(optimizer, args) 73 | 74 | def on_checkpoint_load(self, checkpoint): 75 | self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 76 | 77 | def on_checkpoint_save(self, checkpoint): 78 | checkpoint['scheduler_state_dict'] = self.scheduler.state_dict() 79 | 80 | def on_epoch_end(self): 81 | if self.logger is not None: 82 | self.logger.log_metrics({'learning rate': self.scheduler.get_last_lr()[0]}, step=self.scheduler.last_epoch) 83 | self.scheduler.step() 84 | 85 | 86 | class QM9MetricCallback(BaseCallback): 87 | """ Logs the rescaled mean absolute error for QM9 regression tasks """ 88 | 89 | def __init__(self, logger, targets_std, prefix=''): 90 | self.mae = MeanAbsoluteError() 91 | self.logger = logger 92 | self.targets_std = targets_std 93 | self.prefix = prefix 94 | self.best_mae = float('inf') 95 | 96 | def on_validation_step(self, input, target, pred): 97 | self.mae(pred.detach(), target.detach()) 98 | 99 | def on_validation_end(self, epoch=None): 100 | mae = self.mae.compute() * self.targets_std 101 | logging.info(f'{self.prefix} MAE: {mae}') 102 | self.logger.log_metrics({f'{self.prefix} MAE': mae}, epoch) 103 | self.best_mae = min(self.best_mae, mae) 104 | 105 | def on_fit_end(self): 106 | if self.best_mae != float('inf'): 107 | self.logger.log_metrics({f'{self.prefix} best MAE': self.best_mae}) 108 | 109 | 110 | class QM9LRSchedulerCallback(LRSchedulerCallback): 111 | def __init__(self, logger, epochs): 112 | super().__init__(logger) 113 | self.epochs = epochs 114 | 115 | def get_scheduler(self, optimizer, args): 116 | min_lr = args.min_learning_rate if args.min_learning_rate else args.learning_rate / 10.0 117 | return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, self.epochs, eta_min=min_lr) 118 | 119 | 120 | class PerformanceCallback(BaseCallback): 121 | def __init__(self, logger, batch_size: int, warmup_epochs: int = 1, mode: str = 'train'): 122 | self.batch_size = batch_size 123 | self.warmup_epochs = warmup_epochs 124 | self.epoch = 0 125 | self.timestamps = [] 126 | self.mode = mode 127 | self.logger = logger 128 | 129 | def on_batch_start(self): 130 | if self.epoch >= self.warmup_epochs: 131 | self.timestamps.append(time.time() * 1000.0) 132 | 133 | def _log_perf(self): 134 | stats = self.process_performance_stats() 135 | for k, v in stats.items(): 136 | logging.info(f'performance {k}: {v}') 137 | 138 | self.logger.log_metrics(stats) 139 | 140 | def on_epoch_end(self): 141 | self.epoch += 1 142 | 143 | def on_fit_end(self): 144 | if self.epoch > self.warmup_epochs: 145 | self._log_perf() 146 | self.timestamps = [] 147 | 148 | def process_performance_stats(self): 149 | timestamps = np.asarray(self.timestamps) 150 | deltas = np.diff(timestamps) 151 | throughput = (self.batch_size / deltas).mean() 152 | stats = { 153 | f"throughput_{self.mode}": throughput, 154 | f"latency_{self.mode}_mean": deltas.mean(), 155 | f"total_time_{self.mode}": timestamps[-1] - timestamps[0], 156 | } 157 | for level in [90, 95, 99]: 158 | stats.update({f"latency_{self.mode}_{level}": np.percentile(deltas, level)}) 159 | 160 | return stats 161 | -------------------------------------------------------------------------------- /SE3Transformer/se3_transformer/runtime/gpu_affinity.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a 4 | # copy of this software and associated documentation files (the "Software"), 5 | # to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | # and/or sell copies of the Software, and to permit persons to whom the 8 | # Software is furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | # DEALINGS IN THE SOFTWARE. 20 | # 21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES 22 | # SPDX-License-Identifier: MIT 23 | 24 | import collections 25 | import itertools 26 | import math 27 | import os 28 | import pathlib 29 | import re 30 | 31 | import pynvml 32 | 33 | 34 | class Device: 35 | # assumes nvml returns list of 64 bit ints 36 | _nvml_affinity_elements = math.ceil(os.cpu_count() / 64) 37 | 38 | def __init__(self, device_idx): 39 | super().__init__() 40 | self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx) 41 | 42 | def get_name(self): 43 | return pynvml.nvmlDeviceGetName(self.handle) 44 | 45 | def get_uuid(self): 46 | return pynvml.nvmlDeviceGetUUID(self.handle) 47 | 48 | def get_cpu_affinity(self): 49 | affinity_string = "" 50 | for j in pynvml.nvmlDeviceGetCpuAffinity(self.handle, Device._nvml_affinity_elements): 51 | # assume nvml returns list of 64 bit ints 52 | affinity_string = "{:064b}".format(j) + affinity_string 53 | 54 | affinity_list = [int(x) for x in affinity_string] 55 | affinity_list.reverse() # so core 0 is in 0th element of list 56 | 57 | ret = [i for i, e in enumerate(affinity_list) if e != 0] 58 | return ret 59 | 60 | 61 | def get_thread_siblings_list(): 62 | """ 63 | Returns a list of 2-element integer tuples representing pairs of 64 | hyperthreading cores. 65 | """ 66 | path = "/sys/devices/system/cpu/cpu*/topology/thread_siblings_list" 67 | thread_siblings_list = [] 68 | pattern = re.compile(r"(\d+)\D(\d+)") 69 | for fname in pathlib.Path(path[0]).glob(path[1:]): 70 | with open(fname) as f: 71 | content = f.read().strip() 72 | res = pattern.findall(content) 73 | if res: 74 | pair = tuple(map(int, res[0])) 75 | thread_siblings_list.append(pair) 76 | return thread_siblings_list 77 | 78 | 79 | def check_socket_affinities(socket_affinities): 80 | # sets of cores should be either identical or disjoint 81 | for i, j in itertools.product(socket_affinities, socket_affinities): 82 | if not set(i) == set(j) and not set(i).isdisjoint(set(j)): 83 | raise RuntimeError(f"Sets of cores should be either identical or disjoint, " f"but got {i} and {j}.") 84 | 85 | 86 | def get_socket_affinities(nproc_per_node, exclude_unavailable_cores=True): 87 | devices = [Device(i) for i in range(nproc_per_node)] 88 | socket_affinities = [dev.get_cpu_affinity() for dev in devices] 89 | 90 | if exclude_unavailable_cores: 91 | available_cores = os.sched_getaffinity(0) 92 | socket_affinities = [list(set(affinity) & available_cores) for affinity in socket_affinities] 93 | 94 | check_socket_affinities(socket_affinities) 95 | 96 | return socket_affinities 97 | 98 | 99 | def set_socket_affinity(gpu_id): 100 | """ 101 | The process is assigned with all available logical CPU cores from the CPU 102 | socket connected to the GPU with a given id. 103 | 104 | Args: 105 | gpu_id: index of a GPU 106 | """ 107 | dev = Device(gpu_id) 108 | affinity = dev.get_cpu_affinity() 109 | os.sched_setaffinity(0, affinity) 110 | 111 | 112 | def set_single_affinity(gpu_id): 113 | """ 114 | The process is assigned with the first available logical CPU core from the 115 | list of all CPU cores from the CPU socket connected to the GPU with a given 116 | id. 117 | 118 | Args: 119 | gpu_id: index of a GPU 120 | """ 121 | dev = Device(gpu_id) 122 | affinity = dev.get_cpu_affinity() 123 | 124 | # exclude unavailable cores 125 | available_cores = os.sched_getaffinity(0) 126 | affinity = list(set(affinity) & available_cores) 127 | os.sched_setaffinity(0, affinity[:1]) 128 | 129 | 130 | def set_single_unique_affinity(gpu_id, nproc_per_node): 131 | """ 132 | The process is assigned with a single unique available physical CPU core 133 | from the list of all CPU cores from the CPU socket connected to the GPU with 134 | a given id. 135 | 136 | Args: 137 | gpu_id: index of a GPU 138 | """ 139 | socket_affinities = get_socket_affinities(nproc_per_node) 140 | 141 | siblings_list = get_thread_siblings_list() 142 | siblings_dict = dict(siblings_list) 143 | 144 | # remove siblings 145 | for idx, socket_affinity in enumerate(socket_affinities): 146 | socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values())) 147 | 148 | affinities = [] 149 | assigned = [] 150 | 151 | for socket_affinity in socket_affinities: 152 | for core in socket_affinity: 153 | if core not in assigned: 154 | affinities.append([core]) 155 | assigned.append(core) 156 | break 157 | os.sched_setaffinity(0, affinities[gpu_id]) 158 | 159 | 160 | def set_socket_unique_affinity(gpu_id, nproc_per_node, mode, balanced=True): 161 | """ 162 | The process is assigned with an unique subset of available physical CPU 163 | cores from the CPU socket connected to a GPU with a given id. 164 | Assignment automatically includes hyperthreading siblings (if siblings are 165 | available). 166 | 167 | Args: 168 | gpu_id: index of a GPU 169 | nproc_per_node: total number of processes per node 170 | mode: mode 171 | balanced: assign an equal number of physical cores to each process 172 | """ 173 | socket_affinities = get_socket_affinities(nproc_per_node) 174 | 175 | siblings_list = get_thread_siblings_list() 176 | siblings_dict = dict(siblings_list) 177 | 178 | # remove hyperthreading siblings 179 | for idx, socket_affinity in enumerate(socket_affinities): 180 | socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values())) 181 | 182 | socket_affinities_to_device_ids = collections.defaultdict(list) 183 | 184 | for idx, socket_affinity in enumerate(socket_affinities): 185 | socket_affinities_to_device_ids[tuple(socket_affinity)].append(idx) 186 | 187 | # compute minimal number of physical cores per GPU across all GPUs and 188 | # sockets, code assigns this number of cores per GPU if balanced == True 189 | min_physical_cores_per_gpu = min( 190 | [len(cores) // len(gpus) for cores, gpus in socket_affinities_to_device_ids.items()] 191 | ) 192 | 193 | for socket_affinity, device_ids in socket_affinities_to_device_ids.items(): 194 | devices_per_group = len(device_ids) 195 | if balanced: 196 | cores_per_device = min_physical_cores_per_gpu 197 | socket_affinity = socket_affinity[: devices_per_group * min_physical_cores_per_gpu] 198 | else: 199 | cores_per_device = len(socket_affinity) // devices_per_group 200 | 201 | for group_id, device_id in enumerate(device_ids): 202 | if device_id == gpu_id: 203 | 204 | # In theory there should be no difference in performance between 205 | # 'interleaved' and 'continuous' pattern on Intel-based DGX-1, 206 | # but 'continuous' should be better for DGX A100 because on AMD 207 | # Rome 4 consecutive cores are sharing L3 cache. 208 | # TODO: code doesn't attempt to automatically detect layout of 209 | # L3 cache, also external environment may already exclude some 210 | # cores, this code makes no attempt to detect it and to align 211 | # mapping to multiples of 4. 212 | 213 | if mode == "interleaved": 214 | affinity = list(socket_affinity[group_id::devices_per_group]) 215 | elif mode == "continuous": 216 | affinity = list(socket_affinity[group_id * cores_per_device: (group_id + 1) * cores_per_device]) 217 | else: 218 | raise RuntimeError("Unknown set_socket_unique_affinity mode") 219 | 220 | # unconditionally reintroduce hyperthreading siblings, this step 221 | # may result in a different numbers of logical cores assigned to 222 | # each GPU even if balanced == True (if hyperthreading siblings 223 | # aren't available for a subset of cores due to some external 224 | # constraints, siblings are re-added unconditionally, in the 225 | # worst case unavailable logical core will be ignored by 226 | # os.sched_setaffinity(). 227 | affinity += [siblings_dict[aff] for aff in affinity if aff in siblings_dict] 228 | os.sched_setaffinity(0, affinity) 229 | 230 | 231 | def set_affinity(gpu_id, nproc_per_node, mode="socket_unique_continuous", balanced=True): 232 | """ 233 | The process is assigned with a proper CPU affinity which matches hardware 234 | architecture on a given platform. Usually it improves and stabilizes 235 | performance of deep learning training workloads. 236 | 237 | This function assumes that the workload is running in multi-process 238 | single-device mode (there are multiple training processes and each process 239 | is running on a single GPU), which is typical for multi-GPU training 240 | workloads using `torch.nn.parallel.DistributedDataParallel`. 241 | 242 | Available affinity modes: 243 | * 'socket' - the process is assigned with all available logical CPU cores 244 | from the CPU socket connected to the GPU with a given id. 245 | * 'single' - the process is assigned with the first available logical CPU 246 | core from the list of all CPU cores from the CPU socket connected to the GPU 247 | with a given id (multiple GPUs could be assigned with the same CPU core). 248 | * 'single_unique' - the process is assigned with a single unique available 249 | physical CPU core from the list of all CPU cores from the CPU socket 250 | connected to the GPU with a given id. 251 | * 'socket_unique_interleaved' - the process is assigned with an unique 252 | subset of available physical CPU cores from the CPU socket connected to a 253 | GPU with a given id, hyperthreading siblings are included automatically, 254 | cores are assigned with interleaved indexing pattern 255 | * 'socket_unique_continuous' - (the default) the process is assigned with an 256 | unique subset of available physical CPU cores from the CPU socket connected 257 | to a GPU with a given id, hyperthreading siblings are included 258 | automatically, cores are assigned with continuous indexing pattern 259 | 260 | 'socket_unique_continuous' is the recommended mode for deep learning 261 | training workloads on NVIDIA DGX machines. 262 | 263 | Args: 264 | gpu_id: integer index of a GPU 265 | nproc_per_node: number of processes per node 266 | mode: affinity mode 267 | balanced: assign an equal number of physical cores to each process, 268 | affects only 'socket_unique_interleaved' and 269 | 'socket_unique_continuous' affinity modes 270 | 271 | Returns a set of logical CPU cores on which the process is eligible to run. 272 | 273 | Example: 274 | 275 | import argparse 276 | import os 277 | 278 | import gpu_affinity 279 | import torch 280 | 281 | 282 | def main(): 283 | parser = argparse.ArgumentParser() 284 | parser.add_argument( 285 | '--local_rank', 286 | type=int, 287 | default=os.getenv('LOCAL_RANK', 0), 288 | ) 289 | args = parser.parse_args() 290 | 291 | nproc_per_node = torch.cuda.device_count() 292 | 293 | affinity = gpu_affinity.set_affinity(args.local_rank, nproc_per_node) 294 | print(f'{args.local_rank}: core affinity: {affinity}') 295 | 296 | 297 | if __name__ == "__main__": 298 | main() 299 | 300 | Launch the example with: 301 | python -m torch.distributed.launch --nproc_per_node <#GPUs> example.py 302 | 303 | 304 | WARNING: On DGX A100 only a half of CPU cores have direct access to GPUs. 305 | This function restricts execution only to the CPU cores directly connected 306 | to GPUs, so on DGX A100 it will limit the code to half of CPU cores and half 307 | of CPU memory bandwidth (which may be fine for many DL models). 308 | """ 309 | pynvml.nvmlInit() 310 | 311 | if mode == "socket": 312 | set_socket_affinity(gpu_id) 313 | elif mode == "single": 314 | set_single_affinity(gpu_id) 315 | elif mode == "single_unique": 316 | set_single_unique_affinity(gpu_id, nproc_per_node) 317 | elif mode == "socket_unique_interleaved": 318 | set_socket_unique_affinity(gpu_id, nproc_per_node, "interleaved", balanced) 319 | elif mode == "socket_unique_continuous": 320 | set_socket_unique_affinity(gpu_id, nproc_per_node, "continuous", balanced) 321 | else: 322 | raise RuntimeError("Unknown affinity mode") 323 | 324 | affinity = os.sched_getaffinity(0) 325 | return affinity 326 | -------------------------------------------------------------------------------- /SE3Transformer/se3_transformer/runtime/inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a 4 | # copy of this software and associated documentation files (the "Software"), 5 | # to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | # and/or sell copies of the Software, and to permit persons to whom the 8 | # Software is furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | # DEALINGS IN THE SOFTWARE. 20 | # 21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES 22 | # SPDX-License-Identifier: MIT 23 | 24 | from typing import List 25 | 26 | import torch 27 | import torch.nn as nn 28 | from torch.nn.parallel import DistributedDataParallel 29 | from torch.utils.data import DataLoader 30 | from tqdm import tqdm 31 | 32 | from se3_transformer.runtime import gpu_affinity 33 | from se3_transformer.runtime.arguments import PARSER 34 | from se3_transformer.runtime.callbacks import BaseCallback 35 | from se3_transformer.runtime.loggers import DLLogger 36 | from se3_transformer.runtime.utils import to_cuda, get_local_rank 37 | 38 | 39 | @torch.inference_mode() 40 | def evaluate(model: nn.Module, 41 | dataloader: DataLoader, 42 | callbacks: List[BaseCallback], 43 | args): 44 | model.eval() 45 | for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), unit='batch', desc=f'Evaluation', 46 | leave=False, disable=(args.silent or get_local_rank() != 0)): 47 | *input, target = to_cuda(batch) 48 | 49 | for callback in callbacks: 50 | callback.on_batch_start() 51 | 52 | with torch.cuda.amp.autocast(enabled=args.amp): 53 | pred = model(*input) 54 | 55 | for callback in callbacks: 56 | callback.on_validation_step(input, target, pred) 57 | 58 | 59 | if __name__ == '__main__': 60 | from se3_transformer.runtime.callbacks import QM9MetricCallback, PerformanceCallback 61 | from se3_transformer.runtime.utils import init_distributed, seed_everything 62 | from se3_transformer.model import SE3TransformerPooled, Fiber 63 | from se3_transformer.data_loading import QM9DataModule 64 | import torch.distributed as dist 65 | import logging 66 | import sys 67 | 68 | is_distributed = init_distributed() 69 | local_rank = get_local_rank() 70 | args = PARSER.parse_args() 71 | 72 | logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO) 73 | 74 | logging.info('====== SE(3)-Transformer ======') 75 | logging.info('| Inference on the test set |') 76 | logging.info('===============================') 77 | 78 | if not args.benchmark and args.load_ckpt_path is None: 79 | logging.error('No load_ckpt_path provided, you need to provide a saved model to evaluate') 80 | sys.exit(1) 81 | 82 | if args.benchmark: 83 | logging.info('Running benchmark mode with one warmup pass') 84 | 85 | if args.seed is not None: 86 | seed_everything(args.seed) 87 | 88 | major_cc, minor_cc = torch.cuda.get_device_capability() 89 | 90 | logger = DLLogger(args.log_dir, filename=args.dllogger_name) 91 | datamodule = QM9DataModule(**vars(args)) 92 | model = SE3TransformerPooled( 93 | fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}), 94 | fiber_out=Fiber({0: args.num_degrees * args.num_channels}), 95 | fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}), 96 | output_dim=1, 97 | tensor_cores=(args.amp and major_cc >= 7) or major_cc >= 8, # use Tensor Cores more effectively 98 | **vars(args) 99 | ) 100 | callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='test')] 101 | 102 | model.to(device=torch.cuda.current_device()) 103 | if args.load_ckpt_path is not None: 104 | checkpoint = torch.load(str(args.load_ckpt_path), map_location={'cuda:0': f'cuda:{local_rank}'}) 105 | model.load_state_dict(checkpoint['state_dict']) 106 | 107 | if is_distributed: 108 | nproc_per_node = torch.cuda.device_count() 109 | affinity = gpu_affinity.set_affinity(local_rank, nproc_per_node) 110 | model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) 111 | 112 | test_dataloader = datamodule.test_dataloader() if not args.benchmark else datamodule.train_dataloader() 113 | evaluate(model, 114 | test_dataloader, 115 | callbacks, 116 | args) 117 | 118 | for callback in callbacks: 119 | callback.on_validation_end() 120 | 121 | if args.benchmark: 122 | world_size = dist.get_world_size() if dist.is_initialized() else 1 123 | callbacks = [PerformanceCallback(logger, args.batch_size * world_size, warmup_epochs=1, mode='inference')] 124 | for _ in range(6): 125 | evaluate(model, 126 | test_dataloader, 127 | callbacks, 128 | args) 129 | callbacks[0].on_epoch_end() 130 | 131 | callbacks[0].on_fit_end() 132 | -------------------------------------------------------------------------------- /SE3Transformer/se3_transformer/runtime/loggers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a 4 | # copy of this software and associated documentation files (the "Software"), 5 | # to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | # and/or sell copies of the Software, and to permit persons to whom the 8 | # Software is furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | # DEALINGS IN THE SOFTWARE. 20 | # 21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES 22 | # SPDX-License-Identifier: MIT 23 | 24 | import pathlib 25 | from abc import ABC, abstractmethod 26 | from enum import Enum 27 | from typing import Dict, Any, Callable, Optional 28 | 29 | import dllogger 30 | import torch.distributed as dist 31 | import wandb 32 | from dllogger import Verbosity 33 | 34 | from se3_transformer.runtime.utils import rank_zero_only 35 | 36 | 37 | class Logger(ABC): 38 | @rank_zero_only 39 | @abstractmethod 40 | def log_hyperparams(self, params): 41 | pass 42 | 43 | @rank_zero_only 44 | @abstractmethod 45 | def log_metrics(self, metrics, step=None): 46 | pass 47 | 48 | @staticmethod 49 | def _sanitize_params(params): 50 | def _sanitize(val): 51 | if isinstance(val, Callable): 52 | try: 53 | _val = val() 54 | if isinstance(_val, Callable): 55 | return val.__name__ 56 | return _val 57 | except Exception: 58 | return getattr(val, "__name__", None) 59 | elif isinstance(val, pathlib.Path) or isinstance(val, Enum): 60 | return str(val) 61 | return val 62 | 63 | return {key: _sanitize(val) for key, val in params.items()} 64 | 65 | 66 | class LoggerCollection(Logger): 67 | def __init__(self, loggers): 68 | super().__init__() 69 | self.loggers = loggers 70 | 71 | def __getitem__(self, index): 72 | return [logger for logger in self.loggers][index] 73 | 74 | @rank_zero_only 75 | def log_metrics(self, metrics, step=None): 76 | for logger in self.loggers: 77 | logger.log_metrics(metrics, step) 78 | 79 | @rank_zero_only 80 | def log_hyperparams(self, params): 81 | for logger in self.loggers: 82 | logger.log_hyperparams(params) 83 | 84 | 85 | class DLLogger(Logger): 86 | def __init__(self, save_dir: pathlib.Path, filename: str): 87 | super().__init__() 88 | if not dist.is_initialized() or dist.get_rank() == 0: 89 | save_dir.mkdir(parents=True, exist_ok=True) 90 | dllogger.init( 91 | backends=[dllogger.JSONStreamBackend(Verbosity.DEFAULT, str(save_dir / filename))]) 92 | 93 | @rank_zero_only 94 | def log_hyperparams(self, params): 95 | params = self._sanitize_params(params) 96 | dllogger.log(step="PARAMETER", data=params) 97 | 98 | @rank_zero_only 99 | def log_metrics(self, metrics, step=None): 100 | if step is None: 101 | step = tuple() 102 | 103 | dllogger.log(step=step, data=metrics) 104 | 105 | 106 | class WandbLogger(Logger): 107 | def __init__( 108 | self, 109 | name: str, 110 | save_dir: pathlib.Path, 111 | id: Optional[str] = None, 112 | project: Optional[str] = None 113 | ): 114 | super().__init__() 115 | if not dist.is_initialized() or dist.get_rank() == 0: 116 | save_dir.mkdir(parents=True, exist_ok=True) 117 | self.experiment = wandb.init(name=name, 118 | project=project, 119 | id=id, 120 | dir=str(save_dir), 121 | resume='allow', 122 | anonymous='must') 123 | 124 | @rank_zero_only 125 | def log_hyperparams(self, params: Dict[str, Any]) -> None: 126 | params = self._sanitize_params(params) 127 | self.experiment.config.update(params, allow_val_change=True) 128 | 129 | @rank_zero_only 130 | def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: 131 | if step is not None: 132 | self.experiment.log({**metrics, 'epoch': step}) 133 | else: 134 | self.experiment.log(metrics) 135 | -------------------------------------------------------------------------------- /SE3Transformer/se3_transformer/runtime/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a 4 | # copy of this software and associated documentation files (the "Software"), 5 | # to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | # and/or sell copies of the Software, and to permit persons to whom the 8 | # Software is furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | # DEALINGS IN THE SOFTWARE. 20 | # 21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES 22 | # SPDX-License-Identifier: MIT 23 | 24 | from abc import ABC, abstractmethod 25 | 26 | import torch 27 | import torch.distributed as dist 28 | from torch import Tensor 29 | 30 | 31 | class Metric(ABC): 32 | """ Metric class with synchronization capabilities similar to TorchMetrics """ 33 | 34 | def __init__(self): 35 | self.states = {} 36 | 37 | def add_state(self, name: str, default: Tensor): 38 | assert name not in self.states 39 | self.states[name] = default.clone() 40 | setattr(self, name, default) 41 | 42 | def synchronize(self): 43 | if dist.is_initialized(): 44 | for state in self.states: 45 | dist.all_reduce(getattr(self, state), op=dist.ReduceOp.SUM, group=dist.group.WORLD) 46 | 47 | def __call__(self, *args, **kwargs): 48 | self.update(*args, **kwargs) 49 | 50 | def reset(self): 51 | for name, default in self.states.items(): 52 | setattr(self, name, default.clone()) 53 | 54 | def compute(self): 55 | self.synchronize() 56 | value = self._compute().item() 57 | self.reset() 58 | return value 59 | 60 | @abstractmethod 61 | def _compute(self): 62 | pass 63 | 64 | @abstractmethod 65 | def update(self, preds: Tensor, targets: Tensor): 66 | pass 67 | 68 | 69 | class MeanAbsoluteError(Metric): 70 | def __init__(self): 71 | super().__init__() 72 | self.add_state('error', torch.tensor(0, dtype=torch.float32, device='cuda')) 73 | self.add_state('total', torch.tensor(0, dtype=torch.int32, device='cuda')) 74 | 75 | def update(self, preds: Tensor, targets: Tensor): 76 | preds = preds.detach() 77 | n = preds.shape[0] 78 | error = torch.abs(preds.view(n, -1) - targets.view(n, -1)).sum() 79 | self.total += n 80 | self.error += error 81 | 82 | def _compute(self): 83 | return self.error / self.total 84 | -------------------------------------------------------------------------------- /SE3Transformer/se3_transformer/runtime/training.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a 4 | # copy of this software and associated documentation files (the "Software"), 5 | # to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | # and/or sell copies of the Software, and to permit persons to whom the 8 | # Software is furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | # DEALINGS IN THE SOFTWARE. 20 | # 21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES 22 | # SPDX-License-Identifier: MIT 23 | 24 | import logging 25 | import pathlib 26 | from typing import List 27 | 28 | import numpy as np 29 | import torch 30 | import torch.distributed as dist 31 | import torch.nn as nn 32 | from apex.optimizers import FusedAdam, FusedLAMB 33 | from torch.nn.modules.loss import _Loss 34 | from torch.nn.parallel import DistributedDataParallel 35 | from torch.optim import Optimizer 36 | from torch.utils.data import DataLoader, DistributedSampler 37 | from tqdm import tqdm 38 | 39 | from se3_transformer.data_loading import QM9DataModule 40 | from se3_transformer.model import SE3TransformerPooled 41 | from se3_transformer.model.fiber import Fiber 42 | from se3_transformer.runtime import gpu_affinity 43 | from se3_transformer.runtime.arguments import PARSER 44 | from se3_transformer.runtime.callbacks import QM9MetricCallback, QM9LRSchedulerCallback, BaseCallback, \ 45 | PerformanceCallback 46 | from se3_transformer.runtime.inference import evaluate 47 | from se3_transformer.runtime.loggers import LoggerCollection, DLLogger, WandbLogger, Logger 48 | from se3_transformer.runtime.utils import to_cuda, get_local_rank, init_distributed, seed_everything, \ 49 | using_tensor_cores, increase_l2_fetch_granularity 50 | 51 | 52 | def save_state(model: nn.Module, optimizer: Optimizer, epoch: int, path: pathlib.Path, callbacks: List[BaseCallback]): 53 | """ Saves model, optimizer and epoch states to path (only once per node) """ 54 | if get_local_rank() == 0: 55 | state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict() 56 | checkpoint = { 57 | 'state_dict': state_dict, 58 | 'optimizer_state_dict': optimizer.state_dict(), 59 | 'epoch': epoch 60 | } 61 | for callback in callbacks: 62 | callback.on_checkpoint_save(checkpoint) 63 | 64 | torch.save(checkpoint, str(path)) 65 | logging.info(f'Saved checkpoint to {str(path)}') 66 | 67 | 68 | def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path, callbacks: List[BaseCallback]): 69 | """ Loads model, optimizer and epoch states from path """ 70 | checkpoint = torch.load(str(path), map_location={'cuda:0': f'cuda:{get_local_rank()}'}) 71 | if isinstance(model, DistributedDataParallel): 72 | model.module.load_state_dict(checkpoint['state_dict']) 73 | else: 74 | model.load_state_dict(checkpoint['state_dict']) 75 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 76 | 77 | for callback in callbacks: 78 | callback.on_checkpoint_load(checkpoint) 79 | 80 | logging.info(f'Loaded checkpoint from {str(path)}') 81 | return checkpoint['epoch'] 82 | 83 | 84 | def train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args): 85 | losses = [] 86 | for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch', 87 | desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)): 88 | *inputs, target = to_cuda(batch) 89 | 90 | for callback in callbacks: 91 | callback.on_batch_start() 92 | 93 | with torch.cuda.amp.autocast(enabled=args.amp): 94 | pred = model(*inputs) 95 | loss = loss_fn(pred, target) / args.accumulate_grad_batches 96 | 97 | grad_scaler.scale(loss).backward() 98 | 99 | # gradient accumulation 100 | if (i + 1) % args.accumulate_grad_batches == 0 or (i + 1) == len(train_dataloader): 101 | if args.gradient_clip: 102 | grad_scaler.unscale_(optimizer) 103 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip) 104 | 105 | grad_scaler.step(optimizer) 106 | grad_scaler.update() 107 | optimizer.zero_grad() 108 | 109 | losses.append(loss.item()) 110 | 111 | return np.mean(losses) 112 | 113 | 114 | def train(model: nn.Module, 115 | loss_fn: _Loss, 116 | train_dataloader: DataLoader, 117 | val_dataloader: DataLoader, 118 | callbacks: List[BaseCallback], 119 | logger: Logger, 120 | args): 121 | device = torch.cuda.current_device() 122 | model.to(device=device) 123 | local_rank = get_local_rank() 124 | world_size = dist.get_world_size() if dist.is_initialized() else 1 125 | 126 | if dist.is_initialized(): 127 | model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) 128 | 129 | model.train() 130 | grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp) 131 | if args.optimizer == 'adam': 132 | optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999), 133 | weight_decay=args.weight_decay) 134 | elif args.optimizer == 'lamb': 135 | optimizer = FusedLAMB(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999), 136 | weight_decay=args.weight_decay) 137 | else: 138 | optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, 139 | weight_decay=args.weight_decay) 140 | 141 | epoch_start = load_state(model, optimizer, args.load_ckpt_path, callbacks) if args.load_ckpt_path else 0 142 | 143 | for callback in callbacks: 144 | callback.on_fit_start(optimizer, args) 145 | 146 | for epoch_idx in range(epoch_start, args.epochs): 147 | if isinstance(train_dataloader.sampler, DistributedSampler): 148 | train_dataloader.sampler.set_epoch(epoch_idx) 149 | 150 | loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args) 151 | if dist.is_initialized(): 152 | loss = torch.tensor(loss, dtype=torch.float, device=device) 153 | torch.distributed.all_reduce(loss) 154 | loss = (loss / world_size).item() 155 | 156 | logging.info(f'Train loss: {loss}') 157 | logger.log_metrics({'train loss': loss}, epoch_idx) 158 | 159 | for callback in callbacks: 160 | callback.on_epoch_end() 161 | 162 | if not args.benchmark and args.save_ckpt_path is not None and args.ckpt_interval > 0 \ 163 | and (epoch_idx + 1) % args.ckpt_interval == 0: 164 | save_state(model, optimizer, epoch_idx, args.save_ckpt_path, callbacks) 165 | 166 | if not args.benchmark and args.eval_interval > 0 and (epoch_idx + 1) % args.eval_interval == 0: 167 | evaluate(model, val_dataloader, callbacks, args) 168 | model.train() 169 | 170 | for callback in callbacks: 171 | callback.on_validation_end(epoch_idx) 172 | 173 | if args.save_ckpt_path is not None and not args.benchmark: 174 | save_state(model, optimizer, args.epochs, args.save_ckpt_path, callbacks) 175 | 176 | for callback in callbacks: 177 | callback.on_fit_end() 178 | 179 | 180 | def print_parameters_count(model): 181 | num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) 182 | logging.info(f'Number of trainable parameters: {num_params_trainable}') 183 | 184 | 185 | if __name__ == '__main__': 186 | is_distributed = init_distributed() 187 | local_rank = get_local_rank() 188 | args = PARSER.parse_args() 189 | 190 | logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO) 191 | 192 | logging.info('====== SE(3)-Transformer ======') 193 | logging.info('| Training procedure |') 194 | logging.info('===============================') 195 | 196 | if args.seed is not None: 197 | logging.info(f'Using seed {args.seed}') 198 | seed_everything(args.seed) 199 | 200 | logger = LoggerCollection([ 201 | DLLogger(save_dir=args.log_dir, filename=args.dllogger_name), 202 | WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer') 203 | ]) 204 | 205 | datamodule = QM9DataModule(**vars(args)) 206 | model = SE3TransformerPooled( 207 | fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}), 208 | fiber_out=Fiber({0: args.num_degrees * args.num_channels}), 209 | fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}), 210 | output_dim=1, 211 | tensor_cores=using_tensor_cores(args.amp), # use Tensor Cores more effectively 212 | **vars(args) 213 | ) 214 | loss_fn = nn.L1Loss() 215 | 216 | if args.benchmark: 217 | logging.info('Running benchmark mode') 218 | world_size = dist.get_world_size() if dist.is_initialized() else 1 219 | callbacks = [PerformanceCallback(logger, args.batch_size * world_size)] 220 | else: 221 | callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='validation'), 222 | QM9LRSchedulerCallback(logger, epochs=args.epochs)] 223 | 224 | if is_distributed: 225 | gpu_affinity.set_affinity(gpu_id=get_local_rank(), nproc_per_node=torch.cuda.device_count()) 226 | 227 | print_parameters_count(model) 228 | logger.log_hyperparams(vars(args)) 229 | increase_l2_fetch_granularity() 230 | train(model, 231 | loss_fn, 232 | datamodule.train_dataloader(), 233 | datamodule.val_dataloader(), 234 | callbacks, 235 | logger, 236 | args) 237 | 238 | logging.info('Training finished successfully') 239 | -------------------------------------------------------------------------------- /SE3Transformer/se3_transformer/runtime/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a 4 | # copy of this software and associated documentation files (the "Software"), 5 | # to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | # and/or sell copies of the Software, and to permit persons to whom the 8 | # Software is furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | # DEALINGS IN THE SOFTWARE. 20 | # 21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES 22 | # SPDX-License-Identifier: MIT 23 | 24 | import argparse 25 | import ctypes 26 | import logging 27 | import os 28 | import random 29 | from functools import wraps 30 | from typing import Union, List, Dict 31 | 32 | import numpy as np 33 | import torch 34 | import torch.distributed as dist 35 | from torch import Tensor 36 | 37 | 38 | def aggregate_residual(feats1, feats2, method: str): 39 | """ Add or concatenate two fiber features together. If degrees don't match, will use the ones of feats2. """ 40 | if method in ['add', 'sum']: 41 | return {k: (v + feats1[k]) if k in feats1 else v for k, v in feats2.items()} 42 | elif method in ['cat', 'concat']: 43 | return {k: torch.cat([v, feats1[k]], dim=1) if k in feats1 else v for k, v in feats2.items()} 44 | else: 45 | raise ValueError('Method must be add/sum or cat/concat') 46 | 47 | 48 | def degree_to_dim(degree: int) -> int: 49 | return 2 * degree + 1 50 | 51 | 52 | def unfuse_features(features: Tensor, degrees: List[int]) -> Dict[str, Tensor]: 53 | return dict(zip(map(str, degrees), features.split([degree_to_dim(deg) for deg in degrees], dim=-1))) 54 | 55 | 56 | def str2bool(v: Union[bool, str]) -> bool: 57 | if isinstance(v, bool): 58 | return v 59 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 60 | return True 61 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 62 | return False 63 | else: 64 | raise argparse.ArgumentTypeError('Boolean value expected.') 65 | 66 | 67 | def to_cuda(x): 68 | """ Try to convert a Tensor, a collection of Tensors or a DGLGraph to CUDA """ 69 | if isinstance(x, Tensor): 70 | return x.cuda(non_blocking=True) 71 | elif isinstance(x, tuple): 72 | return (to_cuda(v) for v in x) 73 | elif isinstance(x, list): 74 | return [to_cuda(v) for v in x] 75 | elif isinstance(x, dict): 76 | return {k: to_cuda(v) for k, v in x.items()} 77 | else: 78 | # DGLGraph or other objects 79 | return x.to(device=torch.cuda.current_device()) 80 | 81 | 82 | def get_local_rank() -> int: 83 | return int(os.environ.get('LOCAL_RANK', 0)) 84 | 85 | 86 | def init_distributed() -> bool: 87 | world_size = int(os.environ.get('WORLD_SIZE', 1)) 88 | distributed = world_size > 1 89 | if distributed: 90 | backend = 'nccl' if torch.cuda.is_available() else 'gloo' 91 | dist.init_process_group(backend=backend, init_method='env://') 92 | if backend == 'nccl': 93 | torch.cuda.set_device(get_local_rank()) 94 | else: 95 | logging.warning('Running on CPU only!') 96 | assert torch.distributed.is_initialized() 97 | return distributed 98 | 99 | 100 | def increase_l2_fetch_granularity(): 101 | # maximum fetch granularity of L2: 128 bytes 102 | _libcudart = ctypes.CDLL('libcudart.so') 103 | # set device limit on the current device 104 | # cudaLimitMaxL2FetchGranularity = 0x05 105 | pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int)) 106 | _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128)) 107 | _libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05)) 108 | assert pValue.contents.value == 128 109 | 110 | 111 | def seed_everything(seed): 112 | seed = int(seed) 113 | random.seed(seed) 114 | np.random.seed(seed) 115 | torch.manual_seed(seed) 116 | torch.cuda.manual_seed_all(seed) 117 | 118 | 119 | def rank_zero_only(fn): 120 | @wraps(fn) 121 | def wrapped_fn(*args, **kwargs): 122 | if not dist.is_initialized() or dist.get_rank() == 0: 123 | return fn(*args, **kwargs) 124 | 125 | return wrapped_fn 126 | 127 | 128 | def using_tensor_cores(amp: bool) -> bool: 129 | major_cc, minor_cc = torch.cuda.get_device_capability() 130 | return (amp and major_cc >= 7) or major_cc >= 8 131 | -------------------------------------------------------------------------------- /SE3Transformer/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='se3-transformer', 5 | packages=find_packages(), 6 | include_package_data=True, 7 | version='1.0.0', 8 | description='PyTorch + DGL implementation of SE(3)-Transformers', 9 | author='Alexandre Milesi', 10 | author_email='alexandrem@nvidia.com', 11 | ) 12 | -------------------------------------------------------------------------------- /SE3Transformer/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uw-ipd/RoseTTAFold2NA/f761af286729ea08a6ddab149023c1b73458fbe2/SE3Transformer/tests/__init__.py -------------------------------------------------------------------------------- /SE3Transformer/tests/test_equivariance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a 4 | # copy of this software and associated documentation files (the "Software"), 5 | # to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | # and/or sell copies of the Software, and to permit persons to whom the 8 | # Software is furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | # DEALINGS IN THE SOFTWARE. 20 | # 21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES 22 | # SPDX-License-Identifier: MIT 23 | 24 | import torch 25 | 26 | from se3_transformer.model import SE3Transformer 27 | from se3_transformer.model.fiber import Fiber 28 | from tests.utils import get_random_graph, assign_relative_pos, get_max_diff, rot 29 | 30 | # Tolerances for equivariance error abs( f(x) @ R - f(x @ R) ) 31 | TOL = 1e-3 32 | CHANNELS, NODES = 32, 512 33 | 34 | 35 | def _get_outputs(model, R): 36 | feats0 = torch.randn(NODES, CHANNELS, 1) 37 | feats1 = torch.randn(NODES, CHANNELS, 3) 38 | 39 | coords = torch.randn(NODES, 3) 40 | graph = get_random_graph(NODES) 41 | if torch.cuda.is_available(): 42 | feats0 = feats0.cuda() 43 | feats1 = feats1.cuda() 44 | R = R.cuda() 45 | coords = coords.cuda() 46 | graph = graph.to('cuda') 47 | model.cuda() 48 | 49 | graph1 = assign_relative_pos(graph, coords) 50 | out1 = model(graph1, {'0': feats0, '1': feats1}, {}) 51 | graph2 = assign_relative_pos(graph, coords @ R) 52 | out2 = model(graph2, {'0': feats0, '1': feats1 @ R}, {}) 53 | 54 | return out1, out2 55 | 56 | 57 | def _get_model(**kwargs): 58 | return SE3Transformer( 59 | num_layers=4, 60 | fiber_in=Fiber.create(2, CHANNELS), 61 | fiber_hidden=Fiber.create(3, CHANNELS), 62 | fiber_out=Fiber.create(2, CHANNELS), 63 | fiber_edge=Fiber({}), 64 | num_heads=8, 65 | channels_div=2, 66 | **kwargs 67 | ) 68 | 69 | 70 | def test_equivariance(): 71 | model = _get_model() 72 | R = rot(*torch.rand(3)) 73 | if torch.cuda.is_available(): 74 | R = R.cuda() 75 | out1, out2 = _get_outputs(model, R) 76 | 77 | assert torch.allclose(out2['0'], out1['0'], atol=TOL), \ 78 | f'type-0 features should be invariant {get_max_diff(out1["0"], out2["0"])}' 79 | assert torch.allclose(out2['1'], (out1['1'] @ R), atol=TOL), \ 80 | f'type-1 features should be equivariant {get_max_diff(out1["1"] @ R, out2["1"])}' 81 | 82 | 83 | def test_equivariance_pooled(): 84 | model = _get_model(pooling='avg', return_type=1) 85 | R = rot(*torch.rand(3)) 86 | if torch.cuda.is_available(): 87 | R = R.cuda() 88 | out1, out2 = _get_outputs(model, R) 89 | 90 | assert torch.allclose(out2, (out1 @ R), atol=TOL), \ 91 | f'type-1 features should be equivariant {get_max_diff(out1 @ R, out2)}' 92 | 93 | 94 | def test_invariance_pooled(): 95 | model = _get_model(pooling='avg', return_type=0) 96 | R = rot(*torch.rand(3)) 97 | if torch.cuda.is_available(): 98 | R = R.cuda() 99 | out1, out2 = _get_outputs(model, R) 100 | 101 | assert torch.allclose(out2, out1, atol=TOL), \ 102 | f'type-0 features should be invariant {get_max_diff(out1, out2)}' 103 | -------------------------------------------------------------------------------- /SE3Transformer/tests/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a 4 | # copy of this software and associated documentation files (the "Software"), 5 | # to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | # and/or sell copies of the Software, and to permit persons to whom the 8 | # Software is furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | # DEALINGS IN THE SOFTWARE. 20 | # 21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES 22 | # SPDX-License-Identifier: MIT 23 | 24 | import dgl 25 | import torch 26 | 27 | 28 | def get_random_graph(N, num_edges_factor=18): 29 | graph = dgl.transform.remove_self_loop(dgl.rand_graph(N, N * num_edges_factor)) 30 | return graph 31 | 32 | 33 | def assign_relative_pos(graph, coords): 34 | src, dst = graph.edges() 35 | graph.edata['rel_pos'] = coords[src] - coords[dst] 36 | return graph 37 | 38 | 39 | def get_max_diff(a, b): 40 | return (a - b).abs().max().item() 41 | 42 | 43 | def rot_z(gamma): 44 | return torch.tensor([ 45 | [torch.cos(gamma), -torch.sin(gamma), 0], 46 | [torch.sin(gamma), torch.cos(gamma), 0], 47 | [0, 0, 1] 48 | ], dtype=gamma.dtype) 49 | 50 | 51 | def rot_y(beta): 52 | return torch.tensor([ 53 | [torch.cos(beta), 0, torch.sin(beta)], 54 | [0, 1, 0], 55 | [-torch.sin(beta), 0, torch.cos(beta)] 56 | ], dtype=beta.dtype) 57 | 58 | 59 | def rot(alpha, beta, gamma): 60 | return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma) 61 | -------------------------------------------------------------------------------- /example/RNA.fa: -------------------------------------------------------------------------------- 1 | > RNA 2 | GAGAGAGAAGTCAACCAGAGAAACACACCAACCCATTGCACTCCGGGTTGGTGGTATATTACCTGGTACGGGGGAAACTTCGTGGTGGCCGGCCACCTGACA 3 | -------------------------------------------------------------------------------- /example/dna_binding_protein.fa: -------------------------------------------------------------------------------- 1 | > ANTENNAPEDIA HOMEODOMAIN|Drosophila melanogaster (7227) 2 | MERKRGRQTYTRYQTLELEKEFHFNRYLTRRRRIEIAHALSLTERQIKIWFQNRRMKWKKEN 3 | -------------------------------------------------------------------------------- /example/rna_binding_protein.fa: -------------------------------------------------------------------------------- 1 | > prot 2 | TRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKM 3 | -------------------------------------------------------------------------------- /input_prep/make_protein_msa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # inputs 4 | in_fasta="$1" 5 | out_dir="$2" 6 | tag="$3" 7 | 8 | # resources 9 | CPU="$4" 10 | MEM="$5" 11 | 12 | # sequence databases 13 | DB_UR30="$PIPEDIR/UniRef30_2020_06/UniRef30_2020_06" 14 | DB_BFD="$PIPEDIR/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt" 15 | 16 | # setup hhblits command 17 | HHBLITS_UR30="hhblits -o /dev/null -mact 0.35 -maxfilt 100000000 -neffmax 20 -cov 25 -cpu $CPU -nodiff -realign_max 100000000 -maxseq 1000000 -maxmem $MEM -n 4 -d $DB_UR30" 18 | HHBLITS_BFD="hhblits -o /dev/null -mact 0.35 -maxfilt 100000000 -neffmax 20 -cov 25 -cpu $CPU -nodiff -realign_max 100000000 -maxseq 1000000 -maxmem $MEM -n 4 -d $DB_BFD" 19 | 20 | mkdir -p $out_dir/hhblits 21 | tmp_dir="$out_dir/hhblits" 22 | out_prefix="$out_dir/$tag" 23 | 24 | echo out_prefix $out_prefix 25 | 26 | # perform iterative searches against UniRef30 27 | prev_a3m="$in_fasta" 28 | for e in 1e-10 1e-6 1e-3 29 | do 30 | echo "Running HHblits against UniRef30 with E-value cutoff $e" 31 | $HHBLITS_UR30 -i $prev_a3m -oa3m $tmp_dir/t000_.$e.a3m -e $e -v 0 32 | hhfilter -id 90 -cov 75 -i $tmp_dir/t000_.$e.a3m -o $tmp_dir/t000_.$e.id90cov75.a3m 33 | hhfilter -id 90 -cov 50 -i $tmp_dir/t000_.$e.a3m -o $tmp_dir/t000_.$e.id90cov50.a3m 34 | prev_a3m="$tmp_dir/t000_.$e.id90cov50.a3m" 35 | n75=`grep -c "^>" $tmp_dir/t000_.$e.id90cov75.a3m` 36 | n50=`grep -c "^>" $tmp_dir/t000_.$e.id90cov50.a3m` 37 | 38 | if ((n75>2000)) 39 | then 40 | if [ ! -s ${out_prefix}.msa0.a3m ] 41 | then 42 | cp $tmp_dir/t000_.$e.id90cov75.a3m ${out_prefix}.msa0.a3m 43 | break 44 | fi 45 | elif ((n50>4000)) 46 | then 47 | if [ ! -s ${out_prefix}.msa0.a3m ] 48 | then 49 | cp $tmp_dir/t000_.$e.id90cov50.a3m ${out_prefix}.msa0.a3m 50 | break 51 | fi 52 | else 53 | continue 54 | fi 55 | done 56 | 57 | # perform iterative searches against BFD if it failes to get enough sequences 58 | if [ ! -s ${out_prefix}.msa0.a3m ] 59 | then 60 | e=1e-3 61 | echo "Running HHblits against BFD with E-value cutoff $e" 62 | $HHBLITS_BFD -i $prev_a3m -oa3m $tmp_dir/t000_.$e.bfd.a3m -e $e -v 0 63 | hhfilter -id 90 -cov 75 -i $tmp_dir/t000_.$e.bfd.a3m -o $tmp_dir/t000_.$e.bfd.id90cov75.a3m 64 | hhfilter -id 90 -cov 50 -i $tmp_dir/t000_.$e.bfd.a3m -o $tmp_dir/t000_.$e.bfd.id90cov50.a3m 65 | prev_a3m="$tmp_dir/t000_.$e.bfd.id90cov50.a3m" 66 | n75=`grep -c "^>" $tmp_dir/t000_.$e.bfd.id90cov75.a3m` 67 | n50=`grep -c "^>" $tmp_dir/t000_.$e.bfd.id90cov50.a3m` 68 | 69 | if ((n75>2000)) 70 | then 71 | if [ ! -s ${out_prefix}.msa0.a3m ] 72 | then 73 | cp $tmp_dir/t000_.$e.bfd.id90cov75.a3m ${out_prefix}.msa0.a3m 74 | fi 75 | elif ((n50>4000)) 76 | then 77 | if [ ! -s ${out_prefix}.msa0.a3m ] 78 | then 79 | cp $tmp_dir/t000_.$e.bfd.id90cov50.a3m ${out_prefix}.msa0.a3m 80 | fi 81 | fi 82 | fi 83 | 84 | if [ ! -s ${out_prefix}.msa0.a3m ] 85 | then 86 | cp $prev_a3m ${out_prefix}.msa0.a3m 87 | fi -------------------------------------------------------------------------------- /input_prep/make_rna_msa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # inputs 4 | in_fasta="$1" 5 | out_dir="$2" 6 | out_tag="$3" 7 | 8 | overwrite=true 9 | if [ -f $out_dir/$out_tag.afa -a $overwrite = false ] 10 | then 11 | exit 0 12 | fi 13 | 14 | # resources 15 | CPU="$4" 16 | MEM="$5" 17 | 18 | RNADBDIR="$PIPEDIR/RNA" 19 | 20 | # databases 21 | db0="$RNADBDIR/Rfam.cm"; 22 | db1="$RNADBDIR/rnacentral.fasta"; 23 | db2="$RNADBDIR/nt"; 24 | db0to1="$RNADBDIR/rfam_annotations.tsv.gz"; 25 | db0to2="$RNADBDIR/Rfam.full_region.gz"; 26 | 27 | max_aln_seqs=50000 28 | max_target_seqs=50000 29 | max_split_seqs=5000 30 | max_hhfilter_seqs=5000 31 | max_rfam_num=100 32 | 33 | Lch=`grep -v '^>' $in_fasta | tr -d '\n' | wc -c` 34 | 35 | mkdir -p $out_dir 36 | cp $in_fasta $out_dir 37 | cd $out_dir 38 | in_fasta=`basename $in_fasta` 39 | 40 | function retrieveSeq { 41 | tabfile=$1 42 | db=$2 43 | tag=$3 44 | 45 | head -n $max_aln_seqs $tabfile | awk '{if ($2<$3) print $1,(($2-6>1)?($2-6):1)"-"($3+6),"plus"; else print $1,(($3-6>1)?($3-6):1)"-"($2+6),"minus"}' > $tag.list 46 | split -l $max_split_seqs $tag.list $tag.list.split. 47 | 48 | for file in $tag.list.split.* 49 | do 50 | suffix=`echo $file | sed 's/.*\.list\.split\.//g'` 51 | blastdbcmd -db $db -entry_batch $tag.list.split.$suffix -out $tag.db.$suffix -outfmt ">Accession=%a_TaxID=%T @@NEWLINE@@%s" &> /dev/null 52 | sed -i 's/@@NEWLINE@@/\n/g' $tag.db.$suffix 53 | done 54 | cat $tag.db.* | sed 's/_\([0-9]*\)_TaxID=0/_TaxID=\1/' > $tag.db # fix for incorrect taxids 55 | rm $tag.db.* $tag.list.split.* 56 | } 57 | 58 | # cmscan on Rfam 59 | echo "Run cmscan on Rfam" 60 | cmscan --tblout cmscan.tblout -o cmscan.out --noali $db0 $in_fasta 61 | families=`grep -v '^#' cmscan.tblout | head -n $max_rfam_num | uniq | awk '{print $2}' | sed -z 's/\n/|/g;s/|$/\n/'` 62 | echo "Rfam families:" $families 63 | rm cmscan.out cmscan.tblout 64 | 65 | # Rfam->RNACentral 66 | zcat $db0to1 | grep -E \'$families\' | awk '{print $1,1+$5,1+$6}' > rfam1.tab 67 | head -n $max_aln_seqs rfam1.tab > rfam1.tab.tmp; mv rfam1.tab.tmp rfam1.tab 68 | retrieveSeq rfam1.tab $db1 rfam1 69 | rm rfam1.list rfam1.tab 70 | 71 | # Rfam->nt 72 | zcat $db0to2 | grep -E \'$families\' | awk '{print $2,$3,$4}' > rfam2.tab 73 | head -n $max_aln_seqs rfam2.tab > rfam2.tab.tmp; mv rfam2.tab.tmp rfam2.tab 74 | retrieveSeq rfam2.tab $db2 rfam2 75 | rm rfam2.list rfam2.tab 76 | 77 | if [[ -f "rfam1.db" || -f "rfam2.db" ]] 78 | then 79 | cat rfam1.db rfam2.db > db0 80 | rm rfam1.db rfam2.db 81 | fi 82 | 83 | # blastn on RNACentral 84 | echo "Run blastn on RNACentral" 85 | blastn -num_threads $CPU -query $in_fasta -strand plus -db $db1 -out blastn1.tab -task blastn -max_target_seqs $max_target_seqs -outfmt '6 saccver sstart send evalue bitscore nident staxids' 86 | retrieveSeq blastn1.tab $db1 blastn1 87 | rm blastn1.list blastn1.tab 88 | 89 | # blastn on nt 90 | echo "Run blastn on nt" 91 | blastn -num_threads $CPU -query $in_fasta -strand both -db $db2 -out blastn2.tab -task blastn -max_target_seqs $max_target_seqs -outfmt '6 saccver sstart send evalue bitscore nident staxids' 92 | retrieveSeq blastn2.tab $db2 blastn2 93 | rm blastn2.list blastn2.tab 94 | 95 | # combine, remove redundant 96 | echo "Cluster sequences" 97 | throw_away_sequences=$(( $Lch*2/5 )); 98 | cat db0 blastn*.db > trim.db 99 | rm db0 blastn*.db 100 | 101 | for cut in 1.00 0.99 0.95 0.90 102 | do 103 | cd-hit-est-2d -T $CPU -i $in_fasta -i2 trim.db -c $cut -o cdhitest2d.db -l $throw_away_sequences -M 0 &> /dev/null 104 | cd-hit-est -T $CPU -i cdhitest2d.db -c $cut -o db -l $throw_away_sequences -M 0 &> /dev/null 105 | nhits=`grep '^>' db | wc -l` 106 | if [[ $nhits -lt $max_aln_seqs ]] 107 | then 108 | break 109 | fi 110 | done 111 | rm cdhitest2d.db cdhitest2d.db.clstr db.clstr 112 | 113 | # nhmmer on previous hits 114 | echo "Realign all with nhmmer" 115 | for e_val in 1e-8 1e-7 1e-6 1e-3 1e-2 1e-1 116 | do 117 | nhmmer --noali -A nhmmer.a2m --incE $e_val --cpu $CPU --watson $in_fasta db | grep 'no alignment saved' 118 | esl-reformat --replace=acgt:____ a2m nhmmer.a2m > $out_tag.unfilter.afa 119 | # add query 120 | mafft --preservecase --addfull $out_tag.unfilter.afa --keeplength $in_fasta > $out_tag.wquery.unfilt.afa 2> /dev/null 121 | hhfilter -i $out_tag.wquery.unfilt.afa -id 99 -cov 50 -o $out_tag.afa -M first 122 | hitnum=`grep '^>' $out_tag.afa | wc -l` 123 | if [[ $hitnum -gt $max_hhfilter_seqs ]] 124 | then 125 | break 126 | fi 127 | if [[ $hitnum -eq 0 ]] 128 | then 129 | echo "no hits found" 130 | cp $in_fasta $out_tag.afa 131 | fi 132 | done 133 | 134 | rm nhmmer.a2m 135 | -------------------------------------------------------------------------------- /input_prep/merge_msa_prot_rna.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import numpy as np 3 | import string 4 | import gzip 5 | import os 6 | import sys 7 | import re 8 | 9 | TABLE = str.maketrans(dict.fromkeys(string.ascii_lowercase)) 10 | ALPHABET = np.array(list("ARNDCQEGHILKMFPSTWYV-"), dtype='|S1').view(np.uint8) 11 | remove_lower = lambda text: re.sub('[a-z]', '', text) 12 | RNA_ALPHABET = np.array(list("ACGT-"), dtype='|S1').view(np.uint8) 13 | 14 | def seq2number(seq): 15 | seq_no_ins = seq.translate(TABLE) 16 | seq_no_ins = np.array(list(seq_no_ins), dtype='|S1').view(np.uint8) 17 | for i in range(ALPHABET.shape[0]): 18 | seq_no_ins[seq_no_ins == ALPHABET[i]] = i 19 | seq_no_ins[seq_no_ins > 20] = 20 20 | 21 | return seq_no_ins 22 | 23 | def rnaseq2number(seq): 24 | seq_no_ins = seq.translate(TABLE) 25 | seq_no_ins = np.array(list(seq_no_ins), dtype='|S1').view(np.uint8) 26 | for i in range(RNA_ALPHABET.shape[0]): 27 | seq_no_ins[seq_no_ins == RNA_ALPHABET[i]] = i 28 | seq_no_ins[seq_no_ins > 5] = 5 29 | 30 | return seq_no_ins 31 | 32 | def calc_seqID(query, cand): 33 | same = (query == cand).sum() 34 | return same / float(len(query)) 35 | 36 | def read_a3m(fn): 37 | # read sequences in a3m file 38 | # only take one (having the highest seqID to query) per each taxID 39 | is_first = True 40 | is_ignore = True 41 | tmp = {} 42 | if fn.split('.')[-1] == "gz": 43 | fp = gzip.open(fn, 'rt') 44 | else: 45 | fp = open(fn, 'r') 46 | 47 | for line in fp: 48 | if line[0] == ">": 49 | if is_first: 50 | continue 51 | x = line.split() 52 | seqID = x[0][1:] 53 | try: 54 | idx = line.index("TaxID") 55 | is_ignore = False 56 | except: 57 | is_ignore = True 58 | continue 59 | TaxID = line[idx:].split()[0].split('=')[-1] 60 | if not TaxID in tmp: 61 | tmp[TaxID] = list() 62 | else: 63 | if is_first: 64 | query = line.strip() 65 | is_first = False 66 | elif is_ignore: 67 | continue 68 | else: 69 | tmp[TaxID].append((seqID, line.strip())) 70 | 71 | query_in_num = seq2number(query) 72 | a3m = {} 73 | for TaxID in tmp: 74 | if len(tmp[TaxID]) < 1: 75 | continue 76 | if len(tmp[TaxID]) < 2: 77 | a3m[TaxID] = tmp[TaxID][0] 78 | continue 79 | # Get the best sequence only 80 | score_s = list() 81 | for seqID, seq in tmp[TaxID]: 82 | seq_in_num = seq2number(seq) 83 | score = calc_seqID(query_in_num, seq_in_num) 84 | score_s.append(score) 85 | # 86 | idx = np.argmax(score_s) 87 | a3m[TaxID] = tmp[TaxID][idx] 88 | 89 | return query, a3m 90 | 91 | def read_afa(fn): 92 | # read sequences in afa file (RNA) 93 | # only take one (having the highest seqID to query) per each taxID 94 | is_first = True 95 | is_ignore = True 96 | tmp = {} 97 | if fn.split('.')[-1] == "gz": 98 | fp = gzip.open(fn, 'rt') 99 | else: 100 | fp = open(fn, 'r') 101 | 102 | for line in fp: 103 | if line[0] == ">": 104 | if is_first: 105 | continue 106 | x = line.split() 107 | seqID = x[0][1:] 108 | try: 109 | idx = line.index("TaxID") 110 | is_ignore = False 111 | except: 112 | is_ignore = True 113 | continue 114 | TaxID = line[idx:].split('/')[0].split('=')[-1] 115 | if not TaxID in tmp: 116 | tmp[TaxID] = list() 117 | else: 118 | if is_first: 119 | query = line.strip() 120 | is_first = False 121 | elif is_ignore: 122 | continue 123 | else: 124 | tmp[TaxID].append((seqID, line.strip())) 125 | 126 | query_in_num = rnaseq2number(query) 127 | a3m = {} 128 | for TaxID in tmp: 129 | if len(tmp[TaxID]) < 1: 130 | continue 131 | if len(tmp[TaxID]) < 2: 132 | a3m[TaxID] = tmp[TaxID][0] 133 | continue 134 | # Get the best sequence only 135 | score_s = list() 136 | for seqID, seq in tmp[TaxID]: 137 | seq_in_num = rnaseq2number(seq) 138 | score = calc_seqID(query_in_num, seq_in_num) 139 | score_s.append(score) 140 | # 141 | idx = np.argmax(score_s) 142 | a3m[TaxID] = tmp[TaxID][idx] 143 | 144 | return query, a3m 145 | 146 | def main(fnA, fnB, pair_fn): 147 | queryA, a3mA = read_a3m(fnA) 148 | queryB, a3mB = read_afa(fnB) 149 | 150 | #fnA_filt = fnA.split('.a3m')[0] + '.i90.c75.a3m' 151 | #fnB_filt = fnB.split('.a3m')[0] + '.i90.c75.a3m' 152 | # 153 | #def read_filt(filename): 154 | # all_seqs = [] 155 | # name = '' 156 | # seq = '' 157 | # with open(filename) as fp: 158 | # queryname = fp.readline().strip().split()[0][1:] 159 | # query = fp.readline().strip() 160 | # qlen = len(query) 161 | # for line in fp: 162 | # if line[0] == '>': 163 | # lineparts = line.strip().split() 164 | # if name and seq and name != queryname: 165 | # match = 0 166 | # for i in range(qlen): 167 | # if query[i] == seq[i]: 168 | # match += 1 169 | # all_seqs.append([name, seq, match]) 170 | # name = lineparts[0][1:] 171 | # seq = '' 172 | # else: 173 | # seq += remove_lower(line[:-1]) 174 | # 175 | # if name and seq: 176 | # match = 0 177 | # for i in range(qlen): 178 | # if query[i] == seq[i]: 179 | # match += 1 180 | # all_seqs.append([name, seq, match]) 181 | # 182 | # all_seqs.sort(key = lambda x:x[2], reverse = True) 183 | # return all_seqs 184 | # 185 | #filtA = read_filt(fnA_filt) 186 | #filtB = read_filt(fnB_filt) 187 | 188 | wrt = '>query\n' 189 | wrt += queryA 190 | wrt += '/' 191 | wrt += queryB 192 | wrt += "\n" 193 | wrt2 = '' 194 | 195 | wrtlen = 0 196 | doneset = set([]) 197 | for taxA in a3mA: 198 | if taxA in a3mB: 199 | wrt += ">%s %s\n"%(a3mA[taxA][0], a3mB[taxA][0]) 200 | wrt += "%s/%s\n"%(remove_lower(a3mA[taxA][1]), remove_lower(a3mB[taxA][1])) 201 | wrtlen += 1 202 | doneset.add(a3mA[taxA][0]) 203 | doneset.add(a3mB[taxA][0]) 204 | 205 | elif taxA not in doneset: 206 | wrt2 += ">%s %s\n"%(a3mA[taxA][0], 'singlerep') 207 | wrt2 += "%s%s\n"%(remove_lower(a3mA[taxA][1]), '-'*len(queryB)) 208 | wrtlen += 1 209 | doneset.add(a3mA[taxA][0]) 210 | 211 | for taxB in a3mB: 212 | if taxB not in doneset: 213 | wrt2 += ">%s %s\n"%(a3mB[taxB][0], 'singlerep') 214 | wrt2 += "%s%s\n"%('-'*len(queryA), remove_lower(a3mB[taxB][1])) 215 | wrtlen += 1 216 | doneset.add(a3mB[taxB][0]) 217 | 218 | 219 | with open(pair_fn, 'wt') as fp: 220 | fp.write(wrt) 221 | fp.write(wrt2) 222 | 223 | # if wrtlen <= 3000: 224 | # for A in filtA[:1500]: 225 | # if A[0] not in doneset: 226 | # fp.write('>' + A[0] + ' singleA\n' + A[1] + '/' + '-'*len(queryB) + '\n') 227 | # 228 | # for B in filtB[:1500]: 229 | # if B[0] not in doneset: 230 | # fp.write('>' + B[0] + ' singleB\n' + '-'*len(queryA) + '/' + B[1] + '\n') 231 | 232 | print(str(wrtlen) + '\t' + pair_fn) 233 | 234 | 235 | if __name__ == '__main__': 236 | 237 | if len(sys.argv) == 1: 238 | print ("USAGE: python merge_msa_prot_rna.py [a3m for Protein] [afa for RNA] [a3m for output]") 239 | sys.exit() 240 | 241 | fnA = sys.argv[1] 242 | fnB = sys.argv[2] 243 | pair_fn = sys.argv[3] 244 | 245 | main(fnA, fnB, pair_fn) 246 | -------------------------------------------------------------------------------- /input_prep/reprocess_rnac.pl: -------------------------------------------------------------------------------- 1 | #! /usr/bin/perl 2 | use strict; 3 | 4 | my $taxids = shift @ARGV; 5 | my $idfile = shift @ARGV; 6 | 7 | my %ids; 8 | open(GZIN, "gunzip -c $taxids |") or die("gunzip $taxids: $!"); 9 | foreach my $line () { 10 | my ($id,$taxid); 11 | ($id,$_,$_,$taxid,$_,$_) = split ' ',$line; 12 | if (not defined $ids{$id}) { 13 | $ids{$id} = [] 14 | } 15 | if (not $taxid ~~ @{$ids{$id}}) { 16 | push (@{$ids{$id}}, $taxid) 17 | } 18 | } 19 | close(GZIN); 20 | 21 | system ("mv $idfile $idfile.bak"); 22 | open (GZOUT, "| gzip -c > $idfile") or die("gzip $idfile: $!"); 23 | open(GZIN, "gunzip -c $idfile.bak |") or die("gunzip $idfile: $!"); 24 | foreach my $line () { 25 | #URS0000000001 RF00177 109.4 3.3e-33 2 200 29 230 Bacterial small subunit ribosomal RNA 26 | my @fields = split /\t/,$line; 27 | my $id = $fields[0]; 28 | 29 | foreach my $taxid (@{$ids{$id}}) { 30 | print GZOUT $id."_".$taxid."\t".join("\t",@fields[1..$#fields]); 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /network/AuxiliaryPredictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from chemical import NAATOKENS 4 | 5 | class DistanceNetwork(nn.Module): 6 | def __init__(self, n_feat, p_drop=0.1): 7 | super(DistanceNetwork, self).__init__() 8 | # 9 | self.proj_symm = nn.Linear(n_feat, 37*2) 10 | self.proj_asymm = nn.Linear(n_feat, 37+19) 11 | 12 | self.reset_parameter() 13 | 14 | def reset_parameter(self): 15 | # initialize linear layer for final logit prediction 16 | nn.init.zeros_(self.proj_symm.weight) 17 | nn.init.zeros_(self.proj_asymm.weight) 18 | nn.init.zeros_(self.proj_symm.bias) 19 | nn.init.zeros_(self.proj_asymm.bias) 20 | 21 | def forward(self, x): 22 | # input: pair info (B, L, L, C) 23 | 24 | # predict theta, phi (non-symmetric) 25 | logits_asymm = self.proj_asymm(x) 26 | logits_theta = logits_asymm[:,:,:,:37].permute(0,3,1,2) 27 | logits_phi = logits_asymm[:,:,:,37:].permute(0,3,1,2) 28 | 29 | # predict dist, omega 30 | logits_symm = self.proj_symm(x) 31 | logits_symm = logits_symm + logits_symm.permute(0,2,1,3) 32 | logits_dist = logits_symm[:,:,:,:37].permute(0,3,1,2) 33 | logits_omega = logits_symm[:,:,:,37:].permute(0,3,1,2) 34 | 35 | return logits_dist, logits_omega, logits_theta, logits_phi 36 | 37 | class MaskedTokenNetwork(nn.Module): 38 | def __init__(self, n_feat, p_drop=0.1): 39 | super(MaskedTokenNetwork, self).__init__() 40 | self.proj = nn.Linear(n_feat, NAATOKENS) 41 | 42 | self.reset_parameter() 43 | 44 | def reset_parameter(self): 45 | nn.init.zeros_(self.proj.weight) 46 | nn.init.zeros_(self.proj.bias) 47 | 48 | def forward(self, x): 49 | B, N, L = x.shape[:3] 50 | logits = self.proj(x).permute(0,3,1,2).reshape(B, -1, N*L) 51 | 52 | return logits 53 | 54 | class LDDTNetwork(nn.Module): 55 | def __init__(self, n_feat, n_bin_lddt=50): 56 | super(LDDTNetwork, self).__init__() 57 | self.proj = nn.Linear(n_feat, n_bin_lddt) 58 | 59 | self.reset_parameter() 60 | 61 | def reset_parameter(self): 62 | nn.init.zeros_(self.proj.weight) 63 | nn.init.zeros_(self.proj.bias) 64 | 65 | def forward(self, x): 66 | logits = self.proj(x) # (B, L, 50) 67 | 68 | return logits.permute(0,2,1) 69 | 70 | class PAENetwork(nn.Module): 71 | def __init__(self, n_feat, n_bin_pae=64): 72 | super(PAENetwork, self).__init__() 73 | self.proj = nn.Linear(n_feat, n_bin_pae) 74 | self.reset_parameter() 75 | def reset_parameter(self): 76 | nn.init.zeros_(self.proj.weight) 77 | nn.init.zeros_(self.proj.bias) 78 | 79 | def forward(self, pair, state): 80 | L = pair.shape[1] 81 | left = state.unsqueeze(2).expand(-1,-1,L,-1) 82 | right = state.unsqueeze(1).expand(-1,L,-1,-1) 83 | 84 | logits = self.proj( torch.cat((pair, left, right), dim=-1) ) # (B, L, L, 64) 85 | 86 | return logits.permute(0,3,1,2) 87 | 88 | class BinderNetwork(nn.Module): 89 | def __init__(self, n_hidden=64, n_bin_pae=64): 90 | super(BinderNetwork, self).__init__() 91 | #self.proj = nn.Linear(n_bin_pae, n_hidden) 92 | #self.classify = torch.nn.Linear(2*n_hidden, 1) 93 | self.classify = torch.nn.Linear(n_bin_pae, 1) 94 | self.reset_parameter() 95 | 96 | def reset_parameter(self): 97 | #nn.init.zeros_(self.proj.weight) 98 | #nn.init.zeros_(self.proj.bias) 99 | nn.init.zeros_(self.classify.weight) 100 | nn.init.zeros_(self.classify.bias) 101 | 102 | def forward(self, pae, same_chain): 103 | #logits = self.proj( pae.permute(0,2,3,1) ) 104 | logits = pae.permute(0,2,3,1) 105 | #logits_intra = torch.mean( logits[same_chain==1], dim=0 ) 106 | logits_inter = torch.mean( logits[same_chain==0], dim=0 ).nan_to_num() # all zeros if single chain 107 | #prob = torch.sigmoid( self.classify( torch.cat((logits_intra,logits_inter)) ) ) 108 | prob = torch.sigmoid( self.classify( logits_inter ) ) 109 | return prob 110 | -------------------------------------------------------------------------------- /network/Embeddings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import einsum 5 | import torch.utils.checkpoint as checkpoint 6 | from util import * 7 | from util_module import Dropout, get_clones, create_custom_forward, rbf, init_lecun_normal 8 | from Attention_module import Attention, FeedForwardLayer 9 | from Track_module import PairStr2Pair, PositionalEncoding2D 10 | from chemical import NAATOKENS,NTOTALDOFS 11 | 12 | # Module contains classes and functions to generate initial embeddings 13 | 14 | # class PositionalEncoding2D(nn.Module): 15 | # # Add relative positional encoding to pair features 16 | # def __init__(self, d_model, minpos=-32, maxpos=32, p_drop=0.1): 17 | # super(PositionalEncoding2D, self).__init__() 18 | # self.minpos = minpos 19 | # self.maxpos = maxpos 20 | # self.nbin = abs(minpos)+maxpos+1 21 | # self.emb = nn.Embedding(self.nbin, d_model) 22 | # 23 | # def forward(self, x, idx): 24 | # bins = torch.arange(self.minpos, self.maxpos, device=x.device) 25 | # seqsep = idx[:,None,:] - idx[:,:,None] # (B, L, L) 26 | # # 27 | # ib = torch.bucketize(seqsep, bins).long() # (B, L, L) 28 | # emb = self.emb(ib) #(B, L, L, d_model) 29 | # x = x + emb # add relative positional encoding 30 | # return x 31 | 32 | class MSA_emb(nn.Module): 33 | # Get initial seed MSA embedding 34 | def __init__(self, d_msa=256, d_pair=128, d_state=32, d_init=2*NAATOKENS+2+2, 35 | minpos=-32, maxpos=32, p_drop=0.1): 36 | super(MSA_emb, self).__init__() 37 | self.emb = nn.Linear(d_init, d_msa) # embedding for general MSA 38 | self.emb_q = nn.Embedding(NAATOKENS, d_msa) # embedding for query sequence -- used for MSA embedding 39 | self.emb_left = nn.Embedding(NAATOKENS, d_pair) # embedding for query sequence -- used for pair embedding 40 | self.emb_right = nn.Embedding(NAATOKENS, d_pair) # embedding for query sequence -- used for pair embedding 41 | self.emb_state = nn.Embedding(NAATOKENS, d_state) 42 | self.pos = PositionalEncoding2D(d_pair, minpos=minpos, maxpos=maxpos) 43 | 44 | self.reset_parameter() 45 | 46 | def reset_parameter(self): 47 | self.emb = init_lecun_normal(self.emb) 48 | self.emb_q = init_lecun_normal(self.emb_q) 49 | self.emb_left = init_lecun_normal(self.emb_left) 50 | self.emb_right = init_lecun_normal(self.emb_right) 51 | self.emb_state = init_lecun_normal(self.emb_state) 52 | 53 | nn.init.zeros_(self.emb.bias) 54 | 55 | def forward(self, msa, seq, idx, same_chain): 56 | # Inputs: 57 | # - msa: Input MSA (B, N, L, d_init) 58 | # - seq: Input Sequence (B, L) 59 | # - idx: Residue index 60 | # Outputs: 61 | # - msa: Initial MSA embedding (B, N, L, d_msa) 62 | # - pair: Initial Pair embedding (B, L, L, d_pair) 63 | 64 | N = msa.shape[1] # number of sequenes in MSA 65 | 66 | # msa embedding 67 | msa = self.emb(msa) # (B, N, L, d_model) # MSA embedding 68 | tmp = self.emb_q(seq).unsqueeze(1) # (B, 1, L, d_model) -- query embedding 69 | msa = msa + tmp.expand(-1, N, -1, -1) # adding query embedding to MSA 70 | #msa = self.drop(msa) 71 | 72 | # pair embedding 73 | left = self.emb_left(seq)[:,None] # (B, 1, L, d_pair) 74 | right = self.emb_right(seq)[:,:,None] # (B, L, 1, d_pair) 75 | pair = left + right # (B, L, L, d_pair) 76 | #pair = self.pos(pair, idx) # add relative position 77 | pair = pair + self.pos(idx, same_chain) # add relative position 78 | 79 | # state embedding 80 | state = self.emb_state(seq) 81 | 82 | return msa, pair, state 83 | 84 | class Extra_emb(nn.Module): 85 | # Get initial seed MSA embedding 86 | def __init__(self, d_msa=256, d_init=NAATOKENS+1+2, p_drop=0.1): 87 | super(Extra_emb, self).__init__() 88 | self.emb = nn.Linear(d_init, d_msa) # embedding for general MSA 89 | self.emb_q = nn.Embedding(NAATOKENS, d_msa) # embedding for query sequence 90 | #self.drop = nn.Dropout(p_drop) 91 | 92 | self.reset_parameter() 93 | 94 | def reset_parameter(self): 95 | self.emb = init_lecun_normal(self.emb) 96 | nn.init.zeros_(self.emb.bias) 97 | 98 | def forward(self, msa, seq, idx): 99 | # Inputs: 100 | # - msa: Input MSA (B, N, L, d_init) 101 | # - seq: Input Sequence (B, L) 102 | # - idx: Residue index 103 | # Outputs: 104 | # - msa: Initial MSA embedding (B, N, L, d_msa) 105 | N = msa.shape[1] # number of sequenes in MSA 106 | msa = self.emb(msa) # (B, N, L, d_model) # MSA embedding 107 | seq = self.emb_q(seq).unsqueeze(1) # (B, 1, L, d_model) -- query embedding 108 | msa = msa + seq.expand(-1, N, -1, -1) # adding query embedding to MSA 109 | #return self.drop(msa) 110 | return (msa) 111 | 112 | # TODO: Update template embedding not to use triangles.... 113 | # Use input xyz_t with biased attention 114 | class TemplatePairStack(nn.Module): 115 | # process template pairwise features 116 | # use structure-biased attention 117 | def __init__(self, n_block=2, d_templ=64, n_head=4, d_hidden=16, p_drop=0.25): 118 | super(TemplatePairStack, self).__init__() 119 | self.n_block = n_block 120 | proc_s = [PairStr2Pair(d_pair=d_templ, n_head=n_head, d_hidden=d_hidden, p_drop=p_drop) for i in range(n_block)] 121 | self.block = nn.ModuleList(proc_s) 122 | self.norm = nn.LayerNorm(d_templ) 123 | 124 | def forward(self, templ, rbf_feat, use_checkpoint=False): 125 | B, T, L = templ.shape[:3] 126 | templ = templ.reshape(B*T, L, L, -1) 127 | 128 | for i_block in range(self.n_block): 129 | if use_checkpoint: 130 | templ = checkpoint.checkpoint(create_custom_forward(self.block[i_block]), templ, rbf_feat) 131 | else: 132 | templ = self.block[i_block](templ, rbf_feat) 133 | return self.norm(templ).reshape(B, T, L, L, -1) 134 | 135 | 136 | class Templ_emb(nn.Module): 137 | # Get template embedding 138 | # Features are 139 | # t2d: 140 | # - 37 distogram bins + 6 orientations (43) 141 | # - Mask (missing/unaligned) (1) 142 | # t1d: 143 | # - tiled AA sequence (20 standard aa + gap) 144 | # - confidence (1) 145 | # 146 | def __init__(self, d_t1d=(NAATOKENS-1)+1, d_t2d=43+1, d_tor=3*NTOTALDOFS, d_pair=128, d_state=32, 147 | n_block=2, d_templ=64, 148 | n_head=4, d_hidden=16, p_drop=0.25): 149 | super(Templ_emb, self).__init__() 150 | # process 2D features 151 | self.emb = nn.Linear(d_t1d*2+d_t2d, d_templ) 152 | self.templ_stack = TemplatePairStack(n_block=n_block, d_templ=d_templ, n_head=n_head, 153 | d_hidden=d_hidden, p_drop=p_drop) 154 | 155 | self.attn = Attention(d_pair, d_templ, n_head, d_hidden, d_pair, p_drop=p_drop) 156 | 157 | # process torsion angles 158 | self.proj_t1d = nn.Linear(d_t1d+d_tor, d_templ) 159 | self.attn_tor = Attention(d_state, d_templ, n_head, d_hidden, d_state, p_drop=p_drop) 160 | 161 | self.reset_parameter() 162 | 163 | def reset_parameter(self): 164 | self.emb = init_lecun_normal(self.emb) 165 | nn.init.zeros_(self.emb.bias) 166 | 167 | nn.init.kaiming_normal_(self.proj_t1d.weight, nonlinearity='relu') 168 | nn.init.zeros_(self.proj_t1d.bias) 169 | 170 | def _get_templ_emb(self, t1d, t2d): 171 | B, T, L, _ = t1d.shape 172 | # Prepare 2D template features 173 | left = t1d.unsqueeze(3).expand(-1,-1,-1,L,-1) 174 | right = t1d.unsqueeze(2).expand(-1,-1,L,-1,-1) 175 | # 176 | templ = torch.cat((t2d, left, right), -1) # (B, T, L, L, 88) 177 | return self.emb(templ) # Template templures (B, T, L, L, d_templ) 178 | 179 | def _get_templ_rbf(self, xyz_t, mask_t): 180 | B, T, L = xyz_t.shape[:3] 181 | 182 | # process each template features 183 | xyz_t = xyz_t.reshape(B*T, L, 3).contiguous() 184 | mask_t = mask_t.reshape(B*T, L, L) 185 | assert(xyz_t.is_contiguous()) 186 | rbf_feat = rbf(torch.cdist(xyz_t, xyz_t)) * mask_t[...,None] # (B*T, L, L, d_rbf) 187 | return rbf_feat 188 | 189 | def forward(self, t1d, t2d, alpha_t, xyz_t, mask_t, pair, state, use_checkpoint=False): 190 | # Input 191 | # - t1d: 1D template info (B, T, L, 22) 192 | # - t2d: 2D template info (B, T, L, L, 44) 193 | # - alpha_t: torsion angle info (B, T, L, 30) 194 | # - xyz_t: template CA coordinates (B, T, L, 3) 195 | # - mask_t: is valid residue pair? (B, T, L, L) 196 | # - pair: query pair features (B, L, L, d_pair) 197 | # - state: query state features (B, L, d_state) 198 | B, T, L, _ = t1d.shape 199 | 200 | templ = self._get_templ_emb(t1d, t2d) 201 | rbf_feat = self._get_templ_rbf(xyz_t, mask_t) 202 | 203 | # process each template pair feature 204 | templ = self.templ_stack(templ, rbf_feat, use_checkpoint=use_checkpoint) # (B, T, L,L, d_templ) 205 | 206 | # Prepare 1D template torsion angle features 207 | t1d = torch.cat((t1d, alpha_t), dim=-1) # (B, T, L, 22+30) 208 | t1d = self.proj_t1d(t1d) 209 | 210 | # mixing query state features to template state features 211 | state = state.reshape(B*L, 1, -1) # (B*L, 1, d_state) 212 | t1d = t1d.permute(0,2,1,3).reshape(B*L, T, -1) 213 | if use_checkpoint: 214 | out = checkpoint.checkpoint(create_custom_forward(self.attn_tor), state, t1d, t1d) 215 | out = out.reshape(B, L, -1) 216 | else: 217 | out = self.attn_tor(state, t1d, t1d).reshape(B, L, -1) 218 | state = state.reshape(B, L, -1) 219 | state = state + out 220 | 221 | # mixing query pair features to template information (Template pointwise attention) 222 | pair = pair.reshape(B*L*L, 1, -1) 223 | templ = templ.permute(0, 2, 3, 1, 4).reshape(B*L*L, T, -1) 224 | if use_checkpoint: 225 | out = checkpoint.checkpoint(create_custom_forward(self.attn), pair, templ, templ) 226 | out = out.reshape(B, L, L, -1) 227 | else: 228 | out = self.attn(pair, templ, templ).reshape(B, L, L, -1) 229 | # 230 | pair = pair.reshape(B, L, L, -1) 231 | pair = pair + out 232 | 233 | return pair, state 234 | 235 | 236 | class Recycling(nn.Module): 237 | def __init__(self, d_msa=256, d_pair=128, d_state_in=32, d_state_out=32, rbf_sigma=1.0): 238 | super(Recycling, self).__init__() 239 | self.proj_dist = nn.Linear(64+d_state_in*2, d_pair) 240 | self.norm_pair = nn.LayerNorm(d_pair) 241 | self.proj_sctors = nn.Linear(2*NTOTALDOFS, d_msa) 242 | self.norm_msa = nn.LayerNorm(d_msa) 243 | self.rbf_sigma = rbf_sigma 244 | self.norm_state = nn.LayerNorm(d_state_in) 245 | 246 | self.proj_state = None 247 | if (d_state_in != d_state_out): 248 | self.proj_state = nn.Linear(d_state_in, d_state_out) 249 | 250 | self.reset_parameter() 251 | 252 | def reset_parameter(self): 253 | self.proj_dist = init_lecun_normal(self.proj_dist) 254 | nn.init.zeros_(self.proj_dist.bias) 255 | self.proj_sctors = init_lecun_normal(self.proj_sctors) 256 | nn.init.zeros_(self.proj_sctors.bias) 257 | if (self.proj_state is not None): 258 | self.proj_state = init_lecun_normal(self.proj_state) 259 | nn.init.zeros_(self.proj_state.bias) 260 | 261 | def forward(self, msa, pair, xyz, state, sctors): 262 | B, L = pair.shape[:2] 263 | state = self.norm_state(state) 264 | msa = self.norm_msa(msa) 265 | pair = self.norm_pair(pair) 266 | 267 | left = state.unsqueeze(2).expand(-1,-1,L,-1) 268 | right = state.unsqueeze(1).expand(-1,L,-1,-1) 269 | 270 | Ca_or_P = xyz[:,:,1].contiguous() 271 | 272 | dist = rbf(torch.cdist(Ca_or_P, Ca_or_P), self.rbf_sigma) 273 | dist = torch.cat((dist, left, right), dim=-1) 274 | dist = self.proj_dist(dist) 275 | pair = pair + dist 276 | 277 | sctors = self.proj_sctors(sctors.reshape(B,-1,2*NTOTALDOFS)) 278 | msa = sctors + msa 279 | 280 | if (self.proj_state is not None): 281 | state = self.proj_state(state) 282 | 283 | return msa, pair, state 284 | 285 | 286 | -------------------------------------------------------------------------------- /network/RoseTTAFoldModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from Embeddings import MSA_emb, Extra_emb, Templ_emb, Recycling 4 | from Track_module import IterativeSimulator 5 | from AuxiliaryPredictor import DistanceNetwork, MaskedTokenNetwork, LDDTNetwork, PAENetwork, BinderNetwork 6 | from util import INIT_CRDS 7 | from torch import einsum 8 | from chemical import NAATOKENS 9 | 10 | class RoseTTAFoldModule(nn.Module): 11 | def __init__( 12 | self, n_extra_block=4, n_main_block=8, n_ref_block=4,\ 13 | d_msa=256, d_msa_full=64, d_pair=128, d_templ=64, 14 | n_head_msa=8, n_head_pair=4, n_head_templ=4, 15 | d_hidden=32, d_hidden_templ=64, 16 | p_drop=0.15, 17 | SE3_param_full={}, SE3_param_topk={}, 18 | aamask=None, ljlk_parameters=None, lj_correction_parameters=None, num_bonds=None, lj_lin=0.6, 19 | hbtypes=None, hbbaseatoms=None, hbpolys=None 20 | ): 21 | super(RoseTTAFoldModule, self).__init__() 22 | # 23 | # Input Embeddings 24 | d_state = SE3_param_topk['l0_out_features'] 25 | self.latent_emb = MSA_emb(d_msa=d_msa, d_pair=d_pair, d_state=d_state, p_drop=p_drop) 26 | self.full_emb = Extra_emb(d_msa=d_msa_full, d_init=NAATOKENS-1+4, p_drop=p_drop) 27 | self.templ_emb = Templ_emb(d_pair=d_pair, d_templ=d_templ, d_state=d_state, 28 | n_head=n_head_templ, 29 | d_hidden=d_hidden_templ, p_drop=0.25) 30 | # Update inputs with outputs from previous round 31 | self.recycle = Recycling(d_msa=d_msa, d_pair=d_pair, d_state_in=d_state, d_state_out=d_state) 32 | 33 | # 34 | self.simulator = IterativeSimulator( 35 | n_extra_block=n_extra_block, 36 | n_main_block=n_main_block, 37 | n_ref_block=n_ref_block, 38 | d_msa=d_msa, d_msa_full=d_msa_full, 39 | d_pair=d_pair, d_hidden=d_hidden, 40 | n_head_msa=n_head_msa, 41 | n_head_pair=n_head_pair, 42 | SE3_param_full=SE3_param_full, 43 | SE3_param_topk=SE3_param_topk, 44 | p_drop=p_drop, 45 | aamask=aamask, 46 | ljlk_parameters=ljlk_parameters, 47 | lj_correction_parameters=lj_correction_parameters, 48 | num_bonds=num_bonds, 49 | lj_lin=lj_lin, 50 | hbtypes=hbtypes, 51 | hbbaseatoms=hbbaseatoms, 52 | hbpolys=hbpolys 53 | ) 54 | 55 | ## 56 | self.c6d_pred = DistanceNetwork(d_pair, p_drop=p_drop) 57 | self.aa_pred = MaskedTokenNetwork(d_msa, p_drop=p_drop) 58 | self.lddt_pred = LDDTNetwork(d_state) 59 | self.pae_pred = PAENetwork(d_pair+2*d_state) 60 | self.bind_pred = BinderNetwork() #fd - expose n_hidden as variable? 61 | 62 | def forward( 63 | self, msa_latent, msa_full, seq, seq_unmasked, xyz, sctors, idx, 64 | t1d=None, t2d=None, xyz_t=None, alpha_t=None, mask_t=None, same_chain=None, 65 | msa_prev=None, pair_prev=None, state_prev=None, 66 | return_raw=False, return_full=False, 67 | use_checkpoint=False 68 | ): 69 | B, N, L = msa_latent.shape[:3] 70 | 71 | # Get embeddings 72 | msa_latent, pair, state = self.latent_emb(msa_latent, seq, idx, same_chain) 73 | msa_full = self.full_emb(msa_full, seq, idx) 74 | # 75 | # Do recycling 76 | if msa_prev == None: 77 | msa_prev = torch.zeros_like(msa_latent[:,0]) 78 | pair_prev = torch.zeros_like(pair) 79 | state_prev = torch.zeros_like(state) 80 | 81 | msa_recycle, pair_recycle, state_recycle = self.recycle(msa_prev, pair_prev, xyz, state_prev, sctors) 82 | msa_latent[:,0] = msa_latent[:,0] + msa_recycle.reshape(B,L,-1) 83 | pair = pair + pair_recycle 84 | state = state + state_recycle 85 | 86 | # 87 | # add template embedding 88 | pair, state = self.templ_emb(t1d, t2d, alpha_t, xyz_t, mask_t, pair, state, use_checkpoint=use_checkpoint) 89 | 90 | # Predict coordinates from given inputs 91 | msa, pair, xyz, alpha, xyzallatom, state = self.simulator( 92 | seq_unmasked, msa_latent, msa_full, pair, xyz[:,:,:3], state, idx, same_chain, use_checkpoint=use_checkpoint) 93 | 94 | if return_raw: 95 | # get last structure 96 | return msa[:,0], pair, xyz[-1], state, alpha[-1] 97 | 98 | # predict masked amino acids 99 | logits_aa = self.aa_pred(msa) 100 | # 101 | # predict distogram & orientograms 102 | logits = self.c6d_pred(pair) 103 | 104 | # Predict LDDT 105 | lddt = self.lddt_pred(state) 106 | 107 | # predict PAE 108 | logits_pae = self.pae_pred(pair, state) 109 | 110 | # predict bind/no-bind 111 | p_bind = self.bind_pred(logits_pae,same_chain) 112 | 113 | return logits, logits_aa, logits_pae, p_bind, xyz, alpha, xyzallatom, lddt, msa[:,0], pair, state 114 | -------------------------------------------------------------------------------- /network/SE3_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | #from equivariant_attention.modules import get_basis_and_r, GSE3Res, GNormBias 5 | #from equivariant_attention.modules import GConvSE3, GNormSE3 6 | #from equivariant_attention.fibers import Fiber 7 | 8 | from util_module import init_lecun_normal_param 9 | from se3_transformer.model import SE3Transformer 10 | from se3_transformer.model.fiber import Fiber 11 | 12 | class SE3TransformerWrapper(nn.Module): 13 | """SE(3) equivariant GCN with attention""" 14 | def __init__(self, num_layers=2, num_channels=32, num_degrees=3, n_heads=4, div=4, 15 | l0_in_features=32, l0_out_features=32, 16 | l1_in_features=3, l1_out_features=2, 17 | num_edge_features=32): 18 | super().__init__() 19 | # Build the network 20 | self.l1_in = l1_in_features 21 | self.l1_out = l1_out_features 22 | # 23 | fiber_edge = Fiber({0: num_edge_features}) 24 | if l1_out_features > 0: 25 | if l1_in_features > 0: 26 | fiber_in = Fiber({0: l0_in_features, 1: l1_in_features}) 27 | fiber_hidden = Fiber.create(num_degrees, num_channels) 28 | fiber_out = Fiber({0: l0_out_features, 1: l1_out_features}) 29 | else: 30 | fiber_in = Fiber({0: l0_in_features}) 31 | fiber_hidden = Fiber.create(num_degrees, num_channels) 32 | fiber_out = Fiber({0: l0_out_features, 1: l1_out_features}) 33 | else: 34 | if l1_in_features > 0: 35 | fiber_in = Fiber({0: l0_in_features, 1: l1_in_features}) 36 | fiber_hidden = Fiber.create(num_degrees, num_channels) 37 | fiber_out = Fiber({0: l0_out_features}) 38 | else: 39 | fiber_in = Fiber({0: l0_in_features}) 40 | fiber_hidden = Fiber.create(num_degrees, num_channels) 41 | fiber_out = Fiber({0: l0_out_features}) 42 | 43 | self.se3 = SE3Transformer(num_layers=num_layers, 44 | fiber_in=fiber_in, 45 | fiber_hidden=fiber_hidden, 46 | fiber_out = fiber_out, 47 | num_heads=n_heads, 48 | channels_div=div, 49 | fiber_edge=fiber_edge, 50 | #populate_edges=False, 51 | #sum_over_edge=False, 52 | use_layer_norm=True, 53 | tensor_cores=True, 54 | low_memory=True)#, 55 | #populate_edge='log') 56 | 57 | self.reset_parameter() 58 | 59 | def reset_parameter(self): 60 | 61 | # make sure linear layer before ReLu are initialized with kaiming_normal_ 62 | for n, p in self.se3.named_parameters(): 63 | if "bias" in n: 64 | nn.init.zeros_(p) 65 | elif len(p.shape) == 1: 66 | continue 67 | else: 68 | if "radial_func" not in n: 69 | p = init_lecun_normal_param(p) 70 | else: 71 | if "net.6" in n: 72 | nn.init.zeros_(p) 73 | else: 74 | nn.init.kaiming_normal_(p, nonlinearity='relu') 75 | 76 | # make last layers to be zero-initialized 77 | self.se3.graph_modules[-1].to_kernel_self['0'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['0']) 78 | self.se3.graph_modules[-1].to_kernel_self['1'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['1']) 79 | nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['0']) 80 | if self.l1_out > 0: 81 | nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['1']) 82 | 83 | def forward(self, G, type_0_features, type_1_features=None, edge_features=None): 84 | if self.l1_in > 0: 85 | node_features = {'0': type_0_features, '1': type_1_features} 86 | else: 87 | node_features = {'0': type_0_features} 88 | edge_features = {'0': edge_features} 89 | return self.se3(G, node_features, edge_features) 90 | -------------------------------------------------------------------------------- /network/arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import data_loader 3 | import os 4 | 5 | TRUNK_PARAMS = ['n_extra_block', 'n_main_block', 'n_ref_block',\ 6 | 'd_msa', 'd_msa_full', 'd_pair', 'd_templ',\ 7 | 'n_head_msa', 'n_head_pair', 'n_head_templ', 'd_hidden', 'd_hidden_templ', 'p_drop'] 8 | 9 | SE3_PARAMS = ['num_layers', 'num_channels', 'num_degrees', 'n_heads', 'div', 10 | 'l0_in_features', 'l0_out_features', 'l1_in_features', 'l1_out_features', 'num_edge_features' 11 | ] 12 | 13 | def get_args(): 14 | parser = argparse.ArgumentParser() 15 | 16 | # training parameters 17 | train_group = parser.add_argument_group("training parameters") 18 | train_group.add_argument("-model_name", default=None, 19 | help="model name for saving") 20 | train_group.add_argument('-batch_size', type=int, default=1, 21 | help="Batch size [1]") 22 | train_group.add_argument('-lr', type=float, default=2.0e-4, 23 | help="Learning rate [5.0e-4]") 24 | train_group.add_argument('-num_epochs', type=int, default=300, 25 | help="Number of epochs [300]") 26 | train_group.add_argument("-step_lr", type=int, default=300, 27 | help="Parameter for Step LR scheduler [300]") 28 | train_group.add_argument("-port", type=int, default=12319, 29 | help="PORT for ddp training, should be randomized [12319]") 30 | train_group.add_argument("-accum", type=int, default=1, 31 | help="Gradient accumulation when it's > 1 [1]") 32 | train_group.add_argument("-eval", action='store_true', default=False, 33 | help="Train structure only") 34 | 35 | # data-loading parameters 36 | data_group = parser.add_argument_group("data loading parameters") 37 | data_group.add_argument('-maxseq', type=int, default=1024, 38 | help="Maximum depth of subsampled MSA [1024]") 39 | data_group.add_argument('-maxtoken', type=int, default=2**18, 40 | help="Maximum depth of subsampled MSA [2**18]") 41 | data_group.add_argument('-maxlat', type=int, default=128, 42 | help="Maximum depth of subsampled MSA [128]") 43 | data_group.add_argument("-crop", type=int, default=260, 44 | help="Upper limit of crop size [260]") 45 | data_group.add_argument("-rescut", type=float, default=4.5, 46 | help="Resolution cutoff [4.5]") 47 | data_group.add_argument("-slice", type=str, default="DISCONT", 48 | help="How to make crops [CONT / DISCONT (default)]") 49 | data_group.add_argument("-subsmp", type=str, default="UNI", 50 | help="How to subsample MSAs [UNI (default) / LOG / CONST]") 51 | data_group.add_argument('-mintplt', type=int, default=1, 52 | help="Minimum number of templates to select [1]") 53 | data_group.add_argument('-maxtplt', type=int, default=4, 54 | help="maximum number of templates to select [4]") 55 | data_group.add_argument('-seqid', type=float, default=150.0, 56 | help="maximum sequence identity cutoff for template selection [150.0]") 57 | data_group.add_argument('-maxcycle', type=int, default=4, 58 | help="maximum number of recycle [4]") 59 | 60 | # Trunk module properties 61 | trunk_group = parser.add_argument_group("Trunk module parameters") 62 | trunk_group.add_argument('-n_extra_block', type=int, default=4, 63 | help="Number of iteration blocks for extra sequences [4]") 64 | trunk_group.add_argument('-n_main_block', type=int, default=8, 65 | help="Number of iteration blocks for main sequences [8]") 66 | trunk_group.add_argument('-n_ref_block', type=int, default=4, 67 | help="Number of refinement layers") 68 | trunk_group.add_argument('-d_msa', type=int, default=256, 69 | help="Number of MSA features [256]") 70 | trunk_group.add_argument('-d_msa_full', type=int, default=64, 71 | help="Number of MSA features [64]") 72 | trunk_group.add_argument('-d_pair', type=int, default=128, 73 | help="Number of pair features [128]") 74 | trunk_group.add_argument('-d_templ', type=int, default=64, 75 | help="Number of templ features [64]") 76 | trunk_group.add_argument('-n_head_msa', type=int, default=8, 77 | help="Number of attention heads for MSA2MSA [8]") 78 | trunk_group.add_argument('-n_head_pair', type=int, default=4, 79 | help="Number of attention heads for Pair2Pair [4]") 80 | trunk_group.add_argument('-n_head_templ', type=int, default=4, 81 | help="Number of attention heads for template [4]") 82 | trunk_group.add_argument("-d_hidden", type=int, default=32, 83 | help="Number of hidden features [32]") 84 | trunk_group.add_argument("-d_hidden_templ", type=int, default=64, 85 | help="Number of hidden features for templates [64]") 86 | trunk_group.add_argument("-p_drop", type=float, default=0.15, 87 | help="Dropout ratio [0.15]") 88 | 89 | # Structure module properties 90 | str_group = parser.add_argument_group("structure module parameters") 91 | str_group.add_argument('-num_layers', type=int, default=1, 92 | help="Number of equivariant layers in structure module block [1]") 93 | str_group.add_argument('-num_channels', type=int, default=32, 94 | help="Number of channels [32]") 95 | str_group.add_argument('-num_degrees', type=int, default=2, 96 | help="Number of degrees for SE(3) network [2]") 97 | str_group.add_argument('-l0_in_features', type=int, default=64, 98 | help="Number of type 0 input features [64]") 99 | str_group.add_argument('-l0_out_features', type=int, default=64, 100 | help="Number of type 0 output features [64]") 101 | str_group.add_argument('-l1_in_features', type=int, default=3, 102 | help="Number of type 1 input features [3]") 103 | str_group.add_argument('-l1_out_features', type=int, default=2, 104 | help="Number of type 1 output features [2]") 105 | str_group.add_argument('-num_edge_features', type=int, default=64, 106 | help="Number of edge features [64]") 107 | str_group.add_argument('-n_heads', type=int, default=4, 108 | help="Number of attention heads for SE3-Transformer [4]") 109 | str_group.add_argument("-div", type=int, default=4, 110 | help="Div parameter for SE3-Transformer [4]") 111 | str_group.add_argument('-ref_num_layers', type=int, default=2, 112 | help="Number of equivariant layers in structure module block [2]") 113 | str_group.add_argument('-ref_num_channels', type=int, default=32, 114 | help="Number of channels [32]") 115 | str_group.add_argument('-ref_l0_in_features', type=int, default=64, 116 | help="Number of channels [64]") 117 | str_group.add_argument('-ref_l0_out_features', type=int, default=64, 118 | help="Number of channels [64]") 119 | 120 | # Loss function parameters 121 | loss_group = parser.add_argument_group("loss parameters") 122 | loss_group.add_argument('-w_dist', type=float, default=1.0, 123 | help="Weight on distd in loss function [1.0]") 124 | loss_group.add_argument('-w_str', type=float, default=10.0, 125 | help="Weight on strd in loss function [10.0]") 126 | loss_group.add_argument('-w_lddt', type=float, default=0.1, 127 | help="Weight on predicted lddt loss [0.1]") 128 | loss_group.add_argument('-w_aa', type=float, default=3.0, 129 | help="Weight on MSA masked token prediction loss [3.0]") 130 | loss_group.add_argument('-w_bond', type=float, default=0.0, 131 | help="Weight on predicted bond loss [0.0]") 132 | loss_group.add_argument('-w_dih', type=float, default=0.0, 133 | help="Weight on pseudodihedral loss [0.0]") 134 | loss_group.add_argument('-w_clash', type=float, default=0.0, 135 | help="Weight on clash loss [0.0]") 136 | loss_group.add_argument('-w_hb', type=float, default=0.0, 137 | help="Weight on clash loss [0.0]") 138 | loss_group.add_argument('-w_pae', type=float, default=0.1, 139 | help="Weight on pae loss [0.1]") 140 | loss_group.add_argument('-w_bind', type=float, default=5.0, 141 | help="Weight on bind v no-bind prediction [5.0]") 142 | loss_group.add_argument('-lj_lin', type=float, default=0.75, 143 | help="linear inflection for lj [0.75]") 144 | 145 | # parse arguments 146 | args = parser.parse_args() 147 | 148 | # Setup dataloader parameters: 149 | loader_param = data_loader.set_data_loader_params(args) 150 | 151 | # make dictionary for each parameters 152 | trunk_param = {} 153 | for param in TRUNK_PARAMS: 154 | trunk_param[param] = getattr(args, param) 155 | SE3_param = {} 156 | for param in SE3_PARAMS: 157 | if hasattr(args, param): 158 | SE3_param[param] = getattr(args, param) 159 | 160 | SE3_ref_param = SE3_param.copy() 161 | 162 | for param in SE3_PARAMS: 163 | if hasattr(args, 'ref_'+param): 164 | SE3_ref_param[param] = getattr(args, 'ref_'+param) 165 | 166 | #print (SE3_param) 167 | #print (SE3_ref_param) 168 | trunk_param['SE3_param_full'] = SE3_param 169 | trunk_param['SE3_param_topk'] = SE3_ref_param 170 | 171 | loss_param = {} 172 | for param in ['w_dist', 'w_str', 'w_aa', 'w_lddt', 'w_bond', 'w_dih', 'w_clash', 'w_hb', 'w_pae', 'lj_lin']: 173 | loss_param[param] = getattr(args, param) 174 | 175 | return args, trunk_param, loader_param, loss_param 176 | -------------------------------------------------------------------------------- /network/coords6d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | import scipy.spatial 4 | from util import generate_Cbeta 5 | 6 | # calculate dihedral angles defined by 4 sets of points 7 | def get_dihedrals(a, b, c, d): 8 | 9 | b0 = -1.0*(b - a) 10 | b1 = c - b 11 | b2 = d - c 12 | 13 | b1 /= np.linalg.norm(b1, axis=-1)[:,None] 14 | 15 | v = b0 - np.sum(b0*b1, axis=-1)[:,None]*b1 16 | w = b2 - np.sum(b2*b1, axis=-1)[:,None]*b1 17 | 18 | x = np.sum(v*w, axis=-1) 19 | y = np.sum(np.cross(b1, v)*w, axis=-1) 20 | 21 | return np.arctan2(y, x) 22 | 23 | # calculate planar angles defined by 3 sets of points 24 | def get_angles(a, b, c): 25 | 26 | v = a - b 27 | v /= np.linalg.norm(v, axis=-1)[:,None] 28 | 29 | w = c - b 30 | w /= np.linalg.norm(w, axis=-1)[:,None] 31 | 32 | x = np.sum(v*w, axis=1) 33 | 34 | #return np.arccos(x) 35 | return np.arccos(np.clip(x, -1.0, 1.0)) 36 | 37 | # get 6d coordinates from x,y,z coords of N,Ca,C atoms 38 | def get_coords6d(xyz, dmax): 39 | 40 | nres = xyz.shape[1] 41 | 42 | # three anchor atoms 43 | N = xyz[0] 44 | Ca = xyz[1] 45 | C = xyz[2] 46 | 47 | # recreate Cb given N,Ca,C 48 | Cb = generate_Cbeta(N,Ca,C) 49 | 50 | # fast neighbors search to collect all 51 | # Cb-Cb pairs within dmax 52 | kdCb = scipy.spatial.cKDTree(Cb) 53 | indices = kdCb.query_ball_tree(kdCb, dmax) 54 | 55 | # indices of contacting residues 56 | idx = np.array([[i,j] for i in range(len(indices)) for j in indices[i] if i != j]).T 57 | idx0 = idx[0] 58 | idx1 = idx[1] 59 | 60 | # Cb-Cb distance matrix 61 | dist6d = np.full((nres, nres),999.9, dtype=np.float32) 62 | dist6d[idx0,idx1] = np.linalg.norm(Cb[idx1]-Cb[idx0], axis=-1) 63 | 64 | # matrix of Ca-Cb-Cb-Ca dihedrals 65 | omega6d = np.zeros((nres, nres), dtype=np.float32) 66 | omega6d[idx0,idx1] = get_dihedrals(Ca[idx0], Cb[idx0], Cb[idx1], Ca[idx1]) 67 | 68 | # matrix of polar coord theta 69 | theta6d = np.zeros((nres, nres), dtype=np.float32) 70 | theta6d[idx0,idx1] = get_dihedrals(N[idx0], Ca[idx0], Cb[idx0], Cb[idx1]) 71 | 72 | # matrix of polar coord phi 73 | phi6d = np.zeros((nres, nres), dtype=np.float32) 74 | phi6d[idx0,idx1] = get_angles(Ca[idx0], Cb[idx0], Cb[idx1]) 75 | 76 | mask = np.zeros((nres, nres), dtype=np.float32) 77 | mask[idx0, idx1] = 1.0 78 | return dist6d, omega6d, theta6d, phi6d, mask 79 | -------------------------------------------------------------------------------- /network/ffindex.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # https://raw.githubusercontent.com/ahcm/ffindex/master/python/ffindex.py 3 | 4 | ''' 5 | Created on Apr 30, 2014 6 | 7 | @author: meiermark 8 | ''' 9 | 10 | 11 | import sys 12 | import mmap 13 | from collections import namedtuple 14 | 15 | FFindexEntry = namedtuple("FFindexEntry", "name, offset, length") 16 | 17 | 18 | def read_index(ffindex_filename): 19 | entries = [] 20 | 21 | fh = open(ffindex_filename) 22 | for line in fh: 23 | tokens = line.split("\t") 24 | entries.append(FFindexEntry(tokens[0], int(tokens[1]), int(tokens[2]))) 25 | fh.close() 26 | 27 | return entries 28 | 29 | 30 | def read_data(ffdata_filename): 31 | fh = open(ffdata_filename, "rb") 32 | data = mmap.mmap(fh.fileno(), 0, prot=mmap.PROT_READ) 33 | fh.close() 34 | return data 35 | 36 | 37 | def get_entry_by_name(name, index): 38 | #TODO: bsearch 39 | for entry in index: 40 | if(name == entry.name): 41 | return entry 42 | return None 43 | 44 | 45 | def read_entry_lines(entry, data): 46 | lines = data[entry.offset:entry.offset + entry.length - 1].decode("utf-8").split("\n") 47 | return lines 48 | 49 | 50 | def read_entry_data(entry, data): 51 | return data[entry.offset:entry.offset + entry.length - 1] 52 | 53 | 54 | def write_entry(entries, data_fh, entry_name, offset, data): 55 | data_fh.write(data[:-1]) 56 | data_fh.write(bytearray(1)) 57 | 58 | entry = FFindexEntry(entry_name, offset, len(data)) 59 | entries.append(entry) 60 | 61 | return offset + len(data) 62 | 63 | 64 | def write_entry_with_file(entries, data_fh, entry_name, offset, file_name): 65 | with open(file_name, "rb") as fh: 66 | data = bytearray(fh.read()) 67 | return write_entry(entries, data_fh, entry_name, offset, data) 68 | 69 | 70 | def finish_db(entries, ffindex_filename, data_fh): 71 | data_fh.close() 72 | write_entries_to_db(entries, ffindex_filename) 73 | 74 | 75 | def write_entries_to_db(entries, ffindex_filename): 76 | sorted(entries, key=lambda x: x.name) 77 | index_fh = open(ffindex_filename, "w") 78 | 79 | for entry in entries: 80 | index_fh.write("{name:.64}\t{offset}\t{length}\n".format(name=entry.name, offset=entry.offset, length=entry.length)) 81 | 82 | index_fh.close() 83 | 84 | 85 | def write_entry_to_file(entry, data, file): 86 | lines = read_lines(entry, data) 87 | 88 | fh = open(file, "w") 89 | for line in lines: 90 | fh.write(line+"\n") 91 | fh.close() 92 | -------------------------------------------------------------------------------- /network/kinematics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from util import INIT_CRDS, INIT_NA_CRDS, generate_Cbeta, is_nucleic 4 | from chemical import NTOTAL 5 | 6 | PARAMS = { 7 | "DMIN" : 2.0, 8 | "DMAX" : 20.0, 9 | "DBINS" : 36, 10 | "ABINS" : 36, 11 | } 12 | 13 | # ============================================================ 14 | def get_pair_dist(a, b): 15 | """calculate pair distances between two sets of points 16 | 17 | Parameters 18 | ---------- 19 | a,b : pytorch tensors of shape [batch,nres,3] 20 | store Cartesian coordinates of two sets of atoms 21 | Returns 22 | ------- 23 | dist : pytorch tensor of shape [batch,nres,nres] 24 | stores paitwise distances between atoms in a and b 25 | """ 26 | 27 | assert(a.is_contiguous()) 28 | assert(b.is_contiguous()) 29 | dist = torch.cdist(a, b, p=2) 30 | return dist 31 | 32 | # ============================================================ 33 | def get_ang(a, b, c, eps=1e-6): 34 | """calculate planar angles for all consecutive triples (a[i],b[i],c[i]) 35 | from Cartesian coordinates of three sets of atoms a,b,c 36 | 37 | Parameters 38 | ---------- 39 | a,b,c : pytorch tensors of shape [batch,nres,3] 40 | store Cartesian coordinates of three sets of atoms 41 | Returns 42 | ------- 43 | ang : pytorch tensor of shape [batch,nres] 44 | stores resulting planar angles 45 | """ 46 | v = a - b 47 | w = c - b 48 | vn = v / (torch.norm(v, dim=-1, keepdim=True)+eps) 49 | wn = w / (torch.norm(w, dim=-1, keepdim=True)+eps) 50 | vw = torch.sum(vn*wn, dim=-1) 51 | 52 | return torch.acos(torch.clamp(vw,-0.999,0.999)) 53 | 54 | # ============================================================ 55 | def get_dih(a, b, c, d, eps=1e-6): 56 | """calculate dihedral angles for all consecutive quadruples (a[i],b[i],c[i],d[i]) 57 | given Cartesian coordinates of four sets of atoms a,b,c,d 58 | 59 | Parameters 60 | ---------- 61 | a,b,c,d : pytorch tensors of shape [batch,nres,3] 62 | store Cartesian coordinates of four sets of atoms 63 | Returns 64 | ------- 65 | dih : pytorch tensor of shape [batch,nres] 66 | stores resulting dihedrals 67 | """ 68 | b0 = a - b 69 | b1 = c - b 70 | b2 = d - c 71 | 72 | b1n = b1 / (torch.norm(b1, dim=-1, keepdim=True) + eps) 73 | 74 | v = b0 - torch.sum(b0*b1n, dim=-1, keepdim=True)*b1n 75 | w = b2 - torch.sum(b2*b1n, dim=-1, keepdim=True)*b1n 76 | 77 | x = torch.sum(v*w, dim=-1) 78 | y = torch.sum(torch.cross(b1n,v,dim=-1)*w, dim=-1) 79 | 80 | return torch.atan2(y+eps, x+eps) 81 | 82 | 83 | # ============================================================ 84 | def xyz_to_c6d(xyz, params=PARAMS): 85 | """convert cartesian coordinates into 2d distance 86 | and orientation maps 87 | 88 | Parameters 89 | ---------- 90 | xyz : pytorch tensor of shape [batch,nres,3,3] 91 | stores Cartesian coordinates of backbone N,Ca,C atoms 92 | Returns 93 | ------- 94 | c6d : pytorch tensor of shape [batch,nres,nres,4] 95 | stores stacked dist,omega,theta,phi 2D maps 96 | """ 97 | 98 | batch = xyz.shape[0] 99 | nres = xyz.shape[1] 100 | 101 | N = xyz[:,:,0] 102 | Ca = xyz[:,:,1] 103 | C = xyz[:,:,2] 104 | Cb = generate_Cbeta(N,Ca,C) # note that this doesn't really make sense for NA 105 | 106 | # 6d coordinates order: (dist,omega,theta,phi) 107 | c6d = torch.zeros([batch,nres,nres,4],dtype=xyz.dtype,device=xyz.device) 108 | 109 | dist = get_pair_dist(Cb,Cb) 110 | c6d[...,0] = dist + 999.9*torch.eye(nres,device=xyz.device)[None,...] 111 | b,i,j = torch.where(c6d[...,0]=params['DMAX']] = 999.9 119 | c6d = torch.nan_to_num(c6d) 120 | 121 | return c6d 122 | 123 | def xyz_to_t2d(xyz_t, mask, params=PARAMS): 124 | """convert template cartesian coordinates into 2d distance 125 | and orientation maps 126 | 127 | Parameters 128 | ---------- 129 | xyz_t : pytorch tensor of shape [batch,templ,nres,natm,3] 130 | stores Cartesian coordinates of template backbone N,Ca,C atoms 131 | mask: pytorch tensor of shape [batch,templ,nrres,nres] 132 | indicates whether valid residue pairs or not 133 | Returns 134 | ------- 135 | t2d : pytorch tensor of shape [batch,nres,nres,37+6+1] 136 | stores stacked dist,omega,theta,phi 2D maps 137 | """ 138 | B, T, L = xyz_t.shape[:3] 139 | c6d = xyz_to_c6d(xyz_t[:,:,:,:3].view(B*T,L,3,3), params=params) 140 | c6d = c6d.view(B, T, L, L, 4) 141 | 142 | # dist to one-hot encoded 143 | mask = mask[...,None] 144 | dist = dist_to_onehot(c6d[...,0], params)*mask 145 | orien = torch.cat((torch.sin(c6d[...,1:]), torch.cos(c6d[...,1:])), dim=-1)*mask # (B, T, L, L, 6) 146 | # 147 | t2d = torch.cat((dist, orien, mask), dim=-1) 148 | return t2d 149 | 150 | def xyz_to_bbtor(xyz, params=PARAMS): 151 | batch = xyz.shape[0] 152 | nres = xyz.shape[1] 153 | 154 | # three anchor atoms 155 | N = xyz[:,:,0] 156 | Ca = xyz[:,:,1] 157 | C = xyz[:,:,2] 158 | 159 | # recreate Cb given N,Ca,C 160 | next_N = torch.roll(N, -1, dims=1) 161 | prev_C = torch.roll(C, 1, dims=1) 162 | phi = get_dih(prev_C, N, Ca, C) 163 | psi = get_dih(N, Ca, C, next_N) 164 | # 165 | phi[:,0] = 0.0 166 | psi[:,-1] = 0.0 167 | # 168 | astep = 2.0*np.pi / params['ABINS'] 169 | phi_bin = torch.round((phi+np.pi-astep/2)/astep) 170 | psi_bin = torch.round((psi+np.pi-astep/2)/astep) 171 | return torch.stack([phi_bin, psi_bin], axis=-1).long() 172 | 173 | # ============================================================ 174 | def dist_to_onehot(dist, params=PARAMS): 175 | dist[torch.isnan(dist)] = 999.9 176 | dstep = (params['DMAX'] - params['DMIN']) / params['DBINS'] 177 | dbins = torch.linspace(params['DMIN']+dstep, params['DMAX'], params['DBINS'],dtype=dist.dtype,device=dist.device) 178 | db = torch.bucketize(dist.contiguous(),dbins).long() 179 | dist = torch.nn.functional.one_hot(db, num_classes=params['DBINS']+1).float() 180 | return dist 181 | 182 | # ============================================================ 183 | def dist_to_bins(dist,params=PARAMS): 184 | """bin 2d distance maps 185 | """ 186 | 187 | dstep = (params['DMAX'] - params['DMIN']) / params['DBINS'] 188 | db = torch.round((dist-params['DMIN']-dstep/2)/dstep) 189 | 190 | db[db<0] = 0 191 | db[db>params['DBINS']] = params['DBINS'] 192 | 193 | return db.long() 194 | 195 | 196 | # ============================================================ 197 | def c6d_to_bins(c6d, same_chain, negative=False, params=PARAMS): 198 | """bin 2d distance and orientation maps 199 | """ 200 | 201 | dstep = (params['DMAX'] - params['DMIN']) / params['DBINS'] 202 | astep = 2.0*np.pi / params['ABINS'] 203 | 204 | db = torch.round((c6d[...,0]-params['DMIN']-dstep/2)/dstep) 205 | ob = torch.round((c6d[...,1]+np.pi-astep/2)/astep) 206 | tb = torch.round((c6d[...,2]+np.pi-astep/2)/astep) 207 | pb = torch.round((c6d[...,3]-astep/2)/astep) 208 | 209 | # put all dparams['DBINS']] = params['DBINS'] 214 | ob[db==params['DBINS']] = params['ABINS'] 215 | tb[db==params['DBINS']] = params['ABINS'] 216 | pb[db==params['DBINS']] = params['ABINS']//2 217 | 218 | if negative: 219 | db = torch.where(same_chain.bool(), db.long(), params['DBINS']) 220 | ob = torch.where(same_chain.bool(), ob.long(), params['ABINS']) 221 | tb = torch.where(same_chain.bool(), tb.long(), params['ABINS']) 222 | pb = torch.where(same_chain.bool(), pb.long(), params['ABINS']//2) 223 | 224 | return torch.stack([db,ob,tb,pb],axis=-1).long() 225 | -------------------------------------------------------------------------------- /network/models.json: -------------------------------------------------------------------------------- 1 | { 2 | "full_bigSE3": 3 | { 4 | "description": "deep architecture w/ big SE(3)-Transformer on fully connected graph. Trained on biounit", 5 | "model_param":{ 6 | "n_extra_block" : 4, 7 | "n_main_block" : 32, 8 | "n_ref_block" : 4, 9 | "d_msa" : 256 , 10 | "d_pair" : 128, 11 | "d_templ" : 64, 12 | "n_head_msa" : 8, 13 | "n_head_pair" : 4, 14 | "n_head_templ" : 4, 15 | "d_hidden" : 32, 16 | "d_hidden_templ" : 64, 17 | "p_drop" : 0.0, 18 | "lj_lin" : 0.75, 19 | "SE3_param": { 20 | "num_layers" : 1, 21 | "num_channels" : 32, 22 | "num_degrees" : 2, 23 | "l0_in_features": 64, 24 | "l0_out_features": 64, 25 | "l1_in_features": 3, 26 | "l1_out_features": 2, 27 | "num_edge_features": 64, 28 | "div": 4, 29 | "n_heads": 4 30 | } 31 | }, 32 | "weight_fn": ["full_bigSE3_model1.pt", "full_bigSE3_model2.pt", "full_bigSE3_model3.pt"] 33 | }, 34 | "full_smallSE3": 35 | { 36 | "description": "deep architecture w/ small SE(3)-Transformer on fully connected graph. Trained on biounit", 37 | "model_param":{ 38 | "n_extra_block" : 4, 39 | "n_main_block" : 32, 40 | "n_ref_block" : 4, 41 | "d_msa" : 256 , 42 | "d_pair" : 128, 43 | "d_templ" : 64, 44 | "n_head_msa" : 8, 45 | "n_head_pair" : 4, 46 | "n_head_templ" : 4, 47 | "d_hidden" : 32, 48 | "d_hidden_templ" : 32, 49 | "p_drop" : 0.0, 50 | "SE3_param_full": { 51 | "num_layers" : 1, 52 | "num_channels" : 32, 53 | "num_degrees" : 2, 54 | "l0_in_features": 8, 55 | "l0_out_features": 8, 56 | "l1_in_features": 3, 57 | "l1_out_features": 2, 58 | "num_edge_features": 32, 59 | "div": 4, 60 | "n_heads": 4 61 | }, 62 | "SE3_param_topk": { 63 | "num_layers" : 1, 64 | "num_channels" : 32, 65 | "num_degrees" : 2, 66 | "l0_in_features": 64, 67 | "l0_out_features": 64, 68 | "l1_in_features": 3, 69 | "l1_out_features": 2, 70 | "num_edge_features": 64, 71 | "div": 4, 72 | "n_heads": 4 73 | } 74 | }, 75 | "weight_fn": ["full_smallSE3_model1.pt", "full_smallSE3_model2.pt"] 76 | } 77 | } 78 | 79 | -------------------------------------------------------------------------------- /network/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.checkpoint as checkpoint 4 | 5 | # pre-activation bottleneck resblock 6 | class ResBlock2D_bottleneck(nn.Module): 7 | def __init__(self, n_c, kernel=3, dilation=1, p_drop=0.15): 8 | super(ResBlock2D_bottleneck, self).__init__() 9 | padding = self._get_same_padding(kernel, dilation) 10 | 11 | n_b = n_c // 2 # bottleneck channel 12 | 13 | layer_s = list() 14 | # pre-activation 15 | layer_s.append(nn.InstanceNorm2d(n_c, affine=True, eps=1e-6)) 16 | layer_s.append(nn.ELU(inplace=True)) 17 | # project down to n_b 18 | layer_s.append(nn.Conv2d(n_c, n_b, 1, bias=False)) 19 | layer_s.append(nn.InstanceNorm2d(n_b, affine=True, eps=1e-6)) 20 | layer_s.append(nn.ELU(inplace=True)) 21 | # convolution 22 | layer_s.append(nn.Conv2d(n_b, n_b, kernel, dilation=dilation, padding=padding, bias=False)) 23 | layer_s.append(nn.InstanceNorm2d(n_b, affine=True, eps=1e-6)) 24 | layer_s.append(nn.ELU(inplace=True)) 25 | # dropout 26 | layer_s.append(nn.Dropout(p_drop)) 27 | # project up 28 | layer_s.append(nn.Conv2d(n_b, n_c, 1, bias=False)) 29 | 30 | # make final layer initialize with zeros 31 | #nn.init.zeros_(layer_s[-1].weight) 32 | 33 | self.layer = nn.Sequential(*layer_s) 34 | 35 | self.reset_parameter() 36 | 37 | def reset_parameter(self): 38 | # zero-initialize final layer right before residual connection 39 | nn.init.zeros_(self.layer[-1].weight) 40 | 41 | def _get_same_padding(self, kernel, dilation): 42 | return (kernel + (kernel - 1) * (dilation - 1) - 1) // 2 43 | 44 | def forward(self, x): 45 | out = self.layer(x) 46 | return x + out 47 | 48 | class ResidualNetwork(nn.Module): 49 | def __init__(self, n_block, n_feat_in, n_feat_block, n_feat_out, 50 | dilation=[1,2,4,8], p_drop=0.15): 51 | super(ResidualNetwork, self).__init__() 52 | 53 | 54 | layer_s = list() 55 | # project to n_feat_block 56 | if n_feat_in != n_feat_block: 57 | layer_s.append(nn.Conv2d(n_feat_in, n_feat_block, 1, bias=False)) 58 | 59 | # add resblocks 60 | for i_block in range(n_block): 61 | d = dilation[i_block%len(dilation)] 62 | res_block = ResBlock2D_bottleneck(n_feat_block, kernel=3, dilation=d, p_drop=p_drop) 63 | layer_s.append(res_block) 64 | 65 | if n_feat_out != n_feat_block: 66 | # project to n_feat_out 67 | layer_s.append(nn.Conv2d(n_feat_block, n_feat_out, 1)) 68 | 69 | self.layer = nn.Sequential(*layer_s) 70 | 71 | def forward(self, x): 72 | return self.layer(x) 73 | -------------------------------------------------------------------------------- /network/scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.lr_scheduler import _LRScheduler, LambdaLR 4 | 5 | #def get_cosine_with_hard_restarts_schedule_with_warmup( 6 | # optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1 7 | #): 8 | # """ 9 | # Create a schedule with a learning rate that decreases following the values of the cosine function between the 10 | # initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases 11 | # linearly between 0 and the initial lr set in the optimizer. 12 | # 13 | # Args: 14 | # optimizer (:class:`~torch.optim.Optimizer`): 15 | # The optimizer for which to schedule the learning rate. 16 | # num_warmup_steps (:obj:`int`): 17 | # The number of steps for the warmup phase. 18 | # num_training_steps (:obj:`int`): 19 | # The total number of training steps. 20 | # num_cycles (:obj:`int`, `optional`, defaults to 1): 21 | # The number of hard restarts to use. 22 | # last_epoch (:obj:`int`, `optional`, defaults to -1): 23 | # The index of the last epoch when resuming training. 24 | # 25 | # Return: 26 | # :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 27 | # """ 28 | # 29 | # def lr_lambda(current_step): 30 | # if current_step < num_warmup_steps: 31 | # return float(current_step) / float(max(1, num_warmup_steps)) 32 | # progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 33 | # if progress >= 1.0: 34 | # return 0.0 35 | # return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) 36 | # 37 | # return LambdaLR(optimizer, lr_lambda, last_epoch) 38 | # 39 | 40 | class CosineAnnealingWarmupRestarts(_LRScheduler): 41 | """ 42 | optimizer (Optimizer): Wrapped optimizer. 43 | first_cycle_steps (int): First cycle step size. 44 | cycle_mult(float): Cycle steps magnification. Default: -1. 45 | max_lr(float): First cycle's max learning rate. Default: 0.1. 46 | min_lr(float): Min learning rate. Default: 0.001. 47 | warmup_steps(int): Linear warmup step size. Default: 0. 48 | gamma(float): Decrease rate of max learning rate by cycle. Default: 1. 49 | last_epoch (int): The index of last epoch. Default: -1. 50 | """ 51 | 52 | def __init__(self, 53 | optimizer : torch.optim.Optimizer, 54 | first_cycle_steps : int, 55 | cycle_mult : float = 1., 56 | max_lr : float = 0.1, 57 | min_lr : float = 0.001, 58 | warmup_steps : int = 0, 59 | gamma : float = 1., 60 | last_epoch : int = -1 61 | ): 62 | assert warmup_steps < first_cycle_steps 63 | 64 | self.first_cycle_steps = first_cycle_steps # first cycle step size 65 | self.cycle_mult = cycle_mult # cycle steps magnification 66 | self.base_max_lr = max_lr # first max learning rate 67 | self.max_lr = max_lr # max learning rate in the current cycle 68 | self.min_lr = min_lr # min learning rate 69 | self.warmup_steps = warmup_steps # warmup step size 70 | self.gamma = gamma # decrease rate of max learning rate by cycle 71 | 72 | self.cur_cycle_steps = first_cycle_steps # first cycle step size 73 | self.cycle = 0 # cycle count 74 | self.step_in_cycle = last_epoch # step size of the current cycle 75 | 76 | super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch) 77 | 78 | # set learning rate min_lr 79 | self.init_lr() 80 | 81 | def init_lr(self): 82 | self.base_lrs = [] 83 | for param_group in self.optimizer.param_groups: 84 | param_group['lr'] = self.min_lr 85 | self.base_lrs.append(self.min_lr) 86 | 87 | def get_lr(self): 88 | if self.step_in_cycle == -1: 89 | return self.base_lrs 90 | elif self.step_in_cycle < self.warmup_steps: 91 | return [(self.max_lr - base_lr)*self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs] 92 | else: 93 | return [base_lr + (self.max_lr - base_lr) \ 94 | * (1 + math.cos(math.pi * (self.step_in_cycle-self.warmup_steps) \ 95 | / (self.cur_cycle_steps - self.warmup_steps))) / 2 96 | for base_lr in self.base_lrs] 97 | 98 | def step(self, epoch=None): 99 | if epoch is None: 100 | epoch = self.last_epoch + 1 101 | self.step_in_cycle = self.step_in_cycle + 1 102 | if self.step_in_cycle >= self.cur_cycle_steps: 103 | self.cycle += 1 104 | self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps 105 | self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps 106 | else: 107 | if epoch >= self.first_cycle_steps: 108 | if self.cycle_mult == 1.: 109 | self.step_in_cycle = epoch % self.first_cycle_steps 110 | self.cycle = epoch // self.first_cycle_steps 111 | else: 112 | n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult)) 113 | self.cycle = n 114 | self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1)) 115 | self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n) 116 | else: 117 | self.cur_cycle_steps = self.first_cycle_steps 118 | self.step_in_cycle = epoch 119 | 120 | self.max_lr = self.base_max_lr * (self.gamma**self.cycle) 121 | self.last_epoch = math.floor(epoch) 122 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 123 | param_group['lr'] = lr 124 | 125 | 126 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, min_ratio=0.001, last_epoch=-1): 127 | """ 128 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after 129 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. 130 | 131 | Args: 132 | optimizer (:class:`~torch.optim.Optimizer`): 133 | The optimizer for which to schedule the learning rate. 134 | num_warmup_steps (:obj:`int`): 135 | The number of steps for the warmup phase. 136 | num_training_steps (:obj:`int`): 137 | The total number of training steps. 138 | last_epoch (:obj:`int`, `optional`, defaults to -1): 139 | The index of the last epoch when resuming training. 140 | 141 | Return: 142 | :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 143 | """ 144 | 145 | def lr_lambda(current_step: int): 146 | if current_step < num_warmup_steps: 147 | return float(current_step) / float(max(1, num_warmup_steps)) 148 | return max( 149 | min_ratio, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) 150 | ) 151 | 152 | return LambdaLR(optimizer, lr_lambda, last_epoch) 153 | 154 | def get_stepwise_decay_schedule_with_warmup(optimizer, num_warmup_steps, num_steps_decay, decay_rate, last_epoch=-1): 155 | """ 156 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after 157 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. 158 | 159 | Args: 160 | optimizer (:class:`~torch.optim.Optimizer`): 161 | The optimizer for which to schedule the learning rate. 162 | num_warmup_steps (:obj:`int`): 163 | The number of steps for the warmup phase. 164 | num_training_steps (:obj:`int`): 165 | The total number of training steps. 166 | last_epoch (:obj:`int`, `optional`, defaults to -1): 167 | The index of the last epoch when resuming training. 168 | 169 | Return: 170 | :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 171 | """ 172 | 173 | def lr_lambda(current_step: int): 174 | if current_step < num_warmup_steps: 175 | return float(current_step) / float(max(1, num_warmup_steps)) 176 | 177 | num_fades = (current_step-num_warmup_steps)//num_steps_decay 178 | return (decay_rate**num_fades) 179 | 180 | return LambdaLR(optimizer, lr_lambda, last_epoch) 181 | -------------------------------------------------------------------------------- /run_RF2NA.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # make the script stop when error (non-true exit code) occurs 4 | set -e 5 | 6 | ############################################################ 7 | # >>> conda initialize >>> 8 | # !! Contents within this block are managed by 'conda init' !! 9 | __conda_setup="$('conda' 'shell.bash' 'hook' 2> /dev/null)" 10 | eval "$__conda_setup" 11 | unset __conda_setup 12 | # <<< conda initialize <<< 13 | ############################################################ 14 | 15 | SCRIPT=`realpath -s $0` 16 | export PIPEDIR=`dirname $SCRIPT` 17 | HHDB="$PIPEDIR/pdb100_2021Mar03/pdb100_2021Mar03" 18 | 19 | CPU="8" # number of CPUs to use 20 | MEM="64" # max memory (in GB) 21 | 22 | WDIR=`realpath -s $1` # working folder 23 | mkdir -p $WDIR/log 24 | 25 | conda activate RF2NA 26 | 27 | # process protein (MSA + homology search) 28 | function proteinMSA { 29 | seqfile=$1 30 | tag=$2 31 | 32 | ############################################################ 33 | # generate MSAs 34 | ############################################################ 35 | if [ ! -s $WDIR/$tag.msa0.a3m ] 36 | then 37 | echo "Running HHblits" 38 | echo " -> Running command: $PIPEDIR/input_prep/make_protein_msa.sh $seqfile $WDIR $tag $CPU $MEM" 39 | $PIPEDIR/input_prep/make_protein_msa.sh $seqfile $WDIR $tag $CPU $MEM > $WDIR/log/make_msa.$tag.stdout 2> $WDIR/log/make_msa.$tag.stderr 40 | fi 41 | 42 | 43 | ############################################################ 44 | # search for templates 45 | ############################################################ 46 | if [ ! -s $WDIR/$tag.hhr ] 47 | then 48 | echo "Running hhsearch" 49 | HH="hhsearch -b 50 -B 500 -z 50 -Z 500 -mact 0.05 -cpu $CPU -maxmem $MEM -aliw 100000 -e 100 -p 5.0 -d $HHDB" 50 | echo " -> Running command: $HH -i $WDIR/$tag.msa0.ss2.a3m -o $WDIR/$tag.hhr -atab $WDIR/$tag.atab -v 0" 51 | $HH -i $WDIR/$tag.msa0.a3m -o $WDIR/$tag.hhr -atab $WDIR/$tag.atab -v 0 > $WDIR/log/hhsearch.$tag.stdout 2> $WDIR/log/hhsearch.$tag.stderr 52 | fi 53 | } 54 | 55 | # process RNA (MSA) 56 | function RNAMSA { 57 | seqfile=$1 58 | tag=$2 59 | 60 | ############################################################ 61 | # generate MSAs 62 | ############################################################ 63 | if [ ! -s $WDIR/$tag.afa ] 64 | then 65 | echo "Running rMSA (lite)" 66 | echo " -> Running command: $PIPEDIR/input_prep/make_rna_msa.sh $seqfile $WDIR $tag $CPU $MEM" 67 | $PIPEDIR/input_prep/make_rna_msa.sh $seqfile $WDIR $tag $CPU $MEM > $WDIR/log/make_msa.$tag.stdout 2> $WDIR/log/make_msa.$tag.stderr 68 | fi 69 | } 70 | 71 | argstring="" 72 | 73 | shift 74 | nP=0 75 | nR=0 76 | nD=0 77 | for i in "$@" 78 | do 79 | type=`echo $i | awk -F: '{if (NF==1) {print "P"} else {print $1}}'` 80 | type=${type^^} 81 | fasta=`echo $i | awk -F: '{if (NF==1) {print $1} else {print $2}}'` 82 | tag=`basename $fasta | sed -E 's/\.fasta$|\.fas$|\.fa$//'` 83 | 84 | if [ $type = 'P' ] 85 | then 86 | proteinMSA $fasta $tag 87 | argstring+="P:$WDIR/$tag.msa0.a3m:$WDIR/$tag.hhr:$WDIR/$tag.atab " 88 | nP=$((nP+1)) 89 | lastP="$tag" 90 | elif [ $type = 'R' ] 91 | then 92 | RNAMSA $fasta $tag 93 | argstring+="R:$WDIR/$tag.afa " 94 | nR=$((nR+1)) 95 | lastR="$tag" 96 | elif [ $type = 'D' ] 97 | then 98 | cp $fasta $WDIR/$tag.fa 99 | argstring+="D:$WDIR/$tag.fa " 100 | nD=$((nD+2)) 101 | elif [ $type = 'S' ] 102 | then 103 | cp $fasta $WDIR/$tag.fa 104 | argstring+="S:$WDIR/$tag.fa " 105 | nD=$((nD+1)) 106 | fi 107 | done 108 | 109 | ############################################################ 110 | # Merge MSAs based on taxonomy ID 111 | ############################################################ 112 | if [ $nP -eq 1 ] && [ $nD -eq 0 ] && [ $nR -eq 1 ] 113 | then 114 | echo "Creating joint Protein-RNA MSA" 115 | echo " -> Running command: $PIPEDIR/input_prep/merge_msa_prot_rna.py $WDIR/$lastP.msa0.a3m $WDIR/$lastR.afa $WDIR/$lastP.$lastR.a3m" 116 | $PIPEDIR/input_prep/merge_msa_prot_rna.py $WDIR/$lastP.msa0.a3m $WDIR/$lastR.afa $WDIR/$lastP.$lastR.a3m > $WDIR/log/make_pMSA.$tag.stdout 2> $WDIR/log/make_pMSA.$tag.stderr 117 | argstring="PR:$WDIR/$lastP.$lastR.a3m:$WDIR/$lastP.hhr:$WDIR/$lastP.atab" 118 | fi 119 | 120 | ############################################################ 121 | # end-to-end prediction 122 | ############################################################ 123 | echo "Running RoseTTAFold2NA to predict structures" 124 | echo " -> Running command: python $PIPEDIR/network/predict.py -inputs $argstring -prefix $WDIR/models/model -model $PIPEDIR/network/weights/RF2NA_apr23.pt -db $HHDB" 125 | mkdir -p $WDIR/models 126 | 127 | python $PIPEDIR/network/predict.py \ 128 | -inputs $argstring \ 129 | -prefix $WDIR/models/model \ 130 | -model $PIPEDIR/network/weights/RF2NA_apr23.pt \ 131 | -db $HHDB #2> $WDIR/log/network.stderr #1> $WDIR/log/network.stdout 132 | 133 | echo "Done" 134 | --------------------------------------------------------------------------------