├── .gitignore ├── LICENSE ├── README.md ├── checkpoints └── quickbind_default │ ├── best_checkpoint.pt │ ├── binding_affinity_prediction │ └── ckpt_seed42.pt │ └── config.yaml ├── commons ├── modified_of_modules.py └── utils.py ├── configs └── quickbind_default.yml ├── data ├── timesplit_no_lig_or_rec_overlap_train ├── timesplit_no_lig_or_rec_overlap_val ├── timesplit_no_lig_overlap_train ├── timesplit_no_lig_overlap_val └── timesplit_test ├── dataset ├── dataimporter.py └── process_mols.py ├── inference.py ├── interpretability.ipynb ├── overview.jpg ├── quickbind.py ├── scripts └── process_binding_affinities.py ├── train_binding_affinity.py ├── train_pl.py └── virtual_screening.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | __pycache__/* 3 | */__pycache__/* 4 | wandb/* 5 | configs/ 6 | checkpoints/ 7 | *.log -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 AQ Laboratory 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # QuickBind - A Light-Weight And Interpretable Molecular Docking Model 3 | 4 | ![Overview of QuickBind](overview.jpg) 5 | 6 | This repository contains the code for [QuickBind](https://arxiv.org/abs/2410.16474). 7 | 8 | ## Creating the environment 9 | 10 | You can create the conda environment on a Linux system using the following commands. 11 | ```bash 12 | conda create --name quickbind 13 | conda activate quickbind 14 | conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch 15 | conda install pytorch-lightning==1.9 -c conda-forge 16 | pip install torch-scatter -f https://data.pyg.org/whl/torch-1.11.0+cu113.html 17 | pip install torch-sparse -f https://data.pyg.org/whl/torch-1.12.1+cu113.html 18 | pip install torch-geometric 19 | conda install -c conda-forge rdkit=2022.03.2 20 | pip install wandb 21 | pip install nvidia-pyindex 22 | pip install nvidia-dllogger 23 | pip install spyrmsd 24 | pip install biopython 25 | ``` 26 | You also need to install OpenFold: 27 | ```bash 28 | git clone https://github.com/aqlaboratory/openfold.git 29 | cd openfold 30 | git checkout efcf80f50e5534cdf9b5ede56ef1cbbd1471775e 31 | pip install ./ 32 | ``` 33 | 34 | ## Downloading training and evaluation data 35 | 36 | Download the PDBBind dataset and place it in the `data/` directory: 37 | ```bash 38 | wget https://www.zenodo.org/record/6408497/files/PDBBind.zip?download=1 39 | unzip 'PDBBind.zip?download=1' 40 | mv PDBBind_processed/ PDBBind/ 41 | mv PDBBind/ data/ 42 | ``` 43 | Similarly, the PoseBusters benchmark set can be downloaded from [zenodo](https://zenodo.org/records/8278563), and should be placed in the data directory as well, such that you have `data/posebusters_benchmark_set`. 44 | 45 | ## Training the model 46 | 47 | QuickBind was trained using gradient accumulation on two 4-GPU nodes, resulting in an effective batch size of 16. The number of nodes can be set using the `num_nodes` argument of `pl.Trainer` in [train_pl.py](train_pl.py). The number of iterations over which to accumulate gradients can be set using the `iters_to_accumulate` keyword in the configuration file. To train QuickBind using its default parameters, run: 48 | ```bash 49 | python train_pl.py --config configs/quickbind_default.yml 50 | ``` 51 | For fine-tuning using a larger crop size, add `--finetune True`. To resume training from the most recent checkpoint file, add `--resume True` and provide the Weights & Biases ID to the `--id` flag. After training, copy the configuration file to the checkpoints directory, for example, like so: 52 | ```bash 53 | cp configs/quickbind_default.yml checkpoints/quickbind_default/config.yaml 54 | ``` 55 | Model weights are stored in [checkpoints/quickbind_default](checkpoints/quickbind_default), which also contains model weights for the final QuickBind model. 56 | 57 | ## Running inference 58 | 59 | The following command will run inference on the PDBBind test set using the model weights with the lowest validation loss from `checkpoints/quickbind_default`: 60 | ```bash 61 | python inference.py --name quickbind_default 62 | ``` 63 | Adding `--unseen_only True` will only include proteins that the model has not seen during training. Adding `--save_to_file True` will save the predictions to SD files. Adding `--pb_set True` will run inference on the PoseBusters test set. 64 | 65 | To facilitate the reproduction of the results from the paper, we provide processed input files for the PDBBind and PoseBusters test sets, and predictions of the final model, including SD files, on [zenodo](https://zenodo.org/records/12509123). To use the preprocessed input files, download them and put them in `data/processed/timesplit_test` and `data/processed/posebusters`, respectively: 66 | ```bash 67 | wget https://zenodo.org/records/12509123/files/QuickBind_Data.tar.gz?download=1 68 | tar -xzvf 'QuickBind_Data.tar.gz?download=1' 69 | mv QuickBind_Zenodo/processed_input_files/timesplit_test/ data/processed/ 70 | mv QuickBind_Zenodo/processed_input_files/posebusters/ data/processed/ 71 | ``` 72 | 73 | To evaluate the provided predictions, download them, place them in `checkpoints/quickbind_default`, and run the inference script: 74 | ```bash 75 | mv QuickBind_Zenodo/predictions/* checkpoints/quickbind_default/ 76 | python inference.py --name quickbind_default 77 | ``` 78 | 79 | ## Binding affinity prediction 80 | 81 | Download, extract, and process the raw PDBBind data: 82 | ```bash 83 | wget https://pdbbind.oss-cn-hangzhou.aliyuncs.com/download/PDBbind_v2020_plain_text_index.tar.gz 84 | tar -xf PDBbind_v2020_plain_text_index.tar.gz 85 | python scripts/process_binding_affinities.py 86 | ``` 87 | This will create a pickled dictionary with the binding affinities of all complexes in the PDBBind dataset (`binding_affinity_dict.pkl`). 88 | 89 | Generate training, validation, and test embeddings: 90 | ```bash 91 | python inference.py --name quickbind_default --train_set True --output_s True 92 | python inference.py --name quickbind_default --val_set True --output_s True 93 | python inference.py --name quickbind_default --output_s True 94 | ``` 95 | Alternatively, we provide extracted embeddings on [zenodo](https://zenodo.org/records/12509123): 96 | ```bash 97 | mv QuickBind_Zenodo/predictions_w_embeddings/predictions-w-single-rep-curr-train.pt checkpoints/quickbind_default/train_predictions-w-single-rep.pt 98 | mv QuickBind_Zenodo/predictions_w_embeddings/predictions-w-single-rep-curr-val.pt checkpoints/quickbind_default/val_predictions-w-single-rep.pt 99 | mv QuickBind_Zenodo/predictions_w_embeddings/predictions-w-single-rep-curr.pt checkpoints/quickbind_default/predictions-w-single-rep.pt 100 | ``` 101 | 102 | To train and evaluate the binding affinity prediction model run: 103 | ```bash 104 | python train_binding_affinity.py 105 | ``` 106 | 107 | We provide the weights of the final trained model under [checkpoints/binding_affinity_prediction/](checkpoints/binding_affinity_prediction/). To evaluate the binding affinity prediction model using these weights run: 108 | ```bash 109 | python train_binding_affinity.py --ckpt checkpoints/quickbind_default/binding_affinity_prediction/ckpt_seed42.pt 110 | ``` 111 | 112 | ## Reproduction of additional results 113 | 114 | The pair and single representations used in the interpretability studies, as well as the output files used in the toy virtual screen can be downloaded from [zenodo](https://zenodo.org/records/12509123): 115 | ```bash 116 | mv QuickBind_Zenodo/embeddings_interpret/* ./ 117 | mv QuickBind_Zenodo/virtual_screening/ ./ 118 | ``` 119 | To reproduce the results in the paper, follow the notebooks [interpretability.ipynb](interpretability.ipynb) and [virtual_screening.ipynb](virtual_screening.ipynb). 120 | 121 | -------------------------------------------------------------------------------- /checkpoints/quickbind_default/best_checkpoint.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aqlaboratory/QuickBind/d8c5cf4901b44f233cbbd8e6936d3e31aeebfec2/checkpoints/quickbind_default/best_checkpoint.pt -------------------------------------------------------------------------------- /checkpoints/quickbind_default/binding_affinity_prediction/ckpt_seed42.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aqlaboratory/QuickBind/d8c5cf4901b44f233cbbd8e6936d3e31aeebfec2/checkpoints/quickbind_default/binding_affinity_prediction/ckpt_seed42.pt -------------------------------------------------------------------------------- /checkpoints/quickbind_default/config.yaml: -------------------------------------------------------------------------------- 1 | name: 'quickbind_default' 2 | seed: 1 3 | num_epochs: 500 4 | batch_size: 1 5 | patience: 50 6 | iters_to_accumulate: 2 7 | loss_params: 8 | lig_lig_loss_weight: 1 9 | lig_rec_loss_weight: 1 10 | aux_loss_weight: 1 11 | steric_clash_loss_weight: 1 12 | full_distogram_loss_weight: 0 13 | 14 | train_names: 'data/timesplit_no_lig_overlap_train' 15 | val_names: 'data/timesplit_no_lig_overlap_val' 16 | test_names: 'data/timesplit_test' 17 | num_workers: 4 18 | dataset_params: 19 | chain_radius: 10 # only keep chains that have an atom in this radius around the ligand 20 | remove_h: True 21 | cropping: True 22 | crop_size: 256 23 | binding_site_cropping: True 24 | recenter: True 25 | 26 | optimizer: AdamW 27 | optimizer_params: 28 | lr: 1.0e-4 29 | weight_decay: 1.0e-4 30 | clip_grad: 100 # leave empty for no grad clip 31 | 32 | model_parameters: 33 | recycle: False 34 | recycle_iters: 1 35 | c_emb: 32 36 | c_s: 64 37 | c_z: 64 38 | c_hidden: 16 39 | no_heads: 12 40 | no_qk_points: 4 41 | no_v_points: 8 42 | num_struct_blocks: 8 43 | dropout_rate: 0.1 44 | no_transition_layers: 1 45 | share_ipa_weights: True 46 | c_hidden_msa_att: 16 # c_s // 4 47 | c_hidden_opm: 16 # c_z // 4 48 | c_hidden_mul: 64 # c_z 49 | c_hidden_pair_att: 16 # c_z // 4 50 | c_s_out: 64 51 | no_heads_msa: 8 52 | no_heads_pair: 4 53 | no_evo_blocks: 12 54 | opm_first: False 55 | transition_n: 4 56 | msa_dropout: 0.15 57 | pair_dropout: 0.25 58 | use_pairwise_dist: True 59 | use_radial_basis: False 60 | use_rel_pos: True 61 | mask_off_diagonal: True 62 | use_op_edge_embed: False 63 | use_gated_ipa: True 64 | communicate: False 65 | one_hot_adj: False 66 | use_full_evo_stack: True 67 | att_update: True 68 | use_multimer_rel_pos: False 69 | use_topological_distance: False 70 | construct_frames: True 71 | 72 | wandb: 73 | project: 'QuickBind' 74 | resume: allow 75 | -------------------------------------------------------------------------------- /commons/utils.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import math 3 | import numpy as np 4 | from spyrmsd import rmsd, molecule 5 | import sys 6 | import torch 7 | import os 8 | from dataset.process_mols import read_molecule, reorder_atoms 9 | from rdkit.Chem import RemoveHs, SDWriter 10 | from rdkit.Geometry import Point3D 11 | 12 | def log(*args): 13 | print(f'[{datetime.now()}]', *args) 14 | 15 | def get_parameter_value(param, config): 16 | try: 17 | param_value = bool(config[param]) 18 | except Exception: 19 | param_value = False 20 | return param_value 21 | 22 | class Logger(object): 23 | def __init__(self, logpath, syspart=sys.stdout): 24 | self.terminal = syspart 25 | self.log = open(logpath, "a") 26 | 27 | def write(self, message): 28 | self.terminal.write(message) 29 | self.log.write(message) 30 | self.log.flush() 31 | 32 | def flush(self): 33 | pass 34 | 35 | def save_predictions_to_file(results, receptor_dir, out_path): 36 | receptors = torch.load(os.path.join(receptor_dir, "rec_input_proc_ids.pt")) 37 | coms = [] 38 | for rec in receptors: 39 | c_alpha_coords = rec['c_alpha_coords'] 40 | if c_alpha_coords.shape[0] > 2000: # inference is currently limited to complexes with less than 2000 residues 41 | continue 42 | c = torch.mean(c_alpha_coords, dim=0) 43 | coms.append(c) 44 | 45 | pred_mols = [] 46 | for prediction, name, com in zip(results['predictions'], results['names'], coms): 47 | if os.path.basename(receptor_dir) == 'posebusters': 48 | lig = read_molecule(os.path.join('data/posebusters_benchmark_set', name, f'{name}_ligand.sdf'), remove_hs=True) 49 | else: 50 | lig = read_molecule(os.path.join('data/PDBBind/', name, f'{name}_ligand.mol2'), remove_hs=True) 51 | if lig is None: 52 | lig = read_molecule(os.path.join('data/PDBBind/', name, f'{name}_ligand.sdf'), remove_hs=True) 53 | lig = RemoveHs(lig) 54 | lig = reorder_atoms(lig) 55 | conf = lig.GetConformer() 56 | p = prediction.squeeze().numpy() 57 | c = com.numpy() 58 | for i in range(lig.GetNumAtoms()): 59 | x, y, z = p[i] + c 60 | conf.SetAtomPosition(i, Point3D(float(x), float(y), float(z))) 61 | pred_mols.append(lig) 62 | 63 | if not os.path.exists(out_path): 64 | os.mkdir(out_path) 65 | 66 | for mol, name in zip(pred_mols,results['names']): 67 | with SDWriter(os.path.join(out_path, f'{name}_pred.sdf')) as w: 68 | w.write(mol) 69 | 70 | def read_strings_from_txt(path): 71 | with open(path) as file: 72 | lines = file.readlines() 73 | return [line.rstrip() for line in lines] 74 | 75 | # This function is taken from EquiBind 76 | # Copyright (c) 2022 Hannes Stärk 77 | # R = 3x3 rotation matrix 78 | # t = 3x1 column vector 79 | # This already takes residue identity into account. 80 | def rigid_transform_Kabsch_3D(A, B): 81 | assert A.shape[1] == B.shape[1] 82 | num_rows, num_cols = A.shape 83 | if num_rows != 3: 84 | raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}") 85 | num_rows, num_cols = B.shape 86 | if num_rows != 3: 87 | raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}") 88 | 89 | # find mean column wise: 3 x 1 90 | centroid_A = np.mean(A, axis=1, keepdims=True) 91 | centroid_B = np.mean(B, axis=1, keepdims=True) 92 | 93 | # subtract mean 94 | Am = A - centroid_A 95 | Bm = B - centroid_B 96 | 97 | H = Am @ Bm.T 98 | 99 | # find rotation 100 | U, S, Vt = np.linalg.svd(H) 101 | 102 | R = Vt.T @ U.T 103 | 104 | # special reflection case 105 | if np.linalg.det(R) < 0: 106 | # print("det(R) < R, reflection detected!, correcting for it ...") 107 | SS = np.diag([1.,1.,-1.]) 108 | R = (Vt.T @ SS) @ U.T 109 | assert math.fabs(np.linalg.det(R) - 1) < 1e-5 110 | 111 | t = -R @ centroid_A + centroid_B 112 | return R, t 113 | 114 | # This function is taken from DiffDock 115 | # Copyright (c) 2022 Gabriele Corso, Hannes Stärk, Bowen Jing 116 | def get_symmetry_rmsd(mol, coords1, coords2, mol2=None): 117 | mol = molecule.Molecule.from_rdkit(mol) 118 | mol2 = molecule.Molecule.from_rdkit(mol2) if mol2 is not None else mol2 119 | mol2_atomicnums = mol2.atomicnums if mol2 is not None else mol.atomicnums 120 | mol2_adjacency_matrix = mol2.adjacency_matrix if mol2 is not None else mol.adjacency_matrix 121 | RMSD = rmsd.symmrmsd( 122 | coords1, 123 | coords2, 124 | mol.atomicnums, 125 | mol2_atomicnums, 126 | mol.adjacency_matrix, 127 | mol2_adjacency_matrix, 128 | ) 129 | return RMSD 130 | -------------------------------------------------------------------------------- /configs/quickbind_default.yml: -------------------------------------------------------------------------------- 1 | name: 'quickbind_default' 2 | seed: 1 3 | num_epochs: 500 4 | batch_size: 1 5 | patience: 50 6 | iters_to_accumulate: 2 7 | loss_params: 8 | lig_lig_loss_weight: 1 9 | lig_rec_loss_weight: 1 10 | aux_loss_weight: 1 11 | steric_clash_loss_weight: 1 12 | full_distogram_loss_weight: 0 13 | 14 | train_names: 'data/timesplit_no_lig_overlap_train' 15 | val_names: 'data/timesplit_no_lig_overlap_val' 16 | test_names: 'data/timesplit_test' 17 | num_workers: 4 18 | dataset_params: 19 | chain_radius: 10 # only keep chains that have an atom in this radius around the ligand 20 | remove_h: True 21 | cropping: True 22 | crop_size: 256 23 | binding_site_cropping: True 24 | recenter: True 25 | 26 | optimizer: AdamW 27 | optimizer_params: 28 | lr: 1.0e-4 29 | weight_decay: 1.0e-4 30 | clip_grad: 100 # leave empty for no grad clip 31 | 32 | model_parameters: 33 | recycle: False 34 | recycle_iters: 1 35 | c_emb: 32 36 | c_s: 64 37 | c_z: 64 38 | c_hidden: 16 39 | no_heads: 12 40 | no_qk_points: 4 41 | no_v_points: 8 42 | num_struct_blocks: 8 43 | dropout_rate: 0.1 44 | no_transition_layers: 1 45 | share_ipa_weights: True 46 | c_hidden_msa_att: 16 # c_s // 4 47 | c_hidden_opm: 16 # c_z // 4 48 | c_hidden_mul: 64 # c_z 49 | c_hidden_pair_att: 16 # c_z // 4 50 | c_s_out: 64 51 | no_heads_msa: 8 52 | no_heads_pair: 4 53 | no_evo_blocks: 12 54 | opm_first: False 55 | transition_n: 4 56 | msa_dropout: 0.15 57 | pair_dropout: 0.25 58 | use_pairwise_dist: True 59 | use_radial_basis: False 60 | use_rel_pos: True 61 | mask_off_diagonal: True 62 | use_op_edge_embed: False 63 | use_gated_ipa: True 64 | communicate: False 65 | one_hot_adj: False 66 | use_full_evo_stack: True 67 | att_update: True 68 | use_multimer_rel_pos: False 69 | use_topological_distance: False 70 | construct_frames: True 71 | 72 | wandb: 73 | project: 'QuickBind' 74 | resume: allow 75 | -------------------------------------------------------------------------------- /data/timesplit_no_lig_or_rec_overlap_val: -------------------------------------------------------------------------------- 1 | 4mi6 2 | 5ylv 3 | 4ozo 4 | 6gip 5 | 3std 6 | 3g2n 7 | 6ax1 8 | 6h96 9 | 5q0m 10 | 5hh5 11 | 4idz 12 | 6cec 13 | 5wqa 14 | 3k3e 15 | 1ppk 16 | 4og7 17 | 4b5w 18 | 4bgg 19 | 4jgv 20 | 2g9v 21 | 4og3 22 | 5lz2 23 | 6chq 24 | 5aqv 25 | 4k67 26 | 5j3v 27 | 5lz9 28 | 5t2y 29 | 4el5 30 | 2z60 31 | 4zed 32 | 4pks 33 | 5cuu 34 | 5q0x 35 | 4x5z 36 | 3hs9 37 | 3v7s 38 | 2qyn 39 | 5ehn 40 | 1sz0 41 | 4x6i 42 | 4u82 43 | 2vrj 44 | 6g2l 45 | 2bow 46 | 5o9y 47 | 4mji 48 | 6ccs 49 | 1yvm 50 | 3sym 51 | 4fz3 52 | 5tdi 53 | 2ie4 54 | 1n4m 55 | 1c3i 56 | 5eie 57 | 5ye8 58 | 5in9 59 | 4b7z 60 | 5ncz 61 | 5lz4 62 | 6bw4 63 | 3sl4 64 | 4dmy 65 | 4cmt 66 | 5a9u 67 | 1fki 68 | 4mc1 69 | 5o9o 70 | 4b5t 71 | 4m3g 72 | 1pyg 73 | 3b67 74 | 5l8o 75 | 5mkj 76 | 3oyp 77 | 4anq 78 | 5hog 79 | 4de7 80 | 5tkb 81 | 2vle 82 | 3f7i 83 | 3v51 84 | 3l7d 85 | 2v6n 86 | 4qna 87 | 4cd0 88 | 3iog 89 | 4i8w 90 | 2xup 91 | 3t3i 92 | 1db4 93 | 5es1 94 | 5i9i 95 | 6ccm 96 | 2xui 97 | 1q91 98 | 1bgo 99 | 1akt 100 | 1q84 101 | 1yt7 102 | 2l75 103 | 5aac 104 | 6nao 105 | 5iuh 106 | 3oof 107 | 4ona 108 | 5q0n 109 | 3sfi 110 | 2g9q 111 | 1hlf 112 | 5aqj 113 | 1g9c 114 | 1ayu 115 | 6co4 116 | 6bd1 117 | 4yur 118 | 3vw0 119 | 6cea 120 | 4nyf 121 | 3v43 122 | 2ya8 123 | 1b4d 124 | 2ccb 125 | 5q1f 126 | 1fkb 127 | 3bcs 128 | 1h46 129 | 6dgt 130 | 2ftd 131 | 5t2l 132 | 3i7c 133 | 6ckw 134 | 3csl 135 | 1j07 136 | 3omm 137 | 4g17 138 | 3v49 139 | 4fny 140 | 1fkh 141 | 4u0e 142 | 2g9r 143 | 6hd4 144 | 5oss 145 | 2adm 146 | 4b85 147 | 2fm0 148 | 4hgs 149 | 2qn3 150 | 1ddm 151 | 3fal 152 | 5q0y 153 | 5q1a 154 | 3g2k 155 | 2j7b 156 | 6ee2 157 | 2jjr 158 | 5q1i 159 | 6b7a 160 | 5kh7 161 | 6gqm 162 | 3mta 163 | 3g2l 164 | 4i4f 165 | 5iaw 166 | 5q0i 167 | 3h1z 168 | 4jbl 169 | 5lvx 170 | 6chm 171 | 6fdc 172 | 2ax6 173 | 6cnj 174 | 1pwu 175 | 4fht 176 | 3th9 177 | 5db1 178 | 2hvc 179 | 4pp5 180 | 5q15 181 | 3u8h 182 | 4el0 183 | 4jt8 184 | 5std 185 | 2pwd 186 | 5wh6 187 | 3kf4 188 | 1q83 189 | 2xml 190 | 1c2t 191 | 4in9 192 | 2jt5 193 | 1icj 194 | 3oy3 195 | 3g2h 196 | 3qo2 197 | 3tu9 198 | 3s0j 199 | 1e3g 200 | 6d1x 201 | 6mx8 202 | 3aqt 203 | 6aol 204 | 4m8e 205 | 1ado 206 | 5hh6 207 | 4lkt 208 | 2j78 209 | 5q0o 210 | 1d8m 211 | 3suu 212 | 1gyy 213 | 5aqt 214 | 3oap 215 | 4zs3 216 | 2qn1 217 | 2p98 218 | 4cmu 219 | 4ie0 220 | 1w3j 221 | 5y59 222 | 2mpa 223 | 1akw 224 | 5tyh 225 | 5lb7 226 | 1fkf 227 | 3djv 228 | 5nxq 229 | 1kti 230 | 3mrv 231 | 5t2i 232 | 4uuq 233 | 2gfj 234 | 4poh 235 | 6by8 236 | 6b7e 237 | 4u0f 238 | 5lp1 239 | 4jym 240 | 3suv 241 | 6fse 242 | 5e1s 243 | 2eum 244 | 6hai 245 | 6h7f 246 | 3suw 247 | 2gg7 248 | 3np7 249 | 5l13 250 | 6f8x 251 | 2evc 252 | 1z6q 253 | 5o9r 254 | 4pkt 255 | 1haa 256 | 5jan 257 | 4oiv 258 | 3djq 259 | 3p7i 260 | 2j7h 261 | 3v7c 262 | 2fw6 263 | 3diw 264 | 4dpt 265 | 1tu6 266 | 4pku 267 | 1ggn 268 | 2usn 269 | 6baw 270 | 1zxv 271 | 3bl7 272 | 5vdu 273 | 6fsd 274 | 4u0i 275 | 4j8s 276 | 1g49 277 | 4hni 278 | 6e83 279 | 6f0y 280 | 4mra 281 | 1d5r 282 | 3v3q 283 | 6f8w 284 | 1c7e 285 | 4b5s 286 | 4ara 287 | 4glr 288 | 4cff 289 | 2q93 290 | 5i8p 291 | 1apw 292 | 3sur 293 | 4b84 294 | 4duh 295 | 3l7c 296 | 1j4r 297 | 3mrt 298 | 1l7x 299 | 4dv8 300 | 1iep 301 | 1bsk 302 | 1i8h 303 | 5xih 304 | 5knj 305 | 5ick 306 | 2v3d 307 | 1sln 308 | 4k64 309 | 6cck 310 | 4ajw 311 | 2hb9 312 | 6std 313 | 2w87 314 | 4a23 315 | 1ayv 316 | 4bvb 317 | 1nc1 318 | 5q0p 319 | 3zlv 320 | 3okh 321 | 1pwq 322 | 3pp0 323 | 4pgh 324 | 2xbp 325 | 2qrp 326 | 4mho 327 | 6gpb 328 | 1aku 329 | 6b7f 330 | 3pcu 331 | 1fm9 332 | 5ddc 333 | 5q0u 334 | 1biw 335 | 3ery 336 | 2evo 337 | 4cxx 338 | 4g2l 339 | 5h6v 340 | 5yyz 341 | 4nw2 342 | 1em6 343 | 3shv 344 | 6cz3 345 | 4lh2 346 | 2gfk 347 | 2z78 348 | 4yv8 349 | 2jt6 350 | 3kdt 351 | 6chp 352 | 2xba 353 | 5wbl 354 | 5t2d 355 | 3fqa 356 | 5xii 357 | 2j83 358 | 4x6j 359 | 2q94 360 | 1vsn 361 | 1c8k 362 | 4b81 363 | 5y86 364 | 1caq 365 | 1k3t 366 | 2qrm 367 | 4u68 368 | 1exw 369 | 1nlj 370 | 3kdu 371 | 1axr 372 | 1kkq 373 | 3mrx 374 | 4z2h 375 | 3oki 376 | 6chl 377 | 1ppl 378 | 2q95 379 | 5q13 380 | 2j7g 381 | 2w2u 382 | 5aqq 383 | 1ms0 384 | 2b2v 385 | 6bfa 386 | 5kew 387 | 2cbu 388 | 1nli 389 | 6bw3 390 | 4h4e 391 | 2ha5 392 | 2aq9 393 | 1g98 394 | 2pri 395 | 1apv 396 | 1gar 397 | 3szm 398 | 8gpb 399 | 1noi 400 | 5foo 401 | 2z4w 402 | 3r93 403 | 2z7i 404 | 5vdr 405 | 5ylu 406 | 6f8r 407 | 3upx 408 | 3zyr 409 | 2jnp 410 | 2nm1 411 | 4kab 412 | 4n8r 413 | 5z4h 414 | 4f9v 415 | 3l7a 416 | 2am9 417 | 2wec 418 | 1h6h 419 | 3nfl 420 | 3wd1 421 | 4mdt 422 | 3sxf 423 | 6ftn 424 | 3gt9 425 | 1oim 426 | 1ywh 427 | 1u9w 428 | 1w70 429 | 3u1i 430 | 2v3e 431 | 5jap 432 | 3b68 433 | 2qnb 434 | 5hbs 435 | 2ama 436 | 2web 437 | 6e4t 438 | 4xm6 439 | 6hd6 440 | 4og4 441 | 3il6 442 | 4zs2 443 | 2z50 444 | 1nki 445 | 4my6 446 | 3vvy 447 | 3nc4 448 | 2z4y 449 | 2euk 450 | 3g4k 451 | 2y2i 452 | 1usn 453 | 1y6r 454 | 3bu6 455 | 4u0c 456 | 2ces 457 | 5ye7 458 | 2cbv 459 | 5twg 460 | 1f40 461 | 5m2q 462 | 4hvs 463 | 4z2i 464 | 3g0e 465 | 3dct 466 | 6bnl 467 | 5v8q 468 | 1au2 469 | 3h0a 470 | 2z52 471 | 4ie5 472 | 2auz 473 | 2qn2 474 | 4lh3 475 | 5lyy 476 | 4z2l 477 | 2bmz 478 | 2evm 479 | 5wvd 480 | 3o8h 481 | 6e4w 482 | 1syo 483 | 1yk7 484 | 5q11 485 | 2ha0 486 | 6ccn 487 | 1mh5 488 | 2ai8 489 | 3rde 490 | 5q17 491 | 4zsh 492 | 6aom 493 | 1pot 494 | 5i5x 495 | 4cqe 496 | 3l3x 497 | 4ie6 498 | 6cnk 499 | 1c12 500 | 1gfz 501 | 4jal 502 | 4ie7 503 | 5tvn 504 | 2qoh 505 | 4z2g 506 | 6b7b 507 | 5kxi 508 | 6f8v 509 | 1xor 510 | 1aqi 511 | 6q73 512 | 5t27 513 | 5q0r 514 | 1y2k 515 | 5q1b 516 | 5ko1 517 | 1a8i 518 | 3g2j 519 | 5gwz 520 | 1c8l 521 | 4o42 522 | 3r2a 523 | 1d7i 524 | 3ovz 525 | 3l3z 526 | 4i32 527 | 2ych 528 | 5aqp 529 | 4xkc 530 | 4og8 531 | 3g8i 532 | 2z7h 533 | 3kfa 534 | 3vfa 535 | 5q0v 536 | 3k41 537 | 2pwg 538 | 1au0 539 | 1g9d 540 | 4i8z 541 | 6cef 542 | 4mhs 543 | 4xm7 544 | 1c3e 545 | 3kx1 546 | 2gj4 547 | 5q16 548 | 4pli 549 | 2j7e 550 | 2j7d 551 | 6q6y 552 | 4gq6 553 | 3k5v 554 | 4j09 555 | 3bl9 556 | 2xuf 557 | 1e1y 558 | 6gi6 559 | 5z4o 560 | 5q18 561 | 4foc 562 | 6bcy 563 | 4aaw 564 | 2ha6 565 | 4pl5 566 | 3vw1 567 | 6bq0 568 | 2aux 569 | 1mkd 570 | 2q92 571 | 1xon 572 | 3aig 573 | 3oxz 574 | 2r6n 575 | 4mmp 576 | 5hki 577 | 2gg0 578 | 4qip 579 | 3ms2 580 | 5d1t 581 | 2ot1 582 | 2xpc 583 | 4zcs 584 | 5db3 585 | 3t3h 586 | 6cz4 587 | 4h4d 588 | 3h78 589 | 5q10 590 | 4jsr 591 | 4qll 592 | 4e90 593 | 6dry 594 | 5aqu 595 | 5oei 596 | 5hz5 597 | 3kwz 598 | 4m3d 599 | 5xig 600 | 3u3z 601 | 2y2k 602 | 5lz5 603 | 1o8b 604 | 4hlw 605 | 4tq3 606 | 1fkg 607 | 3pkn 608 | 2xb7 609 | 5ov9 610 | 3ggv 611 | 3cke 612 | 4m84 613 | 6cho 614 | 1m6p 615 | 5jal 616 | 1a0q 617 | 4eky 618 | 3vw2 619 | 2y2n 620 | 5q0q 621 | 6gin 622 | 2nmb 623 | 3rik 624 | 1akv 625 | 4m3e 626 | 4kwg 627 | 5yun 628 | 3mqf 629 | 4pkw 630 | 3k3h 631 | 3t3v 632 | 3t1n 633 | 2bdl 634 | 2prj 635 | 5wh5 636 | 4qhc 637 | 4eoy 638 | 2bv4 639 | 1i7g 640 | 5iui 641 | 4og5 642 | 5evk 643 | 2fsa 644 | 3sdg 645 | 3g2i 646 | 1y2d 647 | 1c7f 648 | 1qkn 649 | 2etm 650 | 3o1g 651 | 3t3d 652 | 1bxo 653 | 2z5o 654 | 3bla 655 | 5o9p 656 | 5g3n 657 | 5v1y 658 | 4gq4 659 | 2vvo 660 | 4u0b 661 | 1opi 662 | 3sut 663 | 3wd2 664 | 4xm8 665 | 4kp4 666 | 1hy7 667 | 1g05 668 | 5aaa 669 | 5wmt 670 | 2fj0 671 | 1bxq 672 | 5t2b 673 | 1o6i 674 | 4xdo 675 | 5ez0 676 | 5wqj 677 | 5t8e 678 | 6g22 679 | 3o0u 680 | 2gfd 681 | 5fpp 682 | 1tuf 683 | 4v0i 684 | 4og6 685 | 3g4g 686 | 2std 687 | 1xnz 688 | 2dw7 689 | 4oue 690 | 6ds0 691 | 5jar 692 | 4ibm 693 | 1d5j 694 | 2hrp 695 | 1koj 696 | 1d7j 697 | 4ryl 698 | 2f6j 699 | 4eke 700 | 4btl 701 | 6b7d 702 | 3bwk 703 | 5aqg 704 | 4i80 705 | 1c3x 706 | 2qrq 707 | 1oif 708 | 2p9a 709 | 5f67 710 | 4mc9 711 | 4dpu 712 | 3il5 713 | 6bnk 714 | 4lh7 715 | 6ccl 716 | 4m3b 717 | 6drz 718 | 4ebw 719 | 6et8 720 | 1g9b 721 | 3vvz 722 | 5q12 723 | 1jys 724 | 1g9a 725 | 5q1c 726 | 4mc6 727 | 2gg9 728 | 5t2m 729 | 3gta 730 | 5q0w 731 | 5oa2 732 | 3mt9 733 | 5iql 734 | 5q0t 735 | 2gkl 736 | 1z95 737 | 6c91 738 | 2z4z 739 | 3syr 740 | 4g16 741 | 3qi3 742 | 1z6p 743 | 3p8o 744 | 1qpl 745 | 2pix 746 | 4crj 747 | 2cet 748 | 4wf6 749 | 4qfr 750 | 1y2c 751 | 4gh6 752 | 1ct8 753 | 3guz 754 | 1oyn 755 | 1d8f 756 | 4x6h 757 | 3gp0 758 | 2srt 759 | 4k63 760 | 1pwp 761 | 4k66 762 | 4ql8 763 | 4ie4 764 | 2fm5 765 | 3g4l 766 | 5ix1 767 | 5d1u 768 | 4y8c 769 | 2evl 770 | 5dde 771 | 5y7w 772 | 6clv 773 | 2fu8 774 | 3hg1 775 | 4xe0 776 | 5k1i 777 | 3c9e 778 | 1gpy 779 | 2gg2 780 | 5vdv 781 | 5eyz 782 | 2wc4 783 | 4qlk 784 | 3t3g 785 | 4xrq 786 | 3v5p 787 | 1exv 788 | 1std 789 | 5jjm 790 | 5cc2 791 | 4f9u 792 | 5jao 793 | 5dda 794 | 3eta 795 | 6f6u 796 | 6cee 797 | 4pl6 798 | 3ms9 799 | 4kwf 800 | 5q0s 801 | 5q1e 802 | 5o83 803 | 5lz7 804 | 5kq5 805 | 5xij 806 | 5kh3 807 | 4qfg 808 | 3ebo 809 | 2zdx 810 | 6q74 811 | 5bpe 812 | 4poj 813 | 4qgi 814 | 2n7b 815 | 1ow7 816 | 3sx9 817 | 2e92 818 | 2amv 819 | 4std 820 | 5ur9 821 | 2jdl 822 | 3ktr 823 | 1ogg 824 | 1onh 825 | 4ad6 826 | 3sl5 827 | 5v8o 828 | 4yrd 829 | 4dce 830 | 3rcd 831 | 3g4i 832 | 3zqt 833 | 3olf 834 | 1j1a 835 | 1aqj 836 | 3fq7 837 | 4cmo 838 | 3b66 839 | 4htp 840 | 5vdw 841 | 3l79 842 | 3usn 843 | 4i4e 844 | 3d27 845 | 2qrh 846 | 2wc3 847 | 4djh 848 | 1jif 849 | 3g58 850 | 3mt7 851 | 4yec 852 | 6b7c 853 | 1y2b 854 | 2v3u 855 | 2qlm 856 | 7std 857 | 2vpe 858 | 2qln 859 | 5wbk 860 | 5ftq 861 | 3fei 862 | 2nsx 863 | 4ebv 864 | 4b82 865 | 4pp3 866 | 5ddd 867 | 1l5r 868 | 4psb 869 | 4cnh 870 | 1azl 871 | 5evb 872 | 5dpx 873 | 4k9y 874 | 1jq3 875 | 2ggb 876 | 2gm9 877 | 4g2j 878 | 5lz8 879 | 2ylq 880 | 5kz0 881 | 5t8j 882 | 4cts 883 | 2j75 884 | 3mt8 885 | 1n3z 886 | 4rme 887 | 2gg8 888 | 4z2k 889 | 2ha7 890 | 4mi3 891 | 4k6i 892 | 3nfk 893 | 4pl4 894 | 6ce8 895 | 6bx6 896 | 1xom 897 | 4u0a 898 | 3djp 899 | 4zeb 900 | 5q0j 901 | 2wos 902 | 5vtb 903 | 1u9x 904 | 1g4k 905 | 1nc3 906 | 4gu9 907 | 6e4u 908 | 4b83 909 | 2p99 910 | 5dtj 911 | 3o8g 912 | 2fwp 913 | 4fod 914 | 3np9 915 | 5o9q 916 | 3v9b 917 | 1i1e 918 | 3jsw 919 | 4gu6 920 | 1snk 921 | 1bm6 922 | 4fnz 923 | 4qfs 924 | 4nj9 925 | 2ha2 926 | 4ej2 927 | 6ced 928 | 3jsi 929 | 2d1o 930 | 4mc2 931 | 6g2m 932 | 6e86 933 | 1l5q 934 | 3g0f 935 | 2jal 936 | 5y1u 937 | 2ya7 938 | 5jau 939 | 4c4n 940 | 3tmk 941 | 1t46 942 | 5ddb 943 | 5n6s 944 | 1ppm 945 | 5tbn 946 | 2wbg 947 | 6d28 948 | 5q14 949 | 3ik3 950 | 5w99 951 | 5q1h 952 | 4joa 953 | 5ha1 954 | 3m3z 955 | 4pzv 956 | 5dd9 957 | 2e91 958 | 1mem 959 | 1rdt 960 | 5vds 961 | 2xwd 962 | 5k32 963 | 3g4f 964 | 4x5y 965 | 3mtb 966 | 2cc7 967 | 4pkr 968 | 1gyx 969 | 5jas 970 | 1xoq 971 | 1u9v 972 | 3mtd 973 | 3kwb 974 | 5aqn 975 | 4ac3 976 | 2ylp 977 | 3p0g 978 | 3bz3 979 | 1xow 980 | 3ew2 981 | 1akq 982 | 5da3 983 | 4lh6 984 | 1db5 985 | 1g27 986 | 2ao6 987 | 5z9e 988 | 5zun 989 | 4cwb 990 | 2ccc 991 | 5tbp 992 | 1nl6 993 | 4pkv 994 | 2ww2 995 | 3upz 996 | 5aab 997 | 2ha4 998 | 3mss 999 | 1zkn 1000 | 4y87 1001 | 2pyi 1002 | 2yhd 1003 | 3rw9 1004 | 3f7h 1005 | 4q9s 1006 | 2g9u 1007 | 4jt9 1008 | 5twh 1009 | 4bj8 1010 | 4pl3 1011 | 2y2h 1012 | 4mi9 1013 | 5cdh 1014 | 4n5g 1015 | 7gpb 1016 | 2wr8 1017 | 3i7b 1018 | 1q9m 1019 | 1p2g 1020 | 4kao 1021 | 5l8n 1022 | 1bl4 1023 | 3iad 1024 | 1q6k 1025 | 4i31 1026 | 4fob 1027 | 5mlj 1028 | 5hm3 1029 | 2oz7 1030 | 5ehq 1031 | 4u0d 1032 | 6b2q 1033 | 4m3f 1034 | 3tcg 1035 | 6ccq 1036 | 4x0u 1037 | 1y6q 1038 | 3iof 1039 | 5db0 1040 | 1n4k 1041 | 4wht 1042 | 4dpy 1043 | 4cli 1044 | 3msc 1045 | 2ylo 1046 | 4x7q 1047 | 1g2a 1048 | 4arb 1049 | 5ncy 1050 | 1zaj 1051 | 3qt6 1052 | 3npa 1053 | 5aqh 1054 | 5oku 1055 | 1yon 1056 | 3ekn 1057 | 2bb7 1058 | 1akr 1059 | 5h2u 1060 | 4cfe 1061 | 4why 1062 | 3ril 1063 | 5q1d 1064 | 5aqo 1065 | 4cxw 1066 | 5osy 1067 | 4m8h 1068 | 1h5u 1069 | 5yea 1070 | 5t2g 1071 | 1c50 1072 | 5l3j 1073 | 4cxy 1074 | 6cco 1075 | 1ow8 1076 | 4k4j 1077 | 5q19 1078 | 5oxg 1079 | 3sus 1080 | 3kw9 1081 | 5wqk 1082 | 6f8u 1083 | 4i33 1084 | 4z2j 1085 | 1y2e 1086 | 4xpj 1087 | 6h0b 1088 | 2wor 1089 | 3ldq 1090 | 3ebp 1091 | 1bqo 1092 | 3ook 1093 | 3l7b 1094 | 1ow6 1095 | 5ye9 1096 | 2off 1097 | 1noj 1098 | 2aig 1099 | 1iup 1100 | 5eou 1101 | 5db2 1102 | 4wcu 1103 | 3ewc 1104 | 6ce6 1105 | 5fto 1106 | 4zei 1107 | 4b80 1108 | 3qi4 1109 | 2xi7 1110 | 2bqv 1111 | 5fkj 1112 | 4wj7 1113 | 6ez6 1114 | 1yhm 1115 | 2z92 1116 | 3sz9 1117 | 5ytu 1118 | 6f8t 1119 | 3amv 1120 | 3eyf 1121 | 5iug 1122 | 5d7a 1123 | 4clj 1124 | 5fum 1125 | 3v5t 1126 | 3ms7 1127 | 1yqy 1128 | 3aox 1129 | 4yjn 1130 | 3o4l 1131 | 2ax9 1132 | 5yto 1133 | 2wed 1134 | 3ozj 1135 | 2whp 1136 | 2qrg 1137 | 2gg5 1138 | 1k08 1139 | 2flh 1140 | 1l5s 1141 | 3n51 1142 | 2vpg 1143 | 5jat 1144 | 6drx 1145 | 4ktc 1146 | 4k8a 1147 | 2zof 1148 | 5aa9 1149 | 1kcs 1150 | 1y4z 1151 | 5oa6 1152 | 4du8 1153 | 2xwe 1154 | 3ms4 1155 | 2y2j 1156 | 6chn 1157 | 5q0l 1158 | 5aa8 1159 | 2qdt 1160 | 4a16 1161 | 3u8d 1162 | 5t28 1163 | 4xkb 1164 | 4hgl 1165 | 4l4v 1166 | 2gg3 1167 | 5ddf 1168 | 4ra1 1169 | 3t3u 1170 | 1ciz 1171 | 2j7x 1172 | 1x8d 1173 | 1kvo 1174 | 1b8y 1175 | 4yik 1176 | 1osv 1177 | 2hdx 1178 | 1k06 1179 | 3g1m 1180 | 5aqf 1181 | 1d7x 1182 | 5yf1 1183 | 3b5r 1184 | 3r0h 1185 | 6b41 1186 | 4mic 1187 | 2rin 1188 | 3bpc 1189 | 2e5y 1190 | 1n5r 1191 | 2j77 1192 | 1gag 1193 | 3djo 1194 | 4zec 1195 | 5xwr 1196 | 5d1s 1197 | 1uz1 1198 | 3sl8 1199 | 2j79 1200 | 3r5m 1201 | 3b65 1202 | 2e95 1203 | 3t3e 1204 | 5cj6 1205 | 1nok 1206 | 5wpb 1207 | 1hfs 1208 | 6e5x 1209 | 5evd 1210 | 5ikb 1211 | 5aqr 1212 | 3p8n 1213 | 5q0z 1214 | 1dg9 1215 | 3qt7 1216 | 5jah 1217 | 5ax9 1218 | 2q96 1219 | 2j7f 1220 | 5q1g 1221 | 2y2p 1222 | 5v84 1223 | 4pji 1224 | -------------------------------------------------------------------------------- /data/timesplit_no_lig_overlap_val: -------------------------------------------------------------------------------- 1 | 4lp9 2 | 1me7 3 | 2zv9 4 | 2qo8 5 | 1cw2 6 | 3k5c 7 | 2o65 8 | 4kqq 9 | 3rdv 10 | 1d4w 11 | 1q4l 12 | 4b5w 13 | 4bgg 14 | 4mm5 15 | 3iej 16 | 3ftu 17 | 830c 18 | 2xye 19 | 1olu 20 | 2wk2 21 | 4pxf 22 | 5o0j 23 | 1my2 24 | 5czm 25 | 4jit 26 | 5mb1 27 | 1sqp 28 | 3zlw 29 | 4xqu 30 | 3hkq 31 | 6fns 32 | 5e0l 33 | 2p8o 34 | 4gzw 35 | 3n87 36 | 1lhc 37 | 4itj 38 | 4m7c 39 | 4olh 40 | 4q1e 41 | 5l7e 42 | 3faa 43 | 5vqx 44 | 3pka 45 | 5x54 46 | 5a9u 47 | 4n9e 48 | 4est 49 | 1il9 50 | 4igr 51 | 3t2t 52 | 6dar 53 | 3gol 54 | 3vbg 55 | 2ydk 56 | 4zpf 57 | 5zo7 58 | 4xnw 59 | 1fpy 60 | 2r1y 61 | 6m8w 62 | 2jds 63 | 5icx 64 | 1hwr 65 | 6bj2 66 | 4b4m 67 | 1zsb 68 | 4do3 69 | 3t3i 70 | 1f8a 71 | 2ke1 72 | 5ezx 73 | 3p78 74 | 4rvm 75 | 3ovn 76 | 5wzv 77 | 4udb 78 | 1okz 79 | 1mpl 80 | 5npc 81 | 5ff6 82 | 1hlf 83 | 1nvq 84 | 4bhf 85 | 4y4g 86 | 5mkz 87 | 2o0u 88 | 3bcs 89 | 1wvc 90 | 4fsl 91 | 3oz1 92 | 6dgt 93 | 1me8 94 | 2puy 95 | 4odp 96 | 1hpx 97 | 4nrq 98 | 1z2b 99 | 3uik 100 | 3mfv 101 | 3vqh 102 | 4w9g 103 | 4xek 104 | 4jok 105 | 2wap 106 | 1g50 107 | 4j0p 108 | 2o9a 109 | 3m94 110 | 4i1c 111 | 5a82 112 | 4i9h 113 | 1k1i 114 | 4uro 115 | 2f7i 116 | 5fpk 117 | 2lgf 118 | 4l7f 119 | 1g3d 120 | 4ir5 121 | 3mta 122 | 3jzg 123 | 5f94 124 | 4nrt 125 | 4yax 126 | 5nhv 127 | 2xtk 128 | 4qh7 129 | 1tok 130 | 4b6p 131 | 3rg2 132 | 3q8d 133 | 3obu 134 | 4awj 135 | 3daj 136 | 2j50 137 | 5l2z 138 | 5bml 139 | 2bba 140 | 5n34 141 | 2xvn 142 | 1dpu 143 | 5fnt 144 | 1jyc 145 | 4zz1 146 | 6hm7 147 | 4rrv 148 | 4rww 149 | 5orv 150 | 3qo2 151 | 3uii 152 | 6d1x 153 | 3juq 154 | 4qk4 155 | 6mr5 156 | 5hjc 157 | 2p4s 158 | 2hnc 159 | 1k4g 160 | 4g0c 161 | 2y5g 162 | 4u3f 163 | 3tv5 164 | 1i3z 165 | 4mw7 166 | 3n2c 167 | 6cvw 168 | 3v66 169 | 3wzp 170 | 3s7m 171 | 5ujv 172 | 1p06 173 | 3ipy 174 | 4wkt 175 | 4ie0 176 | 5fot 177 | 5i59 178 | 5za9 179 | 4gii 180 | 4h2o 181 | 4yrs 182 | 5a6h 183 | 2xo8 184 | 4e3n 185 | 4m5k 186 | 3dga 187 | 6fse 188 | 6ck6 189 | 1sqc 190 | 4x1r 191 | 3dnj 192 | 3rvi 193 | 2a58 194 | 4bf6 195 | 3zlk 196 | 4mbj 197 | 4tpm 198 | 4d8c 199 | 1ejn 200 | 4yt6 201 | 2x7x 202 | 4qp1 203 | 4de3 204 | 5yg4 205 | 1x7b 206 | 5n9s 207 | 2fme 208 | 1ydt 209 | 2bdf 210 | 6baw 211 | 6fsd 212 | 2xn3 213 | 4tk0 214 | 3q4j 215 | 1u9l 216 | 1oqp 217 | 5htz 218 | 4glr 219 | 5kj0 220 | 5ukl 221 | 3fun 222 | 4wk2 223 | 4ht6 224 | 5hv1 225 | 1uze 226 | 4bcc 227 | 3ff6 228 | 5if6 229 | 1tsm 230 | 2r59 231 | 3iqh 232 | 2v7a 233 | 5d10 234 | 5nvh 235 | 3eqr 236 | 1jq9 237 | 1u1b 238 | 6cer 239 | 5uq9 240 | 1u3s 241 | 5icy 242 | 3exh 243 | 2oqs 244 | 1pzp 245 | 1d4i 246 | 4x6p 247 | 4mb9 248 | 5emk 249 | 1iky 250 | 6b7f 251 | 3chq 252 | 3h5s 253 | 5zmq 254 | 4ib5 255 | 2wej 256 | 6fjm 257 | 5ewa 258 | 2igx 259 | 2z78 260 | 5lpm 261 | 4wet 262 | 3lxl 263 | 2xba 264 | 5wbl 265 | 5zla 266 | 2x6x 267 | 4mw9 268 | 5t2d 269 | 4j3m 270 | 4aqh 271 | 3lbk 272 | 4djp 273 | 4odl 274 | 4x6j 275 | 1ero 276 | 5f3t 277 | 4k3q 278 | 5ta4 279 | 1caq 280 | 2eg7 281 | 1f73 282 | 3rxg 283 | 6ezq 284 | 1qkt 285 | 5l3e 286 | 5c28 287 | 4pp9 288 | 4bgk 289 | 3iaf 290 | 5vrp 291 | 5zz4 292 | 5ur5 293 | 3ft2 294 | 5ech 295 | 4jjq 296 | 5iz6 297 | 5dhr 298 | 4l2g 299 | 4r17 300 | 3wk6 301 | 4h1e 302 | 2aq9 303 | 5g1n 304 | 3zm9 305 | 5c4l 306 | 5mfs 307 | 1fzj 308 | 2ltw 309 | 4x7i 310 | 4c94 311 | 2cfg 312 | 2va5 313 | 3vb6 314 | 2hob 315 | 5ah2 316 | 5syn 317 | 3g6g 318 | 3rwj 319 | 5sz4 320 | 4f9v 321 | 5n2d 322 | 3n9r 323 | 5ldo 324 | 3vb7 325 | 1sqo 326 | 3drg 327 | 5j9y 328 | 6b96 329 | 4yz9 330 | 1vcj 331 | 5epr 332 | 4tx6 333 | 3dz6 334 | 3czv 335 | 5v49 336 | 1ahy 337 | 3wzq 338 | 1bq4 339 | 5u8c 340 | 6bj3 341 | 2qnb 342 | 4a9m 343 | 3d4f 344 | 5oui 345 | 5wmg 346 | 6ma4 347 | 4x5q 348 | 5cbr 349 | 6msy 350 | 5avi 351 | 1g3b 352 | 2wi4 353 | 3kjn 354 | 4dhn 355 | 4o7e 356 | 5kit 357 | 5y5t 358 | 3hfj 359 | 2qd8 360 | 5vsj 361 | 2y2i 362 | 5m0m 363 | 3tcp 364 | 4bhz 365 | 1jd6 366 | 5idn 367 | 4zzx 368 | 4kn4 369 | 2a5c 370 | 6hly 371 | 1au2 372 | 4jbo 373 | 5cgj 374 | 3ske 375 | 3lq2 376 | 4pxm 377 | 2wxg 378 | 5tb6 379 | 2vc7 380 | 3iw4 381 | 5hct 382 | 3skf 383 | 5lyy 384 | 3fmz 385 | 4p5z 386 | 5ktw 387 | 6e4w 388 | 1cx9 389 | 6em7 390 | 4mjr 391 | 4u7t 392 | 3rde 393 | 4ux4 394 | 4i6f 395 | 3l3x 396 | 4ie6 397 | 4j70 398 | 1jd0 399 | 4iaw 400 | 1szm 401 | 2afw 402 | 3ess 403 | 3sap 404 | 1olx 405 | 1bzh 406 | 5hfb 407 | 4x3h 408 | 5we9 409 | 3zsw 410 | 5ny6 411 | 1hn2 412 | 3l3z 413 | 4qp2 414 | 1d4p 415 | 4xkc 416 | 2is0 417 | 6c7e 418 | 5zku 419 | 4fai 420 | 6g9a 421 | 4xu3 422 | 5dry 423 | 4d8z 424 | 3zcz 425 | 3kbz 426 | 2y59 427 | 4nal 428 | 4rpv 429 | 4yje 430 | 3vf8 431 | 4bqx 432 | 4z9l 433 | 4ep2 434 | 4ylk 435 | 5mme 436 | 4dht 437 | 2uy4 438 | 6mu3 439 | 3kx1 440 | 5o0s 441 | 4bch 442 | 5c4k 443 | 2br1 444 | 4ddh 445 | 2f9k 446 | 2w2i 447 | 4ogn 448 | 4up5 449 | 5o4y 450 | 5hjd 451 | 2qw1 452 | 5y8z 453 | 4kqr 454 | 1o2t 455 | 6e05 456 | 3u7l 457 | 2mip 458 | 3hvg 459 | 2p59 460 | 4d3h 461 | 4pl5 462 | 3tzd 463 | 2vnp 464 | 4e3m 465 | 3vgc 466 | 5bqi 467 | 1b7h 468 | 1lhu 469 | 3rlr 470 | 3h22 471 | 2wnc 472 | 2wot 473 | 5d1t 474 | 3mo0 475 | 4wn5 476 | 3p3u 477 | 1nfs 478 | 4e90 479 | 5aqu 480 | 1bmq 481 | 3kwz 482 | 6f6n 483 | 4rj5 484 | 4omd 485 | 6min 486 | 1ujj 487 | 4ppa 488 | 4uxl 489 | 5y3n 490 | 6df2 491 | 4wvl 492 | 1xt3 493 | 5oaj 494 | 4a9r 495 | 5mli 496 | 4p4e 497 | 3juo 498 | 1z9g 499 | 2ykc 500 | 5a0e 501 | 3g0w 502 | 5t9w 503 | 1sqa 504 | 3wci 505 | 1fkw 506 | 5u4g 507 | 4mfe 508 | 4kpx 509 | 3nti 510 | 3azb 511 | 2xog 512 | 3c3r 513 | 2buc 514 | 1hyz 515 | 4dcd 516 | 6azl 517 | 3t3d 518 | 3q4l 519 | 4few 520 | 1q95 521 | 4u0b 522 | 3b7u 523 | 4bo4 524 | 4o10 525 | 5wmt 526 | 5v9t 527 | 5aok 528 | 1jtq 529 | 5uit 530 | 2vgc 531 | 2gfd 532 | 3mna 533 | 1aqc 534 | 4xtt 535 | 4z0d 536 | 4ty9 537 | 2yiv 538 | 2hrp 539 | 4zh2 540 | 2z4o 541 | 1qku 542 | 2xdw 543 | 4n7j 544 | 4yp1 545 | 3exf 546 | 4c6z 547 | 6ccu 548 | 2wxn 549 | 1bwb 550 | 2gvf 551 | 1hiy 552 | 5c4t 553 | 2za5 554 | 2xkf 555 | 4q18 556 | 1o2p 557 | 5th2 558 | 4dj7 559 | 3eyd 560 | 4j0r 561 | 2m3o 562 | 2b53 563 | 4m3b 564 | 2izl 565 | 2vtr 566 | 2x6d 567 | 2i0a 568 | 5ehg 569 | 6cw4 570 | 4c37 571 | 3cwj 572 | 1azm 573 | 2qci 574 | 5sz0 575 | 2gkl 576 | 2z4z 577 | 6awo 578 | 1v11 579 | 4l53 580 | 3p55 581 | 2ynn 582 | 2vu3 583 | 4dli 584 | 2bcd 585 | 4l0s 586 | 4uda 587 | 3m37 588 | 5j5t 589 | 2p16 590 | 4gh6 591 | 1mfg 592 | 3s3i 593 | 4j73 594 | 2v5x 595 | 2h4n 596 | 4jsz 597 | 4wk1 598 | 4igt 599 | 4k63 600 | 3qqk 601 | 16pk 602 | 5aom 603 | 1hyv 604 | 5a3w 605 | 3veh 606 | 3g4l 607 | 2ph8 608 | 5mkx 609 | 5c4u 610 | 4gto 611 | 3cj5 612 | 4prj 613 | 2vd7 614 | 5duc 615 | 3odi 616 | 6bg5 617 | 1qwu 618 | 5jn8 619 | 1v1m 620 | 1qpe 621 | 5v3r 622 | 2wc4 623 | 2vte 624 | 1a52 625 | 4dhq 626 | 2qta 627 | 6ccy 628 | 4jog 629 | 4bgy 630 | 5u9i 631 | 3az9 632 | 1gt1 633 | 2jew 634 | 3pdc 635 | 1n3i 636 | 5fyx 637 | 4f49 638 | 4nzn 639 | 6hm2 640 | 4a4l 641 | 5xij 642 | 5vk0 643 | 4xsx 644 | 2aj8 645 | 4odq 646 | 2n7b 647 | 4ygf 648 | 2a4q 649 | 2jc0 650 | 4jsa 651 | 1inq 652 | 3dc3 653 | 5tob 654 | 4urn 655 | 6bik 656 | 4ju4 657 | 5nya 658 | 5oh2 659 | 5znr 660 | 5ct2 661 | 3u4u 662 | 4x7h 663 | 3max 664 | 3rbm 665 | 3krj 666 | 1aj6 667 | 1pmv 668 | 5n0e 669 | 4nhy 670 | 4oem 671 | 6fi4 672 | 4e3j 673 | 1fq4 674 | 5myr 675 | 2hkf 676 | 1os0 677 | 3rqg 678 | 4ivc 679 | 5c7b 680 | 3lq4 681 | 1u6q 682 | 1qxz 683 | 1l5r 684 | 4xxh 685 | 3m40 686 | 5or9 687 | 4okg 688 | 4d89 689 | 2gm9 690 | 5x33 691 | 4de0 692 | 4gr8 693 | 5lz8 694 | 1p93 695 | 2brp 696 | 2gg8 697 | 6fdt 698 | 5cxh 699 | 1jvu 700 | 3wp1 701 | 1fzm 702 | 5cxa 703 | 2gbg 704 | 2g78 705 | 5aml 706 | 2y34 707 | 2qnp 708 | 1v16 709 | 1njj 710 | 2a5u 711 | 4z88 712 | 4wmx 713 | 5vo2 714 | 4fod 715 | 2pou 716 | 3jsw 717 | 2ow2 718 | 5g3m 719 | 3odl 720 | 3o9e 721 | 3eyh 722 | 4ej2 723 | 3c4e 724 | 4b6f 725 | 1pl0 726 | 3pb8 727 | 6fap 728 | 4iax 729 | 2bua 730 | 6fgg 731 | 2o4h 732 | 4uwh 733 | 5wbf 734 | 2yxj 735 | 1ff1 736 | 2giu 737 | 1qbt 738 | 2ovq 739 | 4bak 740 | 2y3p 741 | 2iwu 742 | 3hvi 743 | 2w0x 744 | 3fcl 745 | 1zpa 746 | 5czb 747 | 3t1l 748 | 2cfd 749 | 3k3g 750 | 4cfw 751 | 2e91 752 | 5op8 753 | 3hig 754 | 6h7y 755 | 3mtb 756 | 4eb9 757 | 4lkg 758 | 5ehv 759 | 5ier 760 | 4ode 761 | 1xoq 762 | 5d6p 763 | 3kwa 764 | 5np8 765 | 5v82 766 | 6ma1 767 | 3bz3 768 | 3myq 769 | 4j0s 770 | 4f4p 771 | 4lh6 772 | 1uef 773 | 4j3d 774 | 4yx4 775 | 4amx 776 | 4ptg 777 | 2c97 778 | 4ec4 779 | 4r1v 780 | 1zc9 781 | 4nuf 782 | 3g2u 783 | 6hlx 784 | 5vij 785 | 2x4o 786 | 6hlz 787 | 4lkj 788 | 3s75 789 | 2gz8 790 | 1gvk 791 | 2yhd 792 | 3hqz 793 | 3pb7 794 | 1thr 795 | 4ris 796 | 5twh 797 | 4gql 798 | 3n3l 799 | 3acx 800 | 5yvx 801 | 3gy2 802 | 1xmu 803 | 5l6p 804 | 5l8n 805 | 4msn 806 | 4rz1 807 | 3f66 808 | 3ucj 809 | 5hcl 810 | 1t1r 811 | 3kce 812 | 3u15 813 | 1wbg 814 | 5khi 815 | 3er5 816 | 4qew 817 | 5mft 818 | 6eqp 819 | 5gsw 820 | 2qd7 821 | 4cli 822 | 3f9w 823 | 3msc 824 | 1jgl 825 | 3kid 826 | 1ymx 827 | 1ui0 828 | 3d1f 829 | 1pxl 830 | 5kos 831 | 3vzd 832 | 5fcz 833 | 3ara 834 | 4li6 835 | 5ks7 836 | 4wym 837 | 5j7q 838 | 4qsh 839 | 2ce9 840 | 5vqz 841 | 3o2m 842 | 4bcm 843 | 5orx 844 | 1i41 845 | 3c5u 846 | 4kai 847 | 6gjy 848 | 4tsz 849 | 5o0e 850 | 6drt 851 | 1y57 852 | 3kqb 853 | 3jup 854 | 5ork 855 | 3ikc 856 | 3gwu 857 | 4wke 858 | 4x7l 859 | 3lp1 860 | 5ivy 861 | 3f16 862 | 4c36 863 | 1w2x 864 | 2d06 865 | 1hbj 866 | 1ols 867 | 1iup 868 | 5aix 869 | 1ydd 870 | 5w4r 871 | 3h23 872 | 3rj7 873 | 4ish 874 | 1ebw 875 | 1fcy 876 | 1d09 877 | 5hdv 878 | 4x1n 879 | 5boj 880 | 2xn7 881 | 4b6s 882 | 3f82 883 | 4clj 884 | 4zzz 885 | 5j5d 886 | 2vts 887 | 1k08 888 | 3u3f 889 | 4jk6 890 | 4csy 891 | 6hth 892 | 2mnz 893 | 2vpg 894 | 2qd6 895 | 4jkw 896 | 3ml5 897 | 1ih0 898 | 4at5 899 | 5dgu 900 | 4g31 901 | 5n0d 902 | 5aa9 903 | 4u4s 904 | 5oa6 905 | 2wzm 906 | 4b4q 907 | 6fi1 908 | 6chn 909 | 1z4u 910 | 5aa8 911 | 1lpk 912 | 3cib 913 | 5d75 914 | 5x4o 915 | 1ydb 916 | 5dhq 917 | 5t28 918 | 4zz0 919 | 3evf 920 | 5vyy 921 | 6eip 922 | 1q63 923 | 3ldw 924 | 5tq4 925 | 5uxf 926 | 2j7x 927 | 4kil 928 | 1yda 929 | 3bc4 930 | 2ew5 931 | 6ee3 932 | 4yrr 933 | 3wax 934 | 3bzf 935 | 5ody 936 | 1k06 937 | 4j84 938 | 5l6h 939 | 5eok 940 | 5nne 941 | 5m6m 942 | 2a4r 943 | 3p1d 944 | 2ayp 945 | 3iux 946 | 4b0g 947 | 1jr1 948 | 4qo9 949 | 4bh4 950 | 4xt9 951 | 2ok1 952 | 2r7g 953 | 4uib 954 | 5mmn 955 | 5akj 956 | 3hs4 957 | 5wpb 958 | 6e5x 959 | 5vnd 960 | 5evd 961 | 5wlg 962 | 5l4m 963 | 4kiu 964 | 4own 965 | 5oh9 966 | 6arv 967 | 1xr9 968 | 4hv7 969 | -------------------------------------------------------------------------------- /data/timesplit_test: -------------------------------------------------------------------------------- 1 | 6qqw 2 | 6d08 3 | 6jap 4 | 6np2 5 | 6uvp 6 | 6oxq 7 | 6jsn 8 | 6hzb 9 | 6qrc 10 | 6oio 11 | 6jag 12 | 6moa 13 | 6hld 14 | 6i9a 15 | 6e4c 16 | 6g24 17 | 6jb4 18 | 6s55 19 | 6seo 20 | 6dyz 21 | 5zk5 22 | 6jid 23 | 5ze6 24 | 6qlu 25 | 6a6k 26 | 6qgf 27 | 6e3z 28 | 6te6 29 | 6pka 30 | 6g2o 31 | 6jsf 32 | 5zxk 33 | 6qxd 34 | 6n97 35 | 6jt3 36 | 6qtr 37 | 6oy1 38 | 6n96 39 | 6qzh 40 | 6qqz 41 | 6qmt 42 | 6ibx 43 | 6hmt 44 | 5zk7 45 | 6k3l 46 | 6cjs 47 | 6n9l 48 | 6ibz 49 | 6ott 50 | 6gge 51 | 6hot 52 | 6e3p 53 | 6md6 54 | 6hlb 55 | 6fe5 56 | 6uwp 57 | 6npp 58 | 6g2f 59 | 6mo7 60 | 6bqd 61 | 6nsv 62 | 6i76 63 | 6n53 64 | 6g2c 65 | 6eeb 66 | 6n0m 67 | 6uvy 68 | 6ovz 69 | 6olx 70 | 6v5l 71 | 6hhg 72 | 5zcu 73 | 6dz2 74 | 6mjq 75 | 6efk 76 | 6s9w 77 | 6gdy 78 | 6kqi 79 | 6ueg 80 | 6oxt 81 | 6oy0 82 | 6qr7 83 | 6i41 84 | 6cyg 85 | 6qmr 86 | 6g27 87 | 6ggb 88 | 6g3c 89 | 6n4e 90 | 6fcj 91 | 6quv 92 | 6iql 93 | 6i74 94 | 6qr4 95 | 6rnu 96 | 6jib 97 | 6izq 98 | 6qw8 99 | 6qto 100 | 6qrd 101 | 6hza 102 | 6e5s 103 | 6dz3 104 | 6e6w 105 | 6cyh 106 | 5zlf 107 | 6om4 108 | 6gga 109 | 6pgp 110 | 6qqv 111 | 6qtq 112 | 6gj6 113 | 6os5 114 | 6s07 115 | 6i77 116 | 6hhj 117 | 6ahs 118 | 6oxx 119 | 6mjj 120 | 6hor 121 | 6jb0 122 | 6i68 123 | 6pz4 124 | 6mhb 125 | 6uim 126 | 6jsg 127 | 6i78 128 | 6oxy 129 | 6gbw 130 | 6mo0 131 | 6ggf 132 | 6qge 133 | 6cjr 134 | 6oxp 135 | 6d07 136 | 6i63 137 | 6ten 138 | 6uii 139 | 6qlr 140 | 6sen 141 | 6oxv 142 | 6g2b 143 | 5zr3 144 | 6kjf 145 | 6qr9 146 | 6g9f 147 | 6e6v 148 | 5zk9 149 | 6pnn 150 | 6nri 151 | 6uwv 152 | 6ooz 153 | 6npi 154 | 6oip 155 | 6miv 156 | 6s57 157 | 6p8x 158 | 6hoq 159 | 6qts 160 | 6ggd 161 | 6pnm 162 | 6oy2 163 | 6oi8 164 | 6mhd 165 | 6agt 166 | 6i5p 167 | 6hhr 168 | 6p8z 169 | 6c85 170 | 6g5u 171 | 6j06 172 | 6qsz 173 | 6jbb 174 | 6hhp 175 | 6np5 176 | 6nlj 177 | 6qlp 178 | 6n94 179 | 6e13 180 | 6qls 181 | 6uil 182 | 6st3 183 | 6n92 184 | 6s56 185 | 6hzd 186 | 6uhv 187 | 6k05 188 | 6q36 189 | 6ic0 190 | 6hhi 191 | 6e3m 192 | 6qtx 193 | 6jse 194 | 5zjy 195 | 6o3y 196 | 6rpg 197 | 6rr0 198 | 6gzy 199 | 6qlt 200 | 6ufo 201 | 6o0h 202 | 6o3x 203 | 5zjz 204 | 6i8t 205 | 6ooy 206 | 6oiq 207 | 6od6 208 | 6nrh 209 | 6qra 210 | 6hhh 211 | 6m7h 212 | 6ufn 213 | 6qr0 214 | 6o5u 215 | 6h14 216 | 6jwa 217 | 6ny0 218 | 6jan 219 | 6ftf 220 | 6oxw 221 | 6jon 222 | 6cf7 223 | 6rtn 224 | 6jsz 225 | 6o9c 226 | 6mo8 227 | 6qln 228 | 6qqu 229 | 6i66 230 | 6mja 231 | 6gwe 232 | 6d3z 233 | 6oxr 234 | 6r4k 235 | 6hle 236 | 6h9v 237 | 6hou 238 | 6nv9 239 | 6py0 240 | 6qlq 241 | 6nv7 242 | 6n4b 243 | 6jaq 244 | 6i8m 245 | 6dz0 246 | 6oxs 247 | 6k2n 248 | 6cjj 249 | 6ffg 250 | 6a73 251 | 6qqt 252 | 6a1c 253 | 6oxu 254 | 6qre 255 | 6qtw 256 | 6np4 257 | 6hv2 258 | 6n55 259 | 6e3o 260 | 6kjd 261 | 6sfc 262 | 6qi7 263 | 6hzc 264 | 6k04 265 | 6op0 266 | 6q38 267 | 6n8x 268 | 6np3 269 | 6uvv 270 | 6pgo 271 | 6jbe 272 | 6i75 273 | 6qqq 274 | 6i62 275 | 6j9y 276 | 6g29 277 | 6h7d 278 | 6mo9 279 | 6jao 280 | 6jmf 281 | 6hmy 282 | 6qfe 283 | 5zml 284 | 6i65 285 | 6e7m 286 | 6i61 287 | 6rz6 288 | 6qtm 289 | 6qlo 290 | 6oie 291 | 6miy 292 | 6nrf 293 | 6gj5 294 | 6jad 295 | 6mj4 296 | 6h12 297 | 6d3y 298 | 6qr2 299 | 6qxa 300 | 6o9b 301 | 6ckl 302 | 6oir 303 | 6d40 304 | 6e6j 305 | 6i7a 306 | 6g25 307 | 6oin 308 | 6jam 309 | 6oxz 310 | 6hop 311 | 6rot 312 | 6uhu 313 | 6mji 314 | 6nrj 315 | 6nt2 316 | 6op9 317 | 6pno 318 | 6e4v 319 | 6k1s 320 | 6a87 321 | 6oim 322 | 6cjp 323 | 6pyb 324 | 6h13 325 | 6qrf 326 | 6mhc 327 | 6j9w 328 | 6nrg 329 | 6fff 330 | 6n93 331 | 6jut 332 | 6g2e 333 | 6nd3 334 | 6os6 335 | 6dql 336 | 6inz 337 | 6i67 338 | 6quw 339 | 6qwi 340 | 6npm 341 | 6i64 342 | 6e3n 343 | 6qrg 344 | 6nxz 345 | 6iby 346 | 6gj7 347 | 6qr3 348 | 6qr1 349 | 6s9x 350 | 6q4q 351 | 6hbn 352 | 6nw3 353 | 6tel 354 | 6p8y 355 | 6d5w 356 | 6t6a 357 | 6o5g 358 | 6r7d 359 | 6pya 360 | 6ffe 361 | 6d3x 362 | 6gj8 363 | 6mo2 364 | -------------------------------------------------------------------------------- /dataset/dataimporter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import torch 4 | from torch.utils.data import Dataset 5 | from tqdm import tqdm 6 | from .process_mols import get_receptor_input, read_molecule, process_ligand 7 | from commons.utils import read_strings_from_txt, log 8 | from torch_geometric.data import Data 9 | from openfold.utils.rigid_utils import Rotation, Rigid 10 | torch.cuda.empty_cache() 11 | 12 | class DataImporter(Dataset): 13 | 14 | def __init__( 15 | self, 16 | complex_names_path='~/ProteinLigandBinding/data/timesplit_no_lig_overlap_train', 17 | chain_radius=10, 18 | remove_h=True, 19 | cropping=True, 20 | crop_size=512, 21 | recenter=True, 22 | binding_site_cropping=True, 23 | spatial_cropping=False, 24 | spatial_and_contig_cropping=False, 25 | blackhole_init=False, 26 | count_repeats=False, 27 | rand_lig_coords=False, 28 | unseen_only=False, 29 | seed=0 30 | ): 31 | """" 32 | complex_names_path: 33 | Path to file with names of the directories containing the input files. 34 | chain_radius: 35 | Maximum distance to a ligand atom for a protein chain to still be included. 36 | remove_h: 37 | Whether or not to explicitely consider ligand H atoms. 38 | cropping: 39 | Whether or not to crop the protein. 40 | crop_size: 41 | Number of amino acids to crop the protein to. 42 | recenter: 43 | Whether or not the input protein should be recentred first. 44 | binding_site_cropping: 45 | Whether or not to do binding site cropping. 46 | spatial_cropping: 47 | Whether or not to do spatial cropping. 48 | spatial_and_contig_cropping: 49 | Whether or not to do randomly do spatial or contiguos cropping. 50 | blackhole_init: 51 | Whether or not to do black hole initialisation of ligand atom coordinates (rather than using an RDKit conformer). 52 | count_repeats: 53 | If using black hole initialisation, whether or not to index atoms that have the same chemical properties. 54 | rand_lig_coords: 55 | Whether or not do randomly rotate the ligand first. 56 | unseen_only: 57 | Whether or not to only load unseen PDBBind proteins. 58 | """ 59 | 60 | self.complex_names_path = complex_names_path 61 | self.chain_radius = chain_radius 62 | self.remove_h = remove_h 63 | self.cropping = cropping 64 | self.crop_size = crop_size 65 | self.recenter = recenter 66 | self.binding_site_cropping = binding_site_cropping 67 | self.spatial_cropping = spatial_cropping 68 | self.spatial_and_contig_cropping = spatial_and_contig_cropping 69 | self.blackhole_init = blackhole_init 70 | self.count_repeats = count_repeats 71 | self.rand_lig_coords = rand_lig_coords 72 | self.unseen_only = unseen_only 73 | self.seed = seed 74 | 75 | # get and set some useful paths 76 | self.directory = pathlib.Path(__file__).parent.resolve() 77 | if not os.path.basename(self.complex_names_path) == 'posebusters': 78 | self.dataset_dir = os.path.join(self.directory, '../data/PDBBind') 79 | else: 80 | self.dataset_dir = os.path.join(self.directory, '../data/posebusters_benchmark_set') 81 | self.processed_dir = os.path.join(self.directory, '../data/processed/', os.path.basename(self.complex_names_path)) 82 | if not os.path.exists(os.path.join(self.directory, '../data/processed/')): 83 | os.mkdir(os.path.join(self.directory, '../data/processed/')) 84 | if not os.path.exists(self.processed_dir): 85 | os.mkdir(self.processed_dir) 86 | 87 | # process data if not already done 88 | if ( 89 | not os.path.exists(os.path.join(self.processed_dir, 'rec_input_chain_ids.pt')) or 90 | (not os.path.exists(os.path.join(self.processed_dir, 'lig_input_framing.pt')) and not self.remove_h) or 91 | (not os.path.exists(os.path.join(self.processed_dir, 'lig_input_framing_noH.pt')) and self.remove_h) 92 | ): 93 | self._process() 94 | if not os.path.exists(os.path.join(self.processed_dir, 'rec_input_proc_ids.pt')): 95 | self._process_chain_ids() 96 | if ( 97 | self.unseen_only and 98 | ( 99 | (not os.path.exists(os.path.join(self.processed_dir, 'unseen_lig_input_framing.pt')) and not self.remove_h) or 100 | (not os.path.exists(os.path.join(self.processed_dir, 'unseen_lig_input_framing_noH.pt')) and self.remove_h) or 101 | not os.path.exists(os.path.join(self.processed_dir, 'unseen_rec_input_proc_ids.pt')) 102 | ) 103 | ): 104 | self._process_unseen() 105 | 106 | # load data into memory 107 | log('Loading data into memory.') 108 | if self.remove_h: 109 | self.ligands = torch.load(os.path.join(self.processed_dir, f"{'unseen_' if self.unseen_only else ''}lig_input_framing_noH.pt")) 110 | else: 111 | self.ligands = torch.load(os.path.join(self.processed_dir, f"{'unseen_' if self.unseen_only else ''}lig_input_framing.pt")) 112 | self.receptors = torch.load(os.path.join(self.processed_dir, f"{'unseen_' if self.unseen_only else ''}rec_input_proc_ids.pt")) 113 | 114 | # recentring and cropping 115 | if self.recenter: 116 | self.receptors, self.ligands = self._recenter_proteins_and_ligands(self.receptors, self.ligands) 117 | if self.cropping: 118 | log(f'Cropping sequences into chunks of {self.crop_size} residues.') 119 | elif to_rem := [ 120 | idx 121 | for idx, rec in enumerate(self.receptors) 122 | if len(rec['c_alpha_coords']) > 2000 123 | ]: 124 | for idx in sorted(to_rem, reverse=True): 125 | removed = self.receptors.pop(idx) 126 | del self.ligands[idx] 127 | removed_name = removed['complex_names'] 128 | log(f'Removed complex {removed_name} because it contains more than 2000 residues.') 129 | log(f'Removed {len(to_rem)} complexes in total.') 130 | if self.binding_site_cropping and self.cropping: 131 | log('Finding binding site residues for binding site cropping.') 132 | indices = self._find_binding_site_residues(self.receptors, self.ligands) 133 | log('Cropping proteins.') 134 | self.receptors = self._crop_all_to_size(self.receptors, indices, self.crop_size) 135 | if self.spatial_cropping and self.cropping: 136 | log('Using spatial cropping, finding binding site residues.') 137 | indices = self._find_all_binding_site_residues(self.receptors, self.ligands, self.crop_size) 138 | log('Cropping proteins.') 139 | self.receptors = self._non_contig_crop_to_size(self.receptors, indices, self.crop_size) 140 | if self.spatial_and_contig_cropping and self.cropping: 141 | log('Using spatial and contiguous cropping, finding binding site residues.') 142 | self.indices = self._find_all_binding_site_residues(self.receptors, self.ligands, self.crop_size) 143 | 144 | assert len(self.ligands) == len(self.receptors) 145 | log('Finished loading data into memory.') 146 | torch.cuda.empty_cache() 147 | 148 | def __len__(self): 149 | return len(self.ligands) 150 | 151 | def __getitem__(self, idx): 152 | ligand = self.ligands[idx] 153 | receptor = self.receptors[idx] 154 | seq_length = receptor['seq_length'].item() 155 | 156 | # cropping 157 | if ( 158 | self.cropping and not ( 159 | self.binding_site_cropping or self.spatial_cropping or self.spatial_and_contig_cropping 160 | ) 161 | ): 162 | aatype, c_alpha_coords, n_coords, c_coords, ri, \ 163 | chain_ids_processed, entity_ids_processed, sym_ids_processed = self._random_crop_to_size( 164 | receptor['aatype'], receptor['c_alpha_coords'], 165 | receptor['n_coords'], receptor['c_coords'], 166 | receptor['residue_index'], self.crop_size, 167 | seq_length, receptor['chain_ids_processed'], 168 | receptor['entity_ids_processed'], receptor['sym_ids_processed'], 169 | ) 170 | 171 | elif self.cropping and self.spatial_and_contig_cropping: 172 | use_spatial_cropping = bool(torch.randint(0, 2, (1,))) 173 | if use_spatial_cropping: 174 | aatype, c_alpha_coords, n_coords, c_coords, ri, \ 175 | chain_ids_processed, entity_ids_processed, sym_ids_processed = self._crop_to_size( 176 | receptor['aatype'], receptor['c_alpha_coords'], 177 | receptor['n_coords'], receptor['c_coords'], 178 | receptor['residue_index'], self.crop_size, 179 | seq_length, receptor['chain_ids_processed'], 180 | receptor['entity_ids_processed'], receptor['sym_ids_processed'], 181 | self.indices[idx], 182 | ) 183 | 184 | else: 185 | aatype, c_alpha_coords, n_coords, c_coords, ri, \ 186 | chain_ids_processed, entity_ids_processed, sym_ids_processed = self._random_crop_to_size( 187 | receptor['aatype'], receptor['c_alpha_coords'], 188 | receptor['n_coords'], receptor['c_coords'], 189 | receptor['residue_index'], self.crop_size, 190 | seq_length, receptor['chain_ids_processed'], 191 | receptor['entity_ids_processed'], receptor['sym_ids_processed'], 192 | ) 193 | 194 | else: 195 | aatype, c_alpha_coords, n_coords, c_coords, ri, \ 196 | chain_ids_processed, entity_ids_processed, sym_ids_processed =\ 197 | receptor['aatype'], receptor['c_alpha_coords'], receptor['n_coords'], \ 198 | receptor['c_coords'], receptor['residue_index'], receptor['chain_ids_processed'], \ 199 | receptor['entity_ids_processed'], receptor['sym_ids_processed'] 200 | 201 | # random ligand transformation 202 | c_alpha_coords, n_coords, c_coords, \ 203 | lig_atom_coords, pseudo_N, pseudo_C, \ 204 | true_lig_atom_coords, true_pseudo_N, true_pseudo_C = self._random_transform( 205 | c_alpha_coords, n_coords, c_coords, 206 | ligand['atom_coords'], ligand['pseudo_N'], ligand['pseudo_C'], 207 | ligand['true_atom_coords'], ligand['true_pseudo_N'], ligand['true_pseudo_C'] 208 | ) if self.rand_lig_coords else ( 209 | c_alpha_coords, n_coords, c_coords, 210 | ligand['atom_coords'], ligand['pseudo_N'], ligand['pseudo_C'], 211 | ligand['true_atom_coords'], ligand['true_pseudo_N'], ligand['true_pseudo_C'] 212 | ) 213 | 214 | # black hole initialisation 215 | lig_atom_features = ligand['atom_features'] 216 | if self.blackhole_init: 217 | lig_atom_coords = torch.zeros_like(lig_atom_coords) 218 | if self.count_repeats: 219 | repeat_idx = self._count_atom_repeats(ligand['atom_features']) 220 | lig_atom_features = torch.cat((lig_atom_features, repeat_idx), -1) 221 | 222 | # return data 223 | data = Data( 224 | complex_name = receptor['complex_names'], 225 | aatype = aatype, residue_index = ri, 226 | chain_ids_processed = chain_ids_processed, 227 | entity_ids_processed = entity_ids_processed, 228 | sym_ids_processed = sym_ids_processed, 229 | c_alpha_coords = c_alpha_coords, 230 | n_coords = n_coords, 231 | c_coords = c_coords, 232 | lig_atom_features = lig_atom_features, 233 | pseudo_N = pseudo_N, 234 | pseudo_C = pseudo_C, 235 | true_pseudo_N = true_pseudo_N, 236 | true_pseudo_C = true_pseudo_C, 237 | adjacency = ligand['adjacency'], 238 | lig_atom_coords = lig_atom_coords, 239 | true_lig_atom_coords = true_lig_atom_coords, 240 | distance_matrix = ligand['distance_matrix'], 241 | adjacency_bo = ligand['adjacency_bo'], 242 | ) 243 | del ligand, receptor 244 | return data 245 | 246 | # helper functions 247 | def _random_crop_to_size(self, aatype, c_alpha_coords, n_coords, c_coords, ri, crop_size, seq_length, chain_ids, entity_ids, sym_ids): 248 | """Crop randomly to `crop_size`, or keep as is if shorter than that.""" 249 | g = torch.Generator(device=aatype.device) 250 | num_res_crop_size = min(int(seq_length), crop_size) 251 | 252 | def _randint(lower, upper): 253 | return int(torch.randint( 254 | lower, 255 | upper + 1, 256 | (1,), 257 | device=aatype.device, 258 | generator=g, 259 | )[0]) 260 | 261 | n = seq_length - num_res_crop_size 262 | crop_start = _randint(0, n) 263 | aatype_sliced = aatype[crop_start:(crop_start+crop_size)] 264 | c_alpha_coords_sliced = c_alpha_coords[crop_start:(crop_start+crop_size)] 265 | n_coords_sliced = n_coords[crop_start:(crop_start+crop_size)] 266 | c_coords_sliced = c_coords[crop_start:(crop_start+crop_size)] 267 | ri_sliced = ri[crop_start:(crop_start+crop_size)] 268 | chain_ids_sliced = chain_ids[crop_start:(crop_start+crop_size)] 269 | entity_ids_sliced = entity_ids[crop_start:(crop_start+crop_size)] 270 | sym_ids_sliced = sym_ids[crop_start:(crop_start+crop_size)] 271 | 272 | return aatype_sliced, c_alpha_coords_sliced, n_coords_sliced, c_coords_sliced, ri_sliced, chain_ids_sliced, entity_ids_sliced, sym_ids_sliced 273 | 274 | def _find_binding_site_residues(self, receptors, ligands): 275 | indices = [] 276 | for receptor, ligand in tqdm(zip(receptors, ligands), total=len(receptors)): 277 | c_alpha_coords = receptor['c_alpha_coords'] 278 | lig_coords = ligand['true_atom_coords'] 279 | distances = torch.cdist(c_alpha_coords.to(dtype=torch.float32), lig_coords.to(dtype=torch.float32)) 280 | indices.append(torch.argmin(torch.min(distances, dim=1).values)) 281 | return indices 282 | 283 | def _find_all_binding_site_residues(self, receptors, ligands, crop_size): 284 | indices = [] 285 | for receptor, ligand in tqdm(zip(receptors, ligands), total=len(receptors)): 286 | c_alpha_coords = receptor['c_alpha_coords'] 287 | lig_coords = ligand['true_atom_coords'] 288 | distances = torch.cdist(c_alpha_coords.to(dtype=torch.float32), lig_coords.to(dtype=torch.float32)) 289 | distances_flattened = torch.min(distances, dim=1).values 290 | _, idx = distances_flattened.sort() 291 | indices.append(idx[:crop_size]) 292 | return indices 293 | 294 | def _crop_all_to_size(self, receptors, indices, crop_size): 295 | assert len(receptors) == len(indices) 296 | for receptor, idx in tqdm(zip(receptors, indices), total=len(receptors)): 297 | seq_length = receptor['seq_length'].item() 298 | if seq_length < crop_size: 299 | aatype_sliced = receptor['aatype'] 300 | c_alpha_coords_sliced = receptor['c_alpha_coords'] 301 | n_coords_sliced = receptor['n_coords'] 302 | c_coords_sliced = receptor['c_coords'] 303 | ri_sliced = receptor['residue_index'] 304 | else: 305 | # get start and end indices, making sure that we crop to the full crop_size, if possible 306 | start = idx - crop_size/2 307 | end = idx + crop_size/2 308 | if start < 0: 309 | end -= start 310 | start = 0 311 | if end > seq_length: 312 | start -= (end - seq_length) 313 | end = seq_length 314 | start = int(start) 315 | end = int(end) 316 | aatype_sliced = receptor['aatype'][start:end] 317 | c_alpha_coords_sliced = receptor['c_alpha_coords'][start:end] 318 | n_coords_sliced = receptor['n_coords'][start:end] 319 | c_coords_sliced = receptor['c_coords'][start:end] 320 | ri_sliced = receptor['residue_index'][start:end] 321 | 322 | receptor['aatype'] = aatype_sliced 323 | receptor['c_alpha_coords'] = c_alpha_coords_sliced 324 | receptor['n_coords'] = n_coords_sliced 325 | receptor['c_coords'] = c_coords_sliced 326 | receptor['residue_index'] = ri_sliced 327 | 328 | return receptors 329 | 330 | def _crop_to_size(self, aatype, c_alpha_coords, n_coords, c_coords, ri, crop_size, seq_length, chain_ids, entity_ids, sym_ids, indices): 331 | if seq_length < crop_size: 332 | return ( 333 | aatype, 334 | c_alpha_coords, 335 | n_coords, 336 | c_coords, 337 | ri, 338 | chain_ids, 339 | entity_ids, 340 | sym_ids, 341 | ) 342 | else: 343 | return( 344 | aatype[indices], 345 | c_alpha_coords[indices], 346 | n_coords[indices], 347 | c_coords[indices], 348 | ri[indices], 349 | chain_ids[indices], 350 | entity_ids[indices], 351 | sym_ids[indices], 352 | ) 353 | 354 | def _non_contig_crop_to_size(self, receptors, indices, crop_size): 355 | for receptor, idx in tqdm(zip(receptors, indices), total=len(receptors)): 356 | seq_length = receptor['seq_length'].item() 357 | if seq_length < crop_size: 358 | aatype_sliced = receptor['aatype'] 359 | c_alpha_coords_sliced = receptor['c_alpha_coords'] 360 | n_coords_sliced = receptor['n_coords'] 361 | c_coords_sliced = receptor['c_coords'] 362 | ri_sliced = receptor['residue_index'] 363 | else: 364 | aatype_sliced = receptor['aatype'][idx] 365 | c_alpha_coords_sliced = receptor['c_alpha_coords'][idx] 366 | n_coords_sliced = receptor['n_coords'][idx] 367 | c_coords_sliced = receptor['c_coords'][idx] 368 | ri_sliced = receptor['residue_index'][idx] 369 | 370 | receptor['aatype'] = aatype_sliced 371 | receptor['c_alpha_coords'] = c_alpha_coords_sliced 372 | receptor['n_coords'] = n_coords_sliced 373 | receptor['c_coords'] = c_coords_sliced 374 | receptor['residue_index'] = ri_sliced 375 | 376 | return receptors 377 | 378 | def _recenter_proteins_and_ligands(self, receptors, ligands): 379 | for rec, lig in zip(receptors, ligands): 380 | centre_of_mass = torch.mean(rec['c_alpha_coords'], dim=0) 381 | new_c_alpha_coords = rec['c_alpha_coords'] - centre_of_mass 382 | new_n_coords = rec['n_coords'] - centre_of_mass 383 | new_c_coords = rec['c_coords'] - centre_of_mass 384 | new_true_atom_coords = lig['true_atom_coords'] - centre_of_mass 385 | new_true_pseudo_N = lig['true_pseudo_N'] - centre_of_mass 386 | new_true_pseudo_C = lig['true_pseudo_C'] - centre_of_mass 387 | rec['c_alpha_coords'] = new_c_alpha_coords 388 | rec['n_coords'] = new_n_coords 389 | rec['c_coords'] = new_c_coords 390 | lig['true_atom_coords'] = new_true_atom_coords 391 | lig['true_pseudo_N'] = new_true_pseudo_N 392 | lig['true_pseudo_C'] = new_true_pseudo_C 393 | 394 | return receptors, ligands 395 | 396 | def _count_atom_repeats(self, atom_features): 397 | repeat_indx = torch.zeros(atom_features.shape[0]).unsqueeze(-1) 398 | count_dict = {} 399 | for i, a_i in enumerate(atom_features): 400 | indx = 1 401 | a_i = a_i[0].item() # atomic number 402 | if a_i in count_dict: 403 | indx = count_dict[a_i] + 1 404 | count_dict[a_i] = indx 405 | repeat_indx[i] = indx 406 | return repeat_indx 407 | 408 | def _get_rand_rotation_matrix(self): 409 | import numpy as np 410 | 411 | randnums = np.random.uniform(size=(3,)) 412 | theta, phi, z = randnums 413 | theta = theta * 2.0* np.pi 414 | phi = phi * 2.0 * np.pi 415 | z = z * 2.0 416 | 417 | r = np.sqrt(z) 418 | V = ( 419 | np.sin(phi) * r, 420 | np.cos(phi) * r, 421 | np.sqrt(2.0 - z) 422 | ) 423 | st = np.sin(theta) 424 | ct = np.cos(theta) 425 | 426 | R = np.array(((ct, st, 0), (-st, ct, 0), (0, 0, 1))) 427 | M = (np.outer(V, V) - np.eye(3)).dot(R) 428 | 429 | return torch.tensor(M) 430 | 431 | def _random_transform( 432 | self, c_alpha_coords, n_coords, c_coords, 433 | lig_coords, pseudo_N, pseudo_C, 434 | true_lig_coords, true_pseudo_N, true_pseudo_C 435 | ): 436 | # rot_mats = self._get_rand_rotation_matrix() 437 | # T = Rotation(rot_mats=rot_mats, quats=None) 438 | g = torch.Generator(device=c_alpha_coords.device) 439 | T = Rigid( 440 | Rotation(rot_mats=torch.rand(3, 3, generator=g), quats=None), 25+torch.rand(3, generator=g) 441 | ) 442 | # c_alpha_coords = T.apply(c_alpha_coords) 443 | # n_coords = T.apply(n_coords) 444 | # c_coords = T.apply(c_coords) 445 | lig_coords = T.apply(lig_coords) 446 | pseudo_N = T.apply(pseudo_N) 447 | pseudo_C = T.apply(pseudo_C) 448 | # true_lig_coords = T.apply(true_lig_coords) 449 | # true_pseudo_N = T.apply(true_pseudo_N) 450 | # true_pseudo_C = T.apply(true_pseudo_C) 451 | return ( 452 | c_alpha_coords, n_coords, c_coords, 453 | lig_coords, pseudo_N, pseudo_C, 454 | true_lig_coords, true_pseudo_N, true_pseudo_C 455 | ) 456 | 457 | def get_feature_dimensions(self): 458 | lig, rec = self.ligands[0], self.receptors[0] 459 | lig_feat_dim = lig['atom_features'].size(-1) + 1 if self.count_repeats else lig['atom_features'].size(-1) 460 | rec_feat_dim = rec['aatype'].size(-1) 461 | return lig_feat_dim, rec_feat_dim 462 | 463 | def _process(self): 464 | log(f'Processing complexes from [{self.complex_names_path}] and saving them to [{self.processed_dir}].') 465 | complex_names = read_strings_from_txt(self.complex_names_path) 466 | if '1xn3' in complex_names: 467 | complex_names.remove('1xn3') # corrupt PDB file 468 | log(f'Loading {len(complex_names)} complexes.') 469 | 470 | ligs = [] 471 | if os.path.basename(self.complex_names_path) == 'posebusters': 472 | for name in tqdm(complex_names, desc='Loading ligands'): 473 | lig = read_molecule(os.path.join(self.dataset_dir, name, f'{name}_ligand.sdf'), remove_hs=self.remove_h) 474 | ligs.append(lig) 475 | rec_paths = [os.path.join(self.dataset_dir, name, f'{name}_protein.pdb') for name in complex_names] 476 | else: 477 | for name in tqdm(complex_names, desc='Loading ligands'): 478 | lig = read_molecule(os.path.join(self.dataset_dir, name, f'{name}_ligand.mol2'), remove_hs=self.remove_h) 479 | if lig is None: 480 | lig = read_molecule(os.path.join(self.dataset_dir, name, f'{name}_ligand.sdf'), remove_hs=self.remove_h) 481 | ligs.append(lig) 482 | rec_paths = [os.path.join(self.dataset_dir, name, f'{name}_protein_processed.pdb') for name in complex_names] 483 | 484 | # Get receptor input 485 | if not os.path.exists(os.path.join(self.processed_dir, 'rec_input_chain_ids.pt')): 486 | rec_input = [ 487 | get_receptor_input(r, l, c, cutoff=self.chain_radius) 488 | for r, l, c in tqdm( 489 | zip(rec_paths, ligs, complex_names), desc='Getting receptor input', total=len(complex_names) 490 | ) 491 | ] 492 | log('Saving receptor input.') 493 | torch.save(rec_input, os.path.join(self.processed_dir, 'rec_input_chain_ids.pt')) 494 | 495 | if not self.remove_h and not os.path.exists(os.path.join(self.processed_dir, 'lig_input_framing.pt')): 496 | self._process_ligands( 497 | ligs, complex_names, 'lig_input_framing.pt' 498 | ) 499 | if self.remove_h and not os.path.exists( 500 | os.path.join(self.processed_dir, 'lig_input_framing_noH.pt') 501 | ): 502 | self._process_ligands( 503 | ligs, complex_names, 'lig_input_framing_noH.pt' 504 | ) 505 | 506 | def _process_chain_ids(self): 507 | receptors = torch.load(os.path.join(self.processed_dir, 'rec_input_chain_ids.pt')) 508 | for rec in receptors: 509 | chain_ids = rec['chain_ids'] 510 | res_per_id = rec['n_res_per_chain'] 511 | total_seq_length = rec['seq_length'] 512 | unique_lengths = list(set(res_per_id.values())) 513 | chain_ids_processed = torch.empty(total_seq_length) 514 | entity_ids_processed = torch.empty(total_seq_length) 515 | sym_ids_processed = torch.empty(total_seq_length) 516 | start_idx = 0 517 | existing_ids = {} 518 | 519 | for id, length in res_per_id.items(): 520 | c = torch.tensor([chain_ids.index(id) for _ in range(length)]) 521 | e = torch.tensor([unique_lengths.index(length) for _ in range(length)]) 522 | if length in existing_ids.keys(): 523 | s = existing_ids[length] + 1 524 | else: 525 | s = 0 526 | chain_ids_processed[start_idx:start_idx+length] = c 527 | entity_ids_processed[start_idx:start_idx+length] = e 528 | sym_ids_processed[start_idx:start_idx+length] = torch.tensor([s for _ in range(length)]) 529 | start_idx += length 530 | existing_ids[length] = s 531 | 532 | rec['chain_ids_processed'] = chain_ids_processed 533 | rec['entity_ids_processed'] = entity_ids_processed 534 | rec['sym_ids_processed'] = sym_ids_processed 535 | 536 | torch.save(receptors, os.path.join(self.processed_dir, 'rec_input_proc_ids.pt')) 537 | 538 | def _process_unseen(self): 539 | unseen_pdb_ids = [ 540 | '6qqw', '6jap', '6np2', '6qrc', '6oio', '6jag', '6i9a', '6jb4', '6seo', '6jid', '5ze6', '6pka', 541 | '6n97', '6qtr', '6n96', '6qzh', '6qqz', '6k3l', '6cjs', '6n9l', '6ott', '6npp', '6nsv', '6n53', 542 | '6eeb', '6n0m', '6ovz', '5zcu', '6mjq', '6efk', '6gdy', '6kqi', '6ueg', '6qr7', '6g3c', '6iql', 543 | '6qr4', '6jib', '6qto', '6qrd', '6e5s', '5zlf', '6om4', '6qqv', '6qtq', '6os5', '6s07', '6mjj', 544 | '6jb0', '6uim', '6mo0', '6cjr', '6uii', '6sen', '6kjf', '6qr9', '6g9f', '6npi', '6oip', '6miv', 545 | '6qts', '6oi8', '6c85', '6qsz', '6jbb', '6np5', '6nlj', '6n94', '6e13', '6uil', '6n92', '6uhv', 546 | '6q36', '6qtx', '6rr0', '6ufo', '6oiq', '6qra', '6m7h', '6ufn', '6qr0', '6o5u', '6ny0', '6jan', 547 | '6ftf', '6jon', '6cf7', '6o9c', '6qqu', '6mja', '6r4k', '6h9v', '6py0', '6jaq', '6k2n', '6cjj', 548 | '6a73', '6qqt', '6qre', '6qtw', '6np4', '6n55', '6kjd', '6np3', '6jbe', '6qqq', '6j9y', '6h7d', 549 | '6jao', '6e7m', '6rz6', '6qtm', '6miy', '6jad', '6mj4', '6qr2', '6qxa', '6o9b', '6ckl', '6oir', 550 | '6oin', '6jam', '6uhu', '6mji', '6nt2', '6op9', '6e4v', '6a87', '6cjp', '6qrf', '6j9w', '6n93', 551 | '6nd3', '6os6', '6dql', '6qwi', '6npm', '6qrg', '6nxz', '6qr3', '6qr1', '6o5g', '6r7d', '6mo2' 552 | ] 553 | all_proteins = torch.load(os.path.join(self.processed_dir, 'rec_input_proc_ids.pt')) 554 | all_ligands = torch.load(os.path.join(self.processed_dir, 'lig_input_framing.pt')) 555 | all_ligands_noH = torch.load(os.path.join(self.processed_dir, 'lig_input_framing_noH.pt')) 556 | unseen_proteins, unseen_ligands, unseen_ligands_noH = [], [], [] 557 | for p, l, l_noH in zip(all_proteins, all_ligands, all_ligands_noH): 558 | if p['complex_names'] in unseen_pdb_ids: 559 | unseen_proteins.append(p) 560 | unseen_ligands.append(l) 561 | unseen_ligands_noH.append(l_noH) 562 | torch.save(unseen_proteins, os.path.join(self.processed_dir, 'unseen_rec_input_proc_ids.pt')) 563 | torch.save(unseen_ligands, os.path.join(self.processed_dir, 'unseen_lig_input_framing.pt')) 564 | torch.save(unseen_ligands_noH, os.path.join(self.processed_dir, 'unseen_lig_input_framing_noH.pt')) 565 | 566 | def _process_ligands(self, ligs, complex_names, file_name): 567 | lig_input = [ 568 | process_ligand(l, c, seed=self.seed) 569 | for l, c in tqdm( 570 | zip(ligs, complex_names), desc='Getting ligand input', total=len(complex_names) 571 | ) 572 | ] 573 | log('Saving ligand input.') 574 | torch.save(lig_input, os.path.join(self.processed_dir, file_name)) -------------------------------------------------------------------------------- /dataset/process_mols.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import numpy as np 3 | import torch 4 | from Bio.PDB import PDBParser 5 | from Bio.PDB.PDBExceptions import PDBConstructionWarning 6 | from Bio.PDB.Polypeptide import protein_letters_3to1 7 | from rdkit import Chem 8 | from rdkit.Chem import AllChem, GetDistanceMatrix 9 | from rdkit.Chem.rdmolops import GetAdjacencyMatrix 10 | from rdkit.Chem.rdchem import Conformer 11 | from rdkit.Geometry import Point3D 12 | from rdkit import RDLogger 13 | from scipy import spatial 14 | from openfold.np import residue_constants 15 | from itertools import chain 16 | from typing import Optional 17 | 18 | def read_molecule(file: str, sanitize: Optional[bool] = True, remove_hs: Optional[bool] = False): 19 | """Reads in ligand from an input file. 20 | Args: 21 | file : Path to ``.mol2``, ``.sdf`` or ``.pdb`` input file. 22 | sanitize : Whether to sanitizate the molecule. 23 | remove_hs : Whether to remove hydrogens. 24 | Returns: 25 | Ligand parsed from the input file. 26 | """ 27 | # Suppress RDKit warnings 28 | RDLogger.DisableLog('rdApp.*') 29 | if file.endswith('.mol2'): 30 | mol = Chem.MolFromMol2File(file, sanitize=False, removeHs=False) 31 | elif file.endswith('.sdf'): 32 | supplier = Chem.SDMolSupplier(file, sanitize=False, removeHs=False) 33 | mol = supplier[0] 34 | elif file.endswith('.pdb'): 35 | mol = Chem.MolFromPDBFile(file, sanitize=False, removeHs=False) 36 | else: 37 | raise ValueError('Unsupported file format.') 38 | 39 | try: 40 | if sanitize: 41 | Chem.SanitizeMol(mol) 42 | if remove_hs: 43 | mol = Chem.RemoveHs(mol, sanitize=sanitize) 44 | return mol 45 | except Exception: 46 | return None 47 | 48 | # LIGAND PROCESSING 49 | 50 | lig_features = { 51 | 'atomic_number': [1, 6, 7, 8, 9, 15, 16, 17, 35, 53, 'misc'], 52 | 'chirality': [ 53 | 'CHI_UNSPECIFIED', 54 | 'CHI_TETRAHEDRAL_CW', 55 | 'CHI_TETRAHEDRAL_CCW', 56 | 'CHI_OTHER' 57 | ], 58 | 'degree': [1, 2, 3, 4, 'misc'], 59 | 'formal_charge': [-1, 0, 1, 'misc'], 60 | 'numH': [0, 1, 2, 3, 'misc'], 61 | 'hybridisation': [ 62 | 'SP', 'SP2', 'SP3', 'SP3D', 'SP3D2', 'misc' 63 | ], 64 | 'is_aromatic': [False, True], 65 | 'is_in_ring': [False, True] 66 | } 67 | 68 | def reorder_atoms(ligand, scheme = 'canonicalatomrank'): 69 | """ 70 | Reorders ligand atoms based on global molecular information. 71 | Args: 72 | ligand (rdkit.Chem.rdchem.Mol object): 73 | Input ligand 74 | scheme (str): 75 | Scheme to use. Valid options are 'canonicalatomrank' 76 | and 'longestlinearchain'. 77 | Default: 'canonicalatomrank'. 78 | Returns: 79 | reordered_ligand (rdkit.Chem.rdchem.Mol object): 80 | Ligand with reordered atoms. 81 | """ 82 | if scheme == 'canonicalatomrank': 83 | neworder = list(zip(*sorted([(j, i) for i, j in enumerate(Chem.CanonicalRankAtoms(ligand))])))[1] 84 | elif scheme == 'longestlinearchain': 85 | smiles = Chem.MolToSmiles(ligand) 86 | neworder = list(map(int, ligand.GetProp("_smilesAtomOutputOrder")[1:-2].split(","))) 87 | else: 88 | raise ValueError("Invalid renumbering scheme.") 89 | reordered_ligand = Chem.RenumberAtoms(ligand, neworder) 90 | 91 | conf = ligand.GetConformer() 92 | reordered_conf = reordered_ligand.GetConformer() 93 | for i in range(ligand.GetNumAtoms()): 94 | x, y, z = Conformer.GetPositions(conf)[i] 95 | reordered_conf.SetAtomPosition(neworder.index(i), Point3D(float(x), float(y), float(z))) 96 | 97 | Chem.SanitizeMol(reordered_ligand) 98 | return reordered_ligand 99 | 100 | def process_ligand(mol, name, seed=None): 101 | mol = reorder_atoms(mol) 102 | conf = mol.GetConformer() 103 | true_lig_coords = conf.GetPositions() 104 | adjacency = GetAdjacencyMatrix(mol) 105 | adjacency_bo = GetAdjacencyMatrix(mol, useBO=True) 106 | distance_matrix = GetDistanceMatrix(mol) 107 | neighbour_list = get_neighbour_list(adjacency) 108 | try: 109 | lig_coords = get_rdkit_coords(mol, seed).numpy() 110 | except Exception as e: 111 | lig_coords = true_lig_coords 112 | with open('ligand_processing.log', 'a') as f: 113 | f.write(f'Generating RDKit conformer failed for\n{name}\n{str(e)}\n') 114 | f.flush() 115 | print(f'Generating RDKit conformer failed for {name}, {e}') 116 | pseudo_N, pseudo_C, true_pseudo_N, true_pseudo_C = get_adjacent_atoms(neighbour_list, lig_coords, true_lig_coords) 117 | 118 | features = {"atom_features": get_lig_atom_features(mol)} 119 | features["atom_coords"] = torch.tensor(lig_coords, dtype=torch.float32) 120 | features["true_atom_coords"] = torch.tensor(true_lig_coords, dtype=torch.float32) 121 | features["adjacency"] = torch.tensor(adjacency, dtype=torch.int) 122 | features["adjacency_bo"] = torch.tensor(adjacency_bo, dtype=torch.int) 123 | features["distance_matrix"] = torch.tensor(distance_matrix, dtype=torch.int) 124 | features["pseudo_N"] = pseudo_N.to(dtype=torch.float32) 125 | features["pseudo_C"] = pseudo_C.to(dtype=torch.float32) 126 | features["true_pseudo_N"] = true_pseudo_N.to(dtype=torch.float32) 127 | features["true_pseudo_C"] = true_pseudo_C.to(dtype=torch.float32) 128 | 129 | return features 130 | 131 | def get_rdkit_coords(mol, seed=None): 132 | ETKDG = AllChem.ETKDGv2() 133 | if seed is not None: 134 | ETKDG.randomSeed = seed 135 | new_conf_id = AllChem.EmbedMolecule(mol, ETKDG) 136 | if new_conf_id == -1: 137 | ETKDG.useRandomCoords = True 138 | AllChem.EmbedMolecule(mol, ETKDG) 139 | AllChem.MMFFOptimizeMolecule(mol, confId=0) 140 | conf = mol.GetConformer() 141 | lig_coords = conf.GetPositions() 142 | return torch.tensor(lig_coords, dtype=torch.float32) 143 | 144 | def get_lig_atom_features(mol): 145 | return torch.tensor( 146 | [ 147 | [ 148 | get_index( 149 | lig_features['atomic_number'], atom.GetAtomicNum() 150 | ), 151 | lig_features['chirality'].index( 152 | str(atom.GetChiralTag()) 153 | ), 154 | get_index( 155 | lig_features['degree'], atom.GetTotalDegree() 156 | ), 157 | get_index( 158 | lig_features['formal_charge'], 159 | atom.GetFormalCharge(), 160 | ), 161 | get_index( 162 | lig_features['numH'], atom.GetTotalNumHs() 163 | ), 164 | get_index( 165 | lig_features['hybridisation'], 166 | str(atom.GetHybridization()), 167 | ), 168 | lig_features['is_aromatic'].index( 169 | atom.GetIsAromatic() 170 | ), 171 | get_index( 172 | lig_features['is_in_ring'], 173 | atom.IsInRing(), 174 | ), 175 | ] 176 | for atom in mol.GetAtoms() 177 | ] 178 | ) 179 | 180 | def get_index(l, e): 181 | try: 182 | return l.index(e) 183 | except Exception: 184 | return len(l) - 1 185 | 186 | def get_neighbour_list(adjacency): 187 | neighbour_list = [] 188 | for row in adjacency: 189 | neighbours = [neighbour for neighbour, entry in enumerate(row) if entry == 1] 190 | neighbour_list.append(neighbours) 191 | return neighbour_list 192 | 193 | def get_adjacent_atoms(neighbour_list, lig_coords, true_lig_coords): 194 | """ 195 | For each ligand atom, get the coordinates of the two adjacent ligand atoms with the 196 | lowest indices. These will be used as pseudo N and C atoms, analogously to the N and 197 | C atoms for the protein. 198 | 199 | Procedure: 200 | (1) If a given atom has two or more bonds, then choose the two neighbours two with the lowest indices. 201 | If the ligand atoms were reordered using the canonical atom ranking, these would be the two with 202 | the lowest priority. 203 | (2) If it has only one bond, obtain the coordinates of a dummy atom as follows: 204 | (1) Compute the bond vector of the single bond ("x_y"). 205 | (2) Copy that vector ("x_z"). 206 | (3) While keeping the x and y coordinates constant, find the z coord such that the dot product of the two bond vectors x_y and x_z is 0. 207 | (4) Subtract the new x_z bond vector from the atom coordinate to get the coordinates of the dummy atom. 208 | """ 209 | 210 | pseudo_N, pseudo_C = torch.empty(0, 3), torch.empty(0, 3) 211 | true_pseudo_N, true_pseudo_C = torch.empty(0, 3), torch.empty(0, 3) 212 | 213 | for i in range(len(neighbour_list)): 214 | pseudo_N = torch.cat( 215 | (pseudo_N, torch.tensor(lig_coords[neighbour_list[i][0]]).unsqueeze(0)), 0 216 | ) 217 | true_pseudo_N = torch.cat( 218 | (true_pseudo_N, torch.tensor(true_lig_coords[neighbour_list[i][0]]).unsqueeze(0)), 0 219 | ) 220 | if len(neighbour_list[i]) >= 2: 221 | pseudo_C = torch.cat( 222 | (pseudo_C, torch.tensor(lig_coords[neighbour_list[i][1]]).unsqueeze(0)), 0 223 | ) 224 | true_pseudo_C = torch.cat( 225 | (true_pseudo_C, torch.tensor(true_lig_coords[neighbour_list[i][1]]).unsqueeze(0)), 0 226 | ) 227 | else: 228 | origin = lig_coords[i] 229 | N = lig_coords[neighbour_list[i][0]] 230 | x_y = origin - N 231 | 232 | if abs(x_y[-1]) <= 0.0001: 233 | x_y[-1] = 0.0001 if x_y[-1] > 0 else -0.0001 234 | x_z = x_y.copy() 235 | 236 | x_z[-1] = -(x_z[:-1] @ x_y[:-1]) / x_y[-1] 237 | dummy_coord = torch.tensor(origin - x_z) 238 | pseudo_C = torch.cat((pseudo_C, dummy_coord.unsqueeze(0)), 0) 239 | 240 | origin_t = true_lig_coords[i] 241 | N_t = true_lig_coords[neighbour_list[i][0]] 242 | x_y_t = origin_t - N_t 243 | 244 | if abs(x_y_t[-1]) <= 0.0001: 245 | x_y_t[-1] = 0.0001 if x_y_t[-1] > 0 else -0.0001 246 | x_z_t = x_y_t.copy() 247 | 248 | x_z_t[-1] = -(x_z_t[:-1] @ x_y_t[:-1]) / x_y_t[-1] 249 | dummy_coord_t = torch.tensor(origin_t - x_z_t) 250 | true_pseudo_C = torch.cat((true_pseudo_C, dummy_coord_t.unsqueeze(0)), 0) 251 | 252 | return pseudo_N, pseudo_C, true_pseudo_N, true_pseudo_C 253 | 254 | # RECEPTOR PROCESSING 255 | 256 | def get_receptor_input(rec_path, lig, complex_name, cutoff): 257 | biopython_parser = PDBParser() 258 | conf = lig.GetConformer() 259 | lig_coords = conf.GetPositions() 260 | with warnings.catch_warnings(): 261 | warnings.filterwarnings("ignore", category=PDBConstructionWarning) 262 | structure = biopython_parser.get_structure('random_id', rec_path) 263 | rec = structure[0] 264 | 265 | c_alpha_coords, n_coords, c_coords, three_letter_sequence, n_res_per_chain, valid_chain_ids = get_rec_data(rec, lig_coords, cutoff) 266 | 267 | one_letter_sequence = convert_sequence(three_letter_sequence) 268 | num_res = len(one_letter_sequence) 269 | 270 | features = { 271 | "c_alpha_coords": c_alpha_coords, 272 | "n_coords": n_coords, 273 | "c_coords": c_coords, 274 | "aatype": residue_constants.sequence_to_onehot( 275 | sequence=one_letter_sequence, 276 | mapping=residue_constants.restype_order_with_x, 277 | map_unknown_to_x=True, 278 | ).astype(np.float32), 279 | "residue_index": np.array(range(num_res), dtype=np.int32), 280 | "seq_length": num_res 281 | } 282 | 283 | tensor_dict = { 284 | k: torch.tensor(v) for k, v in features.items() 285 | } 286 | 287 | tensor_dict["complex_names"] = complex_name 288 | tensor_dict["chain_ids"] = valid_chain_ids 289 | tensor_dict["n_res_per_chain"] = n_res_per_chain 290 | sequence = ''.join(one_letter_sequence) 291 | tensor_dict["sequence"] = sequence 292 | 293 | return tensor_dict 294 | 295 | def get_rec_data(receptor, lig_coords, cutoff): 296 | min_distances, c_alpha_coords, n_coords, \ 297 | c_coords, valid_chain_ids, sequence = [], [], [], [], [], [] 298 | 299 | for polypep in receptor: 300 | chain_coords, chain_c_alpha_coords, chain_n_coords, \ 301 | chain_c_coords, invalid_res_ids, chain_sequence = [], [], [], [], [], [] 302 | 303 | for residue in polypep: 304 | if residue.get_resname() == 'HOH': 305 | invalid_res_ids.append(residue.get_id()) 306 | continue 307 | residue_coords = [] 308 | 309 | c_alpha, n, c = None, None, None 310 | for atom in residue: 311 | if atom.name == 'CA': 312 | c_alpha = list(atom.get_vector()) 313 | if atom.name == 'N': 314 | n = list(atom.get_vector()) 315 | if atom.name == 'C': 316 | c = list(atom.get_vector()) 317 | residue_coords.append(list(atom.get_vector())) 318 | 319 | if c_alpha is None or n is None or c is None: 320 | invalid_res_ids.append(residue.get_id()) 321 | else: 322 | chain_c_alpha_coords.append(c_alpha) 323 | chain_n_coords.append(n) 324 | chain_c_coords.append(c) 325 | chain_coords.append(np.array(residue_coords)) 326 | chain_sequence.append(residue.get_resname()) 327 | 328 | for res_id in invalid_res_ids: 329 | polypep.detach_child(res_id) 330 | 331 | if chain_coords: 332 | all_chain_coords = np.concatenate(chain_coords, axis=0) 333 | distances = spatial.distance.cdist(lig_coords, all_chain_coords) 334 | min_distance = distances.min() 335 | else: 336 | min_distance = np.inf 337 | 338 | min_distances.append(min_distance) 339 | c_alpha_coords.append(np.array(chain_c_alpha_coords)) 340 | n_coords.append(np.array(chain_n_coords)) 341 | c_coords.append(np.array(chain_c_coords)) 342 | sequence.append(chain_sequence) 343 | if min_distance < cutoff: 344 | valid_chain_ids.append(polypep.get_id()) 345 | 346 | if not valid_chain_ids: 347 | valid_chain_ids.append(np.argmin(np.array(min_distances))) 348 | 349 | valid_c_alpha_coords, valid_n_coords, \ 350 | valid_c_coords, valid_sequence = [], [], [], [] 351 | for i, polypep in enumerate(receptor): 352 | if polypep.get_id() in valid_chain_ids: 353 | valid_c_alpha_coords.append(c_alpha_coords[i]) 354 | valid_n_coords.append(n_coords[i]) 355 | valid_c_coords.append(c_coords[i]) 356 | valid_sequence.append(sequence[i]) 357 | 358 | c_alpha_coords = np.concatenate(valid_c_alpha_coords, axis=0) # [n_residues, 3] 359 | n_coords = np.concatenate(valid_n_coords, axis=0) # [n_residues, 3] 360 | c_coords = np.concatenate(valid_c_coords, axis=0) # [n_residues, 3] 361 | valid_sequence = list(chain.from_iterable(valid_sequence)) 362 | 363 | n_res_per_chain = { 364 | id: len(s) for id, s in zip(valid_chain_ids, valid_sequence) 365 | } 366 | 367 | return c_alpha_coords, n_coords, c_coords, valid_sequence, n_res_per_chain, valid_chain_ids 368 | 369 | def convert_sequence(three_letter_sequence): 370 | sequence = [] 371 | for res in three_letter_sequence: 372 | try: 373 | one_letter_code = protein_letters_3to1[res] 374 | except KeyError: 375 | one_letter_code = 'X' 376 | sequence.append(one_letter_code) 377 | return sequence 378 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | from copy import deepcopy 4 | import os 5 | from rdkit.Chem import RemoveHs 6 | from rdkit.Geometry import Point3D 7 | from tqdm import tqdm 8 | import torch 9 | import numpy as np 10 | import yaml 11 | import glob 12 | from commons.utils import ( 13 | log, get_parameter_value, rigid_transform_Kabsch_3D, Logger, 14 | get_symmetry_rmsd, save_predictions_to_file 15 | ) 16 | from dataset.process_mols import read_molecule, reorder_atoms 17 | from dataset.dataimporter import DataImporter 18 | from torch.utils.data import DataLoader 19 | from openfold.utils.seed import seed_globally 20 | from quickbind import QuickBind 21 | 22 | def parse_arguments(): 23 | p = argparse.ArgumentParser() 24 | p.add_argument('--name', type=str) 25 | p.add_argument('--unseen_only', type=bool, default=False) 26 | p.add_argument('--save_to_file', type=bool, default=False) 27 | p.add_argument('--pb_set', type=bool, default=False) 28 | p.add_argument('--output_s', type=bool, default=False) 29 | p.add_argument('--val_set', type=bool, default=False) 30 | p.add_argument('--train_set', type=bool, default=False) 31 | return p.parse_args() 32 | 33 | def collate(data, use_topological_distance, one_hot_adj): 34 | assert len(data) == 1 35 | data = data[0] 36 | 37 | aatype = data.aatype.unsqueeze(0).to(device) 38 | lig_atom_features = data.lig_atom_features.unsqueeze(0).to(dtype=torch.float32).to(device) 39 | t_true = data.true_lig_atom_coords.unsqueeze(0).to(device) 40 | rec_mask = torch.ones(data.aatype.shape[0]).unsqueeze(0).to(device) 41 | lig_mask = torch.ones(data.lig_atom_features.shape[0]).unsqueeze(0).to(device) 42 | 43 | if use_topological_distance: 44 | adj = torch.clamp(data.distance_matrix.unsqueeze(0), max=7).to(device) 45 | adj = torch.nn.functional.one_hot(adj.long(), num_classes=8).to(dtype=torch.float32) 46 | elif one_hot_adj: 47 | adj = data.adjacency_bo.unsqueeze(0).to(dtype=torch.int64).to(device) 48 | else: 49 | adj = data.adjacency.unsqueeze(0).unsqueeze(-1).to(dtype=torch.float32).to(device) 50 | 51 | ri = data.residue_index.unsqueeze(0).to(dtype=torch.int64).to(device) 52 | chain_id = data.chain_ids_processed.unsqueeze(0).to(dtype=torch.int64).to(device) 53 | entity_id = data.entity_ids_processed.unsqueeze(0).to(dtype=torch.int64).to(device) 54 | sym_id = data.sym_ids_processed.unsqueeze(0).to(dtype=torch.int64).to(device) 55 | id_batch = (ri, chain_id, entity_id, sym_id) 56 | 57 | t_rec = data.c_alpha_coords.unsqueeze(0).to(device) 58 | N = data.n_coords.unsqueeze(0).to(device) 59 | C = data.c_coords.unsqueeze(0).to(device) 60 | t_lig = data.lig_atom_coords.unsqueeze(0).to(device) 61 | 62 | pseudo_N = data.pseudo_N.unsqueeze(0).to(device) 63 | pseudo_C = data.pseudo_C.unsqueeze(0).to(device) 64 | 65 | names = data.complex_name 66 | 67 | return ( 68 | aatype, lig_atom_features, adj, rec_mask, lig_mask, N, t_rec, C, t_lig, id_batch, pseudo_N, pseudo_C 69 | ), t_true, names 70 | 71 | 72 | def run_inference(model, test_loader): 73 | all_ligs_coords_pred, all_ligs_coords, all_masks, all_names = [], [], [], [] 74 | s_pre_struct_lst = [] 75 | for batch, t_true, names in tqdm(test_loader, desc='Generating model predictions'): 76 | _, _, _, rec_mask, lig_mask, _, _, _, t_lig, _, _, _ = batch 77 | with torch.no_grad(): 78 | if aux_head or lig_aux_head: 79 | if args.output_s: 80 | out, _, s_pre_struct = model(*batch) 81 | else: 82 | out, _ = model(*batch) 83 | else: 84 | if args.output_s: 85 | out, s_pre_struct = model(*batch) 86 | else: 87 | out = model(*batch) 88 | outputs = out[-1][:, rec_mask.shape[-1]:].get_trans() 89 | all_ligs_coords_pred.append(outputs.detach().cpu()) 90 | all_ligs_coords.append(t_true.detach().cpu()) 91 | all_masks.append(lig_mask.detach().cpu()) 92 | all_names.append(names) 93 | if args.output_s: 94 | s_pre_struct_lst.append(s_pre_struct.detach().cpu()) 95 | return { 96 | 'predictions': all_ligs_coords_pred, 97 | 'targets': all_ligs_coords, 98 | 'masks': all_masks, 99 | 'names': all_names, 100 | 's_pre_struct': s_pre_struct_lst 101 | } 102 | 103 | def print_results(rmsds, centroid_distances, incl_H): 104 | rmsds = np.array(rmsds) 105 | centroid_distances = np.array(centroid_distances) 106 | print('----------------------------------------------------------------------------------------------------') 107 | print( 108 | f'| Test statistics ({"incl." if incl_H else "excl."} hydrogen atoms) |' 109 | ) 110 | print('----------------------------------------------------------------------------------------------------') 111 | print(f'Mean RMSD: {rmsds.mean().__round__(2)} +- {rmsds.std().__round__(2)}') 112 | print('RMSD percentiles: ', np.percentile(rmsds, [25, 50, 75]).round(2)) 113 | print(f'% RMSD below 2: {(100 * (rmsds < 2).sum() / len(rmsds)).__round__(2)}%') 114 | print(f'% RMSD below 5: {(100 * (rmsds < 5).sum() / len(rmsds)).__round__(2)}%') 115 | print( 116 | f'Mean centroid distance: {centroid_distances.mean().__round__(2)} +- {centroid_distances.std().__round__(2)}' 117 | ) 118 | print('Centroid percentiles: ', np.percentile(centroid_distances, [25, 50, 75]).round(2)) 119 | print( 120 | f'% centroid distances below 2: {(100 * (centroid_distances < 2).sum() / len(centroid_distances)).__round__(2)}%' 121 | ) 122 | print( 123 | f'% centroid distances below 5: {(100 * (centroid_distances < 5).sum() / len(centroid_distances)).__round__(2)}%' 124 | ) 125 | 126 | def evaluate_predictions(results): 127 | rmsds_wH, centroid_dists_wH = [], [] 128 | kabsch_rmsds, rmsds, centroid_distances = [], [], [] 129 | for prediction, target, mask, name in tqdm(zip( 130 | results['predictions'], results['targets'], results['masks'], results['names']), 131 | desc='Evaluating model predictions', total = len(results['predictions']) 132 | ): 133 | if not remove_h: 134 | # including H atoms 135 | coords_pred = (prediction * mask.unsqueeze(-1)).numpy() 136 | coords_native = target.numpy() 137 | mask = mask.numpy() 138 | rmsd = np.sqrt(np.sum(np.sum((coords_pred - coords_native) ** 2, axis=1)) / np.sum(mask)) 139 | centroid_distance = np.linalg.norm( 140 | np.sum(coords_native, axis=0) / np.sum(mask) - np.sum(coords_pred, axis=0) / np.sum(mask) 141 | ) 142 | centroid_dists_wH.append(centroid_distance) 143 | rmsds_wH.append(rmsd) 144 | 145 | # not including H atoms 146 | if args.pb_set: 147 | lig = read_molecule(os.path.join('data/posebusters_benchmark_set', name, f'{name}_ligand.sdf'), remove_hs=remove_h) 148 | else: 149 | lig = read_molecule(os.path.join('data/PDBBind/', name, f'{name}_ligand.mol2'), remove_hs=remove_h) 150 | if lig is None: 151 | lig = read_molecule(os.path.join('data/PDBBind/', name, f'{name}_ligand.sdf'), remove_hs=remove_h) 152 | lig = RemoveHs(lig) 153 | lig = reorder_atoms(lig) 154 | lig_pred = deepcopy(lig) 155 | conf = lig_pred.GetConformer() 156 | prediction = prediction.squeeze().cpu().numpy() 157 | for i in range(lig_pred.GetNumAtoms()): 158 | x, y, z = prediction[i] 159 | conf.SetAtomPosition(i, Point3D(float(x), float(y), float(z))) 160 | coords_pred = lig_pred.GetConformer().GetPositions() 161 | lig_true = deepcopy(lig) 162 | conf_true = lig_true.GetConformer() 163 | for i in range(lig_true.GetNumAtoms()): 164 | x, y, z = target.squeeze()[i] 165 | conf_true.SetAtomPosition(i, Point3D(float(x), float(y), float(z))) 166 | coords_native = lig_true.GetConformer().GetPositions() 167 | try: 168 | rmsd = get_symmetry_rmsd(lig_true, coords_native, coords_pred, lig_pred) 169 | except Exception as e: 170 | print("Using non corrected RMSD because of the error:", e) 171 | rmsd = np.sqrt(np.sum((coords_pred - coords_native) ** 2, axis=1).mean()) 172 | centroid_distance = np.linalg.norm(coords_native.mean(axis=0) - coords_pred.mean(axis=0)) 173 | R, t = rigid_transform_Kabsch_3D(coords_pred.T, coords_native.T) 174 | moved_coords = (R @ (coords_pred).T).T + t.squeeze() 175 | kabsch_rmsd = np.sqrt(np.sum((moved_coords - coords_native) ** 2, axis=1).mean()) 176 | kabsch_rmsds.append(kabsch_rmsd) 177 | rmsds.append(rmsd) 178 | centroid_distances.append(centroid_distance) 179 | 180 | if not remove_h: 181 | print_results( 182 | rmsds_wH, centroid_dists_wH, True 183 | ) 184 | kabsch_rmsds = np.array(kabsch_rmsds) 185 | print_results( 186 | rmsds, centroid_distances, False 187 | ) 188 | print(f'Mean Kabsch RMSD: {kabsch_rmsds.mean().__round__(2)} +- {kabsch_rmsds.std().__round__(2)}') 189 | print(f'Median Kabsch RMSD: {np.median(kabsch_rmsds).__round__(2)} +- {kabsch_rmsds.std().__round__(2)}') 190 | print('Kabsch RMSD percentiles: ', np.percentile(kabsch_rmsds, [25, 50, 75]).round(2)) 191 | 192 | if __name__ == '__main__': 193 | args = parse_arguments() 194 | log(f'Loading model {args.name}.') 195 | global device 196 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 197 | log(f'Using {device}.') 198 | # Find best model checkpoint 199 | checkpoints = glob.glob(f'checkpoints/{args.name}/best_checkpoint*.pt') 200 | best_scores = [] 201 | for checkpoint in checkpoints: 202 | model_state = torch.load(checkpoint) 203 | best_scores.append( 204 | model_state['callbacks'][ 205 | "ModelCheckpoint{'monitor': 'val_loss', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}" 206 | ]["current_score"] 207 | ) 208 | best_checkpoint = checkpoints[best_scores.index(min(best_scores))] 209 | log(f'Model achieved a validation loss of {min(best_scores)}.') 210 | model_state = torch.load(best_checkpoint) 211 | state_dict = { 212 | key[6:]: model_state['state_dict'][key] 213 | for key in model_state['state_dict'] 214 | } 215 | 216 | with open(f'checkpoints/{args.name}/config.yaml', 'r') as arg_file: 217 | checkpoint_dict = yaml.load(arg_file, Loader=yaml.FullLoader) 218 | 219 | sys.stdout = Logger(logpath=f'checkpoints/{checkpoint_dict["name"]}/inference.log', syspart=sys.stdout) 220 | sys.stderr = Logger(logpath=f'checkpoints/{checkpoint_dict["name"]}/inference.log', syspart=sys.stderr) 221 | seed = 0 if checkpoint_dict['seed'] is None else checkpoint_dict['seed'] 222 | seed_globally(seed) 223 | checkpoint_dict['dataset_params']['cropping'] = False # no cropping at inference time 224 | 225 | if args.pb_set: 226 | test_data = DataImporter( 227 | complex_names_path='data/posebusters_benchmark_set/posebusters', 228 | **checkpoint_dict['dataset_params'] 229 | ) 230 | elif args.val_set: 231 | # run inference on the validation set 232 | test_data = DataImporter(complex_names_path=checkpoint_dict['val_names'], **checkpoint_dict['dataset_params']) 233 | elif args.train_set: 234 | # run inference on the training set 235 | test_data = DataImporter(complex_names_path=checkpoint_dict['train_names'], **checkpoint_dict['dataset_params']) 236 | else: 237 | test_data = DataImporter( 238 | complex_names_path=checkpoint_dict['test_names'], 239 | **checkpoint_dict['dataset_params'], unseen_only=args.unseen_only 240 | ) 241 | log(f'Test size: {len(test_data)}.') 242 | lig_feat_dim, rec_feat_dim = test_data.get_feature_dimensions() 243 | model = QuickBind( 244 | aa_feat=rec_feat_dim, lig_atom_feat=lig_feat_dim, **checkpoint_dict['model_parameters'], 245 | chunk_size=2, output_s=args.output_s 246 | ) 247 | model.load_state_dict(state_dict) 248 | model = model.to(device) 249 | model.eval() 250 | 251 | aux_head = get_parameter_value('use_aux_head', checkpoint_dict['model_parameters']) 252 | lig_aux_head = get_parameter_value('use_lig_aux_head', checkpoint_dict['model_parameters']) 253 | one_hot_adj = get_parameter_value('one_hot_adj', checkpoint_dict['model_parameters']) 254 | use_topological_distance = get_parameter_value('use_topological_distance', checkpoint_dict['model_parameters']) 255 | remove_h = get_parameter_value('remove_h', checkpoint_dict['dataset_params']) 256 | 257 | test_loader = DataLoader( 258 | test_data, batch_size=1, collate_fn=lambda x: collate( 259 | x, use_topological_distance, one_hot_adj 260 | ) 261 | ) 262 | 263 | out_path = ( 264 | f'checkpoints/{checkpoint_dict["name"]}/' 265 | f'{"unseen_" if args.unseen_only else ""}' 266 | f'{"posebusters_" if args.pb_set else ""}' 267 | f'{"train_" if args.train_set else ""}' 268 | f'{"val_" if args.val_set else ""}' 269 | 'predictions' 270 | f'{"-w-single-rep" if args.output_s else ""}.pt' 271 | ) 272 | if not os.path.exists(out_path): 273 | results = run_inference(model, test_loader) 274 | log(f'Saving predictions to {out_path}.') 275 | torch.save(results, out_path) 276 | else: 277 | log('Loading model predictions.') 278 | results = torch.load(out_path) 279 | 280 | evaluate_predictions(results) 281 | 282 | if args.save_to_file: 283 | log(f'Saving predictions as SD files.') 284 | if args.pb_set: 285 | receptor_dir='data/processed/posebusters' 286 | else: 287 | receptor_dir='data/processed/timesplit_test' 288 | save_predictions_to_file( 289 | results, receptor_dir, 290 | os.path.join(f'checkpoints/{args.name}', f'{"posebusters_" if args.pb_set else ""}sdffiles') 291 | ) 292 | -------------------------------------------------------------------------------- /overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aqlaboratory/QuickBind/d8c5cf4901b44f233cbbd8e6936d3e31aeebfec2/overview.jpg -------------------------------------------------------------------------------- /quickbind.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import pytorch_lightning as pl 4 | from commons.modified_of_modules import ( 5 | InputEmbedder, EvoformerStack, StructureModule, 6 | BackboneUpdate, GatedInvariantPointAttention, 7 | FullEvoformerStack 8 | ) 9 | from openfold.model.structure_module import StructureModuleTransition, InvariantPointAttention 10 | from openfold.model.primitives import Linear, LayerNorm 11 | from openfold.utils.rigid_utils import Rigid, Rotation 12 | from functools import partial 13 | from openfold.model.heads import DistogramHead 14 | from openfold.utils.loss import distogram_loss 15 | torch.cuda.empty_cache() 16 | 17 | class QuickBind(nn.Module): 18 | def __init__( 19 | self, 20 | # INPUT EMBEDDINGS # 21 | aa_feat, lig_atom_feat, c_emb, c_s, c_z, use_op_edge_embed, use_pairwise_dist, use_radial_basis, 22 | use_rel_pos, use_multimer_rel_pos, mask_off_diagonal, one_hot_adj, use_topological_distance, 23 | # EVOFORMER # 24 | c_hidden_msa_att, c_hidden_opm, c_hidden_mul, c_hidden_pair_att, c_s_out, 25 | no_heads_msa, no_heads_pair, no_evo_blocks, transition_n, msa_dropout, 26 | pair_dropout, opm_first, chunk_size, 27 | # STRUCTURE MODULE # 28 | c_hidden, no_heads, no_qk_points, no_v_points, 29 | num_struct_blocks, dropout_rate, 30 | no_transition_layers, share_ipa_weights, 31 | use_gated_ipa = True, communicate = False, 32 | sum_pool = False, mean_pool = False, att_update = True, 33 | # RECYCLING # 34 | recycle = False, recycle_iters = 1, 35 | # LOSS FUNCTION # 36 | use_aux_head=False, use_lig_aux_head=False, no_dist_bins=64, no_dist_bins_lig=42, 37 | construct_frames=True, 38 | # GLOBAL SETTINGS # 39 | use_full_evo_stack=False, blackhole_init=False, 40 | # OUTPUT EMBEDDING # 41 | output_s=False 42 | ): 43 | super(QuickBind, self).__init__() 44 | self.inputembedder = InputEmbedder( 45 | aa_feat, lig_atom_feat, c_emb, c_s, c_z, use_op_edge_embed, use_pairwise_dist, use_radial_basis, 46 | use_rel_pos, use_multimer_rel_pos, mask_off_diagonal, one_hot_adj, use_topological_distance 47 | ) 48 | 49 | # EVOFORMER # 50 | if no_evo_blocks > 0: 51 | if use_full_evo_stack: 52 | self.evoformer = FullEvoformerStack( 53 | c_s, c_z, c_hidden_msa_att, c_hidden_opm, c_hidden_mul, c_hidden_pair_att, c_s_out, 54 | no_heads_msa, no_heads_pair, no_evo_blocks, transition_n, msa_dropout, 55 | pair_dropout, opm_first=opm_first 56 | ) 57 | else: 58 | self.evoformer = EvoformerStack( 59 | c_s, c_z, c_hidden_msa_att, c_hidden_opm, c_hidden_mul, c_hidden_pair_att, c_s_out, 60 | no_heads_msa, no_heads_pair, no_evo_blocks, transition_n, msa_dropout, 61 | pair_dropout, opm_first=opm_first 62 | ) 63 | self.no_evo_blocks = no_evo_blocks 64 | self.chunk_size = chunk_size 65 | 66 | # STRUCTURE MODULE # 67 | self.layer_norm_s = LayerNorm(c_s_out) 68 | self.layer_norm_z = LayerNorm(c_z) 69 | self.linear_in = Linear(c_s_out, c_s_out) 70 | self.num_struct_blocks = num_struct_blocks 71 | self.share_ipa_weights = share_ipa_weights 72 | if share_ipa_weights: 73 | self.structure_module_block = StructureModule( 74 | c_s_out, c_z, c_hidden, no_heads, no_qk_points, no_v_points, dropout_rate, 75 | no_transition_layers, sum_pool, mean_pool, att_update, use_gated_ipa, construct_frames 76 | ) 77 | else: 78 | if use_gated_ipa: 79 | self.ipa_blocks = nn.ModuleList([ 80 | GatedInvariantPointAttention( 81 | c_s, c_z, c_hidden, no_heads, no_qk_points, no_v_points 82 | ) for _ in range(num_struct_blocks) 83 | ]) 84 | else: 85 | self.ipa_blocks = nn.ModuleList([ 86 | InvariantPointAttention( 87 | c_s, c_z, c_hidden, no_heads, no_qk_points, no_v_points 88 | ) for _ in range(num_struct_blocks) 89 | ]) 90 | self.ipa_dropout = nn.Dropout(dropout_rate) 91 | self.layer_norm_ipa = LayerNorm(c_s) 92 | self.transition = StructureModuleTransition(c_s, no_transition_layers, dropout_rate) 93 | self.bb_update = BackboneUpdate(c_s, sum_pool, mean_pool, att_update, construct_frames) 94 | 95 | # RECYCLING EMBEDDINGS # 96 | self.recycle = recycle 97 | self.recycle_iters = recycle_iters 98 | if recycle: 99 | self.layer_norm_s_recycle = LayerNorm(c_s) 100 | self.layer_norm_z_recycle = LayerNorm(c_z) 101 | self.linear_z_recycle = Linear(1, c_z) 102 | 103 | # AUXILIARY HEADS # 104 | self.use_aux_head = use_aux_head 105 | self.use_lig_aux_head = use_lig_aux_head 106 | if self.use_aux_head: 107 | self.distogram = DistogramHead(c_z, no_dist_bins) 108 | if self.use_lig_aux_head: 109 | self.lig_distogram = DistogramHead(c_z, no_dist_bins_lig) 110 | 111 | self.communicate = communicate 112 | if self.communicate: 113 | self.linear_a_i = Linear(c_s, c_z) 114 | self.linear_b_i = Linear(c_s, c_z) 115 | self.linear_dist = Linear(1, c_z) 116 | 117 | self.construct_frames = construct_frames 118 | self.blackhole_init = blackhole_init 119 | self.pooled_update = bool(sum_pool or mean_pool or att_update) 120 | 121 | self.output_s = output_s 122 | 123 | def iteration( 124 | self, aatype, lig_atom_features, adj, s_prev, z_prev, t_prev, ri, mask, edge_mask, 125 | N, t_rec, C, rec_mask, lig_mask, pseudo_N, pseudo_C 126 | ): 127 | # INPUT EMBEDDINGS # 128 | s, z = self.inputembedder(aatype, lig_atom_features, t_prev, edge_mask, adj, ri) 129 | t_lig = t_prev[:, rec_mask.shape[-1]:, :] 130 | if self.construct_frames and not self.blackhole_init: 131 | rigids = Rigid.cat( 132 | [ 133 | Rigid.from_3_points(N, t_rec, C), 134 | Rigid.from_3_points(pseudo_N, t_lig, pseudo_C) 135 | ], dim=1 136 | ) 137 | else: 138 | rigids = Rigid.cat( 139 | [ 140 | Rigid.from_3_points(N, t_rec, C), 141 | Rigid( 142 | rots = Rotation.identity( 143 | shape=t_lig.shape[:-1], dtype = torch.float32, device=t_lig.device, fmt="quat" 144 | ), trans = t_lig 145 | ) 146 | ], dim=1 147 | ) 148 | 149 | # RECYCLING EMBEDDINGS # 150 | if None not in [s_prev, z_prev]: 151 | s_prev = self.layer_norm_s_recycle(s_prev) 152 | pairwise_distance_prev = (torch.cdist(t_prev, t_prev, p=2) * edge_mask).unsqueeze(-1).to(dtype=torch.float32) 153 | z_prev = self.linear_z_recycle(pairwise_distance_prev) + self.layer_norm_z_recycle(z_prev) 154 | s = s + s_prev 155 | z = z + z_prev 156 | 157 | # EVOFORMER # 158 | if self.no_evo_blocks > 0: 159 | s = s.unsqueeze(-3) 160 | msa_mask = mask.unsqueeze(-2) 161 | s, z = self.evoformer( 162 | s, z, 163 | msa_mask=msa_mask, 164 | pair_mask=edge_mask, 165 | chunk_size=self.chunk_size 166 | ) 167 | if self.recycle: 168 | s_prev, z_prev = s, z 169 | 170 | if self.output_s: 171 | s_pre_struct = s 172 | 173 | # STRUCTURE MODULE # 174 | s = self.layer_norm_s(s) 175 | z = self.layer_norm_z(z) 176 | s = self.linear_in(s) 177 | 178 | out = [] 179 | if self.share_ipa_weights: 180 | blocks = [ 181 | partial( 182 | self.structure_module_block, mask=mask, rec_mask=rec_mask, lig_mask=lig_mask 183 | ) for _ in range(self.num_struct_blocks) 184 | ] 185 | for block in blocks: 186 | s, z, new_trans = block(s, z, rigids) 187 | if not self.pooled_update: 188 | new_trans = new_trans[:, rec_mask.shape[-1]:, :] 189 | new_trans = new_trans * lig_mask.unsqueeze(-1) 190 | if self.construct_frames: 191 | rigids_ligand = rigids[:, rec_mask.shape[-1]:] 192 | rigids_protein = rigids[:, :rec_mask.shape[-1]] 193 | rigids_ligand_updated = rigids_ligand.compose_q_update_vec(new_trans) 194 | updated_rigids = Rigid.cat([rigids_protein, rigids_ligand_updated], dim=1) 195 | else: 196 | update = torch.cat([torch.zeros_like(rigids.get_trans()[:, :rec_mask.shape[-1], :]), new_trans], dim=-2) 197 | updated_rigids = Rigid( 198 | rots = rigids.get_rots(), 199 | trans = rigids.get_trans() + update 200 | ) 201 | rigids = updated_rigids 202 | out.append(updated_rigids) 203 | if self.construct_frames: 204 | rigids = rigids.stop_rot_gradient() 205 | else: 206 | for ipa in self.ipa_blocks: 207 | s = s + ipa(s, z, rigids, mask) 208 | s = self.ipa_dropout(s) 209 | s = self.layer_norm_ipa(s) 210 | s = self.transition(s) 211 | new_trans = self.bb_update(s, rec_mask, lig_mask) 212 | if not self.pooled_update: 213 | new_trans = new_trans[:, rec_mask.shape[-1]:, :] 214 | new_trans = new_trans * lig_mask.unsqueeze(-1) 215 | if self.construct_frames: 216 | rigids_ligand = rigids[:, rec_mask.shape[-1]:] 217 | rigids_protein = rigids[:, :rec_mask.shape[-1]] 218 | rigids_ligand_updated = rigids_ligand.compose_q_update_vec(new_trans) 219 | updated_rigids = Rigid.cat([rigids_protein, rigids_ligand_updated], dim=1) 220 | else: 221 | update = torch.cat([torch.zeros_like(rigids.get_trans()[:, :rec_mask.shape[-1], :]), new_trans], dim=-2) 222 | updated_rigids = Rigid( 223 | rots = rigids.get_rots(), 224 | trans = rigids.get_trans() + update 225 | ) 226 | rigids = updated_rigids 227 | out.append(updated_rigids) 228 | if self.communicate: 229 | ti = rigids.get_trans() 230 | a_i = self.linear_a_i(s) 231 | b_i = self.linear_b_i(s) 232 | pair_emb = a_i[..., None, :] + b_i[..., None, :, :] 233 | dist = (torch.cdist(ti, ti, p=2) * edge_mask).unsqueeze(-1).to(dtype=torch.float32) 234 | pairwise_distance = self.linear_dist(dist) 235 | pair_emb = pair_emb + pairwise_distance 236 | z = z + pair_emb 237 | if self.construct_frames: 238 | rigids = rigids.stop_rot_gradient() 239 | 240 | if self.recycle: t_prev = rigids.get_trans() 241 | 242 | if (self.use_aux_head or self.use_lig_aux_head) and self.is_final_iter: 243 | if self.output_s: 244 | return out, s, s_prev, z, t_prev, s_pre_struct 245 | else: 246 | return out, s, s_prev, z, t_prev 247 | else: 248 | if self.output_s: 249 | return out, s, s_prev, z_prev, t_prev, s_pre_struct 250 | else: 251 | return out, s, s_prev, z_prev, t_prev 252 | 253 | def forward(self, aatype, lig_atom_features, adj, rec_mask, lig_mask, N, t_rec, C, t_lig, ri, pseudo_N, pseudo_C): 254 | is_grad_enabled = torch.is_grad_enabled() 255 | # RECYCLING # 256 | s_prev, z_prev = None, None 257 | t_prev = torch.cat([t_rec, t_lig], dim=-2) 258 | mask = torch.cat([rec_mask, lig_mask], dim=-1).to(dtype=torch.float32) 259 | edge_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) 260 | 261 | for iteration in range(self.recycle_iters): 262 | self.is_final_iter = (iteration == (self.recycle_iters-1)) 263 | with torch.set_grad_enabled(is_grad_enabled and self.is_final_iter): 264 | if self.is_final_iter and torch.is_autocast_enabled(): # Sidestep AMP bug (PyTorch issue #65766) 265 | torch.clear_autocast_cache() 266 | if self.output_s: 267 | outputs, s, s_prev, z_prev, t_prev, s_pre_struct = self.iteration( 268 | aatype, lig_atom_features, adj, s_prev, z_prev, t_prev, ri, mask, edge_mask, 269 | N, t_rec, C, rec_mask, lig_mask, pseudo_N, pseudo_C 270 | ) 271 | else: 272 | outputs, s, s_prev, z_prev, t_prev = self.iteration( 273 | aatype, lig_atom_features, adj, s_prev, z_prev, t_prev, ri, mask, edge_mask, 274 | N, t_rec, C, rec_mask, lig_mask, pseudo_N, pseudo_C 275 | ) 276 | if not self.is_final_iter: del outputs, s 277 | 278 | if self.use_aux_head and self.use_lig_aux_head: 279 | distogram_logits_full = self.distogram(z_prev) 280 | distogram_logits_lig = self.lig_distogram(z_prev[:, rec_mask.shape[-1]:, rec_mask.shape[-1]:]) 281 | distogram_logits = (distogram_logits_full, distogram_logits_lig) 282 | if self.output_s: 283 | return outputs, distogram_logits, s_pre_struct 284 | else: 285 | return outputs, distogram_logits 286 | elif self.use_aux_head: 287 | distogram_logits = self.distogram(z_prev) 288 | if self.output_s: 289 | return outputs, distogram_logits, s_pre_struct 290 | else: 291 | return outputs, distogram_logits 292 | elif self.use_lig_aux_head: 293 | distogram_logits = self.lig_distogram(z_prev[:, rec_mask.shape[-1]:, rec_mask.shape[-1]:]) 294 | if self.output_s: 295 | return outputs, distogram_logits, s_pre_struct 296 | else: 297 | return outputs, distogram_logits 298 | else: 299 | if self.output_s: 300 | return outputs, s_pre_struct 301 | else: 302 | return outputs 303 | 304 | class QuickBind_PL(pl.LightningModule): 305 | def __init__( 306 | self, 307 | # INPUT EMBEDDINGS # 308 | aa_feat, lig_atom_feat, c_emb, c_s, c_z, use_op_edge_embed, 309 | use_pairwise_dist, use_radial_basis, use_rel_pos, use_multimer_rel_pos, 310 | mask_off_diagonal, one_hot_adj, use_topological_distance, 311 | # EVOFORMER # 312 | c_hidden_msa_att, c_hidden_opm, c_hidden_mul, c_hidden_pair_att, c_s_out, 313 | no_heads_msa, no_heads_pair, no_evo_blocks, transition_n, msa_dropout, 314 | pair_dropout, opm_first, chunk_size, 315 | # STRUCTURE MODULE # 316 | c_hidden, no_heads, no_qk_points, no_v_points, 317 | num_struct_blocks, dropout_rate, 318 | no_transition_layers, share_ipa_weights, 319 | use_gated_ipa = False, communicate = False, 320 | sum_pool = False, mean_pool = False, att_update=False, 321 | # RECYCLING # 322 | recycle = False, recycle_iters = 1, 323 | # LOSS FUNCTION # 324 | loss_config = None, 325 | use_aux_head=False, use_lig_aux_head=False, no_dist_bins=64, no_dist_bins_lig=42, 326 | construct_frames=False, 327 | use_full_evo_stack=False, blackhole_init=False, 328 | # LEARNING RATE # 329 | lr=1.0e-5, weight_decay=1.0e-4, 330 | ): 331 | super().__init__() 332 | self.model = QuickBind( 333 | # INPUT EMBEDDINGS # 334 | aa_feat, lig_atom_feat, c_emb, c_s, c_z, use_op_edge_embed, 335 | use_pairwise_dist, use_radial_basis, use_rel_pos, use_multimer_rel_pos, 336 | mask_off_diagonal, one_hot_adj, use_topological_distance, 337 | # EVOFORMER # 338 | c_hidden_msa_att, c_hidden_opm, c_hidden_mul, c_hidden_pair_att, c_s_out, 339 | no_heads_msa, no_heads_pair, no_evo_blocks, transition_n, msa_dropout, 340 | pair_dropout, opm_first, chunk_size, 341 | # STRUCTURE MODULE # 342 | c_hidden, no_heads, no_qk_points, no_v_points, 343 | num_struct_blocks, dropout_rate, 344 | no_transition_layers, share_ipa_weights, 345 | use_gated_ipa, communicate, 346 | sum_pool, mean_pool, att_update, 347 | # RECYCLING # 348 | recycle, recycle_iters, 349 | # AUXILIARY HEADS # 350 | use_aux_head, use_lig_aux_head, no_dist_bins, no_dist_bins_lig, 351 | construct_frames, use_full_evo_stack, blackhole_init 352 | ) 353 | 354 | self.loss = QuickBindLoss(**loss_config, use_aux_head=use_aux_head, use_lig_aux_head=use_lig_aux_head) 355 | self.use_aux_head = use_aux_head 356 | self.use_lig_aux_head = use_lig_aux_head 357 | self.lr = lr 358 | self.weight_decay = weight_decay 359 | self.save_hyperparameters() 360 | 361 | def forward(self, batch): 362 | return self.model(*batch) 363 | 364 | def configure_optimizers(self): 365 | optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay) 366 | return optimizer 367 | 368 | def training_step(self, batch, idx): 369 | batch, t_true = batch 370 | _, _, _, rec_mask, lig_mask, _, _, _, _, _, _, _ = batch 371 | if self.use_aux_head or self.use_lig_aux_head: 372 | outputs, distogram_logits = self.model(*batch) 373 | else: 374 | outputs = self.model(*batch) 375 | distogram_logits = None 376 | loss, ( 377 | lig_lig_loss, lig_rec_loss, aux_loss, steric_clash_loss, full_distogram_loss 378 | ), rmsd = self.loss(t_true, outputs, lig_mask, rec_mask, distogram_logits) 379 | self.log('train_loss', loss) 380 | self.log('train_lig_lig_loss', lig_lig_loss) 381 | self.log('train_lig_rec_loss', lig_rec_loss) 382 | self.log('train_aux_loss', aux_loss) 383 | self.log('train_steric_clash_loss', steric_clash_loss) 384 | self.log('train_full_distogram_loss', full_distogram_loss) 385 | self.log('train_rmsd', rmsd) 386 | return loss 387 | 388 | def validation_step(self, batch, idx): 389 | batch, t_true = batch 390 | _, _, _, rec_mask, lig_mask, _, _, _, _, _, _, _ = batch 391 | if self.use_aux_head or self.use_lig_aux_head: 392 | outputs, distogram_logits = self.model(*batch) 393 | else: 394 | outputs = self.model(*batch) 395 | distogram_logits = None 396 | loss, ( 397 | lig_lig_loss, lig_rec_loss, aux_loss, steric_clash_loss, full_distogram_loss 398 | ), rmsd = self.loss(t_true, outputs, lig_mask, rec_mask, distogram_logits,) 399 | self.log('val_loss', loss, sync_dist=True) 400 | self.log('val_lig_lig_loss', lig_lig_loss, sync_dist=True) 401 | self.log('val_lig_rec_loss', lig_rec_loss, sync_dist=True) 402 | self.log('val_aux_loss', aux_loss, sync_dist=True) 403 | self.log('val_steric_clash_loss', steric_clash_loss, sync_dist=True) 404 | self.log('val_full_distogram_loss', full_distogram_loss, sync_dist=True) 405 | self.log('val_rmsd', rmsd, sync_dist=True) 406 | return loss 407 | 408 | class QuickBindLoss(nn.Module): 409 | def __init__( 410 | self, lig_lig_loss_weight, lig_rec_loss_weight, aux_loss_weight, 411 | steric_clash_loss_weight, full_distogram_loss_weight, clamp_distance = None, eps = 1e-8, 412 | use_aux_head=False, use_lig_aux_head=False, 413 | ): 414 | super().__init__() 415 | self.lig_lig_loss_weight = lig_lig_loss_weight 416 | self.lig_rec_loss_weight = lig_rec_loss_weight 417 | self.aux_loss_weight = aux_loss_weight 418 | self.steric_clash_loss_weight = steric_clash_loss_weight 419 | self.full_distogram_loss_weight = full_distogram_loss_weight 420 | self.eps = eps 421 | self.clamp_distance = clamp_distance 422 | self.use_aux_head = use_aux_head 423 | self.use_lig_aux_head = use_lig_aux_head 424 | 425 | def compute_fape_lig_lig( 426 | self, 427 | pred_frames: Rigid, 428 | target_frames: Rigid, 429 | pred_positions: torch.Tensor, 430 | target_positions: torch.Tensor, 431 | mask: torch.Tensor 432 | ) -> torch.Tensor: 433 | # [*, N_frames, N_frames, 3] 434 | local_pred_pos = pred_frames.invert()[..., None].apply( 435 | pred_positions[..., None, :, :], 436 | ) 437 | local_target_pos = target_frames.invert()[..., None].apply( 438 | target_positions[..., None, :, :], 439 | ) 440 | error = torch.sqrt( 441 | torch.sum((local_pred_pos - local_target_pos) ** 2, dim=-1) + self.eps 442 | ) 443 | edge_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) 444 | error = error * edge_mask 445 | error = torch.sum(torch.sum(error, dim=-1), dim=-1) / torch.sum(mask, dim=-1)**2 446 | return torch.mean(error) 447 | 448 | def compute_fape_lig_rec( 449 | self, 450 | pred_positions: torch.Tensor, 451 | target_positions: torch.Tensor, 452 | protein_frames: Rigid, 453 | lig_mask: torch.Tensor, 454 | rec_mask: torch.Tensor, 455 | clamp_distance = None, 456 | ) -> torch.Tensor: 457 | # [*, N_protein_frames, N_lig_frames, 3] 458 | local_pred_pos = protein_frames.invert()[..., None].apply( 459 | pred_positions[..., None, :, :], 460 | ) 461 | local_target_pos = protein_frames.invert()[..., None].apply( 462 | target_positions[..., None, :, :], 463 | ) 464 | error = torch.sqrt( 465 | torch.sum((local_pred_pos - local_target_pos) ** 2, dim=-1) + self.eps 466 | ) 467 | edge_mask = rec_mask.unsqueeze(-1) * lig_mask.unsqueeze(-2) 468 | error = error * edge_mask 469 | if clamp_distance is not None: 470 | error = torch.clamp(error, min=0, max=clamp_distance) 471 | error = torch.sum(torch.sum(error, dim=-1), dim=-1) / (torch.sum(rec_mask, dim=-1) * torch.sum(lig_mask, dim=-1)) 472 | return torch.mean(error) 473 | 474 | def compute_rmsd(self, ti, t_true, mask): 475 | error = (ti - t_true) * mask.unsqueeze(-1) 476 | error = torch.sum(torch.sum(error**2, dim=-1), dim=-1) / (torch.sum(mask, dim=-1)) 477 | return torch.mean(torch.sqrt(error + self.eps)) 478 | 479 | def compute_steric_clash_loss_lig(self, ti, lig_mask): 480 | edge_mask = lig_mask.unsqueeze(-1) * lig_mask.unsqueeze(-2) 481 | pairwise_distances = torch.cdist(ti, ti, p=2) * edge_mask 482 | error = torch.nn.functional.relu(0.5 - pairwise_distances) 483 | error = torch.sum(torch.sum(torch.tril(error, diagonal=-1), dim=-1), dim=-1) 484 | return torch.mean(error) 485 | 486 | def compute_kabsch_rmsd(self, ti_batch, t_true_batch, mask): 487 | transformed_coords = [] 488 | for ti, t_true in zip(ti_batch, t_true_batch): 489 | try: 490 | lig_coords_pred_mean = ti.mean(dim=0, keepdim=True, dtype=torch.float32) # (1,3) 491 | lig_coords_mean = t_true.mean(dim=0, keepdim=True, dtype=torch.float32) # (1,3) 492 | A = ((ti - lig_coords_pred_mean).transpose(0, 1) @ (t_true - lig_coords_mean)).to(dtype=torch.float32) 493 | U, S, Vt = torch.linalg.svd(A) 494 | corr_mat = torch.diag(torch.tensor([1, 1, torch.sign(torch.det(A))], device=ti.device)) 495 | rotation = (U @ corr_mat) @ Vt 496 | translation = lig_coords_pred_mean - torch.t(rotation @ lig_coords_mean.t()) # (1,3) 497 | transformed_coords.append((rotation @ t_true.t()).t() + translation) 498 | return self.compute_pos_loss(ti_batch, torch.stack(transformed_coords), mask) 499 | except Exception: 500 | print('Computing Kabsch RMSD failed.') 501 | return torch.zeros(1, requires_grad=True, dtype=torch.float32, device=ti_batch.device) 502 | 503 | def compute_pos_loss(self, ti, t_true, mask): 504 | error = (ti - t_true) * mask.unsqueeze(-1) 505 | error = torch.sum(torch.sum(error**2, dim=-1), dim=-1) / (3*torch.sum(mask, dim=-1)) 506 | return torch.mean(error) 507 | 508 | def forward(self, target_frames, outputs, lig_mask, rec_mask, distogram_logits): 509 | target_frames = target_frames.cuda() 510 | pred_frames = outputs[-1][:, rec_mask.shape[-1]:] 511 | rec_frames = outputs[-1][:, :rec_mask.shape[-1]] 512 | target_positions = target_frames.get_trans() 513 | pred_positions = pred_frames.get_trans() 514 | lig_lig_loss = self.compute_fape_lig_lig(pred_frames, target_frames, pred_positions, target_positions, lig_mask) 515 | lig_rec_loss = self.compute_fape_lig_rec(pred_positions, target_positions, rec_frames, lig_mask, rec_mask, self.clamp_distance) 516 | aux_loss = torch.mean(torch.stack([ 517 | self.compute_fape_lig_rec(pred_frames[:, rec_mask.shape[-1]:].get_trans(), target_positions, rec_frames, lig_mask, rec_mask, self.clamp_distance) for pred_frames in outputs 518 | ])) 519 | steric_clash_loss = self.compute_kabsch_rmsd(pred_positions, target_positions, lig_mask) if self.steric_clash_loss_weight > 0 else 0.0 520 | rmsd = self.compute_rmsd(pred_positions, target_positions, lig_mask) 521 | 522 | if self.use_aux_head and self.use_lig_aux_head: 523 | distogram_logits_full, distogram_logits_lig = distogram_logits 524 | pseudo_beta_mask = torch.cat([rec_mask, lig_mask], dim=-1) 525 | pseudo_beta = torch.cat([rec_frames.get_trans(), pred_positions], dim=-2) 526 | rec_lig_distogram_loss = distogram_loss(distogram_logits_full, pseudo_beta, pseudo_beta_mask, min_bin=2.3125, max_bin=21.6875, no_bins=64) 527 | lig_lig_distogram_loss = distogram_loss(distogram_logits_lig, pred_positions, lig_mask, min_bin=1., max_bin=5., no_bins=42) 528 | full_distogram_loss = rec_lig_distogram_loss + lig_lig_distogram_loss 529 | elif self.use_aux_head: 530 | pseudo_beta_mask = torch.cat([rec_mask, lig_mask], dim=-1) 531 | pseudo_beta = torch.cat([rec_frames.get_trans(), pred_positions], dim=-2) 532 | full_distogram_loss = distogram_loss(distogram_logits, pseudo_beta, pseudo_beta_mask, min_bin=2.3125, max_bin=21.6875, no_bins=64) 533 | elif self.use_lig_aux_head: 534 | full_distogram_loss = distogram_loss(distogram_logits, pred_positions, lig_mask, min_bin=1., max_bin=5., no_bins=42) 535 | else: 536 | full_distogram_loss = 0.0 537 | 538 | loss = ( 539 | self.lig_lig_loss_weight * lig_lig_loss + \ 540 | self.lig_rec_loss_weight * lig_rec_loss + \ 541 | self.aux_loss_weight * aux_loss +\ 542 | self.steric_clash_loss_weight * steric_clash_loss +\ 543 | self.full_distogram_loss_weight * full_distogram_loss 544 | ) 545 | 546 | if torch.isnan(loss): 547 | print('Loss is nan, skipping...') 548 | loss = torch.zeros(1, requires_grad=True, dtype=torch.float32, device=lig_lig_loss.device) 549 | 550 | return loss, (lig_lig_loss, lig_rec_loss, aux_loss, steric_clash_loss, full_distogram_loss), rmsd 551 | -------------------------------------------------------------------------------- /scripts/process_binding_affinities.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pickle 3 | 4 | # taken from the TANKBind GitHub repository 5 | # Copyright (c) 2022 Wei Lu, Galixir Technologies 6 | def read_pdbbind_data(fileName): 7 | with open(fileName) as f: 8 | a = f.readlines() 9 | info = [] 10 | for line in a: 11 | if line[0] == '#': 12 | continue 13 | lines, ligand = line.split('//') 14 | pdb, resolution, year, affinity, raw = lines.strip().split(' ') 15 | ligand = ligand.strip().split('(')[1].split(')')[0] 16 | info.append([pdb, resolution, year, affinity, raw, ligand]) 17 | info = pd.DataFrame(info, columns=['pdb', 'resolution', 'year', 'affinity', 'raw', 'ligand']) 18 | info.year = info.year.astype(int) 19 | info.affinity = info.affinity.astype(float) 20 | return info 21 | 22 | df_pdb_id = pd.read_csv( 23 | '../index/INDEX_general_PL_name.2020', sep=" ", comment='#', header=None, engine='python', 24 | names=['pdb', 'year', 'uid', 'd', 'e','f','g','h','i','j','k','l','m','n','o'] 25 | ) 26 | df_pdb_id = df_pdb_id[['pdb','uid']] 27 | data = read_pdbbind_data('../index/INDEX_general_PL_data.2020') 28 | data = data.merge(df_pdb_id, on=['pdb']) 29 | 30 | binding_affinity_dict = dict(zip(data.pdb, data.affinity)) 31 | 32 | with open('../data/binding_affinity_dict.pkl', 'wb') as f: 33 | pickle.dump(binding_affinity_dict, f) 34 | -------------------------------------------------------------------------------- /train_binding_affinity.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from torch import nn 4 | import numpy as np 5 | import pickle 6 | import torch.optim as optim 7 | from torch.utils.data import Dataset, DataLoader 8 | from torch.nn.utils.rnn import pad_sequence 9 | from openfold.model.primitives import Linear, LayerNorm 10 | from commons.utils import log 11 | from scipy.stats import pearsonr, spearmanr 12 | from sklearn.metrics import mean_squared_error, mean_absolute_error 13 | 14 | def parse_arguments(): 15 | p = argparse.ArgumentParser() 16 | p.add_argument('--seed', type=int, default=42) 17 | p.add_argument('--ckpt', type=str, default=None) 18 | return p.parse_args() 19 | 20 | class BindingAffinityPredictor(nn.Module): 21 | def __init__(self, c_s=64): 22 | super().__init__() 23 | self.norm = LayerNorm(c_s).to(device='cuda') 24 | self.affinity_in = nn.Sequential( 25 | Linear(c_s, c_s), 26 | nn.SiLU(), 27 | Linear(c_s, c_s), 28 | ).to(device='cuda') 29 | self.binding_affinity_head = nn.Sequential( 30 | Linear(c_s, c_s), 31 | nn.ReLU(), 32 | Linear(c_s, c_s//2), 33 | nn.ReLU(), 34 | Linear(c_s//2, 1, init="final"), 35 | ).to(device='cuda') 36 | 37 | def forward(self, s): 38 | mask = torch.ones_like(s) 39 | mask[s == 0] = 0 40 | s = self.norm(s) 41 | s_aff = self.affinity_in(s) 42 | s_aff = torch.sum(s_aff, dim=-2) / torch.sum(mask, dim=-2) 43 | pred_affinity = self.binding_affinity_head(s_aff) 44 | return pred_affinity 45 | 46 | class BindingAffinityData(Dataset): 47 | def __init__(self, data, names, target_dict): 48 | self.data = data 49 | self.target_dict = target_dict 50 | self.names = names 51 | 52 | def __len__(self): 53 | return len(self.data) 54 | 55 | def __getitem__(self, idx): 56 | name = self.names[idx] 57 | x = self.data[idx].to(device='cuda') 58 | y = torch.tensor(self.target_dict[name]).unsqueeze(-1).to(device='cuda') 59 | return x, y 60 | 61 | def train_model(model, criterion, optimizer, train_loader, valid_loader, num_epochs, early_stopping_patience): 62 | best_loss = float('inf') 63 | epochs_no_improve = 0 64 | 65 | for epoch in range(num_epochs): 66 | model.train() 67 | running_loss = 0.0 68 | for inputs, targets in train_loader: 69 | optimizer.zero_grad() 70 | outputs = model(inputs) 71 | loss = criterion(outputs.to(dtype=float), targets.to(dtype=float)) 72 | loss.backward() 73 | optimizer.step() 74 | running_loss += loss.item() 75 | 76 | model.eval() 77 | val_running_loss = 0.0 78 | with torch.no_grad(): 79 | for inputs, targets in valid_loader: 80 | outputs = model(inputs) 81 | loss = criterion(outputs, targets) 82 | val_running_loss += loss.item() 83 | 84 | train_loss = running_loss / len(train_loader) 85 | val_loss = val_running_loss / len(valid_loader) 86 | log(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss}, Valid Loss: {val_loss}') 87 | 88 | if val_loss < best_loss: 89 | best_loss = val_loss 90 | epochs_no_improve = 0 91 | torch.save(model.state_dict(), 'curr_ckpt.pt') 92 | else: 93 | epochs_no_improve += 1 94 | if epochs_no_improve >= early_stopping_patience: 95 | log("Early stopping triggered.") 96 | break 97 | 98 | def get_predictions(model, test_loader): 99 | model.eval() 100 | predictions = [] 101 | true_values = [] 102 | with torch.no_grad(): 103 | for inputs, targets in test_loader: 104 | outputs = model(inputs) 105 | predictions.append(outputs.item()) 106 | true_values.append(targets.item()) 107 | return predictions, true_values 108 | 109 | def compute_metrics(true_values, predicted_values): 110 | rmsd = np.sqrt(mean_squared_error(true_values, predicted_values)) 111 | pearson_corr, _ = pearsonr(true_values, predicted_values) 112 | spearman_corr, _ = spearmanr(true_values, predicted_values) 113 | mae = mean_absolute_error(true_values, predicted_values) 114 | return rmsd, pearson_corr, spearman_corr, mae 115 | 116 | if __name__ == '__main__': 117 | args = parse_arguments() 118 | log(f'Using seed {args.seed}.') 119 | torch.manual_seed(args.seed) 120 | g = torch.Generator() 121 | g.manual_seed(args.seed) 122 | np.random.seed(args.seed) 123 | 124 | batch_size = 64 125 | learning_rate = 0.01 126 | num_epochs = 1000 127 | patience = 50 128 | 129 | log('Getting binding affinity data.') 130 | with open('data/binding_affinity_dict.pkl', 'rb') as f: 131 | binding_affinity_dict = pickle.load(f) 132 | 133 | if not args.ckpt: 134 | log('Getting training data.') 135 | train_outputs = torch.load( 136 | 'checkpoints/quickbind_default/train_predictions-w-single-rep.pt' 137 | ) 138 | train_affinities = {k: v for k, v in binding_affinity_dict.items() if k in train_outputs['names']} 139 | train_s= pad_sequence([s.squeeze() for s in train_outputs['s_pre_struct']], batch_first=True) 140 | train_dataset = BindingAffinityData(train_s, train_outputs['names'], train_affinities) 141 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, generator=g) 142 | 143 | log('Getting validation data.') 144 | val_outputs = torch.load( 145 | 'checkpoints/quickbind_default/val_predictions-w-single-rep.pt' 146 | ) 147 | valid_affinities = {k: v for k, v in binding_affinity_dict.items() if k in val_outputs['names']} 148 | valid_s= pad_sequence([s.squeeze() for s in val_outputs['s_pre_struct']], batch_first=True) 149 | valid_dataset = BindingAffinityData(valid_s, val_outputs['names'], valid_affinities) 150 | valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False) 151 | 152 | log('Getting test data.') 153 | test_outputs = torch.load( 154 | 'checkpoints/quickbind_default/predictions-w-single-rep.pt' 155 | ) 156 | test_affinities = {k: v for k, v in binding_affinity_dict.items() if k in test_outputs['names']} 157 | test_s= pad_sequence([s.squeeze() for s in test_outputs['s_pre_struct']], batch_first=True) 158 | test_dataset = BindingAffinityData(test_s, test_outputs['names'], test_affinities) 159 | test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False) 160 | 161 | model = BindingAffinityPredictor(64) 162 | 163 | if not args.ckpt: 164 | criterion = nn.MSELoss() 165 | optimizer = optim.Adam(model.parameters(), lr=learning_rate) 166 | log('Starting model training.') 167 | train_model(model, criterion, optimizer, train_loader, valid_loader, num_epochs, patience) 168 | model.load_state_dict(torch.load('curr_ckpt.pt')) 169 | else: 170 | model.load_state_dict(torch.load(args.ckpt)) 171 | 172 | log('Starting model evaluation.') 173 | predictions, true_values = get_predictions(model, test_loader) 174 | 175 | rmsd, pearson_corr, spearman_corr, mae = compute_metrics(true_values, predictions) 176 | 177 | log(f'RMSD: {rmsd}') 178 | log(f'Pearson Correlation: {pearson_corr}') 179 | log(f'Spearman Correlation: {spearman_corr}') 180 | log(f'MAE: {mae}') 181 | 182 | if not args.ckpt: 183 | torch.save( 184 | model.state_dict(), 185 | f'checkpoints/quickbind_default/binding_affinity_prediction/ckpt_seed{args.seed}.pt' 186 | ) 187 | -------------------------------------------------------------------------------- /train_pl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | import os 4 | import glob 5 | import torch 6 | import pytorch_lightning as pl 7 | from pytorch_lightning.loggers import WandbLogger 8 | from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint 9 | from pytorch_lightning.strategies.ddp import DDPStrategy 10 | from torch.utils.data import DataLoader 11 | from commons.utils import log, get_parameter_value 12 | from dataset.dataimporter import DataImporter 13 | from openfold.utils.seed import seed_globally 14 | from quickbind import QuickBind_PL 15 | from openfold.utils.rigid_utils import Rigid, Rotation 16 | torch.cuda.empty_cache() 17 | 18 | def parse_arguments(): 19 | p = argparse.ArgumentParser() 20 | p.add_argument('--config', type=argparse.FileType(mode='r')) 21 | p.add_argument('--resume', type=bool, default=False) 22 | p.add_argument('--id', type=str, default=None, help='W&B ID') 23 | p.add_argument('--finetune', type=bool, default=False) 24 | return p.parse_args() 25 | 26 | def collate(data, construct_frames, use_topological_distance, one_hot_adj): 27 | assert len(data) == 1 28 | data = data[0] 29 | aatype = data.aatype.unsqueeze(0) 30 | lig_atom_features = data.lig_atom_features.unsqueeze(0).to(dtype=torch.float32) 31 | t_true = data.true_lig_atom_coords.unsqueeze(0) 32 | rec_mask = torch.ones(data.aatype.shape[0]).unsqueeze(0) 33 | lig_mask = torch.ones(data.lig_atom_features.shape[0]).unsqueeze(0) 34 | 35 | if use_topological_distance: 36 | adj = torch.clamp(data.distance_matrix.unsqueeze(0), max=7) 37 | adj = torch.nn.functional.one_hot(adj.long(), num_classes=8).to(dtype=torch.float32) 38 | elif one_hot_adj: 39 | adj = data.adjacency_bo.unsqueeze(0).to(dtype=torch.int64) 40 | else: 41 | adj = data.adjacency.unsqueeze(0).unsqueeze(-1).to(dtype=torch.float32) 42 | 43 | ri = data.residue_index.unsqueeze(0).to(dtype=torch.int64) 44 | chain_id = data.chain_ids_processed.unsqueeze(0).to(dtype=torch.int64) 45 | entity_id = data.entity_ids_processed.unsqueeze(0).to(dtype=torch.int64) 46 | sym_id = data.sym_ids_processed.unsqueeze(0).to(dtype=torch.int64) 47 | id_batch = (ri, chain_id, entity_id, sym_id) 48 | 49 | t_rec = data.c_alpha_coords.unsqueeze(0) 50 | N = data.n_coords.unsqueeze(0) 51 | C = data.c_coords.unsqueeze(0) 52 | t_lig = data.lig_atom_coords.unsqueeze(0) 53 | 54 | pseudo_N = data.pseudo_N.unsqueeze(0) 55 | pseudo_C = data.pseudo_C.unsqueeze(0) 56 | true_pseudo_N = data.true_pseudo_N.unsqueeze(0) 57 | true_pseudo_C = data.true_pseudo_C.unsqueeze(0) 58 | 59 | if construct_frames: 60 | t_true = Rigid.from_3_points(true_pseudo_N, t_true, true_pseudo_C) 61 | else: 62 | t_true = Rigid( 63 | rots = Rotation.identity( 64 | shape=t_true.shape[:-1], dtype = torch.float32, fmt="quat" 65 | ), trans = t_true 66 | ) 67 | 68 | return (aatype, lig_atom_features, adj, rec_mask, lig_mask, N, t_rec, C, t_lig, id_batch, pseudo_N, pseudo_C), t_true 69 | 70 | def train(config): 71 | seed = 0 if config['seed'] is None else config['seed'] 72 | seed_globally(seed) 73 | pl.seed_everything(seed, workers=True) 74 | 75 | if args.finetune: 76 | config['dataset_params']['crop_size'] = 512 77 | 78 | log('Getting training data.') 79 | train_data = DataImporter(complex_names_path=config['train_names'], **config['dataset_params']) 80 | log('Getting validation data.') 81 | val_data = DataImporter(complex_names_path=config['val_names'], **config['dataset_params']) 82 | 83 | lig_feat_dim, rec_feat_dim = train_data.get_feature_dimensions() 84 | one_hot_adj = get_parameter_value('one_hot_adj', config['model_parameters']) 85 | use_topological_distance = get_parameter_value('use_topological_distance', config['model_parameters']) 86 | construct_frames = get_parameter_value('construct_frames', config['model_parameters']) 87 | 88 | train_loader = DataLoader( 89 | train_data, batch_size=config['batch_size'], shuffle=True, 90 | collate_fn=lambda x: collate(x, construct_frames, use_topological_distance, one_hot_adj), 91 | num_workers=config['num_workers'], prefetch_factor=12 92 | ) 93 | val_loader = DataLoader( 94 | val_data, batch_size=config['batch_size'], 95 | collate_fn=lambda x: collate(x, construct_frames, use_topological_distance, one_hot_adj), 96 | num_workers=config['num_workers'] 97 | ) 98 | 99 | model = QuickBind_PL( 100 | aa_feat=rec_feat_dim, lig_atom_feat=lig_feat_dim, 101 | **config['model_parameters'], loss_config=config['loss_params'], **config['optimizer_params'], 102 | chunk_size=None 103 | ) 104 | early_stopping = EarlyStopping( 105 | monitor='val_loss', 106 | patience=50 107 | ) 108 | checkpoint_callback = ModelCheckpoint( 109 | dirpath=f'checkpoints/{config["name"]}', 110 | filename='best_checkpoint', 111 | monitor='val_loss', 112 | save_top_k=3, 113 | mode="min", 114 | ) 115 | checkpoint_callback.FILE_EXTENSION = ".pt" 116 | lr_monitor = LearningRateMonitor(logging_interval='step') 117 | clip_grad = config['clip_grad'] or 0.0 118 | trainer = pl.Trainer( 119 | max_epochs=config['num_epochs'], 120 | logger=wandb_logger, 121 | accumulate_grad_batches=config['iters_to_accumulate'], # Gradient accumulation 122 | precision=16, # Mixed precision training 123 | gradient_clip_val=clip_grad, 124 | callbacks=[early_stopping, checkpoint_callback, lr_monitor], 125 | default_root_dir=f'checkpoints/{config["name"]}', 126 | deterministic=True, accelerator="gpu", 127 | strategy=DDPStrategy(find_unused_parameters=False), 128 | num_nodes=2 129 | ) 130 | if args.resume: 131 | checkpoints = glob.glob(f'checkpoints/{config["name"]}/best_checkpoint*.pt') 132 | latest_checkpoint = max(checkpoints, key=os.path.getctime) 133 | trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader, ckpt_path=latest_checkpoint) 134 | elif args.finetune: 135 | model_state = torch.load(f'checkpoints/{config["name"]}/best_checkpoint.pt') 136 | model.load_state_dict(model_state['state_dict']) 137 | trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) 138 | else: 139 | trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) 140 | 141 | if __name__ == '__main__': 142 | args = parse_arguments() 143 | torch.set_float32_matmul_precision('high') 144 | torch.backends.cudnn.benchmark = True 145 | torch.backends.cudnn.enabled = True 146 | config = yaml.load(args.config, Loader=yaml.FullLoader) 147 | if args.resume: 148 | assert args.id is not None, log('If you want to resume an experiment, you should provide the W&B ID!') 149 | wandb_logger = WandbLogger(name=config['name'], project=config['wandb']['project'], id=args.id) 150 | else: 151 | wandb_logger = WandbLogger(name=config['name'], project=config['wandb']['project']) 152 | name = config['name'] 153 | if not os.path.exists(f'checkpoints/{name}'): 154 | os.mkdir(f'checkpoints/{name}') 155 | train(config) 156 | -------------------------------------------------------------------------------- /virtual_screening.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Virtual Screening" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import torch\n", 17 | "import seaborn as sns\n", 18 | "import matplotlib.pyplot as plt\n", 19 | "from scipy.stats import ranksums" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "Read in the three sets of generated embeddings using the protein structures from the three lowest affinity binders, first for the decoys, then for the binders." 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 27, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "B1MDI3_decoys = torch.load('checkpoints/quickbind_default/virtual_screening/B1MDI3_predictionsw_single_rep.pt')\n", 36 | "B1MDI3_decoys_1 = torch.load('checkpoints/quickbind_default/virtual_screening/decoy_B1MDI3_predictions_1-w-single-rep.pt')\n", 37 | "B1MDI3_decoys_2 = torch.load('checkpoints/quickbind_default/virtual_screening/decoy_B1MDI3_predictions_2-w-single-rep.pt')" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 28, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "P56817_decoys = torch.load('checkpoints/quickbind_default/virtual_screening/P56817_predictionsw_single_rep.pt')\n", 47 | "P56817_decoys_1 = torch.load('checkpoints/quickbind_default/virtual_screening/decoy_P56817_predictions_1-w-single-rep.pt')\n", 48 | "P56817_decoys_2 = torch.load('checkpoints/quickbind_default/virtual_screening/decoy_P56817_predictions_2-w-single-rep.pt')" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 29, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "P17931_decoys = torch.load('checkpoints/quickbind_default/virtual_screening/P17931_predictionsw_single_rep.pt')\n", 58 | "P17931_decoys_1 = torch.load('checkpoints/quickbind_default/virtual_screening/decoy_P17931_predictions_1-w-single-rep.pt')\n", 59 | "P17931_decoys_2 = torch.load('checkpoints/quickbind_default/virtual_screening/decoy_P17931_predictions_2-w-single-rep.pt')" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 30, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "Q8ULI9_decoys = torch.load('checkpoints/quickbind_default/virtual_screening/Q8ULI9_predictionsw_single_rep.pt')\n", 69 | "Q8ULI9_decoys_1 = torch.load('checkpoints/quickbind_default/virtual_screening/decoy_Q8ULI9_predictions_1-w-single-rep.pt')\n", 70 | "Q8ULI9_decoys_2 = torch.load('checkpoints/quickbind_default/virtual_screening/decoy_Q8ULI9_predictions_2-w-single-rep.pt')" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 31, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "P01116_decoys = torch.load('checkpoints/quickbind_default/virtual_screening/P01116_predictionsw_single_rep.pt')\n", 80 | "P01116_decoys_1 = torch.load('checkpoints/quickbind_default/virtual_screening/decoy_P01116_predictions_1-w-single-rep.pt')\n", 81 | "P01116_decoys_2 = torch.load('checkpoints/quickbind_default/virtual_screening/decoy_P01116_predictions_2-w-single-rep.pt')" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 38, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "B1MDI3_binders = [\n", 91 | " '6qqt', '6qrf', '6qqw', '6qre', '6qrc', '6qqv', '6qrg', '6qr1', '6qqq', '6qqu', '6qr2', '6qra', '6qr4', '6qr3',\n", 92 | " '6qr9', '6qqz', '6qr0', '6qrd', '6qr7'\n", 93 | "]\n", 94 | "B1MDI3_binders_out = torch.load('checkpoints/quickbind_default/virtual_screening/true_B1MDI3_predictions_w_single_rep.pt')\n", 95 | "B1MDI3_binders_out_1 = torch.load('checkpoints/quickbind_default/virtual_screening/true_B1MDI3_predictions_1-w-single-rep.pt')\n", 96 | "B1MDI3_binders_out_2 = torch.load('checkpoints/quickbind_default/virtual_screening/true_B1MDI3_predictions_2-w-single-rep.pt')" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 39, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "P56817_binders = [\n", 106 | " '6uvp', '6uvv', '6uvy', '6uwp', '6nw3', '6e3z', '6nv7', '6uwv', '6nv9', '6od6', '6jt3', '6jsg',\n", 107 | " '6jsn', '6jsf', '6jse', '6pz4'\n", 108 | "]\n", 109 | "P56817_binders_out = torch.load('checkpoints/quickbind_default/virtual_screening/true_P56817_predictions_w_single_rep.pt')\n", 110 | "P56817_binders_out_1 = torch.load('checkpoints/quickbind_default/virtual_screening/true_P56817_predictions_1-w-single-rep.pt')\n", 111 | "P56817_binders_out_2 = torch.load('checkpoints/quickbind_default/virtual_screening/true_P56817_predictions_2-w-single-rep.pt')" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 40, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "P17931_binders = [\n", 121 | " '6qlt', '6qlq', '6i75', '6qlu', '6qln', '6i78', '6qlr', '6i77', '6qlo', '6i76', '6qlp', '6qls',\n", 122 | " '6i74', '6qge', '6qgf'\n", 123 | "]\n", 124 | "P17931_binders_out = torch.load('checkpoints/quickbind_default/virtual_screening/true_P17931_predictions_w_single_rep.pt')\n", 125 | "P17931_binders_out_1 = torch.load('checkpoints/quickbind_default/virtual_screening/true_P17931_predictions_1-w-single-rep.pt')\n", 126 | "P17931_binders_out_2 = torch.load('checkpoints/quickbind_default/virtual_screening/true_P17931_predictions_2-w-single-rep.pt')" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 41, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "Q8ULI9_binders = [\n", 136 | " '6oy0', '6oxy', '6oxz', '6oy2', '6oxt', '6oxs', '6oxx', '6oxu', '6oy1', '6oxr', '6oxw', '6oxv',\n", 137 | " '6oxp', '6oxq'\n", 138 | "]\n", 139 | "Q8ULI9_binders_out = torch.load('checkpoints/quickbind_default/virtual_screening/true_Q8ULI9_predictions_w_single_rep.pt')\n", 140 | "Q8ULI9_binders_out_1 = torch.load('checkpoints/quickbind_default/virtual_screening/true_Q8ULI9_predictions_1-w-single-rep.pt')\n", 141 | "Q8ULI9_binders_out_2 = torch.load('checkpoints/quickbind_default/virtual_screening/true_Q8ULI9_predictions_2-w-single-rep.pt')" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 42, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "P01116_binders = [\n", 151 | " '6quw', '6quv', '6gj5', '6v5l', '6gj6', '6p8z', '6pgo', '6p8x', '6gj8', '6gj7', '6p8y', '6pgp',\n", 152 | " '6oim'\n", 153 | "]\n", 154 | "P01116_binders_out = torch.load('checkpoints/quickbind_default/virtual_screening/true_P01116_predictions_w_single_rep.pt')\n", 155 | "P01116_binders_out_1 = torch.load('checkpoints/quickbind_default/virtual_screening/true_P01116_predictions_1-w-single-rep.pt')\n", 156 | "P01116_binders_out_2 = torch.load('checkpoints/quickbind_default/virtual_screening/true_P01116_predictions_2-w-single-rep.pt')" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "metadata": {}, 162 | "source": [ 163 | "Predict binding affinities." 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 59, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "from train_binding_affinity import BindingAffinityPredictor, BindingAffinityData, get_predictions\n", 173 | "from torch.utils.data import DataLoader\n", 174 | "import pickle" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 60, 180 | "metadata": {}, 181 | "outputs": [ 182 | { 183 | "data": { 184 | "text/plain": [ 185 | "" 186 | ] 187 | }, 188 | "execution_count": 60, 189 | "metadata": {}, 190 | "output_type": "execute_result" 191 | } 192 | ], 193 | "source": [ 194 | "model = BindingAffinityPredictor(64)\n", 195 | "model.load_state_dict(torch.load(\n", 196 | " 'checkpoints/quickbind_default/binding_affinity_prediction/ckpt_seed42.pt',\n", 197 | "))" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 61, 203 | "metadata": {}, 204 | "outputs": [], 205 | "source": [ 206 | "with open('binding_affinity_dict.pkl', 'rb') as f:\n", 207 | " binding_affinity_dict = pickle.load(f)" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 62, 213 | "metadata": {}, 214 | "outputs": [], 215 | "source": [ 216 | "test_proteins = torch.load('data/processed/timesplit_test/rec_input_proc_ids.pt')\n", 217 | "test_protein_names = [p['complex_names'] for p in test_proteins]" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 63, 223 | "metadata": {}, 224 | "outputs": [], 225 | "source": [ 226 | "lists = []\n", 227 | "lists_1 = []\n", 228 | "lists_2 = []" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "metadata": {}, 235 | "outputs": [], 236 | "source": [ 237 | "B1MDI3_binders_input = [s.squeeze() for s in B1MDI3_binders_out['s_pre_struct']]\n", 238 | "B1MDI3_binders_ordered = [n for n in test_protein_names if n in B1MDI3_binders]\n", 239 | "B1MDI3_binders_dataset = BindingAffinityData(B1MDI3_binders_input, B1MDI3_binders_ordered, binding_affinity_dict)\n", 240 | "B1MDI3_binders_loader = DataLoader(B1MDI3_binders_dataset, batch_size=1, shuffle=False)\n", 241 | "predictions_binders, _ = get_predictions(model, B1MDI3_binders_loader)\n", 242 | "\n", 243 | "B1MDI3_decoys_input = [s.squeeze() for s in B1MDI3_decoys['s_pre_struct']]\n", 244 | "B1MDI3_decoys_dataset = BindingAffinityData(B1MDI3_decoys_input, B1MDI3_decoys['names'], binding_affinity_dict)\n", 245 | "B1MDI3_decoys_loader = DataLoader(B1MDI3_decoys_dataset, batch_size=1, shuffle=False)\n", 246 | "predictions_decoys, _ = get_predictions(model, B1MDI3_decoys_loader)\n", 247 | "\n", 248 | "lists.append([predictions_binders, predictions_decoys])\n", 249 | "\n", 250 | "######################\n", 251 | "\n", 252 | "B1MDI3_binders_input_1 = [s.squeeze() for s in B1MDI3_binders_out_1['s_pre_struct']]\n", 253 | "B1MDI3_binders_dataset = BindingAffinityData(B1MDI3_binders_input_1, B1MDI3_binders_ordered, binding_affinity_dict)\n", 254 | "B1MDI3_binders_loader = DataLoader(B1MDI3_binders_dataset, batch_size=1, shuffle=False)\n", 255 | "predictions_binders, _ = get_predictions(model, B1MDI3_binders_loader)\n", 256 | "\n", 257 | "B1MDI3_decoys_input_1 = [s.squeeze() for s in B1MDI3_decoys_1['s_pre_struct']]\n", 258 | "B1MDI3_decoys_dataset = BindingAffinityData(B1MDI3_decoys_input_1, B1MDI3_decoys_1['names'], binding_affinity_dict)\n", 259 | "B1MDI3_decoys_loader = DataLoader(B1MDI3_decoys_dataset, batch_size=1, shuffle=False)\n", 260 | "predictions_decoys, _ = get_predictions(model, B1MDI3_decoys_loader)\n", 261 | "\n", 262 | "lists_1.append([predictions_binders, predictions_decoys])\n", 263 | "\n", 264 | "######################\n", 265 | "\n", 266 | "B1MDI3_binders_input_2 = [s.squeeze() for s in B1MDI3_binders_out_2['s_pre_struct']]\n", 267 | "B1MDI3_binders_dataset = BindingAffinityData(B1MDI3_binders_input_2, B1MDI3_binders_ordered, binding_affinity_dict)\n", 268 | "B1MDI3_binders_loader = DataLoader(B1MDI3_binders_dataset, batch_size=1, shuffle=False)\n", 269 | "\n", 270 | "predictions_binders, _ = get_predictions(model, B1MDI3_binders_loader)\n", 271 | "\n", 272 | "B1MDI3_decoys_input_2 = [s.squeeze() for s in B1MDI3_decoys_2['s_pre_struct']]\n", 273 | "B1MDI3_decoys_dataset = BindingAffinityData(B1MDI3_decoys_input_2, B1MDI3_decoys_2['names'], binding_affinity_dict)\n", 274 | "B1MDI3_decoys_loader = DataLoader(B1MDI3_decoys_dataset, batch_size=1, shuffle=False)\n", 275 | "predictions_decoys, _ = get_predictions(model, B1MDI3_decoys_loader)\n", 276 | "\n", 277 | "lists_2.append([predictions_binders, predictions_decoys])" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": null, 283 | "metadata": {}, 284 | "outputs": [], 285 | "source": [ 286 | "P56817_binders_input = [s.squeeze() for s in P56817_binders_out['s_pre_struct']]\n", 287 | "P56817_binders_ordered = [n for n in test_protein_names if n in P56817_binders]\n", 288 | "P56817_binders_dataset = BindingAffinityData(P56817_binders_input, P56817_binders_ordered, binding_affinity_dict)\n", 289 | "P56817_binders_loader = DataLoader(P56817_binders_dataset, batch_size=1, shuffle=False)\n", 290 | "predictions_binders, _ = get_predictions(model, P56817_binders_loader)\n", 291 | "\n", 292 | "P56817_decoys_input = [s.squeeze() for s in P56817_decoys['s_pre_struct']]\n", 293 | "P56817_decoys_dataset = BindingAffinityData(P56817_decoys_input, P56817_decoys['names'], binding_affinity_dict)\n", 294 | "P56817_decoys_loader = DataLoader(P56817_decoys_dataset, batch_size=1, shuffle=False)\n", 295 | "predictions_decoys, _ = get_predictions(model, P56817_decoys_loader)\n", 296 | "\n", 297 | "lists.append([predictions_binders, predictions_decoys])\n", 298 | "\n", 299 | "######################\n", 300 | "\n", 301 | "P56817_binders_input_1 = [s.squeeze() for s in P56817_binders_out_1['s_pre_struct']]\n", 302 | "P56817_binders_dataset = BindingAffinityData(P56817_binders_input_1, P56817_binders_ordered, binding_affinity_dict)\n", 303 | "P56817_binders_loader = DataLoader(P56817_binders_dataset, batch_size=1, shuffle=False)\n", 304 | "predictions_binders, _ = get_predictions(model, P56817_binders_loader)\n", 305 | "\n", 306 | "P56817_decoys_input_1 = [s.squeeze() for s in P56817_decoys_1['s_pre_struct']]\n", 307 | "P56817_decoys_dataset = BindingAffinityData(P56817_decoys_input_1, P56817_decoys_1['names'], binding_affinity_dict)\n", 308 | "P56817_decoys_loader = DataLoader(P56817_decoys_dataset, batch_size=1, shuffle=False)\n", 309 | "predictions_decoys, _ = get_predictions(model, P56817_decoys_loader)\n", 310 | "\n", 311 | "lists_1.append([predictions_binders, predictions_decoys])\n", 312 | "\n", 313 | "######################\n", 314 | "\n", 315 | "P56817_binders_input_2 = [s.squeeze() for s in P56817_binders_out_2['s_pre_struct']]\n", 316 | "P56817_binders_dataset = BindingAffinityData(P56817_binders_input_2, P56817_binders_ordered, binding_affinity_dict)\n", 317 | "P56817_binders_loader = DataLoader(P56817_binders_dataset, batch_size=1, shuffle=False)\n", 318 | "predictions_binders, _ = get_predictions(model, P56817_binders_loader)\n", 319 | "\n", 320 | "P56817_decoys_input_2 = [s.squeeze() for s in P56817_decoys_2['s_pre_struct']]\n", 321 | "P56817_decoys_dataset = BindingAffinityData(P56817_decoys_input_2, P56817_decoys_2['names'], binding_affinity_dict)\n", 322 | "P56817_decoys_loader = DataLoader(P56817_decoys_dataset, batch_size=1, shuffle=False)\n", 323 | "predictions_decoys, _ = get_predictions(model, P56817_decoys_loader)\n", 324 | "\n", 325 | "lists_2.append([predictions_binders, predictions_decoys])" 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": null, 331 | "metadata": {}, 332 | "outputs": [], 333 | "source": [ 334 | "P17931_binders_input = [s.squeeze() for s in P17931_binders_out['s_pre_struct']]\n", 335 | "P17931_binders_ordered = [n for n in test_protein_names if n in P17931_binders]\n", 336 | "P17931_binders_dataset = BindingAffinityData(P17931_binders_input, P17931_binders_ordered, binding_affinity_dict)\n", 337 | "P17931_binders_loader = DataLoader(P17931_binders_dataset, batch_size=1, shuffle=False)\n", 338 | "predictions_binders, _ = get_predictions(model, P17931_binders_loader)\n", 339 | "\n", 340 | "P17931_decoys_input = [s.squeeze() for s in P17931_decoys['s_pre_struct']]\n", 341 | "P17931_decoys_dataset = BindingAffinityData(P17931_decoys_input, P17931_decoys['names'], binding_affinity_dict)\n", 342 | "P17931_decoys_loader = DataLoader(P17931_decoys_dataset, batch_size=1, shuffle=False)\n", 343 | "predictions_decoys, _ = get_predictions(model, P17931_decoys_loader)\n", 344 | "\n", 345 | "lists.append([predictions_binders, predictions_decoys])\n", 346 | "\n", 347 | "######################\n", 348 | "\n", 349 | "P17931_binders_input_1 = [s.squeeze() for s in P17931_binders_out_1['s_pre_struct']]\n", 350 | "P17931_binders_dataset = BindingAffinityData(P17931_binders_input_1, P17931_binders_ordered, binding_affinity_dict)\n", 351 | "P17931_binders_loader = DataLoader(P17931_binders_dataset, batch_size=1, shuffle=False)\n", 352 | "predictions_binders, _ = get_predictions(model, P17931_binders_loader)\n", 353 | "\n", 354 | "P17931_decoys_input_1 = [s.squeeze() for s in P17931_decoys_1['s_pre_struct']]\n", 355 | "P17931_decoys_dataset = BindingAffinityData(P17931_decoys_input_1, P17931_decoys_1['names'], binding_affinity_dict)\n", 356 | "P17931_decoys_loader = DataLoader(P17931_decoys_dataset, batch_size=1, shuffle=False)\n", 357 | "predictions_decoys, _ = get_predictions(model, P17931_decoys_loader)\n", 358 | "\n", 359 | "lists_1.append([predictions_binders, predictions_decoys])\n", 360 | "\n", 361 | "######################\n", 362 | "\n", 363 | "P17931_binders_input_2 = [s.squeeze() for s in P17931_binders_out_2['s_pre_struct']]\n", 364 | "P17931_binders_dataset = BindingAffinityData(P17931_binders_input_2, P17931_binders_ordered, binding_affinity_dict)\n", 365 | "P17931_binders_loader = DataLoader(P17931_binders_dataset, batch_size=1, shuffle=False)\n", 366 | "predictions_binders, _ = get_predictions(model, P17931_binders_loader)\n", 367 | "\n", 368 | "P17931_decoys_input_2 = [s.squeeze() for s in P17931_decoys_2['s_pre_struct']]\n", 369 | "P17931_decoys_dataset = BindingAffinityData(P17931_decoys_input_2, P17931_decoys_2['names'], binding_affinity_dict)\n", 370 | "P17931_decoys_loader = DataLoader(P17931_decoys_dataset, batch_size=1, shuffle=False)\n", 371 | "predictions_decoys, _ = get_predictions(model, P17931_decoys_loader)\n", 372 | "\n", 373 | "lists_2.append([predictions_binders, predictions_decoys])" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": null, 379 | "metadata": {}, 380 | "outputs": [], 381 | "source": [ 382 | "Q8ULI9_binders_input = [s.squeeze() for s in Q8ULI9_binders_out['s_pre_struct']]\n", 383 | "Q8ULI9_binders_ordered = [n for n in test_protein_names if n in Q8ULI9_binders]\n", 384 | "Q8ULI9_binders_dataset = BindingAffinityData(Q8ULI9_binders_input, Q8ULI9_binders_ordered, binding_affinity_dict)\n", 385 | "Q8ULI9_binders_loader = DataLoader(Q8ULI9_binders_dataset, batch_size=1, shuffle=False)\n", 386 | "predictions_binders, _ = get_predictions(model, Q8ULI9_binders_loader)\n", 387 | "\n", 388 | "Q8ULI9_decoys_input = [s.squeeze() for s in Q8ULI9_decoys['s_pre_struct']]\n", 389 | "Q8ULI9_decoys_dataset = BindingAffinityData(Q8ULI9_decoys_input, Q8ULI9_decoys['names'], binding_affinity_dict)\n", 390 | "Q8ULI9_decoys_loader = DataLoader(Q8ULI9_decoys_dataset, batch_size=1, shuffle=False)\n", 391 | "predictions_decoys, _ = get_predictions(model, Q8ULI9_decoys_loader)\n", 392 | "\n", 393 | "lists.append([predictions_binders, predictions_decoys])\n", 394 | "\n", 395 | "######################\n", 396 | "\n", 397 | "Q8ULI9_binders_input_1 = [s.squeeze() for s in Q8ULI9_binders_out_1['s_pre_struct']]\n", 398 | "Q8ULI9_binders_dataset = BindingAffinityData(Q8ULI9_binders_input_1, Q8ULI9_binders_ordered, binding_affinity_dict)\n", 399 | "Q8ULI9_binders_loader = DataLoader(Q8ULI9_binders_dataset, batch_size=1, shuffle=False)\n", 400 | "predictions_binders, _ = get_predictions(model, Q8ULI9_binders_loader)\n", 401 | "\n", 402 | "Q8ULI9_decoys_input_1 = [s.squeeze() for s in Q8ULI9_decoys_1['s_pre_struct']]\n", 403 | "Q8ULI9_decoys_dataset = BindingAffinityData(Q8ULI9_decoys_input_1, Q8ULI9_decoys_1['names'], binding_affinity_dict)\n", 404 | "Q8ULI9_decoys_loader = DataLoader(Q8ULI9_decoys_dataset, batch_size=1, shuffle=False)\n", 405 | "predictions_decoys, _ = get_predictions(model, Q8ULI9_decoys_loader)\n", 406 | "\n", 407 | "lists_1.append([predictions_binders, predictions_decoys])\n", 408 | "\n", 409 | "######################\n", 410 | "\n", 411 | "Q8ULI9_binders_input_2 = [s.squeeze() for s in Q8ULI9_binders_out_2['s_pre_struct']]\n", 412 | "Q8ULI9_binders_dataset = BindingAffinityData(Q8ULI9_binders_input_2, Q8ULI9_binders_ordered, binding_affinity_dict)\n", 413 | "Q8ULI9_binders_loader = DataLoader(Q8ULI9_binders_dataset, batch_size=1, shuffle=False)\n", 414 | "predictions_binders, _ = get_predictions(model, Q8ULI9_binders_loader)\n", 415 | "\n", 416 | "Q8ULI9_decoys_input_2 = [s.squeeze() for s in Q8ULI9_decoys_2['s_pre_struct']]\n", 417 | "Q8ULI9_decoys_dataset = BindingAffinityData(Q8ULI9_decoys_input_2, Q8ULI9_decoys_2['names'], binding_affinity_dict)\n", 418 | "Q8ULI9_decoys_loader = DataLoader(Q8ULI9_decoys_dataset, batch_size=1, shuffle=False)\n", 419 | "predictions_decoys, _ = get_predictions(model, Q8ULI9_decoys_loader)\n", 420 | "\n", 421 | "lists_2.append([predictions_binders, predictions_decoys])" 422 | ] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": null, 427 | "metadata": {}, 428 | "outputs": [], 429 | "source": [ 430 | "P01116_binders_input = [s.squeeze() for s in P01116_binders_out['s_pre_struct']]\n", 431 | "P01116_binders_ordered = [n for n in test_protein_names if n in P01116_binders]\n", 432 | "P01116_binders_dataset = BindingAffinityData(P01116_binders_input, P01116_binders_ordered, binding_affinity_dict)\n", 433 | "P01116_binders_loader = DataLoader(P01116_binders_dataset, batch_size=1, shuffle=False)\n", 434 | "predictions_binders, _ = get_predictions(model, P01116_binders_loader)\n", 435 | "\n", 436 | "P01116_decoys_input = [s.squeeze() for s in P01116_decoys['s_pre_struct']]\n", 437 | "P01116_decoys_dataset = BindingAffinityData(P01116_decoys_input, P01116_decoys['names'], binding_affinity_dict)\n", 438 | "P01116_decoys_loader = DataLoader(P01116_decoys_dataset, batch_size=1, shuffle=False)\n", 439 | "predictions_decoys, _ = get_predictions(model, P01116_decoys_loader)\n", 440 | "\n", 441 | "lists.append([predictions_binders, predictions_decoys])\n", 442 | "\n", 443 | "######################\n", 444 | "\n", 445 | "P01116_binders_input_1 = [s.squeeze() for s in P01116_binders_out_1['s_pre_struct']]\n", 446 | "P01116_binders_dataset = BindingAffinityData(P01116_binders_input_1, P01116_binders_ordered, binding_affinity_dict)\n", 447 | "P01116_binders_loader = DataLoader(P01116_binders_dataset, batch_size=1, shuffle=False)\n", 448 | "predictions_binders, _ = get_predictions(model, P01116_binders_loader)\n", 449 | "\n", 450 | "P01116_decoys_input_1 = [s.squeeze() for s in P01116_decoys_1['s_pre_struct']]\n", 451 | "P01116_decoys_dataset = BindingAffinityData(P01116_decoys_input_1, P01116_decoys_1['names'], binding_affinity_dict)\n", 452 | "P01116_decoys_loader = DataLoader(P01116_decoys_dataset, batch_size=1, shuffle=False)\n", 453 | "predictions_decoys, _ = get_predictions(model, P01116_decoys_loader)\n", 454 | "\n", 455 | "lists_1.append([predictions_binders, predictions_decoys])\n", 456 | "\n", 457 | "######################\n", 458 | "\n", 459 | "P01116_binders_input_2 = [s.squeeze() for s in P01116_binders_out_2['s_pre_struct']]\n", 460 | "P01116_binders_dataset = BindingAffinityData(P01116_binders_input_2, P01116_binders_ordered, binding_affinity_dict)\n", 461 | "P01116_binders_loader = DataLoader(P01116_binders_dataset, batch_size=1, shuffle=False)\n", 462 | "predictions_binders, _ = get_predictions(model, P01116_binders_loader)\n", 463 | "\n", 464 | "P01116_decoys_input_2 = [s.squeeze() for s in P01116_decoys_2['s_pre_struct']]\n", 465 | "P01116_decoys_dataset = BindingAffinityData(P01116_decoys_input_2, P01116_decoys_2['names'], binding_affinity_dict)\n", 466 | "P01116_decoys_loader = DataLoader(P01116_decoys_dataset, batch_size=1, shuffle=False)\n", 467 | "predictions_decoys, _ = get_predictions(model, P01116_decoys_loader)\n", 468 | "\n", 469 | "lists_2.append([predictions_binders, predictions_decoys])" 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": null, 475 | "metadata": {}, 476 | "outputs": [], 477 | "source": [ 478 | "fig, axs = plt.subplots(1, 5, figsize=(10, 5))\n", 479 | "sns.set(style=\"ticks\")\n", 480 | "names = ['B1MDI3', 'P56817', 'P17931', 'Q8ULI9', 'P01116']\n", 481 | "\n", 482 | "for i, set_of_lists in enumerate(lists):\n", 483 | " name = names[i]\n", 484 | " statistic, p_value = ranksums(set_of_lists[0], set_of_lists[1], alternative='greater')\n", 485 | "\n", 486 | " sns.boxplot(data=[set_of_lists[0], set_of_lists[1]], ax=axs[i])\n", 487 | " axs[i].set_xticks([0, 1])\n", 488 | " axs[i].set_xticklabels(['Binders', 'Decoys'])\n", 489 | " axs[i].set_title(f'{name}\\np-value: {p_value:.4f}')\n", 490 | "\n", 491 | "axs[0].set_ylabel('Predicted Binding Affinity')\n", 492 | "plt.tight_layout()\n", 493 | "plt.savefig('binding_affinity_screen.jpg', bbox_inches='tight', format='jpg')\n", 494 | "plt.show()" 495 | ] 496 | }, 497 | { 498 | "cell_type": "code", 499 | "execution_count": null, 500 | "metadata": {}, 501 | "outputs": [], 502 | "source": [ 503 | "fig, axs = plt.subplots(1, 5, figsize=(10, 5))\n", 504 | "sns.set(style=\"ticks\")\n", 505 | "names = ['B1MDI3', 'P56817', 'P17931', 'Q8ULI9', 'P01116']\n", 506 | "\n", 507 | "for i, set_of_lists in enumerate(lists_1):\n", 508 | " name = names[i]\n", 509 | " statistic, p_value = ranksums(set_of_lists[0], set_of_lists[1], alternative='greater')\n", 510 | "\n", 511 | " sns.boxplot(data=[set_of_lists[0], set_of_lists[1]], ax=axs[i])\n", 512 | " axs[i].set_xticks([0, 1])\n", 513 | " axs[i].set_xticklabels(['Binders', 'Decoys'])\n", 514 | " axs[i].set_title(f'{name}\\np-value: {p_value:.4f}')\n", 515 | "\n", 516 | "axs[0].set_ylabel('Predicted Binding Affinity')\n", 517 | "plt.tight_layout()\n", 518 | "plt.savefig('binding_affinity_screen_1.jpg', bbox_inches='tight', format='jpg')\n", 519 | "plt.show()" 520 | ] 521 | }, 522 | { 523 | "cell_type": "code", 524 | "execution_count": null, 525 | "metadata": {}, 526 | "outputs": [], 527 | "source": [ 528 | "fig, axs = plt.subplots(1, 5, figsize=(10, 5))\n", 529 | "sns.set(style=\"ticks\")\n", 530 | "names = ['B1MDI3', 'P56817', 'P17931', 'Q8ULI9', 'P01116']\n", 531 | "\n", 532 | "for i, set_of_lists in enumerate(lists_2):\n", 533 | " name = names[i]\n", 534 | " statistic, p_value = ranksums(set_of_lists[0], set_of_lists[1], alternative='greater')\n", 535 | "\n", 536 | " sns.boxplot(data=[set_of_lists[0], set_of_lists[1]], ax=axs[i])\n", 537 | " axs[i].set_xticks([0, 1])\n", 538 | " axs[i].set_xticklabels(['Binders', 'Decoys'])\n", 539 | " axs[i].set_title(f'{name}\\np-value: {p_value:.4f}')\n", 540 | "\n", 541 | "axs[0].set_ylabel('Predicted Binding Affinity')\n", 542 | "plt.tight_layout()\n", 543 | "plt.savefig('binding_affinity_screen_2.jpg', bbox_inches='tight', format='jpg')\n", 544 | "plt.show()" 545 | ] 546 | }, 547 | { 548 | "cell_type": "markdown", 549 | "metadata": {}, 550 | "source": [ 551 | "# Cross-docking" 552 | ] 553 | }, 554 | { 555 | "cell_type": "code", 556 | "execution_count": 1, 557 | "metadata": {}, 558 | "outputs": [], 559 | "source": [ 560 | "import torch\n", 561 | "from tqdm import tqdm\n", 562 | "from commons.utils import read_molecule, reorder_atoms, get_symmetry_rmsd, rigid_transform_Kabsch_3D\n", 563 | "from copy import deepcopy\n", 564 | "from rdkit.Chem import RemoveHs\n", 565 | "from rdkit.Geometry import Point3D\n", 566 | "from inference import print_results\n", 567 | "import os\n", 568 | "import numpy as np\n", 569 | "from Bio.Align import PairwiseAligner " 570 | ] 571 | }, 572 | { 573 | "cell_type": "code", 574 | "execution_count": 2, 575 | "metadata": {}, 576 | "outputs": [], 577 | "source": [ 578 | "# functions to find indices to remove from sequences\n", 579 | "def find_extra_indices(reference, input_str):\n", 580 | " extra_indices = []\n", 581 | " i_ref = 0\n", 582 | " i_input = 0\n", 583 | "\n", 584 | " while i_ref < len(reference) and i_input < len(input_str):\n", 585 | " if reference[i_ref] == input_str[i_input]:\n", 586 | " i_ref += 1\n", 587 | " i_input += 1\n", 588 | " else:\n", 589 | " extra_indices.append(i_input)\n", 590 | " i_input += 1\n", 591 | "\n", 592 | " # If there are extra characters at the end of input_str\n", 593 | " while i_input < len(input_str):\n", 594 | " extra_indices.append(i_input)\n", 595 | " i_input += 1\n", 596 | "\n", 597 | " return extra_indices\n", 598 | "\n", 599 | "def find_indices_to_remove(seq1, seq2):\n", 600 | " # Perform sequence alignment\n", 601 | " aligner = PairwiseAligner()\n", 602 | " alignments = aligner.align(seqA=seq1, seqB=seq2)\n", 603 | "\n", 604 | " # Get the aligned sequences\n", 605 | " aligned_seq1 = alignments[0][0]\n", 606 | " aligned_seq2 = alignments[0][1]\n", 607 | " \n", 608 | " # Find indices of differing amino acids\n", 609 | " differing_indices = [i for i, (a1, a2) in enumerate(zip(aligned_seq1, aligned_seq2)) if a1 != a2]\n", 610 | "\n", 611 | " aligned_seq1_clean = ''.join([s for i,s in enumerate(aligned_seq1) if i not in differing_indices])\n", 612 | " aligned_seq2_clean = ''.join([s for i,s in enumerate(aligned_seq2) if i not in differing_indices])\n", 613 | "\n", 614 | " assert len(aligned_seq1_clean) <= len(seq1)\n", 615 | " assert len(aligned_seq2_clean) <= len(seq2)\n", 616 | "\n", 617 | " indices_to_remove_seq1 = find_extra_indices(aligned_seq1_clean, seq1)\n", 618 | " indices_to_remove_seq2 = find_extra_indices(aligned_seq2_clean, seq2)\n", 619 | "\n", 620 | " assert len(seq1) - len(indices_to_remove_seq1) == len(seq2) - len(indices_to_remove_seq2)\n", 621 | " \n", 622 | " return indices_to_remove_seq1, indices_to_remove_seq2" 623 | ] 624 | }, 625 | { 626 | "cell_type": "code", 627 | "execution_count": 3, 628 | "metadata": {}, 629 | "outputs": [], 630 | "source": [ 631 | "test_proteins = torch.load('data/processed/timesplit_test/rec_input_proc_ids.pt')\n", 632 | "test_protein_names = [p['complex_names'] for p in test_proteins]\n", 633 | "test_c_alphas = [p['c_alpha_coords'] for p in test_proteins]\n", 634 | "test_sequences = [p['sequence'] for p in test_proteins]\n", 635 | "outputs = torch.load('checkpoints/quickbind_default/predictions-w-single-rep.pt')\n", 636 | "\n", 637 | "coms = []\n", 638 | "for rec in test_proteins:\n", 639 | " c_alpha_coords = rec['c_alpha_coords']\n", 640 | " if c_alpha_coords.shape[0] > 2000: # inference is currently limited to complexes with less than 2000 residues\n", 641 | " continue\n", 642 | " c = torch.mean(c_alpha_coords, dim=0)\n", 643 | " coms.append(c)" 644 | ] 645 | }, 646 | { 647 | "cell_type": "code", 648 | "execution_count": 4, 649 | "metadata": {}, 650 | "outputs": [], 651 | "source": [ 652 | "def evaluate_crossdocking(binders_list, uid, ref_name, idx=None):\n", 653 | " binders_out = torch.load(\n", 654 | " f'virtual_screening/true_{uid}_predictions{f\"_{idx}\" if idx else \"\"}-w-single-rep.pt'\n", 655 | " )\n", 656 | " ref_index = test_protein_names.index(ref_name)\n", 657 | " binders_ordered = [n for n in test_protein_names if n in binders_list]\n", 658 | " true_c_alphas_ordered = [c for c, n in zip(test_c_alphas, test_protein_names) if n in binders_list]\n", 659 | " true_sequences = [s for s, n in zip(test_sequences, test_protein_names) if n in binders_list]\n", 660 | " true_coms = [c for c, n in zip(coms, test_protein_names) if n in binders_list]\n", 661 | " input_c_alphas = test_c_alphas[ref_index]\n", 662 | " input_sequence = test_sequences[ref_index]\n", 663 | " input_com = coms[ref_index]\n", 664 | " targets = [t for t, n in zip(outputs['targets'], outputs['names']) if n in binders_ordered]\n", 665 | "\n", 666 | " kabsch_rmsds, rmsds, centroid_distances, protein_rmsds = [], [], [], []\n", 667 | " for prediction, target, name, sequence, true_c_alphas, com in tqdm(zip(\n", 668 | " binders_out['predictions'], targets, binders_ordered, true_sequences, true_c_alphas_ordered, true_coms),\n", 669 | " desc='Evaluating model predictions', total = len(binders_out['predictions'])\n", 670 | " ):\n", 671 | " if name == ref_name:\n", 672 | " continue\n", 673 | " assert prediction.shape == target.shape\n", 674 | " lig = read_molecule(os.path.join('data/PDBBind/', name, f'{name}_ligand.mol2'), remove_hs=True)\n", 675 | " if lig is None:\n", 676 | " lig = read_molecule(os.path.join('data/PDBBind/', name, f'{name}_ligand.sdf'), remove_hs=True)\n", 677 | " lig = RemoveHs(lig)\n", 678 | " lig = reorder_atoms(lig)\n", 679 | " \n", 680 | " lig_pred = deepcopy(lig)\n", 681 | " conf = lig_pred.GetConformer()\n", 682 | " prediction = prediction.squeeze().cpu().numpy()\n", 683 | " for i in range(lig_pred.GetNumAtoms()):\n", 684 | " x, y, z = prediction[i]\n", 685 | " conf.SetAtomPosition(i, Point3D(float(x), float(y), float(z)))\n", 686 | " coords_pred = lig_pred.GetConformer().GetPositions()\n", 687 | " \n", 688 | " # transform prediction based on alignment of protein C alpha atoms\n", 689 | " indices_input, indices_true = find_indices_to_remove(input_sequence, sequence)\n", 690 | " keep_input = torch.tensor([i for i in range(len(input_sequence)) if i not in indices_input])\n", 691 | " keep_true = torch.tensor([i for i in range(len(sequence)) if i not in indices_true])\n", 692 | " updated_input = torch.index_select(input_c_alphas, 0, keep_input).numpy() - input_com.numpy()\n", 693 | " updated_true = torch.index_select(true_c_alphas, 0, keep_true).numpy() - com.numpy()\n", 694 | " R, t = rigid_transform_Kabsch_3D(updated_input.T, updated_true.T)\n", 695 | " \n", 696 | " coords_pred = (R @ (coords_pred).T).T + t.squeeze()\n", 697 | " transformed_input = (R @ (updated_input).T).T + t.squeeze()\n", 698 | " protein_rmsds.append(np.sqrt((((transformed_input - updated_true)** 2)*3).mean()))\n", 699 | " \n", 700 | " lig_true = deepcopy(lig)\n", 701 | " conf_true = lig_true.GetConformer()\n", 702 | " for i in range(lig_true.GetNumAtoms()):\n", 703 | " x, y, z = target.squeeze()[i]\n", 704 | " conf_true.SetAtomPosition(i, Point3D(float(x), float(y), float(z)))\n", 705 | " coords_native = lig_true.GetConformer().GetPositions()\n", 706 | " \n", 707 | " try:\n", 708 | " rmsd = get_symmetry_rmsd(lig_true, coords_native, coords_pred, lig_pred)\n", 709 | " except Exception as e:\n", 710 | " print(\"Using non corrected RMSD because of the error:\", e)\n", 711 | " rmsd = np.sqrt(np.sum((coords_pred - coords_native) ** 2, axis=1).mean())\n", 712 | " centroid_distance = np.linalg.norm(coords_native.mean(axis=0) - coords_pred.mean(axis=0))\n", 713 | " R, t = rigid_transform_Kabsch_3D(coords_pred.T, coords_native.T)\n", 714 | " moved_coords = (R @ (coords_pred).T).T + t.squeeze()\n", 715 | " kabsch_rmsd = np.sqrt(np.sum((moved_coords - coords_native) ** 2, axis=1).mean())\n", 716 | " kabsch_rmsds.append(kabsch_rmsd)\n", 717 | " rmsds.append(rmsd)\n", 718 | " centroid_distances.append(centroid_distance)\n", 719 | " \n", 720 | " print(uid)\n", 721 | " kabsch_rmsds = np.array(kabsch_rmsds)\n", 722 | " print_results(\n", 723 | " rmsds, centroid_distances, False\n", 724 | " )\n", 725 | " print(f'Mean Kabsch RMSD: {kabsch_rmsds.mean().__round__(2)} +- {kabsch_rmsds.std().__round__(2)}')\n", 726 | " print(f'Median Kabsch RMSD: {np.median(kabsch_rmsds).__round__(2)} +- {kabsch_rmsds.std().__round__(2)}')\n", 727 | " print('Kabsch RMSD percentiles: ', np.percentile(kabsch_rmsds, [25, 50, 75]).round(2))\n", 728 | " print('Average protein RMSD', np.mean(protein_rmsds))" 729 | ] 730 | }, 731 | { 732 | "cell_type": "code", 733 | "execution_count": 5, 734 | "metadata": {}, 735 | "outputs": [ 736 | { 737 | "name": "stderr", 738 | "output_type": "stream", 739 | "text": [ 740 | "Evaluating model predictions: 100%|██████████| 19/19 [00:00<00:00, 140.75it/s]" 741 | ] 742 | }, 743 | { 744 | "name": "stdout", 745 | "output_type": "stream", 746 | "text": [ 747 | "B1MDI3\n", 748 | "----------------------------------------------------------------------------------------------------\n", 749 | "| Test statistics (excl. hydrogen atoms) |\n", 750 | "----------------------------------------------------------------------------------------------------\n", 751 | "Mean RMSD: 7.36 +- 2.02\n", 752 | "RMSD percentiles: [6. 6.75 8.64]\n", 753 | "% RMSD below 2: 0.0%\n", 754 | "% RMSD below 5: 5.56%\n", 755 | "Mean centroid distance: 5.21 +- 2.67\n", 756 | "Centroid percentiles: [3.53 4.65 6.2 ]\n", 757 | "% centroid distances below 2: 11.11%\n", 758 | "% centroid distances below 5: 55.56%\n", 759 | "Mean Kabsch RMSD: 1.56 +- 0.34\n", 760 | "Median Kabsch RMSD: 1.5 +- 0.34\n", 761 | "Kabsch RMSD percentiles: [1.33 1.5 1.79]\n", 762 | "Average protein RMSD 0.2936543735706902\n" 763 | ] 764 | }, 765 | { 766 | "name": "stderr", 767 | "output_type": "stream", 768 | "text": [ 769 | "\n" 770 | ] 771 | } 772 | ], 773 | "source": [ 774 | "B1MDI3_binders = [\n", 775 | " '6qqt', '6qrf', '6qqw', '6qre', '6qrc', '6qqv', '6qrg', '6qr1', '6qqq', '6qqu', '6qr2', '6qra', '6qr4', '6qr3',\n", 776 | " '6qr9', '6qqz', '6qr0', '6qrd', '6qr7'\n", 777 | "]\n", 778 | "evaluate_crossdocking(B1MDI3_binders, 'B1MDI3', '6qqt')" 779 | ] 780 | }, 781 | { 782 | "cell_type": "code", 783 | "execution_count": 6, 784 | "metadata": {}, 785 | "outputs": [ 786 | { 787 | "name": "stderr", 788 | "output_type": "stream", 789 | "text": [ 790 | "Evaluating model predictions: 100%|██████████| 19/19 [00:00<00:00, 137.22it/s]" 791 | ] 792 | }, 793 | { 794 | "name": "stdout", 795 | "output_type": "stream", 796 | "text": [ 797 | "B1MDI3\n", 798 | "----------------------------------------------------------------------------------------------------\n", 799 | "| Test statistics (excl. hydrogen atoms) |\n", 800 | "----------------------------------------------------------------------------------------------------\n", 801 | "Mean RMSD: 11.35 +- 2.4\n", 802 | "RMSD percentiles: [10.41 11.99 13.06]\n", 803 | "% RMSD below 2: 0.0%\n", 804 | "% RMSD below 5: 0.0%\n", 805 | "Mean centroid distance: 10.35 +- 3.06\n", 806 | "Centroid percentiles: [ 9.13 11.44 12.74]\n", 807 | "% centroid distances below 2: 0.0%\n", 808 | "% centroid distances below 5: 5.56%\n", 809 | "Mean Kabsch RMSD: 1.63 +- 0.41\n", 810 | "Median Kabsch RMSD: 1.5 +- 0.41\n", 811 | "Kabsch RMSD percentiles: [1.43 1.5 1.99]\n", 812 | "Average protein RMSD 0.4883690880740943\n" 813 | ] 814 | }, 815 | { 816 | "name": "stderr", 817 | "output_type": "stream", 818 | "text": [ 819 | "\n" 820 | ] 821 | } 822 | ], 823 | "source": [ 824 | "B1MDI3_binders = [\n", 825 | " '6qqt', '6qrf', '6qqw', '6qre', '6qrc', '6qqv', '6qrg', '6qr1', '6qqq', '6qqu', '6qr2', '6qra', '6qr4', '6qr3',\n", 826 | " '6qr9', '6qqz', '6qr0', '6qrd', '6qr7'\n", 827 | "]\n", 828 | "evaluate_crossdocking(B1MDI3_binders, 'B1MDI3', '6qrf', idx=1)" 829 | ] 830 | }, 831 | { 832 | "cell_type": "code", 833 | "execution_count": 7, 834 | "metadata": {}, 835 | "outputs": [ 836 | { 837 | "name": "stderr", 838 | "output_type": "stream", 839 | "text": [ 840 | "Evaluating model predictions: 100%|██████████| 19/19 [00:00<00:00, 97.92it/s]" 841 | ] 842 | }, 843 | { 844 | "name": "stdout", 845 | "output_type": "stream", 846 | "text": [ 847 | "B1MDI3\n", 848 | "----------------------------------------------------------------------------------------------------\n", 849 | "| Test statistics (excl. hydrogen atoms) |\n", 850 | "----------------------------------------------------------------------------------------------------\n", 851 | "Mean RMSD: 9.96 +- 2.95\n", 852 | "RMSD percentiles: [ 7.23 9.89 12.89]\n", 853 | "% RMSD below 2: 0.0%\n", 854 | "% RMSD below 5: 0.0%\n", 855 | "Mean centroid distance: 8.4 +- 3.66\n", 856 | "Centroid percentiles: [ 5.07 8.4 12.03]\n", 857 | "% centroid distances below 2: 0.0%\n", 858 | "% centroid distances below 5: 27.78%\n", 859 | "Mean Kabsch RMSD: 1.61 +- 0.38\n", 860 | "Median Kabsch RMSD: 1.6 +- 0.38\n", 861 | "Kabsch RMSD percentiles: [1.35 1.6 1.87]\n", 862 | "Average protein RMSD 0.3002542470502205\n" 863 | ] 864 | }, 865 | { 866 | "name": "stderr", 867 | "output_type": "stream", 868 | "text": [ 869 | "\n" 870 | ] 871 | } 872 | ], 873 | "source": [ 874 | "B1MDI3_binders = [\n", 875 | " '6qqt', '6qrf', '6qqw', '6qre', '6qrc', '6qqv', '6qrg', '6qr1', '6qqq', '6qqu', '6qr2', '6qra', '6qr4', '6qr3',\n", 876 | " '6qr9', '6qqz', '6qr0', '6qrd', '6qr7'\n", 877 | "]\n", 878 | "evaluate_crossdocking(B1MDI3_binders, 'B1MDI3', '6qqw', idx=2)" 879 | ] 880 | }, 881 | { 882 | "cell_type": "code", 883 | "execution_count": 8, 884 | "metadata": {}, 885 | "outputs": [ 886 | { 887 | "name": "stderr", 888 | "output_type": "stream", 889 | "text": [ 890 | "Evaluating model predictions: 100%|██████████| 16/16 [00:00<00:00, 83.94it/s]" 891 | ] 892 | }, 893 | { 894 | "name": "stdout", 895 | "output_type": "stream", 896 | "text": [ 897 | "P56817\n", 898 | "----------------------------------------------------------------------------------------------------\n", 899 | "| Test statistics (excl. hydrogen atoms) |\n", 900 | "----------------------------------------------------------------------------------------------------\n", 901 | "Mean RMSD: 2.13 +- 1.57\n", 902 | "RMSD percentiles: [1.5 1.58 1.98]\n", 903 | "% RMSD below 2: 73.33%\n", 904 | "% RMSD below 5: 93.33%\n", 905 | "Mean centroid distance: 1.08 +- 0.51\n", 906 | "Centroid percentiles: [0.69 0.91 1.22]\n", 907 | "% centroid distances below 2: 93.33%\n", 908 | "% centroid distances below 5: 100.0%\n", 909 | "Mean Kabsch RMSD: 1.12 +- 0.39\n", 910 | "Median Kabsch RMSD: 1.19 +- 0.39\n", 911 | "Kabsch RMSD percentiles: [0.9 1.19 1.35]\n", 912 | "Average protein RMSD 1.0027781721654772\n" 913 | ] 914 | }, 915 | { 916 | "name": "stderr", 917 | "output_type": "stream", 918 | "text": [ 919 | "\n" 920 | ] 921 | } 922 | ], 923 | "source": [ 924 | "P56817_binders = [\n", 925 | " '6uvp', '6uvv', '6uvy', '6uwp', '6nw3', '6e3z', '6nv7', '6uwv', '6nv9', '6od6', '6jt3', '6jsg',\n", 926 | " '6jsn', '6jsf', '6jse', '6pz4'\n", 927 | "]\n", 928 | "evaluate_crossdocking(P56817_binders, 'P56817', '6uvp')" 929 | ] 930 | }, 931 | { 932 | "cell_type": "code", 933 | "execution_count": 9, 934 | "metadata": {}, 935 | "outputs": [ 936 | { 937 | "name": "stderr", 938 | "output_type": "stream", 939 | "text": [ 940 | "Evaluating model predictions: 100%|██████████| 16/16 [00:00<00:00, 85.16it/s]" 941 | ] 942 | }, 943 | { 944 | "name": "stdout", 945 | "output_type": "stream", 946 | "text": [ 947 | "P56817\n", 948 | "----------------------------------------------------------------------------------------------------\n", 949 | "| Test statistics (excl. hydrogen atoms) |\n", 950 | "----------------------------------------------------------------------------------------------------\n", 951 | "Mean RMSD: 2.09 +- 1.52\n", 952 | "RMSD percentiles: [1.41 1.52 2.06]\n", 953 | "% RMSD below 2: 73.33%\n", 954 | "% RMSD below 5: 93.33%\n", 955 | "Mean centroid distance: 0.93 +- 0.37\n", 956 | "Centroid percentiles: [0.72 0.85 0.98]\n", 957 | "% centroid distances below 2: 100.0%\n", 958 | "% centroid distances below 5: 100.0%\n", 959 | "Mean Kabsch RMSD: 1.18 +- 0.4\n", 960 | "Median Kabsch RMSD: 1.17 +- 0.4\n", 961 | "Kabsch RMSD percentiles: [0.92 1.17 1.44]\n", 962 | "Average protein RMSD 0.9673100870998292\n" 963 | ] 964 | }, 965 | { 966 | "name": "stderr", 967 | "output_type": "stream", 968 | "text": [ 969 | "\n" 970 | ] 971 | } 972 | ], 973 | "source": [ 974 | "P56817_binders = [\n", 975 | " '6uvp', '6uvv', '6uvy', '6uwp', '6nw3', '6e3z', '6nv7', '6uwv', '6nv9', '6od6', '6jt3', '6jsg',\n", 976 | " '6jsn', '6jsf', '6jse', '6pz4'\n", 977 | "]\n", 978 | "evaluate_crossdocking(P56817_binders, 'P56817', '6uvv', idx=1)" 979 | ] 980 | }, 981 | { 982 | "cell_type": "code", 983 | "execution_count": 10, 984 | "metadata": {}, 985 | "outputs": [ 986 | { 987 | "name": "stderr", 988 | "output_type": "stream", 989 | "text": [ 990 | "Evaluating model predictions: 100%|██████████| 16/16 [00:00<00:00, 95.69it/s] " 991 | ] 992 | }, 993 | { 994 | "name": "stdout", 995 | "output_type": "stream", 996 | "text": [ 997 | "P56817\n", 998 | "----------------------------------------------------------------------------------------------------\n", 999 | "| Test statistics (excl. hydrogen atoms) |\n", 1000 | "----------------------------------------------------------------------------------------------------\n", 1001 | "Mean RMSD: 2.15 +- 1.64\n", 1002 | "RMSD percentiles: [1.44 1.54 2.09]\n", 1003 | "% RMSD below 2: 73.33%\n", 1004 | "% RMSD below 5: 93.33%\n", 1005 | "Mean centroid distance: 0.9 +- 0.41\n", 1006 | "Centroid percentiles: [0.6 0.83 0.94]\n", 1007 | "% centroid distances below 2: 100.0%\n", 1008 | "% centroid distances below 5: 100.0%\n", 1009 | "Mean Kabsch RMSD: 1.17 +- 0.41\n", 1010 | "Median Kabsch RMSD: 1.18 +- 0.41\n", 1011 | "Kabsch RMSD percentiles: [0.89 1.18 1.4 ]\n", 1012 | "Average protein RMSD 0.9086242446514984\n" 1013 | ] 1014 | }, 1015 | { 1016 | "name": "stderr", 1017 | "output_type": "stream", 1018 | "text": [ 1019 | "\n" 1020 | ] 1021 | } 1022 | ], 1023 | "source": [ 1024 | "P56817_binders = [\n", 1025 | " '6uvp', '6uvv', '6uvy', '6uwp', '6nw3', '6e3z', '6nv7', '6uwv', '6nv9', '6od6', '6jt3', '6jsg',\n", 1026 | " '6jsn', '6jsf', '6jse', '6pz4'\n", 1027 | "]\n", 1028 | "evaluate_crossdocking(P56817_binders, 'P56817', '6uvy', idx=2)" 1029 | ] 1030 | }, 1031 | { 1032 | "cell_type": "code", 1033 | "execution_count": 11, 1034 | "metadata": {}, 1035 | "outputs": [ 1036 | { 1037 | "name": "stderr", 1038 | "output_type": "stream", 1039 | "text": [ 1040 | "Evaluating model predictions: 100%|██████████| 15/15 [00:00<00:00, 89.94it/s]" 1041 | ] 1042 | }, 1043 | { 1044 | "name": "stdout", 1045 | "output_type": "stream", 1046 | "text": [ 1047 | "P17931\n", 1048 | "----------------------------------------------------------------------------------------------------\n", 1049 | "| Test statistics (excl. hydrogen atoms) |\n", 1050 | "----------------------------------------------------------------------------------------------------\n", 1051 | "Mean RMSD: 2.19 +- 2.47\n", 1052 | "RMSD percentiles: [1.33 1.46 1.72]\n", 1053 | "% RMSD below 2: 85.71%\n", 1054 | "% RMSD below 5: 92.86%\n", 1055 | "Mean centroid distance: 0.88 +- 0.32\n", 1056 | "Centroid percentiles: [0.64 0.82 0.99]\n", 1057 | "% centroid distances below 2: 100.0%\n", 1058 | "% centroid distances below 5: 100.0%\n", 1059 | "Mean Kabsch RMSD: 1.22 +- 0.47\n", 1060 | "Median Kabsch RMSD: 1.04 +- 0.47\n", 1061 | "Kabsch RMSD percentiles: [1. 1.04 1.22]\n", 1062 | "Average protein RMSD 0.16096047345873035\n" 1063 | ] 1064 | }, 1065 | { 1066 | "name": "stderr", 1067 | "output_type": "stream", 1068 | "text": [ 1069 | "\n" 1070 | ] 1071 | } 1072 | ], 1073 | "source": [ 1074 | "P17931_binders = [\n", 1075 | " '6qlt', '6qlq', '6i75', '6qlu', '6qln', '6i78', '6qlr', '6i77', '6qlo', '6i76', '6qlp', '6qls',\n", 1076 | " '6i74', '6qge', '6qgf'\n", 1077 | "]\n", 1078 | "evaluate_crossdocking(P17931_binders, 'P17931', '6qlt')" 1079 | ] 1080 | }, 1081 | { 1082 | "cell_type": "code", 1083 | "execution_count": 12, 1084 | "metadata": {}, 1085 | "outputs": [ 1086 | { 1087 | "name": "stderr", 1088 | "output_type": "stream", 1089 | "text": [ 1090 | "Evaluating model predictions: 100%|██████████| 15/15 [00:00<00:00, 91.47it/s]" 1091 | ] 1092 | }, 1093 | { 1094 | "name": "stdout", 1095 | "output_type": "stream", 1096 | "text": [ 1097 | "P17931\n", 1098 | "----------------------------------------------------------------------------------------------------\n", 1099 | "| Test statistics (excl. hydrogen atoms) |\n", 1100 | "----------------------------------------------------------------------------------------------------\n", 1101 | "Mean RMSD: 2.25 +- 2.44\n", 1102 | "RMSD percentiles: [1.44 1.53 1.73]\n", 1103 | "% RMSD below 2: 85.71%\n", 1104 | "% RMSD below 5: 92.86%\n", 1105 | "Mean centroid distance: 0.96 +- 0.25\n", 1106 | "Centroid percentiles: [0.76 0.94 1.07]\n", 1107 | "% centroid distances below 2: 100.0%\n", 1108 | "% centroid distances below 5: 100.0%\n", 1109 | "Mean Kabsch RMSD: 1.22 +- 0.47\n", 1110 | "Median Kabsch RMSD: 1.05 +- 0.47\n", 1111 | "Kabsch RMSD percentiles: [1. 1.05 1.21]\n", 1112 | "Average protein RMSD 0.166106119457376\n" 1113 | ] 1114 | }, 1115 | { 1116 | "name": "stderr", 1117 | "output_type": "stream", 1118 | "text": [ 1119 | "\n" 1120 | ] 1121 | } 1122 | ], 1123 | "source": [ 1124 | "P17931_binders = [\n", 1125 | " '6qlt', '6qlq', '6i75', '6qlu', '6qln', '6i78', '6qlr', '6i77', '6qlo', '6i76', '6qlp', '6qls',\n", 1126 | " '6i74', '6qge', '6qgf'\n", 1127 | "]\n", 1128 | "evaluate_crossdocking(P17931_binders, 'P17931', '6qlq', idx=1)" 1129 | ] 1130 | }, 1131 | { 1132 | "cell_type": "code", 1133 | "execution_count": 13, 1134 | "metadata": {}, 1135 | "outputs": [ 1136 | { 1137 | "name": "stderr", 1138 | "output_type": "stream", 1139 | "text": [ 1140 | "Evaluating model predictions: 100%|██████████| 15/15 [00:00<00:00, 94.57it/s]" 1141 | ] 1142 | }, 1143 | { 1144 | "name": "stdout", 1145 | "output_type": "stream", 1146 | "text": [ 1147 | "P17931\n", 1148 | "----------------------------------------------------------------------------------------------------\n", 1149 | "| Test statistics (excl. hydrogen atoms) |\n", 1150 | "----------------------------------------------------------------------------------------------------\n", 1151 | "Mean RMSD: 2.24 +- 2.46\n", 1152 | "RMSD percentiles: [1.4 1.5 1.75]\n", 1153 | "% RMSD below 2: 85.71%\n", 1154 | "% RMSD below 5: 92.86%\n", 1155 | "Mean centroid distance: 0.89 +- 0.29\n", 1156 | "Centroid percentiles: [0.71 0.86 1.02]\n", 1157 | "% centroid distances below 2: 100.0%\n", 1158 | "% centroid distances below 5: 100.0%\n", 1159 | "Mean Kabsch RMSD: 1.2 +- 0.48\n", 1160 | "Median Kabsch RMSD: 1.03 +- 0.48\n", 1161 | "Kabsch RMSD percentiles: [0.99 1.03 1.17]\n", 1162 | "Average protein RMSD 0.31159448404717477\n" 1163 | ] 1164 | }, 1165 | { 1166 | "name": "stderr", 1167 | "output_type": "stream", 1168 | "text": [ 1169 | "\n" 1170 | ] 1171 | } 1172 | ], 1173 | "source": [ 1174 | "P17931_binders = [\n", 1175 | " '6qlt', '6qlq', '6i75', '6qlu', '6qln', '6i78', '6qlr', '6i77', '6qlo', '6i76', '6qlp', '6qls',\n", 1176 | " '6i74', '6qge', '6qgf'\n", 1177 | "]\n", 1178 | "evaluate_crossdocking(P17931_binders, 'P17931', '6i75', idx=2)" 1179 | ] 1180 | }, 1181 | { 1182 | "cell_type": "code", 1183 | "execution_count": 14, 1184 | "metadata": {}, 1185 | "outputs": [ 1186 | { 1187 | "name": "stderr", 1188 | "output_type": "stream", 1189 | "text": [ 1190 | "Evaluating model predictions: 0%| | 0/14 [00:00