├── .gitignore ├── .ipynb_checkpoints ├── test_eval-checkpoint.ipynb └── test_train-checkpoint.ipynb ├── JupyterNotebooks ├── .ipynb_checkpoints │ └── test_evaluation-exclude-normal-checkpoint.ipynb ├── 100RandomExamples.ipynb ├── Final evaluation and viz.ipynb ├── new_transforms_examples.ipynb └── test_evaluation-exclude-normal.ipynb ├── Readme.md ├── TSNE └── tsne_visualize.ipynb ├── Tiling ├── 0b_tileLoop_deepzoom2.py ├── 0d_SortTiles.py └── BuildTileDictionary.py ├── checkpoints ├── breast_subtype.pth ├── kidney_subtype.pth └── lung_suptype.pth ├── pan-cancer ├── confusion-matrix.ipynb ├── nationwidechildrens.org_clinical_patient_brca.csv └── tsne_combine_her2_er_pr.ipynb ├── run-jupyter.sbatch ├── run_job.sh ├── run_multiple_test.sh ├── run_test.sh ├── run_tsne.sh ├── test.py ├── test_eval.ipynb ├── test_train.ipynb ├── train.py ├── tsne.py └── utils ├── .ipynb_checkpoints └── Untitled-checkpoint.ipynb ├── CLR.py ├── __pycache__ ├── __init__.cpython-35.pyc ├── auc.cpython-35.pyc ├── auc.cpython-36.pyc ├── auc_all_class.cpython-35.pyc ├── auc_test.cpython-35.pyc ├── dataloader.cpython-35.pyc ├── dataloader.cpython-36.pyc ├── eval.cpython-36.pyc ├── eval.cpython-37.pyc ├── new_transforms.cpython-35.pyc └── new_transforms.cpython-36.pyc ├── auc.py ├── confusion_matrix.py ├── dataloader.py ├── eval.py ├── extras ├── auc_all_class.py └── auc_test.py └── new_transforms.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Logs # 2 | ###################### 3 | *.log 4 | *.out 5 | *.err 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/test_train-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 2 6 | } 7 | -------------------------------------------------------------------------------- /JupyterNotebooks/100RandomExamples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 4, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "ename": "ModuleNotFoundError", 10 | "evalue": "No module named 'new_transforms'", 11 | "output_type": "error", 12 | "traceback": [ 13 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 14 | "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", 15 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcollections\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtypes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mnew_transforms\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 16 | "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'new_transforms'" 17 | ] 18 | } 19 | ], 20 | "source": [ 21 | "import numpy as np\n", 22 | "import math\n", 23 | "import random\n", 24 | "from PIL import Image, ImageOps, ImageEnhance\n", 25 | "import collections\n", 26 | "import types\n", 27 | "import new_transforms\n", 28 | "import os\n", 29 | "\n", 30 | "\"\"\"\n", 31 | "Original code from new_transforms.py\n", 32 | "\n", 33 | "Taken directly from https://github.com/pytorch/vision/blob/master/torchvision/transforms.py\n", 34 | "Latest update that is not currently deployed to pip.\n", 35 | "All credits to the torchvision developers.\n", 36 | "\"\"\"\n", 37 | "\n", 38 | "# Particulars for this ipython notebook: \n", 39 | "import matplotlib.pyplot as plt\n", 40 | "from matplotlib.pyplot import imshow\n", 41 | "%matplotlib inline\n", 42 | "\n", 43 | "dir_lungTilesNormal = \"/beegfs/jmw784/Capstone/LungTilesSorted/Solid_Tissue_Normal/\"\n", 44 | "dir_lungTilesLUSC = \"/beegfs/jmw784/Capstone/LungTilesSorted/TCGA-LUSC/\"\n", 45 | "dir_lungTilesLUAD = \"/beegfs/jmw784/Capstone/LungTilesSorted/TCGA-LUAD/\"" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "# 100 examples: NORMAL" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 1, 58 | "metadata": { 59 | "scrolled": false 60 | }, 61 | "outputs": [ 62 | { 63 | "ename": "NameError", 64 | "evalue": "name 'os' is not defined", 65 | "output_type": "error", 66 | "traceback": [ 67 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 68 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 69 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mfiles_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlistdir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdir_lungTilesNormal\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mcRandomRotate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnew_transforms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mRandomRotate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mcJitter\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnew_transforms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mColorJitter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbrightness\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.25\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontrast\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.25\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msaturation\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.25\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhue\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.05\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mcVerticalFlip\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnew_transforms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mRandomVerticalFlip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 70 | "\u001b[0;31mNameError\u001b[0m: name 'os' is not defined" 71 | ] 72 | } 73 | ], 74 | "source": [ 75 | "files_list = os.listdir(dir_lungTilesNormal)\n", 76 | "\n", 77 | "cRandomRotate = new_transforms.RandomRotate()\n", 78 | "cJitter = new_transforms.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.05)\n", 79 | "cVerticalFlip = new_transforms.RandomVerticalFlip()\n", 80 | "\n", 81 | "for repet in range(0, 100):\n", 82 | " random = np.random.randint(0,len(files_list))\n", 83 | " image = Image.open(dir_lungTilesNormal + files_list[random])\n", 84 | " modified = cRandomRotate(image)\n", 85 | " modified = cJitter(modified)\n", 86 | " modified = cVerticalFlip(modified)\n", 87 | " \n", 88 | " f, (ax1, ax2) = plt.subplots(sharex=True, sharey=False, nrows=1, ncols=2)\n", 89 | " f.set_figheight(4)\n", 90 | " f.set_figwidth(8)\n", 91 | " ax1.imshow(np.asarray(image))\n", 92 | " ax2.imshow(np.asarray(modified))\n", 93 | " ax1.set_title(\"Original\")\n", 94 | " ax2.set_title(\"Modified\")" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": { 100 | "collapsed": true 101 | }, 102 | "source": [ 103 | "# 100 examples: LUSC" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 2, 109 | "metadata": { 110 | "scrolled": false 111 | }, 112 | "outputs": [ 113 | { 114 | "ename": "NameError", 115 | "evalue": "name 'os' is not defined", 116 | "output_type": "error", 117 | "traceback": [ 118 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 119 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 120 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mfiles_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlistdir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdir_lungTilesLUSC\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mcRandomRotate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnew_transforms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mRandomRotate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mcJitter\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnew_transforms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mColorJitter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbrightness\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.25\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontrast\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.25\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msaturation\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.25\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhue\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.05\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mcVerticalFlip\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnew_transforms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mRandomVerticalFlip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 121 | "\u001b[0;31mNameError\u001b[0m: name 'os' is not defined" 122 | ] 123 | } 124 | ], 125 | "source": [ 126 | "files_list = os.listdir(dir_lungTilesLUSC)\n", 127 | "\n", 128 | "cRandomRotate = new_transforms.RandomRotate()\n", 129 | "cJitter = new_transforms.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.05)\n", 130 | "cVerticalFlip = new_transforms.RandomVerticalFlip()\n", 131 | "\n", 132 | "for repet in range(0, 100):\n", 133 | " random = np.random.randint(0,len(files_list))\n", 134 | " image = Image.open(dir_lungTilesLUSC + files_list[random])\n", 135 | " modified = cRandomRotate(image)\n", 136 | " modified = cJitter(modified)\n", 137 | " modified = cVerticalFlip(modified)\n", 138 | " \n", 139 | " f, (ax1, ax2) = plt.subplots(sharex=True, sharey=False, nrows=1, ncols=2)\n", 140 | " f.set_figheight(4)\n", 141 | " f.set_figwidth(8)\n", 142 | " ax1.imshow(np.asarray(image))\n", 143 | " ax2.imshow(np.asarray(modified))\n", 144 | " ax1.set_title(\"Original\")\n", 145 | " ax2.set_title(\"Modified\")" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "metadata": { 151 | "collapsed": true 152 | }, 153 | "source": [ 154 | "# 100 examples: LUAD" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 3, 160 | "metadata": { 161 | "scrolled": false 162 | }, 163 | "outputs": [ 164 | { 165 | "ename": "NameError", 166 | "evalue": "name 'os' is not defined", 167 | "output_type": "error", 168 | "traceback": [ 169 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 170 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 171 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mfiles_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlistdir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdir_lungTilesLUAD\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mcRandomRotate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnew_transforms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mRandomRotate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mcJitter\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnew_transforms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mColorJitter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbrightness\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.25\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontrast\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.25\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msaturation\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.25\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhue\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.05\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mcVerticalFlip\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnew_transforms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mRandomVerticalFlip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 172 | "\u001b[0;31mNameError\u001b[0m: name 'os' is not defined" 173 | ] 174 | } 175 | ], 176 | "source": [ 177 | "files_list = os.listdir(dir_lungTilesLUAD)\n", 178 | "\n", 179 | "cRandomRotate = new_transforms.RandomRotate()\n", 180 | "cJitter = new_transforms.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.05)\n", 181 | "cVerticalFlip = new_transforms.RandomVerticalFlip()\n", 182 | "\n", 183 | "for repet in range(0, 100):\n", 184 | " random = np.random.randint(0,len(files_list))\n", 185 | " image = Image.open(dir_lungTilesLUAD + files_list[random])\n", 186 | " modified = cRandomRotate(image)\n", 187 | " modified = cJitter(modified)\n", 188 | " modified = cVerticalFlip(modified)\n", 189 | " \n", 190 | " f, (ax1, ax2) = plt.subplots(sharex=True, sharey=False, nrows=1, ncols=2)\n", 191 | " f.set_figheight(4)\n", 192 | " f.set_figwidth(8)\n", 193 | " ax1.imshow(np.asarray(image))\n", 194 | " ax2.imshow(np.asarray(modified))\n", 195 | " ax1.set_title(\"Original\")\n", 196 | " ax2.set_title(\"Modified\")" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": null, 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [] 205 | } 206 | ], 207 | "metadata": { 208 | "kernelspec": { 209 | "display_name": "Python 3", 210 | "language": "python", 211 | "name": "python3" 212 | }, 213 | "language_info": { 214 | "codemirror_mode": { 215 | "name": "ipython", 216 | "version": 3 217 | }, 218 | "file_extension": ".py", 219 | "mimetype": "text/x-python", 220 | "name": "python", 221 | "nbconvert_exporter": "python", 222 | "pygments_lexer": "ipython3", 223 | "version": "3.7.3" 224 | } 225 | }, 226 | "nbformat": 4, 227 | "nbformat_minor": 2 228 | } 229 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # Efficient pan-cancer whole-slide image classification using convolutional neural networks 2 | 3 | Code acompaining paper: [Efficient pan-cancer whole-slide image classification and outlier detection using convolutional neural networks ](https://www.biorxiv.org/content/early/2019/05/14/633123.full.pdf) 4 | 5 | # Data: 6 | 7 | * To demo the software/code described in the manuscript you can download the [GDC data transfer API](https://gdc.cancer.gov/access-data/gdc-data-transfer-tool) 8 | * Create a manifest by selecting Cases > CANCER_TYPE and Files > Data Type > Tissue Slide Image. 9 | * Download the manifest into ```manifest_file``` 10 | * Run the command ```gdc-client download -m manifest_file``` in Terminal 11 | 12 | ## 1. System requirements 13 | * software dependencies python3, PyTorch (software has been tested on Unix machine) 14 | 15 | ## 2. Installation guide 16 | * Instructions 17 | Clone this repo to your local machine using: 18 | ``` 19 | git clone https://github.com/sedab/PathCNN.git 20 | 21 | ``` 22 | * Typical install time on a "normal" desktop computer is around 3 minutes. 23 | 24 | ## 3. Demo 25 | * Expected train and test time are described in the manuscript. 26 | 27 | ### 3.1. Data processing: 28 | 29 | Note that data tiling and sorting scripts come from [Nicolas Coudray](https://github.com/ncoudray/DeepPATH/). Please refer to the README within `DeepPATH_code` for the full range of options. Additionally, note that these scripts may take a significant amount of computing power. We recommend submitting sections 2.1 and 2.2 to a high performance computing cluster with multiple CPUs. 30 | 31 | #### 3.1.1. Data tiling 32 | Run ```Tiling/0b_tileLoop_deepzoom2.py``` to tile the .svs images into .jpeg images. To replicate this particular project, select the following specifications: 33 | 34 | ```sh 35 | python -u Tiling/0b_tileLoop_deepzoom2.py -s 512 -e 0 -j 28 -f jpeg -B 25 -o "/*/*svs" 36 | ``` 37 | 38 | * ``: Path to the outer directory of the original svs files 39 | 40 | * ``: Path to which the tile files will be saved 41 | 42 | * `-s 512`: Tile size of 512x512 pixels 43 | 44 | * `-e 0`: Zero overlap in pixels for tiles 45 | 46 | * `-j 28`: 28 CPU threads 47 | 48 | * `-f jpeg`: jpeg files 49 | 50 | * `-B 25`: 25% allowed background within a tile. 51 | 52 | #### 3.1.2. Data sorting 53 | To ensure that the later sections work properly, we recommend running these commands within ``, the directory in which your images will be stored: 54 | 55 | ```sh 56 | mkdir TilesSorted 57 | cd TilesSorted 58 | ``` 59 | 60 | * ``: The dataset such as `'Lung'`, `'Breast'`, or `'Kidney'` 61 | 62 | Next, run `Tiling/0d_SortTiles.py` to sort the tiles into train, valid and test datasets with the following specifications. 63 | 64 | ```sh 65 | python -u /Tiling/0d_SortTiles.py --SourceFolder="" --JsonFile="" --Magnification=20 --MagDiffAllowed=0 --SortingOption=3 --PercentTest=15 --PercentValid=15 --PatientID=12 --nSplit 0 66 | ``` 67 | 68 | * ``: The full path to the cloned repository 69 | 70 | * ``: Path in which the tile files were saved, should be the same as `` of step 2.1. 71 | 72 | * ``: Path to the JSON file that was downloaded with the .svs tiles 73 | 74 | * `--Magnification=20`: Magnification at which the tiles should be considered (20x) 75 | 76 | * `--MagDiffAllowed=0`: If the requested magnification does not exist for a given slide, take the nearest existing magnification but only if it is at +/- the amount allowed here (0) 77 | 78 | * `--SortingOption=3`: Sort according to type of cancer (types of cancer + Solid Tissue Normal) 79 | 80 | * `--PercentValid=15 --PercentTest=15` The percentage of data to be assigned to the validation and test set. In this case, it will result in a 70 / 15 / 15 % train-valid-test split. 81 | 82 | * `--PatientID=12` This option makes sure that the tiles corresponding to one patient are either on the test set, valid set or train set, but not divided among these categories. 83 | 84 | * `--nSplit=0` If nSplit > 0, it overrides the existing PercentTest and PercentTest options, splitting the data into n even categories. 85 | 86 | #### 3.1.3. Build tile dictionary 87 | 88 | Run `Tiling/BuildTileDictionary.py` to build a dictionary of slides that is used to map each slide to a 2D array of tile paths and the true label. This is used in the `aggregate` function during training and evaluation. 89 | 90 | ```sh 91 | python3 -u Tiling/BuildTileDictionary.py --data --file_path --train_log /gpfs/scratch/bilals01/test-repo/logs/exp6_train.log 92 | ``` 93 | * --file_path `` points to the directory path for which the sorted tiles folder is stored in, same as in 2.2. 94 | 95 | * --data `` is the base name for the given type. 96 | 97 | * train_log points to the log file from training. (This option is only needed if you are testing data where not all the classes are presented ) 98 | 99 | Note that this code assumes that the sorted tiles are stored in `TilesSorted`. If you do not follow this convention, you may need to modify this code. 100 | 101 | ### 3.2 Train model: 102 | 103 | Run `train.py` to train with our CNN architecture. sbatch file `run_job.sh` is provided as an example script for submitting a GPU job for this script. Inside the run_job.sh, set parameteres: nexp, output and param as described below. You need to create two directories where the output of the training will be saved at: one for experiemnts, and one for logs. 104 | 105 | * `exp_name` = "exp8" 106 | 107 | * `nexp` = "dir/experiments/${exp_name}" : will create a subfolder under experiments folder which will save the checkpoints and predcitions, experiment name is determined by the param `exp_name` 108 | 109 | * `output` = "dir/logs/${exp_name}.log" : will create a log file under the logs folder where the training output will be printed on,log name is determined by the param `exp_name` 110 | 111 | * `nparam` = "--cuda --augment --dropout=0.1 --nonlinearity='leaky' --init=‘xavier’ --root_dir=/gpfs/scratch/bilals01/brain-kidney-lung/brain-kidney-lungTilesSorted/ --num_class=7 --tile_dict_path=/gpfs/scratch/bilals01/brain-kidney-lung/brain-kidney-lung_FileMappingDict.p" 112 | 113 | The model checkpoints at every epoch and steps (frequency determined by the user using step_freq) will be saved at experiments/checkpoints folder. And the **validation set** predictions and labels will be saved under experiments/outputs folder if calc_val_auc argument is used (Note that total training time will increase significantly if you choose to use this option). 114 | 115 | **nparam:** 116 | * `--cuda`: enables cuda 117 | 118 | * `--ngpu`: number of GPUs to use (default=1) 119 | 120 | * `--augment`: use data augmentation or not 121 | 122 | * `--batchSize`: batch size for data loaders (default=32) 123 | 124 | * `--imgSize`: the height / width that the image will be shrunk to (default=299) 125 | 126 | * `--metadata`: use metadata or not 127 | 128 | **IMPORTANT NOTE: this option is not fully implemented!** Please see section 6 for additional information about using the metadata. 129 | 130 | * `--nc`: input image channels + concatenated info channels if metadata = True (default = 3 for RGB). 131 | 132 | * `--niter`: number of epochs to train for (default=25) 133 | 134 | * `--lr`: learning rate for the optimizer (default=0.001) 135 | 136 | * `--decay_lr`: activate decay learning rate function 137 | 138 | * `--optimizer`: Adam, SGD or RMSprop (default=Adam) 139 | 140 | * `--beta1`: beta1 for Adam (default=0.5) 141 | 142 | * `--earlystop`: use early stopping 143 | 144 | * `--init`: initialization method (default=normal, xavier, kaiming) 145 | 146 | * `--model`: path to model to continue training from a checkpoint (default='') 147 | 148 | * `--experiment`: where to store samples and models (default=None) 149 | 150 | * `--nonlinearity`: nonlinearity to use (selu, prelu, leaky, default=relu) 151 | 152 | * `--dropout`: probability of dropout in each block (default=0.5) 153 | 154 | * `--method`: aggregation prediction method (max, default=average) 155 | 156 | * `--num_class`: number of classes (default=2) 157 | 158 | * `--root_dir`: path to your sorted tiles Data directory .../dataTilesSorted/ (format="TilesSorted/") 159 | 160 | * `--tile_dict_path`: path to your Tile dictinory path (format="_FileMappingDict.p") 161 | 162 | * `--step_freq`: how often to save the results and the checkpoints (default=100000000; won't save any steps, as this number set very high) 163 | 164 | * `--calc_val_auc`: save validation auc calculatio at each epoch 165 | 166 | ### 3.3 Test model: 167 | 168 | Run ```test.py``` to evaluate a specific model on the test/validation data, ```run_test.sh``` is the associated sbatch file. Inside the run_test.sh, set parameteres: nexp, output and param as described below. 169 | 170 | * `exp_name` = "exp8" : set a name for the experiment 171 | 172 | * `test_val` = "test" : choose between test and valid 173 | 174 | * `nexp` = "dir/experiments/${exp_name}" : same exp directory set for the training, experiment name is determined by the param `exp_name` 175 | 176 | * `output` = "dir/logs/${test_val}.log" : name the test log with "*_test.log", log name is determined by the param `exp_name` 177 | 178 | * `nparam` = "--model='epoch_2.pth' --root_dir=/gpfs/data/abl/deepomics/tsirigoslab/histopathology/Tiles/LungTilesSorted/ --num_class=7 --tile_dict_path=/gpfs/data/abl/deepomics/tsirigoslab/histopathology/Tiles/Lung_FileMappingDict.p --val=${test_val}" 179 | 180 | **nparam:** 181 | * `--model`: Name of model to test, e.g. `epoch_10.pth` 182 | 183 | * `--num_class`: number of classes (default=2) 184 | 185 | * `--root_dir`: path to your sorted tiles Data directory .../dataTilesSorted/ (format="TilesSorted/") 186 | 187 | * `--tile_dict_path`: path to your Tile dictinory path (format="_FileMappingDict.p") 188 | 189 | * `--val`: validation vs test (default='test', or use 'valid'), use `test_val` to set the parameter 190 | 191 | * `--train_log`: log file from training (default='') 192 | 193 | The output data will be dumped under experiments/experiment_name folder. 194 | 195 | * To run the test data with multiple check points, use run_multiple_test.sh script. Set the experiment, count (epoch number to start from) and step (increment of epoch number) variables in the script accordingly. 196 | 197 | Note: If number of classes presented is less than what the model is trained for, you will need to pass the log file created by the model as input to the test script using `--train_log` parameter 198 | 199 | 200 | ### 3.4 Evaluation: 201 | 202 | Use test_eval.ipynb to create the ROC curves and calculate the confidence intervals. To start a jupyter notebook on bigpurple, submit the run-jupyter.sbatch script and the follow the instructions on the output file. 203 | 204 | ### 3.5 TSNE Analysis: 205 | 206 | Once the model is trained, run ```tsne.py``` to extract the last layer weights to create the TSNE plots, ```run_tsne.sh``` is the associated sbatch file. 207 | 208 | **sbatch run_tsne.sh "--root_dir=/gpfs/scratch/bilals01/brain-kidney-lung/brain-kidney-lungTilesSorted/ --num_class=7 --tile_dict_path=/gpfs/scratch/bilals01/brain-kidney-lung/brain-kidney-lung_FileMappingDict.p --val=test" test** 209 | 210 | * `--num_class`: number of classes (default=2) 211 | 212 | * `--root_dir`: path to your sorted tiles Data directory .../dataTilesSorted/ (format="TilesSorted/") 213 | 214 | * `--tile_dict_path`: path to your Tile dictinory path (format="_FileMappingDict.p") 215 | 216 | * `--val`: validation vs test (default='test', or use 'valid') 217 | 218 | The output data will be saved at tsne_data folder 219 | 220 | * Use TSNE/tsne_visualize.ipynb to visualize the results (change the input file name to match with tsne.py output files at tsne_data as needed) 221 | 222 | 223 | ## 4. Instructions for use 224 | * Use the model checkpoints in files in checkpoints folder and follow 4. Test model instructions in the above section. 225 | 226 | ## 5. Additional resources: 227 | 228 | ### iPython Notebooks 229 | 230 | * ```100RandomExamples.ipynb``` visualizes of 100 random examples of tiles in the datasets 231 | * ```Final evaluation and viz.ipynb``` provides code for visualizing the output prediction of a model, and also for evaluating a model on the test set on CPU 232 | * ```new_transforms_examples.ipynb``` visualizes a few examples of the data augmentation used for training. One can tune the data augmentation here. 233 | 234 | -------------------------------------------------------------------------------- /Tiling/0b_tileLoop_deepzoom2.py: -------------------------------------------------------------------------------- 1 | ''' 2 | File name: 0b_tileLoop_deepzoom.py 3 | Date created: March/2017 4 | Date last modified: 2/25/2017 5 | Python Version: 2.7 (native on the cluster 6 | Source: 7 | Tiling code comes from: 8 | from https://github.com/openslide/openslide-python/blob/master/examples/deepzoom/deepzoom_tile.py 9 | which is Copyright (c) 2010-2015 Carnegie Mellon University 10 | Objective: 11 | Tile svs images 12 | Be careful: 13 | Overload of the node - may have memory issue if node is shared with other jobs. 14 | Initial tests: 15 | tested on Test_20_tiled/Test2 and Test5 using imgExample = "/ifs/home/coudrn01/NN/Lung/Test_20imgs/*/*svs" 16 | ''' 17 | 18 | from __future__ import print_function 19 | import json 20 | import openslide 21 | from openslide import open_slide, ImageSlide 22 | from openslide.deepzoom import DeepZoomGenerator 23 | from optparse import OptionParser 24 | import re 25 | import shutil 26 | from unicodedata import normalize 27 | import numpy as np 28 | import subprocess 29 | from glob import glob 30 | from multiprocessing import Process, JoinableQueue 31 | import time 32 | import os 33 | import sys 34 | 35 | 36 | VIEWER_SLIDE_NAME = 'slide' 37 | 38 | 39 | class TileWorker(Process): 40 | """A child process that generates and writes tiles.""" 41 | 42 | def __init__(self, queue, slidepath, tile_size, overlap, limit_bounds,quality, _Bkg): 43 | Process.__init__(self, name='TileWorker') 44 | self.daemon = True 45 | self._queue = queue 46 | self._slidepath = slidepath 47 | self._tile_size = tile_size 48 | self._overlap = overlap 49 | self._limit_bounds = limit_bounds 50 | self._quality = quality 51 | self._slide = None 52 | self._Bkg = _Bkg 53 | 54 | def run(self): 55 | self._slide = open_slide(self._slidepath) 56 | last_associated = None 57 | dz = self._get_dz() 58 | while True: 59 | data = self._queue.get() 60 | if data is None: 61 | self._queue.task_done() 62 | break 63 | #associated, level, address, outfile = data 64 | associated, level, address, outfile, format, outfile_bw = data 65 | if last_associated != associated: 66 | dz = self._get_dz(associated) 67 | last_associated = associated 68 | #try: 69 | if True: 70 | try: 71 | tile = dz.get_tile(level, address) 72 | # A single tile is being read 73 | #nc added: check the percentage of the image with "information". Should be above 50% 74 | gray = tile.convert('L') 75 | bw = gray.point(lambda x: 0 if x<220 else 1, 'F') 76 | arr = np.array(np.asarray(bw)) 77 | avgBkg = np.average(bw) 78 | bw = gray.point(lambda x: 0 if x<220 else 1, '1') 79 | #outfile = os.path.join(outfile, '%s.%s' % (str(round(avgBkg, 3)),format) ) 80 | #outfile_bw = os.path.join(outfile_bw, '%s.%s' % (str(round(avgBkg, 3)),format) ) 81 | # bw.save(outfile_bw, quality=self._quality) 82 | if avgBkg < (self._Bkg / 100): 83 | tile.save(outfile, quality=self._quality) 84 | #print("%s good: %f" %(outfile, avgBkg)) 85 | #else: 86 | #print("%s empty: %f" %(outfile, avgBkg)) 87 | self._queue.task_done() 88 | except: 89 | print(level, address) 90 | print("image %s failed at dz.get_tile for level %f" % (self._slidepath, level)) 91 | self._queue.task_done() 92 | 93 | def _get_dz(self, associated=None): 94 | if associated is not None: 95 | image = ImageSlide(self._slide.associated_images[associated]) 96 | else: 97 | image = self._slide 98 | return DeepZoomGenerator(image, self._tile_size, self._overlap, limit_bounds=self._limit_bounds) 99 | 100 | 101 | class DeepZoomImageTiler(object): 102 | """Handles generation of tiles and metadata for a single image.""" 103 | 104 | def __init__(self, dz, basename, format, associated, queue, slide, basenameJPG): 105 | self._dz = dz 106 | self._basename = basename 107 | self._basenameJPG = basenameJPG 108 | self._format = format 109 | self._associated = associated 110 | self._queue = queue 111 | self._processed = 0 112 | self._slide = slide 113 | 114 | def run(self): 115 | self._write_tiles() 116 | self._write_dzi() 117 | 118 | def _write_tiles(self): 119 | ########################################3 120 | # nc_added 121 | #level = self._dz.level_count-1 122 | Magnification = 20 123 | tol = 2 124 | #get slide dimensions, zoom levels, and objective information 125 | Factors = self._slide.level_downsamples 126 | try: 127 | Objective = float(self._slide.properties[openslide.PROPERTY_NAME_OBJECTIVE_POWER]) 128 | print(self._basename + " - Obj information found") 129 | except: 130 | print(self._basename + " - No Obj information found") 131 | return 132 | #calculate magnifications 133 | Available = tuple(Objective / x for x in Factors) 134 | #find highest magnification greater than or equal to 'Desired' 135 | Mismatch = tuple(x-Magnification for x in Available) 136 | AbsMismatch = tuple(abs(x) for x in Mismatch) 137 | if len(AbsMismatch) < 1: 138 | print(self._basename + " - Objective field empty!") 139 | return 140 | if(min(AbsMismatch) <= tol): 141 | Level = int(AbsMismatch.index(min(AbsMismatch))) 142 | Factor = 1 143 | else: #pick next highest level, downsample 144 | Level = int(max([i for (i, val) in enumerate(Mismatch) if val > 0])) 145 | Factor = Magnification / Available[Level] 146 | # end added 147 | #for level in range(self._dz.level_count): 148 | for level in range(self._dz.level_count-1,-1,-1): 149 | ThisMag = Available[0]/pow(2,self._dz.level_count-(level+1)) 150 | ######################################## 151 | #tiledir = os.path.join("%s_files" % self._basename, str(level)) 152 | tiledir = os.path.join("%s_files" % self._basename, str(ThisMag)) 153 | if not os.path.exists(tiledir): 154 | os.makedirs(tiledir) 155 | cols, rows = self._dz.level_tiles[level] 156 | for row in range(rows): 157 | for col in range(cols): 158 | InsertBaseName = False 159 | if InsertBaseName: 160 | tilename = os.path.join(tiledir, '%s_%d_%d.%s' % ( 161 | self._basenameJPG, col, row, self._format)) 162 | tilename_bw = os.path.join(tiledir, '%s_%d_%d_bw.%s' % ( 163 | self._basenameJPG, col, row, self._format)) 164 | else: 165 | tilename = os.path.join(tiledir, '%d_%d.%s' % ( 166 | col, row, self._format)) 167 | tilename_bw = os.path.join(tiledir, '%d_%d_bw.%s' % ( 168 | col, row, self._format)) 169 | 170 | 171 | if not os.path.exists(tilename): 172 | self._queue.put((self._associated, level, (col, row), 173 | tilename, self._format, tilename_bw)) 174 | self._tile_done() 175 | 176 | def _tile_done(self): 177 | self._processed += 1 178 | count, total = self._processed, self._dz.tile_count 179 | if count % 100 == 0 or count == total: 180 | print("Tiling %s: wrote %d/%d tiles" % ( 181 | self._associated or 'slide', count, total), 182 | end='\r', file=sys.stderr) 183 | if count == total: 184 | print(file=sys.stderr) 185 | 186 | def _write_dzi(self): 187 | with open('%s.dzi' % self._basename, 'w') as fh: 188 | fh.write(self.get_dzi()) 189 | 190 | def get_dzi(self): 191 | return self._dz.get_dzi(self._format) 192 | 193 | 194 | class DeepZoomStaticTiler(object): 195 | """Handles generation of tiles and metadata for all images in a slide.""" 196 | 197 | def __init__(self, slidepath, basename, format, tile_size, overlap, 198 | limit_bounds, quality, workers, with_viewer, Bkg, basenameJPG): 199 | if with_viewer: 200 | # Check extra dependency before doing a bunch of work 201 | import jinja2 202 | print("line226 - %s " % (slidepath) ) 203 | self._slide = open_slide(slidepath) 204 | self._basename = basename 205 | self._basenameJPG = basenameJPG 206 | self._format = format 207 | self._tile_size = tile_size 208 | self._overlap = overlap 209 | self._limit_bounds = limit_bounds 210 | self._queue = JoinableQueue(2 * workers) 211 | self._workers = workers 212 | self._with_viewer = with_viewer 213 | self._Bkg = Bkg 214 | self._dzi_data = {} 215 | for _i in range(workers): 216 | TileWorker(self._queue, slidepath, tile_size, overlap, 217 | limit_bounds, quality, self._Bkg).start() 218 | 219 | def run(self): 220 | self._run_image() 221 | if self._with_viewer: 222 | for name in self._slide.associated_images: 223 | self._run_image(name) 224 | self._write_html() 225 | self._write_static() 226 | self._shutdown() 227 | 228 | def _run_image(self, associated=None): 229 | """Run a single image from self._slide.""" 230 | if associated is None: 231 | image = self._slide 232 | if self._with_viewer: 233 | basename = os.path.join(self._basename, VIEWER_SLIDE_NAME) 234 | else: 235 | basename = self._basename 236 | else: 237 | image = ImageSlide(self._slide.associated_images[associated]) 238 | basename = os.path.join(self._basename, self._slugify(associated)) 239 | dz = DeepZoomGenerator(image, self._tile_size, self._overlap,limit_bounds=self._limit_bounds) 240 | tiler = DeepZoomImageTiler(dz, basename, self._format, associated,self._queue, self._slide, self._basenameJPG) 241 | tiler.run() 242 | self._dzi_data[self._url_for(associated)] = tiler.get_dzi() 243 | 244 | def _url_for(self, associated): 245 | if associated is None: 246 | base = VIEWER_SLIDE_NAME 247 | else: 248 | base = self._slugify(associated) 249 | return '%s.dzi' % base 250 | 251 | def _write_html(self): 252 | import jinja2 253 | env = jinja2.Environment(loader=jinja2.PackageLoader(__name__),autoescape=True) 254 | template = env.get_template('slide-multipane.html') 255 | associated_urls = dict((n, self._url_for(n)) 256 | for n in self._slide.associated_images) 257 | try: 258 | mpp_x = self._slide.properties[openslide.PROPERTY_NAME_MPP_X] 259 | mpp_y = self._slide.properties[openslide.PROPERTY_NAME_MPP_Y] 260 | mpp = (float(mpp_x) + float(mpp_y)) / 2 261 | except (KeyError, ValueError): 262 | mpp = 0 263 | # Embed the dzi metadata in the HTML to work around Chrome's 264 | # refusal to allow XmlHttpRequest from file:///, even when 265 | # the originating page is also a file:/// 266 | data = template.render(slide_url=self._url_for(None),slide_mpp=mpp,associated=associated_urls, properties=self._slide.properties, dzi_data=json.dumps(self._dzi_data)) 267 | with open(os.path.join(self._basename, 'index.html'), 'w') as fh: 268 | fh.write(data) 269 | 270 | def _write_static(self): 271 | basesrc = os.path.join(os.path.dirname(os.path.abspath(__file__)), 272 | 'static') 273 | basedst = os.path.join(self._basename, 'static') 274 | self._copydir(basesrc, basedst) 275 | self._copydir(os.path.join(basesrc, 'images'), 276 | os.path.join(basedst, 'images')) 277 | 278 | def _copydir(self, src, dest): 279 | if not os.path.exists(dest): 280 | os.makedirs(dest) 281 | for name in os.listdir(src): 282 | srcpath = os.path.join(src, name) 283 | if os.path.isfile(srcpath): 284 | shutil.copy(srcpath, os.path.join(dest, name)) 285 | 286 | @classmethod 287 | def _slugify(cls, text): 288 | text = normalize('NFKD', text.lower()).encode('ascii', 'ignore').decode() 289 | return re.sub('[^a-z0-9]+', '_', text) 290 | 291 | def _shutdown(self): 292 | for _i in range(self._workers): 293 | self._queue.put(None) 294 | self._queue.join() 295 | 296 | 297 | 298 | def ImgWorker(queue): 299 | print("ImgWorker started") 300 | while True: 301 | cmd = queue.get() 302 | if cmd is None: 303 | queue.task_done() 304 | break 305 | print("Execute: %s" % (cmd)) 306 | subprocess.Popen(cmd, shell=True).wait() 307 | queue.task_done() 308 | 309 | 310 | if __name__ == '__main__': 311 | parser = OptionParser(usage='Usage: %prog [options] ') 312 | 313 | parser.add_option('-L', '--ignore-bounds', dest='limit_bounds', 314 | default=True, action='store_false', 315 | help='display entire scan area') 316 | parser.add_option('-e', '--overlap', metavar='PIXELS', dest='overlap', 317 | type='int', default=1, 318 | help='overlap of adjacent tiles [1]') 319 | parser.add_option('-f', '--format', metavar='{jpeg|png}', dest='format', 320 | default='jpeg', 321 | help='image format for tiles [jpeg]') 322 | parser.add_option('-j', '--jobs', metavar='COUNT', dest='workers', 323 | type='int', default=4, 324 | help='number of worker processes to start [4]') 325 | parser.add_option('-o', '--output', metavar='NAME', dest='basename', 326 | help='base name of output file') 327 | parser.add_option('-Q', '--quality', metavar='QUALITY', dest='quality', 328 | type='int', default=90, 329 | help='JPEG compression quality [90]') 330 | parser.add_option('-r', '--viewer', dest='with_viewer', 331 | action='store_true', 332 | help='generate directory tree with HTML viewer') 333 | parser.add_option('-s', '--size', metavar='PIXELS', dest='tile_size', 334 | type='int', default=254, 335 | help='tile size [254]') 336 | parser.add_option('-B', '--Background', metavar='PIXELS', dest='Bkg', 337 | type='float', default=50, 338 | help='Max background threshold [50]') 339 | 340 | (opts, args) = parser.parse_args() 341 | 342 | 343 | try: 344 | slidepath = args[0] 345 | except IndexError: 346 | parser.error('Missing slide argument') 347 | if opts.basename is None: 348 | opts.basename = os.path.splitext(os.path.basename(slidepath))[0] 349 | 350 | 351 | 352 | # Initialization 353 | # imgExample = "/ifs/home/coudrn01/NN/Lung/RawImages/*/*svs" 354 | # tile_size = 512 355 | # max_number_processes = 10 356 | # NbrCPU = 4 357 | 358 | 359 | # get images from the data/ file. 360 | print(slidepath) 361 | files = glob(slidepath) 362 | files 363 | print('Number of files: %s' % (len(files))) 364 | print(files) 365 | print("***********************") 366 | 367 | ''' 368 | dz_queue = JoinableQueue() 369 | procs = [] 370 | print("Nb of processes:") 371 | print(opts.max_number_processes) 372 | for i in range(opts.max_number_processes): 373 | p = Process(target = ImgWorker, args = (dz_queue,)) 374 | #p.deamon = True 375 | p.setDaemon = True 376 | p.start() 377 | procs.append(p) 378 | ''' 379 | for imgNb in range(len(files)): 380 | filename = files[imgNb] 381 | # print(filename) 382 | opts.basenameJPG = os.path.splitext(os.path.basename(filename))[0] 383 | print("processing: " + opts.basenameJPG) 384 | #opts.basenameJPG = os.path.splitext(os.path.basename(slidepath))[0] 385 | #if os.path.isdir("%s_files" % (basename)): 386 | # print("EXISTS") 387 | #else: 388 | # print("Not Found") 389 | 390 | output = os.path.join(opts.basename, opts.basenameJPG) 391 | 392 | if os.path.isfile(output): 393 | print('%s: File already exists' % (output)) 394 | continue 395 | 396 | # Adding a try/except for kidney. 397 | try: 398 | DeepZoomStaticTiler(filename, output, opts.format, opts.tile_size, opts.overlap, opts.limit_bounds, opts.quality, opts.workers, opts.with_viewer, opts.Bkg, opts.basenameJPG).run() 399 | except ValueError: 400 | print("Corrupted file: {}".format(filename)) 401 | 402 | ''' 403 | dz_queue.join() 404 | for i in range(opts.max_number_processes): 405 | dz_queue.put( None ) 406 | ''' 407 | 408 | print("End") 409 | -------------------------------------------------------------------------------- /Tiling/0d_SortTiles.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Authors: Nicolas Coudray, Theodoros Sakellaropoulos 3 | Date created: March/2017 4 | Python Version: 2.7 (native on the cluster) 5 | Objective: 6 | Starting with tiles images, select images from a given magnification and order them according the the stage of the cancer, or type of cancer, etc... 7 | Usage: 8 | SourceFolder is the folder where all the svs images have been tiled . : 9 | It encloses: 10 | * 1 subfolder per image, (the name of the subfolder being "imagename_files") 11 | * each contains 14-17 subfolders which name is the magnification of the tiles and contains the tiles 12 | It should not enclose any other folder 13 | The output folder from which the script is run should be empty 14 | ''' 15 | import json 16 | from glob import glob 17 | import os 18 | from argparse import ArgumentParser 19 | import random 20 | from shutil import copyfile 21 | 22 | def extract_stage(metadata): 23 | stage = metadata['cases'][0]['diagnoses'][0]['tumor_stage'] 24 | stage = stage.replace(" ", "_") 25 | stage = stage.rstrip("a") 26 | stage = stage.rstrip("b") 27 | return stage 28 | 29 | 30 | def extract_cancer(metadata): 31 | return metadata['cases'][0]['project']['project_id'] 32 | 33 | 34 | def extract_sample_type(metadata): 35 | return metadata['cases'][0]['samples'][0]['sample_type'] 36 | 37 | 38 | def sort_cancer_stage_separately(metadata, **kwargs): 39 | sample_type = extract_sample_type(metadata) 40 | cancer = extract_cancer(metadata) 41 | if "Normal" in sample_type: 42 | stage = sample_type.replace(" ", "_") 43 | else: 44 | stage = extract_stage(metadata) 45 | 46 | return os.path.join(cancer, stage) 47 | 48 | 49 | def sort_cancer_stage(metadata, **kwargs): 50 | sample_type = extract_sample_type(metadata) 51 | cancer = extract_cancer(metadata) 52 | stage = extract_stage(metadata) 53 | if "Normal" in sample_type: 54 | return sample_type.replace(" ", "_") 55 | return cancer + "_" + stage 56 | 57 | 58 | def sort_type(metadata, **kwargs): 59 | cancer = extract_cancer(metadata) 60 | sample_type = extract_sample_type(metadata) 61 | if "Normal" in sample_type: 62 | return sample_type.replace(" ", "_") 63 | return cancer 64 | 65 | 66 | def sort_cancer_type(metadata, **kwargs): 67 | sample_type = extract_sample_type(metadata) 68 | if "Normal" in sample_type: 69 | return None 70 | return extract_cancer(metadata) 71 | 72 | 73 | def sort_cancer_healthy_pairs(metadata, **kwargs): 74 | sample_type = extract_sample_type(metadata) 75 | cancer = extract_cancer(metadata) 76 | if "Normal" in sample_type: 77 | return os.path.join(cancer, sample_type.replace(" ", "_")) 78 | return os.path.join(cancer, cancer) 79 | 80 | 81 | def sort_cancer_healthy(metadata, **kwargs): 82 | sample_type = extract_sample_type(metadata) 83 | cancer = extract_cancer(metadata) 84 | if "Normal" in sample_type: 85 | return sample_type.replace(" ", "_") 86 | return cancer 87 | 88 | 89 | def sort_random(metadata, **kwargs): 90 | AllOptions = ['TCGA-LUAD', 'TCGA-LUSC', 'Solid_Tissue_Normal'] 91 | return AllOptions[random.randint(0, 2)] 92 | 93 | 94 | def sort_mutational_burden(metadata, load_dic, **kwargs): 95 | submitter_id = metadata["cases"][0]["submitter_id"] 96 | try: 97 | return load_dic[submitter_id] 98 | except KeyError: 99 | return None 100 | 101 | 102 | def sort_mutation_metastatic(metadata, load_dic, **kwargs): 103 | sample_type = extract_sample_type(metadata) 104 | if "Metastatic" in sample_type: 105 | submitter_id = metadata["cases"][0]["submitter_id"] 106 | try: 107 | return load_dic[submitter_id] 108 | except KeyError: 109 | return None 110 | return None 111 | 112 | 113 | def sort_setonly(metadata, load_dic, **kwargs): 114 | return 'All' 115 | 116 | def sort_location(metadata, load_dic, **kwargs): 117 | sample_type = extract_sample_type(metadata) 118 | return sample_type.replace(" ", "_") 119 | 120 | def sort_melanoma_POD(metadata, load_dic, **kwargs): 121 | Response = metadata['Response to Treatment (Best Response)'] 122 | if 'POD' in Response: 123 | return 'POD' 124 | else: 125 | return 'Response' 126 | 127 | def sort_melanoma_Toxicity(metadata, load_dic, **kwargs): 128 | return metadata['Toxicity Observed'] 129 | 130 | 131 | def sort_text(metadata, load_dic, **kwargs): 132 | return metadata 133 | 134 | def copy_svs_lymph_melanoma(metadata, load_dic, **kwargs): 135 | sample_type = extract_sample_type(metadata) 136 | if "Metastatic" in sample_type: 137 | submitter_id = metadata["cases"][0]["diagnoses"][0]["tissue_or_organ_of_origin"] 138 | if 'c77' in submitter_id: 139 | try: 140 | return True 141 | except KeyError: 142 | return False 143 | else: 144 | return False 145 | return False 146 | 147 | 148 | 149 | def copy_svs_skin_primtumor(metadata, load_dic, **kwargs): 150 | sample_type = extract_sample_type(metadata) 151 | if "Primary" in sample_type: 152 | submitter_id = metadata["cases"][0]["diagnoses"][0]["tissue_or_organ_of_origin"] 153 | if 'c44' in submitter_id: 154 | try: 155 | return True 156 | except KeyError: 157 | return False 158 | else: 159 | return False 160 | return False 161 | 162 | 163 | 164 | sort_options = [ 165 | sort_cancer_stage_separately, 166 | sort_cancer_stage, 167 | sort_type, 168 | sort_cancer_type, 169 | sort_cancer_healthy_pairs, 170 | sort_cancer_healthy, 171 | sort_random, 172 | sort_mutational_burden, 173 | sort_mutation_metastatic, 174 | sort_setonly, 175 | sort_location, 176 | sort_melanoma_POD, 177 | sort_melanoma_Toxicity, 178 | sort_text, 179 | copy_svs_lymph_melanoma, 180 | copy_svs_skin_primtumor, 181 | ] 182 | 183 | if __name__ == '__main__': 184 | # python 0d_SortTiles_stage.py '/ifs/home/coudrn01/NN/Lung/Test_All512pxTiled/512pxTiled' '/ifs/home/coudrn01/NN/Lung/RawImages/metadata.cart.2017-03-02T00_36_30.276824.json' 20 5 3 15 15 185 | 186 | descr = """ 187 | Example: python /ifs/home/coudrn01/NN/Lung/0d_SortTiles.py --SourceFolder='/ifs/data/abl/deepomics/pancreas/images_TCGA/512pxTiled_b' --JsonFile='/ifs/data/abl/deepomics/pancreas/images_TCGA/Raw/metadata.cart.2017-09-08T14_46_02.589953.json' --Magnification=20 --MagDiffAllowed=5 --SortingOption=3 --PercentTest=100 --PercentValid=0 --PatientID=12 --nSplit 0 188 | In this example, the images are expected to be in folders in this directory: '/ifs/data/abl/deepomics/pancreas/images_TCGA/512pxTiled_b' 189 | Each images should have its own sub-folder with the svs image name followed by '_files' 190 | Each images should have subfolders with names corresponding to the magnification associated with the jpeg tiles saved inside it 191 | The sorting will be done using tiles corresponding to a magnification of 20 (+/- 5 if the 20 folder does not exist) 192 | 15%% will be put for validation, 15%% for testing and the leftover for training. However, if split is > 0, then the data will be split in train/test only in "# split" non-overlapping ways (each way will have 100/(#split) % of test images). 193 | linked images' names will start with 'train_', 'test_' or 'valid_' followed by the svs name and the tile ID 194 | Sorting options are: 195 | 1. sort according to cancer stage (i, ii, iii or iv) for each cancer separately (classification can be done separately for each cancer) 196 | 2. sort according to cancer stage (i, ii, iii or iv) for each cancer (classification can be done on everything at once) 197 | 3. sort according to type of cancer (LUSC, LUAD, or Nomal Tissue) 198 | 4. sort according to type of cancer (LUSC, LUAD) 199 | 5. sort according to type of cancer / Normal Tissue (2 variables per type) 200 | 6. sort according to cancer / Normal Tissue (2 variables) 201 | 7. Random labels (3 variables for false positive control) 202 | 8. sort according to mutational load (High/Low). Must specify --TMB option. 203 | 9. sort according to BRAF mutations for metastatic only. Must specify --TMB option (BRAF mutant for each file). 204 | 10. Do not sort. Just create symbolic links and assign images to train/test/valid sets. 205 | 11. Sample location (Normal, metastatic, etc...) 206 | 12. Osman's melanoma: Response to Treatment (Best Response) (POD vs other) 207 | 13. Osman's melanoma: Toxicity observed 208 | 14. Json is actually a text file. First column is ID, second is the labels 209 | 15. Copy (not symlink) SVS slides (not jpeg tiles) to new directory if Melanoma + Lymph 210 | 16. Copy (not symlink) SVS slides (not jpeg tiles) to new directory if Primary Tumor + skin 211 | """ 212 | ## Define Arguments 213 | parser = ArgumentParser(description=descr) 214 | 215 | parser.add_argument("--SourceFolder", help="path to tiled images", dest='SourceFolder') 216 | parser.add_argument("--JsonFile", help="path to metadata json file", dest='JsonFile') 217 | parser.add_argument("--Magnification", help="magnification to use", type=float, dest='Magnification') 218 | parser.add_argument("--MagDiffAllowed", help="difference allwed on Magnification", type=float, dest='MagDiffAllowed') 219 | parser.add_argument("--SortingOption", help="see option at the epilog", type=int, dest='SortingOption') 220 | parser.add_argument("--PercentValid", help="percentage of images for validation (between 0 and 100)", type=float, dest='PercentValid') 221 | parser.add_argument("--PercentTest", help="percentage of images for testing (between 0 and 100)", type=float, dest='PercentTest') 222 | parser.add_argument("--PatientID", help="Patient ID is supposed to be the first PatientID characters (integer expected) of the folder in which the pyramidal jpgs are. Slides from same patient will be in same train/test/valid set. This option is ignored if set to 0 or -1 ", type=int, dest='PatientID') 223 | parser.add_argument("--TMB", help="path to json file with mutational loads; or to BRAF mutations", dest='TMB') 224 | parser.add_argument("--nSplit", help="interger n: Split into train/test in n different ways", dest='nSplit') 225 | 226 | ## Parse Arguments 227 | args = parser.parse_args() 228 | 229 | if args.PatientID is None: 230 | print("PatientID ignored") 231 | args.PatientID = 0 232 | 233 | if args.nSplit is None: 234 | args.nSplit = 0 235 | elif int(args.nSplit) > 0 : 236 | args.PercentValid = 100/int(args.nSplit) 237 | args.PercentTest = 0 238 | 239 | SourceFolder = os.path.abspath(args.SourceFolder) 240 | if args.SortingOption in [15, 16]: 241 | # raw TCGA svs images 242 | imgFolders = glob(os.path.join(SourceFolder, "*.svs")) 243 | random.shuffle(imgFolders) # randomize order of images 244 | else: 245 | imgFolders = glob(os.path.join(SourceFolder, "*_files")) 246 | random.shuffle(imgFolders) # randomize order of images 247 | 248 | JsonFile = args.JsonFile 249 | if '.json' in JsonFile: 250 | with open(JsonFile) as fid: 251 | jdata = json.loads(fid.read()) 252 | try: 253 | jdata = dict((jd['file_name'].rstrip('.svs'), jd) for jd in jdata) 254 | except: 255 | jdata = dict((jd['Patient ID'], jd) for jd in jdata) 256 | else: 257 | with open(JsonFile, "rU") as f: 258 | jdata = {} 259 | for line in f: 260 | tmp_PID = line.split()[0] 261 | jdata[tmp_PID[:args.PatientID]] = line.split()[1] 262 | 263 | print("jdata:") 264 | print(jdata) 265 | Magnification = args.Magnification 266 | MagDiffAllowed = args.MagDiffAllowed 267 | 268 | SortingOption = args.SortingOption - 1 # transform to 0-based index 269 | try: 270 | sort_function = sort_options[SortingOption] 271 | except IndexError: 272 | raise ValueError("Unknown sort option") 273 | 274 | 275 | # Special case: svs images - copy and exit program 276 | if args.SortingOption in [15, 16]: 277 | # raw TCGA svs images 278 | for cFolderName in imgFolders: 279 | print("-----------") 280 | print(cFolderName) 281 | 282 | imgRootName = os.path.basename(cFolderName) 283 | imgRootName = imgRootName.rstrip('.svs') 284 | try: 285 | image_meta = jdata[imgRootName] 286 | except KeyError: 287 | try: 288 | image_meta = jdata[imgRootName[:args.PatientID]] 289 | except KeyError: 290 | print("file_name %s not found in metadata" % imgRootName[:args.PatientID]) 291 | continue 292 | IsCopy = sort_function(image_meta, load_dic={}) 293 | if IsCopy: 294 | copyfile(cFolderName, os.path.join(os.getcwd(), imgRootName + '.svs' ) ) 295 | quit() 296 | 297 | PercentValid = args.PercentValid / 100. 298 | if not 0 <= PercentValid <= 1: 299 | raise ValueError("PercentValid is not between 0 and 100") 300 | PercentTest = args.PercentTest / 100. 301 | if not 0 <= PercentTest <= 1: 302 | raise ValueError("PercentTest is not between 0 and 100") 303 | # Tumor mutational burden dictionary 304 | TMBFile = args.TMB 305 | mut_load = {} 306 | if SortingOption == 7: 307 | if TMBFile: 308 | with open(TMBFile) as fid: 309 | mut_load = json.loads(fid.read()) 310 | else: 311 | raise ValueError("For SortingOption = 8 you must specify the --TMB option") 312 | elif SortingOption == 8: 313 | if TMBFile: 314 | with open(TMBFile) as fid: 315 | mut_load = json.loads(fid.read()) 316 | else: 317 | raise ValueError("For SortingOption = 9 you must specify the --TMB option") 318 | 319 | 320 | 321 | ## Main Loop 322 | print("******************") 323 | Classes = {} 324 | NbrTilesCateg = {} 325 | PercentTilesCateg = {} 326 | NbrImagesCateg = {} 327 | PercentSlidesCateg = {} 328 | Patient_set = {} 329 | NbSlides = 0 330 | if int(args.nSplit) > 0: 331 | ttv_split = [] 332 | nbr_valid = [] 333 | for nSet in range(int(args.nSplit)): 334 | ttv_split.append("train") 335 | nbr_valid.append(0) 336 | ttv_split[0] = "test" 337 | 338 | print("imgFolders") 339 | print(imgFolders) 340 | for cFolderName in imgFolders: 341 | NbSlides += 1 342 | #if NbSlides > 10: 343 | # raise ValueError("small test debug") 344 | # exit() 345 | # raise SystemExit 346 | # break 347 | 348 | print("**************** starting %s" % cFolderName) 349 | imgRootName = os.path.basename(cFolderName) 350 | imgRootName = imgRootName.rstrip('_files') 351 | 352 | try: 353 | image_meta = jdata[imgRootName] 354 | except KeyError: 355 | try: 356 | image_meta = jdata[imgRootName[:args.PatientID]] 357 | except KeyError: 358 | print("file_name %s not found in metadata" % imgRootName[:args.PatientID]) 359 | continue 360 | 361 | SubDir = sort_function(image_meta, load_dic=mut_load) 362 | if int(args.nSplit) > 0: 363 | # n-fold cross validation 364 | for nSet in range(int(args.nSplit)): 365 | SetDir = "set_" + str(nSet) 366 | if not os.path.exists(SetDir): 367 | os.makedirs(SetDir) 368 | if SubDir is None: 369 | print("image not valid for this sorting option") 370 | continue 371 | if not os.path.exists(os.path.join(SetDir, SubDir)): 372 | os.makedirs(os.path.join(SetDir, SubDir)) 373 | 374 | 375 | else: 376 | SetDir = "" 377 | if SubDir is None: 378 | print("image not valid for this sorting option") 379 | continue 380 | if not os.path.exists(SubDir): 381 | os.makedirs(SubDir) 382 | 383 | try: 384 | Classes[SubDir].append(imgRootName) 385 | except KeyError: 386 | Classes[SubDir] = [imgRootName] 387 | 388 | # Check in the reference directories if there is a set of tiles at the desired magnification 389 | AvailMagsDir = [x for x in os.listdir(cFolderName) 390 | if os.path.isdir(os.path.join(cFolderName, x))] 391 | AvailMags = tuple(float(x) for x in AvailMagsDir) 392 | # check if the mag was known for that slide 393 | if max(AvailMags) < 0: 394 | print("Magnification was not known for that file.") 395 | continue 396 | mismatch, imin = min((abs(x - Magnification), i) for i, x in enumerate(AvailMags)) 397 | if mismatch <= MagDiffAllowed: 398 | AvailMagsDir = AvailMagsDir[imin] 399 | else: 400 | # No Tiles at the mag within the allowed range 401 | print("No Tiles found at the mag within the allowed range.") 402 | continue 403 | 404 | # Copy/symbolic link the images into the appropriate folder-type 405 | print("Symlinking tiles...") 406 | SourceImageDir = os.path.join(cFolderName, AvailMagsDir, "*") 407 | AllTiles = glob(SourceImageDir) 408 | 409 | if SubDir in NbrTilesCateg.keys(): 410 | print SubDir + " already in dictionary" 411 | else: 412 | print SubDir + " not yet in dictionary" 413 | NbrTilesCateg[SubDir] = 0 414 | NbrTilesCateg[SubDir + "_train"] = 0 415 | NbrTilesCateg[SubDir + "_test"] = 0 416 | NbrTilesCateg[SubDir + "_valid"] = 0 417 | PercentTilesCateg[SubDir + "_train"] = 0 418 | PercentTilesCateg[SubDir + "_test"] = 0 419 | PercentTilesCateg[SubDir + "_valid"] = 0 420 | NbrImagesCateg[SubDir] = 0 421 | NbrImagesCateg[SubDir + "_train"] = 0 422 | NbrImagesCateg[SubDir + "_test"] = 0 423 | NbrImagesCateg[SubDir + "_valid"] = 0 424 | PercentSlidesCateg[SubDir + "_train"] = 0 425 | PercentSlidesCateg[SubDir + "_test"] = 0 426 | PercentSlidesCateg[SubDir + "_valid"] = 0 427 | 428 | NbTiles = 0 429 | for TilePath in AllTiles: 430 | NbTiles += 1 431 | TileName = os.path.basename(TilePath) 432 | 433 | print("current percent in test, valid and ID") 434 | print(PercentSlidesCateg.get(SubDir + "_test")) 435 | print(PercentSlidesCateg.get(SubDir + "_valid")) 436 | print(PercentTest, PercentValid) 437 | print(PercentSlidesCateg.get(SubDir + "_test") 0: 453 | for nSet in range(int(args.nSplit)): 454 | ttv_split[nSet] = "train" 455 | 456 | if args.PatientID > 0: 457 | Patient = imgRootName[:args.PatientID] 458 | if Patient in Patient_set: 459 | SetIndx = Patient_set[Patient] 460 | else: 461 | SetIndx = nbr_valid.index(min(nbr_valid)) 462 | Patient_set[Patient] = SetIndx 463 | else: 464 | SetIndx = nbr_valid.index(min(nbr_valid)) 465 | 466 | ttv_split[SetIndx] = "test" 467 | nbr_valid[SetIndx] = nbr_valid[SetIndx] + 1 468 | 469 | for nSet in range(int(args.nSplit)): 470 | SetDir = "set_" + str(nSet) 471 | NewImageDir = os.path.join(SetDir, SubDir, "_".join((ttv_split[nSet], imgRootName, TileName))) # all train initially 472 | os.symlink(TilePath, NewImageDir) 473 | 474 | else: 475 | if args.PatientID > 0: 476 | Patient = imgRootName[:args.PatientID] 477 | if Patient in Patient_set: 478 | ttv = Patient_set[Patient] 479 | else: 480 | Patient_set[Patient] = ttv 481 | print(ttv) 482 | 483 | NewImageDir = os.path.join(SubDir, "_".join((ttv, imgRootName, TileName))) # all train initially 484 | os.symlink(TilePath, NewImageDir) 485 | # update stats 486 | 487 | NbrTilesCateg[SubDir] = NbrTilesCateg.get(SubDir) + NbTiles 488 | NbrImagesCateg[SubDir] = NbrImagesCateg.get(SubDir) + 1 489 | if ttv == "train": 490 | NbrTilesCateg[SubDir + "_train"] = NbrTilesCateg.get(SubDir + "_train") + NbTiles 491 | NbrImagesCateg[SubDir + "_train"] = NbrImagesCateg[SubDir + "_train"] + 1 492 | elif ttv == "test": 493 | NbrTilesCateg[SubDir + "_test"] = NbrTilesCateg.get(SubDir + "_test") + NbTiles 494 | NbrImagesCateg[SubDir + "_test"] = NbrImagesCateg[SubDir + "_test"] + 1 495 | elif ttv == "valid": 496 | NbrTilesCateg[SubDir + "_valid"] = NbrTilesCateg.get(SubDir + "_valid") + NbTiles 497 | NbrImagesCateg[SubDir + "_valid"] = NbrImagesCateg[SubDir + "_valid"] + 1 498 | 499 | PercentTilesCateg[SubDir + "_train"] = float(NbrTilesCateg.get(SubDir + "_train")) / float(NbrTilesCateg.get(SubDir)) 500 | PercentTilesCateg[SubDir + "_test"] = float(NbrTilesCateg.get(SubDir + "_test")) / float(NbrTilesCateg.get(SubDir)) 501 | PercentTilesCateg[SubDir + "_valid"] = float(NbrTilesCateg.get(SubDir + "_valid")) / float(NbrTilesCateg.get(SubDir)) 502 | PercentSlidesCateg[SubDir + "_train"] = float(NbrImagesCateg.get(SubDir + "_train")) / float(NbrImagesCateg.get(SubDir)) 503 | PercentSlidesCateg[SubDir + "_test"] = float(NbrImagesCateg.get(SubDir + "_test")) / float(NbrImagesCateg.get(SubDir)) 504 | PercentSlidesCateg[SubDir + "_valid"] = float(NbrImagesCateg.get(SubDir + "_valid")) / float(NbrImagesCateg.get(SubDir)) 505 | 506 | print("Done. %d tiles linked to %s " % ( NbTiles, SubDir ) ) 507 | print("Train / Test / Validation tiles sets for %s = %f %% / %f %% / %f %%" % (SubDir, PercentTilesCateg.get(SubDir + "_train"), PercentTilesCateg.get(SubDir + "_test"), PercentTilesCateg.get(SubDir + "_valid") ) ) 508 | print("Train / Test / Validation slides sets for %s = %f %% / %f %% / %f %%" % (SubDir, PercentSlidesCateg.get(SubDir + "_train"), PercentSlidesCateg.get(SubDir + "_test"), PercentSlidesCateg.get(SubDir + "_valid") ) ) 509 | 510 | for k, v in sorted(NbrTilesCateg.iteritems()): 511 | print(k, v) 512 | for k, v in sorted(PercentTilesCateg.iteritems()): 513 | print(k, v) 514 | for k, v in sorted(NbrImagesCateg.iteritems()): 515 | print(k, v) 516 | 517 | 518 | 519 | ''' 520 | # Partition the dataset into train / test / valid 521 | print("********* Partitioning files to train / test / valid") 522 | for SubDir, Images in Classes.items(): 523 | print("Working in Class %s" % SubDir) 524 | Nimages = len(Images) 525 | Ntest = int(round(Nimages * PercentTest)) 526 | Nvalid = int(round(Nimages * PercentValid)) 527 | print("Total number of images %d" % Nimages) 528 | print("Number of test images %d" % Ntest) 529 | print("Number of validation images %d" % Nvalid) 530 | # rename first images 531 | NbTilesValid = 0 532 | for imgRootName in Images[:Nvalid]: 533 | oldprefix = "train_" + imgRootName 534 | newprefix = "valid_" + imgRootName 535 | TileGlob = os.path.join(SubDir, oldprefix + "_*") 536 | for TilePath in glob(TileGlob): 537 | os.rename(TilePath, TilePath.replace(oldprefix, newprefix)) 538 | NbTilesValid += 1 539 | # rename last images 540 | NbTilesTest = 0 541 | for imgRootName in Images[-Ntest:]: 542 | oldprefix = "train_" + imgRootName 543 | newprefix = "test_" + imgRootName 544 | TileGlob = os.path.join(SubDir, oldprefix + "_*") 545 | for TilePath in glob(TileGlob): 546 | os.rename(TilePath, TilePath.replace(oldprefix, newprefix)) 547 | NbTilesTest += 1 548 | NbTiles = len(os.listdir(SubDir)) 549 | NbTilesTrain = NbTiles - NbTilesTest - NbTilesValid 550 | pTrain = 100.0 * NbTilesTrain / NbTiles 551 | pValid = 100.0 * NbTilesValid / NbTiles 552 | pTest = 100.0 * NbTilesTest / NbTiles 553 | print("Done. %d tiles linked to %s " % (NbTiles, SubDir)) 554 | print("Train / Test / Validation sets for %s = %f %% / %f %% / %f %%" % (SubDir, pTrain, pTest, pValid)) 555 | ''' -------------------------------------------------------------------------------- /Tiling/BuildTileDictionary.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | import os 4 | import pickle 5 | import argparse 6 | 7 | ''' 8 | Author: @edufierro 9 | 10 | Capstone project 11 | 12 | Purpose: Get dictionary with files:[2Darray tiles, type of cancer] 13 | ''' 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--data', type=str, default='Lung', help='Data to train on (Lung/Breast/Kidney)') 17 | parser.add_argument('--file_path', type=str, default='/beegfs/jmw784/Capstone/', help='Root path where the tiles are') 18 | parser.add_argument('--train_log', type=str, default='', help='path to the training output') 19 | 20 | opt = parser.parse_args() 21 | 22 | root_dir = opt.file_path + opt.data + "TilesSorted/" 23 | out_file = opt.file_path + opt.data + "_FileMappingDict.p" 24 | train_log = opt.train_log 25 | 26 | def find_classes(dir): 27 | # Classes are subdirectories of the root directory 28 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 29 | classes.sort() 30 | class_to_idx = {classes[i]: i for i in range(len(classes))} 31 | return classes, class_to_idx 32 | 33 | 34 | def get_class_coding(lf): 35 | auc_new = [] 36 | phrase = "Class encoding:" 37 | 38 | with open(lf, 'r+') as f: 39 | lines = f.readlines() 40 | for i in range(0, len(lines)): 41 | line = lines[i] 42 | #print(line) 43 | if phrase in line: 44 | class_encoding = lines[i + 1] # you may want to check that i < len(lines) 45 | break 46 | 47 | class_encoding = class_encoding.strip('\n').strip('{').strip('}') 48 | #print(class_encoding) 49 | 50 | class_names = [] 51 | class_codes = [] 52 | 53 | for c in class_encoding.split(','): 54 | #print(c) 55 | class_names.append(c.split(':')[0].replace("'", "").replace(" ", ""))#.split('-')[-1]) 56 | class_codes.append(int(c.split(':')[1])) 57 | 58 | 59 | class_coding = {} 60 | for i in range(len(class_names)): 61 | class_coding[class_codes[i]] = class_names[i] 62 | 63 | class_codes.sort() 64 | return class_names, class_codes, class_coding 65 | 66 | 67 | def getCoords(tile_list): 68 | 69 | ''' 70 | Given a list of tiles, with format: 71 | [test, valid, train]_NAME_x_y.jpeg 72 | Returns a two list of same size with xcoords and y coords 73 | ''' 74 | 75 | xcoords = [re.split("_", i)[-2] for i in tile_list] 76 | xcoords = list(map(int, xcoords)) 77 | ycoords = [re.split("_", i)[-1] for i in tile_list] 78 | ycoords = [re.sub(".jpeg", "", i) for i in ycoords] 79 | ycoords = list(map(int, ycoords)) 80 | 81 | return xcoords, ycoords 82 | 83 | def fileCleaner(tile_list): 84 | 85 | ''' 86 | Given a list of tiles, remove coords ("_X_Y_") and ".jpeg" termination 87 | ''' 88 | 89 | tile_list = [re.sub("_[0-9]*_[0-9]*.jpeg", "", x) for x in tile_list] 90 | 91 | return (tile_list) 92 | 93 | def get2Darray(xcoords, ycoords, tiles_input): 94 | 95 | ''' 96 | Given a list of xcoords, ycoords and files, returns a 2D array where each file 97 | correspond to the pair of coords 98 | ''' 99 | 100 | xmax = max(xcoords) + 1 101 | ymax = max(ycoords) + 1 102 | tiles_output = np.empty((ymax, xmax), dtype=np.dtype((str, 100))) 103 | for i in range(0,len(xcoords)): 104 | tiles_output[ycoords[i], xcoords[i]] = tiles_input[i] 105 | 106 | return tiles_output 107 | 108 | def fastdump(obj, file): 109 | p = pickle.Pickler(file) 110 | p.fast = True 111 | p.dump(obj) 112 | 113 | def main(): 114 | 115 | if os.path.exists(out_file): 116 | response = None 117 | 118 | while response not in ['y', 'n']: 119 | response = input('Tile dictionary already exists, do you want to continue (y/n)? ') 120 | 121 | if response == 'n': 122 | quit() 123 | 124 | #if classes provided for testing are less than for training 125 | if(train_log!=''): 126 | c_names, c_codes, c_coding = get_class_coding(train_log) 127 | c_coding_invert = {v: k for k, v in c_coding.items()} 128 | classes, _ = find_classes(root_dir) 129 | class_to_idx = {} 130 | for n in classes: 131 | class_to_idx[n] = c_coding_invert[n] 132 | else: 133 | classes, class_to_idx = find_classes(root_dir) 134 | 135 | 136 | print(class_to_idx) 137 | 138 | tile_files = {} 139 | original_files = {} 140 | main_dict = {} 141 | 142 | print("Importing File Names...") 143 | 144 | for c in classes: 145 | tile_files[c] = os.listdir(root_dir + c) 146 | original_files[c] = fileCleaner(tile_files[c]) 147 | 148 | print("Updating dict for %s files ..." % (c)) 149 | 150 | for file in set(original_files[c]): 151 | 152 | index_list = [i for i, x in enumerate(original_files[c]) if x==file] 153 | tiles = [tile_files[c][i] for i in index_list] 154 | xs, ys = getCoords(tiles) 155 | tiles_array = get2Darray(xs, ys, tiles) 156 | loop_dict = {file:[tiles_array, class_to_idx[c]]} 157 | main_dict.update(loop_dict) 158 | 159 | # Prevent running out of memory 160 | del tiles, xs, ys, tiles_array, loop_dict, tile_files[c], original_files[c] 161 | 162 | fastdump(main_dict, open(out_file, "wb" ) ) 163 | print("Dictionary Ready!!! Saved as pickle in: \n {0}".format(out_file)) 164 | 165 | return main_dict 166 | 167 | if __name__ == '__main__': 168 | main() 169 | -------------------------------------------------------------------------------- /checkpoints/breast_subtype.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sedab/PathCNN/3208c552ee2113a0cc85363373318ce175e76c71/checkpoints/breast_subtype.pth -------------------------------------------------------------------------------- /checkpoints/kidney_subtype.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sedab/PathCNN/3208c552ee2113a0cc85363373318ce175e76c71/checkpoints/kidney_subtype.pth -------------------------------------------------------------------------------- /checkpoints/lung_suptype.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sedab/PathCNN/3208c552ee2113a0cc85363373318ce175e76c71/checkpoints/lung_suptype.pth -------------------------------------------------------------------------------- /run-jupyter.sbatch: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=gpu8_medium 3 | #SBATCH --job-name=jupyter 4 | #SBATCH --gres=gpu:1 5 | ##SBATCH --output=outputs/rq_train1_%A_%a.out 6 | ##SBATCH --error=outputs/rq_train1_%A_%a.err 7 | #SBATCH --mem=200GB 8 | 9 | 10 | # get tunneling info 11 | XDG_RUNTIME_DIR="" 12 | port=$(shuf -i8000-9999 -n1) 13 | node=$(hostname -s) 14 | user=$(whoami) 15 | cluster='bigpurple' 16 | 17 | # print tunneling instructions jupyter-log 18 | echo -e " 19 | MacOS or linux terminal command to create your ssh tunnel: 20 | ssh -N -L ${port}:${node}:${port} ${user}@${cluster}.nyumc.org 21 | 22 | For more info and how to connect from windows, 23 | see research.computing.yale.edu/jupyter-nb 24 | Here is the MobaXterm info: 25 | 26 | Forwarded port:same as remote port 27 | Remote server: ${node} 28 | Remote port: ${port} 29 | SSH server: ${cluster}.nyumc.org 30 | SSH login: $user 31 | SSH port: 22 32 | https://research.computing.yale.edu/support/hpc/guides/running-jupyter-notebooks-clusters 33 | Use a Browser on your local machine to go to: 34 | localhost:${port} (prefix w/ https:// if using password) 35 | " 36 | 37 | #install 38 | 39 | # load modules or conda environments here 40 | # e.g. farnam: 41 | module purge 42 | module load anaconda3/gpu/5.2.0 43 | module load cuda91/toolkit/9.1.85 44 | module load miniconda3/4.5.1 45 | python3 -m pip install pybind11 --user 46 | #python3 -m pip install '/gpfs/share/apps/miniconda3/4.5.1/lib/python3.6/site-packages/fasttext_pybind.cpython-36m-x86_64-linux-gnu.so' --user 47 | module load python/gpu/3.6.5 48 | 49 | 50 | source activate myenv 51 | # DON'T USE ADDRESS BELOW. 52 | # DO USE TOKEN BELOW 53 | jupyter-notebook --no-browser --port=${port} --ip=${node} --NotebookApp.token='' 54 | 55 | -------------------------------------------------------------------------------- /run_job.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=gpu8_long 3 | #SBATCH --ntasks=8 4 | #SBATCH --cpus-per-task=1 5 | #SBATCH --job-name=train_PCNN 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=outputs/rq_train1_%A_%a.out 8 | #SBATCH --error=outputs/rq_train1_%A_%a.err 9 | #SBATCH --mem=200GB 10 | 11 | ##### above is for on nyulmc hpc: bigpurple ##### 12 | ##### below is for on nyu hpc: prince ##### 13 | ##!/bin/bash 14 | ## 15 | ##SBATCH --job-name=charrrr 16 | ##SBATCH --gres=gpu:1 17 | ##SBATCH --time=47:00:00 18 | ##SBATCH --mem=15GB 19 | ##SBATCH --output=outputs/%A.out 20 | ##SBATCH --error=outputs/%A.err 21 | #module purge 22 | #module load python3/intel/3.5.3 23 | #module load pytorch/python3.5/0.2.0_3 24 | #module load torchvision/python3.5/0.1.9 25 | #python3 -m pip install comet_ml —user 26 | 27 | echo "Starting at `date`" 28 | echo "Job name: $SLURM_JOB_NAME JobID: $SLURM_JOB_ID" 29 | echo "Running on hosts: $SLURM_NODELIST" 30 | echo "Running on $SLURM_NNODES nodes." 31 | echo "Running on $SLURM_NPROCS processors." 32 | 33 | 34 | module purge 35 | module load python/gpu/3.6.5 36 | 37 | exp_name="pancan_21c_tr" 38 | 39 | #nparam="--cuda --augment --dropout=0.1 --nonlinearity=leaky --init=xavier --calc_val_auc --root_dir=/gpfs/data/abl/deepomics/tsirigoslab/histopathology/Tiles/LungTilesSorted/ --num_class=3 --tile_dict_path=/gpfs/data/abl/deepomics/tsirigoslab/histopathology/Tiles/Lung_FileMappingDict.p" 40 | 41 | nparam="--cuda --augment --dropout=0.1 --nonlinearity=leaky --init=xavier --root_dir=/gpfs/scratch/bilals01/AllDs600/AllDs600TilesSorted/ --num_class=21 --tile_dict_path=/gpfs/scratch/bilals01/AllDs600/AllDs600_FileMappingDict.p" 42 | 43 | nexp="/gpfs/scratch/bilals01/test-repo/experiments/${exp_name}" 44 | 45 | output="/gpfs/scratch/bilals01/test-repo/logs/${exp_name}.log" 46 | 47 | python3 -u /gpfs/scratch/bilals01/test-repo/PathCNN/train.py $nparam --experiment $nexp > $output 48 | -------------------------------------------------------------------------------- /run_multiple_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Resource Request 3 | #SBATCH --partition=cpu_dev 4 | #SBATCH --job-name=gauto_conv 5 | #SBATCH --ntasks=1 6 | #SBATCH --cpus-per-task=1 7 | #SBATCH --mem=10G 8 | #SBATCH --time=3:00:00 9 | #SBATCH --output=outputs/cpu_train1_%A_%a.out 10 | #SBATCH --error=outputs/cpu_train1_%A_%a.err 11 | 12 | ##!/bin/bash 13 | ##SBATCH --partition=gpu4_short 14 | ##SBATCH --ntasks=8 15 | ##SBATCH --cpus-per-task=1 16 | ##SBATCH --job-name=multiple_PCNN 17 | ##SBATCH --gres=gpu:4 18 | ##SBATCH --output=outputs/rq_train1_%A_%a.out 19 | ##SBATCH --error=outputs/rq_train1_%A_%a.err 20 | ##SBATCH --mem=100GB 21 | 22 | ##### above is for on nyulmc hpc: bigpurple ##### 23 | ##### below is for on nyu hpc: prince ##### 24 | ##!/bin/bash 25 | ## 26 | ##SBATCH --job-name=charrrr 27 | ##SBATCH --gres=gpu:1 28 | ##SBATCH --time=47:00:00 29 | ##SBATCH --mem=15GB 30 | ##SBATCH --output=outputs/%A.out 31 | ##SBATCH --error=outputs/%A.err 32 | #module purge 33 | #module load python3/intel/3.5.3 34 | #module load pytorch/python3.5/0.2.0_3 35 | #module load torchvision/python3.5/0.1.9 36 | #python3 -m pip install comet_ml —user 37 | 38 | echo "Starting at `date`" 39 | echo "Job name: $SLURM_JOB_NAME JobID: $SLURM_JOB_ID" 40 | echo "Running on hosts: $SLURM_NODELIST" 41 | echo "Running on $SLURM_NNODES nodes." 42 | echo "Running on $SLURM_NPROCS processors." 43 | 44 | module purge 45 | module load python/gpu/3.6.5 46 | 47 | #input params 48 | exp_name="exp8" 49 | test_val="test3" 50 | 51 | nparam="--root_dir=/gpfs/data/abl/deepomics/tsirigoslab/histopathology/Tiles/LungTilesSorted/ --num_class=3 --tile_dict_path=/gpfs/data/abl/deepomics/tsirigoslab/histopathology/Tiles/Lung_FileMappingDict.p --val=${test_val}" 52 | 53 | nexp="/gpfs/scratch/bilals01/test-repo/experiments/${exp_name}" 54 | 55 | out="/gpfs/scratch/bilals01/test-repo/logs/${test_val}" 56 | 57 | if [ ! -d $out ]; then 58 | mkdir -p $out; 59 | fi 60 | 61 | # check if next checkpoint available 62 | declare -i count=1 63 | declare -i step=1 64 | 65 | while true; do 66 | echo $count 67 | PathToEpoch="${nexp}/checkpoints/" 68 | Cmodel="epoch_$count.pth" 69 | output="${out}/${test_val}_${exp_name}_${Cmodel}.log" 70 | echo $PathToEpoch 71 | echo $Cmodel 72 | echo $output 73 | if [ -f $PathToEpoch/$Cmodel ]; then 74 | python3 -u test.py --experiment $nexp --model $Cmodel $nparam > $output 75 | else 76 | break 77 | fi 78 | count=`expr "$count" + "$step"` 79 | done 80 | -------------------------------------------------------------------------------- /run_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=gpu8_short 3 | #SBATCH --ntasks=8 4 | #SBATCH --cpus-per-task=1 5 | #SBATCH --job-name=test_PCNN 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=outputs/rq_train1_%A_%a.out 8 | #SBATCH --error=outputs/rq_train1_%A_%a.err 9 | #SBATCH --mem=50GB 10 | 11 | ##### above is for on nyulmc hpc: bigpurple ##### 12 | ##### below is for on nyu hpc: prince ##### 13 | ##!/bin/bash 14 | ## 15 | ##SBATCH --job-name=charrrr 16 | ##SBATCH --gres=gpu:1 17 | ##SBATCH --time=47:00:00 18 | ##SBATCH --mem=15GB 19 | ##SBATCH --output=outputs/%A.out 20 | ##SBATCH --error=outputs/%A.err 21 | #module purge 22 | #module load python3/intel/3.5.3 23 | #module load pytorch/python3.5/0.2.0_3 24 | #module load torchvision/python3.5/0.1.9 25 | #python3 -m pip install comet_ml —user 26 | 27 | echo "Starting at `date`" 28 | echo "Job name: $SLURM_JOB_NAME JobID: $SLURM_JOB_ID" 29 | echo "Running on hosts: $SLURM_NODELIST" 30 | echo "Running on $SLURM_NNODES nodes." 31 | echo "Running on $SLURM_NPROCS processors." 32 | 33 | 34 | module purge 35 | module load python/gpu/3.6.5 36 | 37 | #input params 38 | exp_name="exp7" 39 | model_cp="epoch_2.pth" 40 | test_val="test" 41 | 42 | nparam="--model=${model_cp} --root_dir=/gpfs/data/abl/deepomics/tsirigoslab/histopathology/Tiles/LungTilesSorted/ --num_class=3 --tile_dict_path=/gpfs/data/abl/deepomics/tsirigoslab/histopathology/Tiles/Lung_FileMappingDict.p --val=${test_val}" 43 | 44 | nexp="/gpfs/scratch/bilals01/test-repo/experiments/${exp_name}" 45 | 46 | output="/gpfs/scratch/bilals01/test-repo/logs/${exp_name}_${test_val}_${model_cp}.log" 47 | 48 | python3 -u /gpfs/scratch/bilals01/test-repo/PathCNN/test.py $nparam --experiment $nexp > $output 49 | -------------------------------------------------------------------------------- /run_tsne.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=gpu8_medium 3 | #SBATCH --job-name=PathCNN_train 4 | #SBATCH --gres=gpu:4 5 | #SBATCH --output=outputs/rq_TSNE_%A_%a.out 6 | #SBATCH --error=outputs/rq_TSNE_%A_%a.err 7 | #SBATCH --mem=100GB 8 | 9 | ##### above is for on nyulmc hpc: bigpurple ##### 10 | ##### below is for on nyu hpc: prince ##### 11 | ##!/bin/bash 12 | ## 13 | ##SBATCH --job-name=charrrr 14 | ##SBATCH --gres=gpu:1 15 | ##SBATCH --time=47:00:00 16 | ##SBATCH --mem=15GB 17 | ##SBATCH --output=outputs/%A.out 18 | ##SBATCH --error=outputs/%A.err 19 | #module purge 20 | #module load python3/intel/3.5.3 21 | #module load pytorch/python3.5/0.2.0_3 22 | #module load torchvision/python3.5/0.1.9 23 | #python3 -m pip install comet_ml —user 24 | 25 | echo "Starting at `date`" 26 | echo "Job name: $SLURM_JOB_NAME JobID: $SLURM_JOB_ID" 27 | echo "Running on hosts: $SLURM_NODELIST" 28 | echo "Running on $SLURM_NNODES nodes." 29 | echo "Running on $SLURM_NPROCS processors." 30 | 31 | 32 | module purge 33 | module load python/gpu/3.6.5 34 | 35 | cd /gpfs/scratch/bilals01/test-repo/PathCNN/ 36 | 37 | python3 -u tsne.py $1 > logs/$2_tsne.log 38 | 39 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.parallel 5 | import torch.backends.cudnn as cudnn 6 | import torch.utils.data 7 | import torchvision.datasets as dset 8 | import torchvision.transforms as transforms 9 | import torchvision.utils as vutils 10 | import torch.nn.init as init 11 | from torch.autograd import Variable 12 | import argparse 13 | import numpy as np 14 | from PIL import Image 15 | from utils.dataloader import * 16 | #from utils.auc_test import * 17 | from utils.auc import * 18 | from utils import new_transforms 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--experiment', default='', help="name of experiment to test") 22 | parser.add_argument('--model', default='', help="name of model to test") 23 | parser.add_argument('--root_dir', type=str, default='TilesSorted/', help='Data directory .../dataTilesSorted/') 24 | parser.add_argument('--num_class', type=int, default=2, help='number of classes ') 25 | parser.add_argument('--tile_dict_path', type=str, default='"_FileMappingDict.p', help='Tile dictinory path') 26 | parser.add_argument('--val', type=str, default='test', help='validation set') 27 | parser.add_argument('--train_log', type=str, default='/gpfs/scratch/bilals01/test-repo/logs/exp6_train.log', help='point to the log file created from the training') 28 | 29 | opt = parser.parse_args() 30 | 31 | root_dir = str(opt.root_dir) 32 | num_classes = int(opt.num_class) 33 | tile_dict_path = str(opt.tile_dict_path) 34 | tl_file = str(opt.train_log) 35 | test_val = str(opt.val) 36 | 37 | imgSize = 299 38 | 39 | transform = transforms.Compose([new_transforms.Resize((imgSize,imgSize)), 40 | transforms.ToTensor(), 41 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 42 | 43 | test_data = TissueData(root_dir, test_val, train_log=tl_file, transform = transform, metadata=False) 44 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False, num_workers=8) 45 | 46 | classes = test_data.classes 47 | class_to_idx = test_data.class_to_idx 48 | 49 | print('Class encoding:') 50 | print(class_to_idx) 51 | 52 | def get_tile_probability(tile_path): 53 | 54 | """ 55 | Returns an array of probabilities for each class given a tile 56 | @param tile_path: Filepath to the tile 57 | @return: A ndarray of class probabilities for that tile 58 | """ 59 | 60 | # Some tiles are empty with no path, return nan 61 | if tile_path == '': 62 | return np.full(num_classes, np.nan) 63 | 64 | tile_path = root_dir + tile_path 65 | 66 | with open(tile_path, 'rb') as f: 67 | with Image.open(f) as img: 68 | img = img.convert('RGB') 69 | 70 | # Model expects a 4D tensor, unsqueeze first dimension 71 | img = transform(img).unsqueeze(0) 72 | img = img.cuda() 73 | 74 | # Turn output into probabilities with softmax 75 | var_img = Variable(img, volatile=True) 76 | output = F.softmax(model(var_img)).data.squeeze(0) 77 | 78 | return output.cpu().numpy() 79 | 80 | with open(tile_dict_path, 'rb') as f: 81 | tile_dict = pickle.load(f) 82 | 83 | def aggregate(file_list, method): 84 | 85 | """ 86 | Given a list of files, return scores for each class according to the 87 | method and labels for those files. 88 | @param file_list: A list of file paths to do predictions on 89 | @param method: 'average' - returns the average probability score across 90 | all tiles for that file 91 | 'max' - predicts each tile to be the class of the maximum 92 | score, and returns the proportion of tiles for 93 | each class 94 | @return: a ndarray of class probabilities for all files in the list 95 | a ndarray of the labels 96 | """ 97 | 98 | model.eval() 99 | predictions = [] 100 | true_labels = [] 101 | tile_count = [] 102 | 103 | for file in file_list: 104 | tile_paths, label = tile_dict[file] 105 | 106 | folder = classes[label] 107 | 108 | def add_folder(tile_path): 109 | if tile_path == '': 110 | return '' 111 | else: 112 | return folder + '/' + tile_path 113 | 114 | # Add the folder for the class name in front 115 | add_folder_v = np.vectorize(add_folder) 116 | tile_paths = add_folder_v(tile_paths) 117 | 118 | # Get the probability array for the file 119 | prob_v = np.vectorize(get_tile_probability, otypes=[np.ndarray]) 120 | probabilities = prob_v(tile_paths) 121 | 122 | 123 | """ 124 | imgSize = probabilities.shape() 125 | newShape = (imgSize[0], imgSize[1], 3) 126 | probabilities = np.reshape(np.stack(probabilities.flat), newShape) 127 | """ 128 | 129 | if method == 'average': 130 | probabilities = np.stack(probabilities.flat) 131 | prediction = np.nanmean(probabilities, axis = 0) 132 | 133 | elif method == 'max': 134 | probabilities = np.stack(probabilities.flat) 135 | probabilities = probabilities[~np.isnan(probabilities).all(axis=1)] 136 | votes = np.nanargmax(probabilities, axis=1) 137 | 138 | out = np.array([sum(votes == i) for i in range(num_classes)]) 139 | prediction = out / out.sum() 140 | 141 | else: 142 | raise ValueError('Method not valid') 143 | 144 | predictions.append(prediction) 145 | true_labels.append(label) 146 | tile_count.append(len(tile_paths)) 147 | 148 | return np.array(predictions), np.array(true_labels), np.array(tile_count) 149 | 150 | class BasicConv2d(nn.Module): 151 | 152 | def __init__(self, in_channels, out_channels, pool, **kwargs): 153 | super(BasicConv2d, self).__init__() 154 | 155 | self.pool = pool 156 | self.conv = nn.Conv2d(in_channels, out_channels, **kwargs) 157 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 158 | self.relu = nn.LeakyReLU() 159 | 160 | self.dropout = nn.Dropout(p=0.1) 161 | 162 | def forward(self, x): 163 | x = self.conv(x) 164 | 165 | if self.pool: 166 | x = F.max_pool2d(x, 2) 167 | 168 | x = self.relu(x) 169 | x = self.bn(x) 170 | x = self.dropout(x) 171 | return x 172 | 173 | # Define model 174 | class cancer_CNN(nn.Module): 175 | def __init__(self, nc, imgSize, ngpu): 176 | super(cancer_CNN, self).__init__() 177 | self.nc = nc 178 | self.imgSize = imgSize 179 | self.ngpu = ngpu 180 | #self.data = opt.data 181 | self.conv1 = BasicConv2d(nc, 16, False, kernel_size=5, padding=1, stride=2, bias=True) 182 | self.conv2 = BasicConv2d(16, 32, False, kernel_size=3, bias=True) 183 | self.conv3 = BasicConv2d(32, 64, True, kernel_size=3, padding=1, bias=True) 184 | self.conv4 = BasicConv2d(64, 64, True, kernel_size=3, padding=1, bias=True) 185 | self.conv5 = BasicConv2d(64, 128, True, kernel_size=3, padding=1, bias=True) 186 | self.conv6 = BasicConv2d(128, 64, True, kernel_size=3, padding=1, bias=True) 187 | self.linear = nn.Linear(5184, num_classes) 188 | 189 | def forward(self, x): 190 | x = self.conv1(x) 191 | x = self.conv2(x) 192 | x = self.conv3(x) 193 | x = self.conv4(x) 194 | x = self.conv5(x) 195 | x = self.conv6(x) 196 | x = x.view(x.size(0), -1) 197 | x = self.linear(x) 198 | return x 199 | 200 | model = cancer_CNN(3, imgSize, 1) 201 | model.cuda() 202 | 203 | model_path = opt.experiment + '/checkpoints/' + opt.model 204 | state_dict = torch.load(model_path) 205 | model.load_state_dict(state_dict) 206 | 207 | predictions, labels, num_of_tiles = aggregate(test_data.filenames, method='max') 208 | 209 | data = np.column_stack((test_data.filenames,np.asarray(predictions),np.asarray(labels),np.asarray(num_of_tiles))) 210 | 211 | 212 | data.dump(open('{0}/outputs/{1}_pred_label_max_{2}.npy'.format(opt.experiment,opt.val,opt.model), 'wb')) 213 | 214 | #This can be used if need to print the auc and save the roc curve automatically 215 | 216 | roc_auc = get_auc('{0}/images/{1}_AUC_max_{2}.jpg'.format(opt.experiment,opt.val,opt.model), 217 | predictions, labels, classes = range(num_classes)) 218 | print('Max method:') 219 | print(roc_auc) 220 | 221 | predictions, labels, num_of_tiles = aggregate(test_data.filenames, method='average') 222 | data = np.column_stack((test_data.filenames,np.asarray(predictions),np.asarray(labels),np.asarray(num_of_tiles))) 223 | data.dump(open('{0}/outputs/{1}_pred_label_avg_{2}.npy'.format(opt.experiment,opt.val,opt.model), 'wb')) 224 | 225 | #This can be used if need to print the auc and save the roc curve automatically 226 | roc_auc = get_auc('{0}/images/{1}_AUC_avg_{2}.jpg'.format(opt.experiment,opt.val,opt.model), 227 | predictions, labels, classes = range(num_classes)) 228 | print('Average method:') 229 | print(roc_auc) 230 | -------------------------------------------------------------------------------- /test_train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 17, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "import torch.nn.functional as F\n", 12 | "import torch.nn.parallel\n", 13 | "import torch.backends.cudnn as cudnn\n", 14 | "import torch.optim as optim\n", 15 | "import torch.utils.data\n", 16 | "import torchvision.datasets as dset\n", 17 | "import torchvision.transforms as transforms\n", 18 | "import torchvision.utils as vutils\n", 19 | "import torch.nn.init as init\n", 20 | "from torch.autograd import Variable\n", 21 | "\n", 22 | "import os\n", 23 | "import time\n", 24 | "import numpy as np\n", 25 | "from PIL import Image\n", 26 | "from utils.dataloader import *\n", 27 | "#use AUC for AUC and CI, auc2 for precision, AUC and CI, auc3 precision auc and CI\n", 28 | "from utils.auc import *\n", 29 | "from utils import new_transforms\n", 30 | "import argparse\n", 31 | "import random" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 2, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "ngpu = 1\n", 41 | "nc = 3\n", 42 | "imgSize = 299\n", 43 | "\n", 44 | "step_freq = 20000000\n", 45 | "\n", 46 | "\n", 47 | "root_dir = '/gpfs/data/abl/deepomics/tsirigoslab/histopathology/Tiles/LngTilesSorted/'\n", 48 | "num_classes = 3\n", 49 | "tile_dict_path = '/gpfs/data/abl/deepomics/tsirigoslab/histopathology/Tiles/Lng_FileMappingDict.p'\n" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 3, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "manualSeed = random.randint(1, 10000) # fix seed\n", 59 | "\n", 60 | "random.seed(manualSeed)\n", 61 | "torch.manual_seed(manualSeed)\n", 62 | "\n", 63 | "cudnn.benchmark = True" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 4, 69 | "metadata": {}, 70 | "outputs": [ 71 | { 72 | "name": "stdout", 73 | "output_type": "stream", 74 | "text": [ 75 | "Loading from: TCGA-LUSC\n", 76 | "number of samples: 52735\n", 77 | "Class encoding:\n", 78 | "{'TCGA-LUSC': 2}\n" 79 | ] 80 | } 81 | ], 82 | "source": [ 83 | "# Random data augmentation\n", 84 | "augment = transforms.Compose([new_transforms.Resize((imgSize, imgSize)),\n", 85 | " transforms.RandomHorizontalFlip(),\n", 86 | " new_transforms.RandomRotate(),\n", 87 | " new_transforms.ColorJitter(0.25, 0.25, 0.25, 0.05),\n", 88 | " transforms.ToTensor(),\n", 89 | " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n", 90 | "\n", 91 | "transform = transforms.Compose([new_transforms.Resize((imgSize,imgSize)),\n", 92 | " transforms.ToTensor(),\n", 93 | " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n", 94 | "\n", 95 | "data = {}\n", 96 | "loaders = {}\n", 97 | "\n", 98 | "dset_type = 'test'\n", 99 | "test_data = TissueData(root_dir, dset_type, train_log='/gpfs/scratch/bilals01/test-repo/logs/exp6_train.log', transform = transform, metadata=False)\n", 100 | "\n", 101 | "test_loader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False, num_workers=8)\n", 102 | "\n", 103 | "classes = test_data.classes\n", 104 | "class_to_idx = test_data.class_to_idx\n", 105 | "\n", 106 | "print('Class encoding:')\n", 107 | "print(class_to_idx)\n", 108 | "\n" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 5, 114 | "metadata": {}, 115 | "outputs": [ 116 | { 117 | "data": { 118 | "text/plain": [ 119 | "{2: 'TCGA-LUSC'}" 120 | ] 121 | }, 122 | "execution_count": 5, 123 | "metadata": {}, 124 | "output_type": "execute_result" 125 | } 126 | ], 127 | "source": [ 128 | "class_to_idx_invert = {v: k for k, v in class_to_idx.items()}\n", 129 | "class_to_idx_invert" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 6, 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "def get_tile_probability(tile_path):\n", 139 | "\n", 140 | " \"\"\"\n", 141 | " Returns an array of probabilities for each class given a tile\n", 142 | " @param tile_path: Filepath to the tile\n", 143 | " @return: A ndarray of class probabilities for that tile\n", 144 | " \"\"\"\n", 145 | "\n", 146 | " # Some tiles are empty with no path, return nan\n", 147 | " if tile_path == '':\n", 148 | " return np.full(num_classes, np.nan)\n", 149 | "\n", 150 | " tile_path = root_dir + tile_path\n", 151 | "\n", 152 | " with open(tile_path, 'rb') as f:\n", 153 | " with Image.open(f) as img:\n", 154 | " img = img.convert('RGB')\n", 155 | "\n", 156 | " # Model expects a 4D tensor, unsqueeze first dimension\n", 157 | " img = transform(img).unsqueeze(0)\n", 158 | " img = img.cuda()\n", 159 | "\n", 160 | " # Turn output into probabilities with softmax\n", 161 | " var_img = Variable(img, volatile=True)\n", 162 | " output = F.softmax(model(var_img)).data.squeeze(0)\n", 163 | "\n", 164 | " return output.cpu().numpy()\n", 165 | "\n", 166 | "with open(tile_dict_path, 'rb') as f:\n", 167 | " tile_dict = pickle.load(f)\n", 168 | "\n", 169 | " \n", 170 | "def aggregate(file_list, method):\n", 171 | "\n", 172 | " \"\"\"\n", 173 | " Given a list of files, return scores for each class according to the\n", 174 | " method and labels for those files.\n", 175 | " @param file_list: A list of file paths to do predictions on\n", 176 | " @param method: 'average' - returns the average probability score across\n", 177 | " all tiles for that file\n", 178 | " 'max' - predicts each tile to be the class of the maximum\n", 179 | " score, and returns the proportion of tiles for\n", 180 | " each class\n", 181 | " @return: a ndarray of class probabilities for all files in the list\n", 182 | " a ndarray of the labels\n", 183 | " \"\"\"\n", 184 | "\n", 185 | " model.eval()\n", 186 | " predictions = []\n", 187 | " true_labels = []\n", 188 | "\n", 189 | " for file in file_list:\n", 190 | " tile_paths, label = tile_dict[file]\n", 191 | " folder = class_to_idx_invert[label]\n", 192 | "\n", 193 | " def add_folder(tile_path):\n", 194 | " if tile_path == '':\n", 195 | " return ''\n", 196 | " else:\n", 197 | " return folder + '/' + tile_path\n", 198 | "\n", 199 | " # Add the folder for the class name in front\n", 200 | " add_folder_v = np.vectorize(add_folder)\n", 201 | " tile_paths = add_folder_v(tile_paths)\n", 202 | "\n", 203 | " # Get the probability array for the file\n", 204 | " prob_v = np.vectorize(get_tile_probability, otypes=[np.ndarray])\n", 205 | " probabilities = prob_v(tile_paths)\n", 206 | "\n", 207 | " \"\"\"\n", 208 | " imgSize = probabilities.shape()\n", 209 | " newShape = (imgSize[0], imgSize[1], 3)\n", 210 | " probabilities = np.reshape(np.stack(probabilities.flat), newShape)\n", 211 | " \"\"\"\n", 212 | " \n", 213 | " if method == 'average':\n", 214 | " probabilities = np.stack(probabilities.flat)\n", 215 | " prediction = np.nanmean(probabilities, axis = 0)\n", 216 | "\n", 217 | " elif method == 'max':\n", 218 | " probabilities = np.stack(probabilities.flat)\n", 219 | " probabilities = probabilities[~np.isnan(probabilities).all(axis=1)]\n", 220 | " votes = np.nanargmax(probabilities, axis=1) \n", 221 | " out = np.array([sum(votes == i) for i in range(num_classes)])\n", 222 | " prediction = out / out.sum()\n", 223 | "\n", 224 | " else:\n", 225 | " raise ValueError('Method not valid')\n", 226 | "\n", 227 | " predictions.append(prediction)\n", 228 | " true_labels.append(label)\n", 229 | "\n", 230 | " return np.array(predictions), np.array(true_labels)\n", 231 | "\n", 232 | "\n", 233 | "\n", 234 | "class BasicConv2d(nn.Module):\n", 235 | "\n", 236 | " def __init__(self, in_channels, out_channels, pool, **kwargs):\n", 237 | " super(BasicConv2d, self).__init__()\n", 238 | "\n", 239 | " self.pool = pool\n", 240 | " self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)\n", 241 | " self.bn = nn.BatchNorm2d(out_channels, eps=0.001)\n", 242 | " self.relu = nn.LeakyReLU()\n", 243 | " \n", 244 | " self.dropout = nn.Dropout(p=0.1)\n", 245 | "\n", 246 | " def forward(self, x):\n", 247 | " x = self.conv(x)\n", 248 | "\n", 249 | " if self.pool:\n", 250 | " x = F.max_pool2d(x, 2)\n", 251 | " \n", 252 | " x = self.relu(x)\n", 253 | " x = self.bn(x)\n", 254 | " x = self.dropout(x)\n", 255 | " return x\n", 256 | "\n", 257 | "# Define model\n", 258 | "class cancer_CNN(nn.Module):\n", 259 | " def __init__(self, nc, imgSize, ngpu):\n", 260 | " super(cancer_CNN, self).__init__()\n", 261 | " self.nc = nc\n", 262 | " self.imgSize = imgSize\n", 263 | " self.ngpu = ngpu\n", 264 | " #self.data = opt.data\n", 265 | " self.conv1 = BasicConv2d(nc, 16, False, kernel_size=5, padding=1, stride=2, bias=True)\n", 266 | " self.conv2 = BasicConv2d(16, 32, False, kernel_size=3, bias=True)\n", 267 | " self.conv3 = BasicConv2d(32, 64, True, kernel_size=3, padding=1, bias=True)\n", 268 | " self.conv4 = BasicConv2d(64, 64, True, kernel_size=3, padding=1, bias=True)\n", 269 | " self.conv5 = BasicConv2d(64, 128, True, kernel_size=3, padding=1, bias=True)\n", 270 | " self.conv6 = BasicConv2d(128, 64, True, kernel_size=3, padding=1, bias=True)\n", 271 | " self.linear = nn.Linear(5184, num_classes)\n", 272 | "\n", 273 | " def forward(self, x):\n", 274 | " x = self.conv1(x)\n", 275 | " x = self.conv2(x)\n", 276 | " x = self.conv3(x)\n", 277 | " x = self.conv4(x)\n", 278 | " x = self.conv5(x)\n", 279 | " x = self.conv6(x)\n", 280 | " x = x.view(x.size(0), -1)\n", 281 | " x = self.linear(x)\n", 282 | " return x\n", 283 | "\n" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": 7, 289 | "metadata": {}, 290 | "outputs": [ 291 | { 292 | "name": "stderr", 293 | "output_type": "stream", 294 | "text": [ 295 | "/gpfs/share/apps/python/gpu/3.6.5/lib/python3.6/site-packages/ipykernel_launcher.py:24: UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.\n", 296 | "/gpfs/share/apps/python/gpu/3.6.5/lib/python3.6/site-packages/ipykernel_launcher.py:25: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n" 297 | ] 298 | } 299 | ], 300 | "source": [ 301 | "model = cancer_CNN(3, imgSize, 1)\n", 302 | "model.cuda()\n", 303 | "\n", 304 | "model_path = '/gpfs/scratch/bilals01/test-repo/experiments/exp2/checkpoints/epoch_18.pth'\n", 305 | "state_dict = torch.load(model_path)\n", 306 | "model.load_state_dict(state_dict)\n", 307 | "\n", 308 | "\n", 309 | "predictions, labels = aggregate(test_data.filenames, method='average')\n", 310 | "data = np.column_stack((test_data.filenames,np.asarray(predictions),np.asarray(labels)))\n", 311 | "\n" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": 8, 317 | "metadata": {}, 318 | "outputs": [ 319 | { 320 | "data": { 321 | "text/plain": [ 322 | "array([['test_TCGA-66-2782-01A-01-TS1.87ca26b3-ff31-414d-afa8-4d992870128b',\n", 323 | " '0.06785143173991257', '0.0012589603789908174',\n", 324 | " '0.9308896082118011', '2'],\n", 325 | " ['test_TCGA-39-5036-01A-01-TS1.e9596e31-d551-4130-971a-feaaf8b188ad',\n", 326 | " '0.0987874454695041', '0.013180692141556484',\n", 327 | " '0.8880318617412036', '2'],\n", 328 | " ['test_TCGA-63-A5MN-01A-02-TS2.F4F8AAF4-85AC-438D-969D-0AAECAD81F8E',\n", 329 | " '0.039594568926042406', '0.016352557970564698',\n", 330 | " '0.9440528733837735', '2'],\n", 331 | " ['test_TCGA-60-2721-01A-01-BS1.e9a37468-bda9-4485-8545-fdb49a85fe6a',\n", 332 | " '0.050556040752930985', '0.00016863039277334834',\n", 333 | " '0.9492753281188873', '2'],\n", 334 | " ['test_TCGA-85-8666-01A-01-TS1.8bede180-da7c-46e5-b2ab-fd9c0011e66d',\n", 335 | " '0.017455052978947585', '0.06121043121718893',\n", 336 | " '0.9213345140362015', '2'],\n", 337 | " ['test_TCGA-22-1017-01A-01-TS1.9e5d298a-095d-4784-a73f-a92c0be4fe3a',\n", 338 | " '0.08789440218672584', '0.01538824120327584',\n", 339 | " '0.8967173568483502', '2'],\n", 340 | " ['test_TCGA-66-2794-01A-01-TS1.017a3383-b912-4701-9f9c-dbed66d18deb',\n", 341 | " '0.039267165366757964', '0.001547481643985082',\n", 342 | " '0.9591853538939742', '2'],\n", 343 | " ['test_TCGA-43-6143-01A-01-BS1.022a32c1-0fb0-49fa-ae10-4f9cef617bc5',\n", 344 | " '0.011251716462406547', '0.007440678781669186',\n", 345 | " '0.9813076018085403', '2'],\n", 346 | " ['test_TCGA-77-8144-01A-01-BS1.f4edea2d-41c3-45a6-8469-c72e2e7d29dd',\n", 347 | " '0.22300066349232092', '0.0011400880715788997',\n", 348 | " '0.7758592448519388', '2'],\n", 349 | " ['test_TCGA-85-8287-01A-01-TS1.c2eaa43c-44f1-4dae-93e7-eb691541c880',\n", 350 | " '0.013051135039283212', '0.06840057620036889',\n", 351 | " '0.918548287491927', '2'],\n", 352 | " ['test_TCGA-22-1016-01A-01-BS1.36a39a3b-c508-4ad6-9673-161e50e43ebb',\n", 353 | " '0.0789793899845651', '0.0008205008789407887',\n", 354 | " '0.920200111354075', '2'],\n", 355 | " ['test_TCGA-22-4613-01A-01-TS1.8493c28d-f0fe-4431-83ab-d6e648ac1951',\n", 356 | " '0.34124168578821673', '0.00017509260920665762',\n", 357 | " '0.6585832253036733', '2'],\n", 358 | " ['test_TCGA-39-5022-01A-02-BS2.4c32c9bd-80dd-499d-87d2-1516981ea8d4',\n", 359 | " '0.20107539880554007', '0.01408955122044748',\n", 360 | " '0.7848350501399102', '2'],\n", 361 | " ['test_TCGA-85-8582-01A-01-BS1.76e6e03d-39f2-4675-be9d-c566b266aee4',\n", 362 | " '0.1611811345130632', '0.003324958760094061',\n", 363 | " '0.8354939057050405', '2'],\n", 364 | " ['test_TCGA-43-A474-01A-01-TS1.5F9E04AE-51EE-49B5-A3A9-C5DA0391FBC1',\n", 365 | " '0.015525510996921205', '0.02393928667296021',\n", 366 | " '0.9605352024291137', '2'],\n", 367 | " ['test_TCGA-51-4080-01A-01-TS1.a68f7caa-d543-4f6d-a20f-56f845ba867a',\n", 368 | " '0.0019147164576795584', '0.004419346743548181',\n", 369 | " '0.9936659374700384', '2'],\n", 370 | " ['test_TCGA-96-A4JL-01A-01-TSA.3FA7182D-09EC-469A-92E2-EFFDC286913C',\n", 371 | " '0.048104949034871655', '0.029732101060426522',\n", 372 | " '0.9221629494273611', '2'],\n", 373 | " ['test_TCGA-18-3421-01A-01-TS1.de5fc45f-2235-460b-8c7f-d5dd9accb257',\n", 374 | " '0.1314797742172096', '0.007505402157978912',\n", 375 | " '0.8610148248401341', '2'],\n", 376 | " ['test_TCGA-39-5039-01A-01-TS1.94248e36-75ef-4af2-9648-b0e1afaffd02',\n", 377 | " '0.03916454137748189', '0.014407189508222312',\n", 378 | " '0.9464282687779058', '2'],\n", 379 | " ['test_TCGA-22-1005-01A-01-TS1.369217ed-cdd2-4fec-91a5-2db535ac3053',\n", 380 | " '0.04907398624677903', '0.0011902702398143355',\n", 381 | " '0.9497357412319576', '2'],\n", 382 | " ['test_TCGA-85-8351-01A-01-TS1.0f3d1dce-e7d9-467c-89a6-74bc378e7af9',\n", 383 | " '0.011946772245064852', '0.033736315257663414',\n", 384 | " '0.954316911342599', '2'],\n", 385 | " ['test_TCGA-66-2794-01A-01-BS1.ea1eedd8-6763-46db-abe3-4be282c19c6b',\n", 386 | " '0.026690630400231754', '0.0006511605090896269',\n", 387 | " '0.9726582075493921', '2'],\n", 388 | " ['test_TCGA-60-2704-01A-01-BS1.a014ebb5-3532-4c7a-8730-d3020a1f6a93',\n", 389 | " '0.15983368897236533', '0.0013236495472905227',\n", 390 | " '0.838842661881372', '2'],\n", 391 | " ['test_TCGA-22-5477-01A-01-TS1.25b8da50-db12-479d-98ff-da2eaa25cb69',\n", 392 | " '0.15811849092075075', '0.00037261433044065114',\n", 393 | " '0.8415088944803056', '2'],\n", 394 | " ['test_TCGA-66-2782-01A-01-BS1.8b1e40cf-3f84-4aa2-9f92-70f79807cc4c',\n", 395 | " '0.08666484667756971', '0.0010297938361041312',\n", 396 | " '0.9123053591514299', '2'],\n", 397 | " ['test_TCGA-66-2767-01A-01-TS1.6c88cd29-d984-437f-8559-6f73f0cbdc5d',\n", 398 | " '0.1687560066424062', '0.0014976952171607428',\n", 399 | " '0.8297462998840606', '2'],\n", 400 | " ['test_TCGA-77-7338-01A-01-BS1.654045da-6384-4961-ac15-6afb85bf073b',\n", 401 | " '0.06629097077284958', '0.036688846676191315',\n", 402 | " '0.8970201821488812', '2'],\n", 403 | " ['test_TCGA-46-3767-01A-01-BS1.dae112ca-4c8b-4338-8656-b3767f92cde2',\n", 404 | " '0.20322778292937624', '0.0012063394599609774',\n", 405 | " '0.7955658746011776', '2'],\n", 406 | " ['test_TCGA-22-1000-01A-01-BS1.3d74dcf8-2061-49df-b93b-9eb47028bb0d',\n", 407 | " '0.0756149681493031', '0.006650733924009752',\n", 408 | " '0.917734298563272', '2'],\n", 409 | " ['test_TCGA-39-5039-01A-01-BS1.ca14c761-f834-4288-af52-1f32faace0e2',\n", 410 | " '0.06065365811263407', '0.02162540883993418',\n", 411 | " '0.9177209314173729', '2'],\n", 412 | " ['test_TCGA-21-5787-01A-01-BS1.581d1565-9067-4664-ae46-c94dec16f6b3',\n", 413 | " '0.1402148818107994', '0.09162868698574722',\n", 414 | " '0.7681564303976683', '2'],\n", 415 | " ['test_TCGA-22-1011-01A-01-TS1.46fd126b-256f-4a9c-b6de-6ed32a28f6b9',\n", 416 | " '0.08801023164249541', '0.00017024171283984204',\n", 417 | " '0.9118195285959155', '2'],\n", 418 | " ['test_TCGA-90-A59Q-01A-01-TS1.05F08095-9185-4688-814F-DDD1C77F2A39',\n", 419 | " '0.009803155186258562', '0.008822116499347095',\n", 420 | " '0.9813747264029287', '2'],\n", 421 | " ['test_TCGA-39-5022-01A-01-BS1.c573a47e-1680-4bdb-a63a-3891a33c065d',\n", 422 | " '0.44724541076784463', '0.0013958721362135717',\n", 423 | " '0.5513587147292037', '2'],\n", 424 | " ['test_TCGA-21-5787-01A-01-TS1.2920e24c-a22a-4a24-af55-a05f046093f2',\n", 425 | " '0.08528489021408082', '0.06412599034674987',\n", 426 | " '0.8505891207661889', '2'],\n", 427 | " ['test_TCGA-85-7843-01A-01-BS1.8741366f-04cb-45ed-8275-65e1ddc58e0c',\n", 428 | " '0.05083632798381916', '0.004033475157207214',\n", 429 | " '0.9451301959449087', '2'],\n", 430 | " ['test_TCGA-66-2791-01A-01-TS1.a0fd779c-2778-4279-a740-af4035538b06',\n", 431 | " '0.071255563116017', '0.00419589217070724', '0.924548542634584',\n", 432 | " '2'],\n", 433 | " ['test_TCGA-77-A5GA-01A-01-TS1.72DDD4A4-6360-4780-824C-D5F6B197F00A',\n", 434 | " '0.1642027709238759', '0.006078153244353718',\n", 435 | " '0.829719072843882', '2'],\n", 436 | " ['test_TCGA-85-8287-01A-01-BS1.730a979f-7e48-4af4-8191-aa49a13054b7',\n", 437 | " '0.036098796412211365', '0.008026719875630863',\n", 438 | " '0.9558744835755177', '2'],\n", 439 | " ['test_TCGA-22-0940-01A-01-BS1.de73c3ec-bb4d-4a95-b359-2095bacdc28d',\n", 440 | " '0.06906686696338649', '0.03702353189947263',\n", 441 | " '0.893909599026104', '2'],\n", 442 | " ['test_TCGA-60-2695-01A-01-TS1.9a4d6631-8700-4331-8922-fc3fa59ef350',\n", 443 | " '0.011173258044713351', '0.004072309631704953',\n", 444 | " '0.984754433508577', '2'],\n", 445 | " ['test_TCGA-39-5036-01A-01-BS1.22442ecc-912e-4a4b-99f0-717b905b74eb',\n", 446 | " '0.12457853354812591', '0.014292575037000536',\n", 447 | " '0.8611288900892876', '2'],\n", 448 | " ['test_TCGA-46-3767-01A-01-TS1.517904ef-7496-4c8e-a5f6-d5608ba98661',\n", 449 | " '0.1761964823564094', '0.003992300353670075',\n", 450 | " '0.8198112145065476', '2'],\n", 451 | " ['test_TCGA-43-6143-01A-01-TS1.6716067a-f179-46cf-8d38-7f7cd8b25d5d',\n", 452 | " '0.4989744200854699', '0.08071290596770231',\n", 453 | " '0.4203126781526079', '2'],\n", 454 | " ['test_TCGA-66-2778-01A-01-BS1.5d72d9f0-94a7-44c8-a1c6-a09e7ee85b33',\n", 455 | " '0.05048965897103244', '7.839818245816096e-05',\n", 456 | " '0.9494319440907087', '2'],\n", 457 | " ['test_TCGA-22-0940-01A-01-TS1.e6789daf-32c1-4937-a0fc-b307fc728286',\n", 458 | " '0.1934131758644562', '0.001168854644899125',\n", 459 | " '0.8054179686600417', '2'],\n", 460 | " ['test_TCGA-37-5819-01A-01-TS1.598457a2-deb3-4018-ac05-06bfada57374',\n", 461 | " '0.01609545085549734', '0.013873184209280529',\n", 462 | " '0.9700313631558057', '2'],\n", 463 | " ['test_TCGA-21-1071-01A-01-BS1.2a817266-2808-4557-8f7f-0d2c79ffed0d',\n", 464 | " '0.03059375747363793', '0.0008303052487528271',\n", 465 | " '0.9685759381740141', '2'],\n", 466 | " ['test_TCGA-63-7023-01A-01-TS1.cccb8282-147e-4148-b525-6e5dd656b496',\n", 467 | " '0.0202931515566865', '5.885726669195446e-05',\n", 468 | " '0.9796479959641734', '2'],\n", 469 | " ['test_TCGA-33-4586-01A-01-TS1.ed870e19-e471-4be1-8f77-865bc05bb401',\n", 470 | " '0.7960695681382426', '0.002219800023362823',\n", 471 | " '0.2017106336002353', '2'],\n", 472 | " ['test_TCGA-56-8503-01A-01-TS1.6c8bd01a-3b8b-4f09-b98b-70878cab424a',\n", 473 | " '0.0513810579376789', '0.01939243761026569',\n", 474 | " '0.9292265058156087', '2'],\n", 475 | " ['test_TCGA-43-6773-01A-03-BS3.a242cb6f-4b19-454e-a847-90aba0651399',\n", 476 | " '0.4175630972140549', '0.0030636042137404196',\n", 477 | " '0.5793732951365786', '2'],\n", 478 | " ['test_TCGA-85-7843-01A-01-TS1.67c89376-e1ec-4ebd-b96b-68994b952462',\n", 479 | " '0.06919310717132526', '0.005429929135324942',\n", 480 | " '0.9253769629132574', '2'],\n", 481 | " ['test_TCGA-22-1011-01A-01-BS1.e5637aa9-e3af-4541-be66-e24d508a7516',\n", 482 | " '0.18068394799218104', '0.0011894545652740738',\n", 483 | " '0.8181265968398301', '2'],\n", 484 | " ['test_TCGA-22-5491-01A-01-TS1.46659c64-9b3b-4fbb-a6a1-0a894c4730c9',\n", 485 | " '0.12335239062302739', '0.00032279235367955344',\n", 486 | " '0.8763248173626653', '2'],\n", 487 | " ['test_TCGA-51-4080-01A-01-BS1.2a5e8548-9612-4d5b-b309-01388d06f2fa',\n", 488 | " '0.0010138447859034411', '0.0009293212255798874',\n", 489 | " '0.9980568318481905', '2'],\n", 490 | " ['test_TCGA-60-2708-01A-01-TS1.8a5e9d73-ff97-4ee1-8a39-af900beabb93',\n", 491 | " '0.03373031981415466', '0.0003957037720688953',\n", 492 | " '0.9658739695494826', '2'],\n", 493 | " ['test_TCGA-66-2781-01A-01-BS1.dc6170a3-8c58-4164-8ab5-b88c13c91fa1',\n", 494 | " '0.05215480596168599', '0.0029859743476270523',\n", 495 | " '0.9448592196865618', '2'],\n", 496 | " ['test_TCGA-22-1017-01A-01-BS1.11cafdf3-8c27-4dbb-a970-8fe234201ce4',\n", 497 | " '0.1292003247182724', '0.041432410454808805',\n", 498 | " '0.8293672679645712', '2'],\n", 499 | " ['test_TCGA-NC-A5HR-01A-02-TS2.1B2A21A9-E685-461D-A3FF-42A0D9D7FC23',\n", 500 | " '0.05924019818797045', '0.019272554123088755',\n", 501 | " '0.9214872452956736', '2'],\n", 502 | " ['test_TCGA-22-4605-01A-02-BS2.5396cc50-4225-45f3-b0da-a1bd2cb4033d',\n", 503 | " '0.1438984136616629', '0.0012328303508334715',\n", 504 | " '0.8548687569628111', '2'],\n", 505 | " ['test_TCGA-60-2695-01A-01-BS1.815a4b84-3a35-464f-82f0-2c61a4ead7c9',\n", 506 | " '0.13375189962221565', '0.0015723572731511266',\n", 507 | " '0.8646757432344141', '2'],\n", 508 | " ['test_TCGA-66-2767-01A-01-BS1.c35646c2-5db9-431d-a855-8bdfa7aa61c4',\n", 509 | " '0.0762949476792192', '0.0025797856949621046',\n", 510 | " '0.921125268629646', '2'],\n", 511 | " ['test_TCGA-22-4605-01A-01-BS1.863a12bd-4b89-4ed7-8c70-93dec27a21a8',\n", 512 | " '0.06124190413295925', '0.0009174076357925693',\n", 513 | " '0.9378406876026371', '2'],\n", 514 | " ['test_TCGA-18-3421-01A-01-BS1.fbc15d3d-6c4e-4fb5-8ca2-3510a827f101',\n", 515 | " '0.11214197697383887', '0.024350153468995807',\n", 516 | " '0.8635078700190257', '2'],\n", 517 | " ['test_TCGA-77-A5G7-01B-01-TS1.60CBA44D-36B7-4476-A509-1338AA1141C2',\n", 518 | " '0.19906802679944458', '0.028465051121148593',\n", 519 | " '0.772466923831962', '2'],\n", 520 | " ['test_TCGA-NC-A5HH-01A-01-TS1.DD0360BB-775E-46B2-9BF7-F86F2A181CA6',\n", 521 | " '0.22974940641883548', '3.878632162203333e-05',\n", 522 | " '0.7702118060145623', '2'],\n", 523 | " ['test_TCGA-33-4586-01A-01-BS1.1616d67c-e61c-49b7-ad7f-a36bcbc0fa6d',\n", 524 | " '0.0011156422422558723', '0.0005059432519693966',\n", 525 | " '0.9983784172095751', '2'],\n", 526 | " ['test_TCGA-66-2781-01A-01-TS1.db5a4cb1-7755-4119-b071-7b05d2e05fd8',\n", 527 | " '0.09224178470549171', '0.0005733496791384614',\n", 528 | " '0.9071848627584832', '2'],\n", 529 | " ['test_TCGA-52-7810-01A-01-TS1.42a4fb5e-ca1f-4441-bea7-65bfd2eb2f19',\n", 530 | " '0.0012641254676846873', '1.8483512292455374e-05',\n", 531 | " '0.9987173908286624', '2'],\n", 532 | " ['test_TCGA-63-7023-01A-01-BS1.cf277f61-7e00-476e-9e1b-820dc885cf51',\n", 533 | " '0.15519999087294678', '0.0016719845949115038',\n", 534 | " '0.8431280204195865', '2'],\n", 535 | " ['test_TCGA-J1-A4AH-01A-03-TSC.ECE9F5FA-FC5A-4649-9F26-A1835590556E',\n", 536 | " '0.20434232803771504', '0.0033583442041112462',\n", 537 | " '0.7922993280021474', '2'],\n", 538 | " ['test_TCGA-63-A5MM-01A-01-TSA.8D99D26F-6BD0-4F49-ABFB-8F20ACB97C25',\n", 539 | " '0.004650813094208093', '0.0014704902523060821',\n", 540 | " '0.9938786988970877', '2'],\n", 541 | " ['test_TCGA-92-8063-01A-01-TS1.df83d472-73d8-497b-ac82-d15dc5af995d',\n", 542 | " '0.18793747880490955', '0.054497750567708274',\n", 543 | " '0.757564770344859', '2'],\n", 544 | " ['test_TCGA-77-8130-01A-01-BS1.3342f535-14c7-4033-99ad-83831c99a58b',\n", 545 | " '0.03452119484401378', '0.0022108356883148803',\n", 546 | " '0.9632679697404913', '2'],\n", 547 | " ['test_TCGA-51-4079-01A-01-TS1.a9e1a7c9-cd00-4c66-838b-555a5c91c2db',\n", 548 | " '0.011018395116786884', '0.0008941166724750657',\n", 549 | " '0.9880874977483378', '2'],\n", 550 | " ['test_TCGA-18-3414-01A-01-TS1.fc2291c6-270c-4df5-a1ce-5effd58d6642',\n", 551 | " '0.03680392020508077', '0.0013130251970464655',\n", 552 | " '0.9618830599478985', '2'],\n", 553 | " ['test_TCGA-21-5784-01A-01-BS1.8fa2ef29-7a34-4355-a86e-0fb22dc7a69c',\n", 554 | " '0.22452027630869884', '0.0020236183069763823',\n", 555 | " '0.7734561069456893', '2'],\n", 556 | " ['test_TCGA-33-6738-01A-01-TS1.7e519a67-18bc-4166-bb5d-43cc34af150a',\n", 557 | " '0.6015235830144193', '0.022106948605358143',\n", 558 | " '0.37636947657981673', '2'],\n", 559 | " ['test_TCGA-33-A4WN-01A-01-TSA.D8479D83-A232-471D-A122-D9EF08EFE918',\n", 560 | " '0.019166493784446553', '0.03856216295167332',\n", 561 | " '0.9422713425215129', '2'],\n", 562 | " ['test_TCGA-85-A4JB-01A-05-TSE.9AF21BB3-E3E2-47CE-8BF6-025A1BC6FD96',\n", 563 | " '0.021740453320463756', '0.0071422116199065045',\n", 564 | " '0.971117331981659', '2'],\n", 565 | " ['test_TCGA-66-2791-01A-01-BS1.0e33b4ca-0d98-44bb-8085-9a3bdfeb76e5',\n", 566 | " '0.09840367336320795', '0.007903723374587519',\n", 567 | " '0.8936926031190171', '2'],\n", 568 | " ['test_TCGA-77-7142-01A-01-TS1.e2283588-1843-42de-a672-e0b72bc730a4',\n", 569 | " '0.10888070177467714', '0.005431332035975454',\n", 570 | " '0.8856879647398248', '2'],\n", 571 | " ['test_TCGA-NC-A5HO-01A-01-TS1.62FF3313-6DFC-4F64-9E69-9D1C852A57D4',\n", 572 | " '0.015125936438836718', '0.0010675518124606857',\n", 573 | " '0.9838065101429518', '2'],\n", 574 | " ['test_TCGA-77-7142-01A-01-BS1.64627a2a-9ffe-4eef-b61e-f51a37a255e8',\n", 575 | " '0.24814039491763548', '0.001448754346130636',\n", 576 | " '0.7504108522890701', '2'],\n", 577 | " ['test_TCGA-51-4081-01A-01-TS1.03ddb0fe-10d4-49fa-a079-3fb5f5be73b5',\n", 578 | " '0.029331708863285854', '0.0016196513805888239',\n", 579 | " '0.9690486358271705', '2'],\n", 580 | " ['test_TCGA-51-4079-01A-01-BS1.69f4fd87-b5db-431d-bce3-732e19261f79',\n", 581 | " '0.009869463108770148', '4.365540904052106e-05',\n", 582 | " '0.9900868796856603', '2'],\n", 583 | " ['test_TCGA-33-AASL-01A-01-TS1.B055F2AE-7E25-4567-A66C-ABA18A1F7993',\n", 584 | " '0.016703371333355258', '0.003646690987049061',\n", 585 | " '0.9796499345037673', '2'],\n", 586 | " ['test_TCGA-66-2778-01A-02-BS2.8065d383-a69c-4d80-a764-40caad1cab88',\n", 587 | " '0.06394212436100788', '8.712165301368655e-05',\n", 588 | " '0.9359707452248835', '2'],\n", 589 | " ['test_TCGA-77-8130-01A-01-TS1.48a98816-c488-4e48-be20-bb6088f4a7b5',\n", 590 | " '0.242074233002059', '0.001273717037481457',\n", 591 | " '0.7566520472310538', '2'],\n", 592 | " ['test_TCGA-18-3414-01A-01-BS1.160234d0-fa16-438f-96fa-c608d6c1d793',\n", 593 | " '0.021539541225301714', '0.00014174047693710752',\n", 594 | " '0.9783187164803562', '2'],\n", 595 | " ['test_TCGA-60-2704-01A-01-TS1.b534b929-b6d0-42e9-a611-ec80cac09f35',\n", 596 | " '0.24071052522044498', '0.008412645175457999',\n", 597 | " '0.7508768279007391', '2'],\n", 598 | " ['test_TCGA-60-2721-01A-01-TS1.4957a24e-9461-4c84-a44c-07489c996048',\n", 599 | " '0.015183753047595462', '0.00020122616530365722',\n", 600 | " '0.9846150206271992', '2'],\n", 601 | " ['test_TCGA-NC-A5HE-01A-01-TSA.A1F9C4DD-445C-45E4-A2A7-635FDD33CBDF',\n", 602 | " '0.05450001055909848', '0.009591504400095057',\n", 603 | " '0.9359084882640413', '2'],\n", 604 | " ['test_TCGA-77-8128-01A-01-TS1.53aff731-236e-4e9f-8938-30d1b76541f9',\n", 605 | " '0.07536921449205194', '0.007359787306679781',\n", 606 | " '0.9172709943571439', '2'],\n", 607 | " ['test_TCGA-98-A53C-01A-01-TS1.72B02688-C12E-4521-9D5D-D1CAA80730EB',\n", 608 | " '0.3989196199858851', '0.013709482295525507',\n", 609 | " '0.5873708946809891', '2'],\n", 610 | " ['test_TCGA-33-AASB-01A-01-TSA.D1A340FC-96C2-4780-88D3-4EC1DC4964D0',\n", 611 | " '0.2210918357664201', '0.10220036749398408',\n", 612 | " '0.6767077960261921', '2'],\n", 613 | " ['test_TCGA-77-7338-01A-01-TS1.4436243f-fc72-4147-bd18-ac410eb8acb4',\n", 614 | " '0.035428842985433785', '0.046422826737260496',\n", 615 | " '0.9181483256091413', '2'],\n", 616 | " ['test_TCGA-51-4081-01A-01-BS1.e1b79437-906c-4571-988b-46b237ed8a86',\n", 617 | " '0.07055771078914717', '0.00024115023969484747',\n", 618 | " '0.929201140999794', '2'],\n", 619 | " ['test_TCGA-22-1005-01A-01-BS1.c7e59cb0-62ee-4730-a8cb-b81c0f7d6c7a',\n", 620 | " '0.22519951613794548', '0.0007090876773673926',\n", 621 | " '0.7740913920279127', '2'],\n", 622 | " ['test_TCGA-60-2697-01A-01-TS1.7c1ab1f2-fb80-4d23-a31a-d183895a5d28',\n", 623 | " '0.15918543187041223', '0.004255108660714768',\n", 624 | " '0.836559469345957', '2'],\n", 625 | " ['test_TCGA-34-8456-01A-01-BS1.b64ed6d6-db32-486e-bd90-bcc4397d7446',\n", 626 | " '0.11274036011891439', '0.0004404011333008384',\n", 627 | " '0.8868192195892334', '2'],\n", 628 | " ['test_TCGA-52-7810-01A-01-BS1.ce43c735-840c-4f4a-88d3-d79b66d666d8',\n", 629 | " '0.4505572281777859', '0.0004002300473075593',\n", 630 | " '0.5490425109863282', '2']], dtype=' t+1): 348 | differences = [] 349 | for x in range(1, t+1): 350 | differences.append(val_history[-x]-val_history[-(x+1)]) 351 | differences = [y < required_progress for y in differences] 352 | if sum(differences) == t: 353 | return True 354 | else: 355 | return False 356 | else: 357 | return False 358 | 359 | if opt.earlystop: 360 | validation_history = [] 361 | else: 362 | print("No early stopping implemented") 363 | 364 | stop_training = False 365 | 366 | ############################################################################### 367 | 368 | def adjust_learning_rate(optimizer, epoch): 369 | 370 | """Sets the learning rate to the initial LR decayed by 10 every 3 epochs 371 | Function copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py""" 372 | 373 | lr = opt.lr * (0.1 ** (epoch // 3)) # Original 374 | for param_group in optimizer.param_groups: 375 | param_group['lr'] = lr 376 | 377 | ############################################################################### 378 | 379 | """ 380 | Training loop 381 | """ 382 | 383 | best_AUC = 0.0 384 | 385 | print('Starting training') 386 | start = time.time() 387 | local_time = time.ctime(start) 388 | print(local_time) 389 | 390 | print(time.time()) 391 | for epoch in range(1,opt.niter+1): 392 | data_iter = iter(loaders['train']) 393 | i = 0 394 | 395 | if opt.decay_lr: 396 | adjust_learning_rate(optimizer, epoch) 397 | print("Epoch %d :lr = %f" % (epoch, optimizer.state_dict()['param_groups'][0]['lr'])) 398 | while i < len(loaders['train']): 399 | model.train() 400 | img, label = data_iter.next() 401 | i += 1 402 | 403 | # Drop the last batch if it's not the same size as the batchsize 404 | if img.size(0) != opt.batchSize: 405 | break 406 | 407 | if opt.cuda: 408 | img = img.cuda() 409 | label = label.cuda() 410 | 411 | 412 | input_img = Variable(img) 413 | target_label = Variable(label) 414 | 415 | train_loss = criterion(model(input_img), target_label) 416 | #print(model(input_img)[0]) 417 | # Zero gradients then backward pass 418 | optimizer.zero_grad() 419 | train_loss.backward() 420 | 421 | correc=0 422 | total=0 423 | 424 | optimizer.step() 425 | 426 | print('[%d/%d][%d/%d] Training Loss: %f' 427 | % (epoch, opt.niter, i, len(loaders['train']), train_loss.item())) 428 | ii=i+((epoch)*len(loaders['train'])) 429 | #get validation AUC every step_freq 430 | if ii % step_freq == 0: 431 | val_predictions, val_labels = aggregate(data['valid'].filenames, method=opt.method) 432 | 433 | data_ = np.column_stack((data['valid'].filenames,np.asarray(val_predictions),np.asarray(val_labels))) 434 | data_.dump(open('{0}/outputs/val_pred_label_avg_step_{1}.npy'.format(opt.experiment,str(ii)), 'wb')) 435 | torch.save(model.state_dict(), '{0}/checkpoints/step_{1}.pth'.format(opt.experiment, str(ii))) 436 | print('validation scores:') 437 | 438 | roc_auc = get_auc('{0}/images/val_roc_step_{1}.jpg'.format(opt.experiment,epoch), val_predictions, val_labels, classes = range(num_classes)) 439 | for k, v in roc_auc.items(): 440 | if k in range(num_classes): 441 | k = classes[k] 442 | #experiment.log_metric("{0} AUC".format(k), v) 443 | print('%s AUC: %0.4f' % (k, v)) 444 | 445 | #save the checkpoint at every epoch 446 | torch.save(model.state_dict(), '{0}/checkpoints/epoch_{1}.pth'.format(opt.experiment, str(epoch))) 447 | 448 | #print(time.time()) 449 | # Get validation AUC once per epoch 450 | if opt.calc_val_auc: 451 | val_predictions, val_labels = aggregate(data['valid'].filenames, method=opt.method) 452 | data_ = np.column_stack((data['valid'].filenames,np.asarray(val_predictions),np.asarray(val_labels))) 453 | data_.dump(open('{0}/outputs/val_pred_label_avg_epoch_{1}.npy'.format(opt.experiment,str(epoch)), 'wb')) 454 | 455 | roc_auc = get_auc('{0}/images/val_roc_epoch_{1}.jpg'.format(opt.experiment, epoch),val_predictions, val_labels, classes = range(num_classes)) 456 | 457 | for k, v in roc_auc.items(): 458 | if k in range(num_classes): 459 | k = classes[k] 460 | 461 | #experiment.log_metric("{0} AUC".format(k), v) 462 | print('%s AUC: %0.4f' % (k, v)) 463 | 464 | # Stop training if no progress on AUC is being made 465 | if opt.earlystop: 466 | validation_history.append(roc_auc['macro']) 467 | stop_training = early_stop(validation_history) 468 | 469 | if stop_training: 470 | print("Early stop triggered") 471 | break 472 | 473 | 474 | epoch_time = time.time() 475 | local_time = time.ctime(epoch_time) 476 | print(local_time) 477 | 478 | # Final evaluation 479 | print('Finished training, best AUC: %0.4f' % (best_AUC)) 480 | end = time.time() 481 | print(end-start) 482 | -------------------------------------------------------------------------------- /tsne.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from sklearn.manifold import TSNE 3 | import pickle 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.utils.data 10 | import torchvision.datasets as dset 11 | import torchvision.transforms as transforms 12 | import torchvision.utils as vutils 13 | import torch.nn.init as init 14 | from torch.autograd import Variable 15 | import argparse 16 | import copy 17 | import numpy as np 18 | import time 19 | import os 20 | from PIL import Image 21 | from utils.dataloader import * 22 | from utils import new_transforms 23 | 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--root_dir', type=str, default='TilesSorted/', help='Data directory .../dataTilesSorted/') 27 | parser.add_argument('--num_class', type=int, default=2, help='number of classes ') 28 | parser.add_argument('--tile_dict_path', type=str, default='"_FileMappingDict.p', help='Tile dictinory path') 29 | parser.add_argument('--val', type=str, default='test', help='validation set') 30 | 31 | opt = parser.parse_args() 32 | 33 | root_dir = str(opt.root_dir) 34 | num_classes = int(opt.num_class) 35 | tile_dict_path = str(opt.tile_dict_path) 36 | 37 | test_val = str(opt.val) 38 | 39 | imgSize = 299 40 | 41 | transform = transforms.Compose([new_transforms.Resize((imgSize,imgSize)), 42 | transforms.ToTensor(), 43 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 44 | #use Tissuedata2 for ownsampled data, for test use the whole data 45 | test_data = TissueData(root_dir, test_val, transform = transform, metadata=False) 46 | 47 | 48 | #os.chdir("/scratch/sb3923/deep-cancer/tsne_figures/") 49 | #pickle.dump( test_data.filenames, open( "test_data.p", "wb" ) ) 50 | 51 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False) 52 | classes = test_data.classes 53 | 54 | class_to_idx = test_data.class_to_idx 55 | 56 | print('Class encoding:') 57 | print(class_to_idx) 58 | 59 | print('after tissuedata') 60 | 61 | def get_tile_probability(tile_path): 62 | 63 | """ 64 | Returns an array of probabilities for each class given a tile 65 | 66 | @param tile_path: Filepath to the tile 67 | @return: A ndarray of class probabilities for that tile 68 | """ 69 | 70 | # Some tiles are empty with no path, return nan 71 | if tile_path == '': 72 | return np.full(num_classes, np.nan) 73 | 74 | tile_path = root_dir + tile_path 75 | 76 | with open(tile_path, 'rb') as f: 77 | with Image.open(f) as img: 78 | img = img.convert('RGB') 79 | 80 | # Model expects a 4D tensor, unsqueeze first dimension 81 | img = transform(img).unsqueeze(0) 82 | 83 | # Turn output into probabilities with softmax 84 | var_img = Variable(img, volatile=True) 85 | output = F.softmax(model(var_img)[0]).data.squeeze(0) 86 | return output.numpy() 87 | 88 | 89 | def get_tile_probability2(tile_path): 90 | """ 91 | Returns an array of probabilities for each class given a tile 92 | 93 | @param tile_path: Filepath to the tile 94 | @return: A ndarray of class probabilities for that tile 95 | """ 96 | # Some tiles are empty with no path, return nan 97 | if tile_path == '': 98 | return np.full(5184, np.nan) 99 | 100 | tile_path = root_dir + tile_path 101 | 102 | with open(tile_path, 'rb') as f: 103 | with Image.open(f) as img: 104 | img = img.convert('RGB') 105 | # Model expects a 4D tensor, unsqueeze first dimension 106 | img = transform(img).unsqueeze(0) 107 | 108 | # Turn output into probabilities with softmax 109 | var_img = Variable(img, volatile=True) 110 | viz = (model(var_img)[1]).squeeze(0)#torch.FloatTensor of size 1x5184 111 | return viz.data.numpy() #numpy.ndarray 112 | 113 | 114 | 115 | with open(tile_dict_path, 'rb') as f: 116 | tile_dict = pickle.load(f) 117 | 118 | 119 | def aggregate(file_list, method): 120 | 121 | """ 122 | Given a list of files, return scores for each class according to the 123 | method and labels for those files. 124 | 125 | @param file_list: A list of file paths to do predictions on 126 | @param method: 'average' - returns the average probability score across 127 | all tiles for that file 128 | 'max' - predicts each tile to be the class of the maximum 129 | score, and returns the proportion of tiles for 130 | each class 131 | 132 | @return: a ndarray of class probabilities for all files in the list 133 | a ndarray of the labels 134 | 135 | """ 136 | 137 | model.eval() 138 | predictions = [] 139 | true_labels = [] 140 | last_layer = [] 141 | file_name = [] 142 | 143 | for file in file_list: 144 | tile_paths, label = tile_dict[file] 145 | 146 | folder = classes[label] 147 | 148 | def add_folder(tile_path): 149 | if tile_path == '': 150 | return '' 151 | else: 152 | return folder + '/' + tile_path 153 | 154 | # Add the folder for the class name in front 155 | add_folder_v = np.vectorize(add_folder) 156 | tile_paths0 = add_folder_v(tile_paths) 157 | 158 | # Get the probability array for the file 159 | prob_v= np.vectorize(get_tile_probability, otypes=[np.ndarray]) 160 | probabilities = prob_v(tile_paths0) 161 | 162 | tile_paths1 = add_folder_v(tile_paths) 163 | 164 | prob_v2 = np.vectorize(get_tile_probability2, otypes=[np.ndarray]) 165 | lastlayer = prob_v2(tile_paths1) 166 | 167 | probabilities = np.stack(probabilities.flat) 168 | prediction = np.nanmean(probabilities, axis = 0) 169 | 170 | tile_label = np.argmax(probabilities,axis=1) 171 | #last layer 172 | lastlayer = np.stack(lastlayer.flat) 173 | a = lastlayer[np.ix_(label == tile_label),:] 174 | 175 | if (np.squeeze(a)).ndim>1: 176 | lastlayerweights = np.nanmean(np.squeeze(a), axis = 0) 177 | #lastlayer = np.stack(lastlayer.flat) 178 | #lastlayerweights = np.nanmean(lastlayer, axis = 0) 179 | 180 | predictions.append(prediction) 181 | true_labels.append(label) 182 | last_layer.append(lastlayerweights) 183 | file_name.append(file) 184 | 185 | return np.array(predictions), np.array(true_labels), np.array(last_layer),np.array(file_name) 186 | #return np.array(true_labels), np.array(last_layer) 187 | 188 | 189 | 190 | class BasicConv2d(nn.Module): 191 | 192 | def __init__(self, in_channels, out_channels, pool, **kwargs): 193 | super(BasicConv2d, self).__init__() 194 | 195 | self.pool = pool 196 | self.conv = nn.Conv2d(in_channels, out_channels, **kwargs) 197 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 198 | self.relu = nn.LeakyReLU() 199 | 200 | self.dropout = nn.Dropout(p=0.1) 201 | 202 | def forward(self, x): 203 | x = self.conv(x) 204 | 205 | if self.pool: 206 | x = F.max_pool2d(x, 2) 207 | 208 | x = self.relu(x) 209 | x = self.bn(x) 210 | x = self.dropout(x) 211 | return x 212 | 213 | # Define model 214 | class cancer_CNN(nn.Module): 215 | def __init__(self, nc, imgSize, ngpu): 216 | super(cancer_CNN, self).__init__() 217 | self.nc = nc 218 | self.imgSize = imgSize 219 | self.ngpu = ngpu 220 | self.data = 'all' 221 | self.conv1 = BasicConv2d(nc, 16, False, kernel_size=5, padding=1, stride=2, bias=True) 222 | self.conv2 = BasicConv2d(16, 32, False, kernel_size=3, bias=True) 223 | self.conv3 = BasicConv2d(32, 64, True, kernel_size=3, padding=1, bias=True) 224 | self.conv4 = BasicConv2d(64, 64, True, kernel_size=3, padding=1, bias=True) 225 | self.conv5 = BasicConv2d(64, 128, True, kernel_size=3, padding=1, bias=True) 226 | self.conv6 = BasicConv2d(128, 64, True, kernel_size=3, padding=1, bias=True) 227 | self.linear = nn.Linear(5184, num_classes) 228 | 229 | def forward(self, x): 230 | x = self.conv1(x) 231 | x = self.conv2(x) 232 | x = self.conv3(x) 233 | x = self.conv4(x) 234 | x = self.conv5(x) 235 | x = self.conv6(x) 236 | x = x.view(x.size(0), -1) 237 | llw=x 238 | x = self.linear(x) 239 | return x, llw 240 | 241 | 242 | model = cancer_CNN(3, imgSize, 1) 243 | 244 | model_path = "experiments/" + opt.experiment + '/' + opt.model 245 | state_dict = torch.load(model_path, map_location=lambda storage, loc: storage) 246 | model.load_state_dict(state_dict) 247 | 248 | predictions, labels, fw_lastlayer, file_names = aggregate(test_data.filenames, method='average') 249 | print('------------------------------------------------------') 250 | print('last-layer') 251 | print(fw_lastlayer) 252 | 253 | finalWs = fw_lastlayer 254 | os.chdir("tsne_data/") 255 | pickle.dump( finalWs, open( "finalWs_all_3.p", "wb" ) ) 256 | pickle.dump( predictions, open( "predictions_all_3.p", "wb" ) ) 257 | pickle.dump( labels, open( "labels_all_3.p", "wb" ) ) 258 | pickle.dump( file_names, open( "file_names_3.p", "wb" ) ) 259 | 260 | #A function provided by Google in one of their Tensorflow tutorials 261 | #for visualizing data with t-SNE by plotting it to a graph. 262 | 263 | def plot_with_labels(lowDWeights, labels, filename='tsne.png'): 264 | assert lowDWeights.shape[0] >= len(labels), "More labels than weights" 265 | plt.figure(figsize=(20, 20)) #in inches 266 | for i, label in enumerate(labels): 267 | x, y = lowDWeights[i,:] 268 | plt.scatter(x, y) 269 | plt.annotate(label, 270 | xy=(x, y), 271 | xytext=(5, 2), 272 | textcoords='offset points', 273 | ha='right', 274 | va='bottom') 275 | 276 | plt.savefig(filename) 277 | 278 | 279 | tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000) 280 | plot_only = 500 281 | lowDWeights = tsne.fit_transform(fw_lastlayer) 282 | labels = ['0','1','2','3','4','5','6','7','8'] 283 | plot_with_labels(lowDWeights, labels) 284 | 285 | 286 | -------------------------------------------------------------------------------- /utils/.ipynb_checkpoints/Untitled-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 2 6 | } 7 | -------------------------------------------------------------------------------- /utils/CLR.py: -------------------------------------------------------------------------------- 1 | #Preview of CLR scheduler as submitted in: 2 | #https://github.com/pytorch/pytorch/pull/2016 3 | 4 | import numpy as np 5 | from torch.optim.optimizer import Optimizer 6 | 7 | class CyclicLR(object): 8 | """Sets the learning rate of each parameter group according to 9 | cyclical learning rate policy (CLR). The policy cycles the learning 10 | rate between two boundaries with a constant frequency, as detailed in 11 | the paper `Cyclical Learning Rates for Training Neural Networks`_. 12 | The distance between the two boundaries can be scaled on a per-iteration 13 | or per-cycle basis. 14 | Cyclical learning rate policy changes the learning rate after every batch. 15 | `batch_step` should be called after a batch has been used for training. 16 | To resume training, save `last_batch_iteration` and use it to instantiate `CycleLR`. 17 | This class has three built-in policies, as put forth in the paper: 18 | "triangular": 19 | A basic triangular cycle w/ no amplitude scaling. 20 | "triangular2": 21 | A basic triangular cycle that scales initial amplitude by half each cycle. 22 | "exp_range": 23 | A cycle that scales initial amplitude by gamma**(cycle iterations) at each 24 | cycle iteration. 25 | This implementation was adapted from the github repo: `bckenstler/CLR`_ 26 | Args: 27 | optimizer (Optimizer): Wrapped optimizer. 28 | base_lr (float or list): Initial learning rate which is the 29 | lower boundary in the cycle for eachparam groups. 30 | Default: 0.001 31 | max_lr (float or list): Upper boundaries in the cycle for 32 | each parameter group. Functionally, 33 | it defines the cycle amplitude (max_lr - base_lr). 34 | The lr at any cycle is the sum of base_lr 35 | and some scaling of the amplitude; therefore 36 | max_lr may not actually be reached depending on 37 | scaling function. Default: 0.006 38 | step_size (int): Number of training iterations per 39 | half cycle. Authors suggest setting step_size 40 | 2-8 x training iterations in epoch. Default: 2000 41 | mode (str): One of {triangular, triangular2, exp_range}. 42 | Values correspond to policies detailed above. 43 | If scale_fn is not None, this argument is ignored. 44 | Default: 'triangular' 45 | gamma (float): Constant in 'exp_range' scaling function: 46 | gamma**(cycle iterations) 47 | Default: 1.0 48 | scale_fn (function): Custom scaling policy defined by a single 49 | argument lambda function, where 50 | 0 <= scale_fn(x) <= 1 for all x >= 0. 51 | mode paramater is ignored 52 | Default: None 53 | scale_mode (str): {'cycle', 'iterations'}. 54 | Defines whether scale_fn is evaluated on 55 | cycle number or cycle iterations (training 56 | iterations since start of cycle). 57 | Default: 'cycle' 58 | last_batch_iteration (int): The index of the last batch. Default: -1 59 | Example: 60 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 61 | >>> scheduler = torch.optim.CyclicLR(optimizer) 62 | >>> data_loader = torch.utils.data.DataLoader(...) 63 | >>> for epoch in range(10): 64 | >>> for batch in data_loader: 65 | >>> scheduler.batch_step() 66 | >>> train_batch(...) 67 | .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186 68 | .. _bckenstler/CLR: https://github.com/bckenstler/CLR 69 | """ 70 | 71 | def __init__(self, optimizer, base_lr=1e-3, max_lr=6e-3, 72 | step_size=2000, mode='triangular', gamma=1., 73 | scale_fn=None, scale_mode='cycle', last_batch_iteration=-1): 74 | 75 | if not isinstance(optimizer, Optimizer): 76 | raise TypeError('{} is not an Optimizer'.format( 77 | type(optimizer).__name__)) 78 | self.optimizer = optimizer 79 | 80 | if isinstance(base_lr, list) or isinstance(base_lr, tuple): 81 | if len(base_lr) != len(optimizer.param_groups): 82 | raise ValueError("expected {} base_lr, got {}".format( 83 | len(optimizer.param_groups), len(base_lr))) 84 | self.base_lrs = list(base_lr) 85 | else: 86 | self.base_lrs = [base_lr] * len(optimizer.param_groups) 87 | 88 | if isinstance(max_lr, list) or isinstance(max_lr, tuple): 89 | if len(max_lr) != len(optimizer.param_groups): 90 | raise ValueError("expected {} max_lr, got {}".format( 91 | len(optimizer.param_groups), len(max_lr))) 92 | self.max_lrs = list(max_lr) 93 | else: 94 | self.max_lrs = [max_lr] * len(optimizer.param_groups) 95 | 96 | self.step_size = step_size 97 | 98 | if mode not in ['triangular', 'triangular2', 'exp_range'] \ 99 | and scale_fn is None: 100 | raise ValueError('mode is invalid and scale_fn is None') 101 | 102 | self.mode = mode 103 | self.gamma = gamma 104 | 105 | if scale_fn is None: 106 | if self.mode == 'triangular': 107 | self.scale_fn = self._triangular_scale_fn 108 | self.scale_mode = 'cycle' 109 | elif self.mode == 'triangular2': 110 | self.scale_fn = self._triangular2_scale_fn 111 | self.scale_mode = 'cycle' 112 | elif self.mode == 'exp_range': 113 | self.scale_fn = self._exp_range_scale_fn 114 | self.scale_mode = 'iterations' 115 | else: 116 | self.scale_fn = scale_fn 117 | self.scale_mode = scale_mode 118 | 119 | self.batch_step(last_batch_iteration + 1) 120 | self.last_batch_iteration = last_batch_iteration 121 | 122 | def batch_step(self, batch_iteration=None): 123 | if batch_iteration is None: 124 | batch_iteration = self.last_batch_iteration + 1 125 | self.last_batch_iteration = batch_iteration 126 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 127 | param_group['lr'] = lr 128 | 129 | def _triangular_scale_fn(self, x): 130 | return 1. 131 | 132 | def _triangular2_scale_fn(self, x): 133 | return 1 / (2. ** (x - 1)) 134 | 135 | def _exp_range_scale_fn(self, x): 136 | return self.gamma**(x) 137 | 138 | def get_lr(self): 139 | step_size = float(self.step_size) 140 | cycle = np.floor(1 + self.last_batch_iteration / (2 * step_size)) 141 | x = np.abs(self.last_batch_iteration / step_size - 2 * cycle + 1) 142 | 143 | lrs = [] 144 | param_lrs = zip(self.optimizer.param_groups, self.base_lrs, self.max_lrs) 145 | for param_group, base_lr, max_lr in param_lrs: 146 | base_height = (max_lr - base_lr) * np.maximum(0, (1 - x)) 147 | if self.scale_mode == 'cycle': 148 | lr = base_lr + base_height * self.scale_fn(cycle) 149 | else: 150 | lr = base_lr + base_height * self.scale_fn(self.last_batch_iteration) 151 | lrs.append(lr) 152 | return lrs -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sedab/PathCNN/3208c552ee2113a0cc85363373318ce175e76c71/utils/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/auc.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sedab/PathCNN/3208c552ee2113a0cc85363373318ce175e76c71/utils/__pycache__/auc.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/auc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sedab/PathCNN/3208c552ee2113a0cc85363373318ce175e76c71/utils/__pycache__/auc.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/auc_all_class.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sedab/PathCNN/3208c552ee2113a0cc85363373318ce175e76c71/utils/__pycache__/auc_all_class.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/auc_test.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sedab/PathCNN/3208c552ee2113a0cc85363373318ce175e76c71/utils/__pycache__/auc_test.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataloader.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sedab/PathCNN/3208c552ee2113a0cc85363373318ce175e76c71/utils/__pycache__/dataloader.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataloader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sedab/PathCNN/3208c552ee2113a0cc85363373318ce175e76c71/utils/__pycache__/dataloader.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/eval.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sedab/PathCNN/3208c552ee2113a0cc85363373318ce175e76c71/utils/__pycache__/eval.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/eval.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sedab/PathCNN/3208c552ee2113a0cc85363373318ce175e76c71/utils/__pycache__/eval.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/new_transforms.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sedab/PathCNN/3208c552ee2113a0cc85363373318ce175e76c71/utils/__pycache__/new_transforms.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/new_transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sedab/PathCNN/3208c552ee2113a0cc85363373318ce175e76c71/utils/__pycache__/new_transforms.cpython-36.pyc -------------------------------------------------------------------------------- /utils/auc.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import torch 4 | import numpy as np 5 | from sklearn.metrics import roc_curve, auc 6 | from sklearn.preprocessing import label_binarize 7 | from scipy import interp 8 | from itertools import cycle 9 | import matplotlib.pyplot as plt 10 | from PIL import Image 11 | 12 | def get_auc(path, predictions, labels, classes=[0, 1, 2]): 13 | """ 14 | Given predictions and labels, return the AUCs for all classes 15 | and micro, macro AUCs. Also saves a plot of the ROC curve to the 16 | path. 17 | 18 | """ 19 | 20 | fpr = dict() 21 | tpr = dict() 22 | roc_auc = dict() 23 | 24 | 25 | if len(classes) > 2: 26 | # Convert labels to one-hot-encoding 27 | labels = label_binarize(labels, classes = classes) 28 | 29 | ### Individual class AUC ### 30 | for i in classes: 31 | fpr[i], tpr[i], _ = roc_curve(labels[:, i], predictions[:, i]) 32 | roc_auc[i] = auc(fpr[i], tpr[i]) 33 | 34 | ### Micro AUC ### 35 | fpr["micro"], tpr["micro"], _ = roc_curve(labels.ravel(), predictions.ravel()) 36 | roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) 37 | 38 | ### Macro AUC ### 39 | all_fpr = np.unique(np.concatenate([fpr[i] for i in classes])) 40 | mean_tpr = np.zeros_like(all_fpr) 41 | 42 | number_class=0 43 | for i in classes: 44 | ip = interp(all_fpr, fpr[i], tpr[i]) 45 | if(~np.isnan(ip).any()): 46 | mean_tpr += ip 47 | number_class += 1 48 | 49 | 50 | mean_tpr /= number_class 51 | 52 | fpr["macro"] = all_fpr 53 | tpr["macro"] = mean_tpr 54 | roc_auc["macro"] = auc(fpr["macro"], tpr["macro"]) 55 | 56 | ### Make plot ### 57 | 58 | plt.figure(figsize=(12, 12)) 59 | plt.plot(fpr["micro"], tpr["micro"], 60 | label='micro-average ROC curve (area = {0:0.2f})' 61 | ''.format(roc_auc["micro"]), 62 | color='deeppink', linestyle=':', linewidth=4) 63 | 64 | plt.plot(fpr["macro"], tpr["macro"], 65 | label='macro-average ROC curve (area = {0:0.2f})' 66 | ''.format(roc_auc["macro"]), 67 | color='navy', linestyle=':', linewidth=4) 68 | 69 | colors = cycle(['aqua', 'darkorange', 'cornflowerblue']) 70 | for i, color in zip(classes, colors): 71 | plt.plot(fpr[i], tpr[i], color=color, lw=2, 72 | label='ROC curve of class {0} (area = {1:0.2f})' 73 | ''.format(i, roc_auc[i])) 74 | else: 75 | fpr, tpr, _ = roc_curve(labels, predictions[:,1]) 76 | auc_result = auc(fpr, tpr) 77 | 78 | for i in list(classes) + ['macro', 'micro']: 79 | roc_auc[i] = auc_result 80 | 81 | plt.figure(figsize=(12, 12)) 82 | plt.plot(fpr, tpr, lw=2, 83 | label='ROC curve (area = {0:0.2f})' 84 | ''.format(auc_result)) 85 | 86 | plt.plot([0, 1], [0, 1], 'k--', lw=2) 87 | plt.xlim([0.0, 1.0]) 88 | plt.ylim([0.0, 1.05]) 89 | plt.xlabel('False Positive Rate') 90 | plt.ylabel('True Positive Rate') 91 | plt.title('ROC Curve') 92 | plt.legend(loc="lower right") 93 | plt.savefig(path) 94 | 95 | return roc_auc#, fpr, tpr 96 | -------------------------------------------------------------------------------- /utils/confusion_matrix.py: -------------------------------------------------------------------------------- 1 | 2 | from sklearn import metrics 3 | import matplotlib.pyplot as plt 4 | import itertools 5 | 6 | def plot_confusion_matrix( true_values , predicted_values, classes = ['Normal tissue', 'TCGA-LUAD', 'TCGA-LUSC'] , 7 | normalize=False, 8 | title='Confusion matrix', 9 | cmap=plt.cm.Blues, 10 | fig_name = 'cancer_plot.png'): 11 | """ 12 | This function prints and plots the confusion matrix. 13 | Normalization can be applied by setting `normalize=True`. 14 | """ 15 | cm = metrics.confusion_matrix( true_values , predicted_values , labels = classes ) 16 | 17 | plt.figure(figsize=(20,10)) 18 | plt.imshow(cm, interpolation='nearest', cmap=cmap) 19 | plt.title(title) 20 | plt.colorbar() 21 | tick_marks = np.arange(len(classes)) 22 | plt.xticks(tick_marks, classes, rotation=45) 23 | plt.yticks(tick_marks, classes) 24 | 25 | if normalize: 26 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 27 | print("Normalized confusion matrix") 28 | else: 29 | print('Confusion matrix, without normalization') 30 | 31 | print(cm) 32 | 33 | thresh = cm.max() / 2. 34 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 35 | plt.text(j, i, cm[i, j], 36 | horizontalalignment="center", 37 | color="white" if cm[i, j] > thresh else "black") 38 | 39 | plt.tight_layout() 40 | plt.ylabel('True label') 41 | plt.xlabel('Predicted label') 42 | 43 | plt.savefig(fig_name) 44 | 45 | 46 | #Example: 47 | # plot_confusion_matrix( true_values , predicted_values , classes = ['Normal tissue', 'TCGA-LUAD', 'TCGA-LUSC'] , 48 | # title='Confusion matrix, without normalization') 49 | -------------------------------------------------------------------------------- /utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.transforms as transforms 4 | import torch.utils.data as data 5 | import pickle 6 | import numpy as np 7 | from PIL import Image 8 | 9 | def pil_loader(path): 10 | with open(path, 'rb') as f: 11 | with Image.open(f) as img: 12 | return img.convert('RGB') 13 | 14 | def find_classes(dir): 15 | # Classes are subdirectories of the root directory 16 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 17 | classes.sort() 18 | class_to_idx = {classes[i]: i for i in range(len(classes))} 19 | return classes, class_to_idx 20 | 21 | #get the train log file 22 | def get_class_coding(lf): 23 | auc_new = [] 24 | phrase = "Class encoding:" 25 | 26 | with open(lf, 'r+') as f: 27 | lines = f.readlines() 28 | for i in range(0, len(lines)): 29 | line = lines[i] 30 | #print(line) 31 | if phrase in line: 32 | class_encoding = lines[i + 1] # you may want to check that i < len(lines) 33 | break 34 | 35 | class_encoding = class_encoding.strip('\n').strip('{').strip('}') 36 | #print(class_encoding) 37 | 38 | class_names = [] 39 | class_codes = [] 40 | 41 | for c in class_encoding.split(','): 42 | #print(c) 43 | class_names.append(c.split(':')[0].replace("'", "").replace(" ", ""))#.split('-')[-1]) 44 | class_codes.append(int(c.split(':')[1])) 45 | 46 | 47 | class_coding = {} 48 | for i in range(len(class_names)): 49 | class_coding[class_codes[i]] = class_names[i] 50 | 51 | class_codes.sort() 52 | return class_names, class_codes, class_coding 53 | 54 | 55 | class TissueData(data.Dataset): 56 | def __init__(self, root, dset_type, train_log, data='lung', transform=None, metadata=False): 57 | 58 | 59 | #classes, class_to_idx = find_classes(root) 60 | if train_log != '': 61 | c_names, c_codes, c_coding = get_class_coding(train_log) 62 | c_coding_invert = {v: k for k, v in c_coding.items()} 63 | classes, _ = find_classes(root) 64 | class_to_idx = {} 65 | for n in classes: 66 | class_to_idx[n] = c_coding_invert[n] 67 | #print(class_to_idx) 68 | 69 | else: 70 | classes, class_to_idx = find_classes(root) 71 | 72 | 73 | self.data = data 74 | 75 | #json_dict_path = '' 76 | #with open(json_dict_path, 'rb') as f: 77 | # self.json = pickle.load(f) 78 | 79 | self.root = root 80 | self.classes = classes 81 | self.class_to_idx = class_to_idx 82 | self.metadata = metadata 83 | self.datapoints, self.filenames = self.make_dataset(root, dset_type, class_to_idx) 84 | self.transform = transform 85 | 86 | def parse_json(self, fname): 87 | json = self.json[fname] 88 | 89 | if self.data == 'breast': 90 | return [] 91 | elif self.data == 'kidney': 92 | return [] 93 | else: 94 | age, cigarettes, gender = json['age_at_diagnosis'], json['cigarettes_per_day'], json['gender'] 95 | return [age, cigarettes, gender] 96 | 97 | def make_dataset(self, dir, dset_type, class_to_idx): 98 | datapoints = [] 99 | filenames = [] 100 | 101 | dir = os.path.expanduser(dir) 102 | 103 | for target in os.listdir(dir): 104 | d = os.path.join(dir, target) 105 | if not os.path.isdir(d): 106 | continue 107 | print('Loading from:', target) 108 | dd=[] 109 | for root, _, fnames in os.walk(d): 110 | for fname in fnames: 111 | # Parse the filename 112 | dataset_type = fname.strip('.jpeg').split('_')[0] 113 | y = fname.strip('.jpeg').split('_')[-1] 114 | x = fname.strip('.jpeg').split('_')[-2] 115 | raw_file = '_'.join(fname.strip('.jpeg').split('_')[1:-2]) 116 | #dataset_type, raw_file, x, y = fname.strip('.jpeg').split('_') 117 | 118 | raw_file_name = dset_type + '_' + raw_file 119 | original_file = raw_file + '.svs' 120 | 121 | # Only add it if it's the correct dset_type (train, valid, test) 122 | if fname.endswith(".jpeg") and dataset_type == dset_type: 123 | path = os.path.join(root, fname) 124 | 125 | if self.metadata: 126 | item = (path, self.parse_json(original_file) + [int(x), int(y)], class_to_idx[target]) 127 | else: 128 | item = (path, class_to_idx[target]) 129 | 130 | datapoints.append(item) 131 | dd.append(item) 132 | 133 | if raw_file_name not in filenames: 134 | filenames.append(raw_file_name) 135 | print('number of samples:', len(dd)) 136 | 137 | return datapoints, filenames 138 | 139 | def __getitem__(self, index): 140 | """ 141 | Args: 142 | index (int): Index 143 | Returns: 144 | tuple: (img + concatenated extra info, label) for the given index 145 | """ 146 | 147 | if self.metadata: 148 | filepath, info, label = self.datapoints[index] 149 | else: 150 | filepath, label = self.datapoints[index] 151 | 152 | # Load image from filepath 153 | img = pil_loader(filepath) 154 | 155 | if self.transform is not None: 156 | img = self.transform(img) 157 | 158 | if self.metadata: 159 | # Reshape extra info, then concatenate to image as extra channels 160 | info = np.array(info) 161 | info_length = len(info) 162 | height, width = img.size(1), img.size(2) 163 | reshaped = torch.FloatTensor(np.repeat(info, height*width).reshape((len(info), height, width))) 164 | output = torch.cat((img, reshaped), 0) 165 | else: 166 | output = img 167 | 168 | return output, label 169 | 170 | def __len__(self): 171 | return len(self.datapoints) 172 | 173 | 174 | -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import roc_curve, auc, roc_auc_score, cohen_kappa_score, jaccard_similarity_score, log_loss,recall_score, precision_score 3 | from sklearn.preprocessing import label_binarize 4 | from scipy import interp 5 | from itertools import cycle 6 | import matplotlib.pyplot as plt 7 | import os 8 | from itertools import chain 9 | 10 | 11 | #confidence interval 12 | def get_error(pred, true, classes=[0, 1, 2]): 13 | """ 14 | Given predictions and labels, return the confidence intervals for all classes 15 | and micro, macro AUCs. 16 | 17 | """ 18 | 19 | num_class=len(classes) 20 | 21 | n_bootstraps = 1000 22 | rng_seed = {} 23 | bootstrapped_scores = {} 24 | sorted_scores = {} 25 | confidence_lower = {} 26 | confidence_upper = {} 27 | rng = {} 28 | indices = {} 29 | score = {} 30 | 31 | #results will be saved at 32 | all_cl = {} 33 | all_cu = {} 34 | 35 | if num_class>2: 36 | true = label_binarize(true, classes = classes) 37 | else: 38 | true = label_binarize(true, classes = [0,1,0]) 39 | true = np.hstack((true, 1 - true)) 40 | 41 | # control reproducibility 42 | seed=199 43 | rng_seed[0]= seed 44 | for c in range(0,num_class+1): 45 | rng[c]=np.random.RandomState(rng_seed[c]) 46 | seed=seed+1000 47 | rng_seed[c+1]= seed 48 | 49 | 50 | true_all=true.ravel() 51 | pred_all=pred.ravel() 52 | 53 | #initilize the bootsrapped scores 54 | for c in range(0,num_class): 55 | bootstrapped_scores[c]=[] 56 | bootstrapped_scores['micro']=[] 57 | 58 | for i in range(n_bootstraps): 59 | 60 | 61 | # bootstrap by sampling with replacement on the prediction indices 62 | indices[0] = rng[c].random_integers(0, len(pred[:,0]) - 1, len(pred[:,0])) 63 | 64 | if num_class>2: 65 | for c in range(1,num_class): 66 | indices[c] = rng[c].random_integers(0, len(pred[:,c]) - 1, len(pred[:,c])) 67 | if len(np.unique(true[indices[c],c])) < 2: 68 | continue 69 | 70 | indices['micro'] = rng[num_class].random_integers(0, len(pred_all) - 1, len(pred_all)) 71 | 72 | try: 73 | score[0]= roc_auc_score(true[indices[0],0], pred[indices[0],0]) 74 | bootstrapped_scores[0].append(score[0]) 75 | except ValueError: 76 | pass 77 | 78 | 79 | #ValueError: Only one class present in y_true. ROC AUC score is not defined in that case. 80 | if num_class>2: 81 | for c in range(1,num_class): 82 | try: 83 | score[c]= roc_auc_score(true[indices[c],c], pred[indices[c],c]) 84 | bootstrapped_scores[c].append(score[c]) 85 | except ValueError: 86 | pass 87 | 88 | score_micro = roc_auc_score(true_all[indices['micro']], pred_all[indices['micro']]) 89 | bootstrapped_scores['micro'].append(score_micro) 90 | 91 | 92 | if num_class>2: 93 | 94 | for c in range(0,num_class): 95 | if(len(bootstrapped_scores[c])>0): 96 | sorted_scores[c] = np.array(bootstrapped_scores[c]) 97 | sorted_scores[c].sort() 98 | confidence_lower[c] = sorted_scores[c][int(0.05 * len(sorted_scores[c]))] 99 | confidence_upper[c] = sorted_scores[c][int(0.95 * len(sorted_scores[c]))] 100 | else: 101 | confidence_lower[c] = np.nan 102 | confidence_upper[c] = np.nan 103 | 104 | 105 | 106 | #micro 107 | sorted_scores['micro']=np.array(bootstrapped_scores['micro']) #!!!!!!!!!!!! 108 | sorted_scores['micro'].sort()# does sort work like this? 109 | confidence_lower['micro'] = sorted_scores['micro'][int(0.05 * len(sorted_scores['micro']))] 110 | confidence_upper['micro'] = sorted_scores['micro'][int(0.95 * len(sorted_scores['micro']))] 111 | 112 | #macro 113 | all_bs = [] 114 | for c in range(0,num_class): 115 | all_bs.append(bootstrapped_scores[c]) 116 | 117 | 118 | sorted_scores['macro']=np.array(list(chain.from_iterable(all_bs))) 119 | sorted_scores['macro'].sort() 120 | 121 | if (len(sorted_scores['macro'])>0): 122 | confidence_lower['macro'] = sorted_scores['macro'][int(0.05 * len(sorted_scores['macro']))] 123 | confidence_upper['macro'] = sorted_scores['macro'][int(0.95 * len(sorted_scores['macro']))] 124 | else: 125 | confidence_lower['macro'] = np.nan 126 | confidence_upper['macro'] = np.nan 127 | 128 | 129 | for c in range(0,num_class): 130 | all_cl[c] = confidence_lower[c] 131 | all_cu[c] = confidence_upper[c] 132 | all_cl['micro'] = confidence_lower['micro'] 133 | all_cu['micro'] = confidence_upper['micro'] 134 | all_cl['macro'] = confidence_lower['macro'] 135 | all_cu['macro'] = confidence_upper['macro'] 136 | 137 | else: 138 | sorted_scores[0] = np.array(bootstrapped_scores[0]) 139 | sorted_scores[0].sort() 140 | confidence_lower[0] = sorted_scores[0][int(0.05 * len(sorted_scores[0]))] 141 | confidence_upper[0] = sorted_scores[0][int(0.95 * len(sorted_scores[0]))] 142 | all_cl[0] = confidence_lower[0] 143 | all_cu[0] = confidence_upper[0] 144 | 145 | return all_cl, all_cu 146 | 147 | 148 | 149 | 150 | def get_auc(predictions, labels, class_names, classes=[0, 1, 2]): 151 | """ 152 | Given predictions and labels, return the AUCs for all classes 153 | and micro, macro AUCs. 154 | 155 | """ 156 | fpr = dict() 157 | tpr = dict() 158 | roc_auc = dict() 159 | label=labels 160 | cu =[] 161 | cl=[] 162 | 163 | ax = plt.figure(figsize=(12, 12)) 164 | 165 | if len(classes) > 2: 166 | # Convert labels to one-hot-encoding 167 | labels = label_binarize(labels, classes = classes) 168 | 169 | ### Individual class AUC ### 170 | for i in classes: 171 | fpr[i], tpr[i], _ = roc_curve(labels[:, i], predictions[:, i]) 172 | try: 173 | #if (np.isnan(fpr[i]).all() | np.isnan(tpr[i]).all()): 174 | roc_auc[i] = auc(fpr[i], tpr[i]) 175 | except ValueError: 176 | roc_auc[i] = np.nan 177 | pass 178 | 179 | cl , cu =get_error(predictions,label,classes) 180 | 181 | ### Micro AUC ### 182 | fpr["micro"], tpr["micro"], _ = roc_curve(labels.ravel(), predictions.ravel()) 183 | roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) 184 | 185 | ### Macro AUC ### 186 | all_fpr = np.unique(np.concatenate([fpr[i] for i in classes])) 187 | mean_tpr = np.zeros_like(all_fpr) 188 | 189 | number_class=0 190 | for i in classes: 191 | ip = interp(all_fpr, fpr[i], tpr[i]) 192 | if(~np.isnan(ip).any()): 193 | mean_tpr += ip 194 | number_class += 1 195 | 196 | mean_tpr /= number_class 197 | 198 | fpr["macro"] = all_fpr 199 | tpr["macro"] = mean_tpr 200 | roc_auc["macro"] = auc(fpr["macro"], tpr["macro"]) 201 | 202 | #other metrics 203 | precision = precision_score(label, np.argmax(predictions,axis=1),average='macro') 204 | recall = recall_score(label, np.argmax(predictions,axis=1), average='macro') 205 | cohenskappa = cohen_kappa_score(label, np.argmax(predictions,axis=1)) 206 | jaccard = jaccard_similarity_score(label, np.argmax(predictions,axis=1)) 207 | logloss = log_loss(label, predictions, labels=classes) 208 | 209 | roc_auc['precision']=precision 210 | roc_auc['recall']=recall 211 | roc_auc['cohenskappa']=cohenskappa 212 | roc_auc['jaccard']=jaccard 213 | roc_auc['logloss']=logloss 214 | 215 | print('AUC:') 216 | print(roc_auc) 217 | print('CU:') 218 | print(cu) 219 | print('CL:') 220 | print(cl) 221 | 222 | ### Make plot ### 223 | 224 | plt.figure(figsize=(12, 12)) 225 | plt.plot(fpr["micro"], tpr["micro"], 226 | label='micro-average: AUC = {0:0.2f} \n CI [{1:0.2f}, {2:0.2f}]' 227 | ''.format(roc_auc['micro'], cl['micro'], cu['micro']), 228 | color='darkorange', linestyle=':', linewidth=4) 229 | 230 | plt.plot(fpr["macro"], tpr["macro"], 231 | label='macro-average: AUC = {0:0.2f} \n CI [{1:0.2f}, {2:0.2f}]' 232 | ''.format(roc_auc['macro'], cl['macro'], cu['macro']), 233 | color='forestgreen', linestyle=':', linewidth=4) 234 | 235 | colors = cycle(['deeppink','navy','aqua','darkorange', 'cornflowerblue']) 236 | for i, color in zip(classes, colors): 237 | plt.plot(fpr[i], tpr[i], color=color, lw=2, 238 | label='{0}: AUC = {1:0.2f} \n CI [{2:0.2f}, {3:0.2f}]' 239 | ''.format(class_names[i], roc_auc[i], cl[i], cu[i])) 240 | 241 | else: 242 | fpr, tpr, _ = roc_curve(labels, predictions[:,1]) 243 | auc_result = auc(fpr, tpr) 244 | 245 | for i in list(classes) + ['macro', 'micro']: 246 | roc_auc[i] = auc_result 247 | 248 | cl , cu =get_error(predictions,label,classes) 249 | 250 | #other metrics 251 | precision = precision_score(label, np.argmax(predictions,axis=1),average='macro') 252 | recall = recall_score(label, np.argmax(predictions,axis=1), average='macro') 253 | cohenskappa = cohen_kappa_score(label, np.argmax(predictions,axis=1)) 254 | jaccard = jaccard_similarity_score(label, np.argmax(predictions,axis=1)) 255 | logloss = log_loss(label, predictions, labels=classes) 256 | 257 | roc_auc['precision']=precision 258 | roc_auc['recall']=recall 259 | roc_auc['cohenskappa']=cohenskappa 260 | roc_auc['jaccard']=jaccard 261 | roc_auc['logloss']=logloss 262 | 263 | print(roc_auc) 264 | print('CU:') 265 | print(cu) 266 | print('CL') 267 | print(cl) 268 | 269 | color='navy' 270 | plt.figure(figsize=(12, 12)) 271 | plt.plot(fpr, tpr, color=color, lw=2, 272 | label='AUC = {0:0.2f} \n CI [{1:0.2f}, {2:0.2f}]' 273 | ''.format(roc_auc[0], cl[0], cu[0])) 274 | 275 | plt.plot([0, 1], [0, 1], 'k--', lw=2) 276 | plt.xlim([0.0, 1.0]) 277 | plt.ylim([0.0, 1.05]) 278 | plt.xlabel('False Positive Rate', fontsize=60) 279 | plt.ylabel('True Positive Rate', fontsize=60) 280 | plt.title('ROC Curve', fontsize=60) 281 | plt.legend(loc="lower right", fontsize=32) 282 | plt.xticks(fontsize=32) 283 | plt.yticks(fontsize=32) 284 | plt.rcParams['axes.linewidth'] = 4 285 | ax.patch.set_edgecolor('black') 286 | 287 | return fpr, tpr, roc_auc, cu, cl 288 | 289 | 290 | -------------------------------------------------------------------------------- /utils/extras/auc_all_class.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import torch 4 | import numpy as np 5 | from sklearn.metrics import roc_curve, auc, roc_auc_score, average_precision_score 6 | from sklearn.preprocessing import label_binarize 7 | from scipy import interp 8 | from itertools import cycle 9 | import matplotlib.pyplot as plt 10 | from PIL import Image 11 | 12 | def get_auc(path, predictions, labels, classes=[0, 1, 2]): 13 | 14 | """ 15 | Given predictions and labels, return the AUCs for all classes 16 | and micro, macro AUCs. Also saves a plot of the ROC curve to the 17 | path. 18 | 19 | """ 20 | 21 | fpr = dict() 22 | tpr = dict() 23 | roc_auc = dict() 24 | 25 | if len(classes) > 2: 26 | # Convert labels to one-hot-encoding 27 | labels = label_binarize(labels, classes = classes) 28 | 29 | ### Individual class AUC ### 30 | for i in classes: 31 | fpr[i], tpr[i], _ = roc_curve(labels[:, i], predictions[:, i]) 32 | roc_auc[i] = auc(fpr[i], tpr[i]) 33 | 34 | 35 | ### Micro AUC ### 36 | fpr["micro"], tpr["micro"], _ = roc_curve(labels.ravel(), predictions.ravel()) 37 | roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) 38 | #roc_auc["micro_cl"]=errors[2] 39 | #roc_auc["micro_cu"]=errors[3] 40 | 41 | 42 | ### Macro AUC ### 43 | all_fpr = np.unique(np.concatenate([fpr[i] for i in classes])) 44 | mean_tpr = np.zeros_like(all_fpr) 45 | for i in classes: 46 | mean_tpr += interp(all_fpr, fpr[i], tpr[i]) 47 | mean_tpr /= len(classes) 48 | 49 | fpr["macro"] = all_fpr 50 | tpr["macro"] = mean_tpr 51 | roc_auc["macro"] = auc(fpr["macro"], tpr["macro"]) 52 | 53 | 54 | ### Make plot ### 55 | plt.figure(figsize=(12, 12)) 56 | plt.plot(fpr["micro"], tpr["micro"], 57 | label='micro-average ROC curve (area = {0:0.2f})' 58 | ''.format(roc_auc["micro"]), 59 | color='deeppink', linestyle=':', linewidth=4) 60 | 61 | plt.plot(fpr["macro"], tpr["macro"], 62 | label='macro-average ROC curve (area = {0:0.2f})' 63 | ''.format(roc_auc["macro"]), 64 | color='navy', linestyle=':', linewidth=4) 65 | 66 | colors = cycle(['aqua', 'darkorange', 'cornflowerblue']) 67 | for i, color in zip(classes, colors): 68 | plt.plot(fpr[i], tpr[i], color=color, lw=2, 69 | label='ROC curve of class {0} (area = {1:0.2f})' 70 | ''.format(i, roc_auc[i])) 71 | 72 | plt.plot([0, 1], [0, 1], 'k--', lw=2) 73 | plt.xlim([0.0, 1.0]) 74 | plt.ylim([0.0, 1.05]) 75 | plt.xlabel('False Positive Rate') 76 | plt.ylabel('True Positive Rate') 77 | plt.title('ROC Curve') 78 | plt.legend(loc="lower right") 79 | plt.savefig(path) 80 | 81 | return roc_auc 82 | -------------------------------------------------------------------------------- /utils/extras/auc_test.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import torch 4 | import numpy as np 5 | from sklearn.metrics import roc_curve, auc 6 | from sklearn.preprocessing import label_binarize 7 | from scipy import interp 8 | from itertools import cycle 9 | import matplotlib.pyplot as plt 10 | from PIL import Image 11 | 12 | 13 | def get_auc(path, predictions, labels,class_names, classes=[0, 1, 2]): 14 | #def get_auc(path, predictions, labels, classes=[0, 1, 2]): 15 | """ 16 | Given predictions and labels, return the AUCs for all classes 17 | and micro, macro AUCs. Also saves a plot of the ROC curve to the 18 | path. 19 | 20 | """ 21 | 22 | fpr = dict() 23 | tpr = dict() 24 | roc_auc = dict() 25 | 26 | 27 | if len(classes) > 2: 28 | # Convert labels to one-hot-encoding 29 | labels = label_binarize(labels, classes = classes) 30 | 31 | ### Individual class AUC ### 32 | for i in classes: 33 | fpr[i], tpr[i], _ = roc_curve(labels[:, i], predictions[:, i]) 34 | roc_auc[i] = auc(fpr[i], tpr[i]) 35 | 36 | ### Micro AUC ### 37 | fpr["micro"], tpr["micro"], _ = roc_curve(labels.ravel(), predictions.ravel()) 38 | roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) 39 | 40 | ### Macro AUC ### 41 | all_fpr = np.unique(np.concatenate([fpr[i] for i in classes])) 42 | mean_tpr = np.zeros_like(all_fpr) 43 | for i in classes: 44 | mean_tpr += interp(all_fpr, fpr[i], tpr[i]) 45 | mean_tpr /= len(classes) 46 | 47 | fpr["macro"] = all_fpr 48 | tpr["macro"] = mean_tpr 49 | roc_auc["macro"] = auc(fpr["macro"], tpr["macro"]) 50 | 51 | ### Make plot ### 52 | 53 | plt.figure(figsize=(12, 12)) 54 | plt.plot(fpr["micro"], tpr["micro"], 55 | label='micro-average ROC curve (area = {0:0.2f})' 56 | ''.format(roc_auc["micro"]), 57 | color='deeppink', linestyle=':', linewidth=4) 58 | 59 | plt.plot(fpr["macro"], tpr["macro"], 60 | label='macro-average ROC curve (area = {0:0.2f})' 61 | ''.format(roc_auc["macro"]), 62 | color='navy', linestyle=':', linewidth=4) 63 | 64 | colors = ['blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive'] 65 | for i, color in zip(classes, colors): 66 | plt.plot(fpr[i], tpr[i], color=color, lw=2, 67 | label='{0} (area = {1:0.2f})' 68 | ''.format(class_names[i], roc_auc[i])) 69 | else: 70 | fpr, tpr, _ = roc_curve(labels, predictions[:,1]) 71 | auc_result = auc(fpr, tpr) 72 | 73 | for i in list(classes) + ['macro', 'micro']: 74 | roc_auc[i] = auc_result 75 | 76 | plt.figure(figsize=(12, 12)) 77 | plt.plot(fpr, tpr, lw=2, 78 | label='ROC curve (area = {0:0.2f})' 79 | ''.format(auc_result)) 80 | 81 | plt.plot([0, 1], [0, 1], 'k--', lw=2) 82 | plt.xlim([0.0, 1.0]) 83 | plt.ylim([0.0, 1.05]) 84 | plt.xlabel('False Positive Rate') 85 | plt.ylabel('True Positive Rate') 86 | plt.title('ROC Curve') 87 | plt.legend(loc="lower right") 88 | plt.savefig(path) 89 | 90 | return roc_auc#, fpr, tpr 91 | -------------------------------------------------------------------------------- /utils/new_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import random 4 | from PIL import Image, ImageOps, ImageEnhance 5 | import collections 6 | import types 7 | 8 | """ 9 | Taken directly from https://github.com/pytorch/vision/blob/master/torchvision/transforms.py 10 | Latest update that is not currently deployed to pip. 11 | 12 | All credits to the torchvision developers. 13 | """ 14 | 15 | accimage = None 16 | 17 | def _is_pil_image(img): 18 | if accimage is not None: 19 | return isinstance(img, (Image.Image, accimage.Image)) 20 | else: 21 | return isinstance(img, Image.Image) 22 | 23 | def crop(img, i, j, h, w): 24 | """Crop the given PIL.Image. 25 | Args: 26 | img (PIL.Image): Image to be cropped. 27 | i: Upper pixel coordinate. 28 | j: Left pixel coordinate. 29 | h: Height of the cropped image. 30 | w: Width of the cropped image. 31 | Returns: 32 | PIL.Image: Cropped image. 33 | """ 34 | if not _is_pil_image(img): 35 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 36 | 37 | return img.crop((j, i, j + w, i + h)) 38 | 39 | def resize(img, size, interpolation=Image.BILINEAR): 40 | """Resize the input PIL.Image to the given size. 41 | Args: 42 | img (PIL.Image): Image to be resized. 43 | size (sequence or int): Desired output size. If size is a sequence like 44 | (h, w), the output size will be matched to this. If size is an int, 45 | the smaller edge of the image will be matched to this number maintaing 46 | the aspect ratio. i.e, if height > width, then image will be rescaled to 47 | (size * height / width, size) 48 | interpolation (int, optional): Desired interpolation. Default is 49 | ``PIL.Image.BILINEAR`` 50 | Returns: 51 | PIL.Image: Resized image. 52 | """ 53 | if not _is_pil_image(img): 54 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 55 | if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)): 56 | raise TypeError('Got inappropriate size arg: {}'.format(size)) 57 | 58 | if isinstance(size, int): 59 | w, h = img.size 60 | if (w <= h and w == size) or (h <= w and h == size): 61 | return img 62 | if w < h: 63 | ow = size 64 | oh = int(size * h / w) 65 | return img.resize((ow, oh), interpolation) 66 | else: 67 | oh = size 68 | ow = int(size * w / h) 69 | return img.resize((ow, oh), interpolation) 70 | else: 71 | return img.resize(size[::-1], interpolation) 72 | 73 | def vflip(img): 74 | """Vertically flip the given PIL.Image. 75 | Args: 76 | img (PIL.Image): Image to be flipped. 77 | Returns: 78 | PIL.Image: Vertically flipped image. 79 | """ 80 | if not _is_pil_image(img): 81 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 82 | 83 | return img.transpose(Image.FLIP_TOP_BOTTOM) 84 | 85 | class Compose(object): 86 | """Composes several transforms together. 87 | Args: 88 | transforms (list of ``Transform`` objects): list of transforms to compose. 89 | Example: 90 | >>> transforms.Compose([ 91 | >>> transforms.CenterCrop(10), 92 | >>> transforms.ToTensor(), 93 | >>> ]) 94 | """ 95 | 96 | def __init__(self, transforms): 97 | self.transforms = transforms 98 | 99 | def __call__(self, img): 100 | for t in self.transforms: 101 | img = t(img) 102 | return img 103 | 104 | class Resize(object): 105 | """Resize the input PIL Image to the given size. 106 | Args: 107 | size (sequence or int): Desired output size. If size is a sequence like 108 | (h, w), output size will be matched to this. If size is an int, 109 | smaller edge of the image will be matched to this number. 110 | i.e, if height > width, then image will be rescaled to 111 | (size * height / width, size) 112 | interpolation (int, optional): Desired interpolation. Default is 113 | ``PIL.Image.BILINEAR`` 114 | """ 115 | 116 | def __init__(self, size, interpolation=Image.BILINEAR): 117 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 118 | self.size = size 119 | self.interpolation = interpolation 120 | 121 | def __call__(self, img): 122 | """ 123 | Args: 124 | img (PIL Image): Image to be scaled. 125 | Returns: 126 | PIL Image: Rescaled image. 127 | """ 128 | return resize(img, self.size, self.interpolation) 129 | 130 | class RandomVerticalFlip(object): 131 | """Vertically flip the given PIL.Image randomly with a probability of 0.5.""" 132 | 133 | def __call__(self, img): 134 | """ 135 | Args: 136 | img (PIL.Image): Image to be flipped. 137 | Returns: 138 | PIL.Image: Randomly flipped image. 139 | """ 140 | if random.random() < 0.5: 141 | return vflip(img) 142 | return img 143 | 144 | def adjust_brightness(img, brightness_factor): 145 | """Adjust brightness of an Image. 146 | Args: 147 | img (PIL.Image): PIL Image to be adjusted. 148 | brightness_factor (float): How much to adjust the brightness. Can be 149 | any non negative number. 0 gives a black image, 1 gives the 150 | original image while 2 increases the brightness by a factor of 2. 151 | Returns: 152 | PIL.Image: Brightness adjusted image. 153 | """ 154 | if not _is_pil_image(img): 155 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 156 | 157 | enhancer = ImageEnhance.Brightness(img) 158 | img = enhancer.enhance(brightness_factor) 159 | return img 160 | 161 | 162 | def adjust_contrast(img, contrast_factor): 163 | """Adjust contrast of an Image. 164 | Args: 165 | img (PIL.Image): PIL Image to be adjusted. 166 | contrast_factor (float): How much to adjust the contrast. Can be any 167 | non negative number. 0 gives a solid gray image, 1 gives the 168 | original image while 2 increases the contrast by a factor of 2. 169 | Returns: 170 | PIL.Image: Contrast adjusted image. 171 | """ 172 | if not _is_pil_image(img): 173 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 174 | 175 | enhancer = ImageEnhance.Contrast(img) 176 | img = enhancer.enhance(contrast_factor) 177 | return img 178 | 179 | 180 | def adjust_saturation(img, saturation_factor): 181 | """Adjust color saturation of an image. 182 | Args: 183 | img (PIL.Image): PIL Image to be adjusted. 184 | saturation_factor (float): How much to adjust the saturation. 0 will 185 | give a black and white image, 1 will give the original image while 186 | 2 will enhance the saturation by a factor of 2. 187 | Returns: 188 | PIL.Image: Saturation adjusted image. 189 | """ 190 | if not _is_pil_image(img): 191 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 192 | 193 | enhancer = ImageEnhance.Color(img) 194 | img = enhancer.enhance(saturation_factor) 195 | return img 196 | 197 | 198 | def adjust_hue(img, hue_factor): 199 | """Adjust hue of an image. 200 | The image hue is adjusted by converting the image to HSV and 201 | cyclically shifting the intensities in the hue channel (H). 202 | The image is then converted back to original image mode. 203 | `hue_factor` is the amount of shift in H channel and must be in the 204 | interval `[-0.5, 0.5]`. 205 | See https://en.wikipedia.org/wiki/Hue for more details on Hue. 206 | Args: 207 | img (PIL.Image): PIL Image to be adjusted. 208 | hue_factor (float): How much to shift the hue channel. Should be in 209 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in 210 | HSV space in positive and negative direction respectively. 211 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image 212 | with complementary colors while 0 gives the original image. 213 | Returns: 214 | PIL.Image: Hue adjusted image. 215 | """ 216 | if not(-0.5 <= hue_factor <= 0.5): 217 | raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) 218 | 219 | if not _is_pil_image(img): 220 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 221 | 222 | input_mode = img.mode 223 | if input_mode in {'L', '1', 'I', 'F'}: 224 | return img 225 | 226 | h, s, v = img.convert('HSV').split() 227 | 228 | np_h = np.array(h, dtype=np.uint8) 229 | # uint8 addition take cares of rotation across boundaries 230 | with np.errstate(over='ignore'): 231 | np_h += np.uint8(hue_factor * 255) 232 | h = Image.fromarray(np_h, 'L') 233 | 234 | img = Image.merge('HSV', (h, s, v)).convert(input_mode) 235 | return img 236 | 237 | 238 | def adjust_gamma(img, gamma, gain=1): 239 | """Perform gamma correction on an image. 240 | Also known as Power Law Transform. Intensities in RGB mode are adjusted 241 | based on the following equation: 242 | I_out = 255 * gain * ((I_in / 255) ** gamma) 243 | See https://en.wikipedia.org/wiki/Gamma_correction for more details. 244 | Args: 245 | img (PIL.Image): PIL Image to be adjusted. 246 | gamma (float): Non negative real number. gamma larger than 1 make the 247 | shadows darker, while gamma smaller than 1 make dark regions 248 | lighter. 249 | gain (float): The constant multiplier. 250 | """ 251 | if not _is_pil_image(img): 252 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 253 | 254 | if gamma < 0: 255 | raise ValueError('Gamma should be a non-negative real number') 256 | 257 | input_mode = img.mode 258 | img = img.convert('RGB') 259 | 260 | np_img = np.array(img, dtype=np.float32) 261 | np_img = 255 * gain * ((np_img / 255) ** gamma) 262 | np_img = np.uint8(np.clip(np_img, 0, 255)) 263 | 264 | img = Image.fromarray(np_img, 'RGB').convert(input_mode) 265 | return img 266 | 267 | def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR): 268 | """Crop the given PIL.Image and resize it to desired size. 269 | Notably used in RandomResizedCrop. 270 | Args: 271 | img (PIL.Image): Image to be cropped. 272 | i: Upper pixel coordinate. 273 | j: Left pixel coordinate. 274 | h: Height of the cropped image. 275 | w: Width of the cropped image. 276 | size (sequence or int): Desired output size. Same semantics as ``scale``. 277 | interpolation (int, optional): Desired interpolation. Default is 278 | ``PIL.Image.BILINEAR``. 279 | Returns: 280 | PIL.Image: Cropped image. 281 | """ 282 | assert _is_pil_image(img), 'img should be PIL Image' 283 | img = crop(img, i, j, h, w) 284 | img = resize(img, size, interpolation) 285 | return img 286 | 287 | class RandomResizedCrop(object): 288 | """Crop the given PIL.Image to random size and aspect ratio. 289 | A crop of random size of (0.08 to 1.0) of the original size and a random 290 | aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop 291 | is finally resized to given size. 292 | This is popularly used to train the Inception networks. 293 | Args: 294 | size: expected output size of each edge 295 | interpolation: Default: PIL.Image.BILINEAR 296 | """ 297 | 298 | def __init__(self, size, interpolation=Image.BILINEAR): 299 | self.size = (size, size) 300 | self.interpolation = interpolation 301 | 302 | @staticmethod 303 | def get_params(img): 304 | """Get parameters for ``crop`` for a random sized crop. 305 | Args: 306 | img (PIL.Image): Image to be cropped. 307 | Returns: 308 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random 309 | sized crop. 310 | """ 311 | for attempt in range(10): 312 | area = img.size[0] * img.size[1] 313 | target_area = random.uniform(0.08, 1.0) * area 314 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 315 | 316 | w = int(round(math.sqrt(target_area * aspect_ratio))) 317 | h = int(round(math.sqrt(target_area / aspect_ratio))) 318 | 319 | if random.random() < 0.5: 320 | w, h = h, w 321 | 322 | if w <= img.size[0] and h <= img.size[1]: 323 | i = random.randint(0, img.size[1] - h) 324 | j = random.randint(0, img.size[0] - w) 325 | return i, j, h, w 326 | 327 | # Fallback 328 | w = min(img.size[0], img.size[1]) 329 | i = (img.size[1] - w) // 2 330 | j = (img.size[0] - w) // 2 331 | return i, j, w, w 332 | 333 | def __call__(self, img): 334 | """ 335 | Args: 336 | img (PIL.Image): Image to be flipped. 337 | Returns: 338 | PIL.Image: Randomly cropped and resize image. 339 | """ 340 | i, j, h, w = self.get_params(img) 341 | return resized_crop(img, i, j, h, w, self.size, self.interpolation) 342 | 343 | class Lambda(object): 344 | """Apply a user-defined lambda as a transform. 345 | Args: 346 | lambd (function): Lambda/function to be used for transform. 347 | """ 348 | 349 | def __init__(self, lambd): 350 | assert isinstance(lambd, types.LambdaType) 351 | self.lambd = lambd 352 | 353 | def __call__(self, img): 354 | return self.lambd(img) 355 | 356 | class ColorJitter(object): 357 | """Randomly change the brightness, contrast and saturation of an image. 358 | Args: 359 | brightness (float): How much to jitter brightness. brightness_factor 360 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. 361 | contrast (float): How much to jitter contrast. contrast_factor 362 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. 363 | saturation (float): How much to jitter saturation. saturation_factor 364 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. 365 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from 366 | [-hue, hue]. Should be >=0 and <= 0.5. 367 | """ 368 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 369 | self.brightness = brightness 370 | self.contrast = contrast 371 | self.saturation = saturation 372 | self.hue = hue 373 | 374 | @staticmethod 375 | def get_params(brightness, contrast, saturation, hue): 376 | """Get a randomized transform to be applied on image. 377 | Arguments are same as that of __init__. 378 | Returns: 379 | Transform which randomly adjusts brightness, contrast and 380 | saturation in a random order. 381 | """ 382 | transforms = [] 383 | if brightness > 0: 384 | brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) 385 | transforms.append(Lambda(lambda img: adjust_brightness(img, brightness_factor))) 386 | 387 | if contrast > 0: 388 | contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) 389 | transforms.append(Lambda(lambda img: adjust_contrast(img, contrast_factor))) 390 | 391 | if saturation > 0: 392 | saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) 393 | transforms.append(Lambda(lambda img: adjust_saturation(img, saturation_factor))) 394 | 395 | if hue > 0: 396 | hue_factor = np.random.uniform(-hue, hue) 397 | transforms.append(Lambda(lambda img: adjust_hue(img, hue_factor))) 398 | 399 | np.random.shuffle(transforms) 400 | transform = Compose(transforms) 401 | 402 | return transform 403 | 404 | def __call__(self, img): 405 | """ 406 | Args: 407 | img (PIL.Image): Input image. 408 | Returns: 409 | PIL.Image: Color jittered image. 410 | """ 411 | transform = self.get_params(self.brightness, self.contrast, 412 | self.saturation, self.hue) 413 | return transform(img) 414 | 415 | def adjust_rotation(img, degree=90): 416 | """Roatete the given PIL.Image. 417 | Args: 418 | img (PIL.Image): Image to be flipped. 419 | degree: Angle to rotate: 0 to 360 420 | Returns: 421 | PIL.Image: Rotated image. 422 | """ 423 | if not _is_pil_image(img): 424 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 425 | 426 | if degree<0: 427 | raise ValueError('Negative rotation - Select degree between 0 and 360') 428 | 429 | if degree>360: 430 | raise ValueError('Negative rotation - Select degree between 0 and 360') 431 | 432 | return img.rotate(degree) 433 | 434 | class ColorJitterRotate(object): 435 | """Randomly change the brightness, contrast and saturation of an image. 436 | Args: 437 | brightness (float): How much to jitter brightness. brightness_factor 438 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. 439 | contrast (float): How much to jitter contrast. contrast_factor 440 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. 441 | saturation (float): How much to jitter saturation. saturation_factor 442 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. 443 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from 444 | [-hue, hue]. Should be >=0 and <= 0.5. 445 | rotation: Rotate image randomly 0 to the defined parameter, fixed between (0, 90, 180, 270) 446 | """ 447 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, rotation=0): 448 | self.brightness = brightness 449 | self.contrast = contrast 450 | self.saturation = saturation 451 | self.hue = hue 452 | self.rotation = rotation 453 | 454 | @staticmethod 455 | def get_params(brightness, contrast, saturation, hue, rotation): 456 | """Get a randomized transform to be applied on image. 457 | Arguments are same as that of __init__. 458 | Returns: 459 | Transform which randomly adjusts brightness, contrast and 460 | saturation in a random order. 461 | """ 462 | transforms = [] 463 | if brightness > 0: 464 | brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) 465 | transforms.append(Lambda(lambda img: adjust_brightness(img, brightness_factor))) 466 | 467 | if contrast > 0: 468 | contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) 469 | transforms.append(Lambda(lambda img: adjust_contrast(img, contrast_factor))) 470 | 471 | if saturation > 0: 472 | saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) 473 | transforms.append(Lambda(lambda img: adjust_saturation(img, saturation_factor))) 474 | 475 | if hue > 0: 476 | hue_factor = np.random.uniform(-hue, hue) 477 | transforms.append(Lambda(lambda img: adjust_hue(img, hue_factor))) 478 | 479 | if rotation > 0: 480 | rotation_factor = np.random.uniform(0, rotation) 481 | rotation_factor = min([0,90,180,270,360], key=lambda x:abs(x-rotation_factor)) 482 | transforms.append(Lambda(lambda img: adjust_rotation(img, rotation_factor))) 483 | 484 | np.random.shuffle(transforms) 485 | transform = Compose(transforms) 486 | 487 | return transform 488 | 489 | def __call__(self, img): 490 | """ 491 | Args: 492 | img (PIL.Image): Input image. 493 | Returns: 494 | PIL.Image: Color jittered image. 495 | """ 496 | transform = self.get_params(self.brightness, self.contrast, 497 | self.saturation, self.hue, self.rotation) 498 | return transform(img) 499 | 500 | class RandomRotate(object): 501 | """Randomly change the Rotation of an image (0/90/180/270). 502 | 503 | Args: 504 | rotation: Rotate image randomly to the defined parameter, fixed between (0, 90, 180, 270) 505 | 506 | """ 507 | def __init__(self, rotation=0): 508 | self.rotation = rotation 509 | 510 | @staticmethod 511 | def get_params(rotation): 512 | 513 | transforms = [] 514 | rotation_factor = np.random.randint(0, 4, 1) * 90 515 | transforms.append(Lambda(lambda img: adjust_rotation(img, rotation_factor))) 516 | 517 | np.random.shuffle(transforms) 518 | transform = Compose(transforms) 519 | 520 | return transform 521 | 522 | def __call__(self, img): 523 | """ 524 | Args: 525 | img (PIL.Image): Input image. 526 | Returns: 527 | PIL.Image: Randomly rotated image. 528 | """ 529 | transform = self.get_params(self.rotation) 530 | 531 | return transform(img) 532 | --------------------------------------------------------------------------------