├── RosENet
├── __init__.py
├── network
│ ├── evaluate.py
│ ├── predict.py
│ ├── network.py
│ ├── input.py
│ └── utils.py
├── rosetta
│ ├── __init__.py
│ └── rosetta.py
├── __main__.py
├── storage
│ ├── .ropeproject
│ │ ├── history
│ │ ├── objectdb
│ │ ├── globalnames
│ │ └── config.py
│ └── storage.py
├── preprocessing
│ ├── .ropeproject
│ │ ├── history
│ │ ├── objectdb
│ │ ├── globalnames
│ │ └── config.py
│ ├── __init__.py
│ ├── make_protein_pdb.py
│ ├── make_pdbqt.py
│ ├── make_ligand_params_pdb.py
│ ├── make_ligand_mol2.py
│ ├── make_ligand_mol2_renamed.py
│ ├── step.py
│ ├── minimize_rosetta.py
│ ├── make_complex_pdb.py
│ ├── preprocessing.py
│ ├── preprocess_vina.py
│ └── compute_rosetta_energy.py
├── postprocessing
│ ├── __init__.py
│ ├── postprocessing.py
│ └── .ropeproject
│ │ └── config.py
├── voxelization
│ ├── apbs.sh
│ ├── __init__.py
│ ├── apbs.in
│ ├── interpolation.py
│ ├── filter.py
│ ├── utils.py
│ ├── apbs.py
│ └── voxelizers.py
├── utils.py
├── settings.py
├── static
│ ├── flags_relax.txt
│ ├── relax.xml
│ ├── dock.xml
│ ├── dock_relax.xml
│ └── dock_relax2.xml
├── objects
│ ├── dataset.py
│ ├── file.py
│ ├── pdb.py
│ └── model.py
├── models
│ ├── kdeep.py
│ ├── large_kdeep.py
│ └── resnet.py
├── constants.py
└── clui.py
├── test_dataset
├── labels
└── 10gs
│ └── 10gs_ligand.mol2
├── .gitignore
├── logo.png
├── used_pdbs.pdf
├── pip-requirements.txt
├── install.sh
├── clear_dataset.sh
├── conda-requirements.txt
└── README.md
/RosENet/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/RosENet/network/evaluate.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/RosENet/network/predict.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/RosENet/rosetta/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/test_dataset/labels:
--------------------------------------------------------------------------------
1 | 10gs 1
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | **/*.pyc
2 | htmd/
3 | pyrosetta/
4 | __pycache__/
5 |
--------------------------------------------------------------------------------
/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DS3Lab/RosENet/HEAD/logo.png
--------------------------------------------------------------------------------
/used_pdbs.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DS3Lab/RosENet/HEAD/used_pdbs.pdf
--------------------------------------------------------------------------------
/RosENet/__main__.py:
--------------------------------------------------------------------------------
1 | #from . import testing
2 | from .clui import parse_arguments
3 |
4 | parse_arguments()
5 |
--------------------------------------------------------------------------------
/RosENet/storage/.ropeproject/history:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DS3Lab/RosENet/HEAD/RosENet/storage/.ropeproject/history
--------------------------------------------------------------------------------
/RosENet/storage/.ropeproject/objectdb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DS3Lab/RosENet/HEAD/RosENet/storage/.ropeproject/objectdb
--------------------------------------------------------------------------------
/RosENet/storage/.ropeproject/globalnames:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DS3Lab/RosENet/HEAD/RosENet/storage/.ropeproject/globalnames
--------------------------------------------------------------------------------
/RosENet/preprocessing/.ropeproject/history:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DS3Lab/RosENet/HEAD/RosENet/preprocessing/.ropeproject/history
--------------------------------------------------------------------------------
/RosENet/preprocessing/.ropeproject/objectdb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DS3Lab/RosENet/HEAD/RosENet/preprocessing/.ropeproject/objectdb
--------------------------------------------------------------------------------
/RosENet/preprocessing/.ropeproject/globalnames:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DS3Lab/RosENet/HEAD/RosENet/preprocessing/.ropeproject/globalnames
--------------------------------------------------------------------------------
/RosENet/postprocessing/__init__.py:
--------------------------------------------------------------------------------
1 | from .postprocessing import generate_tfrecords
2 |
3 | def postprocess(dataset_object):
4 | generate_tfrecords(dataset_object)
5 |
--------------------------------------------------------------------------------
/RosENet/voxelization/apbs.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | module load gcc/5.2.0
3 | OLD=$LD_LIBRARY_PATH
4 | LD_LIBRARY_PATH="$SCRATCH/apbs/lib:$LD_LIBRARY_PATH"
5 | $SCRATCH/apbs/bin/apbs "$@" > /dev/null 2>&1
6 |
7 | LD_LIBRARY_PATH=$OLD
8 |
--------------------------------------------------------------------------------
/RosENet/preprocessing/__init__.py:
--------------------------------------------------------------------------------
1 | from .compute_rosetta_energy import ComputeRosettaEnergy
2 | from .make_pdbqt import MakePDBQT
3 |
4 | def preprocess(pdb_object):
5 | ComputeRosettaEnergy.run_until(pdb_object,callbacks=[lambda x: print(x)])
6 | MakePDBQT.run_until(pdb_object,callbacks=[lambda x: print(x)])
7 |
--------------------------------------------------------------------------------
/RosENet/utils.py:
--------------------------------------------------------------------------------
1 | def message_callbacks(callbacks, message):
2 | """Utility function to broadcast messages to a list of callbacks.
3 |
4 | callbacks : list of callable
5 | List of callbacks
6 | message : various
7 | Message to broadcast
8 | """
9 | for callback in callbacks:
10 | callback(message)
11 |
--------------------------------------------------------------------------------
/pip-requirements.txt:
--------------------------------------------------------------------------------
1 | colorama==0.4.3
2 | networkx==2.4
3 | mdtraj==1.9.4
4 | periodictable==1.5.2
5 | tqdm==4.47.0
6 | htmd-pdb2pqr==0.0.2
7 | pyfiglet==0.8.post1
8 | natsort==7.0.1
9 | python-dateutil==2.8.1
10 | decorator==4.4.2
11 | mendeleev==0.6.0
12 | pandas==0.25.3
13 | propka==3.2.0
14 | sqlalchemy==1.3.18
15 | pytz==2020.1
16 | numba==0.50.1
17 | prody==1.10.8
18 | moleculekit==0.3.2
19 | llvmlite==0.33.0
20 | biopython==1.77
21 |
--------------------------------------------------------------------------------
/RosENet/settings.py:
--------------------------------------------------------------------------------
1 | """
2 | Changeable settings for the project.
3 | """
4 | from RosENet.voxelization.filter import exp_12
5 | size = 25
6 | voxelization_type = "filter"
7 | voxelization_fn = "exp_12"
8 | voxelization = ("filter", exp_12)
9 | options = f"_{voxelization_type}_{voxelization_fn}_{size}"
10 | max_epochs = 300
11 | parallel_calls = 20
12 | shuffle_buffer_size = 1000
13 | batch_size = 128
14 | prefetch_buffer_size = 10
15 | rotate = True
16 |
17 |
--------------------------------------------------------------------------------
/RosENet/voxelization/__init__.py:
--------------------------------------------------------------------------------
1 | from .voxelizers import VoxelizeHTMD, VoxelizeRosetta, VoxelizeElectronegativity
2 | from RosENet.postprocessing.postprocessing import combine_maps
3 | from RosENet import settings
4 |
5 | def voxelize(pdb_object):
6 | try:
7 | VoxelizeHTMD(pdb_object, settings.size)
8 | VoxelizeRosetta(pdb_object, settings.voxelization, settings.size)
9 | VoxelizeElectronegativity(pdb_object, settings.voxelization, settings.size)
10 | combine_maps(pdb_object)
11 | except:
12 | pass
13 |
--------------------------------------------------------------------------------
/install.sh:
--------------------------------------------------------------------------------
1 | SCRIPTPATH="$( cd "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )"
2 | cd $SCRIPTPATH
3 |
4 | conda remove --name rosenet || true
5 | conda create --name rosenet python=3.7 --file conda-requirements.txt -y
6 | conda init bash
7 | conda activate rosenet && (yes | pip install -r pip-requirements.txt)
8 |
9 | rm -r ./htmd
10 | wget -qO- https://anaconda.org/acellera/HTMD/1.13.10/download/linux-64/htmd-1.13.10-py36_0.tar.bz2 | tar -xvj lib/python3.6/site-packages/htmd
11 | mv lib/python3.6/site-packages/htmd/ .
12 | rm -r lib
13 |
14 | echo "Copy the pyrosetta folder in $SCRIPTPATH or add pyrosetta to PYTHONPATH"
15 |
--------------------------------------------------------------------------------
/RosENet/static/flags_relax.txt:
--------------------------------------------------------------------------------
1 | -in
2 | -file
3 | -s ./$complex
4 | -extra_res_fa ./$params
5 | -packing
6 | -ex1
7 | -ex1aro
8 | -ex2
9 | -mute core.util.prof ## dont show timing info
10 | -mute core.io.database
11 | #-restore_pre_talaris_2013_behavior
12 | #-mute all
13 | #-unmute protocols.jd2.JobDistributor
14 | -constraints:cst_fa_file ./constraints
15 | -score:set_weights atom_pair_constraint 1.0
16 | -in:auto_setup_metals
17 | -in:metals_angle_constraint_multiplier 3.0
18 | -in:metals_distance_constraint_multiplier 3.0
19 | -in:ignore_waters false
20 | -ignore_zero_occupancy false
21 | -keep_input_protonation_state true
22 | -nstruct 50
23 | -overwrite
24 |
--------------------------------------------------------------------------------
/clear_dataset.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | DATASET_PATH=$1
4 |
5 | find $DATASET_PATH -type f -name "metadata.json" | xargs rm
6 | find $DATASET_PATH -type f -name "*_complex*" | xargs rm
7 | find $DATASET_PATH -type f -name "*_ligand_*.*" | xargs rm
8 | find $DATASET_PATH -type f -name "*_protein_*.*" | xargs rm
9 | find $DATASET_PATH -type f -name "*_ligand.pdb" | xargs rm
10 | find $DATASET_PATH -type f -name "*_ligand.params" | xargs rm
11 | find $DATASET_PATH -type f -name "constraints" | xargs rm
12 | find $DATASET_PATH -type f -name "flags_relax.txt" | xargs rm
13 | find $DATASET_PATH -type f -name "score.sc" | xargs rm
14 | find $DATASET_PATH -type d -name "other_complexes" | xargs rm -r
15 |
16 |
--------------------------------------------------------------------------------
/RosENet/static/relax.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/RosENet/preprocessing/make_protein_pdb.py:
--------------------------------------------------------------------------------
1 | from RosENet.preprocessing.step import Step
2 | from RosENet.preprocessing.minimize_rosetta import MinimizeRosetta
3 | import RosENet.constants as constants
4 | import RosENet.rosetta.rosetta as rosetta
5 |
6 | class MakeProteinPDB(metaclass=Step, requirements=[MinimizeRosetta]):
7 | @classmethod
8 | def files(cls, pdb_object):
9 | """List of files being created
10 |
11 | pdb_object : PDBObject
12 | PDB structure being handled
13 | """
14 | return [pdb_object.minimized.protein.pdb]
15 |
16 | @classmethod
17 | def _run(cls, pdb_object):
18 | """Inner function for the preprocessing step.
19 |
20 | pdb_object : PDBObject
21 | PDB structure being handled
22 | """
23 | complex = pdb_object.minimized.complex.pdb.read()
24 | protein = complex.select(constants.protein_selector)
25 | pdb_object.minimized.protein.pdb.write(protein)
26 |
27 |
--------------------------------------------------------------------------------
/RosENet/preprocessing/make_pdbqt.py:
--------------------------------------------------------------------------------
1 | from RosENet.preprocessing.step import Step
2 | from RosENet.preprocessing.make_protein_pdb import MakeProteinPDB
3 | from RosENet.preprocessing.make_ligand_mol2 import MakeLigandMOL2
4 | import RosENet.constants as constants
5 | import subprocess
6 |
7 | class MakePDBQT(metaclass=Step, requirements=[MakeProteinPDB, MakeLigandMOL2]):
8 | @classmethod
9 | def files(cls, pdb_object):
10 | """List of files being created
11 |
12 | pdb_object : PDBObject
13 | PDB structure being handled
14 | """
15 | return [pdb_object.minimized.ligand.pdbqt,
16 | pdb_object.minimized.protein.pdbqt]
17 |
18 | @classmethod
19 | def _run(cls, pdb_object):
20 | """Inner function for the preprocessing step.
21 |
22 | pdb_object : PDBObject
23 | PDB structure being handled
24 | """
25 | subprocess.run([constants.mgl_python_path,
26 | constants.preprocess_vina_path,
27 | pdb_object.minimized.protein.pdb.path,
28 | pdb_object.minimized.ligand.mol2.path])
29 |
30 |
--------------------------------------------------------------------------------
/RosENet/preprocessing/make_ligand_params_pdb.py:
--------------------------------------------------------------------------------
1 | from RosENet.preprocessing.step import Step
2 | import RosENet.rosetta.rosetta as rosetta
3 |
4 | class MakeLigandParamsPDB(metaclass=Step):
5 | """Preprocessing step that create the ligand.params and ligand.pdb files at
6 | the beginning of the pipeline."""
7 |
8 | @classmethod
9 | def files(cls, pdb_object):
10 | """List of files being created
11 |
12 | pdb_object : PDBObject
13 | PDB structure being handled
14 | """
15 | return [pdb_object.ligand.params,
16 | pdb_object.ligand.pdb]
17 |
18 | @classmethod
19 | def _run(cls, pdb_object):
20 | """Inner function for the preprocessing step.
21 |
22 | pdb_object : PDBObject
23 | PDB structure being handled
24 | """
25 | ligand_mol2_path = pdb_object.ligand.mol2.path
26 | params_filename = ligand_mol2_path.stem
27 | working_directory = ligand_mol2_path.parent
28 | return rosetta.molfile_to_params(
29 | working_directory = working_directory,
30 | output_path = params_filename,
31 | input_path = ligand_mol2_path)
32 |
33 |
--------------------------------------------------------------------------------
/RosENet/preprocessing/make_ligand_mol2.py:
--------------------------------------------------------------------------------
1 | from RosENet.preprocessing.step import Step
2 | import RosENet.rosetta.rosetta as rosetta
3 | from RosENet.preprocessing.make_ligand_mol2_renamed import MakeLigandMOL2Renamed
4 | from RosENet.preprocessing.minimize_rosetta import MinimizeRosetta
5 |
6 | class MakeLigandMOL2(metaclass=Step,requirements=[MakeLigandMOL2Renamed, MinimizeRosetta]):
7 | @classmethod
8 | def files(cls, pdb_object):
9 | """List of files being created
10 |
11 | pdb_object : PDBObject
12 | PDB structure being handled
13 | """
14 | return [pdb_object.minimized.ligand.mol2]
15 |
16 |
17 | @classmethod
18 | def _run(cls, pdb_object):
19 | """Inner function for the preprocessing step.
20 |
21 | pdb_object : PDBObject
22 | PDB structure being handled
23 | """
24 | complex_path = pdb_object.minimized.complex.pdb.path
25 | renamed_mol2_path = pdb_object.ligand.renamed_mol2.path
26 | ligand_path = pdb_object.minimized.ligand.mol2.path
27 | rosetta.pdb_to_molfile(
28 | mol = renamed_mol2_path,
29 | complex_path = complex_path,
30 | output_path = ligand_path)
31 |
--------------------------------------------------------------------------------
/RosENet/voxelization/apbs.in:
--------------------------------------------------------------------------------
1 | # READ IN MOLECULES
2 | read
3 | mol pqr XXX.pqr
4 | end
5 |
6 |
7 | elec # Electrostatics calculation on the solvated state
8 | mg-auto # Specify the mode for APBS to run
9 | dime GRID_DIM # The grid dimensions
10 | grid GRID_SPACE # Grid spacing
11 | gcent INH_CENTER # Center the grid
12 | cglen CG_LEN
13 | cgcent INH_CENTER
14 | fglen FG_LEN
15 | fgcent INH_CENTER
16 | mol 1 # Perform the calculation on molecule 1
17 | lpbe # Solve the linearized Poisson-Boltzmann
18 | # equation
19 | bcfl mdh # Use all multipole moments when calculating the
20 | # potential
21 | ion 1 0.150 2.0
22 | ion -1 0.150 2.0
23 | pdie 1.0 # Solute dielectric
24 | sdie 78.54 # Solvent dielectric
25 | chgm spl2 # Spline-based discretization of the delta
26 | # functions
27 | srfm smol # Molecular surface definition
28 | srad 1.4 # Solvent probe radius (for molecular surface)
29 | swin 0.3 # Solvent surface spline window (not used here)
30 | sdens 10.0 # Sphere density of accessibility object
31 | temp 298.15 # Temperature
32 | calcenergy no # Calculate energies
33 | calcforce no # Do not calculate forces
34 | write pot dx XXX_potential # Write out the potential
35 | end
36 | quit
37 |
--------------------------------------------------------------------------------
/RosENet/voxelization/interpolation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.interpolate import Rbf
3 |
4 | def voxel_interpolation(int_type, structure, targets):
5 | """Method to apply an interpolation method to a set of atom coordinates with
6 | assigned values, and distribute them to a set of target points.
7 |
8 | Parameters
9 | ----------
10 | int_type : str or callable
11 | Interpolation name or function handle following scipy.interpolate.Rbf.
12 | points : numpy.ndarray
13 | Atom coordinates.
14 | values : numpy.ndarray
15 | Atom values to spacially distribute.
16 | targets : numpy.ndarray
17 | 3D voxel positions to compute the distributed values at.
18 | """
19 | values = structure.values
20 | points = structure.coordinates
21 | mask = np.linalg.norm(points, axis=-1) <= 12.5*np.sqrt(3)
22 | points = points[mask]
23 | values = values[mask, :]
24 | shape = targets.shape
25 | targets = targets.reshape((-1,3))
26 | points_x, points_y, points_z = [c.flatten()
27 | for c in np.split(points, 3, axis=1)]
28 | targets_x, targets_y, targets_z = [
29 | c.flatten() for c in np.split(targets, 3, axis=1)]
30 | res = np.stack([Rbf(points_x, points_y, points_z,
31 | values[..., i], function=int_type)(targets_x, targets_y, targets_z)
32 | for i in range(values.shape[-1])], axis=-1)
33 | res = res.reshape(targets.shape[:-1] + (-1,))
34 | return res
35 |
--------------------------------------------------------------------------------
/RosENet/objects/dataset.py:
--------------------------------------------------------------------------------
1 | import RosENet.storage.storage as storage
2 | from RosENet.objects.pdb import PDBObject
3 | from RosENet.objects.file import TFRecordsFile, File
4 | import RosENet.settings as settings
5 |
6 | class _Dataset:
7 | """Inner dataset class. Represents a dataset (which stores PDB structure folders inside)."""
8 | _instance_dict = {}
9 | def __init__(self, path):
10 | self.path = path
11 | #self.metadata = storage.read_json(self.path / "metadata.json")
12 |
13 | def __getitem__(self, key):
14 | return PDBObject(self.path / key)
15 |
16 | def list(self):
17 | return [x.name for x in self.path.iterdir() if x.is_dir() and x.name[0] != "_"]
18 |
19 | @property
20 | def name(self):
21 | return f'_{self.path.stem}_{settings.voxelization_type}_{settings.voxelization_fn}_{settings.size}'
22 |
23 | @property
24 | def tfrecords(self):
25 | return self.path / self.name
26 |
27 | def tfrecord(self, i):
28 | return TFRecordsFile(f"{i}.tfrecords",self.tfrecords)
29 |
30 | @property
31 | def images(self):
32 | return [x for x in (self.path/settings.options).iterdir()]
33 |
34 | @property
35 | def labels(self):
36 | return File("labels", self.path)
37 |
38 | def model(self, model_object, channels, seed):
39 | return self.path / f"_{model_object.name}_{channels}_{seed}"
40 |
41 | def DatasetObject(path):
42 | if str(path.absolute()) not in _Dataset._instance_dict:
43 | _Dataset._instance_dict[str(path.absolute())] = _Dataset(path.absolute())
44 | return _Dataset._instance_dict[str(path.absolute())]
45 |
46 |
--------------------------------------------------------------------------------
/conda-requirements.txt:
--------------------------------------------------------------------------------
1 | # This file may be used to create an environment using:
2 | # $ conda create --name --file
3 | # platform: linux-64
4 | _libgcc_mutex=0.1=main
5 | _tflow_select=2.3.0=mkl
6 | absl-py=0.9.0=py37_0
7 | astor=0.8.0=py37_0
8 | biopython=1.77=py37h7b6447c_0
9 | blas=1.0=mkl
10 | c-ares=1.15.0=h7b6447c_1001
11 | ca-certificates=2020.6.24=0
12 | certifi=2020.6.20=py37_0
13 | gast=0.3.3=py_0
14 | grpcio=1.27.2=py37hf8bcb03_0
15 | h5py=2.10.0=py37hd6299e0_1
16 | hdf5=1.10.6=hb1b8bf9_0
17 | intel-openmp=2020.1=217
18 | keras-applications=1.0.8=py_0
19 | keras-preprocessing=1.1.0=py_1
20 | ld_impl_linux-64=2.33.1=h53a641e_7
21 | libedit=3.1.20191231=h7b6447c_0
22 | libffi=3.3=he6710b0_1
23 | libgcc-ng=9.1.0=hdf63c60_0
24 | libgfortran-ng=7.3.0=hdf63c60_0
25 | libprotobuf=3.12.3=hd408876_0
26 | libstdcxx-ng=9.1.0=hdf63c60_0
27 | markdown=3.1.1=py37_0
28 | mkl=2020.1=217
29 | mkl-service=2.3.0=py37he904b0f_0
30 | mkl_fft=1.1.0=py37h23d657b_0
31 | mkl_random=1.1.1=py37h0573a6f_0
32 | mock=4.0.2=py_0
33 | ncurses=6.2=he6710b0_1
34 | numpy=1.18.5=py37ha1c710e_0
35 | numpy-base=1.18.5=py37hde5b4d6_0
36 | openssl=1.1.1g=h7b6447c_0
37 | pip=20.1.1=py37_1
38 | protobuf=3.12.3=py37he6710b0_0
39 | pyparsing=2.4.7=py_0
40 | python=3.7.7=hcff3b4d_5
41 | readline=8.0=h7b6447c_0
42 | scipy=1.5.0=py37h0b6359f_0
43 | setuptools=47.3.1=py37_0
44 | six=1.15.0=py_0
45 | sqlite=3.32.3=h62c20be_0
46 | tensorboard=1.13.1=py37hf484d3e_0
47 | tensorflow=1.13.1=mkl_py37h54b294f_0
48 | tensorflow-base=1.13.1=mkl_py37h7ce6ba3_0
49 | tensorflow-estimator=1.13.0=py_0
50 | termcolor=1.1.0=py37_1
51 | tk=8.6.10=hbc83047_0
52 | werkzeug=1.0.1=py_0
53 | wheel=0.34.2=py37_0
54 | xz=5.2.5=h7b6447c_0
55 | zlib=1.2.11=h7b6447c_3
56 |
--------------------------------------------------------------------------------
/RosENet/static/dock.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
--------------------------------------------------------------------------------
/RosENet/rosetta/rosetta.py:
--------------------------------------------------------------------------------
1 | """ Module wrapping the function calls to the Rosetta Commons software
2 | """
3 |
4 | import subprocess
5 | from collections import OrderedDict
6 | import RosENet.constants as constants
7 | from io import StringIO
8 | import os
9 | import pandas as pd
10 |
11 |
12 | def molfile_to_params(input_path, output_path, working_directory):
13 | new_env = os.environ.copy()
14 | # This should solve the relative import problem in Rosetta's script
15 | new_env["PYTHONPATH"] = f"{constants.rosetta.root}/main/source/scripts/python/public:{new_env.get('PYTHONPATH','')}"
16 | subprocess.run(["python2",
17 | constants.rosetta.molfile_to_params,
18 | "-n", constants.ligand_resname,
19 | "-p", output_path,
20 | "--conformers-in-one-file",
21 | "--keep-names",
22 | "--clobber",
23 | input_path],
24 | env=new_env,
25 | cwd=str(working_directory))
26 |
27 | def minimize(working_directory):
28 | subprocess.run([constants.rosetta.minimize,
29 | f"@{constants.flags_filename}",
30 | f"-parser:protocol",
31 | f'"{constants.relax_path}"',
32 | "-database", constants.rosetta.database],
33 | cwd=str(working_directory))
34 |
35 | def pdb_to_molfile(mol, complex_path, output_path):
36 | new_env = os.environ.copy()
37 | new_env["PYTHONPATH"] = f"{constants.rosetta.root}/main/source/scripts/python/public:{new_env.get('PYTHONPATH','')}"
38 | subprocess.run(" ".join(["python2",
39 | str(constants.rosetta.pdb_to_molfile),
40 | str(mol),
41 | str(complex_path),
42 | ">",
43 | str(output_path)]),
44 | cwd=str(constants.rosetta.py_wd),
45 | env=new_env,
46 | shell=True)
47 |
48 | def parse_scores(scores_text):
49 | """Parse a score.sc file to a dictionary of id to score value ordered in increasing order.
50 |
51 | scores_text : string
52 | Content of the score.sc file
53 | """
54 | scores = StringIO(scores_text)
55 | csv = pd.read_csv(scores, header=1, sep=r"\s+", usecols=['total_score', 'description'])
56 | csv.sort_values("total_score", inplace=True)
57 | return OrderedDict(zip([name.split("_")[-1] for name in csv.description], csv.total_score))
58 |
59 |
--------------------------------------------------------------------------------
/RosENet/models/kdeep.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | LEARNING_RATE = 0.0001
4 |
5 | def fire_module(net, squeeze, expand, training):
6 | net = tf.layers.conv3d(net, squeeze, [1,1,1], activation=tf.nn.relu)
7 | net1 = tf.layers.conv3d(net, expand, [1,1,1], activation=tf.nn.relu)
8 | net2 = tf.layers.conv3d(net, expand, [3,3,3], padding='same', activation=tf.nn.relu)
9 | return tf.concat(axis=-1, values=[net1,net2])
10 |
11 | def conv_net(X, reuse, training):
12 | with tf.variable_scope('SqueezeNet', reuse=reuse):
13 | net = tf.layers.conv3d(X, 96, 1, 2, padding='same', activation=tf.nn.relu)
14 | net = fire_module(net, 16, 64,training=training)
15 | net = fire_module(net, 16, 64,training=training)
16 | net = fire_module(net, 32, 128,training=training)
17 | net = tf.layers.max_pooling3d(net, 3, 2)
18 | net = fire_module(net, 32, 128,training=training)
19 | net = fire_module(net, 48, 192,training=training)
20 | net = fire_module(net, 48, 192,training=training)
21 | net = fire_module(net, 64, 256,training=training)
22 | net = tf.layers.average_pooling3d(net, 3, 2)
23 | net = tf.layers.flatten(net)
24 | net = tf.layers.dense(net, 1)
25 | return net
26 |
27 | def model_fn(features, labels, mode, rotate):
28 | training = mode == tf.estimator.ModeKeys.TRAIN
29 | predictions = conv_net(features, reuse=tf.AUTO_REUSE, training=training)
30 | loss = None
31 | train_op = None
32 | if training:
33 | loss = tf.losses.mean_squared_error(labels=labels, predictions=predictions)
34 | optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
35 | train_op = optimizer.minimize(loss,global_step=tf.train.get_global_step())
36 | else:
37 | if rotate:
38 | #When evaluating or predicting, average over the 24 90º rotations if rotate is enabled.
39 | predictions = tf.reduce_mean(tf.reshape(predictions,(-1,24,1)),axis=1)
40 | loss = tf.losses.mean_squared_error(labels,predictions)
41 | else:
42 | loss = tf.losses.mean_squared_error(labels=labels, predictions=predictions)
43 |
44 | return tf.estimator.EstimatorSpec(
45 | mode=mode,
46 | predictions=predictions,
47 | loss=loss,
48 | train_op=train_op)
49 |
--------------------------------------------------------------------------------
/RosENet/storage/storage.py:
--------------------------------------------------------------------------------
1 | """
2 | Storage module to centralize and abstract most of the IO operations in the project.
3 | """
4 |
5 | from prody import parsePDB, writePDB
6 | import shutil
7 | import h5py
8 | import json
9 | import numpy as np
10 | import tensorflow as tf
11 |
12 | def delete(path):
13 | if path.exists():
14 | if path.is_dir():
15 | shutil.rmtree(path)
16 | else:
17 | path.unlink()
18 |
19 | def clear_directory(path, no_fail=True):
20 | if path.exists() and not no_fail:
21 | shutil.rmtree(path)
22 | make_directory(path)
23 |
24 | def make_directory(path, no_fail=True):
25 | path.mkdir(exist_ok = no_fail, parents=True)
26 |
27 | def move(origin, destination, no_fail=False):
28 | if origin.exists() or not no_fail:
29 | if destination.is_dir():
30 | destination = destination / origin.name
31 | origin.rename(destination)
32 |
33 | def read_image(image_path):
34 | with h5py.File(str(image_path), "r") as f:
35 | return np.array(f["grid"])
36 |
37 | def write_image(image_path, image):
38 | with h5py.File(str(image_path), "w", libver='latest') as f:
39 | f.create_dataset("grid", dtype='f4', data=image)
40 |
41 | def read_json(json_path):
42 | with json_path.open("r") as f:
43 | data = json.load(f)
44 | return data
45 |
46 | def write_json(json_path, data):
47 | with json_path.open("w") as f:
48 | json.dump(data, f)
49 |
50 | def read_pdb(pdb_path):
51 | return parsePDB(str(pdb_path))
52 |
53 | def write_pdb(pdb_path, pdb):
54 | return writePDB(str(pdb_path), pdb)
55 |
56 | def write_tfrecords(path, data):
57 | with tf.python_io.TFRecordWriter(str(path)) as writer:
58 | for datum in data:
59 | writer.write(datum)
60 |
61 | def read_attributes(attr_path):
62 | return np.load(str(attr_path).strip(".npz")+".npz")
63 |
64 | def write_attributes(attr_path, attributes):
65 | return np.savez(str(attr_path), attributes)
66 |
67 | def read_plain(file_path):
68 | with open(file_path, "r") as f:
69 | data = f.read()
70 | return data
71 |
72 | def write_plain(file_path, text):
73 | with open(file_path, "w") as f:
74 | if isinstance(text, str):
75 | f.write(text)
76 | elif isinstance(text, list):
77 | for line in text:
78 | f.write(line.strip("\n")+"\n")
79 | else:
80 | f.write(str(text))
81 |
--------------------------------------------------------------------------------
/RosENet/preprocessing/make_ligand_mol2_renamed.py:
--------------------------------------------------------------------------------
1 | import RosENet.storage.storage as storage
2 | from RosENet.preprocessing.step import Step
3 | from RosENet.preprocessing.make_ligand_params_pdb import MakeLigandParamsPDB
4 |
5 | def read_pdb(pdb_path):
6 | """Read .pdb file and relate atom numbers to their names
7 |
8 | pdb_path : pathlib.Path
9 | Path of the pdb file
10 | """
11 | lines = storage.read_plain(pdb_path).splitlines()
12 | hetatm = filter(lambda x: x.startswith('HETATM'), lines)
13 | atom_num_name = map(lambda x: x.split()[1:3], hetatm)
14 | return dict(atom_num_name)
15 |
16 | def read_mol2(mol2_path, name_map):
17 | """Read .mol2 file and change names to the original .pdb names
18 |
19 | mol2_path : pathlib.Path
20 | Path to the mol2 file
21 | name_map : dict
22 | Atom number to name dictionary
23 | """
24 | lines = storage.read_plain(mol2_path).splitlines()
25 | mode = "search"
26 | for i, line in enumerate(lines):
27 | if mode == "search":
28 | if line.startswith('@ATOM'):
29 | mode = "rename"
30 | elif mode == "rename":
31 | if line.startswith('@BOND'):
32 | mode = "end"
33 | else:
34 | atom_num, atom_name = line.split()[0:2]
35 | new_name = name_map[atom_num].ljust(len(atom_name))
36 | position = line.find(atom_name)
37 | end = position + len(new_name)
38 | new_line = line[0:position] + new_name + line[end:]
39 | lines[i] = new_line
40 | elif mode == "end":
41 | return lines
42 |
43 | class MakeLigandMOL2Renamed(metaclass=Step,requirements=[MakeLigandParamsPDB]):
44 | @classmethod
45 | def files(cls, pdb_object):
46 | """List of files being created
47 |
48 | pdb_object : PDBObject
49 | PDB structure being handled
50 | """
51 | return [pdb_object.ligand.renamed_mol2]
52 |
53 | @classmethod
54 | def _run(cls, pdb_object):
55 | """Inner function for the preprocessing step.
56 |
57 | pdb_object : PDBObject
58 | PDB structure being handled
59 | """
60 | ligand_pdb_path = pdb_object.ligand.pdb.path
61 | ligand_mol2_path = pdb_object.ligand.mol2.path
62 | output_path = pdb_object.ligand.renamed_mol2.path
63 | name_map = read_pdb(ligand_pdb_path)
64 | storage.write_plain(output_path, read_mol2(ligand_mol2_path, name_map))
65 |
--------------------------------------------------------------------------------
/RosENet/models/large_kdeep.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | LEARNING_RATE = 0.0001
4 |
5 | def fire_module(net, squeeze, expand, training):
6 | net = tf.layers.conv3d(net, squeeze, [1,1,1], activation=tf.nn.relu)
7 | net1 = tf.layers.conv3d(net, expand, [1,1,1], activation=tf.nn.relu)
8 | net2 = tf.layers.conv3d(net, expand, [3,3,3], padding='same', activation=tf.nn.relu)
9 | return tf.concat(axis=-1, values=[net1,net2])
10 |
11 | def conv_net(X, reuse, training):
12 | with tf.variable_scope('SqueezeNet', reuse=reuse):
13 | net = tf.layers.conv3d(X, 96, 7, 2, padding='same', activation=tf.nn.relu)
14 | net = fire_module(net, 16, 64,training=training)
15 | net = fire_module(net, 16, 64,training=training)
16 | net = fire_module(net, 32, 128,training=training)
17 | net = tf.layers.max_pooling3d(net, 3, 2)
18 | net = fire_module(net, 32, 128,training=training)
19 | net = fire_module(net, 48, 192,training=training)
20 | net = fire_module(net, 48, 192,training=training)
21 | net = fire_module(net, 64, 256,training=training)
22 | net = tf.layers.average_pooling3d(net, 3, 2)
23 | net = tf.layers.flatten(net)
24 | net = tf.layers.dense(net, 1)
25 | return net
26 |
27 | def model_fn(features, labels, mode, rotate):
28 | training = mode == tf.estimator.ModeKeys.TRAIN
29 | predictions = conv_net(features, reuse=tf.AUTO_REUSE, training=training)
30 |
31 | loss = None
32 | train_op = None
33 | if training:
34 | loss = tf.losses.mean_squared_error(labels=labels, predictions=predictions)
35 | optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
36 | train_op = optimizer.minimize(loss,global_step=tf.train.get_global_step())
37 | else:
38 | #When evaluating or predicting, average over the 24 90º rotations if rotations are enabled.
39 | if rotate:
40 | predictions = tf.reduce_mean(tf.reshape(predictions,(-1,24,1)),axis=1)
41 | loss = tf.losses.mean_squared_error(labels,predictions)
42 | else:
43 | loss = tf.losses.mean_squared_error(labels=labels, predictions=predictions)
44 |
45 | return tf.estimator.EstimatorSpec(
46 | mode=mode,
47 | predictions=predictions,
48 | loss=loss,
49 | train_op=train_op,
50 | eval_metric_ops={
51 | 'R': tf.metrics.mean(r),
52 | 'ro_2': tf.metrics.mean(r_squared)
53 | })
54 |
--------------------------------------------------------------------------------
/RosENet/voxelization/filter.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.spatial.distance import cdist
3 |
4 |
5 | def exp_12(r, rvdw):
6 | """Exponential filter as given by Jimenez et al. (2018)
7 | 10.1021/acs.jcim.7b00650
8 |
9 | Parameters
10 | ----------
11 | r : numpy.ndarray
12 | Array of distances between each atom and all voxels of the 3D image.
13 | rvdw : numpy.ndarray
14 | Array of Van der Waals radii for each atom.
15 |
16 | """
17 | rvdw = rvdw.reshape((-1,))
18 | rr = rvdw[:, None]/r
19 | ret = np.where(r == 0, 1, 1 - np.exp(-(rr)**12))
20 | return ret
21 |
22 |
23 | def gaussian(r, rvdw):
24 | """Gaussian-like filter, constants removed as everything will be normalized.
25 |
26 | Parameters
27 | ----------
28 | r : numpy.ndarray
29 | Array of distances between each atom and all voxels of the 3D image.
30 | rvdw : numpy.ndarray
31 | Array of Van der Waals radii for each atom.
32 |
33 | """
34 | rvdw = rvdw.reshape((-1,))
35 | rr = r/rvdw[:, None]
36 | ret = np.exp(-(rr)**2)
37 | return ret
38 |
39 |
40 | def voxel_filter(filter_type, structure, targets):
41 | """Method to apply a filter to a set of atom coordinates with assigned
42 | values, and distribute them to a set of target points. Atom radii are
43 | taken into consideration to adjust strength of contributions.
44 |
45 | Parameters
46 | ----------
47 | filter_type : callable
48 | Function handle implementing the filter to be applied.
49 | points : numpy.ndarray
50 | Atom coordinates.
51 | values : numpy.ndarray
52 | Atom values to spacially distribute.
53 | targets : numpy.ndarray
54 | 3D voxel positions to compute the distributed values at.
55 | radii : numpy.ndarray
56 | Van der Waals radii of the atoms.
57 | """
58 | values = structure.values
59 | points = structure.coordinates
60 | radii = structure.radii
61 | shape = targets.shape
62 | targets = targets.reshape((-1,3))
63 | mask = np.linalg.norm(points, axis=-1) <= 12.5*np.sqrt(3)
64 | points = points[mask]
65 | values = values[mask, :]
66 | radii = radii[mask]
67 | dists = cdist(points, targets)
68 | aux = np.where(dists < 5, filter_type(dists, radii), 0)
69 | del dists
70 | result = np.array([values[np.argmax(aux, axis=0), i] * np.max(aux, axis=0)
71 | for i in range(values.shape[-1])])
72 | result = np.swapaxes(result, 0, 1)
73 | result = result.reshape(shape[:-1] + (-1,))
74 | return result
75 |
--------------------------------------------------------------------------------
/RosENet/voxelization/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def grid_around(center, size, spacing=1.0):
4 | """Generate an array of 3D positions for a voxel cube of the given size and
5 | coarseness.
6 |
7 | center : numpy.ndarray
8 | 3D coordinates of the cube's center.
9 | size : numpy.ndarray
10 | Number of voxels for each of the cube's dimensions.
11 | spacing : float
12 | Distance between voxel centers.
13 | """
14 | size_ang = ((size - 1) / 2.) * spacing
15 | ex_min = center - size_ang
16 | ex_max = center + size_ang
17 |
18 | x = np.linspace(ex_min[0], ex_max[0], size)
19 | y = np.linspace(ex_min[1], ex_max[1], size)
20 | z = np.linspace(ex_min[2], ex_max[2], size)
21 | return np.stack(np.meshgrid(x, y, z, indexing='ij'), axis=-1)
22 |
23 |
24 | def clip1d(value, upp_limit):
25 | """Clip a single array by a limit and zero. If the limit is negative,
26 | use it as lower bound, otherwise use it as upper bound.
27 |
28 | value : numpy.ndarray
29 | Array to be clipped.
30 | upp_limit : float
31 | Threshold to use as bound.
32 | """
33 | low_limit = 0
34 | if upp_limit < 0:
35 | upp_limit, low_limit = low_limit, upp_limit
36 | return np.clip(value, low_limit, upp_limit) / (low_limit + upp_limit)
37 |
38 |
39 | def clip(values, limits):
40 | """Clip a list of arrays by a list of limits
41 |
42 | values : list of numpy.ndarray
43 | List of arrays to be clipped.
44 | limits : list of float
45 | List of thresholds to use as bounds.
46 | """
47 | if values.shape[-1] == 1:
48 | return clip1d(values, limits[0])
49 | return np.stack((clip1d(values[...,i], limits[i]) for i in range(values.shape[-1])), axis=-1)
50 |
51 |
52 | def get_keys(pdb):
53 | """Get the unique key names to the atoms of the structure.
54 |
55 | pdb : prody.AtomGroup
56 | Structure of atoms.
57 | """
58 | resnums = pdb.getResnums()
59 | chids = pdb.getChids()
60 | names = pdb.getNames()
61 | keys = np.char.add(np.char.mod(
62 | '%s-', np.char.replace(np.char.add(np.char.mod('%s ', resnums), chids), ' ', '-')), names)
63 | return keys
64 |
65 |
66 | def save_grid(saving_path, image):
67 | """Save 3D image to HDF5 format.
68 |
69 | saving_path : str or os.PathLike
70 | Destination path to save the image.
71 | image : numpy.ndarray
72 | 3D image to be saved.
73 | """
74 | with h5py.File(str(saving_path), "w", libver='latest') as f:
75 | f.create_dataset("grid", dtype='f4', data=image)
76 |
77 |
--------------------------------------------------------------------------------
/RosENet/static/dock_relax.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
--------------------------------------------------------------------------------
/RosENet/static/dock_relax2.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
--------------------------------------------------------------------------------
/RosENet/models/resnet.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | LEARNING_RATE = 0.0001
4 |
5 |
6 | def start_block(net, channels, training):
7 | input_net = tf.layers.conv3d(net, 4*channels, 1, 2, padding='same')
8 | net = conv3bn(net, channels, 1, training,stride=2)
9 | net = conv3bn(net, channels, 3, training)
10 | net = conv3bn(net, 4*channels, 1, training)
11 | output_net = net + input_net
12 | return output_net
13 |
14 | def inner_block(net, channels, training):
15 | input_net = tf.layers.conv3d(net, 4*channels, 1, padding='same')
16 | net = conv3bn(net, channels, 1, training)
17 | net = conv3bn(net, channels, 3, training)
18 | net = conv3bn(net, 4*channels, 1, training)
19 | output_net = net + input_net
20 | return output_net
21 |
22 |
23 | def conv3bn(net, channels, filt, training, stride=1):
24 | net = tf.nn.relu(net)
25 | net = tf.layers.conv3d(net, channels, filt, stride, padding='same')
26 | return net
27 |
28 | def conv_net(X, reuse, training):
29 | with tf.variable_scope('ResNet', reuse=reuse):
30 | layers = [3,4,23,3]
31 | k = 64
32 | net = tf.layers.conv3d(X, k, 7, 2, padding='same')
33 | net = tf.nn.relu(net)
34 | net = tf.layers.max_pooling3d(net, 3, 2)
35 | for i in range(0,layers[0]):
36 | net = inner_block(net, k, training)
37 | for i, l in enumerate(layers[1:],1):
38 | net = start_block(net, k*(2**i), training)
39 | for j in range(0,l-1):
40 | net = inner_block(net, k*(2**i), training)
41 | net = tf.reduce_mean(net, axis=(1,2,3))
42 | net = tf.layers.flatten(net)
43 | net = tf.layers.dense(net, 1)
44 | return net
45 |
46 | def model_fn(features, labels, mode, rotate):
47 | training = mode == tf.estimator.ModeKeys.TRAIN
48 | predictions = conv_net(features, reuse=tf.AUTO_REUSE, training=training)
49 |
50 | loss = None
51 | train_op = None
52 | if training:
53 | loss = tf.losses.mean_squared_error(labels=labels, predictions=predictions)
54 | optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
55 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
56 | with tf.control_dependencies(update_ops):
57 | train_op = optimizer.minimize(loss,global_step=tf.train.get_global_step())
58 | else:
59 | if rotate:
60 | #When evaluating or predicting, average over the 24 90º rotations if rotate is enabled.
61 | predictions = tf.reduce_mean(tf.reshape(predictions,(-1,24,1)), axis=1)
62 | loss = tf.losses.mean_squared_error(labels, predictions)
63 | else:
64 | loss = tf.losses.mean_squared_error(labels=labels, predictions=predictions)
65 |
66 | return tf.estimator.EstimatorSpec(
67 | mode=mode,
68 | predictions=predictions,
69 | loss=loss,
70 | train_op=train_op)
71 |
--------------------------------------------------------------------------------
/RosENet/constants.py:
--------------------------------------------------------------------------------
1 | """Constants file. Here one only needs to change the location of the different
2 | tools used by the project.
3 | """
4 |
5 | import os
6 | from types import SimpleNamespace
7 | from pathlib import Path
8 |
9 | source_path = os.path.dirname(os.path.abspath(__file__))
10 | flags_filename = "flags_relax.txt"
11 | flags_relax_path = os.path.join(source_path, "static", flags_filename)
12 | relax_filename = "dock_relax.xml"
13 | relax_path = os.path.join(source_path, "static", relax_filename)
14 | metal_selector = "chain Z"
15 | ligand_resname = "WER"
16 | close_ligand_selector = f"resname {ligand_resname} and (not (element H or element C)) and within 2.5 of t"
17 | protein_selector = f"not resname {ligand_resname}"
18 | mgl_python_path = "/home/hussein/Repositories/Own/mgltools_x86_64Linux2_1.5.6/bin/pythonsh"
19 | preprocess_vina_path = os.path.join(source_path, "preprocessing", "preprocess_vina.py")
20 | nonstd2stdresidues = {'HOH':'WAT',
21 | 'CYX':'CYS',
22 | 'CYM':'CYS',
23 | 'HIE':'HIS',
24 | 'HID':'HIS',
25 | 'HSD':'HIS',
26 | 'HIP':'HIS',
27 | 'HIY':'HIS',
28 | 'ALB':'ALA',
29 | 'ASM':'ASN',
30 | 'DIC':'ASP',
31 | 'GLV':'GLU',
32 | 'GLO':'GLN',
33 | 'HIZ':'HIS',
34 | 'LEV':'LEU',
35 | 'SEM':'SER',
36 | 'TYM':'TYR',
37 | 'TRQ':'TRP',
38 | 'KCX':'LYS',
39 | 'LLP':'LYS',
40 | 'ARN':'ARG',
41 | 'ASH':'ASP',
42 | 'DID':'ASP',
43 | 'ASZ':'ASP',
44 | 'CYT':'CYS',
45 | 'GLH':'GLU',
46 | 'LYN':'LYS',
47 | 'AR0':'ARG',
48 | 'PCA':'GLU',
49 | 'HSE':'SER'}
50 | ligand_chid = "X"
51 | water_residue = "WAT"
52 | water_chain = "W"
53 | accepted_metals = ["MN", "MG", "ZN", "CA", "NA"]
54 | metal_chain = "Z"
55 | ligand_selector = f"not protein and same residue as ((resname {water_residue} and within 3 of resname {ligand_resname} and within 3 of protein) or (resname {' '.join(accepted_metals)} and within 5 of resname {ligand_resname}) or resname {ligand_resname})"
56 |
57 | class classproperty(object):
58 | def __init__(self, getter):
59 | self.getter = getter
60 | def __get__(self, instance, owner):
61 | return self.getter(owner)
62 |
63 |
64 | class rosetta:
65 | root = Path("/home/hussein/Repositories/Own/rosetta")
66 |
67 | @classproperty
68 | def molfile_to_params(cls):
69 | return cls.root / "main/source/scripts/python/public/molfile_to_params.py"
70 |
71 | @classproperty
72 | def minimize(cls):
73 | return cls.root / "main/source/bin/rosetta_scripts.static.linuxgccrelease"
74 |
75 | @classproperty
76 | def pdb_to_molfile(cls):
77 | return cls.root / "main/source/src/apps/public/ligand_docking/pdb_to_molfile.py"
78 |
79 | @classproperty
80 | def database(cls):
81 | return cls.root / "main/database/"
82 |
83 | @classproperty
84 | def py_wd(cls):
85 | return cls.root / "main/source/scripts/python/public/"
86 |
--------------------------------------------------------------------------------
/RosENet/objects/file.py:
--------------------------------------------------------------------------------
1 | from RosENet.storage.storage import *
2 | from RosENet.objects.pdb import PDBObject
3 | import RosENet.storage.storage as storage
4 | import RosENet.rosetta.rosetta as rosetta
5 |
6 | class File:
7 | """Base class for file management. Represents any file, which name may be
8 | generic. It allows for read/write access and templates the name of the file
9 | according to the PDB code(root folder name) and the number of the file (for
10 | when it represents the result of the Rosetta minimization)"""
11 | def __init__(self, name, root):
12 | self.root = root
13 | self.code = root.name
14 | self.name = name.format(code=self.code, number="{number}")
15 | if "{number}" in name:
16 | self.multiple = True
17 | else:
18 | self.multiple = None
19 |
20 | @property
21 | def path(self):
22 | return self.resolve_path()
23 |
24 | def resolve_path(self, number=None):
25 | if number:
26 | return self.root / self.name.format(number=number)
27 | else:
28 | if self.multiple is True:
29 | scores = rosetta.parse_scores(PDBObject(self.root).minimized.scores.read())
30 | number = (list(scores.keys()))[0].split("_")[-1]
31 | self.multiple = self.name.format(number=number)
32 | return self.root / self.multiple
33 | elif self.multiple:
34 | return self.root / self.multiple
35 | else:
36 | return self.root / self.name
37 |
38 | def read(self):
39 | return storage.read_plain(self.path)
40 |
41 | def write(self, data):
42 | return storage.write_plain(self.path, data)
43 |
44 | def delete(self):
45 | return storage.delete(self.path)
46 |
47 | def __getitem__(self, key):
48 | if self.multiple:
49 | return File.create(self.name.format(number=key), self.root)
50 | else:
51 | return self
52 |
53 | @staticmethod
54 | def create(path, root):
55 | if ".pdbqt" in path:
56 | return PDBQTFile(path, root)
57 | elif ".pdb" in path:
58 | return PDBFile(path, root)
59 | elif ".img" in path:
60 | return ImageFile(path, root)
61 | elif ".attr" in path:
62 | return AttributeFile(path, root)
63 | elif ".tfrecords" in path:
64 | return TFRecordsFile(path, root)
65 | else:
66 | return File(path, root)
67 |
68 |
69 | class PDBFile(File):
70 | def read(self):
71 | return read_pdb(self.path)
72 |
73 | def write(self, data):
74 | write_pdb(self.path, data)
75 |
76 |
77 | class TFRecordsFile(File):
78 | def read(self):
79 | return read_tfrecords(self.path)
80 |
81 | def write(self, data):
82 | write_tfrecords(self.path, data)
83 |
84 | class ImageFile(File):
85 | def read(self):
86 | return read_image(self.path)
87 |
88 | def write(self, data):
89 | write_image(self.path, data)
90 |
91 | class AttributeFile(File):
92 | def read(self):
93 | return read_attributes(self.path)
94 |
95 | def write(self, data):
96 | write_attributes(self.path, data)
97 |
98 | class PDBQTFile(File):
99 | def read(self):
100 | return read_pdb(self.resolve_path())
101 |
102 | def write(self, data):
103 | write_pdb(self.resolve_path(), data)
104 |
--------------------------------------------------------------------------------
/RosENet/postprocessing/postprocessing.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | from RosENet.storage import storage
4 |
5 | def combine_maps(pdb_object):
6 | """Postprocessing step for combining the feature map images into one image
7 |
8 | Parameters
9 | ----------
10 | pdb_object : PDBObject
11 | PDB object which images will be combined
12 | """
13 | features = [pdb_object.image.htmd.read(),
14 | pdb_object.image.electronegativity.read(),
15 | pdb_object.image.rosetta.read()]
16 | grid = np.concatenate(features, axis=-1)
17 | storage.make_directory(pdb_object.image.combined.path.parent)
18 | pdb_object.image.combined.write(grid)
19 |
20 |
21 | def serialize_file(file, target, type):
22 | """Serialize image file to Tensorflow serialized format
23 |
24 | Parameters
25 | ----------
26 | file : pathlib.Path
27 | Path for the image file.
28 | target : float
29 | Ground truth binding affinity for the image.
30 | """
31 | datapoint = storage.read_image(file)
32 | features = datapoint.flatten()
33 | label = bytes(file.stem,"utf-8")
34 | type = 0 if type == "Kd" else 1
35 | example = tf.train.Example(features=tf.train.Features(feature={
36 | 'id' : tf.train.Feature(bytes_list=tf.train.BytesList(value=[label])),
37 | 'X' : tf.train.Feature(float_list=tf.train.FloatList(value=features)),
38 | 'y' : tf.train.Feature(float_list=tf.train.FloatList(value=np.array([target]))),
39 | 'type' : tf.train.Feature(int64_list=tf.train.Int64List(value=[type]))
40 | }))
41 | return example.SerializeToString()
42 |
43 | def write_tfrecords(files, dataset_object, number, labels):
44 | """Serialize the given images and store them in a TFRecords file, together with their binding affinity values.
45 |
46 | files : list of pathlib.Path
47 | List of files to be written inside the TFRecords file.
48 | dataset_object : DatasetObject
49 | Dataset of the images.
50 | number : int
51 | ID number of the TFRecords file.
52 | labels : dict
53 | Dictionary relating image names to their binding affinities.
54 | """
55 | output = [serialize_file(file, float(labels[file.stem][0]), "Kd") for file in files]
56 | dataset_object.tfrecord(number).write(output)
57 |
58 | def chunk_by_size(files, recommended_tf_size=float(100*(2**20))):
59 | """Split the images into chunks according to the TFRecords recommended size, which is 100MB.
60 |
61 | files : list of pathlib.Path
62 | List of image paths to be split.
63 | recommended_tf_size : float
64 | Maximum size of each chunk, 100MB by default.
65 | """
66 | average_size = sum(map(lambda x: x.stat().st_size, files)) / float(len(files))
67 | files_per_chunk = np.ceil(recommended_tf_size / average_size)
68 | chunks = int(np.ceil(float(len(files)) / files_per_chunk))
69 | return np.array_split(np.array(files), chunks)
70 |
71 | def generate_tfrecords(dataset_object):
72 | """Generate TFRecords from a dataset's combined images.
73 |
74 | dataset_object : DatasetObject
75 | Dataset of the images.
76 | """
77 | files = dataset_object.images
78 | chunks= chunk_by_size(files)
79 | storage.clear_directory(dataset_object.tfrecords, no_fail=True)
80 | storage.make_directory(dataset_object.tfrecords, no_fail=True)
81 | lines = dataset_object.labels.read().splitlines()
82 | lines = [line.split(" ") for line in lines]
83 | pdb_labels = dict([(line[0],line[1:]) for line in lines])
84 | for i, chunk in enumerate(chunks):
85 | write_tfrecords(chunk, dataset_object, i, pdb_labels)
86 |
87 |
--------------------------------------------------------------------------------
/RosENet/preprocessing/step.py:
--------------------------------------------------------------------------------
1 | import RosENet.utils as utils
2 |
3 | class Step(type):
4 | """Metaclass for preprocessing steps. Implements functions to run the steps
5 | according to their dependencies, clean the outputed data and check for
6 | readiness of execution."""
7 |
8 | def __new__(cls, name, bases, dct, requirements=[]):
9 | x = super(Step, cls).__new__(cls,name, bases, dct)
10 | x.requirements = requirements
11 | x.successors = []
12 | x.name = name
13 | for requirement in requirements:
14 | requirement.successors.append(x)
15 | return x
16 |
17 |
18 | def clean(cls, pdb_object):
19 | """Delete files and folders created by the step and its successor steps.
20 |
21 | pdb_object : PDBObject
22 | PDB structure to be handled
23 | """
24 | if not cls.computed(pdb_object):
25 | return
26 | for successor in cls.successors:
27 | successor.clean(pdb_object)
28 | pdb_object.uncomplete(cls.name)
29 | for file in cls.files(pdb_object):
30 | file.delete()
31 |
32 | def run(cls, pdb_object, callbacks=[], force=False, surrogate=False):
33 | """Run the preprocessing step for the given PDB structure.
34 |
35 | pdb_object : PDBObject
36 | PDB structure to be handled
37 | callbacks : list of callable
38 | List of callbacks for logging
39 | force : bool
40 | Force execution even if the step was already computed? Default False
41 | surrogate : bool
42 | Run step as consequence of another step. Default False
43 | """
44 | if cls.computed(pdb_object) and not force:
45 | if not surrogate:
46 | utils.message_callbacks(callbacks, (cls.name, pdb_object.id, "already_done"))
47 | return True
48 | elif not cls.ready(pdb_object):
49 | if not surrogate:
50 | utils.message_callbacks(callbacks, (cls.name, pdb_object.id, "not_ready"))
51 | return False
52 | utils.message_callbacks(callbacks, (cls.name, pdb_object.id, "start"))
53 | try:
54 | cls._run(pdb_object)
55 | except Exception as e:
56 | utils.message_callbacks(callbacks, (cls.name, pdb_object.id, "error", e))
57 | return False
58 | pdb_object.complete(cls.name)
59 | utils.message_callbacks(callbacks, (cls.name, pdb_object.id, "end"))
60 | return True
61 |
62 | def computed(cls, pdb_object):
63 | """Check if the step was already computed.
64 |
65 | pdb_object : PDBObject
66 | PDB structure to be handled
67 | """
68 | return cls.name in pdb_object.completed_steps
69 |
70 |
71 | def ready(cls, pdb_object):
72 | """Check if the dependencies of the step have been fulfilled.
73 |
74 | pdb_object : PDBObject
75 | PDB structure to be handled
76 | """
77 | return all(map(lambda x: x.computed(pdb_object), cls.requirements))
78 |
79 | def run_until(cls, pdb_object, callbacks=[], surrogate=False):
80 | """Run step and all dependencies necessary for its execution.
81 |
82 | pdb_object : PDBObject
83 | PDB structure to be handled
84 | callbacks : list of callable
85 | List of callbacks for logging
86 | surrogate : bool
87 | Run step as consequence of another step. Default False
88 | """
89 | for requirement in cls.requirements:
90 | if not requirement.run_until(pdb_object, callbacks, surrogate=True):
91 | return False
92 | return cls.run(pdb_object,callbacks, force=False, surrogate=surrogate)
93 |
94 |
95 |
--------------------------------------------------------------------------------
/RosENet/objects/pdb.py:
--------------------------------------------------------------------------------
1 | from types import SimpleNamespace
2 | from RosENet.storage.storage import *
3 | import RosENet.storage.storage as storage
4 | import RosENet.rosetta.rosetta as rosetta
5 | from RosENet import settings
6 |
7 | class _PDB:
8 | """Inner PDB class. Represents a PDB structure (a folder with a protein.pdb and ligand.mol2 files)."""
9 |
10 | _instance_dict = {}
11 | _property_tree = { "flags_relax" : "flags_relax.txt",
12 | "constraints" : "constraints",
13 | "ligand" : {
14 | "pdb" : "{code}_ligand.pdb",
15 | "mol2" : "{code}_ligand.mol2",
16 | "renamed_mol2" : "{code}_ligand_renamed.mol2",
17 | "pdbqt" : "{code}_ligand.pdbqt",
18 | "params" : "{code}_ligand.params"
19 | },
20 | "protein" : {
21 | "pdb" : "{code}_protein.pdb",
22 | "pdbqt" : "{code}_protein.pdbqt"
23 | },
24 | "complex" : {
25 | "pdb" : "{code}_complex.pdb"
26 | },
27 | "minimized" : {
28 | "scores" : "score.sc",
29 | "hidden_complexes" : "other_complexes",
30 | "ligand" : {
31 | "mol2" : "{code}_ligand_{number}.mol2",
32 | "pdbqt": "{code}_ligand_{number}.pdbqt"
33 | },
34 | "protein" : {
35 | "pdb" : "{code}_protein_{number}.pdb",
36 | "mol2" : "{code}_protein_{number}.mol2",
37 | "pdbqt": "{code}_protein_{number}.pdbqt"
38 | },
39 | "complex" : {
40 | "pdb": "{code}_complex_{number}.pdb",
41 | "pdbqt": "{code}_complex_{number}.pdbqt",
42 | "attr": "{code}_complex_{number}.attr"
43 | }
44 | },
45 | "image" : {
46 | "rosetta" : "{code}_rosetta.img",
47 | "htmd" : "{code}_htmd.img",
48 | "electronegativity" : "{code}_electroneg.img"
49 | }
50 | }
51 |
52 | def __init__(self, path):
53 | self.path = path
54 | self.id = path.name
55 | self.metadata_path = path / "metadata.json"
56 | if not self.metadata_path.exists():
57 | storage.write_json(self.metadata_path, { "completed_steps" : []})
58 | self.metadata = storage.read_json(self.metadata_path)
59 | _create(_PDB._property_tree, path, self.__dict__)
60 | from RosENet.objects.file import File
61 | self.image.combined = File.create(f"{self.id}.img", path.parent / settings.options)
62 |
63 | @property
64 | def completed_steps(self):
65 | return self.metadata["completed_steps"]
66 |
67 | def complete(self, step):
68 | self.metadata["completed_steps"].append(step)
69 | self.metadata["completed_steps"] = list(set(self.metadata["completed_steps"]))
70 | storage.write_json(self.metadata_path, self.metadata)
71 |
72 | def uncomplete(self, step):
73 | try:
74 | self.metadata["completed_steps"].remove(step)
75 | storage.write_json(self.metadata_path, self.metadata)
76 | except ValueError:
77 | pass
78 |
79 | def PDBObject(path):
80 | if str(path.absolute()) not in _PDB._instance_dict:
81 | _PDB._instance_dict[str(path.absolute())] = _PDB(path)
82 | return _PDB._instance_dict[str(path.absolute())]
83 |
84 | def _create(tree, path, result=None):
85 | from RosENet.objects.file import File
86 | root = True
87 | if not result:
88 | root = False
89 | result = {}
90 | for key, value in tree.items():
91 | if isinstance(value, dict):
92 | result[key] = _create(value, path)
93 | else:
94 | result[key] = File.create(value, path)
95 | if root:
96 | return result
97 | return SimpleNamespace(**result)
98 |
99 |
--------------------------------------------------------------------------------
/RosENet/preprocessing/minimize_rosetta.py:
--------------------------------------------------------------------------------
1 | from RosENet.preprocessing.step import Step
2 | from RosENet.preprocessing.make_complex_pdb import MakeComplexPDB
3 | from string import Template
4 | import RosENet.constants as constants
5 | import RosENet.rosetta.rosetta as rosetta
6 | import RosENet.storage.storage as storage
7 |
8 | def generate_minimization_flags_file(pdb_object):
9 | """Generate minimization flags file for Rosetta minimization.
10 |
11 | pdb_object : PDBObject
12 | PDB structure to be handled
13 | """
14 | complex_path = pdb_object.complex.pdb.path
15 | complex = complex_path.name
16 | name = complex_path.stem
17 | params = pdb_object.ligand.params.path.name
18 | template = Template(storage.read_plain(constants.flags_relax_path))
19 | substitution = {'complex' : complex,
20 | 'name' : name,
21 | 'params' : params}
22 | output = template.substitute(substitution)
23 | pdb_object.flags_relax.write(output)
24 |
25 | def generate_constraint_file(pdb_object):
26 | """Generate constraint file for Rosetta minimization.
27 | Constraints are added to maintain contact ions close to their original positions.
28 |
29 | pdb_object : PDBObject
30 | PDB structure to be handled
31 | """
32 | complex = pdb_object.complex.pdb.read()
33 | output_path = pdb_object.constraints.path
34 | metals = complex.select(constants.metal_selector)
35 | results = []
36 | if metals:
37 | for atom in metals:
38 | pos = atom.getCoords()
39 | close_ligand = complex.select(constants.close_ligand_selector, t=pos)
40 | if close_ligand:
41 | for close in close_ligand:
42 | results.append((atom.getName(),
43 | str(atom.getResnum())+atom.getChid(),
44 | close.getName(),
45 | str(close.getResnum())+close.getChid()))
46 | pdb_object.constraints.write(
47 | [f"AtomPair {r[0]} {r[1]} {r[2]} {r[3]} SQUARE_WELL 2.5 -2000\n" for r in results])
48 |
49 | def hide_non_minimal_complexes(pdb_object):
50 | """Generate constraint file for Rosetta minimization.
51 | Constraints are added to maintain contact ions close to their original positions.
52 |
53 | pdb_object : PDBObject
54 | PDB structure to be handled
55 | """
56 | scores = rosetta.parse_scores(pdb_object.minimized.scores.read())
57 | hidden_folder = pdb_object.minimized.hidden_complexes.path
58 | storage.make_directory(hidden_folder)
59 | for number in list(scores.keys())[1:]:
60 | complex_path = pdb_object.minimized.complex.pdb[number].path
61 | print("Hiding ", complex_path)
62 | storage.move(complex_path, hidden_folder, no_fail=True)
63 |
64 |
65 |
66 | class MinimizeRosetta(metaclass=Step, requirements=[MakeComplexPDB]):
67 | @classmethod
68 | def files(cls, pdb_object):
69 | """List of files being created
70 |
71 | pdb_object : PDBObject
72 | PDB structure being handled
73 | """
74 | return [pdb_object.flags_relax,
75 | pdb_object.constraints,
76 | pdb_object.minimized.hidden_complexes,
77 | pdb_object.minimized.complex.pdb,
78 | pdb_object.minimized.scores]
79 |
80 | @classmethod
81 | def _run(cls, pdb_object):
82 | """Inner function for the preprocessing step.
83 |
84 | pdb_object : PDBObject
85 | PDB structure being handled
86 | """
87 | generate_minimization_flags_file(pdb_object)
88 | generate_constraint_file(pdb_object)
89 | if pdb_object.minimized.scores.path.exists():
90 | hide_non_minimal_complexes(pdb_object)
91 | rosetta.minimize(working_directory = pdb_object.path)
92 | hide_non_minimal_complexes(pdb_object)
93 |
94 |
--------------------------------------------------------------------------------
/RosENet/clui.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | import os
3 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
4 | import sys
5 | from argparse import ArgumentParser
6 | from pathlib import Path
7 | from RosENet.objects.model import ModelObject
8 | from RosENet.objects.dataset import DatasetObject
9 | from RosENet.preprocessing import preprocess
10 | from RosENet.voxelization import voxelize
11 | from RosENet.postprocessing import postprocess
12 | from RosENet.preprocessing.minimize_rosetta import hide_non_minimal_complexes
13 |
14 | def parse_arguments():
15 | """Parse arguments and perform the selected action."""
16 | arguments = sys.argv[1:]
17 | optional_parser = ArgumentParser()
18 | optional_parser.add_argument("--njobs", default=None)
19 | optional_parser.add_argument("--gpu", default=None)
20 | optional_parser.add_argument("extra", nargs="*")
21 | optional = optional_parser.parse_args(arguments)
22 | njobs = optional.njobs
23 | gpu = optional.gpu
24 | if njobs is None:
25 | njobs = multiprocessing.cpu_count()-1
26 | else:
27 | njobs = int(njobs)
28 | print("#####",njobs)
29 | print(optional.extra)
30 | arguments = optional.extra
31 | os.environ["CUDA_VISIBLE_DEVICES"]=""
32 | if gpu is not None:
33 | os.environ["CUDA_VISIBLE_DEVICES"]=gpu
34 | from RosENet.network.network import train, evaluate, predict
35 | parser = ArgumentParser(description="RosENet Tool",
36 | usage='''tool []
37 |
38 | preprocess Compute steps previous to voxelization
39 | voxelize Compute 3D image voxelization
40 | postprocess Format image for TensorFlow
41 | train Train neural network
42 | evaluate Evaluate neural network
43 | predict Predict binding affinity
44 | ''')
45 | parser.add_argument("action", help='Command to run')
46 | parser.add_argument("dataset", help='Dataset path')
47 | args = parser.parse_args(arguments[0:2])
48 | action = args.action
49 | dataset = DatasetObject(Path(args.dataset))
50 | if action in ["train", "evaluate", "predict"]:
51 | parser = ArgumentParser(usage=
52 | '''{train,evaluate,predict} second_dataset network channels [seed]
53 |
54 | second_dataset has the following meanings:
55 | If action is train: second_dataset is the validation dataset
56 | If action is evaluate: second_dataset is the evaluation dataset
57 | If action is predict: second_dataset is the prediction dataset
58 |
59 | In training if no seed is given, it will be randomly chosen
60 | The seed is necesary for evaluation and prediction to chose the
61 | correct instance of the trained model.
62 | ''')
63 | parser.add_argument("validation_dataset")
64 | parser.add_argument("network")
65 | parser.add_argument("channels")
66 | parser.add_argument("seed", nargs="?", default=None)
67 | args = parser.parse_args(arguments[2:6])
68 | other_dataset= DatasetObject(Path(args.validation_dataset))
69 | network = ModelObject(Path(args.network))
70 | channels = args.channels
71 | seed = int(args.seed)
72 | if action in ["preprocess", "voxelize"]:
73 | import random
74 | pdbs = dataset.list()
75 | random.shuffle(pdbs)
76 | pdbs = map(dataset.__getitem__, pdbs)
77 | p = multiprocessing.Pool(gpu)
78 | if action == "preprocess":
79 | #p.map(hide_non_minimal_complexes, pdbs)
80 | p.map(preprocess, pdbs)
81 | elif action == "voxelize":
82 | p.map(voxelize, pdbs)
83 | elif action == "postprocess":
84 | postprocess(dataset)
85 | if action == "train":
86 | from RosENet.network.network import train
87 | train(dataset, other_dataset, network, seed, channels)
88 | elif action == "evaluate":
89 | from RosENet.network.network import evaluate
90 | evaluate(dataset, other_dataset, network, seed, channels)
91 | elif action == "predict":
92 | from RosENet.network.network import predict
93 | predict(dataset, other_dataset, network, seed, channels)
94 |
--------------------------------------------------------------------------------
/RosENet/network/network.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
3 | import tensorflow as tf
4 | import random
5 | import RosENet.settings as settings
6 | from RosENet.network.utils import save_results
7 | import time
8 |
9 | def train(dataset_train, dataset_evaluate, model_object, seed=None, channels=""):
10 | """Training method
11 |
12 | dataset_train : DatasetObject
13 | Dataset to be used as training set
14 | dataset_evaluate : DatasetObject
15 | Dataset to be used as validation set
16 | model_object : ModelObject
17 | CNN model to train
18 | seed : int
19 | Seed for initializing randomness. If None, a random seed will be used.
20 | channels : string
21 | Channels selectors for choosing feature subsets
22 | """
23 | if seed is None:
24 | seed = random.randint(0,2147483647)
25 | tf.set_random_seed(seed)
26 | model_train_object = model_object.train_object(dataset_train,channels,seed)
27 | model_valid_object = model_object.evaluate_object(dataset_evaluate,channels,seed)
28 | with tf.Session() as sess:
29 | sess.run(tf.global_variables_initializer())
30 | min_validation = float('inf')
31 | epoch_min_validation = -1
32 | for num_epochs in range(settings.max_epochs):
33 | t1 = time.time()
34 | tr_loss = model_train_object.do(sess)
35 | t2 = time.time() - t1
36 | print(f"Train loss: {tr_loss} Elapsed: {t2}")
37 | t1 = time.time()
38 | va_loss = model_valid_object.do(sess)
39 | t2 = time.time() - t1
40 | print(f"Validation loss: {va_loss} Elapsed: {t2}")
41 | if min_validation > va_loss:
42 | min_validation = va_loss
43 | epoch_min_validation = num_epochs
44 | model_train_object.save(sess)
45 | print(f"Best validation: {min_validation} Epoch: {epoch_min_validation}")
46 | model_train_object.save(sess)
47 | save_results(model_train_object, channels, min_validation, epoch_min_validation, seed)
48 |
49 | def evaluate(dataset_train, dataset_evaluate, model_object, seed, channels=""):
50 | """Evaluation method
51 |
52 | dataset_train : DatasetObject
53 | Dataset used during the training phase
54 | dataset_evaluate : DatasetObject
55 | Dataset to be evaluated
56 | model_object : ModelObject
57 | CNN model used during the training phase
58 | seed : int
59 | Seed for initializing randomness. If None, a random seed will be used.
60 | channels : string
61 | Channels selectors for choosing feature subsets
62 | """
63 | model_train_object = model_object.train_object(dataset_train, channels, seed)
64 | model_valid_object = model_object.evaluate_object(dataset_evaluate, channels, seed)
65 | with tf.Session() as sess:
66 | sess.run(tf.global_variables_initializer())
67 | model_train_object.load(sess)
68 | val_loss = model_valid_object.do(sess)
69 | print(val_loss)
70 | return val_loss
71 |
72 | def predict(dataset_train, dataset_predict, model_object, seed, channels=""):
73 | """Predict method
74 |
75 | dataset_train : DatasetObject
76 | Dataset used during the training phase
77 | dataset_evaluate : DatasetObject
78 | Dataset to be predicted
79 | model_object : ModelObject
80 | CNN model used during the training phase
81 | seed : int
82 | Seed for initializing randomness. If None, a random seed will be used.
83 | channels : string
84 | Channels selectors for choosing feature subsets
85 | """
86 | model_train_object = model_object.train_object(dataset_train, channels, seed)
87 | model_pred_object = model_object.predict_object(dataset_predict, channels, seed)
88 | with tf.Session() as sess:
89 | sess.run(tf.global_variables_initializer())
90 | model_train_object.load(sess)
91 | results = model_pred_object.do(sess)
92 | print(results)
93 | return results
94 |
--------------------------------------------------------------------------------
/RosENet/voxelization/apbs.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from subprocess import call
3 | import numpy as np
4 | from prody import writePQR
5 | from utils import isfloat
6 |
7 |
8 | class APBS:
9 | _APBS_BIN_PATH = '/cluster/home/hhussein/apbs.sh'
10 | _TEMPLATE_FILE = '/cluster/home/hhussein/apbs.in'
11 | @staticmethod
12 | def run(
13 | output_path,
14 | name,
15 | pose,
16 | selection,
17 | grid_dim,
18 | grid_space,
19 | center,
20 | cglen,
21 | fglen):
22 | apbs_bin_path = APBS._APBS_BIN_PATH
23 | apbs_template_file = Path(APBS._TEMPLATE_FILE)
24 | apbs_input_file = output_path / f'apbs_{name}.in'
25 | apbs_output_file = output_path / f'{name}_potential.dx'
26 | if apbs_input_file.exists():
27 | apbs_input_file.unlink()
28 | if apbs_output_file.exists():
29 | apbs_output_file.unlink()
30 | writePQR(f'{output_path/name}.pqr', pose.select(selection))
31 | with apbs_template_file.open('r') as f:
32 | file_data = f.read()
33 | file_data = APBS._replace_apbs(
34 | file_data, str(output_path/name), grid_dim, grid_space, center, cglen, fglen)
35 | with apbs_input_file.open('w') as f:
36 | f.write(file_data)
37 | call([apbs_bin_path,
38 | f'{apbs_input_file.absolute()}'],
39 | cwd=str(output_path))
40 | apbs_input_file.unlink()
41 | o, d, potential = APBS._import_dx(apbs_output_file)
42 | apbs_output_file.unlink()
43 | return potential
44 |
45 | @staticmethod
46 | def _replace_apbs(
47 | filedata,
48 | xxx,
49 | grid_dim,
50 | grid_space,
51 | center,
52 | cglen,
53 | fglen):
54 | filedata = filedata \
55 | .replace('XXX', xxx) \
56 | .replace('GRID_DIM', grid_dim) \
57 | .replace('GRID_SPACE', grid_space) \
58 | .replace('INH_CENTER', center) \
59 | .replace('CG_LEN', cglen) \
60 | .replace('FG_LEN', fglen)
61 | return filedata
62 |
63 | @staticmethod
64 | def _import_dx(filename):
65 | origin = delta = data = dims = None
66 | counter = 0
67 | with open(filename, 'r') as dxfile:
68 | for row in dxfile:
69 | row = row.strip().split()
70 | if not row:
71 | continue
72 | if row[0] == '#':
73 | continue
74 | elif row[0] == 'origin':
75 | origin = np.array(row[1:], dtype=float)
76 | elif row[0] == 'delta':
77 | delta = np.array(row[2:], dtype=float)
78 | elif row[0] == 'object':
79 | if row[1] == '1':
80 | dims = np.array(row[-3:], dtype=int)
81 | data = np.empty(np.prod(dims))
82 | elif isfloat(row[0]):
83 | data[3 * counter:min(3 * (counter + 1), len(data))
84 | ] = np.array(row, dtype=float)
85 | counter += 1
86 | data = data.reshape(dims)
87 | return origin, delta, data
88 |
89 | @staticmethod
90 | def _export_dx(filename, density, origin, delta):
91 | nx, ny, nz = density.shape
92 | with open(filename, 'w') as dxfile:
93 | dxfile.write(
94 | f'object 1 class gridpositions counts {nx} {ny} {nz}\n')
95 | dxfile.write(f'origin {origin[0]} {origin[1]} {origin[2]}\n')
96 | dxfile.write(f'delta {delta} 0.0 0.0\n')
97 | dxfile.write(f'delta 0.0 {delta} 0.0\n')
98 | dxfile.write(f'delta 0.0 0.0 {delta}\n')
99 | dxfile.write(
100 | f'object 2 class gridconnections counts {nx}, {ny}, {nz}\n')
101 | dxfile.write(
102 | f'object 3 class array type double rank 0 items {nx * ny * nz} data follows\n')
103 | i = 1
104 | for d in density.flatten(order='C'):
105 | if i % 3:
106 | dxfile.write('{} '.format(d))
107 | else:
108 | dxfile.write('{}\n'.format(d))
109 | i += 1
110 |
111 | dxfile.write('\n')
112 |
--------------------------------------------------------------------------------
/RosENet/preprocessing/make_complex_pdb.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from RosENet.preprocessing.make_ligand_params_pdb import MakeLigandParamsPDB
3 | from RosENet.preprocessing.step import Step
4 | from prody import parsePDB, writePDB
5 | import RosENet.constants as constants
6 | from htmd.molecule.molecule import Molecule
7 | from htmd.molecule.voxeldescriptors import getVoxelDescriptors
8 | from htmd.builder.preparation import proteinPrepare
9 |
10 | def standardize_residues(structure):
11 | """Rename the residues of a structure to their standard names.
12 |
13 | structure : prody.AtomSelection
14 | Structure to be handled
15 | """
16 | residue_names = structure.getResnames()
17 | for nonstd, std in constants.nonstd2stdresidues.items():
18 | residue_names[residue_names==nonstd] = std
19 | structure.setResnames(residue_names)
20 |
21 | def fix_chains(structure):
22 | """Rename chain names to reserve special names
23 | (W for water, X for ligand, Z for metals)
24 |
25 | structure : prody.AtomSelection
26 | Structure to be handled
27 | """
28 | chains = list(set(structure.getChids()))
29 | valid_chains = "ABCDEFGHIJKLMNOPQRSTUVwxYz0123456789abcdefghijklmnopqestuvy"
30 | chain_dict = dict(zip(chains, valid_chains[:len(chains)]))
31 | structure.setChids(np.vectorize(chain_dict.get)(structure.getChids()))
32 |
33 | def fix_ligand_names(structure):
34 | """Rename ligand residue names and chain names to their reserved names.
35 |
36 | structure : prody.AtomSelection
37 | Structure to be handled
38 | """
39 | n_atoms = structure.numAtoms()
40 | structure.setResnames(np.array([constants.ligand_resname]*n_atoms))
41 | structure.setChids(np.array([constants.ligand_chid]*n_atoms))
42 |
43 | def fix_water_chains(structure):
44 | """Rename water molecule chain to its reserved name.
45 |
46 | structure : prody.AtomSelection
47 | Structure to be handled
48 | """
49 | chains = structure.getChids()
50 | residues = structure.getResnames()
51 | chains[residues==constants.water_residue] = constants.water_chain
52 | structure.setChids(chains)
53 |
54 | def fix_metal_chains(structure):
55 | """Rename metal atom chain to its reserved name.
56 |
57 | structure : prody.AtomSelection
58 | Structure to be handled
59 | """
60 | chains = structure.getChids()
61 | residues = structure.getResnames()
62 | for metal in constants.accepted_metals:
63 | chains[residues==metal] = constants.metal_chain
64 | structure.setChids(chains)
65 |
66 | def cleanup_and_merge(pdb_object):
67 | """Apply fixes to protein and ligand structures and write complex.
68 |
69 | pdb_object : PDBObject
70 | PDB structure to be handled
71 | """
72 | protein = pdb_object.protein.pdb.read()
73 | ligand = pdb_object.ligand.pdb.read()
74 | fix_chains(protein)
75 | fix_ligand_names(ligand)
76 | complex = protein + ligand
77 | standardize_residues(complex)
78 | fix_water_chains(complex)
79 | fix_metal_chains(complex)
80 | pdb_object.complex.pdb.write(complex)
81 |
82 | def protein_optimization(complex_path):
83 | """Optimize protein inside complex file and rewrite complex.
84 |
85 | complex_path : pathlib.Path
86 | Path to complex file
87 | """
88 | complex = Molecule(str(complex_path))
89 | prot = complex.copy(); prot.filter("protein")
90 | lig = complex.copy(); lig.filter(constants.ligand_selector)
91 | #prot = proteinPrepare(prot, pH=7.0)
92 | mol = Molecule(name="complex")
93 | mol.append(prot)
94 | mol.append(lig)
95 | mol.write(str(complex_path))
96 |
97 | class MakeComplexPDB(metaclass=Step,requirements=[MakeLigandParamsPDB]):
98 | @classmethod
99 | def files(cls, pdb_object):
100 | """List of files being created
101 |
102 | pdb_object : PDBObject
103 | PDB structure being handled
104 | """
105 | return [pdb_object.complex.pdb]
106 |
107 | @classmethod
108 | def _run(cls, pdb_object):
109 | """Inner function for the preprocessing step.
110 |
111 | pdb_object : PDBObject
112 | PDB structure being handled
113 | """
114 | complex_path = pdb_object.complex.pdb.path
115 | cleanup_and_merge(pdb_object)
116 | protein_optimization(complex_path)
117 | complex = pdb_object.complex.pdb.read()
118 | standardize_residues(complex)
119 | pdb_object.complex.pdb.write(complex)
120 |
121 |
--------------------------------------------------------------------------------
/RosENet/preprocessing/preprocessing.py:
--------------------------------------------------------------------------------
1 | from string import Template
2 | from .make_complex_pdb import make_complex_pdb as _make_complex_pdb
3 | from .make_ligand_mol2_renamed import make_ligand_mol2_renamed as _make_ligand_mol2_renamed
4 | from .compute_rosetta_energy import compute_rosetta_energy as _compute_rosetta_energy
5 | import RosENet.rosetta.rosetta as rosetta
6 | import RosENet.storage.storage as storage
7 | import RosENet.constants as constants
8 | import subprocess
9 |
10 | def make_ligand_params_pdb(pdb_object):
11 | ligand_mol2_path = pdb_object.ligand.mol2.path
12 | params_filename = ligand_mol2_path.stem
13 | working_directory = ligand_mol2_path.parent
14 | return rosetta.molfile_to_params(
15 | working_directory = working_directory,
16 | output_path = params_filename,
17 | input_path = ligand_mol2_path)
18 |
19 | def make_complex_pdb(pdb_object):
20 | protein_path = pdb_object.protein.pdb.path
21 | ligand_path = pdb_object.ligand.pdb.path
22 | complex_path = pdb_object.complex.pdb.path
23 | return _make_complex_pdb(protein_path, ligand_path, complex_path)
24 |
25 | def minimize_rosetta(pdb_object):
26 | generate_minimization_flags_file(pdb_object)
27 | generate_constraint_file(pdb_object)
28 | working_directory = pdb_object.path
29 | pdb_object.minimized.scores.delete()
30 | rosetta.minimize(working_directory = working_directory)
31 | hide_non_minimal_complexes(pdb_object)
32 |
33 | def generate_minimization_flags_file(pdb_object):
34 | complex_path = pdb_object.complex.pdb.path
35 | complex = complex_path.name
36 | name = complex_path.stem
37 | params = pdb_object.ligand.params.path.name
38 | template = Template(storage.read_plain(constants.flags_relax_path))
39 | substitution = {'complex' : complex,
40 | 'name' : name,
41 | 'params' : params}
42 | output = template.substitute(substitution)
43 | pdb_object.flags_relax.write(output)
44 |
45 | def generate_constraint_file(pdb_object):
46 | complex = pdb_object.complex.pdb.read()
47 | output_path = pdb_object.constraints.path
48 | metals = complex.select(constants.metal_selector)
49 | results = []
50 | if metals:
51 | for atom in metals:
52 | pos = atom.getCoords()
53 | close_ligand = complex.select(constants.close_ligand_selector, t=pos)
54 | if close_ligand:
55 | for close in close_ligand:
56 | results.append((atom.getName(),
57 | str(atom.getResnum())+atom.getChid(),
58 | close.getName(),
59 | str(close.getResnum())+close.getChid()))
60 | pdb_object.constraints.write(
61 | [f"AtomPair {r[0]} {r[1]} {r[2]} {r[3]} SQUARE_WELL 2.5 -2000\n" for r in results])
62 |
63 | def hide_non_minimal_complexes(pdb_object):
64 | scores = rosetta.parse_scores(pdb_object.minimized.scores.read())
65 | hidden_folder = pdb_object.minimized.hidden_complexes.path
66 | hidden_folder.delete()
67 | storage.make_directory(hidden_folder)
68 | for number in list(scores.keys())[1:]:
69 | complex_path = pdb_object.minimized.complex.pdb[number].path
70 | storage.move(complex_path, hidden_folder, no_fail=True)
71 |
72 | def make_protein_pdb(pdb_object):
73 | complex = pdb_object.minimized.complex.pdb.read()
74 | protein = complex.select(constants.protein_selector)
75 | pdb_object.minimized.protein.pdb.write(protein)
76 |
77 | def make_ligand_mol2(pdb_object):
78 | complex_path = pdb_object.minimized.complex.pdb.path
79 | renamed_mol2_path = pdb_object.ligand.renamed_mol2.path
80 | ligand_path = pdb_object.minimized.ligand.mol2.path
81 | rosetta.pdb_to_molfile(
82 | mol = renamed_mol2_path,
83 | complex_path = complex_path,
84 | output_path = ligand_path)
85 |
86 | def make_pdbqt(pdb_object):
87 | path = pdb_object.path
88 | subprocess.run([constants.mgl_python_path,
89 | constants.preprocess_vina_path,
90 | path])
91 |
92 | def compute_rosetta_energy(pdb_object):
93 | complex_path = pdb_object.minimized.complex.pdb.path
94 | params = pdb_object.ligand.params.path
95 | output_path = pdb_object.minimized.complex.attr.path
96 | _compute_rosetta_energy(complex_path, params, output_path)
97 |
98 | def make_ligand_mol2_renamed(pdb_object):
99 | ligand_pdb_path = pdb_object.ligand.pdb.path
100 | ligand_mol2_path = pdb_object.ligand.mol2.path
101 | output_path = pdb_object.ligand.renamed_mol2.path
102 | _make_ligand_mol2_renamed(ligand_pdb_path, ligand_mol2_path, output_path)
103 |
104 |
--------------------------------------------------------------------------------
/RosENet/storage/.ropeproject/config.py:
--------------------------------------------------------------------------------
1 | # The default ``config.py``
2 | # flake8: noqa
3 |
4 |
5 | def set_prefs(prefs):
6 | """This function is called before opening the project"""
7 |
8 | # Specify which files and folders to ignore in the project.
9 | # Changes to ignored resources are not added to the history and
10 | # VCSs. Also they are not returned in `Project.get_files()`.
11 | # Note that ``?`` and ``*`` match all characters but slashes.
12 | # '*.pyc': matches 'test.pyc' and 'pkg/test.pyc'
13 | # 'mod*.pyc': matches 'test/mod1.pyc' but not 'mod/1.pyc'
14 | # '.svn': matches 'pkg/.svn' and all of its children
15 | # 'build/*.o': matches 'build/lib.o' but not 'build/sub/lib.o'
16 | # 'build//*.o': matches 'build/lib.o' and 'build/sub/lib.o'
17 | prefs['ignored_resources'] = ['*.pyc', '*~', '.ropeproject',
18 | '.hg', '.svn', '_svn', '.git', '.tox']
19 |
20 | # Specifies which files should be considered python files. It is
21 | # useful when you have scripts inside your project. Only files
22 | # ending with ``.py`` are considered to be python files by
23 | # default.
24 | #prefs['python_files'] = ['*.py']
25 |
26 | # Custom source folders: By default rope searches the project
27 | # for finding source folders (folders that should be searched
28 | # for finding modules). You can add paths to that list. Note
29 | # that rope guesses project source folders correctly most of the
30 | # time; use this if you have any problems.
31 | # The folders should be relative to project root and use '/' for
32 | # separating folders regardless of the platform rope is running on.
33 | # 'src/my_source_folder' for instance.
34 | #prefs.add('source_folders', 'src')
35 |
36 | # You can extend python path for looking up modules
37 | #prefs.add('python_path', '~/python/')
38 |
39 | # Should rope save object information or not.
40 | prefs['save_objectdb'] = True
41 | prefs['compress_objectdb'] = False
42 |
43 | # If `True`, rope analyzes each module when it is being saved.
44 | prefs['automatic_soa'] = True
45 | # The depth of calls to follow in static object analysis
46 | prefs['soa_followed_calls'] = 0
47 |
48 | # If `False` when running modules or unit tests "dynamic object
49 | # analysis" is turned off. This makes them much faster.
50 | prefs['perform_doa'] = True
51 |
52 | # Rope can check the validity of its object DB when running.
53 | prefs['validate_objectdb'] = True
54 |
55 | # How many undos to hold?
56 | prefs['max_history_items'] = 32
57 |
58 | # Shows whether to save history across sessions.
59 | prefs['save_history'] = True
60 | prefs['compress_history'] = False
61 |
62 | # Set the number spaces used for indenting. According to
63 | # :PEP:`8`, it is best to use 4 spaces. Since most of rope's
64 | # unit-tests use 4 spaces it is more reliable, too.
65 | prefs['indent_size'] = 4
66 |
67 | # Builtin and c-extension modules that are allowed to be imported
68 | # and inspected by rope.
69 | prefs['extension_modules'] = []
70 |
71 | # Add all standard c-extensions to extension_modules list.
72 | prefs['import_dynload_stdmods'] = True
73 |
74 | # If `True` modules with syntax errors are considered to be empty.
75 | # The default value is `False`; When `False` syntax errors raise
76 | # `rope.base.exceptions.ModuleSyntaxError` exception.
77 | prefs['ignore_syntax_errors'] = False
78 |
79 | # If `True`, rope ignores unresolvable imports. Otherwise, they
80 | # appear in the importing namespace.
81 | prefs['ignore_bad_imports'] = False
82 |
83 | # If `True`, rope will insert new module imports as
84 | # `from import ` by default.
85 | prefs['prefer_module_from_imports'] = False
86 |
87 | # If `True`, rope will transform a comma list of imports into
88 | # multiple separate import statements when organizing
89 | # imports.
90 | prefs['split_imports'] = False
91 |
92 | # If `True`, rope will remove all top-level import statements and
93 | # reinsert them at the top of the module when making changes.
94 | prefs['pull_imports_to_top'] = True
95 |
96 | # If `True`, rope will sort imports alphabetically by module name instead of
97 | # alphabetically by import statement, with from imports after normal
98 | # imports.
99 | prefs['sort_imports_alphabetically'] = False
100 |
101 | # Location of implementation of rope.base.oi.type_hinting.interfaces.ITypeHintingFactory
102 | # In general case, you don't have to change this value, unless you're an rope expert.
103 | # Change this value to inject you own implementations of interfaces
104 | # listed in module rope.base.oi.type_hinting.providers.interfaces
105 | # For example, you can add you own providers for Django Models, or disable the search
106 | # type-hinting in a class hierarchy, etc.
107 | prefs['type_hinting_factory'] = 'rope.base.oi.type_hinting.factory.default_type_hinting_factory'
108 |
109 |
110 | def project_opened(project):
111 | """This function is called after opening the project"""
112 | # Do whatever you like here!
113 |
--------------------------------------------------------------------------------
/RosENet/postprocessing/.ropeproject/config.py:
--------------------------------------------------------------------------------
1 | # The default ``config.py``
2 | # flake8: noqa
3 |
4 |
5 | def set_prefs(prefs):
6 | """This function is called before opening the project"""
7 |
8 | # Specify which files and folders to ignore in the project.
9 | # Changes to ignored resources are not added to the history and
10 | # VCSs. Also they are not returned in `Project.get_files()`.
11 | # Note that ``?`` and ``*`` match all characters but slashes.
12 | # '*.pyc': matches 'test.pyc' and 'pkg/test.pyc'
13 | # 'mod*.pyc': matches 'test/mod1.pyc' but not 'mod/1.pyc'
14 | # '.svn': matches 'pkg/.svn' and all of its children
15 | # 'build/*.o': matches 'build/lib.o' but not 'build/sub/lib.o'
16 | # 'build//*.o': matches 'build/lib.o' and 'build/sub/lib.o'
17 | prefs['ignored_resources'] = ['*.pyc', '*~', '.ropeproject',
18 | '.hg', '.svn', '_svn', '.git', '.tox']
19 |
20 | # Specifies which files should be considered python files. It is
21 | # useful when you have scripts inside your project. Only files
22 | # ending with ``.py`` are considered to be python files by
23 | # default.
24 | #prefs['python_files'] = ['*.py']
25 |
26 | # Custom source folders: By default rope searches the project
27 | # for finding source folders (folders that should be searched
28 | # for finding modules). You can add paths to that list. Note
29 | # that rope guesses project source folders correctly most of the
30 | # time; use this if you have any problems.
31 | # The folders should be relative to project root and use '/' for
32 | # separating folders regardless of the platform rope is running on.
33 | # 'src/my_source_folder' for instance.
34 | #prefs.add('source_folders', 'src')
35 |
36 | # You can extend python path for looking up modules
37 | #prefs.add('python_path', '~/python/')
38 |
39 | # Should rope save object information or not.
40 | prefs['save_objectdb'] = True
41 | prefs['compress_objectdb'] = False
42 |
43 | # If `True`, rope analyzes each module when it is being saved.
44 | prefs['automatic_soa'] = True
45 | # The depth of calls to follow in static object analysis
46 | prefs['soa_followed_calls'] = 0
47 |
48 | # If `False` when running modules or unit tests "dynamic object
49 | # analysis" is turned off. This makes them much faster.
50 | prefs['perform_doa'] = True
51 |
52 | # Rope can check the validity of its object DB when running.
53 | prefs['validate_objectdb'] = True
54 |
55 | # How many undos to hold?
56 | prefs['max_history_items'] = 32
57 |
58 | # Shows whether to save history across sessions.
59 | prefs['save_history'] = True
60 | prefs['compress_history'] = False
61 |
62 | # Set the number spaces used for indenting. According to
63 | # :PEP:`8`, it is best to use 4 spaces. Since most of rope's
64 | # unit-tests use 4 spaces it is more reliable, too.
65 | prefs['indent_size'] = 4
66 |
67 | # Builtin and c-extension modules that are allowed to be imported
68 | # and inspected by rope.
69 | prefs['extension_modules'] = []
70 |
71 | # Add all standard c-extensions to extension_modules list.
72 | prefs['import_dynload_stdmods'] = True
73 |
74 | # If `True` modules with syntax errors are considered to be empty.
75 | # The default value is `False`; When `False` syntax errors raise
76 | # `rope.base.exceptions.ModuleSyntaxError` exception.
77 | prefs['ignore_syntax_errors'] = False
78 |
79 | # If `True`, rope ignores unresolvable imports. Otherwise, they
80 | # appear in the importing namespace.
81 | prefs['ignore_bad_imports'] = False
82 |
83 | # If `True`, rope will insert new module imports as
84 | # `from import ` by default.
85 | prefs['prefer_module_from_imports'] = False
86 |
87 | # If `True`, rope will transform a comma list of imports into
88 | # multiple separate import statements when organizing
89 | # imports.
90 | prefs['split_imports'] = False
91 |
92 | # If `True`, rope will remove all top-level import statements and
93 | # reinsert them at the top of the module when making changes.
94 | prefs['pull_imports_to_top'] = True
95 |
96 | # If `True`, rope will sort imports alphabetically by module name instead of
97 | # alphabetically by import statement, with from imports after normal
98 | # imports.
99 | prefs['sort_imports_alphabetically'] = False
100 |
101 | # Location of implementation of rope.base.oi.type_hinting.interfaces.ITypeHintingFactory
102 | # In general case, you don't have to change this value, unless you're an rope expert.
103 | # Change this value to inject you own implementations of interfaces
104 | # listed in module rope.base.oi.type_hinting.providers.interfaces
105 | # For example, you can add you own providers for Django Models, or disable the search
106 | # type-hinting in a class hierarchy, etc.
107 | prefs['type_hinting_factory'] = 'rope.base.oi.type_hinting.factory.default_type_hinting_factory'
108 |
109 |
110 | def project_opened(project):
111 | """This function is called after opening the project"""
112 | # Do whatever you like here!
113 |
--------------------------------------------------------------------------------
/RosENet/preprocessing/.ropeproject/config.py:
--------------------------------------------------------------------------------
1 | # The default ``config.py``
2 | # flake8: noqa
3 |
4 |
5 | def set_prefs(prefs):
6 | """This function is called before opening the project"""
7 |
8 | # Specify which files and folders to ignore in the project.
9 | # Changes to ignored resources are not added to the history and
10 | # VCSs. Also they are not returned in `Project.get_files()`.
11 | # Note that ``?`` and ``*`` match all characters but slashes.
12 | # '*.pyc': matches 'test.pyc' and 'pkg/test.pyc'
13 | # 'mod*.pyc': matches 'test/mod1.pyc' but not 'mod/1.pyc'
14 | # '.svn': matches 'pkg/.svn' and all of its children
15 | # 'build/*.o': matches 'build/lib.o' but not 'build/sub/lib.o'
16 | # 'build//*.o': matches 'build/lib.o' and 'build/sub/lib.o'
17 | prefs['ignored_resources'] = ['*.pyc', '*~', '.ropeproject',
18 | '.hg', '.svn', '_svn', '.git', '.tox']
19 |
20 | # Specifies which files should be considered python files. It is
21 | # useful when you have scripts inside your project. Only files
22 | # ending with ``.py`` are considered to be python files by
23 | # default.
24 | #prefs['python_files'] = ['*.py']
25 |
26 | # Custom source folders: By default rope searches the project
27 | # for finding source folders (folders that should be searched
28 | # for finding modules). You can add paths to that list. Note
29 | # that rope guesses project source folders correctly most of the
30 | # time; use this if you have any problems.
31 | # The folders should be relative to project root and use '/' for
32 | # separating folders regardless of the platform rope is running on.
33 | # 'src/my_source_folder' for instance.
34 | #prefs.add('source_folders', 'src')
35 |
36 | # You can extend python path for looking up modules
37 | #prefs.add('python_path', '~/python/')
38 |
39 | # Should rope save object information or not.
40 | prefs['save_objectdb'] = True
41 | prefs['compress_objectdb'] = False
42 |
43 | # If `True`, rope analyzes each module when it is being saved.
44 | prefs['automatic_soa'] = True
45 | # The depth of calls to follow in static object analysis
46 | prefs['soa_followed_calls'] = 0
47 |
48 | # If `False` when running modules or unit tests "dynamic object
49 | # analysis" is turned off. This makes them much faster.
50 | prefs['perform_doa'] = True
51 |
52 | # Rope can check the validity of its object DB when running.
53 | prefs['validate_objectdb'] = True
54 |
55 | # How many undos to hold?
56 | prefs['max_history_items'] = 32
57 |
58 | # Shows whether to save history across sessions.
59 | prefs['save_history'] = True
60 | prefs['compress_history'] = False
61 |
62 | # Set the number spaces used for indenting. According to
63 | # :PEP:`8`, it is best to use 4 spaces. Since most of rope's
64 | # unit-tests use 4 spaces it is more reliable, too.
65 | prefs['indent_size'] = 4
66 |
67 | # Builtin and c-extension modules that are allowed to be imported
68 | # and inspected by rope.
69 | prefs['extension_modules'] = []
70 |
71 | # Add all standard c-extensions to extension_modules list.
72 | prefs['import_dynload_stdmods'] = True
73 |
74 | # If `True` modules with syntax errors are considered to be empty.
75 | # The default value is `False`; When `False` syntax errors raise
76 | # `rope.base.exceptions.ModuleSyntaxError` exception.
77 | prefs['ignore_syntax_errors'] = False
78 |
79 | # If `True`, rope ignores unresolvable imports. Otherwise, they
80 | # appear in the importing namespace.
81 | prefs['ignore_bad_imports'] = False
82 |
83 | # If `True`, rope will insert new module imports as
84 | # `from import ` by default.
85 | prefs['prefer_module_from_imports'] = False
86 |
87 | # If `True`, rope will transform a comma list of imports into
88 | # multiple separate import statements when organizing
89 | # imports.
90 | prefs['split_imports'] = False
91 |
92 | # If `True`, rope will remove all top-level import statements and
93 | # reinsert them at the top of the module when making changes.
94 | prefs['pull_imports_to_top'] = True
95 |
96 | # If `True`, rope will sort imports alphabetically by module name instead of
97 | # alphabetically by import statement, with from imports after normal
98 | # imports.
99 | prefs['sort_imports_alphabetically'] = False
100 |
101 | # Location of implementation of rope.base.oi.type_hinting.interfaces.ITypeHintingFactory
102 | # In general case, you don't have to change this value, unless you're an rope expert.
103 | # Change this value to inject you own implementations of interfaces
104 | # listed in module rope.base.oi.type_hinting.providers.interfaces
105 | # For example, you can add you own providers for Django Models, or disable the search
106 | # type-hinting in a class hierarchy, etc.
107 | prefs['type_hinting_factory'] = 'rope.base.oi.type_hinting.factory.default_type_hinting_factory'
108 |
109 |
110 | def project_opened(project):
111 | """This function is called after opening the project"""
112 | # Do whatever you like here!
113 |
--------------------------------------------------------------------------------
/RosENet/preprocessing/preprocess_vina.py:
--------------------------------------------------------------------------------
1 | """
2 | Module extracted from prepare_receptor4.py and prepare_ligand4.py from AutoDockTools scripts.
3 | http://autodock.scripps.edu/faqs-help/how-to/how-to-prepare-a-receptor-file-for-autodock4
4 | http://autodock.scripps.edu/faqs-help/how-to/how-to-prepare-a-ligand-file-for-autodock4
5 | """
6 | import os
7 | from MolKit import Read
8 | import MolKit.molecule
9 | import MolKit.protein
10 | from AutoDockTools.MoleculePreparation import AD4ReceptorPreparation, AD4LigandPreparation
11 | import sys
12 | import getopt
13 |
14 |
15 | def preprocess_receptor(receptor_filename, outputfilename):
16 | repairs = ''
17 | charges_to_add = 'gasteiger'
18 | preserve_charge_types=None
19 | cleanup = ""
20 | mode = "automatic"
21 | delete_single_nonstd_residues = False
22 | dictionary = None
23 |
24 | mols = Read(receptor_filename)
25 | mol = mols[0]
26 | preserved = {}
27 | if charges_to_add is not None and preserve_charge_types is not None:
28 | preserved_types = preserve_charge_types.split(',')
29 | for t in preserved_types:
30 | if not len(t): continue
31 | ats = mol.allAtoms.get(lambda x: x.autodock_element==t)
32 | for a in ats:
33 | if a.chargeSet is not None:
34 | preserved[a] = [a.chargeSet, a.charge]
35 |
36 | if len(mols)>1:
37 | ctr = 1
38 | for m in mols[1:]:
39 | ctr += 1
40 | if len(m.allAtoms)>len(mol.allAtoms):
41 | mol = m
42 | mol.buildBondsByDistance()
43 |
44 | RPO = AD4ReceptorPreparation(mol, mode, repairs, charges_to_add,
45 | cleanup, outputfilename=outputfilename,
46 | preserved=preserved,
47 | delete_single_nonstd_residues=delete_single_nonstd_residues,
48 | dict=dictionary)
49 |
50 | if charges_to_add is not None:
51 | for atom, chargeList in preserved.items():
52 | atom._charges[chargeList[0]] = chargeList[1]
53 | atom.chargeSet = chargeList[0]
54 |
55 | def preprocess_ligand(ligand_filename, outputfilename):
56 | verbose = None
57 | repairs = "" #"hydrogens_bonds"
58 | charges_to_add = 'gasteiger'
59 | preserve_charge_types=''
60 | cleanup = ""
61 | allowed_bonds = "backbone"
62 | root = 'auto'
63 | check_for_fragments = True
64 | bonds_to_inactivate = ""
65 | inactivate_all_torsions = True
66 | attach_nonbonded_fragments = True
67 | attach_singletons = True
68 | mode = "automatic"
69 | dict = None
70 |
71 | mols = Read(ligand_filename)
72 | if verbose: print 'read ', ligand_filename
73 | mol = mols[0]
74 | if len(mols)>1:
75 | ctr = 1
76 | for m in mols[1:]:
77 | ctr += 1
78 | if len(m.allAtoms)>len(mol.allAtoms):
79 | mol = m
80 | coord_dict = {}
81 | for a in mol.allAtoms: coord_dict[a] = a.coords
82 |
83 | mol.buildBondsByDistance()
84 | if charges_to_add is not None:
85 | preserved = {}
86 | preserved_types = preserve_charge_types.split(',')
87 | for t in preserved_types:
88 | if not len(t): continue
89 | ats = mol.allAtoms.get(lambda x: x.autodock_element==t)
90 | for a in ats:
91 | if a.chargeSet is not None:
92 | preserved[a] = [a.chargeSet, a.charge]
93 |
94 |
95 |
96 | LPO = AD4LigandPreparation(mol, mode, repairs, charges_to_add,
97 | cleanup, allowed_bonds, root,
98 | outputfilename=outputfilename,
99 | dict=dict, check_for_fragments=check_for_fragments,
100 | bonds_to_inactivate=bonds_to_inactivate,
101 | inactivate_all_torsions=inactivate_all_torsions,
102 | attach_nonbonded_fragments=attach_nonbonded_fragments,
103 | attach_singletons=attach_singletons)
104 | if charges_to_add is not None:
105 | for atom, chargeList in preserved.items():
106 | atom._charges[chargeList[0]] = chargeList[1]
107 | atom.chargeSet = chargeList[0]
108 | bad_list = []
109 | for a in mol.allAtoms:
110 | if a in coord_dict.keys() and a.coords!=coord_dict[a]:
111 | bad_list.append(a)
112 | if len(bad_list):
113 | print len(bad_list), ' atom coordinates changed!'
114 | for a in bad_list:
115 | print a.name, ":", coord_dict[a], ' -> ', a.coords
116 | else:
117 | if verbose: print "No change in atomic coordinates"
118 | if mol.returnCode!=0:
119 | sys.stderr.write(mol.returnMsg+"\n")
120 |
121 | def process_folder(receptor, ligand):
122 | try:
123 | preprocess_receptor(receptor, receptor.replace('.pdb','.pdbqt'))
124 | except Exception, e:
125 | print 'Protein', receptor
126 | raise e
127 | try:
128 | preprocess_ligand(ligand, ligand.replace('.mol2','.pdbqt'))
129 | except Exception, e:
130 | print 'Ligand', ligand
131 | raise e
132 |
133 | if __name__ == "__main__":
134 | process_folder(sys.argv[1], sys.argv[2])
135 |
--------------------------------------------------------------------------------
/RosENet/objects/model.py:
--------------------------------------------------------------------------------
1 | import RosENet.storage.storage as storage
2 | import RosENet.network.input as input
3 | import RosENet.settings as settings
4 | import tensorflow as tf
5 | import importlib.util
6 | import numpy as np
7 | from RosENet.objects.file import File
8 |
9 | class _ModelAction:
10 | """Model action base class. Represents an action (training, evaluation or prediction)
11 | of a dataset using a model(neural network)"""
12 | def __init__(self, model_object, dataset_object, channels, action, seed):
13 | self.dataset_object = dataset_object
14 | if not self.results.path.exists():
15 | self.results.write("")
16 | self.model_object = model_object
17 | self.channels = channels
18 | input_fn = input.load_fn(dataset_object, channels, settings.rotate, action)
19 | self.iterator = input_fn().make_initializable_iterator()
20 | model_fn = model_object.load_fn()
21 | self.id, self.X, self.y = self.iterator.get_next()
22 | self.shape = tf.shape(self.y)[0]
23 | model = model_fn(self.X, self.y, action, settings.rotate)
24 | self.op = model.train_op
25 | self.loss = model.loss
26 | self.predictions = model.predictions
27 | checkpoint_folder = self.dataset_object.model(self.model_object, self.channels, seed)
28 | storage.make_directory(checkpoint_folder)
29 | self.save_path = checkpoint_folder / "model.ckpt"
30 | self.saver = tf.train.Saver()
31 |
32 | def do(self, sess = None):
33 | if sess is None:
34 | sess = tf.get_default_session()
35 | try:
36 | sess.run(self.iterator.initializer)
37 | self.start_epoch()
38 | while True:
39 | self.do_batch(sess)
40 | except (KeyboardInterrupt, SystemExit):
41 | raise
42 | except Exception as e:
43 | pass
44 | return self.end_epoch()
45 |
46 | def save(self, sess = None):
47 | if sess is None:
48 | sess = tf.get_default_session()
49 | self.saver.save(sess, str(self.save_path.absolute()))
50 |
51 | @property
52 | def dataset_name(self):
53 | return self.dataset_object.name
54 |
55 | @property
56 | def model_name(self):
57 | return self.model_object.name
58 |
59 | @property
60 | def results(self):
61 | return File.create(f"{self.dataset_name}.txt",self.dataset_object.path)
62 |
63 | def load(self, sess = None):
64 | if sess is None:
65 | sess = tf.get_default_session()
66 | self.saver.restore(sess, str(self.save_path))
67 |
68 | class _ModelTrain(_ModelAction):
69 | def __init__(self, model_path, dataset_path, channels,seed):
70 | super(_ModelTrain, self).__init__(model_path,
71 | dataset_path,
72 | channels,
73 | tf.estimator.ModeKeys.TRAIN,
74 | seed)
75 | def start_epoch(self):
76 | self.size = 0
77 | self.loss_value = 0
78 |
79 | def do_batch(self, sess):
80 | _, loss, batch_shape = sess.run([self.op, self.loss, self.shape])
81 | self.size += batch_shape
82 | self.loss_value = batch_shape*loss
83 |
84 | def end_epoch(self):
85 | return np.sqrt(self.loss_value/self.size)
86 |
87 | class _ModelEvaluate(_ModelAction):
88 | def __init__(self, model_path, dataset_path, channels, seed):
89 | super(_ModelEvaluate, self).__init__(model_path,
90 | dataset_path,
91 | channels,
92 | tf.estimator.ModeKeys.EVAL,
93 | seed)
94 | def start_epoch(self):
95 | self.size = 0
96 | self.loss_value = 0
97 |
98 | def do_batch(self, sess):
99 | loss, batch_shape = sess.run([self.loss, self.shape])
100 | self.size += batch_shape
101 | self.loss_value = batch_shape*loss
102 |
103 | def end_epoch(self):
104 | return np.sqrt(self.loss_value/self.size)
105 |
106 | class _ModelPredict(_ModelAction):
107 | def __init__(self, model_path, dataset_path, channels, seed):
108 | super(_ModelPredict, self).__init__(model_path,
109 | dataset_path,
110 | channels,
111 | tf.estimator.ModeKeys.EVAL,
112 | seed)
113 | def start_epoch(self):
114 | self.prediction_dict = {}
115 |
116 | def do_batch(self, sess):
117 | names, predictions = sess.run([self.id, self.predictions])
118 | for name, prediction in zip(names, predictions):
119 | self.prediction_dict[str(name)] = prediction
120 |
121 | def end_epoch(self):
122 | return self.prediction_dict
123 |
124 | class _Model:
125 | """Inner Model class, represents a CNN."""
126 | _instance_dict = {}
127 | def __init__(self, path):
128 | self.path = path
129 |
130 | def train_object(self, dataset_object, channels, seed):
131 | return _ModelTrain(self, dataset_object, channels, seed)
132 |
133 | def evaluate_object(self, dataset_object, channels, seed):
134 | return _ModelEvaluate(self, dataset_object, channels, seed)
135 |
136 | def predict_object(self, dataset_object, channels, seed):
137 | return _ModelPredict(self, dataset_object, channels, seed)
138 |
139 | @property
140 | def name(self):
141 | return self.path.stem
142 |
143 | def load_fn(self):
144 | spec = importlib.util.spec_from_file_location("aux_module", str(self.path.absolute()))
145 | foo = importlib.util.module_from_spec(spec)
146 | spec.loader.exec_module(foo)
147 | return foo.model_fn
148 | #return __import__(str(self.path.absolute()), fromlist=["model_fn"])
149 |
150 | def ModelObject(path):
151 | if str(path.absolute()) not in _Model._instance_dict:
152 | _Model._instance_dict[str(path.absolute())] = _Model(path)
153 | return _Model._instance_dict[str(path.absolute())]
154 |
155 |
--------------------------------------------------------------------------------
/RosENet/network/input.py:
--------------------------------------------------------------------------------
1 | import random
2 | import tensorflow as tf
3 | import numpy as np
4 | from RosENet.network.utils import random_rot, rots_90, random_rot_90, all_rot_90
5 | import RosENet.settings as settings
6 |
7 | def parse_fn(shape):
8 | """Parse function for the TensorFlow pipeline.
9 | The inner function outputs three tensors: the name of the complex,
10 | the image and the target binding affinity.
11 |
12 | shape : list of int
13 | 4D Shape of the image inputted
14 | """
15 | def fn(example):
16 | example_fmt = {
17 | "id": tf.VarLenFeature(tf.string),
18 | "X": tf.FixedLenFeature((int(np.prod(shape)),), tf.float32),
19 | "y": tf.FixedLenFeature((1,), tf.float32)
20 | }
21 | parsed = tf.parse_single_example(example, example_fmt)
22 | X = tf.reshape(parsed['X'], shape=list(shape))
23 | y = parsed['y']
24 | id = parsed['id'].values
25 | return id, X, y
26 | return fn
27 |
28 | def rotate_fn(training, shape):
29 | """Rotation function for the TensorFlow pipeline.
30 | The inner function rotates the images. If training, the image will be rotated randomly to one of its 24 90º rotations. Otherwise, each image will be expanded to 24 images representing all possible 90º rotations.
31 |
32 | training : bool
33 | True if training
34 | shape : list of int
35 | 4D Shape of the images inputted
36 | """
37 | def fn(id, X, y):
38 | if training:
39 | X = random_rot_90(X,list((-1,) + tuple(shape)))
40 | else:
41 | X = all_rot_90(X, list((-1,) + tuple(shape)))
42 | return id, X, y
43 | return fn
44 |
45 | def take_channels(ch, channel_order):
46 | """Extraction function for the TensorFlow pipeline.
47 | Given a list of channels to be extracted and a list of total channels,
48 | the inner function will output a new image tensor with only the
49 | selected channels.
50 |
51 | ch : list of string
52 | List of channels to be extracted
53 | channel_order : list of string
54 | Ordered list of the channels in the current input tensor
55 | """
56 | idx = [i for i,x in enumerate(channel_order) if x in ch]
57 | def f(id, X, y):
58 | ret = tf.gather(X,idx,axis=-1)
59 | return id, ret, y
60 | return ch, f
61 |
62 | def make_input_fn(input_path, shape, training, rot, merge):
63 | """Create TensorFlow input pipeline.
64 |
65 | input_path : pathlib.Path
66 | Globbed path for the .tfrecord files
67 | shape : list of int
68 | 4D Shape of the images inputted
69 | training : bool
70 | True if training
71 | rot : bool
72 | True if rotations are enabled
73 | param merge :
74 | List of channel selectors to extract channels
75 | """
76 | def in_fn():
77 | channel_order = ['htmd_hydrophobic',
78 | 'htmd_aromatic',
79 | 'htmd_hbond_acceptor',
80 | 'htmd_hbond_donor',
81 | 'htmd_positive_ionizable',
82 | 'htmd_negative_ionizable',
83 | 'htmd_metal',
84 | 'htmd_occupancies',
85 | 'htmd_hydrophobic',
86 | 'htmd_aromatic',
87 | 'htmd_hbond_acceptor',
88 | 'htmd_hbond_donor',
89 | 'htmd_positive_ionizable',
90 | 'htmd_negative_ionizable',
91 | 'htmd_metal',
92 | 'htmd_occupancies',
93 | 'elec_p',
94 | 'elec_l',
95 | 'rosetta_atr_p',
96 | 'rosetta_rep_p',
97 | 'rosetta_sol_p_pos',
98 | 'rosetta_elec_p_pos',
99 | 'rosetta_sol_p_neg',
100 | 'rosetta_elec_p_neg',
101 | 'rosetta_atr_l',
102 | 'rosetta_atr_p',
103 | 'rosetta_sol_l_pos',
104 | 'rosetta_elec_l_pos',
105 | 'rosetta_sol_l_neg',
106 | 'rosetta_elec_l_neg']
107 | files = tf.data.Dataset.list_files(str(input_path))
108 | dataset = files.apply(
109 | tf.data.experimental.parallel_interleave(tf.data.TFRecordDataset,
110 | settings.parallel_calls))
111 | dataset = dataset.map(parse_fn(shape), settings.parallel_calls)
112 | take = list(channel_order)
113 | filts = list(map(lambda m: lambda x: m in x, merge))
114 | take = [x for x in take if any([f(x) for f in filts])]
115 | channel_order, take_fn = take_channels(take, channel_order)
116 | dataset = dataset.map(take_fn, settings.parallel_calls)
117 | seed_t = tf.py_func(lambda : random.randint(0,214748364), [], tf.int64)
118 | dataset = dataset.shuffle(settings.shuffle_buffer_size, seed=seed_t)
119 | shape[-1] = len(channel_order)
120 | if not training and rot:
121 | dataset = dataset.batch(20)
122 | else:
123 | dataset = dataset.batch(settings.batch_size)
124 | dataset = dataset.prefetch(buffer_size=settings.prefetch_buffer_size)
125 | if rot:
126 | dataset = dataset.map(rotate_fn(training,shape),settings.parallel_calls)
127 | return dataset
128 | return in_fn
129 |
130 | def load_fn(dataset_object, channels, rotate, training):
131 | """Create input pipeline function.
132 |
133 | dataset_object : DatasetObject
134 | Dataset for the images being used
135 | channels : list of string
136 | Channel selectors to use
137 | rotate : bool
138 | True if rotations are enabled
139 | training : bool
140 | True if training
141 | """
142 | name = dataset_object.name
143 | shape = [settings.size]*3 + [30]
144 | training = (training == tf.estimator.ModeKeys.TRAIN)
145 | dataset_files = dataset_object.tfrecords / '*.tfrecords'
146 | return make_input_fn(dataset_files, shape, training, rotate, channels)
147 |
148 |
--------------------------------------------------------------------------------
/RosENet/preprocessing/compute_rosetta_energy.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from RosENet.preprocessing.step import Step
3 | import argparse
4 | import os
5 | import h5py
6 | from pyrosetta.rosetta.protocols.scoring import Interface
7 | from pyrosetta.rosetta import *
8 | from pyrosetta import *
9 | from pathlib import Path
10 | import numpy as np
11 | from multiprocessing import Pool, cpu_count
12 | from collections import defaultdict
13 | from pyrosetta.toolbox.atom_pair_energy import print_residue_pair_energies
14 | from RosENet.preprocessing.minimize_rosetta import MinimizeRosetta
15 | init('-in:auto_setup_metals') #-mute core.conformation.Conformation')
16 |
17 | def compute_atom_pair_energy(pdb_filename, ligand_params, interface_cutoff = 21.0):
18 | """Compute pairwise energies and aggregate them to atom-wise values.
19 |
20 | pdb_filename : pathlib.Path
21 | Path for the complex pdb file
22 | ligand_params : pathlib.Path
23 | Path for the ligand params file
24 | interface_cutoff : float
25 | Distance for which to check pairwise interactions between atoms
26 | """
27 | if type(ligand_params) is str:
28 | ligand_params = [ligand_params]
29 | ligand_params = Vector1([str(ligand_params)])
30 |
31 | pose = Pose()
32 | res_set = pose.conformation().modifiable_residue_type_set_for_conf()
33 | res_set.read_files_for_base_residue_types( ligand_params )
34 |
35 | pose.conformation().reset_residue_type_set_for_conf( res_set )
36 | pose_from_file(pose, str(pdb_filename))
37 | scorefxn = create_score_function('ref2015')
38 | pose_score = scorefxn(pose)
39 |
40 | #detect interface
41 | fold_tree = pose.fold_tree()
42 | for jump in range(1, pose.num_jump()+1):
43 | name = pose.residue(fold_tree.downstream_jump_residue(jump)).name()
44 | if name == 'WER':
45 | break
46 | interface = Interface(jump)
47 | interface.distance(interface_cutoff)
48 | interface.calculate(pose)
49 |
50 | energies = []
51 | en = defaultdict(lambda:np.zeros((1,4)))
52 | keys = []
53 | for rnum1 in range(1, pose.total_residue() + 1):
54 | if interface.is_interface(rnum1):
55 | r1 = pose.residue(rnum1)
56 | for a1 in range(1, len(r1.atoms()) + 1):
57 | seq1 = pose.pdb_info().pose2pdb(rnum1).strip().replace(' ','-')
58 | at1 = r1.atom_name(a1).strip()
59 | key1 = seq1 + '-' + at1
60 | for rnum2 in range(rnum1+1, pose.total_residue() + 1):
61 | if interface.is_interface(rnum2):
62 | r2 = pose.residue(rnum2)
63 | for a2 in range(1, len(r2.atoms())+1):
64 | seq2 = pose.pdb_info().pose2pdb(rnum2).strip().replace(' ','-')
65 | at2 = r2.atom_name(a2).strip()
66 | key2 = seq2 + '-' + at2
67 | ee = etable_atom_pair_energies(r1, a1, r2, a2, scorefxn)
68 | if all(e == 0.0 for e in ee):
69 | continue
70 | en[key1] += np.array(ee)
71 | en[key2] += np.array(ee)
72 | energy_matrix = np.array([v for v in en.values()])
73 | return list(en.keys()), energy_matrix
74 |
75 | def get_radii_and_charges(pdb_filename, ligand_params):
76 | """Compute radii and charges for the atoms of the complex.
77 |
78 | pdb_filename : pathlib.Path
79 | Path for the complex pdb file
80 | ligand_params : pathlib.Path
81 | Path for the ligand params file
82 | """
83 | keys = []
84 | charges = []
85 | radii = []
86 |
87 | if type(ligand_params) is str:
88 | ligand_params = [ligand_params]
89 | ligand_params = Vector1([str(ligand_params)])
90 |
91 | pose = Pose()
92 | res_set = pose.conformation().modifiable_residue_type_set_for_conf()
93 | res_set.read_files_for_base_residue_types(ligand_params)
94 |
95 | pose.conformation().reset_residue_type_set_for_conf(res_set)
96 | pose_from_file(pose, str(pdb_filename))
97 | for rnum1 in range(1, pose.total_residue() + 1):
98 | r1 = pose.residue(rnum1)
99 | for a1 in range(1, len(r1.atoms()) + 1):
100 | seq1 = pose.pdb_info().pose2pdb(rnum1).strip().replace(' ','-')
101 | at1 = r1.atom_name(a1).strip()
102 | key1 = seq1 + '-' + at1
103 | charges.append(r1.atomic_charge(a1))
104 | radii.append(r1.atom_type(a1).lj_radius())
105 | keys.append(key1)
106 |
107 | return keys, charges, radii
108 |
109 | class ComputeRosettaEnergy(metaclass=Step,requirements=[MinimizeRosetta]):
110 |
111 | @classmethod
112 | def files(cls, pdb_object):
113 | """List of files being created
114 |
115 | pdb_object : PDBObject
116 | PDB structure being handled
117 | """
118 | return [pdb_object.minimized.complex.attr]
119 |
120 | @classmethod
121 | def _run(cls, pdb_object):
122 | """Inner function for the preprocessing step.
123 |
124 | pdb_object : PDBObject
125 | PDB structure being handled
126 | """
127 | pdb_file = pdb_object.minimized.complex.pdb.path
128 | folder = pdb_file.parent
129 | pdb_code = folder.stem
130 | ligand_params = pdb_object.ligand.params.path
131 | output_file = pdb_object.minimized.complex.attr.path
132 | try:
133 | e_keys, e_values = compute_atom_pair_energy(pdb_file, ligand_params)
134 | rc_keys, charges, radii = get_radii_and_charges(pdb_file, ligand_params)
135 | except Exception as e:
136 | print("Error at ", pdb_file)
137 | print(e)
138 | return
139 | energy_keys = np.array(e_keys)
140 | energy_values = np.array(e_values)
141 | rc_keys = np.array(rc_keys)
142 | radius_values = np.array(radii)
143 | charge_values = np.array(charges)
144 | np.savez_compressed(str(output_file),
145 | energy_keys=e_keys,
146 | energy_values=energy_values,
147 | rc_keys=rc_keys,
148 | radius_values=radius_values,
149 | charge_values=charge_values)
150 |
151 |
--------------------------------------------------------------------------------
/test_dataset/10gs/10gs_ligand.mol2:
--------------------------------------------------------------------------------
1 | ###
2 | ### Created by X-TOOL on Mon Sep 10 21:12:46 2018
3 | ###
4 |
5 | @MOLECULE
6 | 10gs_ligand
7 | 59 60 1 0 0
8 | SMALL
9 | GAST_HUCK
10 |
11 |
12 | @ATOM
13 | 1 N 15.0880 10.7980 23.5470 N.4 1 VWW 0.2328
14 | 2 CA 15.0100 9.9870 24.7920 C.3 1 VWW 0.0304
15 | 3 C 16.1150 8.9240 24.8300 C.2 1 VWW 0.0846
16 | 4 O 16.5200 8.5150 25.9400 O.co2 1 VWW -0.5643
17 | 5 CB 13.6350 9.3270 24.9080 C.3 1 VWW 0.0193
18 | 6 CG 13.3940 8.7080 26.2710 C.3 1 VWW 0.0441
19 | 7 CD 12.0450 8.0460 26.4020 C.2 1 VWW 0.1785
20 | 8 OE1 11.2930 7.9360 25.4350 O.2 1 VWW -0.3969
21 | 9 OXT 16.5780 8.5240 23.7440 O.co2 1 VWW -0.5643
22 | 10 N 11.7260 7.6420 27.6280 N.am 1 VWW -0.2648
23 | 11 CA 10.4720 6.9670 27.9340 C.3 1 VWW 0.1400
24 | 12 CB 10.7260 5.4840 28.2060 C.3 1 VWW 0.0361
25 | 13 SG 11.2910 4.5240 26.8100 S.3 1 VWW -0.1422
26 | 14 CD 9.7290 3.8040 26.2620 C.3 1 VWW 0.0276
27 | 15 CE 8.9300 3.1710 27.3700 C.ar 1 VWW -0.0332
28 | 16 CZ1 7.6400 3.6140 27.6500 C.ar 1 VWW -0.0595
29 | 17 CZ2 9.4640 2.1350 28.1330 C.ar 1 VWW -0.0595
30 | 18 CT1 6.8930 3.0370 28.6730 C.ar 1 VWW -0.0685
31 | 19 CT2 8.7230 1.5500 29.1610 C.ar 1 VWW -0.0685
32 | 20 CH 7.4370 2.0010 29.4300 C.ar 1 VWW -0.0687
33 | 21 C 9.8340 7.5500 29.1800 C.2 1 VWW 0.2059
34 | 22 O 10.5220 8.0230 30.0840 O.2 1 VWW -0.3942
35 | 23 N 8.5120 7.4680 29.2290 N.am 1 VWW -0.2587
36 | 24 CA 7.7400 7.9330 30.3660 C.3 1 VWW 0.1278
37 | 25 CB 6.5550 7.0620 30.6330 C.ar 1 VWW -0.0049
38 | 26 CG1 5.3300 7.3150 30.0270 C.ar 1 VWW -0.0564
39 | 27 CG2 6.6830 5.9410 31.4410 C.ar 1 VWW -0.0564
40 | 28 CD1 4.2500 6.4590 30.2200 C.ar 1 VWW -0.0682
41 | 29 CD2 5.6110 5.0810 31.6400 C.ar 1 VWW -0.0682
42 | 30 CE 4.3920 5.3390 31.0270 C.ar 1 VWW -0.0685
43 | 31 C 7.4520 9.4330 30.3540 C.2 1 VWW 0.0723
44 | 32 O 7.1160 9.9570 31.4330 O.co2 1 VWW -0.5642
45 | 33 OXT 7.5690 10.0680 29.2840 O.co2 1 VWW -0.5642
46 | 34 H1 14.3522 11.4870 23.5482 H 1 VWW 0.2010
47 | 35 H2 14.9824 10.1951 22.7461 H 1 VWW 0.2010
48 | 36 H3 15.9821 11.2614 23.5033 H 1 VWW 0.2010
49 | 37 H4 15.1478 10.6593 25.6517 H 1 VWW 0.1026
50 | 38 H5 13.5582 8.5382 24.1452 H 1 VWW 0.0363
51 | 39 H6 12.8628 10.0891 24.7265 H 1 VWW 0.0363
52 | 40 H7 13.4683 9.4998 27.0309 H 1 VWW 0.0505
53 | 41 H8 14.1719 7.9516 26.4522 H 1 VWW 0.0505
54 | 42 H9 12.3741 7.8085 28.3711 H 1 VWW 0.1883
55 | 43 H10 9.7864 7.0776 27.0810 H 1 VWW 0.0808
56 | 44 H11 9.7847 5.0400 28.5622 H 1 VWW 0.0426
57 | 45 H12 11.4878 5.4113 28.9961 H 1 VWW 0.0426
58 | 46 H13 9.9478 3.0333 25.5083 H 1 VWW 0.0538
59 | 47 H14 9.1210 4.5998 25.8070 H 1 VWW 0.0538
60 | 48 H15 7.2118 4.4189 27.0636 H 1 VWW 0.0557
61 | 49 H16 10.4667 1.7790 27.9259 H 1 VWW 0.0557
62 | 50 H17 5.8903 3.3927 28.8804 H 1 VWW 0.0599
63 | 51 H18 9.1502 0.7453 29.7483 H 1 VWW 0.0599
64 | 52 H19 6.8580 1.5482 30.2268 H 1 VWW 0.0559
65 | 53 H20 8.0270 7.0694 28.4506 H 1 VWW 0.1901
66 | 54 H21 8.3977 7.7875 31.2356 H 1 VWW 0.0887
67 | 55 H22 5.2146 8.1893 29.3967 H 1 VWW 0.0560
68 | 56 H23 7.6317 5.7341 31.9229 H 1 VWW 0.0560
69 | 57 H24 3.2999 6.6666 29.7411 H 1 VWW 0.0600
70 | 58 H25 5.7257 4.2088 32.2735 H 1 VWW 0.0600
71 | 59 H26 3.5543 4.6679 31.1783 H 1 VWW 0.0560
72 | @BOND
73 | 1 2 1 1
74 | 2 2 5 1
75 | 3 2 3 1
76 | 4 3 4 ar
77 | 5 3 9 ar
78 | 6 5 6 1
79 | 7 6 7 1
80 | 8 7 8 2
81 | 9 7 10 am
82 | 10 10 11 1
83 | 11 11 12 1
84 | 12 11 21 1
85 | 13 12 13 1
86 | 14 13 14 1
87 | 15 14 15 1
88 | 16 15 16 ar
89 | 17 15 17 ar
90 | 18 16 18 ar
91 | 19 17 19 ar
92 | 20 18 20 ar
93 | 21 19 20 ar
94 | 22 21 22 2
95 | 23 21 23 am
96 | 24 23 24 1
97 | 25 24 25 1
98 | 26 24 31 1
99 | 27 25 26 ar
100 | 28 25 27 ar
101 | 29 26 28 ar
102 | 30 27 29 ar
103 | 31 28 30 ar
104 | 32 29 30 ar
105 | 33 31 32 ar
106 | 34 31 33 ar
107 | 35 1 34 1
108 | 36 1 35 1
109 | 37 1 36 1
110 | 38 2 37 1
111 | 39 5 38 1
112 | 40 5 39 1
113 | 41 6 40 1
114 | 42 6 41 1
115 | 43 10 42 1
116 | 44 11 43 1
117 | 45 12 44 1
118 | 46 12 45 1
119 | 47 14 46 1
120 | 48 14 47 1
121 | 49 16 48 1
122 | 50 17 49 1
123 | 51 18 50 1
124 | 52 19 51 1
125 | 53 20 52 1
126 | 54 23 53 1
127 | 55 24 54 1
128 | 56 26 55 1
129 | 57 27 56 1
130 | 58 28 57 1
131 | 59 29 58 1
132 | 60 30 59 1
133 | @SUBSTRUCTURE
134 | 1 VWW 1
135 |
136 |
--------------------------------------------------------------------------------
/RosENet/network/utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 | def save_results(model_train_object, channels="", *args):
5 | """Save information about a trained model in csv format.
6 |
7 | model_train_object : _ModelTrain
8 | ModelTrain object representing the instantiation of the training
9 | channels : string
10 | Channel selectors used for the training
11 | args : list of various
12 | Data to be written in the file
13 | """
14 | result_line = ",".join([model_train_object.dataset_name, model_train_object.model_name] + [str(a) for a in args]) + "\n"
15 | model_train_object.results.write(model_train_object.results.read()+result_line)
16 |
17 | def random_rotation_matrix():
18 | """Create a random 3D matrix in TensorFlow"""
19 | A = tf.random_normal((3,3))
20 | Q, _ = tf.linalg.qr(A)
21 | Q = tf.convert_to_tensor([[tf.sign(tf.linalg.det(Q)),0,0],
22 | [0,1,0],
23 | [0,0,1]],dtype=tf.float32) @ Q
24 | return Q
25 |
26 | def line3D(o,t,step):
27 | l0 = tf.linspace(o[0],t[0],step)
28 | l1 = tf.linspace(o[1],t[1],step)
29 | l2 = tf.linspace(o[2],t[2],step)
30 | return tf.stack((l0,l1,l2),axis=1)
31 |
32 | def outer_sum_diag(a,b):
33 | return tf.expand_dims(a,axis=-2) + tf.expand_dims(b,axis=0)
34 |
35 | def random_rot(X, input_shape, output_shape):
36 | output_step = output_shape[0]
37 | batch_size = tf.shape(X)[0]
38 | m = (np.array(input_shape[:-1]) - 1)/2
39 | o = -(np.array(output_shape)-1)/2
40 | corners = np.array([[0,0,0],
41 | [output_step-1,0,0],
42 | [0,output_step-1,0],
43 | [0,0,output_step-1]]) + o
44 | corners = tf.convert_to_tensor(corners, dtype=tf.float32)
45 | rotation_matrix = random_rotation_matrix()
46 | rotated_corners = corners @ rotation_matrix
47 | rotated_corners = rotated_corners + tf.convert_to_tensor(m, dtype=tf.float32)
48 | o = rotated_corners[0]
49 | a = rotated_corners[1]
50 | b = rotated_corners[2]
51 | c = rotated_corners[3]
52 | line_oa = line3D(o,a,output_step)
53 | vector_ob = line3D(o,b,output_step) - o
54 | vector_oc = line3D(o,c,output_step) - o
55 | plane_oab = outer_sum_diag(vector_ob,line_oa)
56 | rotated_query_points = outer_sum_diag(plane_oab, vector_oc)
57 | rotated_query_points = tf.cast(tf.round(rotated_query_points),tf.int64)
58 | rotated_query_points = tf.expand_dims(rotated_query_points, axis=0)
59 | rotated_query_points = tf.reshape(rotated_query_points, [1,-1,3])
60 | rotated_query_points = tf.tile(rotated_query_points, [batch_size, 1, 1])
61 | batch_column = tf.cast(tf.reshape(tf.tile(tf.reshape(tf.range(batch_size),(-1,1,1)), [1, 1, np.prod(output_shape)]),(-1,np.prod(output_shape),1)),tf.int64)
62 | rotated_query_points = tf.concat((batch_column,rotated_query_points),axis=2)
63 | X_output = tf.gather_nd(X, rotated_query_points)
64 | X_output = tf.reshape(X_output, [batch_size] + output_shape + [input_shape[-1]])
65 | return X_output
66 |
67 | def basic_rot(X, axis, ccw=True):
68 | rot_axes = [x for x in [0,1,2] if x != axis]
69 | axis_1, axis_2 = rot_axes
70 | perm = [x if x not in rot_axes else axis_1 if x == axis_2 else axis_2 for x in [0,1,2,3]]
71 | rev_axis = axis_1 if ccw else axis_2
72 | return tf.reverse(tf.transpose(X,perm),[rev_axis])
73 |
74 | def basic_rot_5D(X, axis, ccw=True):
75 | axis = axis + 1
76 | rot_axes = [x for x in [1,2,3] if x != axis]
77 | axis_1, axis_2 = rot_axes
78 | perm = [x if x not in rot_axes else axis_1 if x == axis_2 else axis_2 for x in [0,1,2,3,4]]
79 | rev_axis = axis_1 if ccw else axis_2
80 | return tf.reverse(tf.transpose(X,perm),[rev_axis])
81 |
82 |
83 | def rots_90(X):
84 | x_90 = basic_rot(X, 0)
85 | x_180 = tf.reverse(X, [1,2])
86 | x_270 = basic_rot(X, 0, ccw=False)
87 | maps = [X, x_90, x_180, x_270]
88 | extended_maps = list(maps)
89 | for m in maps:
90 | z_90 = basic_rot(m, 2)
91 | z_180 = tf.reverse(m, [0,1])
92 | z_270 = basic_rot(m, 2, ccw=False)
93 | y_90 = basic_rot(m, 1)
94 | y_270 = basic_rot(m, 1, ccw=False)
95 | extended_maps += [z_90, z_180, z_270, y_90, y_270]
96 | return tf.stack(extended_maps)
97 |
98 | def rots_90_5D(X):
99 | x_90 = basic_rot_5D(X, 0)
100 | x_180 = tf.reverse(X, [1,2])
101 | x_270 = basic_rot_5D(X, 0, ccw=False)
102 | maps = [X, x_90, x_180, x_270]
103 | extended_maps = list(maps)
104 | for m in maps:
105 | z_90 = basic_rot_5D(m, 2)
106 | z_180 = tf.reverse(m, [0,1])
107 | z_270 = basic_rot_5D(m, 2, ccw=False)
108 | y_90 = basic_rot_5D(m, 1)
109 | y_270 = basic_rot_5D(m, 1, ccw=False)
110 | extended_maps += [z_90, z_180, z_270, y_90, y_270]
111 | return tf.stack(extended_maps)
112 |
113 | def random_rot_90(X,shape):
114 | def my_func(X):
115 | # x will be a numpy array with the contents of the placeholder below
116 | anti = np.random.randint(0,2,1)
117 | if anti == 0:
118 | shift = np.random.randint(0,3,1)
119 | x, y, z = ((np.array([0,1,2]) + shift) % 3) + 1
120 | flip = np.random.choice([None,(1,2),(1,3),(2,3)],1)[0]
121 | X = X.transpose((0,x,y,z,4))
122 | if flip is not None:
123 | a,b = flip
124 | X = np.flip(X,(a,b))
125 | else:
126 | shift = np.random.randint(0,3,1)
127 | x, y, z = ((np.array([1,0,2]) + shift) % 3) + 1
128 | flip = np.random.choice([1,2,3,None],1)[0]
129 | X = X.transpose((0,x,y,z,4))
130 | if flip is not None:
131 | X = np.flip(X,flip)
132 | else:
133 | X = np.flip(X,(1,2,3))
134 | return X
135 | return tf.reshape(tf.py_func(my_func, [X], tf.float32),shape)
136 |
137 | def all_rot_90(X,shape):
138 | def my_func(X):
139 | # x will be a numpy array with the contents of the placeholder below
140 | results = []
141 | for point in X:
142 | for shift in range(3):
143 | x, y, z = ((np.array([0,1,2]) + shift) % 3)
144 | for flip in [None,(0,1),(0,2),(1,2)]:
145 | aux = point.transpose((x,y,z,3))
146 | if flip is not None:
147 | a,b = flip
148 | results.append(np.flip(aux,(a,b)))
149 | else:
150 | results.append(aux)
151 | for shift in range(3):
152 | x, y, z = ((np.array([1,0,2]) + shift) % 3)
153 | for flip in [0,1,2,None]:
154 | aux = point.transpose((x,y,z,3))
155 | if flip is not None:
156 | results.append(np.flip(aux,flip))
157 | else:
158 | results.append(np.flip(aux,(0,1,2)))
159 | ret = np.stack(results,axis=0)
160 | return ret
161 | return tf.reshape(tf.py_func(my_func, [X], tf.float32),shape)
162 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # RosENet
4 |
5 | This is the repository for the RosENet project.
6 |
7 | RosENet: A 3D Convolutional Neural Network for predicting absolute binding affinity using molecular mechanics energies
8 |
9 | Hussein Hassan-Harrirou¹, Ce Zhang¹, Thomas Lemmin 1,2,*
10 |
11 | ¹DS3Lab, System Group, Department of Computer Sciences, ETH Zurich, CH-8092 Zurich, Switzerland
12 |
13 | ²Institute of Medical Virology, University of Zurich (UZH), CH-8057 Zurich, Switzerland
14 |
15 | \*corresponding author: thomas.lemmin@inf.ethz.ch
16 |
17 | ## Prerequisites
18 |
19 | A script `install.sh` is included, which creates a Conda environment and install the necessary requirements.
20 |
21 | Rosetta and Pyrosetta must be installed manually due to the required license.
22 |
23 | Pyrosetta can be copied next to the RosENet folder or by adding it to the `PYTHONPATH`.
24 |
25 | You must set the path to Rosetta's main folder in the attribute `rosetta.root` in file `RosENet/constants.py` [LINK](RosENet/constants.py#L65)
26 |
27 | ## How to run
28 |
29 | This repository is built as a Python module. To execute it, one should go above the root folder and run the following command:
30 |
31 | ```
32 | python3 -m RosENet -- arguments
33 | ```
34 |
35 | Currently, six actions are implemented:
36 |
37 | 1. preprocess
38 | 2. voxelize
39 | 3. postprocess
40 | 4. train
41 | 5. evaluate
42 | 6. predict
43 |
44 | These actions are also thought to be executed in this order.
45 | Preprocessing will compute the file transformations and energy minimization on the text files.
46 | Voxelization allows to create the 3D images for HTMD, Rosetta and electronegativity features and combine them.
47 | Postprocessing will transform the 3D images into a TFRecords format to be used as input for the neural networks.
48 | Training, evaluation and prediction have the usual meanings.
49 |
50 | Substeps in `preprocess` are cached. If some steps fail for some reason (usually faulty installations), the correct procedure after fixing the original problem is to clean the cached dataset. To do this, run `bash clear_dataset.sh `, where `` is the path to the dataset (i.e. `test_dataset`).
51 |
52 | ## Folder structure
53 |
54 | To ease the setup, there are some requirements to the structure of the inputs.
55 | The coarse unit of work is the dataset. A dataset is a folder with pdb folders and a file called labels.
56 | The next unit of work is the pdb folder. A pdb folder stores a {code}_protein.pdb file and a {code}_ligand.mol2.
57 | The labels file stores a line per pdb folder with the binding affinity assigned to it, separated by a space
58 |
59 |
60 | A valid folder structure:
61 | ```
62 | test_dataset\
63 | 10gs\
64 | 10gs_protein.pdb
65 | 10gs_ligand.mol2
66 | labels
67 | ```
68 |
69 | A valid `labels` file:
70 | ```
71 | 10gs 12
72 | ```
73 |
74 | ## Neural networks
75 |
76 | There are three CNNs implemented in the module, but any other implementation can be used. The only requirement is following TensorFlow's Estimator API.
77 | It is easy to replicate the same structure as the ones given here.
78 |
79 | After training a neural network, the module will save the trained network in the dataset folder, with the essential parameters also written in the folder name. Additionally, a line with the minimal validation error and the epoch when it was achieved will be saved in a results file in the dataset folder. The random seed used during training will also be stored in this line. This seed will be necessary to load the network afterwards.
80 |
81 | ## Selecting feature subsets
82 |
83 | To select feature subsets, we can use the channel selectors. These are implemented by matching them as a substring of the feature names. Multiple channel selectors can be combined with underscored.
84 |
85 | The feature names are:
86 |
87 | * `htmd_hydrophobic`
88 | * `htmd_aromatic`
89 | * `htmd_hbond_acceptor`
90 | * `htmd_hbond_donor`
91 | * `htmd_positive_ionizable`
92 | * `htmd_negative_ionizable`
93 | * `htmd_metal`
94 | * `htmd_occupancies`
95 | * `elec_p`
96 | * `elec_l`
97 | * `rosetta_atr_p`
98 | * `rosetta_rep_p`
99 | * `rosetta_sol_p_pos`
100 | * `rosetta_elec_p_pos`
101 | * `rosetta_sol_p_neg`
102 | * `rosetta_elec_p_neg`
103 | * `rosetta_atr_l`
104 | * `rosetta_atr_p`
105 | * `rosetta_sol_l_pos`
106 | * `rosetta_elec_l_pos`
107 | * `rosetta_sol_l_neg`
108 | * `rosetta_elec_l_neg`
109 |
110 | All `htmd` feature names represent both the protein and ligand features, so they are effectively two channels.
111 |
112 | For example, using `htmd_rosetta` will include all the HTMD features, and all the Rosetta features.
113 |
114 | The combination used in the paper release is `aromatic_acceptor_ion_rosetta`, adding to 20 feature maps.
115 |
116 |
117 | ## Setting the randomness
118 |
119 | The neural network actions use a seed parameter. For training, this seed is optional and can be randomly chosen by the module.
120 |
121 | The seed is used to identify different trainings of the same model/data. When evaluating and predicting, the seed must be specified. The seed can be found in the name of the trained model folder, located under the training dataset folder with name format `___`.
122 |
123 | ## Settings
124 |
125 | There are a few parameters that may be modified to change the behavior of the module. The file `settings.py` stores these options.
126 | There we can change the voxelization methods and the size of the voxel image. We can also change the parameters for TensorFlow's input pipeline.
127 |
128 | ## Running instructions
129 |
130 | To run the actions mentioned above with the example dataset, we need to run the following commands:
131 |
132 | ### Data preparation
133 |
134 | For the example dataset, substitute `` with `test_dataset`
135 |
136 | 1. preprocess
137 | ```
138 | python3 -m RosENet -- preprocess
139 | ```
140 | 2. voxelize
141 | ```
142 | python3 -m RosENet -- voxelize
143 | ```
144 | 3. postprocess
145 | ```
146 | python3 -m RosENet -- postprocess
147 | ```
148 |
149 | ### Network training/evaluation
150 |
151 | Substitute `<*_dataset>` with the paths to the respective datasets.
152 |
153 | Substitute `` for the path of a network model code. Some examples are located under `RosENet/models` (i.e. `RosENet/models/resnet.py`)
154 |
155 | Substitute `` for an underscore-separated string of channel selectors.
156 |
157 | Substitute `` for a non-negative integer to be used as seed.
158 |
159 | 4. train
160 | ```
161 | python3 -m RosENet -- train []
162 | ```
163 | 5. evaluate
164 | ```
165 | python3 -m RosENet -- evaluate
166 | ```
167 | 6. predict
168 | ```
169 | python3 -m RosENet -- predict
170 | ```
171 |
172 | ### Preprocessed datasets
173 |
174 | The datasets for training and validation of RosENet have been published and are accesible in https://doi.org/10.5281/zenodo.4625486
175 |
176 |
--------------------------------------------------------------------------------
/RosENet/voxelization/voxelizers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from htmd.molecule.molecule import Molecule
3 | from htmd.molecule.voxeldescriptors import getVoxelDescriptors
4 | from operator import itemgetter
5 | from itertools import compress
6 | from prody import parsePDB, calcCenter, moveAtoms
7 | from mendeleev import element
8 | from types import SimpleNamespace
9 | from RosENet.preprocessing.make_pdbqt import MakePDBQT
10 | from RosENet.preprocessing.compute_rosetta_energy import ComputeRosettaEnergy
11 | from .utils import *
12 | from .filter import voxel_filter
13 | from .interpolation import voxel_interpolation
14 |
15 |
16 |
17 | class Voxelizer:
18 | """Voxelizer abstract class.
19 | Voxelizers implement the algorithms that generate 3D voxelized images of
20 | structural attributes.
21 |
22 | Parameters
23 | ----------
24 | size : int
25 | Size of each side of the 3D image, represented as a cube of voxels.
26 |
27 | Attributes
28 | ----------
29 | size : int
30 | Size of each side of the 3D image, represented as a cube of voxels.
31 | protein : SimpleNamespace
32 | Storage object for protein-specific values used during computation and
33 | protein-specific images.
34 | ligand : SimpleNamespace
35 | Storage object for ligand-specific values used during computation and
36 | ligand-specific images.
37 | image : numpy.ndarray
38 | 3D image of the structure obtained as a result of the voxelization.
39 |
40 | """
41 |
42 | def __init__(self, size):
43 | self.protein = type('', (), {})
44 | self.ligand = type('', (), {})
45 | self.image = None
46 | self.size = size
47 |
48 | def voxelize(self):
49 | """Compute the voxelized 3D image and store it in self.image"""
50 | pass
51 |
52 |
53 | class PointwiseVoxelizer(Voxelizer):
54 | """Base voxelizer class implementing the methods to generate 3D images given
55 | pointwise-values assigned to each atom.
56 |
57 | See Also
58 | --------
59 | RosettaVoxelizer : Voxelizer for the Rosetta energy-function features.
60 | ElectronegativityVoxelizer : Voxelizer for the electronegativity features.
61 |
62 | Parameters
63 | ----------
64 | complex_path : str or os.PathLike
65 | Path to the complex structure, a .pdb file.
66 | size : int
67 | Size of each side of the 3D image, represented as a cube of voxels.
68 | method_type : str
69 | Name of the method to voxelize the pointwise-values. Accepted values are
70 | "filter" or "interpolation".
71 | method_fn : str or callable
72 | If method_type == "filter", a function handle with two parameters that
73 | implements the filter to be applied. If method_type == "interpolation",
74 | a string or function handle suitable for scipy.interpolate.Rbf's
75 | function.
76 | """
77 |
78 | def __init__(self, complex_path, data_object, size, method_type, method_fn):
79 | super(PointwiseVoxelizer, self).__init__(size)
80 | self.path = str(complex_path)
81 | self.data_object = data_object
82 | self.method_type = method_type
83 | self.method_fn = method_fn
84 |
85 | def _prepare_points(self):
86 | """Load structures and compute the location of the points of the
87 | 3D image to be generated.
88 | """
89 | self.complex = parsePDB(self.path)
90 | protein = self.complex.select("not (resname WER or water)")
91 | ligand = self.complex.select("resname WER")
92 | center = calcCenter(ligand.getCoords())
93 | moveAtoms(self.complex, by=-center)
94 | center = calcCenter(self.complex.select("resname WER").getCoords())
95 | self.protein.structure = protein
96 | self.ligand.structure = ligand
97 | self.points = grid_around(center, self.size, spacing=24/(self.size-1))
98 |
99 | def _load_attributes(self):
100 | """Load the computed energies, radii and charges for the atoms in the
101 | complex. These may not include all atoms, but only the ones around 20 A
102 | around the center of mass of the ligand."""
103 | data = self.data_object.read()
104 | self.radii = dict(zip(data['rc_keys'], data['radius_values']))
105 | self.charges = dict(zip(data['rc_keys'], data['charge_values']))
106 | self.energies = data['energy_values'].squeeze()
107 | self.energy_keys = data['energy_keys']
108 | self.energy_dict = dict(zip(self.energy_keys, self.energies))
109 |
110 | def _apply_filter(self):
111 | """Apply the filter to generate the voxelized images, for both protein
112 | and ligand separately."""
113 | self.protein.image = voxel_filter(
114 | self.method_fn, self.protein, self.points)\
115 | .reshape(3*(self.size,) + (-1,))
116 | self.ligand.image = voxel_filter(
117 | self.method_fn, self.ligand, self.points)\
118 | .reshape(3*(self.size,) + (-1,))
119 |
120 | def _apply_interpolation(self):
121 | """Apply the interpolation method to generate the voxelized images, for
122 | both protein and ligand separately."""
123 | self.protein.image = voxel_interpolation(
124 | self.method_fn, self.protein, self.points)\
125 | .reshape(3*(self.size,) + (-1,))
126 | self.ligand.image = voxel_interpolation(
127 | self.method_fn, self.ligand, self.points)\
128 | .reshape(3*(self.size,) + (-1,))
129 |
130 | def _prepare_attributes(self, obj):
131 | """Special method for loading and masking the atom coordinates, names,
132 | radii and charges, masking the selection to only the atoms that have
133 | energies computed.
134 |
135 | Notes
136 | -----
137 |
138 | This method is only used to load the attributes of self.protein
139 | and self.ligand.
140 |
141 | Parameters
142 | ----------
143 | obj : SimpleNamespace
144 | """
145 | try:
146 | obj.keys = get_keys(obj.structure)
147 | obj.atoms_in_scope = [x in self.energy_keys for x in obj.keys]
148 | obj.coordinates = np.compress(obj.atoms_in_scope, obj.structure.getCoords(), axis=0).reshape((-1, 3))
149 | obj.keys = list(compress(obj.keys, obj.atoms_in_scope))
150 | obj.radii = np.array(itemgetter(*obj.keys)
151 | (self.radii)).reshape((-1, 1))
152 | obj.charges = np.array(itemgetter(*obj.keys)
153 | (self.charges)).reshape((-1, 1))
154 | except Exception as e:
155 | print(self.path)
156 | print("#"*81)
157 | raise e
158 |
159 | def _prepare_values(self):
160 | """Voxelizer specific value loading."""
161 | pass
162 |
163 | def _normalize(self):
164 | """Voxelizer specific after-voxelization normalization."""
165 | pass
166 |
167 | def _merge(self):
168 | """Merges protein and ligand 3D images into the resulting self.image."""
169 | self.image = np.concatenate(
170 | (self.protein.image, self.ligand.image), axis=-1)
171 |
172 | def voxelize(self):
173 | """Main voxelization method, overrides Voxelizer's voxelize method.
174 | Implements the steps to perform a pointwise-valued voxelization."""
175 | self._load_attributes()
176 | self._prepare_points()
177 | self._prepare_attributes(self.protein)
178 | self._prepare_attributes(self.ligand)
179 | self._prepare_values()
180 | if self.method_type == "filter":
181 | self._apply_filter()
182 | elif self.method_type == "interpolation":
183 | self._apply_interpolation()
184 | self._normalize()
185 | self._merge()
186 |
187 | class ElectronegativityVoxelizer(PointwiseVoxelizer):
188 | """Voxelizer class to generate 3D images given per-atom electronegativity
189 | values.
190 |
191 | See Also
192 | --------
193 | RosettaVoxelizer : Voxelizer for the Rosetta energy-function features.
194 | PointwiseVoxelizer : Base voxelizer for pointwise features.
195 | """
196 | def _prepare_values(self):
197 | """Sets up the electronegativities as values for the voxelization,
198 | for both protein and ligand."""
199 | elements = set(self.complex.getElements())
200 | el_dict = {name: element(
201 | name.capitalize()).en_pauling for name in elements}
202 | self.protein.values = np.array(list(compress((el_dict[name] for name in self.protein.structure.getElements()), self.protein.atoms_in_scope)))\
203 | .reshape((-1, 1))
204 | self.ligand.values = np.array(list(compress([el_dict[name] for name in self.ligand.structure.getElements()], self.ligand.atoms_in_scope)))\
205 | .reshape((-1, 1))
206 |
207 |
208 | def _normalize(self):
209 | """Normalizes the electronegativities by dividing by the highest values
210 | in the protein's and ligand's electronegativies."""
211 | self.protein.image = self.protein.image / 3.44
212 | self.ligand.image = self.ligand.image / 3.98
213 |
214 |
215 | class RosettaVoxelizer(PointwiseVoxelizer):
216 | """Voxelizer class to generate 3D images given per-atom Rosetta
217 | energy-function values.
218 |
219 | See Also
220 | --------
221 | ElectronegativityVoxelizer : Voxelizer for electronegativity features.
222 | PointwiseVoxelizer : Base voxelizer for pointwise features.
223 | """
224 | def _prepare_values(self):
225 | """Prepare energy values for voxelization"""
226 | self.protein.values = np.array(itemgetter(
227 | *self.protein.keys)(self.energy_dict)).reshape((-1, 4))
228 | self.ligand.values = np.array(itemgetter(
229 | *self.ligand.keys)(self.energy_dict)).reshape((-1, 4))
230 |
231 | def _normalize(self):
232 | """Normalize the energy maps to 0-1 range and split positive and
233 | negative maps."""
234 | protein_limits = [-2.27475644, 4.11150004e+02,
235 | 4.01129569e+00, 1.28136470e+00,
236 | -0.37756078, -3.79981632]
237 | ligand_limits = [-1.7244696, 0.59311066, 2.67294434,
238 | 0.40072521, -0.44943017, -2.00621753]
239 | self.protein.image = np.concatenate((self.protein.image,
240 | self.protein.image[...,2:]),axis=-1)
241 | self.ligand.image = np.concatenate((self.ligand.image,
242 | self.ligand.image[...,2:]),axis=-1)
243 | self.protein.image = clip(self.protein.image, protein_limits)
244 | self.ligand.image = clip(self.ligand.image, ligand_limits)
245 |
246 |
247 | class HTMDVoxelizer(Voxelizer):
248 | """Base voxelizer class implementing the methods to generate 3D images given
249 | pointwise-values assigned to each atom.
250 |
251 | Parameters
252 | ----------
253 | protein_path : str or os.PathLike
254 | Path to the protein structure, a .pdb file.
255 | ligand_path : str or os.PathLike
256 | Path to the ligand structure, a .pdb file.
257 | size : int
258 | Size of each side of the 3D image, represented as a cube of voxels.
259 | """
260 | def __init__(self, protein_path, ligand_path, size):
261 | super(HTMDVoxelizer, self).__init__(size)
262 | self.protein.path = protein_path
263 | self.ligand.path = ligand_path
264 |
265 | def _prepare_points(self):
266 | """Load structures and compute the location of the points of the
267 | 3D image to be generated.
268 | """
269 | protein = Molecule(str(self.protein.path))
270 | ligand = Molecule(str(self.ligand.path))
271 | protein.filter(
272 | 'not (water or name CO or name NI or name CU or name NA)')
273 | center = np.mean(ligand.get('coords'), axis=0)
274 | ligand.moveBy(-center)
275 | protein.moveBy(-center)
276 | center = np.mean(ligand.get('coords'), axis=0)
277 | self.protein.structure = protein
278 | self.ligand.structure = ligand
279 | self.points = grid_around(center, self.size, spacing=24/(self.size-1)).reshape((-1,3))
280 |
281 |
282 | def voxelize(self):
283 | """Compute the voxelized 3D image and store it in self.image"""
284 | self._prepare_points()
285 | self.protein.image = getVoxelDescriptors(
286 | self.protein.structure, usercenters=self.points)[0]\
287 | .reshape(3*(self.size,) + (-1,))
288 | self.ligand.image = getVoxelDescriptors(
289 | self.ligand.structure, usercenters=self.points)[0]\
290 | .reshape(3*(self.size,) + (-1,))
291 | self.image = np.concatenate((self.protein.image, self.ligand.image), axis=-1)
292 |
293 | def VoxelizeRosetta(pdb_object, method, size):
294 | """Wrapper function for voxelizing and storing Rosetta features.
295 |
296 | pdb_object : PDBObject
297 | PDB structure to voxelize.
298 | method : tuple of (string, string or callable)
299 | Voxelization method and function. See underlying function for more information.
300 | size : int
301 | Size of voxel cube side.
302 | """
303 | if not ComputeRosettaEnergy.computed(pdb_object):
304 | return False
305 | complex_path = pdb_object.minimized.complex.pdb.path
306 | data_object = pdb_object.minimized.complex.attr
307 | output_path = pdb_object.image.rosetta.path
308 | if output_path.exists():
309 | return True
310 | method_type, method_fn = method
311 | voxelizer = RosettaVoxelizer(complex_path, data_object, size, method_type, method_fn)
312 | voxelizer.voxelize()
313 | pdb_object.image.rosetta.write(voxelizer.image)
314 | return True
315 |
316 | def VoxelizeElectronegativity(pdb_object, method, size):
317 | """Wrapper function for voxelizing and storing electronegativity features.
318 |
319 | pdb_object : PDBObject
320 | PDB structure to voxelize.
321 | method : tuple of (string, string or callable)
322 | Voxelization method and function. See underlying function for more information.
323 | size : int
324 | Size of voxel cube side.
325 | """
326 | if not ComputeRosettaEnergy.computed(pdb_object):
327 | return False
328 | complex_path = pdb_object.minimized.complex.pdb.path
329 | data_object = pdb_object.minimized.complex.attr
330 | output_path = pdb_object.image.electronegativity.path
331 | if output_path.exists():
332 | return True
333 | method_type, method_fn = method
334 | voxelizer = ElectronegativityVoxelizer(complex_path, data_object, size, method_type, method_fn)
335 | voxelizer.voxelize()
336 | pdb_object.image.electronegativity.write(voxelizer.image)
337 |
338 | def VoxelizeHTMD(pdb_object, size):
339 | """Wrapper function for voxelizing and storing HTMD features.
340 |
341 | pdb_object : PDBObject
342 | PDB structure to voxelize.
343 | size : int
344 | Size of voxel cube side.
345 | """
346 | if not MakePDBQT.computed(pdb_object):
347 | return False
348 | protein_path = pdb_object.minimized.protein.pdbqt.path
349 | ligand_path = pdb_object.minimized.ligand.pdbqt.path
350 | output_path = pdb_object.image.htmd.path
351 | if output_path.exists():
352 | return True
353 | voxelizer = HTMDVoxelizer(protein_path, ligand_path, size)
354 | voxelizer.voxelize()
355 | pdb_object.image.htmd.write(voxelizer.image)
356 |
--------------------------------------------------------------------------------