├── LICENSE ├── Protvec.ipynb ├── Readme.md ├── pfamclassification.R └── protvec.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Mike Huang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Protvec.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# ProtVec: Amino Acid Embedding Representation of Proteins for Function Classification\n", 8 | "\n", 9 | "## Objectives\n", 10 | "1. Extract features from amino acid sequences for machine learning\n", 11 | "2. Use features to predict protein family and other structural properties\n", 12 | "\n", 13 | "## Abstract\n", 14 | "This project attempts to reproduce the results from [Asgari 2015](http://journals.plos.org/plosone/article?id=10.1371/journal.pone.0141287) and to expand it to phage sequences and their protein families. Currently, Asgari's classification of protein families can be reproduced with his using his [trained embedding.](https://github.com/ehsanasgari/Deep-Proteomics). However, his results cannot be reproduced with current attempts to train using the skip-gram negative sampling method detailed in [this tutorial.](http://adventuresinmachinelearning.com/word2vec-keras-tutorial/) Training samples have been attempted with the SwissProt database. \n", 15 | "\n", 16 | "## Introduction\n", 17 | "\n", 18 | "Predicting protein function with machine learning methods require informative features that is extracted from data. A natural language processing (NLP) technique, known as Word2Vec is used to represent a word by its context with a vector that encodes for the probability a context would occur for a word. These vectors are effective at representing meanings of words since words with similar meanings would have similar contexts. For example, the word cat and kitten would have similar contexts that they are used in since they have very similar meanings. These words would thus have very similar vectors. \n", 19 | "\n", 20 | "\n", 21 | "## Methods\n", 22 | "1. Preprocessing\n", 23 | " 1. Load dataset containing protein amino acid sequences and Asgari's embedding\n", 24 | " 2. [Convert sequences to three lists of non-overlapping 3-mer words](https://www.researchgate.net/profile/Mohammad_Mofrad/publication/283644387/figure/fig4/AS:341292040114179@1458381771303/Protein-sequence-splitting-In-order-to-prepare-the-training-data-each-protein-sequence.png) \n", 25 | " 3. Convert 3-mers to numerical encoding using kmer indicies from Asgari's embedding (row dimension)\n", 26 | " 4. Generate skipgrams with [Keras function](https://keras.io/preprocessing/sequence/) \n", 27 | " Output: [target word, context word](http://mccormickml.com/assets/word2vec/training_data.png), label \n", 28 | " Label refers to true or false target/context pairing generated for the negative sampling technique \n", 29 | "2. Training embedding\n", 30 | " 1. Create negative sampling skipgram model with Keras [using technique from this tutorial](http://adventuresinmachinelearning.com/word2vec-keras-tutorial/)\n", 31 | "3. Generate ProtVecs from embedding for a given protein sequence\n", 32 | " 1. Break protein sequence to list of kmers\n", 33 | " 2. Convert kmers to vectors by taking the dot product of its one hot vector with the embedding \n", 34 | " 3. Sum up all vectors for all kmers for a single vector representation for a protein (length 100) \n", 35 | "4. Classify protein function with ProtVec features (results currently not working, refer to R script)\n", 36 | " 1. Use protvecs as training features\n", 37 | " 2. Use pfam as labels\n", 38 | " 3. For a given pfam classification, perform binary classification with all of its positive samples and randomly sample an equal amount of negative samples\n", 39 | " 4. Train SVM model \n", 40 | " \n", 41 | " \n", 42 | "## Resources \n", 43 | "1. Intuition behind Word2Vec http://mccormickml.com/2016/04/19/word2vec-tutorial-the-skip-gram-model/\n", 44 | "2. Tutorial followed for implementation of skip-gram negative sampling (includes code) http://adventuresinmachinelearning.com/word2vec-keras-tutorial/\n", 45 | "3. Introduction to protein function prediction\n", 46 | "http://biofunctionprediction.org/cafa-targets/Introduction_to_protein_prediction.pdf\n", 47 | "\n", 48 | "## Author\n", 49 | "Mike Huang \n", 50 | "huangjmike@gmail.com" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 1, 56 | "metadata": {}, 57 | "outputs": [ 58 | { 59 | "name": "stderr", 60 | "output_type": "stream", 61 | "text": [ 62 | "Using TensorFlow backend.\n" 63 | ] 64 | } 65 | ], 66 | "source": [ 67 | "import pandas as pd\n", 68 | "import numpy as np\n", 69 | "from keras.preprocessing.sequence import skipgrams, pad_sequences, make_sampling_table\n", 70 | "from keras.preprocessing.text import hashing_trick\n", 71 | "from keras.layers import Embedding, Input, Reshape, Dense, merge\n", 72 | "from keras.models import Sequential, Model\n", 73 | "from sklearn.manifold import TSNE\n", 74 | "from joblib import Parallel, delayed\n", 75 | "import multiprocessing\n", 76 | "\n", 77 | "import csv\n", 78 | "\n", 79 | "\n", 80 | "#Load Ehsan Asgari's embeddings\n", 81 | "#Source: http://journals.plos.org/plosone/article?id=10.1371/journal.pone.0141287\n", 82 | "#Embedding: https://github.com/ehsanasgari/Deep-Proteomics\n", 83 | "ehsanEmbed = []\n", 84 | "with open(\"protVec_100d_3grams.csv\") as tsvfile:\n", 85 | " tsvreader = csv.reader(tsvfile, delimiter=\"\\t\")\n", 86 | " for line in tsvreader:\n", 87 | " ehsanEmbed.append(line[0].split('\\t'))\n", 88 | "threemers = [vec[0] for vec in ehsanEmbed]\n", 89 | "embeddingMat = [[float(n) for n in vec[1:]] for vec in ehsanEmbed]\n", 90 | "threemersidx = {} #generate word to index translation dictionary. Use for kmersdict function arguments.\n", 91 | "for i, kmer in enumerate(threemers):\n", 92 | " threemersidx[kmer] = i\n", 93 | "\n", 94 | " \n", 95 | "#Load NCBI Phage processed dataset - 38420 sequences\n", 96 | "#table = pd.read_csv(\"90filter.ncbi.statis_phage_gene.csv\", index_col=0)\n", 97 | "#Remove entries without vector representation due to amino acid sequence not reaching threshold length\n", 98 | "#table = table[table['Protein'].apply(lambda x: type(x)!=float)]\n", 99 | "#table[:10]\n", 100 | "\n", 101 | "#Load Second NCBI Phage processed dataset - 99520 sequences\n", 102 | "#cherry = pd.read_csv(\"CherryProteins.csv\")\n", 103 | "#cherry = cherry.loc[cherry['Component'] != 'HYP'] #Filter hypothetical sequences\n", 104 | "#cherry = cherry.loc[cherry['Component'] != 'UNS'] #Filter unsorted sequences\n", 105 | "#cherryseqs = pd.read_csv(\"cherryaaseqs.csv\")\n", 106 | "\n", 107 | "#Load SwissProt 2015 data\n", 108 | "swissprot = pd.read_csv(\"family_classification_metadata.tab\", sep='\\t')\n", 109 | "swissprot['Sequence'] = pd.read_csv(\"family_classification_sequences.tab\", sep='\\t')\n", 110 | "\n", 111 | "#Create non-redundant, concatenated sequences list between two NCBI datasets for training\n", 112 | "#seqsunique = [seq[0] for seq in cherryseqs['Protein'].values if seq not in table['Protein'].values]\n", 113 | "#uniqueinds = [i for i in range(len(cherryseqs)) if cherryseqs['Protein'].iloc[i] not in table['Protein'].values]\n", 114 | "#cherryunq = cherryseqs.iloc[uniqueinds]\n", 115 | "#cherryunq = cherryunq.append(table[['Function','Protein']])\n", 116 | "cherryunq = pd.read_csv(\"cherryall.csv\")\n", 117 | "\n", 118 | "#Set parameters\n", 119 | "vocabsize = len(threemersidx)\n", 120 | "window_size = 25\n", 121 | "num_cores = multiprocessing.cpu_count() #For parallel computing\n", 122 | "\n", 123 | "#Path to model weights\n", 124 | "weightspath = '38420sample10000000epochsAsgari.hdf'" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": {}, 130 | "source": [ 131 | "### Ehsan's embedding trained with SwissProt 2015\n", 132 | "The embedding has dimensions of 9048 x 100. 9048 represents one for each 3-mer. 100 is the size of the vector representation for each 3-mer. The matrix is a lookup table to get the vector for a given 3-mer. 3-mers not in the table are represented by unk.\n", 133 | "\n", 134 | "![](http://mccormickml.com/assets/word2vec/word2vec_weight_matrix_lookup_table.png)" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "metadata": {}, 140 | "source": [ 141 | "### Data preprocessing\n", 142 | "\n", 143 | "Let's create the three lists of non-overlapping 3mers as described in the paper.\n", 144 | "\n", 145 | "\n", 146 | "\n", 147 | "Next, encode each 3-mer to its row index in the embedding. " 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 4, 153 | "metadata": { 154 | "collapsed": true 155 | }, 156 | "outputs": [], 157 | "source": [ 158 | "#Convert sequences to three lists of non overlapping 3mers \n", 159 | "def kmerlists(seq):\n", 160 | " kmer0 = []\n", 161 | " kmer1 = []\n", 162 | " kmer2 = []\n", 163 | " for i in range(0,len(seq)-2,3):\n", 164 | " if len(seq[i:i+3]) == 3:\n", 165 | " kmer0.append(seq[i:i+3])\n", 166 | " i+=1\n", 167 | " if len(seq[i:i+3]) == 3:\n", 168 | " kmer1.append(seq[i:i+3])\n", 169 | " i+=1\n", 170 | " if len(seq[i:i+3]) == 3:\n", 171 | " kmer2.append(seq[i:i+3])\n", 172 | " return [kmer0,kmer1,kmer2]\n", 173 | "\n", 174 | "#Same as kmerlists function but outputs an index number assigned to each kmer. Index number is from Asgari's embedding\n", 175 | "def kmersindex(seqs, kmersdict=threemersidx):\n", 176 | " kmers = []\n", 177 | " for i in range(len(seqs)):\n", 178 | " kmers.append(kmerlists(seqs[i]))\n", 179 | " kmers = np.array(kmers).flatten().flatten(order='F')\n", 180 | " kmersindex = []\n", 181 | " for seq in kmers:\n", 182 | " temp = []\n", 183 | " for kmer in seq:\n", 184 | " try:\n", 185 | " temp.append(kmersdict[kmer])\n", 186 | " except:\n", 187 | " temp.append(kmersdict[''])\n", 188 | " kmersindex.append(temp)\n", 189 | " return kmersindex\n", 190 | "\n", 191 | "sampling_table = make_sampling_table(vocabsize)\n", 192 | "def generateskipgramshelper(kmersindicies): \n", 193 | " couples, labels = skipgrams(kmersindicies, vocabsize, window_size=window_size, sampling_table=sampling_table)\n", 194 | " if len(couples)==0: \n", 195 | " couples, labels = skipgrams(kmersindicies, vocabsize, window_size=window_size, sampling_table=sampling_table)\n", 196 | " if len(couples)==0:\n", 197 | " couples, labels = skipgrams(kmersindicies, vocabsize, window_size=window_size, sampling_table=sampling_table)\n", 198 | " else:\n", 199 | " word_target, word_context = zip(*couples)\n", 200 | " return word_target, word_context, labels\n", 201 | " \n", 202 | "def generateskipgrams(seqs,kmersdict=threemersidx):\n", 203 | " kmersidx = kmersindex(seqs,kmersdict)\n", 204 | " return Parallel(n_jobs=num_cores)(delayed(generateskipgramshelper)(kmers) for kmers in kmersidx)\n" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 20, 210 | "metadata": {}, 211 | "outputs": [ 212 | { 213 | "name": "stdout", 214 | "output_type": "stream", 215 | "text": [ 216 | "Sample sequence\n", 217 | "MAFSAEDVLKEYDRRRRMEALLLSLYYPNDRKLLDYKEWSPPRVQVECPKAPVEWNNPPSEKGLIVGHFSGIKYKGEKAQASEVDVNKMCCWVSKFKDAMRRYQGIQTCKIPGKVLSDLDAKIKAYNLTVEGVEGFVRYSRVTKQHVAAFLKELRHSKQYENVNLIHYILTDKRVDIQHLEKDLVKDFKALVESAHRMRQGHMINVKYILYQLLKKHGHGPDGPDILTVKTGSKGVLYDDSFRKIYTDLGWKFTPL\n", 218 | "\n", 219 | "Convert sequence to list of kmers\n", 220 | "[['MAF', 'SAE', 'DVL', 'KEY', 'DRR', 'RRM', 'EAL', 'LLS', 'LYY', 'PND', 'RKL', 'LDY', 'KEW', 'SPP', 'RVQ', 'VEC', 'PKA', 'PVE', 'WNN', 'PPS', 'EKG', 'LIV', 'GHF', 'SGI', 'KYK', 'GEK', 'AQA', 'SEV', 'DVN', 'KMC', 'CWV', 'SKF', 'KDA', 'MRR', 'YQG', 'IQT', 'CKI', 'PGK', 'VLS', 'DLD', 'AKI', 'KAY', 'NLT', 'VEG', 'VEG', 'FVR', 'YSR', 'VTK', 'QHV', 'AAF', 'LKE', 'LRH', 'SKQ', 'YEN', 'VNL', 'IHY', 'ILT', 'DKR', 'VDI', 'QHL', 'EKD', 'LVK', 'DFK', 'ALV', 'ESA', 'HRM', 'RQG', 'HMI', 'NVK', 'YIL', 'YQL', 'LKK', 'HGH', 'GPD', 'GPD', 'ILT', 'VKT', 'GSK', 'GVL', 'YDD', 'SFR', 'KIY', 'TDL', 'GWK', 'FTP'], ['AFS', 'AED', 'VLK', 'EYD', 'RRR', 'RME', 'ALL', 'LSL', 'YYP', 'NDR', 'KLL', 'DYK', 'EWS', 'PPR', 'VQV', 'ECP', 'KAP', 'VEW', 'NNP', 'PSE', 'KGL', 'IVG', 'HFS', 'GIK', 'YKG', 'EKA', 'QAS', 'EVD', 'VNK', 'MCC', 'WVS', 'KFK', 'DAM', 'RRY', 'QGI', 'QTC', 'KIP', 'GKV', 'LSD', 'LDA', 'KIK', 'AYN', 'LTV', 'EGV', 'EGF', 'VRY', 'SRV', 'TKQ', 'HVA', 'AFL', 'KEL', 'RHS', 'KQY', 'ENV', 'NLI', 'HYI', 'LTD', 'KRV', 'DIQ', 'HLE', 'KDL', 'VKD', 'FKA', 'LVE', 'SAH', 'RMR', 'QGH', 'MIN', 'VKY', 'ILY', 'QLL', 'KKH', 'GHG', 'PDG', 'PDI', 'LTV', 'KTG', 'SKG', 'VLY', 'DDS', 'FRK', 'IYT', 'DLG', 'WKF', 'TPL'], ['FSA', 'EDV', 'LKE', 'YDR', 'RRR', 'MEA', 'LLL', 'SLY', 'YPN', 'DRK', 'LLD', 'YKE', 'WSP', 'PRV', 'QVE', 'CPK', 'APV', 'EWN', 'NPP', 'SEK', 'GLI', 'VGH', 'FSG', 'IKY', 'KGE', 'KAQ', 'ASE', 'VDV', 'NKM', 'CCW', 'VSK', 'FKD', 'AMR', 'RYQ', 'GIQ', 'TCK', 'IPG', 'KVL', 'SDL', 'DAK', 'IKA', 'YNL', 'TVE', 'GVE', 'GFV', 'RYS', 'RVT', 'KQH', 'VAA', 'FLK', 'ELR', 'HSK', 'QYE', 'NVN', 'LIH', 'YIL', 'TDK', 'RVD', 'IQH', 'LEK', 'DLV', 'KDF', 'KAL', 'VES', 'AHR', 'MRQ', 'GHM', 'INV', 'KYI', 'LYQ', 'LLK', 'KHG', 'HGP', 'DGP', 'DIL', 'TVK', 'TGS', 'KGV', 'LYD', 'DSF', 'RKI', 'YTD', 'LGW', 'KFT']]\n", 221 | "\n", 222 | "Convert kmers to their index on the embedding\n", 223 | "[[4330, 704, 165, 2795, 2594, 4177, 9, 12, 4155, 4300, 467, 2012, 6034, 1854, 3001, 5719, 2112, 1382, 7163, 1380, 756, 593, 4582, 718, 3482, 648, 291, 956, 2337, 7690, 7833, 3151, 986, 4003, 4117, 3390, 6159, 1915, 128, 575, 941, 2787, 1260, 507, 507, 2641, 3455, 1665, 5098, 792, 49, 2474, 2708, 4170, 1220, 6212, 566, 2977, 1079, 3490, 1401, 294, 2997, 45, 1012, 6887, 2326, 7200, 2252, 2647, 3514, 105, 5221, 2962, 2962, 566, 1453, 1437, 126, 3808, 2895, 3362, 890, 5668, 3102], [1268, 892, 354, 3434, 376, 4614, 5, 24, 6082, 4205, 59, 4097, 6338, 2660, 1934, 6388, 2277, 5623, 3420, 1820, 410, 470, 4729, 1038, 3533, 220, 1937, 1057, 2191, 7960, 5572, 2975, 3608, 3041, 2208, 7016, 2430, 448, 243, 112, 954, 4503, 428, 579, 2026, 3606, 1337, 3779, 3357, 696, 83, 4422, 4470, 1414, 1146, 6171, 537, 1189, 2959, 2485, 411, 1088, 2518, 108, 3879, 4619, 4982, 5097, 3870, 2748, 167, 3990, 3140, 1281, 2442, 428, 1225, 1388, 2257, 1980, 3114, 4007, 456, 7310, 750], [1528, 687, 49, 4107, 376, 3070, 2, 1977, 5138, 2820, 33, 3143, 5868, 2353, 1907, 6220, 857, 6915, 4110, 1134, 304, 3648, 1332, 3972, 1005, 1789, 682, 641, 5164, 7966, 1153, 2533, 3127, 4646, 2862, 6234, 1458, 202, 431, 1464, 722, 3510, 699, 499, 1576, 3743, 1540, 5448, 17, 1158, 278, 4846, 5132, 3133, 3189, 2647, 2383, 1905, 5440, 48, 211, 2461, 73, 1006, 3681, 5038, 6485, 2575, 3252, 3153, 38, 3956, 5005, 2828, 325, 1188, 571, 842, 2324, 2671, 1478, 4128, 4201, 3964]]\n", 224 | "\n", 225 | "Sample skipgram input:\n", 226 | "Word Target: 986\n", 227 | "Word Context: 12\n", 228 | "Label: 1\n" 229 | ] 230 | } 231 | ], 232 | "source": [ 233 | "print(\"Sample sequence\")\n", 234 | "print(swissprot['Sequence'].iloc[0])\n", 235 | "print(\"\")\n", 236 | "print(\"Convert sequence to list of kmers\")\n", 237 | "print(kmerlists(swissprot['Sequence'].iloc[0]))\n", 238 | "print(\"\")\n", 239 | "print(\"Convert kmers to their index on the embedding\")\n", 240 | "print(kmersindex(swissprot['Sequence'].iloc[:1]))\n", 241 | "print(\"\")\n", 242 | "testskipgrams = generateskipgrams(swissprot['Sequence'].iloc[:1])\n", 243 | "print(\"Sample skipgram input:\")\n", 244 | "print(\"Word Target:\", testskipgrams[0][0][0])\n", 245 | "print(\"Word Context:\", testskipgrams[0][1][0])\n", 246 | "print(\"Label:\", testskipgrams[0][2][0])" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 3, 252 | "metadata": {}, 253 | "outputs": [ 254 | { 255 | "name": "stdout", 256 | "output_type": "stream", 257 | "text": [ 258 | "____________________________________________________________________________________________________\n", 259 | "Layer (type) Output Shape Param # Connected to \n", 260 | "====================================================================================================\n", 261 | "input_1 (InputLayer) (None, 1) 0 \n", 262 | "____________________________________________________________________________________________________\n", 263 | "input_2 (InputLayer) (None, 1) 0 \n", 264 | "____________________________________________________________________________________________________\n", 265 | "embedding (Embedding) (None, 1, 100) 904800 input_1[0][0] \n", 266 | " input_2[0][0] \n", 267 | "____________________________________________________________________________________________________\n", 268 | "reshape_1 (Reshape) (None, 100, 1) 0 embedding[0][0] \n", 269 | "____________________________________________________________________________________________________\n", 270 | "reshape_2 (Reshape) (None, 100, 1) 0 embedding[1][0] \n", 271 | "____________________________________________________________________________________________________\n", 272 | "merge_2 (Merge) (None, 1, 1) 0 reshape_1[0][0] \n", 273 | " reshape_2[0][0] \n", 274 | "____________________________________________________________________________________________________\n", 275 | "reshape_3 (Reshape) (None, 1) 0 merge_2[0][0] \n", 276 | "____________________________________________________________________________________________________\n", 277 | "dense_1 (Dense) (None, 1) 2 reshape_3[0][0] \n", 278 | "====================================================================================================\n", 279 | "Total params: 904,802\n", 280 | "Trainable params: 904,802\n", 281 | "Non-trainable params: 0\n", 282 | "____________________________________________________________________________________________________\n" 283 | ] 284 | }, 285 | { 286 | "name": "stderr", 287 | "output_type": "stream", 288 | "text": [ 289 | "c:\\programdata\\anaconda3\\envs\\tensorflow\\lib\\site-packages\\ipykernel_launcher.py:15: UserWarning: The `merge` function is deprecated and will be removed after 08/2017. Use instead layers from `keras.layers.merge`, e.g. `add`, `concatenate`, etc.\n", 290 | " from ipykernel import kernelapp as app\n", 291 | "c:\\programdata\\anaconda3\\envs\\tensorflow\\lib\\site-packages\\keras\\legacy\\layers.py:458: UserWarning: The `Merge` layer is deprecated and will be removed after 08/2017. Use instead layers from `keras.layers.merge`, e.g. `add`, `concatenate`, etc.\n", 292 | " name=name)\n", 293 | "c:\\programdata\\anaconda3\\envs\\tensorflow\\lib\\site-packages\\ipykernel_launcher.py:18: UserWarning: The `merge` function is deprecated and will be removed after 08/2017. Use instead layers from `keras.layers.merge`, e.g. `add`, `concatenate`, etc.\n", 294 | "c:\\programdata\\anaconda3\\envs\\tensorflow\\lib\\site-packages\\ipykernel_launcher.py:24: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=[']])\n", 448 | " return kmersvec\n", 449 | "\n", 450 | "def formatprotvecs(protvecs):\n", 451 | " #Format protvecs for classifier inputs by transposing the matrix\n", 452 | " protfeatures = []\n", 453 | " for i in range(100):\n", 454 | " protfeatures.append([vec[i] for vec in protvecs])\n", 455 | " protfeatures = np.array(protfeatures).reshape(len(protvecs),len(protfeatures))\n", 456 | " return protfeatures\n", 457 | "\n", 458 | "def formatprotvecsnormalized(protvecs):\n", 459 | " #Formatted protvecs with feature normalization\n", 460 | " protfeatures = []\n", 461 | " for i in range(100):\n", 462 | " tempvec = [vec[i] for vec in protvecs]\n", 463 | " mean = np.mean(tempvec)\n", 464 | " var = np.var(tempvec)\n", 465 | " protfeatures.append([(vec[i]-mean)/var for vec in protvecs])\n", 466 | " protfeatures = np.array(protfeatures).reshape(len(protvecs),len(protfeatures))\n", 467 | " return protfeatures\n", 468 | "\n", 469 | "def sequences2protvecsCSV(filename, seqs, kmersdict=threemersidx, embeddingweights=embeddingMat):\n", 470 | " #Convert a list of sequences to protvecs and save protvecs to a csv file\n", 471 | " #ARGUMENTS;\n", 472 | " #filename: string, name of csv file to save to, i.e. \"sampleprotvecs.csv\"\n", 473 | " #seqs: list, list of amino acid sequences\n", 474 | " #kmersdict: dict to look up index of kmer on embedding, default: Asgari's embedding index\n", 475 | " #embeddingweights: 2D list or np.array, embedding vectors, default: Asgari's embedding vectors\n", 476 | "\n", 477 | " swissprotvecs = Parallel(n_jobs=num_cores)(delayed(protvec)(kmersdict, seq, embeddingweights) for seq in seqs)\n", 478 | " swissprotvecsdf = pd.DataFrame(formatprotvecs(swissprotvecs))\n", 479 | " swissprotvecsdf.to_csv(filename, index=False)\n", 480 | " return swissprotvecsdf" 481 | ] 482 | }, 483 | { 484 | "cell_type": "code", 485 | "execution_count": 33, 486 | "metadata": {}, 487 | "outputs": [ 488 | { 489 | "data": { 490 | "text/html": [ 491 | "
\n", 492 | "\n", 505 | "\n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | "
0123456789...90919293949596979899
0-15.906601-15.786575-12.858975-20.429255-25.811314-1.425858-0.2732611.218515-2.768131-1.975148...-2.3538961.683457-2.855346-4.200350-6.863898-8.9754500.160262-3.139883-12.887104-15.625758
1-1.501441-2.5760110.5726823.0178903.397684-8.050783-7.461454-4.145822-1.219607-9.140709...0.2796024.0740853.937202-5.4397270.134528-1.727660-6.800861-4.274481-0.687948-6.683564
24.5716416.087940-0.0931531.7883963.1435070.5875364.3431743.2816484.5941396.544705...-5.428951-8.151144-4.075138-3.999045-6.826098-19.353062-13.091738-14.026761-31.865056-33.461093
31.18622115.0313104.099640-7.4293060.67338912.43583228.05398813.09455918.31983123.192327...3.3382617.016894-0.6241160.0523453.364417-5.378887-2.588111-7.852132-9.400000-9.117914
4-11.274578-12.735979-9.380683-24.705562-24.733077-0.129001-6.0325882.4547676.5139421.185852...-4.999941-8.337236-1.4823022.712613-1.11347418.63433228.02064713.63500225.41957733.273547
\n", 655 | "

