├── MutComputeX ├── __init__.py ├── data_loader.py ├── inference.py ├── post_process.py └── model.py ├── .gitignore ├── requirements.txt ├── models └── download_models.sh ├── data └── norbelladine_4OMTase │ └── boxes │ └── download_dataset.sh ├── Dockerfile ├── scripts ├── generate_norbelladine_predictions.py ├── generate_norbelladine_predictions_docker.py └── generate_predictions.py ├── LICENSE.md └── README.md /MutComputeX/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | models/ 2 | data/ 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | black==23.1.0 2 | numpy==1.22.4 3 | pandas==1.4.2 4 | tensorflow-rocm==2.9.1 -------------------------------------------------------------------------------- /models/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | aws s3 cp --no-sign-request s3://mutcomputex/models . --recursive 4 | -------------------------------------------------------------------------------- /data/norbelladine_4OMTase/boxes/download_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | aws s3 cp --no-sign-request s3://mutcomputex/data/norbelladine_4OMTase/boxes/4OMTase_dataset.pkl . 4 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM rocm/tensorflow:rocm5.6-tf2.12-dev 2 | 3 | WORKDIR /deps 4 | 5 | RUN pip install pandas==1.4.2 && \ 6 | curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip" && \ 7 | unzip awscliv2.zip && \ 8 | ./aws/install 9 | 10 | WORKDIR /models 11 | RUN aws s3 cp s3://mutcomputex/models . --recursive --no-sign-request 12 | 13 | COPY scripts/generate_norbelladine_predictions.py /scripts/generate_norbelladine_predictions.py 14 | COPY MutComputeX /opt/MutComputeX/MutComputeX 15 | ENV PYTHONPATH=/opt/MutComputeX 16 | WORKDIR /scripts -------------------------------------------------------------------------------- /scripts/generate_norbelladine_predictions.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from os import environ 3 | from pprint import pprint 4 | 5 | from generate_predictions import generate_inference 6 | 7 | 8 | if __name__ == "__main__": 9 | model_dir = Path("../models") 10 | data = Path("../data/norbelladine_4OMTase/boxes/4OMTase_dataset.pkl") 11 | out_file = None 12 | 13 | model_glob = "*" 14 | use_cpu = True 15 | 16 | if use_cpu: 17 | environ["CUDA_VISIBLE_DEVICES"] = "" 18 | 19 | else: 20 | environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" 21 | environ["TF_CPP_MIN_LOG_LEVEL"] = "5" 22 | 23 | models = [m_dir.resolve() for m_dir in model_dir.glob(model_glob) if m_dir.is_dir()] 24 | 25 | print(f"\nSelected model directories:") 26 | pprint(models) 27 | 28 | generate_inference(models, data, out_file) 29 | -------------------------------------------------------------------------------- /MutComputeX/data_loader.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List, Dict, Tuple, Union 3 | import pickle as pkl 4 | 5 | import numpy as np 6 | 7 | 8 | def load_dataset(pkl_file: Path) -> Tuple[List[Dict[str, Union[str, int]]], np.ndarray]: 9 | """Snapshots provide the residue order of the boxes""" 10 | 11 | assert isinstance(pkl_file, Path) 12 | assert pkl_file.is_file(), pkl_file.resolve() 13 | assert pkl_file.suffix == ".pkl", pkl_file.suffix 14 | 15 | with pkl_file.open("rb") as f: 16 | protein_data = pkl.load(f) 17 | 18 | assert "snapshots" in protein_data.keys(), protein_data.keys() 19 | assert "boxes" in protein_data.keys(), protein_data.keys() 20 | 21 | snapshots = protein_data["snapshots"] 22 | boxes = protein_data["boxes"] 23 | 24 | print(f"Loaded data: {boxes.shape}") 25 | 26 | return snapshots, boxes 27 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Danny Diaz 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/generate_norbelladine_predictions_docker.py: -------------------------------------------------------------------------------- 1 | from subprocess import run 2 | from argparse import ArgumentParser 3 | from pathlib import Path 4 | import os 5 | 6 | 7 | def cli(): 8 | cwd = os.getcwd() 9 | 10 | parser = ArgumentParser() 11 | parser.add_argument( 12 | "--data", 13 | default=Path(f"{cwd}/../data/norbelladine_4OMTase/boxes/4OMTase_dataset.pkl"), 14 | type=Path, 15 | ) 16 | parser.add_argument( 17 | "--model-dir", 18 | default=Path("../models"), 19 | type=Path, 20 | help="directory where model checkpoints are located", 21 | ) 22 | parser.add_argument( 23 | "--model-glob", 24 | type=str, 25 | default="*", 26 | help="glob to select specific models/directories in the model-dir folder", 27 | ) 28 | parser.add_argument("--out-file", default=None) 29 | parser.add_argument("--use-cpu", action="store_true") 30 | return parser.parse_args() 31 | 32 | 33 | def main(args): 34 | cmd = f"docker run -v {args.data.resolve()}:/input/input_file.pkl " 35 | 36 | cmd += f"-v {args.model_dir}:/models " 37 | 38 | if args.out_file: 39 | cmd += f"-v {args.out_file.parent.resolve()}:/output " 40 | 41 | cmd += ( 42 | f"-t mutcomputex:latest python generate_predictions.py " 43 | "--data /input/input_file.pkl ", 44 | "--model-dir /models", 45 | ) 46 | 47 | if args.out_file: 48 | cmd += f"--out-file /output/{args.out_file.name} " 49 | 50 | if args.use_cpu: 51 | cmd += "--use-cpu" 52 | 53 | run(cmd, shell=True) 54 | 55 | 56 | if __name__ == "__main__": 57 | args = cli() 58 | main(args) 59 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MutComputeX: A self-supervised 3D-Residual Neural Network for protein-X interface engineering 2 | 3 | ## X = nucleic acids, glycans, ligands, cofactors, other proteins 4 | 5 | 6 | 7 | ## To generate predictions on Norbelladine 4O-methyltransferase: 8 | The models can be executed either natively, or from within a docker container. Regardless, the Norbelladine data must be download from an s3 bucket. 9 | ### Download dataset 10 | The following command, when run from the root directory of this repo, will download the dataset to the appropriate location. 11 | 12 | `$ cd data/norbelladine_4OMTase/boxes/ && ./download_dataset.sh` 13 | 14 | ### Run the Model (native) 15 | All of the following commands must be run from the root directory of this repository. 16 | 17 | To run the model natively, you must furst download the model with the following command: 18 | 19 | `$ cd models && ./download_models.sh` 20 | 21 | Then you must set the PYTHONPATH to the root directory of the repository with the following command: 22 | 23 | `$ export PYTHONPATH=$(pwd)` 24 | 25 | Finally, you can run the inference: 26 | 27 | `$ cd scripts && python generate_norbelladine_predictions.py` 28 | 29 | ### Run the Model (docker) 30 | The model and its dependencies are all bundled in the docker image. To run the model within a docker container, first build the docker image with the following command: 31 | 32 | `$ docker build -t mutcomputex:latest .` 33 | 34 | Then run the following script: 35 | 36 | `$ cd scripts && python generate_norbelladine_predictions_docker.py` 37 | 38 | ## System Requirements 39 | 40 | ### Hardware Requirements 41 | Models were trained using AMD GPUs (MI50s) with tensorflow-rocm >= 2.9.x using 'channel first'. 42 | Channel first tensorflow models can only run on GPUs. Thus, an AMD GPU is required to genereate inferences. 43 | 44 | ### Software requirements 45 | This package has been tested on Ubuntu 18.04 and 20.04 and requires: 46 | - python >= 3.7.x 47 | - ROCM >= 5.1.x 48 | - tensorflow-rocm >= 2.9.x 49 | - pandas >= 1.4.x 50 | - AWS cli >= 2.9.x (download data and models) 51 | 52 | #### install requirements: 53 | `$ pip install -r requirements.txt` 54 | 55 | -------------------------------------------------------------------------------- /scripts/generate_predictions.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List 3 | from os import environ 4 | from argparse import ArgumentParser 5 | from pprint import pprint 6 | 7 | from MutComputeX.inference import EnsembleCIFPredictor 8 | 9 | 10 | def cli(): 11 | parser = ArgumentParser() 12 | parser.add_argument( 13 | "--data", 14 | required=True, 15 | type=Path, 16 | help="Pickle file with serialized microenvironments and snapshots", 17 | ) 18 | parser.add_argument( 19 | "--model-dir", 20 | default=Path("../models"), 21 | type=Path, 22 | help="directory where model checkpoints are located", 23 | ) 24 | 25 | parser.add_argument( 26 | "--model-glob", 27 | type=str, 28 | default="*", 29 | help="glob to select specific models/directories in the model-dir folder", 30 | ) 31 | parser.add_argument("--out-file", default=None, type=Path) 32 | parser.add_argument("--use-cpu", action="store_true") 33 | 34 | args = parser.parse_args() 35 | 36 | assert args.model_dir.is_dir() 37 | assert args.data.is_file() 38 | assert ( 39 | args.data.suffix == ".pkl" 40 | ), f"{args.data.resolve()} must be a pickle file (.pkl)" 41 | 42 | return args 43 | 44 | 45 | def generate_inference(models: List[Path], data: Path, out_file: Path = None): 46 | """ 47 | Generate mutation inferences with: 48 | - tf trained models 49 | - serialized snapshot/boxes pickle file 50 | """ 51 | 52 | predictor = EnsembleCIFPredictor(models, [0]) 53 | 54 | predictor.predict(data) 55 | 56 | predictor.to_csv(out_file) 57 | 58 | 59 | if __name__ == "__main__": 60 | args = cli() 61 | 62 | if args.use_cpu: 63 | environ["CUDA_VISIBLE_DEVICES"] = "" 64 | 65 | else: 66 | environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" 67 | environ["TF_CPP_MIN_LOG_LEVEL"] = "5" 68 | 69 | models = [ 70 | m_dir.resolve() 71 | for m_dir in args.model_dir.glob(args.model_glob) 72 | if m_dir.is_dir() 73 | ] 74 | 75 | print(f"cli options:") 76 | pprint(vars(args)) 77 | print(f"\nSelected model directories:") 78 | pprint(models) 79 | 80 | generate_inference(models, args.data, args.out_file) 81 | -------------------------------------------------------------------------------- /MutComputeX/inference.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Set, List 3 | 4 | import pandas as pd 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | import tensorflow.keras as K 9 | from tensorflow.keras.models import Model 10 | 11 | from MutComputeX.data_loader import load_dataset 12 | from MutComputeX.post_process import ( 13 | prediction_df, 14 | average_predictions, 15 | prediction_accuracy, 16 | ) 17 | 18 | 19 | def load_model(model_dir: Path) -> Model: 20 | assert isinstance(model_dir, Path) 21 | assert model_dir.is_dir() 22 | 23 | return K.models.load_model(str(model_dir)) 24 | 25 | 26 | def load_ensemble_models(model_dirs: List[Path]) -> List[Model]: 27 | assert all([isinstance(m_dir, Path) and m_dir.is_dir() for m_dir in model_dirs]) 28 | 29 | ens_model = [load_model(m_dir) for m_dir in model_dirs] 30 | 31 | return ens_model 32 | 33 | 34 | class EnsembleCIFPredictor: 35 | def __init__( 36 | self, 37 | model_dirs: List[Path], 38 | gpu_idx: Set[int] = [0], 39 | ensemble_name: str = "EnsResNet", 40 | ): 41 | assert all([isinstance(model_dir, Path) for model_dir in model_dirs]) 42 | for model in model_dirs: 43 | assert model.is_dir(), model.resolve() 44 | assert all([model_dir.is_dir() for model_dir in model_dirs]) 45 | 46 | self.gpu_idx = sorted(set(gpu_idx)) 47 | self.gpu_names = [f"/GPU:{idx}" for idx in self.gpu_idx] 48 | self.strategy = tf.distribute.MirroredStrategy(devices=self.gpu_names) 49 | 50 | self.models = [] 51 | for model_dir in model_dirs: 52 | with self.strategy.scope(): 53 | model = load_model(model_dir) 54 | model.model_name = f"{model_dir.stem}" 55 | 56 | self.models.append(model) 57 | 58 | self.model_dirs = model_dirs 59 | self.model_name = ensemble_name 60 | 61 | self.snapshots = None 62 | self.predictions = None 63 | 64 | 65 | def predict(self, pkl_data: Path) -> pd.DataFrame: 66 | assert isinstance(pkl_data, Path) 67 | assert pkl_data.is_file() 68 | assert pkl_data.suffix == ".pkl" 69 | 70 | self.dataset = pkl_data 71 | snapshots, boxes = load_dataset(pkl_data) 72 | 73 | self.model_predictions = {} 74 | for model in self.models: 75 | print(f"Generating predictions with model - {model.model_name}") 76 | 77 | with self.strategy.scope(): 78 | predictions = prediction_df(snapshots, model.predict(boxes, verbose=1)) 79 | self.model_predictions[model.model_name] = predictions 80 | 81 | self.predictions = average_predictions(self.model_predictions.values()) 82 | 83 | return self.predictions 84 | 85 | 86 | def to_csv(self, out_file: Path = None, include_model_name=True) -> Path: 87 | assert self.predictions is not None 88 | assert self.dataset is not None 89 | 90 | predictions = self.predictions.copy() 91 | 92 | accuracy = prediction_accuracy(predictions) 93 | 94 | if out_file is None: 95 | out_dir = self.dataset.parent.parent / "predictions" 96 | out_dir.mkdir(0o770, parents=True, exist_ok=True) 97 | out_file = out_dir / f"{self.dataset.stem}_predictions.csv" 98 | 99 | out_file = out_file.with_suffix(".csv") 100 | 101 | if include_model_name: 102 | predictions = predictions.assign(model=self.model_name, accuracy=accuracy) 103 | model_col = predictions.pop("model") 104 | accuracy_col = predictions.pop("accuracy") 105 | predictions.insert(0, "model", model_col) 106 | predictions.insert(1, "accuracy", accuracy_col) 107 | 108 | else: 109 | predictions = predictions.assign(accuracy=accuracy) 110 | accuracy_col = predictions.pop("accuracy") 111 | predictions.insert(0, "accuracy", accuracy_col) 112 | 113 | predictions.to_csv(out_file, index=False) 114 | 115 | print(f"Wrote predictions: {out_file.resolve()}") 116 | 117 | return out_file 118 | -------------------------------------------------------------------------------- /MutComputeX/post_process.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Iterable 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | RESIDUES = [ 7 | "ALA", 8 | "ARG", 9 | "ASN", 10 | "ASP", 11 | "CYS", 12 | "GLN", 13 | "GLU", 14 | "GLY", 15 | "HIS", 16 | "ILE", 17 | "LEU", 18 | "LYS", 19 | "MET", 20 | "PHE", 21 | "PRO", 22 | "SER", 23 | "THR", 24 | "TRP", 25 | "TYR", 26 | "VAL", 27 | ] 28 | 29 | 30 | PREDICTION_AA_COL_NAMES = [ 31 | "prALA", 32 | "prARG", 33 | "prASN", 34 | "prASP", 35 | "prCYS", 36 | "prGLN", 37 | "prGLU", 38 | "prGLY", 39 | "prHIS", 40 | "prILE", 41 | "prLEU", 42 | "prLYS", 43 | "prMET", 44 | "prPHE", 45 | "prPRO", 46 | "prSER", 47 | "prTHR", 48 | "prTRP", 49 | "prTYR", 50 | "prVAL", 51 | ] 52 | 53 | PREDICTION_HEADER = [ 54 | "pdb_id", 55 | "chain_id", 56 | "pos", 57 | "wtAA", 58 | "prAA", 59 | "wt_prob", 60 | "pred_prob", 61 | "avg_log_ratio", 62 | *PREDICTION_AA_COL_NAMES, 63 | ] 64 | 65 | 66 | def prediction_df( 67 | snapshots: Iterable[dict], predictions: Iterable[Tuple[int]] 68 | ) -> pd.DataFrame: 69 | assert all([isinstance(ss, dict) for ss in snapshots]), type(list(snapshots)[0]) 70 | 71 | res_dict = {r: i for i, r in enumerate(RESIDUES)} 72 | 73 | rows = [] 74 | for s, p in zip(snapshots, predictions): 75 | idx = np.argmax(p) 76 | wt_aa = s["label"] 77 | wt_pr = p[res_dict[wt_aa]] 78 | pred_prob = p[idx] 79 | pr_aa = RESIDUES[idx] 80 | 81 | # Check if wt_pred is 0 so you dont divide by 0 82 | if wt_pr < np.finfo(float).eps: 83 | wt_pr = np.finfo(float).eps 84 | 85 | log_rat = np.log2(pred_prob / wt_pr) 86 | chain_id = s["chain_id"] 87 | pos = s["res_seq_num"] 88 | 89 | if s["type"] == "FILE_CHAIN_RESIDUE": 90 | source = s["filename"] 91 | 92 | else: 93 | source = "" 94 | 95 | row = [source, chain_id, pos, wt_aa, pr_aa, wt_pr, pred_prob, log_rat] 96 | row.extend([aa_prob for aa_prob in p]) 97 | 98 | rows.append(row) 99 | 100 | return pd.DataFrame(rows, columns=PREDICTION_HEADER).sort_values( 101 | PREDICTION_HEADER[:3] 102 | ) 103 | 104 | 105 | def concat_predictions(DFs: Iterable[pd.DataFrame]) -> pd.DataFrame: 106 | return pd.concat(DFs, axis=0, ignore_index=True).sort_values( 107 | ["pdb_id", "chain_id", "pos", "wtAA"], ascending=[True, True, True, True] 108 | ) 109 | 110 | 111 | def find_predAA(row: pd.Series) -> Tuple[str, float]: 112 | assert isinstance(row, pd.Series) 113 | 114 | idx = np.argmax(row) 115 | pred_AA, pred_prob = row.index[idx].replace("pr", ""), row.iloc[idx] 116 | 117 | return pred_AA, pred_prob 118 | 119 | 120 | def update_wt_prob(row: pd.Series) -> float: 121 | assert isinstance(row, pd.Series) 122 | 123 | wt_col = f"pr{row['wtAA']}" 124 | 125 | return row[wt_col] 126 | 127 | 128 | def calc_log_odds(row: pd.Series) -> float: 129 | assert isinstance(row, pd.Series) 130 | assert "pred_prob" in row.index.tolist() 131 | assert "wt_prob" in row.index.tolist() 132 | 133 | return abs(round(np.log2(row["pred_prob"] / row["wt_prob"]), 4)) 134 | 135 | 136 | def average_predictions(prediction_dfs: List[pd.DataFrame]) -> pd.DataFrame: 137 | df = concat_predictions(prediction_dfs) 138 | 139 | groupby_cols = ['pdb_id', 'chain_id', 'pos', 'wtAA'] 140 | 141 | df = df[[*groupby_cols, *PREDICTION_AA_COL_NAMES]] 142 | 143 | avg_df = ( 144 | df.groupby(groupby_cols, as_index=False) 145 | .mean() 146 | .round(6) 147 | ) 148 | 149 | avg_df["wt_prob"] = avg_df.apply(update_wt_prob, axis=1) 150 | 151 | avg_df[["prAA", "pred_prob"]] = avg_df[PREDICTION_AA_COL_NAMES].apply( 152 | find_predAA, axis=1, result_type="expand" 153 | ) 154 | 155 | avg_df["avg_log_ratio"] = avg_df[["wt_prob", "pred_prob"]].apply( 156 | calc_log_odds, axis=1 157 | ) 158 | 159 | avg_df = avg_df[PREDICTION_HEADER] 160 | 161 | return avg_df 162 | 163 | 164 | def prediction_accuracy( 165 | prediction_df: pd.DataFrame, wt_col: str = "wtAA", pr_col: str = "prAA" 166 | ) -> float: 167 | assert wt_col in prediction_df.columns 168 | assert pr_col in prediction_df.columns 169 | 170 | total = len(prediction_df) 171 | 172 | correct = 0 173 | for idx, row in prediction_df.iterrows(): 174 | if row[wt_col] == row[pr_col]: 175 | correct += 1 176 | 177 | return round(correct / total, 5) 178 | -------------------------------------------------------------------------------- /MutComputeX/model.py: -------------------------------------------------------------------------------- 1 | 2 | from tensorflow.keras import activations 3 | from tensorflow.keras.initializers import HeNormal 4 | from tensorflow.keras.models import Model 5 | from tensorflow.keras.layers import ( 6 | Input, 7 | Dense, 8 | Flatten, 9 | BatchNormalization, 10 | Conv3D, 11 | Activation, 12 | Add, 13 | MaxPool3D, 14 | ) 15 | from tensorflow.keras.regularizers import l2 16 | 17 | 18 | def res_identity(x, f1, f2, base_name, data_format="channels_first"): 19 | 20 | x_skip = x 21 | 22 | x = Conv3D( 23 | filters=f1, 24 | kernel_size=(1, 1, 1), 25 | kernel_regularizer=l2(0.001), 26 | padding="valid", 27 | strides=(1, 1, 1), 28 | data_format=data_format, 29 | name="Conv-identity-one-" + base_name, 30 | )(x) 31 | x = BatchNormalization()(x) 32 | x = Activation(activations.relu)(x) 33 | 34 | x = Conv3D( 35 | filters=f1, 36 | kernel_size=(3, 3, 3), 37 | kernel_regularizer=l2(0.001), 38 | padding="same", 39 | strides=(1, 1, 1), 40 | data_format=data_format, 41 | name="Conv-identity-two-" + base_name, 42 | )(x) 43 | x = BatchNormalization()(x) 44 | x = Activation(activations.relu)(x) 45 | 46 | x = Conv3D( 47 | filters=f2, 48 | kernel_size=(1, 1, 1), 49 | kernel_regularizer=l2(0.001), 50 | padding="valid", 51 | strides=(1, 1, 1), 52 | data_format=data_format, 53 | name="Conv-identity-three-" + base_name, 54 | )(x) 55 | x = BatchNormalization()(x) 56 | 57 | x_skip = Conv3D( 58 | filters=f2, 59 | kernel_size=(1, 1, 1), 60 | kernel_regularizer=l2(0.001), 61 | padding="valid", 62 | strides=(1, 1, 1), 63 | data_format=data_format, 64 | name="Conv-identity-skip-" + base_name, 65 | )(x_skip) 66 | x_skip = BatchNormalization()(x_skip) 67 | 68 | x = Add()([x, x_skip]) 69 | x = Activation(activations.relu)(x) 70 | 71 | return x 72 | 73 | 74 | def res_conv(x, s, f1, f2, base_name, data_format="channels_first"): 75 | 76 | x_skip = x 77 | 78 | x = Conv3D( 79 | filters=f1, 80 | kernel_size=(1, 1, 1), 81 | kernel_regularizer=l2(0.001), 82 | padding="valid", 83 | strides=(s, s, s), 84 | data_format=data_format, 85 | name="Conv-redux-one-" + base_name, 86 | )(x) 87 | x = BatchNormalization()(x) 88 | x = Activation(activations.relu)(x) 89 | 90 | x = Conv3D( 91 | filters=f1, 92 | kernel_size=(3, 3, 3), 93 | kernel_regularizer=l2(0.001), 94 | padding="same", 95 | strides=(1, 1, 1), 96 | data_format=data_format, 97 | name="Conv-redux-two-" + base_name, 98 | )(x) 99 | x = BatchNormalization()(x) 100 | x = Activation(activations.relu)(x) 101 | 102 | x = Conv3D( 103 | filters=f2, 104 | kernel_size=(1, 1, 1), 105 | kernel_regularizer=l2(0.001), 106 | padding="valid", 107 | strides=(1, 1, 1), 108 | data_format=data_format, 109 | name="Conv-redux-three-" + base_name, 110 | )(x) 111 | x = BatchNormalization()(x) 112 | 113 | x_skip = Conv3D( 114 | filters=f2, 115 | kernel_size=(1, 1, 1), 116 | kernel_regularizer=l2(0.001), 117 | padding="valid", 118 | strides=(s, s, s), 119 | data_format=data_format, 120 | name="Conv-redux-skip-" + base_name, 121 | )(x_skip) 122 | x_skip = BatchNormalization()(x_skip) 123 | 124 | x = Add()([x, x_skip]) 125 | x = Activation(activations.relu)(x) 126 | 127 | return x 128 | 129 | 130 | def create_resnet_model(input_shape, data_format="channels_first"): 131 | 132 | m_input = Input(shape=input_shape) 133 | 134 | x = BatchNormalization(axis=[1, 2, 3, 4])(m_input) 135 | 136 | x = res_identity(x, 50, 50, "l1-1", data_format) 137 | x = res_identity(x, 50, 50, "l1-2", data_format) 138 | x = res_identity(x, 50, 50, "l1-3", data_format) 139 | 140 | x = res_conv(x, 2, 50, 100, "l2-1", data_format) 141 | x = res_identity(x, 100, 100, "l2-2", data_format) 142 | x = res_identity(x, 100, 100, "l2-3", data_format) 143 | 144 | x = res_conv(x, 2, 100, 200, "l3-1", data_format) 145 | x = res_identity(x, 200, 200, "l3-2", data_format) 146 | x = res_identity(x, 200, 200, "l3-3", data_format) 147 | 148 | x = res_conv(x, 2, 200, 400, "l4-1", data_format) 149 | x = res_identity(x, 400, 400, "l4-2", data_format) 150 | x = res_identity(x, 400, 400, "l4-3", data_format) 151 | 152 | x = MaxPool3D(data_format=data_format)(x) 153 | 154 | x = Flatten()(x) 155 | x = Dense(1000, kernel_initializer=HeNormal())(x) 156 | x = Activation(activations.relu)(x) 157 | x = Dense(20)(x) 158 | x = Activation(activations.softmax)(x) 159 | 160 | model = Model(inputs=m_input, outputs=x, name="MutComputeX") 161 | 162 | return model 163 | --------------------------------------------------------------------------------