├── DiffBindFR ├── __init__.py ├── app │ └── predict.py ├── common │ ├── __init__.py │ ├── args.py │ ├── dataframe.py │ ├── engines.py │ └── inference_dataset.py ├── configs │ └── diffbindfr_ts.py ├── evaluation │ ├── __init__.py │ ├── eval.py │ ├── export.py │ ├── file_utils.py │ ├── pb.py │ └── reporter.py ├── metrics │ ├── __init__.py │ ├── angbin.py │ ├── centroid.py │ ├── lrmsd.py │ ├── rdmol.py │ └── scrmsd.py ├── relax │ └── pl.py ├── scoring │ ├── __init__.py │ ├── architecture │ │ ├── Angle_ResNet.py │ │ ├── EGNN_Block.py │ │ ├── GVP_Block.py │ │ ├── Gate_Block.py │ │ ├── GraphTransformer_Block.py │ │ ├── KarmaDock_sc.py │ │ └── MDN_Block.py │ ├── dataset │ │ ├── dataloader.py │ │ ├── inference.py │ │ ├── ligand_feature.py │ │ ├── lmdbdataset.py │ │ ├── pipeline.py │ │ └── protein_feature.py │ └── utils │ │ └── early_stop.py └── utils │ ├── __init__.py │ ├── apo_holo.py │ ├── blast.py │ ├── io.py │ ├── logger.py │ ├── pocket.py │ ├── uniprot.py │ └── vinafr_remodel.py ├── INSTALL_OPENFF.sh ├── LICENSE ├── README.md ├── druglib ├── LICENSE ├── __init__.py ├── alerts │ ├── __init__.py │ ├── check.py │ ├── errors.py │ └── molerror.py ├── apis │ ├── __init__.py │ └── nn │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── norm.py │ │ └── utils │ │ ├── __init__.py │ │ └── weight_init.py ├── core │ ├── __init__.py │ ├── runner │ │ ├── __init__.py │ │ ├── base_module.py │ │ ├── base_runner.py │ │ ├── builder.py │ │ ├── checkpoint.py │ │ ├── default_RunnerBuilder.py │ │ ├── dist_utils.py │ │ ├── engine │ │ │ ├── __init__.py │ │ │ ├── defaults.py │ │ │ └── test_utils.py │ │ ├── epoch_based_runner.py │ │ ├── fp16_utils.py │ │ ├── hooks │ │ │ ├── __init__.py │ │ │ ├── checkpoint.py │ │ │ ├── closure.py │ │ │ ├── ema.py │ │ │ ├── evaluation.py │ │ │ ├── hook.py │ │ │ ├── iter_timer.py │ │ │ ├── logger │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ ├── clearml.py │ │ │ │ ├── dvclive.py │ │ │ │ ├── mlflow.py │ │ │ │ ├── neptune.py │ │ │ │ ├── pavi.py │ │ │ │ ├── segmind.py │ │ │ │ ├── tensorboard.py │ │ │ │ ├── text.py │ │ │ │ └── wandb.py │ │ │ ├── lr_updater.py │ │ │ ├── memory.py │ │ │ ├── momentum_updater.py │ │ │ ├── optimizer.py │ │ │ ├── profiler.py │ │ │ ├── sampler_seed.py │ │ │ └── sync_buffer.py │ │ ├── iter_based_runner.py │ │ ├── log_buffer.py │ │ ├── optimizer │ │ │ ├── __init__.py │ │ │ ├── builder.py │ │ │ ├── default_OptBuilder.py │ │ │ └── optimizers.py │ │ ├── parallel │ │ │ ├── __init__.py │ │ │ ├── _functions.py │ │ │ ├── collate.py │ │ │ ├── data_parallel.py │ │ │ ├── distributed.py │ │ │ ├── registry.py │ │ │ ├── scatter_gather.py │ │ │ └── utils.py │ │ ├── priority.py │ │ └── utils.py │ └── trainer │ │ ├── __init__.py │ │ └── base_trainer.py ├── data │ ├── __init__.py │ ├── batch.py │ ├── collate.py │ ├── cv_store.py │ ├── data.py │ ├── data_container.py │ ├── dataloader_collate.py │ ├── feature_store.py │ ├── graph_store.py │ ├── hetero_data.py │ ├── mappingview.py │ ├── mixin.py │ ├── separate.py │ ├── storage.py │ ├── torchsparse_patcher.py │ └── typing.py ├── datasets │ ├── Docking │ │ ├── __init__.py │ │ ├── formatting.py │ │ ├── loading.py │ │ ├── mol_pipeline.py │ │ ├── pocket_pipeline.py │ │ ├── struct_init.py │ │ └── utils.py │ ├── __init__.py │ ├── base_pipelines │ │ ├── __init__.py │ │ ├── compose.py │ │ └── formatting.py │ ├── builder.py │ ├── custom_dataset.py │ ├── lmdbdataset.py │ └── samplers │ │ ├── __init__.py │ │ ├── distributed_sampler.py │ │ ├── graph_learning_sampler.py │ │ ├── grouped_batch_sampler.py │ │ └── iteration_based_sampler.py ├── models │ ├── Base │ │ ├── __init__.py │ │ └── diffusion │ │ │ ├── __init__.py │ │ │ └── time_emb.py │ ├── Docking │ │ ├── __init__.py │ │ ├── base.py │ │ ├── default_MLDockBuilder.py │ │ ├── encoder │ │ │ ├── __init__.py │ │ │ └── equibind_encoder.py │ │ ├── interaction │ │ │ ├── __init__.py │ │ │ ├── schnet.py │ │ │ └── tpscore.py │ │ └── scFlex.py │ ├── __init__.py │ ├── base_model_builder.py │ └── builder.py ├── ops │ ├── __init__.py │ ├── dssp │ │ ├── __init__.py │ │ └── mkdssp │ ├── msms │ │ ├── __init__.py │ │ └── msms │ ├── pymol │ │ ├── geom.py │ │ └── tmalign.py │ ├── schrodinger │ │ ├── __init__.py │ │ └── align.py │ ├── smina │ │ ├── __init__.py │ │ └── smina.static │ └── utils │ │ └── which.py ├── resources │ ├── __init__.py │ ├── bond_length.txt │ └── stereo_chemical_props.txt ├── utils │ ├── __init__.py │ ├── bio_utils │ │ ├── __init__.py │ │ ├── box_utils.py │ │ ├── compute_mol_charges.py │ │ ├── conformer_utils.py │ │ ├── fix_protein.py │ │ ├── mol_attrs.py │ │ ├── nxmol.py │ │ ├── pdbqt_utils.py │ │ ├── read_mol.py │ │ ├── select_pocket.py │ │ └── visualization.py │ ├── config.py │ ├── config_utils.py │ ├── deprecation.py │ ├── file.py │ ├── geometry_utils │ │ ├── __init__.py │ │ ├── aaframe.py │ │ ├── io.py │ │ ├── so3.py │ │ ├── superimposition.py │ │ ├── torus.py │ │ └── utils.py │ ├── google_drive_download.py │ ├── handlers │ │ ├── __init__.py │ │ ├── base.py │ │ ├── json_handler.py │ │ ├── pickle_handler.py │ │ └── yaml_handler.py │ ├── hub.py │ ├── io.py │ ├── logger.py │ ├── misc.py │ ├── obj │ │ ├── __init__.py │ │ ├── complex.py │ │ ├── ligand.py │ │ ├── ligand_constants.py │ │ ├── ligand_math.py │ │ ├── prot_fn.py │ │ ├── prot_math.py │ │ ├── protein.py │ │ └── protein_constants.py │ ├── parrots_jit.py │ ├── parrots_wrapper.py │ ├── path.py │ ├── progressbar.py │ ├── registry.py │ ├── testing.py │ ├── timer.py │ ├── torch_utils │ │ ├── __init__.py │ │ ├── graph.py │ │ ├── isom_graph.py │ │ ├── msc.py │ │ └── tensor_extension.py │ ├── trace.py │ └── version_utils.py └── version.py ├── env.yaml ├── examples ├── AF2 │ ├── 2zec.pdb │ ├── Q15661_AF2.pdb │ ├── Q15661_AF2_crystal.mol2 │ ├── Q15661_AF2_crystal.sdf │ ├── ligand.mol2 │ └── ligand.sdf ├── forward │ ├── 3dbs_protein.pdb │ ├── 3dbs_protein_crystal.mol2 │ ├── 3dbs_protein_crystal.sdf │ └── mols │ │ ├── BDB12915.sdf │ │ ├── BDB35585.sdf │ │ ├── ZINC01921759.sdf │ │ ├── ZINC01963302.sdf │ │ ├── ZINC01971864.sdf │ │ ├── ZINC01993838.sdf │ │ ├── ZINC02029177.sdf │ │ ├── ZINC02097618.sdf │ │ ├── ZINC02113508.sdf │ │ ├── ZINC04090693.sdf │ │ ├── ZINC04104043.sdf │ │ ├── ZINC04121338.sdf │ │ ├── ZINC04159315.sdf │ │ ├── ZINC04165102.sdf │ │ └── ZINC04181650.sdf └── reverse │ ├── ligand_1.sdf │ ├── ligand_2.sdf │ └── receptors │ ├── 2src_protein.pdb │ ├── 2src_protein_crystal.mol2 │ ├── 2src_protein_crystal.sdf │ ├── 3mhw_protein.pdb │ ├── 3mhw_protein_crystal.mol2 │ ├── 3mhw_protein_crystal.sdf │ ├── 3pp0_protein.pdb │ ├── 3pp0_protein_crystal.mol2 │ └── 3pp0_protein_crystal.sdf ├── images └── arch.png ├── notebooks ├── AF2_model_docking.ipynb └── DiffBindFR_demo_colab.ipynb ├── openfold ├── LICENSE ├── __init__.py ├── config.py ├── data │ ├── __init__.py │ └── data_transforms.py ├── np │ ├── __init__.py │ ├── protein.py │ └── residue_constants.py ├── resources │ ├── __init__.py │ └── stereo_chemical_props.txt └── utils │ ├── __init__.py │ ├── rigid_utils.py │ └── tensor_utils.py ├── requirements ├── optional.txt ├── requirements.txt ├── requirements_reference.txt └── runtime.txt └── setup.py /DiffBindFR/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | 3 | 4 | import os.path as osp 5 | HOME = osp.dirname(osp.abspath(__file__)) 6 | ROOT = HOME 7 | HOME = osp.abspath(osp.join(HOME, '..')) -------------------------------------------------------------------------------- /DiffBindFR/common/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .args import ( 3 | parse_args, 4 | benchmark_parse_args, 5 | report_args, 6 | ) 7 | from .dataframe import ( 8 | make_inference_jobs, 9 | JobSlice, 10 | ) 11 | from .inference_dataset import ( 12 | InferenceDataset, 13 | add_center_pos, 14 | ) 15 | from .engines import ( 16 | ec_tag, 17 | load_cfg, 18 | load_dataloader, 19 | load_model, 20 | model_run, 21 | inferencer, 22 | error_corrector, 23 | Scorer, 24 | ) 25 | 26 | -------------------------------------------------------------------------------- /DiffBindFR/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .export import ( 3 | rmsd_to_str, 4 | get_traj_id, 5 | update_complex_pos, 6 | export_xtc, 7 | mol2sdf, 8 | complex_modeling, 9 | ) 10 | from .reporter import ( 11 | report_enrichment, 12 | report_performance, 13 | report_pb, 14 | ) -------------------------------------------------------------------------------- /DiffBindFR/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .centroid import calc_lig_centroid 3 | from .lrmsd import ( 4 | calc_rmsd_nx, 5 | get_symmetry_rmsd, 6 | CalcLigRMSD, 7 | symm_rmsd, 8 | calc_rmsd, 9 | ) 10 | from .angbin import chi_differ 11 | from .scrmsd import sidechain_rmsd 12 | from .rdmol import caltestset_cdist, caltestset_rmsd 13 | 14 | 15 | __all__ = [ 16 | 'calc_lig_centroid', 'calc_rmsd_nx', 'get_symmetry_rmsd', 17 | 'CalcLigRMSD', 'symm_rmsd', 'calc_rmsd', 'chi_differ', 'sidechain_rmsd', 18 | 'caltestset_cdist', 'caltestset_rmsd', 19 | ] -------------------------------------------------------------------------------- /DiffBindFR/metrics/angbin.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | import torch 3 | from torch import Tensor 4 | 5 | from openfold.data.data_transforms import atom37_to_torsion_angles 6 | from druglib.utils.torch_utils import batched_gather 7 | from druglib.utils.obj import protein_constants as pc 8 | from druglib.utils.obj import prot_math as pm 9 | 10 | 11 | def angular_difference( 12 | target_angles: Tensor, 13 | predicted_angles: Tensor, 14 | ): 15 | """ 16 | Args: 17 | target_angles: Tensor. sin \theta, cos \theta, shape (..., N, 2) 18 | predicted_angles: Tensor. sin \theta, cos \theta, shape (..., N, 2) 19 | 20 | Returns: 21 | Tensor: shape (..., N,) 22 | """ 23 | target_radians = torch.atan2(target_angles[..., 0], target_angles[..., 1]) 24 | predicted_radians = torch.atan2(predicted_angles[..., 0], predicted_angles[..., 1]) 25 | diff_radians = torch.fmod(predicted_radians - target_radians + torch.pi, 2. * torch.pi) - torch.pi 26 | abs_diff_radians = torch.abs(diff_radians) 27 | clipped_diff_radians = torch.clamp(abs_diff_radians, min = 0, max = torch.pi) 28 | 29 | return clipped_diff_radians 30 | 31 | def expand_font_fn( 32 | tensor: Tensor, 33 | exp_n: int, 34 | ) -> Tensor: 35 | return tensor.view(*((1, ) * exp_n + tensor.shape)) 36 | 37 | def expand_font_dim( 38 | tensor: Tensor, 39 | ref_tensor: Tensor, 40 | last_n: int = 3, 41 | ) -> Tensor: 42 | ndim = tensor.dim() 43 | exp_n = ref_tensor.dim() - last_n 44 | tensor = expand_font_fn(tensor, exp_n) 45 | tensor = tensor.repeat(ref_tensor.shape[:-last_n] + (1,) * ndim) 46 | return tensor 47 | 48 | def chi_differ( 49 | pred_atom14: Tensor, # (..., N, 14, 3) 50 | target_atom14: Tensor, # (N, 14, 3) 51 | target_atom14_mask: Tensor, # (N, 14) 52 | sequence: Tensor, # (N,) 53 | ): 54 | mapper = pc.atoms14_to_atoms37_mapper[sequence] 55 | mapper = torch.LongTensor(mapper).to(pred_atom14.device) 56 | atom37_exists = pc.restype_atom37_mask[sequence] 57 | atom37_exists = target_atom14_mask.new_tensor(atom37_exists) 58 | target_atom37_mask = batched_gather( 59 | target_atom14_mask, 60 | mapper, 61 | dim = -1, 62 | batch_ndims = len(target_atom14_mask.shape[:-1]) 63 | ) 64 | target_atom37_mask = target_atom37_mask * atom37_exists 65 | target_atom37 = pm.atom14_to_atom37(target_atom14, mapper, target_atom37_mask) 66 | target_data = atom37_to_torsion_angles()( 67 | { 68 | 'aatype': sequence, 69 | 'all_atom_positions': target_atom37, 70 | 'all_atom_mask': target_atom37_mask, 71 | } 72 | ) 73 | target_tor_sin_cos = target_data['torsion_angles_sin_cos'][..., -4:, :] 74 | target_alt_tor_sin_cos = target_data['alt_torsion_angles_sin_cos'][..., -4:, :] 75 | torsion_angles_mask = target_data['torsion_angles_mask'][..., -4:] 76 | 77 | mapper = expand_font_dim(mapper, pred_atom14, 3) 78 | target_atom37_mask = expand_font_dim(target_atom37_mask, pred_atom14, 3) 79 | pred_atom37 = pm.atom14_to_atom37(pred_atom14, mapper, target_atom37_mask) 80 | sequence = expand_font_dim(sequence, pred_atom14, 3) 81 | pred_tor_sin_cos = atom37_to_torsion_angles()( 82 | { 83 | 'aatype': sequence, 84 | 'all_atom_positions': pred_atom37, 85 | 'all_atom_mask': target_atom37_mask, 86 | } 87 | )['torsion_angles_sin_cos'][..., -4:, :] 88 | exp_n = pred_atom14.dim() - 3 89 | target_tor_sin_cos = expand_font_fn(target_tor_sin_cos, exp_n) 90 | target_alt_tor_sin_cos = expand_font_fn(target_alt_tor_sin_cos, exp_n) 91 | torsion_angles_mask = expand_font_fn(torsion_angles_mask, exp_n) 92 | 93 | pred_chi_differ = angular_difference(pred_tor_sin_cos, target_tor_sin_cos) 94 | aln_pred_chi_differ = angular_difference(pred_tor_sin_cos, target_alt_tor_sin_cos) 95 | pred_chi_differ = torch.minimum(pred_chi_differ, aln_pred_chi_differ) 96 | pred_chi_differ = pred_chi_differ * torsion_angles_mask 97 | 98 | return pred_chi_differ, torsion_angles_mask -------------------------------------------------------------------------------- /DiffBindFR/metrics/centroid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | import torch 3 | from torch import Tensor 4 | 5 | 6 | def calc_lig_centroid( 7 | pred_pos: Tensor, # (N_pose, N_traj, N_node, 3) 8 | target_pos: Tensor, # (N_node, 3) 9 | ) -> Tensor: 10 | pred_pos_mean = torch.mean(pred_pos, dim = -2) 11 | target_pos_mean = torch.mean(target_pos, dim = -2) 12 | target_pos_mean = target_pos_mean.view(*((1, ) * (pred_pos.dim() - 2) + target_pos_mean.shape)) 13 | dist = (pred_pos_mean - target_pos_mean).norm(dim = -1) 14 | return dist -------------------------------------------------------------------------------- /DiffBindFR/metrics/rdmol.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | import torch 3 | from torch import Tensor 4 | from rdkit import Chem 5 | from .lrmsd import calc_rmsd 6 | from ..utils import read_mol 7 | 8 | 9 | def caltestset_rmsd( 10 | mol_pred_file, 11 | mol_true_file, 12 | ) -> float: 13 | mol_pred = read_mol(mol_pred_file) 14 | assert mol_pred is not None, mol_pred_file 15 | mol_pred = Chem.RemoveAllHs(mol_pred, sanitize=False) 16 | 17 | mol_true = read_mol(mol_true_file) 18 | assert mol_true is not None, mol_true_file 19 | mol_true = Chem.RemoveAllHs(mol_true, sanitize=False) 20 | 21 | rsd = calc_rmsd(mol_pred, mol_true) 22 | return rsd 23 | 24 | def calc_lig_centroid( 25 | pred_pos: Tensor, 26 | target_pos: Tensor, 27 | ) -> float: 28 | pred_pos_mean = torch.mean(pred_pos, dim = -2) 29 | target_pos_mean = torch.mean(target_pos, dim = -2) 30 | dist = (pred_pos_mean - target_pos_mean).norm(dim = -1) 31 | return dist.item() 32 | 33 | def caltestset_cdist( 34 | mol_pred_file, 35 | mol_true_file, 36 | ) -> float: 37 | mol_pred = read_mol(mol_pred_file) 38 | assert mol_pred is not None, mol_pred_file 39 | mol_pred = Chem.RemoveAllHs(mol_pred, sanitize=False) 40 | 41 | mol_true = read_mol(mol_true_file) 42 | assert mol_true is not None, mol_true_file 43 | mol_true = Chem.RemoveAllHs(mol_true, sanitize=False) 44 | 45 | cdist = calc_lig_centroid( 46 | torch.from_numpy(mol_pred.GetConformer(0).GetPositions()), 47 | torch.from_numpy(mol_true.GetConformer(0).GetPositions()), 48 | ) 49 | return cdist 50 | 51 | -------------------------------------------------------------------------------- /DiffBindFR/metrics/scrmsd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | import torch 3 | from torch import Tensor 4 | 5 | from druglib.utils.obj import protein_constants as pc 6 | 7 | 8 | def make_altern_atom14( 9 | atom14_pos: Tensor, 10 | atom14_mask: Tensor, 11 | sequence: Tensor, 12 | ): 13 | # As the atom naming is ambiguous for 7 of the 20 amino acids, provide 14 | # alternative ground truth coordinates where the naming is swapped 15 | restype_3 = [pc.restype_1to3[res] for res in pc.restypes] 16 | restype_3 += ["UNK"] 17 | 18 | # Matrices for renaming ambiguous atoms. 19 | all_matrices = { 20 | res: torch.eye( 21 | 14, 22 | dtype = atom14_pos.dtype, 23 | device = atom14_pos.device, 24 | ) 25 | for res in restype_3 26 | } 27 | for resname, swap in pc.residue_atom_renaming_swaps.items(): 28 | correspondences = torch.arange( 29 | 14, device = atom14_pos.device 30 | ) 31 | for source_atom_swap, target_atom_swap in swap.items(): 32 | source_index = pc.restype_name_to_atom14_names[resname].index( 33 | source_atom_swap 34 | ) 35 | target_index = pc.restype_name_to_atom14_names[resname].index( 36 | target_atom_swap 37 | ) 38 | correspondences[source_index] = target_index 39 | correspondences[target_index] = source_index 40 | renaming_matrix = atom14_pos.new_zeros((14, 14)) 41 | for index, correspondence in enumerate(correspondences): 42 | renaming_matrix[index, correspondence] = 1.0 43 | all_matrices[resname] = renaming_matrix 44 | 45 | renaming_matrices = torch.stack( 46 | [all_matrices[restype] for restype in restype_3] 47 | ) 48 | 49 | # Pick the transformation matrices for the given residue sequence 50 | # shape (num_res, 14, 14). 51 | renaming_transform = renaming_matrices[sequence] 52 | renaming_transform = renaming_transform.view(*((1, ) * (atom14_pos.dim() - 3) + renaming_transform.shape)) 53 | 54 | # Apply it to the ground truth positions. shape (num_res, 14, 3). 55 | alternative_atom14_pos = torch.einsum( 56 | "...rac,...rab->...rbc", atom14_pos, renaming_transform 57 | ) 58 | alternative_atom14_mask = torch.einsum( 59 | "...ra,...rab->...rb", atom14_mask.float(), renaming_transform 60 | ) 61 | return alternative_atom14_pos, alternative_atom14_mask 62 | 63 | 64 | def sidechain_rmsd( 65 | pred_atom14: Tensor, # (..., N, 14, 3) 66 | target_atom14: Tensor, # (N, 14, 3) 67 | target_atom14_mask: Tensor, # (N, 14) 68 | sequence: Tensor, # (N,) 69 | eps: float = 1e-6, 70 | ): 71 | target_atom14 = target_atom14.view(*((1, ) * (pred_atom14.dim() - 3) + target_atom14.shape)) 72 | target_atom14_mask = target_atom14_mask.view(*((1,) * (pred_atom14.dim() - 3) + target_atom14_mask.shape)) 73 | sc_atm_mask = target_atom14_mask[..., 5:] 74 | sc_atm_pred = pred_atom14[..., 5:, :] * sc_atm_mask[..., None] 75 | sc_atm_target = target_atom14[..., 5:, :] * sc_atm_mask[..., None] 76 | alternative_atom14_target, alternative_atm_mask = make_altern_atom14( 77 | target_atom14, target_atom14_mask, sequence, 78 | ) 79 | altern_atm_mask = alternative_atm_mask[..., 5:] 80 | altern_atm_target = alternative_atom14_target[..., 5:, :] * altern_atm_mask[..., None] 81 | 82 | dist_square = ((sc_atm_target - sc_atm_pred) ** 2).sum(dim = (-2, -1)) 83 | altern_dist_square = ((altern_atm_target - sc_atm_pred) ** 2).sum(dim = (-2, -1)) 84 | dist_square = torch.minimum(dist_square, altern_dist_square) 85 | sc_res_mask = sc_atm_mask.any(dim = -1) 86 | deno = sc_atm_mask.sum(dim = -1) 87 | rmsd = torch.sqrt(dist_square / (deno + eps)) * sc_res_mask 88 | rmsd = rmsd.sum(dim = -1) / sc_res_mask.sum(dim = -1) 89 | 90 | return rmsd -------------------------------------------------------------------------------- /DiffBindFR/scoring/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .utils.early_stop import Early_stopper 3 | from .dataset.dataloader import PassNoneDataLoader 4 | from .architecture.KarmaDock_sc import KarmaDock 5 | from .dataset.inference import InferenceScoringDataset_chunk 6 | 7 | 8 | 9 | __all__ = [ 10 | 'Early_stopper', 'PassNoneDataLoader', 'KarmaDock', 11 | 'InferenceScoringDataset_chunk', 12 | ] -------------------------------------------------------------------------------- /DiffBindFR/scoring/architecture/Angle_ResNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class AngleResnetBlock(nn.Module): 6 | def __init__(self, c_hidden): 7 | """ 8 | Args: 9 | c_hidden: 10 | Hidden channel dimension 11 | """ 12 | super(AngleResnetBlock, self).__init__() 13 | 14 | self.c_hidden = c_hidden 15 | 16 | self.linear_1 = nn.Linear(self.c_hidden, self.c_hidden) 17 | self.linear_2 = nn.Linear(self.c_hidden, self.c_hidden) 18 | 19 | self.relu = nn.ReLU() 20 | 21 | def forward(self, a: torch.Tensor) -> torch.Tensor: 22 | 23 | s_initial = a 24 | 25 | a = self.relu(a) 26 | a = self.linear_1(a) 27 | a = self.relu(a) 28 | a = self.linear_2(a) 29 | 30 | return a + s_initial 31 | 32 | 33 | class AngleResnet(nn.Module): 34 | """ 35 | Implements Algorithm 20, lines 11-14 36 | """ 37 | 38 | def __init__(self, c_in, c_hidden, no_blocks, no_angles, epsilon): 39 | """ 40 | Args: 41 | c_in: 42 | Input channel dimension 43 | c_hidden: 44 | Hidden channel dimension 45 | no_blocks: 46 | Number of resnet blocks 47 | no_angles: 48 | Number of torsion angles to generate 49 | epsilon: 50 | Small constant for normalization 51 | """ 52 | super(AngleResnet, self).__init__() 53 | 54 | self.c_in = c_in 55 | self.c_hidden = c_hidden 56 | self.no_blocks = no_blocks 57 | self.no_angles = no_angles 58 | self.eps = epsilon 59 | 60 | self.linear_in = nn.Linear(self.c_in, self.c_hidden) 61 | self.linear_initial = nn.Linear(self.c_in, self.c_hidden) 62 | 63 | self.layers = nn.ModuleList() 64 | for _ in range(self.no_blocks): 65 | layer = AngleResnetBlock(c_hidden=self.c_hidden) 66 | self.layers.append(layer) 67 | 68 | self.linear_out = nn.Linear(self.c_hidden, self.no_angles * 2) 69 | 70 | self.relu = nn.ReLU() 71 | 72 | def forward( 73 | self, s: torch.Tensor, s_initial: torch.Tensor 74 | ): 75 | """ 76 | Args: 77 | s: 78 | [*, C_hidden] single embedding 79 | s_initial: 80 | [*, C_hidden] single embedding as of the start of the 81 | StructureModule 82 | Returns: 83 | [*, no_angles, 2] predicted angles 84 | """ 85 | # NOTE: The ReLU's applied to the inputs are absent from the supplement 86 | # pseudocode but present in the source. For maximal compatibility with 87 | # the pretrained weights, I'm going with the source. 88 | 89 | # [*, C_hidden] 90 | s_initial = self.relu(s_initial) 91 | s_initial = self.linear_initial(s_initial) 92 | s = self.relu(s) 93 | s = self.linear_in(s) 94 | s = s + s_initial 95 | 96 | for l in self.layers: 97 | s = l(s) 98 | 99 | s = self.relu(s) 100 | 101 | # [*, no_angles * 2] 102 | s = self.linear_out(s) 103 | 104 | # [*, no_angles, 2] 105 | s = s.view(s.shape[:-1] + (-1, 2)) 106 | 107 | unnormalized_s = s 108 | norm_denom = torch.sqrt( 109 | torch.clamp( 110 | torch.sum(s ** 2, dim=-1, keepdim=True), 111 | min=self.eps, 112 | ) 113 | ) 114 | s = s / norm_denom 115 | 116 | return unnormalized_s, s -------------------------------------------------------------------------------- /DiffBindFR/scoring/architecture/EGNN_Block.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch_geometric.nn import GraphNorm 5 | from torch_geometric.utils import softmax 6 | from torch_scatter import scatter 7 | 8 | class EGNN(nn.Module): 9 | def __init__(self, dim_in, dim_tmp, edge_in, edge_out, num_head=8, drop_rate=0.15): 10 | super().__init__() 11 | assert dim_tmp % num_head == 0 12 | self.edge_dim = edge_in 13 | self.num_head = num_head # 4 14 | self.dh = dim_tmp // num_head # 32 15 | self.dim_tmp = dim_tmp # 12 16 | self.q_layer = nn.Linear(dim_in, dim_tmp) 17 | self.k_layer = nn.Linear(dim_in, dim_tmp) 18 | self.v_layer = nn.Linear(dim_in, dim_tmp) 19 | self.m_layer = nn.Sequential( 20 | nn.Linear(edge_in+1, dim_tmp), 21 | nn.Dropout(p=drop_rate), 22 | nn.LeakyReLU(), 23 | nn.Linear(dim_tmp, dim_tmp) 24 | ) 25 | self.m2f_layer = nn.Sequential( 26 | nn.Linear(dim_tmp, dim_tmp), 27 | nn.Dropout(p=drop_rate)) 28 | self.e_layer = nn.Sequential( 29 | nn.Linear(dim_tmp, edge_out), 30 | nn.Dropout(p=drop_rate)) 31 | self.gate_layer = nn.Sequential( 32 | nn.Linear(3*dim_tmp, dim_tmp), 33 | nn.Dropout(p=drop_rate)) 34 | self.layer_norm_1 = GraphNorm(dim_tmp) 35 | self.layer_norm_2 = GraphNorm(dim_tmp) 36 | self.fin_layer = nn.Sequential( 37 | nn.Linear(dim_tmp, dim_tmp), 38 | nn.Dropout(p=drop_rate), 39 | nn.LeakyReLU(), 40 | nn.Linear(dim_tmp, dim_tmp) 41 | ) 42 | self.update_layer = coords_update(dim_dh=self.dh, num_head=num_head, drop_rate=drop_rate) 43 | 44 | def forward(self, node_s, edge_s, edge_index, total_pos, pro_nodes, batch, update_pos=True): 45 | q_ = self.q_layer(node_s) 46 | k_ = self.k_layer(node_s) 47 | v_ = self.v_layer(node_s) 48 | # message passing 49 | m_ij = torch.cat([edge_s, 50 | torch.pairwise_distance(total_pos[edge_index[0]], total_pos[edge_index[1]]).unsqueeze(dim=-1)*0.1], dim=-1) 51 | m_ij = self.m_layer(m_ij) 52 | k_ij = k_[edge_index[1]] * m_ij 53 | a_ij = ((q_[edge_index[0]] * k_ij)/math.sqrt(self.dh)).view((-1, self.num_head, self.dh)) 54 | w_ij = softmax(torch.norm(a_ij, p=1, dim=2), index=edge_index[0]).unsqueeze(dim=-1) 55 | # update node and edge embeddings 56 | node_s_new = self.m2f_layer(scatter(w_ij*v_[edge_index[1]].view((-1, self.num_head, self.dh)), index=edge_index[0], reduce='sum', dim=0).view((-1, self.dim_tmp))) 57 | edge_s_new = self.e_layer(a_ij.view((-1, self.dim_tmp))) 58 | g = torch.sigmoid(self.gate_layer(torch.cat([node_s_new, node_s, node_s_new-node_s], dim=-1))) 59 | node_s_new = self.layer_norm_1(g*node_s_new+node_s, batch) 60 | node_s_new = self.layer_norm_2(g*self.fin_layer(node_s_new)+node_s_new, batch) 61 | # update coords 62 | if update_pos: 63 | total_pos = self.update_layer(a_ij, total_pos, edge_index, pro_nodes) 64 | return node_s_new, edge_s_new, edge_index, total_pos 65 | 66 | 67 | class coords_update(nn.Module): 68 | def __init__(self, dim_dh, num_head, drop_rate=0.15): 69 | super().__init__() 70 | self.num_head = num_head 71 | self.attention2deltax = nn.Sequential( 72 | nn.Linear(dim_dh, dim_dh//2), 73 | nn.Dropout(p=drop_rate), 74 | nn.LeakyReLU(), 75 | nn.Linear(dim_dh//2, 1) 76 | ) 77 | self.weighted_head_layer = nn.Linear(num_head, 1, bias=False) 78 | 79 | def forward(self, a_ij, pos, edge_index, pro_nodes): 80 | edge_index_mask = edge_index[0] >= pro_nodes 81 | i, j = edge_index[:, edge_index_mask] 82 | delta_x = pos[i] - pos[j] 83 | delta_x = delta_x/(torch.norm(delta_x, p=2, dim=-1).unsqueeze(dim=-1) + 1e-6 ) 84 | delta_x = delta_x*self.weighted_head_layer(self.attention2deltax(a_ij[edge_index_mask]).squeeze(dim=2)) 85 | delta_x = scatter(delta_x, index=i, reduce='sum', dim=0) 86 | pos += delta_x 87 | return pos 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /DiffBindFR/scoring/architecture/Gate_Block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_geometric.nn import GraphNorm 4 | 5 | class Gate_Block(nn.Module): 6 | def __init__(self, dim_tmp, drop_rate=0.15): 7 | super().__init__() 8 | self.gate_layer = nn.Sequential( 9 | nn.Linear(3*dim_tmp, dim_tmp), 10 | nn.Dropout(p=drop_rate)) 11 | self.norm = GraphNorm(dim_tmp) 12 | 13 | def forward(self, f1, f2): 14 | g = torch.sigmoid(self.gate_layer(torch.cat([f2, f1, f2-f1], dim=-1))) 15 | f2 = self.norm(g*f2+f1) 16 | return f2 -------------------------------------------------------------------------------- /DiffBindFR/scoring/architecture/KarmaDock_sc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_scatter import scatter 4 | from torch_geometric.nn import GraphNorm 5 | from .GVP_Block import GVP_embedding 6 | from .GraphTransformer_Block import GraghTransformer 7 | from .MDN_Block import MDN_Block 8 | from .EGNN_Block import EGNN 9 | from .Gate_Block import Gate_Block 10 | from .Angle_ResNet import AngleResnet 11 | 12 | 13 | class KarmaDock(nn.Module): 14 | def __init__(self): 15 | super(KarmaDock, self).__init__() 16 | # encoders 17 | self.lig_encoder = GraghTransformer( 18 | in_channels=89, 19 | edge_features=20, 20 | num_hidden_channels=128, 21 | activ_fn=torch.nn.SiLU(), 22 | transformer_residual=True, 23 | num_attention_heads=4, 24 | norm_to_apply='batch', 25 | dropout_rate=0.15, 26 | num_layers=6 27 | ) 28 | self.pro_encoder = GVP_embedding( 29 | (9, 3), (128, 16), (21, 1), (32, 1), seq_in=True) 30 | self.gn = GraphNorm(128) 31 | # pose prediction 32 | self.egnn_layers = nn.ModuleList( 33 | [EGNN(dim_in=128, dim_tmp=128, edge_in=128, edge_out=128, num_head=4, drop_rate=0.15) for i in range(8)] 34 | ) 35 | self.edge_init_layer = nn.Linear(6, 128) 36 | self.node_gate_layer = Gate_Block(dim_tmp=128, 37 | drop_rate=0.15 38 | ) 39 | self.edge_gate_layer = Gate_Block(dim_tmp=128, 40 | drop_rate=0.15 41 | ) 42 | # scoring 43 | self.mdn_layer = MDN_Block( 44 | hidden_dim=128, 45 | n_gaussians=10, 46 | dropout_rate=0.10, 47 | dist_threhold=7. 48 | ) 49 | self.torsion_sin_cos_layer = AngleResnet( 50 | c_in=128, 51 | c_hidden=32, 52 | no_blocks=2, 53 | no_angles=4, 54 | epsilon=1e-6 55 | ) 56 | 57 | def forward(self, data): 58 | batch_size = data['ligand'].batch[-1] + 1 59 | pro_node_s, lig_node_s = self.encoding(data) 60 | lig_pos = data['ligand'].xyz 61 | mdn_score_pred = self.scoring( 62 | lig_s=lig_node_s, 63 | lig_pos=lig_pos, 64 | pro_s=pro_node_s, 65 | data=data, 66 | dist_threhold=5., 67 | batch_size=batch_size, 68 | ) 69 | return mdn_score_pred 70 | 71 | def encoding(self, data): 72 | """get ligand & protein embeddings""" 73 | pro_node_s = self.pro_encoder( 74 | ( 75 | data['protein']['node_s'], 76 | data['protein']['node_v'] 77 | ), 78 | data[("protein", "p2p", "protein")]["edge_index"], 79 | ( 80 | data[("protein", "p2p", "protein")]["edge_s"], 81 | data[("protein", "p2p", "protein")]["edge_v"] 82 | ), 83 | data['protein'].seq 84 | ) 85 | lig_node_s = self.lig_encoder(data['ligand'].node_s.to(torch.float32), data['ligand', 'l2l', 'ligand'].edge_s[data['ligand'].cov_edge_mask].to(torch.float32), data['ligand', 'l2l', 'ligand'].edge_index[:,data['ligand'].cov_edge_mask]) 86 | return pro_node_s, lig_node_s 87 | 88 | def scoring(self, lig_s, lig_pos, pro_s, data, dist_threhold, batch_size): 89 | """scoring the protein-ligand binding strength""" 90 | pi, sigma, mu, dist, c_batch, _, _ = self.mdn_layer( 91 | lig_s=lig_s, 92 | lig_pos=lig_pos, 93 | lig_batch=data['ligand'].batch, 94 | pro_s=pro_s, 95 | pro_pos=data['protein'].xyz_full, 96 | pro_batch=data['protein'].batch, 97 | edge_index=data['ligand', 'l2l', 'ligand'].edge_index[:, data['ligand'].cov_edge_mask] 98 | ) 99 | mdn_score = self.mdn_layer.calculate_probablity(pi, sigma, mu, dist) 100 | mdn_score[torch.where(dist > dist_threhold)[0]] = 0. 101 | mdn_score = scatter(mdn_score, index=c_batch, dim=0, reduce='sum', dim_size=batch_size).float() 102 | return mdn_score -------------------------------------------------------------------------------- /DiffBindFR/scoring/architecture/MDN_Block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions import Normal 5 | from torch_geometric.utils import to_dense_batch 6 | 7 | 8 | class MDN_Block(nn.Module): 9 | def __init__(self, hidden_dim, n_gaussians, dropout_rate=0.15, 10 | dist_threhold=1000): 11 | super(MDN_Block, self).__init__() 12 | self.MLP = nn.Sequential(nn.Linear(hidden_dim*2, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ELU(), nn.Dropout(p=dropout_rate)) 13 | self.z_pi = nn.Linear(hidden_dim, n_gaussians) 14 | self.z_sigma = nn.Linear(hidden_dim, n_gaussians) 15 | self.z_mu = nn.Linear(hidden_dim, n_gaussians) 16 | self.atom_types = nn.Linear(hidden_dim, 18) 17 | self.bond_types = nn.Linear(hidden_dim*2, 5) 18 | self.dist_threhold = dist_threhold 19 | 20 | def forward(self, lig_s, lig_pos, lig_batch, pro_s, pro_pos, pro_batch, edge_index): 21 | 22 | h_l_x, l_mask = to_dense_batch(lig_s, lig_batch, fill_value=0) 23 | h_t_x, t_mask = to_dense_batch(pro_s, pro_batch, fill_value=0) 24 | h_l_pos, _ = to_dense_batch(lig_pos, lig_batch, fill_value=0) 25 | h_t_pos, _ = to_dense_batch(pro_pos, pro_batch, fill_value=0) 26 | 27 | assert h_l_x.size(0) == h_t_x.size(0), 'Encountered unequal batch-sizes' 28 | (B, N_l, C_out), N_t = h_l_x.size(), h_t_x.size(1) 29 | self.B = B 30 | self.N_l = N_l 31 | self.N_t = N_t 32 | # Combine and mask 33 | h_l_x = h_l_x.unsqueeze(-2) 34 | h_l_x = h_l_x.repeat(1, 1, N_t, 1) # [B, N_l, N_t, C_out] 35 | 36 | h_t_x = h_t_x.unsqueeze(-3) 37 | h_t_x = h_t_x.repeat(1, N_l, 1, 1) # [B, N_l, N_t, C_out] 38 | 39 | C = torch.cat((h_l_x, h_t_x), -1) 40 | self.C_mask = C_mask = l_mask.view(B, N_l, 1) & t_mask.view(B, 1, N_t) 41 | self.C = C = C[C_mask] 42 | C = self.MLP(C) 43 | 44 | # Get batch indexes for ligand-target combined features 45 | C_batch = torch.tensor(range(B)).unsqueeze(-1).unsqueeze(-1).to(lig_s.device) 46 | C_batch = C_batch.repeat(1, N_l, N_t)[C_mask] 47 | 48 | # Outputs 49 | pi = F.softmax(self.z_pi(C), -1) 50 | sigma = F.elu(self.z_sigma(C))+1.1 51 | mu = F.elu(self.z_mu(C))+1 52 | dist = self.compute_euclidean_distances_matrix(h_l_pos, h_t_pos.view(h_t_pos.size(0), -1, 3))[C_mask] 53 | atom_types = self.atom_types(lig_s) 54 | bond_types = self.bond_types(torch.cat([lig_s[edge_index[0]],lig_s[edge_index[1]]], axis=1)) 55 | return pi, sigma, mu, dist.unsqueeze(1).detach(), C_batch, atom_types, bond_types 56 | 57 | def compute_euclidean_distances_matrix(self, X, Y): 58 | # Based on: https://medium.com/@souravdey/l2-distance-matrix-vectorization-trick-26aa3247ac6c 59 | # (X-Y)^2 = X^2 + Y^2 -2XY 60 | X = X.double() 61 | Y = Y.double() 62 | 63 | dists = -2 * torch.bmm(X, Y.permute(0, 2, 1)) + torch.sum(Y**2, axis=-1).unsqueeze(1) + torch.sum(X**2, axis=-1).unsqueeze(-1) 64 | # return dists**0.5 65 | return torch.nan_to_num((dists**0.5).view(self.B, self.N_l,-1,14),10000).min(axis=-1)[0] 66 | 67 | 68 | def mdn_loss_fn(self, pi, sigma, mu, y): 69 | normal = Normal(mu, sigma) 70 | loglik = normal.log_prob(y.expand_as(normal.loc)) 71 | loss = -torch.logsumexp(torch.log(pi) + loglik, dim=1) 72 | return loss 73 | 74 | def calculate_probablity(self, pi, sigma, mu, y): 75 | normal = Normal(mu, sigma) 76 | logprob = normal.log_prob(y.expand_as(normal.loc)) 77 | logprob += torch.log(pi) 78 | prob = logprob.exp().sum(1) 79 | return prob 80 | -------------------------------------------------------------------------------- /DiffBindFR/scoring/dataset/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections.abc import Mapping, Sequence 3 | from typing import Optional, Union, List 4 | import torch.utils.data 5 | from torch.utils.data.dataloader import default_collate 6 | from torch_geometric.data import Batch, Dataset 7 | from torch_geometric.data.data import BaseData 8 | 9 | 10 | 11 | class PassNoneCollater: 12 | def __init__(self, follow_batch, exclude_keys): 13 | self.follow_batch = follow_batch 14 | self.exclude_keys = exclude_keys 15 | 16 | def __call__(self, batch): 17 | batch = list(filter(lambda x:x is not None, batch)) 18 | elem = batch[0] 19 | if isinstance(elem, BaseData): 20 | return Batch.from_data_list(batch, self.follow_batch, 21 | self.exclude_keys) 22 | elif isinstance(elem, torch.Tensor): 23 | return default_collate(batch) 24 | elif isinstance(elem, float): 25 | return torch.tensor(batch, dtype=torch.float) 26 | elif isinstance(elem, int): 27 | return torch.tensor(batch) 28 | elif isinstance(elem, str): 29 | return batch 30 | elif isinstance(elem, Mapping): 31 | return {key: self([data[key] for data in batch]) for key in elem} 32 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): 33 | return type(elem)(*(self(s) for s in zip(*batch))) 34 | elif isinstance(elem, Sequence) and not isinstance(elem, str): 35 | return [self(s) for s in zip(*batch)] 36 | 37 | raise TypeError(f'DataLoader found invalid type: {type(elem)}') 38 | 39 | def collate(self, batch): # Deprecated... 40 | return self(batch) 41 | 42 | 43 | class PassNoneDataLoader(torch.utils.data.DataLoader): 44 | r"""A data loader which merges data objects from a 45 | :class:`torch_geometric.data.Dataset` to a mini-batch. 46 | Data objects can be either of type :class:`~torch_geometric.data.Data` or 47 | :class:`~torch_geometric.data.HeteroData`. 48 | 49 | Args: 50 | dataset (Dataset): The dataset from which to load the data. 51 | batch_size (int, optional): How many samples per batch to load. 52 | (default: :obj:`1`) 53 | shuffle (bool, optional): If set to :obj:`True`, the data will be 54 | reshuffled at every epoch. (default: :obj:`False`) 55 | follow_batch (List[str], optional): Creates assignment batch 56 | vectors for each key in the list. (default: :obj:`None`) 57 | exclude_keys (List[str], optional): Will exclude each key in the 58 | list. (default: :obj:`None`) 59 | **kwargs (optional): Additional arguments of 60 | :class:`torch.utils.data.DataLoader`. 61 | """ 62 | def __init__( 63 | self, 64 | dataset: Union[Dataset, List[BaseData]], 65 | batch_size: int = 1, 66 | shuffle: bool = False, 67 | follow_batch: Optional[List[str]] = None, 68 | exclude_keys: Optional[List[str]] = None, 69 | **kwargs, 70 | ): 71 | 72 | if 'collate_fn' in kwargs: 73 | del kwargs['collate_fn'] 74 | 75 | # Save for PyTorch Lightning < 1.6: 76 | self.follow_batch = follow_batch 77 | self.exclude_keys = exclude_keys 78 | 79 | super().__init__( 80 | dataset, 81 | batch_size, 82 | shuffle, 83 | collate_fn=PassNoneCollater(follow_batch, exclude_keys), 84 | **kwargs, 85 | ) 86 | 87 | -------------------------------------------------------------------------------- /DiffBindFR/scoring/dataset/lmdbdataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | import os, pickle 3 | from typing import * 4 | import lmdb 5 | from functools import lru_cache 6 | 7 | 8 | class LMDBLoader: 9 | def __init__( 10 | self, 11 | db_path: str, 12 | map_gb: float = 10000.0, 13 | strict_get: bool = True, 14 | _exclude_key: List[str] = ['KEYS'], 15 | ): 16 | if not os.path.exists(db_path): 17 | raise ValueError("{} does not exists.".format(db_path)) 18 | 19 | self.db_path = db_path 20 | self.map_gb = map_gb 21 | self.strict_get = strict_get 22 | self._exclude_key = _exclude_key 23 | 24 | env = self._connect_db(self.db_path) 25 | with env.begin() as txn: 26 | self._keys = [k for k in txn.cursor().iternext(values=False) if k.decode() not in _exclude_key] 27 | 28 | import atexit 29 | atexit.register(lambda s: s._close_db, self) 30 | 31 | def _connect_db( 32 | self, 33 | lmdb_path: str, 34 | attach: bool = False, 35 | ) -> Optional[lmdb.Environment]: 36 | assert getattr(self, '_env', None) is None, 'A connection has already been opened.' 37 | env = lmdb.open( 38 | lmdb_path, 39 | map_size=int(self.map_gb * (1024 * 1024 * 1024)), 40 | create=False, 41 | subdir=os.path.isdir(lmdb_path), 42 | readonly=True, 43 | lock=False, 44 | readahead=False, 45 | meminit=False, 46 | max_readers=256, 47 | ) 48 | if not attach: 49 | return env 50 | else: 51 | self._env = env 52 | 53 | def __len__(self) -> int: 54 | return len(self._keys) 55 | 56 | def __contains__(self, key: str): 57 | return key.encode("ascii") in self._keys 58 | 59 | @lru_cache(maxsize=16) 60 | def __getitem__(self, idx) -> Optional[Any]: 61 | if not hasattr(self, "_env"): 62 | self._connect_db(self.db_path, attach = True) 63 | 64 | idx = str(idx).encode("ascii") 65 | if idx not in self._keys: 66 | if self.strict_get: 67 | raise ValueError(f'query index {idx.decode()} not in lmdb.') 68 | else: 69 | return None 70 | 71 | with self._env.begin() as txn: 72 | with txn.cursor() as curs: 73 | datapoint_pickled = curs.get(idx) 74 | data = pickle.loads(datapoint_pickled) 75 | 76 | return data 77 | 78 | def _close_db(self): 79 | if hasattr(self, '_env') and \ 80 | isinstance(self._env, lmdb.Environment): 81 | self._env.close() 82 | self._env = None -------------------------------------------------------------------------------- /DiffBindFR/scoring/utils/early_stop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Early_stopper(object): 5 | def __init__(self, model_file, mode='higher', patience=70, tolerance=0.0): 6 | self.model_file = model_file 7 | assert mode in ['higher', 'lower'] 8 | self.mode = mode 9 | if self.mode == 'higher': 10 | self._check = self._check_higher 11 | else: 12 | self._check = self._check_lower 13 | self.patience = patience 14 | self.tolerance = tolerance 15 | self.counter = 0 16 | self.best_score = None 17 | self.early_stop = False 18 | 19 | def _check_higher(self, score, prev_best_score): 20 | # return (score > prev_best_score) 21 | return score / prev_best_score > 1 + self.tolerance 22 | 23 | def _check_lower(self, score, prev_best_score): 24 | # return (score < prev_best_score) 25 | return prev_best_score / score > 1 + self.tolerance 26 | 27 | def load_model(self, model_obj, my_device, strict=False, mine=False): 28 | '''Load model saved with early stopping.''' 29 | if not mine: 30 | model_obj.load_state_dict(torch.load(self.model_file, map_location=my_device)['model_state_dict'], strict=strict) 31 | else: 32 | params = torch.load(self.model_file, map_location=my_device)['model'] 33 | params_ = {} 34 | for k, v in params.items(): 35 | k_ = 'module.' + k[6:] 36 | params_[k_] = v 37 | del params 38 | model_obj.load_state_dict(params_, strict=strict) 39 | 40 | def save_model(self, model_obj): 41 | '''Saves model when the metric on the validation set gets improved.''' 42 | torch.save({'model_state_dict': model_obj.state_dict()}, self.model_file) 43 | 44 | def step(self, score, model_obj): 45 | if self.best_score is None: 46 | self.best_score = score 47 | self.save_model(model_obj) 48 | elif self._check(score, self.best_score): 49 | self.best_score = score 50 | self.save_model(model_obj) 51 | self.counter = 0 52 | else: 53 | self.counter += 1 54 | print(f'# EarlyStopping counter: {self.counter} out of {self.patience}') 55 | if self.counter >= self.patience: 56 | self.early_stop = True 57 | print(f'# Current best performance {float(self.best_score):.3f}') 58 | return self.early_stop 59 | -------------------------------------------------------------------------------- /DiffBindFR/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .io import ( 3 | filename, 4 | write_fasta, 5 | JSONEncoder, 6 | exists_or_assert, 7 | mkdir_or_exists, 8 | read_mol, 9 | to_complex_block, 10 | read_molblock, 11 | update_mol_pose, 12 | ) 13 | from .logger import get_logger 14 | from .uniprot import get_seq_from_uniprot, pdb2uniprot 15 | from .blast import ( 16 | PDBBlastRecord_Local, 17 | blastp_local, 18 | PDBBlastRecord, 19 | blastp_prody, 20 | ) 21 | from .apo_holo import ApoHoloBS, pair_spatial_metrics 22 | from .pocket import ( 23 | temp_pdb_file, 24 | PDBPocketResidues, 25 | get_ligand_code, 26 | sdf2prody, 27 | show_pocket_ligand, 28 | get_pocket_resnums_nv_str, 29 | get_pocket_resnums_prody_str, 30 | get_pocket_resnums_bsalign_str, 31 | ) 32 | from .vinafr_remodel import build_vinafr_protein 33 | 34 | 35 | __all__ = [ 36 | 'filename', 'write_fasta', 'JSONEncoder', 'exists_or_assert', 'mkdir_or_exists', 37 | 'read_mol', 'to_complex_block', 'read_molblock', 'update_mol_pose', 38 | 'PDBBlastRecord_Local', 'blastp_local', 'PDBBlastRecord', 'blastp_prody', 39 | 'get_logger', 'get_seq_from_uniprot', 'pdb2uniprot', 'ApoHoloBS', 'pair_spatial_metrics', 40 | 'temp_pdb_file', 'PDBPocketResidues', 'get_ligand_code', 'sdf2prody', 'show_pocket_ligand', 41 | 'get_pocket_resnums_nv_str', 'get_pocket_resnums_prody_str', 'get_pocket_resnums_bsalign_str', 42 | 'build_vinafr_protein', 43 | ] -------------------------------------------------------------------------------- /DiffBindFR/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from typing import Optional 3 | import logging 4 | 5 | 6 | logger_initialized = {} 7 | 8 | def get_logger( 9 | name: str = 'DiffBindFR', 10 | log_file: Optional[str] = None, 11 | log_level: int = logging.INFO, 12 | io_mode: str = 'w' 13 | ) -> logging.Logger: 14 | logger = logging.getLogger(name) 15 | if name in logger_initialized: 16 | return logger 17 | 18 | # handle hierarchical names 19 | # e.g., logger "a" is initialized, then logger "a.b" will skip the 20 | # initialization since it is a child of "a". 21 | for logger_name in logger_initialized: 22 | if name.startswith(logger_name): 23 | return logger 24 | 25 | for handler in logger.root.handlers: 26 | if type(handler) is logging.StreamHandler: 27 | handler.setLevel(logging.ERROR) 28 | 29 | stream_handler = logging.StreamHandler() 30 | handlers = [stream_handler] 31 | 32 | if log_file is not None: 33 | file_handler = logging.FileHandler(log_file, io_mode) 34 | handlers.append(file_handler) 35 | 36 | formatter = logging.Formatter( 37 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s' 38 | ) 39 | for handler in handlers: 40 | handler.setFormatter(formatter) 41 | handler.setLevel(log_level) 42 | logger.addHandler(handler) 43 | 44 | logger.setLevel(log_level) 45 | logger_initialized[name] = True 46 | 47 | return logger -------------------------------------------------------------------------------- /DiffBindFR/utils/uniprot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | import json 3 | import requests 4 | from six.moves.urllib.request import urlopen 5 | from Bio import SeqIO 6 | 7 | 8 | def get_seq_from_uniprot( 9 | uniprot_id: str, 10 | output_dir: str = './', 11 | ) -> str: 12 | """ 13 | Saves and returns the fasta sequence of a protein given 14 | its UNIPROT accession number 15 | """ 16 | URL = "https://www.uniprot.org/uniprot/" 17 | url_fasta = requests.get(URL + uniprot_id + ".fasta") 18 | 19 | file_name_fasta = output_dir + uniprot_id + '.fasta' 20 | open(file_name_fasta, 'wb').write(url_fasta.content) 21 | 22 | # Read the protein sequence 23 | fasta_prot = SeqIO.read(open(file_name_fasta), 'fasta') 24 | seq_prot = str(fasta_prot.seq) 25 | return seq_prot 26 | 27 | def pdb2uniprot(pdbid): 28 | """Mapping the pdb id to the uniprot id""" 29 | pdbid = pdbid.lower() 30 | try: 31 | content = urlopen('https://www.ebi.ac.uk/pdbe/api/mappings/uniprot/' + pdbid).read() 32 | except: 33 | print(pdbid, "PDB Not Found (HTTP Error 404). Skipped.") 34 | return None 35 | content = json.loads(content.decode('utf-8')) 36 | uniprotid = list(content[pdbid]['UniProt'].keys())[0] 37 | return uniprotid -------------------------------------------------------------------------------- /INSTALL_OPENFF.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | WorkDir=$PWD 4 | 5 | echo ${WorkDir} 6 | 7 | pip install python-constraint pint 8 | 9 | cd ${WorkDir} 10 | echo "Install openff-toolkit" 11 | 12 | version="0.15.2" 13 | wget https://github.com/openforcefield/openff-toolkit/archive/refs/tags/${version}.zip 14 | 15 | unzip ${version}.zip && cd openff-toolkit-${version}/ && python -m pip install . 16 | 17 | echo "Install openff-units" 18 | 19 | version="0.2.2" 20 | wget https://github.com/openforcefield/openff-units/archive/refs/tags/${version}.zip 21 | 22 | unzip ${version}.zip && cd openff-units-${version}/ && python -m pip install . && cd .. 23 | 24 | echo "Install openff-utilities" 25 | 26 | version="0.1.8" 27 | wget https://github.com/openforcefield/openff-utilities/archive/refs/tags/v${version}.zip 28 | 29 | unzip v${version}.zip && cd openff-utilities-${version}/ && python -m pip install . && cd .. 30 | 31 | echo "Install openff-forcefields" 32 | 33 | version="2023.11.0" 34 | wget https://github.com/openforcefield/openff-forcefields/archive/refs/tags/${version}.zip 35 | 36 | unzip ${version}.zip && cd openff-forcefields-${version}/ && python -m pip install . && cd .. 37 | 38 | cd ${WorkDir} 39 | 40 | echo "All Done!" 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The Clear BSD License 2 | 3 | Copyright (c) 2024, MDLDrugLib from Peking University 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without modification, are permitted (subject to the limitations in the disclaimer below) provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 9 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 10 | * Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 11 | 12 | NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 13 | -------------------------------------------------------------------------------- /druglib/LICENSE: -------------------------------------------------------------------------------- 1 | The Clear BSD License 2 | 3 | Copyright (c) 2024, Peking University 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without modification, are permitted (subject to the limitations in the disclaimer below) provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 9 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 10 | * Neither the name of Peking University nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 11 | 12 | NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 13 | -------------------------------------------------------------------------------- /druglib/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import * 2 | from .utils import * 3 | from .apis import * 4 | from .core import * 5 | from .ops import * 6 | from .data import * 7 | from .datasets import * 8 | 9 | import os.path as osp 10 | HOME = osp.dirname(osp.abspath(__file__)) 11 | HOME = osp.abspath(osp.join(HOME, '..')) -------------------------------------------------------------------------------- /druglib/alerts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .check import * 3 | from .errors import * 4 | from . import molerror 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /druglib/alerts/check.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from typing import Optional, Union 3 | import numpy as np 4 | import torch 5 | from torch import Tensor 6 | 7 | 8 | def check_inf_nan_np( 9 | array: Union[np.ndarray], 10 | ) -> bool: 11 | return not (np.isnan(array).any() or np.isinf(array).any()) 12 | 13 | def check_inf_nan_torch( 14 | tensor: Tensor, 15 | ) -> bool: 16 | return not (torch.isnan(tensor).any() or torch.isinf(tensor).any()) 17 | 18 | 19 | -------------------------------------------------------------------------------- /druglib/alerts/errors.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | 3 | class Error(Exception): 4 | """Base class for exceptions.""" 5 | 6 | class TimeoutError(Error): 7 | def __init__(self, message): 8 | self.message = message 9 | super(TimeoutError, self).__init__(message) 10 | 11 | class TimerError(Error): 12 | def __init__(self, message): 13 | self.message = message 14 | super(TimerError, self).__init__(message) 15 | 16 | class MultipleChainsError(Error): 17 | """An error indicating that multiple chains were found for a given ID.""" 18 | 19 | -------------------------------------------------------------------------------- /druglib/alerts/molerror.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | 3 | 4 | class MolReconstructError(Exception): 5 | pass 6 | -------------------------------------------------------------------------------- /druglib/apis/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .nn import * -------------------------------------------------------------------------------- /druglib/apis/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .activations import get_activation 3 | from .norm import get_norm 4 | from .utils import (initialize, update_init_info, INITIALIZERS, constant_init, 5 | xavier_init, normal_init, trunc_normal_init, uniform_init, 6 | kaiming_init, caffe2_xavier_init, bias_init_with_prob, 7 | ConstantInit, XavierInit, NormalInit, UniformInit, 8 | TruncNormalInit, KaimingInit, Caffe2XavierInit, PretrainedInit, 9 | glorot_init, glorot_orthogonal_init, he_orthogonal_init, 10 | kaiming_uniform_init 11 | ) 12 | 13 | __all__ = [ 14 | 'get_activation', 'get_norm', 'initialize', 'update_init_info', 'INITIALIZERS', 'constant_init', 'xavier_init', 15 | 'normal_init', 'trunc_normal_init', 'uniform_init', 'kaiming_init', 'caffe2_xavier_init', 'bias_init_with_prob', 16 | 'ConstantInit', 'XavierInit', 'NormalInit', 'UniformInit', 'TruncNormalInit', 'KaimingInit', 'Caffe2XavierInit', 17 | 'PretrainedInit', 'glorot_init', 'glorot_orthogonal_init', 'he_orthogonal_init', 'kaiming_uniform_init' 18 | ] -------------------------------------------------------------------------------- /druglib/apis/nn/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .weight_init import (initialize, update_init_info, INITIALIZERS, constant_init, 3 | xavier_init, normal_init, trunc_normal_init, uniform_init, 4 | kaiming_init, caffe2_xavier_init, bias_init_with_prob, 5 | ConstantInit, XavierInit, NormalInit, UniformInit, 6 | TruncNormalInit, KaimingInit, Caffe2XavierInit, PretrainedInit, 7 | glorot_init, glorot_orthogonal_init, he_orthogonal_init, 8 | kaiming_uniform_init, 9 | ) 10 | 11 | __all__ = [ 12 | 'initialize', 'update_init_info', 'INITIALIZERS', 'constant_init', 'xavier_init', 'normal_init', 13 | 'trunc_normal_init', 'uniform_init', 'kaiming_init', 'caffe2_xavier_init', 'bias_init_with_prob', 14 | 'ConstantInit', 'XavierInit', 'NormalInit', 'UniformInit', 'TruncNormalInit', 'KaimingInit', 'Caffe2XavierInit', 15 | 'PretrainedInit', 'glorot_init', 'glorot_orthogonal_init', 'he_orthogonal_init', 'kaiming_uniform_init' 16 | ] -------------------------------------------------------------------------------- /druglib/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .runner import * -------------------------------------------------------------------------------- /druglib/core/runner/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | import copy 3 | from typing import Any, Optional 4 | 5 | from ...utils import Registry 6 | 7 | RUNNERS = Registry("runner") 8 | RUNNER_BUILDERS = Registry("runner builder") 9 | 10 | def build_runner_builder( 11 | cfg:dict 12 | ) -> Any: 13 | return RUNNER_BUILDERS.build(cfg) 14 | 15 | def build_runner( 16 | cfg:dict, 17 | default_args:Optional[dict] = None, 18 | ) -> Any: 19 | runner_cfg = cfg.copy() 20 | builder_type = runner_cfg.pop( 21 | "RunnerBuilder", 22 | "DefaultRunnerBuilder" 23 | ) 24 | runner_builder = build_runner_builder( 25 | dict( 26 | type = builder_type, 27 | runner_cfg = runner_cfg, 28 | default_args = default_args, 29 | ) 30 | ) 31 | runner = runner_builder() 32 | return runner -------------------------------------------------------------------------------- /druglib/core/runner/default_RunnerBuilder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .builder import RUNNERS, RUNNER_BUILDERS 3 | 4 | @RUNNER_BUILDERS.register_module() 5 | class DefaultRunnerBuilder: 6 | """ 7 | Default builder for runners 8 | Custom existing `Runner` like `EpochBasedRunner` though `RunnerBuilder`. 9 | For example, we can inject some new properties and functions for `Runner`. 10 | E.g. 11 | >>> from druglib.core.runner.builder import RUNNER_BUILDERS, build_runner, RUNNERS 12 | >>> # Define a new RunnerRebuilder 13 | >>> @@RUNNER_BUILDERS.register_module() 14 | >>> class MyRunnerBuilder: 15 | ... def __init__(self, runner_cfg, default_args = None): 16 | ... if not isinstance(runner_cfg, dict): 17 | ... raise TypeError('runner_cfg should be a dict', 18 | ... f'but got {type(runner_cfg)}') 19 | ... self.runner_cfg = runner_cfg 20 | ... self.default_args = default_args 21 | ... 22 | ... def __call__(self): 23 | ... runner = RUNNERS.build(self.runner_cfg, 24 | ... default_args=self.default_args) 25 | ... # Add new properties for existing runner 26 | ... runner.my_name = 'my_runner' 27 | ... runner.my_function = lambda self: print(self.my_name) 28 | ... ... 29 | >>> # build your runner 30 | >>> runner_cfg = dict(type='EpochBasedRunner', max_epochs=40, 31 | ... constructor='MyRunnerConstructor') 32 | >>> runnerbuilder_cfg = dict(type='MyRunnerConstructor', runner_cfg=runner_cfg, default_args=None) 33 | >>> runner = build_runner(runnerbuilder_cfg) 34 | """ 35 | def __init__( 36 | self, 37 | runner_cfg:dict, 38 | default_args = None, 39 | ): 40 | if not isinstance(runner_cfg, dict): 41 | raise TypeError( 42 | f'`runner_cfg` must be a dict, but got {type(runner_cfg)}.' 43 | ) 44 | self.runner_cfg = runner_cfg 45 | self.default_args = default_args 46 | 47 | def __call__(self): 48 | return RUNNERS.build( 49 | self.runner_cfg, 50 | default_args = self.default_args 51 | ) -------------------------------------------------------------------------------- /druglib/core/runner/engine/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .test_utils import single_gpu_inference, multi_gpu_inference 3 | from .defaults import default_argument_parser 4 | 5 | __all__ = [ 6 | 'single_gpu_inference', 'multi_gpu_inference', 'default_argument_parser', 7 | ] -------------------------------------------------------------------------------- /druglib/core/runner/engine/defaults.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | import argparse 3 | from druglib import DictAction 4 | 5 | 6 | def default_argument_parser(): 7 | parser = argparse.ArgumentParser(description="Base Trainer in MDLDrugLib") 8 | parser.add_argument('config', help = 'Train Config File Path.') 9 | parser.add_argument('--work-dir', help = 'The Dir to Save logs and model.') 10 | parser.add_argument('--resume-from', help = 'The Checkpoint File to Resume From.') 11 | 12 | parser.add_argument('--auto-resume', action = 'store_true', help = 'Resume From The Latest Checkpoint Automatically.') 13 | parser.add_argument('--no-validate', action='store_true', help='Whether or Not to Evaluate The Checkpoint During Training.') 14 | parser.add_argument('--gpu-id', type = int, default = 0, help = 'ID of GPU to Use (Only Applicable to Non-distributed Training).') 15 | parser.add_argument('--seed', type = int, default = None, help = 'Random Seed for Experiment Reproduction.') 16 | parser.add_argument('--diff-seed', action='store_true', 17 | help='Whether or Not to Set Different Seeds For Different Ranks.') 18 | parser.add_argument('--deterministic', action='store_true', 19 | help='Whether to Set Deterministic Options For CUDNN Backend.') 20 | parser.add_argument('--cfg-options', nargs = '+', action = DictAction, 21 | help = 'override some settings in the used config, the key-value pair ' 22 | 'in xxx=yyy format will be merged into config file. If the value to ' 23 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 24 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 25 | 'Note that the quotation marks are necessary and that no white space ' 26 | 'is allowed.') 27 | parser.add_argument( 28 | '--launcher', 29 | choices = ['none', 'pytorch', 'slurm', 'mpi'], 30 | default = 'none', 31 | help = 'Job Launcher') 32 | parser.add_argument('--local_rank', type=int, default=0) 33 | 34 | return parser 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /druglib/core/runner/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .hook import HOOKS, Hook 3 | from .checkpoint import CheckpointHook 4 | from .closure import ClosureHook 5 | from .ema import EMAHook, ExpMomentumEMAHook, LinearMomentumEMAHook 6 | from .memory import EmptyCacheHook 7 | from .iter_timer import IterTimerHook 8 | from .profiler import ProfilerHook 9 | from .sampler_seed import DistSamplerSeedHook 10 | from .sync_buffer import SyncBuffersHook 11 | from .logger import (ClearMLLoggerHook, DvcliveLoggerHook, LoggerHook, 12 | MlflowLoggerHook, NeptuneLoggerHook, PaviLoggerHook, 13 | SegmindLoggerHook, TensorboardLoggerHook, TextLoggerHook, 14 | WandbLoggerHook) 15 | from .lr_updater import (LrUpdaterHook, AnnealingLrUpdaterHook, 16 | CosineRestartLrUpdaterHook, CyclicLrUpdaterHook, 17 | ExpLrUpdaterHook, FixedLrUpdaterHook, 18 | FlatCosineAnnealingLrUpdaterHook, InvLrUpdaterHook, 19 | OneCycleLrUpdaterHook, PolyLrUpdaterHook, 20 | StepLrUpdaterHook) 21 | from .momentum_updater import (AnnealingMomentumUpdaterHook, 22 | CyclicMomentumUpdaterHook, 23 | MomentumUpdaterHook, 24 | OneCycleMomentumUpdaterHook, 25 | StepMomentumUpdaterHook) 26 | from .optimizer import (Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook, 27 | GradientCumulativeOptimizerHook, OptimizerHook) 28 | from .evaluation import DistEvalHook, EvalHook 29 | 30 | __all__ = [ 31 | 'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 32 | 'FixedLrUpdaterHook', 'StepLrUpdaterHook', 'ExpLrUpdaterHook', 33 | 'PolyLrUpdaterHook', 'InvLrUpdaterHook', 'AnnealingLrUpdaterHook', 34 | 'FlatCosineAnnealingLrUpdaterHook', 'CosineRestartLrUpdaterHook', 35 | 'CyclicLrUpdaterHook', 'OneCycleLrUpdaterHook', 'OptimizerHook', 36 | 'Fp16OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 37 | 'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook', 38 | 'TextLoggerHook', 'TensorboardLoggerHook', 'NeptuneLoggerHook', 39 | 'WandbLoggerHook', 'DvcliveLoggerHook', 'MomentumUpdaterHook', 40 | 'StepMomentumUpdaterHook', 'AnnealingMomentumUpdaterHook', 41 | 'CyclicMomentumUpdaterHook', 'OneCycleMomentumUpdaterHook', 42 | 'SyncBuffersHook', 'EMAHook', 'ProfilerHook', 'GradientCumulativeOptimizerHook', 43 | 'GradientCumulativeFp16OptimizerHook', 'SegmindLoggerHook', 44 | 'ClearMLLoggerHook', 'EvalHook', 'DistEvalHook', 'ExpMomentumEMAHook', 45 | 'LinearMomentumEMAHook' 46 | ] -------------------------------------------------------------------------------- /druglib/core/runner/hooks/closure.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .hook import HOOKS, Hook 3 | 4 | # add some method for Hook 5 | @HOOKS.register_module() 6 | class ClosureHook(Hook): 7 | 8 | def __init__(self, fn_name, fn): 9 | assert hasattr(self, fn_name) 10 | assert callable(fn) 11 | setattr(self, fn_name, fn) 12 | 13 | @HOOKS.register_module() 14 | class MultiClosureHook(Hook): 15 | 16 | def __init__(self, fnmapping:dict): 17 | for fn_name, fn in fnmapping.items(): 18 | assert hasattr(self, fn_name) 19 | assert callable(fn) 20 | setattr(self, fn_name, fn) -------------------------------------------------------------------------------- /druglib/core/runner/hooks/hook.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from druglib.utils import Registry, is_method_overridden 3 | 4 | HOOKS = Registry('hook') 5 | 6 | # Basic Hook 7 | class Hook: 8 | stages = ( 9 | 'before_run', 'before_train_epoch', 'before_train_iter', 10 | 'after_train_iter', 'after_train_epoch', 'before_val_epoch', 11 | 'before_val_iter', 'after_val_iter', 'after_val_epoch', 12 | 'after_run' 13 | ) 14 | def before_run(self, runner): 15 | pass 16 | 17 | def after_run(self, runner): 18 | pass 19 | 20 | def before_epoch(self, runner): 21 | pass 22 | 23 | def after_epoch(self, runner): 24 | pass 25 | 26 | def before_iter(self, runner): 27 | pass 28 | 29 | def after_iter(self, runner): 30 | pass 31 | 32 | def before_train_epoch(self, runner): 33 | self.before_epoch(runner) 34 | 35 | def before_val_epoch(self, runner): 36 | self.before_epoch(runner) 37 | 38 | def after_train_epoch(self, runner): 39 | self.after_epoch(runner) 40 | 41 | def after_val_epoch(self, runner): 42 | self.after_epoch(runner) 43 | 44 | def before_train_iter(self, runner): 45 | self.before_iter(runner) 46 | 47 | def before_val_iter(self, runner): 48 | self.before_iter(runner) 49 | 50 | def after_train_iter(self, runner): 51 | self.after_iter(runner) 52 | 53 | def after_val_iter(self, runner): 54 | self.after_iter(runner) 55 | 56 | def every_n_epochs(self, runner, n): 57 | """ 58 | the number of repeated traveling dataset in data_loader 59 | """ 60 | return (runner.epoch + 1) % n == 0 if n > 0 else False 61 | 62 | def every_n_inner_iters(self, runner, n): 63 | ''' 64 | inner_iters represents iterations in one epoch in epochbased_runner 65 | inner_iters equivalent to iter in iterbased_runner 66 | ''' 67 | return (runner.inner_iter + 1) % n == 0 if n > 0 else False 68 | 69 | def every_n_iters(self, runner, n): 70 | ''' 71 | iters represents iterations among the passed epochs.:iters = max_inner_iters * passed_epochs 72 | ''' 73 | return (runner.iter + 1) % n == 0 if n > 0 else False 74 | 75 | def end_of_epoch(self, runner): 76 | ''' 77 | druglib define exhaustive data as the indicate of the end of epoch. 78 | that say len(data_loader) == ceil(num_all_samples / batch_size) 79 | ''' 80 | return runner.inner_iter + 1 == len(runner.data_loader) 81 | 82 | def is_last_epoch(self, runner): 83 | return runner.epoch + 1 == runner._max_epochs 84 | 85 | def is_last_iter(self, runner): 86 | return runner.iter + 1 == runner._max_iters 87 | 88 | def get_triggered_stages(self): 89 | trigger_stages = set() 90 | for stage in Hook.stages: 91 | if is_method_overridden(stage, Hook, self): 92 | trigger_stages.add(stage) 93 | # some methods will be triggered in multi-stages 94 | # use this dict to map method to stages. 95 | method_stages_map = { 96 | 'before_epoch': ['before_train_epoch', 'before_val_epoch'], 97 | 'after_epoch': ['after_train_epoch', 'after_val_epoch'], 98 | 'before_iter': ['before_train_iter', 'before_val_iter'], 99 | 'after_iter': ['after_train_iter', 'after_val_iter'], 100 | } 101 | for method, map_stages in method_stages_map.items(): 102 | if is_method_overridden(method, Hook, self): 103 | trigger_stages.update(map_stages) 104 | 105 | return [stage for stage in Hook.stages if stage in trigger_stages] -------------------------------------------------------------------------------- /druglib/core/runner/hooks/iter_timer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | import time 3 | from .hook import HOOKS, Hook 4 | 5 | @HOOKS.register_module() 6 | class IterTimerHook(Hook): 7 | 8 | def before_epoch(self, runner): 9 | self.t = time.time() 10 | 11 | def before_iter(self, runner): 12 | runner.log_buffer.update({'data_time':time.time() - self.t}) 13 | 14 | def after_iter(self, runner): 15 | runner.log_buffer.update({'time':time.time() - self.t}) 16 | self.t = time.time() -------------------------------------------------------------------------------- /druglib/core/runner/hooks/logger/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .base import LoggerHook 3 | from .clearml import ClearMLLoggerHook 4 | from .dvclive import DvcliveLoggerHook 5 | from .mlflow import MlflowLoggerHook 6 | from .neptune import NeptuneLoggerHook 7 | from .pavi import PaviLoggerHook 8 | from .segmind import SegmindLoggerHook 9 | from .tensorboard import TensorboardLoggerHook 10 | from .text import TextLoggerHook 11 | from .wandb import WandbLoggerHook 12 | 13 | __all__ = [ 14 | 'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook', 15 | 'TensorboardLoggerHook', 'TextLoggerHook', 'WandbLoggerHook', 16 | 'NeptuneLoggerHook', 'DvcliveLoggerHook', 'SegmindLoggerHook', 17 | 'ClearMLLoggerHook' 18 | ] -------------------------------------------------------------------------------- /druglib/core/runner/hooks/logger/clearml.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from typing import Dict, Optional 3 | 4 | from ...dist_utils import master_only 5 | from ..hook import HOOKS 6 | from .base import LoggerHook 7 | 8 | 9 | @HOOKS.register_module() 10 | class ClearMLLoggerHook(LoggerHook): 11 | """Class to log metrics with clearml. 12 | 13 | It requires `clearml`_ to be installed. 14 | 15 | 16 | Args: 17 | init_kwargs (dict): A dict contains the `clearml.Task.init` 18 | initialization keys. See `taskinit`_ for more details. 19 | interval (int): Logging interval (every k iterations). Default 10. 20 | ignore_last (bool): Ignore the log of last iterations in each epoch 21 | if less than `interval`. Default: True. 22 | reset_flag (bool): Whether to clear the output buffer after logging. 23 | Default: False. 24 | by_epoch (bool): Whether EpochBasedRunner is used. Default: True. 25 | 26 | .. _clearml: 27 | https://clear.ml/docs/latest/docs/ 28 | .. _taskinit: 29 | https://clear.ml/docs/latest/docs/references/sdk/task/#taskinit 30 | """ 31 | 32 | def __init__(self, 33 | init_kwargs: Optional[Dict] = None, 34 | interval: int = 10, 35 | ignore_last: bool = True, 36 | reset_flag: bool = False, 37 | by_epoch: bool = True): 38 | super().__init__(interval, ignore_last, reset_flag, by_epoch) 39 | self.import_clearml() 40 | self.init_kwargs = init_kwargs 41 | 42 | def import_clearml(self): 43 | try: 44 | import clearml 45 | except ImportError: 46 | raise ImportError( 47 | 'Please run "pip install clearml" to install clearml') 48 | self.clearml = clearml 49 | 50 | @master_only 51 | def before_run(self, runner) -> None: 52 | super().before_run(runner) 53 | task_kwargs = self.init_kwargs if self.init_kwargs else {} 54 | self.task = self.clearml.Task.init(**task_kwargs) 55 | self.task_logger = self.task.get_logger() 56 | 57 | @master_only 58 | def log(self, runner) -> None: 59 | tags = self.get_loggable_tags(runner) 60 | for tag, val in tags.items(): 61 | self.task_logger.report_scalar(tag, tag, val, 62 | self.get_iter(runner)) 63 | -------------------------------------------------------------------------------- /druglib/core/runner/hooks/logger/dvclive.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from pathlib import Path 3 | from typing import Optional 4 | 5 | from ...dist_utils import master_only 6 | from ..hook import HOOKS 7 | from .base import LoggerHook 8 | 9 | 10 | @HOOKS.register_module() 11 | class DvcliveLoggerHook(LoggerHook): 12 | """Class to log metrics with dvclive. 13 | 14 | It requires `dvclive`_ to be installed. 15 | 16 | Args: 17 | model_file (str): Default None. If not None, after each epoch the 18 | model will be saved to {model_file}. 19 | interval (int): Logging interval (every k iterations). Default 10. 20 | ignore_last (bool): Ignore the log of last iterations in each epoch 21 | if less than `interval`. Default: True. 22 | reset_flag (bool): Whether to clear the output buffer after logging. 23 | Default: False. 24 | by_epoch (bool): Whether EpochBasedRunner is used. Default: True. 25 | kwargs: Arguments for instantiating `Live`_. 26 | 27 | .. _dvclive: 28 | https://dvc.org/doc/dvclive 29 | 30 | .. _Live: 31 | https://dvc.org/doc/dvclive/api-reference/live#parameters 32 | """ 33 | 34 | def __init__(self, 35 | model_file: Optional[str] = None, 36 | interval: int = 10, 37 | ignore_last: bool = True, 38 | reset_flag: bool = False, 39 | by_epoch: bool = True, 40 | **kwargs): 41 | super().__init__(interval, ignore_last, reset_flag, by_epoch) 42 | self.model_file = model_file 43 | self.import_dvclive(**kwargs) 44 | 45 | def import_dvclive(self, **kwargs) -> None: 46 | try: 47 | from dvclive import Live 48 | except ImportError: 49 | raise ImportError( 50 | 'Please run "pip install dvclive" to install dvclive') 51 | self.dvclive = Live(**kwargs) 52 | 53 | @master_only 54 | def log(self, runner) -> None: 55 | tags = self.get_loggable_tags(runner) 56 | if tags: 57 | self.dvclive.set_step(self.get_iter(runner)) 58 | for k, v in tags.items(): 59 | self.dvclive.log(k, v) 60 | 61 | @master_only 62 | def after_train_epoch(self, runner) -> None: 63 | super().after_train_epoch(runner) 64 | if self.model_file is not None: 65 | runner.save_checkpoint( 66 | Path(self.model_file).parent, 67 | filename_tmpl=Path(self.model_file).name, 68 | create_symlink=False, 69 | ) -------------------------------------------------------------------------------- /druglib/core/runner/hooks/logger/mlflow.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from typing import Dict, Optional 3 | 4 | from druglib.utils import TORCH_VERSION 5 | from ...dist_utils import master_only 6 | from ..hook import HOOKS 7 | from .base import LoggerHook 8 | 9 | 10 | @HOOKS.register_module() 11 | class MlflowLoggerHook(LoggerHook): 12 | """Class to log metrics and (optionally) a trained model to MLflow. 13 | 14 | It requires `MLflow`_ to be installed. 15 | 16 | Args: 17 | exp_name (str, optional): Name of the experiment to be used. 18 | Default None. If not None, set the active experiment. 19 | If experiment does not exist, an experiment with provided name 20 | will be created. 21 | tags (Dict[str], optional): Tags for the current run. 22 | Default None. If not None, set tags for the current run. 23 | log_model (bool, optional): Whether to log an MLflow artifact. 24 | Default True. If True, log runner.model as an MLflow artifact 25 | for the current run. 26 | interval (int): Logging interval (every k iterations). Default: 10. 27 | ignore_last (bool): Ignore the log of last iterations in each epoch 28 | if less than `interval`. Default: True. 29 | reset_flag (bool): Whether to clear the output buffer after logging. 30 | Default: False. 31 | by_epoch (bool): Whether EpochBasedRunner is used. Default: True. 32 | 33 | .. _MLflow: 34 | https://www.mlflow.org/docs/latest/index.html 35 | """ 36 | 37 | def __init__(self, 38 | exp_name: Optional[str] = None, 39 | tags: Optional[Dict] = None, 40 | log_model: bool = True, 41 | interval: int = 10, 42 | ignore_last: bool = True, 43 | reset_flag: bool = False, 44 | by_epoch: bool = True): 45 | super().__init__(interval, ignore_last, reset_flag, by_epoch) 46 | self.import_mlflow() 47 | self.exp_name = exp_name 48 | self.tags = tags 49 | self.log_model = log_model 50 | 51 | def import_mlflow(self) -> None: 52 | try: 53 | import mlflow 54 | import mlflow.pytorch as mlflow_pytorch 55 | except ImportError: 56 | raise ImportError( 57 | 'Please run "pip install mlflow" to install mlflow') 58 | self.mlflow = mlflow 59 | self.mlflow_pytorch = mlflow_pytorch 60 | 61 | @master_only 62 | def before_run(self, runner) -> None: 63 | super().before_run(runner) 64 | if self.exp_name is not None: 65 | self.mlflow.set_experiment(self.exp_name) 66 | if self.tags is not None: 67 | self.mlflow.set_tags(self.tags) 68 | 69 | @master_only 70 | def log(self, runner) -> None: 71 | tags = self.get_loggable_tags(runner) 72 | if tags: 73 | self.mlflow.log_metrics(tags, step=self.get_iter(runner)) 74 | 75 | @master_only 76 | def after_run(self, runner) -> None: 77 | if self.log_model: 78 | self.mlflow_pytorch.log_model( 79 | runner.model, 80 | 'models', 81 | pip_requirements=[f'torch=={TORCH_VERSION}']) -------------------------------------------------------------------------------- /druglib/core/runner/hooks/logger/neptune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from typing import Dict, Optional 3 | 4 | from ...dist_utils import master_only 5 | from ..hook import HOOKS 6 | from .base import LoggerHook 7 | 8 | 9 | @HOOKS.register_module() 10 | class NeptuneLoggerHook(LoggerHook): 11 | """Class to log metrics to NeptuneAI. 12 | 13 | It requires `Neptune`_ to be installed. 14 | 15 | Args: 16 | init_kwargs (dict): a dict contains the initialization keys as below: 17 | 18 | - project (str): Name of a project in a form of 19 | namespace/project_name. If None, the value of NEPTUNE_PROJECT 20 | environment variable will be taken. 21 | - api_token (str): User’s API token. If None, the value of 22 | NEPTUNE_API_TOKEN environment variable will be taken. Note: It is 23 | strongly recommended to use NEPTUNE_API_TOKEN environment 24 | variable rather than placing your API token in plain text in your 25 | source code. 26 | - name (str, optional, default is 'Untitled'): Editable name of the 27 | run. Name is displayed in the run's Details and in Runs table as 28 | a column. 29 | 30 | Check https://docs.neptune.ai/api-reference/neptune#init for more 31 | init arguments. 32 | interval (int): Logging interval (every k iterations). Default: 10. 33 | ignore_last (bool): Ignore the log of last iterations in each epoch 34 | if less than ``interval``. Default: True. 35 | reset_flag (bool): Whether to clear the output buffer after logging. 36 | Default: True. 37 | with_step (bool): If True, the step will be logged from 38 | ``self.get_iters``. Otherwise, step will not be logged. 39 | Default: True. 40 | by_epoch (bool): Whether EpochBasedRunner is used. Default: True. 41 | 42 | .. _Neptune: 43 | https://docs.neptune.ai 44 | """ 45 | 46 | def __init__(self, 47 | init_kwargs: Optional[Dict] = None, 48 | interval: int = 10, 49 | ignore_last: bool = True, 50 | reset_flag: bool = True, 51 | with_step: bool = True, 52 | by_epoch: bool = True): 53 | 54 | super().__init__(interval, ignore_last, reset_flag, by_epoch) 55 | self.import_neptune() 56 | self.init_kwargs = init_kwargs 57 | self.with_step = with_step 58 | 59 | def import_neptune(self) -> None: 60 | try: 61 | import neptune.new as neptune 62 | except ImportError: 63 | raise ImportError( 64 | 'Please run "pip install neptune-client" to install neptune') 65 | self.neptune = neptune 66 | self.run = None 67 | 68 | @master_only 69 | def before_run(self, runner) -> None: 70 | if self.init_kwargs: 71 | self.run = self.neptune.init(**self.init_kwargs) 72 | else: 73 | self.run = self.neptune.init() 74 | 75 | @master_only 76 | def log(self, runner) -> None: 77 | tags = self.get_loggable_tags(runner) 78 | if tags: 79 | for tag_name, tag_value in tags.items(): 80 | if self.with_step: 81 | self.run[tag_name].log( # type: ignore 82 | tag_value, step=self.get_iter(runner)) 83 | else: 84 | tags['global_step'] = self.get_iter(runner) 85 | self.run[tag_name].log(tags) # type: ignore 86 | 87 | @master_only 88 | def after_run(self, runner) -> None: 89 | self.run.stop() # type: ignore -------------------------------------------------------------------------------- /druglib/core/runner/hooks/logger/segmind.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from ...dist_utils import master_only 3 | from ..hook import HOOKS 4 | from .base import LoggerHook 5 | 6 | 7 | @HOOKS.register_module() 8 | class SegmindLoggerHook(LoggerHook): 9 | """Class to log metrics to Segmind. 10 | 11 | It requires `Segmind`_ to be installed. 12 | 13 | Args: 14 | interval (int): Logging interval (every k iterations). Default: 10. 15 | ignore_last (bool): Ignore the log of last iterations in each epoch 16 | if less than `interval`. Default True. 17 | reset_flag (bool): Whether to clear the output buffer after logging. 18 | Default False. 19 | by_epoch (bool): Whether EpochBasedRunner is used. Default True. 20 | 21 | .. _Segmind: 22 | https://docs.segmind.com/python-library 23 | """ 24 | 25 | def __init__(self, 26 | interval: int = 10, 27 | ignore_last: bool = True, 28 | reset_flag: bool = False, 29 | by_epoch=True): 30 | super().__init__(interval, ignore_last, reset_flag, by_epoch) 31 | self.import_segmind() 32 | 33 | def import_segmind(self) -> None: 34 | try: 35 | import segmind 36 | except ImportError: 37 | raise ImportError( 38 | "Please run 'pip install segmind' to install segmind") 39 | self.log_metrics = segmind.tracking.fluent.log_metrics 40 | self.mlflow_log = segmind.utils.logging_utils.try_mlflow_log 41 | 42 | @master_only 43 | def log(self, runner) -> None: 44 | tags = self.get_loggable_tags(runner) 45 | if tags: 46 | # logging metrics to segmind 47 | self.mlflow_log( 48 | self.log_metrics, tags, step=runner.epoch, epoch=runner.epoch) 49 | -------------------------------------------------------------------------------- /druglib/core/runner/hooks/logger/tensorboard.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | import os.path as osp 3 | from typing import Optional 4 | 5 | from druglib.utils import TORCH_VERSION, digit_version 6 | from ...dist_utils import master_only 7 | from ..hook import HOOKS 8 | from .base import LoggerHook 9 | 10 | 11 | @HOOKS.register_module() 12 | class TensorboardLoggerHook(LoggerHook): 13 | """Class to log metrics to Tensorboard. 14 | 15 | Args: 16 | log_dir (string): Save directory location. Default: None. If default 17 | values are used, directory location is ``runner.work_dir``/tf_logs. 18 | interval (int): Logging interval (every k iterations). Default: True. 19 | ignore_last (bool): Ignore the log of last iterations in each epoch 20 | if less than `interval`. Default: True. 21 | reset_flag (bool): Whether to clear the output buffer after logging. 22 | Default: False. 23 | by_epoch (bool): Whether EpochBasedRunner is used. Default: True. 24 | """ 25 | 26 | def __init__(self, 27 | log_dir: Optional[str] = None, 28 | interval: int = 10, 29 | ignore_last: bool = True, 30 | reset_flag: bool = False, 31 | by_epoch: bool = True): 32 | super().__init__(interval, ignore_last, reset_flag, by_epoch) 33 | self.log_dir = log_dir 34 | 35 | @master_only 36 | def before_run(self, runner) -> None: 37 | super().before_run(runner) 38 | if (TORCH_VERSION == 'parrots' 39 | or digit_version(TORCH_VERSION) < digit_version('1.1')): 40 | try: 41 | from tensorboardX import SummaryWriter 42 | except ImportError: 43 | raise ImportError('Please install tensorboardX to use ' 44 | 'TensorboardLoggerHook.') 45 | else: 46 | try: 47 | from torch.utils.tensorboard import SummaryWriter 48 | except ImportError: 49 | raise ImportError( 50 | 'Please run "pip install future tensorboard" to install ' 51 | 'the dependencies to use torch.utils.tensorboard ' 52 | '(applicable to PyTorch 1.1 or higher)') 53 | 54 | if self.log_dir is None: 55 | self.log_dir = osp.join(runner.work_dir, 'tf_logs') 56 | self.writer = SummaryWriter(self.log_dir) 57 | 58 | @master_only 59 | def log(self, runner) -> None: 60 | tags = self.get_loggable_tags(runner, allow_text=True) 61 | for tag, val in tags.items(): 62 | if isinstance(val, str): 63 | self.writer.add_text(tag, val, self.get_iter(runner)) 64 | else: 65 | self.writer.add_scalar(tag, val, self.get_iter(runner)) 66 | 67 | @master_only 68 | def after_run(self, runner) -> None: 69 | self.writer.close() -------------------------------------------------------------------------------- /druglib/core/runner/hooks/memory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | import torch 3 | 4 | from .hook import HOOKS, Hook 5 | 6 | @HOOKS.register_module() 7 | class EmptyCacheHook(Hook): 8 | 9 | def __init__( 10 | self, 11 | before_epoch: bool = False, 12 | after_epoch: bool = True, 13 | after_iter: bool = False, 14 | ): 15 | self._before_epoch = before_epoch 16 | self._after_epoch = after_epoch 17 | self._after_iter = after_iter 18 | 19 | def after_iter(self, runner): 20 | if self._after_iter: 21 | torch.cuda.empty_cache() 22 | 23 | def before_epoch(self, runner): 24 | if self._before_epoch: 25 | torch.cuda.empty_cache() 26 | 27 | def after_epoch(self, runner): 28 | if self._after_epoch: 29 | torch.cuda.empty_cache() -------------------------------------------------------------------------------- /druglib/core/runner/hooks/sampler_seed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .hook import HOOKS, Hook 3 | 4 | @HOOKS.register_module() 5 | class DistSamplerSeedHook(Hook): 6 | """ 7 | Data-loading sampler for distributed training. 8 | When distributed training, it is only useful in conjunction with 9 | :obj:`EpochBasedRunner`, while :obj:`IterBasedRunner` achieves the same 10 | purpose with :obj:`IterLoader`. 11 | """ 12 | def before_epoch(self, runner): 13 | if hasattr(runner.data_loader.sampler, 'set_epoch'): 14 | # in case the data loader uses `SequentialSampler` in PyTorch 15 | runner.data_loader.sampler.set_epoch(runner.epoch) 16 | elif hasattr(runner.data_loader.batch_sampler.sampler, 'set_epoch'): 17 | # batch sampler in PyTorch wraps the sampler as its attributes. 18 | runner.data_loader.batch_sampler.sampler.set_epoch(runner.epoch) -------------------------------------------------------------------------------- /druglib/core/runner/hooks/sync_buffer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from ..dist_utils import allreduce_params 3 | from .hook import HOOKS, Hook 4 | 5 | @HOOKS.register_module() 6 | class SyncBuffersHook(Hook): 7 | """ 8 | Synchronize model buffers such as running_mean and running_var in BN at 9 | the end of each epoch. 10 | Args: 11 | distributed:bool: Whether distributed training is used. It is 12 | effective only for distributed training. Defaults to True. 13 | """ 14 | def __init__(self, distributed:bool = True): 15 | self.distributed = distributed 16 | 17 | def after_epoch(self, runner): 18 | if self.distributed: 19 | allreduce_params(runner.model.buffers()) -------------------------------------------------------------------------------- /druglib/core/runner/log_buffer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from collections import OrderedDict 3 | 4 | import numpy as np 5 | 6 | class LogBuffer: 7 | 8 | def __init__(self): 9 | self.val_history = OrderedDict() 10 | self.n_history = OrderedDict() 11 | self.output = OrderedDict() 12 | self.ready = False 13 | 14 | def clear(self): 15 | self.val_history.clear() 16 | self.n_history.clear() 17 | self.clear_output() 18 | 19 | def clear_output(self): 20 | self.output.clear() 21 | self.ready = False 22 | 23 | def update( 24 | self, 25 | vars:dict, 26 | count:int = 1, 27 | ): 28 | assert isinstance(vars, dict) 29 | for k, v in vars.items(): 30 | if k not in self.val_history: 31 | self.val_history[k] = [] 32 | self.n_history[k] = [] 33 | self.val_history[k].append(v) 34 | self.n_history[k].append(count) 35 | 36 | def average( 37 | self, 38 | n:int = 0 39 | ): 40 | """ 41 | Average latest n values or all values 42 | """ 43 | assert n >= 0 44 | for k in self.val_history: 45 | v = np.array(self.val_history[k][-n:]) 46 | nums = np.array(self.n_history[k][-n:]) 47 | avg = np.sum(v * nums) / np.sum(nums) 48 | self.output[k] = avg 49 | self.ready = True -------------------------------------------------------------------------------- /druglib/core/runner/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .builder import OPTIMIZERS, OPTIMIZERS_BUILDERS, builder_optimizer 3 | from .default_OptBuilder import DefaultOptimizerBuilder 4 | from .optimizers import Lion 5 | 6 | __all__ = [ 7 | 'OPTIMIZERS', 'OPTIMIZERS_BUILDERS', 'builder_optimizer', 'DefaultOptimizerBuilder', 8 | 'Lion', 9 | ] -------------------------------------------------------------------------------- /druglib/core/runner/optimizer/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | import copy, inspect, torch 3 | from typing import Dict, List 4 | 5 | from ....utils import Registry, build_from_cfg 6 | 7 | OPTIMIZERS = Registry('optimizer') 8 | OPTIMIZERS_BUILDERS = Registry('optimizer builder') 9 | 10 | def register_torch_optimizers() -> List[str]: 11 | torch_optimizers = [] 12 | for module_name in dir(torch.optim): 13 | # filer to get ['ASGD', 'Adadelta', 'Adagrad', 14 | # 'Adam', 'AdamW', 'Adamax', 'LBFGS', 15 | # 'RMSprop', 'Rprop', 'SGD', 'SparseAdam'] 16 | # in torch version == 1.9.0 17 | if "_" in module_name: 18 | continue 19 | _optim = getattr(torch.optim, module_name) 20 | if inspect.isclass(_optim) and issubclass(_optim, torch.optim.Optimizer): 21 | OPTIMIZERS.register_module()(_optim) 22 | torch_optimizers.append(module_name) 23 | return torch_optimizers 24 | 25 | TORCH_OPTIMIZERS = register_torch_optimizers() 26 | 27 | def build_optimizer_builder(cfg: Dict): 28 | return build_from_cfg( 29 | cfg, 30 | OPTIMIZERS_BUILDERS, 31 | ) 32 | 33 | def builder_optimizer( 34 | model, 35 | cfg: Dict, 36 | ): 37 | optimizer_cfg = copy.deepcopy(cfg) 38 | optimizer_builder = optimizer_cfg.pop( 39 | "OptimizerBuilder", 40 | "DefaultOptimizerBuilder", 41 | ) 42 | paramwise_cfg = optimizer_cfg.pop( 43 | "paramwise_cfg", 44 | None 45 | ) 46 | builder_cfg = { 47 | "type": optimizer_builder, 48 | "optimizer_cfg": optimizer_cfg, 49 | "paramwise_cfg": paramwise_cfg, 50 | } 51 | optim_builder = build_optimizer_builder(builder_cfg) 52 | optimizer = optim_builder(model) 53 | return optimizer 54 | -------------------------------------------------------------------------------- /druglib/core/runner/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .registry import MODULE_WRAPPERS 3 | from .scatter_gather import scatter, scatter_kwargs 4 | from .data_parallel import MDLDataParallel 5 | from .distributed import MDLDistributedDataParallel 6 | from .collate import collate 7 | from .utils import is_module_wrapper, is_mlu_available, get_device, build_dp, build_ddp 8 | 9 | 10 | 11 | __all__ = [ 12 | 'MODULE_WRAPPERS', 'is_module_wrapper', 'scatter', 'scatter_kwargs', 13 | 'MDLDataParallel', 'MDLDistributedDataParallel', 'collate', 'is_mlu_available', 14 | 'get_device', 'build_dp', 'build_ddp', 15 | ] 16 | -------------------------------------------------------------------------------- /druglib/core/runner/parallel/_functions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from typing import Union, Optional, List, Tuple 3 | 4 | import torch 5 | from torch import Tensor 6 | from torch.nn.parallel._functions import _get_stream 7 | from druglib.data import BaseData 8 | from druglib.utils import TORCH_VERSION, digit_version 9 | 10 | 11 | def scatter( 12 | input: Union[List, Tensor], 13 | devices: List, 14 | streams: Optional[List] = None, 15 | ) -> Union[List, Tensor]: 16 | if streams is None: 17 | streams = [None] * len(devices) 18 | if isinstance(input, list): 19 | chunk_size = (len(input) - 1) // len(devices) + 1 20 | # scatter: distribute data to different gpus 21 | outputs = [scatter( 22 | input[i], [devices[i // chunk_size]], [streams[i // chunk_size]] 23 | ) for i in range(len(input))] 24 | return outputs 25 | elif isinstance(input, (Tensor, BaseData)): 26 | output = input.contiguous() 27 | stream = streams[0] if output.numel() > 0 else None 28 | 29 | if devices != [-1]: 30 | # when stream == None, no cpu to gpu works 31 | with torch.cuda.device(devices[0]), torch.cuda.stream(stream): 32 | output = output.cuda(devices[0], non_blocking=False) 33 | return output 34 | else: 35 | raise Exception(f"Unknown type {type(input)}") 36 | 37 | def get_input_device( 38 | input: Union[List, Tensor, BaseData], 39 | ) -> int: 40 | """-1 represents CPU; input device is either cpu or the same cuda id""" 41 | if isinstance(input, List): 42 | for t in input: 43 | device = get_input_device(t) 44 | if device != -1: 45 | return device 46 | return -1 47 | elif isinstance(input, (Tensor, BaseData)): 48 | return input.get_device() if input.is_cuda() else -1 49 | else: 50 | raise Exception(f"Unknown type {type(input)}") 51 | 52 | def synchronize_stream( 53 | output: Union[List, Tensor, BaseData], 54 | devices: List, 55 | streams: List, 56 | ) -> None: 57 | if isinstance(output, List): 58 | chunk_size = len(output) // len(devices) 59 | for d in range(len(devices)): 60 | for c in range(chunk_size): 61 | synchronize_stream( 62 | output[ d * chunk_size + c ], 63 | [devices[d]], 64 | [streams[d]] 65 | ) 66 | elif isinstance(output, (Tensor, BaseData)): 67 | if output.numel() > 0: 68 | with torch.cuda.device(devices[0]): 69 | main_stream = torch.cuda.current_stream() 70 | main_stream.wait_stream(streams[0]) 71 | output.record_stream(main_stream) 72 | else: 73 | raise Exception(f"Unknown type {type(output)}") 74 | 75 | class Scatter: 76 | 77 | @staticmethod 78 | def forward( 79 | target_gpus: List[int], 80 | input: Union[Tensor, List, BaseData], 81 | ) -> Tuple: 82 | device = get_input_device(input) 83 | streams = None 84 | if device == -1 and target_gpus != [-1]: 85 | # Perform CPU to GPU copies in a background stream 86 | if digit_version(TORCH_VERSION) < digit_version('2.0.0'): 87 | streams = [_get_stream(gpu) for gpu in target_gpus] 88 | else: 89 | streams = [_get_stream(torch.device(f'cuda:{gpu}')) for gpu in target_gpus] 90 | outputs = scatter(input, target_gpus, streams) 91 | if streams is not None: 92 | synchronize_stream(outputs, target_gpus, streams) 93 | 94 | return tuple(outputs) if isinstance(outputs, list) else (outputs, ) 95 | -------------------------------------------------------------------------------- /druglib/core/runner/parallel/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | import torch.nn 3 | from torch.nn.parallel import DataParallel, DistributedDataParallel 4 | 5 | from druglib.utils import Registry 6 | 7 | MODULE_WRAPPERS = Registry('module wrapper') 8 | MODULE_WRAPPERS.register_module(module = DataParallel) 9 | MODULE_WRAPPERS.register_module(module = DistributedDataParallel) -------------------------------------------------------------------------------- /druglib/core/runner/parallel/scatter_gather.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from typing import Union, List, Tuple 3 | 4 | import torch 5 | from torch import Tensor 6 | from torch.nn.parallel._functions import Scatter as OrigScatter 7 | 8 | from ._functions import Scatter 9 | from druglib.data import DataContainer, BaseData 10 | 11 | ScatterInputs = Union[BaseData, Tensor, DataContainer, tuple, list, dict] 12 | 13 | def scatter( 14 | inputs: ScatterInputs, 15 | target_gpus: List[int], 16 | dim: int = 0 17 | ) -> list: 18 | """ 19 | Scatter inputs to target gpus. 20 | 21 | The only difference from original :func:`scatter` is to add support for 22 | :type:`DataContainer`. 23 | """ 24 | def scatter_map(obj): 25 | if isinstance(obj, (Tensor, BaseData)): 26 | if target_gpus != [-1]: 27 | return OrigScatter.apply(target_gpus, None, dim, obj) 28 | else: 29 | # for CPU inference we use self-implement scatter 30 | return Scatter.forward(target_gpus, obj) 31 | if isinstance(obj, DataContainer): 32 | # Then thirdly, section 1 33 | if obj.cpu_only: 34 | return obj.data 35 | else: 36 | return Scatter.forward(target_gpus, obj.data) 37 | if isinstance(obj, tuple) and len(obj) > 0: 38 | # Then, secondly, ("keys", DataContainer) is input, 39 | # output [("keys", Tensor[Batch,...]), ("keys", Tensor[Batch,...])] 40 | return list(zip(*map(scatter_map, obj))) 41 | if isinstance(obj, list) and len(obj) > 0: 42 | out = list(map(list, zip(*map(scatter_map, obj)))) 43 | return out 44 | if isinstance(obj, dict) and len(obj) > 0: 45 | # In this part, when target_gpus set to N (len(target_gpus) = N), firstly, 46 | # "keys":DataContainer([Tensor[Batch,...]] * N) 47 | # (or "keys":DC(list[list[Batch*] * N])) 48 | # obj.items() \equiv (("keys", DataContainer), ) 49 | # zip(*map(scatter_map, obj.items())) -> 50 | # Output: ((("keys", Tensor[Batch,...]), ...(.other keys if keys len > 2)), ...(len(target_gpus) ((...), ...))) 51 | # apply dict `map(type(obj)...` -> (dict(keys=Tensor[Batch,...], ...(other keys)=...), ...(len(target_gpus) dict(...=...))) 52 | # so get target_gpus-wise data in the form of dict type. 53 | out = list(map(type(obj), zip(*map(scatter_map, obj.items())))) 54 | return out 55 | # Then thirdly, section 2 56 | return [obj for _ in target_gpus] 57 | 58 | # After scatter_map is called, a scatter_map cell will exist. This cell 59 | # has a reference to the actual function scatter_map, which has references 60 | # to a closure that has a reference to the scatter_map cell (because the 61 | # fn is recursive). To avoid this reference cycle, we set the function to 62 | # None, clearing the cell 63 | try: 64 | return scatter_map(inputs) 65 | finally: 66 | scatter_map = None 67 | 68 | def scatter_kwargs( 69 | inputs: ScatterInputs, 70 | kwargs: ScatterInputs, 71 | target_gpus: List[int], 72 | dim: int = 0, 73 | ) -> Tuple[tuple, tuple]: 74 | """Scatter with support for kwargs dictionary.""" 75 | inputs = scatter(inputs, target_gpus, dim) if inputs else [] 76 | kwargs = scatter(kwargs, target_gpus, dim) if kwargs else [] 77 | if len(inputs) < len(kwargs): 78 | length = len(kwargs) - len(inputs) 79 | inputs.extend([() for _ in range(length)]) 80 | elif len(kwargs) < len(inputs): 81 | length = len(inputs) - len(kwargs) 82 | kwargs.extend([{} for _ in range(length)]) 83 | inputs = tuple(inputs) 84 | kwargs = tuple(kwargs) 85 | return inputs, kwargs -------------------------------------------------------------------------------- /druglib/core/runner/priority.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from enum import Enum 3 | from typing import Union 4 | 5 | class Priority(Enum): 6 | """ 7 | Hook priority levels. 8 | 9 | +--------------+------------+ 10 | | Level | Value | 11 | +==============+============+ 12 | | HIGHEST | 0 | 13 | +--------------+------------+ 14 | | VERY_HIGH | 10 | 15 | +--------------+------------+ 16 | | HIGH | 30 | 17 | +--------------+------------+ 18 | | ABOVE_NORMAL | 40 | 19 | +--------------+------------+ 20 | | NORMAL | 50 | 21 | +--------------+------------+ 22 | | BELOW_NORMAL | 60 | 23 | +--------------+------------+ 24 | | LOW | 70 | 25 | +--------------+------------+ 26 | | VERY_LOW | 90 | 27 | +--------------+------------+ 28 | | LOWEST | 100 | 29 | +--------------+------------+ 30 | """ 31 | HIGHEST = 0 32 | VERY_HIGH = 10 33 | HIGH = 30 34 | ABOVE_NORMAL = 40 35 | NORMAL = 50 36 | BELOW_NORMAL = 60 37 | LOW = 70 38 | VERY_LOW = 90 39 | LOWEST = 100 40 | 41 | def get_priority( 42 | priority:Union[int, str, Priority], 43 | ) -> int: 44 | """ 45 | Get priority value. 46 | Args: 47 | priority:Union[int, str, Priority]: Priority. 48 | Returns: 49 | int: The priority value. 50 | """ 51 | if isinstance(priority, int): 52 | if priority < 0 or priority > 100: 53 | raise ValueError('`priority` must be between 0 and 100.') 54 | return priority 55 | elif isinstance(priority, Priority): 56 | return priority.value 57 | elif isinstance(priority, str): 58 | return Priority[priority.upper()].value 59 | else: 60 | raise TypeError( 61 | '`priority must be an integer or Priority enum value.`' 62 | ) -------------------------------------------------------------------------------- /druglib/core/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .base_trainer import auto_scale_lr, train_model 3 | 4 | 5 | __all__ = [ 6 | 'auto_scale_lr', 'train_model' 7 | ] -------------------------------------------------------------------------------- /druglib/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .data_container import DataContainer 3 | from .data import BaseData, Data 4 | from .hetero_data import HeteroData 5 | from .batch import Batch 6 | from .dataloader_collate import collate 7 | 8 | __all__ = [ 9 | 'DataContainer', 'BaseData', 'Data', 'HeteroData', 'Batch', 'collate' 10 | ] -------------------------------------------------------------------------------- /druglib/data/cv_store.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from dataclasses import dataclass 3 | from .feature_store import TensorAttr, _field_status 4 | 5 | 6 | @dataclass 7 | class CVTensorAttr(TensorAttr): 8 | """Attribute class for CV Data, whose `group_name` is 'cv'.""" 9 | def __init__( 10 | self, 11 | attr_name = _field_status.UNSET, 12 | index = _field_status.UNSET, 13 | ): 14 | # Treat group_name as optional, and move it to the end 15 | super().__init__("cv", attr_name, index) 16 | -------------------------------------------------------------------------------- /druglib/data/dataloader_collate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from typing import Mapping, Optional, Union, List 3 | from collections.abc import Sequence 4 | 5 | from torch.utils.data.dataloader import default_collate 6 | from .data_container import DataContainer 7 | from .data import BaseData 8 | from .batch import Batch 9 | 10 | 11 | def collate( 12 | batch: Sequence, 13 | samples_per_gpu: int = 1, 14 | follow_batch: Optional[Union[List[str]]] = None, 15 | exclude_keys: Optional[Union[List[str]]] = None, 16 | ): 17 | """ 18 | Puts each data field into a tensor/DataContainer with outer dimension 19 | batch size. 20 | in the implement, ready for data formatting: 21 | 1. formatting one after collect data by PIPELINES in training 22 | [Data{ 23 | 'img':DC(Tensor[C, H, W]), 24 | 'gt_bbox':DC(Tensor[N,4]), 25 | 'gt_label':DC(Tensor[N,]), 26 | 'metastore':DC({'flip':True, '..':bool}), 27 | }, ....] --> len(data) = Batch Size 28 | 2. formatting two after MultiScaleFlipAug by PIPELINES in test 29 | [[ 30 | num_augs * Data{"img": ...} 31 | ], ....] --> len(data) = Batch Size 32 | in PyG implement, ready for data formatting: 33 | [Data{ 34 | 'x': DC(Tensor [num_nodes, node_feature]), 35 | 'edge_index': DC(Tensor [2, num_edges] or SparseTensor), 36 | 'y': DC(Tensor [num_nodes,] or graph-level [1, ...]), 37 | 'metastore': DC({'num_nodes': 10, '..': bool}), 38 | } 39 | ] 40 | Extend default_collate to add support for :type:`DataContainer`. 41 | There are 3 cases. 42 | 1. cpu_only = True, e.g., meta data 43 | 2. cpu_only = False, stack = True, e.g., data tensors, such as images tensors; PyG data, using batch 44 | 3. cpu_only = False, stack = False, e.g., gt bboxes in cv area, constraints info in biology area 45 | """ 46 | if not isinstance(batch, Sequence): 47 | raise TypeError(f'{batch.dtype} is not supported.') 48 | 49 | batch_ele = batch[0] 50 | if isinstance(batch_ele, BaseData): 51 | # TODO: this collate function is better suitable for cv task with image data 52 | # and PyG Graph Learning tasl with image data, 53 | # other tasks required new collate function. 54 | batch_list = [] 55 | for i in range(0, len(batch), samples_per_gpu): 56 | batch_per_gpu = Batch.from_data_list( 57 | batch[i:i + samples_per_gpu], 58 | follow_batch = follow_batch, 59 | exclude_keys = exclude_keys 60 | ) 61 | batch_list.append(batch_per_gpu) 62 | return DataContainer( 63 | data = batch_list, 64 | ) 65 | 66 | elif isinstance(batch_ele, Sequence): 67 | transposed = zip(*batch) 68 | return [collate( 69 | samples, 70 | samples_per_gpu, 71 | follow_batch, 72 | exclude_keys 73 | ) for samples in transposed] 74 | 75 | elif isinstance(batch_ele, Mapping): 76 | return { 77 | key: collate( 78 | [b[key] for b in batch], 79 | samples_per_gpu, 80 | follow_batch, 81 | exclude_keys 82 | ) for key in batch_ele 83 | } 84 | 85 | else: 86 | return default_collate(batch) -------------------------------------------------------------------------------- /druglib/data/mappingview.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | # Copied from https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/data/view.py 3 | from typing import ( 4 | Iterable, 5 | List 6 | ) 7 | from collections.abc import Mapping 8 | 9 | 10 | class MappingView(object): 11 | __class__getitem__ = classmethod(type([])) 12 | 13 | def __init__( 14 | self, 15 | _mapping: Mapping, 16 | *args: List[str], 17 | ): 18 | self._mapping = _mapping 19 | self._args = args 20 | 21 | def _keys(self) -> Iterable: 22 | if len(self._args) == 0: 23 | return self._mapping.keys() 24 | else: 25 | return [k for k in self._args if k in self._mapping] 26 | 27 | def __len__(self) -> int: 28 | return len(self._keys()) 29 | 30 | def __repr__(self) -> str: 31 | mapping = {k: self._mapping[k] for k in self._keys()} 32 | return f"{self.__class__.__name__}({mapping})" 33 | 34 | class KeysView(MappingView): 35 | def __iter__(self) -> Iterable: 36 | yield from self._keys() 37 | 38 | class ValuesView(MappingView): 39 | def __iter__(self) -> Iterable: 40 | for k in self._keys(): 41 | yield self._mapping[k] 42 | 43 | class ItemsView(MappingView): 44 | def __iter__(self) -> Iterable: 45 | for k in self._keys(): 46 | yield (k, self._mapping[k]) -------------------------------------------------------------------------------- /druglib/data/mixin.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | 3 | class CastMixin: 4 | """Copied from https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/utils/mixin.py""" 5 | @classmethod 6 | def cast(cls, *args, **kwargs): 7 | if len(args) == 1 and len(kwargs) == 0: 8 | elem = args[0] 9 | if elem is None: 10 | return None 11 | if isinstance(elem, CastMixin): 12 | return elem 13 | if isinstance(elem, (tuple, list)): 14 | return cls(*elem) 15 | if isinstance(elem, dict): 16 | return cls(**elem) 17 | return cls(*args, **kwargs) -------------------------------------------------------------------------------- /druglib/data/typing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from typing import ( 3 | Optional, 4 | Dict, 5 | Union, 6 | List, 7 | Tuple, 8 | ) 9 | from enum import Enum 10 | 11 | import torch 12 | from torch import Tensor 13 | 14 | from torch_sparse import SparseTensor 15 | from torch_geometric.typing import ( 16 | # Node-types are denoted by a single string, e.g.: `data['paper']`: 17 | NodeType, 18 | # Edge-types are denotes by a triplet of strings, e.g.: 19 | # `data[('author', 'writes', 'paper')] 20 | EdgeType, 21 | # There exist some short-cuts to query edge-types (given that the full triplet 22 | # can be uniquely reconstructed, e.g.: 23 | # * via str: `data['writes']` 24 | # * via Tuple[str, str]: `data[('author', 'paper')]` 25 | QueryType, 26 | Metadata, 27 | # Types for message passing 28 | Adj, 29 | OptTensor, 30 | PairTensor, 31 | OptPairTensor, 32 | PairOptTensor, 33 | Size, 34 | NoneType, 35 | ) 36 | from .data_container import DataContainer, IndexType 37 | 38 | # TODO note that "y" maybe graph-level label 39 | # absolute words for the data type assignment 40 | NODEWORLD = ["x", "pos", "y", "node_attr", "batch", "node_feature", "ptr"] 41 | EDGEWORLD = ["edge_index", "edge_weight", "edge_attr", "edge_feature", 42 | "adj", "adj_t", "face", "coo", "csc", "csr"] 43 | CVWORLD = ["img", "bbox", "cvmask", "cvlabel", "proposal", "seg"] 44 | 45 | # key words to search the other data, so-called `relative` words 46 | NODEKEYS = ["node", "pos"] 47 | EDGEKEYS = ["edge", "adj", "face"] 48 | CVKEYS = ["img", "bbox", "cvmask", "cvlabel", "proposal", "seg", "imgcls"] 49 | 50 | class DataType(Enum): 51 | CV = 'cv' 52 | NODE = 'node' 53 | EDGE = 'edge' 54 | GRAPH = "graph" 55 | META = "meta" 56 | 57 | class EdgeLayout(Enum): 58 | COO = 'coo' 59 | CSC = 'csc' 60 | CSR = 'csr' 61 | 62 | # typing that is missed in importing PyTorch Geometric but has released in the github 63 | # https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/typing.py 64 | # A representation of a feature tensor 65 | FeatureTensorType = Union[Tensor, SparseTensor, DataContainer] 66 | 67 | # A representation of an edge index, following the possible formats: 68 | # * COO: (row, col) 69 | # * CSC: (row, colptr) 70 | # * CSR: (rowptr, col) 71 | EdgeTensorType = Tuple[Tensor, Tensor] 72 | 73 | # Types for sampling 74 | InputNodes = Union[OptTensor, NodeType, Tuple[NodeType, OptTensor]] 75 | InputEdges = Union[OptTensor, EdgeType, Tuple[EdgeType, OptTensor]] 76 | NumNeighbors = Union[List[int], Dict[EdgeType, List[int]]] 77 | 78 | # CV types 79 | CVType = str 80 | 81 | # Data input datatype 82 | OptInput = Union[None, Tensor, DataContainer, SparseTensor] -------------------------------------------------------------------------------- /druglib/datasets/Docking/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .loading import LoadLigand, LoadProtein 3 | from .mol_pipeline import ( 4 | LigandFeaturizer, TorsionFactory, LigandGrapher 5 | ) 6 | from .pocket_pipeline import ( 7 | PocketFinderDefault, SCPocketFinderDefault, 8 | PocketGraphBuilder, PocketFeaturizer, Decentration, 9 | ) 10 | from .struct_init import ( 11 | LigInit, SCFixer, SCProtInit, 12 | ) 13 | from .formatting import Atom14ToAllAtomsRepr, ToPLData 14 | 15 | 16 | __all__ = [ 17 | 'LoadLigand', 'LoadProtein', 18 | 'LigandFeaturizer', 'TorsionFactory', 'LigandGrapher', 19 | 'PocketFinderDefault', 'SCPocketFinderDefault', 20 | 'PocketGraphBuilder', 'PocketFeaturizer', 'Decentration', 21 | 'LigInit', 'SCFixer', 'SCProtInit', 22 | 'Atom14ToAllAtomsRepr', 'ToPLData', 23 | ] 24 | 25 | -------------------------------------------------------------------------------- /druglib/datasets/Docking/formatting.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from typing import Any 3 | from druglib.data import Data, DataContainer 4 | from ..builder import PIPELINES 5 | 6 | class PLData(Data): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | 10 | def __cat_dim__(self, key: str, value: DataContainer, *args, **kwargs) -> Any: 11 | if key in ['torsion_edge_index']: 12 | return 0 13 | return super().__cat_dim__(key, value, *args, **kwargs) 14 | 15 | def __inc__(self, key: str, value: DataContainer, *args, **kwargs) -> Any: 16 | if key in ['lig_edge_index']: 17 | return self['lig_node'].size(0) 18 | elif key in ['torsion_edge_index']: 19 | # consider missing atoms 20 | if 'rec_atm_pos' in self: 21 | return self['rec_atm_pos'].size(0) 22 | elif 'atom14_position' in self: 23 | return self['atom14_position'][self['atom14_mask']].size(0) 24 | 25 | return super().__inc__(key, value, *args, **kwargs) 26 | 27 | 28 | @PIPELINES.register_module() 29 | class ToPLData: 30 | """ 31 | Use `PLData` encapsulate data for easy collation. 32 | """ 33 | def __call__(self, data) -> Data: 34 | return PLData(**data) 35 | 36 | def __repr__(self): 37 | return (f'{self.__class__.__name__}(' 38 | f')') 39 | 40 | @PIPELINES.register_module() 41 | class Atom14ToAllAtomsRepr: 42 | """ 43 | Cancel out atom14 repr to all atoms repr, 44 | so that no worries about mask. 45 | """ 46 | def __call__(self, data): 47 | data['rec_atm_pos'] = data['atom14_position'][data['atom14_mask']] 48 | del data['atom14_position'] 49 | data['pocket_node_feature'] = data['pocket_node_feature'][data['atom14_mask']] 50 | 51 | return data 52 | -------------------------------------------------------------------------------- /druglib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .lmdbdataset import LMDBLoader 3 | from .custom_dataset import CustomDataset 4 | from .builder import DATASETS, PIPELINES, build_dataset, build_dataloader 5 | 6 | 7 | __all__ = [ 8 | 'LMDBLoader', 'CustomDataset', 9 | 'DATASETS', 'PIPELINES', 'build_dataset', 'build_dataloader', 10 | ] -------------------------------------------------------------------------------- /druglib/datasets/base_pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .compose import Compose 3 | from .formatting import (ToTensor, ToSparseTensor, Transpose, 4 | ToData, ToDataContainer, Collect) 5 | 6 | 7 | __all__ = [ 8 | 'Compose', 'ToTensor', 'ToSparseTensor', 'Transpose', 9 | 'ToData', 'ToDataContainer', 'Collect', 10 | ] -------------------------------------------------------------------------------- /druglib/datasets/base_pipelines/compose.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from typing import Callable, List, Union, Tuple 3 | from collections.abc import Sequence, Mapping 4 | 5 | from druglib.utils import build_from_cfg 6 | from ..builder import PIPELINES 7 | 8 | 9 | @PIPELINES.register_module() 10 | class Compose: 11 | """ 12 | Basic function to compose multiple sequential transforms. 13 | `PyTorch Geometric` and `mmcv` extended libraries compatible. 14 | Args: 15 | transforms: Sequence[Mapping|callable]: a sequence of transform object or 16 | config mapping type to be composed. 17 | """ 18 | def __init__( 19 | self, 20 | transforms: Union[List[Mapping], List[Callable]], 21 | ): 22 | assert isinstance(transforms, Sequence) 23 | self.transforms = [] 24 | for t in transforms: 25 | if isinstance(t, Mapping): 26 | transform = build_from_cfg(t, PIPELINES) 27 | self.transforms.append(transform) 28 | elif isinstance(t, Callable): 29 | self.transforms.append(t) 30 | else: 31 | raise TypeError(f"The elements of transforms must be dict or callable, but got {type(t)}") 32 | 33 | def __call__( 34 | self, 35 | data: Union[List[Mapping], Tuple[Mapping], Mapping], 36 | ): 37 | if isinstance(data, (list, tuple)): 38 | return [self(d) for d in data] 39 | elif isinstance(data, Mapping): 40 | for t in self.transforms: 41 | data = t(data) 42 | if data is None: 43 | return None 44 | return data 45 | return data 46 | 47 | def __repr__(self): 48 | string_formated = self.__class__.__name__ + "(" 49 | for t in self.transforms: 50 | str_ = t.__repr__() 51 | if "Compose(" in str_: 52 | str_ = str_.replace("\n", "\n ") 53 | string_formated += "\n" 54 | string_formated += f" {str_}" 55 | string_formated += "\n)" 56 | return string_formated 57 | 58 | -------------------------------------------------------------------------------- /druglib/datasets/lmdbdataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | import os, pickle 3 | from typing import * 4 | import lmdb 5 | from functools import lru_cache 6 | 7 | 8 | class LMDBLoader: 9 | def __init__( 10 | self, 11 | db_path: str, 12 | map_gb: float = 1000.0, 13 | strict_get: bool = True, 14 | _exclude_key: List[str] = ['KEYS'], 15 | ): 16 | if not os.path.exists(db_path): 17 | raise ValueError("{} does not exists.".format(db_path)) 18 | 19 | self.db_path = db_path 20 | self.map_gb = map_gb 21 | self.strict_get = strict_get 22 | self._exclude_key = _exclude_key 23 | 24 | env = self._connect_db(self.db_path) 25 | with env.begin() as txn: 26 | self._keys = [k for k in txn.cursor().iternext(values=False) if k.decode() not in _exclude_key] 27 | 28 | import atexit 29 | atexit.register(lambda s: s._close_db, self) 30 | 31 | def _connect_db( 32 | self, 33 | lmdb_path: str, 34 | attach: bool = False, 35 | ) -> Optional[lmdb.Environment]: 36 | assert getattr(self, '_env', None) is None, 'A connection has already been opened.' 37 | env = lmdb.open( 38 | lmdb_path, 39 | map_size=int(self.map_gb * (1024 * 1024 * 1024)), 40 | create=False, 41 | subdir=os.path.isdir(lmdb_path), 42 | readonly=True, 43 | lock=False, 44 | readahead=False, 45 | meminit=False, 46 | max_readers=256, 47 | ) 48 | if not attach: 49 | return env 50 | else: 51 | self._env = env 52 | 53 | def __len__(self) -> int: 54 | return len(self._keys) 55 | 56 | def __contains__(self, key: str): 57 | return key.encode("ascii") in self._keys 58 | 59 | @lru_cache(maxsize=16) 60 | def __getitem__(self, idx) -> Optional[Any]: 61 | if not hasattr(self, "_env"): 62 | self._connect_db(self.db_path, attach = True) 63 | 64 | idx = str(idx).encode("ascii") 65 | if idx not in self._keys: 66 | if self.strict_get: 67 | raise ValueError(f'query index {idx.decode()} not in lmdb.') 68 | else: 69 | return None 70 | 71 | with self._env.begin() as txn: 72 | with txn.cursor() as curs: 73 | datapoint_pickled = curs.get(idx) 74 | data = pickle.loads(datapoint_pickled) 75 | 76 | return data 77 | 78 | def _close_db(self): 79 | if hasattr(self, '_env') and \ 80 | isinstance(self._env, lmdb.Environment): 81 | self._env.close() 82 | self._env = None 83 | 84 | def to_lmdb( 85 | datas: List[Any], 86 | output_file: str, 87 | ) -> None: 88 | os.makedirs(os.path.dirname(output_file), exist_ok = True) 89 | 90 | env = lmdb.open( 91 | output_file, 92 | map_size=int(1e12), 93 | create=True, 94 | subdir=False, 95 | readonly=False, 96 | ) 97 | txn = env.begin(write=True) 98 | for idx, d in enumerate(datas): 99 | txn.put(str(idx).encode('ascii'), pickle.dumps(d)) 100 | if idx % 1000 == 0: 101 | txn.commit() 102 | txn = env.begin(write=True) 103 | txn.commit() 104 | env.close() 105 | 106 | return -------------------------------------------------------------------------------- /druglib/datasets/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .distributed_sampler import DistributedSampler 3 | from .grouped_batch_sampler import GroupSampler, DistributedGroupSampler 4 | from .iteration_based_sampler import IterBatchSampler, IterGroupBatchSampler 5 | from .graph_learning_sampler import ImbalancedSampler, DynamicBatchSampler 6 | 7 | 8 | __all__ = [ 9 | 'DistributedSampler', 'GroupSampler', 'DistributedGroupSampler', 'IterGroupBatchSampler', 10 | 'IterBatchSampler', 'ImbalancedSampler', 'DynamicBatchSampler', 11 | ] -------------------------------------------------------------------------------- /druglib/models/Base/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HBioquant/DiffBindFR/b8bb82027fab5e74fc83ce2a44c0f920a9012ad3/druglib/models/Base/__init__.py -------------------------------------------------------------------------------- /druglib/models/Base/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HBioquant/DiffBindFR/b8bb82027fab5e74fc83ce2a44c0f920a9012ad3/druglib/models/Base/diffusion/__init__.py -------------------------------------------------------------------------------- /druglib/models/Base/diffusion/time_emb.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | import math 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def sinusoidal_embedding( 10 | timesteps: torch.FloatTensor, 11 | embed_dim: int = 64, 12 | max_positions: int = 10000, 13 | ) -> torch.FloatTensor: 14 | dtype = timesteps.dtype 15 | device = timesteps.device 16 | assert len(timesteps.shape) == 1 17 | half_dim = embed_dim // 2 18 | emb = math.log(max_positions) / (half_dim - 1) 19 | emb = torch.exp(torch.arange(half_dim, dtype = dtype, device = device) * -emb) 20 | emb = timesteps[:, None] * emb[None, :] 21 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim = 1) 22 | if embed_dim % 2 == 1: # zero pad 23 | emb = F.pad(emb, (0, 1, 0, 0), mode = 'constant') 24 | assert emb.shape == (timesteps.shape[0], embed_dim) 25 | 26 | return emb 27 | 28 | class GaussianFourierProjection(nn.Module): 29 | """ 30 | Gaussian Fourier embeddings for noise levels. 31 | from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/models/layerspp.py#L32 32 | """ 33 | def __init__( 34 | self, 35 | embedding_size: int = 256, 36 | scale: float = 1.0, 37 | ): 38 | super().__init__() 39 | self.W = nn.Parameter(torch.randn(embedding_size//2) * scale, requires_grad = False) 40 | 41 | def forward(self, x): 42 | x_proj = x[:, None] * self.W[None, :] * 2 * np.pi 43 | emb = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim = -1) 44 | return emb 45 | 46 | def get_timestep_embfunc( 47 | emb_type: str = 'sinusoidal', 48 | emb_dim: int = 64, 49 | emb_scale: int = 1000, 50 | ): 51 | if emb_type == 'sinusoidal': 52 | emb_func = (lambda x : sinusoidal_embedding(emb_scale * x, emb_dim)) 53 | elif emb_type == 'fourier': 54 | emb_func = GaussianFourierProjection(embedding_size = emb_dim, scale = emb_scale) 55 | else: 56 | raise NotImplementedError('Only support ddpm sinusoidal embedding or score matching Gaussian Fourier Projection.') 57 | 58 | return emb_func -------------------------------------------------------------------------------- /druglib/models/Docking/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .interaction import TensorProductModel 3 | from .scFlex import DiffBindFR 4 | 5 | __all__ = [ 6 | 'TensorProductModel', 'DiffBindFR', 7 | ] 8 | -------------------------------------------------------------------------------- /druglib/models/Docking/default_MLDockBuilder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from typing import Dict, Optional 3 | from ..base_model_builder import BaseModelBuilder 4 | from ..builder import MLDOCK_BUILDER, TASKS_MANAGER 5 | 6 | 7 | @TASKS_MANAGER.register_module() 8 | class DefaultMLDOCKBuilder(BaseModelBuilder): 9 | """ 10 | Default Machine Learning Docking model builder. 11 | """ 12 | 13 | def __init__( 14 | self, 15 | cfg: Dict, 16 | train_cfg: Optional[dict] = None, 17 | test_cfg: Optional[dict] = None, 18 | ): 19 | self.cfg = cfg 20 | self.train_cfg = train_cfg 21 | self.test_cfg = test_cfg 22 | 23 | def build_model(self): 24 | return MLDOCK_BUILDER.build( 25 | self.cfg, 26 | default_args = dict(train_cfg = self.train_cfg, 27 | test_cfg = self.test_cfg) 28 | ) -------------------------------------------------------------------------------- /druglib/models/Docking/encoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HBioquant/DiffBindFR/b8bb82027fab5e74fc83ce2a44c0f920a9012ad3/druglib/models/Docking/encoder/__init__.py -------------------------------------------------------------------------------- /druglib/models/Docking/encoder/equibind_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from typing import Optional, Tuple 3 | import torch 4 | from torch import Tensor 5 | import torch.nn as nn 6 | 7 | from druglib.apis import xavier_init, kaiming_init 8 | 9 | 10 | class AtomEncoder(nn.Module): 11 | """ 12 | This :cls:`AtomEncoder` is from 'EQUIBIND: Geometric Deep Learning for Drug Binding Structure Prediction' 13 | from https://arxiv.org/pdf/2202.05146.pdf 14 | Merged modeified :cls:`AtomEncoder` from 'DiffDock: Diffusion Steps, Twists, and Turns for Molecular Docking' 15 | from https://arxiv.org/abs/2210.01776.pdf 16 | """ 17 | def __init__( 18 | self, 19 | emb_dim: int, 20 | feature_dims: Tuple[Tuple[int], int], 21 | sigma_dim: Optional[int] = None, 22 | lm_embed_type: Optional[int] = None, 23 | n_feats_to_use: Optional[int] = None, 24 | use_bias: bool = False, 25 | ): 26 | super(AtomEncoder, self).__init__() 27 | self.n_feats_to_use = n_feats_to_use 28 | 29 | self.atom_emb_list = nn.ModuleList() 30 | # the first element of feature_dims tuple is a list 31 | # with the lenght of each categorical feature 32 | # the second is the number of scalar features 33 | self.num_onehot = len(feature_dims[0]) 34 | self.scalar_dim = feature_dims[1] 35 | self.sigma_dim = sigma_dim 36 | if sigma_dim is not None: 37 | self.scalar_dim = feature_dims[1] + sigma_dim 38 | self.lm_embed_type = lm_embed_type 39 | for i, dim in enumerate(feature_dims[0]): 40 | emb = nn.Embedding(dim, emb_dim) 41 | self.atom_emb_list.append(emb) 42 | if i + 1 == self.n_feats_to_use: 43 | break 44 | if self.scalar_dim > 0: 45 | self.scalar_lin = nn.Linear(self.scalar_dim + emb_dim, 46 | emb_dim, bias = use_bias) 47 | 48 | if self.lm_embed_type is not None: 49 | if self.lm_embed_type == 'esm': 50 | self.lm_embed_dim = 1280 51 | else: 52 | raise ValueError( 53 | 'LM Embedding type was not correctly determined. LM embedding type: ', 54 | self.lm_embed_type) 55 | self.lm_lin = nn.Linear(self.lm_embed_dim + emb_dim, 56 | emb_dim, bias = use_bias) 57 | 58 | def init_weights(self): 59 | for layer in self.atom_emb_list: 60 | xavier_init(layer, distribution = 'uniform') 61 | if self.scalar_dim > 0: 62 | kaiming_init(self.scalar_lin, distribution = 'uniform') 63 | if self.lm_embed_type is not None: 64 | kaiming_init(self.lm_lin, distribution = 'uniform') 65 | 66 | def forward(self, x: Tensor): 67 | x_emb = 0 68 | if self.lm_embed_type is not None: 69 | assert x.shape[1] == self.num_onehot + self.scalar_dim + self.lm_embed_dim 70 | else: 71 | assert x.shape[1] == self.num_onehot + self.scalar_dim 72 | 73 | for i in range(self.num_onehot): 74 | x_emb += self.atom_emb_list[i](x[:, i].long()) 75 | if i + 1 == self.n_feats_to_use: 76 | break 77 | 78 | if self.scalar_dim > 0: 79 | x_emb += self.scalar_lin( 80 | torch.cat([x_emb, x[:, self.num_onehot:self.num_onehot + self.scalar_dim]], 81 | dim = -1)) 82 | 83 | if self.lm_embed_type is not None: 84 | x_emb += self.lm_lin( 85 | torch.cat([x_emb, x[:, -self.lm_embed_dim:]], 86 | dim = -1)) 87 | 88 | return x_emb 89 | 90 | -------------------------------------------------------------------------------- /druglib/models/Docking/interaction/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .schnet import * 3 | from .tpscore import TensorProductModel 4 | -------------------------------------------------------------------------------- /druglib/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .Base import * 3 | from .Docking import * 4 | from .builder import build_task_model 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /druglib/models/base_model_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class BaseModelBuilder(metaclass=ABCMeta): 6 | 7 | @abstractmethod 8 | def build_model(self): 9 | pass -------------------------------------------------------------------------------- /druglib/models/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | import copy 3 | from typing import Optional 4 | from .base_model_builder import BaseModelBuilder 5 | from druglib.utils import Registry, build_from_cfg, Config 6 | 7 | TASKS_MANAGER = Registry("tasks manager") 8 | 9 | # task-wised model container 10 | MLDOCK_BUILDER = Registry("mldock model builder") 11 | 12 | # subtask module container 13 | ENCODER = Registry("encoder") 14 | ATTENTION = Registry("attention") 15 | DIFFUSION = Registry("diffusion") 16 | INTERACTION = Registry("interaction") 17 | ENERGY = Registry("energy") 18 | 19 | 20 | def build_encoder(cfg): 21 | """Build encoder block""" 22 | return ENCODER.build(cfg) 23 | 24 | def build_attention(cfg): 25 | """Build attention block""" 26 | return ATTENTION.build(cfg) 27 | 28 | def build_diffusion(cfg): 29 | """Build diffusion model""" 30 | return DIFFUSION.build(cfg) 31 | 32 | def build_interaction(cfg): 33 | """Build interaction block""" 34 | return INTERACTION.build(cfg) 35 | 36 | def build_energy(cfg): 37 | """Build energy block""" 38 | return ENERGY.build(cfg) 39 | 40 | def build_task_builder( 41 | cfg: Config, 42 | default_args: Optional[dict] = None, 43 | ) -> BaseModelBuilder: 44 | 45 | return build_from_cfg( 46 | cfg, 47 | TASKS_MANAGER, 48 | default_args, 49 | ) 50 | 51 | 52 | def build_task_model( 53 | cfg: Config, 54 | train_cfg: Optional[dict] = None, 55 | test_cfg: Optional[dict] = None, 56 | ): 57 | assert cfg.get("train_cfg") is None or train_cfg is None, "train_cfg specified in either outer field or model field" 58 | assert cfg.get("test_cfg") is None or test_cfg is None, "test_cfg specified in either outer field or model field" 59 | 60 | task_cfg = copy.deepcopy(cfg) 61 | task = task_cfg.pop( 62 | "task" 63 | ) 64 | # upper string required for calling `task builder` 65 | task = task.upper() 66 | task_builder = task_cfg.pop( 67 | f"{task}Builder", 68 | f"Default{task}Builder", 69 | ) 70 | builder_cfg = { 71 | "type": task_builder, 72 | "cfg": task_cfg 73 | } 74 | model_builder: BaseModelBuilder = build_task_builder( 75 | builder_cfg, 76 | default_args = dict(train_cfg = train_cfg, test_cfg = test_cfg) 77 | ) 78 | 79 | return model_builder.build_model() -------------------------------------------------------------------------------- /druglib/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .utils.which import which 3 | from .dssp import DSSP_bin 4 | from .msms import MSMS_bin 5 | from .smina import ( 6 | Smina_bin, 7 | get_smina_score, 8 | smina_min, 9 | smina_min_forward, 10 | smina_min_inplace, 11 | ) 12 | from .pymol.geom import ( 13 | calc_centroid, parse_lig_center, calc_sasa, 14 | ) 15 | from .pymol.tmalign import tmalign2 16 | from .schrodinger.align import parse_rmsd, bs_algn 17 | 18 | 19 | __all__ = [ 20 | 'which', 'DSSP_bin', 'MSMS_bin', 21 | 'calc_centroid', 'parse_lig_center', 'calc_sasa', 22 | 'tmalign2', 'parse_rmsd', 'bs_algn', 23 | ] -------------------------------------------------------------------------------- /druglib/ops/dssp/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | import subprocess, re 3 | import os.path as osp 4 | 5 | # User defined Bin path 6 | this_file = osp.abspath(__file__) 7 | this_dir = osp.dirname(this_file) 8 | DSSP_bin = osp.join(this_dir, 'mkdssp') 9 | 10 | # automatic detection 11 | if DSSP_bin is None or not osp.exists(DSSP_bin): 12 | DSSP_bin = 'dssp' 13 | if DSSP_bin == "dssp": 14 | DSSP_bin = "mkdssp" 15 | elif DSSP_bin == "mkdssp": 16 | DSSP_bin = "dssp" 17 | else: 18 | raise NotImplementedError(DSSP_bin) 19 | p = subprocess.Popen( 20 | ["which", DSSP_bin], 21 | universal_newlines = True, 22 | stdout = subprocess.PIPE, 23 | stderr = subprocess.PIPE, 24 | ) 25 | DSSP_bin, err = p.communicate() 26 | DSSP_bin = DSSP_bin.strip() 27 | 28 | if DSSP_bin: 29 | subprocess.run(f'sed -i "s/\r$//" {this_file}', shell=True) 30 | p = subprocess.Popen( 31 | "awk '/DSSP_bin = /{print NR; exit}' " + this_file, 32 | shell=True, 33 | stdout=subprocess.PIPE, 34 | stderr=subprocess.PIPE, 35 | ) 36 | 37 | try: 38 | output, error = p.communicate() 39 | row_number = int(output) 40 | except: 41 | row_number = 8 # current scripts number 42 | 43 | subprocess.run( 44 | f'sed -i "{row_number}c DSSP_bin = ' + f"'{DSSP_bin}'" + f'" {this_file}', 45 | shell=True 46 | ) 47 | 48 | try: 49 | if not DSSP_bin: 50 | raise ValueError('No DSSP detected.') 51 | 52 | version_string = subprocess.check_output( 53 | [DSSP_bin, "--version"], universal_newlines = True 54 | ) 55 | dssp_version = re.search(r"\s*([\d.]+)", version_string).group(1) 56 | except: 57 | # probably invalid DSSP executable file 58 | DSSP_bin = None 59 | dssp_version = '' 60 | 61 | __all__ = ['dssp_version', 'DSSP_bin'] 62 | -------------------------------------------------------------------------------- /druglib/ops/dssp/mkdssp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HBioquant/DiffBindFR/b8bb82027fab5e74fc83ce2a44c0f920a9012ad3/druglib/ops/dssp/mkdssp -------------------------------------------------------------------------------- /druglib/ops/msms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | import subprocess 3 | import os.path as osp 4 | 5 | # User defined Bin path 6 | this_file = osp.abspath(__file__) 7 | this_dir = osp.dirname(this_file) 8 | MSMS_bin = osp.join(this_dir, 'msms') 9 | 10 | 11 | # automatic detection 12 | if MSMS_bin is None or not osp.exists(MSMS_bin): 13 | MSMS_bin = 'msms' 14 | p = subprocess.Popen( 15 | ["which", MSMS_bin], 16 | universal_newlines = True, 17 | stdout = subprocess.PIPE, 18 | stderr = subprocess.PIPE 19 | ) 20 | MSMS_bin, err = p.communicate() 21 | MSMS_bin = MSMS_bin.strip() 22 | 23 | if MSMS_bin: 24 | this_file = osp.abspath(__file__) 25 | subprocess.check_call(f'sed -i "s/\r$//" {this_file}', shell=True) 26 | p = subprocess.Popen( 27 | "awk '/MSMS_bin = /{print NR; exit}' " + this_file, 28 | shell=True, 29 | stdout=subprocess.PIPE, 30 | stderr=subprocess.PIPE, 31 | ) 32 | 33 | try: 34 | output, error = p.communicate() 35 | row_number = int(output) 36 | except: 37 | row_number = 8 # current scripts number 38 | 39 | subprocess.run( 40 | f'sed -i "{row_number}c MSMS_bin = ' + f"'{MSMS_bin}'" + f'" {this_file}', 41 | shell = True, 42 | ) 43 | else: 44 | MSMS_bin = None 45 | 46 | __all__ = ['MSMS_bin'] 47 | -------------------------------------------------------------------------------- /druglib/ops/msms/msms: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HBioquant/DiffBindFR/b8bb82027fab5e74fc83ce2a44c0f920a9012ad3/druglib/ops/msms/msms -------------------------------------------------------------------------------- /druglib/ops/schrodinger/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | 3 | 4 | __all__ = [ 5 | 6 | ] -------------------------------------------------------------------------------- /druglib/ops/smina/smina.static: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HBioquant/DiffBindFR/b8bb82027fab5e74fc83ce2a44c0f920a9012ad3/druglib/ops/smina/smina.static -------------------------------------------------------------------------------- /druglib/ops/utils/which.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | import os 3 | 4 | def isExecutable(path): 5 | """Returns true if *path* is an executable.""" 6 | 7 | return (isinstance(path, str) and os.path.exists(path) and 8 | os.access(path, os.X_OK)) 9 | 10 | def which(program): 11 | """ 12 | This function is based on the example in: 13 | http://stackoverflow.com/questions/377017/ 14 | """ 15 | fpath, fname = os.path.split(program) 16 | fname, fext = os.path.splitext(fname) 17 | 18 | if fpath and isExecutable(program): 19 | return program 20 | else: 21 | if os.name == 'nt' and fext == '': 22 | program += '.exe' 23 | for path in os.environ["PATH"].split(os.pathsep): 24 | path = os.path.join(path, program) 25 | if isExecutable(path): 26 | return path 27 | return None -------------------------------------------------------------------------------- /druglib/resources/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. -------------------------------------------------------------------------------- /druglib/resources/bond_length.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HBioquant/DiffBindFR/b8bb82027fab5e74fc83ce2a44c0f920a9012ad3/druglib/resources/bond_length.txt -------------------------------------------------------------------------------- /druglib/utils/bio_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .read_mol import read_mol 3 | from .fix_protein import fix_protein 4 | from .nxmol import nx2mol, mol2nx 5 | from .pdbqt_utils import pdbqt2pdbblock, pdb2pdbqt, pdbqt2sdf 6 | from .compute_mol_charges import compute_mol_charges, get_atom_partial_charge 7 | from .conformer_utils import ( 8 | generate_multiple_conformers, conformer_generation, 9 | fast_conformer_generation, fast_generate_conformers_onebyone, 10 | simple_conformer_generation, remove_all_hs, get_pos_from_mol, 11 | modify_conformer_torsion_angles, modify_conformer, randomize_lig_pos, 12 | randomize_batchlig_pos, randomize_sc_dihedral, get_ligconf, 13 | update_batchlig_pos, 14 | ) 15 | from .mol_attrs import ( 16 | get_rotatable_bonds, 17 | get_angles, 18 | set_dihedral, 19 | get_dihedral, 20 | set_angle, 21 | get_angle, 22 | set_bond_length, 23 | get_bond_length, 24 | get_mol_dihedrals, 25 | get_multi_mols_dihedrals, 26 | get_mol_angles, 27 | get_multi_mols_angles, 28 | get_mol_bonds, 29 | get_multi_mols_bonds, 30 | mol_with_atom_index, 31 | atom_env, 32 | ) 33 | from .select_pocket import ( 34 | select_bs, select_bs_any, 35 | select_bs_atoms, select_bs_centroid, 36 | ) 37 | 38 | 39 | __all__ = [ 40 | 'read_mol', 'compute_mol_charges', 'get_atom_partial_charge', 'generate_multiple_conformers', 'conformer_generation', 41 | 'fast_conformer_generation', 'fast_generate_conformers_onebyone', 'fix_protein', 42 | 'get_rotatable_bonds', 'get_angles', 'set_dihedral', 'get_dihedral', 'set_angle', 'get_angle', 'set_bond_length', 'get_bond_length', 43 | 'get_mol_dihedrals', 'get_multi_mols_dihedrals', 'get_mol_angles', 'get_multi_mols_angles', 'get_mol_bonds', 'get_multi_mols_bonds', 44 | 'mol_with_atom_index', 'atom_env', 'simple_conformer_generation', 'remove_all_hs', 'get_pos_from_mol', 'modify_conformer_torsion_angles', 45 | 'modify_conformer', 'randomize_lig_pos', 'randomize_batchlig_pos', 'randomize_sc_dihedral', 'get_ligconf', 'select_bs', 'select_bs_any', 46 | 'select_bs_atoms', 'select_bs_centroid', 'update_batchlig_pos', 47 | ] 48 | -------------------------------------------------------------------------------- /druglib/utils/bio_utils/box_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from typing import Union, Tuple 3 | from pymol import cmd 4 | import numpy as np 5 | 6 | 7 | def getbox( 8 | selection: str = "sele", 9 | extending: Union[float, int] = 6.0, 10 | docking_software: str = "vina", 11 | ) -> dict: 12 | # check docking software available 13 | if docking_software not in ['vina', 'ledock', 'smina', 'linf9', 'rbdock', 'gnina']: 14 | raise ValueError(f"your desired {docking_software} can not be available in the present, " 15 | f"while \'vina\', \'ledock\', \'smina\', " 16 | f"\'linf9\', \'rbdock\', \'gnina\' is available now.") 17 | 18 | ([minX, minY, minZ], [maxX, maxY, maxZ]) = cmd.get_extent(selection=selection) 19 | 20 | # expanding the box boundary 21 | minX = minX - float(extending) 22 | minY = minY - float(extending) 23 | minZ = minZ - float(extending) 24 | maxX = maxX + float(extending) 25 | maxY = maxY + float(extending) 26 | maxZ = maxZ + float(extending) 27 | 28 | # get box size and center 29 | SizeX = maxX - minX 30 | SizeY = maxY - minY 31 | SizeZ = maxZ - minZ 32 | CenterX = (minX + maxX) / 2 33 | CenterY = (minY + maxY) / 2 34 | CenterZ = (minZ + maxZ) / 2 35 | 36 | cmd.delete("all") 37 | 38 | output1 = { 39 | 'center_x': CenterX, 40 | 'center_y': CenterY, 41 | 'center_z': CenterZ 42 | }, { 43 | 'size_x' : SizeX, 44 | 'size_y' : SizeY, 45 | 'size_z' : SizeZ 46 | } 47 | output2 = { 48 | 'minX' : minX, 49 | 'maxX' : maxX 50 | }, { 51 | 'minY': minY, 52 | 'maxY': maxY 53 | }, { 54 | 'minZ': minZ, 55 | 'maxZ': maxZ 56 | } 57 | if docking_software == "vina": 58 | return output1 59 | elif docking_software == 'ledock': 60 | return output2 61 | else: 62 | raise NotImplementedError 63 | 64 | def compute_protein_bbox( 65 | prot_coords: np.ndarray, 66 | ) -> np.ndarray: 67 | """ 68 | Compute the protein axis aligned bounding box 69 | Args: 70 | prot_coords: np.ndarray. A numpy array of shape `(N, 3)`, 71 | where `N` is the number of atoms. 72 | Returns: 73 | protein_range: np.ndarray. A numpy array of shape `(3,)`, 74 | where `3` is (x,y,z). 75 | """ 76 | protein_max = np.max(prot_coords, axis=0) 77 | protein_min = np.min(prot_coords, axis=0) 78 | protein_bbox = protein_max - protein_min 79 | return protein_bbox 80 | 81 | def compute_protein_bbox_open3d( 82 | prot_coords: np.ndarray, 83 | ) -> Tuple[np.ndarray, np.ndarray]: 84 | """ 85 | Compute the protein axis aligned bounding box (aabb) and 86 | rotated bounding box (obb) 87 | Args: 88 | prot_coords: np.ndarray. A numpy array of shape `(N, 3)`, 89 | where `N` is the number of atoms. 90 | Returns: 91 | protein_range: np.ndarray. A numpy array of shape `(3,)`, 92 | where `3` is (x,y,z). 93 | """ 94 | try: 95 | import open3D as o3d 96 | except: 97 | raise ImportError("Call :func:`compute_protein_bbox_open3d` needs `pip install open3d`") 98 | 99 | pcd = o3d.geometry.PointCloud() 100 | pcd.points = o3d.utility.Vector3dVector(prot_coords) 101 | # get protein point clouds aabb box as the results from :func:`compute_protein_bbox` 102 | aabb = pcd.get_axis_aligned_bounding_box() 103 | # aabb = np.asarray(aabb.get_box_points())# shape (8, 3) 8 3D points 104 | aabb = aabb.get_extent() 105 | # get protein point clouds obb box 106 | obb = pcd.get_oriented_bounding_box() 107 | # obb = np.asarray(obb.get_box_points())# shape (8, 3) 8 3D points 108 | obb = obb.get_extent() 109 | 110 | return aabb, obb 111 | 112 | -------------------------------------------------------------------------------- /druglib/utils/bio_utils/compute_mol_charges.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | import numpy as np 3 | from rdkit import Chem 4 | from rdkit.Chem import rdPartialCharges 5 | from ..logger import print_log 6 | 7 | 8 | def compute_mol_charges( 9 | mol: Chem.rdchem.Mol, 10 | ) -> None: 11 | """ 12 | Compute molecule :obj:`Chem.rdchem.Mol` Gasteiger charges 13 | Args: 14 | rdkit molecule. 15 | """ 16 | try: 17 | rdPartialCharges.ComputeGasteigerCharges(mol) 18 | except Exception as e: 19 | print_log("Unable to compute Gasteiger charges.") 20 | raise RuntimeError(e) 21 | 22 | def get_atom_partial_charge( 23 | atom: Chem.Atom, 24 | ) -> float: 25 | """ 26 | Get atom :obj:`Chem.Atom` Gasteiger charges 27 | Args: 28 | rdkit atom. 29 | E.g. 30 | >>> from rdkit import Chem 31 | >>> mol = Chem.MolFromSmiles("CCCCCC") 32 | >>> atom = mol.GetAtoms()[0] 33 | >>> get_atom_partial_charge(atom) 34 | """ 35 | if isinstance(atom, Chem.Atom): 36 | try: 37 | value = atom.GetDoubleProp(str("_GasteigerCharge")) 38 | if np.isnan(value) or np.isinf(value): 39 | return 0.0 40 | return float(value) 41 | except KeyError: 42 | return 0.0 43 | else: 44 | raise TypeError(f"Input must be rdkit.Chem.Atom, but got {type(atom)}") 45 | 46 | -------------------------------------------------------------------------------- /druglib/utils/bio_utils/nxmol.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from rdkit import Chem 3 | import networkx as nx 4 | 5 | 6 | def mol2nx( 7 | mol: Chem.rdchem.Mol, 8 | ) -> nx.Graph: 9 | """ 10 | rdkit molecule to convert networkx graph. 11 | Args: 12 | mol: rdkit mol. 13 | Returns: 14 | networkx graph. 15 | """ 16 | G = nx.Graph() 17 | 18 | for atom in mol.GetAtoms(): 19 | G.add_node( 20 | atom.GetIdx(), 21 | atomic_num = atom.GetAtomicNum(), 22 | formal_charge = atom.GetFormalCharge(), 23 | chiral_tag = atom.GetChiralTag(), 24 | hybridization = atom.GetHybridization(), 25 | num_explicit_hs = atom.GetNumExplicitHs(), 26 | is_aromatic = atom.GetIsAromatic(), 27 | ) 28 | for bond in mol.GetBonds(): 29 | G.add_edge( 30 | bond.GetBeginAtomIdx(), 31 | bond.GetEndAtomIdx(), 32 | bond_type = bond.GetBondType(), 33 | ) 34 | 35 | return G 36 | 37 | def nx2mol( 38 | G: nx.Graph 39 | ) -> Chem.rdchem.Mol: 40 | """ 41 | Molecule formatting as networkx graph to rebuild rdkit mol 42 | Args: 43 | G: networkx graph. 44 | Returns: 45 | rdkit mol. 46 | """ 47 | mol = Chem.RWMol() 48 | atomic_nums = nx.get_node_attributes(G, 'atomic_num') 49 | chiral_tags = nx.get_node_attributes(G, 'chiral_tag') 50 | formal_charges = nx.get_node_attributes(G, 'formal_charge') 51 | node_is_aromatics = nx.get_node_attributes(G, 'is_aromatic') 52 | node_hybridizations = nx.get_node_attributes(G, 'hybridization') 53 | num_explicit_hss = nx.get_node_attributes(G, 'num_explicit_hs') 54 | 55 | node_to_idx = {} 56 | for node in G.nodes(): 57 | a = Chem.Atom(atomic_nums[node]) 58 | a.SetChiralTag(chiral_tags[node]) 59 | a.SetFormalCharge(formal_charges[node]) 60 | a.SetIsAromatic(node_is_aromatics[node]) 61 | a.SetHybridization(node_hybridizations[node]) 62 | a.SetNumExplicitHs(num_explicit_hss[node]) 63 | idx = mol.AddAtom(a) 64 | node_to_idx[node] = idx 65 | 66 | bond_types = nx.get_edge_attributes(G, 'bond_type') 67 | for edge in G.edges(): 68 | source, target = edge 69 | idx1 = node_to_idx[source] 70 | idx2 = node_to_idx[target] 71 | bond_type = bond_types[source, target] 72 | mol.AddBond(idx1, idx2, bond_type) 73 | 74 | Chem.SanitizeMol(mol) 75 | 76 | return mol 77 | -------------------------------------------------------------------------------- /druglib/utils/bio_utils/pdbqt_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | import os.path as osp 3 | from rdkit import Chem 4 | from openbabel import pybel 5 | from ..path import mkdir_or_exists 6 | 7 | 8 | def pdbqt2pdbblock( 9 | pdbqt_file: str, 10 | ) -> str: 11 | """ 12 | Extracts the PDB part of a pdbqt file to pdb block. 13 | Strip pdbqt charge information from the provided input. 14 | Args: 15 | pdbqt_file: str. Filename of pdbqt file. 16 | Returns: pdb block, str. 17 | """ 18 | assert osp.exists(pdbqt_file), f"PDBQT File does not exists from {pdbqt_file}" 19 | 20 | with open(pdbqt_file, 'r') as f: 21 | pdbqt = f.readlines() 22 | 23 | pdb_block = '' 24 | for line in pdbqt: 25 | pdb_block += f'{line[:66]}\n' 26 | 27 | return pdb_block 28 | 29 | def pdb2pdbqt( 30 | pdb_file: str, 31 | out_file: str, 32 | ) -> str: 33 | """ 34 | Convert pdb file to pdbqt file. 35 | Write extra the pdbqt terms into the pdb file. 36 | Args: 37 | pdb_file: str. pdb file path. 38 | out_file: str. output pdbqt file path. 39 | Returns: 40 | out_file, str. 41 | **Note that we suggest using vina prepare_receptor and 42 | prepare_ligand to transform pdb file to pdbqt file. 43 | """ 44 | assert osp.exists(pdb_file), f"PDB File does not exists from {pdb_file}" 45 | mkdir_or_exists(osp.dirname(osp.abspath(out_file))) 46 | 47 | mol = Chem.MolFromPDBFile( 48 | pdb_file, 49 | sanitize = True, 50 | removeHs = False, 51 | ) 52 | lines = [line.strip() for line in open(pdb_file).readlines()] 53 | pdbqt_lines = [] 54 | for line in lines: 55 | if 'ROOT' in line or 'ENDROOT' in line or 'TORSDOF' in line: 56 | pdbqt_lines.append(f'{line}\n') 57 | continue 58 | if not line.startswith("ATOM"): 59 | continue 60 | line = line[:66] 61 | atom_index = int(line[6:11]) 62 | atom = mol.GetAtoms()[atom_index - 1] 63 | line = "%s +0.000 %s\n" % (line, atom.GetSymbol().ljust(2)) 64 | pdbqt_lines.append(line) 65 | with open(out_file, 'w') as f: 66 | for line in pdbqt_lines: 67 | f.write(line) 68 | 69 | return out_file 70 | 71 | def pdbqt2sdf( 72 | pdbqt_file: str, 73 | out_file: str, 74 | log_level: int = 0, 75 | ): 76 | """ 77 | A simple implementation about format transformation from pdbqt to sdf file. 78 | Args: 79 | pdbqt_file: str. pdbqt file path. 80 | out_file: str. output sdf file path. 81 | log_level: int. Log level, set 0 to slience warning. 82 | Returns: 83 | out_file, str. 84 | """ 85 | assert osp.exists(pdbqt_file), f"PDBQT File does not exists from {pdbqt_file}" 86 | mkdir_or_exists(osp.dirname(osp.abspath(out_file))) 87 | 88 | pybel.ob.obErrorLog.SetOutputLevel(log_level) 89 | 90 | results = [m for m in pybel.readfile(format = 'pdbqt', filename = pdbqt_file)] 91 | outfile = pybel.Outputfile( 92 | filename = out_file, 93 | format = 'sdf', 94 | overwrite = True 95 | ) 96 | for pose in results: 97 | pose.data.update( 98 | { 99 | 'Pose' : pose.data['MODEL'], 100 | 'Score':pose.data['REMARK'].split()[2] 101 | }) 102 | del pose.data['MODEL'], pose.data['REMARK'], pose.data['TORSDO'] 103 | 104 | outfile.write(pose) 105 | 106 | outfile.close() 107 | 108 | return out_file -------------------------------------------------------------------------------- /druglib/utils/bio_utils/read_mol.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from typing import Optional, Union 3 | import os.path as osp 4 | from rdkit import Chem 5 | from .compute_mol_charges import compute_mol_charges 6 | from .pdbqt_utils import pdbqt2pdbblock 7 | from .conformer_utils import fast_generate_conformers_onebyone 8 | 9 | 10 | def read_mol( 11 | mol_file: Union[str, Chem.rdchem.Mol], 12 | sanitize: bool = False, 13 | calc_charges: bool = False, 14 | remove_hs: bool = False, 15 | assign_chirality: bool = False, 16 | emb_multiple_3d: Optional[int] = None, 17 | ) -> Optional[Chem.rdchem.Mol]: 18 | """ 19 | Read a molecule file in the format of .sdf, .mol2, .pdbqt or .pdb. 20 | Args: 21 | mol_file: str or Mol. Molecule file path or rdkit Mol object. 22 | sanitize: bool, optional. rdkit sanitize molecule. 23 | Default to False. 24 | calc_charges: bool, optional. If True, add Gasteiger charges. 25 | Default to False. 26 | Note that when calculating charges, the mol must be sanitized. 27 | remove_hs: bool, optional. If True, remove the hydrogens. 28 | Default to False. 29 | assign_chirality: bool, optional. If True, inference the chirality. 30 | Default to False. 31 | emb_multiple_3d: int, optional. If int, generate multiple conformers for mol 32 | Returns: 33 | molecule: Chem.rdchem.Mol or None. 34 | """ 35 | if not isinstance(mol_file, Chem.rdchem.Mol): 36 | assert osp.exists(mol_file), f"Ligand file does not exist from {mol_file}." 37 | 38 | if isinstance(mol_file, Chem.rdchem.Mol): 39 | # Here we allow the valid mol input 40 | mol = mol_file 41 | elif mol_file.endswith('.sdf'): 42 | mols = Chem.SDMolSupplier( 43 | mol_file, 44 | sanitize = False, 45 | removeHs = False 46 | ) 47 | # Note that this requires input a single molecule sdf file 48 | # if file saves multiply molecules, it is dangerous for execute 49 | # the next part. 50 | mol = mols[0] 51 | elif mol_file.endswith('.mol2'): 52 | mol = Chem.MolFromMol2File( 53 | mol_file, 54 | sanitize = False, 55 | removeHs = False, 56 | ) 57 | elif mol_file.endswith('.pdb'): 58 | mol = Chem.MolFromPDBFile( 59 | mol_file, 60 | sanitize = False, 61 | removeHs = False, 62 | ) 63 | elif mol_file.endswith('.pdbqt'): 64 | pdbblock = pdbqt2pdbblock(mol_file) 65 | mol = Chem.MolFromPDBBlock( 66 | pdbblock, 67 | sanitize = False, 68 | removeHs = False, 69 | ) 70 | else: 71 | raise ValueError("Current supported mol files include sdf, mol2, pdbqt, pdb, " 72 | f"but got {mol_file.split('.')[-1]}") 73 | 74 | Chem.GetSymmSSSR(mol) 75 | 76 | if emb_multiple_3d is not None: 77 | assert isinstance(emb_multiple_3d, int) and emb_multiple_3d > 0 78 | mol = fast_generate_conformers_onebyone( 79 | mol, num_confs = emb_multiple_3d, 80 | force_field = 'MMFF94s', 81 | ) 82 | 83 | try: 84 | if sanitize: Chem.SanitizeMol(mol) 85 | 86 | if calc_charges: 87 | try: 88 | compute_mol_charges(mol) 89 | except RuntimeError: 90 | pass 91 | 92 | if remove_hs: 93 | mol = Chem.RemoveHs(mol, sanitize = sanitize) 94 | except: 95 | return None 96 | 97 | if assign_chirality: 98 | Chem.AssignStereochemistryFrom3D(mol) 99 | 100 | return mol 101 | -------------------------------------------------------------------------------- /druglib/utils/deprecation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | import functools 3 | import inspect 4 | import warnings 5 | from typing import Optional 6 | 7 | 8 | def deprecated(details: Optional[str] = None, func_name: Optional[str] = None): 9 | def decorator(func): 10 | name = func_name or func.__name__ 11 | 12 | if inspect.isclass(func): 13 | cls = type(func.__name__, (func, ), {}) 14 | cls.__init__ = deprecated(details, name)(func.__init__) 15 | cls.__doc__ = func.__doc__ 16 | return cls 17 | 18 | @functools.wraps(func) 19 | def wrapper(*args, **kwargs): 20 | out = f"'{name}' is deprecated" 21 | if details is not None: 22 | out += f", {details}" 23 | warnings.warn(out) 24 | return func(*args, **kwargs) 25 | 26 | return wrapper 27 | 28 | return decorator -------------------------------------------------------------------------------- /druglib/utils/geometry_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from . import aaframe 3 | from .utils import ( 4 | radian2sincos, radian2sincos_torch, rot_vec_around_x_axis, parse_xrot_angle, 5 | make_rigid_transformation_4x4, residue_frame, residue_frame_torch, 6 | apply_euclidean, apply_inv_euclidean, calc_euclidean_distance_np, 7 | calc_euclidean_distance_torch, unit_vector_np, angle_between_np, 8 | angle_between_torch, uniform_unit_s2, rots_matmul, rot_vec_matmul, 9 | identity_rot_mats, identity_trans, rot_inv, euler_to_rot, rot_to_euler_angles, 10 | standardize_quaternion, normalised_quaternion, quaternion_to_rot, 11 | identity_quats, rot_to_quaternion, quaternions_multiply, quaternion_vec_multiply, 12 | quaternions_multiply_rot, quaternion_invert, random_quaternions, random_quaternion, 13 | random_rotations, random_rotation, apply_quaternion, apply_rotmat, apply_euler, 14 | axis_angle_to_quaternion, quaternion_to_axis_angle, quaternion_to_euler, 15 | euler_to_quaternion, rot_to_axis_angle, axis_angle_to_rot, check_rotation_matrix, 16 | ) 17 | 18 | __all__ = [ 19 | 'aaframe', 'radian2sincos', 'radian2sincos_torch', 'rot_vec_around_x_axis', 'parse_xrot_angle', 20 | 'make_rigid_transformation_4x4', 'residue_frame', 'residue_frame_torch', 'apply_euclidean', 21 | 'apply_inv_euclidean', 'calc_euclidean_distance_np', 'calc_euclidean_distance_torch', 22 | 'unit_vector_np', 'angle_between_np', 'angle_between_torch', 'uniform_unit_s2', 'rots_matmul', 23 | 'rot_vec_matmul', 'identity_rot_mats', 'identity_trans', 'rot_inv', 'euler_to_rot', 24 | 'rot_to_euler_angles', 'standardize_quaternion', 'normalised_quaternion', 'quaternion_to_rot', 25 | 'identity_quats', 'rot_to_quaternion', 'quaternions_multiply', 'quaternion_vec_multiply', 26 | 'quaternions_multiply_rot', 'quaternion_invert', 'random_quaternions', 'random_quaternion', 27 | 'random_rotations', 'random_rotation', 'apply_quaternion', 'apply_rotmat', 'apply_euler', 28 | 'axis_angle_to_quaternion', 'quaternion_to_axis_angle', 'quaternion_to_euler', 29 | 'euler_to_quaternion', 'rot_to_axis_angle', 'axis_angle_to_rot', 'check_rotation_matrix', 30 | ] -------------------------------------------------------------------------------- /druglib/utils/geometry_utils/io.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | import os 3 | import pickle 4 | import lmdb 5 | from typing import Union, List, Sequence 6 | from pathlib import Path 7 | import numpy as np 8 | 9 | 10 | def _load(path: Union[Path, str]): 11 | path = str(path) 12 | return np.load(path) 13 | 14 | 15 | def _save(path: Union[Path, str], obj): 16 | path = str(path) 17 | return np.save(path, obj) 18 | 19 | 20 | def _load_lmdb( 21 | path: Union[Path, str] 22 | ): 23 | env = lmdb.open( 24 | str(path), 25 | map_size=int(1e12), 26 | create=False, 27 | subdir=os.path.isdir(str(path)), 28 | readonly=True, 29 | lock=False, 30 | readahead=False, 31 | meminit=False, 32 | max_readers=256, 33 | ) 34 | 35 | return env 36 | 37 | 38 | def _save_lmdb( 39 | output_file: Union[Path, str], 40 | data_dict: dict, 41 | ): 42 | env = lmdb.open( 43 | str(output_file), 44 | map_size=int(1e12), 45 | create=True, 46 | subdir=False, 47 | readonly=False, 48 | ) 49 | txn = env.begin(write=True) 50 | for k, d in data_dict.items(): 51 | txn.put(str(k).encode('ascii'), pickle.dumps(d)) 52 | txn.commit() 53 | env.close() 54 | 55 | return 56 | 57 | 58 | def _load_lmdb_data( 59 | env: lmdb.Environment, 60 | keys: Union[str, List[str]], 61 | ): 62 | if isinstance(keys, str) or not isinstance(keys, Sequence): 63 | keys = [keys] 64 | 65 | with env.begin() as txn: 66 | KEYS = [k for k in txn.cursor().iternext(values=False)] 67 | 68 | ds = [] 69 | for key in keys: 70 | key = str(key).encode("ascii") 71 | if key not in KEYS: 72 | raise ValueError(f'query index {key.decode()} not in lmdb.') 73 | 74 | with env.begin() as txn: 75 | with txn.cursor() as curs: 76 | d = pickle.loads(curs.get(key)) 77 | if len(keys) == 1: 78 | return d 79 | 80 | ds.append(d) 81 | 82 | return ds -------------------------------------------------------------------------------- /druglib/utils/handlers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .base import BaseFileHandler 3 | from .json_handler import JsonHandler 4 | from .pickle_handler import PickleHandler 5 | from .yaml_handler import YamlHandler 6 | 7 | __all__ = [ 8 | 'BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler', 9 | ] -------------------------------------------------------------------------------- /druglib/utils/handlers/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from abc import ABCMeta, abstractmethod 3 | 4 | class BaseFileHandler(metaclass = ABCMeta): 5 | # `str_like` is a flag to indicate whether the type of file object is 6 | # str-like object or bytes-like object. Pickle only processes bytes-like 7 | # objects but json only processes str-like object. If it is str-like 8 | # object, `StringIO` will be used to process the buffer. 9 | str_like = True 10 | 11 | @abstractmethod 12 | def load_from_fileobj(self, file, **kwargs): 13 | pass 14 | 15 | @abstractmethod 16 | def dump_to_fileobj(self, obj, file, **kwargs): 17 | pass 18 | 19 | @abstractmethod 20 | def dump_to_str(self, obj, **kwargs): 21 | pass 22 | 23 | def load_from_path(self, filepath, mode='r', **kwargs): 24 | with open(filepath, mode) as f: 25 | return self.load_from_fileobj(f, **kwargs) 26 | 27 | def dump_to_path(self, obj, filepath, mode='w', **kwargs): 28 | with open(filepath, mode) as f: 29 | self.dump_to_fileobj(obj, f, **kwargs) 30 | -------------------------------------------------------------------------------- /druglib/utils/handlers/json_handler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | import json 3 | import numpy as np 4 | 5 | from .base import BaseFileHandler 6 | 7 | 8 | def set_default(obj): 9 | """ 10 | Set default json values for non-serializable values. 11 | It helps convert `seT`, `range` and `np.ndarray` data types to list. 12 | It also converts `np.generic` (including `np.int32`, `np.float32`, 13 | etc.) into plain numbers of plain python built-in types. 14 | """ 15 | if isinstance(obj, (set, range)): 16 | return list(obj) 17 | elif isinstance(obj, np.ndarray): 18 | return obj.tolist() 19 | elif isinstance(obj, np.generic): 20 | return obj.item() 21 | raise TypeError(f'{type(obj)} is unsupported for json dump') 22 | 23 | 24 | class JsonHandler(BaseFileHandler): 25 | 26 | def load_from_fileobj(self, file): 27 | return json.load(file) 28 | 29 | def dump_to_fileobj(self, obj, file, **kwargs): 30 | kwargs.setdefault('default', set_default) 31 | json.dump(obj, file, **kwargs) 32 | 33 | def dump_to_str(self, obj, **kwargs): 34 | kwargs.setdefault('default', set_default) 35 | return json.dumps(obj, **kwargs) 36 | -------------------------------------------------------------------------------- /druglib/utils/handlers/pickle_handler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | import pickle 3 | 4 | from .base import BaseFileHandler 5 | 6 | import torch 7 | 8 | class PickleHandler(BaseFileHandler): 9 | 10 | str_like = False 11 | 12 | def load_from_fileobj(self, file, **kwargs): 13 | return pickle.load(file, **kwargs) 14 | 15 | def load_from_path(self, filepath, **kwargs): 16 | return super(PickleHandler, self).load_from_path( 17 | filepath, mode='rb', **kwargs) 18 | 19 | def dump_to_str(self, obj, **kwargs): 20 | kwargs.setdefault('protocol', 2) 21 | return pickle.dumps(obj, **kwargs) 22 | 23 | def dump_to_fileobj(self, obj, file, **kwargs): 24 | kwargs.setdefault('protocol', 2) 25 | pickle.dump(obj, file, **kwargs) 26 | 27 | def dump_to_path(self, obj, filepath, **kwargs): 28 | super(PickleHandler, self).dump_to_path( 29 | obj, filepath, mode='wb', **kwargs) -------------------------------------------------------------------------------- /druglib/utils/handlers/yaml_handler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | import yaml 3 | 4 | try: 5 | from yaml import CLoader as Loader, CDumper as Dumper 6 | except ImportError: 7 | from yaml import Loader, Dumper 8 | 9 | from .base import BaseFileHandler 10 | 11 | 12 | class YamlHandler(BaseFileHandler): 13 | 14 | def load_from_fileobj(self, file, **kwargs): 15 | kwargs.setdefault('Loader', Loader) 16 | return yaml.load(file, **kwargs) 17 | 18 | def dump_to_fileobj(self, obj, file, **kwargs): 19 | kwargs.setdefault('Dumper', Dumper) 20 | yaml.dump(obj, file, **kwargs) 21 | 22 | def dump_to_str(self, obj, **kwargs): 23 | kwargs.setdefault('Dumper', Dumper) 24 | return yaml.dump(obj, **kwargs) 25 | -------------------------------------------------------------------------------- /druglib/utils/obj/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from . import protein_constants as pc 3 | from . import ligand_constants as lc 4 | from .ligand_math import ( 5 | merge_edge, vdw_radius, cov_adj, uff_vdw_param, 6 | make_cov_tensor, make_vdw_param, make_angle_indices 7 | ) 8 | from .ligand import ( 9 | Ligand3D, ligand_parser, reconstruct 10 | ) 11 | from .prot_math import ( 12 | extract_chi_and_template, extract_backbone_template, 13 | build_pdb_from_template, make_torsion_mask 14 | ) 15 | from .prot_fn import ( 16 | aatype_to_seq, ideal_atom_mask, 17 | create_full_prot, write_prot_to_pdb, 18 | ) 19 | from .protein import Protein, pdb_parser 20 | from .complex import PLComplex 21 | 22 | 23 | 24 | __all__ = [ 25 | 'Ligand3D', 'ligand_parser', 'reconstruct', 'merge_edge', 26 | 'vdw_radius', 'cov_adj', 'uff_vdw_param', 'make_cov_tensor', 27 | 'make_vdw_param', 'make_angle_indices', 'make_torsion_mask', 28 | 'aatype_to_seq', 'ideal_atom_mask', 'create_full_prot', 29 | 'write_prot_to_pdb', 30 | 'Protein', 'pdb_parser', 'build_pdb_from_template', 31 | 'extract_backbone_template', 'PLComplex', 'pc', 'lc', 32 | ] -------------------------------------------------------------------------------- /druglib/utils/obj/prot_fn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from typing import Optional, Union 3 | import os, re 4 | import os.path as osp 5 | import numpy as np 6 | 7 | from torch import Tensor 8 | 9 | from . import protein_constants as pc 10 | from .protein import Protein, to_pdb 11 | 12 | 13 | def aatype_to_seq( 14 | aatype: Union[np.ndarray, Tensor], 15 | ): 16 | return ''.join([pc.restypes_with_x[aa] for aa in aatype]) 17 | 18 | def ideal_atom_mask( 19 | prot: Protein, 20 | ) -> np.ndarray: 21 | """ 22 | Computes an ideal atom mask. 23 | `Protein.atom_mask` typically is defined according to the atoms that are 24 | reported in the PDB. This function computes a mask according to heavy atoms 25 | that should be present in the given sequence of amino acids. 26 | Args: 27 | prot: `Protein` whose fields are `numpy.ndarray` objects. 28 | Returns: 29 | An ideal atom37 mask. 30 | """ 31 | return pc.STANDARD_ATOM_MASK[prot.aatype] 32 | 33 | def create_full_prot( 34 | atom37: np.ndarray, 35 | atom37_mask: np.ndarray, 36 | aatype: np.ndarray, 37 | b_factors: Optional[np.ndarray] = None, 38 | ) -> Protein: 39 | assert (atom37.ndim == 3) and (atom37.shape[-2:] == (37, 3)), \ 40 | f"Expected shape (N_res, 37. 3), but got {atom37.shape}." 41 | n = atom37.shape[0] 42 | residue_index = np.arange(n) 43 | chain_index = np.zeros(n) 44 | if b_factors is None: 45 | b_factors = np.zeros([n, 37]) 46 | 47 | return Protein( 48 | name = 'MDLP', # MDLDruglib Protein 49 | atom_positions = atom37, 50 | atom_mask = atom37_mask, 51 | aatype = aatype, 52 | residue_index = residue_index, 53 | chain_index = chain_index, 54 | b_factors = b_factors) 55 | 56 | def _search_max_index( 57 | file_path, 58 | ) -> int: 59 | _dir = osp.dirname(file_path) 60 | _name = osp.basename(file_path) 61 | exists = [f for f in os.listdir(_dir) if _name in f] 62 | idxs = [0] 63 | for ex in exists: 64 | find = re.findall(r'_(\d+).pdb', ex) 65 | if len(find) > 0: 66 | idxs.append(int(find[0])) 67 | return max(idxs) 68 | 69 | def write_prot_to_pdb( 70 | file_path: str, 71 | pos37_repr: np.ndarray, 72 | aatype: np.ndarray, 73 | b_factors: Optional[np.ndarray] = None, 74 | overwrite: bool = False, 75 | no_indexing: bool = False, 76 | ) -> str: 77 | """ 78 | Write aatype and atom37 representation to pdb file 79 | (Support multiple protein positions recording). 80 | """ 81 | save_path = file_path 82 | if not no_indexing: 83 | max_existing_idx = _search_max_index(file_path) if overwrite else 0 84 | save_path = file_path.replace('.pdb', f'_{max_existing_idx + 1}.pdb') 85 | 86 | with open(save_path, 'w') as f: 87 | if pos37_repr.ndim == 4: 88 | for t, pos37 in enumerate(pos37_repr): 89 | atom37_mask = np.sum(np.abs(pos37), axis = -1) > 1e-7 90 | prot = create_full_prot( 91 | pos37, atom37_mask, 92 | aatype = aatype, b_factors = b_factors) 93 | pdb_prot = to_pdb(prot, model = t + 1, add_end = False) 94 | f.write(pdb_prot) 95 | elif pos37_repr.ndim == 3: 96 | atom37_mask = np.sum(np.abs(pos37_repr), axis = -1) > 1e-7 97 | prot = create_full_prot( 98 | pos37_repr, atom37_mask, 99 | aatype=aatype, b_factors=b_factors) 100 | pdb_prot = to_pdb(prot, model = 1, add_end = False) 101 | f.write(pdb_prot) 102 | else: 103 | raise ValueError(f'Invalid positions shape {pos37_repr.shape}. ' 104 | f'(M, N, 37, 3) or (N, 37, 3) are allowed.') 105 | f.write('END\n') 106 | 107 | return save_path 108 | 109 | 110 | -------------------------------------------------------------------------------- /druglib/utils/parrots_jit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | import os 3 | 4 | from .parrots_wrapper import TORCH_VERSION 5 | 6 | parrots_jit_option = os.getenv('PARROTS_JIT_OPTION') 7 | 8 | if TORCH_VERSION == 'parrots' and parrots_jit_option == 'ON': 9 | from parrots.jit import pat as jit 10 | else: 11 | 12 | def jit( 13 | func = None 14 | ): 15 | def wrapper(func): 16 | def wrapper_inner(*args, **kwargs): 17 | return func(*args, **kwargs) 18 | 19 | return wrapper_inner 20 | if func is None: 21 | return wrapper 22 | else: 23 | return func 24 | 25 | if TORCH_VERSION == 'parrots': 26 | from parrots.utils.tester import skip_no_elena 27 | else: 28 | def skip_no_elena(func): 29 | def wrapper(*args, **kwargs): 30 | return func(*args, **kwargs) 31 | return wrapper 32 | -------------------------------------------------------------------------------- /druglib/utils/parrots_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | # Copy from https://github.com/open-mmlab/mmcv/blob/master/mmcv/mmcv/utils/parrots_wrapper.py 3 | from functools import partial 4 | 5 | import torch 6 | 7 | TORCH_VERSION = torch.__version__ 8 | 9 | 10 | def is_rocm_pytorch() -> bool: 11 | is_rocm = False 12 | if TORCH_VERSION != 'parrots': 13 | try: 14 | from torch.utils.cpp_extension import ROCM_HOME 15 | is_rocm = True if ((torch.version.hip is not None) and 16 | (ROCM_HOME is not None)) else False 17 | except ImportError: 18 | pass 19 | return is_rocm 20 | 21 | 22 | def _get_cuda_home(): 23 | if TORCH_VERSION == 'parrots': 24 | from parrots.utils.build_extension import CUDA_HOME 25 | else: 26 | if is_rocm_pytorch(): 27 | from torch.utils.cpp_extension import ROCM_HOME 28 | CUDA_HOME = ROCM_HOME 29 | else: 30 | from torch.utils.cpp_extension import CUDA_HOME 31 | return CUDA_HOME 32 | 33 | 34 | def get_build_config(): 35 | if TORCH_VERSION == 'parrots': 36 | from parrots.config import get_build_info 37 | return get_build_info() 38 | else: 39 | return torch.__config__.show() 40 | 41 | 42 | def _get_conv(): 43 | if TORCH_VERSION == 'parrots': 44 | from parrots.nn.modules.conv import _ConvNd, _ConvTransposeMixin 45 | else: 46 | from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin 47 | return _ConvNd, _ConvTransposeMixin 48 | 49 | 50 | def _get_dataloader(): 51 | if TORCH_VERSION == 'parrots': 52 | from torch.utils.data import DataLoader, PoolDataLoader 53 | else: 54 | from torch.utils.data import DataLoader 55 | PoolDataLoader = DataLoader 56 | return DataLoader, PoolDataLoader 57 | 58 | 59 | def _get_extension(): 60 | if TORCH_VERSION == 'parrots': 61 | from parrots.utils.build_extension import BuildExtension, Extension 62 | CppExtension = partial(Extension, cuda=False) 63 | CUDAExtension = partial(Extension, cuda=True) 64 | else: 65 | from torch.utils.cpp_extension import (BuildExtension, CppExtension, 66 | CUDAExtension) 67 | return BuildExtension, CppExtension, CUDAExtension 68 | 69 | 70 | def _get_pool(): 71 | if TORCH_VERSION == 'parrots': 72 | from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd, 73 | _AdaptiveMaxPoolNd, _AvgPoolNd, 74 | _MaxPoolNd) 75 | else: 76 | from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd, 77 | _AdaptiveMaxPoolNd, _AvgPoolNd, 78 | _MaxPoolNd) 79 | return _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd 80 | 81 | 82 | def _get_norm(): 83 | if TORCH_VERSION == 'parrots': 84 | from parrots.nn.modules.batchnorm import _BatchNorm, _InstanceNorm 85 | SyncBatchNorm_ = torch.nn.SyncBatchNorm2d 86 | else: 87 | from torch.nn.modules.instancenorm import _InstanceNorm 88 | from torch.nn.modules.batchnorm import _BatchNorm 89 | SyncBatchNorm_ = torch.nn.SyncBatchNorm 90 | return _BatchNorm, _InstanceNorm, SyncBatchNorm_ 91 | 92 | 93 | _ConvNd, _ConvTransposeMixin = _get_conv() 94 | DataLoader, PoolDataLoader = _get_dataloader() 95 | BuildExtension, CppExtension, CUDAExtension = _get_extension() 96 | _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool() 97 | _BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm() 98 | 99 | 100 | class SyncBatchNorm(SyncBatchNorm_): 101 | 102 | def _check_input_dim(self, input): 103 | if TORCH_VERSION == 'parrots': 104 | if input.dim() < 2: 105 | raise ValueError( 106 | f'expected at least 2D input (got {input.dim()}D input)') 107 | else: 108 | super()._check_input_dim(input) 109 | -------------------------------------------------------------------------------- /druglib/utils/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | from .msc import * 3 | from .tensor_extension import * 4 | from .graph import * 5 | from .isom_graph import * 6 | 7 | -------------------------------------------------------------------------------- /druglib/utils/trace.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | import warnings, torch 3 | from .version_utils import digit_version 4 | from .parrots_wrapper import TORCH_VERSION 5 | 6 | def is_jit_tracing() -> bool: 7 | if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) >= digit_version('1.6.0'): 8 | on_trace = torch.jit.is_tracing() 9 | # In PyTorch 1.6, torch,jit.is_tracing has a bug. 10 | # Refers ti https://github.com/pytorch/pytorch/issues/42448 11 | if isinstance(on_trace, bool): 12 | return on_trace 13 | else: 14 | return torch._C._is_tracing() 15 | else: 16 | warnings.warn( 17 | 'torch.jit.is_tracing is only supported after v1.6.0. ' 18 | 'Therefore is_tracing returns False automatically. Please ' 19 | 'set on_trace manually if you are using trace.', UserWarning 20 | ) 21 | return False -------------------------------------------------------------------------------- /druglib/utils/version_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | import os 3 | import subprocess 4 | import warnings 5 | from typing import Tuple, AnyStr, Optional 6 | from packaging.version import parse 7 | 8 | 9 | def digit_version( 10 | version_str:str, 11 | length:int = 4, 12 | ) -> Tuple[int]: 13 | """ 14 | Convert a version string into a tuple of integers 15 | This method is usually used for comparing two versions. 16 | For pre-release version: alpha < beta < rc. 17 | Args: 18 | version_str:str: The version string. 19 | length:int: The maximum number of version levels. Defaults to 4, 20 | Returns: 21 | Tuple[int]: The version info in digits [integers]. 22 | """ 23 | assert 'parrots' not in version_str 24 | version = parse(version_str) 25 | assert version.release, f"failed to parse version {version}" 26 | release = list(version.release)#'0.24.4.rc' -> version.release -> (0, 24, 4) 27 | release = release[:length] 28 | if len(release) < length: 29 | release += [0] * (length - len(release)) 30 | if version.is_prerelease:#'0.24.4.rc' -> version.is_prerelease -> True, '0.24.4' -> version.is_prerelease -> False 31 | mapping = {'a':-3, 'b':-2, 'rc':-1} 32 | val = -4 33 | # version.pre can be None 34 | if version.pre:#'0.24.4.rc' -> version.pre -> ('rc', 0), '0.24.4.rc2' -> version.pre -> ('rc', 2) 35 | if version.pre[0] not in mapping: 36 | warnings.warn(f'Unknown prerelease version {version.pre[0]}, ' 37 | f'version checking may go wrong.') 38 | else: 39 | val = mapping[version.pre[0]] 40 | else: 41 | release.extend([val, 0]) 42 | 43 | elif version.is_postrelease:#'0.24.4-2022' -> version.post -> 2022, '0.24.4.rc2' -> version.post -> None 44 | release.extend([1, version.post]) 45 | else: 46 | release.extend([0, 0]) 47 | 48 | return tuple(release) 49 | 50 | 51 | def _minimal_ext_cmd( 52 | cmd, 53 | ) -> AnyStr: 54 | env = {} 55 | for k in ['SYSTEMROOT', 'PATH', 'HOME']: 56 | v = os.environ.get(k) 57 | if v is not None: 58 | env[k] = v 59 | env['LANGUAGE'] = 'C'#LANGUAGE is used on win32 60 | env['LANG'] = 'C' 61 | env['LC_ALL'] = 'C' 62 | out = subprocess.Popen( 63 | cmd, stdout=subprocess.PIPE, env = env 64 | ).communicate()[0] 65 | 66 | return out 67 | 68 | 69 | def get_git_hash( 70 | fallback:str = 'unknown', 71 | digits:Optional[int] = None, 72 | ) -> str: 73 | """ 74 | Get the git hash of the current repo. 75 | 76 | Args: 77 | fallback:str:: The fallback string when git hash is 78 | unavailable. Defaults to 'unknown'. 79 | digits:Optional[int]: Kept digits of the hash. Defaults to None, 80 | meaning all digits are kept. 81 | 82 | Returns: 83 | str: Git commit hash. 84 | """ 85 | 86 | if digits is not None and not isinstance(digits, int): 87 | raise TypeError('digits must be None or an integer') 88 | 89 | try: 90 | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) 91 | sha = out.strip().decode('ascii') 92 | if digits is not None: 93 | sha = sha[:digits] 94 | except OSError: 95 | sha = fallback 96 | 97 | return sha -------------------------------------------------------------------------------- /druglib/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MDLDrugLib. All rights reserved. 2 | """Version info define""" 3 | 4 | __version__ = "0.2.0" 5 | 6 | 7 | def parse_version_info(version_str): 8 | version_info = [] 9 | for x in version_str.split('.'): 10 | if x.isdigit(): 11 | version_info.append(int(x)) 12 | elif x.find("rc") != -1: 13 | patch_version = x.split('rc') 14 | version_info.append(int(patch_version[0])) 15 | version_info.append(f'rc{patch_version[1]}') 16 | return tuple(version_info) 17 | 18 | version_info = parse_version_info(__version__) 19 | 20 | __all__ = ['__version__', 'parse_version_info', 'version_info'] -------------------------------------------------------------------------------- /env.yaml: -------------------------------------------------------------------------------- 1 | name: diffbindfr 2 | channels: 3 | - pyg 4 | - pytorch 5 | - nvidia 6 | - conda-forge 7 | - bioconda 8 | - defaults 9 | dependencies: 10 | - conda-forge::python=3.9.18 11 | - conda-forge::setuptools=59.5.0 12 | - conda-forge::pip 13 | - conda-forge::openmm=7.7 14 | - conda-forge::pdbfixer 15 | - conda-forge::cudatoolkit==11.7.* 16 | - conda-forge::pymol-open-source 17 | - conda-forge::ambertools 18 | - conda-forge::mpi4py 19 | - conda-forge::openbabel 20 | - conda-forge::openmmforcefields 21 | - pytorch::pytorch=1.13.1 22 | - netcdf4 23 | - git 24 | - pip: 25 | - --extra-index-url https://download.pytorch.org/whl/cu117 26 | - --find-links https://pytorch-geometric.com/whl/torch-1.13.1+cu117.html 27 | - scikit-learn==1.4.1 28 | - scipy==1.12.0 29 | - torch-cluster==1.6.0+pt113cu117 30 | - torch-scatter==2.1.0+pt113cu117 31 | - torch-sparse==0.6.16+pt113cu117 32 | - torch-spline-conv==1.2.1+pt113cu117 33 | - torchmetrics==0.11.0 34 | - torch-geometric==2.2.0 35 | - pip: 36 | - yapf==0.32.0 37 | - addict==2.4.0 38 | - prettytable==3.6.0 39 | - easydict==1.10 40 | - python-constraint==1.4.0 41 | - pint 42 | - omegaconf 43 | - tqdm 44 | - typing-extensions==4.5.0 45 | - networkx 46 | - opencv-python 47 | - matplotlib 48 | - seaborn 49 | - requests 50 | - numba 51 | - e3nn==0.5.1 52 | - pandas 53 | - prody<=2.4.0 54 | - biopython==1.80 55 | - rdkit==2023.03.2 56 | - MDAnalysis 57 | - meeko==0.5.0 58 | - einops==0.5.0 59 | - py3Dmol==2.0.4 60 | - nglview==3.0.3 61 | - joblib==1.1.0 62 | - pandarallel==1.6.3 63 | - spyrmsd==0.5.2 64 | - deepsmiles 65 | - selfies 66 | - fcd_torch 67 | - tensorboardX 68 | - tensorboard 69 | - torchmetrics 70 | - torchtyping 71 | - typeguard 72 | - triton==2.0.0 73 | - gpustat 74 | - dm-tree 75 | - lmdb==1.4.1 76 | - wandb==0.13.3 77 | - ml_collections 78 | - jupyter-client==8.1.0 79 | - jupyter-core==5.3.0 80 | - ipykernel==6.19.2 81 | - ipython==8.12.0 82 | - filelock==3.9.0 83 | - posebusters 84 | - prefetch_generator 85 | - mdtraj 86 | - docutils==0.17 87 | -------------------------------------------------------------------------------- /examples/forward/mols/BDB12915.sdf: -------------------------------------------------------------------------------- 1 | BDB12915 2 | SciTegic10041214373D 3 | 4 | 40 43 0 0 0 0 999 V2000 5 | 1.1970 -0.6200 0.3500 C 0 0 6 | 1.2740 0.7370 0.0550 C 0 0 7 | 0.1070 1.4600 -0.2050 C 0 0 8 | -1.1600 0.8430 -0.1760 C 0 0 9 | -2.3660 1.6470 -0.4470 C 0 0 10 | -2.5320 2.3060 -1.6770 C 0 0 11 | -3.6640 3.0860 -1.9290 C 0 0 12 | -4.6510 3.2190 -0.9550 C 0 0 13 | -4.5060 2.5720 0.2700 C 0 0 14 | -3.3720 1.7940 0.5220 C 0 0 15 | -0.0520 -1.2470 0.3810 C 0 0 16 | -1.2200 -0.5340 0.1210 C 0 0 17 | -0.1550 -2.6750 0.6910 C 0 0 18 | -2.5850 -2.4960 0.4340 C 0 0 19 | -1.5070 -3.2510 0.7070 C 0 0 20 | -2.4530 -1.1400 0.1300 O 0 0 21 | 0.8280 -3.3650 0.9280 O 0 0 22 | -5.0610 -2.0980 0.1280 C 0 0 23 | -4.3260 -4.3490 0.7410 C 0 0 24 | -5.2090 -4.8940 -0.3890 C 0 0 25 | -5.9150 -2.7270 -0.9790 C 0 0 26 | -3.9040 -2.9690 0.4290 N 0 0 27 | -6.3430 -4.0470 -0.6160 O 0 0 28 | 2.1040 -1.1840 0.5540 H 0 0 29 | 2.2400 1.2360 0.0300 H 0 0 30 | 0.1900 2.5240 -0.4230 H 0 0 31 | -1.7760 2.2120 -2.4550 H 0 0 32 | -3.7740 3.5890 -2.8860 H 0 0 33 | -5.5290 3.8300 -1.1490 H 0 0 34 | -5.2720 2.6740 1.0340 H 0 0 35 | -3.2760 1.3040 1.4890 H 0 0 36 | -1.5520 -4.3020 0.9500 H 0 0 37 | -4.7680 -1.0920 -0.1830 H 0 0 38 | -5.6580 -1.9970 1.0430 H 0 0 39 | -3.4880 -5.0340 0.8890 H 0 0 40 | -4.8980 -4.3270 1.6760 H 0 0 41 | -5.5860 -5.8860 -0.1200 H 0 0 42 | -4.6420 -4.9920 -1.3220 H 0 0 43 | -6.8110 -2.1220 -1.1460 H 0 0 44 | -5.3620 -2.7770 -1.9250 H 0 0 45 | 1 2 1 0 46 | 1 11 2 0 47 | 1 24 1 0 48 | 2 3 2 0 49 | 2 25 1 0 50 | 3 4 1 0 51 | 3 26 1 0 52 | 4 5 1 0 53 | 4 12 2 0 54 | 5 6 1 0 55 | 5 10 2 0 56 | 6 7 2 0 57 | 6 27 1 0 58 | 7 8 1 0 59 | 7 28 1 0 60 | 8 9 2 0 61 | 8 29 1 0 62 | 9 10 1 0 63 | 9 30 1 0 64 | 10 31 1 0 65 | 11 12 1 0 66 | 11 13 1 0 67 | 12 16 1 0 68 | 13 15 1 0 69 | 13 17 2 0 70 | 14 15 2 0 71 | 14 16 1 0 72 | 14 22 1 0 73 | 15 32 1 0 74 | 18 21 1 0 75 | 18 22 1 0 76 | 18 33 1 0 77 | 18 34 1 0 78 | 19 20 1 0 79 | 19 22 1 0 80 | 19 35 1 0 81 | 19 36 1 0 82 | 20 23 1 0 83 | 20 37 1 0 84 | 20 38 1 0 85 | 21 23 1 0 86 | 21 39 1 0 87 | 21 40 1 0 88 | M END 89 | > 90 | BDB12915 91 | 92 | > 93 | Phosphoinositide 3-Kinase (PI3K), gamma Chain A 94 | 95 | > 96 | false 97 | 98 | > 99 | P48736 100 | 101 | > 102 | Homo sapiens 103 | 104 | > 105 | 1E7V(96%) 106 | 107 | > 108 | 1600 109 | -------------------------------------------------------------------------------- /examples/forward/mols/ZINC01921759.sdf: -------------------------------------------------------------------------------- 1 | ZINC01921759 2 | SciTegic06141217593D 3 | 4 | 29 30 0 0 0 0 999 V2000 5 | 0.1957 4.4421 0.0238 C 0 0 0 0 0 0 0 0 0 0 0 0 6 | 1.7466 3.5013 0.0037 S 0 0 0 0 0 0 0 0 0 0 0 0 7 | 1.2242 1.8185 0.0001 C 0 0 0 0 0 0 0 0 0 0 0 0 8 | -0.0169 1.3968 0.0097 N 0 0 0 0 0 0 0 0 0 0 0 0 9 | 0.0021 -0.0041 0.0020 N 0 0 0 0 0 0 0 0 0 0 0 0 10 | 1.2535 -0.3922 -0.0120 C 0 0 0 0 0 0 0 0 0 0 0 0 11 | 2.0419 0.7234 -0.0136 N 0 0 0 0 0 0 0 0 0 0 0 0 12 | 3.5067 0.7426 -0.0270 C 0 0 0 0 0 0 0 0 0 0 0 0 13 | 3.9958 0.7559 -1.4524 C 0 0 0 0 0 0 0 0 0 0 0 0 14 | 4.2412 -0.4350 -2.1100 C 0 0 0 0 0 0 0 0 0 0 0 0 15 | 4.6861 -0.4229 -3.4188 C 0 0 0 0 0 0 0 0 0 0 0 0 16 | 4.8935 0.7803 -4.0671 C 0 0 0 0 0 0 0 0 0 0 0 0 17 | 4.6532 1.9712 -3.4077 C 0 0 0 0 0 0 0 0 0 0 0 0 18 | 4.2046 1.9590 -2.1003 C 0 0 0 0 0 0 0 0 0 0 0 0 19 | 5.3735 0.7933 -5.4673 N 0 0 0 0 0 0 0 0 0 0 0 0 20 | 5.5566 1.8534 -6.0385 O 0 0 0 0 0 0 0 0 0 0 0 0 21 | 5.5852 -0.2560 -6.0483 O 0 0 0 0 0 0 0 0 0 0 0 0 22 | 1.7001 -1.7073 -0.0238 N 0 0 0 0 0 0 0 0 0 0 0 0 23 | -0.3924 4.2518 0.9337 H 0 0 0 0 0 0 0 0 0 0 0 0 24 | -0.4101 4.2611 -0.8763 H 0 0 0 0 0 0 0 0 0 0 0 0 25 | 0.4616 5.5095 0.0267 H 0 0 0 0 0 0 0 0 0 0 0 0 26 | 3.8910 -0.1529 0.4834 H 0 0 0 0 0 0 0 0 0 0 0 0 27 | 3.8665 1.6433 0.4920 H 0 0 0 0 0 0 0 0 0 0 0 0 28 | 4.0825 -1.3923 -1.5919 H 0 0 0 0 0 0 0 0 0 0 0 0 29 | 4.8751 -1.3706 -3.9443 H 0 0 0 0 0 0 0 0 0 0 0 0 30 | 4.8190 2.9286 -3.9234 H 0 0 0 0 0 0 0 0 0 0 0 0 31 | 4.0138 2.9067 -1.5754 H 0 0 0 0 0 0 0 0 0 0 0 0 32 | 0.9672 -2.4311 -0.0211 H 0 0 0 0 0 0 0 0 0 0 0 0 33 | 2.6990 -1.9583 -0.0346 H 0 0 0 0 0 0 0 0 0 0 0 0 34 | 1 2 1 0 0 0 0 35 | 2 3 1 0 0 0 0 36 | 3 7 1 0 0 0 0 37 | 3 4 2 0 0 0 0 38 | 4 5 1 0 0 0 0 39 | 5 6 2 0 0 0 0 40 | 6 7 1 0 0 0 0 41 | 6 18 1 0 0 0 0 42 | 7 8 1 0 0 0 0 43 | 8 9 1 0 0 0 0 44 | 9 14 2 0 0 0 0 45 | 9 10 1 0 0 0 0 46 | 10 11 2 0 0 0 0 47 | 11 12 1 0 0 0 0 48 | 12 13 2 0 0 0 0 49 | 12 15 1 0 0 0 0 50 | 13 14 1 0 0 0 0 51 | 15 16 2 0 0 0 0 52 | 15 17 2 0 0 0 0 53 | 1 19 1 0 0 0 0 54 | 1 20 1 0 0 0 0 55 | 1 21 1 0 0 0 0 56 | 8 22 1 0 0 0 0 57 | 8 23 1 0 0 0 0 58 | 10 24 1 0 0 0 0 59 | 11 25 1 0 0 0 0 60 | 13 26 1 0 0 0 0 61 | 14 27 1 0 0 0 0 62 | 18 28 1 0 0 0 0 63 | 18 29 1 0 0 0 0 64 | M END 65 | > 66 | ZINC01921759 67 | 68 | > 69 | BDB50313061 70 | 71 | > 72 | 0.798 73 | 74 | > 75 | decoy 76 | -------------------------------------------------------------------------------- /examples/forward/mols/ZINC01993838.sdf: -------------------------------------------------------------------------------- 1 | ZINC01993838 2 | SciTegic06141217593D 3 | 4 | 39 41 0 0 0 0 999 V2000 5 | 8.1526 6.6142 1.2212 C 0 0 0 0 0 0 0 0 0 0 0 0 6 | 6.9436 5.7761 0.8940 C 0 0 0 0 0 0 0 0 0 0 0 0 7 | 5.6830 6.3506 0.8961 C 0 0 0 0 0 0 0 0 0 0 0 0 8 | 4.5655 5.5983 0.5965 C 0 0 0 0 0 0 0 0 0 0 0 0 9 | 4.6981 4.2504 0.2883 C 0 0 0 0 0 0 0 0 0 0 0 0 10 | 5.9745 3.6726 0.2923 C 0 0 0 0 0 0 0 0 0 0 0 0 11 | 7.0944 4.4371 0.5931 C 0 0 0 0 0 0 0 0 0 0 0 0 12 | 6.0863 2.2261 -0.0248 C 0 0 0 0 0 0 0 0 0 0 0 0 13 | 7.2401 1.6578 -0.2805 N 0 0 0 0 0 0 0 0 0 0 0 0 14 | 7.2976 0.2746 -0.5778 O 0 5 0 0 0 0 0 0 0 0 0 0 15 | 4.8223 1.4672 -0.0300 C 0 0 0 0 0 0 0 0 0 0 0 0 16 | 3.6590 2.1616 -0.0207 C 0 0 0 0 0 0 0 0 0 0 0 0 17 | 3.6136 3.5047 -0.0126 O 0 0 0 0 0 0 0 0 0 0 0 0 18 | 2.3879 1.4113 -0.0131 C 0 0 0 0 0 0 0 0 0 0 0 0 19 | 2.3994 0.0144 -0.0212 C 0 0 0 0 0 0 0 0 0 0 0 0 20 | 1.2112 -0.6847 -0.0136 C 0 0 0 0 0 0 0 0 0 0 0 0 21 | 0.0021 -0.0041 0.0020 C 0 0 0 0 0 0 0 0 0 0 0 0 22 | -0.0168 1.3892 0.0097 C 0 0 0 0 0 0 0 0 0 0 0 0 23 | 1.1705 2.0965 0.0021 C 0 0 0 0 0 0 0 0 0 0 0 0 24 | -1.2037 2.0531 0.0189 O 0 0 0 0 0 0 0 0 0 0 0 0 25 | -1.1483 3.4810 0.0264 C 0 0 0 0 0 0 0 0 0 0 0 0 26 | -1.1664 -0.6979 0.0094 O 0 0 0 0 0 0 0 0 0 0 0 0 27 | 3.2015 6.2390 0.6041 C 0 0 0 0 0 0 0 0 0 0 0 0 28 | 8.8185 6.5078 0.3522 H 0 0 0 0 0 0 0 0 0 0 0 0 29 | 8.6359 6.1020 2.0662 H 0 0 0 0 0 0 0 0 0 0 0 0 30 | 7.9176 7.6652 1.4450 H 0 0 0 0 0 0 0 0 0 0 0 0 31 | 5.5716 7.4176 1.1391 H 0 0 0 0 0 0 0 0 0 0 0 0 32 | 8.0945 3.9790 0.5916 H 0 0 0 0 0 0 0 0 0 0 0 0 33 | 4.8237 0.3673 -0.0411 H 0 0 0 0 0 0 0 0 0 0 0 0 34 | 3.3565 -0.5277 -0.0336 H 0 0 0 0 0 0 0 0 0 0 0 0 35 | 1.2217 -1.7846 -0.0199 H 0 0 0 0 0 0 0 0 0 0 0 0 36 | 1.1570 3.1964 0.0080 H 0 0 0 0 0 0 0 0 0 0 0 0 37 | -0.5719 3.7728 -0.8639 H 0 0 0 0 0 0 0 0 0 0 0 0 38 | -0.5635 3.7631 0.9143 H 0 0 0 0 0 0 0 0 0 0 0 0 39 | -2.1454 3.9455 0.0336 H 0 0 0 0 0 0 0 0 0 0 0 0 40 | -1.1079 -1.2886 0.8017 H 0 0 0 0 0 0 0 0 0 0 0 0 41 | 3.2763 7.0791 -0.1020 H 0 0 0 0 0 0 0 0 0 0 0 0 42 | 3.0932 6.6690 1.6108 H 0 0 0 0 0 0 0 0 0 0 0 0 43 | 2.3839 5.5480 0.3512 H 0 0 0 0 0 0 0 0 0 0 0 0 44 | 1 2 1 0 0 0 0 45 | 2 7 2 0 0 0 0 46 | 2 3 1 0 0 0 0 47 | 3 4 2 0 0 0 0 48 | 4 5 1 0 0 0 0 49 | 4 23 1 0 0 0 0 50 | 5 13 1 0 0 0 0 51 | 5 6 2 0 0 0 0 52 | 6 7 1 0 0 0 0 53 | 6 8 1 0 0 0 0 54 | 8 9 2 0 0 0 0 55 | 8 11 1 0 0 0 0 56 | 9 10 1 0 0 0 0 57 | 11 12 2 0 0 0 0 58 | 12 13 1 0 0 0 0 59 | 12 14 1 0 0 0 0 60 | 14 19 2 0 0 0 0 61 | 14 15 1 0 0 0 0 62 | 15 16 2 0 0 0 0 63 | 16 17 1 0 0 0 0 64 | 17 18 2 0 0 0 0 65 | 17 22 1 0 0 0 0 66 | 18 19 1 0 0 0 0 67 | 18 20 1 0 0 0 0 68 | 20 21 1 0 0 0 0 69 | 1 24 1 0 0 0 0 70 | 1 25 1 0 0 0 0 71 | 1 26 1 0 0 0 0 72 | 3 27 1 0 0 0 0 73 | 7 28 1 0 0 0 0 74 | 11 29 1 0 0 0 0 75 | 15 30 1 0 0 0 0 76 | 16 31 1 0 0 0 0 77 | 19 32 1 0 0 0 0 78 | 21 33 1 0 0 0 0 79 | 21 34 1 0 0 0 0 80 | 21 35 1 0 0 0 0 81 | 22 36 1 0 0 0 0 82 | 23 37 1 0 0 0 0 83 | 23 38 1 0 0 0 0 84 | 23 39 1 0 0 0 0 85 | M CHG 1 10 -1 86 | M END 87 | > 88 | ZINC01993838 89 | 90 | > 91 | BDB50189752 92 | 93 | > 94 | 0.818 95 | 96 | > 97 | decoy 98 | -------------------------------------------------------------------------------- /examples/forward/mols/ZINC02029177.sdf: -------------------------------------------------------------------------------- 1 | ZINC02029177 2 | SciTegic06141217593D 3 | 4 | 36 38 0 0 0 0 999 V2000 5 | 1.7970 6.2239 1.9878 C 0 0 0 0 0 0 0 0 0 0 0 0 6 | 2.2177 5.7348 0.6004 C 0 0 0 0 0 0 0 0 0 0 0 0 7 | 2.2267 4.2280 0.5804 C 0 0 0 0 0 0 0 0 0 0 0 0 8 | 1.1738 3.5491 0.0361 C 0 0 0 0 0 0 0 0 0 0 0 0 9 | 1.1844 2.1540 0.0182 C 0 0 0 0 0 0 0 0 0 0 0 0 10 | 2.2718 1.4438 0.5571 C 0 0 0 0 0 0 0 0 0 0 0 0 11 | 3.3338 2.1494 1.1065 C 0 0 0 0 0 0 0 0 0 0 0 0 12 | 3.3090 3.5340 1.1166 C 0 0 0 0 0 0 0 0 0 0 0 0 13 | 4.3488 4.2212 1.6543 O 0 5 0 0 0 0 0 0 0 0 0 0 14 | 2.2797 0.0963 0.5395 O 0 0 0 0 0 0 0 0 0 0 0 0 15 | 1.2881 -0.6360 0.0259 C 0 0 0 0 0 0 0 0 0 0 0 0 16 | 0.1792 -0.0799 -0.5269 C 0 0 0 0 0 0 0 0 0 0 0 0 17 | 0.0657 1.3883 -0.5608 C 0 0 0 0 0 0 0 0 0 0 0 0 18 | -0.9038 1.9432 -1.0485 O 0 0 0 0 0 0 0 0 0 0 0 0 19 | -0.8965 -0.9343 -1.0860 C 0 0 0 0 0 0 0 0 0 0 0 0 20 | -0.6813 -1.6485 -2.2634 C 0 0 0 0 0 0 0 0 0 0 0 0 21 | -1.6838 -2.4458 -2.7769 C 0 0 0 0 0 0 0 0 0 0 0 0 22 | -2.9051 -2.5284 -2.1306 C 0 0 0 0 0 0 0 0 0 0 0 0 23 | -3.1255 -1.8174 -0.9636 C 0 0 0 0 0 0 0 0 0 0 0 0 24 | -2.1279 -1.0214 -0.4390 C 0 0 0 0 0 0 0 0 0 0 0 0 25 | -4.2770 -3.6168 -2.8442 Br 0 0 0 0 0 0 0 0 0 0 0 0 26 | 1.4043 -2.1382 0.0606 C 0 0 1 0 0 0 0 0 0 0 0 0 27 | 1.5604 -2.6240 -1.2420 F 0 0 0 0 0 0 0 0 0 0 0 0 28 | 2.5143 -2.5024 0.8304 F 0 0 0 0 0 0 0 0 0 0 0 0 29 | 0.2460 -2.6825 0.6257 F 0 0 0 0 0 0 0 0 0 0 0 0 30 | 2.5083 5.8522 2.7401 H 0 0 0 0 0 0 0 0 0 0 0 0 31 | 0.7893 5.8488 2.2200 H 0 0 0 0 0 0 0 0 0 0 0 0 32 | 1.7902 7.3238 2.0031 H 0 0 0 0 0 0 0 0 0 0 0 0 33 | 1.5057 6.1067 -0.1511 H 0 0 0 0 0 0 0 0 0 0 0 0 34 | 3.2254 6.1099 0.3681 H 0 0 0 0 0 0 0 0 0 0 0 0 35 | 0.3207 4.1002 -0.3865 H 0 0 0 0 0 0 0 0 0 0 0 0 36 | 4.1926 1.6101 1.5326 H 0 0 0 0 0 0 0 0 0 0 0 0 37 | 0.2859 -1.5768 -2.7824 H 0 0 0 0 0 0 0 0 0 0 0 0 38 | -1.5121 -3.0167 -3.7014 H 0 0 0 0 0 0 0 0 0 0 0 0 39 | -4.0976 -1.8871 -0.4535 H 0 0 0 0 0 0 0 0 0 0 0 0 40 | -2.3042 -0.4565 0.4883 H 0 0 0 0 0 0 0 0 0 0 0 0 41 | 1 2 1 0 0 0 0 42 | 2 3 1 0 0 0 0 43 | 3 8 2 0 0 0 0 44 | 3 4 1 0 0 0 0 45 | 4 5 2 0 0 0 0 46 | 5 13 1 0 0 0 0 47 | 5 6 1 0 0 0 0 48 | 6 7 2 0 0 0 0 49 | 6 10 1 0 0 0 0 50 | 7 8 1 0 0 0 0 51 | 8 9 1 0 0 0 0 52 | 10 11 1 0 0 0 0 53 | 11 12 2 0 0 0 0 54 | 11 22 1 0 0 0 0 55 | 12 13 1 0 0 0 0 56 | 12 15 1 0 0 0 0 57 | 13 14 2 0 0 0 0 58 | 15 20 2 0 0 0 0 59 | 15 16 1 0 0 0 0 60 | 16 17 2 0 0 0 0 61 | 17 18 1 0 0 0 0 62 | 18 19 2 0 0 0 0 63 | 18 21 1 0 0 0 0 64 | 19 20 1 0 0 0 0 65 | 22 23 1 0 0 0 0 66 | 22 24 1 0 0 0 0 67 | 22 25 1 0 0 0 0 68 | 1 26 1 0 0 0 0 69 | 1 27 1 0 0 0 0 70 | 1 28 1 0 0 0 0 71 | 2 29 1 0 0 0 0 72 | 2 30 1 0 0 0 0 73 | 4 31 1 0 0 0 0 74 | 7 32 1 0 0 0 0 75 | 16 33 1 0 0 0 0 76 | 17 34 1 0 0 0 0 77 | 19 35 1 0 0 0 0 78 | 20 36 1 0 0 0 0 79 | M CHG 1 9 -1 80 | M END 81 | > 82 | ZINC02029177 83 | 84 | > 85 | BDB50262462 86 | 87 | > 88 | 0.806 89 | 90 | > 91 | decoy 92 | -------------------------------------------------------------------------------- /examples/forward/mols/ZINC04165102.sdf: -------------------------------------------------------------------------------- 1 | ZINC04165102 2 | SciTegic06141217593D 3 | 4 | 38 42 0 0 1 0 999 V2000 5 | 2.2229 1.0444 -0.0091 C 0 0 0 0 0 0 0 0 0 0 0 0 6 | 1.3122 2.0789 0.0042 C 0 0 0 0 0 0 0 0 0 0 0 0 7 | -0.0233 1.7795 0.0146 C 0 0 0 0 0 0 0 0 0 0 0 0 8 | -0.0001 -0.0057 0.0057 S 0 0 0 0 0 0 0 0 0 0 0 0 9 | 1.7702 -0.2352 -0.0112 C 0 0 0 0 0 0 0 0 0 0 0 0 10 | -1.1840 2.6965 0.0244 C 0 0 0 0 0 0 0 0 0 0 0 0 11 | -2.4081 2.4830 -0.6200 C 0 0 0 0 0 0 0 0 0 0 0 0 12 | -3.2098 3.5587 -0.3706 C 0 0 0 0 0 0 0 0 0 0 0 0 13 | -2.5156 4.4182 0.4044 N 0 0 0 0 0 0 0 0 0 0 0 0 14 | -1.2503 3.8606 0.6454 N 0 0 0 0 0 0 0 0 0 0 0 0 15 | -4.4714 3.6603 -0.8586 O 0 0 0 0 0 0 0 0 0 0 0 0 16 | -4.8354 2.8011 -1.8325 C 0 0 0 0 0 0 0 0 0 0 0 0 17 | -4.1306 1.6959 -2.1619 C 0 0 0 0 0 0 0 0 0 0 0 0 18 | -2.8531 1.3125 -1.4579 C 0 0 2 0 0 0 0 0 0 0 0 0 19 | -3.0992 0.1178 -0.5727 C 0 0 0 0 0 0 0 0 0 0 0 0 20 | -4.0479 0.1865 0.4308 C 0 0 0 0 0 0 0 0 0 0 0 0 21 | -4.2759 -0.9100 1.2456 C 0 0 0 0 0 0 0 0 0 0 0 0 22 | -3.5474 -2.0814 1.0517 C 0 0 0 0 0 0 0 0 0 0 0 0 23 | -2.5982 -2.1439 0.0450 C 0 0 0 0 0 0 0 0 0 0 0 0 24 | -2.3795 -1.0468 -0.7694 C 0 0 0 0 0 0 0 0 0 0 0 0 25 | -1.1947 -1.1298 -2.0355 Cl 0 0 0 0 0 0 0 0 0 0 0 0 26 | -3.9531 -3.0048 1.9714 O 0 0 0 0 0 0 0 0 0 0 0 0 27 | -5.2191 -2.5269 2.4635 C 0 0 0 0 0 0 0 0 0 0 0 0 28 | -5.1381 -1.1002 2.2870 O 0 0 0 0 0 0 0 0 0 0 0 0 29 | -4.6143 0.8587 -3.2182 C 0 0 0 0 0 0 0 0 0 0 0 0 30 | -4.9980 0.1945 -4.0562 N 0 0 0 0 0 0 0 0 0 0 0 0 31 | -5.9886 3.0642 -2.5289 N 0 0 0 0 0 0 0 0 0 0 0 0 32 | 3.3029 1.2532 -0.0179 H 0 0 0 0 0 0 0 0 0 0 0 0 33 | 1.6522 3.1250 0.0064 H 0 0 0 0 0 0 0 0 0 0 0 0 34 | 2.3460 -1.1724 -0.0208 H 0 0 0 0 0 0 0 0 0 0 0 0 35 | -0.5014 4.2863 1.2100 H 0 0 0 0 0 0 0 0 0 0 0 0 36 | -2.0766 1.0536 -2.1928 H 0 0 0 0 0 0 0 0 0 0 0 0 37 | -4.6219 1.1126 0.5821 H 0 0 0 0 0 0 0 0 0 0 0 0 38 | -2.0180 -3.0660 -0.1072 H 0 0 0 0 0 0 0 0 0 0 0 0 39 | -5.3306 -2.7717 3.5301 H 0 0 0 0 0 0 0 0 0 0 0 0 40 | -6.0416 -2.9381 1.8599 H 0 0 0 0 0 0 0 0 0 0 0 0 41 | -6.5646 3.8920 -2.3194 H 0 0 0 0 0 0 0 0 0 0 0 0 42 | -6.2807 2.4179 -3.2758 H 0 0 0 0 0 0 0 0 0 0 0 0 43 | 1 5 2 0 0 0 0 44 | 1 2 1 0 0 0 0 45 | 2 3 2 0 0 0 0 46 | 3 4 1 0 0 0 0 47 | 3 6 1 0 0 0 0 48 | 4 5 1 0 0 0 0 49 | 6 7 2 0 0 0 0 50 | 6 10 1 0 0 0 0 51 | 7 14 1 0 0 0 0 52 | 7 8 1 0 0 0 0 53 | 8 11 1 0 0 0 0 54 | 8 9 2 0 0 0 0 55 | 9 10 1 0 0 0 0 56 | 11 12 1 0 0 0 0 57 | 12 13 2 0 0 0 0 58 | 12 27 1 0 0 0 0 59 | 13 14 1 0 0 0 0 60 | 13 25 1 0 0 0 0 61 | 14 15 1 0 0 0 0 62 | 15 20 2 0 0 0 0 63 | 15 16 1 0 0 0 0 64 | 16 17 2 0 0 0 0 65 | 17 24 1 0 0 0 0 66 | 17 18 1 0 0 0 0 67 | 18 19 2 0 0 0 0 68 | 18 22 1 0 0 0 0 69 | 19 20 1 0 0 0 0 70 | 20 21 1 0 0 0 0 71 | 22 23 1 0 0 0 0 72 | 23 24 1 0 0 0 0 73 | 25 26 3 0 0 0 0 74 | 1 28 1 0 0 0 0 75 | 2 29 1 0 0 0 0 76 | 5 30 1 0 0 0 0 77 | 10 31 1 0 0 0 0 78 | 14 32 1 0 0 0 0 79 | 16 33 1 0 0 0 0 80 | 19 34 1 0 0 0 0 81 | 23 35 1 0 0 0 0 82 | 23 36 1 0 0 0 0 83 | 27 37 1 0 0 0 0 84 | 27 38 1 0 0 0 0 85 | M END 86 | > 87 | ZINC04165102 88 | 89 | > 90 | BDB50315438 91 | 92 | > 93 | 0.788 94 | 95 | > 96 | decoy 97 | -------------------------------------------------------------------------------- /examples/reverse/ligand_1.sdf: -------------------------------------------------------------------------------- 1 | 3fur_ligand 2 | 3 | Created by X-TOOL on Fri Sep 26 17:34:44 2014 4 | 43 46 0 0 0 0 0 0 0 0999 V2000 5 | 2.1080 5.0210 27.2770 Cl 0 0 0 1 0 1 6 | 3.7200 4.9400 28.0550 C 0 0 0 1 0 3 7 | 3.6880 4.9240 29.4520 C 0 0 0 2 0 3 8 | 4.8650 4.8410 30.1890 C 0 0 0 1 0 3 9 | 4.7320 4.8240 31.9820 Cl 0 0 0 1 0 1 10 | 6.0920 4.7980 29.5130 C 0 0 0 2 0 3 11 | 6.1210 4.8310 28.1070 C 0 0 0 2 0 3 12 | 4.9390 4.8990 27.3460 C 0 0 0 1 0 3 13 | 5.0330 4.8850 25.7120 S 0 0 0 1 0 4 14 | 4.2050 3.8090 25.2180 O 0 0 0 1 0 1 15 | 6.4030 4.6910 25.3230 O 0 0 0 1 0 1 16 | 4.4560 6.2160 25.0500 N 0 0 0 2 0 3 17 | 4.9750 7.4640 25.1130 C 0 0 0 1 0 3 18 | 4.1680 8.4970 24.6460 C 0 0 0 2 0 3 19 | 4.6280 9.7950 24.6660 C 0 0 0 1 0 3 20 | 3.5620 11.0820 24.0410 Cl 0 0 0 1 0 1 21 | 5.8880 10.0820 25.1740 C 0 0 0 1 0 3 22 | 6.6860 9.0610 25.6370 C 0 0 0 1 0 3 23 | 8.2940 9.4900 26.2750 Cl 0 0 0 1 0 1 24 | 6.2420 7.7410 25.6070 C 0 0 0 2 0 3 25 | 6.3780 11.3610 25.2170 O 0 0 0 1 0 2 26 | 6.3470 12.1650 26.3470 C 0 0 0 1 0 3 27 | 6.1810 11.6570 27.6440 C 0 0 0 2 0 3 28 | 6.1820 12.5630 28.7240 C 0 0 0 1 0 3 29 | 6.0270 12.1070 30.0290 C 0 0 0 2 0 3 30 | 6.5120 13.5400 26.1820 C 0 0 0 2 0 3 31 | 6.5030 14.3810 27.2240 N 0 0 0 1 0 2 32 | 6.3420 13.9330 28.4850 C 0 0 0 1 0 3 33 | 6.3440 14.8480 29.5370 C 0 0 0 2 0 3 34 | 6.1800 14.3840 30.8330 C 0 0 0 2 0 3 35 | 6.0270 13.0150 31.0820 C 0 0 0 2 0 3 36 | 2.7356 4.9768 29.9670 H 0 0 0 1 0 1 37 | 7.0181 4.7393 30.0733 H 0 0 0 1 0 1 38 | 7.0772 4.8034 27.5971 H 0 0 0 1 0 1 39 | 3.6101 6.1165 24.5260 H 0 0 0 1 0 1 40 | 3.1763 8.2798 24.2661 H 0 0 0 1 0 1 41 | 6.8784 6.9399 25.9652 H 0 0 0 1 0 1 42 | 6.0550 10.5935 27.8117 H 0 0 0 1 0 1 43 | 5.9066 11.0476 30.2243 H 0 0 0 1 0 1 44 | 6.6512 13.9386 25.1836 H 0 0 0 1 0 1 45 | 6.4720 15.9070 29.3443 H 0 0 0 1 0 1 46 | 6.1702 15.0859 31.6590 H 0 0 0 1 0 1 47 | 5.9081 12.6614 32.0998 H 0 0 0 1 0 1 48 | 1 2 1 0 0 2 49 | 2 3 4 0 0 1 50 | 2 8 4 0 0 1 51 | 3 4 4 0 0 1 52 | 4 5 1 0 0 2 53 | 4 6 4 0 0 1 54 | 6 7 4 0 0 1 55 | 7 8 4 0 0 1 56 | 8 9 1 0 0 2 57 | 9 10 2 0 0 2 58 | 9 11 2 0 0 2 59 | 9 12 1 0 0 2 60 | 12 13 1 0 0 2 61 | 13 14 4 0 0 1 62 | 13 20 4 0 0 1 63 | 14 15 4 0 0 1 64 | 15 16 1 0 0 2 65 | 15 17 4 0 0 1 66 | 17 18 4 0 0 1 67 | 17 21 1 0 0 2 68 | 18 19 1 0 0 2 69 | 18 20 4 0 0 1 70 | 21 22 1 0 0 2 71 | 22 23 4 0 0 1 72 | 22 26 4 0 0 1 73 | 23 24 4 0 0 1 74 | 24 25 4 0 0 1 75 | 24 28 4 0 0 1 76 | 25 31 4 0 0 1 77 | 26 27 4 0 0 1 78 | 27 28 4 0 0 1 79 | 28 29 4 0 0 1 80 | 29 30 4 0 0 1 81 | 30 31 4 0 0 1 82 | 3 32 1 0 0 2 83 | 6 33 1 0 0 2 84 | 7 34 1 0 0 2 85 | 12 35 1 0 0 2 86 | 14 36 1 0 0 2 87 | 20 37 1 0 0 2 88 | 23 38 1 0 0 2 89 | 25 39 1 0 0 2 90 | 26 40 1 0 0 2 91 | 29 41 1 0 0 2 92 | 30 42 1 0 0 2 93 | 31 43 1 0 0 2 94 | M END 95 | > 96 | C21H12N2O3SCl4 97 | 98 | > 99 | 514.1 100 | 101 | > 102 | 5 103 | 104 | > 105 | 3 106 | 107 | > 108 | 6.74 109 | 110 | $$$$ 111 | -------------------------------------------------------------------------------- /examples/reverse/receptors/3mhw_protein_crystal.mol2: -------------------------------------------------------------------------------- 1 | @MOLECULE 2 | 3mhw_ligand 3 | 16 17 1 4 | SMALL 5 | USER_CHARGES 6 | 7 | 8 | @ATOM 9 | 1 C1 -29.5790 -17.6000 9.0430 C.ar 1 ABV 0.0000 10 | 2 N1 -30.4030 -19.8040 11.6350 N.2 1 ABV 0.0000 11 | 3 S1 -31.6130 -19.1420 9.4990 S.3 1 ABV 0.0000 12 | 4 C2 -28.4240 -17.0880 9.4690 C.ar 1 ABV 0.0000 13 | 5 N2 -32.4300 -20.9140 11.2400 N.pl3 1 ABV 0.0000 14 | 6 C3 -27.8510 -17.4600 10.6270 C.ar 1 ABV 0.0000 15 | 7 C4 -28.4280 -18.4020 11.4040 C.ar 1 ABV 0.0000 16 | 8 C5 -29.5950 -18.9180 11.0100 C.ar 1 ABV 0.0000 17 | 9 C6 -30.1590 -18.4920 9.8400 C.ar 1 ABV 0.0000 18 | 10 C7 -31.4920 -20.0710 10.8950 C.2 1 ABV 0.0000 19 | 11 HC1 -29.9929 -17.2850 8.0965 H 1 ABV 0.0000 20 | 12 HC2 -27.9118 -16.3443 8.8766 H 1 ABV 0.0000 21 | 13 H2_1 -33.2222 -21.0647 10.6319 H 1 ABV 0.0000 22 | 14 H2_2 -32.3624 -21.4145 12.1146 H 1 ABV 0.0000 23 | 15 HC3 -26.9234 -17.0048 10.9412 H 1 ABV 0.0000 24 | 16 HC4 -27.9456 -18.7188 12.3169 H 1 ABV 0.0000 25 | @BOND 26 | 1 1 4 ar 27 | 2 1 9 ar 28 | 3 1 11 1 29 | 4 2 8 1 30 | 5 2 10 2 31 | 6 3 9 1 32 | 7 3 10 1 33 | 8 4 6 ar 34 | 9 4 12 1 35 | 10 5 10 1 36 | 11 5 13 1 37 | 12 5 14 1 38 | 13 6 7 ar 39 | 14 6 15 1 40 | 15 7 8 ar 41 | 16 7 16 1 42 | 17 8 9 ar 43 | @SUBSTRUCTURE 44 | 1 ABV 1 GROUP 0 U **** 0 ROOT 45 | -------------------------------------------------------------------------------- /examples/reverse/receptors/3mhw_protein_crystal.sdf: -------------------------------------------------------------------------------- 1 | 3mhw_ligand 2 | 3D 3 | Schrodinger Suite 2021-3. 4 | 16 17 0 0 1 0 999 V2000 5 | -29.5790 -17.6000 9.0430 C 0 0 0 0 0 0 6 | -30.4030 -19.8040 11.6350 N 0 0 0 0 0 0 7 | -31.6130 -19.1420 9.4990 S 0 0 0 0 0 0 8 | -28.4240 -17.0880 9.4690 C 0 0 0 0 0 0 9 | -32.4300 -20.9140 11.2400 N 0 0 0 0 0 0 10 | -27.8510 -17.4600 10.6270 C 0 0 0 0 0 0 11 | -28.4280 -18.4020 11.4040 C 0 0 0 0 0 0 12 | -29.5950 -18.9180 11.0100 C 0 0 0 0 0 0 13 | -30.1590 -18.4920 9.8400 C 0 0 0 0 0 0 14 | -31.4920 -20.0710 10.8950 C 0 0 0 0 0 0 15 | -29.9929 -17.2850 8.0965 H 0 0 0 0 0 0 16 | -27.9118 -16.3443 8.8766 H 0 0 0 0 0 0 17 | -33.2222 -21.0647 10.6319 H 0 0 0 0 0 0 18 | -32.3624 -21.4145 12.1146 H 0 0 0 0 0 0 19 | -26.9234 -17.0048 10.9412 H 0 0 0 0 0 0 20 | -27.9456 -18.7188 12.3169 H 0 0 0 0 0 0 21 | 1 4 2 0 0 0 22 | 1 9 1 0 0 0 23 | 1 11 1 0 0 0 24 | 2 8 1 0 0 0 25 | 2 10 2 0 0 0 26 | 3 9 1 0 0 0 27 | 3 10 1 0 0 0 28 | 4 6 1 0 0 0 29 | 4 12 1 0 0 0 30 | 5 10 1 0 0 0 31 | 5 13 1 0 0 0 32 | 5 14 1 0 0 0 33 | 6 7 2 0 0 0 34 | 6 15 1 0 0 0 35 | 7 8 1 0 0 0 36 | 7 16 1 0 0 0 37 | 8 9 2 0 0 0 38 | M END 39 | > 40 | 0 41 | 42 | > 43 | 1039 44 | 45 | > 46 | Structure1039 47 | 48 | > 49 | THE COMPLEX CRYSTAL STRUCTURE OF UROKIANSE AND 2-AMINOBENZOTHIAZOLE 50 | 51 | > 52 | 3MHW 53 | 54 | > 55 | 121.276 56 | 57 | > 58 | 121.276 59 | 60 | > 61 | 43.136 62 | 63 | > 64 | 90 65 | 66 | > 67 | 90 68 | 69 | > 70 | 120 71 | 72 | > 73 | H 3 74 | 75 | > 76 | 9 77 | 78 | > 79 | HYDROLASE 80 | 81 | > 82 | 09-APR-10 83 | 84 | > 85 | 3.30 86 | 87 | > 88 | 0.205 89 | 90 | > 91 | 0.234 92 | 93 | > 94 | 1.45 95 | 96 | > 97 | X-RAY DIFFRACTION 98 | 99 | > 100 | 4.6 101 | 102 | > 103 | U 104 | 105 | > 106 | 1.000000 0.000000 0.000000 0.000000;0.000000 1.000000 0.000000 0.000000;0.000000 0.000000 1.000000 0.000000 107 | 108 | > 109 | D:\msc\VS\DEKOIS_2.0x\upa\upa_prot 110 | 111 | > 112 | 3mhw.pdb 113 | 114 | > 115 | 1 116 | 117 | > 118 | 1037 119 | 120 | > 121 | 1 122 | 123 | > 124 | 2021-3 125 | 126 | > 127 | 1 128 | 129 | > 130 | 1 131 | 132 | > 133 | 1 134 | 135 | > 136 | 1 137 | 138 | > 139 | eJxdkLESwiAMhl8lx9wqIBXbbk6unqPH0KudRHDAwev13SVQzSED95HkT/4wM8E6uMYjWtVUwE0FiHtCTXggbH/YcEJBKAl3hAoReZ2Xm8q1aQpguSAbuVUZUCTR/xL9rRBkOpsUZL2oSLnkXdIaxdiEaYpE8zXfRMnM0E/8PXwucS8MsksYwgTPyQ02vHOOw30c7PbhbQ+n+ujdDUb/cqED3sM53syY5QOP2ldb 140 | 141 | > 142 | 1 143 | 144 | > 145 | 1 146 | 147 | > 148 | 100 149 | 150 | $$$$ 151 | -------------------------------------------------------------------------------- /images/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HBioquant/DiffBindFR/b8bb82027fab5e74fc83ce2a44c0f920a9012ad3/images/arch.png -------------------------------------------------------------------------------- /openfold/__init__.py: -------------------------------------------------------------------------------- 1 | # from . import model 2 | from . import utils 3 | from . import np 4 | from . import resources 5 | 6 | __all__ = ["utils", "np", "data", "resources"] 7 | -------------------------------------------------------------------------------- /openfold/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HBioquant/DiffBindFR/b8bb82027fab5e74fc83ce2a44c0f920a9012ad3/openfold/data/__init__.py -------------------------------------------------------------------------------- /openfold/np/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HBioquant/DiffBindFR/b8bb82027fab5e74fc83ce2a44c0f920a9012ad3/openfold/np/__init__.py -------------------------------------------------------------------------------- /openfold/resources/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HBioquant/DiffBindFR/b8bb82027fab5e74fc83ce2a44c0f920a9012ad3/openfold/resources/__init__.py -------------------------------------------------------------------------------- /openfold/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HBioquant/DiffBindFR/b8bb82027fab5e74fc83ce2a44c0f920a9012ad3/openfold/utils/__init__.py -------------------------------------------------------------------------------- /openfold/utils/tensor_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from functools import partial 17 | import logging 18 | from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional 19 | 20 | import torch 21 | import torch.nn as nn 22 | 23 | 24 | def add(m1, m2, inplace): 25 | # The first operation in a checkpoint can't be in-place, but it's 26 | # nice to have in-place addition during inference. Thus... 27 | if(not inplace): 28 | m1 = m1 + m2 29 | else: 30 | m1 += m2 31 | 32 | return m1 33 | 34 | 35 | def permute_final_dims(tensor: torch.Tensor, inds: List[int]): 36 | zero_index = -1 * len(inds) 37 | first_inds = list(range(len(tensor.shape[:zero_index]))) 38 | return tensor.permute(first_inds + [zero_index + i for i in inds]) 39 | 40 | 41 | def flatten_final_dims(t: torch.Tensor, no_dims: int): 42 | return t.reshape(t.shape[:-no_dims] + (-1,)) 43 | 44 | 45 | def masked_mean(mask, value, dim, eps=1e-4): 46 | mask = mask.expand(*value.shape) 47 | return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) 48 | 49 | 50 | def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64): 51 | boundaries = torch.linspace( 52 | min_bin, max_bin, no_bins - 1, device=pts.device 53 | ) 54 | dists = torch.sqrt( 55 | torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1) 56 | ) 57 | return torch.bucketize(dists, boundaries) 58 | 59 | 60 | def dict_multimap(fn, dicts): 61 | first = dicts[0] 62 | new_dict = {} 63 | for k, v in first.items(): 64 | all_v = [d[k] for d in dicts] 65 | if type(v) is dict: 66 | new_dict[k] = dict_multimap(fn, all_v) 67 | else: 68 | new_dict[k] = fn(all_v) 69 | 70 | return new_dict 71 | 72 | 73 | def one_hot(x, v_bins): 74 | reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) 75 | diffs = x[..., None] - reshaped_bins 76 | am = torch.argmin(torch.abs(diffs), dim=-1) 77 | return nn.functional.one_hot(am, num_classes=len(v_bins)).float() 78 | 79 | 80 | def batched_gather(data, inds, dim=0, no_batch_dims=0): 81 | ranges = [] 82 | for i, s in enumerate(data.shape[:no_batch_dims]): 83 | r = torch.arange(s) 84 | r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) 85 | ranges.append(r) 86 | 87 | remaining_dims = [ 88 | slice(None) for _ in range(len(data.shape) - no_batch_dims) 89 | ] 90 | remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds 91 | ranges.extend(remaining_dims) 92 | return data[ranges] 93 | 94 | 95 | # With tree_map, a poor man's JAX tree_map 96 | def dict_map(fn, dic, leaf_type): 97 | new_dict = {} 98 | for k, v in dic.items(): 99 | if type(v) is dict: 100 | new_dict[k] = dict_map(fn, v, leaf_type) 101 | else: 102 | new_dict[k] = tree_map(fn, v, leaf_type) 103 | 104 | return new_dict 105 | 106 | 107 | def tree_map(fn, tree, leaf_type): 108 | if isinstance(tree, dict): 109 | return dict_map(fn, tree, leaf_type) 110 | elif isinstance(tree, list): 111 | return [tree_map(fn, x, leaf_type) for x in tree] 112 | elif isinstance(tree, tuple): 113 | return tuple([tree_map(fn, x, leaf_type) for x in tree]) 114 | elif isinstance(tree, leaf_type): 115 | return fn(tree) 116 | else: 117 | print(type(tree)) 118 | raise ValueError("Not supported") 119 | 120 | 121 | tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) 122 | -------------------------------------------------------------------------------- /requirements/optional.txt: -------------------------------------------------------------------------------- 1 | python-constraint 2 | pint 3 | deepsmiles 4 | selfies 5 | fcd_torch 6 | meeko==0.5.0 7 | triton==2.0.0 8 | ml_collections 9 | netcdf4 -------------------------------------------------------------------------------- /requirements/requirements.txt: -------------------------------------------------------------------------------- 1 | -r runtime.txt 2 | -r optional.txt -------------------------------------------------------------------------------- /requirements/runtime.txt: -------------------------------------------------------------------------------- 1 | openmm>=7.7 2 | pdbfixer 3 | mpi4py 4 | openmmforcefields 5 | torch 6 | torch-cluster 7 | torch-scatter 8 | torch-sparse 9 | torch-spline-conv 10 | torch-geometric 11 | pymol 12 | yapf==0.32.0 13 | addict==2.4.0 14 | prettytable==3.6.0 15 | easydict==1.10 16 | omegaconf 17 | tqdm 18 | typing-extensions 19 | networkx 20 | opencv-python 21 | matplotlib 22 | seaborn 23 | requests 24 | pandas 25 | numba 26 | scikit-learn>=1.1.3 27 | scipy>=1.12.0 28 | e3nn 29 | prody<=2.4.0 30 | biopython==1.80 31 | rdkit<=2023.03.2 32 | MDAnalysis 33 | einops 34 | py3Dmol 35 | nglview 36 | joblib==1.1.0 37 | pandarallel==1.6.3 38 | spyrmsd==0.5.2 39 | tensorboardX 40 | tensorboard 41 | torchmetrics 42 | torchtyping 43 | typeguard 44 | gpustat 45 | dm-tree 46 | lmdb==1.4.1 47 | wandb==0.13.3 48 | filelock 49 | posebusters 50 | prefetch_generator 51 | mdtraj 52 | docutils --------------------------------------------------------------------------------