├── TBiNet_overview.png ├── model └── best │ └── readme.md ├── data └── readme.md ├── readme.md ├── test.ipynb └── train.ipynb /TBiNet_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmis-lab/tbinet/HEAD/TBiNet_overview.png -------------------------------------------------------------------------------- /model/best/readme.md: -------------------------------------------------------------------------------- 1 | Click [here](https://drive.google.com/open?id=16bDDb9N3NOCngERRxfR9eJDvZAbs1pBl) to download the best model of TBiNet
2 | -------------------------------------------------------------------------------- /data/readme.md: -------------------------------------------------------------------------------- 1 | After download ChIP-seq data from http://deepsea.princeton.edu/media/code/deepsea_train_bundle.v0.9.tar.gz,
2 | place the 'train.mat', 'valid.mat' and 'test.mat' files in the 'deepsea_train' folder into the this folder. 3 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # TBiNet: A deep neural network for predicting transcription factor binding sites using attention mechanism 2 | TBiNet is an attention-based neural network that predicts transcription factor-DNA binding in a given DNA sequence. 3 | 4 | Paper information: Park, S., Koh, Y., Jeon, H. et al. "[Enhancing the interpretability of transcription factor binding site prediction using attention mechanism](https://www.nature.com/articles/s41598-020-70218-4)", Scientific Reports (2020). 5 | 6 | - Overview of TBiNet 7 | ![model image](TBiNet_overview.png) 8 | 9 | ## Requirements 10 | - Python (version 3.6.6, recommend installing Anaconda3) 11 | - Numpy (version 1.14.6) 12 | - H5py (version 2.8.0) 13 | - Scipy (version 1.1.0) 14 | - Sklearn (version 0.20.1) 15 | - Theano (version 1.0.3) 16 | - Keras (version 2.2.4, backend:theano) 17 | 18 | ## Usage 19 | ### Data 20 | The ChIP-seq data used in this work can be downloaded from . 21 | 22 | ### Training TBiNet 23 | `train.ipynb` 24 | 25 | ### Testing TBiNet 26 | `test.ipynb` 27 | 28 | ## Contact information 29 | For help or issues using TBiNet, please submit a GitHub issue. Please contact Sungjoon Park (sungjoonopark@korea.ac.kr) for communication related to TBiNet. 30 | -------------------------------------------------------------------------------- /test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import scipy.io\n", 11 | "from sklearn import metrics\n", 12 | "import pandas as pd\n", 13 | "import os\n", 14 | "os.environ['THEANO_FLAGS'] = \"device=cuda0,force_device=True,floatX=float32\"\n", 15 | "import theano\n", 16 | "print(theano.config.device)\n", 17 | "\n", 18 | "from keras.layers import Embedding\n", 19 | "from keras.models import Sequential\n", 20 | "from keras.models import Model\n", 21 | "from keras.layers import Dense, Dropout, Activation, Flatten, Layer, merge, Input, Concatenate, Reshape\n", 22 | "from keras.layers.convolutional import Conv1D, MaxPooling1D\n", 23 | "from keras.layers.pooling import GlobalMaxPooling1D\n", 24 | "from keras.layers.recurrent import LSTM\n", 25 | "from keras.layers.wrappers import Bidirectional, TimeDistributed\n", 26 | "from keras.models import load_model\n", 27 | "from keras.callbacks import ModelCheckpoint, EarlyStopping" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 13, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "def get_auroc(preds, obs):\n", 37 | " fpr, tpr, thresholds = metrics.roc_curve(obs, preds, drop_intermediate=False)\n", 38 | " auroc = metrics.auc(fpr,tpr)\n", 39 | " return auroc\n", 40 | "\n", 41 | "def get_aupr(preds, obs):\n", 42 | " precision, recall, thresholds = metrics.precision_recall_curve(obs, preds)\n", 43 | " aupr = metrics.auc(recall,precision)\n", 44 | " return aupr\n", 45 | "\n", 46 | "def get_aurocs_and_auprs(tpreds, tobs):\n", 47 | " tpreds_df = pd.DataFrame(tpreds)\n", 48 | " tobs_df = pd.DataFrame(tobs)\n", 49 | " \n", 50 | " task_list = []\n", 51 | " auroc_list = []\n", 52 | " aupr_list = []\n", 53 | " for task in tpreds_df:\n", 54 | " pred = tpreds_df[task]\n", 55 | " obs = tobs_df[task]\n", 56 | " auroc=round(get_auroc(pred,obs),5)\n", 57 | " aupr = round(get_aupr(pred,obs),5)\n", 58 | " task_list.append(task)\n", 59 | " auroc_list.append(auroc)\n", 60 | " aupr_list.append(aupr)\n", 61 | " return auroc_list, aupr_list" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "### Load data (test)\n" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 14, 74 | "metadata": { 75 | "scrolled": true 76 | }, 77 | "outputs": [], 78 | "source": [ 79 | "data_folder = \"./data/\"\n", 80 | "\n", 81 | "testmat = scipy.io.loadmat(data_folder+'test.mat')" 82 | ] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "metadata": {}, 87 | "source": [ 88 | "### Load model" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "model = load_model(\"./model/best/tbinet_best.hdf5\")\n", 98 | "print('model summary')\n", 99 | "model.summary()" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": {}, 105 | "source": [ 106 | "### Calculate averaged AUROC and AUPR" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "tpreds = model.predict(np.transpose(testmat['testxdata'],axes=(0,2,1)),verbose=1)\n", 116 | "tpreds_temp = np.copy(tpreds)\n", 117 | "reverse_start_id = int(testmat['testdata'][:,125:815].shape[0]/2)\n", 118 | "\n", 119 | "for i in range(reverse_start_id):\n", 120 | " tpreds_avg_temp = (tpreds_temp[i] + tpreds_temp[reverse_start_id+i])/2.0\n", 121 | " tpreds_temp[i] = tpreds_avg_temp\n", 122 | " tpreds_temp[reverse_start_id+i] = tpreds_avg_temp\n", 123 | "\n", 124 | "\n", 125 | "aurocs, auprs = get_aurocs_and_auprs(tpreds_temp,testmat['testdata'][:,125:815])\n", 126 | "print(\"Averaged AUROC:\",np.nanmean(aurocs))\n", 127 | "print(\"Averaged AUPR:\", np.nanmean(auprs))" 128 | ] 129 | } 130 | ], 131 | "metadata": { 132 | "kernelspec": { 133 | "display_name": "Python 3", 134 | "language": "python", 135 | "name": "python3" 136 | }, 137 | "language_info": { 138 | "codemirror_mode": { 139 | "name": "ipython", 140 | "version": 3 141 | }, 142 | "file_extension": ".py", 143 | "mimetype": "text/x-python", 144 | "name": "python", 145 | "nbconvert_exporter": "python", 146 | "pygments_lexer": "ipython3", 147 | "version": "3.6.8" 148 | } 149 | }, 150 | "nbformat": 4, 151 | "nbformat_minor": 2 152 | } 153 | -------------------------------------------------------------------------------- /train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import h5py\n", 11 | "import scipy.io\n", 12 | "from sklearn import metrics\n", 13 | "import pandas as pd\n", 14 | "import os\n", 15 | "os.environ['THEANO_FLAGS'] = \"device=cuda0,force_device=True,floatX=float32,gpuarray.preallocate=0.3\"\n", 16 | "import theano\n", 17 | "print(theano.config.device)\n", 18 | "from keras.layers import Embedding\n", 19 | "from keras.models import Sequential\n", 20 | "from keras.models import Model\n", 21 | "from keras.layers import Dense, Dropout, Activation, Flatten, Layer, merge, Input, Concatenate, Reshape, concatenate,Lambda,multiply,Permute,Reshape,RepeatVector\n", 22 | "from keras.layers.convolutional import Conv1D, MaxPooling1D\n", 23 | "from keras.layers.pooling import GlobalMaxPooling1D\n", 24 | "from keras.layers.recurrent import LSTM\n", 25 | "from keras.layers.wrappers import Bidirectional, TimeDistributed\n", 26 | "from keras.models import load_model\n", 27 | "from keras.callbacks import ModelCheckpoint, EarlyStopping\n", 28 | "from keras import optimizers\n", 29 | "from keras import backend as K\n", 30 | "from keras import regularizers" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "### Load data (training and validation)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "data_folder = \"./data/\"\n", 47 | "\n", 48 | "trainmat = h5py.File(data_folder+'train.mat')\n", 49 | "validmat = scipy.io.loadmat(data_folder+'valid.mat')\n", 50 | "\n", 51 | "X_train = np.transpose(np.array(trainmat['trainxdata']),axes=(2,0,1))\n", 52 | "y_train = np.array(trainmat['traindata']).T\n", 53 | "\n", 54 | "trainmat.close()" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "### Choose only the targets that correspond to the TF binding\n" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 3, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "y_train = y_train[:,125:815]" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "### Run TBiNet" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "sequence_input = Input(shape=(1000,4))\n", 87 | "\n", 88 | "# Convolutional Layer\n", 89 | "output = Conv1D(320,kernel_size=26,padding=\"valid\",activation=\"relu\")(sequence_input)\n", 90 | "output = MaxPooling1D(pool_size=13, strides=13)(output)\n", 91 | "output = Dropout(0.2)(output)\n", 92 | "\n", 93 | "#Attention Layer\n", 94 | "attention = Dense(1)(output)\n", 95 | "attention = Permute((2, 1))(attention)\n", 96 | "attention = Activation('softmax')(attention)\n", 97 | "attention = Permute((2, 1))(attention)\n", 98 | "attention = Lambda(lambda x: K.mean(x, axis=2), name='attention',output_shape=(75,))(attention)\n", 99 | "attention = RepeatVector(320)(attention)\n", 100 | "attention = Permute((2,1))(attention)\n", 101 | "output = multiply([output, attention])\n", 102 | "\n", 103 | "#BiLSTM Layer\n", 104 | "output = Bidirectional(LSTM(320,return_sequences=True))(output)\n", 105 | "output = Dropout(0.5)(output)\n", 106 | "\n", 107 | "flat_output = Flatten()(output)\n", 108 | "\n", 109 | "#FC Layer\n", 110 | "FC_output = Dense(695)(flat_output)\n", 111 | "FC_output = Activation('relu')(FC_output)\n", 112 | "\n", 113 | "#Output Layer\n", 114 | "output = Dense(690)(FC_output)\n", 115 | "output = Activation('sigmoid')(output)\n", 116 | "\n", 117 | "model = Model(inputs=sequence_input, outputs=output)\n", 118 | "\n", 119 | "print('compiling model')\n", 120 | "model.compile(loss='binary_crossentropy', optimizer='adam')\n", 121 | "\n", 122 | "print('model summary')\n", 123 | "model.summary()\n", 124 | "\n", 125 | "checkpointer = ModelCheckpoint(filepath=\"./model/tbinet.{epoch:02d}-{val_loss:.2f}.hdf5\", verbose=1, save_best_only=False)\n", 126 | "earlystopper = EarlyStopping(monitor='val_loss', patience=10, verbose=1)\n", 127 | "\n", 128 | "model.fit(X_train, y_train, batch_size=100, epochs=60, shuffle=True, verbose=1, validation_data=(np.transpose(validmat['validxdata'],axes=(0,2,1)),validmat['validdata'][:,125:815]), callbacks=[checkpointer,earlystopper])\n", 129 | "\n", 130 | "model.save('./model/tbinet.h5')" 131 | ] 132 | } 133 | ], 134 | "metadata": { 135 | "kernelspec": { 136 | "display_name": "Python 3", 137 | "language": "python", 138 | "name": "python3" 139 | }, 140 | "language_info": { 141 | "codemirror_mode": { 142 | "name": "ipython", 143 | "version": 3 144 | }, 145 | "file_extension": ".py", 146 | "mimetype": "text/x-python", 147 | "name": "python", 148 | "nbconvert_exporter": "python", 149 | "pygments_lexer": "ipython3", 150 | "version": "3.6.8" 151 | } 152 | }, 153 | "nbformat": 4, 154 | "nbformat_minor": 2 155 | } 156 | --------------------------------------------------------------------------------