├── FROC_CPM.ipynb ├── GenerateCSV.py ├── LICENSE ├── README.md ├── adable.py ├── config_training.py ├── data_detector.py ├── json └── Read.md ├── layers.py ├── loss.py ├── main_detector_recon.py ├── net └── OSAF_YOLOv3.py ├── noduleCADEvaluationLUNA16.py ├── prepare.py └── split_combine.py /FROC_CPM.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "FROC.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | }, 17 | "accelerator": "GPU" 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "code", 22 | "source": [ 23 | "%cd /content/drive/MyDrive/DL" 24 | ], 25 | "metadata": { 26 | "colab": { 27 | "base_uri": "https://localhost:8080/" 28 | }, 29 | "id": "-j3cGnQPVTGs", 30 | "outputId": "8680ec9f-9ad0-42db-d298-18c6b7b62bbb" 31 | }, 32 | "execution_count": 1, 33 | "outputs": [ 34 | { 35 | "output_type": "stream", 36 | "name": "stdout", 37 | "text": [ 38 | "/content/drive/MyDrive/DL\n" 39 | ] 40 | } 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "source": [ 46 | "import numpy as np\n", 47 | "from sklearn import metrics\n", 48 | "import matplotlib.pyplot as plt\n", 49 | "%matplotlib inline\n", 50 | "\n", 51 | "res_fps = np.load('OSAF_YOLOv3_fps.npy')\n", 52 | "res_sens = np.load('OSAF_YOLOv3_sens.npy')\n", 53 | "\n", 54 | "\n", 55 | "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n", 56 | "plt.plot(res_fps, res_sens, color='red', label='OSAF_YOLOv3, CPM = 0.905', linestyle=':')\n", 57 | "\n", 58 | "plt.legend(loc='lower right', fontsize=16)\n", 59 | "plt.xlabel('FPs per scan')\n", 60 | "plt.ylabel('Sensitivity')\n", 61 | "plt.title('FROC Curve')\n", 62 | "plt.savefig('FROC_OSAF_YOLOv3.png', dpi=300)" 63 | ], 64 | "metadata": { 65 | "colab": { 66 | "base_uri": "https://localhost:8080/", 67 | "height": 347 68 | }, 69 | "id": "blKsmMRzVZpr", 70 | "outputId": "32890c80-db92-427f-e9c8-d4d5a516071e" 71 | }, 72 | "execution_count": 3, 73 | "outputs": [ 74 | { 75 | "output_type": "stream", 76 | "name": "stderr", 77 | "text": [ 78 | "findfont: Font family ['Times New Roman'] not found. Falling back to DejaVu Sans.\n", 79 | "findfont: Font family ['Times New Roman'] not found. Falling back to DejaVu Sans.\n", 80 | "findfont: Font family ['Times New Roman'] not found. Falling back to DejaVu Sans.\n" 81 | ] 82 | }, 83 | { 84 | "output_type": "display_data", 85 | "data": { 86 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAY4AAAEWCAYAAABxMXBSAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deXgUVfbw8e8hEMK+JYJD2ISgIv5UiCiDEEZFwQVBZ1QWQeVVZ9zFDVBGQMENlXEBBJUlIsgIKiObbGEGBSXIooAgCLIqYREEIhI47x9VnXSSJukO6VQnOZ/nqSd9b22nEqjTdW/VLVFVjDHGmGCV8ToAY4wxxYslDmOMMSGxxGGMMSYkljiMMcaExBKHMcaYkFjiMMYYExJLHMYYY0JiicOUOiKyVUTSReSw3/QnEWkoIupXt1VE+gVY/3YR+VZEjorIzyIySkSq51imqYj8W0T2ishBEVkjIn1FJOoUMVUVkREiss3d92a3HBuu34MxBWWJw5RW16tqZb9pl9+86qpaGfgrMFBEOvhmiMijwIvA40A14FKgATBPRKLdZRoDXwHbgfNVtRrwNyARqJIzEHe9BcB5QEegKtAa2Ae0CvXARKRsqOsYEwpLHMacgqqmAmuBC8G5KgAGAw+o6hxVPa6qW4GbgYZAT3fVwcCXqtpXVXe729qgqt1V9dcAu+oF1Ae6quo6VT2pqntU9VlVneXuW0WkiW8FERkvIs+5n9uLyA4ReVJEfgbGich6EbnOb/myIpImIi3c8qUi8qWI/Coiq0WkfWH93kzJZ4nDmFMQkUuB5sAmt+rPQAww3X85VT0MzAJ8VyZXAh+FsKsrgTnudgqqDlAT5+rnbmAy0M1v/tXAXlX9RkTqAjOB59x1HgOmiUjcaezflCKWOExp9Yn7bftXEfkkx7y9IpIOLAVGAr75sTgn34wA29vtzgeo5ZaDFerygZwEnlHVY6qaDnwAdBaRiu787jjJBJwro1mqOsu9upkHpALXnGYMppSwxGFKqy6qWt2duuSYFwtUBh4F2gPl3Pq9QOwp+hDOdOeD0zdxZgixhLp8IGmq+ruvoKqbgPXA9W7y6IyTTMC5KvmbX+L8FbisEGIwpYQlDmMCUNUTqvoq8Dtwr1u9FDgG3Oi/rIhUBjrhdHADzAduCmF384GrRaRSHsscBSr6levkDDnAOr7mqhuAdW4yAafTPtkvcVZX1Uqq+kIIMZtSzBKHMXl7AXhCRGJU9SBOx/cbItJRRMqJSENgKrADSHbXeQb4s4i8LCJ1AESkiYi8n/O2XVcyzsl8moicIyJlRKSWiAwQEV/z0Sqgu4hEiUhHICmI2KcAVwH/IOtqA+B9nCuRq93txbgd7PHB/1pMaWaJw5i8zQQOAHcBqOpLwABgOHCIrNtur1DVY+4ym3Fup20IrBWRg8A0nH6E33LuwF3vSuB7YJ673a9xmsy+chd7CLge+BXoQVa/yym5d3QtxenU/9CvfjvOVcgAIM2N/3HsfGCCJPYiJ2OMMaGwbxjGGGNCYonDGGNMSCxxGGOMCYklDmOMMSEpFYOhxcbGasOGDb0OwxhjipUVK1bsVdVcQ9GUisTRsGFDUlNTvQ7DGGOKFRH5KVC9NVUZY4wJiSUOY4wxIbHEYYwxJiSWOIwxxoTEEocxxpiQWOIwxhgTEkscxhhjQmKJw5hIk5EBe/fCH3845d9+g5Ur4cgRp/zLLzBjBvz6q1PevBlGjoT9+53yd9/BkCHONgBWrXLKvuVXrHDKv7kjvH/9tVM+etQpL13qlI8dc8pLljjlDPeNuSkpMHhwVrwLFsCzz2aVP/8cnn8+qzx7Nrz0Ulb5s8/glVeyyp98AiNGZJWnTYM33sgqT50Ko0ZllT/4AN5+O6ucnAzvvJNVHj/emXzefddZxuftt51t+IwcCR9+mFV+4w34yO+V8SNGODH6vPIK/Oc/WeUXX4SZM7PKQ4fCnDlZ5SFDYN68rPIzz8CiRc7njAyn/N//OuVjx5zyl1865aNHnfJX7uj6v/3mlH3PpR044JRXrnTKe/fCP/8Ja9YQVqpa4qeWLVuqMaeUkaG6Zo3qjh1OOT1d9ZVXVL/6yikfPqz60EOqixY55YMHVe+6K6uclqbas6dqSopT3rVL9aabVBcvdsqbN6u2a6e6cKFT/vZb1caNVefPd8pLl6qWL686d65TXrhQFbK2P3euU/7iC6c8Y4ZTXr7cKU+b5pRXrXLKH3zglL//3imPG+eUt2xxym+/7ZR37nTKr7/ulPfudcrDhzvlQ4ec8rBhTvn3353yM884ZZ/+/VXLlcsq9+2rWrlyVvn++1Vr1swq3323ap06WeXevVUbNMgqd+ummpCQVb7xRtXmzbPK112n2qJFVvmqq1QvvTSrnJTkTD6tW6t26JBVbtHC2YZP8+bOPnyaNlW99dascoMGTow+Z57p/P19atZ0jtGncmXnd+ATHa3ar19WGZzfoarqsWNOeehQp3zokFMePtwp793rlF9/3Snv2uWUR492ylu2OOVx45zyhg2qIqqTJmlhAFI1wDnV85N6UUyWOIyePKl65Ijz+cQJ5+T03ntO2fefd/Bgp/z77075+eed8sGDqlWrqo4c6ZTT0pyTx/jxTnnXLtWzzlL98EOn/NNPqs2aOSd4VdWtW50T2YIFWeVu3bJO/Nu2qT7xhOr69U55xw7nRLFtm1P+5RfVjz/OOrEfOKCamuokNFXVo0dVf/5Z9fjxrGM9ccL5Ga6y77NXZVMkTpU4SsWLnBITE9WGHCllDhxwLtsTEpxyQgIkJWU1abRpAzfeCI8+6pQ//hjOOw+aNnXKBw9CpUpQtlSMymNMQCKyQlUTc9bb/wpTPKk6fQDlyzvliRPh55/hiSeccufOIJLVdnz//VC/ftb6X3yRfXtdu2YvV6sWnriNKQGsc9wUjbVr4f33nRM+OB2g3bplzf/XvyDR74vNsGFw4YVZ5aefhvPPzyr/v/8HzZpllefPh3//O6s8YIDTKenz0EO5k4MxpkDsisOEx+bNzp0tAwdCdLRzZ83jj8MNN0CVKrB7N6xe7Vw1REdDrVpw1llZ69evDy1bZpXPPRcOH84qd+3q1PmMGwdRUVnlTp3CdmjGlHbWx2EKh6pzi2CjRhAb69xy2aWL0yR0ySXOraL79zvz/U/wxpiIdao+DmuqMqfH98Vj0yZo1Srr/virrnKeN7jkEqdcsyY0aWJJw5gSwBKHKRhV+NvfoG9fp5yQ4PQx9OzplH3NT8aYEsf6OExojhxxblMVgbp1oU6drHl//at3cRljikxYrzhEpKOIbBCRTSLSL8D8BiKyQETWiEiKiMT7zTshIqvcaYZffSMR+crd5ociEh3OYzB+Zs92Oq03bHDKI0bAk096G5MxpsiFLXGISBTwFtAJaAZ0E5FmORYbDkxU1f8DhgB+A9yQrqoXulNnv/oXgddUtQlwAOgTrmMwLt8YRS1awNVXQ4UK3sZjjPFUOK84WgGbVPVHVf0DmALckGOZZsBC9/OiAPOzEREBLgd8I5BNALoUWsQmtwcfdPoyAGrXdjq//R+kM8aUOuFMHHWB7X7lHW6dv9XAje7nrkAVEfH1qMaISKqILBMRX3KoBfyqqhl5bBMAEbnbXT81LS3tdI+ldDl5Mutzo0bOMBwnTngXjzEmonh9V9VjQJKIrASSgJ2A7wzVwL1/uDswQkQah7JhVR2jqomqmhgXF1eoQZdoO3c6t9AudC8EH3nEGTbabqM1xrjCmTh2AvX8yvFuXSZV3aWqN6rqRcBTbt2v7s+d7s8fgRTgImAfUF1Eyp5qm6aAfv/d+RkbC2XKZL2LwRhjcghn4lgOJLh3QUUDtwIz/BcQkVgR8cXQH3jPra8hIuV9ywBtgHXuML+LAN99n72BT8N4DKVD//7QurXTHFW+PCxbZkN2GGNOKWyJw+2HuB+YC6wHpqrqWhEZIiK+u6TaAxtEZCNQGxjq1p8LpIrIapxE8YKqrnPnPQn0FZFNOH0e74brGEq0Q4ey7pZq1cp50tv3xjkR7+IyxkQ8G6uqNPrpJ7j4YucVl3fd5XU0xpgIZWNVlXaqsHWr87l+fejRw3kuwxhjQmSJo7R4+GHnbqnffnOaol57Lfuw5cYYEyQbq6ok27PHefdFhQrQq5fzIiR76tsYc5rsiqOkOnTIeYf200875ZYtnbfm2Tu0jTGnyc4iJVXVqs77t6+91utIjDEljF1xlCS7dsF118E6987lxx/P/l5uY4wpBJY4SpI//nCGPN+0yetIjDElmDVVlQQ7dzovVWrYEL7/3saVMsaElV1xFHc//uiMXvvmm07ZkoYxJswscRR3DRvCP/5hr201xhQZa6oqrn74AeLioHp1GD7c62iMMaWIXXEUR8ePO6PX3nyz15EYY0ohu+IojsqVg9GjoVat/Jc1xphCZomjuNm0CZo0gSuv9DoSY0wpZU1Vxcns2c4DfUuXeh2JMaYUs8RRnLRo4TwNbsOhG2M8ZE1VxcUff0Dt2s7Ll4wxxkN2xVEczJmT/dWuxhjjIUscxcHvv8PBg3DypNeRGGNMeBOHiHQUkQ0isklE+gWY30BEFojIGhFJEZF4t/5CEVkqImvdebf4rTNeRLaIyCp3ujCcxxARunSBZcsgJsbrSIwxJnyJQ0SigLeATkAzoJuI5BzjezgwUVX/DxgCPO/WHwV6qep5QEdghIhU91vvcVW90J1WhesYPPfVV86dVADly3sbizHGuMJ5xdEK2KSqP6rqH8AU4IYcyzQDFrqfF/nmq+pGVf3B/bwL2APEhTHWyPTSS/D3vztNVcYYEyHCmTjqAtv9yjvcOn+rgRvdz12BKiKS7XFoEWkFRAOb/aqHuk1Yr4lIwK/iInK3iKSKSGpaWtrpHId3Jk+Gzz+3JipjTETxunP8MSBJRFYCScBO4IRvpoicCSQDd6iqr2e4P3AOcDFQE3gy0IZVdYyqJqpqYlxcMbtYSU+HEycgOhrOPtvraIwxJptwJo6dQD2/crxbl0lVd6nqjap6EfCUW/crgIhUBWYCT6nqMr91dqvjGDAOp0msZBkwAM4/H44e9ToSY4zJJZwPAC4HEkSkEU7CuBXo7r+AiMQC+92rif7Ae259NPAxTsf5RznWOVNVd4uIAF2A78J4DN644grnmY2KFb2OxBhjcglb4lDVDBG5H5gLRAHvqepaERkCpKrqDKA98LyIKPBf4D539ZuBdkAtEbndrbvdvYNqkojEAQKsAv4ermPwzHXXOZMxxkQgUVWvYwi7xMRETU1N9TqM/Kk6HeI33ACVKnkdjTGmlBORFaqamLPe685x42/7dmcsqokTvY7EGGNOyRJHJKlfH269FW68Mf9ljTHGIzY6bqQZONDrCIwxJk92xREpvvwS7rsP9uzxOhJjjMmTJY5IsXo1TJ9uneLGmIhniSNS/OMfTue4JQ5jTISzxBEJjh93fpa1LidjTOSzxBEJevSwO6mMMcWGfcWNBK1a2dWGMabYsLNVJHjsMa8jMMaYoFlTldc++SSrj8MYY4oBSxxeWrsWunZ13vRnjDHFhDVVealpU5g50+njMMaYYsISh5fKlYNrrvE6CmOMCYk1VXnlgw/g+eedV8QaY0wxYonDK//9L3z8MURFeR2JMcaExJqqvDJ6NBw+7HUUxhgTMrvi8MKRI87PypW9jcMYYwrAEkdR+/57qF0bZs3yOhJjjCmQsCYOEekoIhtEZJOI9Aswv4GILBCRNSKSIiLxfvN6i8gP7tTbr76liHzrbvN1EZFwHkOhi46G7t2hZUuvIzHGmAIJW+IQkSjgLaAT0AzoJiLNciw2HJioqv8HDAGed9etCTwDXAK0Ap4RkRruOqOAu4AEd+oYrmMIi7POgjFjnKsOY4wphsJ5xdEK2KSqP6rqH8AU4IYcyzQDFrqfF/nNvxqYp6r7VfUAMA/oKCJnAlVVdZmqKjAR6BLGYyhc334Lmzd7HYUxxpyWcCaOusB2v/IOt87fasA3nnhXoIqI1Mpj3bru57y2CYCI3C0iqSKSmpaWVuCDKFRPPgl/+QucPOl1JMYYU2Bed44/BiSJyEogCdgJFMoTcao6RlUTVTUxLi6uMDZ5+saOheRkKOP1r90YYwounM9x7ATq+ZXj3bpMqroL94pDRCoDN6nqryKyE2ifY90Ud/34HPXZthnR6tZ1JmOMKcbC+dV3OZAgIo1EJBq4FZjhv4CIxIqIL4b+wHvu57nAVSJSw+0UvwqYq6q7gUMicql7N1Uv4NMwHkPhOHkSHngAvvnG60iMMea0hS1xqGoGcD9OElgPTFXVtSIyREQ6u4u1BzaIyEagNjDUXXc/8CxO8lkODHHrAO4F3gE2AZuB2eE6hkLzww/w/vuwaZPXkRhjzGkT5+akki0xMVFTU1O9DSI93Xk9bLly3sZhjDFBEpEVqpqYs97Gqgo3VRCBChW8jsQYYwpFUE1VIjJdRK71648wwXr5ZWjTxgY0NMaUGMEmgpFAd+AHEXlBRM4OY0wlS506cM45NqChMabECCpxqOp8Ve0BtAC2AvNF5EsRuUNErNE+L716wbvveh2FMcYUmqCbntwnum8H/h+wEvgXTiKZF5bIiruDB2HKFKePwxhjSpBg+zg+Bv4HVASuV9XOqvqhqj4AWBtMIKNGQY8esH6915EYY0yhCvauqrGqmu0FEiJSXlWPBbpVywBPPAHt2kGznAMCG2NM8RZsU9VzAeqWFmYgJcaxY7BvnzMe1Z//7HU0xhhT6PJMHCJSR0RaAhVE5CIRaeFO7XGarUxOY8Y4L2k6eNDrSIwxJizya6q6GqdDPB541a/+N2BAmGIq3pKSYOVKqFrV60iMMSYsghpyRERuUtVpRRBPWETEkCPGGFPMFGjIERHpqarvAw1FpG/O+ar6aoDVSqeMDHjrLbjjDrvaMMaUaPl1jldyf1YGqgSYjE9qKjz8MCxe7HUkxhgTVnlecajq2+7HkaoaIe9fjVCXXuokjbZtvY7EGGPCKtjbcb8Qkc9FpI/7YiXjz/cO8XbtnJFwjTGmBAt2rKqmwNPAecAKEflMRHqGNbLiIiMD/u//nCfFjTGmFAh6rCpV/VpV+wKtgP3AhLBFVZwcOAAXXwz163sdiTHGFImghhwRkapAV5z3hjcGPsZJICYuDsaN8zoKY4wpMsGOVbUa+ATn3d821IjPiROwcaPzvg3r2zDGlBLBNlWdpaqPhJo0RKSjiGwQkU0i0i/A/PoiskhEVorIGhG5xq3vISKr/KaTInKhOy/F3aZv3hmhxFSoli93BjH84APPQjDGmKKW3wOAI1T1YWCGiOR6xFxVO+exbhTwFtAB2AEsF5EZqrrOb7GngamqOkpEmgGzgIaqOgmY5G7nfOATVV3lt14PVfX+UfCzzoJXX4XLL/c6EmOMKTL5NVUluz+HF2DbrYBNqvojgIhMAW4A/BOHAr7HrKsBuwJspxswpQD7D78zzoBHHvE6CmOMKVJ5NlWp6gr344Wquth/Ai7MZ9t1ge1+5R1unb9BQE8R2YFztfFAgO3cAkzOUTfObaYaKBK4c0FE7haRVBFJTUsL07OL778P27fnv5wxxpQgwfZx9A5Qd3sh7L8bMF5V44FrgGQRyYxJRC4Bjqrqd37r9FDV84G27nRboA2r6hhVTVTVxLi4uEIINYeMDOjdGwbYIMHGmNIlvz6ObkB3oJGIzPCbVQXnWY687ATq+ZXj3Tp/fYCOAKq6VERigFhgjzv/VnJcbajqTvfnbyLyAU6T2MR8Yil8Zco4nePVqxf5ro0xxkv59XF8CezGOZm/4lf/G7Amn3WXAwki0ggnYdyKk4T8bQOuAMaLyLlADJAG4F553IxzVYFbVxaorqp7RaQccB0wP584wqNMGWjRwpNdG2OMl/Ib5PAn4CegdagbVtUMEbkfmAtEAe+p6loRGQKkquoM4FFgrIg8gtNRfrtmvSCkHbDd17nuKg/MdZNGFE7SGBtqbIVi9WrYuhWuvRbKBvs4jDHGFH95vshJRJao6mUi8hvOiT1zFqCqWixePBGWFzl17w7TpkF6unP1YYwxJUyBXuSkqpe5P+3dGzmNHg133mlJwxhT6gR11hORxiJS3v3cXkQeFJHS3StctSpceaXXURhjTJEL9uvyNOCEiDQBxuDcLVW6x9kYNQq+/trrKIwxpsgFmzhOqmoGzgi5b6jq48CZ4QsrwmVkwH33wcyZXkdijDFFLtjbgY67z3T0Bq5368qFJ6RiICoK9uf3GIsxxpRMwV5x3IFzS+5QVd3iPpuRnM86JZeI8+CfPfxnjCmFgn117DpVfVBVJ7vlLar6YnhDi2ArVsDQofDrr15HYowxRS7Yu6raiMg8EdkoIj+KyBYR+TH/NUuoZcvg6aedFzkZY0wpk+cDgJkLiXwPPAKsADLPlqq6L3yhFZ6wPAD4++9Qvry9+c8YU2IV6AFAPwdVdXYhx1S8xcR4HYExxngi2MSxSEReBqYDx3yVqvpNWKKKdC++CLGx0KeP15EYY0yRCzZxXOL+9L9kUaB0vjP1s8+gfn1LHMaYUimoxKGqfwl3IMXK//4HQfQNGWNMSRTsXVW1ReRdEZntlpuJSOn+um2d4saYUirYBwDH47xX409ueSPwcDgCing7dkDXrrBypdeRGGOMJ4JNHLGqOhU4Cc5LmvC7LbdU2boVliyB3bu9jsQYYzwRbOf4ERGphfsyJxG5FDgYtqgi2WWXwbZtdjuuMabUCjZx9AVmAI1F5AsgDvhr2KKKdBUqeB2BMcZ4Js+mKhG5WETquM9rJAEDcJ7j+BzYUQTxRZ7+/e02XGNMqZZfH8fbwB/u5z8DTwFvAQdwXuiUJxHpKCIbRGSTiPQLML++iCwSkZUiskZErnHrG4pIuoiscqfRfuu0FJFv3W2+LlLEtzeVLQurVhXpLo0xJpLklziiVNX34olbgDGqOk1VBwJN8lpRRKJwkkwnoBnQTUSa5VjsaWCqql4E3AqM9Ju3WVUvdKe/+9WPAu4CEtypYz7HULiefdYZHdcYY0qpfBOHiPj6Qa4AFvrNy69/pBWwSVV/VNU/gCnADTmWUaCq+7kasCuvDYrImUBVVV2mzuiME4Eu+cRhjDGmEOWXOCYDi0XkUyAd+B+A++7x/O6qqgts9yvvcOv8DQJ6isgOYBbwgN+8Rm4T1mIRaeu3Tf++lUDbxI3xbhFJFZHUtLS0fEIN0smT0KYNJJfed1gZY0yeiUNVhwKP4jwAeJlmjcFehuwn+YLqBoxX1XjgGiBZRMoAu4H6bhNWX+ADEamax3YCxT5GVRNVNTEuLq4QQgXS06FSJaefwxhjSql8z4CquixA3cYgtr0TqOdXjnfr/PXB7aNQ1aUiEoPzsOEe3FF4VXWFiGwGmrrrx+ezzfCpVAk+/7zIdmeMMZEo2CfHC2I5kCAijUQkGqfze0aOZbbh9J0gIucCMUCaiMS5neuIyFk4neA/qupu4JCIXOreTdUL+DSMx2CMMSaHsCUOd1iS+3HGuFqPc/fUWhEZIiKd3cUeBe4SkdU4/Sm3u81h7YA1IrIK+Aj4u9/dXfcC7wCbgM1A0b1gat48aNUKNm8usl0aY0ykCWtjvarOwun09q/7p9/ndUCbAOtNA6adYpupQPPCjTRIZctCzZpQsaInuzfGmEhgvbyh+MtfnMkYY0qxcPZxGGOMKYEscYTi/vuhUyevozDGGE9ZU1UomjaFypW9jsIYYzxliSMUDz7odQTGGOM5a6oyxhgTEkscoYiPh2ee8ToKY4zxlCWOUNx8M1x0kddRGGOMp6yPIxSvvup1BMYY4zm74jDGGBMSSxzB+vZbZ6iRGTnHaTTGmNLFEkewqleH++6DJnm+MdcYY0o86+MIVr168PLLXkdhjDGesyuOYP3xB2RkeB2FMcZ4zhJHsEaNgnLlYP/+/Jc1xpgSzBJHsC69FAYPhqohvfrcGGNKHOvjCNYllziTMcaUcnbFEaz9++H3372OwhhjPGeJI1jXXQcJCV5HYYwxngtr4hCRjiKyQUQ2iUi/APPri8giEVkpImtE5Bq3voOIrBCRb92fl/utk+Juc5U7nRHOY8j01FPQv3+R7MoYYyJZ2Po4RCQKeAvoAOwAlovIDFVd57fY08BUVR0lIs2AWUBDYC9wvaruEpHmwFygrt96PVQ1NVyxB3TttUW6O2OMiVThvOJoBWxS1R9V9Q9gCnBDjmUU8N2mVA3YBaCqK1V1l1u/FqggIuXDGGv+tm+HQ4c8DcEYYyJBOBNHXWC7X3kH2a8aAAYBPUVkB87VxgMBtnMT8I2qHvOrG+c2Uw0UEQm0cxG5W0RSRSQ1LS2twAeRqWVLePLJ09+OMcYUc153jncDxqtqPHANkCwimTGJyHnAi8A9fuv0UNXzgbbudFugDavqGFVNVNXEuLi404/0tdegZ8/T344xxhRz4XyOYydQz68c79b56wN0BFDVpSISA8QCe0QkHvgY6KWqm30rqOpO9+dvIvIBTpPYxLAdhU+PHmHfhTHGFAfhvOJYDiSISCMRiQZuBXKOSb4NuAJARM4FYoA0EakOzAT6qeoXvoVFpKyIxLqfywHXAd+F8Rgcx4/Dxo1w+HDYd2WMMZEubIlDVTOA+3HuiFqPc/fUWhEZIiKd3cUeBe4SkdXAZOB2VVV3vSbAP3PcdlsemCsia4BVOFcwY8N1DJm2b4ezz4bp08O+K2OMiXTinKdLtsTERE1NPY27dw8dgs8+g9atoVGjwgvMGGMimIisUNXEnPU2VlUwqlaF7t29jsIYYyKC13dVFQ979sCKFXDsWP7LGmNMCWeJIxgzZkBiopNAjDGmlLPEEYyrr4ZPPoEzimZYLGOMiWTWxxGMevWcyRhjjF1xBOX7750+DmOMMZY4gvLCC9C1q9dRGGNMRLDEEYx+/eD9972OwhhjIoL1cQTjnHOcyRhjjF1xBGXhQtiwwesojDEmIljiCMYtt8CIEV5HYYwxEcGaqoIxZ44z7IgxxhhLHEFp2bLIdnXo0CH27NnD8ePHi2yfxpjSo2zZssTExBAXF0dMTEzBtlHIMZU8Bw/CokXOyLi1a4d1V4cOHeKXXzu5qTcAABr2SURBVH6hbt26VKhQgVO8FdcYYwpEVcnIyODw4cNs27aN2rVrU61atZC3Y30c+dmwwXmG43SGZQ/Snj17qFu3LhUrVrSkYYwpdCJCuXLlqFGjBvHx8ezbt69A27Erjvw0bw7ffFMk7+E4fvw4FSpUCPt+jDGmQoUKHCvgiN+WOPJTsSJcdFGR7c6uNIwxReF0zjXWVJWf5cth0iQoBW9KNMaYYFjiyM/kyXDPPWBXAsYYA4Q5cYhIRxHZICKbRKRfgPn1RWSRiKwUkTUico3fvP7uehtE5Opgt1noBg2C1avDvpuS7PPPP6dTp07UqlWLmJgYmjZtypNPPsmBAwdyLfvLL7/w4IMP0rRpUypUqEBsbCwtW7bkoYceOmV7bIcOHRAR/vWvfwWcP2jQIEQk4LRp06Z843/zzTcRET7//PNc8/bu3csZZ5xBV79BMJcvX85NN91E7dq1KV++PA0bNuTee+9l586dudZv3749l112Wb4xTJkyhaSkJKpXr07FihU5//zzGTZsGOnp6fmuG8j7779PmzZtiIuLy4yxT58+bNu2rUDb81m3bh133HEHDRo0oHz58lSrVo22bdvy+uuv8/vvvwOwdevWbH+D6OhomjZtyiOPPJLt34Tv71ahQgUOHjyYa18TJkwI6e9YmMaOHcs555xD+fLlOfvssxk9enTQ644aNSpz3fr16zNw4MCAt9+vXbuWq666isqVK1OrVi3uuOMO9u/fn22ZlJSUgP+uq1evftrHmCdVDcsERAGbgbOAaGA10CzHMmOAf7ifmwFb/T6vBsoDjdztRAWzzUBTy5YttThYt26d1yEUuqFDhyqgXbp00enTp2tKSoq+8soresYZZ2jjxo1127ZtmcsePHhQGzRooI0aNdKRI0fqwoUL9aOPPtKnn35aGzVqpAcOHMi1/e3bt2uZMmUU0BYtWgSM4ZlnnlFAlyxZokuXLs02/f777/kew8mTJ7Vt27basGFDPXz4cLZ5PXr00Bo1auju3btVVXXixIkaFRWlSUlJOnnyZF28eLGOHj1azzrrLI2Li9PVq1dnWz8pKUnbtGmT5/7vvvtuFRG944479LPPPtOFCxfqoEGDtHLlynrxxRfrwYMH8z2GnEaMGKEDBw7UTz75RFNSUvTtt9/W+Ph4jY+P10OHDoW8PVXVqVOnanR0tCYmJurYsWM1JSVFZ86cqU888YRWr15dR4wYoaqqW7ZsUUD79++vS5cu1ZSUFB06dKhWqFBBk5KS9OTJk6qa9XerUqWKvvPOO7n295e//EWrVKmigP7www8FirkgxowZoyKiAwYM0IULF+pTTz2lIqIjR47Md91hw4apiGjfvn31888/15deekkrVKigffr0ybbczp07NTY2Vi+77DKdPXu2Tp48WePj47V169Z64sSJzOUWLVqkgL7++uvZ/l0vX748qGPJ75wDpGqg83ugysKYgNbAXL9yf6B/jmXeBp70W/7LQMsCc935+W4z0HRaiWPsWNW5cwu+fghKWuJYuHChiog+/PDDueb9+OOPWqNGDW3fvn1m3bvvvquArlq1KtfyJ0+ezDyh+Bs2bJgCes011yig3377ba5lfCeg48ePF/hYNm7cqBUqVNCHHnoos27WrFkK6Pjx41VVdf369Vq+fHm96aabsv3nVlXdu3evNm7cWBMSEvSPP/7IrM8vcYwbN06BzJOuv6+//lqjo6P19ttvL/Bx+ZszZ44C+tFHH4W87saNGzUmJka7dOkS8Pe8Z88eXbJkiapmJY6xY8dmW2bQoEEK6IoVK1Q16+/Wu3dvTUpKyrbstm3bVET09ttvL9LEcfz4cY2Li9NevXplq7/jjju0Vq1a2f62OaWnp2vlypW1d+/e2epffvllFRH97rvvMusefvhhrVatWrYvS4sXL1ZAp02bllnnSxzz5s0r0PEUNHGEs6mqLrDdr7zDrfM3COgpIjuAWcAD+awbzDYBEJG7RSRVRFLT0tIKegxOU9WHHxZ8/VLspZdeombNmjz//PO55jVq1Ih+/fqRkpLCV199BZB5GV6nTp1cy/suwXOaMGEC5513HiPcscQmTJhQmIeQKSEhgSFDhvDGG2+wbNkyfvvtN/7+97/TqVMnevfuDcC//vUvTpw4wRtvvEGZMtn/a9WqVYthw4bxww8/MH369KD3++KLL3Leeefx4IMP5pp38cUX06dPH5KTk9m1axfHjh2jZs2a9O3bN9eyU6dORURYuXLlKfdVq1YtwHmyOFQjRowgIyODkSNHBlw/Li6ONm3a5LmNiy++GCBXs1OvXr3473//y08//ZRZl5ycTIMGDWjXrl3IsZ6OpUuXkpaWRs+ePbPV33bbbezbt48lS5acct3vvvuOw4cP06lTp2z1HTt2RFX55JNPMutmzJjBtddem63JqV27dtSvX59PP/20kI6m4LzuHO8GjFfVeOAaIFlECiUmVR2jqomqmhgXF1fwDW3caAMcFkBGRgaLFy+mQ4cOpxzWoHPnzgAsXLgQgFatWgFw6623MnfuXI4cOZLnPr766is2bNjAbbfdRkJCAq1bt2bSpEmcOHEi4PInTpwgIyMjczp58mRIx/TII4+QmJhInz59eOyxx/j11195++23M+cvWLCAxMREzjzzzIDrX3vttZQpUybzePOza9cuvv/+e66//vpT3jrZuXNnTpw4weLFiylfvjw333wzkydPzvU7SE5Opnnz5lyU49byEydOcOzYMdasWUPfvn1p1qwZV199NaGaN28eF1988SmPPRhbtmwByNU+37ZtWxo2bMikSZMy65KTk+nZs2fQt5Sq+8R0flN+/ybWrl0LQPPmzbPVn3feeYDTx3MqUVFRAERHR2erL1++POAkFoD09HS2bNmSax++/QTaR48ePYiKiqJWrVp07979tPuq8hPOxLET8H9Rd7xb568PMBVAVZcCMUBsHusGs83CVbEiVKkS1l3kq317GD/e+Xz8uFP2vVjq6FGn7LsqOnjQKfu+1e7d65T/8x+n/PPPTnnOHKe8fbtTnj/fKf/4o1NevNgpF3A4+X379pGenk7Dhg1PuYxv3vbtzkVku3btGDx4MF988QUdO3akWrVqJCYmMmjQIH799ddc60+YMIEyZcpkfvvr3bs3u3fvZt68eQH3FxMTQ7ly5TKnXr16hXRMUVFRjBs3jk2bNjFmzBheeukl6vm9i3779u15Hm+lSpWIi4vLPN78+JYL5Xd422238fPPPzPf9/cE0tLSmDNnDrfddluu9WvXrk1MTAwXXHAB6enpzJ8/v0DjF23fvp0GDRqEtM7JkyfJyMjg6NGjzJs3j+eee44zzzyTtm3bZltOROjZsyfJyckAfP3113z//fch/f0mTJiQ7W9/qunOO+/Mczu+q+IaNWpkq69Zs2a2+YEkJCRQpkwZli1blq1+6dKl2dY9cOAAqpprH779+O+jWrVqPProo7zzzjssXLiQgQMHMn/+fFq3bs2ePXvyPJbTEc4HAJcDCSLSCOfkfivQPccy24ArgPEici5O4kgDZgAfiMirwJ+ABOBrQILYZuE5cgReew2uuw4uvDBsuzFZ/vnPf3L33Xczc+ZMlixZQkpKCoMHD+add95hxYoV1HbHCzt27BhTpkzh8ssvp25dp7Xylltu4aGHHmLChAl07Ngx17aXLVuW+a0PsppmQtGsWTNuuukmFixYwN13313AowyfNm3a0LhxY5KTkzOvHKZMmcLJkyfp0aNHruUXLFjA0aNHWb9+Pc8//zwdOnRgyZIl4b8rB7jnnnu45557MsuXXXYZb731VsDRE3r16sWzzz7L8uXLmThxIpdeeikJCQl88cUXQe3r+uuvZ/ny5fkuFxsbG/wBhKhy5crceeedvPnmm1x00UV07NiRlStXMmDAAKKionI1bwbjoosuynYVmZSURLt27WjVqhWvv/46zz33XGEeQqawJQ5VzRCR+3E6tqOA91R1rYgMwelwmQE8CowVkUcABW53O2TWishUYB2QAdynqicAAm0zXMdAWhoMHAh163qbOFJSsj6XK5e9XLFi9nK1atnLsbHZy3XqZC/Xq5e9fNZZ2ctnn12gkH233m7duvWUy/jm+X9rd0KsQ58+fejTpw8Ab731Fvfffz8vv/wyw4cPB+A///kPBw4coGvXrtmuRq6++mo+/fRTDh06RNUcQ+G3bNmyQO33OUVHR1OuXLlczSTx8fF5Hu+RI0dIS0vLdbynEh8fDxDy77Bnz54MHz6cI0eOUKlSJZKTk7MlWH8XXHABAK1bt6Z9+/Y0adKE0aNH069faHe616tXL1sfRDCefvppbrjhhszbUvMabK9Jkya0bt2ad999l48++ohnn302pH3VrFkzqMH88jt5+64CDhw4kK1ZzncV4LvyOJVXXnmFffv20b17d1SVmJgYhgwZwksvvZS5verVqyMiAW9X379/f777aNGiBU2bNg0qURZUWPs4VHWWqjZV1caqOtSt+6ebNFDVdaraRlUvUNULVfVzv3WHuuudraqz89pm2DRsCMeOQYBvaiZvZcuWJSkpiXnz5mXev5/TjBkzALj88svz3NZ9991HjRo1srXt+jrBffN804wZM0hPT2fq1KmFdCTBu+KKK0hNTWX37t0B58+cOZOTJ0/me7w+devW5eyzz+Y/vmbGAGbMmEFUVBRJSUmZdbfddhtHjhxh+vTpbNy4keXLlwdspsrprLPOombNmgV6JuLKK68kNTWVn3/+Oeh1GjRoQGJiIueff35QJ/VevXoxduxYfvvtN2699daQ4iuspipfX4avr8PH92+zWbNmea5ftWpVpk+fzi+//MKaNWvYs2cPvXr1Yu/evZnP81SsWJGGDRvm2odvP/ntwyecwxd53Tke+aKjncmE7LHHHmPfvn0MGDAg17wtW7bw4osv0q5dOy655BLAefgvUOfk7t27OXjwYOY3sj179jBnzhxuuOEGFi1alGuqU6dO2O6uystDDz1EmTJleOCBB3Idx/79+xkwYABNmjThxhtvDHqbjz/+OGvXruX111/PNW/58uW8++679OjRgz/96U+Z9Y0bN+bPf/4zycnJJCcnU6lSpaD2uXbtWvbt20fjxo2Djs/nkUceISoqinvvvTfgzQl79+4NulnpVG655RY6d+5Mv379Arb/58XXVJXfNGjQoDy307p1a2JjY7N11IPzQGXNmjXzvXPMJy4ujvPPP58qVarw2muvERsby9/+9rfM+Z07d2bmzJnZHnxcsmQJP/30U+ZNJaeSmprKhg0bMm82CYtA9+iWtKnAz3GsW6f6z3+q7tpVsPVD3l3Jeo5DVXXw4MEKaNeuXfXjjz/WlJQUffXVV/WMM87QRo0a6U8//ZS57Msvv6wJCQk6aNAgnTVrlqakpOiYMWP07LPP1piYmMz7+1999VUFNCUlJeA+n3zySRUR3bx5s6oWznMc/nr37q1169YNOG/cuHEaFRWl7du31ylTpujixYv17bff1saNG2utWrX0m2++ybZ8UlKSnnPOOfrvf/8717RhwwZVVb3zzjtVRPTOO+/UmTNn6sKFC3Xw4MFapUoVbdGiRcAHI0eNGqVlypTROnXqaM+ePXPNb9OmjQ4fPlw/++wznT9/vr766qtat25djY+P17S0tMzlfM8JjBs3Lt/fi/8DgO+8844uXrxYZ82apf3799eaNWvmegAw53McOQXzd/M951KUDwCOGjVKRUSfeuopXbRokQ4cOFBFRN98881sy915550aFRWVrW7KlCk6cuRIXbBggX700UfavXt3LVu2rH766afZltuxY4fWqlVL27Vrp7Nnz9YpU6Zo/fr19ZJLLsn2jFD37t31qaee0mnTpumCBQt0+PDhWqtWLa1Xr162v+OpRNwDgJE0FThx/Pvfzq+oiE7oJTFxqKrOnj1br7rqKq1evbpGR0drkyZN9LHHHtN9+/ZlW27dunX68MMP64UXXqg1a9bUsmXLap06dfSmm27KTBqqqhdccIE2btw44AOBqqobNmxQQJ955hlVLdrEoaq6dOlS7dKli8bGxmq5cuW0fv36es8992R7St4nKSlJcfr3ck0vv/xy5nKTJk3Stm3bapUqVTQmJkbPO+88ffbZZ/XIkSMBY9i/f79GR0croHMDPMDat29fbd68uVapUkUrVaqk5557rj722GP6yy+/ZFvus88+U0Bnz54d1O/mu+++0969e2u9evW0XLlyWrVqVb3sssv0rbfeynxKv7gnDlXV0aNHa0JCQua/57feeivXMr1791bnu3mWDz/8UJs3b64VKlTQKlWqaIcOHTIfjMxpzZo1euWVV2rFihW1evXq2rt3b927d2+2ZYYNG6bnn3++Vq1aVcuWLavx8fF611136a4gv+wWNHGIM69kS0xM1NSCvojJ1+RQgDseQrV+/XrOPffcsO/HmGANGDCAGTNm8O2339qQ/yVQfuccEVmhqok56+19HPkpgoRhTKRavHgxAwYMsKRhsrHEkZcFC5xpyBAohNs4TWTKyMjIc35UVFSpPXGeboe2KZns63RevvoKXnkF/B4aMyVPfrdoenGHljGRzL5G52XAAGcyJVp+D0o1KoL3zRtTnFjiMKVeYmKuvj9jTB6sqSovEyfCsGFeR2GMMRHFEkdeypWDHE+IhltpuD3aGOO90znXWFNVXrp1g0qVimx35cqVIz09nYoVKxbZPo0xpVN6enrmu0BCZVcc+clnXJjCdMYZZ7Bz506OHj1qVx7GmEKnqhw/fpz9+/ezY8eOAr1aAOyKI6L4hgHftWsXx48f9zgaY0xJVLZsWWJiYqhfv36BXtoFljgiTtWqVXO9R8IYYyKJNVUZY4wJiSUOY4wxIbHEYYwxJiSWOIwxxoTEEocxxpiQWOIwxhgTklLxBkARSQN+CmLRWGBvmMMpKIstdJEaF0RubJEaF1hsBXG6cTVQ1biclaUicQRLRFIDvSYxElhsoYvUuCByY4vUuMBiK4hwxWVNVcYYY0JiicMYY0xILHFkN8brAPJgsYUuUuOCyI0tUuMCi60gwhKX9XEYY4wJiV1xGGOMCYklDmOMMSGxxOESkY4iskFENolIP6/j8RGR90Rkj4h853Us/kSknogsEpF1IrJWRB7yOiYfEYkRka9FZLUb22CvY/InIlEislJEPvM6Fn8islVEvhWRVSKS6nU8/kSkuoh8JCLfi8h6EWkdATGd7f6ufNMhEXnY67h8ROQR99//dyIyWUQK9vKNQNu2Pg7nPzKwEegA7ACWA91UdZ2ngQEi0g44DExU1eZex+MjImcCZ6rqNyJSBVgBdImQ35kAlVT1sIiUA5YAD6nqMo9DA0BE+gKJQFVVvc7reHxEZCuQqKoR9yCbiEwA/qeq74hINFBRVX/1Oi4f9xyyE7hEVYN52Djc8dTF+XffTFXTRWQqMEtVxxfG9u2Kw9EK2KSqP6rqH8AU4AaPYwJAVf8L7Pc6jpxUdbeqfuN+/g1YD9T1NiqHOg67xXLuFBHfkEQkHrgWeMfrWIoLEakGtAPeBVDVPyIpabiuADZHQtLwUxaoICJlgYrArsLasCUOR11gu195BxFyEiwORKQhcBHwlbeRZHGbg1YBe4B5qhopsY0AngBOeh1IAAp8LiIrRORur4Px0whIA8a5TXzviEglr4PK4VZgstdB+KjqTmA4sA3YDRxU1c8La/uWOMxpEZHKwDTgYVU95HU8Pqp6QlUvBOKBViLieTOfiFwH7FHVFV7HcgqXqWoLoBNwn9tMGgnKAi2AUap6EXAEiKR+yGigM/Bvr2PxEZEaOK0mjYA/AZVEpGdhbd8Sh2MnUM+vHO/WmTy4/QfTgEmqOt3reAJxmzQWAR29jgVoA3R2+xKmAJeLyPvehpTF/ZaKqu4BPsZpwo0EO4AdfleNH+EkkkjRCfhGVX/xOhA/VwJbVDVNVY8D04E/F9bGLXE4lgMJItLI/fZwKzDD45gimtsB/S6wXlVf9ToefyISJyLV3c8VcG56+N7bqEBV+6tqvKo2xPk3tlBVC+1b4OkQkUruTQ64zUBXARFxJ5+q/gxsF5Gz3aorAM9vwvDTjQhqpnJtAy4VkYru/9UrcPohC0XZwtpQcaaqGSJyPzAXiALeU9W1HocFgIhMBtoDsSKyA3hGVd/1NirA+fZ8G/Ct25cAMEBVZ3kYk8+ZwAT3TpcywFRVjahbXyNQbeBj5xxDWeADVZ3jbUjZPABMcr/Y/Qjc4XE8QGaS7QDc43Us/lT1KxH5CPgGyABWUojDj9jtuMYYY0JiTVXGGGNCYonDGGNMSCxxGGOMCYklDmOMMSGxxGGMMSYkdjuuMYCInAC+9avqAjQEPgW2AOWBKaoaUSPtGuMFSxzGONLdIUoyuWNw/U9Vr3Pv118lIv/xDe4YLiJSVlUzwrkPY06HNVUZEwRVPYIzdHwTEUnyewfDSt8T1z4i0tB9b8Qk990RH4lIRXdeSxFZ7A4kONcdnh4RSRGREe57MB7Ksb2A+xORJ933Z6wWkRfcurtEZLlbN81vv+NF5HUR+VJEfhSRv4b9l2ZKLEscxjgq+J2cP845U0RqAZcCa4HHgPvcK5S2QHqA7Z0NjFTVc4FDwL3u2F5vAH9V1ZbAe8BQv3WiVTVRVV/Jsa1c+xORTjiD2F2iqhcAL7nLTlfVi9269UAfv+2cCVwGXAe8EOTvxZhcrKnKGEeupipXWxFZiTMM+guqulZEvgBeFZFJOCfqHQHW266qX7if3wceBOYAzYF57tAeUThDXvt8eIrYcu1PRK4ExqnqUQBV9b2zpbmIPAdUByrjDKPj84mqngTWiUjtPH4XxuTJEocxeftfzrf0qeoLIjITuAb4QkSuVtWcgyjmHMtHAQHWquqpXnt6JFBloP3lEe94nDcxrhaR23HGOfM55vdZ8tiGMXmypipjQiQijVX1W1V9EWdk5XMCLFZfst6L3R3nNZ4bgDhfvYiUE5HzCri/ecAdfn0YNd3FqwC73WaxHgU/SmNOzRKHMaF7WES+E5E1wHFgdoBlNuC8DGk9UAPnJUR/AH8FXhSR1cAqgntHQq79uSPXzgBS3dGJH3OXHYjzJsYviICh5E3JZKPjGlPI3Nt4P1NVz986aEw42BWHMcaYkNgVhzHGmJDYFYcxxpiQWOIwxhgTEkscxhhjQmKJwxhjTEgscRhjjAnJ/wd3FSffnTh1SQAAAABJRU5ErkJggg==\n", 87 | "text/plain": [ 88 | "
" 89 | ] 90 | }, 91 | "metadata": { 92 | "needs_background": "light" 93 | } 94 | } 95 | ] 96 | } 97 | ] 98 | } 99 | -------------------------------------------------------------------------------- /GenerateCSV.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn import metrics 3 | import matplotlib.pyplot as plt 4 | import json 5 | from pathlib import Path 6 | import matplotlib.gridspec as gridspec 7 | from matplotlib.patches import FancyBboxPatch 8 | import math 9 | import SimpleITK as sitk 10 | import csv 11 | from tqdm import tqdm 12 | import argparse 13 | 14 | parser = argparse.ArgumentParser(description='PyTorch DataBowl3 Detector') 15 | parser.add_argument('--model', '-m', metavar='MODEL', default='base', 16 | help='model') 17 | parser.add_argument('--cross', default=None, type=str, metavar='N', 18 | help='which data cross be used') 19 | parser.add_argument('--epoch', default=None, type=str, metavar='N', 20 | help='which data cross be used') 21 | 22 | args = parser.parse_args() 23 | 24 | 25 | 26 | def worldToVoxelCoord(worldCoord, origin, spacing): 27 | stretchedVoxelCoord = np.absolute(worldCoord - origin) 28 | voxelCoord = stretchedVoxelCoord / spacing 29 | return voxelCoord 30 | 31 | 32 | def VoxelToWorldCoord(voxelCoord, origin, spacing): 33 | strechedVocelCoord = voxelCoord * spacing 34 | worldCoord = strechedVocelCoord + origin 35 | return worldCoord 36 | 37 | def nms(output, nms_th): 38 | if len(output) == 0: 39 | return output 40 | 41 | output = output[np.argsort(-output[:, 0])] 42 | bboxes = [output[0]] 43 | 44 | for i in np.arange(1, len(output)): 45 | bbox = output[i] 46 | 47 | for j in range(len(bboxes)): 48 | if iou(bbox[1:5], bboxes[j][1:5]) >= nms_th: 49 | break 50 | else: 51 | bboxes.append(bbox) 52 | 53 | bboxes = np.asarray(bboxes, np.float32) 54 | return bboxes 55 | 56 | def iou(box0, box1): 57 | r0 = box0[3] / 2 58 | s0 = box0[:3] - r0 59 | e0 = box0[:3] + r0 60 | 61 | r1 = box1[3] / 2 62 | s1 = box1[:3] - r1 63 | e1 = box1[:3] + r1 64 | 65 | overlap = [] 66 | for i in range(len(s0)): 67 | overlap.append(max(0, min(e0[i], e1[i]) - max(s0[i], s1[i]))) 68 | 69 | intersection = overlap[0] * overlap[1] * overlap[2] 70 | union = box0[3] * box0[3] * box0[3] + box1[3] * box1[3] * box1[3] - intersection 71 | return intersection / union 72 | 73 | def load_itk_image(filename): 74 | with open(filename) as f: 75 | contents = f.readlines() 76 | line = [k for k in contents if k.startswith('TransformMatrix')][0] 77 | transformM = np.array(line.split(' = ')[1].split(' ')).astype('float') 78 | transformM = np.round(transformM) 79 | if np.any( transformM!=np.array([1,0,0, 0, 1, 0, 0, 0, 1])): 80 | isflip = True 81 | else: 82 | isflip = False 83 | 84 | itkimage = sitk.ReadImage(filename) 85 | numpyImage = sitk.GetArrayFromImage(itkimage) 86 | 87 | numpyOrigin = np.array(list(reversed(itkimage.GetOrigin()))) 88 | numpySpacing = np.array(list(reversed(itkimage.GetSpacing()))) 89 | 90 | return numpyImage, numpyOrigin, numpySpacing, isflip 91 | 92 | def main(bbox_path, preprocess_path, lunaseg_path, save_file): 93 | total_list = [] 94 | epochs = args.epoch 95 | epochs = epochs.split('.') 96 | count = 0 97 | for i in range(5): 98 | total_list.append([]) 99 | with Path('test_0222_%s/LUNA_test.json' %str(i+1)).open('rt', encoding='utf-8') as fp: 100 | idcs = json.load(fp) 101 | for x in tqdm(range(len(idcs))): 102 | pbb = np.load('%s%s/bbox_%s/%s_pbb.npy' %(bbox_path, str(i+1), epochs[i], idcs[x]), mmap_mode='r') 103 | lbb = np.load("%s%s_label.npy" % (preprocess_path, idcs[x]), allow_pickle=True) 104 | pbb = nms(pbb, 0.1) 105 | 106 | Mask,origin,spacing,isflip = load_itk_image('%s%s.mhd' %(lunaseg_path, idcs[x])) 107 | 108 | origin = np.load('%s%s_origin.npy' %(preprocess_path, idcs[x]), mmap_mode='r') 109 | spacing = np.load('%s%s_spacing.npy' %(preprocess_path, idcs[x]), mmap_mode='r') 110 | resolution = np.array([1, 1, 1]) 111 | extendbox = np.load('%s%s_extendbox.npy' %(preprocess_path, idcs[x]), mmap_mode='r') 112 | 113 | pbb = np.array(pbb[:, :-1]) 114 | pbb[:, 1:] = np.array(pbb[:, 1:] + np.expand_dims(extendbox[:,0], 1).T) 115 | pbb[:, 1:] = np.array(pbb[:, 1:] * np.expand_dims(resolution, 1).T / np.expand_dims(spacing, 1).T) 116 | 117 | if isflip: 118 | Mask = np.load('%s%s_mask.npy' %(preprocess_path, idcs[x]), mmap_mode='r') 119 | pbb[:, 2] = pbb[:, 2] - Mask.shape[1] 120 | pbb[:, 3] = pbb[:, 3] - Mask.shape[2] 121 | 122 | pos = VoxelToWorldCoord(pbb[:, 1:], origin, spacing) 123 | 124 | rowlist = [] 125 | for nk in range(pos.shape[0]): 126 | rowlist.append([idcs[x], pos[nk, 2], pos[nk, 1], pos[nk, 0], pbb[nk,0]]) 127 | 128 | total_list[i].append(rowlist) 129 | 130 | with open(save_file, "w") as f: 131 | first_row = ['seriesuid', 'coordX', 'coordY', 'coordZ', 'probability'] 132 | 133 | f.write("%s,%s,%s,%s,%s\n" %(first_row[0], first_row[1], first_row[2], first_row[3], first_row[4])) 134 | 135 | for k in total_list: 136 | for i in k: 137 | for j in i: 138 | f.write("%s,%.9f,%.9f,%.9f,%.9f\n" %(j[0], j[1], j[2], j[3], j[4])) 139 | 140 | if __name__=='__main__': 141 | main(bbox_path='./results/'+args.model+'_testcross', preprocess_path='../data/preprocess/all/', lunaseg_path='../data/LUNA16/seg-lungs-LUNA16/', save_file=args.model+'_80_all.csv') 142 | 143 | 144 | 145 | 146 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Lung Nodule Detection in Pytorch 2 | A One-stage Method for lung nodule detection in LUNA16 dataset 3 | 4 | ## Scripts: 5 | - Would be executed: 6 | - config_training.py: set filepaths 7 | - data_detector: generate data loader during training and testing 8 | - prepare.py: LUNA16 dataset preprocessing 9 | - main_detector_recon.py: model training and inference 10 | - GenerateCSV.py: Generate result.csv for computing CPM 11 | - noduleCADEvaluaionLUNA16.py: Compute CPM of .csv 12 | - FROC_CPM.ipynb: Plot FROC curve 13 | 14 | - Others 15 | - data_detector.py: generate data loader during training and testing 16 | - preprocess.py: some preproceesing-related codes 17 | - layers.py 18 | - loss.py 19 | - split_combine.py (At Testing stage) 20 | - utils.py 21 | 22 | ## Requirements: 23 | - Python 3.6 24 | - torch 0.4.1 25 | - torchvision 0.2.0 26 | - SimpleITK 27 | - scikit-image 28 | 29 | ## Files: 30 | - LUNA.json: Every Case ID stored in .json for make_dataset.py 31 | - json: Directory contains train/val/test.json (Case ID) respectively 32 | - Download LUNA16 dataset from Grand Challeng: https://luna16.grand-challenge.org and save at the following filepaths 33 | ./data/LUNA16/allset: all .raw and .mhd of LUNA16 data 34 | ./data/LUNA16/seg-lungs-LUNA16: all .zraw and .mhd of LUNA16 mask 35 | 36 | ## How to Do step by step: 37 | - Preprocessing for LUNA16 38 | - python prepare.py 39 | - output file path: config_training -> config[preprocess_result_path] 40 | ``` 41 | Output: id_clean.npy & id_label.npy (for training) ; id_extendbox.npy & id_mask.npy & id_origin.npy & id_spacing.npy (for vox2world) 42 | ``` 43 | - Start training and testing 44 | - training 45 | ``` 46 | python main_detector_recon.py --model OSAF_YOLOv3 -b [batch_size] --epochs [num_epochs] --save-dir [save_dir_path] --save-freq [save_freq_ckpt] --gpu '0' --n_test [number of gpu for test] --lr [lr_rate] --cross [1-5 set which cross_data be used] 47 | eg: python main_detector_recon.py --model OSAF_YOLOv3 -b 2 --epochs 100 --save-dir OSAF_YOLOv3_testcross1 --save-freq 1 --gpu '0' --n_test 1 --lr 0.001 --cross 1 48 | ``` 49 | - testing 50 | ``` 51 | python main_detector_recon.py --model OSAF_YOLOv3 --resume [resume_ckpt] --save-dir [] --test 1 --gpu '0' --n_test [] --cross [] 52 | eg: python main_detector_recon.py --model OSAF_YOLOv3 --test 1 --cross 1 --resume "./results/OSAF_YOLOv3_testcross1/1.ckpt" --save-dir "OSAF_YOLOv3_testcross1" --gpu 0 53 | ``` 54 | 55 | - Compute CPM (After test all 5 fold) 56 | - Generate result.csv 57 | ``` 58 | python GenerateCSV.py 59 | ``` 60 | ``` 61 | output: OSAF_YOLOv3_80_all.csv 62 | ``` 63 | - Then compute CPM and save related .png and .npy 64 | ``` 65 | python noduleCADEvaluaionLUNA16.py (Remember to modify the filepath in noduleCADEvaluation.py) 66 | ``` 67 | ``` 68 | output: print(csv_name, CPM, seven_sensitivities@predefined_fps) and save ./CPM_Results_OSAF_YOLOv3_80/_.npy & _.png 69 | ``` 70 | 71 | - Plot FROC Curve 72 | ``` 73 | Execute FROC_CPM.ipynb 74 | ``` 75 | ``` 76 | output: FROC_OSAF_YOLOv3.png 77 | ``` 78 | 79 | - How to transform voxel coord of pbb into world voxel of 3D CT ? 80 | - You can refer to GenerateCSV.py: How to transform pbb -> pos 81 | -------------------------------------------------------------------------------- /adable.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer 4 | version_higher = ( torch.__version__ >= "1.5.0" ) 5 | 6 | class AdaBelief(Optimizer): 7 | r"""Implements AdaBelief algorithm. Modified from Adam in PyTorch 8 | Arguments: 9 | params (iterable): iterable of parameters to optimize or dicts defining 10 | parameter groups 11 | lr (float, optional): learning rate (default: 1e-3) 12 | betas (Tuple[float, float], optional): coefficients used for computing 13 | running averages of gradient and its square (default: (0.9, 0.999)) 14 | eps (float, optional): term added to the denominator to improve 15 | numerical stability (default: 1e-8) 16 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 17 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 18 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 19 | (default: False) 20 | weight_decouple (boolean, optional): ( default: False) If set as True, then 21 | the optimizer uses decoupled weight decay as in AdamW 22 | fixed_decay (boolean, optional): (default: False) This is used when weight_decouple 23 | is set as True. 24 | When fixed_decay == True, the weight decay is performed as 25 | $W_{new} = W_{old} - W_{old} \times decay$. 26 | When fixed_decay == False, the weight decay is performed as 27 | $W_{new} = W_{old} - W_{old} \times decay \times lr$. Note that in this case, the 28 | weight decay ratio decreases with learning rate (lr). 29 | rectify (boolean, optional): (default: False) If set as True, then perform the rectified 30 | update similar to RAdam 31 | reference: AdaBelief Optimizer, adapting stepsizes by the belief in observed gradients 32 | NeurIPS 2020 Spotlight 33 | """ 34 | 35 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 36 | weight_decay=0, amsgrad=False, weight_decouple = False, fixed_decay=False, rectify = False ): 37 | if not 0.0 <= lr: 38 | raise ValueError("Invalid learning rate: {}".format(lr)) 39 | if not 0.0 <= eps: 40 | raise ValueError("Invalid epsilon value: {}".format(eps)) 41 | if not 0.0 <= betas[0] < 1.0: 42 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 43 | if not 0.0 <= betas[1] < 1.0: 44 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 45 | defaults = dict(lr=lr, betas=betas, eps=eps, 46 | weight_decay=weight_decay, amsgrad=amsgrad) 47 | super(AdaBelief, self).__init__(params, defaults) 48 | 49 | self.weight_decouple = weight_decouple 50 | self.rectify = rectify 51 | self.fixed_decay = fixed_decay 52 | if self.weight_decouple: 53 | print('Weight decoupling enabled in AdaBelief') 54 | if self.fixed_decay: 55 | print('Weight decay fixed') 56 | if self.rectify: 57 | print('Rectification enabled in AdaBelief') 58 | if amsgrad: 59 | print('AMS enabled in AdaBelief') 60 | def __setstate__(self, state): 61 | super(AdaBelief, self).__setstate__(state) 62 | for group in self.param_groups: 63 | group.setdefault('amsgrad', False) 64 | 65 | def reset(self): 66 | for group in self.param_groups: 67 | for p in group['params']: 68 | state = self.state[p] 69 | amsgrad = group['amsgrad'] 70 | 71 | # State initialization 72 | state['step'] = 0 73 | # Exponential moving average of gradient values 74 | state['exp_avg'] = torch.zeros_like(p.data, 75 | memory_format=torch.preserve_format) if version_higher else torch.zeros_like(p.data) 76 | 77 | # Exponential moving average of squared gradient values 78 | state['exp_avg_var'] = torch.zeros_like(p.data, 79 | memory_format=torch.preserve_format) if version_higher else torch.zeros_like(p.data) 80 | if amsgrad: 81 | # Maintains max of all exp. moving avg. of sq. grad. values 82 | state['max_exp_avg_var'] = torch.zeros_like(p.data, 83 | memory_format=torch.preserve_format) if version_higher else torch.zeros_like(p.data) 84 | 85 | def step(self, closure=None): 86 | """Performs a single optimization step. 87 | Arguments: 88 | closure (callable, optional): A closure that reevaluates the model 89 | and returns the loss. 90 | """ 91 | loss = None 92 | if closure is not None: 93 | loss = closure() 94 | 95 | for group in self.param_groups: 96 | for p in group['params']: 97 | if p.grad is None: 98 | continue 99 | grad = p.grad.data 100 | if grad.is_sparse: 101 | raise RuntimeError('AdaBelief does not support sparse gradients, please consider SparseAdam instead') 102 | amsgrad = group['amsgrad'] 103 | 104 | state = self.state[p] 105 | 106 | beta1, beta2 = group['betas'] 107 | 108 | # State initialization 109 | if len(state) == 0: 110 | state['rho_inf'] = 2.0 / (1.0 - beta2) - 1.0 111 | state['step'] = 0 112 | # Exponential moving average of gradient values 113 | state['exp_avg'] = torch.zeros_like(p.data, 114 | memory_format=torch.preserve_format) if version_higher else torch.zeros_like(p.data) 115 | # Exponential moving average of squared gradient values 116 | state['exp_avg_var'] = torch.zeros_like(p.data, 117 | memory_format=torch.preserve_format) if version_higher else torch.zeros_like(p.data) 118 | if amsgrad: 119 | # Maintains max of all exp. moving avg. of sq. grad. values 120 | state['max_exp_avg_var'] = torch.zeros_like(p.data, 121 | memory_format=torch.preserve_format) if version_higher else torch.zeros_like(p.data) 122 | 123 | # get current state variable 124 | exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var'] 125 | 126 | state['step'] += 1 127 | bias_correction1 = 1 - beta1 ** state['step'] 128 | bias_correction2 = 1 - beta2 ** state['step'] 129 | 130 | # perform weight decay, check if decoupled weight decay 131 | if self.weight_decouple: 132 | if not self.fixed_decay: 133 | p.data.mul_(1.0 - group['lr'] * group['weight_decay']) 134 | else: 135 | p.data.mul_(1.0 - group['weight_decay']) 136 | else: 137 | if group['weight_decay'] != 0: 138 | grad.add_(p.data, alpha=group['weight_decay']) 139 | 140 | # Update first and second moment running average 141 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 142 | grad_residual = grad - exp_avg 143 | exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2) 144 | 145 | if amsgrad: 146 | max_exp_avg_var = state['max_exp_avg_var'] 147 | # Maintains the maximum of all 2nd moment running avg. till now 148 | torch.max(max_exp_avg_var, exp_avg_var, out=max_exp_avg_var) 149 | 150 | # Use the max. for normalizing running avg. of gradient 151 | denom = (max_exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 152 | else: 153 | denom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 154 | 155 | if not self.rectify: 156 | # Default update 157 | step_size = group['lr'] / bias_correction1 158 | p.data.addcdiv_(exp_avg, denom, value=-step_size) 159 | 160 | else:# Rectified update 161 | # calculate rho_t 162 | state['rho_t'] = state['rho_inf'] - 2 * state['step'] * beta2 ** state['step'] / ( 163 | 1.0 - beta2 ** state['step']) 164 | 165 | if state['rho_t'] > 4: # perform Adam style update if variance is small 166 | rho_inf, rho_t = state['rho_inf'], state['rho_t'] 167 | rt = (rho_t - 4.0) * (rho_t - 2.0) * rho_inf / (rho_inf - 4.0) / (rho_inf - 2.0) / rho_t 168 | rt = math.sqrt(rt) 169 | 170 | step_size = rt * group['lr'] / bias_correction1 171 | 172 | p.data.addcdiv_(exp_avg, denom, value=-step_size) 173 | 174 | else: # perform SGD style update 175 | p.data.add_(exp_avg, alpha=-group['lr']) 176 | 177 | return loss -------------------------------------------------------------------------------- /config_training.py: -------------------------------------------------------------------------------- 1 | config = {'luna_root':'./data/LUNA16/', 2 | 'luna_segment':'./data/LUNA16/seg-lungs-LUNA16/', 3 | 'luna_data':'./data/LUNA16/allset/', 4 | 'luna_label':'./data/LUNA16/CSVFILES/annotations.csv', 5 | 'preprocess_result_path':'./data/preprocess/all/', 6 | } 7 | -------------------------------------------------------------------------------- /data_detector.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | #coding=utf-8 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | import os 7 | import time 8 | import collections 9 | import random 10 | import warnings 11 | from scipy.ndimage import zoom 12 | from scipy.ndimage.interpolation import rotate 13 | import json 14 | from pathlib import Path 15 | 16 | 17 | 18 | class DataBowl3Detector(Dataset): 19 | def __init__(self, data_dir, split, config, phase='train', split_comber=None): 20 | assert phase in ['train', 'val', 'test'] 21 | self.phase = phase 22 | self.max_stride = config['max_stride'] 23 | self.stride = config['stride'] 24 | sizelim = config['sizelim']/config['reso'] 25 | sizelim2 = config['sizelim2']/config['reso'] 26 | sizelim3 = config['sizelim3']/config['reso'] 27 | sizelim4 = config['sizelim4']/config['reso'] 28 | self.blacklist = config['blacklist'] 29 | self.isScale = config['aug_scale'] 30 | self.r_rand = config['r_rand_crop'] # random ratio for sample augmentation == 0.3 31 | self.augtype = config['augtype'] 32 | self.pad_value = config['pad_value'] 33 | self.split_comber = split_comber 34 | if isinstance(split, list): 35 | idcs = split 36 | elif isinstance(split, str) and Path(split).suffix == '.json': 37 | with Path(split).open('rt', encoding='utf-8') as fp: 38 | idcs = json.load(fp) 39 | # A fix for python2 ascii compatibility 40 | idcs = [str(i, encoding='utf-8') if not isinstance(i, str) else i for i in idcs] 41 | 42 | if phase!='test': 43 | idcs = [f for f in idcs if f not in self.blacklist] 44 | 45 | self.channel = config['channel'] # channel==1 46 | self.filenames = [os.path.join(data_dir, '{}_clean.npy'.format(idx)) for idx in idcs] 47 | 48 | labels = [] 49 | 50 | for idx in idcs: 51 | l = np.load(Path(data_dir)/'{}_label.npy'.format(idx), allow_pickle=True) # l = [z, y, x, d] 52 | if np.all(l==0): 53 | l = np.array([]) 54 | labels.append(l) 55 | self.sample_bboxes = labels 56 | 57 | # Balance nodules of different diameters by sizelim. sizelim2, sizelim3 by augment bigger nodules, which are fewer in dataset. 58 | if self.phase != 'test': 59 | self.bboxes = [] 60 | for index, label in enumerate(labels): 61 | if len(label) > 0 : 62 | for t in label: # t = (z, y, x, d) 63 | if t[3] > sizelim: 64 | self.bboxes.append([np.concatenate([[index], t])]) 65 | if t[3] > sizelim2: 66 | self.bboxes += [[np.concatenate([[index], t])]] * 2 67 | if t[3] > sizelim3: 68 | self.bboxes += [[np.concatenate([[index], t])]] * 4 69 | if t[3] > sizelim4: 70 | self.bboxes += [[np.concatenate([[index], t])]] * 8 71 | 72 | # Finally, a balanced label collection is generated. 73 | if len(self.bboxes) > 0: 74 | self.bboxes = np.concatenate(self.bboxes, axis=0) 75 | else: 76 | self.bboxes = np.array(self.bboxes) 77 | 78 | self.crop = Crop(config) 79 | self.label_mapping = LabelMapping(config, self.phase) 80 | 81 | def __getitem__(self, idx, split=None): 82 | t = time.time() 83 | np.random.seed(int(str(t%1)[2:7])) 84 | 85 | isRandomImg = False 86 | if self.phase == 'train' or self.phase == 'val': 87 | if idx >= len(self.bboxes): 88 | isRandom = True 89 | idx = np.random.randint(0, len(self.bboxes)) 90 | isRandomImg = False 91 | else: 92 | isRandom = False 93 | elif self.phase == 'test': 94 | isRandom = False 95 | 96 | if self.phase == 'train' or self.phase == 'val': 97 | if not isRandomImg: 98 | bbox = self.bboxes[idx] 99 | filename = self.filenames[int(bbox[0])] 100 | imgs = np.load(filename)[0:self.channel] 101 | bboxes = self.sample_bboxes[int(bbox[0])] 102 | isScale = self.augtype['scale'] and (self.phase=='train') 103 | sample, target, bboxes, coord = self.crop(imgs, bbox[1:], bboxes, isScale=isScale, isRand=isRandom) 104 | 105 | if self.phase=='train' and not isRandom: 106 | sample, target, bboxes, coord = augment(sample, target, bboxes, coord, 107 | ifflip=self.augtype['flip'], ifrotate=self.augtype['rotate'], ifswap=self.augtype['swap']) 108 | else: 109 | randimid = np.random.randint(len(self.filenames)) 110 | filename = self.filenames[randimid] 111 | imgs = np.load(filename)[0:self.channel] 112 | bboxes = self.sample_bboxes[randimid] 113 | sample, target, bboxes, coord = self.crop(imgs, [], bboxes, isScale=False, isRand=True) 114 | 115 | try: 116 | label = self.label_mapping(sample.shape[1:], target, bboxes, filename) 117 | except ZeroDivisionError: 118 | raise Exception('Bug in {}'.format(os.path.basename(filename).split('_clean')[0])) 119 | 120 | sample = (sample.astype(np.float32)-128)/128 121 | 122 | # print('sample_shape: ', sample.shape, ' label_shape: ', label.shape) 123 | return torch.from_numpy(sample), torch.from_numpy(label), coord 124 | 125 | 126 | elif self.phase == 'test': 127 | imgs = np.load(self.filenames[idx]) 128 | bboxes = self.sample_bboxes[idx] 129 | nz, nh, nw = imgs.shape[1:] 130 | pz = int(np.ceil(float(nz) / self.stride)) * self.stride 131 | ph = int(np.ceil(float(nh) / self.stride)) * self.stride 132 | pw = int(np.ceil(float(nw) / self.stride)) * self.stride 133 | imgs = np.pad(imgs, [[0,0], [0, pz - nz], [0, ph - nh], [0, pw - nw]], 'constant', constant_values=self.pad_value) 134 | xx,yy,zz = np.meshgrid(np.linspace(-0.5,0.5,imgs.shape[1]//self.stride), 135 | np.linspace(-0.5,0.5,imgs.shape[2]//self.stride), 136 | np.linspace(-0.5,0.5,imgs.shape[3]//self.stride), indexing='ij') 137 | coord = np.concatenate([xx[np.newaxis,...], yy[np.newaxis,...],zz[np.newaxis,:]],0).astype('float32') 138 | imgs, nzhw = self.split_comber.split(imgs) 139 | coord2, nzhw2 = self.split_comber.split(coord, 140 | side_len=self.split_comber.side_len//self.stride, 141 | max_stride=self.split_comber.max_stride//self.stride, 142 | margin=self.split_comber.margin//self.stride) 143 | assert np.all(nzhw==nzhw2) 144 | imgs = (imgs.astype(np.float32)-128)/128 145 | return torch.from_numpy(imgs.astype(np.float32)), bboxes, torch.from_numpy(coord2.astype(np.float32)), np.array(nzhw) 146 | 147 | def __len__(self): 148 | if self.phase == 'train': 149 | return int(len(self.bboxes)//(1-self.r_rand)) 150 | elif self.phase =='val': 151 | return len(self.bboxes) 152 | else: 153 | return len(self.sample_bboxes) 154 | 155 | 156 | class NoduleMalignancyDetector(Dataset): 157 | """ Save malignancy label of each nodule in label.npy with [z, y, x, d, malignancy] 158 | 159 | """ 160 | def __init__(self, data_dir, split, config, phase='train', split_comber=None): 161 | assert phase in ['train', 'val', 'test'] 162 | self.phase = phase 163 | self.max_stride = config['max_stride'] 164 | self.stride = config['stride'] 165 | sizelim = config['sizelim'] / config['reso'] 166 | sizelim2 = config['sizelim2'] / config['reso'] 167 | sizelim3 = config['sizelim3'] / config['reso'] 168 | sizelim4 = config['sizelim4'] / config['reso'] 169 | self.blacklist = config['blacklist'] 170 | self.isScale = config['aug_scale'] 171 | self.r_rand = config['r_rand_crop'] # random ratio for sample augmentation == 0.3 172 | self.augtype = config['augtype'] 173 | self.pad_value = config['pad_value'] 174 | self.split_comber = split_comber 175 | if isinstance(split, list): 176 | idcs = split 177 | elif isinstance(split, str) and Path(split).suffix == '.json': 178 | with Path(split).open('rt', encoding='utf-8') as fp: 179 | idcs = json.load(fp) 180 | # A fix for python2 ascii compatibility 181 | idcs = [str(i, encoding='utf-8') if not isinstance(i, str) else i for i in idcs] 182 | 183 | if phase != 'test': 184 | idcs = [f for f in idcs if f not in self.blacklist] 185 | 186 | self.channel = config['channel'] # channel==1 187 | self.filenames = [os.path.join(data_dir, '{}_clean.npy'.format(idx)) for idx in idcs] 188 | 189 | self.sample_bboxes = [] 190 | 191 | for idx in idcs: 192 | l = np.load(Path(data_dir)/'{}_label.npy'.format(idx)) # l = [z, y, x, d, malignancy] 193 | if np.all(l==0): 194 | l = np.array([]) 195 | self.sample_bboxes.append(l) 196 | 197 | 198 | # Balance nodules of different diameters by sizelim. sizelim2, sizelim3 by augment bigger nodules, which are fewer in dataset. 199 | if self.phase != 'test': 200 | self.bboxes = [] 201 | for index, label in enumerate(self.sample_bboxes): 202 | if len(label) > 0: 203 | for t in label: # t = (z, y, x, d, malignancy) 204 | if t[3] > sizelim: 205 | self.bboxes.append([np.concatenate([[index], t])]) 206 | if t[3] > sizelim2: 207 | self.bboxes += [[np.concatenate([[index], t])]] * 2 208 | if t[3] > sizelim3: 209 | self.bboxes += [[np.concatenate([[index], t])]] * 4 210 | if t[3] > sizelim4: 211 | self.bboxes += [[np.concatenate([[index], t])]] * 8 212 | 213 | # Finally, a balanced label collection is generated. 214 | if len(self.bboxes) > 0: 215 | self.bboxes = np.concatenate(self.bboxes, axis=0) 216 | else: 217 | self.bboxes = np.array(self.bboxes) 218 | 219 | self.crop = Crop(config) 220 | self.label_mapping = LabelMapping(config, self.phase) 221 | 222 | def __getitem__(self, idx, split=None): 223 | t = time.time() 224 | np.random.seed(int(str(t % 1)[2:7])) 225 | 226 | isRandomImg = False 227 | if self.phase == 'train' or self.phase == 'val': 228 | if idx >= len(self.bboxes): 229 | isRandom = True 230 | idx = idx % len(self.bboxes) 231 | isRandomImg = np.random.randint(2) 232 | else: 233 | isRandom = False 234 | elif self.phase == 'test': 235 | isRandom = False 236 | 237 | if self.phase == 'train' or self.phase == 'val': 238 | if not isRandomImg: 239 | bbox = self.bboxes[idx] # bbox = (idx, z, y, x, d, malignancy) 240 | filename = self.filenames[int(bbox[0])] 241 | imgs = np.load(filename)[0:self.channel] 242 | bboxes = self.sample_bboxes[int(bbox[0])] 243 | isScale = self.augtype['scale'] and (self.phase == 'train') 244 | sample, target, bboxes, coord = self.crop(imgs, bbox[1:5], bboxes, isScale=isScale, isRand=isRandom) 245 | try: 246 | malignancy = bbox[5] 247 | except IndexError: 248 | malignancy = 0 249 | 250 | if self.phase == 'train' and not isRandom: 251 | sample, target, bboxes, coord = augment(sample, target, bboxes, coord, 252 | ifflip=self.augtype['flip'], 253 | ifrotate=self.augtype['rotate'], 254 | ifswap=self.augtype['swap']) 255 | else: 256 | randimid = np.random.randint(len(self.filenames)) 257 | filename = self.filenames[randimid] 258 | imgs = np.load(filename)[0:self.channel] 259 | bboxes = self.sample_bboxes[randimid] 260 | sample, target, bboxes, coord = self.crop(imgs, [], bboxes, isScale=False, isRand=True) 261 | malignancy = 0 # it's randomly selected, so the malignancy is unknown. 262 | 263 | try: 264 | label = self.label_mapping(sample.shape[1:], target, bboxes, filename) 265 | except ZeroDivisionError: 266 | raise Exception('Bug in {}'.format(os.path.basename(filename).split('_clean')[0])) 267 | 268 | sample = (sample.astype(np.float32) - 128) / 128 269 | return torch.from_numpy(sample), torch.from_numpy(label), coord, torch.tensor(malignancy, dtype=torch.int) 270 | 271 | elif self.phase == 'test': 272 | imgs = np.load(self.filenames[idx]) 273 | bboxes = self.sample_bboxes[idx] 274 | nz, nh, nw = imgs.shape[1:] 275 | pz = int(np.ceil(float(nz) / self.stride)) * self.stride 276 | ph = int(np.ceil(float(nh) / self.stride)) * self.stride 277 | pw = int(np.ceil(float(nw) / self.stride)) * self.stride 278 | imgs = np.pad(imgs, [[0, 0], [0, pz - nz], [0, ph - nh], [0, pw - nw]], 'constant', 279 | constant_values=self.pad_value) 280 | xx, yy, zz = np.meshgrid(np.linspace(-0.5, 0.5, imgs.shape[1] // self.stride), 281 | np.linspace(-0.5, 0.5, imgs.shape[2] // self.stride), 282 | np.linspace(-0.5, 0.5, imgs.shape[3] // self.stride), indexing='ij') 283 | coord = np.concatenate([xx[np.newaxis, ...], yy[np.newaxis, ...], zz[np.newaxis, :]], 0).astype('float32') 284 | imgs, nzhw = self.split_comber.split(imgs) 285 | coord2, nzhw2 = self.split_comber.split(coord, 286 | side_len=self.split_comber.side_len // self.stride, 287 | max_stride=self.split_comber.max_stride // self.stride, 288 | margin=self.split_comber.margin // self.stride) 289 | assert np.all(nzhw == nzhw2) 290 | imgs = (imgs.astype(np.float32) - 128) / 128 291 | return torch.from_numpy(imgs.astype(np.float32)), bboxes, torch.from_numpy(coord2.astype(np.float32)), np.array(nzhw) 292 | 293 | def __len__(self): 294 | if self.phase == 'train': 295 | return int(len(self.bboxes) // (1 - self.r_rand)) 296 | elif self.phase == 'val': 297 | return len(self.bboxes) 298 | else: 299 | return len(self.sample_bboxes) 300 | 301 | 302 | class Crop(object): 303 | def __init__(self, config): 304 | self.crop_size = config['crop_size'] #int: [96,96,96] 305 | self.bound_size = config['bound_size'] #12 306 | self.stride = config['stride'] #4 307 | self.pad_value = config['pad_value'] #170 308 | 309 | def __call__(self, imgs, target, bboxes, isScale=False, isRand=False): 310 | if isScale: 311 | # target: (z,y,x,d) 312 | radiusLim = [8., 120.] 313 | scaleLim = [0.75, 1.25] 314 | scaleRange = [np.min([np.max([(radiusLim[0] / target[3]), scaleLim[0]]), 1]), 315 | np.max([np.min([(radiusLim[1] / target[3]), scaleLim[1]]), 1])] 316 | scale = np.random.rand() * (scaleRange[1] - scaleRange[0]) + scaleRange[0] 317 | crop_size = (np.array(self.crop_size).astype('float') / scale).astype('int') 318 | else: 319 | crop_size = self.crop_size 320 | bound_size = self.bound_size 321 | target = np.copy(target) 322 | bboxes = np.copy(bboxes) 323 | 324 | start = [] 325 | for i in range(3): 326 | if not isRand: 327 | r = target[3] / 2 328 | s = np.floor(target[i] - r) + 1 - bound_size 329 | e = np.ceil(target[i] + r) + 1 + bound_size - crop_size[i] 330 | else: 331 | s = np.max([imgs.shape[i + 1] - crop_size[i] / 2, imgs.shape[i + 1] / 2 + bound_size]) 332 | e = np.min([crop_size[i] / 2, imgs.shape[i + 1] / 2 - bound_size]) 333 | target = np.array([np.nan, np.nan, np.nan, np.nan]) 334 | 335 | if s > e: 336 | start.append(np.random.randint(e, s)) # ! 337 | else: 338 | start.append(int(target[i] - crop_size[i] / 2 + np.random.randint(-bound_size / 2, bound_size / 2))) 339 | 340 | normstart = np.array(start).astype('float32') / np.array(imgs.shape[1:]) - 0.5 341 | normsize = np.array(crop_size).astype('float32') / np.array(imgs.shape[1:]) 342 | xx, yy, zz = np.meshgrid(np.linspace(normstart[0], normstart[0] + normsize[0], self.crop_size[0] // self.stride), 343 | np.linspace(normstart[1], normstart[1] + normsize[1], self.crop_size[1] // self.stride), 344 | np.linspace(normstart[2], normstart[2] + normsize[2], self.crop_size[2] // self.stride), 345 | indexing='ij') 346 | coord = np.concatenate([xx[np.newaxis, ...], yy[np.newaxis, ...], zz[np.newaxis, :]], 0).astype('float32') 347 | 348 | pad = [] 349 | pad.append([0, 0]) 350 | 351 | for i in range(3): 352 | leftpad = max(0, -start[i]) 353 | rightpad = max(0, start[i] + crop_size[i] - imgs.shape[i + 1]) 354 | pad.append([leftpad, rightpad]) 355 | 356 | crop = imgs[:, 357 | max(start[0], 0):min(start[0] + crop_size[0], imgs.shape[1]), 358 | max(start[1], 0):min(start[1] + crop_size[1], imgs.shape[2]), 359 | max(start[2], 0):min(start[2] + crop_size[2], imgs.shape[3])] 360 | crop = np.pad(crop, pad, 'constant', constant_values=self.pad_value) 361 | 362 | for i in range(3): 363 | target[i] = target[i] - start[i] 364 | for i in range(len(bboxes)): 365 | for j in range(3): 366 | bboxes[i][j] = bboxes[i][j] - start[j] 367 | 368 | if isScale: 369 | with warnings.catch_warnings(): 370 | warnings.simplefilter("ignore") 371 | crop = zoom(crop, [1, scale, scale, scale], order=2) 372 | newpad = self.crop_size[0] - crop.shape[1:][0] 373 | 374 | if newpad < 0: 375 | crop = crop[:, :-newpad, :-newpad, :-newpad] 376 | elif newpad > 0: 377 | pad2 = [[0, 0], [0, newpad], [0, newpad], [0, newpad]] 378 | crop = np.pad(crop, pad2, 'constant', constant_values=self.pad_value) 379 | 380 | for i in range(4): 381 | target[i] = target[i] * scale 382 | for i in range(len(bboxes)): 383 | for j in range(4): 384 | bboxes[i][j] = bboxes[i][j] * scale 385 | return crop, target, bboxes, coord 386 | 387 | 388 | class LabelMapping(object): 389 | def __init__(self, config, phase): 390 | self.stride = np.array(config['stride']) #4 391 | self.num_neg = int(config['num_neg']) #800 392 | self.th_neg = config['th_neg'] #0.02 393 | self.anchors = np.asarray(config['anchors']) 394 | self.phase = phase 395 | if phase == 'train': 396 | self.th_pos = config['th_pos_train'] #0.5 397 | elif phase == 'val': 398 | self.th_pos = config['th_pos_val'] #1 399 | 400 | def __call__(self, input_size, target, bboxes, filename): 401 | stride = self.stride 402 | num_neg = self.num_neg 403 | th_neg = self.th_neg 404 | anchors = self.anchors 405 | th_pos = self.th_pos 406 | 407 | output_size = [] 408 | for i in range(3): 409 | assert(input_size[i] % stride == 0), 'input_size[{}]={}, stride={}, filename={}'.format(i, input_size[i], str(stride), filename) 410 | output_size.append(input_size[i] // stride) 411 | 412 | # Initialize all grid labels to -1 413 | label = -1 * np.ones(output_size + [len(anchors), 5], np.float32) #(24, 24, 24, #anchor, 5) 414 | offset = ((stride.astype('float')) - 1) / 2 415 | oz = np.arange(offset, offset + stride * (output_size[0] - 1) + 1, stride) 416 | oh = np.arange(offset, offset + stride * (output_size[1] - 1) + 1, stride) 417 | ow = np.arange(offset, offset + stride * (output_size[2] - 1) + 1, stride) 418 | 419 | # Find the positively-labeled grids in bboxes, and set them to 0 420 | for bbox in bboxes: 421 | for i, anchor in enumerate(anchors): 422 | iz, ih, iw = select_samples(bbox, anchor, th_neg, oz, oh, ow) 423 | label[iz, ih, iw, i, 0] = 0 424 | 425 | if self.phase == 'train' and self.num_neg > 0: 426 | # Now, all grids which are labeled as -1 are negative grids. 427 | neg_z, neg_h, neg_w, neg_a = np.where(label[:, :, :, :, 0] == -1) 428 | 429 | # Select num_neg(=800) of them, set as -1, leave all others(including positive grid) to 0 430 | neg_idcs = random.sample(range(len(neg_z)), min(num_neg, len(neg_z))) 431 | neg_z, neg_h, neg_w, neg_a = neg_z[neg_idcs], neg_h[neg_idcs], neg_w[neg_idcs], neg_a[neg_idcs] 432 | label[:, :, :, :, 0] = 0 433 | label[neg_z, neg_h, neg_w, neg_a, 0] = -1 434 | 435 | # If no target in this crop, return negative grids(labeled as -1) only. 436 | if np.isnan(target[0]): 437 | return label 438 | 439 | # Locate the target on the grids 440 | iz, ih, iw, ia = [], [], [], [] 441 | for i, anchor in enumerate(anchors): 442 | iiz, iih, iiw = select_samples(target, anchor, th_pos, oz, oh, ow) 443 | iz.append(iiz) 444 | ih.append(iih) 445 | iw.append(iiw) 446 | ia.append(i * np.ones((len(iiz),), np.int64)) 447 | iz = np.concatenate(iz, 0) 448 | ih = np.concatenate(ih, 0) 449 | iw = np.concatenate(iw, 0) 450 | ia = np.concatenate(ia, 0) 451 | 452 | if len(iz) == 0: 453 | pos = [] 454 | for i in range(3): 455 | pos.append(max(0, int(np.round((target[i] - offset) / stride)))) 456 | idx = np.argmin(np.abs(np.log(target[3] / anchors))) 457 | pos.append(idx) 458 | else: # randomly choose one if there is more than one positive grid 459 | idx = random.sample(range(len(iz)), 1)[0] 460 | pos = [iz[idx], ih[idx], iw[idx], ia[idx]] 461 | 462 | # Calculate the difference ratio of (z,h,w,d) between target and positive grid(=pos) relative to anchor 463 | dz = (target[0] - oz[pos[0]]) / anchors[pos[3]] 464 | dh = (target[1] - oh[pos[1]]) / anchors[pos[3]] 465 | dw = (target[2] - ow[pos[2]]) / anchors[pos[3]] 466 | dd = np.log(target[3] / anchors[pos[3]]) 467 | label[pos[0], pos[1], pos[2], pos[3], :] = [1, dz, dh, dw, dd] 468 | return label 469 | 470 | 471 | def augment(sample, target, bboxes, coord, ifflip=True, ifrotate=True, ifswap=True): 472 | if ifrotate: 473 | validrot = False 474 | counter = 0 475 | while not validrot: 476 | newtarget = np.copy(target) 477 | angle1 = (np.random.rand()-0.5)*20 478 | size = np.array(sample.shape[2:4]).astype('float') 479 | rotmat = np.array([[np.cos(angle1/180*np.pi),-np.sin(angle1/180*np.pi)],[np.sin(angle1/180*np.pi),np.cos(angle1/180*np.pi)]]) 480 | newtarget[1:3] = np.dot(rotmat,target[1:3]-size/2)+size/2 481 | if np.all(newtarget[:3]>target[3]) and np.all(newtarget[:3]< np.array(sample.shape[1:4])-newtarget[3]): 482 | validrot = True 483 | target = newtarget 484 | sample = rotate(sample,angle1,axes=(2,3),reshape=False) 485 | coord = rotate(coord,angle1,axes=(2,3),reshape=False) 486 | for box in bboxes: 487 | box[1:3] = np.dot(rotmat,box[1:3]-size/2)+size/2 488 | else: 489 | counter += 1 490 | if counter ==3: 491 | break 492 | if ifswap: 493 | if sample.shape[1]==sample.shape[2] and sample.shape[1]==sample.shape[3]: 494 | axisorder = np.random.permutation(3) 495 | sample = np.transpose(sample,np.concatenate([[0],axisorder+1])) 496 | coord = np.transpose(coord,np.concatenate([[0],axisorder+1])) 497 | target[:3] = target[:3][axisorder] 498 | bboxes[:,:3] = bboxes[:,:3][:,axisorder] 499 | 500 | if ifflip: 501 | flipid = np.array([1,np.random.randint(2),np.random.randint(2)])*2-1 502 | sample = np.ascontiguousarray(sample[:,::flipid[0],::flipid[1],::flipid[2]]) 503 | coord = np.ascontiguousarray(coord[:,::flipid[0],::flipid[1],::flipid[2]]) 504 | for ax in range(3): 505 | if flipid[ax]==-1: 506 | target[ax] = np.array(sample.shape[ax+1])-target[ax] 507 | bboxes[:,ax]= np.array(sample.shape[ax+1])-bboxes[:,ax] 508 | 509 | # normal = np.random.uniform(-0.1, 0.1, sample.shape) 510 | # if np.random.randint(0, 2): 511 | # sample = sample + normal 512 | 513 | return sample, target, bboxes, coord 514 | 515 | def select_samples(bbox, anchor, th, oz, oh, ow): 516 | z, h, w, d = bbox 517 | 518 | if d == 0: 519 | return np.zeros((0,), np.int64), np.zeros((0,), np.int64), np.zeros((0,), np.int64) 520 | 521 | max_overlap = min(d, anchor) 522 | min_overlap = np.power(max(d, anchor), 3) * th / max_overlap / max_overlap 523 | 524 | if min_overlap > max_overlap: 525 | return np.zeros((0,), np.int64), np.zeros((0,), np.int64), np.zeros((0,), np.int64) 526 | else: 527 | s = z - 0.5 * np.abs(d - anchor) - (max_overlap - min_overlap) 528 | e = z + 0.5 * np.abs(d - anchor) + (max_overlap - min_overlap) 529 | mz = np.logical_and(oz >= s, oz <= e) 530 | iz = np.where(mz)[0] 531 | 532 | s = h - 0.5 * np.abs(d - anchor) - (max_overlap - min_overlap) 533 | e = h + 0.5 * np.abs(d - anchor) + (max_overlap - min_overlap) 534 | mh = np.logical_and(oh >= s, oh <= e) 535 | ih = np.where(mh)[0] 536 | 537 | s = w - 0.5 * np.abs(d - anchor) - (max_overlap - min_overlap) 538 | e = w + 0.5 * np.abs(d - anchor) + (max_overlap - min_overlap) 539 | mw = np.logical_and(ow >= s, ow <= e) 540 | iw = np.where(mw)[0] 541 | 542 | if len(iz) == 0 or len(ih) == 0 or len(iw) == 0: 543 | return np.zeros((0,), np.int64), np.zeros((0,), np.int64), np.zeros((0,), np.int64) 544 | 545 | lz, lh, lw = len(iz), len(ih), len(iw) 546 | iz = iz.reshape((-1, 1, 1)) 547 | ih = ih.reshape((1, -1, 1)) 548 | iw = iw.reshape((1, 1, -1)) 549 | iz = np.tile(iz, (1, lh, lw)).reshape((-1)) 550 | ih = np.tile(ih, (lz, 1, lw)).reshape((-1)) 551 | iw = np.tile(iw, (lz, lh, 1)).reshape((-1)) 552 | 553 | centers = np.concatenate([ 554 | oz[iz].reshape((-1, 1)), 555 | oh[ih].reshape((-1, 1)), 556 | ow[iw].reshape((-1, 1))], axis=1) 557 | 558 | r0 = anchor / 2 559 | s0 = centers - r0 560 | e0 = centers + r0 561 | 562 | r1 = d / 2 563 | s1 = bbox[:3] - r1 564 | s1 = s1.reshape((1, -1)) 565 | e1 = bbox[:3] + r1 566 | e1 = e1.reshape((1, -1)) 567 | 568 | overlap = np.maximum(0, np.minimum(e0, e1) - np.maximum(s0, s1)) 569 | 570 | intersection = overlap[:, 0] * overlap[:, 1] * overlap[:, 2] 571 | union = anchor * anchor * anchor + d * d * d - intersection 572 | 573 | iou = intersection / union 574 | mask = iou >= th 575 | 576 | iz = iz[mask] 577 | ih = ih[mask] 578 | iw = iw[mask] 579 | 580 | return iz, ih, iw 581 | 582 | def collate(batch): 583 | if torch.is_tensor(batch[0]): 584 | return [b.unsqueeze(0) for b in batch] 585 | elif isinstance(batch[0], np.ndarray): 586 | return batch 587 | elif isinstance(batch[0], int): 588 | return torch.LongTensor(batch) 589 | elif isinstance(batch[0], collections.Iterable): 590 | transposed = zip(*batch) 591 | return [collate(samples) for samples in transposed] 592 | 593 | -------------------------------------------------------------------------------- /json/Read.md: -------------------------------------------------------------------------------- 1 | Put your train, val, and test.json (Case ID) 2 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | 9 | class PostRes(nn.Module): 10 | def __init__(self, n_in, n_out, stride=1): 11 | super(PostRes, self).__init__() 12 | self.conv1 = nn.Conv3d(n_in, n_out, kernel_size=3, stride=stride, padding=1) 13 | self.bn1 = nn.BatchNorm3d(n_out) 14 | self.relu = nn.ReLU(inplace=True) 15 | self.conv2 = nn.Conv3d(n_out, n_out, kernel_size=3, padding=1) 16 | self.bn2 = nn.BatchNorm3d(n_out) 17 | 18 | if stride != 1 or n_out != n_in: 19 | self.shortcut = nn.Sequential( 20 | nn.Conv3d(n_in, n_out, kernel_size=1, stride=stride), 21 | nn.BatchNorm3d(n_out)) 22 | else: 23 | self.shortcut = None 24 | 25 | def forward(self, x): 26 | residual = x 27 | if self.shortcut is not None: 28 | residual = self.shortcut(x) 29 | out = self.conv1(x) 30 | out = self.bn1(out) 31 | out = self.relu(out) 32 | out = self.conv2(out) 33 | out = self.bn2(out) 34 | 35 | out += residual 36 | out = self.relu(out) 37 | return out 38 | 39 | 40 | class ResidualBlock(nn.Module): 41 | """ 42 | Simple residual block 43 | Design inspiraed by Kaiming's 'Identity Mappings in Deep Residual Networks' 44 | https://arxiv.org/pdf/1603.05027v3.pdf 45 | """ 46 | 47 | def __init__(self, last_planes, in_planes, out_planes, kernel_size, stride, padding, dilation, debug=False): 48 | super(ResidualBlock, self).__init__() 49 | self.debug = debug 50 | self.last_planes = last_planes 51 | self.in_planes = in_planes 52 | self.out_planes = out_planes 53 | 54 | self.n0 = nn.InstanceNorm3d(last_planes) 55 | self.conv1 = nn.Conv3d(last_planes, in_planes, kernel_size=1, bias=False) 56 | self.n1 = nn.InstanceNorm3d(in_planes) 57 | self.conv2 = nn.Conv3d(in_planes, in_planes, kernel_size=kernel_size, stride=stride, 58 | padding=padding, dilation=dilation, bias=False) 59 | self.n2 = nn.InstanceNorm3d(in_planes) 60 | self.conv3 = nn.Conv3d(in_planes, out_planes, kernel_size=1, bias=False) 61 | self.shortcut = nn.Conv3d(last_planes, out_planes, kernel_size=1, stride=stride, bias=False) 62 | 63 | def forward(self, x): 64 | if self.debug: print(f'bottleneck: x={x.size()}') 65 | if self.debug: print(f'last_planes={self.last_planes} ' 66 | f'in_planes={self.in_planes} ' 67 | f'out_planes={self.out_planes} ' 68 | ) 69 | 70 | out = F.relu(self.n0(x)) 71 | if self.debug: print(f'ResidualBlock:x={out.size()}') 72 | 73 | out = F.relu(self.n1(self.conv1(out))) 74 | if self.debug: print(f'ResidualBlock: conv1={out.size()}') 75 | 76 | out = F.relu(self.n2(self.conv2(out))) 77 | if self.debug: print(f'ResidualBlock: conv2={out.size()}') 78 | 79 | out = self.conv3(out) 80 | if self.debug: print(f'ResidualBlock: conv3={out.size()}') 81 | 82 | x = self.shortcut(x) 83 | if self.debug: print(f'ResidualBlock: shortcut={x.size()}') 84 | 85 | out = out + x 86 | if self.debug: print(f'ResidualBlock: conv3+shortcut={out.size()}') 87 | 88 | return out 89 | 90 | 91 | class GetPBB(object): 92 | def __init__(self, config): 93 | self.stride = config['stride'] 94 | self.anchors = np.asarray(config['anchors']) 95 | 96 | def __call__(self, output, thresh=-3, ismask=False): 97 | stride = self.stride 98 | anchors = self.anchors 99 | output = np.copy(output) 100 | offset = (float(stride) - 1) / 2 101 | output_size = output.shape 102 | oz = np.arange(offset, offset + stride * (output_size[0] - 1) + 1, stride) 103 | oh = np.arange(offset, offset + stride * (output_size[1] - 1) + 1, stride) 104 | ow = np.arange(offset, offset + stride * (output_size[2] - 1) + 1, stride) 105 | 106 | output[:, :, :, :, 1] = oz.reshape((-1, 1, 1, 1)) + output[:, :, :, :, 1] * anchors.reshape((1, 1, 1, -1)) 107 | output[:, :, :, :, 2] = oh.reshape((1, -1, 1, 1)) + output[:, :, :, :, 2] * anchors.reshape((1, 1, 1, -1)) 108 | output[:, :, :, :, 3] = ow.reshape((1, 1, -1, 1)) + output[:, :, :, :, 3] * anchors.reshape((1, 1, 1, -1)) 109 | output[:, :, :, :, 4] = np.exp(output[:, :, :, :, 4]) * anchors.reshape((1, 1, 1, -1)) 110 | mask = output[..., 0] > thresh 111 | xx, yy, zz, aa = np.where(mask) 112 | 113 | output = output[xx, yy, zz, aa] 114 | if ismask: 115 | return output, [xx, yy, zz, aa] 116 | else: 117 | return output 118 | # output = output[output[:, 0] >= self.conf_th] 119 | # bboxes = nms(output, self.nms_th) 120 | 121 | 122 | class CReLU(nn.Module): 123 | """ 124 | Understanding and Improving Convolutional Neural Networks via Concatenated Rectified Linear Units 125 | arXiv:1603.05201 126 | https://arxiv.org/pdf/1603.05201.pdf 127 | """ 128 | def __init__(self): 129 | super(CReLU, self).__init__() 130 | 131 | def forward(self, x): 132 | return torch.cat((F.relu(x), F.relu(-x)), 1) 133 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | #coding=utf-8 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | 9 | 10 | class Loss(nn.Module): 11 | def __init__(self, num_hard=0, class_loss='BCELoss', average=True): 12 | super().__init__() 13 | self.sigmoid = nn.Sigmoid() 14 | self.regress_loss = nn.SmoothL1Loss() 15 | self.num_hard = num_hard 16 | 17 | loss_dict = nn.ModuleDict( 18 | {'BCELoss': nn.BCELoss(), 19 | 'MarginLoss': MarginLoss(size_average=average), 20 | 'MSELoss': nn.MSELoss(reduction='mean' if average else 'sum'), 21 | 'FocalMarginLoss': FocalMarginLoss(size_average=average) 22 | } 23 | ) 24 | self.classify_loss = loss_dict[class_loss] 25 | 26 | @staticmethod 27 | def hard_mining(neg_output, neg_labels, num_hard): 28 | _, idcs = torch.topk(neg_output, min(num_hard, len(neg_output))) 29 | neg_output = torch.index_select(neg_output, 0, idcs) 30 | neg_labels = torch.index_select(neg_labels, 0, idcs) 31 | return neg_output, neg_labels 32 | 33 | def forward(self, output, labels, train=True): 34 | batch_size = labels.size(0) 35 | output = output.view(-1, 5) 36 | labels = labels.view(-1, 5) 37 | 38 | # positive grids are labeled as 1 39 | pos_idcs = labels[:, 0] > 0.5 40 | pos_idcs = pos_idcs.unsqueeze(1).expand(pos_idcs.size(0), 5) 41 | pos_output = output[pos_idcs].view(-1, 5) 42 | pos_labels = labels[pos_idcs].view(-1, 5) 43 | 44 | # negative grids are labeled as -1 45 | neg_idcs = labels[:, 0] < -0.5 46 | neg_output = output[:, 0][neg_idcs] 47 | neg_labels = labels[:, 0][neg_idcs] 48 | 49 | if self.num_hard > 0 and train: 50 | # Pick up the grid with the most wrong output (ie. highest output >> -1) 51 | neg_output, neg_labels = self.hard_mining(neg_output, neg_labels, self.num_hard * batch_size) 52 | 53 | neg_prob = neg_output 54 | 55 | if len(pos_output) > 0: 56 | pos_prob = pos_output[:, 0] 57 | 58 | pz, ph, pw, pd = pos_output[:, 1], pos_output[:, 2], pos_output[:, 3], pos_output[:, 4] 59 | lz, lh, lw, ld = pos_labels[:, 1], pos_labels[:, 2], pos_labels[:, 3], pos_labels[:, 4] 60 | 61 | regress_losses = [ 62 | self.regress_loss(pz, lz), 63 | self.regress_loss(ph, lh), 64 | self.regress_loss(pw, lw), 65 | self.regress_loss(pd, ld)] 66 | regress_losses_data = [l.item() for l in regress_losses] 67 | classify_loss = 0.5 * self.classify_loss(pos_prob, pos_labels[:, 0]) + \ 68 | 0.5 * self.classify_loss(neg_prob, neg_labels + 1) 69 | pos_correct = (pos_prob.data >= 0.9).sum() 70 | pos_total = len(pos_prob) 71 | 72 | else: 73 | regress_losses = [0, 0, 0, 0] 74 | classify_loss = 0.5 * self.classify_loss(neg_prob, neg_labels + 1) 75 | pos_correct = 0 76 | pos_total = 0 77 | regress_losses_data = [0, 0, 0, 0] 78 | classify_loss_data = classify_loss.item() 79 | 80 | loss = classify_loss 81 | for regress_loss in regress_losses: 82 | loss += regress_loss 83 | 84 | neg_correct = (neg_prob.data < 0.9).sum() 85 | neg_total = len(neg_prob) 86 | 87 | return [loss, classify_loss_data] + regress_losses_data + [pos_correct, pos_total, neg_correct, neg_total] 88 | 89 | 90 | class Loss_recon(nn.Module): 91 | def __init__(self, num_hard=0, class_loss='MarginLoss', recon_loss_scale=1e-6, average=True): 92 | super().__init__() 93 | self.sigmoid = nn.Sigmoid() 94 | self.regress_loss = nn.SmoothL1Loss() 95 | self.num_hard = num_hard 96 | self.recon_loss_scale = recon_loss_scale 97 | self.reconstruction_loss = nn.MSELoss(reduction='mean' if average else 'sum') 98 | 99 | loss_dict = nn.ModuleDict( 100 | {'BCELoss': nn.BCELoss(), 101 | 'MarginLoss': MarginLoss(size_average=average), 102 | 'MSELoss': nn.MSELoss(reduction='mean' if average else 'sum'), 103 | 'FocalMarginLoss': FocalMarginLoss(size_average=average) 104 | } 105 | ) 106 | self.classify_loss = loss_dict[class_loss] 107 | 108 | @staticmethod 109 | def hard_mining(neg_output, neg_labels, num_hard): 110 | _, idcs = torch.topk(neg_output, min(num_hard, len(neg_output))) 111 | neg_output = torch.index_select(neg_output, 0, idcs) 112 | neg_labels = torch.index_select(neg_labels, 0, idcs) 113 | return neg_output, neg_labels 114 | 115 | def forward(self, output, labels, images, reconstructions, train=True): 116 | batch_size = labels.size(0) 117 | output = output.view(-1, 5) 118 | labels = labels.view(-1, 5) 119 | 120 | # positive grids are labeled as 1 121 | pos_idcs = labels[:, 0] > 0.5 122 | pos_idcs = pos_idcs.unsqueeze(1).expand(pos_idcs.size(0), 5) 123 | pos_output = output[pos_idcs].view(-1, 5) 124 | pos_labels = labels[pos_idcs].view(-1, 5) 125 | 126 | # negative grids are labeled as -1 127 | neg_idcs = labels[:, 0] < -0.5 128 | neg_output = output[:, 0][neg_idcs] 129 | neg_labels = labels[:, 0][neg_idcs] 130 | 131 | if self.num_hard > 0 and train: 132 | # Pick up the grid with the most wrong output (ie. highest output >> -1) 133 | neg_output, neg_labels = self.hard_mining(neg_output, neg_labels, self.num_hard * batch_size) 134 | 135 | neg_prob = neg_output 136 | 137 | if len(pos_output) > 0: # there are positive anchors in this crop 138 | pos_prob = pos_output[:, 0] 139 | 140 | pz, ph, pw, pd = pos_output[:, 1], pos_output[:, 2], pos_output[:, 3], pos_output[:, 4] 141 | lz, lh, lw, ld = pos_labels[:, 1], pos_labels[:, 2], pos_labels[:, 3], pos_labels[:, 4] 142 | 143 | regress_losses = [ 144 | self.regress_loss(pz, lz), 145 | self.regress_loss(ph, lh), 146 | self.regress_loss(pw, lw), 147 | self.regress_loss(pd, ld)] 148 | regress_losses_data = [l.item() for l in regress_losses] 149 | classify_loss = .5 * self.classify_loss(pos_prob, pos_labels[:, 0]) + \ 150 | .5 * self.classify_loss(neg_prob, neg_labels + 1) 151 | pos_correct = (pos_prob.data >= 0.5).sum() 152 | pos_total = len(pos_prob) 153 | 154 | else: 155 | regress_losses = [0, 0, 0, 0] 156 | classify_loss = 0.5 * self.classify_loss(neg_prob, neg_labels + 1) 157 | pos_correct = 0 158 | pos_total = 0 159 | regress_losses_data = [0, 0, 0, 0] 160 | classify_loss_data = classify_loss.item() 161 | 162 | # Total loss = classify loss + regress_loss of z, h, w, d + recon_loss * recon_loss_scale 163 | loss = classify_loss 164 | for regress_loss in regress_losses: 165 | loss += regress_loss 166 | 167 | reconstruction_loss = self.recon_loss_scale * self.reconstruction_loss(reconstructions, images) 168 | reconstruction_loss_data = reconstruction_loss.item() 169 | loss += reconstruction_loss 170 | 171 | neg_correct = (neg_prob.data < 0.5).sum() 172 | neg_total = len(neg_prob) 173 | 174 | return [loss, classify_loss_data] + regress_losses_data + [pos_correct, pos_total, neg_correct, neg_total] + [reconstruction_loss_data] 175 | 176 | class FocalLoss(nn.Module): 177 | """ Kaiming's Focal loss 178 | https://arxiv.org/pdf/1708.02002.pdf 179 | """ 180 | def __init__(self, num_classes=1, alpha=0.25, gamma=2): 181 | super(FocalLoss, self).__init__() 182 | self.num_classes = num_classes # exclude the background 183 | self.alpha = alpha 184 | self.gamma = gamma 185 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 186 | 187 | def one_hot_embedding(self, labels, num_classes): 188 | '''Embedding labels to one-hot form. 189 | 190 | Args: 191 | labels: (LongTensor) class labels, sized [N,]. 192 | num_classes: (int) number of classes. 193 | 194 | Returns: 195 | (tensor) encoded labels, sized [N,#classes]. 196 | ''' 197 | y = torch.eye(num_classes) # [D,D] 198 | return y[labels] # [N,D] 199 | 200 | def focal_loss(self, x, y): 201 | '''Focal loss. 202 | 203 | Args: 204 | x: (tensor) sized [N] --> [N, 1]. 205 | y: (tensor) sized [N]. 206 | 207 | Return: 208 | (tensor) focal loss. 209 | ''' 210 | n = x.size(0) 211 | x = x.view(n, -1) # (N, 1) 212 | 213 | # Convert the label to one-hot encoded target 214 | t = self.one_hot_embedding(y.type(torch.long), 1 + self.num_classes) 215 | t = t[:, 1:] # exclude background 216 | t = t.view(n, -1) 217 | t = t.to(self.device) # [N, num_classes] 218 | 219 | # Calculate weight from target and prediction distribution 220 | p = x.sigmoid() 221 | pt = p * t + (1 - p) * (1 - t) # pt = p if t = 1 else 1-p 222 | w = self.alpha * t + (1 - self.alpha) * (1 - t) # w = alpha if t = 1 else 1-alpha 223 | w = w * (1 - pt).pow(self.gamma) 224 | return F.binary_cross_entropy_with_logits(x, t, w, size_average=False) 225 | 226 | def forward(self, output, labels): 227 | output = output.view(-1, 1) 228 | labels = labels.view(-1, 1) 229 | classify_loss = self.focal_loss(output, labels) 230 | return classify_loss 231 | 232 | 233 | class MarginLoss(nn.Module): 234 | def __init__(self, num_classes=1, size_average=True, loss_lambda=0.5): 235 | ''' 236 | Margin loss for digit existence 237 | Eq. (4): L_k = T_k * max(0, m+ - ||v_k||)^2 + lambda * (1 - T_k) * max(0, ||v_k|| - m-)^2 238 | 239 | Args: 240 | size_average: should the losses be averaged (True) or summed (False) over observations for each minibatch. 241 | loss_lambda: parameter for down-weighting the loss for missing digits 242 | num_classes: number of classes (exclude the background) 243 | ''' 244 | super().__init__() 245 | self.size_average = size_average 246 | self.m_plus = 0.9 247 | self.m_minus = 0.1 248 | self.loss_lambda = loss_lambda 249 | self.num_classes = num_classes 250 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 251 | 252 | def one_hot_embedding(self, labels, num_classes): 253 | '''Embedding labels to one-hot form. 254 | 255 | Args: 256 | labels: (LongTensor) class labels, sized [N,]. 257 | num_classes: (int) number of classes. 258 | 259 | Returns: 260 | (tensor) encoded labels, sized [N, #classes]. 261 | ''' 262 | y = torch.eye(num_classes) # [D,D] 263 | return y[labels] # [N,D] 264 | 265 | def forward(self, inputs, labels): 266 | """ 267 | :param inputs: [n, num_classes] with one-hot encoded 268 | :param labels: [n,] 269 | """ 270 | n = inputs.size(0) 271 | inputs = inputs.view(n, -1) 272 | 273 | # Convert the label to one-hot encoded target 274 | labels = self.one_hot_embedding(labels.type(torch.long), 1 + self.num_classes) 275 | labels = labels[:, 1:] # exclude background 276 | labels = labels.view(n, -1) 277 | labels = labels.to(self.device) # (N, num_classes) 278 | 279 | left = labels * F.relu(self.m_plus - inputs)**2 280 | right = self.loss_lambda * (1 - labels) * F.relu(inputs - self.m_minus)**2 281 | L_k = left + right 282 | 283 | # Summation of all classes 284 | L_k = L_k.sum(dim=-1) 285 | 286 | if self.size_average: 287 | return L_k.mean() 288 | else: 289 | return L_k.sum() 290 | 291 | 292 | class FocalMarginLoss(nn.Module): 293 | def __init__(self, num_classes=1, alpha=0.25, gamma=2, size_average=True): 294 | ''' 295 | Focal loss binds with Margin loss for one-hot label 296 | Eq. (4): L_k = T_k * max(0, m+ - ||v_k||)^2 + lambda * (1 - T_k) * max(0, ||v_k|| - m-)^2 297 | 298 | Args: 299 | size_average: should the losses be averaged (True) or summed (False) over observations for each minibatch. 300 | loss_lambda: parameter for down-weighting the loss for missing digits 301 | num_classes: number of classes (exclude the background) 302 | ''' 303 | super().__init__() 304 | self.size_average = size_average 305 | self.m_plus = 0.9 306 | self.m_minus = 0.1 307 | self.alpha = alpha 308 | self.gamma = gamma 309 | self.num_classes = num_classes 310 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 311 | 312 | def one_hot_embedding(self, labels, num_classes): 313 | '''Embedding labels to one-hot form. 314 | 315 | Args: 316 | labels: (LongTensor) class labels, sized [N,]. 317 | num_classes: (int) number of classes. 318 | 319 | Returns: 320 | (tensor) encoded labels, sized [N,#classes]. 321 | ''' 322 | y = torch.eye(num_classes) # [D,D] 323 | return y[labels] # [N,D] 324 | 325 | def forward(self, inputs, labels): 326 | """ 327 | :param inputs: (n, num_classes) with one-hot encoded 328 | :param labels: (n,) 329 | :return: loss value 330 | """ 331 | n = inputs.size(0) 332 | inputs = inputs.view(n, -1) 333 | 334 | # Convert the label to one-hot encoded target 335 | t = self.one_hot_embedding(labels.type(torch.long), 1 + self.num_classes) 336 | t = t[:, 1:] # exclude background 337 | t = t.view(n, -1) 338 | t = t.to(self.device) # (N, num_classes) 339 | 340 | # Calculate weight from target and prediction distribution 341 | p = inputs # capsule's output has already squashed to 0-1, so sigmoid is not needed. 342 | pt = p * t + (1 - p) * (1 - t) # pt = p if t = 1 else 1-p 343 | w = self.alpha * t + (1 - self.alpha) * (1 - t) # w = alpha if t = 1 else 1-alpha 344 | w = w * (1 - pt).pow(self.gamma) 345 | 346 | labels = labels.view(n, 1) 347 | left = labels * F.relu(self.m_plus - inputs)**2 348 | right = w * (1 - labels) * F.relu(inputs - self.m_minus)**2 349 | L_k = left + right 350 | 351 | # Summation of all classes 352 | L_k = L_k.sum(dim=-1) 353 | 354 | if self.size_average: 355 | return L_k.mean() 356 | else: 357 | return L_k.sum() 358 | 359 | 360 | 361 | -------------------------------------------------------------------------------- /main_detector_recon.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | #coding=utf-8 3 | 4 | import argparse 5 | import os 6 | import time 7 | import numpy as np 8 | from importlib import import_module 9 | import shutil 10 | from pathlib import Path 11 | import sys 12 | from tqdm import tqdm 13 | # from tensorboardX import SummaryWriter 14 | 15 | import torch 16 | from torch.nn import DataParallel 17 | from torch.backends import cudnn 18 | from torch.utils.data import DataLoader 19 | 20 | from data_detector import DataBowl3Detector, collate 21 | # from data_detector import NoduleMalignancyDetector 22 | from utils import setgpu 23 | from split_combine import SplitComb 24 | from config_training import config as config_training 25 | from adable import AdaBelief 26 | 27 | parser = argparse.ArgumentParser(description='PyTorch DataBowl3 Detector') 28 | parser.add_argument('--model', '-m', metavar='MODEL', default='base', 29 | help='model') 30 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', 31 | help='number of data loading workers (default: 32)') 32 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 33 | help='number of total epochs to run') 34 | parser.add_argument('--start-epoch', default=None, type=int, metavar='N', 35 | help='manual epoch number (useful on restarts)') 36 | parser.add_argument('-b', '--batch-size', default=16, type=int, 37 | metavar='N', help='mini-batch size (default: 16)') 38 | parser.add_argument('--lr', '--learning-rate', default=1e-2, type=float, 39 | metavar='LR', help='initial learning rate') 40 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 41 | help='momentum') 42 | parser.add_argument('--weight-decay', '--wd', default=1e-5, type=float, 43 | metavar='W', help='weight decay (default: 1e-4)') 44 | parser.add_argument('--save-freq', default='10', type=int, metavar='S', 45 | help='save frequency') 46 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 47 | help='path to latest checkpoint (default: none)') 48 | parser.add_argument('--save-dir', default=None, type=str, metavar='SAVE', 49 | help='directory to save checkpoint (default: none)') 50 | parser.add_argument('--test', default=0, type=int, metavar='TEST', 51 | help='1 do test evaluation, 0 not') 52 | parser.add_argument('--split', default=8, type=int, metavar='SPLIT', 53 | help='In the test phase, split the image to 8 parts') 54 | parser.add_argument('--gpu', default='all', type=str, metavar='N', 55 | help='use gpu, "all" or "0,1,2,3" or "0,2" etc') 56 | parser.add_argument('--n_test', default=4, type=int, metavar='N', 57 | help='number of gpu for test') 58 | parser.add_argument('--cross', default=None, type=str, metavar='N', 59 | help='which data cross be used') 60 | args = parser.parse_args() 61 | best_loss = 100.0 62 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 63 | 64 | use_tqdm = True 65 | 66 | def get_lr(epoch): 67 | if epoch <= 10: 68 | lr = 0.002 69 | elif epoch <= 200: 70 | lr = 0.001 71 | else: 72 | lr = 0.0001 73 | return lr 74 | 75 | def main(): 76 | global args, best_loss 77 | datadir = config_training['preprocess_result_path'] 78 | 79 | train_id = './json/' + args.cross + '/LUNA_train.json' 80 | val_id = './json/' + args.cross + '/LUNA_val.json' 81 | test_id = './json/' + args.cross + '/LUNA_test.json' 82 | 83 | torch.manual_seed(0) 84 | cudnn.benchmark = False 85 | 86 | # Load model 87 | print("=> loading model '{}'".format(args.model)) 88 | model_root = 'net' 89 | model = import_module('{}.{}'.format(model_root, args.model)) 90 | config, net, criterion, get_pbb = model.get_model(output_feature=False) 91 | 92 | # If possible, resume from a checkpoint 93 | if args.resume: 94 | checkpoint = torch.load(args.resume) 95 | net.load_state_dict(checkpoint['state_dict']) 96 | best_loss = checkpoint['best_loss'] 97 | print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) 98 | 99 | # Determine the save dir 100 | if args.save_dir is None: 101 | if args.resume: 102 | save_dir = checkpoint['save_dir'] 103 | else: 104 | exp_id = time.strftime('%Y%m%d-%H%M%S', time.localtime()) 105 | save_dir = os.path.join('results', f'{args.model}_{exp_id}') 106 | else: 107 | save_dir = os.path.join('results', args.save_dir) 108 | 109 | # Determine the start epoch 110 | if args.start_epoch is None: 111 | if args.resume: 112 | start_epoch = checkpoint['epoch'] + 1 113 | else: 114 | start_epoch = 1 115 | else: 116 | start_epoch = args.start_epoch 117 | 118 | # If no save_dir, make a new one 119 | if not os.path.isdir(save_dir): 120 | os.makedirs(save_dir) 121 | 122 | # Preserve training parameters for future analysis 123 | if args.test != 1: 124 | pyfiles = list(Path('.').glob('*.py')) + list(Path('net').glob('*.py')) 125 | if not (Path(save_dir)/'net').is_dir(): 126 | os.makedirs(Path(save_dir)/'net') 127 | for f in pyfiles: 128 | shutil.copy(f, Path(save_dir)/f) 129 | 130 | # Setup GPU 131 | ''' 132 | n_gpu = setgpu(args.gpu) 133 | args.n_gpu = n_gpu 134 | gpu_id = range(torch.cuda.device_count()) if args.gpu == 'all' else [int(idx.strip()) for idx in args.gpu.split(',')] 135 | ''' 136 | net = net.to(device) 137 | 138 | # Define loss function (criterion) and optimizer 139 | criterion = criterion.to(device) 140 | optimizer = AdaBelief(net.parameters(), lr=args.lr, weight_decay=args.weight_decay) 141 | pytorch_total_params = sum(p.numel() for p in net.parameters()) 142 | print("Total number of params = ", pytorch_total_params) 143 | 144 | # Infer luna16's pbb/lbb, which are used in training the classifier. 145 | if args.test == 1: 146 | margin = 16#16#32 147 | sidelen = 48#64#144 148 | split_comber = SplitComb(sidelen, config['max_stride'], config['stride'], margin, config['pad_value']) 149 | testset = DataBowl3Detector(datadir, test_id, config, 150 | phase='test', split_comber=split_comber) 151 | test_loader = DataLoader(testset, batch_size=1, shuffle=False, num_workers=0, 152 | collate_fn=collate, pin_memory=False) 153 | test(test_loader, net, get_pbb, save_dir, config) 154 | return 155 | 156 | trainset = DataBowl3Detector(datadir, train_id, config, phase='train') 157 | train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, 158 | pin_memory=True) 159 | valset = DataBowl3Detector(datadir, val_id, config, phase='val') 160 | val_loader = DataLoader(valset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, 161 | pin_memory=True) 162 | 163 | # run train and validate 164 | for epoch in range(start_epoch, args.epochs + 1): 165 | # Train for one epoch 166 | train(train_loader, net, criterion, epoch, optimizer) 167 | # Evaluate on validation set 168 | val_loss = validate(val_loader, net, criterion, epoch, save_dir) 169 | # Remember the best val_loss and save checkpoint 170 | is_best = val_loss < best_loss 171 | best_loss = min(val_loss, best_loss) 172 | 173 | if epoch % args.save_freq == 0 or is_best: 174 | state_dict = net.state_dict() 175 | state_dict = {k:v.cpu() for k, v in state_dict.items()} 176 | state = {'epoch': epoch, 177 | 'save_dir': save_dir, 178 | 'state_dict': state_dict, 179 | 'args': args, 180 | 'best_loss': best_loss} 181 | save_checkpoint(state, is_best, os.path.join(save_dir, '{:>03d}.ckpt'.format(epoch))) 182 | 183 | def train(data_loader, net, criterion, epoch, optimizer, lr_adjuster=None): 184 | start_time = time.time() 185 | 186 | # Switch to train mode 187 | net.train() 188 | cur_iter = int((epoch - 1) * len(data_loader)) + 1 189 | lr = get_lr(epoch) 190 | for param_group in optimizer.param_groups: 191 | param_group['lr'] = lr 192 | 193 | metrics = [] 194 | pbar = tqdm(data_loader) if use_tqdm else data_loader 195 | for i, (input, target, coord) in enumerate(pbar): 196 | input, target, coord = input.to(device), target.to(device), coord.to(device) 197 | # print('input.shape = ', input.shape) 198 | # Compute output 199 | output, _ = net(input, coord) 200 | loss = criterion(output, target, input, input, train=True) 201 | 202 | # Compute gradient and do optimizer step 203 | loss[0].backward() 204 | optimizer.step() 205 | optimizer.zero_grad() 206 | 207 | # Record the loss to metrics 208 | loss[0] = loss[0].item() 209 | metrics.append(loss) 210 | 211 | cur_iter += 1 212 | 213 | end_time = time.time() 214 | 215 | metrics = np.asarray(metrics, np.float32) 216 | eps = 1e-9 217 | total_postive = np.sum(metrics[:, 7]) 218 | total_negative = np.sum(metrics[:, 9]) 219 | total = total_postive + total_negative 220 | tpn = np.sum(metrics[:, 6]) 221 | tnn = np.sum(metrics[:, 8]) 222 | fpn = total_negative - tnn 223 | fnn = total_postive - tpn 224 | accuracy = 100.0 * (tpn + tnn) / total 225 | precision = 100.0 * tpn / (tpn + fpn + eps) 226 | recall = 100.0 * tpn / (tpn + fnn + eps) 227 | f1_score = 2 * precision * recall / (precision + recall + eps) 228 | 229 | print('Epoch %03d (lr %.6f)' % (epoch, lr)) 230 | print('Train: tpr %3.2f, tnr %3.2f, total pos %d, total neg %d, time %3.2f' % ( 231 | 100.0 * tpn / total_postive, 232 | 100.0 * tnn / total_negative, 233 | total_postive, 234 | total_negative, 235 | end_time - start_time)) 236 | print('Train: Acc %3.2f, P %3.2f, R %3.2f, F1 %3.2f' % ( 237 | accuracy, 238 | precision, 239 | recall, 240 | f1_score)) 241 | print('loss %2.4f, classify loss %2.4f, regress loss %2.4f, %2.4f, %2.4f, %2.4f' % ( 242 | np.mean(metrics[:, 0]), 243 | np.mean(metrics[:, 1]), 244 | np.mean(metrics[:, 2]), 245 | np.mean(metrics[:, 3]), 246 | np.mean(metrics[:, 4]), 247 | np.mean(metrics[:, 5]),)) 248 | print() 249 | 250 | def validate(data_loader, net, criterion, epoch, save_dir): 251 | start_time = time.time() 252 | 253 | # Switch to evaluate mode 254 | net.eval() 255 | 256 | metrics = [] 257 | 258 | pred = 0 259 | targ = 0 260 | global f1 261 | with torch.no_grad(): 262 | pbar = tqdm(data_loader) if use_tqdm else data_loader 263 | for i, (input, target, coord) in enumerate(pbar): 264 | input, target, coord = input.to(device), target.to(device), coord.to(device) 265 | 266 | # Compute output and loss 267 | output, _ = net(input, coord, 'val') 268 | loss = criterion(output, target, input, input, train=False) 269 | loss[0] = loss[0].item() 270 | metrics.append(loss) 271 | 272 | end_time = time.time() 273 | 274 | metrics = np.asarray(metrics, np.float32) 275 | eps = 1e-9 276 | total_postive = np.sum(metrics[:, 7]) 277 | total_negative = np.sum(metrics[:, 9]) 278 | total = total_postive + total_negative 279 | tpn = np.sum(metrics[:, 6]) 280 | tnn = np.sum(metrics[:, 8]) 281 | fpn = total_negative - tnn 282 | fnn = total_postive - tpn 283 | accuracy = 100.0 * (tpn + tnn) / total 284 | precision = 100.0 * tpn / (tpn + fpn + eps) 285 | recall = 100.0 * tpn / (tpn + fnn + eps) 286 | f1_score = 2 * precision * recall / (precision + recall + eps) 287 | 288 | 289 | print('Valid: tpr %3.2f, tnr %3.2f, total pos %d, total neg %d, time %3.2f' % ( 290 | 100.0 * tpn / total_postive, 291 | 100.0 * tnn / total_negative, 292 | total_postive, 293 | total_negative, 294 | end_time - start_time) 295 | ) 296 | print('Valid: Acc %3.2f, P %3.2f, R %3.2f, F1 %3.2f' % ( 297 | accuracy, 298 | precision, 299 | recall, 300 | f1_score) 301 | ) 302 | print('loss %2.4f, classify loss %2.4f, regress loss %2.4f, %2.4f, %2.4f, %2.4f' % ( 303 | np.mean(metrics[:, 0]), 304 | np.mean(metrics[:, 1]), 305 | np.mean(metrics[:, 2]), 306 | np.mean(metrics[:, 3]), 307 | np.mean(metrics[:, 4]), 308 | np.mean(metrics[:, 5]),) 309 | ) 310 | print() 311 | 312 | val_loss = np.mean(metrics[:, 0]) 313 | return val_loss 314 | 315 | def test(data_loader, net, get_pbb, save_dir, config): 316 | start_time = time.time() 317 | epoch = args.resume.split('/')[-1].split('.')[0] 318 | 319 | bbox_dir = Path(save_dir)/'bbox' 320 | bbox_dir_back = Path(save_dir)/'bbox_{}'.format(epoch) 321 | 322 | if not bbox_dir.is_dir(): 323 | os.makedirs(bbox_dir) 324 | 325 | if not bbox_dir_back.is_dir(): 326 | os.makedirs(bbox_dir_back) 327 | print('Save pbb/lbb in {}'.format(bbox_dir)) 328 | 329 | net.eval() 330 | split_comber = data_loader.dataset.split_comber 331 | 332 | pbar = tqdm(data_loader) if use_tqdm else data_loader 333 | for i_name, (data, target, coord, nzhw) in enumerate(pbar): 334 | target = [np.asarray(t, np.float32) for t in target] 335 | lbb = target[0] 336 | nzhw = nzhw[0] 337 | name = os.path.basename(data_loader.dataset.filenames[i_name]).split('_clean.npy')[0] 338 | data = data[0][0] 339 | coord = coord[0][0] 340 | isfeat = False 341 | splitlist = list(range(0, len(data)+1, args.n_test)) 342 | 343 | if splitlist[-1] != len(data): 344 | splitlist.append(len(data)) 345 | 346 | outputlist = [] 347 | featurelist = [] 348 | 349 | with torch.no_grad(): 350 | for i in range(len(splitlist)-1): 351 | input = data[splitlist[i]:splitlist[i+1]].to(device) 352 | inputcoord = coord[splitlist[i]:splitlist[i+1]].to(device) 353 | if isfeat: 354 | feature, output, recon = net(input, inputcoord) 355 | featurelist.append(feature.detach().cpu().numpy()) 356 | else: 357 | output, recon = net(input, inputcoord, 'val') 358 | outputlist.append(output.detach().cpu().numpy()) 359 | output = np.concatenate(outputlist, axis=0) 360 | 361 | output = split_comber.combine(output, nzhw=nzhw) 362 | thresh = config['conf_thresh'] 363 | pbb, mask = get_pbb(output, thresh, ismask=True) 364 | # Save nodule prediction 365 | np.save(os.path.join(bbox_dir, name + '_pbb.npy'), pbb) 366 | # Save nodule ground truth 367 | np.save(os.path.join(bbox_dir, name + '_lbb.npy'), lbb) 368 | np.save(os.path.join(bbox_dir_back, name + '_pbb.npy'), pbb) 369 | # Save nodule ground truth 370 | np.save(os.path.join(bbox_dir_back, name + '_lbb.npy'), lbb) 371 | 372 | if isfeat: 373 | feature = np.concatenate(featurelist,0).transpose([0,2,3,4,1])[:,:,:,:,:,np.newaxis] 374 | feature = split_comber.combine(feature, nzhw=nzhw)[...,0] 375 | feature_selected = feature[mask[0], mask[1], mask[2]] 376 | np.save(os.path.join(bbox_dir, name+'_feature.npy'), feature_selected) 377 | 378 | end_time = time.time() 379 | print('elapsed time is %3.2f seconds' % (end_time - start_time)) 380 | print() 381 | print() 382 | 383 | def save_checkpoint(state, is_best, filename): 384 | torch.save(state, filename) 385 | if is_best: 386 | shutil.copyfile(filename, os.path.join(os.path.dirname(filename), 'best_loss.ckpt')) 387 | 388 | 389 | 390 | if __name__ == '__main__': 391 | status = main() 392 | sys.exit(status) 393 | 394 | -------------------------------------------------------------------------------- /net/OSAF_YOLOv3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import math 4 | import torch.nn as nn 5 | from collections import OrderedDict 6 | # from layers_se import * 7 | from loss import Loss_recon, FocalLoss 8 | import torch.nn.functional as F 9 | from layers import GetPBB 10 | 11 | config = {} 12 | config['anchors'] = [5.0, 10.0, 20.] 13 | config['channel'] = 1 14 | config['crop_size'] = [80, 80, 80] 15 | config['stride'] = 4 16 | config['max_stride'] = 16 17 | config['num_neg'] = 800 18 | config['th_neg'] = 0.02 19 | config['th_pos_train'] = 0.5 20 | config['th_pos_val'] = 1.0 21 | config['num_hard'] = 2 22 | config['bound_size'] = 12 23 | config['reso'] = 1 24 | config['sizelim'] = 3. #mm, smallest nodule size 25 | config['sizelim2'] = 10 26 | config['sizelim3'] = 20 27 | config['sizelim4'] = 30 28 | config['aug_scale'] = True 29 | config['r_rand_crop'] = 0.3 30 | config['pad_value'] = 0 31 | config['augtype'] = {'flip':True,'swap':False,'scale':True,'rotate':False, 'noise':False} 32 | config['blacklist'] = ['868b024d9fa388b7ddab12ec1c06af38','990fbe3f0a1b53878669967b9afd1441','adc3bbc63d40f8761c59be10f1e504c3'] 33 | config['conf_thresh'] = 0.15 34 | 35 | 36 | class Conv3d_WS(nn.Conv3d): 37 | 38 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 39 | padding=0, dilation=1, groups=1, bias=True): 40 | super(Conv3d_WS, self).__init__(in_channels, out_channels, kernel_size, stride, 41 | padding, dilation, groups, bias) 42 | 43 | def forward(self, x): 44 | weight = self.weight 45 | weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, 46 | keepdim=True).mean(dim=3, keepdim=True).mean(dim=4, keepdim=True) 47 | weight = weight - weight_mean 48 | std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1, 1) + 1e-5 49 | weight = weight / std.expand_as(weight) 50 | return F.conv3d(x, weight, self.bias, self.stride, 51 | self.padding, self.dilation, self.groups) 52 | 53 | 54 | class Mish(nn.Module): 55 | def __init__(self): 56 | super().__init__() 57 | 58 | def forward(self, x): 59 | #inlining this saves 1 second per epoch (V100 GPU) vs having a temp x and then returning x(!) 60 | return x*(torch.tanh(F.softplus(x))) 61 | 62 | class SpatialPyramidPooling(nn.Module): 63 | def __init__(self, feature_channels, pool_sizes=[3, 5]): 64 | super(SpatialPyramidPooling, self).__init__() 65 | 66 | # head conv 67 | self.head_conv = nn.Sequential( 68 | Conv(feature_channels[-1], feature_channels[-1] // 2, 1), 69 | Conv(feature_channels[-1] // 2, feature_channels[-1], 3), 70 | Conv(feature_channels[-1], feature_channels[-1] // 2, 1), 71 | ) 72 | 73 | self.maxpools = nn.ModuleList( 74 | [ 75 | nn.MaxPool2d(pool_size, 1, pool_size // 2) 76 | for pool_size in pool_sizes 77 | ] 78 | ) 79 | self.__initialize_weights() 80 | 81 | def forward(self, x): 82 | x = self.head_conv(x) 83 | features = [maxpool(x) for maxpool in self.maxpools] 84 | features = torch.cat([x] + features, dim=1) 85 | 86 | return features 87 | 88 | def __initialize_weights(self): 89 | for m in self.modules(): 90 | if isinstance(m, nn.Conv2d): 91 | m.weight.data.normal_(0, 0.01) 92 | if m.bias is not None: 93 | m.bias.data.zero_() 94 | print("initing {}".format(m)) 95 | elif isinstance(m, nn.BatchNorm2d): 96 | m.weight.data.fill_(1) 97 | m.bias.data.zero_() 98 | 99 | print("initing {}".format(m)) 100 | 101 | 102 | class ChannelGate(nn.Module): 103 | def __init__(self, gate_channels, pool_types=['avg', 'max', 'lse']): 104 | super(ChannelGate, self).__init__() 105 | self.gate_channels = gate_channels 106 | self.mlp = nn.Sequential( 107 | nn.Conv3d(gate_channels, gate_channels, 1, 1, 0, bias=False), 108 | ) 109 | self.pool_types = pool_types 110 | def forward(self, x): 111 | channel_att_sum = None 112 | for pool_type in self.pool_types: 113 | if pool_type=='avg': 114 | avg_pool = F.avg_pool3d( x, (x.size(2), x.size(3), x.size(4)), stride=(x.size(2), x.size(3), x.size(4))) 115 | channel_att_raw = self.mlp( avg_pool ) 116 | elif pool_type=='max': 117 | max_pool = F.max_pool3d( x, (x.size(2), x.size(3), x.size(4)), stride=(x.size(2), x.size(3), x.size(4))) 118 | channel_att_raw = self.mlp( max_pool ) 119 | 120 | elif pool_type=='lse': 121 | # LSE pool only 122 | lse_pool = logsumexp_3d(x).unsqueeze(-1).unsqueeze(-1) 123 | channel_att_raw = self.mlp( lse_pool ) 124 | 125 | if channel_att_sum is None: 126 | channel_att_sum = channel_att_raw 127 | else: 128 | channel_att_sum = channel_att_sum + channel_att_raw 129 | 130 | scale = F.sigmoid( channel_att_sum ).expand_as(x) 131 | return x * scale 132 | 133 | def logsumexp_3d(tensor): 134 | tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) 135 | s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) 136 | outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() 137 | return outputs 138 | 139 | class ChannelPool(nn.Module): 140 | def forward(self, x): 141 | return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) 142 | 143 | class SpatialGate(nn.Module): 144 | def __init__(self): 145 | super(SpatialGate, self).__init__() 146 | self.compress = ChannelPool() 147 | self.spatial = nn.Sequential( 148 | nn.Conv3d(2, 1, 3, 1, 1, bias=False), 149 | nn.BatchNorm3d(1) 150 | ) 151 | def forward(self, x): 152 | x_compress = self.compress(x) 153 | x_out = self.spatial(x_compress) 154 | scale = F.sigmoid(x_out) # broadcasting 155 | return x * scale 156 | 157 | class CBAM(nn.Module): 158 | def __init__(self, gate_channels, pool_types=['avg', 'max'], no_spatial=False): 159 | super(CBAM, self).__init__() 160 | self.ChannelGate = ChannelGate(gate_channels, pool_types) 161 | self.no_spatial=no_spatial 162 | if not no_spatial: 163 | self.SpatialGate = SpatialGate() 164 | def forward(self, x): 165 | x_out = self.ChannelGate(x) 166 | if not self.no_spatial: 167 | x_out = self.SpatialGate(x_out) 168 | return x_out 169 | 170 | class eSE(nn.Module): 171 | 172 | def __init__(self, ch): 173 | super().__init__() 174 | self.avg_pool = nn.AdaptiveAvgPool3d(1) 175 | self.fc = nn.Sequential( 176 | nn.Linear(ch, ch), 177 | nn.ReLU(inplace=True), 178 | nn.Linear(ch, ch), 179 | nn.Sigmoid() 180 | ) 181 | 182 | def forward(self, x): 183 | b, c, _, _, _ = x.size() 184 | y = self.avg_pool(x).view(b, c) 185 | y = self.fc(y).view(b, c, 1, 1, 1) 186 | return x * y 187 | 188 | class conv3x3(nn.Module): 189 | 190 | def __init__(self, in_ch, out_ch, stride=1, act='mish', dilation=1): 191 | super().__init__() 192 | 193 | self.conv = nn.Conv3d(in_ch, out_ch, 3, stride, dilation=dilation, padding=dilation, bias=False) 194 | self.norm = nn.BatchNorm3d(out_ch) 195 | if act == 'mish': 196 | self.act = Mish() 197 | elif act == 'leaky': 198 | self.act = nn.LeakyReLU(0.1, inplace=True) 199 | 200 | def forward(self, input): 201 | return self.act(self.norm(self.conv(input))) 202 | 203 | class conv1x1(nn.Module): 204 | 205 | def __init__(self, in_ch, out_ch, stride=1, act='mish'): 206 | super().__init__() 207 | 208 | self.conv = nn.Conv3d(in_ch, out_ch, 1, stride, padding=0, bias=False) 209 | self.norm = nn.BatchNorm3d(out_ch) 210 | if act == 'mish': 211 | self.act = Mish() 212 | elif act == 'leaky': 213 | self.act = nn.LeakyReLU(0.1, inplace=True) 214 | 215 | def forward(self, input): 216 | return self.act(self.norm(self.conv(input))) 217 | 218 | class OSA_Module(nn.Module): 219 | 220 | def __init__(self, in_ch, stage_ch, concat_ch, layer_per_block, identity=False, SE=False): 221 | super().__init__() 222 | self.identity = identity 223 | self.layers = nn.ModuleList() 224 | in_channel = in_ch 225 | for i in range(layer_per_block): 226 | if in_channel == stage_ch: 227 | self.layers.append(nn.Sequential( 228 | conv3x3(in_channel, stage_ch))) 229 | else: 230 | self.layers.append(nn.Sequential( 231 | conv1x1(in_channel, stage_ch))) 232 | 233 | in_channel = stage_ch 234 | 235 | # feature aggregation 236 | in_channel = in_ch + layer_per_block * stage_ch 237 | self.concat = nn.Sequential( 238 | conv1x1(in_channel, concat_ch, 1)) 239 | 240 | self.SE = SE 241 | if self.SE: 242 | self.ese = eSE(concat_ch) 243 | # self.cbam = CBAM(concat_ch) 244 | 245 | def forward(self, x): 246 | identity_feat = x 247 | output = [] 248 | output.append(x) 249 | for layer in self.layers: 250 | x = layer(x) 251 | output.append(x) 252 | 253 | x = torch.cat(output, dim=1) 254 | xt = self.concat(x) 255 | 256 | if self.SE: 257 | xt = self.ese(xt) 258 | if self.identity: 259 | xt = xt + identity_feat 260 | 261 | return xt 262 | 263 | 264 | 265 | class CSP_OSA_Stage(nn.Module): 266 | 267 | def __init__(self, in_ch, stage_ch, concat_ch, block_per_stage, layer_per_block, isDown=False, isFocus=False): 268 | super().__init__() 269 | 270 | self.block_per_stage = block_per_stage 271 | 272 | self.isDown = isDown 273 | if self.isDown: 274 | if isFocus: 275 | self.downsample = Focus(in_ch, in_ch) 276 | else: 277 | self.downsample = nn.Sequential( 278 | nn.MaxPool3d(2, 2) 279 | # conv1x1(in_ch * 2, in_ch, 1) 280 | ) 281 | 282 | 283 | m = [OSA_Module(in_ch, stage_ch, concat_ch, layer_per_block, False, False)] 284 | 285 | for i in range(block_per_stage - 1): 286 | m.append(OSA_Module(concat_ch, stage_ch, concat_ch, layer_per_block, True, i == block_per_stage - 2)) 287 | self.m = nn.Sequential(*m) 288 | 289 | 290 | def forward(self, input): 291 | if self.isDown: 292 | input = self.downsample(input) 293 | 294 | out = self.m(input) 295 | 296 | return out 297 | 298 | class Focus(nn.Module): 299 | 300 | def __init__(self, c1, c2, k=1): 301 | super(Focus, self).__init__() 302 | self.conv = conv3x3(c1, c2, 1) 303 | def forward(self, input): 304 | l1 = self.conv(input[..., ::2, ::2, ::2]) + self.conv(input[..., 1::2, ::2, ::2]) + self.conv(input[..., ::2, 1::2, ::2]) + self.conv(input[..., ::2, ::2, 1::2]) + self.conv(input[..., 1::2, 1::2, ::2]) + self.conv(input[..., ::2, 1::2, 1::2]) + self.conv(input[..., 1::2, ::2, 1::2]) + self.conv(input[..., 1::2, 1::2, 1::2]) 305 | l1 = l1 / 8 306 | return l1 307 | 308 | class Upsample(nn.Module): 309 | 310 | def __init__(self, in_ch, out_ch, scale): 311 | super().__init__() 312 | self.conv = conv1x1(in_ch, out_ch, 1) 313 | self.up = nn.Upsample(scale_factor=2) 314 | 315 | def forward(self, input): 316 | input = self.conv(input) 317 | return self.up(input) 318 | 319 | class RFB(nn.Module): 320 | 321 | def __init__(self, plane, out): 322 | super().__init__() 323 | self.conv1x1 = conv1x1(plane, out) 324 | 325 | self.conv1x1_in = nn.Sequential( 326 | conv1x1(out, out), 327 | conv3x3(out, out, dilation=1) 328 | ) 329 | 330 | self.conv3x3_in = nn.Sequential( 331 | conv3x3(out, out), 332 | conv3x3(out, out, dilation=3) 333 | ) 334 | self.conv5x5_in = nn.Sequential( 335 | nn.Conv3d(out, out, 5, 1, 2, bias=False), 336 | nn.BatchNorm3d(out), 337 | Mish(), 338 | conv3x3(out, out, dilation=5) 339 | ) 340 | 341 | self.conv1x1_out = conv1x1(out * 3, out) 342 | 343 | self.ese = eSE(out) 344 | 345 | 346 | def forward(self, input): 347 | input = self.conv1x1(input) 348 | identity = input 349 | conv1x1_in = self.conv1x1_in(input) 350 | conv3x3_in = self.conv3x3_in(input) 351 | conv5x5_in = self.conv5x5_in(input) 352 | combine = torch.cat((conv1x1_in, conv3x3_in, conv5x5_in), 1) 353 | return self.ese(self.conv1x1_out(combine)) + identity 354 | 355 | class VoVNet(nn.Module): 356 | 357 | def __init__(self, config_stage_ch, config_concat_ch, block_per_stage, layer_per_block): 358 | super(VoVNet, self).__init__() 359 | 360 | basic_ch = 64 361 | self.isFocus = False 362 | self.basic_conv = nn.Sequential( 363 | nn.Conv3d(1, basic_ch, 5, 1, 2, bias=False), 364 | nn.BatchNorm3d(basic_ch), 365 | Mish(), 366 | nn.MaxPool3d(2, 2) 367 | ) 368 | 369 | basic_out = [basic_ch] 370 | in_ch_list = basic_out + config_concat_ch[:-1] 371 | 372 | self.stage00 = CSP_OSA_Stage(64, 32, 64, 2, 8, isDown=False, isFocus=False) 373 | self.stage01 = CSP_OSA_Stage(64, 32, 64, 4, 8, isDown=True, isFocus=False) 374 | self.stage12 = CSP_OSA_Stage(64, 32, 64, 4, 8, isDown=True, isFocus=False) 375 | self.stage23 = CSP_OSA_Stage(64, 32, 64, 2, 8, isDown=True, isFocus=False) 376 | 377 | self.up1 = Upsample(64, 64, 2) 378 | self.rfb1 = RFB(128, 64) 379 | 380 | self.up2 = Upsample(64, 64, 2) 381 | self.rfb2 = RFB(128, 64) 382 | 383 | self.up3 = Upsample(64, 64, 2) 384 | self.rfb3 = RFB(128, 64) 385 | 386 | self.downsample = nn.MaxPool3d(2, 2) 387 | self.rfb4 = RFB(128, 64) 388 | 389 | self.head20 = nn.Conv3d(64, len(config['anchors']) * 5, 1 , 1, 0) 390 | 391 | def forward(self, input, coord, mode='train'): 392 | recon= input 393 | 394 | x = self.basic_conv(input) # 128 395 | 396 | l00 = self.stage00(x) # 40 397 | l01 = self.stage01(l00) # 20 398 | l12 = self.stage12(l01) # 10 399 | l23 = self.stage23(l12) # 5 400 | 401 | l23_up = self.up1(l23) 402 | l12_combine = torch.cat((l23_up, l12), 1) 403 | l12_combine = F.dropout(l12_combine, 0.3) if mode == 'train' else l12_combine 404 | l12_final = self.rfb1(l12_combine) 405 | 406 | l12_up = self.up2(l12_final) 407 | l01_combine = torch.cat((l12_up, l01), 1) 408 | l01_combine = F.dropout(l01_combine, 0.3) if mode == 'train' else l01_combine 409 | l01_final = self.rfb2(l01_combine) 410 | 411 | l01_up = self.up3(l01_final) 412 | l00_combine = torch.cat((l01_up, l00), 1) 413 | l00_combine = F.dropout(l00_combine, 0.3) if mode == 'train' else l00_combine 414 | l00_final = self.rfb3(l00_combine) 415 | 416 | final = self.downsample(l00_final) 417 | final = torch.cat((final, l01_final), 1) 418 | final = F.dropout(final, 0.3) if mode == 'train' else final 419 | final = self.rfb4(final) 420 | 421 | cls_out20 = self.head20(final) 422 | cls_size = cls_out20.size() 423 | cls_out20 = cls_out20.view(cls_out20.size(0), cls_out20.size(1), -1) 424 | cls_out20 = cls_out20.transpose(1, 2).contiguous().view(cls_size[0], cls_size[2], cls_size[3], cls_size[4], len(config['anchors']), -1) 425 | cls_out20[..., 0] = torch.sigmoid(cls_out20[..., 0]) 426 | 427 | 428 | return cls_out20, recon 429 | 430 | 431 | 432 | def get_model(output_feature=False): 433 | 434 | net = VoVNet( 435 | config_stage_ch = [32, 32, 32, 32], \ 436 | config_concat_ch= [64, 96, 128, 196], \ 437 | block_per_stage = [1, 2, 2, 1], \ 438 | layer_per_block = [3, 6, 8, 12], 439 | ) 440 | # print(net) 441 | # loss = FocalLoss(config['num_hard']) 442 | loss = Loss_recon(config['num_hard'], class_loss='BCELoss') 443 | 444 | get_pbb = GetPBB(config) 445 | return config, net, loss, get_pbb 446 | 447 | 448 | 449 | if __name__ == '__main__': 450 | 451 | 452 | net = VoVNet([32, 32, 32, 32], [64, 64, 64, 64], [1, 1, 2, 2], 5) 453 | print(net) 454 | x = torch.zeros((2, 1, 80, 80, 80)) 455 | coord = torch.zeros(2, 3, 20, 20, 20) 456 | print(net(x, coord).size()) 457 | 458 | 459 | 460 | 461 | 462 | 463 | 464 | 465 | -------------------------------------------------------------------------------- /noduleCADEvaluationLUNA16.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import sys 4 | import matplotlib 5 | matplotlib.use('agg') 6 | import matplotlib.pyplot as plt 7 | from matplotlib.ticker import ScalarFormatter,LogFormatter,StrMethodFormatter,FixedFormatter 8 | import sklearn.metrics as skl_metrics 9 | import numpy as np 10 | 11 | from NoduleFinding import NoduleFinding 12 | 13 | from tools import csvTools 14 | font = {'family' : 'normal', 15 | 'size' : 17} 16 | import argparse 17 | 18 | parser = argparse.ArgumentParser(description='PyTorch DataBowl3 Detector') 19 | parser.add_argument('--model', '-m', metavar='MODEL', default='base', 20 | help='model') 21 | args = parser.parse_args() 22 | 23 | matplotlib.rc('font', **font) 24 | # Evaluation settings 25 | bPerformBootstrapping = True 26 | bNumberOfBootstrapSamples = 1000 27 | bOtherNodulesAsIrrelevant = True 28 | bConfidence = 0.95 29 | 30 | seriesuid_label = 'seriesuid' 31 | coordX_label = 'coordX' 32 | coordY_label = 'coordY' 33 | coordZ_label = 'coordZ' 34 | diameter_mm_label = 'diameter_mm' 35 | CADProbability_label = 'probability' 36 | 37 | # plot settings 38 | FROC_minX = 0.125 # Mininum value of x-axis of FROC curve 39 | FROC_maxX = 8 # Maximum value of x-axis of FROC curve 40 | bLogPlot = True 41 | 42 | def generateBootstrapSet(scanToCandidatesDict, FROCImList): 43 | ''' 44 | Generates bootstrapped version of set 45 | ''' 46 | imageLen = FROCImList.shape[0] 47 | 48 | # get a random list of images using sampling with replacement 49 | rand_index_im = np.random.randint(imageLen, size=imageLen) 50 | FROCImList_rand = FROCImList[rand_index_im] 51 | 52 | # get a new list of candidates 53 | candidatesExists = False 54 | for im in FROCImList_rand: 55 | if im not in scanToCandidatesDict: 56 | continue 57 | 58 | if not candidatesExists: 59 | candidates = np.copy(scanToCandidatesDict[im]) 60 | candidatesExists = True 61 | else: 62 | candidates = np.concatenate((candidates,scanToCandidatesDict[im]),axis = 1) 63 | 64 | return candidates 65 | 66 | def compute_mean_ci(interp_sens, confidence = 0.95): 67 | sens_mean = np.zeros((interp_sens.shape[1]),dtype = 'float32') 68 | sens_lb = np.zeros((interp_sens.shape[1]),dtype = 'float32') 69 | sens_up = np.zeros((interp_sens.shape[1]),dtype = 'float32') 70 | 71 | Pz = (1.0-confidence)/2.0 72 | 73 | for i in range(interp_sens.shape[1]): 74 | # get sorted vector 75 | vec = interp_sens[:,i] 76 | vec.sort() 77 | 78 | sens_mean[i] = np.average(vec) 79 | sens_lb[i] = vec[int(math.floor(Pz*len(vec)))] 80 | sens_up[i] = vec[int(math.floor((1.0-Pz)*len(vec)))] 81 | 82 | return sens_mean,sens_lb,sens_up 83 | 84 | def computeFROC_bootstrap(FROCGTList,FROCProbList,FPDivisorList,FROCImList,excludeList,numberOfBootstrapSamples=1000, confidence = 0.95): 85 | 86 | set1 = np.concatenate(([FROCGTList], [FROCProbList], [excludeList]), axis=0) 87 | 88 | fps_lists = [] 89 | sens_lists = [] 90 | thresholds_lists = [] 91 | 92 | FPDivisorList_np = np.asarray(FPDivisorList) 93 | FROCImList_np = np.asarray(FROCImList) 94 | 95 | # Make a dict with all candidates of all scans 96 | scanToCandidatesDict = {} 97 | for i in range(len(FPDivisorList_np)): 98 | seriesuid = FPDivisorList_np[i] 99 | candidate = set1[:,i:i+1] 100 | 101 | if seriesuid not in scanToCandidatesDict: 102 | scanToCandidatesDict[seriesuid] = np.copy(candidate) 103 | else: 104 | scanToCandidatesDict[seriesuid] = np.concatenate((scanToCandidatesDict[seriesuid],candidate),axis = 1) 105 | 106 | for i in range(numberOfBootstrapSamples): 107 | 108 | # Generate a bootstrapped set 109 | btpsamp = generateBootstrapSet(scanToCandidatesDict,FROCImList_np) 110 | fps, sens, thresholds = computeFROC(btpsamp[0,:],btpsamp[1,:],len(FROCImList_np),btpsamp[2,:]) 111 | 112 | fps_lists.append(fps) 113 | sens_lists.append(sens) 114 | thresholds_lists.append(thresholds) 115 | 116 | # compute statistic 117 | all_fps = np.linspace(FROC_minX, FROC_maxX, num=10000) 118 | 119 | # Then interpolate all FROC curves at this points 120 | interp_sens = np.zeros((numberOfBootstrapSamples,len(all_fps)), dtype = 'float32') 121 | for i in range(numberOfBootstrapSamples): 122 | interp_sens[i,:] = np.interp(all_fps, fps_lists[i], sens_lists[i]) 123 | 124 | # compute mean and CI 125 | sens_mean,sens_lb,sens_up = compute_mean_ci(interp_sens, confidence = confidence) 126 | 127 | return all_fps, sens_mean, sens_lb, sens_up 128 | 129 | def computeFROC(FROCGTList, FROCProbList, totalNumberOfImages, excludeList): 130 | # Remove excluded candidates 131 | FROCGTList_local = [] 132 | FROCProbList_local = [] 133 | for i in range(len(excludeList)): 134 | if excludeList[i] == False: 135 | FROCGTList_local.append(FROCGTList[i]) 136 | FROCProbList_local.append(FROCProbList[i]) 137 | 138 | numberOfDetectedLesions = sum(FROCGTList_local) 139 | totalNumberOfLesions = sum(FROCGTList) 140 | totalNumberOfCandidates = len(FROCProbList_local) 141 | fpr, tpr, thresholds = skl_metrics.roc_curve(FROCGTList_local, FROCProbList_local) 142 | if sum(FROCGTList) == len(FROCGTList): # Handle border case when there are no false positives and ROC analysis give nan values. 143 | print("WARNING, this system has no false positives..") 144 | fps = np.zeros(len(fpr)) 145 | else: 146 | fps = fpr * (totalNumberOfCandidates - numberOfDetectedLesions) / totalNumberOfImages 147 | sens = (tpr * numberOfDetectedLesions) / totalNumberOfLesions 148 | return fps, sens, thresholds 149 | 150 | def evaluateCAD(seriesUIDs, results_filename, outputDir, allNodules, CADSystemName, maxNumberOfCADMarks=-1, 151 | performBootstrapping=False,numberOfBootstrapSamples=1000,confidence = 0.95): 152 | ''' 153 | function to evaluate a CAD algorithm 154 | @param seriesUIDs: list of the seriesUIDs of the cases to be processed 155 | @param results_filename: file with results 156 | @param outputDir: output directory 157 | @param allNodules: dictionary with all nodule annotations of all cases, keys of the dictionary are the seriesuids 158 | @param CADSystemName: name of the CAD system, to be used in filenames and on FROC curve 159 | ''' 160 | 161 | nodOutputfile = open(os.path.join(outputDir,'CADAnalysis.txt'),'w') 162 | nodOutputfile.write("\n") 163 | nodOutputfile.write((60 * "*") + "\n") 164 | nodOutputfile.write("CAD Analysis: %s\n" % CADSystemName) 165 | nodOutputfile.write((60 * "*") + "\n") 166 | nodOutputfile.write("\n") 167 | 168 | results = csvTools.readCSV(results_filename) 169 | 170 | allCandsCAD = {} 171 | 172 | for seriesuid in seriesUIDs: 173 | 174 | # collect candidates from result file 175 | nodules = {} 176 | header = results[0] 177 | 178 | i = 0 179 | for result in results[1:]: 180 | nodule_seriesuid = result[header.index(seriesuid_label)] 181 | 182 | if seriesuid == nodule_seriesuid: 183 | nodule = getNodule(result, header) 184 | nodule.candidateID = i 185 | nodules[nodule.candidateID] = nodule 186 | i += 1 187 | 188 | if (maxNumberOfCADMarks > 0): 189 | # number of CAD marks, only keep must suspicous marks 190 | 191 | if len(nodules.keys()) > maxNumberOfCADMarks: 192 | # make a list of all probabilities 193 | probs = [] 194 | for keytemp, noduletemp in nodules.items(): 195 | probs.append(float(noduletemp.CADprobability)) 196 | probs.sort(reverse=True) # sort from large to small 197 | probThreshold = probs[maxNumberOfCADMarks] 198 | nodules2 = {} 199 | nrNodules2 = 0 200 | for keytemp, noduletemp in nodules.items(): 201 | if nrNodules2 >= maxNumberOfCADMarks: 202 | break 203 | if float(noduletemp.CADprobability) > probThreshold: 204 | nodules2[keytemp] = noduletemp 205 | nrNodules2 += 1 206 | 207 | nodules = nodules2 208 | 209 | allCandsCAD[seriesuid] = nodules 210 | 211 | # open output files 212 | nodNoCandFile = open(os.path.join(outputDir, "nodulesWithoutCandidate_%s.txt" % CADSystemName), 'w') 213 | 214 | # --- iterate over all cases (seriesUIDs) and determine how 215 | # often a nodule annotation is not covered by a candidate 216 | 217 | # initialize some variables to be used in the loop 218 | candTPs = 0 219 | candFPs = 0 220 | candFNs = 0 221 | candTNs = 0 222 | totalNumberOfCands = 0 223 | totalNumberOfNodules = 0 224 | doubleCandidatesIgnored = 0 225 | irrelevantCandidates = 0 226 | minProbValue = -1000000000.0 # minimum value of a float 227 | FROCGTList = [] 228 | FROCProbList = [] 229 | FPDivisorList = [] 230 | excludeList = [] 231 | FROCtoNoduleMap = [] 232 | ignoredCADMarksList = [] 233 | 234 | # -- loop over the cases 235 | for seriesuid in seriesUIDs: 236 | # get the candidates for this case 237 | try: 238 | candidates = allCandsCAD[seriesuid] 239 | except KeyError: 240 | candidates = {} 241 | 242 | # add to the total number of candidates 243 | totalNumberOfCands += len(candidates.keys()) 244 | 245 | # make a copy in which items will be deleted 246 | candidates2 = candidates.copy() 247 | 248 | # get the nodule annotations on this case 249 | try: 250 | noduleAnnots = allNodules[seriesuid] 251 | except KeyError: 252 | noduleAnnots = [] 253 | 254 | # - loop over the nodule annotations 255 | for noduleAnnot in noduleAnnots: 256 | # increment the number of nodules 257 | if noduleAnnot.state == "Included": 258 | totalNumberOfNodules += 1 259 | 260 | x = float(noduleAnnot.coordX) 261 | y = float(noduleAnnot.coordY) 262 | z = float(noduleAnnot.coordZ) 263 | 264 | # 2. Check if the nodule annotation is covered by a candidate 265 | # A nodule is marked as detected when the center of mass of the candidate is within a distance R of 266 | # the center of the nodule. In order to ensure that the CAD mark is displayed within the nodule on the 267 | # CT scan, we set R to be the radius of the nodule size. 268 | diameter = float(noduleAnnot.diameter_mm) 269 | if diameter < 0.0: 270 | diameter = 10.0 271 | radiusSquared = pow((diameter / 2.0), 2.0) 272 | 273 | found = False 274 | noduleMatches = [] 275 | for key, candidate in candidates.items(): 276 | x2 = float(candidate.coordX) 277 | y2 = float(candidate.coordY) 278 | z2 = float(candidate.coordZ) 279 | dist = math.pow(x - x2, 2.) + math.pow(y - y2, 2.) + math.pow(z - z2, 2.) 280 | if dist < radiusSquared: 281 | if (noduleAnnot.state == "Included"): 282 | found = True 283 | noduleMatches.append(candidate) 284 | if key not in candidates2.keys(): 285 | print("This is strange: CAD mark %s detected two nodules! Check for overlapping nodule annotations, SeriesUID: %s, nodule Annot ID: %s" % (str(candidate.id), seriesuid, str(noduleAnnot.id))) 286 | else: 287 | del candidates2[key] 288 | elif (noduleAnnot.state == "Excluded"): # an excluded nodule 289 | if bOtherNodulesAsIrrelevant: # delete marks on excluded nodules so they don't count as false positives 290 | if key in candidates2.keys(): 291 | irrelevantCandidates += 1 292 | ignoredCADMarksList.append("%s,%s,%s,%s,%s,%s,%.9f" % (seriesuid, -1, candidate.coordX, candidate.coordY, candidate.coordZ, str(candidate.id), float(candidate.CADprobability))) 293 | del candidates2[key] 294 | if len(noduleMatches) > 1: # double detection 295 | doubleCandidatesIgnored += (len(noduleMatches) - 1) 296 | if noduleAnnot.state == "Included": 297 | # only include it for FROC analysis if it is included 298 | # otherwise, the candidate will not be counted as FP, but ignored in the 299 | # analysis since it has been deleted from the nodules2 vector of candidates 300 | if found == True: 301 | # append the sample with the highest probability for the FROC analysis 302 | maxProb = None 303 | for idx in range(len(noduleMatches)): 304 | candidate = noduleMatches[idx] 305 | if (maxProb is None) or (float(candidate.CADprobability) > maxProb): 306 | maxProb = float(candidate.CADprobability) 307 | 308 | FROCGTList.append(1.0) 309 | FROCProbList.append(float(maxProb)) 310 | FPDivisorList.append(seriesuid) 311 | excludeList.append(False) 312 | FROCtoNoduleMap.append("%s,%s,%s,%s,%s,%.9f,%s,%.9f" % (seriesuid, noduleAnnot.id, noduleAnnot.coordX, noduleAnnot.coordY, noduleAnnot.coordZ, float(noduleAnnot.diameter_mm), str(candidate.id), float(candidate.CADprobability))) 313 | candTPs += 1 314 | else: 315 | candFNs += 1 316 | # append a positive sample with the lowest probability, such that this is added in the FROC analysis 317 | FROCGTList.append(1.0) 318 | FROCProbList.append(minProbValue) 319 | FPDivisorList.append(seriesuid) 320 | excludeList.append(True) 321 | FROCtoNoduleMap.append("%s,%s,%s,%s,%s,%.9f,%s,%s" % (seriesuid, noduleAnnot.id, noduleAnnot.coordX, noduleAnnot.coordY, noduleAnnot.coordZ, float(noduleAnnot.diameter_mm), int(-1), "NA")) 322 | nodNoCandFile.write("%s,%s,%s,%s,%s,%.9f,%s\n" % (seriesuid, noduleAnnot.id, noduleAnnot.coordX, noduleAnnot.coordY, noduleAnnot.coordZ, float(noduleAnnot.diameter_mm), str(-1))) 323 | 324 | # add all false positives to the vectors 325 | for key, candidate3 in candidates2.items(): 326 | candFPs += 1 327 | FROCGTList.append(0.0) 328 | FROCProbList.append(float(candidate3.CADprobability)) 329 | FPDivisorList.append(seriesuid) 330 | excludeList.append(False) 331 | FROCtoNoduleMap.append("%s,%s,%s,%s,%s,%s,%.9f" % (seriesuid, -1, candidate3.coordX, candidate3.coordY, candidate3.coordZ, str(candidate3.id), float(candidate3.CADprobability))) 332 | 333 | if not (len(FROCGTList) == len(FROCProbList) and len(FROCGTList) == len(FPDivisorList) and len(FROCGTList) == len(FROCtoNoduleMap) and len(FROCGTList) == len(excludeList)): 334 | nodOutputfile.write("Length of FROC vectors not the same, this should never happen! Aborting..\n") 335 | 336 | nodOutputfile.write("Candidate detection results:\n") 337 | nodOutputfile.write(" True positives: %d\n" % candTPs) 338 | nodOutputfile.write(" False positives: %d\n" % candFPs) 339 | nodOutputfile.write(" False negatives: %d\n" % candFNs) 340 | nodOutputfile.write(" True negatives: %d\n" % candTNs) 341 | nodOutputfile.write(" Total number of candidates: %d\n" % totalNumberOfCands) 342 | nodOutputfile.write(" Total number of nodules: %d\n" % totalNumberOfNodules) 343 | 344 | nodOutputfile.write(" Ignored candidates on excluded nodules: %d\n" % irrelevantCandidates) 345 | nodOutputfile.write(" Ignored candidates which were double detections on a nodule: %d\n" % doubleCandidatesIgnored) 346 | if int(totalNumberOfNodules) == 0: 347 | nodOutputfile.write(" Sensitivity: 0.0\n") 348 | else: 349 | nodOutputfile.write(" Sensitivity: %.9f\n" % (float(candTPs) / float(totalNumberOfNodules))) 350 | nodOutputfile.write(" Average number of candidates per scan: %.9f\n" % (float(totalNumberOfCands) / float(len(seriesUIDs)))) 351 | 352 | # compute FROC 353 | fps, sens, thresholds = computeFROC(FROCGTList,FROCProbList,len(seriesUIDs),excludeList) 354 | 355 | if performBootstrapping: 356 | fps_bs_itp,sens_bs_mean,sens_bs_lb,sens_bs_up = computeFROC_bootstrap(FROCGTList,FROCProbList,FPDivisorList,seriesUIDs,excludeList, 357 | numberOfBootstrapSamples=numberOfBootstrapSamples, confidence = confidence) 358 | 359 | # Write FROC curve 360 | with open(os.path.join(outputDir, "froc_%s.txt" % CADSystemName), 'w') as f: 361 | for i in range(len(sens)): 362 | f.write("%.9f,%.9f,%.9f\n" % (fps[i], sens[i], thresholds[i])) 363 | 364 | # Write FROC vectors to disk as well 365 | with open(os.path.join(outputDir, "froc_gt_prob_vectors_%s.csv" % CADSystemName), 'w') as f: 366 | for i in range(len(FROCGTList)): 367 | f.write("%d,%.9f\n" % (FROCGTList[i], FROCProbList[i])) 368 | 369 | fps_itp = np.linspace(FROC_minX, FROC_maxX, num=10001) 370 | 371 | sens_itp = np.interp(fps_itp, fps, sens) 372 | frvvlu = 0 373 | nxth = 0.125 374 | ssss = [] 375 | for fp, ss in zip(fps_itp, sens_itp): 376 | if abs(fp - nxth) < 3e-4: 377 | frvvlu += ss 378 | ssss.append(ss) 379 | nxth *= 2 380 | if abs(nxth - 16) < 1e-5: break 381 | 382 | print(frvvlu/7) 383 | print(frvvlu/7, nxth) 384 | print(ssss[0], ssss[1], ssss[2], ssss[3], ssss[4], ssss[5], ssss[6]) 385 | if performBootstrapping: 386 | # Write mean, lower, and upper bound curves to disk 387 | with open(os.path.join(outputDir, "froc_%s_bootstrapping.csv" % CADSystemName), 'w') as f: 388 | f.write("FPrate,Sensivity[Mean],Sensivity[Lower bound],Sensivity[Upper bound]\n") 389 | for i in range(len(fps_bs_itp)): 390 | f.write("%.9f,%.9f,%.9f,%.9f\n" % (fps_bs_itp[i], sens_bs_mean[i], sens_bs_lb[i], sens_bs_up[i])) 391 | else: 392 | fps_bs_itp = None 393 | sens_bs_mean = None 394 | sens_bs_lb = None 395 | sens_bs_up = None 396 | 397 | # create FROC graphs 398 | if int(totalNumberOfNodules) > 0: 399 | graphTitle = str("") 400 | fig1 = plt.figure() 401 | ax = plt.gca() 402 | clr = 'b' 403 | plt.plot(fps_itp, sens_itp, color=clr, label="%s" % CADSystemName, lw=2) 404 | if performBootstrapping: 405 | plt.plot(fps_bs_itp, sens_bs_mean, color=clr, ls='--') 406 | 407 | np.save(os.path.join(outputDir,'%s_fps.npy' % (CADSystemName)), fps_bs_itp) 408 | np.save(os.path.join(outputDir,'%s_sens.npy' % (CADSystemName)), sens_bs_mean) 409 | 410 | plt.plot(fps_bs_itp, sens_bs_lb, color=clr, ls=':') # , label = "lb") 411 | plt.plot(fps_bs_itp, sens_bs_up, color=clr, ls=':') # , label = "ub") 412 | ax.fill_between(fps_bs_itp, sens_bs_lb, sens_bs_up, facecolor=clr, alpha=0.05) 413 | xmin = FROC_minX 414 | xmax = FROC_maxX 415 | plt.xlim(xmin, xmax) 416 | plt.ylim(0.1, 1) 417 | plt.xlabel('Average number of false positives per scan') 418 | plt.ylabel('Sensitivity') 419 | plt.legend(loc='lower right') 420 | plt.title('FROC performance - %s' % (CADSystemName)) 421 | 422 | if bLogPlot: 423 | plt.xscale('log', basex=2) 424 | ax.xaxis.set_major_formatter(FixedFormatter([0.125,0.25,0.5,1,2,4,8])) 425 | 426 | # set your ticks manually 427 | ax.xaxis.set_ticks([0.125,0.25,0.5,1,2,4,8]) 428 | ax.yaxis.set_ticks(np.arange(0.5, 1, 0.1)) 429 | # ax.yaxis.set_ticks(np.arange(0, 1.1, 0.1)) 430 | plt.grid(b=True, which='both') 431 | plt.tight_layout() 432 | 433 | plt.savefig(os.path.join(outputDir, "froc_%s.png" % CADSystemName), bbox_inches=0, dpi=300) 434 | 435 | return (fps, sens, thresholds, fps_bs_itp, sens_bs_mean, sens_bs_lb, sens_bs_up) 436 | 437 | def getNodule(annotation, header, state = ""): 438 | nodule = NoduleFinding() 439 | nodule.coordX = annotation[header.index(coordX_label)] 440 | nodule.coordY = annotation[header.index(coordY_label)] 441 | nodule.coordZ = annotation[header.index(coordZ_label)] 442 | 443 | if diameter_mm_label in header: 444 | nodule.diameter_mm = annotation[header.index(diameter_mm_label)] 445 | 446 | if CADProbability_label in header: 447 | nodule.CADprobability = annotation[header.index(CADProbability_label)] 448 | 449 | if not state == "": 450 | nodule.state = state 451 | 452 | return nodule 453 | 454 | def collectNoduleAnnotations(annotations, annotations_excluded, seriesUIDs): 455 | allNodules = {} 456 | noduleCount = 0 457 | noduleCountTotal = 0 458 | 459 | for seriesuid in seriesUIDs: 460 | # print 'adding nodule annotations: ' + seriesuid 461 | 462 | nodules = [] 463 | numberOfIncludedNodules = 0 464 | 465 | # add included findings 466 | header = annotations[0] 467 | for annotation in annotations[1:]: 468 | nodule_seriesuid = annotation[header.index(seriesuid_label)] 469 | 470 | if seriesuid == nodule_seriesuid: 471 | nodule = getNodule(annotation, header, state = "Included") 472 | nodules.append(nodule) 473 | numberOfIncludedNodules += 1 474 | 475 | # add excluded findings 476 | header = annotations_excluded[0] 477 | for annotation in annotations_excluded[1:]: 478 | nodule_seriesuid = annotation[header.index(seriesuid_label)] 479 | 480 | if seriesuid == nodule_seriesuid: 481 | nodule = getNodule(annotation, header, state = "Excluded") 482 | nodules.append(nodule) 483 | 484 | allNodules[seriesuid] = nodules 485 | noduleCount += numberOfIncludedNodules 486 | noduleCountTotal += len(nodules) 487 | 488 | return allNodules 489 | 490 | 491 | def collect(annotations_filename,annotations_excluded_filename,seriesuids_filename): 492 | annotations = csvTools.readCSV(annotations_filename) 493 | annotations_excluded = csvTools.readCSV(annotations_excluded_filename) 494 | seriesUIDs_csv = csvTools.readCSV(seriesuids_filename) 495 | 496 | seriesUIDs = [] 497 | for seriesUID in seriesUIDs_csv: 498 | seriesUIDs.append(seriesUID[0]) 499 | 500 | allNodules = collectNoduleAnnotations(annotations, annotations_excluded, seriesUIDs) 501 | 502 | return (allNodules, seriesUIDs) 503 | 504 | 505 | def noduleCADEvaluation(annotations_filename,annotations_excluded_filename,seriesuids_filename,results_filename,outputDir): 506 | ''' 507 | function to load annotations and evaluate a CAD algorithm 508 | @param annotations_filename: list of annotations 509 | @param annotations_excluded_filename: list of annotations that are excluded from analysis 510 | @param seriesuids_filename: list of CT images in seriesuids 511 | @param results_filename: list of CAD marks with probabilities 512 | @param outputDir: output directory 513 | ''' 514 | if not os.path.isdir(outputDir): 515 | os.mkdir(outputDir) 516 | # print(annotations_filename) 517 | 518 | (allNodules, seriesUIDs) = collect(annotations_filename, annotations_excluded_filename, seriesuids_filename) 519 | 520 | evaluateCAD(seriesUIDs, results_filename, outputDir, allNodules, 521 | os.path.splitext(os.path.basename(results_filename))[0], 522 | maxNumberOfCADMarks=100, performBootstrapping=bPerformBootstrapping, 523 | numberOfBootstrapSamples=bNumberOfBootstrapSamples, confidence=bConfidence) 524 | 525 | 526 | if __name__ == '__main__': 527 | annotations_filename = '../data/LUNA16/evaluationScript/annotations/annotations.csv' 528 | annotations_excluded_filename = '../data/LUNA16/evaluationScript/annotations/annotations_excluded.csv' 529 | seriesuids_filename = '../data/LUNA16/evaluationScript/annotations/seriesuids.csv' 530 | 531 | results_filename = args.model+'_80_all.csv' 532 | outputDir = './CPM_Results_'+args.model+'_80/' 533 | noduleCADEvaluation(annotations_filename,annotations_excluded_filename,seriesuids_filename,results_filename,outputDir) 534 | print("Finished!") -------------------------------------------------------------------------------- /prepare.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | #coding=utf-8 3 | 4 | import os 5 | import shutil 6 | import numpy as np 7 | from scipy.ndimage.interpolation import zoom 8 | import SimpleITK as sitk 9 | from scipy.ndimage.morphology import binary_dilation, generate_binary_structure 10 | from skimage.morphology import convex_hull_image 11 | import pandas 12 | import warnings 13 | from glob import glob 14 | import concurrent.futures 15 | from config_training import config 16 | 17 | def resample(imgs, spacing, new_spacing, order=2): 18 | if len(imgs.shape)==3: 19 | new_shape = np.round(imgs.shape * spacing / new_spacing) 20 | true_spacing = spacing * imgs.shape / new_shape 21 | resize_factor = new_shape / imgs.shape 22 | imgs = zoom(imgs, resize_factor, mode='nearest', order=order) 23 | return imgs, true_spacing 24 | elif len(imgs.shape) == 4: 25 | n = imgs.shape[-1] 26 | newimg = [] 27 | for i in range(n): 28 | slice = imgs[:,:,:,i] 29 | newslice,true_spacing = resample(slice, spacing, new_spacing) 30 | newimg.append(newslice) 31 | newimg = np.transpose(np.array(newimg), [1, 2, 3, 0]) 32 | return newimg, true_spacing 33 | else: 34 | raise ValueError('wrong shape') 35 | 36 | def worldToVoxelCoord(worldCoord, origin, spacing): 37 | stretchedVoxelCoord = np.absolute(worldCoord - origin) 38 | voxelCoord = stretchedVoxelCoord / spacing 39 | return voxelCoord 40 | 41 | def load_itk_image(filename): 42 | with open(filename) as f: 43 | contents = f.readlines() 44 | line = [k for k in contents if k.startswith('TransformMatrix')][0] 45 | transformM = np.array(line.split(' = ')[1].split(' ')).astype('float') 46 | transformM = np.round(transformM) 47 | if np.any(transformM!=np.array([1, 0, 0, 0, 1, 0, 0, 0, 1])): 48 | isflip = True 49 | else: 50 | isflip = False 51 | 52 | itkimage = sitk.ReadImage(filename) 53 | numpyImage = sitk.GetArrayFromImage(itkimage) 54 | 55 | numpyOrigin = np.array(list(reversed(itkimage.GetOrigin()))) 56 | numpySpacing = np.array(list(reversed(itkimage.GetSpacing()))) 57 | 58 | return numpyImage, numpyOrigin, numpySpacing, isflip 59 | 60 | def process_mask(mask): 61 | convex_mask = np.copy(mask) 62 | for i_layer in range(convex_mask.shape[0]): 63 | mask1 = np.ascontiguousarray(mask[i_layer]) 64 | if np.sum(mask1) > 0: 65 | mask2 = convex_hull_image(mask1) 66 | if np.sum(mask2) > 1.5 * np.sum(mask1): 67 | mask2 = mask1 68 | else: 69 | mask2 = mask1 70 | convex_mask[i_layer] = mask2 71 | struct = generate_binary_structure(3, 1) 72 | dilatedMask = binary_dilation(convex_mask, structure=struct, iterations=10) 73 | return dilatedMask 74 | 75 | def lumTrans(img): 76 | lungwin = np.array([-1200., 600.]) 77 | newimg = (img - lungwin[0]) / (lungwin[1] - lungwin[0]) 78 | newimg[newimg < 0] = 0 79 | newimg[newimg > 1] = 1 80 | newimg = (newimg*255).astype('uint8') 81 | return newimg 82 | 83 | def savenpy_luna(id, annos, filelist, luna_segment, luna_data, savepath): 84 | """ 85 | Note: Dr. Chen adds malignancy label, so the label becomes (z,y,x,d,malignancy), <- but I cancelled it ! 86 | """ 87 | islabel = True 88 | isClean = True 89 | resolution = np.array([1, 1, 1]) 90 | name = filelist[id] 91 | 92 | # Load mask, and calculate extendbox from the mask 93 | Mask, origin, spacing, isflip = load_itk_image(os.path.join(luna_segment, name+'.mhd')) 94 | if isflip: 95 | Mask = Mask[:,::-1,::-1] 96 | newshape = np.round(np.array(Mask.shape)*spacing/resolution).astype('int') 97 | m1 = Mask==3 98 | m2 = Mask==4 99 | Mask = m1+m2 100 | 101 | xx,yy,zz= np.where(Mask) 102 | box = np.array([[np.min(xx),np.max(xx)],[np.min(yy),np.max(yy)],[np.min(zz),np.max(zz)]]) 103 | box = box*np.expand_dims(spacing,1)/np.expand_dims(resolution,1) 104 | box = np.floor(box).astype('int') 105 | margin = 5 106 | extendbox = np.vstack([np.max([[0,0,0],box[:,0]-margin],0),np.min([newshape,box[:,1]+2*margin],axis=0).T]).T 107 | 108 | 109 | if isClean: 110 | dm1 = process_mask(m1) 111 | dm2 = process_mask(m2) 112 | dilatedMask = dm1 + dm2 113 | Mask = m1 + m2 114 | extramask = dilatedMask ^ Mask # '-' substration is deprecated in numpy, use '^' 115 | bone_thresh = 210 116 | pad_value = 170 117 | 118 | sliceim, origin, spacing, isflip = load_itk_image(os.path.join(luna_data, name+'.mhd')) 119 | if isflip: 120 | sliceim = sliceim[:,::-1,::-1] 121 | print('{}: flip!'.format(name)) 122 | sliceim = lumTrans(sliceim) 123 | sliceim = sliceim*dilatedMask+pad_value*(1-dilatedMask).astype('uint8') 124 | bones = (sliceim*extramask)>bone_thresh 125 | sliceim[bones] = pad_value 126 | 127 | sliceim1,_ = resample(sliceim,spacing,resolution,order=1) 128 | sliceim2 = sliceim1[extendbox[0,0]:extendbox[0,1], 129 | extendbox[1,0]:extendbox[1,1], 130 | extendbox[2,0]:extendbox[2,1]] 131 | sliceim = sliceim2[np.newaxis,...] 132 | 133 | np.save(os.path.join(savepath, name + '_clean.npy'),sliceim) 134 | 135 | np.save(os.path.join(savepath, name+'_spacing.npy'), spacing) 136 | np.save(os.path.join(savepath, name+'_extendbox.npy'), extendbox) 137 | np.save(os.path.join(savepath, name+'_origin.npy'), origin) 138 | np.save(os.path.join(savepath, name+'_mask.npy'), Mask) 139 | 140 | 141 | if islabel: 142 | this_annos = np.copy(annos[annos[:,0] == name]) 143 | label = [] 144 | 145 | if len(this_annos)>0: 146 | for c in this_annos: # unit in mm --> voxel 147 | pos = worldToVoxelCoord(c[1:4][::-1], origin=origin, spacing=spacing) # (z,y,x) 148 | if isflip: 149 | pos[1:] = Mask.shape[1:3] - pos[1:] # flip in y and x coordinates 150 | d = c[4]/spacing[1] 151 | try: 152 | malignancy = int(c[5]) 153 | except IndexError: 154 | malignancy = 0 155 | # label.append(np.concatenate([pos,[d],[malignancy]])) # (z,y,x,d,malignancy) 156 | label.append(np.concatenate([pos,[d]])) # (z,y,x,d) 157 | 158 | label = np.array(label) 159 | 160 | # Voxel --> resample to (1mm,1mm,1mm) voxel coordinate 161 | if len(label)==0: 162 | # label2 = np.array([[0,0,0,0,0]]) 163 | label2 = np.array([[0,0,0,0]]) 164 | else: 165 | label2 = np.copy(label).T 166 | label2[:3] = label2[:3]*np.expand_dims(spacing,1)/np.expand_dims(resolution,1) 167 | label2[3] = label2[3]*spacing[1]/resolution[1] 168 | label2[:3] = label2[:3]-np.expand_dims(extendbox[:,0],1) 169 | # label2 = label2[:5].T #(z,y,x,d,malignancy) 170 | label2 = label2[:4].T #(z,y,x,d) 171 | 172 | np.save(os.path.join(savepath, name+'_label.npy'), label2) 173 | 174 | print('{} is done.'.format(name)) 175 | 176 | def preprocess_luna(): 177 | luna_segment = config['luna_segment'] 178 | savepath = config['preprocess_result_path'] 179 | luna_data = config['luna_data'] 180 | luna_label = config['luna_label'] 181 | finished_flag = '.flag_preprocess_luna' 182 | 183 | print('starting preprocessing luna') 184 | 185 | if True: 186 | exist_files = {f.split('_clean.npy')[0] for f in os.listdir(savepath) if f.endswith('_clean.npy')} 187 | filelist = {f.split('.mhd')[0] for f in os.listdir(luna_data) if f.endswith('.mhd')} 188 | filelist = list(filelist - exist_files) 189 | annos = np.array(pandas.read_csv(luna_label)) 190 | 191 | if not os.path.isdir(savepath): 192 | os.mkdir(savepath) 193 | 194 | with concurrent.futures.ProcessPoolExecutor(max_workers=4) as executor: 195 | futures = {executor.submit(savenpy_luna, f, annos=annos, filelist=filelist, 196 | luna_segment=luna_segment, luna_data=luna_data, savepath=savepath):f for f in range(len(filelist))} 197 | for future in concurrent.futures.as_completed(futures): 198 | filename = filelist[futures[future]] 199 | try: 200 | _ = future.result() 201 | except: 202 | print('{} failed.'.format(filename)) 203 | 204 | 205 | print('end preprocessing luna') 206 | f = open(finished_flag,"w+") 207 | f.close() 208 | return 209 | 210 | 211 | if __name__=='__main__': 212 | # Pre-process LUNA16 MHD files 213 | preprocess_luna() 214 | -------------------------------------------------------------------------------- /split_combine.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | #coding=utf-8 3 | 4 | import numpy as np 5 | 6 | class SplitComb(): 7 | def __init__(self, side_len, max_stride, stride, margin, pad_value): 8 | self.side_len = side_len 9 | self.max_stride = max_stride 10 | self.stride = stride 11 | self.margin = margin 12 | self.pad_value = pad_value 13 | 14 | def split(self, data, side_len=None, max_stride=None, margin=None): 15 | if side_len == None: 16 | side_len = self.side_len 17 | if max_stride == None: 18 | max_stride = self.max_stride 19 | if margin == None: 20 | margin = self.margin 21 | 22 | assert(side_len > margin) 23 | assert(side_len % max_stride == 0) 24 | assert(margin % max_stride == 0) 25 | 26 | splits = [] 27 | _, z, h, w = data.shape 28 | 29 | nz = int(np.ceil(float(z) / side_len)) 30 | nh = int(np.ceil(float(h) / side_len)) 31 | nw = int(np.ceil(float(w) / side_len)) 32 | 33 | nzhw = [nz, nh, nw] 34 | self.nzhw = nzhw 35 | 36 | pad = [ [0, 0], 37 | [int(margin), int(nz * side_len - z + margin)], 38 | [int(margin), int(nh * side_len - h + margin)], 39 | [int(margin), int(nw * side_len - w + margin)]] 40 | data = np.pad(data, pad, 'edge') 41 | 42 | for iz in range(nz): 43 | for ih in range(nh): 44 | for iw in range(nw): 45 | sz = int(iz * side_len) 46 | ez = int((iz + 1) * side_len + 2 * margin) 47 | sh = int(ih * side_len) 48 | eh = int((ih + 1) * side_len + 2 * margin) 49 | sw = int(iw * side_len) 50 | ew = int((iw + 1) * side_len + 2 * margin) 51 | 52 | split = data[np.newaxis, :, sz:ez, sh:eh, sw:ew] 53 | splits.append(split) 54 | 55 | splits = np.concatenate(splits, 0) 56 | return splits, nzhw 57 | 58 | def combine(self, output, nzhw=None, side_len=None, stride=None, margin=None): 59 | 60 | if side_len is None: 61 | side_len = self.side_len 62 | if stride is None: 63 | stride = self.stride 64 | if margin is None: 65 | margin = self.margin 66 | if nzhw is None: 67 | nz = self.nz 68 | nh = self.nh 69 | nw = self.nw 70 | else: 71 | nz, nh, nw = nzhw 72 | assert(side_len % stride == 0) 73 | assert(margin % stride == 0) 74 | side_len //= stride 75 | margin //= stride 76 | 77 | splits = [] 78 | for i in range(len(output)): 79 | splits.append(output[i]) 80 | 81 | output = -1000000 * np.ones(( 82 | nz * side_len, 83 | nh * side_len, 84 | nw * side_len, 85 | splits[0].shape[3], 86 | splits[0].shape[4]), np.float32) 87 | 88 | idx = 0 89 | for iz in range(nz): 90 | for ih in range(nh): 91 | for iw in range(nw): 92 | sz = iz * side_len 93 | ez = (iz + 1) * side_len 94 | sh = ih * side_len 95 | eh = (ih + 1) * side_len 96 | sw = iw * side_len 97 | ew = (iw + 1) * side_len 98 | 99 | split = splits[idx][margin:margin + side_len, margin:margin + side_len, margin:margin + side_len] 100 | output[sz:ez, sh:eh, sw:ew] = split 101 | idx += 1 102 | 103 | return output 104 | --------------------------------------------------------------------------------