├── user ├── environment.yaml ├── readme.md ├── 0_get_file_from_hugging_face.ipynb └── 2_mdCATH_ML.ipynb ├── generator ├── process │ ├── append_info_from_joined.py │ ├── readme.md │ ├── join_multiple_h5.py │ ├── tools.py │ ├── read_info.ipynb │ ├── append_info_toh5.py │ └── write_info_toh5.py ├── builder │ ├── scheduler.py │ ├── utils.py │ ├── trajManager.py │ ├── generator.py │ └── molAnalyzer.py ├── prepare │ └── run.py └── sanity_check │ └── sanity_check.py ├── LICENSE ├── analysis ├── ks_tests_for_rsf_distributions_between_superfamilies.py ├── ks_tests.out ├── plot_from_h5.py └── utils.py ├── user-utils ├── README.md ├── load_mdCATH.tcl └── convert_mdCATH.py ├── README.md └── .gitignore /user/environment.yaml: -------------------------------------------------------------------------------- 1 | name: mdcath_torchmdnet 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - torchmd-net>=2.4.0 7 | - pip: 8 | - ipykernel 9 | - ase -------------------------------------------------------------------------------- /user/readme.md: -------------------------------------------------------------------------------- 1 | ## mdCATH User Guide 2 | 3 | To run the 2_mdCATH_ML.ipynb notebook, follow these steps to set up the environment: 4 | 5 | 1. Create the environment by running the following command: 6 | 7 | ```bash 8 | mamba env create -f environment.yml 9 | ``` 10 | 11 | 2. Activate the environment named mdcath_torchmdnet in your notebook. -------------------------------------------------------------------------------- /generator/process/append_info_from_joined.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | from tqdm import tqdm 3 | from tools import readPDBs 4 | 5 | if __name__ == "__main__": 6 | # Set the directory and base filename (mdcath_analysis or mdcath_source) 7 | source = '/PATH/TO/SOURCE/FILE/FROM/WHICH/TO/COPY.h5' 8 | dest = 'h5files/mdcath_noh_source.h5' 9 | pdb_list = '/PATH/TO/PDB/LIST/FILE.txt/OR/LIST' 10 | 11 | doms_list = readPDBs(pdb_list) 12 | 13 | with h5py.File(source, mode='r') as source_h5: 14 | with h5py.File(dest, mode='a') as dest_h5: 15 | for dom in tqdm(doms_list, total=len(doms_list)): 16 | # del the group from dest_h5 if it exists 17 | if dom in dest_h5: 18 | del dest_h5[dom] 19 | # copy the group from source_h5 to dest_h5 20 | source_h5.copy(dom, dest_h5) -------------------------------------------------------------------------------- /generator/process/readme.md: -------------------------------------------------------------------------------- 1 | This directory contains the script to write essential info to a unique h5 file. There are 2 possible file_type: 2 | - mdcath_source.h5: used by the mdcath_dataloader in torchmd-net to select the data and setup the idxs. 3 | - mdcath_analysis.h5: used to store the data for the analysis and visualization by analysis/plot_metrics_from_h5.py 4 | 5 | Script usage: 6 | - append_info_from_joined.py: append info to mdcath_ source/analysis h5 file, using another source/analysis h5 file (copy) 7 | - append_info_to_h5.py: append/modify source/analysis h5 file retrieving the info from the mdcath dataset h5 files 8 | - write_info_to_h5.py: generate the selected file type (source/analysis) using batching 9 | - read_info.ipynb: to inspect the content of the mdcath source/analysis h5 file 10 | - join_multiple_h5.py: to join multiple h5 files (batches), output of write_info_to_h5.py, into a single h5 file -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 The authors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /generator/builder/scheduler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | 5 | class ComputationScheduler: 6 | """Class for the parallelization of the dataset generation and computation using batches.""" 7 | 8 | def __init__(self, batchSize, startBatch, numBatches, molecules): 9 | self.batchSize = batchSize 10 | self.numBatches = numBatches 11 | self.molecules = molecules 12 | self.startBatch = startBatch if startBatch != None else 0 13 | 14 | def getBatches(self): 15 | self.allBatches = np.arange(self.startBatch, self.numBatches) 16 | return self.allBatches 17 | 18 | def process(self, idBatch): 19 | assert idBatch >= 0 20 | assert idBatch in self.allBatches 21 | self.idStart = self.batchSize * idBatch 22 | self.idEnd = min(self.batchSize * (idBatch + 1), len(self.molecules)) 23 | indices = self.molecules[self.idStart : self.idEnd] 24 | return indices 25 | 26 | def getFileName(self, outPath, idBatch, fileName=None): 27 | name = fileName if fileName != None else "cath_dataset" 28 | resFile = os.path.join( 29 | outPath, 30 | f"{name}_{idBatch:06d}_{self.idStart:09d}_{self.idEnd-1:09d}.h5", 31 | ) 32 | return resFile 33 | -------------------------------------------------------------------------------- /generator/process/join_multiple_h5.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import shutil 3 | import tempfile 4 | from glob import glob 5 | from tqdm import tqdm 6 | from os.path import join as opj 7 | 8 | 9 | def sorter(filelist): 10 | """returns a sorted list of files based on the last number (batch number) in the filename""" 11 | sortdict = {} 12 | for file in filelist: 13 | sortdict[int(file.split("_")[-1].split(".")[0])] = file 14 | return [sortdict[key] for key in sorted(sortdict.keys())] 15 | 16 | 17 | if __name__ == "__main__": 18 | # Set the directory and base filename (mdcath_analysis or mdcath_source) 19 | batches_dir = "h5files/batch_files/" 20 | base_filename = "mdcath_analysis" 21 | 22 | h5_list = sorter(glob(opj(batches_dir, f"{base_filename}_*.h5"))) 23 | with tempfile.TemporaryDirectory() as temp: 24 | with h5py.File(opj(temp, "merged.h5"), "w") as merged: 25 | for h5_file in tqdm(h5_list, total=len(h5_list), desc="Merging"): 26 | with h5py.File(h5_file, "r") as h5: 27 | for key in h5.keys(): 28 | if key in merged.keys(): 29 | print(f"Key {key} already in merged file") 30 | continue 31 | merged.copy(h5[key], key) 32 | 33 | shutil.copyfile(opj(temp, "merged.h5"), f"./{base_filename}.h5") 34 | print(f"Merged file saved to {base_filename}.h5')") 35 | -------------------------------------------------------------------------------- /generator/builder/utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | 4 | class LoadFromFile(argparse.Action): 5 | """Load a configuration file and update the namespace""" 6 | 7 | def __call__(self, parser, namespace, values, option_string=None): 8 | if values.name.endswith("yaml") or values.name.endswith("yml"): 9 | with values as f: 10 | namespace.__dict__.update(yaml.load(f, Loader=yaml.FullLoader)) 11 | return 12 | 13 | with values as f: 14 | input = f.read() 15 | input = input.rstrip() 16 | for lines in input.split("\n"): 17 | k, v = lines.split("=") 18 | typ = type(namespace.__dict__[k]) 19 | v = typ(v) if typ is not None else v 20 | namespace.__dict__[k] = v 21 | 22 | def save_argparse(args, filename, exclude=None): 23 | if filename.endswith("yaml") or filename.endswith("yml"): 24 | if isinstance(exclude, str): 25 | exclude = [exclude] 26 | args = args.__dict__.copy() 27 | for exl in exclude: 28 | del args[exl] 29 | yaml.dump(args, open(filename, "w")) 30 | else: 31 | raise ValueError("Configuration file should end with yaml or yml") 32 | 33 | def readPDBs(pdbList): 34 | if isinstance(pdbList, list): 35 | return pdbList 36 | pdblist = [] 37 | with open(pdbList, "r") as f: 38 | for line in f: 39 | pdblist.append(line.strip()) 40 | return sorted(pdblist) -------------------------------------------------------------------------------- /analysis/ks_tests_for_rsf_distributions_between_superfamilies.py: -------------------------------------------------------------------------------- 1 | # Kolmogorov-Smirnov test for the difference between distributions of 2 | # secondary structure contents at the end of the trajectories between the 3 | # four CATH superfamilies (reviewer 4 request). 4 | 5 | import scipy 6 | import pandas as pd 7 | 8 | def ks_pairwise_tests(df): 9 | for sf1 in range(4): 10 | for sf2 in range(sf1+1,4): 11 | p=scipy.stats.ks_2samp(df.loc[df.sf==sf1+1,"all_alpha_beta"], 12 | df.loc[df.sf==sf2+1,"all_alpha_beta"]).pvalue 13 | print(f"KS test between superfamily {sf1+1} and {sf2+1}: p = {p}") 14 | 15 | def mw_pairwise_tests(df): 16 | for sf1 in range(4): 17 | for sf2 in range(sf1+1,4): 18 | p=scipy.stats.mannwhitneyu(df.loc[df.sf==sf1+1,"all_alpha_beta"], 19 | df.loc[df.sf==sf2+1,"all_alpha_beta"]).pvalue 20 | print(f"MW test between superfamily {sf1+1} and {sf2+1}: p = {p}") 21 | 22 | 23 | T = 450 24 | timepoint = 400 25 | 26 | print(f"Comparing at {T} K and {timepoint} ns") 27 | 28 | print("Dataset 50...") 29 | d = pd.read_csv("HeatMap_RSF_vs_TIME_50Samples_4Superfamilies.csv.gz") 30 | df = d.loc[(d.temp==T) & (d.time_points==timepoint), ] 31 | ks_pairwise_tests(df) 32 | mw_pairwise_tests(df) 33 | 34 | print("Dataset ALL...") 35 | d = pd.read_csv("HeatMap_RSF_vs_TIME_NoneSamples_4Superfamilies.csv.gz") 36 | df = d.loc[(d.temp==T) & (d.time_points==timepoint), ] 37 | ks_pairwise_tests(df) 38 | mw_pairwise_tests(df) 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /analysis/ks_tests.out: -------------------------------------------------------------------------------- 1 | Comparing at 450 K and 400 ns 2 | Dataset 50... 3 | KS test between superfamily 1 and 2: p = 6.581573872996166e-38 4 | KS test between superfamily 1 and 3: p = 1.7307902029125603e-14 5 | KS test between superfamily 1 and 4: p = 5.372525605655362e-07 6 | KS test between superfamily 2 and 3: p = 7.681808528159001e-21 7 | KS test between superfamily 2 and 4: p = 3.515704846981462e-15 8 | KS test between superfamily 3 and 4: p = 0.00043792185579115856 9 | MW test between superfamily 1 and 2: p = 1.0457841378617276e-28 10 | MW test between superfamily 1 and 3: p = 6.4703084412835936e-18 11 | MW test between superfamily 1 and 4: p = 1.204967764513355e-05 12 | MW test between superfamily 2 and 3: p = 3.364412437589386e-15 13 | MW test between superfamily 2 and 4: p = 5.512181339610675e-16 14 | MW test between superfamily 3 and 4: p = 0.021690045008652396 15 | Dataset ALL... 16 | KS test between superfamily 1 and 2: p = 0.0 17 | KS test between superfamily 1 and 3: p = 0.0 18 | KS test between superfamily 1 and 4: p = 4.311088393615429e-08 19 | KS test between superfamily 2 and 3: p = 0.0 20 | KS test between superfamily 2 and 4: p = 9.618254579953343e-35 21 | KS test between superfamily 3 and 4: p = 1.5269634056171598e-10 22 | MW test between superfamily 1 and 2: p = 0.0 23 | MW test between superfamily 1 and 3: p = 0.0 24 | MW test between superfamily 1 and 4: p = 1.3523629171669599e-05 25 | MW test between superfamily 2 and 3: p = 3.0645888323936354e-283 26 | MW test between superfamily 2 and 4: p = 2.1610578472697972e-34 27 | MW test between superfamily 3 and 4: p = 9.648918941532586e-09 28 | -------------------------------------------------------------------------------- /analysis/plot_from_h5.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import math 3 | import numpy as np 4 | from tqdm import tqdm 5 | import seaborn as sns 6 | import json 7 | from os.path import join as opj 8 | import matplotlib.pyplot as plt 9 | import matplotlib.colors as mcolors 10 | from mpl_toolkits.axes_grid1.inset_locator import inset_axes 11 | from utils import * 12 | 13 | 14 | 15 | if __name__ == "__main__": 16 | output_dir = "figures/" 17 | h5metrics = h5py.File("../generator/process/h5files/mdcath_analysis.h5", "r") 18 | 19 | plot_len_trajs(h5metrics, output_dir) 20 | plot_numAtoms(h5metrics, output_dir) 21 | plot_numResidues(h5metrics, output_dir) 22 | plot_RMSD(h5metrics, output_dir, rmsdcutoff=10, yscale="linear") 23 | plot_RMSF(h5metrics, output_dir, yscale="linear", temp_oi=None) 24 | plot_numRes_trajLength(h5metrics, output_dir) 25 | plot_GyrRad_SecondaryStruc(h5metrics, output_dir, numSamples=6, shared_axes=False, plot_type=['A', 'B']) 26 | plot_solidFraction_RMSF(h5metrics, output_dir, numSamples=3, simplified=True, repl='1') 27 | plot_solidFraction_vs_numResidues(h5metrics, output_dir, mean_across='all', temps=None, simplified=True) 28 | plot_heatmap_ss_time_superfamilies(h5metrics, output_dir, mean_across='all', temps=None, num_pdbs=None, simplified=True) 29 | plot_ternary_superfamilies(h5metrics, output_dir, mean_across='all', temps=None, num_pdbs=None, cbar=True) 30 | plot_combine_metrics(h5metrics, output_dir) 31 | plot_maxNumNeighbors(h5metrics, output_dir, cutoff=['5A']) 32 | scatterplot_maxNumNeighbors_numNoHAtoms(h5metrics, output_dir, cutoff=['5A', '9A']) 33 | plot_numNoHAtoms(h5metrics, output_dir) 34 | -------------------------------------------------------------------------------- /user-utils/README.md: -------------------------------------------------------------------------------- 1 | # Helper functions for the mdCATH dataset 2 | 3 | 4 | ## 1. Command-line conversion 5 | 6 | Converts an mdCATH HDF5 file to PDB and XTC. 7 | 8 | #### Usage 9 | 10 | ```bash 11 | convert_mdCATH.py [-h] [--basename BASENAME] [--temp_list TEMP_LIST [TEMP_LIST ...]] [--replica_list REPLICA_LIST [REPLICA_LIST ...]] fn 12 | ``` 13 | 14 | Requires the `mdtraj` and `h5py` packages. 15 | 16 | 17 | 18 | ## 2. VMD/TCL 19 | 20 | The `load_mdCATH` VMD/TCL procedure loads molecular dynamics (MD) simulation data from a specified HDF5 file of the mdCATH dataset into VMD. 21 | 22 | **Note:** The utilities `h5ls` and `h5dump` are required and must be accessible in the system's path. 23 | 24 | 25 | #### Usage 26 | ```tcl 27 | load_mdCATH filename temperature replica 28 | ``` 29 | 30 | #### Parameters 31 | 32 | - `filename`: Path to the HDF5 file containing the MD simulation data. 33 | - `temperature`: Temperature (K) of the simulation data to load (320, 348, 379, 413, or 450). 34 | - `replica`: Identifier of the simulation replica (0 to 4). 35 | 36 | 37 | #### Return Values 38 | - On successful execution, the function sets up the molecular visualization with the loaded data but does not return a value. 39 | - On failure, returns an error with a specific message detailing the cause of the failure. 40 | 41 | 42 | 43 | #### Example Call 44 | ```tcl 45 | source load_mdCATH.tcl 46 | load_mdCATH cath_dataset_153lA00.h5 320 1 47 | ``` 48 | 49 | This call loads the MD simulation data from `path/to/simulation.h5` for the simulation run at 320 Kelvin and the first replica. 50 | 51 | #### Notes 52 | - Ensure that the environment variable `TMPDIR` is set as it is used to define temporary file paths. 53 | - This procedure assumes that the required utilities `h5ls` and `h5dump` are installed and accessible in the system's path. 54 | 55 | 56 | 57 | ## 3. Python 58 | 59 | Python functions are also provided to convert mdCATH HDF5 into [HTMD](https://software.acellera.com/htmd/index.html)/[MoleculeKit](https://software.acellera.com/moleculekit/index.html) and [MDTraj](https://www.mdtraj.org) trajectory objects, for further analysis. See docstrings inside `convert_mdCATH.py` for usage. 60 | -------------------------------------------------------------------------------- /generator/prepare/run.py: -------------------------------------------------------------------------------- 1 | """ 2 | HTMD version used: 1.16 3 | Note: CATH files don't have HETATM 4 | wget ftp://orengoftp.biochem.ucl.ac.uk/cath/releases/latest-release/non-redundant-data-sets/cath-dataset-nonredundant-S20.pdb.tgz 5 | tar -zxvf cath-dataset-nonredundant-S20.pdb.tgz 6 | """ 7 | 8 | import os 9 | import ray 10 | import glob 11 | from htmd.ui import * 12 | 13 | 14 | def getBBsize(m): 15 | Xmin = np.amin(m.coords[:, :, 0], axis=0) 16 | Xmax = np.amax(m.coords[:, :, 0], axis=0) 17 | return Xmax - Xmin 18 | 19 | @ray.remote 20 | def cbuild(pdb): 21 | # Get a cube fitting the bounding box of the given molecule + at least R per side 22 | def getCubicSize(m, R=9.0): 23 | box = getBBsize(m) 24 | rmax = np.max(box) 25 | rmax_h = rmax / 2.0 + R 26 | cminmax = np.array([[-rmax_h, -rmax_h, -rmax_h], [rmax_h, rmax_h, rmax_h]]) 27 | return cminmax 28 | 29 | try: 30 | m = Molecule(f"dompdb/{pdb}", type="pdb") 31 | m.center() 32 | nRes = len(np.unique(m.resid)) 33 | if nRes < 50 or nRes > 500: 34 | raise ValueError(f"Domain has {nRes} residues, out of range.") 35 | cminmax = getCubicSize(m) 36 | if -2.0 * cminmax[0, 0] > 100: 37 | raise ValueError(f"Cubic box {-2*cminmax[0,0]} too large, out of range.") 38 | mp = proteinPrepare(m, pH=7.0) 39 | ms = autoSegment(mp) 40 | mw = solvate(ms, minmax=cminmax) 41 | charmm.build( 42 | mw, 43 | topo=["top/top_all22star_prot.rtf", "top/top_water_ions.rtf"], 44 | param=["par/par_all22star_prot.prm", "par/par_water_ions.prm"], 45 | outdir=f"build/{pdb}", 46 | saltconc=0.150, 47 | ) 48 | return f"{pdb}: OK" 49 | except Exception as e: 50 | return f"{pdb}: {e}" 51 | 52 | 53 | if __name__ == "__main__": 54 | # https://colab.research.google.com/github/ray-project/tutorial/blob/master/exercises/colab01-03.ipynb#scrollTo=IlrIrAyldfu4 55 | ray.init() 56 | 57 | pdblist = glob.glob("dompdb/???????") 58 | pdblist = [os.path.basename(p) for p in pdblist] 59 | out = [cbuild.remote(pdb) for pdb in pdblist] 60 | outf = ray.get(out) 61 | 62 | with open("build_failures.log", "w") as f: 63 | f.write("\n".join(outf)) 64 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mdCATH Dataset Repository 2 | 3 | Welcome to the mdCATH dataset repository! This repository houses all the scripts and notebooks utilized for generating, analyzing, and validating the mdCATH dataset. The dataset is available on the Hugging Face platform. All mdCATH trajectories can be directly visualized on PlayMolecule without needing to download, or alternatively download them in XTC format from PlayMolecule if needed. 4 | 5 | ## Useful Links 6 | - Playmolecule: https://open.playmolecule.org/mdcath
7 | - Hugging Face: https://huggingface.co/datasets/compsciencelab/mdCATH 8 | 9 | ## Repository Structure 10 | 11 | - #### `user` 12 | - Provides tutorials and example scripts to help new users familiarize themselves with the dataset. 13 | - Step-by-step tutorials to guide users through common tasks and procedures using the dataset. 14 | - Example scripts that demonstrate practical applications of the dataset in research scenarios. 15 | 16 | - #### `user-utils` 17 | - TCL code to load mdCATH's HDF5 files in VMD (for end-users) 18 | - Python code to convert files to XTC format (for end-users) 19 | 20 | - #### `generator` 21 | - Directory with the scripts used to generate the dataset. 22 | - `builder/generator.py`: is the main script responsible for dataset creation. It processes a list of CATH domains and their molecular dynamics outputs to produce H5 files for the mdCATH dataset. It features multiprocessing to accelerate the dataset generation process. For each domain, an H5 file is created accompanied by a log file that records the progress. 23 | 24 | - #### `analysis` 25 | - Houses tools required for analyzing the dataset. 26 | - This directory includes various scripts and functions used to perform the analyses and generate the plots presented in the paper. 27 | 28 | 29 | ## Citation 30 | 31 | > Antonio Mirarchi, Toni Giorgino and Gianni De Fabritiis. *mdCATH: A Large-Scale MD Dataset for Data-Driven Computational Biophysics*. https://arxiv.org/abs/2407.14794 32 | 33 | ``` 34 | @misc{mirarchi2024mdcathlargescalemddataset, 35 | title={mdCATH: A Large-Scale MD Dataset for Data-Driven Computational Biophysics}, 36 | author={Antonio Mirarchi and Toni Giorgino and Gianni De Fabritiis}, 37 | year={2024}, 38 | eprint={2407.14794}, 39 | archivePrefix={arXiv}, 40 | primaryClass={q-bio.BM}, 41 | url={https://arxiv.org/abs/2407.14794}, 42 | } 43 | ``` 44 | -------------------------------------------------------------------------------- /user-utils/load_mdCATH.tcl: -------------------------------------------------------------------------------- 1 | proc load_mdCATH {fn temperature replica} { 2 | 3 | set pid [pid] 4 | 5 | # Try to execute h5ls and handle errors 6 | set status [catch {exec h5ls $fn} tmp] 7 | if {$status} { 8 | return -code error "Error executing h5ls on file $fn: $tmp" 9 | } 10 | set code [lindex $tmp 0] 11 | 12 | set tmpdir $::env(TMPDIR) 13 | set pdbname $tmpdir/loadmdcath.$pid.pdb 14 | 15 | # Handle potential errors from h5dump 16 | if {[catch {exec h5dump -b -o $pdbname -d /$code/pdbProteinAtoms $fn} result]} { 17 | return -code error "Error dumping pdbProteinAtoms from $fn: $result" 18 | } 19 | 20 | # Load the molecular data 21 | mol new $pdbname 22 | file delete $pdbname 23 | 24 | set N [molinfo top get numatoms] 25 | if {$N == 0} { 26 | return -code error "No atoms found in the molecule loaded from $pdbname" 27 | } 28 | 29 | animate delete all 30 | set cbin $tmpdir/loadmdcath.$pid.coords.bin 31 | if {[catch {exec h5dump -b -o $cbin -d /$code/sims${temperature}/$replica/coords $fn} result]} { 32 | return -code error "Error dumping coords from $fn: $result" 33 | } 34 | 35 | # Handle file opening and binary data reading 36 | if {[catch {open $cbin r} fp msg]} { 37 | return -code error "Error opening coordinates file $cbin: $msg" 38 | } 39 | fconfigure $fp -translation binary 40 | if {[catch {read $fp} cdat]} { 41 | close $fp 42 | return -code error "Error reading data from coordinates file $cbin" 43 | } 44 | close $fp 45 | file delete $cbin 46 | 47 | # Binary data processing 48 | set M [binary scan $cdat f* dat] 49 | if {$M == 0} { 50 | return -code error "Failed to scan binary data from $cbin" 51 | } 52 | 53 | set L [llength $dat] 54 | set T [expr {$L/$N/3.0}] 55 | set N3 [expr {$N * 3}] 56 | set N3m1 [expr {$N3-1}] 57 | 58 | puts "Assuming $T frames" 59 | 60 | set a [atomselect top all] 61 | 62 | for {set t 0} {$t<$T} {incr t} { 63 | animate dup top 64 | set xyz {} 65 | set fcoor [lrange $dat 0 $N3m1] 66 | set dat [lreplace $dat 0 $N3m1] 67 | foreach {x y z} $fcoor { 68 | lappend xyz [list $x $y $z] 69 | } 70 | $a set {x y z} $xyz 71 | $a update 72 | } 73 | $a delete 74 | mol rename top "mdCATH: $code $temperature $replica" 75 | 76 | # TODO set box 77 | } 78 | -------------------------------------------------------------------------------- /generator/builder/trajManager.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as opj 3 | from glob import glob 4 | import logging 5 | 6 | # The exception cases are the trajectories that not have a dir in gpugrid_extend_results, so gpugrid_run_results is used 7 | exception_cases = {"2k88A00_413_0", "3vsmA01_320_4", "3qdkA02_320_2", "3qdkA02_413_0", "4qxdB02_450_2"} 8 | 9 | class TrajectoryFileManager: 10 | def __init__(self, gpugridResultsPath, concatTrajPath): 11 | """Initialize the TrajectoryFileManager object, which is responsible for managing the trajectory files. 12 | Parameters 13 | ---------- 14 | gpugridResultsPath: str 15 | The path to the gpugrid results directory. 16 | concatTrajPath: str or None 17 | The path to the concatenated trajectory files. 18 | """ 19 | self.logger = logging.getLogger("TrajectoryFileManager") 20 | self.gpugridResultsPath = gpugridResultsPath 21 | self.gpugridRunResults = "/workspace7/toni_cath/gpugrid_run_results/" 22 | self.concatTrajPath = concatTrajPath 23 | 24 | def getTrajFiles(self, pdbname, temp, repl): 25 | """Get the trajectory files from the input directory, and if concatTrajPath is 26 | not None, look for the concatenated trajectory file directly. Also in concaTrajPath the file are for replica 27 | Parameters 28 | ---------- 29 | pdbname: str 30 | The PDB name. 31 | temp: str 32 | The temperature. 33 | repl: int 34 | The replica number, to retrieve the corresponding trajectory file. 35 | 36 | Returns 37 | ------- 38 | list 39 | The list of trajectory files (xtc files). 40 | """ 41 | basename = f"{pdbname}_{temp}_{repl}" 42 | if self.concatTrajPath: 43 | trajFiles = sorted( 44 | glob(opj(self.concatTrajPath, pdbname, f"{basename}.xtc")) 45 | ) 46 | if len(trajFiles) > 0: 47 | return trajFiles 48 | self.logger.info( 49 | f"No concatenated trajectory files found for {pdbname} at {temp}K" 50 | ) 51 | 52 | alltrajs = [] 53 | if basename not in exception_cases: 54 | trajs = sorted( 55 | glob(opj(self.gpugridResultsPath, basename, f"{basename}*.xtc")) 56 | ) 57 | else: 58 | trajs = sorted( 59 | glob(opj(self.gpugridRunResults, basename, f"{basename}*.xtc")) 60 | ) 61 | alltrajs.extend(trajs) 62 | 63 | assert len(alltrajs) > 0, "No trajectory files found" 64 | alltrajs = self.orderTrajFiles(alltrajs) 65 | return alltrajs 66 | 67 | def orderTrajFiles(self, trajFiles): 68 | """Order the trajectory files by the traj index. 69 | Parameters 70 | ---------- 71 | trajFiles: list 72 | The list of trajectory files. 73 | """ 74 | sortDict = {} 75 | for traj in trajFiles: 76 | filename = os.path.basename(traj) 77 | trajid = int(filename.split("-")[-3]) 78 | sortDict[trajid] = traj 79 | return [sortDict[trajid] for trajid in sorted(sortDict.keys())] 80 | -------------------------------------------------------------------------------- /generator/process/tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def readPDBs(pdbList): 4 | if isinstance(pdbList, list): 5 | return pdbList 6 | pdblist = [] 7 | with open(pdbList, "r") as f: 8 | for line in f: 9 | pdblist.append(line.strip()) 10 | return sorted(pdblist) 11 | 12 | def get_secondary_structure_compositions(dssp): 13 | '''This funtcion returns the percentage composition of alpha, beta and coil in the protein. 14 | A special "NA" code will be assigned to each "residue" in the topology which isn"t actually 15 | a protein residue (does not contain atoms with the names "CA", "N", "C", "O") 16 | ''' 17 | floatMap = {"H": 0, "B": 1, "E": 1, "G": 0, "I": 0, "T": 2, "S": 2, " ": 2, 'NA': 3} 18 | 19 | decoded_dssp = [el.decode() for el in dssp[0]] 20 | float_dssp = np.array([floatMap[el] for el in decoded_dssp]) 21 | unique, counts = np.unique(float_dssp, return_counts=True) 22 | numResAlpha, numResBeta, numResCoil = 0, 0, 0 23 | for u, c in zip(unique, counts): 24 | if u == 0: 25 | numResAlpha += c 26 | elif u == 1: 27 | numResBeta += c 28 | else: 29 | # NA or Coil 30 | numResCoil += c 31 | # percentage composition in alpha, beta and coil 32 | alpha_comp = (numResAlpha / np.sum(counts)) * 100 33 | beta_comp = (numResBeta / np.sum(counts)) * 100 34 | coil_comp = (numResCoil / np.sum(counts)) * 100 35 | 36 | return alpha_comp, beta_comp, coil_comp 37 | 38 | def get_max_neighbors(coords, distance): 39 | """This function computes the maximum number of neighbors for all the conformations in a replica using a distance threshold, 40 | Parameters: 41 | coords: np.array, shape=(num_frames, num_atoms, 3) 42 | distance: float, the distance threshold to consider two atoms as neighbors 43 | Returns: 44 | max_neighbors: int, the maximum number of neighbors found in the replica 45 | """ 46 | from scipy.spatial import cKDTree 47 | 48 | max_neighbors = 0 49 | for i in range(coords.shape[0]): 50 | tree = cKDTree(coords[i]) 51 | # Query the tree to find neighbors within the specified distance 52 | num_neighbors = tree.query_ball_tree(tree, distance) 53 | # Get the maximum number of neighbors for this conformation 54 | max_neighbors = max(max_neighbors, max(len(n) for n in num_neighbors)) 55 | return max_neighbors 56 | 57 | def get_solid_secondary_structure(dssp): 58 | """ This function returns the percentage of solid secondary structure in the protein, computed as 59 | the sum of alpha and beta residues over the total number of residues.""" 60 | floatMap = {"H": 0, "B": 1, "E": 1, "G": 0, "I": 0, "T": 2, "S": 2, " ": 2, 'NA': 3} 61 | decoded_dssp = [el.decode() for el in dssp] 62 | float_dssp = np.array([floatMap[el] for el in decoded_dssp]) 63 | unique, counts = np.unique(float_dssp, return_counts=True) 64 | numResAlpha, numResBeta, numResCoil = 0, 0, 0 65 | for u, c in zip(unique, counts): 66 | if u == 0: 67 | numResAlpha += c 68 | elif u == 1: 69 | numResBeta += c 70 | else: 71 | # NA or Coil 72 | numResCoil += c 73 | 74 | solid_secondary_structure = (numResAlpha+numResBeta)/np.sum(counts) 75 | return solid_secondary_structure -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # format to ignore 132 | *.npz 133 | *.npy 134 | *.csv 135 | *.png 136 | *.hdf5 137 | *.h5 138 | *.ckpt 139 | *.xyz 140 | *.traj 141 | *.json 142 | *.pdb 143 | *.yaml 144 | *.h5 145 | *.png 146 | *.log 147 | *.txt 148 | *.tiff 149 | benchmark/* 150 | generators/force_eval.py 151 | fastfolder_metadata.ipynb 152 | htmdMetrics/multipleMSM.sh 153 | fix_corrupted_pdbs/* 154 | analysis/structure_h5.ipynb 155 | analysis/check_mdtraj_functions.ipynb 156 | analysis/check_data.ipynb 157 | analysis/plot_metrics_from_csv.py 158 | test_dcdLoaders/* 159 | analysis/note_reproducibility.ipynb 160 | analysis/recover_corrupted_files_from_log.ipynb 161 | build_noh/* 162 | check_reproducibility/* 163 | fixer/* 164 | support/get_superfamily_distr.py 165 | process/append_info_from_joined.py 166 | support/getCATH_info.py 167 | test/test_force_agg.py 168 | analysis/domains_mdCATH.ipynb 169 | analysis/note.ipynb 170 | -------------------------------------------------------------------------------- /user/0_get_file_from_hugging_face.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### TUTORIAL TO GET mdCATH H5 FILES" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "Use the HuggingFace API to download the mdCATH dataset in H5 format directly from the HuggingFace Hub. In this notebook, we demonstrate how to use the API module for this purpose. Alternatively, you can download the dataset by appending the desired filename to the base URL: `https://huggingface.co/datasets/compsciencelab/mdCATH/resolve/main/`" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 3, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import os\n", 24 | "from os.path import join as opj\n", 25 | "from huggingface_hub import HfApi\n", 26 | "from huggingface_hub import hf_hub_download\n", 27 | "from huggingface_hub import hf_hub_url" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 2, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "# Initialize the API\n", 37 | "api = HfApi()" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 8, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "# Output directory\n", 47 | "data_root = '.'" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 9, 53 | "metadata": {}, 54 | "outputs": [ 55 | { 56 | "data": { 57 | "application/vnd.jupyter.widget-view+json": { 58 | "model_id": "15a7d808356d4c428da37c4bde5217de", 59 | "version_major": 2, 60 | "version_minor": 0 61 | }, 62 | "text/plain": [ 63 | "mdcath_dataset_1r9lA02.h5: 0%| | 0.00/645M [00:00 {data[pdb][temperature][replica].attrs['numFrames']} frames\")\n", 69 | " print()" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "available_repls = list(data[pdb][temp].keys())\n", 79 | "print(f'Available replicas ({temp}K): {available_repls}')" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "for k in data[pdb][temp][repl].attrs.keys():\n", 89 | " print(f'trajectory {pdb}/{temp}/{repl} -->', k, data[pdb][temp][repl].attrs[k])" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "for dat in data[pdb][temp][repl].keys():\n", 99 | " print(dat, f'shape -> {data[pdb][temp][repl][dat].shape}')" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "if \"mdcath_analysis\" in file_name:\n", 109 | " ssd = data[pdb][temp][repl][\"solid_secondary_structure\"][:]\n", 110 | " gyration = data[pdb][temp][repl][\"gyration_radius\"][:]\n", 111 | " rmsf = data[pdb][temp][repl][\"rmsf\"][:]\n", 112 | " rmsd = data[pdb][temp][repl][\"rmsd\"][:]\n" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "if \"mdcath_analysis\" in file_name:\n", 122 | " fig = plt.figure(figsize=(6, 4))\n", 123 | " plt.title(f\"{pdb} - {temp}K - {repl}\\nRMSD\")\n", 124 | " plt.plot(rmsd)\n", 125 | " plt.xlabel(\"Frame\")\n", 126 | " plt.ylabel(\"RMSD (nm)\")\n", 127 | " plt.tight_layout()\n", 128 | " plt.show()" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "if \"mdcath_analysis\" in file_name:\n", 138 | " fig, axs = plt.subplots(1, 4, figsize=(25, 5))\n", 139 | " axs[0].plot(ssd)\n", 140 | " axs[0].set_ylabel(\"Solid Secondary Structure\")\n", 141 | " axs[0].set_xlabel(\"frame\")\n", 142 | " axs[0].set_title('Solid Secondary Structure')\n", 143 | "\n", 144 | " axs[1].plot(gyration)\n", 145 | " axs[1].set_ylabel(\"Gyration Radius (nm)\")\n", 146 | " axs[1].set_xlabel(\"frame\")\n", 147 | " axs[1].set_title('Gyration Radius')\n", 148 | "\n", 149 | " axs[2].scatter(ssd, gyration)\n", 150 | " axs[2].set_xlabel(\"Solid Secondary Structure\")\n", 151 | " axs[2].set_ylabel(\"Gyration Radius (nm)\")\n", 152 | " axs[2].set_title('SSS vs GR')\n", 153 | "\n", 154 | " axs[3].plot(rmsf)\n", 155 | " axs[3].set_ylabel(\"RMSF (nm)\")\n", 156 | " axs[3].set_xlabel(\"residue\")\n", 157 | " axs[3].set_title('RMSF')\n", 158 | " fig.suptitle(f\"{pdb}\", fontsize=16)\n", 159 | "\n", 160 | " plt.subplots_adjust(wspace=0.25, top=0.85)\n", 161 | " plt.show()" 162 | ] 163 | } 164 | ], 165 | "metadata": { 166 | "kernelspec": { 167 | "display_name": "gemini2", 168 | "language": "python", 169 | "name": "python3" 170 | }, 171 | "language_info": { 172 | "codemirror_mode": { 173 | "name": "ipython", 174 | "version": 3 175 | }, 176 | "file_extension": ".py", 177 | "mimetype": "text/x-python", 178 | "name": "python", 179 | "nbconvert_exporter": "python", 180 | "pygments_lexer": "ipython3", 181 | "version": "3.10.13" 182 | } 183 | }, 184 | "nbformat": 4, 185 | "nbformat_minor": 2 186 | } 187 | -------------------------------------------------------------------------------- /generator/process/append_info_toh5.py: -------------------------------------------------------------------------------- 1 | # Use this script if you want to append information to an existing source/analysis file. 2 | # The script will append or modify the information in the source/analysis file. It works per domain-ID, 3 | # so all the temperatures and replicas will be update. 4 | # Use this if a small number of domains need to be updated, otherwise use the write_info_toh5.py script (multiprocessing supported). 5 | 6 | import sys 7 | import h5py 8 | import logging 9 | import numpy as np 10 | from tqdm import tqdm 11 | from os.path import join as opj 12 | from tools import get_secondary_structure_compositions, get_max_neighbors, get_solid_secondary_structure, readPDBs 13 | sys.path.append("/../builder/") 14 | 15 | # Create a custom logger 16 | logger = logging.getLogger('append_info_toh5') 17 | logger.setLevel(logging.INFO) # Set the logger level to the lowest level you want to log 18 | 19 | # Create handlers 20 | console_handler = logging.StreamHandler() 21 | file_handler = logging.FileHandler('append_to_mdcath_analysis.log') 22 | 23 | # Set levels for handlers 24 | console_handler.setLevel(logging.ERROR) # Only log errors to the console 25 | file_handler.setLevel(logging.INFO) # Log info and higher level messages to the file 26 | 27 | # Create formatters and add them to handlers 28 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 29 | console_handler.setFormatter(formatter) 30 | file_handler.setFormatter(formatter) 31 | 32 | # Add handlers to the logger 33 | logger.addHandler(console_handler) 34 | logger.addHandler(file_handler) 35 | 36 | 37 | if __name__ == '__main__': 38 | # Define the h5 file for which the information will be modified 39 | origin_file = 'mdcath_analysis.h5' 40 | data_dir = "PATH/TO/MDCATH/DATASET/DIR" 41 | pdb_list = ['1cqzB02', '4i69A00', '3qdkA02'] 42 | 43 | # Define the type of file to be written, source or analysis 44 | # Based on this different attributes will be written 45 | file_type = 'analysis' 46 | noh_mode = False 47 | pdb_list = readPDBs(pdb_list) 48 | if file_type == 'analysis': 49 | to_recheck = open('log_doms_torecheck_mdcath_analysis_update.txt', 'a') 50 | basename = 'mdcath_noh' if noh_mode else 'mdcath' 51 | 52 | with h5py.File(opj('h5files', origin_file), mode='a', libver='latest') as dest: 53 | for dom in tqdm(pdb_list, total=len(pdb_list)): 54 | source_file = f"{basename}_dataset_{dom}.h5" 55 | if dom in dest: 56 | del dest[dom] 57 | 58 | dom_group = dest.create_group(dom) 59 | with h5py.File(opj(data_dir, source_file), 'r') as source: 60 | dom_group.attrs['numResidues'] = source[dom].attrs['numResidues'] 61 | dom_group.attrs['numProteinAtoms'] = source[dom].attrs['numProteinAtoms'] 62 | dom_group.attrs['numChains'] = source[dom].attrs['numChains'] 63 | dom_group.attrs['numNoHAtoms'] = len([el for el in source[dom]['z'][:] if el != 1]) 64 | availample_temps = [t for t in ['320', '348', '379', '413', '450'] if t in source[dom].keys()] 65 | for temp in availample_temps: 66 | temp_group = dom_group.create_group(temp) 67 | for replica in source[dom][temp]: 68 | repl_group = temp_group.create_group(replica) 69 | if 'numFrames' not in source[dom][temp][replica].attrs.keys(): 70 | logger.error(f"numFrames not found in {dom} {temp} {replica}") 71 | continue 72 | 73 | repl_group.attrs['numFrames'] = source[dom][temp][replica].attrs['numFrames'] 74 | 75 | if file_type == 'analysis': 76 | assert noh_mode == False, "Analysis file cannot be created for noh dataset" 77 | repl_group.create_dataset('gyration_radius', data = source[dom][temp][replica]['gyrationRadius'][:]) 78 | repl_group.create_dataset('rmsd', data = source[dom][temp][replica]['rmsd'][:]) 79 | repl_group.create_dataset('rmsf', data = source[dom][temp][replica]['rmsf'][:]) 80 | repl_group.create_dataset('box', data = source[dom][temp][replica]['box'][:]) 81 | 82 | try: 83 | solid_secondary_structure = np.zeros(source[dom][temp][replica]['dssp'].shape[0]) 84 | for i in range(source[dom][temp][replica]['dssp'].shape[0]): 85 | solid_secondary_structure[i] = get_solid_secondary_structure(source[dom][temp][replica]['dssp'][i]) 86 | 87 | repl_group.create_dataset('solid_secondary_structure', data=solid_secondary_structure) 88 | except Exception as e: 89 | logger.error(f"Error in {dom} {temp} {replica}") 90 | logger.error(e) 91 | to_recheck.write(f"{dom} {temp} {replica}\n") 92 | continue 93 | 94 | 95 | elif file_type == 'source': 96 | if noh_mode: 97 | repl_group.attrs['max_num_neighbors_5A'] = get_max_neighbors(source[dom][temp][replica]['coords'][:], 5.5) # use 5.5 for confidence on the 5A 98 | repl_group.attrs['max_num_neighbors_9A'] = get_max_neighbors(source[dom][temp][replica]['coords'][:], 9.5) # use 9.5 for confidence on the 9A 99 | 100 | # The noh dataset does not have the dssp information, to store it in the source file we need to read the dssp from the original dataset 101 | with h5py.File(opj('/workspace8/antoniom/mdcath_htmd', dom, f"mdcath_dataset_{dom}.h5"), "r") as ref_h5: 102 | repl_group.attrs['min_gyration_radius'] = np.min(ref_h5[dom][temp][replica]['gyrationRadius'][:]) 103 | repl_group.attrs['max_gyration_radius'] = np.max(ref_h5[dom][temp][replica]['gyrationRadius'][:]) 104 | 105 | alpha_comp, beta_comp, coil_comp = get_secondary_structure_compositions(ref_h5[dom][temp][replica]['dssp']) 106 | 107 | repl_group.attrs['alpha'] = alpha_comp 108 | repl_group.attrs['beta'] = beta_comp 109 | repl_group.attrs['coil'] = coil_comp 110 | else: 111 | repl_group.attrs['min_gyration_radius'] = np.min(source[dom][temp][replica]['gyrationRadius'][:]) 112 | repl_group.attrs['max_gyration_radius'] = np.max(source[dom][temp][replica]['gyrationRadius'][:]) 113 | 114 | alpha_comp, beta_comp, coil_comp = get_secondary_structure_compositions(source[dom][temp][replica]['dssp']) 115 | 116 | repl_group.attrs['alpha'] = alpha_comp 117 | repl_group.attrs['beta'] = beta_comp 118 | repl_group.attrs['coil'] = coil_comp 119 | 120 | logger.info(f"Successfully updated information for {dom}") -------------------------------------------------------------------------------- /user-utils/convert_mdCATH.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import tempfile 3 | import numpy as np 4 | import argparse 5 | import os 6 | 7 | MDCATH_PICOSECONDS_PER_FRAME = 1000. 8 | 9 | def _open_h5_file(h5): 10 | if isinstance(h5, str): 11 | h5 = h5py.File(h5, "r") 12 | code = [_ for _ in h5][0] 13 | return h5, code 14 | 15 | 16 | def _extract_structure_and_coordinates(h5, code, temp, replica): 17 | """ 18 | Extracts the structure in PDB format and coordinates from an H5 file based on temperature and replica. 19 | 20 | Parameters: 21 | h5 : h5py.File 22 | An opened H5 file object containing protein structures and simulation data. 23 | code : str 24 | The identifier for the dataset in the H5 file. 25 | temp : int or float 26 | The temperature (in Kelvin). 27 | replica : int 28 | The replica number. 29 | 30 | Returns: 31 | tuple 32 | A tuple containing the PDB data as bytes, coordinates as a numpy array, and box as a numpy vector. 33 | """ 34 | with tempfile.NamedTemporaryFile(suffix=".pdb", delete=False) as pdbfile: 35 | pdb = h5[code]["pdbProteinAtoms"][()] 36 | pdbfile.write(pdb) 37 | pdbfile.flush() 38 | coords = h5[code][f"{temp}"][f"{replica}"]["coords"][:] 39 | box = h5[code][f"{temp}"][f"{replica}"]["box"][:] 40 | coords = coords / 10.0 41 | return pdbfile.name, coords, box 42 | 43 | 44 | def convert_to_mdtraj(h5, temp, replica): 45 | """ 46 | Convert data from an H5 file to an MDTraj trajectory object. 47 | 48 | This function extracts the first protein atom structure and coordinates 49 | for a given temperature and replica from an H5 file and creates an MDTraj 50 | trajectory object. This object can be used for further molecular dynamics 51 | analysis. 52 | 53 | Parameters: 54 | h5 : h5py.File 55 | An opened H5 file object containing protein structures and simulation data. 56 | temp : int or float 57 | The temperature (in Kelvin) at which the simulation was run. This is used 58 | to select the corresponding dataset within the H5 file. 59 | replica : int 60 | The replica number of the simulation to extract data from. This is used 61 | to select the corresponding dataset within the H5 file. 62 | 63 | Returns: 64 | md.Trajectory 65 | An MDTraj trajectory object containing the loaded protein structure and 66 | simulation coordinates. 67 | 68 | Example: 69 | ------- 70 | import h5py 71 | import mdtraj as md 72 | 73 | # Open the H5 file 74 | with h5py.File('simulation_data.h5', 'r') as h5file: 75 | traj = convert_to_mdtraj(h5file, 300, 1) 76 | 77 | # Now 'traj' can be used for analysis with MDTraj 78 | """ 79 | import mdtraj as md 80 | 81 | h5, code = _open_h5_file(h5) 82 | pdb_file_name, coords, box = _extract_structure_and_coordinates(h5, code, temp, replica) 83 | top = md.load(pdb_file_name).topology 84 | os.unlink(pdb_file_name) 85 | nframes = coords.shape[0] 86 | uc_lengths = np.repeat(box.diagonal()[None,:], nframes, axis=0) 87 | uc_angles = np.repeat(np.array([90.,90.,90.])[None,:], nframes, axis=0) 88 | trj = md.Trajectory(coords.copy(), 89 | topology=top, 90 | time=np.arange(1, coords.shape[0] + 1)*MDCATH_PICOSECONDS_PER_FRAME, 91 | unitcell_lengths = uc_lengths, 92 | unitcell_angles = uc_angles 93 | ) 94 | return trj 95 | 96 | 97 | def convert_to_moleculekit(h5, temp, replica): 98 | """ 99 | Convert data from an H5 file to a MoleculeKit/HTMD trajectory object. 100 | 101 | This function extracts the first protein atom structure and coordinates 102 | for a given temperature and replica from an H5 file and creates an MDTraj 103 | trajectory object. This object can be used for further molecular dynamics 104 | analysis. 105 | 106 | Parameters: 107 | h5 : h5py.File 108 | An opened H5 file object containing protein structures and simulation data. 109 | temp : int or float 110 | The temperature (in Kelvin) at which the simulation was run. This is used 111 | to select the corresponding dataset within the H5 file. 112 | replica : int 113 | The replica number of the simulation to extract data from. This is used 114 | to select the corresponding dataset within the H5 file. 115 | 116 | Returns: 117 | moleculekit.molecule.Molecule 118 | A Molecule object containing the loaded protein structure and 119 | simulation coordinates. 120 | 121 | Example: 122 | ------- 123 | import h5py 124 | import moleculekit as mk 125 | 126 | # Open the H5 file 127 | with h5py.File('simulation_data.h5', 'r') as h5file: 128 | traj = convert_to_moleculekit(h5file, 300, 1) 129 | 130 | # Now 'traj' can be used for analysis with HTMD 131 | """ 132 | 133 | import moleculekit.molecule as mk 134 | 135 | h5, code = _open_h5_file(h5) 136 | pdb_file_name, coords, box = _extract_structure_and_coordinates(h5, code, temp, replica) 137 | trj = mk.Molecule(pdb_file_name, name=f"{code}_{temp}_{replica}") 138 | os.unlink(pdb_file_name) 139 | nframes = coords.shape[0] 140 | uc_lengths = np.repeat(box.diagonal()[None,:], nframes, axis=0) 141 | trj.coords = coords.transpose([1, 2, 0]).copy() 142 | trj.time = np.arange(1, coords.shape[0] + 1) 143 | trj.box = uc_lengths.T * 10.0 144 | 145 | # TODO? .step, .numframes 146 | return trj 147 | 148 | 149 | def convert_to_files( 150 | fn, basename=None, temp_list=[320, 348, 379, 413, 450], replica_list=[0, 1, 2, 3, 4] 151 | ): 152 | """ 153 | Converts data from an H5 file to separate PDB and XTC files based on specified temperatures and replicas. 154 | 155 | This function reads protein atom structures and simulation data from an H5 file and writes a single PDB file 156 | and multiple XTC files. Each XTC file corresponds to a specific temperature and replica combination. The 157 | function uses `convert_to_mdtraj` to generate MDTraj trajectory objects which are then saved in the XTC format. 158 | 159 | Parameters: 160 | fn : str 161 | The file name or path to the H5 file containing the simulation data. 162 | basename : str 163 | The base name to use for output files. If None, it is taken from the domain ID. 164 | temp_list : list of int, optional 165 | A list of temperatures (in Kelvin) for which the simulations were run. Defaults to [320, 348, 379, 413, 450]. 166 | replica_list : list of int, optional 167 | A list of replica numbers to extract data for. Defaults to [0, 1, 2, 3, 4]. 168 | 169 | Outputs: 170 | Creates a PDB file named `{basename}.pdb` and multiple XTC files named `{basename}_{temp}_{replica}.xtc`, 171 | where `{temp}` and `{replica}` are values from `temp_list` and `replica_list`. 172 | 173 | Example: 174 | ------- 175 | # Convert data to files with base name 'protein_simulation' 176 | convert_to_files('simulation_data.h5', 'protein_simulation') 177 | """ 178 | 179 | h5, code = _open_h5_file(fn) 180 | 181 | if not basename: 182 | basename = code 183 | 184 | pdbpath = f"{basename}.pdb" 185 | with open(pdbpath, "wb") as pdbfile: 186 | pdb = h5[code]["pdbProteinAtoms"][()] 187 | pdbfile.write(pdb) 188 | print(f"Wrote {pdbpath}") 189 | 190 | for temp in temp_list: 191 | for replica in replica_list: 192 | xtcpath = f"{basename}_{temp}_{replica}.xtc" 193 | trj = convert_to_mdtraj(h5, temp, replica) 194 | trj.save_xtc(xtcpath) 195 | print(f"Wrote {xtcpath}") 196 | 197 | 198 | def main(): 199 | parser = argparse.ArgumentParser( 200 | description="Convert H5 file data to PDB and XTC files." 201 | ) 202 | parser.add_argument( 203 | "fn", 204 | type=str, 205 | help="File name or path to the H5 file containing simulation data.", 206 | ) 207 | parser.add_argument( 208 | "--basename", 209 | type=str, 210 | help="Base name for output files, defaults to domain ID from H5 file.", 211 | default=None, 212 | ) 213 | parser.add_argument( 214 | "--temp_list", 215 | type=int, 216 | nargs="+", 217 | help="List of temperatures.", 218 | default=[320, 348, 379, 413, 450], 219 | ) 220 | parser.add_argument( 221 | "--replica_list", 222 | type=int, 223 | nargs="+", 224 | help="List of replicas.", 225 | default=[0, 1, 2, 3, 4], 226 | ) 227 | 228 | args = parser.parse_args() 229 | 230 | convert_to_files(args.fn, args.basename, args.temp_list, args.replica_list) 231 | 232 | 233 | if __name__ == "__main__": 234 | main() 235 | -------------------------------------------------------------------------------- /generator/process/write_info_toh5.py: -------------------------------------------------------------------------------- 1 | # This script reads the mdCATH dataset (h5 files) and writes the information to a new h5 file (source or analysis). 2 | # It includes multiprocessing to speed up the process (batch processing). 3 | 4 | import os 5 | import sys 6 | import h5py 7 | import math 8 | import shutil 9 | import logging 10 | import tempfile 11 | import numpy as np 12 | from tqdm import tqdm 13 | import concurrent.futures 14 | from os.path import join as opj 15 | from tools import get_secondary_structure_compositions, get_max_neighbors, get_solid_secondary_structure, readPDBs 16 | sys.path.append("/../builder/") 17 | from scheduler import ComputationScheduler 18 | 19 | 20 | logger = logging.getLogger('writer') 21 | logger.setLevel(logging.INFO) 22 | # all the error messages will be written in the error.log file 23 | fh = logging.FileHandler('error.log') 24 | fh.setLevel(logging.ERROR) 25 | logger.addHandler(fh) 26 | 27 | class Payload: 28 | def __init__(self, scheduler, data_dir, output_dir='.', file_type='source', noh=False): 29 | self.scheduler = scheduler 30 | self.data_dir = data_dir 31 | self.output_dir = output_dir 32 | self.file_type = file_type 33 | self.noh = noh 34 | 35 | def runComputation(self, batch_idx): 36 | logger.info(f"Batch {batch_idx} started") 37 | run(self.scheduler, batch_idx, self.data_dir, self.output_dir, self.file_type, self.noh) 38 | 39 | def run(scheduler, batch_idx, data_dir, output_dir='.', file_type='source', noh=False): 40 | """Extract information from the mdCATH dataset and write them to a h5 file per batch 41 | Parameters: 42 | scheduler: ComputationScheduler 43 | the scheduler object that will process the batch 44 | batch_idx: int 45 | the index of the batch to process 46 | data_dir: str 47 | the path to the directory containing the mdCATH dataset 48 | output_dir: str 49 | the path to the directory where the h5 files will be written 50 | file_type: str 51 | the type of file to be written: source or analysis 52 | noh: bool 53 | if True, the information will be extracted from the noh dataset 54 | """ 55 | pdb_idxs = scheduler.process(batch_idx) 56 | basename = 'mdcath_noh' if noh else 'mdcath' 57 | file_name = f"{basename}_{file_type}_{batch_idx}.h5" 58 | resfile = opj(output_dir, file_name) 59 | 60 | with tempfile.NamedTemporaryFile() as tmp: 61 | tmp_file = tmp.name 62 | with h5py.File(tmp_file, "w", libver='latest') as h5: 63 | for i, pdb in tqdm(enumerate(pdb_idxs), total=len(pdb_idxs), desc=f"processing batch {batch_idx}"): 64 | h5_file = opj(data_dir, f"{basename}_dataset_{pdb}.h5") 65 | if not os.path.exists(h5_file): 66 | logger.error(f"File {h5_file} does not exist") 67 | continue 68 | 69 | group = h5.create_group(pdb) 70 | with h5py.File(h5_file, "r") as origin: 71 | group.attrs['numResidues'] = origin[pdb].attrs['numResidues'] 72 | group.attrs['numProteinAtoms'] = origin[pdb].attrs['numProteinAtoms'] 73 | group.attrs['numChains'] = origin[pdb].attrs['numChains'] 74 | group.attrs['numNoHAtoms'] = len([el for el in origin[pdb]['z'][:] if el != 1]) 75 | availample_temps = [t for t in ['320', '348', '379', '413', '450'] if t in origin[pdb].keys()] 76 | for temp in availample_temps: 77 | temp_group = group.create_group(temp) 78 | for replica in origin[pdb][temp]: 79 | repl_group = temp_group.create_group(replica) 80 | if 'numFrames' not in origin[pdb][temp][replica].attrs.keys(): 81 | logger.error(f"numFrames not found in {pdb} {temp} {replica}") 82 | continue 83 | 84 | repl_group.attrs['numFrames'] = origin[pdb][temp][replica].attrs['numFrames'] 85 | 86 | if file_type == 'analysis': 87 | assert noh == False, "Analysis file cannot be created for noh dataset" 88 | repl_group.create_dataset('gyration_radius', data = origin[pdb][temp][replica]['gyrationRadius'][:]) 89 | repl_group.create_dataset('rmsd', data = origin[pdb][temp][replica]['rmsd'][:]) 90 | repl_group.create_dataset('rmsf', data = origin[pdb][temp][replica]['rmsf'][:]) 91 | repl_group.create_dataset('box', data = origin[pdb][temp][replica]['box'][:]) 92 | solid_secondary_structure = np.zeros(origin[pdb][temp][replica]['dssp'].shape[0]) 93 | for i in range(origin[pdb][temp][replica]['dssp'].shape[0]): 94 | solid_secondary_structure[i] = get_solid_secondary_structure(origin[pdb][temp][replica]['dssp'][i]) 95 | 96 | repl_group.create_dataset('solid_secondary_structure', data=solid_secondary_structure) 97 | 98 | elif file_type == 'source': 99 | if noh: 100 | repl_group.attrs['max_num_neighbors_5A'] = get_max_neighbors(origin[pdb][temp][replica]['coords'][:], 5.5) # use 5.5 for confidence on the 5A 101 | repl_group.attrs['max_num_neighbors_9A'] = get_max_neighbors(origin[pdb][temp][replica]['coords'][:], 9.5) # use 9.5 for confidence on the 9A 102 | 103 | # The noh dataset does not have the dssp information, to store it in the source file we need to read the dssp from the original dataset 104 | with h5py.File(opj('/workspace3/mdcath', f"mdcath_dataset_{pdb}.h5"), "r") as ref_h5: 105 | repl_group.attrs['min_gyration_radius'] = np.min(ref_h5[pdb][temp][replica]['gyrationRadius'][:]) 106 | repl_group.attrs['max_gyration_radius'] = np.max(ref_h5[pdb][temp][replica]['gyrationRadius'][:]) 107 | 108 | alpha_comp, beta_comp, coil_comp = get_secondary_structure_compositions(ref_h5[pdb][temp][replica]['dssp']) 109 | 110 | repl_group.attrs['alpha'] = alpha_comp 111 | repl_group.attrs['beta'] = beta_comp 112 | repl_group.attrs['coil'] = coil_comp 113 | else: 114 | repl_group.attrs['min_gyration_radius'] = np.min(origin[pdb][temp][replica]['gyrationRadius'][:]) 115 | repl_group.attrs['max_gyration_radius'] = np.max(origin[pdb][temp][replica]['gyrationRadius'][:]) 116 | 117 | alpha_comp, beta_comp, coil_comp = get_secondary_structure_compositions(origin[pdb][temp][replica]['dssp']) 118 | 119 | repl_group.attrs['alpha'] = alpha_comp 120 | repl_group.attrs['beta'] = beta_comp 121 | repl_group.attrs['coil'] = coil_comp 122 | 123 | 124 | shutil.copyfile(tmp_file, resfile) 125 | 126 | def launch(): 127 | data_dir = "PATH/TO/MDCATH/DATASET/DIR" 128 | output_dir = "batch_files" 129 | pdb_list_file = '../mdcath_domains.txt' 130 | # Define the type of file to be written, source or analysis 131 | # Based on this different attributes will be written 132 | file_type = 'source' 133 | noh_mode = True 134 | pdb_list = readPDBs(pdb_list_file) 135 | batch_size = 250 136 | toRunBatches = None 137 | startBatch = None 138 | max_workers = 24 139 | 140 | os.makedirs(output_dir, exist_ok=True) 141 | # Get a number of batches 142 | numBatches = int(math.ceil(len(pdb_list) / batch_size)) 143 | logger.info(f"Batch size: {batch_size}") 144 | logger.info(f"Number of total batches: {numBatches}") 145 | 146 | if toRunBatches is not None and startBatch is not None: 147 | numBatches = toRunBatches + startBatch 148 | elif toRunBatches is not None: 149 | numBatches = toRunBatches 150 | elif startBatch is not None: 151 | pass 152 | 153 | # Initialize the parallelization system 154 | scheduler = ComputationScheduler(batch_size, startBatch, numBatches, pdb_list) 155 | toRunBatches = scheduler.getBatches() 156 | logger.info(f"numBatches to run: {len(toRunBatches)}") 157 | 158 | payload = Payload(scheduler, data_dir, output_dir, file_type, noh_mode) 159 | 160 | with concurrent.futures.ProcessPoolExecutor(max_workers) as executor: 161 | try: 162 | results = list( 163 | tqdm( 164 | executor.map(payload.runComputation, toRunBatches), 165 | total=len(toRunBatches), 166 | ) 167 | ) 168 | except Exception as e: 169 | print(e) 170 | raise e 171 | # this return it's needed for the tqdm progress bar 172 | return results 173 | 174 | if __name__ == "__main__": 175 | launch() 176 | -------------------------------------------------------------------------------- /user/2_mdCATH_ML.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### MDCATH DATASET IN MACHINE LEARNING FRAMEWORK" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "This tutorial provides a practical example of training ML models using the mdCATH dataset in TorchMD-Net. Before you begin, please ensure that TorchMD-Net is correctly installed. You can find installation instructions and further details [here](https://torchmd-net.readthedocs.io/en/latest/installation.html). Note that the MDCATH dataloader is available starting from TorchMD-Net version 2.4.0 and later. " 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [ 22 | { 23 | "name": "stderr", 24 | "output_type": "stream", 25 | "text": [ 26 | "/shared/antoniom/mambaforge/envs/mdcath_torchmdnet/lib/python3.12/site-packages/torchmdnet/extensions/__init__.py:150: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.\n", 27 | " impl_abstract(\n", 28 | "/shared/antoniom/mambaforge/envs/mdcath_torchmdnet/lib/python3.12/site-packages/torchmdnet/extensions/__init__.py:153: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.\n", 29 | " impl_abstract(\n" 30 | ] 31 | } 32 | ], 33 | "source": [ 34 | "import os\n", 35 | "import torch\n", 36 | "import lightning.pytorch as pl\n", 37 | "from torchmdnet.data import DataModule\n", 38 | "from torchmdnet.module import LNNP\n", 39 | "from torchmdnet.scripts.train import get_args\n", 40 | "from lightning.pytorch.callbacks import TQDMProgressBar, ModelCheckpoint\n", 41 | "from lightning.pytorch.loggers import CSVLogger" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 2, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "# Define the arguments\n", 51 | "args = get_args() # default arguments by tmdnet\n", 52 | "args = vars(args) # convert to dictionary\n", 53 | "\n", 54 | "pargs = {\n", 55 | " # DATA\n", 56 | " 'dataset': 'MDCATH',\n", 57 | " 'dataset_arg':{\n", 58 | " 'numAtoms': None,\n", 59 | " 'numResidues': None,\n", 60 | " 'pdb_list': ['1balA00', '1ce3A00', '1e8rA00'],\n", 61 | " 'temperatures': ['348'],\n", 62 | " 'skip_frames': 2,\n", 63 | " 'solid_ss': None,\n", 64 | " },\n", 65 | " 'dataset_root': 'data',\n", 66 | " # MODEL\n", 67 | " 'model': 'tensornet',\n", 68 | " 'embedding_dimension': 32,\n", 69 | " 'num_layers': 0,\n", 70 | " 'num_rbf': 8,\n", 71 | " 'rbf_type': 'expnorm',\n", 72 | " 'activation': 'silu',\n", 73 | " 'cutoff_lower': 0.0,\n", 74 | " 'cutoff_upper': 5.0,\n", 75 | " 'max_z': 20,\n", 76 | " 'num_epochs': 10,\n", 77 | " 'max_num_neighbors': 48,\n", 78 | " 'derivative': True, \n", 79 | " # TRAIN\n", 80 | " 'batch_size': 3,\n", 81 | " 'train_size': 200, \n", 82 | " 'val_size': 50,\n", 83 | " 'test_size': 100,\n", 84 | " 'lr': 1e-3,\n", 85 | " 'lr_metric': 'val',\n", 86 | " 'log_dir': 'logs/',\n", 87 | " 'check_errors': True,\n", 88 | " 'static_shapes': False,\n", 89 | " 'num_workers': 2,\n", 90 | "}\n", 91 | "\n", 92 | "# Update the default arguments with the new ones\n", 93 | "args.update(pargs)\n", 94 | "os.makedirs(args['log_dir'], exist_ok=True)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 3, 100 | "metadata": {}, 101 | "outputs": [ 102 | { 103 | "name": "stderr", 104 | "output_type": "stream", 105 | "text": [ 106 | "Processing mdcath source: 100%|██████████| 3/3 [00:00<00:00, 13.53it/s]" 107 | ] 108 | }, 109 | { 110 | "name": "stdout", 111 | "output_type": "stream", 112 | "text": [ 113 | "train 200, val 50, test 100\n" 114 | ] 115 | }, 116 | { 117 | "name": "stderr", 118 | "output_type": "stream", 119 | "text": [ 120 | "\n", 121 | "/shared/antoniom/mambaforge/envs/mdcath_torchmdnet/lib/python3.12/site-packages/torchmdnet/utils.py:221: UserWarning: 2970 samples were excluded from the dataset\n", 122 | " rank_zero_warn(f\"{dset_len - total} samples were excluded from the dataset\")\n" 123 | ] 124 | } 125 | ], 126 | "source": [ 127 | "# Here MDCATH torch_geometric dataset class is used \n", 128 | "# If the h5 files are not present in the 'dataset_root' then they will be downloaded from HF\n", 129 | "# The downlaoad process can take some time\n", 130 | "\n", 131 | "data = DataModule(args)\n", 132 | "data.prepare_data()\n", 133 | "data.setup(\"fit\")" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 4, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "# Lightning wrapper for the Neural Network Potentials in TorchMD-Net\n", 143 | "lnnp = LNNP(args, \n", 144 | " prior_model=None, \n", 145 | " mean=data.mean, \n", 146 | " std=data.std)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 5, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "# Callbacks, used to save model ckpts\n", 156 | "val_loss_name = 'val_total_mse_loss'\n", 157 | "checkpoint_callback = ModelCheckpoint(dirpath=args['log_dir'], \n", 158 | " monitor=val_loss_name, \n", 159 | " every_n_epochs=2, \n", 160 | " filename=f\"epoch={{epoch}}-val_loss={{{val_loss_name}:.4f}}\",\n", 161 | " save_top_k=3)" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 6, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "# Logger for the training process, it will save the training logs in a csv file\n", 171 | "csv_logger = CSVLogger(args['log_dir'], name=\"\", version=\"\")" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": 7, 177 | "metadata": {}, 178 | "outputs": [ 179 | { 180 | "name": "stdout", 181 | "output_type": "stream", 182 | "text": [ 183 | "cuda available: True\n", 184 | "cuda device count: 1\n", 185 | "CUDA_VISIBLE_DEVICES ID: 0\n" 186 | ] 187 | } 188 | ], 189 | "source": [ 190 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\" \n", 191 | "print(f'cuda available: {torch.cuda.is_available()}')\n", 192 | "print(f'cuda device count: {torch.cuda.device_count()}')\n", 193 | "print(f'CUDA_VISIBLE_DEVICES ID: {os.environ[\"CUDA_VISIBLE_DEVICES\"]}')" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 11, 199 | "metadata": {}, 200 | "outputs": [ 201 | { 202 | "name": "stderr", 203 | "output_type": "stream", 204 | "text": [ 205 | "GPU available: True (cuda), used: True\n", 206 | "TPU available: False, using: 0 TPU cores\n", 207 | "HPU available: False, using: 0 HPUs\n", 208 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", 209 | "\n", 210 | " | Name | Type | Params | Mode \n", 211 | "----------------------------------------------\n", 212 | "0 | model | TorchMD_Net | 18.9 K | train\n", 213 | "----------------------------------------------\n", 214 | "18.9 K Trainable params\n", 215 | "0 Non-trainable params\n", 216 | "18.9 K Total params\n", 217 | "0.076 Total estimated model params size (MB)\n", 218 | "31 Modules in train mode\n", 219 | "0 Modules in eval mode\n" 220 | ] 221 | }, 222 | { 223 | "name": "stdout", 224 | "output_type": "stream", 225 | "text": [ 226 | "train 200, val 50, test 100\n", 227 | "Epoch 9: 100%|██████████| 67/67 [00:04<00:00, 16.14it/s] " 228 | ] 229 | }, 230 | { 231 | "name": "stderr", 232 | "output_type": "stream", 233 | "text": [ 234 | "`Trainer.fit` stopped: `max_epochs=10` reached.\n" 235 | ] 236 | }, 237 | { 238 | "name": "stdout", 239 | "output_type": "stream", 240 | "text": [ 241 | "Epoch 9: 100%|██████████| 67/67 [00:04<00:00, 16.03it/s]\n" 242 | ] 243 | } 244 | ], 245 | "source": [ 246 | "# Train\n", 247 | "trainer = pl.Trainer(strategy=\"auto\",\n", 248 | " devices=1,\n", 249 | " max_epochs=args['num_epochs'], \n", 250 | " precision=args['precision'],\n", 251 | " default_root_dir = args['log_dir'],\n", 252 | " logger=csv_logger,\n", 253 | " callbacks=[checkpoint_callback, TQDMProgressBar(refresh_rate=1)])\n", 254 | "\n", 255 | "trainer.fit(lnnp, data, ckpt_path=None)" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 12, 261 | "metadata": {}, 262 | "outputs": [ 263 | { 264 | "name": "stderr", 265 | "output_type": "stream", 266 | "text": [ 267 | "GPU available: True (cuda), used: True\n", 268 | "TPU available: False, using: 0 TPU cores\n", 269 | "HPU available: False, using: 0 HPUs\n", 270 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n" 271 | ] 272 | }, 273 | { 274 | "name": "stdout", 275 | "output_type": "stream", 276 | "text": [ 277 | "train 200, val 50, test 100\n", 278 | "Testing DataLoader 0: 100%|██████████| 4/4 [00:00<00:00, 9.54it/s]\n", 279 | "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n", 280 | " Test metric DataLoader 0\n", 281 | "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n", 282 | " test_neg_dy_l1_loss 4.174280643463135\n", 283 | " test_total_l1_loss 4.174280643463135\n", 284 | " test_y_l1_loss 0.0\n", 285 | "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n" 286 | ] 287 | }, 288 | { 289 | "data": { 290 | "text/plain": [ 291 | "[{'test_total_l1_loss': 4.174280643463135,\n", 292 | " 'test_y_l1_loss': 0.0,\n", 293 | " 'test_neg_dy_l1_loss': 4.174280643463135}]" 294 | ] 295 | }, 296 | "execution_count": 12, 297 | "metadata": {}, 298 | "output_type": "execute_result" 299 | } 300 | ], 301 | "source": [ 302 | "# Test\n", 303 | "model = LNNP.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)\n", 304 | "trainer = pl.Trainer(inference_mode=False)\n", 305 | "trainer.test(model, data)" 306 | ] 307 | } 308 | ], 309 | "metadata": { 310 | "kernelspec": { 311 | "display_name": "gemini2", 312 | "language": "python", 313 | "name": "python3" 314 | }, 315 | "language_info": { 316 | "codemirror_mode": { 317 | "name": "ipython", 318 | "version": 3 319 | }, 320 | "file_extension": ".py", 321 | "mimetype": "text/x-python", 322 | "name": "python", 323 | "nbconvert_exporter": "python", 324 | "pygments_lexer": "ipython3", 325 | "version": "3.12.6" 326 | } 327 | }, 328 | "nbformat": 4, 329 | "nbformat_minor": 2 330 | } 331 | -------------------------------------------------------------------------------- /generator/builder/generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["NUMEXPR_MAX_THREADS"] = "24" 4 | os.environ["OMP_NUM_THREADS"] = "24" 5 | import math 6 | import sys 7 | import h5py 8 | import shutil 9 | import argparse 10 | import logging 11 | import tempfile 12 | from glob import glob 13 | from tqdm import tqdm 14 | import concurrent.futures 15 | from os.path import join as opj 16 | from molAnalyzer import molAnalyzer 17 | from scheduler import ComputationScheduler 18 | from trajManager import TrajectoryFileManager 19 | from utils import readPDBs, save_argparse, LoadFromFile 20 | 21 | 22 | import warnings 23 | 24 | warnings.filterwarnings("ignore", category=DeprecationWarning, module="MDAnalysis") 25 | 26 | logging.basicConfig(level=logging.INFO) 27 | logger = logging.getLogger("builder") 28 | 29 | def check_readers(coords, forces, numTrajFiles): 30 | # Considering that each trajectory file has 10 frames, one frame save every 1ns 31 | if coords is None or forces is None: 32 | return False 33 | nframes = coords.shape[2] 34 | if nframes / 10 != numTrajFiles: 35 | return False 36 | else: 37 | return True 38 | 39 | def get_argparse(): 40 | parser = argparse.ArgumentParser( 41 | description="mdCATH dataset builder", prefix_chars="--" 42 | ) 43 | 44 | # fmt: off 45 | parser.add_argument('--conf', '-c', type=open, action=LoadFromFile, help='Configuration yaml file') # keep second 46 | parser.add_argument('--pdblist', help='Path to the list of accepted PDBs or a list of PDBs') 47 | parser.add_argument('--gpugridResultsPath', type=str, help='Path to GPU grid results') 48 | parser.add_argument('--gpugridInputsPath', type=str, help='Path to GPU grid inputs') 49 | parser.add_argument('--concatTrajPath', type=str, default=None, help='Path to concatenated trajectory') 50 | parser.add_argument('--finaldatasetPath', type=str, default='mdcath', help='Path to the final dataset') 51 | parser.add_argument('--temperatures', type=list, default=["320", "348", "379", "413", "450"], help='The simulation temperatures to consider') 52 | parser.add_argument('--numReplicas', type=int, default=1, help='Number of replicas, available for each temperature', choices=range(0,4,1)) 53 | parser.add_argument('--trajAttrs', type=list, default=['numFrames'], help='Trajectory attributes for each replica') 54 | parser.add_argument('--trajDatasets', type=list, default=['rmsd', 'gyrationRadius', 'rmsf', 'dssp'], help='Trajectory datasets for each replica') 55 | parser.add_argument('--pdbAttrs', type=list, default=['numProteinAtoms', 'numResidues', 'numChains'], help='PDB attributes, shared by temperatures and replicas') 56 | parser.add_argument('--pdbDatasets', type=list, default=['element', 'z', 'resname', 'resid', 'chain'], help='PDB datasets, shared by temperatures and replicas') 57 | parser.add_argument('--batchSize', type=int, default=1, help='batch size to use in the computation') 58 | parser.add_argument('--toRunBatches', type=int, default=None, help='Number of batches to run, if None all the batches will be run') 59 | parser.add_argument('--startBatch', type=int, default=None, help='Start batch, if None the first batch will be run') 60 | parser.add_argument('--endBatch', type=int, default=None, help='End batch, if None the last batch will be run') 61 | parser.add_argument('--maxWorkers', type=int, default=24, help='Number of workers to use in the multiprocessing') 62 | # fmt: on 63 | return parser 64 | 65 | 66 | def get_args(): 67 | parser = get_argparse() 68 | args = parser.parse_args() 69 | os.makedirs(args.finaldatasetPath, exist_ok=True) 70 | save_argparse( 71 | args, os.path.join(args.finaldatasetPath, "input.yaml"), exclude=["conf"] 72 | ) 73 | return args 74 | 75 | 76 | class Payload: 77 | def __init__(self, scheduler, args): 78 | self.scheduler = scheduler 79 | self.args = args 80 | 81 | def runComputation(self, batch_idx): 82 | logger.info(f"Batch {batch_idx} started") 83 | logger.info(f"OMP_NUM_THREADS= {os.environ.get('OMP_NUM_THREADS')}") 84 | run(self.scheduler, self.args, batch_idx) 85 | 86 | 87 | def run(scheduler, args, batch_idx): 88 | """Run the dataset generation for a specific batch. 89 | Parameters 90 | ---------- 91 | scheduler : Scheduler object 92 | The scheduler object is used to get the indices of the molecules to be processed in the batch, 93 | and to get the name of the file to be generated 94 | args: argparse.Namespace 95 | The arguments from the command line 96 | batch_idx: int 97 | The index of the batch to be processed 98 | """ 99 | pdb_idxs = scheduler.process(batch_idx) 100 | trajFileManager = TrajectoryFileManager( 101 | args.gpugridResultsPath, args.concatTrajPath 102 | ) 103 | desc = pdb_idxs[0] if len(pdb_idxs) == 1 else "reading PDBs" 104 | for pdb in tqdm(pdb_idxs, total=len(pdb_idxs), desc=desc): 105 | with tempfile.TemporaryDirectory() as temp: 106 | tmpFile = opj(temp, f"mdcath_dataset_{pdb}.h5") 107 | tmplogfile = tmpFile.replace(".h5", ".txt") 108 | 109 | resFile = opj(args.finaldatasetPath, pdb, f"mdcath_dataset_{pdb}.h5") 110 | if os.path.exists(resFile): 111 | logger.info( 112 | f"File {resFile} already exists, skipping batch {batch_idx} for {pdb}" 113 | ) 114 | continue 115 | logFile = opj(args.finaldatasetPath, pdb, f"log_{pdb}.txt") 116 | 117 | pdbLogger = logging.getLogger(f"builder_{pdb}") 118 | file_handler = logging.FileHandler(tmplogfile) 119 | file_handler.setLevel(logging.INFO) 120 | pdbLogger.addHandler(file_handler) 121 | pdbLogger.setLevel(logging.INFO) 122 | 123 | pdbLogger.info(f"Starting the dataset generation for {pdb} and batch {batch_idx}") 124 | 125 | pdbFilePath = glob(opj(args.gpugridInputsPath, pdb, "*/*.pdb"))[0] # get structure.pdb from input folder (same for all replicas and temps) 126 | if not os.path.exists(pdbFilePath): 127 | logger.warning(f"{pdb} does not exist") 128 | continue 129 | 130 | os.makedirs(os.path.dirname(resFile), exist_ok=True) 131 | 132 | with h5py.File(tmpFile, "w", libver='latest') as h5: 133 | 134 | h5.attrs["layout"] = "mdcath-only-protein-v1.0" 135 | pdbGroup = h5.create_group(pdb) 136 | Analyzer = molAnalyzer(pdbFilePath, file_handler, os.path.dirname(resFile)) 137 | Analyzer.computeProperties() 138 | 139 | for temp in args.temperatures: 140 | pdbTempGroup = pdbGroup.create_group(temp) 141 | pdbLogger.info( 142 | f"---------------------------------------------------" 143 | ) 144 | pdbLogger.info(f"Starting the analysis for {pdb} at {temp}K \n") 145 | for repl in range(args.numReplicas): 146 | pdbLogger.info(f"## REPLICA {repl} ##") 147 | pdbTempReplGroup = pdbTempGroup.create_group(str(repl)) 148 | try: 149 | trajFiles = trajFileManager.getTrajFiles(pdb, temp, repl) 150 | dcdFiles = [ 151 | f.replace("9.xtc", "8.vel.dcd") for f in trajFiles 152 | ] 153 | pdbLogger.info(f"numTrajFiles: {len(trajFiles)}") 154 | except AssertionError as e: 155 | pdbLogger.error(e) 156 | continue 157 | 158 | Analyzer.readXTC(trajFiles, batch_idx) 159 | Analyzer.readDCD(dcdFiles, batch_idx) 160 | 161 | status = check_readers(Analyzer.coords, Analyzer.forces, len(trajFiles)) # True if the number of frames is correct 162 | if not status: 163 | pdbLogger.error( 164 | f"Number of frames is not correct for {pdb}_{temp}_{repl} and batch {batch_idx}" 165 | ) 166 | pdbLogger.error(f"Fixing the readers") 167 | Analyzer.fix_readers(trajFiles, dcdFiles) 168 | 169 | Analyzer.trajAnalysis() 170 | 171 | # write the data to the h5 file for the replica 172 | Analyzer.write_toH5( 173 | molGroup=None, 174 | replicaGroup=pdbTempReplGroup, 175 | attrs=args.trajAttrs, 176 | datasets=args.trajDatasets, 177 | ) 178 | pdbLogger.info("\n") 179 | 180 | # If no replica was found, skip the molecule. The molecule will be written to the h5 file only if it has at least one replica at one temperature 181 | if not hasattr(Analyzer, "molAttrs"): 182 | pdbLogger.error( 183 | f"molAttrs not found for {pdb} and batch {batch_idx}" 184 | ) 185 | continue 186 | 187 | # write the data to the h5 file for the molecule 188 | Analyzer.write_toH5( 189 | molGroup=pdbGroup, 190 | replicaGroup=None, 191 | attrs=args.pdbAttrs, 192 | datasets=args.pdbDatasets, 193 | ) 194 | 195 | shutil.move(tmpFile, resFile) 196 | pdbLogger.info( 197 | f"\n{pdb} batch {batch_idx} completed successfully added to mdCATH dataset: {args.finaldatasetPath}" 198 | ) 199 | shutil.move(tmplogfile, logFile) 200 | 201 | 202 | def launch(): 203 | args = get_args() 204 | 205 | acceptedPDBs = readPDBs(args.pdblist) if args.pdblist else None 206 | if acceptedPDBs is None: 207 | logger.error( 208 | "Please provide a list of accepted PDBs which will be used to generate the dataset." 209 | ) 210 | sys.exit(1) 211 | 212 | logger.info(f"numAccepetedPDBs: {len(acceptedPDBs)}") 213 | 214 | # Get a number of batches 215 | numBatches = int(math.ceil(len(acceptedPDBs) / args.batchSize)) 216 | logger.info(f"Batch size: {args.batchSize}") 217 | logger.info(f"Number of total batches: {numBatches}") 218 | 219 | if args.toRunBatches is not None and args.startBatch is not None: 220 | numBatches = args.toRunBatches + args.startBatch 221 | elif args.toRunBatches is not None: 222 | numBatches = args.toRunBatches 223 | elif args.startBatch is not None: 224 | pass 225 | 226 | # Initialize the parallelization system 227 | scheduler = ComputationScheduler( 228 | args.batchSize, args.startBatch, numBatches, acceptedPDBs 229 | ) 230 | toRunBatches = scheduler.getBatches() 231 | logger.info(f"numBatches to run: {len(toRunBatches)}") 232 | logger.info(f"starting from batch: {args.startBatch}") 233 | 234 | payload = Payload(scheduler, args) 235 | 236 | error_domains = open("errors.txt", "w") 237 | results = [] 238 | with concurrent.futures.ProcessPoolExecutor(max_workers=args.maxWorkers) as executor: 239 | future_to_batch = {executor.submit(payload.runComputation, batch): batch for batch in toRunBatches} 240 | 241 | for future in tqdm(concurrent.futures.as_completed(future_to_batch), total=len(toRunBatches)): 242 | batch = future_to_batch[future] 243 | try: 244 | result = future.result() 245 | results.append(result) 246 | except Exception as e: 247 | error_domains.write(f"Batch {batch} failed with exception: {e}\n") 248 | # Optionally, log the error and continue with the next computation 249 | 250 | return results 251 | 252 | 253 | if __name__ == "__main__": 254 | launch() 255 | logger.info("mdCATH-DATASET BUILD COMPLETED!") 256 | -------------------------------------------------------------------------------- /generator/builder/molAnalyzer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | from moleculekit.molecule import Molecule 5 | from moleculekit.periodictable import periodictable 6 | from moleculekit.projections.metricrmsd import MetricRmsd 7 | from moleculekit.projections.metricgyration import MetricGyration 8 | from moleculekit.projections.metricfluctuation import MetricFluctuation 9 | from moleculekit.projections.metricsecondarystructure import MetricSecondaryStructure 10 | 11 | # get moleculekit.bondguesser logger and ignore it 12 | # moleculekit v 1.8.36 the get function uses atom_select in which by default guessBonds=True 13 | logging.getLogger("moleculekit.bondguesser").setLevel(logging.CRITICAL) 14 | 15 | ANGSTROM_TO_NM = 0.1 16 | RMSD_CUTOFF = 40 # nm 17 | 18 | def encodeDSSP(dssp): 19 | encodedDSSP = [] 20 | for i in range(len(dssp)): 21 | encodedDSSP.append([x.encode("utf-8") for x in dssp[i]]) 22 | return encodedDSSP 23 | 24 | def txt_toH5(txtfile, h5group, dataset_name="pdb"): 25 | """Write the content of the txt file to the h5 group as a dataset. 26 | Parameters 27 | ---------- 28 | txtfile : str 29 | The path to the txt file to be written in the h5 group (pdb or psf file) 30 | h5group : h5py.Group 31 | The group of the h5 file where the dataset will be written 32 | dataset_name : str 33 | The name of the dataset to be written in the h5 group, used just for pdb extension. It can be either "pdb" or "pdbProteinAtoms" 34 | """ 35 | if txtfile.endswith(".pdb"): 36 | with open(txtfile, "r") as pdb_file: 37 | if dataset_name == "pdb": 38 | pdbcontent = pdb_file.read() 39 | elif dataset_name == "pdbProteinAtoms": 40 | pdb_lines = [ 41 | line 42 | for line in pdb_file.readlines() 43 | if line.startswith("ATOM") or line.startswith("MODEL") 44 | ] 45 | pdbcontent = "".join(pdb_lines) 46 | h5group.create_dataset(dataset_name, data=pdbcontent.encode("utf-8")) 47 | 48 | elif txtfile.endswith(".psf"): 49 | with open(txtfile, "r") as psf_file: 50 | psfcontent = psf_file.read() 51 | h5group.create_dataset("psf", data=psfcontent.encode("utf-8")) 52 | else: 53 | raise ValueError(f"Unknown file type: {txtfile}") 54 | 55 | class molAnalyzer: 56 | def __init__(self, pdbFile, file_handler=None, processed_path="."): 57 | """MolAnalyzer class take care of the analysis of the molecule, it builds the molecule object and compute all a serires of properties 58 | this will be then used to generate a series othe h5dataset. 59 | Parameters 60 | ---------- 61 | pdbFile : str 62 | The path to the pdb file 63 | file_handler : logging.FileHandler 64 | The file handler to be used to write the log file 65 | processed_path : str 66 | The path where the processed files will be saved, in this case the filtered pdb file 67 | """ 68 | self.processed_path = processed_path 69 | self.molLogger = logging.getLogger("MolAnalyzer") 70 | if file_handler is not None: 71 | self.molLogger.addHandler(file_handler) 72 | #logging.getLogger("moleculekit").handlers = [self.molLogger] 73 | 74 | self.pdbFile = pdbFile 75 | self.pdbName = os.path.basename(pdbFile).split(".")[0] 76 | 77 | # the mol object is created from structure.pdb and structure.psf files which were used to start the simulation 78 | # all atom and solvent atoms are considered 79 | self.mol = Molecule(pdbFile) 80 | self.mol.read(pdbFile.replace(".pdb", ".psf")) 81 | self.proteinIdxs = self.mol.get("index", sel="protein") 82 | 83 | # to have retrocompatibility with initial version of mdcath the chain of the protein atoms is set to 0 84 | # TODO: change in future version of mdcath 85 | self.mol.chain[self.proteinIdxs] = "0" 86 | 87 | self.protein_mol = self.mol.copy() 88 | self.protein_mol.filter("protein") 89 | 90 | # Log the properties of the molecule 91 | self.molLogger.info('Molecule object created from pdb and psf files') 92 | self.molLogger.info(f"Number of residues: {self.protein_mol.numResidues}") 93 | self.molLogger.info(f"Number of atoms: {self.mol.numAtoms}") 94 | self.molLogger.info(f"Number of protein atoms: {len(self.proteinIdxs)}") 95 | 96 | self.pdb_filtered_name = ( 97 | f"{self.processed_path}/{self.pdbName}_protein_filter.pdb" 98 | ) 99 | self.protein_mol.write(self.pdb_filtered_name) 100 | 101 | def computeProperties( 102 | self, 103 | ): 104 | """Compute the properties of the molecule""" 105 | tmpmol = self.protein_mol.copy() 106 | 107 | # dataset 108 | self.molData = {} 109 | self.molData["chain"] = tmpmol.chain 110 | self.molData["resname"] = tmpmol.resname 111 | self.molData["resid"] = tmpmol.resid 112 | self.molData["element"] = tmpmol.element 113 | self.molData["z"] = np.array([periodictable[x].number for x in tmpmol.element]) 114 | ## attrs 115 | self.molAttrs = {} 116 | self.molAttrs["numProteinAtoms"] = tmpmol.numAtoms 117 | self.molAttrs["numResidues"] = tmpmol.numResidues 118 | self.molAttrs["numChains"] = len(set(list(self.molData["chain"]))) 119 | self.molAttrs["numBonds"] = tmpmol.numBonds 120 | 121 | def readXTC(self, xtcFiles, batch_idx): 122 | """Read the xtc trajectory files 123 | Parameters 124 | ---------- 125 | xtcFiles : list 126 | The list of xtc files to be read 127 | 128 | batch_idx : int 129 | The index of the batch to be used in the log file """ 130 | 131 | self.trajmol = self.mol.copy() 132 | try: 133 | self.trajmol.read(xtcFiles) 134 | self.trajmol.filter("protein") 135 | 136 | except (RuntimeError, ValueError, OSError) as e: 137 | self.molLogger.error( 138 | f"TRAJECTORY LOADING ERROR ON BATCH:{batch_idx} | SIM: {os.path.basename(xtcFiles[0]).split('-')[0]}" 139 | ) 140 | self.molLogger.error(e) 141 | self.coords = None 142 | return 143 | 144 | # COORDS 145 | self.coords = self.trajmol.coords.copy() # Angstrom (numAtoms, 3, numFrames) 146 | 147 | def readDCD(self, dcdFiles, batch_idx): 148 | dcdmol = self.mol.copy() 149 | try: 150 | dcdmol.read(dcdFiles) 151 | dcdmol.filter("protein") 152 | self.forces = dcdmol.coords.copy() # kcal/mol/Angstrom 153 | 154 | except (RuntimeError, ValueError, OSError) as e: 155 | self.molLogger.error( 156 | f"FORCE LOADING ERROR ON BATCH:{batch_idx} | SIM: {os.path.basename(dcdFiles[0]).split('-')[0]}" 157 | ) 158 | self.molLogger.error(e) 159 | self.forces = None 160 | return 161 | 162 | if self.coords is not None: 163 | if self.forces.shape != self.coords.shape: 164 | self.molLogger.warning( 165 | f"Forces {self.forces.shape} and Coords {self.coords.shape} shapes do not match" 166 | ) 167 | last_idx = min(self.forces.shape[2], self.coords.shape[2]) 168 | self.forces = self.forces[:, :, :last_idx] 169 | self.coords = self.coords[:, :, :last_idx] 170 | 171 | 172 | def trajAnalysis(self): 173 | """Perform the analysis of the trajectory""" 174 | 175 | if self.trajmol.numFrames != self.coords.shape[2]: 176 | # a mismatch between the number of frames in the trajectory and the number of frames in the coords 177 | # can be found since in readDCD we take the minimum between forces and coords 178 | self.trajmol.dropFrames(keep=np.arange(self.coords.shape[2])) 179 | 180 | self.trajAttrs = {} 181 | self.metricAnalysis = {} 182 | 183 | # first frame is used as reference for the rmsd 184 | # TODO: compute rmsd wrt to the input structure of the md-simulation (it's not the first frame of the trajectory) 185 | refmol = self.protein_mol.copy() 186 | refmol.coords = self.trajmol.coords[:, :, 0].copy()[:, :, np.newaxis] 187 | 188 | # RMSD 189 | # the rmsd is computed for the heavy atoms only wrt the first frame 190 | rmsd_metric = MetricRmsd( 191 | refmol=refmol, 192 | trajrmsdstr="protein and not element H", 193 | trajalnstr="protein and name CA", 194 | pbc=True, 195 | ) 196 | 197 | rmsd = rmsd_metric.project(self.trajmol) * ANGSTROM_TO_NM # shape (numFrames) [nm] 198 | rmsd_accepted_frames = np.where(rmsd < RMSD_CUTOFF)[0] 199 | if len(rmsd_accepted_frames) != self.trajmol.numFrames: 200 | self.molLogger.warning(f"RMSD cutoff {RMSD_CUTOFF} nm, {self.trajmol.numFrames - len(rmsd_accepted_frames)} frames were removed") 201 | self.molLogger.warning('coords and forces shapes were updated to match the rmsd_accepted_frames') 202 | self.coords = self.coords[:, :, rmsd_accepted_frames] 203 | self.forces = self.forces[:, :, rmsd_accepted_frames] 204 | 205 | self.metricAnalysis["rmsd"] = rmsd[rmsd_accepted_frames] 206 | 207 | self.trajmol.dropFrames(keep=rmsd_accepted_frames) 208 | 209 | # GYRATION RADIUS 210 | # gyration radius computed for the heay atoms only 211 | gyr_metric = MetricGyration(atomsel="not element H", refmol=refmol, 212 | trajalnsel='name CA', refalnsel='name CA', centersel='protein', pbc=True) 213 | 214 | # the gyr_metric projection output rg, rg_x, rg_y, rg_z. We take only the first column which is the radius of gyration average over the three dimensions 215 | # the dtype is set to float64 to have retrocompatibility with initial version of mdcath 216 | # TODO: make everything float32 in future version of mdcath 217 | self.metricAnalysis["gyrationRadius"] = (gyr_metric.project(self.trajmol)[:, 0] * ANGSTROM_TO_NM).astype(np.float32) # nm 218 | 219 | # RMSF 220 | # compute rmsf wrt their mean positions 221 | rmsf_metric = MetricFluctuation(atomsel="name CA") 222 | self.metricAnalysis["rmsf"] = (np.sqrt(np.mean(rmsf_metric.project(self.trajmol), axis=0)) * ANGSTROM_TO_NM).astype(np.float32) # nm 223 | 224 | # DSSP 225 | dssp_metric = MetricSecondaryStructure(sel="protein", simplified=False, integer=False) 226 | dssp = dssp_metric.project(self.trajmol) 227 | self.metricAnalysis["dssp"] = np.array(encodeDSSP(dssp)).astype(object) 228 | 229 | # BOX 230 | # the box has shape (3, numFrames), we take the first frame only 231 | box = self.trajmol.box.copy()[:, 0] * ANGSTROM_TO_NM # nm, shape (3,) 232 | self.box = np.diag(box) # shape (3, 3) 233 | 234 | def sanityCheck(self): 235 | """Sanity check on the shapes of the arrays""" 236 | numAtoms = self.protein_mol.numAtoms 237 | numResidues = self.protein_mol.numResidues 238 | 239 | # Fix coords and forces shapes to (numFrames, numAtoms, 3) 240 | if self.coords.shape[0] == numAtoms: 241 | self.coords = np.moveaxis(self.coords, -1, 0) 242 | if self.forces.shape[0] == numAtoms: 243 | self.forces = np.moveaxis(self.forces, -1, 0) 244 | 245 | numFrames = self.coords.shape[0] 246 | assert self.coords.shape == (numFrames, numAtoms, 3), f"Coords shape {self.coords.shape} does not match (numFrames, numAtoms, 3)" 247 | assert self.coords.shape == self.forces.shape, f"Shapes of coords {self.coords.shape} and forces {self.forces.shape} do not match" 248 | assert self.metricAnalysis["rmsd"].shape[0] == numFrames, f'rmsd shape {self.metricAnalysis["rmsd"].shape[0]} and numFrames {numFrames} do not match' 249 | assert self.metricAnalysis["gyrationRadius"].shape[0] == numFrames, f'gyrationRadius shape {self.metricAnalysis["gyrationRadius"].shape[0]} and numFrames {numFrames} do not match' 250 | assert self.metricAnalysis["rmsf"].shape[0] == numResidues, f'rmsf shape {self.metricAnalysis["rmsf"].shape[0]} and numResidues {numResidues} do not match' 251 | assert self.metricAnalysis["dssp"].shape[0] == numFrames, f'dssp shape {self.metricAnalysis["dssp"].shape[0]} and numFrames {numFrames} do not match' 252 | 253 | def write_toH5(self, molGroup, replicaGroup, attrs, datasets): 254 | """Write the data to the h5 file, according to the properties defined in the input for the dataset 255 | Parameters 256 | ---------- 257 | molGroup : h5py.Group 258 | The group of the molecule, this will be the parent group of the replicas 259 | so the properties of the molecule will be written here, they are shared among all replicas 260 | replicaGroup : h5py.Group 261 | The group of the replica, this will be the parent group of the properties of the replica, each replica has its own properties 262 | defined by trajectory analysis 263 | attrs: 264 | list of attributes to be written in the h5 group 265 | datasets: 266 | list of datasets to be written in the h5 group 267 | """ 268 | if molGroup is not None and replicaGroup is None: 269 | # write the pdb file to the h5 file 270 | txt_toH5(self.pdbFile, molGroup, dataset_name="pdb") 271 | # write the filtered pdb file to the h5 file 272 | txt_toH5(self.pdb_filtered_name, molGroup, dataset_name="pdbProteinAtoms") 273 | # write the psf file to the h5 file 274 | txt_toH5(self.pdbFile.replace(".pdb", ".psf"), molGroup) 275 | # mol attributes 276 | for key, value in self.molAttrs.items(): 277 | if key in attrs: 278 | molGroup.attrs[key] = value 279 | # mol datasets 280 | for key, value in self.molData.items(): 281 | if key in datasets: 282 | molGroup.create_dataset(key, data=value) 283 | 284 | elif molGroup is None and replicaGroup is not None: 285 | self.sanityCheck() 286 | # replica attributes 287 | replicaGroup.attrs["numFrames"] = self.coords.shape[0] 288 | # replica datasets 289 | for key, value in self.metricAnalysis.items(): 290 | if key in datasets: 291 | replicaGroup.create_dataset(key, data=value) 292 | if key == "dssp": 293 | continue # dssp does not have unit 294 | replicaGroup[key].attrs["unit"] = "nm" 295 | 296 | replicaGroup.create_dataset("box", data=self.box) 297 | 298 | # coords and forces are written here using mdtraj function 299 | replicaGroup.create_dataset("coords", data=self.coords) 300 | replicaGroup.create_dataset("forces", data=self.forces) 301 | 302 | self.molLogger.info(f'coords shape: {self.coords.shape}') 303 | self.molLogger.info(f'forces shape: {self.forces.shape}') 304 | 305 | # add units attributes 306 | replicaGroup["coords"].attrs["unit"] = "Angstrom" 307 | replicaGroup["forces"].attrs["unit"] = "kcal/mol/Angstrom" 308 | replicaGroup["box"].attrs["unit"] = "nm" 309 | 310 | else: 311 | self.molLogger.error("Only one of the two groups could be None") 312 | return 313 | 314 | def fix_readers(self, xtc_files, dcd_files): 315 | """ In some instances, the DCD or XTC files may be corrupted. Moleculekit can still read and load all 316 | the frames from a list, even if a corrupted file is included in the sequence. Simply cutting the frames 317 | to match the minimum frame count between coordinates and forces is not enough. This function addresses the 318 | issue by employing a for loop to systematically truncate the frames, ensuring that the XTC and DCD files 319 | ultimately contain the same number of frames. """ 320 | sanitize = False 321 | num_frames = 0 322 | 323 | for i, (xtc, dcd) in enumerate(zip(xtc_files, dcd_files)): 324 | fixmol = self.mol.copy() 325 | # xtc 326 | try: 327 | fixmol.read(xtc) 328 | xtc_frames = fixmol.numFrames 329 | except Exception as e: 330 | self.molLogger.error(f"Error reading xtc file: {xtc}") 331 | self.molLogger.error(e) 332 | sanitize = True 333 | 334 | # dcd 335 | try: 336 | fixmol.read(dcd) 337 | dcd_frames = fixmol.numFrames 338 | except Exception as e: 339 | self.molLogger.error(f"Error reading dcd file: {dcd}") 340 | self.molLogger.error(e) 341 | sanitize = True 342 | 343 | if xtc_frames != dcd_frames or sanitize == True: 344 | if sanitize == False: 345 | self.molLogger.info(f"Trajectory {i} has different number of frames: XTC [{xtc_frames}] vs DCD[{dcd_frames}]") 346 | self.molLogger.info(f'dcd file: {dcd}') 347 | self.molLogger.info(f'xtc file: {xtc}') 348 | last_frame = num_frames + min(xtc_frames, dcd_frames) 349 | last_file = i 350 | self.molLogger.info(f"Last frame: {last_frame}") 351 | 352 | else: 353 | last_frame = num_frames 354 | last_file = i-1 355 | # read the trajectory and filter only the protein atoms to get the correct number of frames also in trajAnalysis 356 | # if the exception occur in reading one by one, then we need to consder until the file i-1 357 | self.trajmol = self.mol.copy() 358 | self.trajmol.read(xtc_files[:last_file+1]) 359 | self.trajmol.filter("protein") 360 | self.trajmol.dropFrames(keep=np.arange(last_frame)) 361 | 362 | # fix coords and forces shapes, if None then the files need to be read again and the coords and forces will be updated 363 | if self.coords is not None: 364 | self.coords = self.coords[:, :, :last_frame] 365 | else: 366 | self.molLogger.error("Coords is None") 367 | self.coords = self.trajmol.coords.copy() 368 | 369 | if self.forces is not None: 370 | self.forces = self.forces[:, :, :last_frame] 371 | else: 372 | self.molLogger.error("Forces is None") 373 | self.dcdmol = self.mol.copy() 374 | self.dcdmol.read(dcd_files[:last_file+1]) 375 | self.dcdmol.filter("protein") 376 | self.dcdmol.dropFrames(keep=np.arange(last_frame)) 377 | self.forces = self.dcdmol.coords.copy() 378 | 379 | self.molLogger.info(f"FixReaders adjusted coords shape: {self.coords.shape}, forces shape: {self.forces.shape}, trajmol numFrames: {self.trajmol.numFrames}") 380 | 381 | break # stop the loop 382 | 383 | else: 384 | num_frames += xtc_frames 385 | return -------------------------------------------------------------------------------- /analysis/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import math 4 | import json 5 | import numpy as np 6 | import pandas as pd 7 | from tqdm import tqdm 8 | from os.path import join as opj 9 | import matplotlib.pyplot as plt 10 | import matplotlib.colors as mcolors 11 | from mpl_toolkits.axes_grid1.inset_locator import inset_axes 12 | 13 | # Set global plotting parameters 14 | fsize = 19 15 | plt.rcParams.update({'font.size': fsize, 16 | 'axes.labelsize': fsize-1, 17 | 'axes.titlesize': fsize+2, 18 | 'xtick.labelsize': fsize-2 , 19 | 'ytick.labelsize': fsize-2, 20 | 'legend.fontsize': fsize-2, 21 | 'figure.titlesize': fsize+4, 22 | }) 23 | 24 | def get_stats(data, metric=""): 25 | # print, mean, std, min, max, median for a specific metric 26 | print(f"Stats for mdCATH: {metric}") 27 | print("--------------------------") 28 | print(f"Mean: {np.mean(data)}") 29 | print(f"Std: {np.std(data)}") 30 | print(f"Min: {np.min(data)}") 31 | print(f"Max: {np.max(data)}") 32 | print(f"Median: {np.median(data)}") 33 | if metric in ["Trajectory length"]: 34 | print(f'Total time of simulation: {np.sum(data)*1e-6} ms') 35 | print(f'Total number of trajectories: {len(data)}') 36 | if metric in ["Number of atoms", "Number of residues"]: 37 | print(f"Total {metric}: {np.sum(data)}") 38 | if metric in ["RMSD", "Trajectory length"]: 39 | return 40 | print(" ") 41 | 42 | def plot_len_trajs(h5metrics, output_dir): 43 | data = [] 44 | for pdb in tqdm(h5metrics.keys(), total=len(h5metrics.keys()), desc="Trajectory length"): 45 | for temp in h5metrics[pdb].keys(): 46 | for repl in h5metrics[pdb][temp].keys(): 47 | data.append(h5metrics[pdb][temp][repl].attrs['numFrames']) 48 | get_stats(data, metric="Trajectory length") 49 | plt.figure() 50 | plt.hist(data, bins=50) 51 | plt.xlabel("Trajectory length (ns)") 52 | plt.ylabel("Counts") 53 | plt.tight_layout() 54 | plt.savefig(opj(output_dir, "traj_len.png"), dpi=600) 55 | 56 | def plot_numAtoms(h5metrics, output_dir): 57 | data = [] 58 | for pdb in tqdm(h5metrics.keys(), total=len(h5metrics.keys()), desc="Number of atoms"): 59 | data.append(h5metrics[pdb].attrs['numProteinAtoms']) 60 | get_stats(data, metric="Number of atoms") 61 | plt.figure() 62 | plt.hist(data, bins=50) 63 | plt.xlabel("Number of atoms") 64 | plt.ylabel("Counts") 65 | plt.tight_layout() 66 | plt.savefig(opj(output_dir, "num_atoms.png"), dpi=600) 67 | 68 | def plot_numResidues(h5metrics, output_dir): 69 | data = [] 70 | for pdb in tqdm(h5metrics.keys(), total=len(h5metrics.keys()), desc="Number of residues"): 71 | data.append(h5metrics[pdb].attrs['numResidues']) 72 | get_stats(data, metric="Number of residues") 73 | plt.figure() 74 | plt.hist(data, bins=50) 75 | plt.xlabel("Number of residues") 76 | plt.ylabel("Counts") 77 | plt.tight_layout() 78 | plt.savefig(opj(output_dir, "num_residues.png"), dpi=600) 79 | 80 | def plot_RMSD(h5metrics, output_dir, rmsdcutoff=10, yscale="linear"): 81 | # Compute RMSD distribution considering only the last frame of each trajectory 82 | 83 | data = [] 84 | for pdb in tqdm(h5metrics.keys(), total=len(h5metrics.keys()), desc="RMSD"): 85 | for temp in h5metrics[pdb].keys(): 86 | for repl in h5metrics[pdb][temp].keys(): 87 | rmsd = h5metrics[pdb][temp][repl]['rmsd'][-1] 88 | if rmsd > rmsdcutoff: 89 | print(f"RMSD above cutoff {rmsdcutoff}: {rmsd} nm for {pdb} at {temp} K and replica {repl}") 90 | continue 91 | data.append(h5metrics[pdb][temp][repl]['rmsd'][-1]) 92 | get_stats(data, metric="RMSD") 93 | plt.figure() 94 | plt.hist(data, bins=50) 95 | plt.xlabel("RMSD (nm)") 96 | plt.ylabel("Counts") 97 | if yscale == "log": 98 | plt.yscale("log") 99 | plt.tight_layout() 100 | plt.savefig(opj(output_dir, f"rmsd{'_log' if yscale == 'log' else ''}.png"), dpi=600) 101 | 102 | def plot_RMSF(h5metrics, output_dir, yscale="linear", temp_oi=None): 103 | data = [] 104 | temperatures = ['320', '348', '379', '413', '450'] if temp_oi is None else [temp_oi] 105 | for pdb in tqdm(h5metrics.keys(), total=len(h5metrics.keys()), desc="RMSF"): 106 | for temp in temperatures: 107 | for repl in h5metrics[pdb][temp].keys(): 108 | data.extend(h5metrics[pdb][temp][repl]['rmsf'][:]) 109 | plt.figure() 110 | plt.hist(data, bins=50) 111 | plt.xlabel("RMSF (nm)") 112 | plt.ylabel("Counts") 113 | if yscale == "log": 114 | plt.yscale("log") 115 | plt.tight_layout() 116 | plt.savefig(opj(output_dir, f"rmsf{'_log' if yscale == 'log' else ''}.png"), dpi=200) 117 | 118 | def plot_numRes_trajLength(h5metrics, output_dir): 119 | # Number of residues vs Trajectory length 120 | data = [] 121 | for pdb in tqdm(h5metrics.keys(), total=len(h5metrics.keys()), desc="numResidues vs trajLength"): 122 | for temp in h5metrics[pdb].keys(): 123 | for repl in h5metrics[pdb][temp].keys(): 124 | data.append([h5metrics[pdb][temp][repl].attrs['numResidues'], h5metrics[pdb][temp][repl].attrs['numFrames']]) 125 | 126 | data = np.array(data) 127 | plt.figure() 128 | plt.scatter(data[:, 0], data[:, 1]) 129 | plt.xlabel("Number of residues") 130 | plt.ylabel("Trajectory length (ns)") 131 | plt.tight_layout() 132 | plt.savefig(opj(output_dir, "numRes_trajLen.png"), dpi=600) 133 | 134 | def plot_GyrRad_SecondaryStruc(h5data, output_dir, numSamples=6, shared_axes=False, plot_type=['A']): 135 | ''' Select numSamples random keys from the h5 file and plot the gyration radius and secondary structure 136 | plot1: numSamples different pdbs, same temperature and replica (A) 137 | plot2: numSamples different temperatures, same pdb and replica (B) 138 | ''' 139 | np.random.seed(42) 140 | numFrames = 450 141 | deltaFrames = 50 # it's an arbitrary number, in order to not have too different lengths of the trajectories 142 | # domain figures is the directory where the images of the domains are stored, these are going to be overlapped to the scatter plot 143 | domain_figures = '/shared/antoniom/buildCATHDataset/analysis/figures/domains_figure4' 144 | 145 | ## cbar common settings ## 146 | cbar_kws = {"orientation":"vertical", "shrink":0.8, "aspect":40} 147 | cbar_label = "Simulation time (ns)" 148 | cbar_ticks = [0, 250, 500] 149 | cbar_ticklabels = [0, 250, 500] 150 | 151 | if 'A' in plot_type: 152 | # HERE WE PLOT A GRID OF SAMPLES, LOWEST TEMP AND ONE REPLICA (SAME FOR ALL SAMPLES) 153 | temp = '320' 154 | repl = '1' 155 | samples = np.random.choice(list(h5data.keys()), numSamples, replace=False) 156 | 157 | ncols = 3 158 | nrows = math.ceil(numSamples / ncols) 159 | fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * 5, nrows * 5), sharex=shared_axes, sharey=shared_axes) 160 | fig.subplots_adjust(bottom=0.1, top=0.9, left=0.1, right=0.85, hspace=0.35, wspace=0.35) 161 | 162 | for i, sample in tqdm(enumerate(samples), total=numSamples, desc="GyrRad_solidSS (A)"): 163 | if nrows == 1 or ncols == 1: # Single row or column case 164 | ax = axs.flatten()[i] if numSamples != 1 else axs 165 | else: 166 | ax = axs[i // ncols, i % ncols] 167 | 168 | temp_repl_group = h5data[sample][temp][repl] 169 | 170 | # be sure that the trajectory has at least numFrames inside the range numFrames-deltaFrames,numFrames+deltaFrames 171 | if temp_repl_group.attrs['numFrames'] >= numFrames-deltaFrames and temp_repl_group.attrs['numFrames'] <= numFrames+deltaFrames: 172 | pass 173 | else: 174 | print(f"Sample {sample} has not the right number of frames({temp_repl_group.attrs['numFrames']})") 175 | wrong_sample = sample 176 | while not (numFrames - deltaFrames < temp_repl_group.attrs['numFrames'] < numFrames + deltaFrames): 177 | sample = np.random.choice(list(h5data.keys()), 1, replace=False)[0] 178 | temp_repl_group = h5data[sample][temp][repl] 179 | print(f"Sample {wrong_sample} has been replaced by {sample}") 180 | # add the sample to the list of samples and replace the one that was not good 181 | samples[i] = sample 182 | 183 | ss = temp_repl_group['solid_secondary_structure'][:] 184 | gr = temp_repl_group['gyration_radius'][:] 185 | 186 | # Normalized color mapping 187 | norm = mcolors.Normalize(vmin=0, vmax=numFrames+deltaFrames) 188 | cmap = plt.get_cmap('viridis') 189 | 190 | # Scatter plot 191 | scatter = ax.scatter(ss, gr, c=range(len(ss)), cmap=cmap, norm=norm, s=5, zorder=2) 192 | if i != len(samples)-1: 193 | loc = 'upper right' 194 | else: 195 | loc = 'lower left' 196 | # Add the image of the domain as an inset 197 | axins = inset_axes(ax, width="40%", height="40%", loc=loc) 198 | pngpath = opj(domain_figures, f"{sample}.png") 199 | if not os.path.exists(pngpath): 200 | print(f"Image {pngpath} not found") 201 | continue 202 | img = plt.imread(pngpath) 203 | axins.imshow(img) 204 | axins.axis('off') # Hide the axis of the inset 205 | 206 | ax.set_title(f"{sample}") 207 | 208 | xmin = min([ax.get_xlim()[0] for ax in axs.flatten()]) 209 | xmax = max([ax.get_xlim()[1] for ax in axs.flatten()]) 210 | ymin = min([ax.get_ylim()[0] for ax in axs.flatten()]) 211 | ymax = max([ax.get_ylim()[1] for ax in axs.flatten()]) 212 | if shared_axes: 213 | for axi, ax in enumerate(axs.flatten()): 214 | ax.xaxis.set_tick_params(labelbottom=True) 215 | ax.yaxis.set_tick_params(labelleft=True) 216 | if axi % ncols == 0: 217 | ax.set_ylabel("Gyration radius (nm)") 218 | ax.set_ylim(ymin-0.1, ymax+0.1) 219 | if axi // ncols == nrows-1: 220 | ax.set_xlabel("Fraction of α+β structure") 221 | ax.set_xlim(xmin-0.1, xmax+0.1) 222 | 223 | else: 224 | for ax in axs.flatten(): 225 | ax.set_xlim(xmin-0.1, xmax+0.1) 226 | ax.set_ylim(ymin-0.1, ymax+0.1) 227 | ax.set_xlabel("Fraction of α+β structure") 228 | ax.set_ylabel("Gyration radius (nm)") 229 | ax.set_yticks([round(el,1) for el in np.linspace(ymin+0.1, ymax-0.1, 4)]) 230 | 231 | # Colorbar with dedicated space 232 | cbar_ax = fig.add_axes([0.9, 0.25, 0.02, 0.5]) # x, y, width, height 233 | cbar = fig.colorbar(scatter, cax=cbar_ax, **cbar_kws) 234 | cbar.set_label(cbar_label) 235 | cbar.set_ticks(cbar_ticks) 236 | cbar.set_ticklabels(cbar_ticklabels) 237 | plt.savefig(opj(output_dir, f"GyrRad_solidSS_A_domainImages{'_ShareAxs' if shared_axes else ''}.png"), dpi=600) 238 | 239 | 240 | ## HERE WE PLOT A GRID FOR THE SAME SAMPLE BUT DIFFERENT TEMPERATURES (SAME REPLICA) ## 241 | if 'B' in plot_type: 242 | sample_i = '5sicI00' 243 | repl = '1' 244 | ncols = 3 245 | nrows = 2 246 | fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * 5, nrows * 5)) 247 | axs = axs.flatten() 248 | fig.subplots_adjust(bottom=0.1, top=0.9, left=0.1, right=0.85, hspace=0.35, wspace=0.35) 249 | for i, temp in tqdm(enumerate(list(h5data[sample_i].keys())), total=len(h5data[sample_i].keys()), desc="GyrRad_solidSS (B)"): 250 | ax = axs[i] 251 | temp_repl_group = h5data[sample_i][temp][repl] 252 | 253 | ss = temp_repl_group['solid_secondary_structure'][:] 254 | gr = temp_repl_group['gyration_radius'][:] 255 | 256 | # Normalized color mapping 257 | norm = mcolors.Normalize(vmin=0, vmax=numFrames+deltaFrames) 258 | cmap = plt.get_cmap('viridis') 259 | scatter = ax.scatter(ss, gr, c=range(len(ss)), cmap=cmap, norm=norm, s=5) 260 | ax.set_title(f"{temp}K") 261 | ax.set_xlabel("Fraction of α+β structure") 262 | ax.set_ylabel("Gyration radius (nm)") 263 | 264 | # On the first subplot, a thumbnail of the domain is added 265 | if i == 0: 266 | axins = inset_axes(ax, width="45%", height="45%", loc='upper right') 267 | pngpath = opj(domain_figures, f"{sample_i}.png") 268 | assert os.path.exists(pngpath), f"Image {pngpath} not found" 269 | img = plt.imread(pngpath) 270 | axins.imshow(img) 271 | axins.axis('off') 272 | 273 | xmin = 0.28 274 | xmax = 0.62 275 | ymin = 1.28 276 | ymax = 1.52 277 | 278 | # last axis instead of be empty, report the last temperature with a zoom out 279 | # the zoom out is created getting the xmin, xmax, ymin, ymax as the min and the max values of the other plots 280 | ax = axs[-1] 281 | ax.scatter(ss, gr, c=range(len(ss)), cmap=cmap, norm=norm, s=5) 282 | # add a rectangle to show the zoom out 283 | rect = plt.Rectangle((xmin, ymin), xmax-xmin, ymax-ymin, linewidth=1, edgecolor='r', facecolor='none') 284 | ax.add_patch(rect) 285 | ax.set_xlabel("Fraction of α+β structure") 286 | ax.set_ylabel("Gyration radius (nm)") 287 | ax.set_title(f"{temp}K (zoom out)") 288 | 289 | for ax in axs[:-1]: 290 | ax.set_xlim(xmin, xmax) 291 | ax.set_ylim(ymin, ymax) 292 | ax.set_xticks([0.3, 0.4, 0.5, 0.6]) 293 | ax.set_yticks([1.3, 1.4, 1.5]) 294 | 295 | axs[-1].set_yticks([round(el,1) for el in np.linspace(np.min(gr)+0.1, np.max(gr)-0.1, 3)]) 296 | 297 | cbar_ax = fig.add_axes([0.9, 0.25, 0.02, 0.5]) # x, y, width, height 298 | cbar = fig.colorbar(scatter, cax=cbar_ax, **cbar_kws) 299 | cbar.set_label(cbar_label) 300 | cbar.set_ticks(cbar_ticks) 301 | cbar.set_ticklabels(cbar_ticklabels) 302 | plt.savefig(opj(output_dir, f"GyrRad_solidSS_B_{sample_i}_replica{repl}_thumbnail.png"), dpi=600) 303 | 304 | def get_solid_fraction(dssp, simplified=False): 305 | # Compute the solid fraction of α+β structure in the secondary structure wrt to the time. 306 | if simplified: 307 | floatMap = {"H": 0, "B": 1, "E": 1, "G": 0, "I": 0, "T": 2, "S": 2, " ": 2} # 3 type differentiation 308 | else: 309 | floatMap = {"H": 0, "B": 1, "E": 2, "G": 3, "I": 4, "T": 5, "S": 6, " ": 7} 310 | 311 | dssp_decoded_float = np.zeros((dssp.shape[0], dssp.shape[1]), dtype=np.float32) # shape (numFrames, numResidues) 312 | for i in range(dssp.shape[0]): 313 | dssp_decoded_float[i] = [floatMap[el.decode()] for el in dssp[i]] 314 | solid_fraction_time = np.logical_or(dssp_decoded_float == 0, dssp_decoded_float == 1) # shape (numFrames, numResidues) 315 | return solid_fraction_time 316 | 317 | def plot_solidFraction_RMSF(h5metrics, output_dir, numSamples=3, simplified=False, repl='1'): 318 | """ Solid fraction vs RMSF for N random samples at 320K and 450K. 319 | Solid fraction is defined as the fraction of α+β structure in the secondary structure. 320 | simplified: if True, the solid fraction is simplified to 3 types: α, β, and other (dssp based) 321 | """ 322 | np.random.seed(2) 323 | samples = np.random.choice(list(h5metrics.keys()), numSamples, replace=False) 324 | mdcath_dir = "/workspace8/antoniom/mdcath_htmd" 325 | ncols = 2 # 2 temperatures 326 | nrows = numSamples 327 | fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * 5, nrows * 5)) 328 | fig.subplots_adjust(bottom=0.1, top=0.9, left=0.1, right=0.85, hspace=0.35, wspace=0.3) 329 | 330 | for ax_i, sample in tqdm(enumerate(samples), total=len(samples), desc="Solid fraction vs RMSF"): 331 | with h5py.File(opj(mdcath_dir, sample, f"mdcath_dataset_{sample}.h5"), "r") as h5file: 332 | for j, temp in enumerate(['320', '450']): 333 | solid_fraction_time = get_solid_fraction(h5file[sample][temp][repl]['dssp'], simplified=simplified) # shape (numFrames, numResidues) 334 | solid_fraction_time = solid_fraction_time.mean(axis=0) # mean across the frames 335 | rmsf = h5file[sample][temp][repl]['rmsf'][:] 336 | ax = axs[ax_i, j] 337 | ax.scatter(rmsf, solid_fraction_time, c=np.arange(len(rmsf)), cmap='rainbow') 338 | ax.set_title(f"{sample} {temp}K") 339 | ax.set_xlabel("RMSF (nm)") 340 | ax.set_ylabel("Fraction of time in α/β config.") 341 | # column zero set xlim 342 | if j == 0: 343 | ax.set_xlim(0, 1) 344 | else: 345 | ax.set_xlim(0, 2.6) 346 | ax.axvline(x=1, color='grey', linestyle='--') 347 | ax.set_ylim(-0.1, 1.1) 348 | 349 | plt.tight_layout() 350 | plt.savefig(opj(output_dir, f"solidFraction_RMSF{'_simplified' if simplified else ''}.png"), dpi=600) 351 | 352 | def get_replicas(mean_across): 353 | if mean_across == 'all': 354 | return ['0', '1', '2', '3', '4'] 355 | elif isinstance(mean_across, list): 356 | return mean_across 357 | elif isinstance(mean_across, int) or isinstance(mean_across, str): 358 | return [str(mean_across)] 359 | else: 360 | raise ValueError("The mean_across should be 'all' or a list of one element") 361 | 362 | def get_solid_fraction_extended(h5group, replicas, numResidues, simplified=False): 363 | # Compute solid fraction of α+β structure in the secondary structure wrt to the time, across multiple replicas. 364 | # h5group is the group of a specific pdb at a specific temperature 365 | if simplified: 366 | floatMap = {"H": 0, "B": 1, "E": 1, "G": 0, "I": 0, "T": 2, "S": 2, " ": 2} # 3 type differentiation 367 | else: 368 | floatMap = {"H": 0, "B": 1, "E": 2, "G": 3, "I": 4, "T": 5, "S": 6, " ": 7} # 8 type differentiation 369 | 370 | max_num_frames = max([h5group[repl].attrs['numFrames'] for repl in replicas]) # used to build the array of decoded dssp 371 | # dssp_decoded_float shape (len(replicas), numFrames, numResidues) 372 | dssp_decoded_float = np.zeros((len(replicas), max_num_frames, numResidues), dtype=np.float32) 373 | for repl_i, repl in enumerate(replicas): 374 | encoded_dssp = h5group[repl]['dssp'] 375 | for frame_i in range(encoded_dssp.shape[0]): 376 | # we use the axis 0 to store the value of the fraction of alpha+beta structure per replica, 377 | dssp_decoded_float[repl_i, frame_i, :] = [floatMap[el.decode()] for el in encoded_dssp[frame_i]] 378 | 379 | return dssp_decoded_float 380 | 381 | def plot_solidFraction_vs_numResidues(h5metrics, output_dir, mean_across='all', temps=None, simplified=False): 382 | """ Plot the fraction of alpha+beta structure vs the number of residues in the protein, 383 | this is done for all the proteins in the dataset. One value of the fraction of alpha+beta per domain 384 | and one value of the number of residues per domain. The mean of secondary structure could be computed 385 | across all the replicas (all) or just a replica of interest (replica id). If temps is None, one plot for 386 | each temperature is generated, otherwise, the plot is generated for the temperatures in the list temps.""" 387 | 388 | temps = ['320', '348', '379', '413', '450'] if temps is None else temps 389 | replicas = get_replicas(mean_across) 390 | mdcath_dir = "/workspace8/antoniom/mdcath_htmd" 391 | 392 | nPlots = len(temps) 393 | nCols = nPlots if nPlots < 3 else 3 394 | nRows = math.ceil(nPlots / nCols) 395 | fig, axs = plt.subplots(nrows=nRows, ncols=nCols, figsize=(nCols * 5, nRows * 5)) 396 | fig.subplots_adjust(bottom=0.1, top=0.9, left=0.1, right=0.85, hspace=0.35, wspace=0.3) 397 | axs = axs.flatten() if nPlots > 1 else [axs] 398 | 399 | for temp_i, temp in enumerate(temps): 400 | print(f"Temperature: {temp}") 401 | all_alpha_beta_mean = [] 402 | all_numResidues = [] 403 | for temp_i, temp in enumerate(temps): 404 | for pdb in tqdm(h5metrics.keys(), total=len(h5metrics.keys()), desc="Solid fraction vs numResidues"): 405 | numResidues = h5metrics[pdb].attrs['numResidues'] 406 | with h5py.File(opj(mdcath_dir, f"mdcath_dataset_{pdb}.h5"), "r") as h5file: 407 | # dssp_decoded_float of shape (len(replicas), maxNumFrames, numResidues) 408 | dssp_decoded_float = get_solid_fraction_extended(h5file[pdb][temp], replicas, numResidues, simplified=simplified) 409 | # drop zeros from the array, they are the frames that are not present in all the replicas 410 | dssp_decoded_float = dssp_decoded_float.flatten() 411 | dssp_decoded_float = dssp_decoded_float[dssp_decoded_float != 0] 412 | 413 | solid_fraction_time = np.logical_or(dssp_decoded_float == 8, dssp_decoded_float == 1).mean() # 8 is the value of helices, 1 is the value of beta strands 414 | all_alpha_beta_mean.append(solid_fraction_time) 415 | all_numResidues.append(numResidues) 416 | 417 | # plot the scatter plot for the specific temperature 418 | ax = axs[temp_i] 419 | ax.scatter(all_numResidues, all_alpha_beta_mean, s=2.5) 420 | ax.set_title(f"{temp}K") 421 | ax.set_xlabel("Number of residues") 422 | ax.set_ylabel("Fraction of α+β structure") 423 | # save also the single plot 424 | plt.savefig(opj(output_dir, f"all_dataset_plots_studycase/solidFraction_vs_numResidues_{temp}K.png"), dpi=600) 425 | 426 | axs[-1].axis('off') 427 | plt.savefig(opj(output_dir, f"all_dataset_plots_studycase/solidFraction_vs_numResidues_replica_{mean_across}_{len(temps)}Temps.png"), dpi=600) 428 | 429 | def get_secondary_structure_compositions(dssp): 430 | '''This funtcion returns the percentage composition of alpha, beta and coil in the protein. 431 | A special "NA" code will be assigned to each "residue" in the topology which isn"t actually 432 | a protein residue (does not contain atoms with the names "CA", "N", "C", "O") 433 | ''' 434 | floatMap = {"H": 0, "B": 1, "E": 1, "G": 0, "I": 0, "T": 2, "S": 2, " ": 2, 'NA': 3} 435 | 436 | decoded_dssp = [el.decode() for el in dssp[-1]] # consider only the last frame 437 | float_dssp = np.array([floatMap[el] for el in decoded_dssp]) 438 | unique, counts = np.unique(float_dssp, return_counts=True) 439 | numResAlpha, numResBeta, numResCoil = 0, 0, 0 440 | for u, c in zip(unique, counts): 441 | if u == 0: 442 | numResAlpha += c 443 | elif u == 1: 444 | numResBeta += c 445 | else: 446 | # NA or Coil 447 | numResCoil += c 448 | # percentage composition in alpha, beta and coil 449 | alpha_comp = (numResAlpha / np.sum(counts)) 450 | beta_comp = (numResBeta / np.sum(counts)) 451 | coil_comp = (numResCoil / np.sum(counts)) 452 | 453 | return alpha_comp, beta_comp, coil_comp 454 | 455 | def plot_heatmap_ss_time_superfamilies(h5metrics, output_dir, mean_across='all', temps=None, num_pdbs=None, simplified=False): 456 | """ Plot on x_axis the time in ns and on y_axis the fraction of alpha+beta structure respect to the start of the simulation. 457 | Rows are the temperatures and columns are the superfamilies. The relative solid fraction (RSF) is computed as the fraction of 458 | α+β structure in the secondary structure. [figure S1 of the paper] 459 | 460 | Params: 461 | ------------ 462 | - h5metrics: 463 | h5 file with the metrics of the dataset 464 | - output_dir: 465 | directory where to save the plots 466 | - mean_across: 467 | replica to consider, if 'all' all the replicas are considered 468 | - temps: 469 | temperatures to consider, if None all the temperatures are considered 470 | - num_pdbs: 471 | number of pdbs to consider per superfamily 472 | - simplified: 473 | if True, the solid fraction is simplified to 3 types: α, β, and other (dssp based) 474 | """ 475 | np.random.seed(7) 476 | superfamily_labels = {1:'Mainly Alpha', 2:'Mainly Beta', 3:'Mixed Alpha-Beta', 4:'Few Secondary Structures'} 477 | super_family_json = json.load(open("/shared/antoniom/buildCATHDataset/support/cath_info.json", "r")) 478 | mdcath_dir = "/workspace8/antoniom/mdcath_htmd" 479 | 480 | temps = ['320', '348', '379', '413', '450'] if temps is None else temps 481 | replicas = get_replicas(mean_across) 482 | # Determine number of columns based on the number of superfamilies considered 483 | superfamilies = sorted({int(super_family_json[pdb]['superfamily_id'].split(".")[0]) for pdb in h5metrics.keys() if pdb in super_family_json.keys()}) 484 | 485 | # In order to avoid bias, we shuffle the list of pdbs if a subset is requested 486 | pdb_list = list(h5metrics.keys()) if num_pdbs is None else np.random.choice(list(h5metrics.keys()), len(h5metrics.keys()), replace=False) 487 | 488 | nRows = len(temps) 489 | nCols = len(superfamilies) 490 | 491 | # Setup figure and axes 492 | fig, axs = plt.subplots(nrows=nRows, ncols=nCols, figsize=(nCols * 6, nRows * 5)) 493 | fig.subplots_adjust(bottom=0.1, top=0.9, left=0.1, right=0.85, hspace=0.4, wspace=0.4) 494 | 495 | 496 | # Iterate over temperatures and superfamilies 497 | result_dataset = [] 498 | for row, temp in enumerate(temps): 499 | print(f"Temperature: {temp}") 500 | for col, sf in enumerate(superfamilies): 501 | ax = axs[row, col] 502 | 503 | # initialize the arrays to store the data for the 2D histogram (heat-map) 504 | time_points = [] 505 | all_alpha_beta = [] 506 | accepted_superfamilies_domains = 0 507 | 508 | for pdb in tqdm(pdb_list, total=len(pdb_list), desc=f"Solid Fraction vs Time {superfamily_labels[sf]}"): 509 | if num_pdbs is not None and accepted_superfamilies_domains >= num_pdbs: 510 | break # Exit loop if reached the specified number of PDBs for the superfamily 511 | super_family_id = int(super_family_json[pdb]['superfamily_id'].split(".")[0]) 512 | if super_family_id != sf: 513 | continue # Skip non-matching superfamilies 514 | 515 | accepted_superfamilies_domains += 1 516 | 517 | with h5py.File(opj(mdcath_dir, pdb, f"mdcath_dataset_{pdb}.h5"), "r") as h5file: 518 | numResidues = h5metrics[pdb].attrs['numResidues'] 519 | for repl in replicas: 520 | dssp = h5file[pdb][temp][repl]['dssp'] # shape (numFrames, numResidues) 521 | assert dssp.shape[1] == numResidues, f"Number of residues mismatch for {pdb} {temp}K {repl}" 522 | 523 | solid_fraction = get_solid_fraction(dssp, simplified=simplified) 524 | mean_across_time = np.mean(solid_fraction, axis=1) # mean across the residues, shape (numFrames,) 525 | assert not np.isnan(mean_across_time).any(), f"NaN values in the mean_across_time for {pdb} {temp}K {repl}" 526 | 527 | if mean_across_time[0] == 0: 528 | Warning(f"First value of the solid fraction is zero for {pdb} {temp}K replica {repl}, the trajectory will be skipped!") 529 | continue 530 | 531 | normalized_ss_time = mean_across_time / mean_across_time[0] # shape (numFrames,) 532 | time_points.extend(np.arange(0, len(normalized_ss_time), 1)) 533 | all_alpha_beta.extend(normalized_ss_time) 534 | 535 | print(f"Number of domains in {superfamily_labels[sf]} superfamily : {accepted_superfamilies_domains}") 536 | 537 | result_dataset.append(pd.DataFrame({'temp': temp, 538 | 'sf': sf, 539 | 'time_points': np.array(time_points), 540 | 'all_alpha_beta': np.array(all_alpha_beta)})) 541 | 542 | # Create 2D histogram 543 | hist, xedges, yedges = np.histogram2d(time_points, all_alpha_beta, bins=50, range=[[0, 450], [0, 1.5]], density=True) 544 | ax.imshow(hist.T, origin='lower', aspect='auto', extent=(xedges[0], xedges[-1], yedges[0], yedges[-1]), cmap='viridis') 545 | 546 | # Axis labels and title 547 | if col == 0: 548 | ax.set_ylabel(f"{temp}K\nRel. frac. of α+β structure", fontsize=20) 549 | else: 550 | ax.tick_params(axis='y', which='both', left=True, right=False, labelleft=False) 551 | if row == nRows - 1: 552 | ax.set_xlabel("Time (ns)", fontsize=20) 553 | else: 554 | ax.tick_params(axis='x', which='both', bottom=True, top=False, labelbottom=False) 555 | 556 | ax.set_title(superfamily_labels[sf] if row == 0 else "", fontsize=21) 557 | ax.set_xlim(0, 450) 558 | ax.set_ylim(0, 1.5) 559 | 560 | result_dataset = pd.concat(result_dataset) 561 | result_dataset.to_csv(f"HeatMap_RSF_vs_TIME_{num_pdbs}Samples_4Superfamilies.csv") 562 | 563 | plt.tight_layout() 564 | plt.savefig(opj(output_dir, f"HeatMap_RSF_vs_TIME_{num_pdbs}Samples_4Superfamilies.png"), dpi=300, bbox_inches='tight') 565 | 566 | 567 | 568 | 569 | def plot_ternary_superfamilies(h5metrics, output_dir, mean_across='all', temps=None, num_pdbs=None, cbar=False): 570 | import mpltern 571 | from matplotlib.ticker import LogFormatter 572 | 573 | ''' Use mpltern to plot alpha, beta and coil fractions for each superfamily at each temperature considering only the last 574 | frame of the trajectory in a ternary plot. 575 | Params: 576 | - h5metrics: 577 | h5 file with the metrics of the dataset 578 | - output_dir: 579 | directory where to save the plots 580 | - mean_across: 581 | replica to consider, if 'all' all the replicas are considered 582 | - temps: 583 | temperatures to consider, if None all the temperatures are considered 584 | - num_pdbs: 585 | number of pdbs to consider per superfamily 586 | - cbar: 587 | if True, a colorbar is added to the plot 588 | ''' 589 | 590 | np.random.seed(7) 591 | superfamily_labels = {1:'Mainly Alpha', 2:'Mainly Beta', 3:'Mixed Alpha-Beta', 4:'Few Secondary Structures'} 592 | super_family_json = json.load(open("/shared/antoniom/buildCATHDataset/support/cath_info.json", "r")) 593 | mdcath_dir = "/workspace8/antoniom/mdcath_htmd" 594 | 595 | temps = ['320', '348', '379', '413', '450'] if temps is None else temps 596 | replicas = get_replicas(mean_across) 597 | superfamilies = sorted({int(super_family_json[pdb]['superfamily_id'].split(".")[0]) for pdb in h5metrics.keys() if pdb in super_family_json.keys()}) 598 | 599 | vmin, vmax = 1, 350 600 | # In order to avoid bias, we shuffle the list of pdbs if a subset is requested 601 | pdb_list = list(h5metrics.keys()) if num_pdbs is None else np.random.choice(list(h5metrics.keys()), len(h5metrics.keys()), replace=False) 602 | 603 | nRows = len(temps) 604 | nCols = len(superfamilies) 605 | 606 | # Setup figure and axes 607 | fig = plt.figure(figsize=(nCols * 5, nRows * 5)) 608 | 609 | for row, temp in enumerate(temps): 610 | for col, sf in enumerate(superfamilies): 611 | accepted_superfamilies_domains = 0 612 | all_alpha = [] 613 | all_beta = [] 614 | all_coil = [] 615 | 616 | for pdb in tqdm(pdb_list, total=len(pdb_list), desc=f"Ternary Plot for temp {temp} {superfamily_labels[sf]}", ): 617 | if num_pdbs is not None and accepted_superfamilies_domains >= num_pdbs: 618 | break # Exit loop if reached the specified number of PDBs for the superfamily 619 | super_family_id = int(super_family_json[pdb]['superfamily_id'].split(".")[0]) 620 | if super_family_id != sf: 621 | continue 622 | accepted_superfamilies_domains += 1 623 | 624 | with h5py.File(opj(mdcath_dir, pdb, f"mdcath_dataset_{pdb}.h5"), "r") as h5file: 625 | for repl in replicas: 626 | alpha_comp, beta_comp, coil_comp = get_secondary_structure_compositions(h5file[pdb][temp][repl]['dssp']) 627 | all_alpha.append(alpha_comp) 628 | all_beta.append(beta_comp) 629 | all_coil.append(coil_comp) 630 | 631 | # ternary plot for the specific superfamily and temperature 632 | ax = plt.subplot(nRows, nCols, row * nCols + col + 1, projection='ternary') 633 | ax.set_tlabel('Alpha', fontsize=fsize-2) 634 | ax.set_llabel('Beta', fontsize=fsize-2) 635 | ax.set_rlabel('Coil/turn', fontsize=fsize-2) 636 | 637 | t, l, r = np.array(all_alpha), np.array(all_beta), np.array(all_coil) 638 | if sf == 4: 639 | # few secondary structure less points 640 | hex = ax.hexbin(t, l, r, bins='log', edgecolors='face', cmap='viridis', gridsize=30, linewidths=0, vmin=vmin, vmax=vmax) 641 | else: 642 | hex = ax.hexbin(t, l, r, bins='log', edgecolors='face', cmap='viridis', gridsize=50, linewidths=0, vmin=vmin, vmax=vmax) 643 | 644 | if col == 0: 645 | ax.annotate(f"{temp}K", xy=(0.5, 0.5), 646 | xytext=(-0.52, 0.5), 647 | fontsize=fsize+2, 648 | ha='center', 649 | va='center', 650 | xycoords='axes fraction', 651 | textcoords='axes fraction', 652 | #fontweight='bold', 653 | ) 654 | if row == 0: 655 | ax.annotate(superfamily_labels[sf], 656 | xy=(0.5, 0.5), 657 | xytext=(0.5, 1.6), 658 | fontsize=fsize+2, 659 | ha='center', 660 | va='center', 661 | xycoords='axes fraction', 662 | textcoords='axes fraction', 663 | #fontweight='bold', 664 | ) 665 | if cbar: 666 | # Place the color bar at the bottom of the plots 667 | fig.subplots_adjust(bottom=0.17, top=0.84, left=0.1, right=0.95, hspace=0.69, wspace=0.55) 668 | cax = fig.add_axes([0.1, 0.1, 0.8, 0.02]) 669 | #ticks = np.array([0, 50, 100, 200, vmax]) 670 | 671 | #cbar = fig.colorbar(hex, cax=cax, orientation='horizontal', ticks=ticks) 672 | #cbar.set_ticklabels([f"{int(t)}" for t in ticks]) 673 | 674 | formatter = LogFormatter(10, labelOnlyBase=False) 675 | cbar = fig.colorbar(hex, format=formatter, cax=cax, orientation='horizontal') 676 | cbar.set_label('Counts', fontsize=22) 677 | cbar.ax.tick_params(labelsize=20) 678 | 679 | else: 680 | plt.tight_layout() 681 | 682 | plt.savefig(opj(output_dir, f"ternary_plot_{(str(num_pdbs) + 'Samples_') if num_pdbs is not None else ''}4Superfamilies{'_cbar' if cbar else ''}.png"), 683 | dpi=600, bbox_inches='tight') 684 | plt.close() 685 | 686 | def plot_combine_metrics(h5metrics, output_dir): 687 | # Figure 3 of the paper 688 | labels = ['Number of Atoms', 'Number of Residues', 'Trajectory length (ns)', 'RMSD (nm)'] 689 | data = {label: [] for label in labels} 690 | for pdb in tqdm(h5metrics.keys(), total=len(h5metrics.keys()), desc="Figure 3"): 691 | data['Number of Atoms'].append(h5metrics[pdb].attrs['numProteinAtoms']) 692 | data['Number of Residues'].append(h5metrics[pdb].attrs['numResidues']) 693 | for temp in h5metrics[pdb].keys(): 694 | for repl in h5metrics[pdb][temp].keys(): 695 | data['Trajectory length (ns)'].append(h5metrics[pdb][temp][repl].attrs['numFrames']) 696 | data['RMSD (nm)'].append(h5metrics[pdb][temp][repl]['rmsd'][-1]) 697 | 698 | 699 | fig, axs = plt.subplots(2, 2, figsize=(12, 10)) 700 | fig.subplots_adjust(bottom=0.1, top=0.9, left=0.1, right=0.85, hspace=0.4, wspace=0.4) 701 | axs = axs.flatten() 702 | letters = ['a', 'b', 'c', 'd'] 703 | for i, label in enumerate(labels): 704 | axs[i].set_title(letters[i], loc='left', fontweight='bold') 705 | axs[i].hist(data[label], linewidth=1.2, bins=40, color='cornflowerblue', edgecolor='black') 706 | axs[i].set_xlabel(label) 707 | axs[i].set_ylabel("Counts") 708 | 709 | plt.tight_layout() 710 | plt.savefig(opj(output_dir, "dataset_info.png"), dpi=300) 711 | plt.close() 712 | 713 | def plot_maxNumNeighbors(h5metrics, output_dir, cutoff=['5A']): 714 | ''' Plot the maximum number of neighbors for each domain in the dataset. ''' 715 | data_dict = {} 716 | for c in cutoff: 717 | if c not in data_dict: 718 | data_dict[c] = [] 719 | for pdb in tqdm(h5metrics.keys(), total=len(h5metrics.keys()), desc=f"MaxNumNeighbors {c}"): 720 | for temp in h5metrics[pdb].keys(): 721 | for repl in h5metrics[pdb][temp].keys(): 722 | data_dict[c].append(h5metrics[pdb][temp][repl].attrs[f'max_num_neighbors_{c}']) 723 | 724 | f, axs = plt.subplots(1, len(cutoff), figsize=(len(cutoff) * 6, 5)) 725 | for i, c in enumerate(cutoff): 726 | axs[i].hist(data_dict[c], bins=50, color='skyblue', edgecolor='black', linewidth=1.2) 727 | axs[i].set_xlabel("Max number of neighbors per replica") 728 | axs[i].set_title(f"Cutoff {c}") 729 | axs[i].set_ylabel("Counts") 730 | axs[i].set_yscale('log') 731 | 732 | plt.tight_layout() 733 | plt.savefig(opj(output_dir, "maxNumNeighbors.png"), dpi=600) 734 | 735 | def scatterplot_maxNumNeighbors_numNoHAtoms(h5metrics, output_dir, cutoff=['5A', '9A']): 736 | """ Plot the maximum number of neighbors distribution and color the points based on the number of heavy atoms in the protein. """ 737 | 738 | cutoff_neighbors_results = {} 739 | num_heavy_atoms = [] 740 | min_heavy_atoms = 0 741 | max_heavy_atoms = 0 742 | for pdb in tqdm(h5metrics.keys(), total=len(h5metrics.keys()), desc="MaxNumNeighbors vs numNoHAtoms"): 743 | counter_ = 0 744 | for temp in h5metrics[pdb].keys(): 745 | for repl in h5metrics[pdb][temp].keys(): 746 | numNoHAtoms = h5metrics[pdb].attrs['numNoHAtoms'] 747 | max_heavy_atoms = max(max_heavy_atoms, numNoHAtoms) 748 | min_heavy_atoms = min(min_heavy_atoms, numNoHAtoms) 749 | counter_ += 1 750 | for c in cutoff: 751 | if c not in cutoff_neighbors_results: 752 | cutoff_neighbors_results[c] = [] 753 | cutoff_neighbors_results[c].append(h5metrics[pdb][temp][repl].attrs[f'max_num_neighbors_{c}']) 754 | 755 | num_heavy_atoms.extend([numNoHAtoms]*counter_) 756 | 757 | f, axs = plt.subplots(1, len(cutoff), figsize=(len(cutoff) * 6, 5), sharey=False) 758 | 759 | for i, c in enumerate(cutoff): 760 | axs[i].scatter(num_heavy_atoms, cutoff_neighbors_results[c], color='dodgerblue', s=0.8) 761 | axs[i].set_ylabel("Max number of neighbors per replica") 762 | axs[i].set_title(f"Cutoff {c}") 763 | axs[i].set_xlabel("Number of heavy atoms") 764 | 765 | plt.subplots_adjust(hspace=0.75, wspace=0.25) 766 | plt.savefig(opj(output_dir, "maxNumNeighbors_numNoHAtoms.png"), dpi=600) 767 | 768 | def plot_numNoHAtoms(h5metrics, output_dir): 769 | ''' Plot the number of heavy atoms in the protein for each protein in the dataset. ''' 770 | numNoHAtoms = [] 771 | for pdb in tqdm(h5metrics.keys(), total=len(h5metrics.keys()), desc="numNoHAtoms"): 772 | numNoHAtoms.append(h5metrics[pdb].attrs['numNoHAtoms']) 773 | 774 | plt.figure(figsize=(6, 5)) 775 | plt.hist(numNoHAtoms, bins=50) 776 | plt.xlabel("Number of heavy atoms") 777 | plt.ylabel("Counts") 778 | plt.tight_layout() 779 | plt.savefig(opj(output_dir, "numNoHAtoms.png"), dpi=300) --------------------------------------------------------------------------------