├── README.md ├── cupy_chembl_example.ipynb ├── get_multitask_data.ipynb ├── name2chembl.ipynb └── train_chembl_multitask.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # ChEMBL blogpost notebooks 2 | 3 | - Multi-task neural network on ChEMBL with PyTorch 1.0 and RDKit: https://chembl.blogspot.com/2019/05/multi-task-neural-network-on-chembl.html. [dataset](http://ftp.ebi.ac.uk/pub/databases/chembl/blog/pytorch_mtl/mt_data.h5) 4 | - CuPy example for CUDA based similarity search in Python: http://chembl.blogspot.com/2019/07/cupy-example-for-cuda-based-similarity.html 5 | -------------------------------------------------------------------------------- /cupy_chembl_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "cupy_chembl_example.ipynb", 7 | "version": "0.3.2", 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "GPU" 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "k_ZTVo0BdhJh", 22 | "colab_type": "text" 23 | }, 24 | "source": [ 25 | "# CuPy CUDA ChEMBL similarity search example\n", 26 | "\n", 27 | "If using NumPy => 1.16 PyTables > 3.44 will be required. Install PyTables 3.51 if you're running this notebook in colab. Any other dependency is already installed in colab's default env.\n", 28 | "\n", 29 | "Remember to restart the runtime after upgrading PyTables with pip!!!\n", 30 | "\n", 31 | "\n", 32 | "You will also need to download ChEMBL25 FPSim2 database file.\n", 33 | "\n", 34 | "Did you know BTW that we recently updated [FPSim2](https://github.com/chembl/FPSim2) replacing it's Cython core by C++ binded with [PyBind11](https://github.com/pybind/pybind11) with improved performance and that it's now also compatible with Windows and Python 3.7?" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": { 40 | "id": "Iye5Wa4jf9MQ", 41 | "colab_type": "text" 42 | }, 43 | "source": [ 44 | "Preflight config, run this cell only the first time" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "metadata": { 50 | "id": "v5AcWu7RfwPZ", 51 | "colab_type": "code", 52 | "colab": {} 53 | }, 54 | "source": [ 55 | "# update PyTables, you'll need to restart the environment in colab after the install!!!\n", 56 | "# !pip install tables==3.5.1\n", 57 | "\n", 58 | "# download ChEMBL25 FPSim2 FP db\n", 59 | "# !wget \"http://ftp.ebi.ac.uk/pub/databases/chembl/fpsim2/chembl_25.h5\"" 60 | ], 61 | "execution_count": 0, 62 | "outputs": [] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "metadata": { 67 | "id": "rB3HLtMr2LsA", 68 | "colab_type": "code", 69 | "colab": {} 70 | }, 71 | "source": [ 72 | "import cupy as cp\n", 73 | "import tables as tb\n", 74 | "import numpy as np" 75 | ], 76 | "execution_count": 0, 77 | "outputs": [] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": { 82 | "id": "THvf0361Bl4v", 83 | "colab_type": "text" 84 | }, 85 | "source": [ 86 | "# Load FPSim2 fingerprint database\n", 87 | "\n", 88 | "\n", 89 | "fps variable contains fingerprints (2048 bit hashed Morgan, radius 2) for all ChEMBL25 database molecules. Each row represents a molecule and it's structure is the following:\n", 90 | "\n", 91 | "\n", 92 | "```\n", 93 | "array([84419, 0,140737488355328, 17592186044416, 1024, 1099549376512, 0, 0, 0, 0, 0, 9007199254741248, 0,16777216, 0, 2305843009213693952, 0, 1073741824, 0, 0, 2199023255552, 0, 0, 0, 0, 0, 0, 32, 0, 0, 34359738372, 0, 0, 15], dtype=uint64)\n", 94 | "```\n", 95 | "First array's element is the ChEMBL molregno and last one is the count of ON bits in it's fingerprint (popcount). The 32 values in between are the 2048 fingerprint bits grouped as 64bit unsigned integers.\n", 96 | "\n", 97 | "Molecules in FP db are sorted by popcount, which is needed to apply the bounds for sublinear time found in this classic paper: [10.1021/ci600358f](https://doi.org/10.1021/ci600358f)\n" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "metadata": { 103 | "id": "wyJdmMupFXl2", 104 | "colab_type": "code", 105 | "colab": {} 106 | }, 107 | "source": [ 108 | "# using same FPsim2 ChEMBL FP database :)\n", 109 | "fp_filename = \"chembl_25.h5\"\n", 110 | "with tb.open_file(fp_filename, mode=\"r\") as fp_file:\n", 111 | " fps = fp_file.root.fps[:]\n", 112 | " num_fields = len(fps[0])\n", 113 | " fps = fps.view(\"u8\")\n", 114 | " fps = fps.reshape(int(fps.size / num_fields), num_fields)\n", 115 | " # we'll use popcnt_ranges for the bounds optimisaiton, it stores \n", 116 | " # the ranges for each popcount in the database\n", 117 | " popcnt_ranges = fp_file.root.config[3]" 118 | ], 119 | "execution_count": 0, 120 | "outputs": [] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "metadata": { 125 | "id": "nnjGuUnGFs41", 126 | "colab_type": "code", 127 | "colab": {} 128 | }, 129 | "source": [ 130 | "# aspirin, ChEMBL molregno 1280\n", 131 | "query_molregno = 1280" 132 | ], 133 | "execution_count": 0, 134 | "outputs": [] 135 | }, 136 | { 137 | "cell_type": "markdown", 138 | "metadata": { 139 | "id": "GTXWEq77Evb1", 140 | "colab_type": "text" 141 | }, 142 | "source": [ 143 | "# Let's try the ElementWise kernel\n", 144 | "\n", 145 | "CuPy's [ElementWise](https://docs-cupy.chainer.org/en/stable/tutorial/kernel.html#basics-of-elementwise-kernels) kernel will apply the same operation for each row. This makes sense to us because we would like to calc similarity for all molecules in the FP db file for given a query molecule .\n", 146 | "\n", 147 | "You probably noticed that we are using **i** variable which is not declared in the code... this is a special variable that indicates the index within the loop\n", 148 | "\n", 149 | "__popcll is a GPU instruction similar to the ones found in CPU which efficiently counts the number of 1's in a bit array" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "metadata": { 155 | "id": "5qpH8XKIEuO-", 156 | "colab_type": "code", 157 | "colab": {} 158 | }, 159 | "source": [ 160 | "taniEW = cp.ElementwiseKernel(\n", 161 | " in_params=\"raw T db, raw U query, uint64 in_width, float32 threshold\",\n", 162 | " out_params=\"raw V out\",\n", 163 | " operation=r\"\"\"\n", 164 | " int comm_sum = 0;\n", 165 | " for(int j = 1; j < in_width - 1; ++j){\n", 166 | " int pos = i * in_width + j;\n", 167 | " comm_sum += __popcll(db[pos] & query[j]);\n", 168 | " }\n", 169 | " float coeff = 0.0;\n", 170 | " coeff = query[in_width - 1] + db[i * in_width + in_width - 1] - comm_sum;\n", 171 | " if (coeff != 0.0)\n", 172 | " coeff = comm_sum / coeff;\n", 173 | " out[i] = coeff >= threshold ? coeff : 0.0;\n", 174 | " \"\"\",\n", 175 | " name='taniEW',\n", 176 | " options=('-std=c++14',),\n", 177 | " reduce_dims=False\n", 178 | ")" 179 | ], 180 | "execution_count": 0, 181 | "outputs": [] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "metadata": { 186 | "id": "Nh1lehEWFPuV", 187 | "colab_type": "code", 188 | "colab": {} 189 | }, 190 | "source": [ 191 | "# get the query molecule from the FP database\n", 192 | "query = cp.asarray(fps[(fps[:,0] == query_molregno)][0])\n", 193 | "# copy the database to GPU\n", 194 | "database = cp.asarray(fps)\n", 195 | "\n", 196 | "def cupy_elementwise_search(db, query, threshold):\n", 197 | " # init the results variable \n", 198 | " sim = cp.zeros(database.shape[0], dtype=\"f4\")\n", 199 | " \n", 200 | " # set the threshold variable and run the search\n", 201 | " threshold = cp.asarray(threshold, dtype=\"f4\")\n", 202 | " taniEW(db, query, db.shape[1], threshold, sim, size=db.shape[0])\n", 203 | "\n", 204 | " mask = sim.nonzero()[0]\n", 205 | " np_sim = cp.asnumpy(sim[mask])\n", 206 | " np_ids = cp.asnumpy(db[:,0][mask])\n", 207 | " \n", 208 | " dtype = np.dtype([(\"mol_id\", \"u4\"), (\"coeff\", \"f4\")])\n", 209 | " results = np.empty(len(np_ids), dtype=dtype)\n", 210 | " results[\"mol_id\"] = np_ids\n", 211 | " results[\"coeff\"] = np_sim\n", 212 | " results[::-1].sort(order='coeff')\n", 213 | " return results" 214 | ], 215 | "execution_count": 0, 216 | "outputs": [] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "metadata": { 221 | "id": "gFP4HKROF5l_", 222 | "colab_type": "code", 223 | "outputId": "74d4ddb5-6ca5-461b-d311-5134110b7405", 224 | "colab": { 225 | "base_uri": "https://localhost:8080/", 226 | "height": 69 227 | } 228 | }, 229 | "source": [ 230 | "results = cupy_elementwise_search(database, query, 0.7)\n", 231 | "results" 232 | ], 233 | "execution_count": 7, 234 | "outputs": [ 235 | { 236 | "output_type": "execute_result", 237 | "data": { 238 | "text/plain": [ 239 | "array([( 1280, 1. ), (2096455, 0.8888889 ),\n", 240 | " ( 271022, 0.85714287), ( 875057, 0.7 )],\n", 241 | " dtype=[('mol_id', '= threshold:\n", 310 | " range_to_screen.append(c_range)\n", 311 | " if range_to_screen:\n", 312 | " range_to_screen = (range_to_screen[0][0], \n", 313 | " range_to_screen[len(range_to_screen) - 1][1])\n", 314 | " return range_to_screen" 315 | ], 316 | "execution_count": 0, 317 | "outputs": [] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "metadata": { 322 | "id": "0ARyUwHQKaUH", 323 | "colab_type": "code", 324 | "colab": {} 325 | }, 326 | "source": [ 327 | "def cupy_elementwise_search_bounds(db, query, popcnt_ranges, threshold):\n", 328 | " \n", 329 | " # get the range of molecules to screen\n", 330 | " rk = get_tanimoto_bounds(int(query[-1]), popcnt_ranges, threshold)\n", 331 | " \n", 332 | " # set the threshold variable\n", 333 | " threshold = cp.asarray(threshold, dtype=\"f4\")\n", 334 | "\n", 335 | " # get the subset of molecule ids\n", 336 | " ids = db[:,0][slice(*rk)]\n", 337 | " subset_size = int(rk[1]-rk[0])\n", 338 | "\n", 339 | " # init the results variable\n", 340 | " sim = cp.zeros(subset_size, dtype=cp.float32)\n", 341 | "\n", 342 | " # run the search. It will compile the kernel only the first time it runs\n", 343 | " taniEW(db[slice(*rk)], query, db.shape[1], threshold, sim, size=subset_size)\n", 344 | "\n", 345 | " # get all non 0 values and ids\n", 346 | " mask = sim.nonzero()[0]\n", 347 | " np_sim = cp.asnumpy(sim[mask])\n", 348 | " np_ids = cp.asnumpy(ids[mask])\n", 349 | "\n", 350 | " # create results numpy array\n", 351 | " dtype = np.dtype([(\"mol_id\", \"u4\"), (\"coeff\", \"f4\")])\n", 352 | " results = np.empty(len(np_ids), dtype=dtype)\n", 353 | " results[\"mol_id\"] = np_ids\n", 354 | " results[\"coeff\"] = np_sim\n", 355 | " results[::-1].sort(order='coeff')\n", 356 | " return results" 357 | ], 358 | "execution_count": 0, 359 | "outputs": [] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "metadata": { 364 | "id": "bS5tt0voW-Lt", 365 | "colab_type": "code", 366 | "outputId": "e8511868-5502-4b74-f2ca-533631856a08", 367 | "colab": { 368 | "base_uri": "https://localhost:8080/", 369 | "height": 69 370 | } 371 | }, 372 | "source": [ 373 | "results = cupy_elementwise_search_bounds(database, query, popcnt_ranges, 0.7)\n", 374 | "results" 375 | ], 376 | "execution_count": 11, 377 | "outputs": [ 378 | { 379 | "output_type": "execute_result", 380 | "data": { 381 | "text/plain": [ 382 | "array([( 1280, 1. ), (2096455, 0.8888889 ),\n", 383 | " ( 271022, 0.85714287), ( 875057, 0.7 )],\n", 384 | " dtype=[('mol_id', '= *threshold ? coeff : 0.0;\n", 488 | " }}\n", 489 | " }}\n", 490 | " \"\"\".format(block=database.shape[1]),\n", 491 | " name=\"taniRAW\",\n", 492 | " options=('-std=c++14',),\n", 493 | ")" 494 | ], 495 | "execution_count": 0, 496 | "outputs": [] 497 | }, 498 | { 499 | "cell_type": "code", 500 | "metadata": { 501 | "id": "fINESKQAZWo8", 502 | "colab_type": "code", 503 | "colab": {} 504 | }, 505 | "source": [ 506 | "def cupy_sim_search_bounds(db, db_popcnts, db_ids, query, popcnt_ranges, threshold):\n", 507 | " \n", 508 | " c_query = cp.asarray(query[:,1:-1])\n", 509 | " qpopcnt = cp.asarray(query[:,-1])\n", 510 | "\n", 511 | " # get the range of the molecule subset to screen\n", 512 | " rk = get_tanimoto_bounds(int(query[:,-1]), popcnt_ranges, threshold)\n", 513 | " \n", 514 | " threshold = cp.asarray(threshold, dtype=\"f4\")\n", 515 | "\n", 516 | " # get the subset of molecule ids\n", 517 | " subset_size = int(rk[1]-rk[0])\n", 518 | " ids2 = db_ids[slice(*rk)]\n", 519 | "\n", 520 | " # init results array\n", 521 | " sim = cp.zeros(subset_size, dtype=cp.float32)\n", 522 | "\n", 523 | " # run the search, it compiles the kernel only the first time it runs\n", 524 | " # grid, block and arguments\n", 525 | " taniRAW((subset_size,), \n", 526 | " (db.shape[1],), \n", 527 | " (c_query, qpopcnt, db[slice(*rk)], db_popcnts[slice(*rk)], threshold, sim))\n", 528 | "\n", 529 | " # get all non 0 values and ids\n", 530 | " mask = sim.nonzero()[0]\n", 531 | " np_sim = cp.asnumpy(sim[mask])\n", 532 | " np_ids = cp.asnumpy(ids2[mask])\n", 533 | "\n", 534 | " # create results numpy array\n", 535 | " dtype = np.dtype([(\"mol_id\", \"u4\"), (\"coeff\", \"f4\")])\n", 536 | " results = np.empty(len(np_ids), dtype=dtype)\n", 537 | " results[\"mol_id\"] = np_ids\n", 538 | " results[\"coeff\"] = np_sim\n", 539 | " results[::-1].sort(order='coeff')\n", 540 | " return results" 541 | ], 542 | "execution_count": 0, 543 | "outputs": [] 544 | }, 545 | { 546 | "cell_type": "code", 547 | "metadata": { 548 | "id": "Jkx39KZKZW4a", 549 | "colab_type": "code", 550 | "outputId": "5b062cd2-60b9-4c02-d9cf-6b6cb193be70", 551 | "colab": { 552 | "base_uri": "https://localhost:8080/", 553 | "height": 69 554 | } 555 | }, 556 | "source": [ 557 | "results = cupy_sim_search_bounds(database, popcnts, ids, query, popcnt_ranges, 0.7)\n", 558 | "results" 559 | ], 560 | "execution_count": 15, 561 | "outputs": [ 562 | { 563 | "output_type": "execute_result", 564 | "data": { 565 | "text/plain": [ 566 | "array([( 1280, 1. ), (2096455, 0.8888889 ),\n", 567 | " ( 271022, 0.85714287), ( 875057, 0.7 )],\n", 568 | " dtype=[('mol_id', '= 8 AND\n", 80 | " target_dictionary.target_type = 'SINGLE PROTEIN'\"\"\"\n", 81 | "\n", 82 | "with engine.begin() as conn:\n", 83 | " res = conn.execute(text(qtext))\n", 84 | " df = pd.DataFrame(res.fetchall())\n", 85 | "\n", 86 | "df.columns = res.keys()\n", 87 | "df = df.where((pd.notnull(df)), None)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "# Drop duplicate activities keeping the activity with lower concentration for each molecule-target pair" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 3, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "df = df.sort_values(by=['standard_value', 'molregno', 'tid'], ascending=True)\n", 104 | "df = df.drop_duplicates(subset=['molregno', 'tid'], keep='first')\n", 105 | "\n", 106 | "# save to csv\n", 107 | "df.to_csv('chembl_activity_data.csv', index=False)" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "metadata": {}, 113 | "source": [ 114 | "# Set to active/inactive by threshold\n", 115 | "- Depending on family type from IDG: https://druggablegenome.net/ProteinFam\n", 116 | "\n", 117 | " - Kinases: <= 30nM\n", 118 | " - GPCRs: <= 100nM\n", 119 | " - Nuclear Receptors: <= 100nM\n", 120 | " - Ion Channels: <= 10μM\n", 121 | " - Non-IDG Family Targets: <= 1μM\n" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 4, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "def set_active(row):\n", 131 | " active = 0\n", 132 | " if row['standard_value'] <= 1000:\n", 133 | " active = 1\n", 134 | " if row['l1'] == 'Ion channel':\n", 135 | " if row['standard_value'] <= 10000:\n", 136 | " active = 1\n", 137 | " if row['l2'] == 'Kinase':\n", 138 | " if row['standard_value'] > 30:\n", 139 | " active = 0\n", 140 | " if row['l2'] == 'Nuclear receptor':\n", 141 | " if row['standard_value'] > 100:\n", 142 | " active = 0\n", 143 | " if row['l3'] and 'GPCR' in row['l3']:\n", 144 | " if row['standard_value'] > 100:\n", 145 | " active = 0\n", 146 | " return active\n", 147 | "\n", 148 | "df['active'] = df.apply(lambda row: set_active(row), axis=1)" 149 | ] 150 | }, 151 | { 152 | "cell_type": "markdown", 153 | "metadata": {}, 154 | "source": [ 155 | "# Filter target data\n", 156 | "\n", 157 | "- Keep targets mentioned at least in two different docs\n", 158 | "- Keep targets with at least 100 active and 100 inactive molecules. Threshold set to 100 to get a 'small' dataset that will train faster on this example." 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 5, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "# get targets with at least 100 different active molecules\n", 168 | "acts = df[df['active'] == 1].groupby(['target_chembl_id']).agg('count')\n", 169 | "acts = acts[acts['molregno'] >= 100].reset_index()['target_chembl_id']\n", 170 | "\n", 171 | "# get targets with at least 100 different inactive molecules\n", 172 | "inacts = df[df['active'] == 0].groupby(['target_chembl_id']).agg('count')\n", 173 | "inacts = inacts[inacts['molregno'] >= 100].reset_index()['target_chembl_id']\n", 174 | "\n", 175 | "# get targets mentioned in at least two docs\n", 176 | "docs = df.drop_duplicates(subset=['doc_id', 'target_chembl_id'])\n", 177 | "docs = docs.groupby(['target_chembl_id']).agg('count')\n", 178 | "docs = docs[docs['doc_id'] >= 2.0].reset_index()['target_chembl_id']\n" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 6, 184 | "metadata": {}, 185 | "outputs": [ 186 | { 187 | "name": "stdout", 188 | "output_type": "stream", 189 | "text": [ 190 | "Number of unique targets: 560\n", 191 | " Ion channel: 5\n", 192 | " Kinase: 96\n", 193 | " Nuclear receptor: 21\n", 194 | " GPCR: 180\n", 195 | " Others: 258\n" 196 | ] 197 | } 198 | ], 199 | "source": [ 200 | "t_keep = set(acts).intersection(set(inacts)).intersection(set(docs))\n", 201 | "\n", 202 | "# get dta for filtered targets\n", 203 | "activities = df[df['target_chembl_id'].isin(t_keep)]\n", 204 | "\n", 205 | "ion = pd.unique(activities[activities['l1'] == 'Ion channel']['tid']).shape[0]\n", 206 | "kin = pd.unique(activities[activities['l2'] == 'Kinase']['tid']).shape[0]\n", 207 | "nuc = pd.unique(activities[activities['l2'] == 'Nuclear receptor']['tid']).shape[0]\n", 208 | "gpcr = pd.unique(activities[activities['l3'].str.contains('GPCR', na=False)]['tid']).shape[0]\n", 209 | "\n", 210 | "print('Number of unique targets: ', len(t_keep))\n", 211 | "print(' Ion channel: ', ion)\n", 212 | "print(' Kinase: ', kin)\n", 213 | "print(' Nuclear receptor: ', nuc)\n", 214 | "print(' GPCR: ', gpcr)\n", 215 | "print(' Others: ', len(t_keep) - ion - kin - nuc - gpcr)" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 7, 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [ 224 | "# save it to a file\n", 225 | "activities.to_csv('chembl_activity_data_filtered.csv', index=False)" 226 | ] 227 | }, 228 | { 229 | "cell_type": "markdown", 230 | "metadata": {}, 231 | "source": [ 232 | "# Prepare the label matrix for the multi-task deep neural network\n", 233 | "\n", 234 | " - known active = 1\n", 235 | " - known no-active = 0\n", 236 | " - unknown activity = -1, so we'll be able to easilly filter them and won't be taken into account when calculating the loss during model training.\n", 237 | " \n", 238 | "The matrix is extremely sparse so using sparse matrices (COO/CSR/CSC) should be considered. There are a couple of issues making it a bit tricker than what it should be so we'll keep the example without them.\n", 239 | "\n", 240 | "- https://github.com/pytorch/pytorch/issues/20248\n", 241 | "- https://github.com/scipy/scipy/issues/7531\n" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": 8, 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [ 250 | "def gen_dict(group):\n", 251 | " return {tid: act for tid, act in zip(group['target_chembl_id'], group['active'])}\n", 252 | "\n", 253 | "group = activities.groupby('chembl_id')\n", 254 | "temp = pd.DataFrame(group.apply(gen_dict))\n", 255 | "mt_df = pd.DataFrame(temp[0].tolist())\n", 256 | "mt_df['chembl_id'] = temp.index\n", 257 | "mt_df = mt_df.where((pd.notnull(mt_df)), -1)" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 9, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "structs = activities[['chembl_id', 'canonical_smiles']].drop_duplicates(subset='chembl_id')\n", 267 | "\n", 268 | "# drop mols not sanitizing on rdkit\n", 269 | "structs['romol'] = structs.apply(lambda row: Chem.MolFromSmiles(row['canonical_smiles']), axis=1)\n", 270 | "structs = structs.dropna()\n", 271 | "del structs['romol']\n", 272 | "\n", 273 | "# add the structures to the final df\n", 274 | "mt_df = pd.merge(structs, mt_df, how='inner', on='chembl_id')" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": 10, 280 | "metadata": {}, 281 | "outputs": [], 282 | "source": [ 283 | "# save to csv\n", 284 | "mt_df.to_csv('chembl_multi_task_data.csv', index=False)" 285 | ] 286 | }, 287 | { 288 | "cell_type": "markdown", 289 | "metadata": {}, 290 | "source": [ 291 | "# Calc fingeprints and save data to a PyTables H5 file" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 11, 297 | "metadata": {}, 298 | "outputs": [], 299 | "source": [ 300 | "FP_SIZE = 1024\n", 301 | "RADIUS = 2\n", 302 | "\n", 303 | "def calc_fp(smiles, fp_size, radius):\n", 304 | " \"\"\"\n", 305 | " calcs morgan fingerprints as a numpy array.\n", 306 | " \"\"\"\n", 307 | " mol = Chem.MolFromSmiles(smiles, sanitize=False)\n", 308 | " mol.UpdatePropertyCache(False)\n", 309 | " Chem.GetSSSR(mol)\n", 310 | " fp = rdMolDescriptors.GetMorganFingerprintAsBitVect(mol, radius, nBits=fp_size)\n", 311 | " a = np.zeros((0,), dtype=np.float32)\n", 312 | " Chem.DataStructs.ConvertToNumpyArray(fp, a)\n", 313 | " return a\n", 314 | "\n", 315 | "# calc fps\n", 316 | "descs = [calc_fp(smi, FP_SIZE, RADIUS) for smi in mt_df['canonical_smiles'].values]\n", 317 | "descs = np.asarray(descs, dtype=np.float32)\n", 318 | "\n", 319 | "# put all training data in a pytables file\n", 320 | "with tb.open_file('mt_data.h5', mode='w') as t_file:\n", 321 | "\n", 322 | " # set compression filter. It will make the file much smaller\n", 323 | " filters = tb.Filters(complib='blosc', complevel=5)\n", 324 | "\n", 325 | " # save chembl_ids\n", 326 | " tatom = ObjectAtom()\n", 327 | " cids = t_file.create_vlarray(t_file.root, 'chembl_ids', atom=tatom)\n", 328 | " for cid in mt_df['chembl_id'].values:\n", 329 | " cids.append(cid)\n", 330 | "\n", 331 | " # save fps\n", 332 | " fatom = tb.Atom.from_dtype(descs.dtype)\n", 333 | " fps = t_file.create_carray(t_file.root, 'fps', fatom, descs.shape, filters=filters)\n", 334 | " fps[:] = descs\n", 335 | "\n", 336 | " del mt_df['chembl_id']\n", 337 | " del mt_df['canonical_smiles']\n", 338 | "\n", 339 | " # save target chembl ids\n", 340 | " tcids = t_file.create_vlarray(t_file.root, 'target_chembl_ids', atom=tatom)\n", 341 | " for tcid in mt_df.columns.values:\n", 342 | " tcids.append(tcid)\n", 343 | "\n", 344 | " # save labels\n", 345 | " labs = t_file.create_carray(t_file.root, 'labels', fatom, mt_df.values.shape, filters=filters)\n", 346 | " labs[:] = mt_df.values\n", 347 | " \n", 348 | " # save task weights\n", 349 | " # each task loss will be weighted inversely proportional to its number of data points\n", 350 | " weights = []\n", 351 | " for col in mt_df.columns.values:\n", 352 | " c = mt_df[mt_df[col] >= 0.0].shape[0]\n", 353 | " weights.append(1 / c)\n", 354 | " weights = np.array(weights)\n", 355 | " ws = t_file.create_carray(t_file.root, 'weights', fatom, weights.shape)\n", 356 | " ws[:] = weights" 357 | ] 358 | }, 359 | { 360 | "cell_type": "markdown", 361 | "metadata": {}, 362 | "source": [ 363 | "# Open H5 file and show the shape of all collections" 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": 12, 369 | "metadata": {}, 370 | "outputs": [ 371 | { 372 | "name": "stdout", 373 | "output_type": "stream", 374 | "text": [ 375 | "(711591,)\n", 376 | "(560,)\n", 377 | "(711591, 1024)\n", 378 | "(711591, 560)\n", 379 | "(560,)\n" 380 | ] 381 | } 382 | ], 383 | "source": [ 384 | "with tb.open_file('mt_data.h5', mode='r') as t_file:\n", 385 | " print(t_file.root.chembl_ids.shape)\n", 386 | " print(t_file.root.target_chembl_ids.shape)\n", 387 | " print(t_file.root.fps.shape)\n", 388 | " print(t_file.root.labels.shape)\n", 389 | " print(t_file.root.weights.shape)\n", 390 | " \n", 391 | " # save targets to a json file\n", 392 | " with open('targets.json', 'w') as f:\n", 393 | " json.dump(t_file.root.target_chembl_ids[:], f)" 394 | ] 395 | }, 396 | { 397 | "cell_type": "code", 398 | "execution_count": null, 399 | "metadata": {}, 400 | "outputs": [], 401 | "source": [] 402 | } 403 | ], 404 | "metadata": { 405 | "kernelspec": { 406 | "display_name": "Python 3", 407 | "language": "python", 408 | "name": "python3" 409 | }, 410 | "language_info": { 411 | "codemirror_mode": { 412 | "name": "ipython", 413 | "version": 3 414 | }, 415 | "file_extension": ".py", 416 | "mimetype": "text/x-python", 417 | "name": "python", 418 | "nbconvert_exporter": "python", 419 | "pygments_lexer": "ipython3", 420 | "version": "3.6.8" 421 | } 422 | }, 423 | "nbformat": 4, 424 | "nbformat_minor": 2 425 | } 426 | -------------------------------------------------------------------------------- /name2chembl.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Drug name to ChEMBL function\n", 8 | "\n", 9 | "Single function using chembl_webresource_client: https://pypi.org/project/chembl-webresource-client/\n", 10 | "\n", 11 | "Tries 3 different things (in order):\n", 12 | "\n", 13 | "1. Case insensitive match against molecule_dictionary.pref_name\n", 14 | "2. Case insensitive match against molecule_synonyms.synonyms\n", 15 | "3. Use elastic search as a last resort (optional)\n", 16 | "\n", 17 | "Note: not all pref_name are included in molecule_synonyms so it's not possible to skip step 1.\n", 18 | "\n", 19 | "A name can match against many chembl compounds. Compounds are sorted by max_phase. Manual curation should be considered under this situation (reason to keep smiles, inchi and inchi key)." 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 1, 25 | "metadata": { 26 | "tags": [] 27 | }, 28 | "outputs": [ 29 | { 30 | "output_type": "stream", 31 | "name": "stdout", 32 | "text": "Requirement already satisfied: chembl_webresource_client in /Users/efelix/miniconda3/envs/python38/lib/python3.8/site-packages (0.10.2)\nRequirement already satisfied: requests-cache>=0.4.7 in /Users/efelix/miniconda3/envs/python38/lib/python3.8/site-packages (from chembl_webresource_client) (0.5.2)\nRequirement already satisfied: easydict in /Users/efelix/miniconda3/envs/python38/lib/python3.8/site-packages (from chembl_webresource_client) (1.9)\nRequirement already satisfied: requests>=2.18.4 in /Users/efelix/miniconda3/envs/python38/lib/python3.8/site-packages (from chembl_webresource_client) (2.23.0)\nRequirement already satisfied: urllib3 in /Users/efelix/miniconda3/envs/python38/lib/python3.8/site-packages (from chembl_webresource_client) (1.25.8)\nRequirement already satisfied: chardet<4,>=3.0.2 in /Users/efelix/miniconda3/envs/python38/lib/python3.8/site-packages (from requests>=2.18.4->chembl_webresource_client) (3.0.4)\nRequirement already satisfied: certifi>=2017.4.17 in /Users/efelix/miniconda3/envs/python38/lib/python3.8/site-packages (from requests>=2.18.4->chembl_webresource_client) (2020.6.20)\nRequirement already satisfied: idna<3,>=2.5 in /Users/efelix/miniconda3/envs/python38/lib/python3.8/site-packages (from requests>=2.18.4->chembl_webresource_client) (2.9)\n" 33 | } 34 | ], 35 | "source": [ 36 | "# install the webresource client\n", 37 | "!pip install chembl_webresource_client" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "from chembl_webresource_client.new_client import new_client\n", 47 | "\n", 48 | "def name2chembl(name, use_search=False):\n", 49 | " \"\"\"\n", 50 | " Tries to retrieve the chembl_id and the structure for given a drug name.\n", 51 | " \"\"\"\n", 52 | " molecule = new_client.molecule\n", 53 | " fields = [\"molecule_chembl_id\", \"pref_name\", \"max_phase\", \"molecule_structures\"]\n", 54 | " # search in pref_name\n", 55 | " # iexact does exact case insensitive search\n", 56 | " res = molecule.filter(pref_name__iexact=name).only(fields)\n", 57 | " res = list(res)\n", 58 | " if res:\n", 59 | " # sort by max_phase\n", 60 | " res = sorted(res, key=lambda k: k[\"max_phase\"], reverse=True)\n", 61 | " return res, \"pref_name\"\n", 62 | " else:\n", 63 | " # if no pref_name match, look at the synonyms\n", 64 | " # some pref_name are not included in molecule_synonyms talbe so is not possible\n", 65 | " # to skip the first step\n", 66 | " res = molecule.filter(molecule_synonyms__molecule_synonym__iexact=name).only(fields)\n", 67 | " res = list(res)\n", 68 | " if res:\n", 69 | " # sort by max_phase\n", 70 | " res = sorted(res, key=lambda k: k[\"max_phase\"], reverse=True)\n", 71 | " return res, \"synonyms\"\n", 72 | " else:\n", 73 | " if use_search:\n", 74 | " # last resort:\n", 75 | " #   search function uses elastic and could eventually retrieve inexact matches\n", 76 | " # can also take longer than previous calls\n", 77 | " res = molecule.search(name).only(fields)[0]\n", 78 | " if res:\n", 79 | " return [res], \"search\"\n", 80 | " return None, None\n" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "# Example where it gets a match from molecule_dictionary.pref_name" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 3, 93 | "metadata": { 94 | "tags": [] 95 | }, 96 | "outputs": [ 97 | { 98 | "output_type": "stream", 99 | "name": "stdout", 100 | "text": "pref_name\n" 101 | }, 102 | { 103 | "output_type": "execute_result", 104 | "data": { 105 | "text/plain": "[{'max_phase': 4,\n 'molecule_chembl_id': 'CHEMBL192',\n 'molecule_structures': {'canonical_smiles': 'CCCc1nn(C)c2c(=O)[nH]c(-c3cc(S(=O)(=O)N4CCN(C)CC4)ccc3OCC)nc12',\n 'molfile': '\\n RDKit 2D\\n\\n 33 36 0 0 0 0 0 0 0 0999 V2000\\n 2.1000 -0.0042 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 2.1000 0.7000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -1.5375 -0.0042 0.0000 S 0 0 0 0 0 0 0 0 0 0 0 0\\n 1.4917 -0.3667 0.0000 N 0 0 0 0 0 0 0 0 0 0 0 0\\n 0.8792 -0.0042 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 2.8042 0.9083 0.0000 N 0 0 0 0 0 0 0 0 0 0 0 0\\n 1.4917 1.0625 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 0.8792 0.6833 0.0000 N 0 0 0 0 0 0 0 0 0 0 0 0\\n 3.2042 0.3458 0.0000 N 0 0 0 0 0 0 0 0 0 0 0 0\\n 2.8042 -0.2417 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 0.2875 -0.3750 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -2.1583 -0.3750 0.0000 N 0 0 0 0 0 0 0 0 0 0 0 0\\n -0.9333 -0.3750 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -0.3208 -0.0333 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -1.1875 0.6083 0.0000 O 0 0 0 0 0 0 0 0 0 0 0 0\\n -1.8958 0.6083 0.0000 O 0 0 0 0 0 0 0 0 0 0 0 0\\n -3.3958 -1.0917 0.0000 N 0 0 0 0 0 0 0 0 0 0 0 0\\n -2.7833 -0.0042 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -2.1583 -1.0917 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 0.2875 -1.1125 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 1.4917 1.7708 0.0000 O 0 0 0 0 0 0 0 0 0 0 0 0\\n -0.9333 -1.1125 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -0.3208 -1.4542 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -3.3958 -0.3750 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -2.7833 -1.4417 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 3.0750 1.5750 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 2.8042 -0.9500 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 0.8792 -1.4542 0.0000 O 0 0 0 0 0 0 0 0 0 0 0 0\\n -3.9958 -1.4292 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 1.4958 -1.1000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 3.4167 -1.3125 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 2.1125 -1.4500 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 4.0375 -0.9542 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 2 1 2 0\\n 3 13 1 0\\n 4 1 1 0\\n 5 4 2 0\\n 6 2 1 0\\n 7 2 1 0\\n 8 5 1 0\\n 9 10 2 0\\n 10 1 1 0\\n 11 5 1 0\\n 12 3 1 0\\n 13 14 2 0\\n 14 11 1 0\\n 15 3 2 0\\n 16 3 2 0\\n 17 25 1 0\\n 18 12 1 0\\n 19 12 1 0\\n 20 11 2 0\\n 21 7 2 0\\n 22 23 2 0\\n 23 20 1 0\\n 24 18 1 0\\n 25 19 1 0\\n 26 6 1 0\\n 27 10 1 0\\n 28 20 1 0\\n 29 17 1 0\\n 30 28 1 0\\n 31 27 1 0\\n 32 30 1 0\\n 33 31 1 0\\n 9 6 1 0\\n 8 7 1 0\\n 22 13 1 0\\n 17 24 1 0\\nM END\\n\\n> \\nCHEMBL192\\n\\n> \\nSILDENAFIL\\n\\n',\n 'standard_inchi': 'InChI=1S/C22H30N6O4S/c1-5-7-17-19-20(27(4)25-17)22(29)24-21(23-19)16-14-15(8-9-18(16)32-6-2)33(30,31)28-12-10-26(3)11-13-28/h8-9,14H,5-7,10-13H2,1-4H3,(H,23,24,29)',\n 'standard_inchi_key': 'BNRNXUUZRGQAQC-UHFFFAOYSA-N'},\n 'pref_name': 'SILDENAFIL'}]" 106 | }, 107 | "metadata": {}, 108 | "execution_count": 3 109 | } 110 | ], 111 | "source": [ 112 | "matches, where = name2chembl('sildenafil')\n", 113 | "\n", 114 | "print(where)\n", 115 | "matches" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": {}, 121 | "source": [ 122 | "# Example where it gets matches from molecule_dictionary.synonyms" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 4, 128 | "metadata": { 129 | "tags": [] 130 | }, 131 | "outputs": [ 132 | { 133 | "output_type": "stream", 134 | "name": "stdout", 135 | "text": "synonyms\n" 136 | }, 137 | { 138 | "output_type": "execute_result", 139 | "data": { 140 | "text/plain": "[{'max_phase': 4,\n 'molecule_chembl_id': 'CHEMBL192',\n 'molecule_structures': {'canonical_smiles': 'CCCc1nn(C)c2c(=O)[nH]c(-c3cc(S(=O)(=O)N4CCN(C)CC4)ccc3OCC)nc12',\n 'molfile': '\\n RDKit 2D\\n\\n 33 36 0 0 0 0 0 0 0 0999 V2000\\n 2.1000 -0.0042 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 2.1000 0.7000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -1.5375 -0.0042 0.0000 S 0 0 0 0 0 0 0 0 0 0 0 0\\n 1.4917 -0.3667 0.0000 N 0 0 0 0 0 0 0 0 0 0 0 0\\n 0.8792 -0.0042 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 2.8042 0.9083 0.0000 N 0 0 0 0 0 0 0 0 0 0 0 0\\n 1.4917 1.0625 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 0.8792 0.6833 0.0000 N 0 0 0 0 0 0 0 0 0 0 0 0\\n 3.2042 0.3458 0.0000 N 0 0 0 0 0 0 0 0 0 0 0 0\\n 2.8042 -0.2417 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 0.2875 -0.3750 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -2.1583 -0.3750 0.0000 N 0 0 0 0 0 0 0 0 0 0 0 0\\n -0.9333 -0.3750 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -0.3208 -0.0333 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -1.1875 0.6083 0.0000 O 0 0 0 0 0 0 0 0 0 0 0 0\\n -1.8958 0.6083 0.0000 O 0 0 0 0 0 0 0 0 0 0 0 0\\n -3.3958 -1.0917 0.0000 N 0 0 0 0 0 0 0 0 0 0 0 0\\n -2.7833 -0.0042 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -2.1583 -1.0917 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 0.2875 -1.1125 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 1.4917 1.7708 0.0000 O 0 0 0 0 0 0 0 0 0 0 0 0\\n -0.9333 -1.1125 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -0.3208 -1.4542 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -3.3958 -0.3750 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -2.7833 -1.4417 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 3.0750 1.5750 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 2.8042 -0.9500 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 0.8792 -1.4542 0.0000 O 0 0 0 0 0 0 0 0 0 0 0 0\\n -3.9958 -1.4292 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 1.4958 -1.1000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 3.4167 -1.3125 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 2.1125 -1.4500 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 4.0375 -0.9542 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 2 1 2 0\\n 3 13 1 0\\n 4 1 1 0\\n 5 4 2 0\\n 6 2 1 0\\n 7 2 1 0\\n 8 5 1 0\\n 9 10 2 0\\n 10 1 1 0\\n 11 5 1 0\\n 12 3 1 0\\n 13 14 2 0\\n 14 11 1 0\\n 15 3 2 0\\n 16 3 2 0\\n 17 25 1 0\\n 18 12 1 0\\n 19 12 1 0\\n 20 11 2 0\\n 21 7 2 0\\n 22 23 2 0\\n 23 20 1 0\\n 24 18 1 0\\n 25 19 1 0\\n 26 6 1 0\\n 27 10 1 0\\n 28 20 1 0\\n 29 17 1 0\\n 30 28 1 0\\n 31 27 1 0\\n 32 30 1 0\\n 33 31 1 0\\n 9 6 1 0\\n 8 7 1 0\\n 22 13 1 0\\n 17 24 1 0\\nM END\\n\\n> \\nCHEMBL192\\n\\n> \\nSILDENAFIL\\n\\n',\n 'standard_inchi': 'InChI=1S/C22H30N6O4S/c1-5-7-17-19-20(27(4)25-17)22(29)24-21(23-19)16-14-15(8-9-18(16)32-6-2)33(30,31)28-12-10-26(3)11-13-28/h8-9,14H,5-7,10-13H2,1-4H3,(H,23,24,29)',\n 'standard_inchi_key': 'BNRNXUUZRGQAQC-UHFFFAOYSA-N'},\n 'pref_name': 'SILDENAFIL'},\n {'max_phase': 4,\n 'molecule_chembl_id': 'CHEMBL1737',\n 'molecule_structures': {'canonical_smiles': 'CCCc1nn(C)c2c(=O)[nH]c(-c3cc(S(=O)(=O)N4CCN(C)CC4)ccc3OCC)nc12.O=C(O)CC(O)(CC(=O)O)C(=O)O',\n 'molfile': '\\n RDKit 2D\\n\\n 46 48 0 0 0 0 0 0 0 0999 V2000\\n 9.2182 1.4870 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 11.8182 1.4870 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 7.9182 2.2370 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 13.1182 2.2370 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 10.5182 3.7377 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 10.5182 2.2370 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 6.8788 1.6373 0.0000 O 0 0 0 0 0 0 0 0 0 0 0 0\\n 7.9182 3.4370 0.0000 O 0 0 0 0 0 0 0 0 0 0 0 0\\n 14.1576 1.6373 0.0000 O 0 0 0 0 0 0 0 0 0 0 0 0\\n 13.1182 3.4370 0.0000 O 0 0 0 0 0 0 0 0 0 0 0 0\\n 11.5571 4.3384 0.0000 O 0 0 0 0 0 0 0 0 0 0 0 0\\n 9.4788 4.3374 0.0000 O 0 0 0 0 0 0 0 0 0 0 0 0\\n 10.5182 1.0370 0.0000 O 0 0 0 0 0 0 0 0 0 0 0 0\\n 4.0244 4.0756 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -6.2133 -2.7101 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -9.8661 8.0854 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 2.0907 -2.3426 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 3.6500 2.9355 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -6.2134 -1.5101 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 2.1812 2.6271 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -6.2185 2.9892 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -6.2151 1.4892 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -8.8235 5.9871 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -7.5281 8.2391 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -7.5233 5.2391 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -6.2278 7.4912 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -3.6204 2.9950 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -4.9211 3.7421 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -3.6168 1.4950 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 1.7138 1.2033 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -4.9144 0.7421 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 0.2917 0.7475 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 0.2917 -0.7475 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -2.3155 0.7475 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -1.0028 -1.5132 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -1.0028 1.5132 0.0000 N 0 0 0 0 0 0 0 0 0 0 0 0\\n -2.3155 -0.7475 0.0000 N 0 0 0 0 0 0 0 0 0 0 0 0\\n 2.5889 0.0182 0.0000 N 0 0 0 0 0 0 0 0 0 0 0 0\\n -8.8259 7.4871 0.0000 N 0 0 0 0 0 0 0 0 0 0 0 0\\n 1.7138 -1.2033 0.0000 N 0 0 0 0 0 0 0 0 0 0 0 0\\n -6.2254 5.9912 0.0000 N 0 0 0 0 0 0 0 0 0 0 0 0\\n -0.9991 -2.7132 0.0000 O 0 0 0 0 0 0 0 0 0 0 0 0\\n -3.8864 5.8449 0.0000 O 0 0 0 0 0 0 0 0 0 0 0 0\\n -3.8840 4.6451 0.0000 O 0 0 0 0 0 0 0 0 0 0 0 0\\n -4.9142 -0.7587 0.0000 O 0 0 0 0 0 0 0 0 0 0 0 0\\n -4.9245 5.2429 0.0000 S 0 0 0 0 0 0 0 0 0 0 0 0\\n 1 3 1 0\\n 1 6 1 0\\n 2 4 1 0\\n 2 6 1 0\\n 3 7 2 0\\n 3 8 1 0\\n 4 9 2 0\\n 4 10 1 0\\n 5 6 1 0\\n 5 11 2 0\\n 5 12 1 0\\n 6 13 1 0\\n 14 18 1 0\\n 15 19 1 0\\n 16 39 1 0\\n 17 40 1 0\\n 18 20 1 0\\n 19 45 1 0\\n 20 30 1 0\\n 21 22 2 0\\n 21 28 1 0\\n 22 31 1 0\\n 23 25 1 0\\n 23 39 1 0\\n 24 26 1 0\\n 24 39 1 0\\n 25 41 1 0\\n 26 41 1 0\\n 27 28 2 0\\n 27 29 1 0\\n 28 46 1 0\\n 29 31 2 0\\n 29 34 1 0\\n 30 32 1 0\\n 30 38 2 0\\n 31 45 1 0\\n 32 33 2 0\\n 32 36 1 0\\n 33 35 1 0\\n 33 40 1 0\\n 34 36 2 0\\n 34 37 1 0\\n 35 37 1 0\\n 35 42 2 0\\n 38 40 1 0\\n 41 46 1 0\\n 43 46 2 0\\n 44 46 2 0\\nM END\\n\\n> \\nCHEMBL1737\\n\\n> \\nSILDENAFIL CITRATE\\n\\n',\n 'standard_inchi': 'InChI=1S/C22H30N6O4S.C6H8O7/c1-5-7-17-19-20(27(4)25-17)22(29)24-21(23-19)16-14-15(8-9-18(16)32-6-2)33(30,31)28-12-10-26(3)11-13-28;7-3(8)1-6(13,5(11)12)2-4(9)10/h8-9,14H,5-7,10-13H2,1-4H3,(H,23,24,29);13H,1-2H2,(H,7,8)(H,9,10)(H,11,12)',\n 'standard_inchi_key': 'DEIYFTQMQPDXOT-UHFFFAOYSA-N'},\n 'pref_name': 'SILDENAFIL CITRATE'}]" 141 | }, 142 | "metadata": {}, 143 | "execution_count": 4 144 | } 145 | ], 146 | "source": [ 147 | "matches, where = name2chembl('viagra')\n", 148 | "\n", 149 | "print(where)\n", 150 | "matches" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [ 157 | "# Example where it gets a match using the search feature" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 5, 163 | "metadata": { 164 | "tags": [] 165 | }, 166 | "outputs": [ 167 | { 168 | "output_type": "stream", 169 | "name": "stdout", 170 | "text": "None\n" 171 | } 172 | ], 173 | "source": [ 174 | "matches, where = name2chembl('Azaguanine-8')\n", 175 | "\n", 176 | "print(where)\n", 177 | "matches" 178 | ] 179 | }, 180 | { 181 | "cell_type": "markdown", 182 | "metadata": {}, 183 | "source": [ 184 | "## Elasticsearch always tries to retrieve results so matches retrieved with the use_search should be manually curated" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 6, 190 | "metadata": { 191 | "tags": [] 192 | }, 193 | "outputs": [ 194 | { 195 | "output_type": "stream", 196 | "name": "stdout", 197 | "text": "search\n" 198 | }, 199 | { 200 | "output_type": "execute_result", 201 | "data": { 202 | "text/plain": "[{'max_phase': 0,\n 'molecule_chembl_id': 'CHEMBL374107',\n 'molecule_structures': {'canonical_smiles': 'Nc1nc(O)c2[nH]nnc2n1',\n 'molfile': '\\n RDKit 2D\\n\\n 11 12 0 0 0 0 0 0 0 0999 V2000\\n -0.4152 0.4905 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 0.2993 0.9030 0.0000 N 0 0 0 0 0 0 0 0 0 0 0 0\\n 1.0137 0.4905 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 1.7282 0.9030 0.0000 N 0 0 0 0 0 0 0 0 0 0 0 0\\n 1.0137 -0.3345 0.0000 N 0 0 0 0 0 0 0 0 0 0 0 0\\n 0.2993 -0.7470 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n 0.2993 -1.5720 0.0000 O 0 0 0 0 0 0 0 0 0 0 0 0\\n -0.4152 -0.3345 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\\n -1.1998 -0.5894 0.0000 N 0 0 0 0 0 0 0 0 0 0 0 0\\n -1.6848 0.0780 0.0000 N 0 0 0 0 0 0 0 0 0 0 0 0\\n -1.1998 0.7455 0.0000 N 0 0 0 0 0 0 0 0 0 0 0 0\\n 1 2 1 0\\n 1 8 2 0\\n 1 11 1 0\\n 3 2 2 0\\n 3 4 1 0\\n 3 5 1 0\\n 6 5 2 0\\n 6 7 1 0\\n 8 6 1 0\\n 8 9 1 0\\n 9 10 1 0\\n 10 11 2 0\\nM END\\n\\n> \\nCHEMBL374107\\n\\n> \\n8-AZAGUANINE\\n\\n',\n 'standard_inchi': 'InChI=1S/C4H4N6O/c5-4-6-2-1(3(11)7-4)8-10-9-2/h(H4,5,6,7,8,9,10,11)',\n 'standard_inchi_key': 'LPXQRXLUHJKZIE-UHFFFAOYSA-N'},\n 'pref_name': '8-AZAGUANINE'}]" 203 | }, 204 | "metadata": {}, 205 | "execution_count": 6 206 | } 207 | ], 208 | "source": [ 209 | "matches, where = name2chembl('Azaguanine-8', use_search=True)\n", 210 | "\n", 211 | "print(where)\n", 212 | "matches" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [] 221 | } 222 | ], 223 | "metadata": { 224 | "kernelspec": { 225 | "display_name": "Python 3", 226 | "language": "python", 227 | "name": "python3" 228 | }, 229 | "language_info": { 230 | "codemirror_mode": { 231 | "name": "ipython", 232 | "version": 3 233 | }, 234 | "file_extension": ".py", 235 | "mimetype": "text/x-python", 236 | "name": "python", 237 | "nbconvert_exporter": "python", 238 | "pygments_lexer": "ipython3", 239 | "version": "3.8.2-final" 240 | } 241 | }, 242 | "nbformat": 4, 243 | "nbformat_minor": 4 244 | } -------------------------------------------------------------------------------- /train_chembl_multitask.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Requirements for running this notebook:\n", 8 | "- Python =>3.6 (using f-Strings)\n", 9 | "- PyTorch => 1.0\n", 10 | "- scikit-learn\n", 11 | "- NumPy\n", 12 | "- PyTables\n", 13 | "- mt_data.h5 file: You can generate by yourself with this other notebook or [download it](http://ftp.ebi.ac.uk/pub/databases/chembl/blog/pytorch_mtl/mt_data.h5)\n", 14 | "\n", 15 | "## This notebook trains and test a multi-task neural network on ChEMBL data\n", 16 | "- It uses a simple shuffled 80/20 train/test split\n", 17 | "- Automatically configures the output layer no matter the number of targets in the training data.\n", 18 | "- Tries to use GPU if available\n", 19 | "- Saves and loads a model to/from a file\n" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 1, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "import numpy as np\n", 29 | "import torch\n", 30 | "from torch import nn\n", 31 | "import torch.nn.functional as F\n", 32 | "import torch.utils.data as D\n", 33 | "import tables as tb\n", 34 | "from sklearn.metrics import (matthews_corrcoef, \n", 35 | " confusion_matrix, \n", 36 | " f1_score, \n", 37 | " roc_auc_score,\n", 38 | " accuracy_score,\n", 39 | " roc_auc_score)\n" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 2, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "# set the device to GPU if available\n", 49 | "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "# Set some config values" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 3, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "MAIN_PATH = '.'\n", 66 | "DATA_FILE = 'mt_data.h5'\n", 67 | "MODEL_FILE = 'chembl_mt.model'\n", 68 | "N_WORKERS = 8 # Dataloader workers, prefetch data in parallel to have it ready for the model after each batch train\n", 69 | "BATCH_SIZE = 32 # https://twitter.com/ylecun/status/989610208497360896?lang=es\n", 70 | "LR = 2 # Learning rate. Big value because of the way we are weighting the targets\n", 71 | "N_EPOCHS = 2 # You should train longer!!!" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "# Set the dataset loaders\n", 79 | "\n", 80 | "Simple 80/20 train/test split for the example" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 4, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "class ChEMBLDataset(D.Dataset):\n", 90 | " \n", 91 | " def __init__(self, file_path):\n", 92 | " self.file_path = file_path\n", 93 | " with tb.open_file(self.file_path, mode='r') as t_file:\n", 94 | " self.length = t_file.root.fps.shape[0]\n", 95 | " self.n_targets = t_file.root.labels.shape[1]\n", 96 | " \n", 97 | " def __len__(self):\n", 98 | " return self.length\n", 99 | " \n", 100 | " def __getitem__(self, index):\n", 101 | " with tb.open_file(self.file_path, mode='r') as t_file:\n", 102 | " structure = t_file.root.fps[index]\n", 103 | " labels = t_file.root.labels[index]\n", 104 | " return structure, labels\n", 105 | "\n", 106 | "\n", 107 | "dataset = ChEMBLDataset(f\"{MAIN_PATH}/{DATA_FILE}\")\n", 108 | "validation_split = .2\n", 109 | "random_seed= 42\n", 110 | "\n", 111 | "dataset_size = len(dataset)\n", 112 | "indices = list(range(dataset_size))\n", 113 | "split = int(np.floor(validation_split * dataset_size))\n", 114 | "\n", 115 | "np.random.seed(random_seed)\n", 116 | "np.random.shuffle(indices)\n", 117 | "train_indices, test_indices = indices[split:], indices[:split]\n", 118 | "\n", 119 | "train_sampler = D.sampler.SubsetRandomSampler(train_indices)\n", 120 | "test_sampler = D.sampler.SubsetRandomSampler(test_indices)\n", 121 | "\n", 122 | "# dataloaders can prefetch the next batch if using n workers while\n", 123 | "# the model is tranining\n", 124 | "train_loader = torch.utils.data.DataLoader(dataset,\n", 125 | " batch_size=BATCH_SIZE,\n", 126 | " num_workers=N_WORKERS,\n", 127 | " sampler=train_sampler)\n", 128 | "\n", 129 | "test_loader = torch.utils.data.DataLoader(dataset, \n", 130 | " batch_size=BATCH_SIZE,\n", 131 | " num_workers=N_WORKERS,\n", 132 | " sampler=test_sampler)\n" 133 | ] 134 | }, 135 | { 136 | "cell_type": "markdown", 137 | "metadata": {}, 138 | "source": [ 139 | "# Define the model, the optimizer and the loss criterion" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 5, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "class ChEMBLMultiTask(nn.Module):\n", 149 | " \"\"\"\n", 150 | " Architecture borrowed from: https://arxiv.org/abs/1502.02072\n", 151 | " \"\"\"\n", 152 | " def __init__(self, n_tasks):\n", 153 | " super(ChEMBLMultiTask, self).__init__()\n", 154 | " self.n_tasks = n_tasks\n", 155 | " self.fc1 = nn.Linear(1024, 2000)\n", 156 | " self.fc2 = nn.Linear(2000, 100)\n", 157 | " self.dropout = nn.Dropout(0.25)\n", 158 | "\n", 159 | " # add an independet output for each task int the output laer\n", 160 | " for n_m in range(self.n_tasks):\n", 161 | " self.add_module(f\"y{n_m}o\", nn.Linear(100, 1))\n", 162 | " \n", 163 | " def forward(self, x):\n", 164 | " h1 = self.dropout(F.relu(self.fc1(x)))\n", 165 | " h2 = F.relu(self.fc2(h1))\n", 166 | " out = [torch.sigmoid(getattr(self, f\"y{n_m}o\")(h2)) for n_m in range(self.n_tasks)]\n", 167 | " return out\n", 168 | " \n", 169 | "# create the model, to GPU if available\n", 170 | "model = ChEMBLMultiTask(dataset.n_targets).to(device)\n", 171 | "\n", 172 | "# binary cross entropy\n", 173 | "# each task loss is weighted inversely proportional to its number of datapoints, borrowed from:\n", 174 | "# http://www.bioinf.at/publications/2014/NIPS2014a.pdf\n", 175 | "with tb.open_file(f\"{MAIN_PATH}/{DATA_FILE}\", mode='r') as t_file:\n", 176 | " weights = torch.tensor(t_file.root.weights[:])\n", 177 | " weights = weights.to(device)\n", 178 | "\n", 179 | "criterion = [nn.BCELoss(weight=w) for x, w in zip(range(dataset.n_targets), weights.float())]\n", 180 | "\n", 181 | "# stochastic gradient descend as an optimiser\n", 182 | "optimizer = torch.optim.SGD(model.parameters(), LR)\n" 183 | ] 184 | }, 185 | { 186 | "cell_type": "markdown", 187 | "metadata": {}, 188 | "source": [ 189 | "# Train the model\n", 190 | "Given the extremely sparse nature of the dataset is difficult to clearly see how the loss is improving after every batch. It looks clearer after several epochs and much more clear when testing :)" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 6, 196 | "metadata": {}, 197 | "outputs": [ 198 | { 199 | "name": "stdout", 200 | "output_type": "stream", 201 | "text": [ 202 | "Epoch: [1/2], Step: [500/17789], Loss: 0.01780553348362446\n", 203 | "Epoch: [1/2], Step: [1000/17789], Loss: 0.01136045902967453\n", 204 | "Epoch: [1/2], Step: [1500/17789], Loss: 0.018664617091417313\n", 205 | "Epoch: [1/2], Step: [2000/17789], Loss: 0.013626799918711185\n", 206 | "Epoch: [1/2], Step: [2500/17789], Loss: 0.012855792418122292\n", 207 | "Epoch: [1/2], Step: [3000/17789], Loss: 0.013796127401292324\n", 208 | "Epoch: [1/2], Step: [3500/17789], Loss: 0.021601887419819832\n", 209 | "Epoch: [1/2], Step: [4000/17789], Loss: 0.00950919184833765\n", 210 | "Epoch: [1/2], Step: [4500/17789], Loss: 0.02028888650238514\n", 211 | "Epoch: [1/2], Step: [5000/17789], Loss: 0.013251284137368202\n", 212 | "Epoch: [1/2], Step: [5500/17789], Loss: 0.008788244798779488\n", 213 | "Epoch: [1/2], Step: [6000/17789], Loss: 0.012066680938005447\n", 214 | "Epoch: [1/2], Step: [6500/17789], Loss: 0.013928443193435669\n", 215 | "Epoch: [1/2], Step: [7000/17789], Loss: 0.011484757997095585\n", 216 | "Epoch: [1/2], Step: [7500/17789], Loss: 0.0071386718191206455\n", 217 | "Epoch: [1/2], Step: [8000/17789], Loss: 0.014712771400809288\n", 218 | "Epoch: [1/2], Step: [8500/17789], Loss: 0.010457032360136509\n", 219 | "Epoch: [1/2], Step: [9000/17789], Loss: 0.00854165107011795\n", 220 | "Epoch: [1/2], Step: [9500/17789], Loss: 0.009312299080193043\n", 221 | "Epoch: [1/2], Step: [10000/17789], Loss: 0.010153095237910748\n", 222 | "Epoch: [1/2], Step: [10500/17789], Loss: 0.006983090192079544\n", 223 | "Epoch: [1/2], Step: [11000/17789], Loss: 0.010238541290163994\n", 224 | "Epoch: [1/2], Step: [11500/17789], Loss: 0.012679124251008034\n", 225 | "Epoch: [1/2], Step: [12000/17789], Loss: 0.01116170920431614\n", 226 | "Epoch: [1/2], Step: [12500/17789], Loss: 0.011749005876481533\n", 227 | "Epoch: [1/2], Step: [13000/17789], Loss: 0.015176426619291306\n", 228 | "Epoch: [1/2], Step: [13500/17789], Loss: 0.013586488552391529\n", 229 | "Epoch: [1/2], Step: [14000/17789], Loss: 0.012365413829684258\n", 230 | "Epoch: [1/2], Step: [14500/17789], Loss: 0.009591283276677132\n", 231 | "Epoch: [1/2], Step: [15000/17789], Loss: 0.01857740990817547\n", 232 | "Epoch: [1/2], Step: [15500/17789], Loss: 0.009823130443692207\n", 233 | "Epoch: [1/2], Step: [16000/17789], Loss: 0.01805167831480503\n", 234 | "Epoch: [1/2], Step: [16500/17789], Loss: 0.011896809563040733\n", 235 | "Epoch: [1/2], Step: [17000/17789], Loss: 0.008349821902811527\n", 236 | "Epoch: [1/2], Step: [17500/17789], Loss: 0.013517800718545914\n", 237 | "Epoch: [2/2], Step: [500/17789], Loss: 0.007128629367798567\n", 238 | "Epoch: [2/2], Step: [1000/17789], Loss: 0.01153416559100151\n", 239 | "Epoch: [2/2], Step: [1500/17789], Loss: 0.02041609212756157\n", 240 | "Epoch: [2/2], Step: [2000/17789], Loss: 0.0165218748152256\n", 241 | "Epoch: [2/2], Step: [2500/17789], Loss: 0.011772445403039455\n", 242 | "Epoch: [2/2], Step: [3000/17789], Loss: 0.011200090870261192\n", 243 | "Epoch: [2/2], Step: [3500/17789], Loss: 0.012209323234856129\n", 244 | "Epoch: [2/2], Step: [4000/17789], Loss: 0.007769708056002855\n", 245 | "Epoch: [2/2], Step: [4500/17789], Loss: 0.012243629433214664\n", 246 | "Epoch: [2/2], Step: [5000/17789], Loss: 0.018942933529615402\n", 247 | "Epoch: [2/2], Step: [5500/17789], Loss: 0.013197326101362705\n", 248 | "Epoch: [2/2], Step: [6000/17789], Loss: 0.011520257219672203\n", 249 | "Epoch: [2/2], Step: [6500/17789], Loss: 0.020596494898200035\n", 250 | "Epoch: [2/2], Step: [7000/17789], Loss: 0.018161792308092117\n", 251 | "Epoch: [2/2], Step: [7500/17789], Loss: 0.01610906422138214\n", 252 | "Epoch: [2/2], Step: [8000/17789], Loss: 0.004183729644864798\n", 253 | "Epoch: [2/2], Step: [8500/17789], Loss: 0.01284581795334816\n", 254 | "Epoch: [2/2], Step: [9000/17789], Loss: 0.014269811101257801\n", 255 | "Epoch: [2/2], Step: [9500/17789], Loss: 0.009626287035644054\n", 256 | "Epoch: [2/2], Step: [10000/17789], Loss: 0.008639814332127571\n", 257 | "Epoch: [2/2], Step: [10500/17789], Loss: 0.011639382690191269\n", 258 | "Epoch: [2/2], Step: [11000/17789], Loss: 0.005331861320883036\n", 259 | "Epoch: [2/2], Step: [11500/17789], Loss: 0.011540957726538181\n", 260 | "Epoch: [2/2], Step: [12000/17789], Loss: 0.010148015804588795\n", 261 | "Epoch: [2/2], Step: [12500/17789], Loss: 0.011556670069694519\n", 262 | "Epoch: [2/2], Step: [13000/17789], Loss: 0.0069694253616034985\n", 263 | "Epoch: [2/2], Step: [13500/17789], Loss: 0.008971192874014378\n", 264 | "Epoch: [2/2], Step: [14000/17789], Loss: 0.02061212807893753\n", 265 | "Epoch: [2/2], Step: [14500/17789], Loss: 0.013362124562263489\n", 266 | "Epoch: [2/2], Step: [15000/17789], Loss: 0.00966110359877348\n", 267 | "Epoch: [2/2], Step: [15500/17789], Loss: 0.017838571220636368\n", 268 | "Epoch: [2/2], Step: [16000/17789], Loss: 0.007174369413405657\n", 269 | "Epoch: [2/2], Step: [16500/17789], Loss: 0.0074622794054448605\n", 270 | "Epoch: [2/2], Step: [17000/17789], Loss: 0.015448285266757011\n", 271 | "Epoch: [2/2], Step: [17500/17789], Loss: 0.011626753024756908\n" 272 | ] 273 | } 274 | ], 275 | "source": [ 276 | "# model is by default in train mode. Training can be resumed after .eval() but needs to be set to .train() again\n", 277 | "model.train()\n", 278 | "for ep in range(N_EPOCHS):\n", 279 | " for i, (fps, labels) in enumerate(train_loader):\n", 280 | " # move it to GPU if available\n", 281 | " fps, labels = fps.to(device), labels.to(device)\n", 282 | "\n", 283 | " optimizer.zero_grad()\n", 284 | " outputs = model(fps)\n", 285 | " \n", 286 | " # calc the loss\n", 287 | " loss = torch.tensor(0.0).to(device)\n", 288 | " for j, crit in enumerate(criterion):\n", 289 | " # mask keeping labeled molecules for each task\n", 290 | " mask = labels[:, j] >= 0.0\n", 291 | " if len(labels[:, j][mask]) != 0:\n", 292 | " # the loss is the sum of each task/target loss.\n", 293 | " # there are labeled samples for this task, so we add it's loss\n", 294 | " loss += crit(outputs[j][mask], labels[:, j][mask].view(-1, 1))\n", 295 | "\n", 296 | " loss.backward()\n", 297 | " optimizer.step()\n", 298 | "\n", 299 | " if (i+1) % 500 == 0:\n", 300 | " print(f\"Epoch: [{ep+1}/{N_EPOCHS}], Step: [{i+1}/{len(train_indices)//BATCH_SIZE}], Loss: {loss.item()}\")\n", 301 | " " 302 | ] 303 | }, 304 | { 305 | "cell_type": "markdown", 306 | "metadata": {}, 307 | "source": [ 308 | "# Test the model" 309 | ] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "execution_count": 7, 314 | "metadata": {}, 315 | "outputs": [ 316 | { 317 | "name": "stdout", 318 | "output_type": "stream", 319 | "text": [ 320 | "accuracy: 0.8371918235997756, auc: 0.8942389411754185, sens: 0.7053822792666977, spec: 0.8987519347341067, prec: 0.7649158653846154, mcc: 0.6179805824644773, f1: 0.733943790291889\n", 321 | "Not bad for only 2 epochs!\n" 322 | ] 323 | } 324 | ], 325 | "source": [ 326 | "y_trues = []\n", 327 | "y_preds = []\n", 328 | "y_preds_proba = []\n", 329 | "\n", 330 | "# do not track history\n", 331 | "with torch.no_grad():\n", 332 | " for fps, labels in test_loader:\n", 333 | " # move it to GPU if available\n", 334 | " fps, labels = fps.to(device), labels.to(device)\n", 335 | " # set model to eval, so will not use the dropout layer\n", 336 | " model.eval()\n", 337 | " outputs = model(fps)\n", 338 | " for j, out in enumerate(outputs):\n", 339 | " mask = labels[:, j] >= 0.0\n", 340 | " y_pred = torch.where(out[mask] > 0.5, torch.ones(1), torch.zeros(1)).view(1, -1)\n", 341 | "\n", 342 | " if y_pred.shape[1] > 0:\n", 343 | " for l in labels[:, j][mask].long().tolist():\n", 344 | " y_trues.append(l)\n", 345 | " for p in y_pred.view(-1, 1).tolist():\n", 346 | " y_preds.append(int(p[0]))\n", 347 | " for p in out[mask].view(-1, 1).tolist():\n", 348 | " y_preds_proba.append(float(p[0]))\n", 349 | "\n", 350 | "tn, fp, fn, tp = confusion_matrix(y_trues, y_preds).ravel()\n", 351 | "sens = tp / (tp + fn)\n", 352 | "spec = tn / (tn + fp)\n", 353 | "prec = tp / (tp + fp)\n", 354 | "f1 = f1_score(y_trues, y_preds)\n", 355 | "acc = accuracy_score(y_trues, y_preds)\n", 356 | "mcc = matthews_corrcoef(y_trues, y_preds)\n", 357 | "auc = roc_auc_score(y_trues, y_preds_proba)\n", 358 | "\n", 359 | "print(f\"accuracy: {acc}, auc: {auc}, sens: {sens}, spec: {spec}, prec: {prec}, mcc: {mcc}, f1: {f1}\")\n", 360 | "print(f\"Not bad for only {N_EPOCHS} epochs!\")" 361 | ] 362 | }, 363 | { 364 | "cell_type": "markdown", 365 | "metadata": {}, 366 | "source": [ 367 | "# Save the model to a file" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": 8, 373 | "metadata": {}, 374 | "outputs": [], 375 | "source": [ 376 | "torch.save(model.state_dict(), f\"./{MODEL_FILE}\")" 377 | ] 378 | }, 379 | { 380 | "cell_type": "markdown", 381 | "metadata": {}, 382 | "source": [ 383 | "# Load the model from the file" 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": 9, 389 | "metadata": {}, 390 | "outputs": [ 391 | { 392 | "data": { 393 | "text/plain": [ 394 | "ChEMBLMultiTask(\n", 395 | " (fc1): Linear(in_features=1024, out_features=2000, bias=True)\n", 396 | " (fc2): Linear(in_features=2000, out_features=100, bias=True)\n", 397 | " (dropout): Dropout(p=0.25)\n", 398 | " (y0o): Linear(in_features=100, out_features=1, bias=True)\n", 399 | " (y1o): Linear(in_features=100, out_features=1, bias=True)\n", 400 | " (y2o): Linear(in_features=100, out_features=1, bias=True)\n", 401 | " (y3o): Linear(in_features=100, out_features=1, bias=True)\n", 402 | " (y4o): Linear(in_features=100, out_features=1, bias=True)\n", 403 | " (y5o): Linear(in_features=100, out_features=1, bias=True)\n", 404 | " (y6o): Linear(in_features=100, out_features=1, bias=True)\n", 405 | " (y7o): Linear(in_features=100, out_features=1, bias=True)\n", 406 | " (y8o): Linear(in_features=100, out_features=1, bias=True)\n", 407 | " (y9o): Linear(in_features=100, out_features=1, bias=True)\n", 408 | " (y10o): Linear(in_features=100, out_features=1, bias=True)\n", 409 | " (y11o): Linear(in_features=100, out_features=1, bias=True)\n", 410 | " (y12o): Linear(in_features=100, out_features=1, bias=True)\n", 411 | " (y13o): Linear(in_features=100, out_features=1, bias=True)\n", 412 | " (y14o): Linear(in_features=100, out_features=1, bias=True)\n", 413 | " (y15o): Linear(in_features=100, out_features=1, bias=True)\n", 414 | " (y16o): Linear(in_features=100, out_features=1, bias=True)\n", 415 | " (y17o): Linear(in_features=100, out_features=1, bias=True)\n", 416 | " (y18o): Linear(in_features=100, out_features=1, bias=True)\n", 417 | " (y19o): Linear(in_features=100, out_features=1, bias=True)\n", 418 | " (y20o): Linear(in_features=100, out_features=1, bias=True)\n", 419 | " (y21o): Linear(in_features=100, out_features=1, bias=True)\n", 420 | " (y22o): Linear(in_features=100, out_features=1, bias=True)\n", 421 | " (y23o): Linear(in_features=100, out_features=1, bias=True)\n", 422 | " (y24o): Linear(in_features=100, out_features=1, bias=True)\n", 423 | " (y25o): Linear(in_features=100, out_features=1, bias=True)\n", 424 | " (y26o): Linear(in_features=100, out_features=1, bias=True)\n", 425 | " (y27o): Linear(in_features=100, out_features=1, bias=True)\n", 426 | " (y28o): Linear(in_features=100, out_features=1, bias=True)\n", 427 | " (y29o): Linear(in_features=100, out_features=1, bias=True)\n", 428 | " (y30o): Linear(in_features=100, out_features=1, bias=True)\n", 429 | " (y31o): Linear(in_features=100, out_features=1, bias=True)\n", 430 | " (y32o): Linear(in_features=100, out_features=1, bias=True)\n", 431 | " (y33o): Linear(in_features=100, out_features=1, bias=True)\n", 432 | " (y34o): Linear(in_features=100, out_features=1, bias=True)\n", 433 | " (y35o): Linear(in_features=100, out_features=1, bias=True)\n", 434 | " (y36o): Linear(in_features=100, out_features=1, bias=True)\n", 435 | " (y37o): Linear(in_features=100, out_features=1, bias=True)\n", 436 | " (y38o): Linear(in_features=100, out_features=1, bias=True)\n", 437 | " (y39o): Linear(in_features=100, out_features=1, bias=True)\n", 438 | " (y40o): Linear(in_features=100, out_features=1, bias=True)\n", 439 | " (y41o): Linear(in_features=100, out_features=1, bias=True)\n", 440 | " (y42o): Linear(in_features=100, out_features=1, bias=True)\n", 441 | " (y43o): Linear(in_features=100, out_features=1, bias=True)\n", 442 | " (y44o): Linear(in_features=100, out_features=1, bias=True)\n", 443 | " (y45o): Linear(in_features=100, out_features=1, bias=True)\n", 444 | " (y46o): Linear(in_features=100, out_features=1, bias=True)\n", 445 | " (y47o): Linear(in_features=100, out_features=1, bias=True)\n", 446 | " (y48o): Linear(in_features=100, out_features=1, bias=True)\n", 447 | " (y49o): Linear(in_features=100, out_features=1, bias=True)\n", 448 | " (y50o): Linear(in_features=100, out_features=1, bias=True)\n", 449 | " (y51o): Linear(in_features=100, out_features=1, bias=True)\n", 450 | " (y52o): Linear(in_features=100, out_features=1, bias=True)\n", 451 | " (y53o): Linear(in_features=100, out_features=1, bias=True)\n", 452 | " (y54o): Linear(in_features=100, out_features=1, bias=True)\n", 453 | " (y55o): Linear(in_features=100, out_features=1, bias=True)\n", 454 | " (y56o): Linear(in_features=100, out_features=1, bias=True)\n", 455 | " (y57o): Linear(in_features=100, out_features=1, bias=True)\n", 456 | " (y58o): Linear(in_features=100, out_features=1, bias=True)\n", 457 | " (y59o): Linear(in_features=100, out_features=1, bias=True)\n", 458 | " (y60o): Linear(in_features=100, out_features=1, bias=True)\n", 459 | " (y61o): Linear(in_features=100, out_features=1, bias=True)\n", 460 | " (y62o): Linear(in_features=100, out_features=1, bias=True)\n", 461 | " (y63o): Linear(in_features=100, out_features=1, bias=True)\n", 462 | " (y64o): Linear(in_features=100, out_features=1, bias=True)\n", 463 | " (y65o): Linear(in_features=100, out_features=1, bias=True)\n", 464 | " (y66o): Linear(in_features=100, out_features=1, bias=True)\n", 465 | " (y67o): Linear(in_features=100, out_features=1, bias=True)\n", 466 | " (y68o): Linear(in_features=100, out_features=1, bias=True)\n", 467 | " (y69o): Linear(in_features=100, out_features=1, bias=True)\n", 468 | " (y70o): Linear(in_features=100, out_features=1, bias=True)\n", 469 | " (y71o): Linear(in_features=100, out_features=1, bias=True)\n", 470 | " (y72o): Linear(in_features=100, out_features=1, bias=True)\n", 471 | " (y73o): Linear(in_features=100, out_features=1, bias=True)\n", 472 | " (y74o): Linear(in_features=100, out_features=1, bias=True)\n", 473 | " (y75o): Linear(in_features=100, out_features=1, bias=True)\n", 474 | " (y76o): Linear(in_features=100, out_features=1, bias=True)\n", 475 | " (y77o): Linear(in_features=100, out_features=1, bias=True)\n", 476 | " (y78o): Linear(in_features=100, out_features=1, bias=True)\n", 477 | " (y79o): Linear(in_features=100, out_features=1, bias=True)\n", 478 | " (y80o): Linear(in_features=100, out_features=1, bias=True)\n", 479 | " (y81o): Linear(in_features=100, out_features=1, bias=True)\n", 480 | " (y82o): Linear(in_features=100, out_features=1, bias=True)\n", 481 | " (y83o): Linear(in_features=100, out_features=1, bias=True)\n", 482 | " (y84o): Linear(in_features=100, out_features=1, bias=True)\n", 483 | " (y85o): Linear(in_features=100, out_features=1, bias=True)\n", 484 | " (y86o): Linear(in_features=100, out_features=1, bias=True)\n", 485 | " (y87o): Linear(in_features=100, out_features=1, bias=True)\n", 486 | " (y88o): Linear(in_features=100, out_features=1, bias=True)\n", 487 | " (y89o): Linear(in_features=100, out_features=1, bias=True)\n", 488 | " (y90o): Linear(in_features=100, out_features=1, bias=True)\n", 489 | " (y91o): Linear(in_features=100, out_features=1, bias=True)\n", 490 | " (y92o): Linear(in_features=100, out_features=1, bias=True)\n", 491 | " (y93o): Linear(in_features=100, out_features=1, bias=True)\n", 492 | " (y94o): Linear(in_features=100, out_features=1, bias=True)\n", 493 | " (y95o): Linear(in_features=100, out_features=1, bias=True)\n", 494 | " (y96o): Linear(in_features=100, out_features=1, bias=True)\n", 495 | " (y97o): Linear(in_features=100, out_features=1, bias=True)\n", 496 | " (y98o): Linear(in_features=100, out_features=1, bias=True)\n", 497 | " (y99o): Linear(in_features=100, out_features=1, bias=True)\n", 498 | " (y100o): Linear(in_features=100, out_features=1, bias=True)\n", 499 | " (y101o): Linear(in_features=100, out_features=1, bias=True)\n", 500 | " (y102o): Linear(in_features=100, out_features=1, bias=True)\n", 501 | " (y103o): Linear(in_features=100, out_features=1, bias=True)\n", 502 | " (y104o): Linear(in_features=100, out_features=1, bias=True)\n", 503 | " (y105o): Linear(in_features=100, out_features=1, bias=True)\n", 504 | " (y106o): Linear(in_features=100, out_features=1, bias=True)\n", 505 | " (y107o): Linear(in_features=100, out_features=1, bias=True)\n", 506 | " (y108o): Linear(in_features=100, out_features=1, bias=True)\n", 507 | " (y109o): Linear(in_features=100, out_features=1, bias=True)\n", 508 | " (y110o): Linear(in_features=100, out_features=1, bias=True)\n", 509 | " (y111o): Linear(in_features=100, out_features=1, bias=True)\n", 510 | " (y112o): Linear(in_features=100, out_features=1, bias=True)\n", 511 | " (y113o): Linear(in_features=100, out_features=1, bias=True)\n", 512 | " (y114o): Linear(in_features=100, out_features=1, bias=True)\n", 513 | " (y115o): Linear(in_features=100, out_features=1, bias=True)\n", 514 | " (y116o): Linear(in_features=100, out_features=1, bias=True)\n", 515 | " (y117o): Linear(in_features=100, out_features=1, bias=True)\n", 516 | " (y118o): Linear(in_features=100, out_features=1, bias=True)\n", 517 | " (y119o): Linear(in_features=100, out_features=1, bias=True)\n", 518 | " (y120o): Linear(in_features=100, out_features=1, bias=True)\n", 519 | " (y121o): Linear(in_features=100, out_features=1, bias=True)\n", 520 | " (y122o): Linear(in_features=100, out_features=1, bias=True)\n", 521 | " (y123o): Linear(in_features=100, out_features=1, bias=True)\n", 522 | " (y124o): Linear(in_features=100, out_features=1, bias=True)\n", 523 | " (y125o): Linear(in_features=100, out_features=1, bias=True)\n", 524 | " (y126o): Linear(in_features=100, out_features=1, bias=True)\n", 525 | " (y127o): Linear(in_features=100, out_features=1, bias=True)\n", 526 | " (y128o): Linear(in_features=100, out_features=1, bias=True)\n", 527 | " (y129o): Linear(in_features=100, out_features=1, bias=True)\n", 528 | " (y130o): Linear(in_features=100, out_features=1, bias=True)\n", 529 | " (y131o): Linear(in_features=100, out_features=1, bias=True)\n", 530 | " (y132o): Linear(in_features=100, out_features=1, bias=True)\n", 531 | " (y133o): Linear(in_features=100, out_features=1, bias=True)\n", 532 | " (y134o): Linear(in_features=100, out_features=1, bias=True)\n", 533 | " (y135o): Linear(in_features=100, out_features=1, bias=True)\n", 534 | " (y136o): Linear(in_features=100, out_features=1, bias=True)\n", 535 | " (y137o): Linear(in_features=100, out_features=1, bias=True)\n", 536 | " (y138o): Linear(in_features=100, out_features=1, bias=True)\n", 537 | " (y139o): Linear(in_features=100, out_features=1, bias=True)\n", 538 | " (y140o): Linear(in_features=100, out_features=1, bias=True)\n", 539 | " (y141o): Linear(in_features=100, out_features=1, bias=True)\n", 540 | " (y142o): Linear(in_features=100, out_features=1, bias=True)\n", 541 | " (y143o): Linear(in_features=100, out_features=1, bias=True)\n", 542 | " (y144o): Linear(in_features=100, out_features=1, bias=True)\n", 543 | " (y145o): Linear(in_features=100, out_features=1, bias=True)\n", 544 | " (y146o): Linear(in_features=100, out_features=1, bias=True)\n", 545 | " (y147o): Linear(in_features=100, out_features=1, bias=True)\n", 546 | " (y148o): Linear(in_features=100, out_features=1, bias=True)\n", 547 | " (y149o): Linear(in_features=100, out_features=1, bias=True)\n", 548 | " (y150o): Linear(in_features=100, out_features=1, bias=True)\n", 549 | " (y151o): Linear(in_features=100, out_features=1, bias=True)\n", 550 | " (y152o): Linear(in_features=100, out_features=1, bias=True)\n", 551 | " (y153o): Linear(in_features=100, out_features=1, bias=True)\n", 552 | " (y154o): Linear(in_features=100, out_features=1, bias=True)\n", 553 | " (y155o): Linear(in_features=100, out_features=1, bias=True)\n", 554 | " (y156o): Linear(in_features=100, out_features=1, bias=True)\n", 555 | " (y157o): Linear(in_features=100, out_features=1, bias=True)\n", 556 | " (y158o): Linear(in_features=100, out_features=1, bias=True)\n", 557 | " (y159o): Linear(in_features=100, out_features=1, bias=True)\n", 558 | " (y160o): Linear(in_features=100, out_features=1, bias=True)\n", 559 | " (y161o): Linear(in_features=100, out_features=1, bias=True)\n", 560 | " (y162o): Linear(in_features=100, out_features=1, bias=True)\n", 561 | " (y163o): Linear(in_features=100, out_features=1, bias=True)\n", 562 | " (y164o): Linear(in_features=100, out_features=1, bias=True)\n", 563 | " (y165o): Linear(in_features=100, out_features=1, bias=True)\n", 564 | " (y166o): Linear(in_features=100, out_features=1, bias=True)\n", 565 | " (y167o): Linear(in_features=100, out_features=1, bias=True)\n", 566 | " (y168o): Linear(in_features=100, out_features=1, bias=True)\n", 567 | " (y169o): Linear(in_features=100, out_features=1, bias=True)\n", 568 | " (y170o): Linear(in_features=100, out_features=1, bias=True)\n", 569 | " (y171o): Linear(in_features=100, out_features=1, bias=True)\n", 570 | " (y172o): Linear(in_features=100, out_features=1, bias=True)\n", 571 | " (y173o): Linear(in_features=100, out_features=1, bias=True)\n", 572 | " (y174o): Linear(in_features=100, out_features=1, bias=True)\n", 573 | " (y175o): Linear(in_features=100, out_features=1, bias=True)\n", 574 | " (y176o): Linear(in_features=100, out_features=1, bias=True)\n", 575 | " (y177o): Linear(in_features=100, out_features=1, bias=True)\n", 576 | " (y178o): Linear(in_features=100, out_features=1, bias=True)\n", 577 | " (y179o): Linear(in_features=100, out_features=1, bias=True)\n", 578 | " (y180o): Linear(in_features=100, out_features=1, bias=True)\n", 579 | " (y181o): Linear(in_features=100, out_features=1, bias=True)\n", 580 | " (y182o): Linear(in_features=100, out_features=1, bias=True)\n", 581 | " (y183o): Linear(in_features=100, out_features=1, bias=True)\n", 582 | " (y184o): Linear(in_features=100, out_features=1, bias=True)\n", 583 | " (y185o): Linear(in_features=100, out_features=1, bias=True)\n", 584 | " (y186o): Linear(in_features=100, out_features=1, bias=True)\n", 585 | " (y187o): Linear(in_features=100, out_features=1, bias=True)\n", 586 | " (y188o): Linear(in_features=100, out_features=1, bias=True)\n", 587 | " (y189o): Linear(in_features=100, out_features=1, bias=True)\n", 588 | " (y190o): Linear(in_features=100, out_features=1, bias=True)\n", 589 | " (y191o): Linear(in_features=100, out_features=1, bias=True)\n", 590 | " (y192o): Linear(in_features=100, out_features=1, bias=True)\n", 591 | " (y193o): Linear(in_features=100, out_features=1, bias=True)\n", 592 | " (y194o): Linear(in_features=100, out_features=1, bias=True)\n", 593 | " (y195o): Linear(in_features=100, out_features=1, bias=True)\n", 594 | " (y196o): Linear(in_features=100, out_features=1, bias=True)\n", 595 | " (y197o): Linear(in_features=100, out_features=1, bias=True)\n", 596 | " (y198o): Linear(in_features=100, out_features=1, bias=True)\n", 597 | " (y199o): Linear(in_features=100, out_features=1, bias=True)\n", 598 | " (y200o): Linear(in_features=100, out_features=1, bias=True)\n", 599 | " (y201o): Linear(in_features=100, out_features=1, bias=True)\n", 600 | " (y202o): Linear(in_features=100, out_features=1, bias=True)\n", 601 | " (y203o): Linear(in_features=100, out_features=1, bias=True)\n", 602 | " (y204o): Linear(in_features=100, out_features=1, bias=True)\n", 603 | " (y205o): Linear(in_features=100, out_features=1, bias=True)\n", 604 | " (y206o): Linear(in_features=100, out_features=1, bias=True)\n", 605 | " (y207o): Linear(in_features=100, out_features=1, bias=True)\n", 606 | " (y208o): Linear(in_features=100, out_features=1, bias=True)\n", 607 | " (y209o): Linear(in_features=100, out_features=1, bias=True)\n", 608 | " (y210o): Linear(in_features=100, out_features=1, bias=True)\n", 609 | " (y211o): Linear(in_features=100, out_features=1, bias=True)\n", 610 | " (y212o): Linear(in_features=100, out_features=1, bias=True)\n", 611 | " (y213o): Linear(in_features=100, out_features=1, bias=True)\n", 612 | " (y214o): Linear(in_features=100, out_features=1, bias=True)\n", 613 | " (y215o): Linear(in_features=100, out_features=1, bias=True)\n", 614 | " (y216o): Linear(in_features=100, out_features=1, bias=True)\n", 615 | " (y217o): Linear(in_features=100, out_features=1, bias=True)\n", 616 | " (y218o): Linear(in_features=100, out_features=1, bias=True)\n", 617 | " (y219o): Linear(in_features=100, out_features=1, bias=True)\n", 618 | " (y220o): Linear(in_features=100, out_features=1, bias=True)\n", 619 | " (y221o): Linear(in_features=100, out_features=1, bias=True)\n", 620 | " (y222o): Linear(in_features=100, out_features=1, bias=True)\n", 621 | " (y223o): Linear(in_features=100, out_features=1, bias=True)\n", 622 | " (y224o): Linear(in_features=100, out_features=1, bias=True)\n", 623 | " (y225o): Linear(in_features=100, out_features=1, bias=True)\n", 624 | " (y226o): Linear(in_features=100, out_features=1, bias=True)\n", 625 | " (y227o): Linear(in_features=100, out_features=1, bias=True)\n", 626 | " (y228o): Linear(in_features=100, out_features=1, bias=True)\n", 627 | " (y229o): Linear(in_features=100, out_features=1, bias=True)\n", 628 | " (y230o): Linear(in_features=100, out_features=1, bias=True)\n", 629 | " (y231o): Linear(in_features=100, out_features=1, bias=True)\n", 630 | " (y232o): Linear(in_features=100, out_features=1, bias=True)\n", 631 | " (y233o): Linear(in_features=100, out_features=1, bias=True)\n", 632 | " (y234o): Linear(in_features=100, out_features=1, bias=True)\n", 633 | " (y235o): Linear(in_features=100, out_features=1, bias=True)\n", 634 | " (y236o): Linear(in_features=100, out_features=1, bias=True)\n", 635 | " (y237o): Linear(in_features=100, out_features=1, bias=True)\n", 636 | " (y238o): Linear(in_features=100, out_features=1, bias=True)\n", 637 | " (y239o): Linear(in_features=100, out_features=1, bias=True)\n", 638 | " (y240o): Linear(in_features=100, out_features=1, bias=True)\n", 639 | " (y241o): Linear(in_features=100, out_features=1, bias=True)\n", 640 | " (y242o): Linear(in_features=100, out_features=1, bias=True)\n", 641 | " (y243o): Linear(in_features=100, out_features=1, bias=True)\n", 642 | " (y244o): Linear(in_features=100, out_features=1, bias=True)\n", 643 | " (y245o): Linear(in_features=100, out_features=1, bias=True)\n", 644 | " (y246o): Linear(in_features=100, out_features=1, bias=True)\n", 645 | " (y247o): Linear(in_features=100, out_features=1, bias=True)\n", 646 | " (y248o): Linear(in_features=100, out_features=1, bias=True)\n", 647 | " (y249o): Linear(in_features=100, out_features=1, bias=True)\n", 648 | " (y250o): Linear(in_features=100, out_features=1, bias=True)\n", 649 | " (y251o): Linear(in_features=100, out_features=1, bias=True)\n", 650 | " (y252o): Linear(in_features=100, out_features=1, bias=True)\n", 651 | " (y253o): Linear(in_features=100, out_features=1, bias=True)\n", 652 | " (y254o): Linear(in_features=100, out_features=1, bias=True)\n", 653 | " (y255o): Linear(in_features=100, out_features=1, bias=True)\n", 654 | " (y256o): Linear(in_features=100, out_features=1, bias=True)\n", 655 | " (y257o): Linear(in_features=100, out_features=1, bias=True)\n", 656 | " (y258o): Linear(in_features=100, out_features=1, bias=True)\n", 657 | " (y259o): Linear(in_features=100, out_features=1, bias=True)\n", 658 | " (y260o): Linear(in_features=100, out_features=1, bias=True)\n", 659 | " (y261o): Linear(in_features=100, out_features=1, bias=True)\n", 660 | " (y262o): Linear(in_features=100, out_features=1, bias=True)\n", 661 | " (y263o): Linear(in_features=100, out_features=1, bias=True)\n", 662 | " (y264o): Linear(in_features=100, out_features=1, bias=True)\n", 663 | " (y265o): Linear(in_features=100, out_features=1, bias=True)\n", 664 | " (y266o): Linear(in_features=100, out_features=1, bias=True)\n", 665 | " (y267o): Linear(in_features=100, out_features=1, bias=True)\n", 666 | " (y268o): Linear(in_features=100, out_features=1, bias=True)\n", 667 | " (y269o): Linear(in_features=100, out_features=1, bias=True)\n", 668 | " (y270o): Linear(in_features=100, out_features=1, bias=True)\n", 669 | " (y271o): Linear(in_features=100, out_features=1, bias=True)\n", 670 | " (y272o): Linear(in_features=100, out_features=1, bias=True)\n", 671 | " (y273o): Linear(in_features=100, out_features=1, bias=True)\n", 672 | " (y274o): Linear(in_features=100, out_features=1, bias=True)\n", 673 | " (y275o): Linear(in_features=100, out_features=1, bias=True)\n", 674 | " (y276o): Linear(in_features=100, out_features=1, bias=True)\n", 675 | " (y277o): Linear(in_features=100, out_features=1, bias=True)\n", 676 | " (y278o): Linear(in_features=100, out_features=1, bias=True)\n", 677 | " (y279o): Linear(in_features=100, out_features=1, bias=True)\n", 678 | " (y280o): Linear(in_features=100, out_features=1, bias=True)\n", 679 | " (y281o): Linear(in_features=100, out_features=1, bias=True)\n", 680 | " (y282o): Linear(in_features=100, out_features=1, bias=True)\n", 681 | " (y283o): Linear(in_features=100, out_features=1, bias=True)\n", 682 | " (y284o): Linear(in_features=100, out_features=1, bias=True)\n", 683 | " (y285o): Linear(in_features=100, out_features=1, bias=True)\n", 684 | " (y286o): Linear(in_features=100, out_features=1, bias=True)\n", 685 | " (y287o): Linear(in_features=100, out_features=1, bias=True)\n", 686 | " (y288o): Linear(in_features=100, out_features=1, bias=True)\n", 687 | " (y289o): Linear(in_features=100, out_features=1, bias=True)\n", 688 | " (y290o): Linear(in_features=100, out_features=1, bias=True)\n", 689 | " (y291o): Linear(in_features=100, out_features=1, bias=True)\n", 690 | " (y292o): Linear(in_features=100, out_features=1, bias=True)\n", 691 | " (y293o): Linear(in_features=100, out_features=1, bias=True)\n", 692 | " (y294o): Linear(in_features=100, out_features=1, bias=True)\n", 693 | " (y295o): Linear(in_features=100, out_features=1, bias=True)\n", 694 | " (y296o): Linear(in_features=100, out_features=1, bias=True)\n", 695 | " (y297o): Linear(in_features=100, out_features=1, bias=True)\n", 696 | " (y298o): Linear(in_features=100, out_features=1, bias=True)\n", 697 | " (y299o): Linear(in_features=100, out_features=1, bias=True)\n", 698 | " (y300o): Linear(in_features=100, out_features=1, bias=True)\n", 699 | " (y301o): Linear(in_features=100, out_features=1, bias=True)\n", 700 | " (y302o): Linear(in_features=100, out_features=1, bias=True)\n", 701 | " (y303o): Linear(in_features=100, out_features=1, bias=True)\n", 702 | " (y304o): Linear(in_features=100, out_features=1, bias=True)\n", 703 | " (y305o): Linear(in_features=100, out_features=1, bias=True)\n", 704 | " (y306o): Linear(in_features=100, out_features=1, bias=True)\n", 705 | " (y307o): Linear(in_features=100, out_features=1, bias=True)\n", 706 | " (y308o): Linear(in_features=100, out_features=1, bias=True)\n", 707 | " (y309o): Linear(in_features=100, out_features=1, bias=True)\n", 708 | " (y310o): Linear(in_features=100, out_features=1, bias=True)\n", 709 | " (y311o): Linear(in_features=100, out_features=1, bias=True)\n", 710 | " (y312o): Linear(in_features=100, out_features=1, bias=True)\n", 711 | " (y313o): Linear(in_features=100, out_features=1, bias=True)\n", 712 | " (y314o): Linear(in_features=100, out_features=1, bias=True)\n", 713 | " (y315o): Linear(in_features=100, out_features=1, bias=True)\n", 714 | " (y316o): Linear(in_features=100, out_features=1, bias=True)\n", 715 | " (y317o): Linear(in_features=100, out_features=1, bias=True)\n", 716 | " (y318o): Linear(in_features=100, out_features=1, bias=True)\n", 717 | " (y319o): Linear(in_features=100, out_features=1, bias=True)\n", 718 | " (y320o): Linear(in_features=100, out_features=1, bias=True)\n", 719 | " (y321o): Linear(in_features=100, out_features=1, bias=True)\n", 720 | " (y322o): Linear(in_features=100, out_features=1, bias=True)\n", 721 | " (y323o): Linear(in_features=100, out_features=1, bias=True)\n", 722 | " (y324o): Linear(in_features=100, out_features=1, bias=True)\n", 723 | " (y325o): Linear(in_features=100, out_features=1, bias=True)\n", 724 | " (y326o): Linear(in_features=100, out_features=1, bias=True)\n", 725 | " (y327o): Linear(in_features=100, out_features=1, bias=True)\n", 726 | " (y328o): Linear(in_features=100, out_features=1, bias=True)\n", 727 | " (y329o): Linear(in_features=100, out_features=1, bias=True)\n", 728 | " (y330o): Linear(in_features=100, out_features=1, bias=True)\n", 729 | " (y331o): Linear(in_features=100, out_features=1, bias=True)\n", 730 | " (y332o): Linear(in_features=100, out_features=1, bias=True)\n", 731 | " (y333o): Linear(in_features=100, out_features=1, bias=True)\n", 732 | " (y334o): Linear(in_features=100, out_features=1, bias=True)\n", 733 | " (y335o): Linear(in_features=100, out_features=1, bias=True)\n", 734 | " (y336o): Linear(in_features=100, out_features=1, bias=True)\n", 735 | " (y337o): Linear(in_features=100, out_features=1, bias=True)\n", 736 | " (y338o): Linear(in_features=100, out_features=1, bias=True)\n", 737 | " (y339o): Linear(in_features=100, out_features=1, bias=True)\n", 738 | " (y340o): Linear(in_features=100, out_features=1, bias=True)\n", 739 | " (y341o): Linear(in_features=100, out_features=1, bias=True)\n", 740 | " (y342o): Linear(in_features=100, out_features=1, bias=True)\n", 741 | " (y343o): Linear(in_features=100, out_features=1, bias=True)\n", 742 | " (y344o): Linear(in_features=100, out_features=1, bias=True)\n", 743 | " (y345o): Linear(in_features=100, out_features=1, bias=True)\n", 744 | " (y346o): Linear(in_features=100, out_features=1, bias=True)\n", 745 | " (y347o): Linear(in_features=100, out_features=1, bias=True)\n", 746 | " (y348o): Linear(in_features=100, out_features=1, bias=True)\n", 747 | " (y349o): Linear(in_features=100, out_features=1, bias=True)\n", 748 | " (y350o): Linear(in_features=100, out_features=1, bias=True)\n", 749 | " (y351o): Linear(in_features=100, out_features=1, bias=True)\n", 750 | " (y352o): Linear(in_features=100, out_features=1, bias=True)\n", 751 | " (y353o): Linear(in_features=100, out_features=1, bias=True)\n", 752 | " (y354o): Linear(in_features=100, out_features=1, bias=True)\n", 753 | " (y355o): Linear(in_features=100, out_features=1, bias=True)\n", 754 | " (y356o): Linear(in_features=100, out_features=1, bias=True)\n", 755 | " (y357o): Linear(in_features=100, out_features=1, bias=True)\n", 756 | " (y358o): Linear(in_features=100, out_features=1, bias=True)\n", 757 | " (y359o): Linear(in_features=100, out_features=1, bias=True)\n", 758 | " (y360o): Linear(in_features=100, out_features=1, bias=True)\n", 759 | " (y361o): Linear(in_features=100, out_features=1, bias=True)\n", 760 | " (y362o): Linear(in_features=100, out_features=1, bias=True)\n", 761 | " (y363o): Linear(in_features=100, out_features=1, bias=True)\n", 762 | " (y364o): Linear(in_features=100, out_features=1, bias=True)\n", 763 | " (y365o): Linear(in_features=100, out_features=1, bias=True)\n", 764 | " (y366o): Linear(in_features=100, out_features=1, bias=True)\n", 765 | " (y367o): Linear(in_features=100, out_features=1, bias=True)\n", 766 | " (y368o): Linear(in_features=100, out_features=1, bias=True)\n", 767 | " (y369o): Linear(in_features=100, out_features=1, bias=True)\n", 768 | " (y370o): Linear(in_features=100, out_features=1, bias=True)\n", 769 | " (y371o): Linear(in_features=100, out_features=1, bias=True)\n", 770 | " (y372o): Linear(in_features=100, out_features=1, bias=True)\n", 771 | " (y373o): Linear(in_features=100, out_features=1, bias=True)\n", 772 | " (y374o): Linear(in_features=100, out_features=1, bias=True)\n", 773 | " (y375o): Linear(in_features=100, out_features=1, bias=True)\n", 774 | " (y376o): Linear(in_features=100, out_features=1, bias=True)\n", 775 | " (y377o): Linear(in_features=100, out_features=1, bias=True)\n", 776 | " (y378o): Linear(in_features=100, out_features=1, bias=True)\n", 777 | " (y379o): Linear(in_features=100, out_features=1, bias=True)\n", 778 | " (y380o): Linear(in_features=100, out_features=1, bias=True)\n", 779 | " (y381o): Linear(in_features=100, out_features=1, bias=True)\n", 780 | " (y382o): Linear(in_features=100, out_features=1, bias=True)\n", 781 | " (y383o): Linear(in_features=100, out_features=1, bias=True)\n", 782 | " (y384o): Linear(in_features=100, out_features=1, bias=True)\n", 783 | " (y385o): Linear(in_features=100, out_features=1, bias=True)\n", 784 | " (y386o): Linear(in_features=100, out_features=1, bias=True)\n", 785 | " (y387o): Linear(in_features=100, out_features=1, bias=True)\n", 786 | " (y388o): Linear(in_features=100, out_features=1, bias=True)\n", 787 | " (y389o): Linear(in_features=100, out_features=1, bias=True)\n", 788 | " (y390o): Linear(in_features=100, out_features=1, bias=True)\n", 789 | " (y391o): Linear(in_features=100, out_features=1, bias=True)\n", 790 | " (y392o): Linear(in_features=100, out_features=1, bias=True)\n", 791 | " (y393o): Linear(in_features=100, out_features=1, bias=True)\n", 792 | " (y394o): Linear(in_features=100, out_features=1, bias=True)\n", 793 | " (y395o): Linear(in_features=100, out_features=1, bias=True)\n", 794 | " (y396o): Linear(in_features=100, out_features=1, bias=True)\n", 795 | " (y397o): Linear(in_features=100, out_features=1, bias=True)\n", 796 | " (y398o): Linear(in_features=100, out_features=1, bias=True)\n", 797 | " (y399o): Linear(in_features=100, out_features=1, bias=True)\n", 798 | " (y400o): Linear(in_features=100, out_features=1, bias=True)\n", 799 | " (y401o): Linear(in_features=100, out_features=1, bias=True)\n", 800 | " (y402o): Linear(in_features=100, out_features=1, bias=True)\n", 801 | " (y403o): Linear(in_features=100, out_features=1, bias=True)\n", 802 | " (y404o): Linear(in_features=100, out_features=1, bias=True)\n", 803 | " (y405o): Linear(in_features=100, out_features=1, bias=True)\n", 804 | " (y406o): Linear(in_features=100, out_features=1, bias=True)\n", 805 | " (y407o): Linear(in_features=100, out_features=1, bias=True)\n", 806 | " (y408o): Linear(in_features=100, out_features=1, bias=True)\n", 807 | " (y409o): Linear(in_features=100, out_features=1, bias=True)\n", 808 | " (y410o): Linear(in_features=100, out_features=1, bias=True)\n", 809 | " (y411o): Linear(in_features=100, out_features=1, bias=True)\n", 810 | " (y412o): Linear(in_features=100, out_features=1, bias=True)\n", 811 | " (y413o): Linear(in_features=100, out_features=1, bias=True)\n", 812 | " (y414o): Linear(in_features=100, out_features=1, bias=True)\n", 813 | " (y415o): Linear(in_features=100, out_features=1, bias=True)\n", 814 | " (y416o): Linear(in_features=100, out_features=1, bias=True)\n", 815 | " (y417o): Linear(in_features=100, out_features=1, bias=True)\n", 816 | " (y418o): Linear(in_features=100, out_features=1, bias=True)\n", 817 | " (y419o): Linear(in_features=100, out_features=1, bias=True)\n", 818 | " (y420o): Linear(in_features=100, out_features=1, bias=True)\n", 819 | " (y421o): Linear(in_features=100, out_features=1, bias=True)\n", 820 | " (y422o): Linear(in_features=100, out_features=1, bias=True)\n", 821 | " (y423o): Linear(in_features=100, out_features=1, bias=True)\n", 822 | " (y424o): Linear(in_features=100, out_features=1, bias=True)\n", 823 | " (y425o): Linear(in_features=100, out_features=1, bias=True)\n", 824 | " (y426o): Linear(in_features=100, out_features=1, bias=True)\n", 825 | " (y427o): Linear(in_features=100, out_features=1, bias=True)\n", 826 | " (y428o): Linear(in_features=100, out_features=1, bias=True)\n", 827 | " (y429o): Linear(in_features=100, out_features=1, bias=True)\n", 828 | " (y430o): Linear(in_features=100, out_features=1, bias=True)\n", 829 | " (y431o): Linear(in_features=100, out_features=1, bias=True)\n", 830 | " (y432o): Linear(in_features=100, out_features=1, bias=True)\n", 831 | " (y433o): Linear(in_features=100, out_features=1, bias=True)\n", 832 | " (y434o): Linear(in_features=100, out_features=1, bias=True)\n", 833 | " (y435o): Linear(in_features=100, out_features=1, bias=True)\n", 834 | " (y436o): Linear(in_features=100, out_features=1, bias=True)\n", 835 | " (y437o): Linear(in_features=100, out_features=1, bias=True)\n", 836 | " (y438o): Linear(in_features=100, out_features=1, bias=True)\n", 837 | " (y439o): Linear(in_features=100, out_features=1, bias=True)\n", 838 | " (y440o): Linear(in_features=100, out_features=1, bias=True)\n", 839 | " (y441o): Linear(in_features=100, out_features=1, bias=True)\n", 840 | " (y442o): Linear(in_features=100, out_features=1, bias=True)\n", 841 | " (y443o): Linear(in_features=100, out_features=1, bias=True)\n", 842 | " (y444o): Linear(in_features=100, out_features=1, bias=True)\n", 843 | " (y445o): Linear(in_features=100, out_features=1, bias=True)\n", 844 | " (y446o): Linear(in_features=100, out_features=1, bias=True)\n", 845 | " (y447o): Linear(in_features=100, out_features=1, bias=True)\n", 846 | " (y448o): Linear(in_features=100, out_features=1, bias=True)\n", 847 | " (y449o): Linear(in_features=100, out_features=1, bias=True)\n", 848 | " (y450o): Linear(in_features=100, out_features=1, bias=True)\n", 849 | " (y451o): Linear(in_features=100, out_features=1, bias=True)\n", 850 | " (y452o): Linear(in_features=100, out_features=1, bias=True)\n", 851 | " (y453o): Linear(in_features=100, out_features=1, bias=True)\n", 852 | " (y454o): Linear(in_features=100, out_features=1, bias=True)\n", 853 | " (y455o): Linear(in_features=100, out_features=1, bias=True)\n", 854 | " (y456o): Linear(in_features=100, out_features=1, bias=True)\n", 855 | " (y457o): Linear(in_features=100, out_features=1, bias=True)\n", 856 | " (y458o): Linear(in_features=100, out_features=1, bias=True)\n", 857 | " (y459o): Linear(in_features=100, out_features=1, bias=True)\n", 858 | " (y460o): Linear(in_features=100, out_features=1, bias=True)\n", 859 | " (y461o): Linear(in_features=100, out_features=1, bias=True)\n", 860 | " (y462o): Linear(in_features=100, out_features=1, bias=True)\n", 861 | " (y463o): Linear(in_features=100, out_features=1, bias=True)\n", 862 | " (y464o): Linear(in_features=100, out_features=1, bias=True)\n", 863 | " (y465o): Linear(in_features=100, out_features=1, bias=True)\n", 864 | " (y466o): Linear(in_features=100, out_features=1, bias=True)\n", 865 | " (y467o): Linear(in_features=100, out_features=1, bias=True)\n", 866 | " (y468o): Linear(in_features=100, out_features=1, bias=True)\n", 867 | " (y469o): Linear(in_features=100, out_features=1, bias=True)\n", 868 | " (y470o): Linear(in_features=100, out_features=1, bias=True)\n", 869 | " (y471o): Linear(in_features=100, out_features=1, bias=True)\n", 870 | " (y472o): Linear(in_features=100, out_features=1, bias=True)\n", 871 | " (y473o): Linear(in_features=100, out_features=1, bias=True)\n", 872 | " (y474o): Linear(in_features=100, out_features=1, bias=True)\n", 873 | " (y475o): Linear(in_features=100, out_features=1, bias=True)\n", 874 | " (y476o): Linear(in_features=100, out_features=1, bias=True)\n", 875 | " (y477o): Linear(in_features=100, out_features=1, bias=True)\n", 876 | " (y478o): Linear(in_features=100, out_features=1, bias=True)\n", 877 | " (y479o): Linear(in_features=100, out_features=1, bias=True)\n", 878 | " (y480o): Linear(in_features=100, out_features=1, bias=True)\n", 879 | " (y481o): Linear(in_features=100, out_features=1, bias=True)\n", 880 | " (y482o): Linear(in_features=100, out_features=1, bias=True)\n", 881 | " (y483o): Linear(in_features=100, out_features=1, bias=True)\n", 882 | " (y484o): Linear(in_features=100, out_features=1, bias=True)\n", 883 | " (y485o): Linear(in_features=100, out_features=1, bias=True)\n", 884 | " (y486o): Linear(in_features=100, out_features=1, bias=True)\n", 885 | " (y487o): Linear(in_features=100, out_features=1, bias=True)\n", 886 | " (y488o): Linear(in_features=100, out_features=1, bias=True)\n", 887 | " (y489o): Linear(in_features=100, out_features=1, bias=True)\n", 888 | " (y490o): Linear(in_features=100, out_features=1, bias=True)\n", 889 | " (y491o): Linear(in_features=100, out_features=1, bias=True)\n", 890 | " (y492o): Linear(in_features=100, out_features=1, bias=True)\n", 891 | " (y493o): Linear(in_features=100, out_features=1, bias=True)\n", 892 | " (y494o): Linear(in_features=100, out_features=1, bias=True)\n", 893 | " (y495o): Linear(in_features=100, out_features=1, bias=True)\n", 894 | " (y496o): Linear(in_features=100, out_features=1, bias=True)\n", 895 | " (y497o): Linear(in_features=100, out_features=1, bias=True)\n", 896 | " (y498o): Linear(in_features=100, out_features=1, bias=True)\n", 897 | " (y499o): Linear(in_features=100, out_features=1, bias=True)\n", 898 | " (y500o): Linear(in_features=100, out_features=1, bias=True)\n", 899 | " (y501o): Linear(in_features=100, out_features=1, bias=True)\n", 900 | " (y502o): Linear(in_features=100, out_features=1, bias=True)\n", 901 | " (y503o): Linear(in_features=100, out_features=1, bias=True)\n", 902 | " (y504o): Linear(in_features=100, out_features=1, bias=True)\n", 903 | " (y505o): Linear(in_features=100, out_features=1, bias=True)\n", 904 | " (y506o): Linear(in_features=100, out_features=1, bias=True)\n", 905 | " (y507o): Linear(in_features=100, out_features=1, bias=True)\n", 906 | " (y508o): Linear(in_features=100, out_features=1, bias=True)\n", 907 | " (y509o): Linear(in_features=100, out_features=1, bias=True)\n", 908 | " (y510o): Linear(in_features=100, out_features=1, bias=True)\n", 909 | " (y511o): Linear(in_features=100, out_features=1, bias=True)\n", 910 | " (y512o): Linear(in_features=100, out_features=1, bias=True)\n", 911 | " (y513o): Linear(in_features=100, out_features=1, bias=True)\n", 912 | " (y514o): Linear(in_features=100, out_features=1, bias=True)\n", 913 | " (y515o): Linear(in_features=100, out_features=1, bias=True)\n", 914 | " (y516o): Linear(in_features=100, out_features=1, bias=True)\n", 915 | " (y517o): Linear(in_features=100, out_features=1, bias=True)\n", 916 | " (y518o): Linear(in_features=100, out_features=1, bias=True)\n", 917 | " (y519o): Linear(in_features=100, out_features=1, bias=True)\n", 918 | " (y520o): Linear(in_features=100, out_features=1, bias=True)\n", 919 | " (y521o): Linear(in_features=100, out_features=1, bias=True)\n", 920 | " (y522o): Linear(in_features=100, out_features=1, bias=True)\n", 921 | " (y523o): Linear(in_features=100, out_features=1, bias=True)\n", 922 | " (y524o): Linear(in_features=100, out_features=1, bias=True)\n", 923 | " (y525o): Linear(in_features=100, out_features=1, bias=True)\n", 924 | " (y526o): Linear(in_features=100, out_features=1, bias=True)\n", 925 | " (y527o): Linear(in_features=100, out_features=1, bias=True)\n", 926 | " (y528o): Linear(in_features=100, out_features=1, bias=True)\n", 927 | " (y529o): Linear(in_features=100, out_features=1, bias=True)\n", 928 | " (y530o): Linear(in_features=100, out_features=1, bias=True)\n", 929 | " (y531o): Linear(in_features=100, out_features=1, bias=True)\n", 930 | " (y532o): Linear(in_features=100, out_features=1, bias=True)\n", 931 | " (y533o): Linear(in_features=100, out_features=1, bias=True)\n", 932 | " (y534o): Linear(in_features=100, out_features=1, bias=True)\n", 933 | " (y535o): Linear(in_features=100, out_features=1, bias=True)\n", 934 | " (y536o): Linear(in_features=100, out_features=1, bias=True)\n", 935 | " (y537o): Linear(in_features=100, out_features=1, bias=True)\n", 936 | " (y538o): Linear(in_features=100, out_features=1, bias=True)\n", 937 | " (y539o): Linear(in_features=100, out_features=1, bias=True)\n", 938 | " (y540o): Linear(in_features=100, out_features=1, bias=True)\n", 939 | " (y541o): Linear(in_features=100, out_features=1, bias=True)\n", 940 | " (y542o): Linear(in_features=100, out_features=1, bias=True)\n", 941 | " (y543o): Linear(in_features=100, out_features=1, bias=True)\n", 942 | " (y544o): Linear(in_features=100, out_features=1, bias=True)\n", 943 | " (y545o): Linear(in_features=100, out_features=1, bias=True)\n", 944 | " (y546o): Linear(in_features=100, out_features=1, bias=True)\n", 945 | " (y547o): Linear(in_features=100, out_features=1, bias=True)\n", 946 | " (y548o): Linear(in_features=100, out_features=1, bias=True)\n", 947 | " (y549o): Linear(in_features=100, out_features=1, bias=True)\n", 948 | " (y550o): Linear(in_features=100, out_features=1, bias=True)\n", 949 | " (y551o): Linear(in_features=100, out_features=1, bias=True)\n", 950 | " (y552o): Linear(in_features=100, out_features=1, bias=True)\n", 951 | " (y553o): Linear(in_features=100, out_features=1, bias=True)\n", 952 | " (y554o): Linear(in_features=100, out_features=1, bias=True)\n", 953 | " (y555o): Linear(in_features=100, out_features=1, bias=True)\n", 954 | " (y556o): Linear(in_features=100, out_features=1, bias=True)\n", 955 | " (y557o): Linear(in_features=100, out_features=1, bias=True)\n", 956 | " (y558o): Linear(in_features=100, out_features=1, bias=True)\n", 957 | " (y559o): Linear(in_features=100, out_features=1, bias=True)\n", 958 | ")" 959 | ] 960 | }, 961 | "execution_count": 10, 962 | "metadata": {}, 963 | "output_type": "execute_result" 964 | } 965 | ], 966 | "source": [ 967 | "model = ChEMBLMultiTask(560) # number of tasks\n", 968 | "model.load_state_dict(torch.load(f\"./{MODEL_FILE}\"))\n", 969 | "model.eval()" 970 | ] 971 | }, 972 | { 973 | "cell_type": "code", 974 | "execution_count": null, 975 | "metadata": {}, 976 | "outputs": [], 977 | "source": [] 978 | } 979 | ], 980 | "metadata": { 981 | "kernelspec": { 982 | "display_name": "Python 3", 983 | "language": "python", 984 | "name": "python3" 985 | }, 986 | "language_info": { 987 | "codemirror_mode": { 988 | "name": "ipython", 989 | "version": 3 990 | }, 991 | "file_extension": ".py", 992 | "mimetype": "text/x-python", 993 | "name": "python", 994 | "nbconvert_exporter": "python", 995 | "pygments_lexer": "ipython3", 996 | "version": "3.6.8" 997 | } 998 | }, 999 | "nbformat": 4, 1000 | "nbformat_minor": 2 1001 | } 1002 | --------------------------------------------------------------------------------