├── RosENet ├── __init__.py ├── network │ ├── evaluate.py │ ├── predict.py │ ├── network.py │ ├── input.py │ └── utils.py ├── rosetta │ ├── __init__.py │ └── rosetta.py ├── __main__.py ├── storage │ ├── .ropeproject │ │ ├── history │ │ ├── objectdb │ │ ├── globalnames │ │ └── config.py │ └── storage.py ├── preprocessing │ ├── .ropeproject │ │ ├── history │ │ ├── objectdb │ │ ├── globalnames │ │ └── config.py │ ├── __init__.py │ ├── make_protein_pdb.py │ ├── make_pdbqt.py │ ├── make_ligand_params_pdb.py │ ├── make_ligand_mol2.py │ ├── make_ligand_mol2_renamed.py │ ├── step.py │ ├── minimize_rosetta.py │ ├── make_complex_pdb.py │ ├── preprocessing.py │ ├── preprocess_vina.py │ └── compute_rosetta_energy.py ├── postprocessing │ ├── __init__.py │ ├── postprocessing.py │ └── .ropeproject │ │ └── config.py ├── voxelization │ ├── apbs.sh │ ├── __init__.py │ ├── apbs.in │ ├── interpolation.py │ ├── filter.py │ ├── utils.py │ ├── apbs.py │ └── voxelizers.py ├── utils.py ├── settings.py ├── static │ ├── flags_relax.txt │ ├── relax.xml │ ├── dock.xml │ ├── dock_relax.xml │ └── dock_relax2.xml ├── objects │ ├── dataset.py │ ├── file.py │ ├── pdb.py │ └── model.py ├── models │ ├── kdeep.py │ ├── large_kdeep.py │ └── resnet.py ├── constants.py └── clui.py ├── test_dataset ├── labels └── 10gs │ └── 10gs_ligand.mol2 ├── .gitignore ├── logo.png ├── used_pdbs.pdf ├── pip-requirements.txt ├── install.sh ├── clear_dataset.sh ├── conda-requirements.txt └── README.md /RosENet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /RosENet/network/evaluate.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /RosENet/network/predict.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /RosENet/rosetta/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test_dataset/labels: -------------------------------------------------------------------------------- 1 | 10gs 1 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/*.pyc 2 | htmd/ 3 | pyrosetta/ 4 | __pycache__/ 5 | -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DS3Lab/RosENet/HEAD/logo.png -------------------------------------------------------------------------------- /used_pdbs.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DS3Lab/RosENet/HEAD/used_pdbs.pdf -------------------------------------------------------------------------------- /RosENet/__main__.py: -------------------------------------------------------------------------------- 1 | #from . import testing 2 | from .clui import parse_arguments 3 | 4 | parse_arguments() 5 | -------------------------------------------------------------------------------- /RosENet/storage/.ropeproject/history: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DS3Lab/RosENet/HEAD/RosENet/storage/.ropeproject/history -------------------------------------------------------------------------------- /RosENet/storage/.ropeproject/objectdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DS3Lab/RosENet/HEAD/RosENet/storage/.ropeproject/objectdb -------------------------------------------------------------------------------- /RosENet/storage/.ropeproject/globalnames: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DS3Lab/RosENet/HEAD/RosENet/storage/.ropeproject/globalnames -------------------------------------------------------------------------------- /RosENet/preprocessing/.ropeproject/history: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DS3Lab/RosENet/HEAD/RosENet/preprocessing/.ropeproject/history -------------------------------------------------------------------------------- /RosENet/preprocessing/.ropeproject/objectdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DS3Lab/RosENet/HEAD/RosENet/preprocessing/.ropeproject/objectdb -------------------------------------------------------------------------------- /RosENet/preprocessing/.ropeproject/globalnames: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DS3Lab/RosENet/HEAD/RosENet/preprocessing/.ropeproject/globalnames -------------------------------------------------------------------------------- /RosENet/postprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from .postprocessing import generate_tfrecords 2 | 3 | def postprocess(dataset_object): 4 | generate_tfrecords(dataset_object) 5 | -------------------------------------------------------------------------------- /RosENet/voxelization/apbs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | module load gcc/5.2.0 3 | OLD=$LD_LIBRARY_PATH 4 | LD_LIBRARY_PATH="$SCRATCH/apbs/lib:$LD_LIBRARY_PATH" 5 | $SCRATCH/apbs/bin/apbs "$@" > /dev/null 2>&1 6 | 7 | LD_LIBRARY_PATH=$OLD 8 | -------------------------------------------------------------------------------- /RosENet/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from .compute_rosetta_energy import ComputeRosettaEnergy 2 | from .make_pdbqt import MakePDBQT 3 | 4 | def preprocess(pdb_object): 5 | ComputeRosettaEnergy.run_until(pdb_object,callbacks=[lambda x: print(x)]) 6 | MakePDBQT.run_until(pdb_object,callbacks=[lambda x: print(x)]) 7 | -------------------------------------------------------------------------------- /RosENet/utils.py: -------------------------------------------------------------------------------- 1 | def message_callbacks(callbacks, message): 2 | """Utility function to broadcast messages to a list of callbacks. 3 | 4 | callbacks : list of callable 5 | List of callbacks 6 | message : various 7 | Message to broadcast 8 | """ 9 | for callback in callbacks: 10 | callback(message) 11 | -------------------------------------------------------------------------------- /pip-requirements.txt: -------------------------------------------------------------------------------- 1 | colorama==0.4.3 2 | networkx==2.4 3 | mdtraj==1.9.4 4 | periodictable==1.5.2 5 | tqdm==4.47.0 6 | htmd-pdb2pqr==0.0.2 7 | pyfiglet==0.8.post1 8 | natsort==7.0.1 9 | python-dateutil==2.8.1 10 | decorator==4.4.2 11 | mendeleev==0.6.0 12 | pandas==0.25.3 13 | propka==3.2.0 14 | sqlalchemy==1.3.18 15 | pytz==2020.1 16 | numba==0.50.1 17 | prody==1.10.8 18 | moleculekit==0.3.2 19 | llvmlite==0.33.0 20 | biopython==1.77 21 | -------------------------------------------------------------------------------- /RosENet/settings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Changeable settings for the project. 3 | """ 4 | from RosENet.voxelization.filter import exp_12 5 | size = 25 6 | voxelization_type = "filter" 7 | voxelization_fn = "exp_12" 8 | voxelization = ("filter", exp_12) 9 | options = f"_{voxelization_type}_{voxelization_fn}_{size}" 10 | max_epochs = 300 11 | parallel_calls = 20 12 | shuffle_buffer_size = 1000 13 | batch_size = 128 14 | prefetch_buffer_size = 10 15 | rotate = True 16 | 17 | -------------------------------------------------------------------------------- /RosENet/voxelization/__init__.py: -------------------------------------------------------------------------------- 1 | from .voxelizers import VoxelizeHTMD, VoxelizeRosetta, VoxelizeElectronegativity 2 | from RosENet.postprocessing.postprocessing import combine_maps 3 | from RosENet import settings 4 | 5 | def voxelize(pdb_object): 6 | try: 7 | VoxelizeHTMD(pdb_object, settings.size) 8 | VoxelizeRosetta(pdb_object, settings.voxelization, settings.size) 9 | VoxelizeElectronegativity(pdb_object, settings.voxelization, settings.size) 10 | combine_maps(pdb_object) 11 | except: 12 | pass 13 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | SCRIPTPATH="$( cd "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )" 2 | cd $SCRIPTPATH 3 | 4 | conda remove --name rosenet || true 5 | conda create --name rosenet python=3.7 --file conda-requirements.txt -y 6 | conda init bash 7 | conda activate rosenet && (yes | pip install -r pip-requirements.txt) 8 | 9 | rm -r ./htmd 10 | wget -qO- https://anaconda.org/acellera/HTMD/1.13.10/download/linux-64/htmd-1.13.10-py36_0.tar.bz2 | tar -xvj lib/python3.6/site-packages/htmd 11 | mv lib/python3.6/site-packages/htmd/ . 12 | rm -r lib 13 | 14 | echo "Copy the pyrosetta folder in $SCRIPTPATH or add pyrosetta to PYTHONPATH" 15 | -------------------------------------------------------------------------------- /RosENet/static/flags_relax.txt: -------------------------------------------------------------------------------- 1 | -in 2 | -file 3 | -s ./$complex 4 | -extra_res_fa ./$params 5 | -packing 6 | -ex1 7 | -ex1aro 8 | -ex2 9 | -mute core.util.prof ## dont show timing info 10 | -mute core.io.database 11 | #-restore_pre_talaris_2013_behavior 12 | #-mute all 13 | #-unmute protocols.jd2.JobDistributor 14 | -constraints:cst_fa_file ./constraints 15 | -score:set_weights atom_pair_constraint 1.0 16 | -in:auto_setup_metals 17 | -in:metals_angle_constraint_multiplier 3.0 18 | -in:metals_distance_constraint_multiplier 3.0 19 | -in:ignore_waters false 20 | -ignore_zero_occupancy false 21 | -keep_input_protonation_state true 22 | -nstruct 50 23 | -overwrite 24 | -------------------------------------------------------------------------------- /clear_dataset.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | DATASET_PATH=$1 4 | 5 | find $DATASET_PATH -type f -name "metadata.json" | xargs rm 6 | find $DATASET_PATH -type f -name "*_complex*" | xargs rm 7 | find $DATASET_PATH -type f -name "*_ligand_*.*" | xargs rm 8 | find $DATASET_PATH -type f -name "*_protein_*.*" | xargs rm 9 | find $DATASET_PATH -type f -name "*_ligand.pdb" | xargs rm 10 | find $DATASET_PATH -type f -name "*_ligand.params" | xargs rm 11 | find $DATASET_PATH -type f -name "constraints" | xargs rm 12 | find $DATASET_PATH -type f -name "flags_relax.txt" | xargs rm 13 | find $DATASET_PATH -type f -name "score.sc" | xargs rm 14 | find $DATASET_PATH -type d -name "other_complexes" | xargs rm -r 15 | 16 | -------------------------------------------------------------------------------- /RosENet/static/relax.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /RosENet/preprocessing/make_protein_pdb.py: -------------------------------------------------------------------------------- 1 | from RosENet.preprocessing.step import Step 2 | from RosENet.preprocessing.minimize_rosetta import MinimizeRosetta 3 | import RosENet.constants as constants 4 | import RosENet.rosetta.rosetta as rosetta 5 | 6 | class MakeProteinPDB(metaclass=Step, requirements=[MinimizeRosetta]): 7 | @classmethod 8 | def files(cls, pdb_object): 9 | """List of files being created 10 | 11 | pdb_object : PDBObject 12 | PDB structure being handled 13 | """ 14 | return [pdb_object.minimized.protein.pdb] 15 | 16 | @classmethod 17 | def _run(cls, pdb_object): 18 | """Inner function for the preprocessing step. 19 | 20 | pdb_object : PDBObject 21 | PDB structure being handled 22 | """ 23 | complex = pdb_object.minimized.complex.pdb.read() 24 | protein = complex.select(constants.protein_selector) 25 | pdb_object.minimized.protein.pdb.write(protein) 26 | 27 | -------------------------------------------------------------------------------- /RosENet/preprocessing/make_pdbqt.py: -------------------------------------------------------------------------------- 1 | from RosENet.preprocessing.step import Step 2 | from RosENet.preprocessing.make_protein_pdb import MakeProteinPDB 3 | from RosENet.preprocessing.make_ligand_mol2 import MakeLigandMOL2 4 | import RosENet.constants as constants 5 | import subprocess 6 | 7 | class MakePDBQT(metaclass=Step, requirements=[MakeProteinPDB, MakeLigandMOL2]): 8 | @classmethod 9 | def files(cls, pdb_object): 10 | """List of files being created 11 | 12 | pdb_object : PDBObject 13 | PDB structure being handled 14 | """ 15 | return [pdb_object.minimized.ligand.pdbqt, 16 | pdb_object.minimized.protein.pdbqt] 17 | 18 | @classmethod 19 | def _run(cls, pdb_object): 20 | """Inner function for the preprocessing step. 21 | 22 | pdb_object : PDBObject 23 | PDB structure being handled 24 | """ 25 | subprocess.run([constants.mgl_python_path, 26 | constants.preprocess_vina_path, 27 | pdb_object.minimized.protein.pdb.path, 28 | pdb_object.minimized.ligand.mol2.path]) 29 | 30 | -------------------------------------------------------------------------------- /RosENet/preprocessing/make_ligand_params_pdb.py: -------------------------------------------------------------------------------- 1 | from RosENet.preprocessing.step import Step 2 | import RosENet.rosetta.rosetta as rosetta 3 | 4 | class MakeLigandParamsPDB(metaclass=Step): 5 | """Preprocessing step that create the ligand.params and ligand.pdb files at 6 | the beginning of the pipeline.""" 7 | 8 | @classmethod 9 | def files(cls, pdb_object): 10 | """List of files being created 11 | 12 | pdb_object : PDBObject 13 | PDB structure being handled 14 | """ 15 | return [pdb_object.ligand.params, 16 | pdb_object.ligand.pdb] 17 | 18 | @classmethod 19 | def _run(cls, pdb_object): 20 | """Inner function for the preprocessing step. 21 | 22 | pdb_object : PDBObject 23 | PDB structure being handled 24 | """ 25 | ligand_mol2_path = pdb_object.ligand.mol2.path 26 | params_filename = ligand_mol2_path.stem 27 | working_directory = ligand_mol2_path.parent 28 | return rosetta.molfile_to_params( 29 | working_directory = working_directory, 30 | output_path = params_filename, 31 | input_path = ligand_mol2_path) 32 | 33 | -------------------------------------------------------------------------------- /RosENet/preprocessing/make_ligand_mol2.py: -------------------------------------------------------------------------------- 1 | from RosENet.preprocessing.step import Step 2 | import RosENet.rosetta.rosetta as rosetta 3 | from RosENet.preprocessing.make_ligand_mol2_renamed import MakeLigandMOL2Renamed 4 | from RosENet.preprocessing.minimize_rosetta import MinimizeRosetta 5 | 6 | class MakeLigandMOL2(metaclass=Step,requirements=[MakeLigandMOL2Renamed, MinimizeRosetta]): 7 | @classmethod 8 | def files(cls, pdb_object): 9 | """List of files being created 10 | 11 | pdb_object : PDBObject 12 | PDB structure being handled 13 | """ 14 | return [pdb_object.minimized.ligand.mol2] 15 | 16 | 17 | @classmethod 18 | def _run(cls, pdb_object): 19 | """Inner function for the preprocessing step. 20 | 21 | pdb_object : PDBObject 22 | PDB structure being handled 23 | """ 24 | complex_path = pdb_object.minimized.complex.pdb.path 25 | renamed_mol2_path = pdb_object.ligand.renamed_mol2.path 26 | ligand_path = pdb_object.minimized.ligand.mol2.path 27 | rosetta.pdb_to_molfile( 28 | mol = renamed_mol2_path, 29 | complex_path = complex_path, 30 | output_path = ligand_path) 31 | -------------------------------------------------------------------------------- /RosENet/voxelization/apbs.in: -------------------------------------------------------------------------------- 1 | # READ IN MOLECULES 2 | read 3 | mol pqr XXX.pqr 4 | end 5 | 6 | 7 | elec # Electrostatics calculation on the solvated state 8 | mg-auto # Specify the mode for APBS to run 9 | dime GRID_DIM # The grid dimensions 10 | grid GRID_SPACE # Grid spacing 11 | gcent INH_CENTER # Center the grid 12 | cglen CG_LEN 13 | cgcent INH_CENTER 14 | fglen FG_LEN 15 | fgcent INH_CENTER 16 | mol 1 # Perform the calculation on molecule 1 17 | lpbe # Solve the linearized Poisson-Boltzmann 18 | # equation 19 | bcfl mdh # Use all multipole moments when calculating the 20 | # potential 21 | ion 1 0.150 2.0 22 | ion -1 0.150 2.0 23 | pdie 1.0 # Solute dielectric 24 | sdie 78.54 # Solvent dielectric 25 | chgm spl2 # Spline-based discretization of the delta 26 | # functions 27 | srfm smol # Molecular surface definition 28 | srad 1.4 # Solvent probe radius (for molecular surface) 29 | swin 0.3 # Solvent surface spline window (not used here) 30 | sdens 10.0 # Sphere density of accessibility object 31 | temp 298.15 # Temperature 32 | calcenergy no # Calculate energies 33 | calcforce no # Do not calculate forces 34 | write pot dx XXX_potential # Write out the potential 35 | end 36 | quit 37 | -------------------------------------------------------------------------------- /RosENet/voxelization/interpolation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.interpolate import Rbf 3 | 4 | def voxel_interpolation(int_type, structure, targets): 5 | """Method to apply an interpolation method to a set of atom coordinates with 6 | assigned values, and distribute them to a set of target points. 7 | 8 | Parameters 9 | ---------- 10 | int_type : str or callable 11 | Interpolation name or function handle following scipy.interpolate.Rbf. 12 | points : numpy.ndarray 13 | Atom coordinates. 14 | values : numpy.ndarray 15 | Atom values to spacially distribute. 16 | targets : numpy.ndarray 17 | 3D voxel positions to compute the distributed values at. 18 | """ 19 | values = structure.values 20 | points = structure.coordinates 21 | mask = np.linalg.norm(points, axis=-1) <= 12.5*np.sqrt(3) 22 | points = points[mask] 23 | values = values[mask, :] 24 | shape = targets.shape 25 | targets = targets.reshape((-1,3)) 26 | points_x, points_y, points_z = [c.flatten() 27 | for c in np.split(points, 3, axis=1)] 28 | targets_x, targets_y, targets_z = [ 29 | c.flatten() for c in np.split(targets, 3, axis=1)] 30 | res = np.stack([Rbf(points_x, points_y, points_z, 31 | values[..., i], function=int_type)(targets_x, targets_y, targets_z) 32 | for i in range(values.shape[-1])], axis=-1) 33 | res = res.reshape(targets.shape[:-1] + (-1,)) 34 | return res 35 | -------------------------------------------------------------------------------- /RosENet/objects/dataset.py: -------------------------------------------------------------------------------- 1 | import RosENet.storage.storage as storage 2 | from RosENet.objects.pdb import PDBObject 3 | from RosENet.objects.file import TFRecordsFile, File 4 | import RosENet.settings as settings 5 | 6 | class _Dataset: 7 | """Inner dataset class. Represents a dataset (which stores PDB structure folders inside).""" 8 | _instance_dict = {} 9 | def __init__(self, path): 10 | self.path = path 11 | #self.metadata = storage.read_json(self.path / "metadata.json") 12 | 13 | def __getitem__(self, key): 14 | return PDBObject(self.path / key) 15 | 16 | def list(self): 17 | return [x.name for x in self.path.iterdir() if x.is_dir() and x.name[0] != "_"] 18 | 19 | @property 20 | def name(self): 21 | return f'_{self.path.stem}_{settings.voxelization_type}_{settings.voxelization_fn}_{settings.size}' 22 | 23 | @property 24 | def tfrecords(self): 25 | return self.path / self.name 26 | 27 | def tfrecord(self, i): 28 | return TFRecordsFile(f"{i}.tfrecords",self.tfrecords) 29 | 30 | @property 31 | def images(self): 32 | return [x for x in (self.path/settings.options).iterdir()] 33 | 34 | @property 35 | def labels(self): 36 | return File("labels", self.path) 37 | 38 | def model(self, model_object, channels, seed): 39 | return self.path / f"_{model_object.name}_{channels}_{seed}" 40 | 41 | def DatasetObject(path): 42 | if str(path.absolute()) not in _Dataset._instance_dict: 43 | _Dataset._instance_dict[str(path.absolute())] = _Dataset(path.absolute()) 44 | return _Dataset._instance_dict[str(path.absolute())] 45 | 46 | -------------------------------------------------------------------------------- /conda-requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | _tflow_select=2.3.0=mkl 6 | absl-py=0.9.0=py37_0 7 | astor=0.8.0=py37_0 8 | biopython=1.77=py37h7b6447c_0 9 | blas=1.0=mkl 10 | c-ares=1.15.0=h7b6447c_1001 11 | ca-certificates=2020.6.24=0 12 | certifi=2020.6.20=py37_0 13 | gast=0.3.3=py_0 14 | grpcio=1.27.2=py37hf8bcb03_0 15 | h5py=2.10.0=py37hd6299e0_1 16 | hdf5=1.10.6=hb1b8bf9_0 17 | intel-openmp=2020.1=217 18 | keras-applications=1.0.8=py_0 19 | keras-preprocessing=1.1.0=py_1 20 | ld_impl_linux-64=2.33.1=h53a641e_7 21 | libedit=3.1.20191231=h7b6447c_0 22 | libffi=3.3=he6710b0_1 23 | libgcc-ng=9.1.0=hdf63c60_0 24 | libgfortran-ng=7.3.0=hdf63c60_0 25 | libprotobuf=3.12.3=hd408876_0 26 | libstdcxx-ng=9.1.0=hdf63c60_0 27 | markdown=3.1.1=py37_0 28 | mkl=2020.1=217 29 | mkl-service=2.3.0=py37he904b0f_0 30 | mkl_fft=1.1.0=py37h23d657b_0 31 | mkl_random=1.1.1=py37h0573a6f_0 32 | mock=4.0.2=py_0 33 | ncurses=6.2=he6710b0_1 34 | numpy=1.18.5=py37ha1c710e_0 35 | numpy-base=1.18.5=py37hde5b4d6_0 36 | openssl=1.1.1g=h7b6447c_0 37 | pip=20.1.1=py37_1 38 | protobuf=3.12.3=py37he6710b0_0 39 | pyparsing=2.4.7=py_0 40 | python=3.7.7=hcff3b4d_5 41 | readline=8.0=h7b6447c_0 42 | scipy=1.5.0=py37h0b6359f_0 43 | setuptools=47.3.1=py37_0 44 | six=1.15.0=py_0 45 | sqlite=3.32.3=h62c20be_0 46 | tensorboard=1.13.1=py37hf484d3e_0 47 | tensorflow=1.13.1=mkl_py37h54b294f_0 48 | tensorflow-base=1.13.1=mkl_py37h7ce6ba3_0 49 | tensorflow-estimator=1.13.0=py_0 50 | termcolor=1.1.0=py37_1 51 | tk=8.6.10=hbc83047_0 52 | werkzeug=1.0.1=py_0 53 | wheel=0.34.2=py37_0 54 | xz=5.2.5=h7b6447c_0 55 | zlib=1.2.11=h7b6447c_3 56 | -------------------------------------------------------------------------------- /RosENet/static/dock.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | -------------------------------------------------------------------------------- /RosENet/rosetta/rosetta.py: -------------------------------------------------------------------------------- 1 | """ Module wrapping the function calls to the Rosetta Commons software 2 | """ 3 | 4 | import subprocess 5 | from collections import OrderedDict 6 | import RosENet.constants as constants 7 | from io import StringIO 8 | import os 9 | import pandas as pd 10 | 11 | 12 | def molfile_to_params(input_path, output_path, working_directory): 13 | new_env = os.environ.copy() 14 | # This should solve the relative import problem in Rosetta's script 15 | new_env["PYTHONPATH"] = f"{constants.rosetta.root}/main/source/scripts/python/public:{new_env.get('PYTHONPATH','')}" 16 | subprocess.run(["python2", 17 | constants.rosetta.molfile_to_params, 18 | "-n", constants.ligand_resname, 19 | "-p", output_path, 20 | "--conformers-in-one-file", 21 | "--keep-names", 22 | "--clobber", 23 | input_path], 24 | env=new_env, 25 | cwd=str(working_directory)) 26 | 27 | def minimize(working_directory): 28 | subprocess.run([constants.rosetta.minimize, 29 | f"@{constants.flags_filename}", 30 | f"-parser:protocol", 31 | f'"{constants.relax_path}"', 32 | "-database", constants.rosetta.database], 33 | cwd=str(working_directory)) 34 | 35 | def pdb_to_molfile(mol, complex_path, output_path): 36 | new_env = os.environ.copy() 37 | new_env["PYTHONPATH"] = f"{constants.rosetta.root}/main/source/scripts/python/public:{new_env.get('PYTHONPATH','')}" 38 | subprocess.run(" ".join(["python2", 39 | str(constants.rosetta.pdb_to_molfile), 40 | str(mol), 41 | str(complex_path), 42 | ">", 43 | str(output_path)]), 44 | cwd=str(constants.rosetta.py_wd), 45 | env=new_env, 46 | shell=True) 47 | 48 | def parse_scores(scores_text): 49 | """Parse a score.sc file to a dictionary of id to score value ordered in increasing order. 50 | 51 | scores_text : string 52 | Content of the score.sc file 53 | """ 54 | scores = StringIO(scores_text) 55 | csv = pd.read_csv(scores, header=1, sep=r"\s+", usecols=['total_score', 'description']) 56 | csv.sort_values("total_score", inplace=True) 57 | return OrderedDict(zip([name.split("_")[-1] for name in csv.description], csv.total_score)) 58 | 59 | -------------------------------------------------------------------------------- /RosENet/models/kdeep.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | LEARNING_RATE = 0.0001 4 | 5 | def fire_module(net, squeeze, expand, training): 6 | net = tf.layers.conv3d(net, squeeze, [1,1,1], activation=tf.nn.relu) 7 | net1 = tf.layers.conv3d(net, expand, [1,1,1], activation=tf.nn.relu) 8 | net2 = tf.layers.conv3d(net, expand, [3,3,3], padding='same', activation=tf.nn.relu) 9 | return tf.concat(axis=-1, values=[net1,net2]) 10 | 11 | def conv_net(X, reuse, training): 12 | with tf.variable_scope('SqueezeNet', reuse=reuse): 13 | net = tf.layers.conv3d(X, 96, 1, 2, padding='same', activation=tf.nn.relu) 14 | net = fire_module(net, 16, 64,training=training) 15 | net = fire_module(net, 16, 64,training=training) 16 | net = fire_module(net, 32, 128,training=training) 17 | net = tf.layers.max_pooling3d(net, 3, 2) 18 | net = fire_module(net, 32, 128,training=training) 19 | net = fire_module(net, 48, 192,training=training) 20 | net = fire_module(net, 48, 192,training=training) 21 | net = fire_module(net, 64, 256,training=training) 22 | net = tf.layers.average_pooling3d(net, 3, 2) 23 | net = tf.layers.flatten(net) 24 | net = tf.layers.dense(net, 1) 25 | return net 26 | 27 | def model_fn(features, labels, mode, rotate): 28 | training = mode == tf.estimator.ModeKeys.TRAIN 29 | predictions = conv_net(features, reuse=tf.AUTO_REUSE, training=training) 30 | loss = None 31 | train_op = None 32 | if training: 33 | loss = tf.losses.mean_squared_error(labels=labels, predictions=predictions) 34 | optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE) 35 | train_op = optimizer.minimize(loss,global_step=tf.train.get_global_step()) 36 | else: 37 | if rotate: 38 | #When evaluating or predicting, average over the 24 90º rotations if rotate is enabled. 39 | predictions = tf.reduce_mean(tf.reshape(predictions,(-1,24,1)),axis=1) 40 | loss = tf.losses.mean_squared_error(labels,predictions) 41 | else: 42 | loss = tf.losses.mean_squared_error(labels=labels, predictions=predictions) 43 | 44 | return tf.estimator.EstimatorSpec( 45 | mode=mode, 46 | predictions=predictions, 47 | loss=loss, 48 | train_op=train_op) 49 | -------------------------------------------------------------------------------- /RosENet/storage/storage.py: -------------------------------------------------------------------------------- 1 | """ 2 | Storage module to centralize and abstract most of the IO operations in the project. 3 | """ 4 | 5 | from prody import parsePDB, writePDB 6 | import shutil 7 | import h5py 8 | import json 9 | import numpy as np 10 | import tensorflow as tf 11 | 12 | def delete(path): 13 | if path.exists(): 14 | if path.is_dir(): 15 | shutil.rmtree(path) 16 | else: 17 | path.unlink() 18 | 19 | def clear_directory(path, no_fail=True): 20 | if path.exists() and not no_fail: 21 | shutil.rmtree(path) 22 | make_directory(path) 23 | 24 | def make_directory(path, no_fail=True): 25 | path.mkdir(exist_ok = no_fail, parents=True) 26 | 27 | def move(origin, destination, no_fail=False): 28 | if origin.exists() or not no_fail: 29 | if destination.is_dir(): 30 | destination = destination / origin.name 31 | origin.rename(destination) 32 | 33 | def read_image(image_path): 34 | with h5py.File(str(image_path), "r") as f: 35 | return np.array(f["grid"]) 36 | 37 | def write_image(image_path, image): 38 | with h5py.File(str(image_path), "w", libver='latest') as f: 39 | f.create_dataset("grid", dtype='f4', data=image) 40 | 41 | def read_json(json_path): 42 | with json_path.open("r") as f: 43 | data = json.load(f) 44 | return data 45 | 46 | def write_json(json_path, data): 47 | with json_path.open("w") as f: 48 | json.dump(data, f) 49 | 50 | def read_pdb(pdb_path): 51 | return parsePDB(str(pdb_path)) 52 | 53 | def write_pdb(pdb_path, pdb): 54 | return writePDB(str(pdb_path), pdb) 55 | 56 | def write_tfrecords(path, data): 57 | with tf.python_io.TFRecordWriter(str(path)) as writer: 58 | for datum in data: 59 | writer.write(datum) 60 | 61 | def read_attributes(attr_path): 62 | return np.load(str(attr_path).strip(".npz")+".npz") 63 | 64 | def write_attributes(attr_path, attributes): 65 | return np.savez(str(attr_path), attributes) 66 | 67 | def read_plain(file_path): 68 | with open(file_path, "r") as f: 69 | data = f.read() 70 | return data 71 | 72 | def write_plain(file_path, text): 73 | with open(file_path, "w") as f: 74 | if isinstance(text, str): 75 | f.write(text) 76 | elif isinstance(text, list): 77 | for line in text: 78 | f.write(line.strip("\n")+"\n") 79 | else: 80 | f.write(str(text)) 81 | -------------------------------------------------------------------------------- /RosENet/preprocessing/make_ligand_mol2_renamed.py: -------------------------------------------------------------------------------- 1 | import RosENet.storage.storage as storage 2 | from RosENet.preprocessing.step import Step 3 | from RosENet.preprocessing.make_ligand_params_pdb import MakeLigandParamsPDB 4 | 5 | def read_pdb(pdb_path): 6 | """Read .pdb file and relate atom numbers to their names 7 | 8 | pdb_path : pathlib.Path 9 | Path of the pdb file 10 | """ 11 | lines = storage.read_plain(pdb_path).splitlines() 12 | hetatm = filter(lambda x: x.startswith('HETATM'), lines) 13 | atom_num_name = map(lambda x: x.split()[1:3], hetatm) 14 | return dict(atom_num_name) 15 | 16 | def read_mol2(mol2_path, name_map): 17 | """Read .mol2 file and change names to the original .pdb names 18 | 19 | mol2_path : pathlib.Path 20 | Path to the mol2 file 21 | name_map : dict 22 | Atom number to name dictionary 23 | """ 24 | lines = storage.read_plain(mol2_path).splitlines() 25 | mode = "search" 26 | for i, line in enumerate(lines): 27 | if mode == "search": 28 | if line.startswith('@ATOM'): 29 | mode = "rename" 30 | elif mode == "rename": 31 | if line.startswith('@BOND'): 32 | mode = "end" 33 | else: 34 | atom_num, atom_name = line.split()[0:2] 35 | new_name = name_map[atom_num].ljust(len(atom_name)) 36 | position = line.find(atom_name) 37 | end = position + len(new_name) 38 | new_line = line[0:position] + new_name + line[end:] 39 | lines[i] = new_line 40 | elif mode == "end": 41 | return lines 42 | 43 | class MakeLigandMOL2Renamed(metaclass=Step,requirements=[MakeLigandParamsPDB]): 44 | @classmethod 45 | def files(cls, pdb_object): 46 | """List of files being created 47 | 48 | pdb_object : PDBObject 49 | PDB structure being handled 50 | """ 51 | return [pdb_object.ligand.renamed_mol2] 52 | 53 | @classmethod 54 | def _run(cls, pdb_object): 55 | """Inner function for the preprocessing step. 56 | 57 | pdb_object : PDBObject 58 | PDB structure being handled 59 | """ 60 | ligand_pdb_path = pdb_object.ligand.pdb.path 61 | ligand_mol2_path = pdb_object.ligand.mol2.path 62 | output_path = pdb_object.ligand.renamed_mol2.path 63 | name_map = read_pdb(ligand_pdb_path) 64 | storage.write_plain(output_path, read_mol2(ligand_mol2_path, name_map)) 65 | -------------------------------------------------------------------------------- /RosENet/models/large_kdeep.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | LEARNING_RATE = 0.0001 4 | 5 | def fire_module(net, squeeze, expand, training): 6 | net = tf.layers.conv3d(net, squeeze, [1,1,1], activation=tf.nn.relu) 7 | net1 = tf.layers.conv3d(net, expand, [1,1,1], activation=tf.nn.relu) 8 | net2 = tf.layers.conv3d(net, expand, [3,3,3], padding='same', activation=tf.nn.relu) 9 | return tf.concat(axis=-1, values=[net1,net2]) 10 | 11 | def conv_net(X, reuse, training): 12 | with tf.variable_scope('SqueezeNet', reuse=reuse): 13 | net = tf.layers.conv3d(X, 96, 7, 2, padding='same', activation=tf.nn.relu) 14 | net = fire_module(net, 16, 64,training=training) 15 | net = fire_module(net, 16, 64,training=training) 16 | net = fire_module(net, 32, 128,training=training) 17 | net = tf.layers.max_pooling3d(net, 3, 2) 18 | net = fire_module(net, 32, 128,training=training) 19 | net = fire_module(net, 48, 192,training=training) 20 | net = fire_module(net, 48, 192,training=training) 21 | net = fire_module(net, 64, 256,training=training) 22 | net = tf.layers.average_pooling3d(net, 3, 2) 23 | net = tf.layers.flatten(net) 24 | net = tf.layers.dense(net, 1) 25 | return net 26 | 27 | def model_fn(features, labels, mode, rotate): 28 | training = mode == tf.estimator.ModeKeys.TRAIN 29 | predictions = conv_net(features, reuse=tf.AUTO_REUSE, training=training) 30 | 31 | loss = None 32 | train_op = None 33 | if training: 34 | loss = tf.losses.mean_squared_error(labels=labels, predictions=predictions) 35 | optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE) 36 | train_op = optimizer.minimize(loss,global_step=tf.train.get_global_step()) 37 | else: 38 | #When evaluating or predicting, average over the 24 90º rotations if rotations are enabled. 39 | if rotate: 40 | predictions = tf.reduce_mean(tf.reshape(predictions,(-1,24,1)),axis=1) 41 | loss = tf.losses.mean_squared_error(labels,predictions) 42 | else: 43 | loss = tf.losses.mean_squared_error(labels=labels, predictions=predictions) 44 | 45 | return tf.estimator.EstimatorSpec( 46 | mode=mode, 47 | predictions=predictions, 48 | loss=loss, 49 | train_op=train_op, 50 | eval_metric_ops={ 51 | 'R': tf.metrics.mean(r), 52 | 'ro_2': tf.metrics.mean(r_squared) 53 | }) 54 | -------------------------------------------------------------------------------- /RosENet/voxelization/filter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial.distance import cdist 3 | 4 | 5 | def exp_12(r, rvdw): 6 | """Exponential filter as given by Jimenez et al. (2018) 7 | 10.1021/acs.jcim.7b00650 8 | 9 | Parameters 10 | ---------- 11 | r : numpy.ndarray 12 | Array of distances between each atom and all voxels of the 3D image. 13 | rvdw : numpy.ndarray 14 | Array of Van der Waals radii for each atom. 15 | 16 | """ 17 | rvdw = rvdw.reshape((-1,)) 18 | rr = rvdw[:, None]/r 19 | ret = np.where(r == 0, 1, 1 - np.exp(-(rr)**12)) 20 | return ret 21 | 22 | 23 | def gaussian(r, rvdw): 24 | """Gaussian-like filter, constants removed as everything will be normalized. 25 | 26 | Parameters 27 | ---------- 28 | r : numpy.ndarray 29 | Array of distances between each atom and all voxels of the 3D image. 30 | rvdw : numpy.ndarray 31 | Array of Van der Waals radii for each atom. 32 | 33 | """ 34 | rvdw = rvdw.reshape((-1,)) 35 | rr = r/rvdw[:, None] 36 | ret = np.exp(-(rr)**2) 37 | return ret 38 | 39 | 40 | def voxel_filter(filter_type, structure, targets): 41 | """Method to apply a filter to a set of atom coordinates with assigned 42 | values, and distribute them to a set of target points. Atom radii are 43 | taken into consideration to adjust strength of contributions. 44 | 45 | Parameters 46 | ---------- 47 | filter_type : callable 48 | Function handle implementing the filter to be applied. 49 | points : numpy.ndarray 50 | Atom coordinates. 51 | values : numpy.ndarray 52 | Atom values to spacially distribute. 53 | targets : numpy.ndarray 54 | 3D voxel positions to compute the distributed values at. 55 | radii : numpy.ndarray 56 | Van der Waals radii of the atoms. 57 | """ 58 | values = structure.values 59 | points = structure.coordinates 60 | radii = structure.radii 61 | shape = targets.shape 62 | targets = targets.reshape((-1,3)) 63 | mask = np.linalg.norm(points, axis=-1) <= 12.5*np.sqrt(3) 64 | points = points[mask] 65 | values = values[mask, :] 66 | radii = radii[mask] 67 | dists = cdist(points, targets) 68 | aux = np.where(dists < 5, filter_type(dists, radii), 0) 69 | del dists 70 | result = np.array([values[np.argmax(aux, axis=0), i] * np.max(aux, axis=0) 71 | for i in range(values.shape[-1])]) 72 | result = np.swapaxes(result, 0, 1) 73 | result = result.reshape(shape[:-1] + (-1,)) 74 | return result 75 | -------------------------------------------------------------------------------- /RosENet/voxelization/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def grid_around(center, size, spacing=1.0): 4 | """Generate an array of 3D positions for a voxel cube of the given size and 5 | coarseness. 6 | 7 | center : numpy.ndarray 8 | 3D coordinates of the cube's center. 9 | size : numpy.ndarray 10 | Number of voxels for each of the cube's dimensions. 11 | spacing : float 12 | Distance between voxel centers. 13 | """ 14 | size_ang = ((size - 1) / 2.) * spacing 15 | ex_min = center - size_ang 16 | ex_max = center + size_ang 17 | 18 | x = np.linspace(ex_min[0], ex_max[0], size) 19 | y = np.linspace(ex_min[1], ex_max[1], size) 20 | z = np.linspace(ex_min[2], ex_max[2], size) 21 | return np.stack(np.meshgrid(x, y, z, indexing='ij'), axis=-1) 22 | 23 | 24 | def clip1d(value, upp_limit): 25 | """Clip a single array by a limit and zero. If the limit is negative, 26 | use it as lower bound, otherwise use it as upper bound. 27 | 28 | value : numpy.ndarray 29 | Array to be clipped. 30 | upp_limit : float 31 | Threshold to use as bound. 32 | """ 33 | low_limit = 0 34 | if upp_limit < 0: 35 | upp_limit, low_limit = low_limit, upp_limit 36 | return np.clip(value, low_limit, upp_limit) / (low_limit + upp_limit) 37 | 38 | 39 | def clip(values, limits): 40 | """Clip a list of arrays by a list of limits 41 | 42 | values : list of numpy.ndarray 43 | List of arrays to be clipped. 44 | limits : list of float 45 | List of thresholds to use as bounds. 46 | """ 47 | if values.shape[-1] == 1: 48 | return clip1d(values, limits[0]) 49 | return np.stack((clip1d(values[...,i], limits[i]) for i in range(values.shape[-1])), axis=-1) 50 | 51 | 52 | def get_keys(pdb): 53 | """Get the unique key names to the atoms of the structure. 54 | 55 | pdb : prody.AtomGroup 56 | Structure of atoms. 57 | """ 58 | resnums = pdb.getResnums() 59 | chids = pdb.getChids() 60 | names = pdb.getNames() 61 | keys = np.char.add(np.char.mod( 62 | '%s-', np.char.replace(np.char.add(np.char.mod('%s ', resnums), chids), ' ', '-')), names) 63 | return keys 64 | 65 | 66 | def save_grid(saving_path, image): 67 | """Save 3D image to HDF5 format. 68 | 69 | saving_path : str or os.PathLike 70 | Destination path to save the image. 71 | image : numpy.ndarray 72 | 3D image to be saved. 73 | """ 74 | with h5py.File(str(saving_path), "w", libver='latest') as f: 75 | f.create_dataset("grid", dtype='f4', data=image) 76 | 77 | -------------------------------------------------------------------------------- /RosENet/static/dock_relax.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /RosENet/static/dock_relax2.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | -------------------------------------------------------------------------------- /RosENet/models/resnet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | LEARNING_RATE = 0.0001 4 | 5 | 6 | def start_block(net, channels, training): 7 | input_net = tf.layers.conv3d(net, 4*channels, 1, 2, padding='same') 8 | net = conv3bn(net, channels, 1, training,stride=2) 9 | net = conv3bn(net, channels, 3, training) 10 | net = conv3bn(net, 4*channels, 1, training) 11 | output_net = net + input_net 12 | return output_net 13 | 14 | def inner_block(net, channels, training): 15 | input_net = tf.layers.conv3d(net, 4*channels, 1, padding='same') 16 | net = conv3bn(net, channels, 1, training) 17 | net = conv3bn(net, channels, 3, training) 18 | net = conv3bn(net, 4*channels, 1, training) 19 | output_net = net + input_net 20 | return output_net 21 | 22 | 23 | def conv3bn(net, channels, filt, training, stride=1): 24 | net = tf.nn.relu(net) 25 | net = tf.layers.conv3d(net, channels, filt, stride, padding='same') 26 | return net 27 | 28 | def conv_net(X, reuse, training): 29 | with tf.variable_scope('ResNet', reuse=reuse): 30 | layers = [3,4,23,3] 31 | k = 64 32 | net = tf.layers.conv3d(X, k, 7, 2, padding='same') 33 | net = tf.nn.relu(net) 34 | net = tf.layers.max_pooling3d(net, 3, 2) 35 | for i in range(0,layers[0]): 36 | net = inner_block(net, k, training) 37 | for i, l in enumerate(layers[1:],1): 38 | net = start_block(net, k*(2**i), training) 39 | for j in range(0,l-1): 40 | net = inner_block(net, k*(2**i), training) 41 | net = tf.reduce_mean(net, axis=(1,2,3)) 42 | net = tf.layers.flatten(net) 43 | net = tf.layers.dense(net, 1) 44 | return net 45 | 46 | def model_fn(features, labels, mode, rotate): 47 | training = mode == tf.estimator.ModeKeys.TRAIN 48 | predictions = conv_net(features, reuse=tf.AUTO_REUSE, training=training) 49 | 50 | loss = None 51 | train_op = None 52 | if training: 53 | loss = tf.losses.mean_squared_error(labels=labels, predictions=predictions) 54 | optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE) 55 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 56 | with tf.control_dependencies(update_ops): 57 | train_op = optimizer.minimize(loss,global_step=tf.train.get_global_step()) 58 | else: 59 | if rotate: 60 | #When evaluating or predicting, average over the 24 90º rotations if rotate is enabled. 61 | predictions = tf.reduce_mean(tf.reshape(predictions,(-1,24,1)), axis=1) 62 | loss = tf.losses.mean_squared_error(labels, predictions) 63 | else: 64 | loss = tf.losses.mean_squared_error(labels=labels, predictions=predictions) 65 | 66 | return tf.estimator.EstimatorSpec( 67 | mode=mode, 68 | predictions=predictions, 69 | loss=loss, 70 | train_op=train_op) 71 | -------------------------------------------------------------------------------- /RosENet/constants.py: -------------------------------------------------------------------------------- 1 | """Constants file. Here one only needs to change the location of the different 2 | tools used by the project. 3 | """ 4 | 5 | import os 6 | from types import SimpleNamespace 7 | from pathlib import Path 8 | 9 | source_path = os.path.dirname(os.path.abspath(__file__)) 10 | flags_filename = "flags_relax.txt" 11 | flags_relax_path = os.path.join(source_path, "static", flags_filename) 12 | relax_filename = "dock_relax.xml" 13 | relax_path = os.path.join(source_path, "static", relax_filename) 14 | metal_selector = "chain Z" 15 | ligand_resname = "WER" 16 | close_ligand_selector = f"resname {ligand_resname} and (not (element H or element C)) and within 2.5 of t" 17 | protein_selector = f"not resname {ligand_resname}" 18 | mgl_python_path = "/home/hussein/Repositories/Own/mgltools_x86_64Linux2_1.5.6/bin/pythonsh" 19 | preprocess_vina_path = os.path.join(source_path, "preprocessing", "preprocess_vina.py") 20 | nonstd2stdresidues = {'HOH':'WAT', 21 | 'CYX':'CYS', 22 | 'CYM':'CYS', 23 | 'HIE':'HIS', 24 | 'HID':'HIS', 25 | 'HSD':'HIS', 26 | 'HIP':'HIS', 27 | 'HIY':'HIS', 28 | 'ALB':'ALA', 29 | 'ASM':'ASN', 30 | 'DIC':'ASP', 31 | 'GLV':'GLU', 32 | 'GLO':'GLN', 33 | 'HIZ':'HIS', 34 | 'LEV':'LEU', 35 | 'SEM':'SER', 36 | 'TYM':'TYR', 37 | 'TRQ':'TRP', 38 | 'KCX':'LYS', 39 | 'LLP':'LYS', 40 | 'ARN':'ARG', 41 | 'ASH':'ASP', 42 | 'DID':'ASP', 43 | 'ASZ':'ASP', 44 | 'CYT':'CYS', 45 | 'GLH':'GLU', 46 | 'LYN':'LYS', 47 | 'AR0':'ARG', 48 | 'PCA':'GLU', 49 | 'HSE':'SER'} 50 | ligand_chid = "X" 51 | water_residue = "WAT" 52 | water_chain = "W" 53 | accepted_metals = ["MN", "MG", "ZN", "CA", "NA"] 54 | metal_chain = "Z" 55 | ligand_selector = f"not protein and same residue as ((resname {water_residue} and within 3 of resname {ligand_resname} and within 3 of protein) or (resname {' '.join(accepted_metals)} and within 5 of resname {ligand_resname}) or resname {ligand_resname})" 56 | 57 | class classproperty(object): 58 | def __init__(self, getter): 59 | self.getter = getter 60 | def __get__(self, instance, owner): 61 | return self.getter(owner) 62 | 63 | 64 | class rosetta: 65 | root = Path("/home/hussein/Repositories/Own/rosetta") 66 | 67 | @classproperty 68 | def molfile_to_params(cls): 69 | return cls.root / "main/source/scripts/python/public/molfile_to_params.py" 70 | 71 | @classproperty 72 | def minimize(cls): 73 | return cls.root / "main/source/bin/rosetta_scripts.static.linuxgccrelease" 74 | 75 | @classproperty 76 | def pdb_to_molfile(cls): 77 | return cls.root / "main/source/src/apps/public/ligand_docking/pdb_to_molfile.py" 78 | 79 | @classproperty 80 | def database(cls): 81 | return cls.root / "main/database/" 82 | 83 | @classproperty 84 | def py_wd(cls): 85 | return cls.root / "main/source/scripts/python/public/" 86 | -------------------------------------------------------------------------------- /RosENet/objects/file.py: -------------------------------------------------------------------------------- 1 | from RosENet.storage.storage import * 2 | from RosENet.objects.pdb import PDBObject 3 | import RosENet.storage.storage as storage 4 | import RosENet.rosetta.rosetta as rosetta 5 | 6 | class File: 7 | """Base class for file management. Represents any file, which name may be 8 | generic. It allows for read/write access and templates the name of the file 9 | according to the PDB code(root folder name) and the number of the file (for 10 | when it represents the result of the Rosetta minimization)""" 11 | def __init__(self, name, root): 12 | self.root = root 13 | self.code = root.name 14 | self.name = name.format(code=self.code, number="{number}") 15 | if "{number}" in name: 16 | self.multiple = True 17 | else: 18 | self.multiple = None 19 | 20 | @property 21 | def path(self): 22 | return self.resolve_path() 23 | 24 | def resolve_path(self, number=None): 25 | if number: 26 | return self.root / self.name.format(number=number) 27 | else: 28 | if self.multiple is True: 29 | scores = rosetta.parse_scores(PDBObject(self.root).minimized.scores.read()) 30 | number = (list(scores.keys()))[0].split("_")[-1] 31 | self.multiple = self.name.format(number=number) 32 | return self.root / self.multiple 33 | elif self.multiple: 34 | return self.root / self.multiple 35 | else: 36 | return self.root / self.name 37 | 38 | def read(self): 39 | return storage.read_plain(self.path) 40 | 41 | def write(self, data): 42 | return storage.write_plain(self.path, data) 43 | 44 | def delete(self): 45 | return storage.delete(self.path) 46 | 47 | def __getitem__(self, key): 48 | if self.multiple: 49 | return File.create(self.name.format(number=key), self.root) 50 | else: 51 | return self 52 | 53 | @staticmethod 54 | def create(path, root): 55 | if ".pdbqt" in path: 56 | return PDBQTFile(path, root) 57 | elif ".pdb" in path: 58 | return PDBFile(path, root) 59 | elif ".img" in path: 60 | return ImageFile(path, root) 61 | elif ".attr" in path: 62 | return AttributeFile(path, root) 63 | elif ".tfrecords" in path: 64 | return TFRecordsFile(path, root) 65 | else: 66 | return File(path, root) 67 | 68 | 69 | class PDBFile(File): 70 | def read(self): 71 | return read_pdb(self.path) 72 | 73 | def write(self, data): 74 | write_pdb(self.path, data) 75 | 76 | 77 | class TFRecordsFile(File): 78 | def read(self): 79 | return read_tfrecords(self.path) 80 | 81 | def write(self, data): 82 | write_tfrecords(self.path, data) 83 | 84 | class ImageFile(File): 85 | def read(self): 86 | return read_image(self.path) 87 | 88 | def write(self, data): 89 | write_image(self.path, data) 90 | 91 | class AttributeFile(File): 92 | def read(self): 93 | return read_attributes(self.path) 94 | 95 | def write(self, data): 96 | write_attributes(self.path, data) 97 | 98 | class PDBQTFile(File): 99 | def read(self): 100 | return read_pdb(self.resolve_path()) 101 | 102 | def write(self, data): 103 | write_pdb(self.resolve_path(), data) 104 | -------------------------------------------------------------------------------- /RosENet/postprocessing/postprocessing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from RosENet.storage import storage 4 | 5 | def combine_maps(pdb_object): 6 | """Postprocessing step for combining the feature map images into one image 7 | 8 | Parameters 9 | ---------- 10 | pdb_object : PDBObject 11 | PDB object which images will be combined 12 | """ 13 | features = [pdb_object.image.htmd.read(), 14 | pdb_object.image.electronegativity.read(), 15 | pdb_object.image.rosetta.read()] 16 | grid = np.concatenate(features, axis=-1) 17 | storage.make_directory(pdb_object.image.combined.path.parent) 18 | pdb_object.image.combined.write(grid) 19 | 20 | 21 | def serialize_file(file, target, type): 22 | """Serialize image file to Tensorflow serialized format 23 | 24 | Parameters 25 | ---------- 26 | file : pathlib.Path 27 | Path for the image file. 28 | target : float 29 | Ground truth binding affinity for the image. 30 | """ 31 | datapoint = storage.read_image(file) 32 | features = datapoint.flatten() 33 | label = bytes(file.stem,"utf-8") 34 | type = 0 if type == "Kd" else 1 35 | example = tf.train.Example(features=tf.train.Features(feature={ 36 | 'id' : tf.train.Feature(bytes_list=tf.train.BytesList(value=[label])), 37 | 'X' : tf.train.Feature(float_list=tf.train.FloatList(value=features)), 38 | 'y' : tf.train.Feature(float_list=tf.train.FloatList(value=np.array([target]))), 39 | 'type' : tf.train.Feature(int64_list=tf.train.Int64List(value=[type])) 40 | })) 41 | return example.SerializeToString() 42 | 43 | def write_tfrecords(files, dataset_object, number, labels): 44 | """Serialize the given images and store them in a TFRecords file, together with their binding affinity values. 45 | 46 | files : list of pathlib.Path 47 | List of files to be written inside the TFRecords file. 48 | dataset_object : DatasetObject 49 | Dataset of the images. 50 | number : int 51 | ID number of the TFRecords file. 52 | labels : dict 53 | Dictionary relating image names to their binding affinities. 54 | """ 55 | output = [serialize_file(file, float(labels[file.stem][0]), "Kd") for file in files] 56 | dataset_object.tfrecord(number).write(output) 57 | 58 | def chunk_by_size(files, recommended_tf_size=float(100*(2**20))): 59 | """Split the images into chunks according to the TFRecords recommended size, which is 100MB. 60 | 61 | files : list of pathlib.Path 62 | List of image paths to be split. 63 | recommended_tf_size : float 64 | Maximum size of each chunk, 100MB by default. 65 | """ 66 | average_size = sum(map(lambda x: x.stat().st_size, files)) / float(len(files)) 67 | files_per_chunk = np.ceil(recommended_tf_size / average_size) 68 | chunks = int(np.ceil(float(len(files)) / files_per_chunk)) 69 | return np.array_split(np.array(files), chunks) 70 | 71 | def generate_tfrecords(dataset_object): 72 | """Generate TFRecords from a dataset's combined images. 73 | 74 | dataset_object : DatasetObject 75 | Dataset of the images. 76 | """ 77 | files = dataset_object.images 78 | chunks= chunk_by_size(files) 79 | storage.clear_directory(dataset_object.tfrecords, no_fail=True) 80 | storage.make_directory(dataset_object.tfrecords, no_fail=True) 81 | lines = dataset_object.labels.read().splitlines() 82 | lines = [line.split(" ") for line in lines] 83 | pdb_labels = dict([(line[0],line[1:]) for line in lines]) 84 | for i, chunk in enumerate(chunks): 85 | write_tfrecords(chunk, dataset_object, i, pdb_labels) 86 | 87 | -------------------------------------------------------------------------------- /RosENet/preprocessing/step.py: -------------------------------------------------------------------------------- 1 | import RosENet.utils as utils 2 | 3 | class Step(type): 4 | """Metaclass for preprocessing steps. Implements functions to run the steps 5 | according to their dependencies, clean the outputed data and check for 6 | readiness of execution.""" 7 | 8 | def __new__(cls, name, bases, dct, requirements=[]): 9 | x = super(Step, cls).__new__(cls,name, bases, dct) 10 | x.requirements = requirements 11 | x.successors = [] 12 | x.name = name 13 | for requirement in requirements: 14 | requirement.successors.append(x) 15 | return x 16 | 17 | 18 | def clean(cls, pdb_object): 19 | """Delete files and folders created by the step and its successor steps. 20 | 21 | pdb_object : PDBObject 22 | PDB structure to be handled 23 | """ 24 | if not cls.computed(pdb_object): 25 | return 26 | for successor in cls.successors: 27 | successor.clean(pdb_object) 28 | pdb_object.uncomplete(cls.name) 29 | for file in cls.files(pdb_object): 30 | file.delete() 31 | 32 | def run(cls, pdb_object, callbacks=[], force=False, surrogate=False): 33 | """Run the preprocessing step for the given PDB structure. 34 | 35 | pdb_object : PDBObject 36 | PDB structure to be handled 37 | callbacks : list of callable 38 | List of callbacks for logging 39 | force : bool 40 | Force execution even if the step was already computed? Default False 41 | surrogate : bool 42 | Run step as consequence of another step. Default False 43 | """ 44 | if cls.computed(pdb_object) and not force: 45 | if not surrogate: 46 | utils.message_callbacks(callbacks, (cls.name, pdb_object.id, "already_done")) 47 | return True 48 | elif not cls.ready(pdb_object): 49 | if not surrogate: 50 | utils.message_callbacks(callbacks, (cls.name, pdb_object.id, "not_ready")) 51 | return False 52 | utils.message_callbacks(callbacks, (cls.name, pdb_object.id, "start")) 53 | try: 54 | cls._run(pdb_object) 55 | except Exception as e: 56 | utils.message_callbacks(callbacks, (cls.name, pdb_object.id, "error", e)) 57 | return False 58 | pdb_object.complete(cls.name) 59 | utils.message_callbacks(callbacks, (cls.name, pdb_object.id, "end")) 60 | return True 61 | 62 | def computed(cls, pdb_object): 63 | """Check if the step was already computed. 64 | 65 | pdb_object : PDBObject 66 | PDB structure to be handled 67 | """ 68 | return cls.name in pdb_object.completed_steps 69 | 70 | 71 | def ready(cls, pdb_object): 72 | """Check if the dependencies of the step have been fulfilled. 73 | 74 | pdb_object : PDBObject 75 | PDB structure to be handled 76 | """ 77 | return all(map(lambda x: x.computed(pdb_object), cls.requirements)) 78 | 79 | def run_until(cls, pdb_object, callbacks=[], surrogate=False): 80 | """Run step and all dependencies necessary for its execution. 81 | 82 | pdb_object : PDBObject 83 | PDB structure to be handled 84 | callbacks : list of callable 85 | List of callbacks for logging 86 | surrogate : bool 87 | Run step as consequence of another step. Default False 88 | """ 89 | for requirement in cls.requirements: 90 | if not requirement.run_until(pdb_object, callbacks, surrogate=True): 91 | return False 92 | return cls.run(pdb_object,callbacks, force=False, surrogate=surrogate) 93 | 94 | 95 | -------------------------------------------------------------------------------- /RosENet/objects/pdb.py: -------------------------------------------------------------------------------- 1 | from types import SimpleNamespace 2 | from RosENet.storage.storage import * 3 | import RosENet.storage.storage as storage 4 | import RosENet.rosetta.rosetta as rosetta 5 | from RosENet import settings 6 | 7 | class _PDB: 8 | """Inner PDB class. Represents a PDB structure (a folder with a protein.pdb and ligand.mol2 files).""" 9 | 10 | _instance_dict = {} 11 | _property_tree = { "flags_relax" : "flags_relax.txt", 12 | "constraints" : "constraints", 13 | "ligand" : { 14 | "pdb" : "{code}_ligand.pdb", 15 | "mol2" : "{code}_ligand.mol2", 16 | "renamed_mol2" : "{code}_ligand_renamed.mol2", 17 | "pdbqt" : "{code}_ligand.pdbqt", 18 | "params" : "{code}_ligand.params" 19 | }, 20 | "protein" : { 21 | "pdb" : "{code}_protein.pdb", 22 | "pdbqt" : "{code}_protein.pdbqt" 23 | }, 24 | "complex" : { 25 | "pdb" : "{code}_complex.pdb" 26 | }, 27 | "minimized" : { 28 | "scores" : "score.sc", 29 | "hidden_complexes" : "other_complexes", 30 | "ligand" : { 31 | "mol2" : "{code}_ligand_{number}.mol2", 32 | "pdbqt": "{code}_ligand_{number}.pdbqt" 33 | }, 34 | "protein" : { 35 | "pdb" : "{code}_protein_{number}.pdb", 36 | "mol2" : "{code}_protein_{number}.mol2", 37 | "pdbqt": "{code}_protein_{number}.pdbqt" 38 | }, 39 | "complex" : { 40 | "pdb": "{code}_complex_{number}.pdb", 41 | "pdbqt": "{code}_complex_{number}.pdbqt", 42 | "attr": "{code}_complex_{number}.attr" 43 | } 44 | }, 45 | "image" : { 46 | "rosetta" : "{code}_rosetta.img", 47 | "htmd" : "{code}_htmd.img", 48 | "electronegativity" : "{code}_electroneg.img" 49 | } 50 | } 51 | 52 | def __init__(self, path): 53 | self.path = path 54 | self.id = path.name 55 | self.metadata_path = path / "metadata.json" 56 | if not self.metadata_path.exists(): 57 | storage.write_json(self.metadata_path, { "completed_steps" : []}) 58 | self.metadata = storage.read_json(self.metadata_path) 59 | _create(_PDB._property_tree, path, self.__dict__) 60 | from RosENet.objects.file import File 61 | self.image.combined = File.create(f"{self.id}.img", path.parent / settings.options) 62 | 63 | @property 64 | def completed_steps(self): 65 | return self.metadata["completed_steps"] 66 | 67 | def complete(self, step): 68 | self.metadata["completed_steps"].append(step) 69 | self.metadata["completed_steps"] = list(set(self.metadata["completed_steps"])) 70 | storage.write_json(self.metadata_path, self.metadata) 71 | 72 | def uncomplete(self, step): 73 | try: 74 | self.metadata["completed_steps"].remove(step) 75 | storage.write_json(self.metadata_path, self.metadata) 76 | except ValueError: 77 | pass 78 | 79 | def PDBObject(path): 80 | if str(path.absolute()) not in _PDB._instance_dict: 81 | _PDB._instance_dict[str(path.absolute())] = _PDB(path) 82 | return _PDB._instance_dict[str(path.absolute())] 83 | 84 | def _create(tree, path, result=None): 85 | from RosENet.objects.file import File 86 | root = True 87 | if not result: 88 | root = False 89 | result = {} 90 | for key, value in tree.items(): 91 | if isinstance(value, dict): 92 | result[key] = _create(value, path) 93 | else: 94 | result[key] = File.create(value, path) 95 | if root: 96 | return result 97 | return SimpleNamespace(**result) 98 | 99 | -------------------------------------------------------------------------------- /RosENet/preprocessing/minimize_rosetta.py: -------------------------------------------------------------------------------- 1 | from RosENet.preprocessing.step import Step 2 | from RosENet.preprocessing.make_complex_pdb import MakeComplexPDB 3 | from string import Template 4 | import RosENet.constants as constants 5 | import RosENet.rosetta.rosetta as rosetta 6 | import RosENet.storage.storage as storage 7 | 8 | def generate_minimization_flags_file(pdb_object): 9 | """Generate minimization flags file for Rosetta minimization. 10 | 11 | pdb_object : PDBObject 12 | PDB structure to be handled 13 | """ 14 | complex_path = pdb_object.complex.pdb.path 15 | complex = complex_path.name 16 | name = complex_path.stem 17 | params = pdb_object.ligand.params.path.name 18 | template = Template(storage.read_plain(constants.flags_relax_path)) 19 | substitution = {'complex' : complex, 20 | 'name' : name, 21 | 'params' : params} 22 | output = template.substitute(substitution) 23 | pdb_object.flags_relax.write(output) 24 | 25 | def generate_constraint_file(pdb_object): 26 | """Generate constraint file for Rosetta minimization. 27 | Constraints are added to maintain contact ions close to their original positions. 28 | 29 | pdb_object : PDBObject 30 | PDB structure to be handled 31 | """ 32 | complex = pdb_object.complex.pdb.read() 33 | output_path = pdb_object.constraints.path 34 | metals = complex.select(constants.metal_selector) 35 | results = [] 36 | if metals: 37 | for atom in metals: 38 | pos = atom.getCoords() 39 | close_ligand = complex.select(constants.close_ligand_selector, t=pos) 40 | if close_ligand: 41 | for close in close_ligand: 42 | results.append((atom.getName(), 43 | str(atom.getResnum())+atom.getChid(), 44 | close.getName(), 45 | str(close.getResnum())+close.getChid())) 46 | pdb_object.constraints.write( 47 | [f"AtomPair {r[0]} {r[1]} {r[2]} {r[3]} SQUARE_WELL 2.5 -2000\n" for r in results]) 48 | 49 | def hide_non_minimal_complexes(pdb_object): 50 | """Generate constraint file for Rosetta minimization. 51 | Constraints are added to maintain contact ions close to their original positions. 52 | 53 | pdb_object : PDBObject 54 | PDB structure to be handled 55 | """ 56 | scores = rosetta.parse_scores(pdb_object.minimized.scores.read()) 57 | hidden_folder = pdb_object.minimized.hidden_complexes.path 58 | storage.make_directory(hidden_folder) 59 | for number in list(scores.keys())[1:]: 60 | complex_path = pdb_object.minimized.complex.pdb[number].path 61 | print("Hiding ", complex_path) 62 | storage.move(complex_path, hidden_folder, no_fail=True) 63 | 64 | 65 | 66 | class MinimizeRosetta(metaclass=Step, requirements=[MakeComplexPDB]): 67 | @classmethod 68 | def files(cls, pdb_object): 69 | """List of files being created 70 | 71 | pdb_object : PDBObject 72 | PDB structure being handled 73 | """ 74 | return [pdb_object.flags_relax, 75 | pdb_object.constraints, 76 | pdb_object.minimized.hidden_complexes, 77 | pdb_object.minimized.complex.pdb, 78 | pdb_object.minimized.scores] 79 | 80 | @classmethod 81 | def _run(cls, pdb_object): 82 | """Inner function for the preprocessing step. 83 | 84 | pdb_object : PDBObject 85 | PDB structure being handled 86 | """ 87 | generate_minimization_flags_file(pdb_object) 88 | generate_constraint_file(pdb_object) 89 | if pdb_object.minimized.scores.path.exists(): 90 | hide_non_minimal_complexes(pdb_object) 91 | rosetta.minimize(working_directory = pdb_object.path) 92 | hide_non_minimal_complexes(pdb_object) 93 | 94 | -------------------------------------------------------------------------------- /RosENet/clui.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 4 | import sys 5 | from argparse import ArgumentParser 6 | from pathlib import Path 7 | from RosENet.objects.model import ModelObject 8 | from RosENet.objects.dataset import DatasetObject 9 | from RosENet.preprocessing import preprocess 10 | from RosENet.voxelization import voxelize 11 | from RosENet.postprocessing import postprocess 12 | from RosENet.preprocessing.minimize_rosetta import hide_non_minimal_complexes 13 | 14 | def parse_arguments(): 15 | """Parse arguments and perform the selected action.""" 16 | arguments = sys.argv[1:] 17 | optional_parser = ArgumentParser() 18 | optional_parser.add_argument("--njobs", default=None) 19 | optional_parser.add_argument("--gpu", default=None) 20 | optional_parser.add_argument("extra", nargs="*") 21 | optional = optional_parser.parse_args(arguments) 22 | njobs = optional.njobs 23 | gpu = optional.gpu 24 | if njobs is None: 25 | njobs = multiprocessing.cpu_count()-1 26 | else: 27 | njobs = int(njobs) 28 | print("#####",njobs) 29 | print(optional.extra) 30 | arguments = optional.extra 31 | os.environ["CUDA_VISIBLE_DEVICES"]="" 32 | if gpu is not None: 33 | os.environ["CUDA_VISIBLE_DEVICES"]=gpu 34 | from RosENet.network.network import train, evaluate, predict 35 | parser = ArgumentParser(description="RosENet Tool", 36 | usage='''tool [] 37 | 38 | preprocess Compute steps previous to voxelization 39 | voxelize Compute 3D image voxelization 40 | postprocess Format image for TensorFlow 41 | train Train neural network 42 | evaluate Evaluate neural network 43 | predict Predict binding affinity 44 | ''') 45 | parser.add_argument("action", help='Command to run') 46 | parser.add_argument("dataset", help='Dataset path') 47 | args = parser.parse_args(arguments[0:2]) 48 | action = args.action 49 | dataset = DatasetObject(Path(args.dataset)) 50 | if action in ["train", "evaluate", "predict"]: 51 | parser = ArgumentParser(usage= 52 | '''{train,evaluate,predict} second_dataset network channels [seed] 53 | 54 | second_dataset has the following meanings: 55 | If action is train: second_dataset is the validation dataset 56 | If action is evaluate: second_dataset is the evaluation dataset 57 | If action is predict: second_dataset is the prediction dataset 58 | 59 | In training if no seed is given, it will be randomly chosen 60 | The seed is necesary for evaluation and prediction to chose the 61 | correct instance of the trained model. 62 | ''') 63 | parser.add_argument("validation_dataset") 64 | parser.add_argument("network") 65 | parser.add_argument("channels") 66 | parser.add_argument("seed", nargs="?", default=None) 67 | args = parser.parse_args(arguments[2:6]) 68 | other_dataset= DatasetObject(Path(args.validation_dataset)) 69 | network = ModelObject(Path(args.network)) 70 | channels = args.channels 71 | seed = int(args.seed) 72 | if action in ["preprocess", "voxelize"]: 73 | import random 74 | pdbs = dataset.list() 75 | random.shuffle(pdbs) 76 | pdbs = map(dataset.__getitem__, pdbs) 77 | p = multiprocessing.Pool(gpu) 78 | if action == "preprocess": 79 | #p.map(hide_non_minimal_complexes, pdbs) 80 | p.map(preprocess, pdbs) 81 | elif action == "voxelize": 82 | p.map(voxelize, pdbs) 83 | elif action == "postprocess": 84 | postprocess(dataset) 85 | if action == "train": 86 | from RosENet.network.network import train 87 | train(dataset, other_dataset, network, seed, channels) 88 | elif action == "evaluate": 89 | from RosENet.network.network import evaluate 90 | evaluate(dataset, other_dataset, network, seed, channels) 91 | elif action == "predict": 92 | from RosENet.network.network import predict 93 | predict(dataset, other_dataset, network, seed, channels) 94 | -------------------------------------------------------------------------------- /RosENet/network/network.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 3 | import tensorflow as tf 4 | import random 5 | import RosENet.settings as settings 6 | from RosENet.network.utils import save_results 7 | import time 8 | 9 | def train(dataset_train, dataset_evaluate, model_object, seed=None, channels=""): 10 | """Training method 11 | 12 | dataset_train : DatasetObject 13 | Dataset to be used as training set 14 | dataset_evaluate : DatasetObject 15 | Dataset to be used as validation set 16 | model_object : ModelObject 17 | CNN model to train 18 | seed : int 19 | Seed for initializing randomness. If None, a random seed will be used. 20 | channels : string 21 | Channels selectors for choosing feature subsets 22 | """ 23 | if seed is None: 24 | seed = random.randint(0,2147483647) 25 | tf.set_random_seed(seed) 26 | model_train_object = model_object.train_object(dataset_train,channels,seed) 27 | model_valid_object = model_object.evaluate_object(dataset_evaluate,channels,seed) 28 | with tf.Session() as sess: 29 | sess.run(tf.global_variables_initializer()) 30 | min_validation = float('inf') 31 | epoch_min_validation = -1 32 | for num_epochs in range(settings.max_epochs): 33 | t1 = time.time() 34 | tr_loss = model_train_object.do(sess) 35 | t2 = time.time() - t1 36 | print(f"Train loss: {tr_loss} Elapsed: {t2}") 37 | t1 = time.time() 38 | va_loss = model_valid_object.do(sess) 39 | t2 = time.time() - t1 40 | print(f"Validation loss: {va_loss} Elapsed: {t2}") 41 | if min_validation > va_loss: 42 | min_validation = va_loss 43 | epoch_min_validation = num_epochs 44 | model_train_object.save(sess) 45 | print(f"Best validation: {min_validation} Epoch: {epoch_min_validation}") 46 | model_train_object.save(sess) 47 | save_results(model_train_object, channels, min_validation, epoch_min_validation, seed) 48 | 49 | def evaluate(dataset_train, dataset_evaluate, model_object, seed, channels=""): 50 | """Evaluation method 51 | 52 | dataset_train : DatasetObject 53 | Dataset used during the training phase 54 | dataset_evaluate : DatasetObject 55 | Dataset to be evaluated 56 | model_object : ModelObject 57 | CNN model used during the training phase 58 | seed : int 59 | Seed for initializing randomness. If None, a random seed will be used. 60 | channels : string 61 | Channels selectors for choosing feature subsets 62 | """ 63 | model_train_object = model_object.train_object(dataset_train, channels, seed) 64 | model_valid_object = model_object.evaluate_object(dataset_evaluate, channels, seed) 65 | with tf.Session() as sess: 66 | sess.run(tf.global_variables_initializer()) 67 | model_train_object.load(sess) 68 | val_loss = model_valid_object.do(sess) 69 | print(val_loss) 70 | return val_loss 71 | 72 | def predict(dataset_train, dataset_predict, model_object, seed, channels=""): 73 | """Predict method 74 | 75 | dataset_train : DatasetObject 76 | Dataset used during the training phase 77 | dataset_evaluate : DatasetObject 78 | Dataset to be predicted 79 | model_object : ModelObject 80 | CNN model used during the training phase 81 | seed : int 82 | Seed for initializing randomness. If None, a random seed will be used. 83 | channels : string 84 | Channels selectors for choosing feature subsets 85 | """ 86 | model_train_object = model_object.train_object(dataset_train, channels, seed) 87 | model_pred_object = model_object.predict_object(dataset_predict, channels, seed) 88 | with tf.Session() as sess: 89 | sess.run(tf.global_variables_initializer()) 90 | model_train_object.load(sess) 91 | results = model_pred_object.do(sess) 92 | print(results) 93 | return results 94 | -------------------------------------------------------------------------------- /RosENet/voxelization/apbs.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from subprocess import call 3 | import numpy as np 4 | from prody import writePQR 5 | from utils import isfloat 6 | 7 | 8 | class APBS: 9 | _APBS_BIN_PATH = '/cluster/home/hhussein/apbs.sh' 10 | _TEMPLATE_FILE = '/cluster/home/hhussein/apbs.in' 11 | @staticmethod 12 | def run( 13 | output_path, 14 | name, 15 | pose, 16 | selection, 17 | grid_dim, 18 | grid_space, 19 | center, 20 | cglen, 21 | fglen): 22 | apbs_bin_path = APBS._APBS_BIN_PATH 23 | apbs_template_file = Path(APBS._TEMPLATE_FILE) 24 | apbs_input_file = output_path / f'apbs_{name}.in' 25 | apbs_output_file = output_path / f'{name}_potential.dx' 26 | if apbs_input_file.exists(): 27 | apbs_input_file.unlink() 28 | if apbs_output_file.exists(): 29 | apbs_output_file.unlink() 30 | writePQR(f'{output_path/name}.pqr', pose.select(selection)) 31 | with apbs_template_file.open('r') as f: 32 | file_data = f.read() 33 | file_data = APBS._replace_apbs( 34 | file_data, str(output_path/name), grid_dim, grid_space, center, cglen, fglen) 35 | with apbs_input_file.open('w') as f: 36 | f.write(file_data) 37 | call([apbs_bin_path, 38 | f'{apbs_input_file.absolute()}'], 39 | cwd=str(output_path)) 40 | apbs_input_file.unlink() 41 | o, d, potential = APBS._import_dx(apbs_output_file) 42 | apbs_output_file.unlink() 43 | return potential 44 | 45 | @staticmethod 46 | def _replace_apbs( 47 | filedata, 48 | xxx, 49 | grid_dim, 50 | grid_space, 51 | center, 52 | cglen, 53 | fglen): 54 | filedata = filedata \ 55 | .replace('XXX', xxx) \ 56 | .replace('GRID_DIM', grid_dim) \ 57 | .replace('GRID_SPACE', grid_space) \ 58 | .replace('INH_CENTER', center) \ 59 | .replace('CG_LEN', cglen) \ 60 | .replace('FG_LEN', fglen) 61 | return filedata 62 | 63 | @staticmethod 64 | def _import_dx(filename): 65 | origin = delta = data = dims = None 66 | counter = 0 67 | with open(filename, 'r') as dxfile: 68 | for row in dxfile: 69 | row = row.strip().split() 70 | if not row: 71 | continue 72 | if row[0] == '#': 73 | continue 74 | elif row[0] == 'origin': 75 | origin = np.array(row[1:], dtype=float) 76 | elif row[0] == 'delta': 77 | delta = np.array(row[2:], dtype=float) 78 | elif row[0] == 'object': 79 | if row[1] == '1': 80 | dims = np.array(row[-3:], dtype=int) 81 | data = np.empty(np.prod(dims)) 82 | elif isfloat(row[0]): 83 | data[3 * counter:min(3 * (counter + 1), len(data)) 84 | ] = np.array(row, dtype=float) 85 | counter += 1 86 | data = data.reshape(dims) 87 | return origin, delta, data 88 | 89 | @staticmethod 90 | def _export_dx(filename, density, origin, delta): 91 | nx, ny, nz = density.shape 92 | with open(filename, 'w') as dxfile: 93 | dxfile.write( 94 | f'object 1 class gridpositions counts {nx} {ny} {nz}\n') 95 | dxfile.write(f'origin {origin[0]} {origin[1]} {origin[2]}\n') 96 | dxfile.write(f'delta {delta} 0.0 0.0\n') 97 | dxfile.write(f'delta 0.0 {delta} 0.0\n') 98 | dxfile.write(f'delta 0.0 0.0 {delta}\n') 99 | dxfile.write( 100 | f'object 2 class gridconnections counts {nx}, {ny}, {nz}\n') 101 | dxfile.write( 102 | f'object 3 class array type double rank 0 items {nx * ny * nz} data follows\n') 103 | i = 1 104 | for d in density.flatten(order='C'): 105 | if i % 3: 106 | dxfile.write('{} '.format(d)) 107 | else: 108 | dxfile.write('{}\n'.format(d)) 109 | i += 1 110 | 111 | dxfile.write('\n') 112 | -------------------------------------------------------------------------------- /RosENet/preprocessing/make_complex_pdb.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from RosENet.preprocessing.make_ligand_params_pdb import MakeLigandParamsPDB 3 | from RosENet.preprocessing.step import Step 4 | from prody import parsePDB, writePDB 5 | import RosENet.constants as constants 6 | from htmd.molecule.molecule import Molecule 7 | from htmd.molecule.voxeldescriptors import getVoxelDescriptors 8 | from htmd.builder.preparation import proteinPrepare 9 | 10 | def standardize_residues(structure): 11 | """Rename the residues of a structure to their standard names. 12 | 13 | structure : prody.AtomSelection 14 | Structure to be handled 15 | """ 16 | residue_names = structure.getResnames() 17 | for nonstd, std in constants.nonstd2stdresidues.items(): 18 | residue_names[residue_names==nonstd] = std 19 | structure.setResnames(residue_names) 20 | 21 | def fix_chains(structure): 22 | """Rename chain names to reserve special names 23 | (W for water, X for ligand, Z for metals) 24 | 25 | structure : prody.AtomSelection 26 | Structure to be handled 27 | """ 28 | chains = list(set(structure.getChids())) 29 | valid_chains = "ABCDEFGHIJKLMNOPQRSTUVwxYz0123456789abcdefghijklmnopqestuvy" 30 | chain_dict = dict(zip(chains, valid_chains[:len(chains)])) 31 | structure.setChids(np.vectorize(chain_dict.get)(structure.getChids())) 32 | 33 | def fix_ligand_names(structure): 34 | """Rename ligand residue names and chain names to their reserved names. 35 | 36 | structure : prody.AtomSelection 37 | Structure to be handled 38 | """ 39 | n_atoms = structure.numAtoms() 40 | structure.setResnames(np.array([constants.ligand_resname]*n_atoms)) 41 | structure.setChids(np.array([constants.ligand_chid]*n_atoms)) 42 | 43 | def fix_water_chains(structure): 44 | """Rename water molecule chain to its reserved name. 45 | 46 | structure : prody.AtomSelection 47 | Structure to be handled 48 | """ 49 | chains = structure.getChids() 50 | residues = structure.getResnames() 51 | chains[residues==constants.water_residue] = constants.water_chain 52 | structure.setChids(chains) 53 | 54 | def fix_metal_chains(structure): 55 | """Rename metal atom chain to its reserved name. 56 | 57 | structure : prody.AtomSelection 58 | Structure to be handled 59 | """ 60 | chains = structure.getChids() 61 | residues = structure.getResnames() 62 | for metal in constants.accepted_metals: 63 | chains[residues==metal] = constants.metal_chain 64 | structure.setChids(chains) 65 | 66 | def cleanup_and_merge(pdb_object): 67 | """Apply fixes to protein and ligand structures and write complex. 68 | 69 | pdb_object : PDBObject 70 | PDB structure to be handled 71 | """ 72 | protein = pdb_object.protein.pdb.read() 73 | ligand = pdb_object.ligand.pdb.read() 74 | fix_chains(protein) 75 | fix_ligand_names(ligand) 76 | complex = protein + ligand 77 | standardize_residues(complex) 78 | fix_water_chains(complex) 79 | fix_metal_chains(complex) 80 | pdb_object.complex.pdb.write(complex) 81 | 82 | def protein_optimization(complex_path): 83 | """Optimize protein inside complex file and rewrite complex. 84 | 85 | complex_path : pathlib.Path 86 | Path to complex file 87 | """ 88 | complex = Molecule(str(complex_path)) 89 | prot = complex.copy(); prot.filter("protein") 90 | lig = complex.copy(); lig.filter(constants.ligand_selector) 91 | #prot = proteinPrepare(prot, pH=7.0) 92 | mol = Molecule(name="complex") 93 | mol.append(prot) 94 | mol.append(lig) 95 | mol.write(str(complex_path)) 96 | 97 | class MakeComplexPDB(metaclass=Step,requirements=[MakeLigandParamsPDB]): 98 | @classmethod 99 | def files(cls, pdb_object): 100 | """List of files being created 101 | 102 | pdb_object : PDBObject 103 | PDB structure being handled 104 | """ 105 | return [pdb_object.complex.pdb] 106 | 107 | @classmethod 108 | def _run(cls, pdb_object): 109 | """Inner function for the preprocessing step. 110 | 111 | pdb_object : PDBObject 112 | PDB structure being handled 113 | """ 114 | complex_path = pdb_object.complex.pdb.path 115 | cleanup_and_merge(pdb_object) 116 | protein_optimization(complex_path) 117 | complex = pdb_object.complex.pdb.read() 118 | standardize_residues(complex) 119 | pdb_object.complex.pdb.write(complex) 120 | 121 | -------------------------------------------------------------------------------- /RosENet/preprocessing/preprocessing.py: -------------------------------------------------------------------------------- 1 | from string import Template 2 | from .make_complex_pdb import make_complex_pdb as _make_complex_pdb 3 | from .make_ligand_mol2_renamed import make_ligand_mol2_renamed as _make_ligand_mol2_renamed 4 | from .compute_rosetta_energy import compute_rosetta_energy as _compute_rosetta_energy 5 | import RosENet.rosetta.rosetta as rosetta 6 | import RosENet.storage.storage as storage 7 | import RosENet.constants as constants 8 | import subprocess 9 | 10 | def make_ligand_params_pdb(pdb_object): 11 | ligand_mol2_path = pdb_object.ligand.mol2.path 12 | params_filename = ligand_mol2_path.stem 13 | working_directory = ligand_mol2_path.parent 14 | return rosetta.molfile_to_params( 15 | working_directory = working_directory, 16 | output_path = params_filename, 17 | input_path = ligand_mol2_path) 18 | 19 | def make_complex_pdb(pdb_object): 20 | protein_path = pdb_object.protein.pdb.path 21 | ligand_path = pdb_object.ligand.pdb.path 22 | complex_path = pdb_object.complex.pdb.path 23 | return _make_complex_pdb(protein_path, ligand_path, complex_path) 24 | 25 | def minimize_rosetta(pdb_object): 26 | generate_minimization_flags_file(pdb_object) 27 | generate_constraint_file(pdb_object) 28 | working_directory = pdb_object.path 29 | pdb_object.minimized.scores.delete() 30 | rosetta.minimize(working_directory = working_directory) 31 | hide_non_minimal_complexes(pdb_object) 32 | 33 | def generate_minimization_flags_file(pdb_object): 34 | complex_path = pdb_object.complex.pdb.path 35 | complex = complex_path.name 36 | name = complex_path.stem 37 | params = pdb_object.ligand.params.path.name 38 | template = Template(storage.read_plain(constants.flags_relax_path)) 39 | substitution = {'complex' : complex, 40 | 'name' : name, 41 | 'params' : params} 42 | output = template.substitute(substitution) 43 | pdb_object.flags_relax.write(output) 44 | 45 | def generate_constraint_file(pdb_object): 46 | complex = pdb_object.complex.pdb.read() 47 | output_path = pdb_object.constraints.path 48 | metals = complex.select(constants.metal_selector) 49 | results = [] 50 | if metals: 51 | for atom in metals: 52 | pos = atom.getCoords() 53 | close_ligand = complex.select(constants.close_ligand_selector, t=pos) 54 | if close_ligand: 55 | for close in close_ligand: 56 | results.append((atom.getName(), 57 | str(atom.getResnum())+atom.getChid(), 58 | close.getName(), 59 | str(close.getResnum())+close.getChid())) 60 | pdb_object.constraints.write( 61 | [f"AtomPair {r[0]} {r[1]} {r[2]} {r[3]} SQUARE_WELL 2.5 -2000\n" for r in results]) 62 | 63 | def hide_non_minimal_complexes(pdb_object): 64 | scores = rosetta.parse_scores(pdb_object.minimized.scores.read()) 65 | hidden_folder = pdb_object.minimized.hidden_complexes.path 66 | hidden_folder.delete() 67 | storage.make_directory(hidden_folder) 68 | for number in list(scores.keys())[1:]: 69 | complex_path = pdb_object.minimized.complex.pdb[number].path 70 | storage.move(complex_path, hidden_folder, no_fail=True) 71 | 72 | def make_protein_pdb(pdb_object): 73 | complex = pdb_object.minimized.complex.pdb.read() 74 | protein = complex.select(constants.protein_selector) 75 | pdb_object.minimized.protein.pdb.write(protein) 76 | 77 | def make_ligand_mol2(pdb_object): 78 | complex_path = pdb_object.minimized.complex.pdb.path 79 | renamed_mol2_path = pdb_object.ligand.renamed_mol2.path 80 | ligand_path = pdb_object.minimized.ligand.mol2.path 81 | rosetta.pdb_to_molfile( 82 | mol = renamed_mol2_path, 83 | complex_path = complex_path, 84 | output_path = ligand_path) 85 | 86 | def make_pdbqt(pdb_object): 87 | path = pdb_object.path 88 | subprocess.run([constants.mgl_python_path, 89 | constants.preprocess_vina_path, 90 | path]) 91 | 92 | def compute_rosetta_energy(pdb_object): 93 | complex_path = pdb_object.minimized.complex.pdb.path 94 | params = pdb_object.ligand.params.path 95 | output_path = pdb_object.minimized.complex.attr.path 96 | _compute_rosetta_energy(complex_path, params, output_path) 97 | 98 | def make_ligand_mol2_renamed(pdb_object): 99 | ligand_pdb_path = pdb_object.ligand.pdb.path 100 | ligand_mol2_path = pdb_object.ligand.mol2.path 101 | output_path = pdb_object.ligand.renamed_mol2.path 102 | _make_ligand_mol2_renamed(ligand_pdb_path, ligand_mol2_path, output_path) 103 | 104 | -------------------------------------------------------------------------------- /RosENet/storage/.ropeproject/config.py: -------------------------------------------------------------------------------- 1 | # The default ``config.py`` 2 | # flake8: noqa 3 | 4 | 5 | def set_prefs(prefs): 6 | """This function is called before opening the project""" 7 | 8 | # Specify which files and folders to ignore in the project. 9 | # Changes to ignored resources are not added to the history and 10 | # VCSs. Also they are not returned in `Project.get_files()`. 11 | # Note that ``?`` and ``*`` match all characters but slashes. 12 | # '*.pyc': matches 'test.pyc' and 'pkg/test.pyc' 13 | # 'mod*.pyc': matches 'test/mod1.pyc' but not 'mod/1.pyc' 14 | # '.svn': matches 'pkg/.svn' and all of its children 15 | # 'build/*.o': matches 'build/lib.o' but not 'build/sub/lib.o' 16 | # 'build//*.o': matches 'build/lib.o' and 'build/sub/lib.o' 17 | prefs['ignored_resources'] = ['*.pyc', '*~', '.ropeproject', 18 | '.hg', '.svn', '_svn', '.git', '.tox'] 19 | 20 | # Specifies which files should be considered python files. It is 21 | # useful when you have scripts inside your project. Only files 22 | # ending with ``.py`` are considered to be python files by 23 | # default. 24 | #prefs['python_files'] = ['*.py'] 25 | 26 | # Custom source folders: By default rope searches the project 27 | # for finding source folders (folders that should be searched 28 | # for finding modules). You can add paths to that list. Note 29 | # that rope guesses project source folders correctly most of the 30 | # time; use this if you have any problems. 31 | # The folders should be relative to project root and use '/' for 32 | # separating folders regardless of the platform rope is running on. 33 | # 'src/my_source_folder' for instance. 34 | #prefs.add('source_folders', 'src') 35 | 36 | # You can extend python path for looking up modules 37 | #prefs.add('python_path', '~/python/') 38 | 39 | # Should rope save object information or not. 40 | prefs['save_objectdb'] = True 41 | prefs['compress_objectdb'] = False 42 | 43 | # If `True`, rope analyzes each module when it is being saved. 44 | prefs['automatic_soa'] = True 45 | # The depth of calls to follow in static object analysis 46 | prefs['soa_followed_calls'] = 0 47 | 48 | # If `False` when running modules or unit tests "dynamic object 49 | # analysis" is turned off. This makes them much faster. 50 | prefs['perform_doa'] = True 51 | 52 | # Rope can check the validity of its object DB when running. 53 | prefs['validate_objectdb'] = True 54 | 55 | # How many undos to hold? 56 | prefs['max_history_items'] = 32 57 | 58 | # Shows whether to save history across sessions. 59 | prefs['save_history'] = True 60 | prefs['compress_history'] = False 61 | 62 | # Set the number spaces used for indenting. According to 63 | # :PEP:`8`, it is best to use 4 spaces. Since most of rope's 64 | # unit-tests use 4 spaces it is more reliable, too. 65 | prefs['indent_size'] = 4 66 | 67 | # Builtin and c-extension modules that are allowed to be imported 68 | # and inspected by rope. 69 | prefs['extension_modules'] = [] 70 | 71 | # Add all standard c-extensions to extension_modules list. 72 | prefs['import_dynload_stdmods'] = True 73 | 74 | # If `True` modules with syntax errors are considered to be empty. 75 | # The default value is `False`; When `False` syntax errors raise 76 | # `rope.base.exceptions.ModuleSyntaxError` exception. 77 | prefs['ignore_syntax_errors'] = False 78 | 79 | # If `True`, rope ignores unresolvable imports. Otherwise, they 80 | # appear in the importing namespace. 81 | prefs['ignore_bad_imports'] = False 82 | 83 | # If `True`, rope will insert new module imports as 84 | # `from import ` by default. 85 | prefs['prefer_module_from_imports'] = False 86 | 87 | # If `True`, rope will transform a comma list of imports into 88 | # multiple separate import statements when organizing 89 | # imports. 90 | prefs['split_imports'] = False 91 | 92 | # If `True`, rope will remove all top-level import statements and 93 | # reinsert them at the top of the module when making changes. 94 | prefs['pull_imports_to_top'] = True 95 | 96 | # If `True`, rope will sort imports alphabetically by module name instead of 97 | # alphabetically by import statement, with from imports after normal 98 | # imports. 99 | prefs['sort_imports_alphabetically'] = False 100 | 101 | # Location of implementation of rope.base.oi.type_hinting.interfaces.ITypeHintingFactory 102 | # In general case, you don't have to change this value, unless you're an rope expert. 103 | # Change this value to inject you own implementations of interfaces 104 | # listed in module rope.base.oi.type_hinting.providers.interfaces 105 | # For example, you can add you own providers for Django Models, or disable the search 106 | # type-hinting in a class hierarchy, etc. 107 | prefs['type_hinting_factory'] = 'rope.base.oi.type_hinting.factory.default_type_hinting_factory' 108 | 109 | 110 | def project_opened(project): 111 | """This function is called after opening the project""" 112 | # Do whatever you like here! 113 | -------------------------------------------------------------------------------- /RosENet/postprocessing/.ropeproject/config.py: -------------------------------------------------------------------------------- 1 | # The default ``config.py`` 2 | # flake8: noqa 3 | 4 | 5 | def set_prefs(prefs): 6 | """This function is called before opening the project""" 7 | 8 | # Specify which files and folders to ignore in the project. 9 | # Changes to ignored resources are not added to the history and 10 | # VCSs. Also they are not returned in `Project.get_files()`. 11 | # Note that ``?`` and ``*`` match all characters but slashes. 12 | # '*.pyc': matches 'test.pyc' and 'pkg/test.pyc' 13 | # 'mod*.pyc': matches 'test/mod1.pyc' but not 'mod/1.pyc' 14 | # '.svn': matches 'pkg/.svn' and all of its children 15 | # 'build/*.o': matches 'build/lib.o' but not 'build/sub/lib.o' 16 | # 'build//*.o': matches 'build/lib.o' and 'build/sub/lib.o' 17 | prefs['ignored_resources'] = ['*.pyc', '*~', '.ropeproject', 18 | '.hg', '.svn', '_svn', '.git', '.tox'] 19 | 20 | # Specifies which files should be considered python files. It is 21 | # useful when you have scripts inside your project. Only files 22 | # ending with ``.py`` are considered to be python files by 23 | # default. 24 | #prefs['python_files'] = ['*.py'] 25 | 26 | # Custom source folders: By default rope searches the project 27 | # for finding source folders (folders that should be searched 28 | # for finding modules). You can add paths to that list. Note 29 | # that rope guesses project source folders correctly most of the 30 | # time; use this if you have any problems. 31 | # The folders should be relative to project root and use '/' for 32 | # separating folders regardless of the platform rope is running on. 33 | # 'src/my_source_folder' for instance. 34 | #prefs.add('source_folders', 'src') 35 | 36 | # You can extend python path for looking up modules 37 | #prefs.add('python_path', '~/python/') 38 | 39 | # Should rope save object information or not. 40 | prefs['save_objectdb'] = True 41 | prefs['compress_objectdb'] = False 42 | 43 | # If `True`, rope analyzes each module when it is being saved. 44 | prefs['automatic_soa'] = True 45 | # The depth of calls to follow in static object analysis 46 | prefs['soa_followed_calls'] = 0 47 | 48 | # If `False` when running modules or unit tests "dynamic object 49 | # analysis" is turned off. This makes them much faster. 50 | prefs['perform_doa'] = True 51 | 52 | # Rope can check the validity of its object DB when running. 53 | prefs['validate_objectdb'] = True 54 | 55 | # How many undos to hold? 56 | prefs['max_history_items'] = 32 57 | 58 | # Shows whether to save history across sessions. 59 | prefs['save_history'] = True 60 | prefs['compress_history'] = False 61 | 62 | # Set the number spaces used for indenting. According to 63 | # :PEP:`8`, it is best to use 4 spaces. Since most of rope's 64 | # unit-tests use 4 spaces it is more reliable, too. 65 | prefs['indent_size'] = 4 66 | 67 | # Builtin and c-extension modules that are allowed to be imported 68 | # and inspected by rope. 69 | prefs['extension_modules'] = [] 70 | 71 | # Add all standard c-extensions to extension_modules list. 72 | prefs['import_dynload_stdmods'] = True 73 | 74 | # If `True` modules with syntax errors are considered to be empty. 75 | # The default value is `False`; When `False` syntax errors raise 76 | # `rope.base.exceptions.ModuleSyntaxError` exception. 77 | prefs['ignore_syntax_errors'] = False 78 | 79 | # If `True`, rope ignores unresolvable imports. Otherwise, they 80 | # appear in the importing namespace. 81 | prefs['ignore_bad_imports'] = False 82 | 83 | # If `True`, rope will insert new module imports as 84 | # `from import ` by default. 85 | prefs['prefer_module_from_imports'] = False 86 | 87 | # If `True`, rope will transform a comma list of imports into 88 | # multiple separate import statements when organizing 89 | # imports. 90 | prefs['split_imports'] = False 91 | 92 | # If `True`, rope will remove all top-level import statements and 93 | # reinsert them at the top of the module when making changes. 94 | prefs['pull_imports_to_top'] = True 95 | 96 | # If `True`, rope will sort imports alphabetically by module name instead of 97 | # alphabetically by import statement, with from imports after normal 98 | # imports. 99 | prefs['sort_imports_alphabetically'] = False 100 | 101 | # Location of implementation of rope.base.oi.type_hinting.interfaces.ITypeHintingFactory 102 | # In general case, you don't have to change this value, unless you're an rope expert. 103 | # Change this value to inject you own implementations of interfaces 104 | # listed in module rope.base.oi.type_hinting.providers.interfaces 105 | # For example, you can add you own providers for Django Models, or disable the search 106 | # type-hinting in a class hierarchy, etc. 107 | prefs['type_hinting_factory'] = 'rope.base.oi.type_hinting.factory.default_type_hinting_factory' 108 | 109 | 110 | def project_opened(project): 111 | """This function is called after opening the project""" 112 | # Do whatever you like here! 113 | -------------------------------------------------------------------------------- /RosENet/preprocessing/.ropeproject/config.py: -------------------------------------------------------------------------------- 1 | # The default ``config.py`` 2 | # flake8: noqa 3 | 4 | 5 | def set_prefs(prefs): 6 | """This function is called before opening the project""" 7 | 8 | # Specify which files and folders to ignore in the project. 9 | # Changes to ignored resources are not added to the history and 10 | # VCSs. Also they are not returned in `Project.get_files()`. 11 | # Note that ``?`` and ``*`` match all characters but slashes. 12 | # '*.pyc': matches 'test.pyc' and 'pkg/test.pyc' 13 | # 'mod*.pyc': matches 'test/mod1.pyc' but not 'mod/1.pyc' 14 | # '.svn': matches 'pkg/.svn' and all of its children 15 | # 'build/*.o': matches 'build/lib.o' but not 'build/sub/lib.o' 16 | # 'build//*.o': matches 'build/lib.o' and 'build/sub/lib.o' 17 | prefs['ignored_resources'] = ['*.pyc', '*~', '.ropeproject', 18 | '.hg', '.svn', '_svn', '.git', '.tox'] 19 | 20 | # Specifies which files should be considered python files. It is 21 | # useful when you have scripts inside your project. Only files 22 | # ending with ``.py`` are considered to be python files by 23 | # default. 24 | #prefs['python_files'] = ['*.py'] 25 | 26 | # Custom source folders: By default rope searches the project 27 | # for finding source folders (folders that should be searched 28 | # for finding modules). You can add paths to that list. Note 29 | # that rope guesses project source folders correctly most of the 30 | # time; use this if you have any problems. 31 | # The folders should be relative to project root and use '/' for 32 | # separating folders regardless of the platform rope is running on. 33 | # 'src/my_source_folder' for instance. 34 | #prefs.add('source_folders', 'src') 35 | 36 | # You can extend python path for looking up modules 37 | #prefs.add('python_path', '~/python/') 38 | 39 | # Should rope save object information or not. 40 | prefs['save_objectdb'] = True 41 | prefs['compress_objectdb'] = False 42 | 43 | # If `True`, rope analyzes each module when it is being saved. 44 | prefs['automatic_soa'] = True 45 | # The depth of calls to follow in static object analysis 46 | prefs['soa_followed_calls'] = 0 47 | 48 | # If `False` when running modules or unit tests "dynamic object 49 | # analysis" is turned off. This makes them much faster. 50 | prefs['perform_doa'] = True 51 | 52 | # Rope can check the validity of its object DB when running. 53 | prefs['validate_objectdb'] = True 54 | 55 | # How many undos to hold? 56 | prefs['max_history_items'] = 32 57 | 58 | # Shows whether to save history across sessions. 59 | prefs['save_history'] = True 60 | prefs['compress_history'] = False 61 | 62 | # Set the number spaces used for indenting. According to 63 | # :PEP:`8`, it is best to use 4 spaces. Since most of rope's 64 | # unit-tests use 4 spaces it is more reliable, too. 65 | prefs['indent_size'] = 4 66 | 67 | # Builtin and c-extension modules that are allowed to be imported 68 | # and inspected by rope. 69 | prefs['extension_modules'] = [] 70 | 71 | # Add all standard c-extensions to extension_modules list. 72 | prefs['import_dynload_stdmods'] = True 73 | 74 | # If `True` modules with syntax errors are considered to be empty. 75 | # The default value is `False`; When `False` syntax errors raise 76 | # `rope.base.exceptions.ModuleSyntaxError` exception. 77 | prefs['ignore_syntax_errors'] = False 78 | 79 | # If `True`, rope ignores unresolvable imports. Otherwise, they 80 | # appear in the importing namespace. 81 | prefs['ignore_bad_imports'] = False 82 | 83 | # If `True`, rope will insert new module imports as 84 | # `from import ` by default. 85 | prefs['prefer_module_from_imports'] = False 86 | 87 | # If `True`, rope will transform a comma list of imports into 88 | # multiple separate import statements when organizing 89 | # imports. 90 | prefs['split_imports'] = False 91 | 92 | # If `True`, rope will remove all top-level import statements and 93 | # reinsert them at the top of the module when making changes. 94 | prefs['pull_imports_to_top'] = True 95 | 96 | # If `True`, rope will sort imports alphabetically by module name instead of 97 | # alphabetically by import statement, with from imports after normal 98 | # imports. 99 | prefs['sort_imports_alphabetically'] = False 100 | 101 | # Location of implementation of rope.base.oi.type_hinting.interfaces.ITypeHintingFactory 102 | # In general case, you don't have to change this value, unless you're an rope expert. 103 | # Change this value to inject you own implementations of interfaces 104 | # listed in module rope.base.oi.type_hinting.providers.interfaces 105 | # For example, you can add you own providers for Django Models, or disable the search 106 | # type-hinting in a class hierarchy, etc. 107 | prefs['type_hinting_factory'] = 'rope.base.oi.type_hinting.factory.default_type_hinting_factory' 108 | 109 | 110 | def project_opened(project): 111 | """This function is called after opening the project""" 112 | # Do whatever you like here! 113 | -------------------------------------------------------------------------------- /RosENet/preprocessing/preprocess_vina.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module extracted from prepare_receptor4.py and prepare_ligand4.py from AutoDockTools scripts. 3 | http://autodock.scripps.edu/faqs-help/how-to/how-to-prepare-a-receptor-file-for-autodock4 4 | http://autodock.scripps.edu/faqs-help/how-to/how-to-prepare-a-ligand-file-for-autodock4 5 | """ 6 | import os 7 | from MolKit import Read 8 | import MolKit.molecule 9 | import MolKit.protein 10 | from AutoDockTools.MoleculePreparation import AD4ReceptorPreparation, AD4LigandPreparation 11 | import sys 12 | import getopt 13 | 14 | 15 | def preprocess_receptor(receptor_filename, outputfilename): 16 | repairs = '' 17 | charges_to_add = 'gasteiger' 18 | preserve_charge_types=None 19 | cleanup = "" 20 | mode = "automatic" 21 | delete_single_nonstd_residues = False 22 | dictionary = None 23 | 24 | mols = Read(receptor_filename) 25 | mol = mols[0] 26 | preserved = {} 27 | if charges_to_add is not None and preserve_charge_types is not None: 28 | preserved_types = preserve_charge_types.split(',') 29 | for t in preserved_types: 30 | if not len(t): continue 31 | ats = mol.allAtoms.get(lambda x: x.autodock_element==t) 32 | for a in ats: 33 | if a.chargeSet is not None: 34 | preserved[a] = [a.chargeSet, a.charge] 35 | 36 | if len(mols)>1: 37 | ctr = 1 38 | for m in mols[1:]: 39 | ctr += 1 40 | if len(m.allAtoms)>len(mol.allAtoms): 41 | mol = m 42 | mol.buildBondsByDistance() 43 | 44 | RPO = AD4ReceptorPreparation(mol, mode, repairs, charges_to_add, 45 | cleanup, outputfilename=outputfilename, 46 | preserved=preserved, 47 | delete_single_nonstd_residues=delete_single_nonstd_residues, 48 | dict=dictionary) 49 | 50 | if charges_to_add is not None: 51 | for atom, chargeList in preserved.items(): 52 | atom._charges[chargeList[0]] = chargeList[1] 53 | atom.chargeSet = chargeList[0] 54 | 55 | def preprocess_ligand(ligand_filename, outputfilename): 56 | verbose = None 57 | repairs = "" #"hydrogens_bonds" 58 | charges_to_add = 'gasteiger' 59 | preserve_charge_types='' 60 | cleanup = "" 61 | allowed_bonds = "backbone" 62 | root = 'auto' 63 | check_for_fragments = True 64 | bonds_to_inactivate = "" 65 | inactivate_all_torsions = True 66 | attach_nonbonded_fragments = True 67 | attach_singletons = True 68 | mode = "automatic" 69 | dict = None 70 | 71 | mols = Read(ligand_filename) 72 | if verbose: print 'read ', ligand_filename 73 | mol = mols[0] 74 | if len(mols)>1: 75 | ctr = 1 76 | for m in mols[1:]: 77 | ctr += 1 78 | if len(m.allAtoms)>len(mol.allAtoms): 79 | mol = m 80 | coord_dict = {} 81 | for a in mol.allAtoms: coord_dict[a] = a.coords 82 | 83 | mol.buildBondsByDistance() 84 | if charges_to_add is not None: 85 | preserved = {} 86 | preserved_types = preserve_charge_types.split(',') 87 | for t in preserved_types: 88 | if not len(t): continue 89 | ats = mol.allAtoms.get(lambda x: x.autodock_element==t) 90 | for a in ats: 91 | if a.chargeSet is not None: 92 | preserved[a] = [a.chargeSet, a.charge] 93 | 94 | 95 | 96 | LPO = AD4LigandPreparation(mol, mode, repairs, charges_to_add, 97 | cleanup, allowed_bonds, root, 98 | outputfilename=outputfilename, 99 | dict=dict, check_for_fragments=check_for_fragments, 100 | bonds_to_inactivate=bonds_to_inactivate, 101 | inactivate_all_torsions=inactivate_all_torsions, 102 | attach_nonbonded_fragments=attach_nonbonded_fragments, 103 | attach_singletons=attach_singletons) 104 | if charges_to_add is not None: 105 | for atom, chargeList in preserved.items(): 106 | atom._charges[chargeList[0]] = chargeList[1] 107 | atom.chargeSet = chargeList[0] 108 | bad_list = [] 109 | for a in mol.allAtoms: 110 | if a in coord_dict.keys() and a.coords!=coord_dict[a]: 111 | bad_list.append(a) 112 | if len(bad_list): 113 | print len(bad_list), ' atom coordinates changed!' 114 | for a in bad_list: 115 | print a.name, ":", coord_dict[a], ' -> ', a.coords 116 | else: 117 | if verbose: print "No change in atomic coordinates" 118 | if mol.returnCode!=0: 119 | sys.stderr.write(mol.returnMsg+"\n") 120 | 121 | def process_folder(receptor, ligand): 122 | try: 123 | preprocess_receptor(receptor, receptor.replace('.pdb','.pdbqt')) 124 | except Exception, e: 125 | print 'Protein', receptor 126 | raise e 127 | try: 128 | preprocess_ligand(ligand, ligand.replace('.mol2','.pdbqt')) 129 | except Exception, e: 130 | print 'Ligand', ligand 131 | raise e 132 | 133 | if __name__ == "__main__": 134 | process_folder(sys.argv[1], sys.argv[2]) 135 | -------------------------------------------------------------------------------- /RosENet/objects/model.py: -------------------------------------------------------------------------------- 1 | import RosENet.storage.storage as storage 2 | import RosENet.network.input as input 3 | import RosENet.settings as settings 4 | import tensorflow as tf 5 | import importlib.util 6 | import numpy as np 7 | from RosENet.objects.file import File 8 | 9 | class _ModelAction: 10 | """Model action base class. Represents an action (training, evaluation or prediction) 11 | of a dataset using a model(neural network)""" 12 | def __init__(self, model_object, dataset_object, channels, action, seed): 13 | self.dataset_object = dataset_object 14 | if not self.results.path.exists(): 15 | self.results.write("") 16 | self.model_object = model_object 17 | self.channels = channels 18 | input_fn = input.load_fn(dataset_object, channels, settings.rotate, action) 19 | self.iterator = input_fn().make_initializable_iterator() 20 | model_fn = model_object.load_fn() 21 | self.id, self.X, self.y = self.iterator.get_next() 22 | self.shape = tf.shape(self.y)[0] 23 | model = model_fn(self.X, self.y, action, settings.rotate) 24 | self.op = model.train_op 25 | self.loss = model.loss 26 | self.predictions = model.predictions 27 | checkpoint_folder = self.dataset_object.model(self.model_object, self.channels, seed) 28 | storage.make_directory(checkpoint_folder) 29 | self.save_path = checkpoint_folder / "model.ckpt" 30 | self.saver = tf.train.Saver() 31 | 32 | def do(self, sess = None): 33 | if sess is None: 34 | sess = tf.get_default_session() 35 | try: 36 | sess.run(self.iterator.initializer) 37 | self.start_epoch() 38 | while True: 39 | self.do_batch(sess) 40 | except (KeyboardInterrupt, SystemExit): 41 | raise 42 | except Exception as e: 43 | pass 44 | return self.end_epoch() 45 | 46 | def save(self, sess = None): 47 | if sess is None: 48 | sess = tf.get_default_session() 49 | self.saver.save(sess, str(self.save_path.absolute())) 50 | 51 | @property 52 | def dataset_name(self): 53 | return self.dataset_object.name 54 | 55 | @property 56 | def model_name(self): 57 | return self.model_object.name 58 | 59 | @property 60 | def results(self): 61 | return File.create(f"{self.dataset_name}.txt",self.dataset_object.path) 62 | 63 | def load(self, sess = None): 64 | if sess is None: 65 | sess = tf.get_default_session() 66 | self.saver.restore(sess, str(self.save_path)) 67 | 68 | class _ModelTrain(_ModelAction): 69 | def __init__(self, model_path, dataset_path, channels,seed): 70 | super(_ModelTrain, self).__init__(model_path, 71 | dataset_path, 72 | channels, 73 | tf.estimator.ModeKeys.TRAIN, 74 | seed) 75 | def start_epoch(self): 76 | self.size = 0 77 | self.loss_value = 0 78 | 79 | def do_batch(self, sess): 80 | _, loss, batch_shape = sess.run([self.op, self.loss, self.shape]) 81 | self.size += batch_shape 82 | self.loss_value = batch_shape*loss 83 | 84 | def end_epoch(self): 85 | return np.sqrt(self.loss_value/self.size) 86 | 87 | class _ModelEvaluate(_ModelAction): 88 | def __init__(self, model_path, dataset_path, channels, seed): 89 | super(_ModelEvaluate, self).__init__(model_path, 90 | dataset_path, 91 | channels, 92 | tf.estimator.ModeKeys.EVAL, 93 | seed) 94 | def start_epoch(self): 95 | self.size = 0 96 | self.loss_value = 0 97 | 98 | def do_batch(self, sess): 99 | loss, batch_shape = sess.run([self.loss, self.shape]) 100 | self.size += batch_shape 101 | self.loss_value = batch_shape*loss 102 | 103 | def end_epoch(self): 104 | return np.sqrt(self.loss_value/self.size) 105 | 106 | class _ModelPredict(_ModelAction): 107 | def __init__(self, model_path, dataset_path, channels, seed): 108 | super(_ModelPredict, self).__init__(model_path, 109 | dataset_path, 110 | channels, 111 | tf.estimator.ModeKeys.EVAL, 112 | seed) 113 | def start_epoch(self): 114 | self.prediction_dict = {} 115 | 116 | def do_batch(self, sess): 117 | names, predictions = sess.run([self.id, self.predictions]) 118 | for name, prediction in zip(names, predictions): 119 | self.prediction_dict[str(name)] = prediction 120 | 121 | def end_epoch(self): 122 | return self.prediction_dict 123 | 124 | class _Model: 125 | """Inner Model class, represents a CNN.""" 126 | _instance_dict = {} 127 | def __init__(self, path): 128 | self.path = path 129 | 130 | def train_object(self, dataset_object, channels, seed): 131 | return _ModelTrain(self, dataset_object, channels, seed) 132 | 133 | def evaluate_object(self, dataset_object, channels, seed): 134 | return _ModelEvaluate(self, dataset_object, channels, seed) 135 | 136 | def predict_object(self, dataset_object, channels, seed): 137 | return _ModelPredict(self, dataset_object, channels, seed) 138 | 139 | @property 140 | def name(self): 141 | return self.path.stem 142 | 143 | def load_fn(self): 144 | spec = importlib.util.spec_from_file_location("aux_module", str(self.path.absolute())) 145 | foo = importlib.util.module_from_spec(spec) 146 | spec.loader.exec_module(foo) 147 | return foo.model_fn 148 | #return __import__(str(self.path.absolute()), fromlist=["model_fn"]) 149 | 150 | def ModelObject(path): 151 | if str(path.absolute()) not in _Model._instance_dict: 152 | _Model._instance_dict[str(path.absolute())] = _Model(path) 153 | return _Model._instance_dict[str(path.absolute())] 154 | 155 | -------------------------------------------------------------------------------- /RosENet/network/input.py: -------------------------------------------------------------------------------- 1 | import random 2 | import tensorflow as tf 3 | import numpy as np 4 | from RosENet.network.utils import random_rot, rots_90, random_rot_90, all_rot_90 5 | import RosENet.settings as settings 6 | 7 | def parse_fn(shape): 8 | """Parse function for the TensorFlow pipeline. 9 | The inner function outputs three tensors: the name of the complex, 10 | the image and the target binding affinity. 11 | 12 | shape : list of int 13 | 4D Shape of the image inputted 14 | """ 15 | def fn(example): 16 | example_fmt = { 17 | "id": tf.VarLenFeature(tf.string), 18 | "X": tf.FixedLenFeature((int(np.prod(shape)),), tf.float32), 19 | "y": tf.FixedLenFeature((1,), tf.float32) 20 | } 21 | parsed = tf.parse_single_example(example, example_fmt) 22 | X = tf.reshape(parsed['X'], shape=list(shape)) 23 | y = parsed['y'] 24 | id = parsed['id'].values 25 | return id, X, y 26 | return fn 27 | 28 | def rotate_fn(training, shape): 29 | """Rotation function for the TensorFlow pipeline. 30 | The inner function rotates the images. If training, the image will be rotated randomly to one of its 24 90º rotations. Otherwise, each image will be expanded to 24 images representing all possible 90º rotations. 31 | 32 | training : bool 33 | True if training 34 | shape : list of int 35 | 4D Shape of the images inputted 36 | """ 37 | def fn(id, X, y): 38 | if training: 39 | X = random_rot_90(X,list((-1,) + tuple(shape))) 40 | else: 41 | X = all_rot_90(X, list((-1,) + tuple(shape))) 42 | return id, X, y 43 | return fn 44 | 45 | def take_channels(ch, channel_order): 46 | """Extraction function for the TensorFlow pipeline. 47 | Given a list of channels to be extracted and a list of total channels, 48 | the inner function will output a new image tensor with only the 49 | selected channels. 50 | 51 | ch : list of string 52 | List of channels to be extracted 53 | channel_order : list of string 54 | Ordered list of the channels in the current input tensor 55 | """ 56 | idx = [i for i,x in enumerate(channel_order) if x in ch] 57 | def f(id, X, y): 58 | ret = tf.gather(X,idx,axis=-1) 59 | return id, ret, y 60 | return ch, f 61 | 62 | def make_input_fn(input_path, shape, training, rot, merge): 63 | """Create TensorFlow input pipeline. 64 | 65 | input_path : pathlib.Path 66 | Globbed path for the .tfrecord files 67 | shape : list of int 68 | 4D Shape of the images inputted 69 | training : bool 70 | True if training 71 | rot : bool 72 | True if rotations are enabled 73 | param merge : 74 | List of channel selectors to extract channels 75 | """ 76 | def in_fn(): 77 | channel_order = ['htmd_hydrophobic', 78 | 'htmd_aromatic', 79 | 'htmd_hbond_acceptor', 80 | 'htmd_hbond_donor', 81 | 'htmd_positive_ionizable', 82 | 'htmd_negative_ionizable', 83 | 'htmd_metal', 84 | 'htmd_occupancies', 85 | 'htmd_hydrophobic', 86 | 'htmd_aromatic', 87 | 'htmd_hbond_acceptor', 88 | 'htmd_hbond_donor', 89 | 'htmd_positive_ionizable', 90 | 'htmd_negative_ionizable', 91 | 'htmd_metal', 92 | 'htmd_occupancies', 93 | 'elec_p', 94 | 'elec_l', 95 | 'rosetta_atr_p', 96 | 'rosetta_rep_p', 97 | 'rosetta_sol_p_pos', 98 | 'rosetta_elec_p_pos', 99 | 'rosetta_sol_p_neg', 100 | 'rosetta_elec_p_neg', 101 | 'rosetta_atr_l', 102 | 'rosetta_atr_p', 103 | 'rosetta_sol_l_pos', 104 | 'rosetta_elec_l_pos', 105 | 'rosetta_sol_l_neg', 106 | 'rosetta_elec_l_neg'] 107 | files = tf.data.Dataset.list_files(str(input_path)) 108 | dataset = files.apply( 109 | tf.data.experimental.parallel_interleave(tf.data.TFRecordDataset, 110 | settings.parallel_calls)) 111 | dataset = dataset.map(parse_fn(shape), settings.parallel_calls) 112 | take = list(channel_order) 113 | filts = list(map(lambda m: lambda x: m in x, merge)) 114 | take = [x for x in take if any([f(x) for f in filts])] 115 | channel_order, take_fn = take_channels(take, channel_order) 116 | dataset = dataset.map(take_fn, settings.parallel_calls) 117 | seed_t = tf.py_func(lambda : random.randint(0,214748364), [], tf.int64) 118 | dataset = dataset.shuffle(settings.shuffle_buffer_size, seed=seed_t) 119 | shape[-1] = len(channel_order) 120 | if not training and rot: 121 | dataset = dataset.batch(20) 122 | else: 123 | dataset = dataset.batch(settings.batch_size) 124 | dataset = dataset.prefetch(buffer_size=settings.prefetch_buffer_size) 125 | if rot: 126 | dataset = dataset.map(rotate_fn(training,shape),settings.parallel_calls) 127 | return dataset 128 | return in_fn 129 | 130 | def load_fn(dataset_object, channels, rotate, training): 131 | """Create input pipeline function. 132 | 133 | dataset_object : DatasetObject 134 | Dataset for the images being used 135 | channels : list of string 136 | Channel selectors to use 137 | rotate : bool 138 | True if rotations are enabled 139 | training : bool 140 | True if training 141 | """ 142 | name = dataset_object.name 143 | shape = [settings.size]*3 + [30] 144 | training = (training == tf.estimator.ModeKeys.TRAIN) 145 | dataset_files = dataset_object.tfrecords / '*.tfrecords' 146 | return make_input_fn(dataset_files, shape, training, rotate, channels) 147 | 148 | -------------------------------------------------------------------------------- /RosENet/preprocessing/compute_rosetta_energy.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from RosENet.preprocessing.step import Step 3 | import argparse 4 | import os 5 | import h5py 6 | from pyrosetta.rosetta.protocols.scoring import Interface 7 | from pyrosetta.rosetta import * 8 | from pyrosetta import * 9 | from pathlib import Path 10 | import numpy as np 11 | from multiprocessing import Pool, cpu_count 12 | from collections import defaultdict 13 | from pyrosetta.toolbox.atom_pair_energy import print_residue_pair_energies 14 | from RosENet.preprocessing.minimize_rosetta import MinimizeRosetta 15 | init('-in:auto_setup_metals') #-mute core.conformation.Conformation') 16 | 17 | def compute_atom_pair_energy(pdb_filename, ligand_params, interface_cutoff = 21.0): 18 | """Compute pairwise energies and aggregate them to atom-wise values. 19 | 20 | pdb_filename : pathlib.Path 21 | Path for the complex pdb file 22 | ligand_params : pathlib.Path 23 | Path for the ligand params file 24 | interface_cutoff : float 25 | Distance for which to check pairwise interactions between atoms 26 | """ 27 | if type(ligand_params) is str: 28 | ligand_params = [ligand_params] 29 | ligand_params = Vector1([str(ligand_params)]) 30 | 31 | pose = Pose() 32 | res_set = pose.conformation().modifiable_residue_type_set_for_conf() 33 | res_set.read_files_for_base_residue_types( ligand_params ) 34 | 35 | pose.conformation().reset_residue_type_set_for_conf( res_set ) 36 | pose_from_file(pose, str(pdb_filename)) 37 | scorefxn = create_score_function('ref2015') 38 | pose_score = scorefxn(pose) 39 | 40 | #detect interface 41 | fold_tree = pose.fold_tree() 42 | for jump in range(1, pose.num_jump()+1): 43 | name = pose.residue(fold_tree.downstream_jump_residue(jump)).name() 44 | if name == 'WER': 45 | break 46 | interface = Interface(jump) 47 | interface.distance(interface_cutoff) 48 | interface.calculate(pose) 49 | 50 | energies = [] 51 | en = defaultdict(lambda:np.zeros((1,4))) 52 | keys = [] 53 | for rnum1 in range(1, pose.total_residue() + 1): 54 | if interface.is_interface(rnum1): 55 | r1 = pose.residue(rnum1) 56 | for a1 in range(1, len(r1.atoms()) + 1): 57 | seq1 = pose.pdb_info().pose2pdb(rnum1).strip().replace(' ','-') 58 | at1 = r1.atom_name(a1).strip() 59 | key1 = seq1 + '-' + at1 60 | for rnum2 in range(rnum1+1, pose.total_residue() + 1): 61 | if interface.is_interface(rnum2): 62 | r2 = pose.residue(rnum2) 63 | for a2 in range(1, len(r2.atoms())+1): 64 | seq2 = pose.pdb_info().pose2pdb(rnum2).strip().replace(' ','-') 65 | at2 = r2.atom_name(a2).strip() 66 | key2 = seq2 + '-' + at2 67 | ee = etable_atom_pair_energies(r1, a1, r2, a2, scorefxn) 68 | if all(e == 0.0 for e in ee): 69 | continue 70 | en[key1] += np.array(ee) 71 | en[key2] += np.array(ee) 72 | energy_matrix = np.array([v for v in en.values()]) 73 | return list(en.keys()), energy_matrix 74 | 75 | def get_radii_and_charges(pdb_filename, ligand_params): 76 | """Compute radii and charges for the atoms of the complex. 77 | 78 | pdb_filename : pathlib.Path 79 | Path for the complex pdb file 80 | ligand_params : pathlib.Path 81 | Path for the ligand params file 82 | """ 83 | keys = [] 84 | charges = [] 85 | radii = [] 86 | 87 | if type(ligand_params) is str: 88 | ligand_params = [ligand_params] 89 | ligand_params = Vector1([str(ligand_params)]) 90 | 91 | pose = Pose() 92 | res_set = pose.conformation().modifiable_residue_type_set_for_conf() 93 | res_set.read_files_for_base_residue_types(ligand_params) 94 | 95 | pose.conformation().reset_residue_type_set_for_conf(res_set) 96 | pose_from_file(pose, str(pdb_filename)) 97 | for rnum1 in range(1, pose.total_residue() + 1): 98 | r1 = pose.residue(rnum1) 99 | for a1 in range(1, len(r1.atoms()) + 1): 100 | seq1 = pose.pdb_info().pose2pdb(rnum1).strip().replace(' ','-') 101 | at1 = r1.atom_name(a1).strip() 102 | key1 = seq1 + '-' + at1 103 | charges.append(r1.atomic_charge(a1)) 104 | radii.append(r1.atom_type(a1).lj_radius()) 105 | keys.append(key1) 106 | 107 | return keys, charges, radii 108 | 109 | class ComputeRosettaEnergy(metaclass=Step,requirements=[MinimizeRosetta]): 110 | 111 | @classmethod 112 | def files(cls, pdb_object): 113 | """List of files being created 114 | 115 | pdb_object : PDBObject 116 | PDB structure being handled 117 | """ 118 | return [pdb_object.minimized.complex.attr] 119 | 120 | @classmethod 121 | def _run(cls, pdb_object): 122 | """Inner function for the preprocessing step. 123 | 124 | pdb_object : PDBObject 125 | PDB structure being handled 126 | """ 127 | pdb_file = pdb_object.minimized.complex.pdb.path 128 | folder = pdb_file.parent 129 | pdb_code = folder.stem 130 | ligand_params = pdb_object.ligand.params.path 131 | output_file = pdb_object.minimized.complex.attr.path 132 | try: 133 | e_keys, e_values = compute_atom_pair_energy(pdb_file, ligand_params) 134 | rc_keys, charges, radii = get_radii_and_charges(pdb_file, ligand_params) 135 | except Exception as e: 136 | print("Error at ", pdb_file) 137 | print(e) 138 | return 139 | energy_keys = np.array(e_keys) 140 | energy_values = np.array(e_values) 141 | rc_keys = np.array(rc_keys) 142 | radius_values = np.array(radii) 143 | charge_values = np.array(charges) 144 | np.savez_compressed(str(output_file), 145 | energy_keys=e_keys, 146 | energy_values=energy_values, 147 | rc_keys=rc_keys, 148 | radius_values=radius_values, 149 | charge_values=charge_values) 150 | 151 | -------------------------------------------------------------------------------- /test_dataset/10gs/10gs_ligand.mol2: -------------------------------------------------------------------------------- 1 | ### 2 | ### Created by X-TOOL on Mon Sep 10 21:12:46 2018 3 | ### 4 | 5 | @MOLECULE 6 | 10gs_ligand 7 | 59 60 1 0 0 8 | SMALL 9 | GAST_HUCK 10 | 11 | 12 | @ATOM 13 | 1 N 15.0880 10.7980 23.5470 N.4 1 VWW 0.2328 14 | 2 CA 15.0100 9.9870 24.7920 C.3 1 VWW 0.0304 15 | 3 C 16.1150 8.9240 24.8300 C.2 1 VWW 0.0846 16 | 4 O 16.5200 8.5150 25.9400 O.co2 1 VWW -0.5643 17 | 5 CB 13.6350 9.3270 24.9080 C.3 1 VWW 0.0193 18 | 6 CG 13.3940 8.7080 26.2710 C.3 1 VWW 0.0441 19 | 7 CD 12.0450 8.0460 26.4020 C.2 1 VWW 0.1785 20 | 8 OE1 11.2930 7.9360 25.4350 O.2 1 VWW -0.3969 21 | 9 OXT 16.5780 8.5240 23.7440 O.co2 1 VWW -0.5643 22 | 10 N 11.7260 7.6420 27.6280 N.am 1 VWW -0.2648 23 | 11 CA 10.4720 6.9670 27.9340 C.3 1 VWW 0.1400 24 | 12 CB 10.7260 5.4840 28.2060 C.3 1 VWW 0.0361 25 | 13 SG 11.2910 4.5240 26.8100 S.3 1 VWW -0.1422 26 | 14 CD 9.7290 3.8040 26.2620 C.3 1 VWW 0.0276 27 | 15 CE 8.9300 3.1710 27.3700 C.ar 1 VWW -0.0332 28 | 16 CZ1 7.6400 3.6140 27.6500 C.ar 1 VWW -0.0595 29 | 17 CZ2 9.4640 2.1350 28.1330 C.ar 1 VWW -0.0595 30 | 18 CT1 6.8930 3.0370 28.6730 C.ar 1 VWW -0.0685 31 | 19 CT2 8.7230 1.5500 29.1610 C.ar 1 VWW -0.0685 32 | 20 CH 7.4370 2.0010 29.4300 C.ar 1 VWW -0.0687 33 | 21 C 9.8340 7.5500 29.1800 C.2 1 VWW 0.2059 34 | 22 O 10.5220 8.0230 30.0840 O.2 1 VWW -0.3942 35 | 23 N 8.5120 7.4680 29.2290 N.am 1 VWW -0.2587 36 | 24 CA 7.7400 7.9330 30.3660 C.3 1 VWW 0.1278 37 | 25 CB 6.5550 7.0620 30.6330 C.ar 1 VWW -0.0049 38 | 26 CG1 5.3300 7.3150 30.0270 C.ar 1 VWW -0.0564 39 | 27 CG2 6.6830 5.9410 31.4410 C.ar 1 VWW -0.0564 40 | 28 CD1 4.2500 6.4590 30.2200 C.ar 1 VWW -0.0682 41 | 29 CD2 5.6110 5.0810 31.6400 C.ar 1 VWW -0.0682 42 | 30 CE 4.3920 5.3390 31.0270 C.ar 1 VWW -0.0685 43 | 31 C 7.4520 9.4330 30.3540 C.2 1 VWW 0.0723 44 | 32 O 7.1160 9.9570 31.4330 O.co2 1 VWW -0.5642 45 | 33 OXT 7.5690 10.0680 29.2840 O.co2 1 VWW -0.5642 46 | 34 H1 14.3522 11.4870 23.5482 H 1 VWW 0.2010 47 | 35 H2 14.9824 10.1951 22.7461 H 1 VWW 0.2010 48 | 36 H3 15.9821 11.2614 23.5033 H 1 VWW 0.2010 49 | 37 H4 15.1478 10.6593 25.6517 H 1 VWW 0.1026 50 | 38 H5 13.5582 8.5382 24.1452 H 1 VWW 0.0363 51 | 39 H6 12.8628 10.0891 24.7265 H 1 VWW 0.0363 52 | 40 H7 13.4683 9.4998 27.0309 H 1 VWW 0.0505 53 | 41 H8 14.1719 7.9516 26.4522 H 1 VWW 0.0505 54 | 42 H9 12.3741 7.8085 28.3711 H 1 VWW 0.1883 55 | 43 H10 9.7864 7.0776 27.0810 H 1 VWW 0.0808 56 | 44 H11 9.7847 5.0400 28.5622 H 1 VWW 0.0426 57 | 45 H12 11.4878 5.4113 28.9961 H 1 VWW 0.0426 58 | 46 H13 9.9478 3.0333 25.5083 H 1 VWW 0.0538 59 | 47 H14 9.1210 4.5998 25.8070 H 1 VWW 0.0538 60 | 48 H15 7.2118 4.4189 27.0636 H 1 VWW 0.0557 61 | 49 H16 10.4667 1.7790 27.9259 H 1 VWW 0.0557 62 | 50 H17 5.8903 3.3927 28.8804 H 1 VWW 0.0599 63 | 51 H18 9.1502 0.7453 29.7483 H 1 VWW 0.0599 64 | 52 H19 6.8580 1.5482 30.2268 H 1 VWW 0.0559 65 | 53 H20 8.0270 7.0694 28.4506 H 1 VWW 0.1901 66 | 54 H21 8.3977 7.7875 31.2356 H 1 VWW 0.0887 67 | 55 H22 5.2146 8.1893 29.3967 H 1 VWW 0.0560 68 | 56 H23 7.6317 5.7341 31.9229 H 1 VWW 0.0560 69 | 57 H24 3.2999 6.6666 29.7411 H 1 VWW 0.0600 70 | 58 H25 5.7257 4.2088 32.2735 H 1 VWW 0.0600 71 | 59 H26 3.5543 4.6679 31.1783 H 1 VWW 0.0560 72 | @BOND 73 | 1 2 1 1 74 | 2 2 5 1 75 | 3 2 3 1 76 | 4 3 4 ar 77 | 5 3 9 ar 78 | 6 5 6 1 79 | 7 6 7 1 80 | 8 7 8 2 81 | 9 7 10 am 82 | 10 10 11 1 83 | 11 11 12 1 84 | 12 11 21 1 85 | 13 12 13 1 86 | 14 13 14 1 87 | 15 14 15 1 88 | 16 15 16 ar 89 | 17 15 17 ar 90 | 18 16 18 ar 91 | 19 17 19 ar 92 | 20 18 20 ar 93 | 21 19 20 ar 94 | 22 21 22 2 95 | 23 21 23 am 96 | 24 23 24 1 97 | 25 24 25 1 98 | 26 24 31 1 99 | 27 25 26 ar 100 | 28 25 27 ar 101 | 29 26 28 ar 102 | 30 27 29 ar 103 | 31 28 30 ar 104 | 32 29 30 ar 105 | 33 31 32 ar 106 | 34 31 33 ar 107 | 35 1 34 1 108 | 36 1 35 1 109 | 37 1 36 1 110 | 38 2 37 1 111 | 39 5 38 1 112 | 40 5 39 1 113 | 41 6 40 1 114 | 42 6 41 1 115 | 43 10 42 1 116 | 44 11 43 1 117 | 45 12 44 1 118 | 46 12 45 1 119 | 47 14 46 1 120 | 48 14 47 1 121 | 49 16 48 1 122 | 50 17 49 1 123 | 51 18 50 1 124 | 52 19 51 1 125 | 53 20 52 1 126 | 54 23 53 1 127 | 55 24 54 1 128 | 56 26 55 1 129 | 57 27 56 1 130 | 58 28 57 1 131 | 59 29 58 1 132 | 60 30 59 1 133 | @SUBSTRUCTURE 134 | 1 VWW 1 135 | 136 | -------------------------------------------------------------------------------- /RosENet/network/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | def save_results(model_train_object, channels="", *args): 5 | """Save information about a trained model in csv format. 6 | 7 | model_train_object : _ModelTrain 8 | ModelTrain object representing the instantiation of the training 9 | channels : string 10 | Channel selectors used for the training 11 | args : list of various 12 | Data to be written in the file 13 | """ 14 | result_line = ",".join([model_train_object.dataset_name, model_train_object.model_name] + [str(a) for a in args]) + "\n" 15 | model_train_object.results.write(model_train_object.results.read()+result_line) 16 | 17 | def random_rotation_matrix(): 18 | """Create a random 3D matrix in TensorFlow""" 19 | A = tf.random_normal((3,3)) 20 | Q, _ = tf.linalg.qr(A) 21 | Q = tf.convert_to_tensor([[tf.sign(tf.linalg.det(Q)),0,0], 22 | [0,1,0], 23 | [0,0,1]],dtype=tf.float32) @ Q 24 | return Q 25 | 26 | def line3D(o,t,step): 27 | l0 = tf.linspace(o[0],t[0],step) 28 | l1 = tf.linspace(o[1],t[1],step) 29 | l2 = tf.linspace(o[2],t[2],step) 30 | return tf.stack((l0,l1,l2),axis=1) 31 | 32 | def outer_sum_diag(a,b): 33 | return tf.expand_dims(a,axis=-2) + tf.expand_dims(b,axis=0) 34 | 35 | def random_rot(X, input_shape, output_shape): 36 | output_step = output_shape[0] 37 | batch_size = tf.shape(X)[0] 38 | m = (np.array(input_shape[:-1]) - 1)/2 39 | o = -(np.array(output_shape)-1)/2 40 | corners = np.array([[0,0,0], 41 | [output_step-1,0,0], 42 | [0,output_step-1,0], 43 | [0,0,output_step-1]]) + o 44 | corners = tf.convert_to_tensor(corners, dtype=tf.float32) 45 | rotation_matrix = random_rotation_matrix() 46 | rotated_corners = corners @ rotation_matrix 47 | rotated_corners = rotated_corners + tf.convert_to_tensor(m, dtype=tf.float32) 48 | o = rotated_corners[0] 49 | a = rotated_corners[1] 50 | b = rotated_corners[2] 51 | c = rotated_corners[3] 52 | line_oa = line3D(o,a,output_step) 53 | vector_ob = line3D(o,b,output_step) - o 54 | vector_oc = line3D(o,c,output_step) - o 55 | plane_oab = outer_sum_diag(vector_ob,line_oa) 56 | rotated_query_points = outer_sum_diag(plane_oab, vector_oc) 57 | rotated_query_points = tf.cast(tf.round(rotated_query_points),tf.int64) 58 | rotated_query_points = tf.expand_dims(rotated_query_points, axis=0) 59 | rotated_query_points = tf.reshape(rotated_query_points, [1,-1,3]) 60 | rotated_query_points = tf.tile(rotated_query_points, [batch_size, 1, 1]) 61 | batch_column = tf.cast(tf.reshape(tf.tile(tf.reshape(tf.range(batch_size),(-1,1,1)), [1, 1, np.prod(output_shape)]),(-1,np.prod(output_shape),1)),tf.int64) 62 | rotated_query_points = tf.concat((batch_column,rotated_query_points),axis=2) 63 | X_output = tf.gather_nd(X, rotated_query_points) 64 | X_output = tf.reshape(X_output, [batch_size] + output_shape + [input_shape[-1]]) 65 | return X_output 66 | 67 | def basic_rot(X, axis, ccw=True): 68 | rot_axes = [x for x in [0,1,2] if x != axis] 69 | axis_1, axis_2 = rot_axes 70 | perm = [x if x not in rot_axes else axis_1 if x == axis_2 else axis_2 for x in [0,1,2,3]] 71 | rev_axis = axis_1 if ccw else axis_2 72 | return tf.reverse(tf.transpose(X,perm),[rev_axis]) 73 | 74 | def basic_rot_5D(X, axis, ccw=True): 75 | axis = axis + 1 76 | rot_axes = [x for x in [1,2,3] if x != axis] 77 | axis_1, axis_2 = rot_axes 78 | perm = [x if x not in rot_axes else axis_1 if x == axis_2 else axis_2 for x in [0,1,2,3,4]] 79 | rev_axis = axis_1 if ccw else axis_2 80 | return tf.reverse(tf.transpose(X,perm),[rev_axis]) 81 | 82 | 83 | def rots_90(X): 84 | x_90 = basic_rot(X, 0) 85 | x_180 = tf.reverse(X, [1,2]) 86 | x_270 = basic_rot(X, 0, ccw=False) 87 | maps = [X, x_90, x_180, x_270] 88 | extended_maps = list(maps) 89 | for m in maps: 90 | z_90 = basic_rot(m, 2) 91 | z_180 = tf.reverse(m, [0,1]) 92 | z_270 = basic_rot(m, 2, ccw=False) 93 | y_90 = basic_rot(m, 1) 94 | y_270 = basic_rot(m, 1, ccw=False) 95 | extended_maps += [z_90, z_180, z_270, y_90, y_270] 96 | return tf.stack(extended_maps) 97 | 98 | def rots_90_5D(X): 99 | x_90 = basic_rot_5D(X, 0) 100 | x_180 = tf.reverse(X, [1,2]) 101 | x_270 = basic_rot_5D(X, 0, ccw=False) 102 | maps = [X, x_90, x_180, x_270] 103 | extended_maps = list(maps) 104 | for m in maps: 105 | z_90 = basic_rot_5D(m, 2) 106 | z_180 = tf.reverse(m, [0,1]) 107 | z_270 = basic_rot_5D(m, 2, ccw=False) 108 | y_90 = basic_rot_5D(m, 1) 109 | y_270 = basic_rot_5D(m, 1, ccw=False) 110 | extended_maps += [z_90, z_180, z_270, y_90, y_270] 111 | return tf.stack(extended_maps) 112 | 113 | def random_rot_90(X,shape): 114 | def my_func(X): 115 | # x will be a numpy array with the contents of the placeholder below 116 | anti = np.random.randint(0,2,1) 117 | if anti == 0: 118 | shift = np.random.randint(0,3,1) 119 | x, y, z = ((np.array([0,1,2]) + shift) % 3) + 1 120 | flip = np.random.choice([None,(1,2),(1,3),(2,3)],1)[0] 121 | X = X.transpose((0,x,y,z,4)) 122 | if flip is not None: 123 | a,b = flip 124 | X = np.flip(X,(a,b)) 125 | else: 126 | shift = np.random.randint(0,3,1) 127 | x, y, z = ((np.array([1,0,2]) + shift) % 3) + 1 128 | flip = np.random.choice([1,2,3,None],1)[0] 129 | X = X.transpose((0,x,y,z,4)) 130 | if flip is not None: 131 | X = np.flip(X,flip) 132 | else: 133 | X = np.flip(X,(1,2,3)) 134 | return X 135 | return tf.reshape(tf.py_func(my_func, [X], tf.float32),shape) 136 | 137 | def all_rot_90(X,shape): 138 | def my_func(X): 139 | # x will be a numpy array with the contents of the placeholder below 140 | results = [] 141 | for point in X: 142 | for shift in range(3): 143 | x, y, z = ((np.array([0,1,2]) + shift) % 3) 144 | for flip in [None,(0,1),(0,2),(1,2)]: 145 | aux = point.transpose((x,y,z,3)) 146 | if flip is not None: 147 | a,b = flip 148 | results.append(np.flip(aux,(a,b))) 149 | else: 150 | results.append(aux) 151 | for shift in range(3): 152 | x, y, z = ((np.array([1,0,2]) + shift) % 3) 153 | for flip in [0,1,2,None]: 154 | aux = point.transpose((x,y,z,3)) 155 | if flip is not None: 156 | results.append(np.flip(aux,flip)) 157 | else: 158 | results.append(np.flip(aux,(0,1,2))) 159 | ret = np.stack(results,axis=0) 160 | return ret 161 | return tf.reshape(tf.py_func(my_func, [X], tf.float32),shape) 162 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # RosENet 4 | 5 | This is the repository for the RosENet project. 6 | 7 | RosENet: A 3D Convolutional Neural Network for predicting absolute binding affinity using molecular mechanics energies 8 | 9 | Hussein Hassan-Harrirou¹, Ce Zhang¹, Thomas Lemmin 1,2,* 10 | 11 | ¹DS3Lab, System Group, Department of Computer Sciences, ETH Zurich, CH-8092 Zurich, Switzerland 12 | 13 | ²Institute of Medical Virology, University of Zurich (UZH), CH-8057 Zurich, Switzerland 14 | 15 | \*corresponding author: thomas.lemmin@inf.ethz.ch 16 | 17 | ## Prerequisites 18 | 19 | A script `install.sh` is included, which creates a Conda environment and install the necessary requirements. 20 | 21 | Rosetta and Pyrosetta must be installed manually due to the required license. 22 | 23 | Pyrosetta can be copied next to the RosENet folder or by adding it to the `PYTHONPATH`. 24 | 25 | You must set the path to Rosetta's main folder in the attribute `rosetta.root` in file `RosENet/constants.py` [LINK](RosENet/constants.py#L65) 26 | 27 | ## How to run 28 | 29 | This repository is built as a Python module. To execute it, one should go above the root folder and run the following command: 30 | 31 | ``` 32 | python3 -m RosENet -- arguments 33 | ``` 34 | 35 | Currently, six actions are implemented: 36 | 37 | 1. preprocess 38 | 2. voxelize 39 | 3. postprocess 40 | 4. train 41 | 5. evaluate 42 | 6. predict 43 | 44 | These actions are also thought to be executed in this order. 45 | Preprocessing will compute the file transformations and energy minimization on the text files. 46 | Voxelization allows to create the 3D images for HTMD, Rosetta and electronegativity features and combine them. 47 | Postprocessing will transform the 3D images into a TFRecords format to be used as input for the neural networks. 48 | Training, evaluation and prediction have the usual meanings. 49 | 50 | Substeps in `preprocess` are cached. If some steps fail for some reason (usually faulty installations), the correct procedure after fixing the original problem is to clean the cached dataset. To do this, run `bash clear_dataset.sh `, where `` is the path to the dataset (i.e. `test_dataset`). 51 | 52 | ## Folder structure 53 | 54 | To ease the setup, there are some requirements to the structure of the inputs. 55 | The coarse unit of work is the dataset. A dataset is a folder with pdb folders and a file called labels. 56 | The next unit of work is the pdb folder. A pdb folder stores a {code}_protein.pdb file and a {code}_ligand.mol2. 57 | The labels file stores a line per pdb folder with the binding affinity assigned to it, separated by a space 58 | 59 | 60 | A valid folder structure: 61 | ``` 62 | test_dataset\ 63 | 10gs\ 64 | 10gs_protein.pdb 65 | 10gs_ligand.mol2 66 | labels 67 | ``` 68 | 69 | A valid `labels` file: 70 | ``` 71 | 10gs 12 72 | ``` 73 | 74 | ## Neural networks 75 | 76 | There are three CNNs implemented in the module, but any other implementation can be used. The only requirement is following TensorFlow's Estimator API. 77 | It is easy to replicate the same structure as the ones given here. 78 | 79 | After training a neural network, the module will save the trained network in the dataset folder, with the essential parameters also written in the folder name. Additionally, a line with the minimal validation error and the epoch when it was achieved will be saved in a results file in the dataset folder. The random seed used during training will also be stored in this line. This seed will be necessary to load the network afterwards. 80 | 81 | ## Selecting feature subsets 82 | 83 | To select feature subsets, we can use the channel selectors. These are implemented by matching them as a substring of the feature names. Multiple channel selectors can be combined with underscored. 84 | 85 | The feature names are: 86 | 87 | * `htmd_hydrophobic` 88 | * `htmd_aromatic` 89 | * `htmd_hbond_acceptor` 90 | * `htmd_hbond_donor` 91 | * `htmd_positive_ionizable` 92 | * `htmd_negative_ionizable` 93 | * `htmd_metal` 94 | * `htmd_occupancies` 95 | * `elec_p` 96 | * `elec_l` 97 | * `rosetta_atr_p` 98 | * `rosetta_rep_p` 99 | * `rosetta_sol_p_pos` 100 | * `rosetta_elec_p_pos` 101 | * `rosetta_sol_p_neg` 102 | * `rosetta_elec_p_neg` 103 | * `rosetta_atr_l` 104 | * `rosetta_atr_p` 105 | * `rosetta_sol_l_pos` 106 | * `rosetta_elec_l_pos` 107 | * `rosetta_sol_l_neg` 108 | * `rosetta_elec_l_neg` 109 | 110 | All `htmd` feature names represent both the protein and ligand features, so they are effectively two channels. 111 | 112 | For example, using `htmd_rosetta` will include all the HTMD features, and all the Rosetta features. 113 | 114 | The combination used in the paper release is `aromatic_acceptor_ion_rosetta`, adding to 20 feature maps. 115 | 116 | 117 | ## Setting the randomness 118 | 119 | The neural network actions use a seed parameter. For training, this seed is optional and can be randomly chosen by the module. 120 | 121 | The seed is used to identify different trainings of the same model/data. When evaluating and predicting, the seed must be specified. The seed can be found in the name of the trained model folder, located under the training dataset folder with name format `___`. 122 | 123 | ## Settings 124 | 125 | There are a few parameters that may be modified to change the behavior of the module. The file `settings.py` stores these options. 126 | There we can change the voxelization methods and the size of the voxel image. We can also change the parameters for TensorFlow's input pipeline. 127 | 128 | ## Running instructions 129 | 130 | To run the actions mentioned above with the example dataset, we need to run the following commands: 131 | 132 | ### Data preparation 133 | 134 | For the example dataset, substitute `` with `test_dataset` 135 | 136 | 1. preprocess 137 | ``` 138 | python3 -m RosENet -- preprocess 139 | ``` 140 | 2. voxelize 141 | ``` 142 | python3 -m RosENet -- voxelize 143 | ``` 144 | 3. postprocess 145 | ``` 146 | python3 -m RosENet -- postprocess 147 | ``` 148 | 149 | ### Network training/evaluation 150 | 151 | Substitute `<*_dataset>` with the paths to the respective datasets. 152 | 153 | Substitute `` for the path of a network model code. Some examples are located under `RosENet/models` (i.e. `RosENet/models/resnet.py`) 154 | 155 | Substitute `` for an underscore-separated string of channel selectors. 156 | 157 | Substitute `` for a non-negative integer to be used as seed. 158 | 159 | 4. train 160 | ``` 161 | python3 -m RosENet -- train [] 162 | ``` 163 | 5. evaluate 164 | ``` 165 | python3 -m RosENet -- evaluate 166 | ``` 167 | 6. predict 168 | ``` 169 | python3 -m RosENet -- predict 170 | ``` 171 | 172 | ### Preprocessed datasets 173 | 174 | The datasets for training and validation of RosENet have been published and are accesible in https://doi.org/10.5281/zenodo.4625486 175 | 176 | -------------------------------------------------------------------------------- /RosENet/voxelization/voxelizers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from htmd.molecule.molecule import Molecule 3 | from htmd.molecule.voxeldescriptors import getVoxelDescriptors 4 | from operator import itemgetter 5 | from itertools import compress 6 | from prody import parsePDB, calcCenter, moveAtoms 7 | from mendeleev import element 8 | from types import SimpleNamespace 9 | from RosENet.preprocessing.make_pdbqt import MakePDBQT 10 | from RosENet.preprocessing.compute_rosetta_energy import ComputeRosettaEnergy 11 | from .utils import * 12 | from .filter import voxel_filter 13 | from .interpolation import voxel_interpolation 14 | 15 | 16 | 17 | class Voxelizer: 18 | """Voxelizer abstract class. 19 | Voxelizers implement the algorithms that generate 3D voxelized images of 20 | structural attributes. 21 | 22 | Parameters 23 | ---------- 24 | size : int 25 | Size of each side of the 3D image, represented as a cube of voxels. 26 | 27 | Attributes 28 | ---------- 29 | size : int 30 | Size of each side of the 3D image, represented as a cube of voxels. 31 | protein : SimpleNamespace 32 | Storage object for protein-specific values used during computation and 33 | protein-specific images. 34 | ligand : SimpleNamespace 35 | Storage object for ligand-specific values used during computation and 36 | ligand-specific images. 37 | image : numpy.ndarray 38 | 3D image of the structure obtained as a result of the voxelization. 39 | 40 | """ 41 | 42 | def __init__(self, size): 43 | self.protein = type('', (), {}) 44 | self.ligand = type('', (), {}) 45 | self.image = None 46 | self.size = size 47 | 48 | def voxelize(self): 49 | """Compute the voxelized 3D image and store it in self.image""" 50 | pass 51 | 52 | 53 | class PointwiseVoxelizer(Voxelizer): 54 | """Base voxelizer class implementing the methods to generate 3D images given 55 | pointwise-values assigned to each atom. 56 | 57 | See Also 58 | -------- 59 | RosettaVoxelizer : Voxelizer for the Rosetta energy-function features. 60 | ElectronegativityVoxelizer : Voxelizer for the electronegativity features. 61 | 62 | Parameters 63 | ---------- 64 | complex_path : str or os.PathLike 65 | Path to the complex structure, a .pdb file. 66 | size : int 67 | Size of each side of the 3D image, represented as a cube of voxels. 68 | method_type : str 69 | Name of the method to voxelize the pointwise-values. Accepted values are 70 | "filter" or "interpolation". 71 | method_fn : str or callable 72 | If method_type == "filter", a function handle with two parameters that 73 | implements the filter to be applied. If method_type == "interpolation", 74 | a string or function handle suitable for scipy.interpolate.Rbf's 75 | function. 76 | """ 77 | 78 | def __init__(self, complex_path, data_object, size, method_type, method_fn): 79 | super(PointwiseVoxelizer, self).__init__(size) 80 | self.path = str(complex_path) 81 | self.data_object = data_object 82 | self.method_type = method_type 83 | self.method_fn = method_fn 84 | 85 | def _prepare_points(self): 86 | """Load structures and compute the location of the points of the 87 | 3D image to be generated. 88 | """ 89 | self.complex = parsePDB(self.path) 90 | protein = self.complex.select("not (resname WER or water)") 91 | ligand = self.complex.select("resname WER") 92 | center = calcCenter(ligand.getCoords()) 93 | moveAtoms(self.complex, by=-center) 94 | center = calcCenter(self.complex.select("resname WER").getCoords()) 95 | self.protein.structure = protein 96 | self.ligand.structure = ligand 97 | self.points = grid_around(center, self.size, spacing=24/(self.size-1)) 98 | 99 | def _load_attributes(self): 100 | """Load the computed energies, radii and charges for the atoms in the 101 | complex. These may not include all atoms, but only the ones around 20 A 102 | around the center of mass of the ligand.""" 103 | data = self.data_object.read() 104 | self.radii = dict(zip(data['rc_keys'], data['radius_values'])) 105 | self.charges = dict(zip(data['rc_keys'], data['charge_values'])) 106 | self.energies = data['energy_values'].squeeze() 107 | self.energy_keys = data['energy_keys'] 108 | self.energy_dict = dict(zip(self.energy_keys, self.energies)) 109 | 110 | def _apply_filter(self): 111 | """Apply the filter to generate the voxelized images, for both protein 112 | and ligand separately.""" 113 | self.protein.image = voxel_filter( 114 | self.method_fn, self.protein, self.points)\ 115 | .reshape(3*(self.size,) + (-1,)) 116 | self.ligand.image = voxel_filter( 117 | self.method_fn, self.ligand, self.points)\ 118 | .reshape(3*(self.size,) + (-1,)) 119 | 120 | def _apply_interpolation(self): 121 | """Apply the interpolation method to generate the voxelized images, for 122 | both protein and ligand separately.""" 123 | self.protein.image = voxel_interpolation( 124 | self.method_fn, self.protein, self.points)\ 125 | .reshape(3*(self.size,) + (-1,)) 126 | self.ligand.image = voxel_interpolation( 127 | self.method_fn, self.ligand, self.points)\ 128 | .reshape(3*(self.size,) + (-1,)) 129 | 130 | def _prepare_attributes(self, obj): 131 | """Special method for loading and masking the atom coordinates, names, 132 | radii and charges, masking the selection to only the atoms that have 133 | energies computed. 134 | 135 | Notes 136 | ----- 137 | 138 | This method is only used to load the attributes of self.protein 139 | and self.ligand. 140 | 141 | Parameters 142 | ---------- 143 | obj : SimpleNamespace 144 | """ 145 | try: 146 | obj.keys = get_keys(obj.structure) 147 | obj.atoms_in_scope = [x in self.energy_keys for x in obj.keys] 148 | obj.coordinates = np.compress(obj.atoms_in_scope, obj.structure.getCoords(), axis=0).reshape((-1, 3)) 149 | obj.keys = list(compress(obj.keys, obj.atoms_in_scope)) 150 | obj.radii = np.array(itemgetter(*obj.keys) 151 | (self.radii)).reshape((-1, 1)) 152 | obj.charges = np.array(itemgetter(*obj.keys) 153 | (self.charges)).reshape((-1, 1)) 154 | except Exception as e: 155 | print(self.path) 156 | print("#"*81) 157 | raise e 158 | 159 | def _prepare_values(self): 160 | """Voxelizer specific value loading.""" 161 | pass 162 | 163 | def _normalize(self): 164 | """Voxelizer specific after-voxelization normalization.""" 165 | pass 166 | 167 | def _merge(self): 168 | """Merges protein and ligand 3D images into the resulting self.image.""" 169 | self.image = np.concatenate( 170 | (self.protein.image, self.ligand.image), axis=-1) 171 | 172 | def voxelize(self): 173 | """Main voxelization method, overrides Voxelizer's voxelize method. 174 | Implements the steps to perform a pointwise-valued voxelization.""" 175 | self._load_attributes() 176 | self._prepare_points() 177 | self._prepare_attributes(self.protein) 178 | self._prepare_attributes(self.ligand) 179 | self._prepare_values() 180 | if self.method_type == "filter": 181 | self._apply_filter() 182 | elif self.method_type == "interpolation": 183 | self._apply_interpolation() 184 | self._normalize() 185 | self._merge() 186 | 187 | class ElectronegativityVoxelizer(PointwiseVoxelizer): 188 | """Voxelizer class to generate 3D images given per-atom electronegativity 189 | values. 190 | 191 | See Also 192 | -------- 193 | RosettaVoxelizer : Voxelizer for the Rosetta energy-function features. 194 | PointwiseVoxelizer : Base voxelizer for pointwise features. 195 | """ 196 | def _prepare_values(self): 197 | """Sets up the electronegativities as values for the voxelization, 198 | for both protein and ligand.""" 199 | elements = set(self.complex.getElements()) 200 | el_dict = {name: element( 201 | name.capitalize()).en_pauling for name in elements} 202 | self.protein.values = np.array(list(compress((el_dict[name] for name in self.protein.structure.getElements()), self.protein.atoms_in_scope)))\ 203 | .reshape((-1, 1)) 204 | self.ligand.values = np.array(list(compress([el_dict[name] for name in self.ligand.structure.getElements()], self.ligand.atoms_in_scope)))\ 205 | .reshape((-1, 1)) 206 | 207 | 208 | def _normalize(self): 209 | """Normalizes the electronegativities by dividing by the highest values 210 | in the protein's and ligand's electronegativies.""" 211 | self.protein.image = self.protein.image / 3.44 212 | self.ligand.image = self.ligand.image / 3.98 213 | 214 | 215 | class RosettaVoxelizer(PointwiseVoxelizer): 216 | """Voxelizer class to generate 3D images given per-atom Rosetta 217 | energy-function values. 218 | 219 | See Also 220 | -------- 221 | ElectronegativityVoxelizer : Voxelizer for electronegativity features. 222 | PointwiseVoxelizer : Base voxelizer for pointwise features. 223 | """ 224 | def _prepare_values(self): 225 | """Prepare energy values for voxelization""" 226 | self.protein.values = np.array(itemgetter( 227 | *self.protein.keys)(self.energy_dict)).reshape((-1, 4)) 228 | self.ligand.values = np.array(itemgetter( 229 | *self.ligand.keys)(self.energy_dict)).reshape((-1, 4)) 230 | 231 | def _normalize(self): 232 | """Normalize the energy maps to 0-1 range and split positive and 233 | negative maps.""" 234 | protein_limits = [-2.27475644, 4.11150004e+02, 235 | 4.01129569e+00, 1.28136470e+00, 236 | -0.37756078, -3.79981632] 237 | ligand_limits = [-1.7244696, 0.59311066, 2.67294434, 238 | 0.40072521, -0.44943017, -2.00621753] 239 | self.protein.image = np.concatenate((self.protein.image, 240 | self.protein.image[...,2:]),axis=-1) 241 | self.ligand.image = np.concatenate((self.ligand.image, 242 | self.ligand.image[...,2:]),axis=-1) 243 | self.protein.image = clip(self.protein.image, protein_limits) 244 | self.ligand.image = clip(self.ligand.image, ligand_limits) 245 | 246 | 247 | class HTMDVoxelizer(Voxelizer): 248 | """Base voxelizer class implementing the methods to generate 3D images given 249 | pointwise-values assigned to each atom. 250 | 251 | Parameters 252 | ---------- 253 | protein_path : str or os.PathLike 254 | Path to the protein structure, a .pdb file. 255 | ligand_path : str or os.PathLike 256 | Path to the ligand structure, a .pdb file. 257 | size : int 258 | Size of each side of the 3D image, represented as a cube of voxels. 259 | """ 260 | def __init__(self, protein_path, ligand_path, size): 261 | super(HTMDVoxelizer, self).__init__(size) 262 | self.protein.path = protein_path 263 | self.ligand.path = ligand_path 264 | 265 | def _prepare_points(self): 266 | """Load structures and compute the location of the points of the 267 | 3D image to be generated. 268 | """ 269 | protein = Molecule(str(self.protein.path)) 270 | ligand = Molecule(str(self.ligand.path)) 271 | protein.filter( 272 | 'not (water or name CO or name NI or name CU or name NA)') 273 | center = np.mean(ligand.get('coords'), axis=0) 274 | ligand.moveBy(-center) 275 | protein.moveBy(-center) 276 | center = np.mean(ligand.get('coords'), axis=0) 277 | self.protein.structure = protein 278 | self.ligand.structure = ligand 279 | self.points = grid_around(center, self.size, spacing=24/(self.size-1)).reshape((-1,3)) 280 | 281 | 282 | def voxelize(self): 283 | """Compute the voxelized 3D image and store it in self.image""" 284 | self._prepare_points() 285 | self.protein.image = getVoxelDescriptors( 286 | self.protein.structure, usercenters=self.points)[0]\ 287 | .reshape(3*(self.size,) + (-1,)) 288 | self.ligand.image = getVoxelDescriptors( 289 | self.ligand.structure, usercenters=self.points)[0]\ 290 | .reshape(3*(self.size,) + (-1,)) 291 | self.image = np.concatenate((self.protein.image, self.ligand.image), axis=-1) 292 | 293 | def VoxelizeRosetta(pdb_object, method, size): 294 | """Wrapper function for voxelizing and storing Rosetta features. 295 | 296 | pdb_object : PDBObject 297 | PDB structure to voxelize. 298 | method : tuple of (string, string or callable) 299 | Voxelization method and function. See underlying function for more information. 300 | size : int 301 | Size of voxel cube side. 302 | """ 303 | if not ComputeRosettaEnergy.computed(pdb_object): 304 | return False 305 | complex_path = pdb_object.minimized.complex.pdb.path 306 | data_object = pdb_object.minimized.complex.attr 307 | output_path = pdb_object.image.rosetta.path 308 | if output_path.exists(): 309 | return True 310 | method_type, method_fn = method 311 | voxelizer = RosettaVoxelizer(complex_path, data_object, size, method_type, method_fn) 312 | voxelizer.voxelize() 313 | pdb_object.image.rosetta.write(voxelizer.image) 314 | return True 315 | 316 | def VoxelizeElectronegativity(pdb_object, method, size): 317 | """Wrapper function for voxelizing and storing electronegativity features. 318 | 319 | pdb_object : PDBObject 320 | PDB structure to voxelize. 321 | method : tuple of (string, string or callable) 322 | Voxelization method and function. See underlying function for more information. 323 | size : int 324 | Size of voxel cube side. 325 | """ 326 | if not ComputeRosettaEnergy.computed(pdb_object): 327 | return False 328 | complex_path = pdb_object.minimized.complex.pdb.path 329 | data_object = pdb_object.minimized.complex.attr 330 | output_path = pdb_object.image.electronegativity.path 331 | if output_path.exists(): 332 | return True 333 | method_type, method_fn = method 334 | voxelizer = ElectronegativityVoxelizer(complex_path, data_object, size, method_type, method_fn) 335 | voxelizer.voxelize() 336 | pdb_object.image.electronegativity.write(voxelizer.image) 337 | 338 | def VoxelizeHTMD(pdb_object, size): 339 | """Wrapper function for voxelizing and storing HTMD features. 340 | 341 | pdb_object : PDBObject 342 | PDB structure to voxelize. 343 | size : int 344 | Size of voxel cube side. 345 | """ 346 | if not MakePDBQT.computed(pdb_object): 347 | return False 348 | protein_path = pdb_object.minimized.protein.pdbqt.path 349 | ligand_path = pdb_object.minimized.ligand.pdbqt.path 350 | output_path = pdb_object.image.htmd.path 351 | if output_path.exists(): 352 | return True 353 | voxelizer = HTMDVoxelizer(protein_path, ligand_path, size) 354 | voxelizer.voxelize() 355 | pdb_object.image.htmd.write(voxelizer.image) 356 | --------------------------------------------------------------------------------