├── README.md └── COVID_19_detection.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # covid-19-detection-project 2 | 3 | ## Jupyter notebook for the blog post: 4 | 5 | [تشخیص کرونا با یادگیری عمیق](https://howsam.org/covid19-detection/) 6 | -------------------------------------------------------------------------------- /COVID_19_detection.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "COVID-19-detection.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "toc_visible": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "GPU" 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "gTv0ayLu_Nsf", 22 | "colab_type": "text" 23 | }, 24 | "source": [ 25 | "# **Download data from github**" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "metadata": { 31 | "id": "VPw6cHEZO-SJ", 32 | "colab_type": "code", 33 | "colab": { 34 | "base_uri": "https://localhost:8080/", 35 | "height": 150 36 | }, 37 | "outputId": "1b1137c0-6f00-4de2-b3cc-15a5eec8f5b0" 38 | }, 39 | "source": [ 40 | "! git clone https://github.com/ieee8023/covid-chestxray-dataset" 41 | ], 42 | "execution_count": 1, 43 | "outputs": [ 44 | { 45 | "output_type": "stream", 46 | "text": [ 47 | "Cloning into 'covid-chestxray-dataset'...\n", 48 | "remote: Enumerating objects: 36, done.\u001b[K\n", 49 | "remote: Counting objects: 100% (36/36), done.\u001b[K\n", 50 | "remote: Compressing objects: 100% (28/28), done.\u001b[K\n", 51 | "remote: Total 3651 (delta 13), reused 21 (delta 8), pack-reused 3615\u001b[K\n", 52 | "Receiving objects: 100% (3651/3651), 632.29 MiB | 42.82 MiB/s, done.\n", 53 | "Resolving deltas: 100% (1461/1461), done.\n", 54 | "Checking out files: 100% (1164/1164), done.\n" 55 | ], 56 | "name": "stdout" 57 | } 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": { 63 | "id": "HSq5-DLR_fXl", 64 | "colab_type": "text" 65 | }, 66 | "source": [ 67 | "# **create a folder named dataset and ...**" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "metadata": { 73 | "id": "VL9QTzwsPu2z", 74 | "colab_type": "code", 75 | "colab": {} 76 | }, 77 | "source": [ 78 | "! mkdir dataset\n", 79 | "! mkdir /content/dataset/train\n", 80 | "! mkdir /content/dataset/train/covid\n", 81 | "! mkdir /content/dataset/train/normal\n", 82 | "! mkdir /content/dataset/test\n", 83 | "! mkdir /content/dataset/test/covid\n", 84 | "! mkdir /content/dataset/test/normal" 85 | ], 86 | "execution_count": 2, 87 | "outputs": [] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": { 92 | "id": "ltMFV5BLAQIG", 93 | "colab_type": "text" 94 | }, 95 | "source": [ 96 | "# **Download data from kaggle**\n", 97 | "from kaggle - my account press \"create new token api\". upload downloaded file here." 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "metadata": { 103 | "id": "_lo53zC6SAIR", 104 | "colab_type": "code", 105 | "colab": { 106 | "resources": { 107 | "http://localhost:8080/nbextensions/google.colab/files.js": { 108 | "data": "Ly8gQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQwovLwovLyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKLy8geW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLgovLyBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXQKLy8KLy8gICAgICBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjAKLy8KLy8gVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQovLyBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiAiQVMgSVMiIEJBU0lTLAovLyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KLy8gU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAovLyBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS4KCi8qKgogKiBAZmlsZW92ZXJ2aWV3IEhlbHBlcnMgZm9yIGdvb2dsZS5jb2xhYiBQeXRob24gbW9kdWxlLgogKi8KKGZ1bmN0aW9uKHNjb3BlKSB7CmZ1bmN0aW9uIHNwYW4odGV4dCwgc3R5bGVBdHRyaWJ1dGVzID0ge30pIHsKICBjb25zdCBlbGVtZW50ID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnc3BhbicpOwogIGVsZW1lbnQudGV4dENvbnRlbnQgPSB0ZXh0OwogIGZvciAoY29uc3Qga2V5IG9mIE9iamVjdC5rZXlzKHN0eWxlQXR0cmlidXRlcykpIHsKICAgIGVsZW1lbnQuc3R5bGVba2V5XSA9IHN0eWxlQXR0cmlidXRlc1trZXldOwogIH0KICByZXR1cm4gZWxlbWVudDsKfQoKLy8gTWF4IG51bWJlciBvZiBieXRlcyB3aGljaCB3aWxsIGJlIHVwbG9hZGVkIGF0IGEgdGltZS4KY29uc3QgTUFYX1BBWUxPQURfU0laRSA9IDEwMCAqIDEwMjQ7CgpmdW5jdGlvbiBfdXBsb2FkRmlsZXMoaW5wdXRJZCwgb3V0cHV0SWQpIHsKICBjb25zdCBzdGVwcyA9IHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCk7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICAvLyBDYWNoZSBzdGVwcyBvbiB0aGUgb3V0cHV0RWxlbWVudCB0byBtYWtlIGl0IGF2YWlsYWJsZSBmb3IgdGhlIG5leHQgY2FsbAogIC8vIHRvIHVwbG9hZEZpbGVzQ29udGludWUgZnJvbSBQeXRob24uCiAgb3V0cHV0RWxlbWVudC5zdGVwcyA9IHN0ZXBzOwoKICByZXR1cm4gX3VwbG9hZEZpbGVzQ29udGludWUob3V0cHV0SWQpOwp9CgovLyBUaGlzIGlzIHJvdWdobHkgYW4gYXN5bmMgZ2VuZXJhdG9yIChub3Qgc3VwcG9ydGVkIGluIHRoZSBicm93c2VyIHlldCksCi8vIHdoZXJlIHRoZXJlIGFyZSBtdWx0aXBsZSBhc3luY2hyb25vdXMgc3RlcHMgYW5kIHRoZSBQeXRob24gc2lkZSBpcyBnb2luZwovLyB0byBwb2xsIGZvciBjb21wbGV0aW9uIG9mIGVhY2ggc3RlcC4KLy8gVGhpcyB1c2VzIGEgUHJvbWlzZSB0byBibG9jayB0aGUgcHl0aG9uIHNpZGUgb24gY29tcGxldGlvbiBvZiBlYWNoIHN0ZXAsCi8vIHRoZW4gcGFzc2VzIHRoZSByZXN1bHQgb2YgdGhlIHByZXZpb3VzIHN0ZXAgYXMgdGhlIGlucHV0IHRvIHRoZSBuZXh0IHN0ZXAuCmZ1bmN0aW9uIF91cGxvYWRGaWxlc0NvbnRpbnVlKG91dHB1dElkKSB7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICBjb25zdCBzdGVwcyA9IG91dHB1dEVsZW1lbnQuc3RlcHM7CgogIGNvbnN0IG5leHQgPSBzdGVwcy5uZXh0KG91dHB1dEVsZW1lbnQubGFzdFByb21pc2VWYWx1ZSk7CiAgcmV0dXJuIFByb21pc2UucmVzb2x2ZShuZXh0LnZhbHVlLnByb21pc2UpLnRoZW4oKHZhbHVlKSA9PiB7CiAgICAvLyBDYWNoZSB0aGUgbGFzdCBwcm9taXNlIHZhbHVlIHRvIG1ha2UgaXQgYXZhaWxhYmxlIHRvIHRoZSBuZXh0CiAgICAvLyBzdGVwIG9mIHRoZSBnZW5lcmF0b3IuCiAgICBvdXRwdXRFbGVtZW50Lmxhc3RQcm9taXNlVmFsdWUgPSB2YWx1ZTsKICAgIHJldHVybiBuZXh0LnZhbHVlLnJlc3BvbnNlOwogIH0pOwp9CgovKioKICogR2VuZXJhdG9yIGZ1bmN0aW9uIHdoaWNoIGlzIGNhbGxlZCBiZXR3ZWVuIGVhY2ggYXN5bmMgc3RlcCBvZiB0aGUgdXBsb2FkCiAqIHByb2Nlc3MuCiAqIEBwYXJhbSB7c3RyaW5nfSBpbnB1dElkIEVsZW1lbnQgSUQgb2YgdGhlIGlucHV0IGZpbGUgcGlja2VyIGVsZW1lbnQuCiAqIEBwYXJhbSB7c3RyaW5nfSBvdXRwdXRJZCBFbGVtZW50IElEIG9mIHRoZSBvdXRwdXQgZGlzcGxheS4KICogQHJldHVybiB7IUl0ZXJhYmxlPCFPYmplY3Q+fSBJdGVyYWJsZSBvZiBuZXh0IHN0ZXBzLgogKi8KZnVuY3Rpb24qIHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCkgewogIGNvbnN0IGlucHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKGlucHV0SWQpOwogIGlucHV0RWxlbWVudC5kaXNhYmxlZCA9IGZhbHNlOwoKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIG91dHB1dEVsZW1lbnQuaW5uZXJIVE1MID0gJyc7CgogIGNvbnN0IHBpY2tlZFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgaW5wdXRFbGVtZW50LmFkZEV2ZW50TGlzdGVuZXIoJ2NoYW5nZScsIChlKSA9PiB7CiAgICAgIHJlc29sdmUoZS50YXJnZXQuZmlsZXMpOwogICAgfSk7CiAgfSk7CgogIGNvbnN0IGNhbmNlbCA9IGRvY3VtZW50LmNyZWF0ZUVsZW1lbnQoJ2J1dHRvbicpOwogIGlucHV0RWxlbWVudC5wYXJlbnRFbGVtZW50LmFwcGVuZENoaWxkKGNhbmNlbCk7CiAgY2FuY2VsLnRleHRDb250ZW50ID0gJ0NhbmNlbCB1cGxvYWQnOwogIGNvbnN0IGNhbmNlbFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgY2FuY2VsLm9uY2xpY2sgPSAoKSA9PiB7CiAgICAgIHJlc29sdmUobnVsbCk7CiAgICB9OwogIH0pOwoKICAvLyBXYWl0IGZvciB0aGUgdXNlciB0byBwaWNrIHRoZSBmaWxlcy4KICBjb25zdCBmaWxlcyA9IHlpZWxkIHsKICAgIHByb21pc2U6IFByb21pc2UucmFjZShbcGlja2VkUHJvbWlzZSwgY2FuY2VsUHJvbWlzZV0pLAogICAgcmVzcG9uc2U6IHsKICAgICAgYWN0aW9uOiAnc3RhcnRpbmcnLAogICAgfQogIH07CgogIGNhbmNlbC5yZW1vdmUoKTsKCiAgLy8gRGlzYWJsZSB0aGUgaW5wdXQgZWxlbWVudCBzaW5jZSBmdXJ0aGVyIHBpY2tzIGFyZSBub3QgYWxsb3dlZC4KICBpbnB1dEVsZW1lbnQuZGlzYWJsZWQgPSB0cnVlOwoKICBpZiAoIWZpbGVzKSB7CiAgICByZXR1cm4gewogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbXBsZXRlJywKICAgICAgfQogICAgfTsKICB9CgogIGZvciAoY29uc3QgZmlsZSBvZiBmaWxlcykgewogICAgY29uc3QgbGkgPSBkb2N1bWVudC5jcmVhdGVFbGVtZW50KCdsaScpOwogICAgbGkuYXBwZW5kKHNwYW4oZmlsZS5uYW1lLCB7Zm9udFdlaWdodDogJ2JvbGQnfSkpOwogICAgbGkuYXBwZW5kKHNwYW4oCiAgICAgICAgYCgke2ZpbGUudHlwZSB8fCAnbi9hJ30pIC0gJHtmaWxlLnNpemV9IGJ5dGVzLCBgICsKICAgICAgICBgbGFzdCBtb2RpZmllZDogJHsKICAgICAgICAgICAgZmlsZS5sYXN0TW9kaWZpZWREYXRlID8gZmlsZS5sYXN0TW9kaWZpZWREYXRlLnRvTG9jYWxlRGF0ZVN0cmluZygpIDoKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgJ24vYSd9IC0gYCkpOwogICAgY29uc3QgcGVyY2VudCA9IHNwYW4oJzAlIGRvbmUnKTsKICAgIGxpLmFwcGVuZENoaWxkKHBlcmNlbnQpOwoKICAgIG91dHB1dEVsZW1lbnQuYXBwZW5kQ2hpbGQobGkpOwoKICAgIGNvbnN0IGZpbGVEYXRhUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICAgIGNvbnN0IHJlYWRlciA9IG5ldyBGaWxlUmVhZGVyKCk7CiAgICAgIHJlYWRlci5vbmxvYWQgPSAoZSkgPT4gewogICAgICAgIHJlc29sdmUoZS50YXJnZXQucmVzdWx0KTsKICAgICAgfTsKICAgICAgcmVhZGVyLnJlYWRBc0FycmF5QnVmZmVyKGZpbGUpOwogICAgfSk7CiAgICAvLyBXYWl0IGZvciB0aGUgZGF0YSB0byBiZSByZWFkeS4KICAgIGxldCBmaWxlRGF0YSA9IHlpZWxkIHsKICAgICAgcHJvbWlzZTogZmlsZURhdGFQcm9taXNlLAogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbnRpbnVlJywKICAgICAgfQogICAgfTsKCiAgICAvLyBVc2UgYSBjaHVua2VkIHNlbmRpbmcgdG8gYXZvaWQgbWVzc2FnZSBzaXplIGxpbWl0cy4gU2VlIGIvNjIxMTU2NjAuCiAgICBsZXQgcG9zaXRpb24gPSAwOwogICAgd2hpbGUgKHBvc2l0aW9uIDwgZmlsZURhdGEuYnl0ZUxlbmd0aCkgewogICAgICBjb25zdCBsZW5ndGggPSBNYXRoLm1pbihmaWxlRGF0YS5ieXRlTGVuZ3RoIC0gcG9zaXRpb24sIE1BWF9QQVlMT0FEX1NJWkUpOwogICAgICBjb25zdCBjaHVuayA9IG5ldyBVaW50OEFycmF5KGZpbGVEYXRhLCBwb3NpdGlvbiwgbGVuZ3RoKTsKICAgICAgcG9zaXRpb24gKz0gbGVuZ3RoOwoKICAgICAgY29uc3QgYmFzZTY0ID0gYnRvYShTdHJpbmcuZnJvbUNoYXJDb2RlLmFwcGx5KG51bGwsIGNodW5rKSk7CiAgICAgIHlpZWxkIHsKICAgICAgICByZXNwb25zZTogewogICAgICAgICAgYWN0aW9uOiAnYXBwZW5kJywKICAgICAgICAgIGZpbGU6IGZpbGUubmFtZSwKICAgICAgICAgIGRhdGE6IGJhc2U2NCwKICAgICAgICB9LAogICAgICB9OwogICAgICBwZXJjZW50LnRleHRDb250ZW50ID0KICAgICAgICAgIGAke01hdGgucm91bmQoKHBvc2l0aW9uIC8gZmlsZURhdGEuYnl0ZUxlbmd0aCkgKiAxMDApfSUgZG9uZWA7CiAgICB9CiAgfQoKICAvLyBBbGwgZG9uZS4KICB5aWVsZCB7CiAgICByZXNwb25zZTogewogICAgICBhY3Rpb246ICdjb21wbGV0ZScsCiAgICB9CiAgfTsKfQoKc2NvcGUuZ29vZ2xlID0gc2NvcGUuZ29vZ2xlIHx8IHt9OwpzY29wZS5nb29nbGUuY29sYWIgPSBzY29wZS5nb29nbGUuY29sYWIgfHwge307CnNjb3BlLmdvb2dsZS5jb2xhYi5fZmlsZXMgPSB7CiAgX3VwbG9hZEZpbGVzLAogIF91cGxvYWRGaWxlc0NvbnRpbnVlLAp9Owp9KShzZWxmKTsK", 109 | "ok": true, 110 | "headers": [ 111 | [ 112 | "content-type", 113 | "application/javascript" 114 | ] 115 | ], 116 | "status": 200, 117 | "status_text": "" 118 | } 119 | }, 120 | "base_uri": "https://localhost:8080/", 121 | "height": 290 122 | }, 123 | "outputId": "3365e6fd-1f2e-42e4-f45e-8f728a36931f" 124 | }, 125 | "source": [ 126 | "! pip install kaggle\n", 127 | "from google.colab import files\n", 128 | "files.upload()" 129 | ], 130 | "execution_count": 3, 131 | "outputs": [ 132 | { 133 | "output_type": "stream", 134 | "text": [ 135 | "Requirement already satisfied: kaggle in /usr/local/lib/python3.6/dist-packages (1.5.8)\n", 136 | "Requirement already satisfied: slugify in /usr/local/lib/python3.6/dist-packages (from kaggle) (0.0.1)\n", 137 | "Requirement already satisfied: six>=1.10 in /usr/local/lib/python3.6/dist-packages (from kaggle) (1.15.0)\n", 138 | "Requirement already satisfied: python-dateutil in /usr/local/lib/python3.6/dist-packages (from kaggle) (2.8.1)\n", 139 | "Requirement already satisfied: python-slugify in /usr/local/lib/python3.6/dist-packages (from kaggle) (4.0.1)\n", 140 | "Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from kaggle) (4.41.1)\n", 141 | "Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from kaggle) (1.24.3)\n", 142 | "Requirement already satisfied: certifi in /usr/local/lib/python3.6/dist-packages (from kaggle) (2020.6.20)\n", 143 | "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from kaggle) (2.23.0)\n", 144 | "Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.6/dist-packages (from python-slugify->kaggle) (1.3)\n", 145 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->kaggle) (2.10)\n", 146 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->kaggle) (3.0.4)\n" 147 | ], 148 | "name": "stdout" 149 | }, 150 | { 151 | "output_type": "display_data", 152 | "data": { 153 | "text/html": [ 154 | "\n", 155 | " \n", 157 | " \n", 158 | " Upload widget is only available when the cell has been executed in the\n", 159 | " current browser session. Please rerun this cell to enable.\n", 160 | " \n", 161 | " " 162 | ], 163 | "text/plain": [ 164 | "" 165 | ] 166 | }, 167 | "metadata": { 168 | "tags": [] 169 | } 170 | }, 171 | { 172 | "output_type": "stream", 173 | "text": [ 174 | "Saving kaggle.json to kaggle.json\n" 175 | ], 176 | "name": "stdout" 177 | }, 178 | { 179 | "output_type": "execute_result", 180 | "data": { 181 | "text/plain": [ 182 | "{'kaggle.json': b'{\"username\":\"golabiabi\",\"key\":\"ccaab0263258794f607a2b32e397be80\"}'}" 183 | ] 184 | }, 185 | "metadata": { 186 | "tags": [] 187 | }, 188 | "execution_count": 3 189 | } 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "metadata": { 195 | "id": "9QFcJfhqSRv8", 196 | "colab_type": "code", 197 | "colab": { 198 | "base_uri": "https://localhost:8080/", 199 | "height": 67 200 | }, 201 | "outputId": "45f905b9-23b8-4b41-eaa1-1d511371dbf9" 202 | }, 203 | "source": [ 204 | "# ******* make root/kaggle folder **********\n", 205 | "! mkdir -p ~/.kaggle\n", 206 | "!cp kaggle.json ~/.kaggle/\n", 207 | "!chmod 600 ~/.kaggle/kaggle.json\n", 208 | "# ************** download data **********\n", 209 | "! kaggle datasets download -d paultimothymooney/chest-xray-pneumonia\n", 210 | "# !unzip\n", 211 | "!unzip -q /content/chest-xray-pneumonia.zip" 212 | ], 213 | "execution_count": 4, 214 | "outputs": [ 215 | { 216 | "output_type": "stream", 217 | "text": [ 218 | "Downloading chest-xray-pneumonia.zip to /content\n", 219 | " 99% 2.27G/2.29G [00:30<00:00, 41.8MB/s]\n", 220 | "100% 2.29G/2.29G [00:30<00:00, 79.8MB/s]\n" 221 | ], 222 | "name": "stdout" 223 | } 224 | ] 225 | }, 226 | { 227 | "cell_type": "markdown", 228 | "metadata": { 229 | "id": "0ENy-aFwAe6a", 230 | "colab_type": "text" 231 | }, 232 | "source": [ 233 | "# **extract all covid cases**" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "metadata": { 239 | "id": "qjYzBvMV-wNu", 240 | "colab_type": "code", 241 | "colab": {} 242 | }, 243 | "source": [ 244 | "import pandas as pd\n", 245 | "\n", 246 | "csvPath = '/content/covid-chestxray-dataset/metadata.csv'\n", 247 | "df = pd.read_csv(csvPath)" 248 | ], 249 | "execution_count": 5, 250 | "outputs": [] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "metadata": { 255 | "id": "k7cC4zSkV_od", 256 | "colab_type": "code", 257 | "colab": { 258 | "base_uri": "https://localhost:8080/", 259 | "height": 33 260 | }, 261 | "outputId": "40f2debe-6686-4793-c825-6c2c133ae395" 262 | }, 263 | "source": [ 264 | "df.__len__()" 265 | ], 266 | "execution_count": 6, 267 | "outputs": [ 268 | { 269 | "output_type": "execute_result", 270 | "data": { 271 | "text/plain": [ 272 | "941" 273 | ] 274 | }, 275 | "metadata": { 276 | "tags": [] 277 | }, 278 | "execution_count": 6 279 | } 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "metadata": { 285 | "id": "gBP-mLlhTOo4", 286 | "colab_type": "code", 287 | "colab": {} 288 | }, 289 | "source": [ 290 | "p_id = 0\n", 291 | "cov_fn = []\n", 292 | "allfiles = []\n", 293 | "\n", 294 | "for (i, row) in df.iterrows():\n", 295 | "\tn_id = row[\"patientid\"]\n", 296 | "\n", 297 | "\tif n_id != p_id and len(cov_fn)>0 and p_id != 0:\n", 298 | "\t\tallfiles.append(cov_fn)\n", 299 | "\t\tcov_fn = []\n", 300 | "\tif row[\"finding\"] == \"Pneumonia/Viral/COVID-19\" and row[\"view\"] == \"PA\":\n", 301 | "\t\tcov_fn.append(row[\"filename\"])\n", 302 | "\t\tp_id = row[\"patientid\"]" 303 | ], 304 | "execution_count": 7, 305 | "outputs": [] 306 | }, 307 | { 308 | "cell_type": "markdown", 309 | "metadata": { 310 | "id": "jgkCE6i6A1Ul", 311 | "colab_type": "text" 312 | }, 313 | "source": [ 314 | "# **split covid data to train and test**" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "metadata": { 320 | "id": "ZMRRFuP6E4E3", 321 | "colab_type": "code", 322 | "colab": {} 323 | }, 324 | "source": [ 325 | "from sklearn.model_selection import train_test_split\n", 326 | "\t\n", 327 | "x_train_c, x_test_c = train_test_split(allfiles, test_size=0.20, random_state=23)" 328 | ], 329 | "execution_count": 8, 330 | "outputs": [] 331 | }, 332 | { 333 | "cell_type": "markdown", 334 | "metadata": { 335 | "id": "SiF5jPyIBF2o", 336 | "colab_type": "text" 337 | }, 338 | "source": [ 339 | "# **save data to folders**" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "metadata": { 345 | "id": "69d8Klcx_C3Q", 346 | "colab_type": "code", 347 | "colab": {} 348 | }, 349 | "source": [ 350 | "import shutil\n", 351 | "\n", 352 | "for img in sum(x_train_c, []):\n", 353 | " src = '/content/covid-chestxray-dataset/images/' + img\n", 354 | " dst = '/content/dataset/train/covid/' + img\n", 355 | " shutil.copy2(src, dst)\n", 356 | "\n", 357 | "for img in sum(x_test_c, []):\n", 358 | " src = '/content/covid-chestxray-dataset/images/' + img\n", 359 | " dst = '/content/dataset/test/covid/' + img\n", 360 | " shutil.copy2(src, dst)" 361 | ], 362 | "execution_count": 9, 363 | "outputs": [] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "metadata": { 368 | "id": "bkAHDv_VUJto", 369 | "colab_type": "code", 370 | "colab": {} 371 | }, 372 | "source": [ 373 | "import os\n", 374 | "import cv2\n", 375 | "import random\n", 376 | "\n", 377 | "n_samples = len(sum(allfiles, []))\n", 378 | "kaggle_data_path = '/content/chest_xray/train/NORMAL/'\n", 379 | "output_path_train = '/content/dataset/train/normal/'\n", 380 | "output_path_test = '/content/dataset/test/normal/'\n", 381 | "\n", 382 | "filenames = os.listdir(kaggle_data_path)\n", 383 | "random.seed(42)\n", 384 | "filenames = random.sample(filenames, len(filenames))\n", 385 | "for i in range(n_samples):\n", 386 | " n_image = cv2.imread(kaggle_data_path + filenames[i])\n", 387 | " if i < sum(x_train_c, []).__len__():\n", 388 | " cv2.imwrite(output_path_train + filenames[i], n_image)\n", 389 | " else:\n", 390 | " cv2.imwrite(output_path_test + filenames[i], n_image)" 391 | ], 392 | "execution_count": 10, 393 | "outputs": [] 394 | }, 395 | { 396 | "cell_type": "markdown", 397 | "metadata": { 398 | "id": "xEypj6CLBW11", 399 | "colab_type": "text" 400 | }, 401 | "source": [ 402 | "# **prepare data**" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "metadata": { 408 | "id": "xcx1GbycEdsP", 409 | "colab_type": "code", 410 | "colab": {} 411 | }, 412 | "source": [ 413 | "import os \n", 414 | "import numpy as np \n", 415 | "\n", 416 | "path = \"/content/dataset\" \n", 417 | "x_train_n = [] \n", 418 | "x_train_c = [] \n", 419 | "x_test_n = [] \n", 420 | "x_test_c = [] \n", 421 | "\n", 422 | "for p in os.listdir(path+ '/train/normal/'): \n", 423 | " x_train_n.append(cv2.imread(path + '/train/normal/' + p)) \n", 424 | "for p in os.listdir(path+ '/train/covid/'): \n", 425 | " x_train_c.append(cv2.imread(path + '/train/covid/' + p)) \n", 426 | "for p in os.listdir(path+ '/test/normal/'): \n", 427 | " x_test_n.append(cv2.imread(path + '/test/normal/' + p)) \n", 428 | "for p in os.listdir(path+ '/test/covid/'): \n", 429 | " x_test_c.append(cv2.imread(path + '/test/covid/' + p))" 430 | ], 431 | "execution_count": 11, 432 | "outputs": [] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "metadata": { 437 | "id": "nYdYmFZ1E0Nx", 438 | "colab_type": "code", 439 | "colab": {} 440 | }, 441 | "source": [ 442 | "for i in range(len(x_train_n)):\n", 443 | " x_train_n[i] = cv2.resize(x_train_n[i], (224, 224))\n", 444 | " x_train_c[i] = cv2.resize(x_train_c[i], (224, 224))\n", 445 | "\n", 446 | "for i in range(len(x_test_n)):\n", 447 | " x_test_n[i] = cv2.resize(x_test_n[i], (224, 224))\n", 448 | " x_test_c[i] = cv2.resize(x_test_c[i], (224, 224))" 449 | ], 450 | "execution_count": 12, 451 | "outputs": [] 452 | }, 453 | { 454 | "cell_type": "code", 455 | "metadata": { 456 | "id": "h1DuI0Y7GGRx", 457 | "colab_type": "code", 458 | "colab": {} 459 | }, 460 | "source": [ 461 | "x_train_n = np.array(x_train_n)\n", 462 | "x_train_c = np.array(x_train_c)\n", 463 | "x_test_n = np.array(x_test_n)\n", 464 | "x_test_c = np.array(x_test_c)" 465 | ], 466 | "execution_count": 13, 467 | "outputs": [] 468 | }, 469 | { 470 | "cell_type": "code", 471 | "metadata": { 472 | "id": "VsLZxeigF1Af", 473 | "colab_type": "code", 474 | "colab": { 475 | "base_uri": "https://localhost:8080/", 476 | "height": 84 477 | }, 478 | "outputId": "0341825e-1dcd-432a-86b2-6be8c9160a2f" 479 | }, 480 | "source": [ 481 | "print(x_train_n.shape) \n", 482 | "print(x_train_c.shape) \n", 483 | "print(x_test_n.shape) \n", 484 | "print(x_test_c.shape)" 485 | ], 486 | "execution_count": 14, 487 | "outputs": [ 488 | { 489 | "output_type": "stream", 490 | "text": [ 491 | "(148, 224, 224, 3)\n", 492 | "(148, 224, 224, 3)\n", 493 | "(47, 224, 224, 3)\n", 494 | "(47, 224, 224, 3)\n" 495 | ], 496 | "name": "stdout" 497 | } 498 | ] 499 | }, 500 | { 501 | "cell_type": "code", 502 | "metadata": { 503 | "id": "4ZdxezcRHL9N", 504 | "colab_type": "code", 505 | "colab": {} 506 | }, 507 | "source": [ 508 | "x_train = np.concatenate((x_train_n, x_train_c))/255.0\n", 509 | "x_test = np.concatenate((x_test_n, x_test_c))/255.0" 510 | ], 511 | "execution_count": 15, 512 | "outputs": [] 513 | }, 514 | { 515 | "cell_type": "code", 516 | "metadata": { 517 | "id": "z_6N6l9xHPrP", 518 | "colab_type": "code", 519 | "colab": {} 520 | }, 521 | "source": [ 522 | "y_train_n = np.zeros(x_train_n.shape[0])\n", 523 | "y_train_c = np.ones(x_train_c.shape[0])\n", 524 | "y_test_n = np.zeros(x_test_n.shape[0])\n", 525 | "y_test_c = np.ones(x_test_c.shape[0])" 526 | ], 527 | "execution_count": 16, 528 | "outputs": [] 529 | }, 530 | { 531 | "cell_type": "code", 532 | "metadata": { 533 | "id": "JJxnOSioHOs0", 534 | "colab_type": "code", 535 | "colab": {} 536 | }, 537 | "source": [ 538 | "y_train = np.concatenate((y_train_n, y_train_c))\n", 539 | "y_test = np.concatenate((y_test_n, y_test_c))" 540 | ], 541 | "execution_count": 17, 542 | "outputs": [] 543 | }, 544 | { 545 | "cell_type": "code", 546 | "metadata": { 547 | "id": "uIcyqQlTJpKh", 548 | "colab_type": "code", 549 | "colab": {} 550 | }, 551 | "source": [ 552 | "y_train = np.expand_dims(y_train, -1)\n", 553 | "y_test = np.expand_dims(y_test, -1)" 554 | ], 555 | "execution_count": 18, 556 | "outputs": [] 557 | }, 558 | { 559 | "cell_type": "code", 560 | "metadata": { 561 | "id": "spSzm24VJe5O", 562 | "colab_type": "code", 563 | "colab": { 564 | "base_uri": "https://localhost:8080/", 565 | "height": 84 566 | }, 567 | "outputId": "d9a1c551-ff94-448b-f3fd-a09b5a5730b4" 568 | }, 569 | "source": [ 570 | "print(x_train.shape)\n", 571 | "print(x_test.shape)\n", 572 | "print(y_train.shape)\n", 573 | "print(y_test.shape)" 574 | ], 575 | "execution_count": 19, 576 | "outputs": [ 577 | { 578 | "output_type": "stream", 579 | "text": [ 580 | "(296, 224, 224, 3)\n", 581 | "(94, 224, 224, 3)\n", 582 | "(296, 1)\n", 583 | "(94, 1)\n" 584 | ], 585 | "name": "stdout" 586 | } 587 | ] 588 | }, 589 | { 590 | "cell_type": "markdown", 591 | "metadata": { 592 | "id": "S4nPCHWlDpbY", 593 | "colab_type": "text" 594 | }, 595 | "source": [ 596 | "# **define model**" 597 | ] 598 | }, 599 | { 600 | "cell_type": "code", 601 | "metadata": { 602 | "id": "utkkMYzLDOAp", 603 | "colab_type": "code", 604 | "colab": {} 605 | }, 606 | "source": [ 607 | "# import the necessary packages\n", 608 | "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n", 609 | "from tensorflow.keras.applications import VGG16\n", 610 | "from tensorflow.keras.layers import AveragePooling2D\n", 611 | "from tensorflow.keras.layers import Dropout\n", 612 | "from tensorflow.keras.layers import Flatten\n", 613 | "from tensorflow.keras.layers import Dense\n", 614 | "from tensorflow.keras.layers import Input\n", 615 | "from tensorflow.keras.models import Model\n", 616 | "from tensorflow.keras.optimizers import Adam\n", 617 | "from tensorflow.keras.utils import to_categorical\n", 618 | "from sklearn.preprocessing import LabelBinarizer\n", 619 | "from sklearn.model_selection import train_test_split\n", 620 | "from sklearn.metrics import classification_report\n", 621 | "from sklearn.metrics import confusion_matrix\n", 622 | "from imutils import paths\n", 623 | "import matplotlib.pyplot as plt\n", 624 | "import numpy as np\n", 625 | "import argparse\n", 626 | "import cv2\n", 627 | "import os" 628 | ], 629 | "execution_count": 20, 630 | "outputs": [] 631 | }, 632 | { 633 | "cell_type": "code", 634 | "metadata": { 635 | "id": "j9e9uz_dGTGJ", 636 | "colab_type": "code", 637 | "colab": {} 638 | }, 639 | "source": [ 640 | "# initialize the initial learning rate, number of epochs to train for,\n", 641 | "# and batch size\n", 642 | "INIT_LR = 1e-3\n", 643 | "EPOCHS = 50\n", 644 | "BS = 8" 645 | ], 646 | "execution_count": 21, 647 | "outputs": [] 648 | }, 649 | { 650 | "cell_type": "code", 651 | "metadata": { 652 | "id": "8-ZzRZgRGbOb", 653 | "colab_type": "code", 654 | "colab": {} 655 | }, 656 | "source": [ 657 | "# initialize the training data augmentation object\n", 658 | "trainAug = ImageDataGenerator(\n", 659 | "\trotation_range=15,\n", 660 | "\tfill_mode=\"nearest\")" 661 | ], 662 | "execution_count": 22, 663 | "outputs": [] 664 | }, 665 | { 666 | "cell_type": "code", 667 | "metadata": { 668 | "id": "ZQj-zQ-JGiVv", 669 | "colab_type": "code", 670 | "colab": { 671 | "base_uri": "https://localhost:8080/", 672 | "height": 70 673 | }, 674 | "outputId": "7b709beb-5f94-48f0-848d-ead6f944e2c2" 675 | }, 676 | "source": [ 677 | "# load the VGG16 network, ensuring the head FC layer sets are left\n", 678 | "# off\n", 679 | "baseModel = VGG16(weights=\"imagenet\", include_top=False,\n", 680 | "\tinput_tensor=Input(shape=(224, 224, 3)))\n", 681 | "# construct the head of the model that will be placed on top of the\n", 682 | "# the base model\n", 683 | "headModel = baseModel.output\n", 684 | "headModel = AveragePooling2D(pool_size=(4, 4))(headModel)\n", 685 | "headModel = Flatten(name=\"flatten\")(headModel)\n", 686 | "headModel = Dense(64, activation=\"relu\")(headModel)\n", 687 | "headModel = Dropout(0.5)(headModel)\n", 688 | "headModel = Dense(1, activation=\"sigmoid\")(headModel)\n", 689 | "# place the head FC model on top of the base model (this will become\n", 690 | "# the actual model we will train)\n", 691 | "model = Model(inputs=baseModel.input, outputs=headModel)\n", 692 | "# loop over all layers in the base model and freeze them so they will\n", 693 | "# *not* be updated during the first training process\n", 694 | "for layer in baseModel.layers:\n", 695 | "\tlayer.trainable = False" 696 | ], 697 | "execution_count": 23, 698 | "outputs": [ 699 | { 700 | "output_type": "stream", 701 | "text": [ 702 | "Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5\n", 703 | "58892288/58889256 [==============================] - 0s 0us/step\n" 704 | ], 705 | "name": "stdout" 706 | } 707 | ] 708 | }, 709 | { 710 | "cell_type": "code", 711 | "metadata": { 712 | "id": "1q2hyqzfGlcU", 713 | "colab_type": "code", 714 | "colab": { 715 | "base_uri": "https://localhost:8080/", 716 | "height": 1000 717 | }, 718 | "outputId": "e899105d-77b1-4d5d-880e-56f19fdf4c3b" 719 | }, 720 | "source": [ 721 | "# compile our model\n", 722 | "print(\"[INFO] compiling model...\")\n", 723 | "opt = Adam(lr=INIT_LR, decay=INIT_LR / EPOCHS)\n", 724 | "model.compile(loss=\"binary_crossentropy\", optimizer=opt,\n", 725 | "\tmetrics=[\"accuracy\"])\n", 726 | "# train the head of the network\n", 727 | "print(\"[INFO] training head...\")\n", 728 | "H = model.fit_generator(\n", 729 | "\ttrainAug.flow(x_train, y_train, batch_size=BS),\n", 730 | "\tsteps_per_epoch=len(x_train) // BS,\n", 731 | "\tvalidation_data=(x_test, y_test),\n", 732 | "\tvalidation_steps=len(x_test) // BS,\n", 733 | "\tepochs=EPOCHS)" 734 | ], 735 | "execution_count": 24, 736 | "outputs": [ 737 | { 738 | "output_type": "stream", 739 | "text": [ 740 | "[INFO] compiling model...\n", 741 | "[INFO] training head...\n", 742 | "WARNING:tensorflow:From :13: Model.fit_generator (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.\n", 743 | "Instructions for updating:\n", 744 | "Please use Model.fit, which supports generators.\n", 745 | "Epoch 1/50\n", 746 | "37/37 [==============================] - 5s 139ms/step - loss: 0.6236 - accuracy: 0.6520 - val_loss: 0.4694 - val_accuracy: 0.9468\n", 747 | "Epoch 2/50\n", 748 | "37/37 [==============================] - 4s 108ms/step - loss: 0.4234 - accuracy: 0.8480 - val_loss: 0.3183 - val_accuracy: 0.9468\n", 749 | "Epoch 3/50\n", 750 | "37/37 [==============================] - 4s 108ms/step - loss: 0.3176 - accuracy: 0.9155 - val_loss: 0.2423 - val_accuracy: 0.9574\n", 751 | "Epoch 4/50\n", 752 | "37/37 [==============================] - 4s 108ms/step - loss: 0.2675 - accuracy: 0.9122 - val_loss: 0.2038 - val_accuracy: 0.9468\n", 753 | "Epoch 5/50\n", 754 | "37/37 [==============================] - 4s 108ms/step - loss: 0.2471 - accuracy: 0.9189 - val_loss: 0.1899 - val_accuracy: 0.9574\n", 755 | "Epoch 6/50\n", 756 | "37/37 [==============================] - 4s 108ms/step - loss: 0.2090 - accuracy: 0.9392 - val_loss: 0.1580 - val_accuracy: 0.9574\n", 757 | "Epoch 7/50\n", 758 | "37/37 [==============================] - 4s 107ms/step - loss: 0.1850 - accuracy: 0.9493 - val_loss: 0.1663 - val_accuracy: 0.9574\n", 759 | "Epoch 8/50\n", 760 | "37/37 [==============================] - 4s 110ms/step - loss: 0.1551 - accuracy: 0.9459 - val_loss: 0.1494 - val_accuracy: 0.9681\n", 761 | "Epoch 9/50\n", 762 | "37/37 [==============================] - 4s 110ms/step - loss: 0.1647 - accuracy: 0.9493 - val_loss: 0.1320 - val_accuracy: 0.9681\n", 763 | "Epoch 10/50\n", 764 | "37/37 [==============================] - 4s 110ms/step - loss: 0.1231 - accuracy: 0.9662 - val_loss: 0.1285 - val_accuracy: 0.9681\n", 765 | "Epoch 11/50\n", 766 | "37/37 [==============================] - 4s 108ms/step - loss: 0.1397 - accuracy: 0.9493 - val_loss: 0.1168 - val_accuracy: 0.9787\n", 767 | "Epoch 12/50\n", 768 | "37/37 [==============================] - 4s 108ms/step - loss: 0.1209 - accuracy: 0.9696 - val_loss: 0.1252 - val_accuracy: 0.9681\n", 769 | "Epoch 13/50\n", 770 | "37/37 [==============================] - 4s 107ms/step - loss: 0.1034 - accuracy: 0.9764 - val_loss: 0.1101 - val_accuracy: 0.9787\n", 771 | "Epoch 14/50\n", 772 | "37/37 [==============================] - 4s 107ms/step - loss: 0.1115 - accuracy: 0.9595 - val_loss: 0.1311 - val_accuracy: 0.9681\n", 773 | "Epoch 15/50\n", 774 | "37/37 [==============================] - 4s 107ms/step - loss: 0.0916 - accuracy: 0.9696 - val_loss: 0.1219 - val_accuracy: 0.9681\n", 775 | "Epoch 16/50\n", 776 | "37/37 [==============================] - 4s 109ms/step - loss: 0.1037 - accuracy: 0.9662 - val_loss: 0.1416 - val_accuracy: 0.9681\n", 777 | "Epoch 17/50\n", 778 | "37/37 [==============================] - 4s 110ms/step - loss: 0.1053 - accuracy: 0.9628 - val_loss: 0.1376 - val_accuracy: 0.9681\n", 779 | "Epoch 18/50\n", 780 | "37/37 [==============================] - 4s 109ms/step - loss: 0.0840 - accuracy: 0.9764 - val_loss: 0.1014 - val_accuracy: 0.9787\n", 781 | "Epoch 19/50\n", 782 | "37/37 [==============================] - 4s 107ms/step - loss: 0.0718 - accuracy: 0.9730 - val_loss: 0.1123 - val_accuracy: 0.9681\n", 783 | "Epoch 20/50\n", 784 | "37/37 [==============================] - 4s 108ms/step - loss: 0.0683 - accuracy: 0.9730 - val_loss: 0.0975 - val_accuracy: 0.9681\n", 785 | "Epoch 21/50\n", 786 | "37/37 [==============================] - 4s 108ms/step - loss: 0.0708 - accuracy: 0.9865 - val_loss: 0.0992 - val_accuracy: 0.9787\n", 787 | "Epoch 22/50\n", 788 | "37/37 [==============================] - 4s 110ms/step - loss: 0.0740 - accuracy: 0.9831 - val_loss: 0.0986 - val_accuracy: 0.9787\n", 789 | "Epoch 23/50\n", 790 | "37/37 [==============================] - 4s 108ms/step - loss: 0.0815 - accuracy: 0.9730 - val_loss: 0.0928 - val_accuracy: 0.9787\n", 791 | "Epoch 24/50\n", 792 | "37/37 [==============================] - 4s 107ms/step - loss: 0.0638 - accuracy: 0.9797 - val_loss: 0.0921 - val_accuracy: 0.9787\n", 793 | "Epoch 25/50\n", 794 | "37/37 [==============================] - 4s 107ms/step - loss: 0.0672 - accuracy: 0.9899 - val_loss: 0.0948 - val_accuracy: 0.9787\n", 795 | "Epoch 26/50\n", 796 | "37/37 [==============================] - 4s 108ms/step - loss: 0.0637 - accuracy: 0.9764 - val_loss: 0.0881 - val_accuracy: 0.9787\n", 797 | "Epoch 27/50\n", 798 | "37/37 [==============================] - 4s 106ms/step - loss: 0.0581 - accuracy: 0.9899 - val_loss: 0.0995 - val_accuracy: 0.9681\n", 799 | "Epoch 28/50\n", 800 | "37/37 [==============================] - 4s 106ms/step - loss: 0.0561 - accuracy: 0.9797 - val_loss: 0.0873 - val_accuracy: 0.9787\n", 801 | "Epoch 29/50\n", 802 | "37/37 [==============================] - 4s 108ms/step - loss: 0.0523 - accuracy: 0.9899 - val_loss: 0.1072 - val_accuracy: 0.9681\n", 803 | "Epoch 30/50\n", 804 | "37/37 [==============================] - 4s 106ms/step - loss: 0.0644 - accuracy: 0.9764 - val_loss: 0.0988 - val_accuracy: 0.9681\n", 805 | "Epoch 31/50\n", 806 | "37/37 [==============================] - 4s 106ms/step - loss: 0.0535 - accuracy: 0.9831 - val_loss: 0.0851 - val_accuracy: 0.9787\n", 807 | "Epoch 32/50\n", 808 | "37/37 [==============================] - 4s 107ms/step - loss: 0.0480 - accuracy: 0.9899 - val_loss: 0.0952 - val_accuracy: 0.9681\n", 809 | "Epoch 33/50\n", 810 | "37/37 [==============================] - 4s 106ms/step - loss: 0.0458 - accuracy: 0.9865 - val_loss: 0.0916 - val_accuracy: 0.9787\n", 811 | "Epoch 34/50\n", 812 | "37/37 [==============================] - 4s 106ms/step - loss: 0.0388 - accuracy: 0.9932 - val_loss: 0.1017 - val_accuracy: 0.9681\n", 813 | "Epoch 35/50\n", 814 | "37/37 [==============================] - 4s 107ms/step - loss: 0.0419 - accuracy: 0.9865 - val_loss: 0.0887 - val_accuracy: 0.9787\n", 815 | "Epoch 36/50\n", 816 | "37/37 [==============================] - 4s 107ms/step - loss: 0.0490 - accuracy: 0.9831 - val_loss: 0.1045 - val_accuracy: 0.9681\n", 817 | "Epoch 37/50\n", 818 | "37/37 [==============================] - 4s 108ms/step - loss: 0.0554 - accuracy: 0.9865 - val_loss: 0.0823 - val_accuracy: 0.9787\n", 819 | "Epoch 38/50\n", 820 | "37/37 [==============================] - 4s 106ms/step - loss: 0.0474 - accuracy: 0.9865 - val_loss: 0.1017 - val_accuracy: 0.9681\n", 821 | "Epoch 39/50\n", 822 | "37/37 [==============================] - 4s 107ms/step - loss: 0.0375 - accuracy: 0.9865 - val_loss: 0.0863 - val_accuracy: 0.9787\n", 823 | "Epoch 40/50\n", 824 | "37/37 [==============================] - 4s 108ms/step - loss: 0.0261 - accuracy: 0.9966 - val_loss: 0.0862 - val_accuracy: 0.9787\n", 825 | "Epoch 41/50\n", 826 | "37/37 [==============================] - 4s 107ms/step - loss: 0.0368 - accuracy: 0.9899 - val_loss: 0.0959 - val_accuracy: 0.9681\n", 827 | "Epoch 42/50\n", 828 | "37/37 [==============================] - 4s 108ms/step - loss: 0.0262 - accuracy: 0.9932 - val_loss: 0.0879 - val_accuracy: 0.9787\n", 829 | "Epoch 43/50\n", 830 | "37/37 [==============================] - 4s 109ms/step - loss: 0.0347 - accuracy: 0.9865 - val_loss: 0.0725 - val_accuracy: 0.9787\n", 831 | "Epoch 44/50\n", 832 | "37/37 [==============================] - 4s 107ms/step - loss: 0.0340 - accuracy: 0.9932 - val_loss: 0.0813 - val_accuracy: 0.9787\n", 833 | "Epoch 45/50\n", 834 | "37/37 [==============================] - 4s 108ms/step - loss: 0.0338 - accuracy: 0.9932 - val_loss: 0.0870 - val_accuracy: 0.9787\n", 835 | "Epoch 46/50\n", 836 | "37/37 [==============================] - 4s 107ms/step - loss: 0.0212 - accuracy: 0.9966 - val_loss: 0.0966 - val_accuracy: 0.9681\n", 837 | "Epoch 47/50\n", 838 | "37/37 [==============================] - 4s 107ms/step - loss: 0.0326 - accuracy: 0.9865 - val_loss: 0.0823 - val_accuracy: 0.9787\n", 839 | "Epoch 48/50\n", 840 | "37/37 [==============================] - 4s 109ms/step - loss: 0.0278 - accuracy: 0.9932 - val_loss: 0.0751 - val_accuracy: 0.9787\n", 841 | "Epoch 49/50\n", 842 | "37/37 [==============================] - 4s 109ms/step - loss: 0.0304 - accuracy: 0.9899 - val_loss: 0.0750 - val_accuracy: 0.9787\n", 843 | "Epoch 50/50\n", 844 | "37/37 [==============================] - 4s 108ms/step - loss: 0.0412 - accuracy: 0.9797 - val_loss: 0.0665 - val_accuracy: 0.9894\n" 845 | ], 846 | "name": "stdout" 847 | } 848 | ] 849 | }, 850 | { 851 | "cell_type": "markdown", 852 | "metadata": { 853 | "id": "n8WBvnWaEc2t", 854 | "colab_type": "text" 855 | }, 856 | "source": [ 857 | "# **prediction**" 858 | ] 859 | }, 860 | { 861 | "cell_type": "code", 862 | "metadata": { 863 | "id": "ax8loRHIqXri", 864 | "colab_type": "code", 865 | "colab": { 866 | "base_uri": "https://localhost:8080/", 867 | "height": 1000 868 | }, 869 | "outputId": "639b5206-0578-4247-9eea-946615358a08" 870 | }, 871 | "source": [ 872 | "model.predict(x_test)" 873 | ], 874 | "execution_count": 25, 875 | "outputs": [ 876 | { 877 | "output_type": "execute_result", 878 | "data": { 879 | "text/plain": [ 880 | "array([[3.3645108e-04],\n", 881 | " [1.2787229e-04],\n", 882 | " [3.0687204e-04],\n", 883 | " [1.8594293e-04],\n", 884 | " [8.8789128e-04],\n", 885 | " [2.2062780e-03],\n", 886 | " [8.5091895e-05],\n", 887 | " [2.3272367e-04],\n", 888 | " [6.8524624e-03],\n", 889 | " [2.3365160e-03],\n", 890 | " [4.6214741e-04],\n", 891 | " [4.8456270e-01],\n", 892 | " [2.1608867e-04],\n", 893 | " [9.8477978e-01],\n", 894 | " [3.7198868e-05],\n", 895 | " [5.3255004e-04],\n", 896 | " [3.3552865e-03],\n", 897 | " [3.3389160e-04],\n", 898 | " [9.6598919e-04],\n", 899 | " [9.8803849e-04],\n", 900 | " [6.0224649e-04],\n", 901 | " [6.5070181e-04],\n", 902 | " [3.8730315e-04],\n", 903 | " [2.2554809e-01],\n", 904 | " [3.8252815e-04],\n", 905 | " [4.0165377e-03],\n", 906 | " [1.2528970e-03],\n", 907 | " [9.4746792e-05],\n", 908 | " [4.5105047e-04],\n", 909 | " [2.0995762e-03],\n", 910 | " [4.0024426e-04],\n", 911 | " [2.2839874e-03],\n", 912 | " [7.5507886e-04],\n", 913 | " [4.5164018e-03],\n", 914 | " [1.0044893e-02],\n", 915 | " [4.7172961e-04],\n", 916 | " [1.3611122e-04],\n", 917 | " [7.9747280e-03],\n", 918 | " [4.3664771e-04],\n", 919 | " [7.0768646e-03],\n", 920 | " [3.5995257e-04],\n", 921 | " [2.3332106e-02],\n", 922 | " [3.4441127e-04],\n", 923 | " [3.0484118e-03],\n", 924 | " [5.2329859e-05],\n", 925 | " [1.0854526e-03],\n", 926 | " [1.5450418e-03],\n", 927 | " [9.9200833e-01],\n", 928 | " [9.9901772e-01],\n", 929 | " [9.9973100e-01],\n", 930 | " [9.9732852e-01],\n", 931 | " [9.9922824e-01],\n", 932 | " [9.9430370e-01],\n", 933 | " [9.9756771e-01],\n", 934 | " [9.5160627e-01],\n", 935 | " [9.9445552e-01],\n", 936 | " [8.8230652e-01],\n", 937 | " [9.9962628e-01],\n", 938 | " [9.9718899e-01],\n", 939 | " [9.9953568e-01],\n", 940 | " [9.2060763e-01],\n", 941 | " [9.9866438e-01],\n", 942 | " [9.9978584e-01],\n", 943 | " [9.9577904e-01],\n", 944 | " [9.9678314e-01],\n", 945 | " [9.9796391e-01],\n", 946 | " [9.9735093e-01],\n", 947 | " [9.9920744e-01],\n", 948 | " [9.9899489e-01],\n", 949 | " [9.9944538e-01],\n", 950 | " [9.9932742e-01],\n", 951 | " [9.9955171e-01],\n", 952 | " [9.9872005e-01],\n", 953 | " [9.8695487e-01],\n", 954 | " [9.9914718e-01],\n", 955 | " [9.9996626e-01],\n", 956 | " [9.9959117e-01],\n", 957 | " [9.9994636e-01],\n", 958 | " [9.9727505e-01],\n", 959 | " [9.9915981e-01],\n", 960 | " [9.9951649e-01],\n", 961 | " [9.9986589e-01],\n", 962 | " [9.9372864e-01],\n", 963 | " [9.7922003e-01],\n", 964 | " [9.9798799e-01],\n", 965 | " [9.9974900e-01],\n", 966 | " [9.9997377e-01],\n", 967 | " [9.6405590e-01],\n", 968 | " [9.2889494e-01],\n", 969 | " [9.9977261e-01],\n", 970 | " [9.9922121e-01],\n", 971 | " [9.9944884e-01],\n", 972 | " [6.2954772e-01],\n", 973 | " [8.8289243e-01]], dtype=float32)" 974 | ] 975 | }, 976 | "metadata": { 977 | "tags": [] 978 | }, 979 | "execution_count": 25 980 | } 981 | ] 982 | }, 983 | { 984 | "cell_type": "code", 985 | "metadata": { 986 | "id": "90LWTRpDVYaN", 987 | "colab_type": "code", 988 | "colab": { 989 | "base_uri": "https://localhost:8080/", 990 | "height": 134 991 | }, 992 | "outputId": "1eb38e27-762d-4f5e-fd27-afd5926e62da" 993 | }, 994 | "source": [ 995 | "# predicts = model.predict(x_test)\n", 996 | "th = np.linspace(0.2, 0.8, 7)\n", 997 | "all_cms = []\n", 998 | "all_sens = []\n", 999 | "all_spec = []\n", 1000 | "all_acc = []\n", 1001 | "for t in th:\n", 1002 | " preds = model.predict(x_test)\n", 1003 | " preds[preds > t] = 1\n", 1004 | " preds[preds < t] = 0\n", 1005 | " cm = confusion_matrix(y_test, preds)\n", 1006 | " total = sum(sum(cm))\n", 1007 | " acc = (cm[0, 0] + cm[1, 1]) / total\n", 1008 | " all_acc.append(acc)\n", 1009 | " sensitivity = cm[0, 0] / (cm[0, 0] + cm[0, 1])\n", 1010 | " specificity = cm[1, 1] / (cm[1, 0] + cm[1, 1])\n", 1011 | " all_cms.append(cm)\n", 1012 | " all_sens.append(sensitivity)\n", 1013 | " all_spec.append(specificity)\n", 1014 | "print('best accuracy: ', all_acc[np.argmax(all_acc)])\n", 1015 | "print('threshold: ', th[np.argmax(all_acc)])\n", 1016 | "print('confusion matrix: ', '\\n', all_cms[np.argmax(all_acc)])\n", 1017 | "print('sensitivity: ', all_sens[np.argmax(all_acc)])\n", 1018 | "print('specificity: ', all_spec[np.argmax(all_acc)])\n", 1019 | "\n" 1020 | ], 1021 | "execution_count": 26, 1022 | "outputs": [ 1023 | { 1024 | "output_type": "stream", 1025 | "text": [ 1026 | "best accuracy: 0.9893617021276596\n", 1027 | "threshold: 0.5\n", 1028 | "confusion matrix: \n", 1029 | " [[46 1]\n", 1030 | " [ 0 47]]\n", 1031 | "sensitivity: 0.9787234042553191\n", 1032 | "specificity: 1.0\n" 1033 | ], 1034 | "name": "stdout" 1035 | } 1036 | ] 1037 | } 1038 | ] 1039 | } --------------------------------------------------------------------------------