5 rows × 100 columns

\n", 656 | "
" 657 | ], 658 | "text/plain": [ 659 | " 0 1 2 3 4 5 \\\n", 660 | "0 -15.906601 -15.786575 -12.858975 -20.429255 -25.811314 -1.425858 \n", 661 | "1 -1.501441 -2.576011 0.572682 3.017890 3.397684 -8.050783 \n", 662 | "2 4.571641 6.087940 -0.093153 1.788396 3.143507 0.587536 \n", 663 | "3 1.186221 15.031310 4.099640 -7.429306 0.673389 12.435832 \n", 664 | "4 -11.274578 -12.735979 -9.380683 -24.705562 -24.733077 -0.129001 \n", 665 | "\n", 666 | " 6 7 8 9 ... 90 91 \\\n", 667 | "0 -0.273261 1.218515 -2.768131 -1.975148 ... -2.353896 1.683457 \n", 668 | "1 -7.461454 -4.145822 -1.219607 -9.140709 ... 0.279602 4.074085 \n", 669 | "2 4.343174 3.281648 4.594139 6.544705 ... -5.428951 -8.151144 \n", 670 | "3 28.053988 13.094559 18.319831 23.192327 ... 3.338261 7.016894 \n", 671 | "4 -6.032588 2.454767 6.513942 1.185852 ... -4.999941 -8.337236 \n", 672 | "\n", 673 | " 92 93 94 95 96 97 98 \\\n", 674 | "0 -2.855346 -4.200350 -6.863898 -8.975450 0.160262 -3.139883 -12.887104 \n", 675 | "1 3.937202 -5.439727 0.134528 -1.727660 -6.800861 -4.274481 -0.687948 \n", 676 | "2 -4.075138 -3.999045 -6.826098 -19.353062 -13.091738 -14.026761 -31.865056 \n", 677 | "3 -0.624116 0.052345 3.364417 -5.378887 -2.588111 -7.852132 -9.400000 \n", 678 | "4 -1.482302 2.712613 -1.113474 18.634332 28.020647 13.635002 25.419577 \n", 679 | "\n", 680 | " 99 \n", 681 | "0 -15.625758 \n", 682 | "1 -6.683564 \n", 683 | "2 -33.461093 \n", 684 | "3 -9.117914 \n", 685 | "4 33.273547 \n", 686 | "\n", 687 | "[5 rows x 100 columns]" 688 | ] 689 | }, 690 | "execution_count": 33, 691 | "metadata": {}, 692 | "output_type": "execute_result" 693 | } 694 | ], 695 | "source": [ 696 | "sequences2protvecsCSV(\"testprotvecs.csv\", swissprot['Sequence'][:5])" 697 | ] 698 | }, 699 | { 700 | "cell_type": "markdown", 701 | "metadata": {}, 702 | "source": [ 703 | "## Classification of Protein Function Category" 704 | ] 705 | }, 706 | { 707 | "cell_type": "code", 708 | "execution_count": null, 709 | "metadata": { 710 | "collapsed": true 711 | }, 712 | "outputs": [], 713 | "source": [ 714 | "from sklearn.preprocessing import LabelBinarizer\n", 715 | "from sklearn.model_selection import train_test_split, StratifiedKFold\n", 716 | "from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, AdaBoostClassifier, ExtraTreesClassifier\n", 717 | "from sklearn.svm import SVC\n", 718 | "from sklearn.naive_bayes import GaussianNB, MultinomialNB, BernoulliNB\n", 719 | "from sklearn.linear_model import LogisticRegression\n", 720 | "from sklearn.metrics import confusion_matrix, classification_report,log_loss\n", 721 | "from sklearn.model_selection import cross_val_score\n", 722 | "from scipy.ndimage.measurements import center_of_mass, label\n", 723 | "from skimage.measure import regionprops\n", 724 | "from sklearn.cross_validation import ShuffleSplit\n", 725 | "from sklearn.grid_search import GridSearchCV\n", 726 | "from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score\n", 727 | "from scipy.stats import percentileofscore" 728 | ] 729 | }, 730 | { 731 | "cell_type": "code", 732 | "execution_count": null, 733 | "metadata": { 734 | "collapsed": true 735 | }, 736 | "outputs": [], 737 | "source": [ 738 | "lb = LabelBinarizer()\n", 739 | "binlab=lb.fit_transform(table['labels'])" 740 | ] 741 | }, 742 | { 743 | "cell_type": "code", 744 | "execution_count": null, 745 | "metadata": { 746 | "collapsed": true 747 | }, 748 | "outputs": [], 749 | "source": [ 750 | "n_splits=10\n", 751 | "kfold=StratifiedKFold(n_splits=n_splits, shuffle=True)\n", 752 | "\n", 753 | "models=[RandomForestClassifier(),\n", 754 | " GradientBoostingClassifier(),]\n", 755 | "name=[\"Random Forest\", \"Gradient Boosting\"]\n", 756 | "\n", 757 | "predictedmodels={}\n", 758 | "\n", 759 | "for nm, clf in zip(name[:-1], models[:-1]):\n", 760 | " print(nm)\n", 761 | " predicted=[]\n", 762 | " for train,test in kfold.split(featuretable,binlab[:,0]):\n", 763 | " scores=cross_val_score(clf,featuretable, binlab, cv=StratifiedKFold(n_splits=n_splits, shuffle=True), n_jobs=-1, scoring='neg_log_loss')\n", 764 | " print(\"Cross-validated logloss\",-np.mean(scores))\n", 765 | " print(\"---------------------------------------\")\n", 766 | "\n" 767 | ] 768 | }, 769 | { 770 | "cell_type": "code", 771 | "execution_count": null, 772 | "metadata": { 773 | "collapsed": true 774 | }, 775 | "outputs": [], 776 | "source": [ 777 | "clf.fit(featuretable, binlab)" 778 | ] 779 | }, 780 | { 781 | "cell_type": "code", 782 | "execution_count": null, 783 | "metadata": { 784 | "collapsed": true 785 | }, 786 | "outputs": [], 787 | "source": [ 788 | "featuretable = pd.DataFrame(np.array(protfeatures).reshape(9173,100))" 789 | ] 790 | }, 791 | { 792 | "cell_type": "code", 793 | "execution_count": null, 794 | "metadata": { 795 | "collapsed": true 796 | }, 797 | "outputs": [], 798 | "source": [ 799 | "#Binary classification of single function\n", 800 | "intsamples = table.loc[table['labels']=='INF']\n", 801 | "intsamples['binarylabel'] = [1]*len(intsamples)\n", 802 | "nonint = table.loc[table['labels'] != 'INF']\n", 803 | "nonint['binarylabel'] = [0]*len(nonint)\n", 804 | "intsamples = intsamples.append(nonint.sample(frac=len(intsamples)/len(nonint)))\n", 805 | "intsamples = intsamples.sample(frac=1)" 806 | ] 807 | }, 808 | { 809 | "cell_type": "code", 810 | "execution_count": null, 811 | "metadata": { 812 | "collapsed": true 813 | }, 814 | "outputs": [], 815 | "source": [ 816 | "features = formatprotvecs(intsamples['ProtVecs'].values)\n" 817 | ] 818 | }, 819 | { 820 | "cell_type": "code", 821 | "execution_count": null, 822 | "metadata": { 823 | "collapsed": true 824 | }, 825 | "outputs": [], 826 | "source": [ 827 | "models=[LogisticRegression(C=0.1),\n", 828 | " RandomForestClassifier(),\n", 829 | " GradientBoostingClassifier(),\n", 830 | " SVC(C=0.02,kernel='rbf', probability=True)]\n", 831 | "name=[\"Logistic Regression\",\"Random Forest\", \"Gradient Boosting\",\"SVM rbf kernel\"]\n", 832 | "\n", 833 | "predictedmodels={}\n", 834 | "\n", 835 | "for nm, clf in zip(name, models):\n", 836 | " print(nm)\n", 837 | " scores=cross_val_score(clf,features, intsamples['binarylabel'], cv=StratifiedKFold(n_splits=n_splits, shuffle=True), n_jobs=-1, scoring='neg_log_loss')\n", 838 | " print(\"Cross-validated logloss\",-np.mean(scores))\n", 839 | " print(\"---------------------------------------\")\n", 840 | " \n", 841 | "\n" 842 | ] 843 | }, 844 | { 845 | "cell_type": "code", 846 | "execution_count": null, 847 | "metadata": { 848 | "collapsed": true 849 | }, 850 | "outputs": [], 851 | "source": [ 852 | "import matplotlib.pyplot as plt\n", 853 | "\n", 854 | "def BinaryClassification(x,y):\n", 855 | " n_splits=10\n", 856 | " kfold=StratifiedKFold(n_splits=n_splits, shuffle=True)\n", 857 | " models=[LogisticRegression(C=0.1),\n", 858 | " RandomForestClassifier(),\n", 859 | " GradientBoostingClassifier(),\n", 860 | " SVC(C=1,kernel='rbf')]\n", 861 | " name=[\"Logistic Regression\", \"Random Forest\", \"Gradient Boosting\", \"SVM with rbf kernel\"]\n", 862 | "\n", 863 | " predictedmodels={}\n", 864 | "\n", 865 | " for nm, clf in zip(name[-1:], models[-1:]):\n", 866 | " print(nm)\n", 867 | " predicted=[]\n", 868 | " labelcv=[]\n", 869 | " for train,test in kfold.split(x, y):\n", 870 | " clf.fit(x[train],y[train])\n", 871 | " predicted.append(clf.predict(x[test]))\n", 872 | " labelcv.append(y[test])\n", 873 | " #scores=cross_val_score(clf,x, y, cv=StratifiedKFold(n_splits=n_splits, shuffle=True), n_jobs=-1, scoring='neg_log_loss')\n", 874 | " predicted=np.concatenate(np.array(predicted),axis=0)\n", 875 | " labelcv=np.concatenate(np.array(labelcv),axis=0)\n", 876 | " predictedmodels[nm]=predicted\n", 877 | " #roc=roc_curve(labelcv,predicted)\n", 878 | " #print(\"Average precision score:\", average_precision_score(labelcv,predicted))\n", 879 | " #print(\"Area under curve:\", auc(roc[0],roc[1]))\n", 880 | " #plt.plot(roc[0],roc[1])\n", 881 | " #print(-scores)\n", 882 | " print(classification_report(labelcv,predicted))\n", 883 | " print(confusion_matrix(labelcv,predicted))\n", 884 | " print(\"Cross-validated logloss\",-np.mean(scores))\n", 885 | " print(\"---------------------------------------\")\n", 886 | " #plt.plot(rocrandom[0],rocrandom[1])\n", 887 | " #plt.title('ROC')\n", 888 | " #plt.ylabel('TPrate')\n", 889 | " #plt.xlabel('FPrate')\n", 890 | " #plt.legend(name)\n", 891 | " #plt.savefig(\"clfroccomparison.png\",dpi=300)\n", 892 | " #plt.show()" 893 | ] 894 | }, 895 | { 896 | "cell_type": "code", 897 | "execution_count": null, 898 | "metadata": { 899 | "collapsed": true 900 | }, 901 | "outputs": [], 902 | "source": [ 903 | "#Set up accuracy, sensitivty, specificity as evaluation scorews\n", 904 | "#Normalize features\n", 905 | "#Try Asgari 2015 initial weights\n", 906 | "#Try regularization\n" 907 | ] 908 | }, 909 | { 910 | "cell_type": "markdown", 911 | "metadata": {}, 912 | "source": [ 913 | "## Replicate Asgari 2015 Protein family classification results" 914 | ] 915 | }, 916 | { 917 | "cell_type": "code", 918 | "execution_count": 5, 919 | "metadata": { 920 | "collapsed": true 921 | }, 922 | "outputs": [], 923 | "source": [ 924 | "swissprot = pd.read_csv(\"family_classification_metadata.tab\", sep='\\t')\n", 925 | "swissprot['Sequence'] = pd.read_csv(\"family_classification_sequences.tab\", sep='\\t')" 926 | ] 927 | }, 928 | { 929 | "cell_type": "code", 930 | "execution_count": null, 931 | "metadata": { 932 | "collapsed": true 933 | }, 934 | "outputs": [], 935 | "source": [ 936 | "swissprot.loc[swissprot['FamilyDescription'] == '50S ribosome-binding GTPase']" 937 | ] 938 | }, 939 | { 940 | "cell_type": "code", 941 | "execution_count": null, 942 | "metadata": { 943 | "collapsed": true 944 | }, 945 | "outputs": [], 946 | "source": [ 947 | "del swissprotseq" 948 | ] 949 | }, 950 | { 951 | "cell_type": "code", 952 | "execution_count": null, 953 | "metadata": { 954 | "collapsed": true 955 | }, 956 | "outputs": [], 957 | "source": [ 958 | "concordance = [table.iloc[i] for i in range(len(table)) if table['Protein'].iloc[i] in swissprot['Sequence'].values]" 959 | ] 960 | }, 961 | { 962 | "cell_type": "code", 963 | "execution_count": null, 964 | "metadata": { 965 | "collapsed": true 966 | }, 967 | "outputs": [], 968 | "source": [ 969 | "results = Parallel(n_jobs=num_cores)(delayed(generateskipgrams)(kmers) for kmers in kmersindex[20000:])\n", 970 | "word_target = []\n", 971 | "word_context = []\n", 972 | "labels = []\n", 973 | "for sample in results:\n", 974 | " if type(sample) == tuple:\n", 975 | " word_target += sample[0]\n", 976 | " word_context += sample[1]\n", 977 | " labels += sample[2]\n", 978 | "del results" 979 | ] 980 | }, 981 | { 982 | "cell_type": "code", 983 | "execution_count": null, 984 | "metadata": { 985 | "collapsed": true 986 | }, 987 | "outputs": [], 988 | "source": [ 989 | "#Sample 50S ribosome-binding GTPase and equal amount of negative cases\n", 990 | "def SampleBinaryClassification(table, function,ProtVecs):\n", 991 | " pos = table.loc[table['FamilyDescription'] == function]\n", 992 | " neg = table.loc[table['FamilyDescription'] != function]\n", 993 | " pos['binarylabel'] = np.ones(len(pos), dtype=bool)\n", 994 | " neg = neg.sample(frac=len(pos)/len(neg))\n", 995 | " neg['binarylabel'] = np.zeros(len(neg), dtype=bool)\n", 996 | " pos = pos.append(neg)\n", 997 | " pos = pos.sample(frac=1)\n", 998 | " #print(\"Generating ProtVecs\")\n", 999 | " #ProtVecs = Parallel(n_jobs=num_cores)(delayed(protvec)(len(threemers), threemersidx, seq, embeddingMat) for seq in pos['Sequence'])\n", 1000 | " #pos['ProtVecs'] = ProtVecs\n", 1001 | " features = formatprotvecs(ProtVecs)\n", 1002 | " BinaryClassification(features,pfambinary['binarylabel'].values)\n", 1003 | " return pos\n" 1004 | ] 1005 | }, 1006 | { 1007 | "cell_type": "code", 1008 | "execution_count": null, 1009 | "metadata": { 1010 | "collapsed": true 1011 | }, 1012 | "outputs": [], 1013 | "source": [ 1014 | "#ProtVecs = Parallel(n_jobs=num_cores)(delayed(protvec)(threemersidx, seq, embeddingMat) for seq in pfambinary['Sequence'])" 1015 | ] 1016 | }, 1017 | { 1018 | "cell_type": "code", 1019 | "execution_count": null, 1020 | "metadata": { 1021 | "collapsed": true 1022 | }, 1023 | "outputs": [], 1024 | "source": [ 1025 | "SampleBinaryClassification(swissprot,'50S ribosome-binding GTPase',famclass.iloc[pfambinary.index].values)" 1026 | ] 1027 | }, 1028 | { 1029 | "cell_type": "code", 1030 | "execution_count": null, 1031 | "metadata": { 1032 | "collapsed": true, 1033 | "scrolled": false 1034 | }, 1035 | "outputs": [], 1036 | "source": [ 1037 | "features = formatprotvecsnormalized(famclass.iloc[pfambinary.index].values)\n", 1038 | "#labels = LabelBinarizer().fit_transform(pfambinary['binarylabel'].values)\n", 1039 | "BinaryClassification(features,pfambinary['binarylabel'].values)" 1040 | ] 1041 | }, 1042 | { 1043 | "cell_type": "code", 1044 | "execution_count": null, 1045 | "metadata": { 1046 | "collapsed": true 1047 | }, 1048 | "outputs": [], 1049 | "source": [ 1050 | "features = formatprotvecsnormalized(famclass.iloc[pfambinary.index].values)\n", 1051 | "#labels = LabelBinarizer().fit_transform(pfambinary['binarylabel'].values)\n", 1052 | "BinaryClassification(features,pfambinary['binarylabel'].values)" 1053 | ] 1054 | }, 1055 | { 1056 | "cell_type": "code", 1057 | "execution_count": null, 1058 | "metadata": { 1059 | "collapsed": true 1060 | }, 1061 | "outputs": [], 1062 | "source": [ 1063 | "labels = pfambinary['binarylabel'].values\n", 1064 | "def fit_model(X, y, clf):\n", 1065 | " cv_sets = ShuffleSplit(X.shape[0], n_iter = 5, test_size = 0.20, random_state = 42)\n", 1066 | " params = {'C':np.arange(10,100),\n", 1067 | " 'gamma':np.arange(1e-2,1e-1)}\n", 1068 | " grid = GridSearchCV(clf, params, cv=cv_sets, n_jobs=-1)\n", 1069 | " grid = grid.fit(X, y)\n", 1070 | " return grid.best_params_, grid.best_score_, grid.best_estimator_\n", 1071 | "\n", 1072 | "best_params, best_score, optimal_svm=fit_model(features,labels,SVC())\n", 1073 | "\n", 1074 | "print(\"The best parameters are %s with a score of %0.2f\"\n", 1075 | " % (best_params, best_score))\n", 1076 | "print(optimal_svm)\n", 1077 | "\n", 1078 | "name=[\"Optimized SVM\"]\n", 1079 | "print(name)\n", 1080 | "#scores=cross_val_score(optimal_gb,inputfeatures[featurelist], malignantlabel, cv=5, scoring='neg_log_loss')\n", 1081 | "#print(-scores)\n", 1082 | "#print(\"Cross-validated logloss\",-np.mean(scores))\n", 1083 | "print(\"---------------------------------------\")\n", 1084 | "clf=optimal_svm\n", 1085 | "clf.fit(features[train],labels[train])\n", 1086 | "print(classification_report(labels[test],clf.predict(features[test])))\n", 1087 | "print(confusion_matrix(labels[test],clf.predict(features[test])))\n", 1088 | "#roc=roc_curve(Ytest,clf.predict_proba(Xtest[featurelist])[:,1])\n", 1089 | "#print(clf.feature_importances)\n", 1090 | "#ROC curve\n", 1091 | "#plt.plot(roc[0],roc[1], alpha=0.5)\n", 1092 | "#plt.plot(rocrandom[0],rocrandom[1])\n", 1093 | "\n", 1094 | "#scores=cross_val_score(GradientBoostingClassifier(),inputfeatures[featurelist], malignantlabel, cv=5, scoring='neg_log_loss')\n", 1095 | "#print(classification_report(Ytest,model.predict(Xtest[featurelist])))\n", 1096 | "#print(-scores)\n", 1097 | "#print(\"Cross-validated logloss\",-np.mean(scores))\n", 1098 | "#print(\"---------------------------------------\")\n", 1099 | "#clf=SVC()\n", 1100 | "#clf.fit(Xtrain[featurelist],Ytrain)\n", 1101 | "#roc=roc_curve(Ytest,clf.predict_proba(Xtest[featurelist])[:,1])" 1102 | ] 1103 | }, 1104 | { 1105 | "cell_type": "code", 1106 | "execution_count": null, 1107 | "metadata": { 1108 | "collapsed": true, 1109 | "scrolled": true 1110 | }, 1111 | "outputs": [], 1112 | "source": [ 1113 | "#Collect cherry annotated dataset\n", 1114 | "#Import into R and use bioconductor to translate DNAseqs to AASeqs\n", 1115 | "#Get protvecs for each AAseq\n", 1116 | "#Load into classifier to determine prediction rate for each category" 1117 | ] 1118 | } 1119 | ], 1120 | "metadata": { 1121 | "kernelspec": { 1122 | "display_name": "Python 3", 1123 | "language": "python", 1124 | "name": "python3" 1125 | }, 1126 | "language_info": { 1127 | "codemirror_mode": { 1128 | "name": "ipython", 1129 | "version": 3 1130 | }, 1131 | "file_extension": ".py", 1132 | "mimetype": "text/x-python", 1133 | "name": "python", 1134 | "nbconvert_exporter": "python", 1135 | "pygments_lexer": "ipython3", 1136 | "version": "3.5.4" 1137 | } 1138 | }, 1139 | "nbformat": 4, 1140 | "nbformat_minor": 2 1141 | } 1142 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # Protvec: Amino Acid Embedding Representation for Machine Learning Features 2 | 3 | ## Objectives 4 | 1. Extract features from amino acid sequences for machine learning 5 | 2. Use features to predict protein family and other structural properties 6 | 7 | ## Requirements 8 | * anaconda3 9 | * Python 3.4 10 | * Tensorflow 11 | * Keras 12 | * joblib - for multiprocessing - pip install joblib 13 | 14 | ## Abstract 15 | This project attempts to reproduce the results from [Asgari 2015](http://journals.plos.org/plosone/article?id=10.1371/journal.pone.0141287) and to expand it to phage sequences and their protein families. Currently, Asgari's classification of protein families can be reproduced with his using his [trained embedding.](https://github.com/ehsanasgari/Deep-Proteomics). However, his results cannot be reproduced with current attempts to train using the skip-gram negative sampling method detailed in [this tutorial.](http://adventuresinmachinelearning.com/word2vec-keras-tutorial/) Training samples have been attempted with the SwissProt database. 16 | 17 | ## Introduction 18 | Predicting protein function with machine learning methods require informative features that is extracted from data. A natural language processing (NLP) technique, known as Word2Vec is used to represent a word by its context with a vector that encodes for the probability a context would occur for a word. These vectors are effective at representing meanings of words since words with similar meanings would have similar contexts. For example, the word cat and kitten would have similar contexts that they are used in since they have very similar meanings. These words would thus have very similar vectors. 19 | 20 | ## Methods 21 | 1. Preprocessing 22 | 1. Load dataset containing protein amino acid sequences and Asgari's embedding 23 | 2. [Convert sequences to three lists of non-overlapping 3-mer words](https://www.researchgate.net/profile/Mohammad_Mofrad/publication/283644387/figure/fig4/AS:341292040114179@1458381771303/Protein-sequence-splitting-In-order-to-prepare-the-training-data-each-protein-sequence.png) 24 | 3. Convert 3-mers to numerical encoding using kmer indicies from Asgari's embedding (row dimension) 25 | 4. Generate skipgrams with [Keras function](https://keras.io/preprocessing/sequence/) 26 | Output: [target word, context word](http://mccormickml.com/assets/word2vec/training_data.png), label 27 | Label refers to true or false target/context pairing generated for the negative sampling technique 28 | 2. Training embedding 29 | 1. Create negative sampling skipgram model with Keras [using technique from this tutorial](http://adventuresinmachinelearning.com/word2vec-keras-tutorial/) 30 | 3. Generate ProtVecs from embedding for a given protein sequence 31 | 1. Break protein sequence to list of kmers 32 | 2. Convert kmers to vectors by taking the dot product of its one hot vector with the embedding 33 | 3. Sum up all vectors for all kmers for a single vector representation for a protein (length 100) 34 | 4. Classify protein function with ProtVec features (results currently not working, refer to R script) 35 | 1. Use protvecs as training features 36 | 2. Use pfam as labels 37 | 3. For a given pfam classification, perform binary classification with all of its positive samples and randomly sample an equal amount of negative samples 38 | 4. Train SVM model 39 | 40 | ## Resources 41 | 1. Intuition behind Word2Vec http://mccormickml.com/2016/04/19/word2vec-tutorial-the-skip-gram-model/ 42 | 2. Tutorial followed for implementation of skip-gram negative sampling (includes code) http://adventuresinmachinelearning.com/word2vec-keras-tutorial/ 43 | 3. Introduction to protein function prediction 44 | http://biofunctionprediction.org/cafa-targets/Introduction_to_protein_prediction.pdf 45 | 46 | ## Author 47 | Mike Huang 48 | huangjmike@gmail.com -------------------------------------------------------------------------------- /pfamclassification.R: -------------------------------------------------------------------------------- 1 | library(caret) 2 | library(kernlab) 3 | library(e1071) 4 | library(Rtsne) 5 | library(ggplot2) 6 | setwd('/phagegenes/genevec/') 7 | 8 | #Load SwissProt 2015 pfam and protvec 9 | pfam <- read.delim("family_classification_metadata.tab") #pfam annotations for SwissProt 2015 dataset 10 | protvec <- read.csv("family_classification_protVec.csv", header=FALSE) #Original protvecs from Asgari for SwissProt 2015 dataset 11 | 12 | #Load Cherry's NCBI phage protein functions and protvecs 13 | pfam<- read.csv("cherryall.csv") #All of Cherry's NCBI data, 120949 phage sequences 14 | protvec <- read.csv("CherryAllProtVecs.csv") #Protvecs for Cherry's NCBI data, with Asgari's embedding 15 | 16 | 17 | ClassifyQuery <- function(strQuery, pfamColumn, negQuery="", visType='tsne', perplexity=10, savePlot = FALSE){ 18 | posind = grep(strQuery, pfamColumn, ignore.case=TRUE) 19 | pos <- pfam[posind,] 20 | pos <- cbind(pos,data.frame(label=rep(1,nrow(pos)))) 21 | View(pos) 22 | neg <- pfam[-posind,] 23 | #Negative samples from phage 24 | if(nchar(negQuery)>0){ 25 | neg <- neg[grep(negQuery,neg$FamilyDescription, ignore.case=TRUE),] } 26 | neg <- cbind(neg,data.frame(label=rep(0,nrow(neg)))) 27 | dat <- rbind(pos,neg[sample(nrow(neg),size=nrow(pos)),]) 28 | dat <- dat[sample(nrow(dat),size=nrow(dat)),] 29 | dat$label <- as.factor(dat$label) 30 | features <- protvec[rownames(dat),] 31 | pca <- prcomp(features) # principal components analysis using correlation matrix 32 | features <- cbind(features,pca$x[,1]) 33 | #dat<- dat[!is.na(features$X0),] 34 | #features <- features[!is.na(features$X0),] 35 | inTrain <- createDataPartition(y=dat$label, p=0.8, list=FALSE) 36 | Xtrain <- features[inTrain,] 37 | Ytrain <- dat$label[inTrain] 38 | Xtest <- features[-inTrain,] 39 | Ytest <- dat$label[-inTrain] 40 | 41 | if(visType == 'tsne'){ 42 | tsne <- Rtsne(pca$x[,1:5], dims = 2, perplexity=perplexity, verbose=TRUE, max_iter = 500, check_duplicates=FALSE) 43 | #tsne <- Rtsne(features, dims = 2, perplexity=5, verbose=TRUE, max_iter = 500, check_duplicates=FALSE) 44 | tsnedf <- data.frame(tsne$Y) 45 | colnames(tsnedf) <- c('Xtsne','Ytsne') 46 | dat <- cbind(dat, tsnedf) 47 | chart = ggplot(dat,aes(Xtsne, Ytsne)) + geom_point(aes(color=label),alpha=0.5) + ggtitle(paste0("ProtVec tSNE:", strQuery, "vs non-", strQuery, " phage proteins")) 48 | chart 49 | } 50 | if(visType == 'pca'){ 51 | dat<-cbind(dat, pca$x[,1:2]) 52 | chart = ggplot(dat,aes(PC1, PC2)) + geom_point(aes(color=label),alpha=0.5) + ggtitle(paste0("ProtVec tSNE: proteins vs non-Ci phage proteins")) 53 | chart 54 | } 55 | if(savePlot == TRUE){ 56 | ggsave(paste0(strQuery, "vs non-", strQuery, "Proteins Asgari Embedding tsne.png"),plot=chart, width=7, height=5, units="in") 57 | } 58 | svm <- train(Xtrain,Ytrain,method="svmRadial") 59 | Ypred <- predict(svm, pca$x[-inTrain,1:2]) 60 | Ypred <- predict(svm, Xtest) 61 | confusion <- confusionMatrix(Ypred,Ytest) 62 | print(confusion) 63 | } 64 | 65 | 66 | #Plot visualization of top X pfams 67 | topPFamsVis <- function(pfamscol, topn, remove.names=""){ 68 | 69 | 70 | if(length(remove.names)>1){ 71 | for(i in 1:length(remove.names)){ 72 | pfam <- pfam[-grep(remove.names[i],pfamscol,ignore.case=TRUE),] 73 | } 74 | } 75 | else if(nchar(remove.names)>0){ 76 | pfam <- pfam[-grep(remove.names,pfamscol,ignore.case=TRUE),] 77 | } 78 | 79 | toppfams <- names(sort(table(pfamscol),decreasing=TRUE))[1:topn] 80 | 81 | posind = pfamscol %in% toppfams 82 | features = protvec[posind,] 83 | dat <- pfam[posind,] 84 | pca <- prcomp(features) 85 | tsne <- Rtsne(pca$x[,1:3], dims = 2, perplexity=50, verbose=TRUE, max_iter = 500, check_duplicates=FALSE) 86 | tsnedf <- data.frame(tsne$Y) 87 | colnames(tsnedf) <- c('Xtsne','Ytsne') 88 | dat <- cbind(dat, tsnedf) 89 | dat<- cbind(dat, pca$x[,1:2]) 90 | chart = ggplot(dat,aes(Xtsne, Ytsne)) + geom_point(aes(color=FamilyDescription),alpha=0.5) + ggtitle(paste0("ProtVec tSNE: Top 20 PFam in swissProt")) 91 | chart 92 | ggsave(paste0("Top20PFamSwissProtTrainedwith20millionEpochsPhageData - tsne.png"),plot=chart, width=9, height=5, units="in") 93 | chart = ggplot(dat,aes(PC1, PC2)) + geom_point(aes(color=FamilyDescription),alpha=0.5) + ggtitle(paste0("ProtVec PCA: proteins vs non-Ci phage proteins")) 94 | chart 95 | } 96 | -------------------------------------------------------------------------------- /protvec.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from keras.preprocessing.sequence import skipgrams, pad_sequences, make_sampling_table 4 | from keras.preprocessing.text import hashing_trick 5 | from keras.layers import Embedding, Input, Reshape, Dense, merge 6 | from keras.models import Sequential, Model 7 | from sklearn.manifold import TSNE 8 | from joblib import Parallel, delayed 9 | import multiprocessing 10 | import csv 11 | 12 | 13 | #Load Ehsan Asgari's embeddings 14 | #Source: http://journals.plos.org/plosone/article?id=10.1371/journal.pone.0141287 15 | #Embedding: https://github.com/ehsanasgari/Deep-Proteomics 16 | ehsanEmbed = [] 17 | with open("protVec_100d_3grams.csv") as tsvfile: 18 | tsvreader = csv.reader(tsvfile, delimiter="\t") 19 | for line in tsvreader: 20 | ehsanEmbed.append(line[0].split('\t')) 21 | threemers = [vec[0] for vec in ehsanEmbed] 22 | embeddingMat = [[float(n) for n in vec[1:]] for vec in ehsanEmbed] 23 | threemersidx = {} #generate word to index translation dictionary. Use for kmersdict function arguments. 24 | for i, kmer in enumerate(threemers): 25 | threemersidx[kmer] = i 26 | #Set parameters 27 | vocabsize = len(threemersidx) 28 | window_size = 25 29 | num_cores = multiprocessing.cpu_count() #For parallel computing 30 | 31 | 32 | # Convert sequences to three lists of non overlapping 3mers 33 | def kmerlists(seq): 34 | kmer0 = [] 35 | kmer1 = [] 36 | kmer2 = [] 37 | for i in range(0, len(seq) - 2, 3): 38 | if len(seq[i:i + 3]) == 3: 39 | kmer0.append(seq[i:i + 3]) 40 | i += 1 41 | if len(seq[i:i + 3]) == 3: 42 | kmer1.append(seq[i:i + 3]) 43 | i += 1 44 | if len(seq[i:i + 3]) == 3: 45 | kmer2.append(seq[i:i + 3]) 46 | return [kmer0, kmer1, kmer2] 47 | 48 | 49 | # Same as kmerlists function but outputs an index number assigned to each kmer. Index number is from Asgari's embedding 50 | def kmersindex(seqs, kmersdict): 51 | kmers = [] 52 | for i in range(len(seqs)): 53 | kmers.append(kmerlists(seqs[i])) 54 | kmers = np.array(kmers).flatten().flatten(order='F') 55 | kmersindex = [] 56 | for seq in kmers: 57 | temp = [] 58 | for kmer in seq: 59 | try: 60 | temp.append(kmersdict[kmer]) 61 | except: 62 | temp.append(kmersdict['']) 63 | kmersindex.append(temp) 64 | return kmersindex 65 | 66 | 67 | sampling_table = make_sampling_table(vocabsize) 68 | 69 | 70 | def generateskipgramshelper(kmersindicies): 71 | couples, labels = skipgrams(kmersindicies, vocabsize, window_size=window_size, sampling_table=sampling_table) 72 | if len(couples) == 0: 73 | couples, labels = skipgrams(kmersindicies, vocabsize, window_size=window_size, sampling_table=sampling_table) 74 | if len(couples) == 0: 75 | couples, labels = skipgrams(kmersindicies, vocabsize, window_size=window_size, sampling_table=sampling_table) 76 | else: 77 | word_target, word_context = zip(*couples) 78 | return word_target, word_context, labels 79 | 80 | 81 | def generateskipgrams(seqs, kmersdict=threemersidx): 82 | #Generate skipgrams for training keras embedding model with negative sampling technique 83 | #ARGUMENTS: 84 | # seqs: list, list of amino acid sequences 85 | # kmersdict: dict to look up index of kmer on embedding, default: Asgari's embedding index 86 | kmersidx = kmersindex(seqs, kmersdict) 87 | return Parallel(n_jobs=num_cores)(delayed(generateskipgramshelper)(kmers) for kmers in kmersidx) 88 | 89 | def protvec(kmersdict, seq, embeddingweights): 90 | #Convert seq to three lists of kmers 91 | kmerlist = kmerlists(seq) 92 | kmerlist = [j for i in kmerlist for j in i] 93 | #Convert center kmers to their vector representations 94 | kmersvec = [0]*100 95 | for kmer in kmerlist: 96 | try: 97 | kmersvec = np.add(kmersvec,embeddingweights[kmersdict[kmer]]) 98 | except: 99 | kmersvec = np.add(kmersvec,embeddingweights[kmersdict['']]) 100 | return kmersvec 101 | 102 | def formatprotvecs(protvecs): 103 | protfeatures = [] 104 | for i in range(100): 105 | protfeatures.append([vec[i] for vec in protvecs]) 106 | protfeatures = np.array(protfeatures).reshape(len(protvecs),len(protfeatures)) 107 | return protfeatures 108 | 109 | def formatprotvecsnormalized(protvecs): 110 | protfeatures = [] 111 | for i in range(100): 112 | tempvec = [vec[i] for vec in protvecs] 113 | mean = np.mean(tempvec) 114 | var = np.var(tempvec) 115 | protfeatures.append([(vec[i]-mean)/var for vec in protvecs]) 116 | protfeatures = np.array(protfeatures).reshape(len(protvecs),len(protfeatures)) 117 | return protfeatures 118 | 119 | def sequences2protvecsCSV(filename, seqs, kmersdict=threemersidx, embeddingweights=embeddingMat): 120 | #Convert a list of sequences to protvecs and save protvecs to a csv file 121 | #ARGUMENTS; 122 | #filename: string, name of csv file to save to, i.e. "sampleprotvecs.csv" 123 | #seqs: list, list of amino acid sequences 124 | #kmersdict: dict to look up index of kmer on embedding, default: Asgari's embedding index 125 | #embeddingweights: 2D list or np.array, embedding vectors, default: Asgari's embedding vectors 126 | 127 | swissprotvecs = Parallel(n_jobs=num_cores)(delayed(protvec)(kmersdict, seq, embeddingweights) for seq in seqs) 128 | swissprotvecsdf = pd.DataFrame(formatprotvecs(swissprotvecs)) 129 | swissprotvecsdf.to_csv(filename, index=False) 130 | return swissprotvecsdf 131 | 132 | --------------------------------------------------------------------------------