├── README.md
└── COVID_19_detection.ipynb
/README.md:
--------------------------------------------------------------------------------
1 | # covid-19-detection-project
2 |
3 | ## Jupyter notebook for the blog post:
4 |
5 | [تشخیص کرونا با یادگیری عمیق](https://howsam.org/covid19-detection/)
6 |
--------------------------------------------------------------------------------
/COVID_19_detection.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "COVID-19-detection.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "toc_visible": true
10 | },
11 | "kernelspec": {
12 | "name": "python3",
13 | "display_name": "Python 3"
14 | },
15 | "accelerator": "GPU"
16 | },
17 | "cells": [
18 | {
19 | "cell_type": "markdown",
20 | "metadata": {
21 | "id": "gTv0ayLu_Nsf",
22 | "colab_type": "text"
23 | },
24 | "source": [
25 | "# **Download data from github**"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "metadata": {
31 | "id": "VPw6cHEZO-SJ",
32 | "colab_type": "code",
33 | "colab": {
34 | "base_uri": "https://localhost:8080/",
35 | "height": 150
36 | },
37 | "outputId": "1b1137c0-6f00-4de2-b3cc-15a5eec8f5b0"
38 | },
39 | "source": [
40 | "! git clone https://github.com/ieee8023/covid-chestxray-dataset"
41 | ],
42 | "execution_count": 1,
43 | "outputs": [
44 | {
45 | "output_type": "stream",
46 | "text": [
47 | "Cloning into 'covid-chestxray-dataset'...\n",
48 | "remote: Enumerating objects: 36, done.\u001b[K\n",
49 | "remote: Counting objects: 100% (36/36), done.\u001b[K\n",
50 | "remote: Compressing objects: 100% (28/28), done.\u001b[K\n",
51 | "remote: Total 3651 (delta 13), reused 21 (delta 8), pack-reused 3615\u001b[K\n",
52 | "Receiving objects: 100% (3651/3651), 632.29 MiB | 42.82 MiB/s, done.\n",
53 | "Resolving deltas: 100% (1461/1461), done.\n",
54 | "Checking out files: 100% (1164/1164), done.\n"
55 | ],
56 | "name": "stdout"
57 | }
58 | ]
59 | },
60 | {
61 | "cell_type": "markdown",
62 | "metadata": {
63 | "id": "HSq5-DLR_fXl",
64 | "colab_type": "text"
65 | },
66 | "source": [
67 | "# **create a folder named dataset and ...**"
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "metadata": {
73 | "id": "VL9QTzwsPu2z",
74 | "colab_type": "code",
75 | "colab": {}
76 | },
77 | "source": [
78 | "! mkdir dataset\n",
79 | "! mkdir /content/dataset/train\n",
80 | "! mkdir /content/dataset/train/covid\n",
81 | "! mkdir /content/dataset/train/normal\n",
82 | "! mkdir /content/dataset/test\n",
83 | "! mkdir /content/dataset/test/covid\n",
84 | "! mkdir /content/dataset/test/normal"
85 | ],
86 | "execution_count": 2,
87 | "outputs": []
88 | },
89 | {
90 | "cell_type": "markdown",
91 | "metadata": {
92 | "id": "ltMFV5BLAQIG",
93 | "colab_type": "text"
94 | },
95 | "source": [
96 | "# **Download data from kaggle**\n",
97 | "from kaggle - my account press \"create new token api\". upload downloaded file here."
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "metadata": {
103 | "id": "_lo53zC6SAIR",
104 | "colab_type": "code",
105 | "colab": {
106 | "resources": {
107 | "http://localhost:8080/nbextensions/google.colab/files.js": {
108 | "data": "Ly8gQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQwovLwovLyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKLy8geW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLgovLyBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXQKLy8KLy8gICAgICBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjAKLy8KLy8gVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQovLyBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiAiQVMgSVMiIEJBU0lTLAovLyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KLy8gU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAovLyBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS4KCi8qKgogKiBAZmlsZW92ZXJ2aWV3IEhlbHBlcnMgZm9yIGdvb2dsZS5jb2xhYiBQeXRob24gbW9kdWxlLgogKi8KKGZ1bmN0aW9uKHNjb3BlKSB7CmZ1bmN0aW9uIHNwYW4odGV4dCwgc3R5bGVBdHRyaWJ1dGVzID0ge30pIHsKICBjb25zdCBlbGVtZW50ID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnc3BhbicpOwogIGVsZW1lbnQudGV4dENvbnRlbnQgPSB0ZXh0OwogIGZvciAoY29uc3Qga2V5IG9mIE9iamVjdC5rZXlzKHN0eWxlQXR0cmlidXRlcykpIHsKICAgIGVsZW1lbnQuc3R5bGVba2V5XSA9IHN0eWxlQXR0cmlidXRlc1trZXldOwogIH0KICByZXR1cm4gZWxlbWVudDsKfQoKLy8gTWF4IG51bWJlciBvZiBieXRlcyB3aGljaCB3aWxsIGJlIHVwbG9hZGVkIGF0IGEgdGltZS4KY29uc3QgTUFYX1BBWUxPQURfU0laRSA9IDEwMCAqIDEwMjQ7CgpmdW5jdGlvbiBfdXBsb2FkRmlsZXMoaW5wdXRJZCwgb3V0cHV0SWQpIHsKICBjb25zdCBzdGVwcyA9IHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCk7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICAvLyBDYWNoZSBzdGVwcyBvbiB0aGUgb3V0cHV0RWxlbWVudCB0byBtYWtlIGl0IGF2YWlsYWJsZSBmb3IgdGhlIG5leHQgY2FsbAogIC8vIHRvIHVwbG9hZEZpbGVzQ29udGludWUgZnJvbSBQeXRob24uCiAgb3V0cHV0RWxlbWVudC5zdGVwcyA9IHN0ZXBzOwoKICByZXR1cm4gX3VwbG9hZEZpbGVzQ29udGludWUob3V0cHV0SWQpOwp9CgovLyBUaGlzIGlzIHJvdWdobHkgYW4gYXN5bmMgZ2VuZXJhdG9yIChub3Qgc3VwcG9ydGVkIGluIHRoZSBicm93c2VyIHlldCksCi8vIHdoZXJlIHRoZXJlIGFyZSBtdWx0aXBsZSBhc3luY2hyb25vdXMgc3RlcHMgYW5kIHRoZSBQeXRob24gc2lkZSBpcyBnb2luZwovLyB0byBwb2xsIGZvciBjb21wbGV0aW9uIG9mIGVhY2ggc3RlcC4KLy8gVGhpcyB1c2VzIGEgUHJvbWlzZSB0byBibG9jayB0aGUgcHl0aG9uIHNpZGUgb24gY29tcGxldGlvbiBvZiBlYWNoIHN0ZXAsCi8vIHRoZW4gcGFzc2VzIHRoZSByZXN1bHQgb2YgdGhlIHByZXZpb3VzIHN0ZXAgYXMgdGhlIGlucHV0IHRvIHRoZSBuZXh0IHN0ZXAuCmZ1bmN0aW9uIF91cGxvYWRGaWxlc0NvbnRpbnVlKG91dHB1dElkKSB7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICBjb25zdCBzdGVwcyA9IG91dHB1dEVsZW1lbnQuc3RlcHM7CgogIGNvbnN0IG5leHQgPSBzdGVwcy5uZXh0KG91dHB1dEVsZW1lbnQubGFzdFByb21pc2VWYWx1ZSk7CiAgcmV0dXJuIFByb21pc2UucmVzb2x2ZShuZXh0LnZhbHVlLnByb21pc2UpLnRoZW4oKHZhbHVlKSA9PiB7CiAgICAvLyBDYWNoZSB0aGUgbGFzdCBwcm9taXNlIHZhbHVlIHRvIG1ha2UgaXQgYXZhaWxhYmxlIHRvIHRoZSBuZXh0CiAgICAvLyBzdGVwIG9mIHRoZSBnZW5lcmF0b3IuCiAgICBvdXRwdXRFbGVtZW50Lmxhc3RQcm9taXNlVmFsdWUgPSB2YWx1ZTsKICAgIHJldHVybiBuZXh0LnZhbHVlLnJlc3BvbnNlOwogIH0pOwp9CgovKioKICogR2VuZXJhdG9yIGZ1bmN0aW9uIHdoaWNoIGlzIGNhbGxlZCBiZXR3ZWVuIGVhY2ggYXN5bmMgc3RlcCBvZiB0aGUgdXBsb2FkCiAqIHByb2Nlc3MuCiAqIEBwYXJhbSB7c3RyaW5nfSBpbnB1dElkIEVsZW1lbnQgSUQgb2YgdGhlIGlucHV0IGZpbGUgcGlja2VyIGVsZW1lbnQuCiAqIEBwYXJhbSB7c3RyaW5nfSBvdXRwdXRJZCBFbGVtZW50IElEIG9mIHRoZSBvdXRwdXQgZGlzcGxheS4KICogQHJldHVybiB7IUl0ZXJhYmxlPCFPYmplY3Q+fSBJdGVyYWJsZSBvZiBuZXh0IHN0ZXBzLgogKi8KZnVuY3Rpb24qIHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCkgewogIGNvbnN0IGlucHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKGlucHV0SWQpOwogIGlucHV0RWxlbWVudC5kaXNhYmxlZCA9IGZhbHNlOwoKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIG91dHB1dEVsZW1lbnQuaW5uZXJIVE1MID0gJyc7CgogIGNvbnN0IHBpY2tlZFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgaW5wdXRFbGVtZW50LmFkZEV2ZW50TGlzdGVuZXIoJ2NoYW5nZScsIChlKSA9PiB7CiAgICAgIHJlc29sdmUoZS50YXJnZXQuZmlsZXMpOwogICAgfSk7CiAgfSk7CgogIGNvbnN0IGNhbmNlbCA9IGRvY3VtZW50LmNyZWF0ZUVsZW1lbnQoJ2J1dHRvbicpOwogIGlucHV0RWxlbWVudC5wYXJlbnRFbGVtZW50LmFwcGVuZENoaWxkKGNhbmNlbCk7CiAgY2FuY2VsLnRleHRDb250ZW50ID0gJ0NhbmNlbCB1cGxvYWQnOwogIGNvbnN0IGNhbmNlbFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgY2FuY2VsLm9uY2xpY2sgPSAoKSA9PiB7CiAgICAgIHJlc29sdmUobnVsbCk7CiAgICB9OwogIH0pOwoKICAvLyBXYWl0IGZvciB0aGUgdXNlciB0byBwaWNrIHRoZSBmaWxlcy4KICBjb25zdCBmaWxlcyA9IHlpZWxkIHsKICAgIHByb21pc2U6IFByb21pc2UucmFjZShbcGlja2VkUHJvbWlzZSwgY2FuY2VsUHJvbWlzZV0pLAogICAgcmVzcG9uc2U6IHsKICAgICAgYWN0aW9uOiAnc3RhcnRpbmcnLAogICAgfQogIH07CgogIGNhbmNlbC5yZW1vdmUoKTsKCiAgLy8gRGlzYWJsZSB0aGUgaW5wdXQgZWxlbWVudCBzaW5jZSBmdXJ0aGVyIHBpY2tzIGFyZSBub3QgYWxsb3dlZC4KICBpbnB1dEVsZW1lbnQuZGlzYWJsZWQgPSB0cnVlOwoKICBpZiAoIWZpbGVzKSB7CiAgICByZXR1cm4gewogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbXBsZXRlJywKICAgICAgfQogICAgfTsKICB9CgogIGZvciAoY29uc3QgZmlsZSBvZiBmaWxlcykgewogICAgY29uc3QgbGkgPSBkb2N1bWVudC5jcmVhdGVFbGVtZW50KCdsaScpOwogICAgbGkuYXBwZW5kKHNwYW4oZmlsZS5uYW1lLCB7Zm9udFdlaWdodDogJ2JvbGQnfSkpOwogICAgbGkuYXBwZW5kKHNwYW4oCiAgICAgICAgYCgke2ZpbGUudHlwZSB8fCAnbi9hJ30pIC0gJHtmaWxlLnNpemV9IGJ5dGVzLCBgICsKICAgICAgICBgbGFzdCBtb2RpZmllZDogJHsKICAgICAgICAgICAgZmlsZS5sYXN0TW9kaWZpZWREYXRlID8gZmlsZS5sYXN0TW9kaWZpZWREYXRlLnRvTG9jYWxlRGF0ZVN0cmluZygpIDoKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgJ24vYSd9IC0gYCkpOwogICAgY29uc3QgcGVyY2VudCA9IHNwYW4oJzAlIGRvbmUnKTsKICAgIGxpLmFwcGVuZENoaWxkKHBlcmNlbnQpOwoKICAgIG91dHB1dEVsZW1lbnQuYXBwZW5kQ2hpbGQobGkpOwoKICAgIGNvbnN0IGZpbGVEYXRhUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICAgIGNvbnN0IHJlYWRlciA9IG5ldyBGaWxlUmVhZGVyKCk7CiAgICAgIHJlYWRlci5vbmxvYWQgPSAoZSkgPT4gewogICAgICAgIHJlc29sdmUoZS50YXJnZXQucmVzdWx0KTsKICAgICAgfTsKICAgICAgcmVhZGVyLnJlYWRBc0FycmF5QnVmZmVyKGZpbGUpOwogICAgfSk7CiAgICAvLyBXYWl0IGZvciB0aGUgZGF0YSB0byBiZSByZWFkeS4KICAgIGxldCBmaWxlRGF0YSA9IHlpZWxkIHsKICAgICAgcHJvbWlzZTogZmlsZURhdGFQcm9taXNlLAogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbnRpbnVlJywKICAgICAgfQogICAgfTsKCiAgICAvLyBVc2UgYSBjaHVua2VkIHNlbmRpbmcgdG8gYXZvaWQgbWVzc2FnZSBzaXplIGxpbWl0cy4gU2VlIGIvNjIxMTU2NjAuCiAgICBsZXQgcG9zaXRpb24gPSAwOwogICAgd2hpbGUgKHBvc2l0aW9uIDwgZmlsZURhdGEuYnl0ZUxlbmd0aCkgewogICAgICBjb25zdCBsZW5ndGggPSBNYXRoLm1pbihmaWxlRGF0YS5ieXRlTGVuZ3RoIC0gcG9zaXRpb24sIE1BWF9QQVlMT0FEX1NJWkUpOwogICAgICBjb25zdCBjaHVuayA9IG5ldyBVaW50OEFycmF5KGZpbGVEYXRhLCBwb3NpdGlvbiwgbGVuZ3RoKTsKICAgICAgcG9zaXRpb24gKz0gbGVuZ3RoOwoKICAgICAgY29uc3QgYmFzZTY0ID0gYnRvYShTdHJpbmcuZnJvbUNoYXJDb2RlLmFwcGx5KG51bGwsIGNodW5rKSk7CiAgICAgIHlpZWxkIHsKICAgICAgICByZXNwb25zZTogewogICAgICAgICAgYWN0aW9uOiAnYXBwZW5kJywKICAgICAgICAgIGZpbGU6IGZpbGUubmFtZSwKICAgICAgICAgIGRhdGE6IGJhc2U2NCwKICAgICAgICB9LAogICAgICB9OwogICAgICBwZXJjZW50LnRleHRDb250ZW50ID0KICAgICAgICAgIGAke01hdGgucm91bmQoKHBvc2l0aW9uIC8gZmlsZURhdGEuYnl0ZUxlbmd0aCkgKiAxMDApfSUgZG9uZWA7CiAgICB9CiAgfQoKICAvLyBBbGwgZG9uZS4KICB5aWVsZCB7CiAgICByZXNwb25zZTogewogICAgICBhY3Rpb246ICdjb21wbGV0ZScsCiAgICB9CiAgfTsKfQoKc2NvcGUuZ29vZ2xlID0gc2NvcGUuZ29vZ2xlIHx8IHt9OwpzY29wZS5nb29nbGUuY29sYWIgPSBzY29wZS5nb29nbGUuY29sYWIgfHwge307CnNjb3BlLmdvb2dsZS5jb2xhYi5fZmlsZXMgPSB7CiAgX3VwbG9hZEZpbGVzLAogIF91cGxvYWRGaWxlc0NvbnRpbnVlLAp9Owp9KShzZWxmKTsK",
109 | "ok": true,
110 | "headers": [
111 | [
112 | "content-type",
113 | "application/javascript"
114 | ]
115 | ],
116 | "status": 200,
117 | "status_text": ""
118 | }
119 | },
120 | "base_uri": "https://localhost:8080/",
121 | "height": 290
122 | },
123 | "outputId": "3365e6fd-1f2e-42e4-f45e-8f728a36931f"
124 | },
125 | "source": [
126 | "! pip install kaggle\n",
127 | "from google.colab import files\n",
128 | "files.upload()"
129 | ],
130 | "execution_count": 3,
131 | "outputs": [
132 | {
133 | "output_type": "stream",
134 | "text": [
135 | "Requirement already satisfied: kaggle in /usr/local/lib/python3.6/dist-packages (1.5.8)\n",
136 | "Requirement already satisfied: slugify in /usr/local/lib/python3.6/dist-packages (from kaggle) (0.0.1)\n",
137 | "Requirement already satisfied: six>=1.10 in /usr/local/lib/python3.6/dist-packages (from kaggle) (1.15.0)\n",
138 | "Requirement already satisfied: python-dateutil in /usr/local/lib/python3.6/dist-packages (from kaggle) (2.8.1)\n",
139 | "Requirement already satisfied: python-slugify in /usr/local/lib/python3.6/dist-packages (from kaggle) (4.0.1)\n",
140 | "Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from kaggle) (4.41.1)\n",
141 | "Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from kaggle) (1.24.3)\n",
142 | "Requirement already satisfied: certifi in /usr/local/lib/python3.6/dist-packages (from kaggle) (2020.6.20)\n",
143 | "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from kaggle) (2.23.0)\n",
144 | "Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.6/dist-packages (from python-slugify->kaggle) (1.3)\n",
145 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->kaggle) (2.10)\n",
146 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->kaggle) (3.0.4)\n"
147 | ],
148 | "name": "stdout"
149 | },
150 | {
151 | "output_type": "display_data",
152 | "data": {
153 | "text/html": [
154 | "\n",
155 | " \n",
157 | " \n",
161 | " "
162 | ],
163 | "text/plain": [
164 | ""
165 | ]
166 | },
167 | "metadata": {
168 | "tags": []
169 | }
170 | },
171 | {
172 | "output_type": "stream",
173 | "text": [
174 | "Saving kaggle.json to kaggle.json\n"
175 | ],
176 | "name": "stdout"
177 | },
178 | {
179 | "output_type": "execute_result",
180 | "data": {
181 | "text/plain": [
182 | "{'kaggle.json': b'{\"username\":\"golabiabi\",\"key\":\"ccaab0263258794f607a2b32e397be80\"}'}"
183 | ]
184 | },
185 | "metadata": {
186 | "tags": []
187 | },
188 | "execution_count": 3
189 | }
190 | ]
191 | },
192 | {
193 | "cell_type": "code",
194 | "metadata": {
195 | "id": "9QFcJfhqSRv8",
196 | "colab_type": "code",
197 | "colab": {
198 | "base_uri": "https://localhost:8080/",
199 | "height": 67
200 | },
201 | "outputId": "45f905b9-23b8-4b41-eaa1-1d511371dbf9"
202 | },
203 | "source": [
204 | "# ******* make root/kaggle folder **********\n",
205 | "! mkdir -p ~/.kaggle\n",
206 | "!cp kaggle.json ~/.kaggle/\n",
207 | "!chmod 600 ~/.kaggle/kaggle.json\n",
208 | "# ************** download data **********\n",
209 | "! kaggle datasets download -d paultimothymooney/chest-xray-pneumonia\n",
210 | "# !unzip\n",
211 | "!unzip -q /content/chest-xray-pneumonia.zip"
212 | ],
213 | "execution_count": 4,
214 | "outputs": [
215 | {
216 | "output_type": "stream",
217 | "text": [
218 | "Downloading chest-xray-pneumonia.zip to /content\n",
219 | " 99% 2.27G/2.29G [00:30<00:00, 41.8MB/s]\n",
220 | "100% 2.29G/2.29G [00:30<00:00, 79.8MB/s]\n"
221 | ],
222 | "name": "stdout"
223 | }
224 | ]
225 | },
226 | {
227 | "cell_type": "markdown",
228 | "metadata": {
229 | "id": "0ENy-aFwAe6a",
230 | "colab_type": "text"
231 | },
232 | "source": [
233 | "# **extract all covid cases**"
234 | ]
235 | },
236 | {
237 | "cell_type": "code",
238 | "metadata": {
239 | "id": "qjYzBvMV-wNu",
240 | "colab_type": "code",
241 | "colab": {}
242 | },
243 | "source": [
244 | "import pandas as pd\n",
245 | "\n",
246 | "csvPath = '/content/covid-chestxray-dataset/metadata.csv'\n",
247 | "df = pd.read_csv(csvPath)"
248 | ],
249 | "execution_count": 5,
250 | "outputs": []
251 | },
252 | {
253 | "cell_type": "code",
254 | "metadata": {
255 | "id": "k7cC4zSkV_od",
256 | "colab_type": "code",
257 | "colab": {
258 | "base_uri": "https://localhost:8080/",
259 | "height": 33
260 | },
261 | "outputId": "40f2debe-6686-4793-c825-6c2c133ae395"
262 | },
263 | "source": [
264 | "df.__len__()"
265 | ],
266 | "execution_count": 6,
267 | "outputs": [
268 | {
269 | "output_type": "execute_result",
270 | "data": {
271 | "text/plain": [
272 | "941"
273 | ]
274 | },
275 | "metadata": {
276 | "tags": []
277 | },
278 | "execution_count": 6
279 | }
280 | ]
281 | },
282 | {
283 | "cell_type": "code",
284 | "metadata": {
285 | "id": "gBP-mLlhTOo4",
286 | "colab_type": "code",
287 | "colab": {}
288 | },
289 | "source": [
290 | "p_id = 0\n",
291 | "cov_fn = []\n",
292 | "allfiles = []\n",
293 | "\n",
294 | "for (i, row) in df.iterrows():\n",
295 | "\tn_id = row[\"patientid\"]\n",
296 | "\n",
297 | "\tif n_id != p_id and len(cov_fn)>0 and p_id != 0:\n",
298 | "\t\tallfiles.append(cov_fn)\n",
299 | "\t\tcov_fn = []\n",
300 | "\tif row[\"finding\"] == \"Pneumonia/Viral/COVID-19\" and row[\"view\"] == \"PA\":\n",
301 | "\t\tcov_fn.append(row[\"filename\"])\n",
302 | "\t\tp_id = row[\"patientid\"]"
303 | ],
304 | "execution_count": 7,
305 | "outputs": []
306 | },
307 | {
308 | "cell_type": "markdown",
309 | "metadata": {
310 | "id": "jgkCE6i6A1Ul",
311 | "colab_type": "text"
312 | },
313 | "source": [
314 | "# **split covid data to train and test**"
315 | ]
316 | },
317 | {
318 | "cell_type": "code",
319 | "metadata": {
320 | "id": "ZMRRFuP6E4E3",
321 | "colab_type": "code",
322 | "colab": {}
323 | },
324 | "source": [
325 | "from sklearn.model_selection import train_test_split\n",
326 | "\t\n",
327 | "x_train_c, x_test_c = train_test_split(allfiles, test_size=0.20, random_state=23)"
328 | ],
329 | "execution_count": 8,
330 | "outputs": []
331 | },
332 | {
333 | "cell_type": "markdown",
334 | "metadata": {
335 | "id": "SiF5jPyIBF2o",
336 | "colab_type": "text"
337 | },
338 | "source": [
339 | "# **save data to folders**"
340 | ]
341 | },
342 | {
343 | "cell_type": "code",
344 | "metadata": {
345 | "id": "69d8Klcx_C3Q",
346 | "colab_type": "code",
347 | "colab": {}
348 | },
349 | "source": [
350 | "import shutil\n",
351 | "\n",
352 | "for img in sum(x_train_c, []):\n",
353 | " src = '/content/covid-chestxray-dataset/images/' + img\n",
354 | " dst = '/content/dataset/train/covid/' + img\n",
355 | " shutil.copy2(src, dst)\n",
356 | "\n",
357 | "for img in sum(x_test_c, []):\n",
358 | " src = '/content/covid-chestxray-dataset/images/' + img\n",
359 | " dst = '/content/dataset/test/covid/' + img\n",
360 | " shutil.copy2(src, dst)"
361 | ],
362 | "execution_count": 9,
363 | "outputs": []
364 | },
365 | {
366 | "cell_type": "code",
367 | "metadata": {
368 | "id": "bkAHDv_VUJto",
369 | "colab_type": "code",
370 | "colab": {}
371 | },
372 | "source": [
373 | "import os\n",
374 | "import cv2\n",
375 | "import random\n",
376 | "\n",
377 | "n_samples = len(sum(allfiles, []))\n",
378 | "kaggle_data_path = '/content/chest_xray/train/NORMAL/'\n",
379 | "output_path_train = '/content/dataset/train/normal/'\n",
380 | "output_path_test = '/content/dataset/test/normal/'\n",
381 | "\n",
382 | "filenames = os.listdir(kaggle_data_path)\n",
383 | "random.seed(42)\n",
384 | "filenames = random.sample(filenames, len(filenames))\n",
385 | "for i in range(n_samples):\n",
386 | " n_image = cv2.imread(kaggle_data_path + filenames[i])\n",
387 | " if i < sum(x_train_c, []).__len__():\n",
388 | " cv2.imwrite(output_path_train + filenames[i], n_image)\n",
389 | " else:\n",
390 | " cv2.imwrite(output_path_test + filenames[i], n_image)"
391 | ],
392 | "execution_count": 10,
393 | "outputs": []
394 | },
395 | {
396 | "cell_type": "markdown",
397 | "metadata": {
398 | "id": "xEypj6CLBW11",
399 | "colab_type": "text"
400 | },
401 | "source": [
402 | "# **prepare data**"
403 | ]
404 | },
405 | {
406 | "cell_type": "code",
407 | "metadata": {
408 | "id": "xcx1GbycEdsP",
409 | "colab_type": "code",
410 | "colab": {}
411 | },
412 | "source": [
413 | "import os \n",
414 | "import numpy as np \n",
415 | "\n",
416 | "path = \"/content/dataset\" \n",
417 | "x_train_n = [] \n",
418 | "x_train_c = [] \n",
419 | "x_test_n = [] \n",
420 | "x_test_c = [] \n",
421 | "\n",
422 | "for p in os.listdir(path+ '/train/normal/'): \n",
423 | " x_train_n.append(cv2.imread(path + '/train/normal/' + p)) \n",
424 | "for p in os.listdir(path+ '/train/covid/'): \n",
425 | " x_train_c.append(cv2.imread(path + '/train/covid/' + p)) \n",
426 | "for p in os.listdir(path+ '/test/normal/'): \n",
427 | " x_test_n.append(cv2.imread(path + '/test/normal/' + p)) \n",
428 | "for p in os.listdir(path+ '/test/covid/'): \n",
429 | " x_test_c.append(cv2.imread(path + '/test/covid/' + p))"
430 | ],
431 | "execution_count": 11,
432 | "outputs": []
433 | },
434 | {
435 | "cell_type": "code",
436 | "metadata": {
437 | "id": "nYdYmFZ1E0Nx",
438 | "colab_type": "code",
439 | "colab": {}
440 | },
441 | "source": [
442 | "for i in range(len(x_train_n)):\n",
443 | " x_train_n[i] = cv2.resize(x_train_n[i], (224, 224))\n",
444 | " x_train_c[i] = cv2.resize(x_train_c[i], (224, 224))\n",
445 | "\n",
446 | "for i in range(len(x_test_n)):\n",
447 | " x_test_n[i] = cv2.resize(x_test_n[i], (224, 224))\n",
448 | " x_test_c[i] = cv2.resize(x_test_c[i], (224, 224))"
449 | ],
450 | "execution_count": 12,
451 | "outputs": []
452 | },
453 | {
454 | "cell_type": "code",
455 | "metadata": {
456 | "id": "h1DuI0Y7GGRx",
457 | "colab_type": "code",
458 | "colab": {}
459 | },
460 | "source": [
461 | "x_train_n = np.array(x_train_n)\n",
462 | "x_train_c = np.array(x_train_c)\n",
463 | "x_test_n = np.array(x_test_n)\n",
464 | "x_test_c = np.array(x_test_c)"
465 | ],
466 | "execution_count": 13,
467 | "outputs": []
468 | },
469 | {
470 | "cell_type": "code",
471 | "metadata": {
472 | "id": "VsLZxeigF1Af",
473 | "colab_type": "code",
474 | "colab": {
475 | "base_uri": "https://localhost:8080/",
476 | "height": 84
477 | },
478 | "outputId": "0341825e-1dcd-432a-86b2-6be8c9160a2f"
479 | },
480 | "source": [
481 | "print(x_train_n.shape) \n",
482 | "print(x_train_c.shape) \n",
483 | "print(x_test_n.shape) \n",
484 | "print(x_test_c.shape)"
485 | ],
486 | "execution_count": 14,
487 | "outputs": [
488 | {
489 | "output_type": "stream",
490 | "text": [
491 | "(148, 224, 224, 3)\n",
492 | "(148, 224, 224, 3)\n",
493 | "(47, 224, 224, 3)\n",
494 | "(47, 224, 224, 3)\n"
495 | ],
496 | "name": "stdout"
497 | }
498 | ]
499 | },
500 | {
501 | "cell_type": "code",
502 | "metadata": {
503 | "id": "4ZdxezcRHL9N",
504 | "colab_type": "code",
505 | "colab": {}
506 | },
507 | "source": [
508 | "x_train = np.concatenate((x_train_n, x_train_c))/255.0\n",
509 | "x_test = np.concatenate((x_test_n, x_test_c))/255.0"
510 | ],
511 | "execution_count": 15,
512 | "outputs": []
513 | },
514 | {
515 | "cell_type": "code",
516 | "metadata": {
517 | "id": "z_6N6l9xHPrP",
518 | "colab_type": "code",
519 | "colab": {}
520 | },
521 | "source": [
522 | "y_train_n = np.zeros(x_train_n.shape[0])\n",
523 | "y_train_c = np.ones(x_train_c.shape[0])\n",
524 | "y_test_n = np.zeros(x_test_n.shape[0])\n",
525 | "y_test_c = np.ones(x_test_c.shape[0])"
526 | ],
527 | "execution_count": 16,
528 | "outputs": []
529 | },
530 | {
531 | "cell_type": "code",
532 | "metadata": {
533 | "id": "JJxnOSioHOs0",
534 | "colab_type": "code",
535 | "colab": {}
536 | },
537 | "source": [
538 | "y_train = np.concatenate((y_train_n, y_train_c))\n",
539 | "y_test = np.concatenate((y_test_n, y_test_c))"
540 | ],
541 | "execution_count": 17,
542 | "outputs": []
543 | },
544 | {
545 | "cell_type": "code",
546 | "metadata": {
547 | "id": "uIcyqQlTJpKh",
548 | "colab_type": "code",
549 | "colab": {}
550 | },
551 | "source": [
552 | "y_train = np.expand_dims(y_train, -1)\n",
553 | "y_test = np.expand_dims(y_test, -1)"
554 | ],
555 | "execution_count": 18,
556 | "outputs": []
557 | },
558 | {
559 | "cell_type": "code",
560 | "metadata": {
561 | "id": "spSzm24VJe5O",
562 | "colab_type": "code",
563 | "colab": {
564 | "base_uri": "https://localhost:8080/",
565 | "height": 84
566 | },
567 | "outputId": "d9a1c551-ff94-448b-f3fd-a09b5a5730b4"
568 | },
569 | "source": [
570 | "print(x_train.shape)\n",
571 | "print(x_test.shape)\n",
572 | "print(y_train.shape)\n",
573 | "print(y_test.shape)"
574 | ],
575 | "execution_count": 19,
576 | "outputs": [
577 | {
578 | "output_type": "stream",
579 | "text": [
580 | "(296, 224, 224, 3)\n",
581 | "(94, 224, 224, 3)\n",
582 | "(296, 1)\n",
583 | "(94, 1)\n"
584 | ],
585 | "name": "stdout"
586 | }
587 | ]
588 | },
589 | {
590 | "cell_type": "markdown",
591 | "metadata": {
592 | "id": "S4nPCHWlDpbY",
593 | "colab_type": "text"
594 | },
595 | "source": [
596 | "# **define model**"
597 | ]
598 | },
599 | {
600 | "cell_type": "code",
601 | "metadata": {
602 | "id": "utkkMYzLDOAp",
603 | "colab_type": "code",
604 | "colab": {}
605 | },
606 | "source": [
607 | "# import the necessary packages\n",
608 | "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
609 | "from tensorflow.keras.applications import VGG16\n",
610 | "from tensorflow.keras.layers import AveragePooling2D\n",
611 | "from tensorflow.keras.layers import Dropout\n",
612 | "from tensorflow.keras.layers import Flatten\n",
613 | "from tensorflow.keras.layers import Dense\n",
614 | "from tensorflow.keras.layers import Input\n",
615 | "from tensorflow.keras.models import Model\n",
616 | "from tensorflow.keras.optimizers import Adam\n",
617 | "from tensorflow.keras.utils import to_categorical\n",
618 | "from sklearn.preprocessing import LabelBinarizer\n",
619 | "from sklearn.model_selection import train_test_split\n",
620 | "from sklearn.metrics import classification_report\n",
621 | "from sklearn.metrics import confusion_matrix\n",
622 | "from imutils import paths\n",
623 | "import matplotlib.pyplot as plt\n",
624 | "import numpy as np\n",
625 | "import argparse\n",
626 | "import cv2\n",
627 | "import os"
628 | ],
629 | "execution_count": 20,
630 | "outputs": []
631 | },
632 | {
633 | "cell_type": "code",
634 | "metadata": {
635 | "id": "j9e9uz_dGTGJ",
636 | "colab_type": "code",
637 | "colab": {}
638 | },
639 | "source": [
640 | "# initialize the initial learning rate, number of epochs to train for,\n",
641 | "# and batch size\n",
642 | "INIT_LR = 1e-3\n",
643 | "EPOCHS = 50\n",
644 | "BS = 8"
645 | ],
646 | "execution_count": 21,
647 | "outputs": []
648 | },
649 | {
650 | "cell_type": "code",
651 | "metadata": {
652 | "id": "8-ZzRZgRGbOb",
653 | "colab_type": "code",
654 | "colab": {}
655 | },
656 | "source": [
657 | "# initialize the training data augmentation object\n",
658 | "trainAug = ImageDataGenerator(\n",
659 | "\trotation_range=15,\n",
660 | "\tfill_mode=\"nearest\")"
661 | ],
662 | "execution_count": 22,
663 | "outputs": []
664 | },
665 | {
666 | "cell_type": "code",
667 | "metadata": {
668 | "id": "ZQj-zQ-JGiVv",
669 | "colab_type": "code",
670 | "colab": {
671 | "base_uri": "https://localhost:8080/",
672 | "height": 70
673 | },
674 | "outputId": "7b709beb-5f94-48f0-848d-ead6f944e2c2"
675 | },
676 | "source": [
677 | "# load the VGG16 network, ensuring the head FC layer sets are left\n",
678 | "# off\n",
679 | "baseModel = VGG16(weights=\"imagenet\", include_top=False,\n",
680 | "\tinput_tensor=Input(shape=(224, 224, 3)))\n",
681 | "# construct the head of the model that will be placed on top of the\n",
682 | "# the base model\n",
683 | "headModel = baseModel.output\n",
684 | "headModel = AveragePooling2D(pool_size=(4, 4))(headModel)\n",
685 | "headModel = Flatten(name=\"flatten\")(headModel)\n",
686 | "headModel = Dense(64, activation=\"relu\")(headModel)\n",
687 | "headModel = Dropout(0.5)(headModel)\n",
688 | "headModel = Dense(1, activation=\"sigmoid\")(headModel)\n",
689 | "# place the head FC model on top of the base model (this will become\n",
690 | "# the actual model we will train)\n",
691 | "model = Model(inputs=baseModel.input, outputs=headModel)\n",
692 | "# loop over all layers in the base model and freeze them so they will\n",
693 | "# *not* be updated during the first training process\n",
694 | "for layer in baseModel.layers:\n",
695 | "\tlayer.trainable = False"
696 | ],
697 | "execution_count": 23,
698 | "outputs": [
699 | {
700 | "output_type": "stream",
701 | "text": [
702 | "Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5\n",
703 | "58892288/58889256 [==============================] - 0s 0us/step\n"
704 | ],
705 | "name": "stdout"
706 | }
707 | ]
708 | },
709 | {
710 | "cell_type": "code",
711 | "metadata": {
712 | "id": "1q2hyqzfGlcU",
713 | "colab_type": "code",
714 | "colab": {
715 | "base_uri": "https://localhost:8080/",
716 | "height": 1000
717 | },
718 | "outputId": "e899105d-77b1-4d5d-880e-56f19fdf4c3b"
719 | },
720 | "source": [
721 | "# compile our model\n",
722 | "print(\"[INFO] compiling model...\")\n",
723 | "opt = Adam(lr=INIT_LR, decay=INIT_LR / EPOCHS)\n",
724 | "model.compile(loss=\"binary_crossentropy\", optimizer=opt,\n",
725 | "\tmetrics=[\"accuracy\"])\n",
726 | "# train the head of the network\n",
727 | "print(\"[INFO] training head...\")\n",
728 | "H = model.fit_generator(\n",
729 | "\ttrainAug.flow(x_train, y_train, batch_size=BS),\n",
730 | "\tsteps_per_epoch=len(x_train) // BS,\n",
731 | "\tvalidation_data=(x_test, y_test),\n",
732 | "\tvalidation_steps=len(x_test) // BS,\n",
733 | "\tepochs=EPOCHS)"
734 | ],
735 | "execution_count": 24,
736 | "outputs": [
737 | {
738 | "output_type": "stream",
739 | "text": [
740 | "[INFO] compiling model...\n",
741 | "[INFO] training head...\n",
742 | "WARNING:tensorflow:From :13: Model.fit_generator (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.\n",
743 | "Instructions for updating:\n",
744 | "Please use Model.fit, which supports generators.\n",
745 | "Epoch 1/50\n",
746 | "37/37 [==============================] - 5s 139ms/step - loss: 0.6236 - accuracy: 0.6520 - val_loss: 0.4694 - val_accuracy: 0.9468\n",
747 | "Epoch 2/50\n",
748 | "37/37 [==============================] - 4s 108ms/step - loss: 0.4234 - accuracy: 0.8480 - val_loss: 0.3183 - val_accuracy: 0.9468\n",
749 | "Epoch 3/50\n",
750 | "37/37 [==============================] - 4s 108ms/step - loss: 0.3176 - accuracy: 0.9155 - val_loss: 0.2423 - val_accuracy: 0.9574\n",
751 | "Epoch 4/50\n",
752 | "37/37 [==============================] - 4s 108ms/step - loss: 0.2675 - accuracy: 0.9122 - val_loss: 0.2038 - val_accuracy: 0.9468\n",
753 | "Epoch 5/50\n",
754 | "37/37 [==============================] - 4s 108ms/step - loss: 0.2471 - accuracy: 0.9189 - val_loss: 0.1899 - val_accuracy: 0.9574\n",
755 | "Epoch 6/50\n",
756 | "37/37 [==============================] - 4s 108ms/step - loss: 0.2090 - accuracy: 0.9392 - val_loss: 0.1580 - val_accuracy: 0.9574\n",
757 | "Epoch 7/50\n",
758 | "37/37 [==============================] - 4s 107ms/step - loss: 0.1850 - accuracy: 0.9493 - val_loss: 0.1663 - val_accuracy: 0.9574\n",
759 | "Epoch 8/50\n",
760 | "37/37 [==============================] - 4s 110ms/step - loss: 0.1551 - accuracy: 0.9459 - val_loss: 0.1494 - val_accuracy: 0.9681\n",
761 | "Epoch 9/50\n",
762 | "37/37 [==============================] - 4s 110ms/step - loss: 0.1647 - accuracy: 0.9493 - val_loss: 0.1320 - val_accuracy: 0.9681\n",
763 | "Epoch 10/50\n",
764 | "37/37 [==============================] - 4s 110ms/step - loss: 0.1231 - accuracy: 0.9662 - val_loss: 0.1285 - val_accuracy: 0.9681\n",
765 | "Epoch 11/50\n",
766 | "37/37 [==============================] - 4s 108ms/step - loss: 0.1397 - accuracy: 0.9493 - val_loss: 0.1168 - val_accuracy: 0.9787\n",
767 | "Epoch 12/50\n",
768 | "37/37 [==============================] - 4s 108ms/step - loss: 0.1209 - accuracy: 0.9696 - val_loss: 0.1252 - val_accuracy: 0.9681\n",
769 | "Epoch 13/50\n",
770 | "37/37 [==============================] - 4s 107ms/step - loss: 0.1034 - accuracy: 0.9764 - val_loss: 0.1101 - val_accuracy: 0.9787\n",
771 | "Epoch 14/50\n",
772 | "37/37 [==============================] - 4s 107ms/step - loss: 0.1115 - accuracy: 0.9595 - val_loss: 0.1311 - val_accuracy: 0.9681\n",
773 | "Epoch 15/50\n",
774 | "37/37 [==============================] - 4s 107ms/step - loss: 0.0916 - accuracy: 0.9696 - val_loss: 0.1219 - val_accuracy: 0.9681\n",
775 | "Epoch 16/50\n",
776 | "37/37 [==============================] - 4s 109ms/step - loss: 0.1037 - accuracy: 0.9662 - val_loss: 0.1416 - val_accuracy: 0.9681\n",
777 | "Epoch 17/50\n",
778 | "37/37 [==============================] - 4s 110ms/step - loss: 0.1053 - accuracy: 0.9628 - val_loss: 0.1376 - val_accuracy: 0.9681\n",
779 | "Epoch 18/50\n",
780 | "37/37 [==============================] - 4s 109ms/step - loss: 0.0840 - accuracy: 0.9764 - val_loss: 0.1014 - val_accuracy: 0.9787\n",
781 | "Epoch 19/50\n",
782 | "37/37 [==============================] - 4s 107ms/step - loss: 0.0718 - accuracy: 0.9730 - val_loss: 0.1123 - val_accuracy: 0.9681\n",
783 | "Epoch 20/50\n",
784 | "37/37 [==============================] - 4s 108ms/step - loss: 0.0683 - accuracy: 0.9730 - val_loss: 0.0975 - val_accuracy: 0.9681\n",
785 | "Epoch 21/50\n",
786 | "37/37 [==============================] - 4s 108ms/step - loss: 0.0708 - accuracy: 0.9865 - val_loss: 0.0992 - val_accuracy: 0.9787\n",
787 | "Epoch 22/50\n",
788 | "37/37 [==============================] - 4s 110ms/step - loss: 0.0740 - accuracy: 0.9831 - val_loss: 0.0986 - val_accuracy: 0.9787\n",
789 | "Epoch 23/50\n",
790 | "37/37 [==============================] - 4s 108ms/step - loss: 0.0815 - accuracy: 0.9730 - val_loss: 0.0928 - val_accuracy: 0.9787\n",
791 | "Epoch 24/50\n",
792 | "37/37 [==============================] - 4s 107ms/step - loss: 0.0638 - accuracy: 0.9797 - val_loss: 0.0921 - val_accuracy: 0.9787\n",
793 | "Epoch 25/50\n",
794 | "37/37 [==============================] - 4s 107ms/step - loss: 0.0672 - accuracy: 0.9899 - val_loss: 0.0948 - val_accuracy: 0.9787\n",
795 | "Epoch 26/50\n",
796 | "37/37 [==============================] - 4s 108ms/step - loss: 0.0637 - accuracy: 0.9764 - val_loss: 0.0881 - val_accuracy: 0.9787\n",
797 | "Epoch 27/50\n",
798 | "37/37 [==============================] - 4s 106ms/step - loss: 0.0581 - accuracy: 0.9899 - val_loss: 0.0995 - val_accuracy: 0.9681\n",
799 | "Epoch 28/50\n",
800 | "37/37 [==============================] - 4s 106ms/step - loss: 0.0561 - accuracy: 0.9797 - val_loss: 0.0873 - val_accuracy: 0.9787\n",
801 | "Epoch 29/50\n",
802 | "37/37 [==============================] - 4s 108ms/step - loss: 0.0523 - accuracy: 0.9899 - val_loss: 0.1072 - val_accuracy: 0.9681\n",
803 | "Epoch 30/50\n",
804 | "37/37 [==============================] - 4s 106ms/step - loss: 0.0644 - accuracy: 0.9764 - val_loss: 0.0988 - val_accuracy: 0.9681\n",
805 | "Epoch 31/50\n",
806 | "37/37 [==============================] - 4s 106ms/step - loss: 0.0535 - accuracy: 0.9831 - val_loss: 0.0851 - val_accuracy: 0.9787\n",
807 | "Epoch 32/50\n",
808 | "37/37 [==============================] - 4s 107ms/step - loss: 0.0480 - accuracy: 0.9899 - val_loss: 0.0952 - val_accuracy: 0.9681\n",
809 | "Epoch 33/50\n",
810 | "37/37 [==============================] - 4s 106ms/step - loss: 0.0458 - accuracy: 0.9865 - val_loss: 0.0916 - val_accuracy: 0.9787\n",
811 | "Epoch 34/50\n",
812 | "37/37 [==============================] - 4s 106ms/step - loss: 0.0388 - accuracy: 0.9932 - val_loss: 0.1017 - val_accuracy: 0.9681\n",
813 | "Epoch 35/50\n",
814 | "37/37 [==============================] - 4s 107ms/step - loss: 0.0419 - accuracy: 0.9865 - val_loss: 0.0887 - val_accuracy: 0.9787\n",
815 | "Epoch 36/50\n",
816 | "37/37 [==============================] - 4s 107ms/step - loss: 0.0490 - accuracy: 0.9831 - val_loss: 0.1045 - val_accuracy: 0.9681\n",
817 | "Epoch 37/50\n",
818 | "37/37 [==============================] - 4s 108ms/step - loss: 0.0554 - accuracy: 0.9865 - val_loss: 0.0823 - val_accuracy: 0.9787\n",
819 | "Epoch 38/50\n",
820 | "37/37 [==============================] - 4s 106ms/step - loss: 0.0474 - accuracy: 0.9865 - val_loss: 0.1017 - val_accuracy: 0.9681\n",
821 | "Epoch 39/50\n",
822 | "37/37 [==============================] - 4s 107ms/step - loss: 0.0375 - accuracy: 0.9865 - val_loss: 0.0863 - val_accuracy: 0.9787\n",
823 | "Epoch 40/50\n",
824 | "37/37 [==============================] - 4s 108ms/step - loss: 0.0261 - accuracy: 0.9966 - val_loss: 0.0862 - val_accuracy: 0.9787\n",
825 | "Epoch 41/50\n",
826 | "37/37 [==============================] - 4s 107ms/step - loss: 0.0368 - accuracy: 0.9899 - val_loss: 0.0959 - val_accuracy: 0.9681\n",
827 | "Epoch 42/50\n",
828 | "37/37 [==============================] - 4s 108ms/step - loss: 0.0262 - accuracy: 0.9932 - val_loss: 0.0879 - val_accuracy: 0.9787\n",
829 | "Epoch 43/50\n",
830 | "37/37 [==============================] - 4s 109ms/step - loss: 0.0347 - accuracy: 0.9865 - val_loss: 0.0725 - val_accuracy: 0.9787\n",
831 | "Epoch 44/50\n",
832 | "37/37 [==============================] - 4s 107ms/step - loss: 0.0340 - accuracy: 0.9932 - val_loss: 0.0813 - val_accuracy: 0.9787\n",
833 | "Epoch 45/50\n",
834 | "37/37 [==============================] - 4s 108ms/step - loss: 0.0338 - accuracy: 0.9932 - val_loss: 0.0870 - val_accuracy: 0.9787\n",
835 | "Epoch 46/50\n",
836 | "37/37 [==============================] - 4s 107ms/step - loss: 0.0212 - accuracy: 0.9966 - val_loss: 0.0966 - val_accuracy: 0.9681\n",
837 | "Epoch 47/50\n",
838 | "37/37 [==============================] - 4s 107ms/step - loss: 0.0326 - accuracy: 0.9865 - val_loss: 0.0823 - val_accuracy: 0.9787\n",
839 | "Epoch 48/50\n",
840 | "37/37 [==============================] - 4s 109ms/step - loss: 0.0278 - accuracy: 0.9932 - val_loss: 0.0751 - val_accuracy: 0.9787\n",
841 | "Epoch 49/50\n",
842 | "37/37 [==============================] - 4s 109ms/step - loss: 0.0304 - accuracy: 0.9899 - val_loss: 0.0750 - val_accuracy: 0.9787\n",
843 | "Epoch 50/50\n",
844 | "37/37 [==============================] - 4s 108ms/step - loss: 0.0412 - accuracy: 0.9797 - val_loss: 0.0665 - val_accuracy: 0.9894\n"
845 | ],
846 | "name": "stdout"
847 | }
848 | ]
849 | },
850 | {
851 | "cell_type": "markdown",
852 | "metadata": {
853 | "id": "n8WBvnWaEc2t",
854 | "colab_type": "text"
855 | },
856 | "source": [
857 | "# **prediction**"
858 | ]
859 | },
860 | {
861 | "cell_type": "code",
862 | "metadata": {
863 | "id": "ax8loRHIqXri",
864 | "colab_type": "code",
865 | "colab": {
866 | "base_uri": "https://localhost:8080/",
867 | "height": 1000
868 | },
869 | "outputId": "639b5206-0578-4247-9eea-946615358a08"
870 | },
871 | "source": [
872 | "model.predict(x_test)"
873 | ],
874 | "execution_count": 25,
875 | "outputs": [
876 | {
877 | "output_type": "execute_result",
878 | "data": {
879 | "text/plain": [
880 | "array([[3.3645108e-04],\n",
881 | " [1.2787229e-04],\n",
882 | " [3.0687204e-04],\n",
883 | " [1.8594293e-04],\n",
884 | " [8.8789128e-04],\n",
885 | " [2.2062780e-03],\n",
886 | " [8.5091895e-05],\n",
887 | " [2.3272367e-04],\n",
888 | " [6.8524624e-03],\n",
889 | " [2.3365160e-03],\n",
890 | " [4.6214741e-04],\n",
891 | " [4.8456270e-01],\n",
892 | " [2.1608867e-04],\n",
893 | " [9.8477978e-01],\n",
894 | " [3.7198868e-05],\n",
895 | " [5.3255004e-04],\n",
896 | " [3.3552865e-03],\n",
897 | " [3.3389160e-04],\n",
898 | " [9.6598919e-04],\n",
899 | " [9.8803849e-04],\n",
900 | " [6.0224649e-04],\n",
901 | " [6.5070181e-04],\n",
902 | " [3.8730315e-04],\n",
903 | " [2.2554809e-01],\n",
904 | " [3.8252815e-04],\n",
905 | " [4.0165377e-03],\n",
906 | " [1.2528970e-03],\n",
907 | " [9.4746792e-05],\n",
908 | " [4.5105047e-04],\n",
909 | " [2.0995762e-03],\n",
910 | " [4.0024426e-04],\n",
911 | " [2.2839874e-03],\n",
912 | " [7.5507886e-04],\n",
913 | " [4.5164018e-03],\n",
914 | " [1.0044893e-02],\n",
915 | " [4.7172961e-04],\n",
916 | " [1.3611122e-04],\n",
917 | " [7.9747280e-03],\n",
918 | " [4.3664771e-04],\n",
919 | " [7.0768646e-03],\n",
920 | " [3.5995257e-04],\n",
921 | " [2.3332106e-02],\n",
922 | " [3.4441127e-04],\n",
923 | " [3.0484118e-03],\n",
924 | " [5.2329859e-05],\n",
925 | " [1.0854526e-03],\n",
926 | " [1.5450418e-03],\n",
927 | " [9.9200833e-01],\n",
928 | " [9.9901772e-01],\n",
929 | " [9.9973100e-01],\n",
930 | " [9.9732852e-01],\n",
931 | " [9.9922824e-01],\n",
932 | " [9.9430370e-01],\n",
933 | " [9.9756771e-01],\n",
934 | " [9.5160627e-01],\n",
935 | " [9.9445552e-01],\n",
936 | " [8.8230652e-01],\n",
937 | " [9.9962628e-01],\n",
938 | " [9.9718899e-01],\n",
939 | " [9.9953568e-01],\n",
940 | " [9.2060763e-01],\n",
941 | " [9.9866438e-01],\n",
942 | " [9.9978584e-01],\n",
943 | " [9.9577904e-01],\n",
944 | " [9.9678314e-01],\n",
945 | " [9.9796391e-01],\n",
946 | " [9.9735093e-01],\n",
947 | " [9.9920744e-01],\n",
948 | " [9.9899489e-01],\n",
949 | " [9.9944538e-01],\n",
950 | " [9.9932742e-01],\n",
951 | " [9.9955171e-01],\n",
952 | " [9.9872005e-01],\n",
953 | " [9.8695487e-01],\n",
954 | " [9.9914718e-01],\n",
955 | " [9.9996626e-01],\n",
956 | " [9.9959117e-01],\n",
957 | " [9.9994636e-01],\n",
958 | " [9.9727505e-01],\n",
959 | " [9.9915981e-01],\n",
960 | " [9.9951649e-01],\n",
961 | " [9.9986589e-01],\n",
962 | " [9.9372864e-01],\n",
963 | " [9.7922003e-01],\n",
964 | " [9.9798799e-01],\n",
965 | " [9.9974900e-01],\n",
966 | " [9.9997377e-01],\n",
967 | " [9.6405590e-01],\n",
968 | " [9.2889494e-01],\n",
969 | " [9.9977261e-01],\n",
970 | " [9.9922121e-01],\n",
971 | " [9.9944884e-01],\n",
972 | " [6.2954772e-01],\n",
973 | " [8.8289243e-01]], dtype=float32)"
974 | ]
975 | },
976 | "metadata": {
977 | "tags": []
978 | },
979 | "execution_count": 25
980 | }
981 | ]
982 | },
983 | {
984 | "cell_type": "code",
985 | "metadata": {
986 | "id": "90LWTRpDVYaN",
987 | "colab_type": "code",
988 | "colab": {
989 | "base_uri": "https://localhost:8080/",
990 | "height": 134
991 | },
992 | "outputId": "1eb38e27-762d-4f5e-fd27-afd5926e62da"
993 | },
994 | "source": [
995 | "# predicts = model.predict(x_test)\n",
996 | "th = np.linspace(0.2, 0.8, 7)\n",
997 | "all_cms = []\n",
998 | "all_sens = []\n",
999 | "all_spec = []\n",
1000 | "all_acc = []\n",
1001 | "for t in th:\n",
1002 | " preds = model.predict(x_test)\n",
1003 | " preds[preds > t] = 1\n",
1004 | " preds[preds < t] = 0\n",
1005 | " cm = confusion_matrix(y_test, preds)\n",
1006 | " total = sum(sum(cm))\n",
1007 | " acc = (cm[0, 0] + cm[1, 1]) / total\n",
1008 | " all_acc.append(acc)\n",
1009 | " sensitivity = cm[0, 0] / (cm[0, 0] + cm[0, 1])\n",
1010 | " specificity = cm[1, 1] / (cm[1, 0] + cm[1, 1])\n",
1011 | " all_cms.append(cm)\n",
1012 | " all_sens.append(sensitivity)\n",
1013 | " all_spec.append(specificity)\n",
1014 | "print('best accuracy: ', all_acc[np.argmax(all_acc)])\n",
1015 | "print('threshold: ', th[np.argmax(all_acc)])\n",
1016 | "print('confusion matrix: ', '\\n', all_cms[np.argmax(all_acc)])\n",
1017 | "print('sensitivity: ', all_sens[np.argmax(all_acc)])\n",
1018 | "print('specificity: ', all_spec[np.argmax(all_acc)])\n",
1019 | "\n"
1020 | ],
1021 | "execution_count": 26,
1022 | "outputs": [
1023 | {
1024 | "output_type": "stream",
1025 | "text": [
1026 | "best accuracy: 0.9893617021276596\n",
1027 | "threshold: 0.5\n",
1028 | "confusion matrix: \n",
1029 | " [[46 1]\n",
1030 | " [ 0 47]]\n",
1031 | "sensitivity: 0.9787234042553191\n",
1032 | "specificity: 1.0\n"
1033 | ],
1034 | "name": "stdout"
1035 | }
1036 | ]
1037 | }
1038 | ]
1039 | }
--------------------------------------------------------------------------------