├── input └── .gitkeep ├── checkpoints └── .gitkeep ├── metrics ├── __init__.py └── accuracy.py ├── models ├── __init__.py └── resnet34.py ├── trainer ├── __init__.py └── trainer.py ├── dataloaders ├── __init__.py └── subset_random_dataloader.py ├── .gitignore ├── datasets ├── __init__.py ├── chest_xray_pneumonia_dataset.py ├── covid_chestxray_dataset.py └── nih_cx38_dataset.py ├── diagnostics.py ├── tools └── docker │ ├── setup.sh │ └── setup.fish ├── LICENSE ├── train.py ├── README.md ├── Dockerfile └── environment.yml /input/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /checkpoints/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .accuracy import * -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet34 import Resnet34 -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import Trainer -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from .subset_random_dataloader import SubsetRandomDataLoader -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__/ 3 | *.py[cod] 4 | .ipynb_checkpoints 5 | input/* 6 | *.pth 7 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .chest_xray_pneumonia_dataset import ChestXRayPneumoniaDataset 2 | from .covid_chestxray_dataset import COVIDChestXRayDataset 3 | from .nih_cx38_dataset import NIHCX38Dataset 4 | -------------------------------------------------------------------------------- /metrics/accuracy.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class Accuracy(nn.Module): 4 | def __init__(self, threshold=0.5): 5 | super().__init__() 6 | self.threshold = threshold 7 | 8 | def forward(self, y_true, y_pred): 9 | preds = (y_pred > self.threshold).int() 10 | return (preds == y_true).sum().float() / len(preds) -------------------------------------------------------------------------------- /diagnostics.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import sklearn 4 | import pandas as pd 5 | import cv2 6 | import torch 7 | 8 | print('Python version: {}'. format(sys.version)) 9 | print('NumPy version: {}'. format(np.__version__)) 10 | print('scikit-learn version: {}'. format(sklearn.__version__)) 11 | print('pandas version: {}'. format(pd.__version__)) 12 | print('OpenCV version: {}'. format(cv2.__version__)) 13 | print('Torch version: {}'. format(torch.__version__)) 14 | print('Available GPUs: {}'.format(torch.cuda.device_count())) 15 | if torch.cuda.is_available: 16 | device = torch.device('cuda') 17 | print('Cuda version: {}'.format(torch.version.cuda)) 18 | else: 19 | device = torch.device('cpu') 20 | print("Torch device: {}".format(device)) -------------------------------------------------------------------------------- /dataloaders/subset_random_dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import DataLoader, SubsetRandomSampler 4 | 5 | def null_collate(batch): 6 | batch_size = len(batch) 7 | images = np.array([x[0] for x in batch]) 8 | images = torch.from_numpy(images) 9 | 10 | labels = np.array([x[1] for x in batch]) 11 | labels = torch.from_numpy(labels) 12 | labels = labels.unsqueeze(1) 13 | 14 | assert(images.shape[0] == labels.shape[0] == batch_size) 15 | 16 | return images, labels 17 | 18 | class SubsetRandomDataLoader(DataLoader): 19 | def __init__(self, dataset, indexes, batch_size): 20 | loader_params = dict( 21 | batch_size=batch_size, 22 | num_workers=1, 23 | pin_memory=True, 24 | collate_fn=null_collate 25 | ) 26 | super(SubsetRandomDataLoader, self).__init__(dataset=dataset, sampler=SubsetRandomSampler(indexes), **loader_params) 27 | -------------------------------------------------------------------------------- /tools/docker/setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DOCKER_IMAGE=defeatcovid:0.1 4 | 5 | if [[ "$(docker images -q $DOCKER_IMAGE 2> /dev/null)" == "" ]]; then 6 | docker build --tag defeatcovid:0.1 . 7 | fi 8 | 9 | is_repo_root() { 10 | if [ ! -f "Dockerfile" ]; then 11 | echo "This needs to be source from the root of the repository!" 12 | return 13 | fi 14 | } 15 | 16 | is_repo_root 17 | 18 | WD=$PWD 19 | 20 | dkrun() { 21 | if [[ -z "$1" ]]; then 22 | echo "Script to run missing!" 23 | return 24 | fi 25 | if [ ! is_repo_root ]; then 26 | return 27 | fi 28 | docker run -it -v $PWD:$PWD $DOCKER_IMAGE python3 $PWD/$1 29 | echo "docker run -ti --gpus all -e NVIDIA_DRIVER_CAPABILITIES=compute,utility -e NVIDIA_VISIBLE_DEVICES=all -v $WD:$WD -w $WD $DOCKER_IMAGE python3 $WD/$1" 30 | docker run -ti --gpus all -e NVIDIA_DRIVER_CAPABILITIES=compute,utility -e NVIDIA_VISIBLE_DEVICES=all -v $WD:$WD -w $WD $DOCKER_IMAGE python3 $WD/$1 31 | } -------------------------------------------------------------------------------- /tools/docker/setup.fish: -------------------------------------------------------------------------------- 1 | #!/bin/fish 2 | 3 | set DOCKER_IMAGE defeatcovid:0.1 4 | 5 | if test (docker images -q $DOCKER_IMAGE 2> /dev/null) = ""; 6 | echo "Building docker image..." 7 | docker build --tag defeatcovid:0.1 . 8 | else 9 | echo "Docker image already exists! Skipping build phase..." 10 | end 11 | 12 | function is_repo_root 13 | if not test -e Dockerfile; 14 | echo "This needs to be source from the root of the repository!" 15 | exit 2 16 | end 17 | end 18 | 19 | is_repo_root 20 | 21 | set WD $PWD 22 | 23 | function dkrun 24 | if not test -n "$argv"; 25 | echo "Script to run missing!" 26 | return 27 | end 28 | if not is_repo_root; 29 | return 30 | end 31 | echo "docker run -ti --gpus all -e NVIDIA_DRIVER_CAPABILITIES=compute,utility -e NVIDIA_VISIBLE_DEVICES=all -v $WD:$WD -w $WD $DOCKER_IMAGE python3 $WD/$argv" 32 | docker run -ti --gpus all -e NVIDIA_DRIVER_CAPABILITIES=compute,utility -e NVIDIA_VISIBLE_DEVICES=all -v $WD:$WD -w $WD $DOCKER_IMAGE python3 $WD/$argv 33 | end -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 defeatcovid19 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/resnet34.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision.models import resnet34 3 | 4 | class Resnet34(nn.Module): 5 | def __init__(self, num_classes=1, dropout=0.5): 6 | super(Resnet34, self).__init__() 7 | resnet = resnet34(pretrained=True) 8 | 9 | self.conv1 = resnet.conv1 10 | self.bn1 = resnet.bn1 11 | self.relu = resnet.relu 12 | self.maxpool = resnet.maxpool 13 | self.layer1 = resnet.layer1 14 | self.layer2 = resnet.layer2 15 | self.layer3 = resnet.layer3 16 | self.layer4 = resnet.layer4 17 | self.avgpool = resnet.avgpool 18 | bottleneck_features = resnet.fc.in_features 19 | self.fc = nn.Sequential( 20 | nn.BatchNorm1d(bottleneck_features), 21 | nn.Dropout(dropout), 22 | nn.Linear(bottleneck_features, num_classes), 23 | nn.Sigmoid() 24 | ) 25 | 26 | def forward(self, x): 27 | # mean = MEAN 28 | # std = STD 29 | x = x / 255. 30 | # x = torch.cat([ 31 | # (x[:, [0]] - mean[0]) / std[0], 32 | # (x[:, [1]] - mean[1]) / std[1], 33 | # (x[:, [2]] - mean[2]) / std[2], 34 | # (x[:, [3]] - mean[3]) / std[3], 35 | # ], 1) 36 | x = self.conv1(x) 37 | x = self.bn1(x) 38 | x = self.relu(x) 39 | x = self.maxpool(x) 40 | 41 | x = self.layer1(x) 42 | x = self.layer2(x) 43 | x = self.layer3(x) 44 | x = self.layer4(x) 45 | 46 | x = self.avgpool(x) 47 | x = x.view(x.size(0), -1) 48 | x = self.fc(x) 49 | return x -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import diagnostics 3 | from pathlib import Path 4 | from datasets import ChestXRayPneumoniaDataset, COVIDChestXRayDataset, NIHCX38Dataset 5 | from models import Resnet34 6 | from trainer import Trainer 7 | from sklearn.model_selection import train_test_split, StratifiedKFold 8 | 9 | 10 | import random 11 | import numpy as np 12 | import torch 13 | 14 | # Fix seed to improve reproducibility 15 | SEED = 6666 16 | random.seed(SEED) 17 | np.random.seed(SEED) 18 | torch.manual_seed(SEED) 19 | torch.cuda.manual_seed(SEED) 20 | torch.backends.cudnn.deterministic = True 21 | 22 | batch_size = 64 23 | size = 256 24 | n_splits = 5 25 | 26 | # Pretrain with Chest XRay Pneumonia dataset (>5k images) 27 | pneumonia_classifier = Resnet34() 28 | dataset = ChestXRayPneumoniaDataset(Path('input/chest-xray-pneumonia'), size) 29 | # dataset = NIHCX38Dataset(Path('input/nih-cx38'), size, balance=True) 30 | train_idx, validation_idx = train_test_split( 31 | list(range(len(dataset))), 32 | test_size=0.2, 33 | stratify=dataset.labels 34 | ) 35 | trainer = Trainer(pneumonia_classifier, dataset, batch_size, train_idx, validation_idx) 36 | trainer.run(max_epochs=2) 37 | 38 | # Fine tune with COVID-19 Chest XRay dataset (~120 images) 39 | dataset = COVIDChestXRayDataset(Path('input/covid_chestxray'), size) 40 | print('Executing a {}-fold cross validation'.format(n_splits)) 41 | split = 1 42 | skf = StratifiedKFold(n_splits=n_splits) 43 | for train_idx, validation_idx in skf.split(dataset.df, dataset.labels): 44 | print('===Split #{}==='.format(split)) 45 | # Start from the pneumonia classifier 46 | classifier = copy.deepcopy(pneumonia_classifier) 47 | trainer = Trainer(classifier, dataset, batch_size, train_idx, validation_idx) 48 | trainer.run(max_epochs=15) 49 | split += 1 50 | -------------------------------------------------------------------------------- /datasets/chest_xray_pneumonia_dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import pandas as pd 3 | import numpy as np 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | class ChestXRayPneumoniaDataset(Dataset): 8 | def __init__(self, path, size=128, augment=None): 9 | super(ChestXRayPneumoniaDataset, self).__init__() 10 | print('{} initialized with size={}, augment={}'.format(self.__class__.__name__, size, augment)) 11 | print('Dataset is located in {}'.format(path)) 12 | self.size = size 13 | self.augment = augment 14 | 15 | train_dir = path / 'train' 16 | val_dir = path / 'val' 17 | test_dir = path / 'test' 18 | 19 | normal_cases = [] 20 | pneumonia_cases = [] 21 | for folder in [train_dir, val_dir, test_dir]: 22 | normal_cases.extend((folder / 'NORMAL').glob('*.jpeg')) 23 | pneumonia_cases.extend((folder / 'PNEUMONIA').glob('*.jpeg')) 24 | 25 | self.labels = np.concatenate(( 26 | np.zeros(len(normal_cases)), 27 | np.ones(len(pneumonia_cases)) 28 | )).reshape(-1, 1) 29 | images = np.concatenate((normal_cases, pneumonia_cases)).reshape(-1, 1) 30 | 31 | self.df = pd.DataFrame(np.concatenate((images, self.labels), axis=1), columns=['image', 'label']) 32 | 33 | del images 34 | 35 | print("Dataset: {}".format(self.df)) 36 | 37 | 38 | @staticmethod 39 | def _load_image(path, size): 40 | img = Image.open(path) 41 | img = cv2.resize(np.array(img), (size, size), interpolation=cv2.INTER_AREA) 42 | if len(img.shape) == 2: 43 | img = np.expand_dims(img, axis=2) 44 | img = np.dstack([img, img, img]) 45 | else: 46 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 47 | 48 | # size, size, chan -> chan, size, size 49 | img = np.transpose(img, axes=[2, 0, 1]) 50 | 51 | return img 52 | 53 | def __getitem__(self, index): 54 | row = self.df.iloc[index] 55 | img = self._load_image(row['image'], self.size) 56 | label = row['label'] 57 | 58 | if self.augment is not None: 59 | img = self.augment(img) 60 | 61 | return img, label 62 | 63 | def __len__(self): 64 | return self.df.shape[0] -------------------------------------------------------------------------------- /datasets/covid_chestxray_dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import pandas as pd 3 | import numpy as np 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | class COVIDChestXRayDataset(Dataset): 8 | def __init__(self, path, size=128, augment=None): 9 | super(COVIDChestXRayDataset, self).__init__() 10 | print('{} initialized with size={}, augment={}'.format(self.__class__.__name__, size, augment)) 11 | print('Dataset is located in {}'.format(path)) 12 | self.size = size 13 | self.augment = augment 14 | 15 | image_dir = path / 'images' 16 | metadata_path = path / 'metadata.csv' 17 | 18 | df_metadata = pd.read_csv(metadata_path, header=0) 19 | # Drop CT scans 20 | df_metadata = df_metadata[df_metadata['modality'] == 'X-ray'] 21 | # Keep only PA/AP/AP Supine, drop Axial, L (lateral) 22 | allowed_views = ['PA', 'AP', 'AP Supine'] 23 | df_metadata = df_metadata[df_metadata['view'].isin(allowed_views)] 24 | 25 | # COVID-19 = 1, SARS/ARDS/Pneumocystis/Streptococcus/No finding = 0 26 | self.labels = (df_metadata.finding == 'COVID-19').values.reshape(-1, 1) 27 | images = df_metadata.filename 28 | images = images.apply(lambda x: image_dir / x).values.reshape(-1, 1) 29 | 30 | self.df = pd.DataFrame(np.concatenate((images, self.labels), axis=1), columns=['image', 'label']) 31 | 32 | del images 33 | 34 | print("Dataset: {}".format(self.df)) 35 | 36 | 37 | @staticmethod 38 | def _load_image(path, size): 39 | img = Image.open(path) 40 | img = cv2.resize(np.array(img), (size, size), interpolation=cv2.INTER_AREA) 41 | if len(img.shape) == 2: 42 | img = np.expand_dims(img, axis=2) 43 | img = np.dstack([img, img, img]) 44 | else: 45 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 46 | 47 | # size, size, chan -> chan, size, size 48 | img = np.transpose(img, axes=[2, 0, 1]) 49 | 50 | return img 51 | 52 | def __getitem__(self, index): 53 | row = self.df.iloc[index] 54 | img = self._load_image(row['image'], self.size) 55 | label = row['label'] 56 | 57 | if self.augment is not None: 58 | img = self.augment(img) 59 | 60 | return img, label 61 | 62 | def __len__(self): 63 | return self.df.shape[0] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # defeatcovid19-net-pytorch 2 | 3 | This repo provides a Pytorch solution for predictions on X-ray images for COVID-19 patients. 4 | 5 | ## Motivation 6 | It is intended to be used as a template for **defeatcovid19** group partecipants who like to contribute. You can find more info on our group's effort [here](https://github.com/defeatcovid19/defeatcovid19-project). At the moment we're actively trying to contact local hospitals to collect radiologic (mainly XRay and Eco) images to build a robust dataset for deep learning training. 7 | 8 | ## Implementation 9 | 10 | The network of choice is ResNet34, provided by torchvision and pretrained on Imagenet. 11 | The net is first trained on the [Kaggle Chest X-Ray Pneumonia dataset](https://www.kaggle.com/paultimothymooney/chest-xray-pneumonia) (5856 images) and then on the [COVID-19 Chest X-Ray dataset](https://github.com/ieee8023/covid-chestxray-dataset) (123 usable images). 12 | 13 | Axial and lateral images were removed from the latter dataset. COVID-19 diagnoses were labelled 1, 0 otherwise (SARS/ARDS/Pneumocystis/Streptococcus/No finding). 14 | 15 | ### Requirements 16 | An `environment.yml` file is provided to list the package requirements (mainly numpy, pandas, opencv, torch). The train entrypoint expects to find the aforementioned datasets in `./input`. Adjust your paths accordingly. 17 | 18 | 19 | ### Training 20 | You can train the network and see the results of the cross validation with 21 | ``` 22 | python train.py 23 | ``` 24 | 25 | ## Running with Docker 26 | 27 | ### Requirements 28 | NVIDIA Driver Installation 29 | [Docker installation](https://docs.docker.com/install/linux/docker-ce/ubuntu/) 30 | [NVIDIA Docker installation](https://github.com/NVIDIA/nvidia-docker) 31 | 32 | ### Build docker image 33 | From the root of the repository (the image takes several minutes to build, due to download and compilation): 34 | ``` 35 | source tools/docker/setup.sh 36 | ``` 37 | Or if you are using shell fish: 38 | ``` 39 | source tools/docker/setup.fish 40 | ``` 41 | For running the training process: 42 | ``` 43 | dkrun train.py 44 | ``` 45 | ## Results (initial) 46 | The first part of the training (on the "Pneumonia" dataset) uses a simple 80/20 train/valid split. It achieves a ROC AUC score close to 1 for the selected fold. 47 | The second part of the training (on the "COVID" dataset) uses a more robust 5-fold cross validation and it results in a ~0.77 ROC AUC score. 48 | 49 | 50 | ## Citations 51 | - Paul Mooney, Chest X-Ray Images (Pneumonia), Kaggle dataset, https://www.kaggle.com/paultimothymooney/chest-xray-pneumonia, 2018 52 | - Joseph Paul Cohen, COVID-19 image data collection, https://github.com/ieee8023/covid-chestxray-dataset, 2020 53 | 54 | ## License 55 | 56 | This repo serves as a template for future effort of the **defeatcovid19** group and as such is intended to be released under the MIT license. 57 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.4-cuda10.1-cudnn7-devel 2 | 3 | ENV TZ=Europe/Rome 4 | ENV DEBIAN_FRONTEND=noninteractive 5 | RUN apt-get update && apt-get install -y --no-install-recommends \ 6 | build-essential \ 7 | cmake \ 8 | git \ 9 | libavcodec-dev \ 10 | libavformat-dev \ 11 | libswscale-dev \ 12 | libgstreamer-plugins-base1.0-dev libgstreamer1.0-dev \ 13 | libgtk2.0-dev \ 14 | libgtk-3-dev \ 15 | libpng-dev \ 16 | libjpeg-dev \ 17 | libopenexr-dev \ 18 | libtiff-dev \ 19 | libtbb2 \ 20 | libtbb-dev \ 21 | libwebp-dev \ 22 | qtbase5-dev \ 23 | qtdeclarative5-dev \ 24 | qttools5-dev \ 25 | python3-setuptools \ 26 | python3-pip \ 27 | git \ 28 | wget \ 29 | unzip \ 30 | yasm \ 31 | cython \ 32 | && rm -rf /var/lib/apt/lists/* 33 | 34 | RUN conda install Cython numpy=1.18.1 scipy=1.4.1 matplotlib scikit-learn=0.22.1 pandas=1.0.3 35 | 36 | WORKDIR / 37 | ENV OPENCV_VERSION="4.2.0" 38 | RUN wget https://github.com/opencv/opencv/archive/${OPENCV_VERSION}.zip \ 39 | && unzip ${OPENCV_VERSION}.zip \ 40 | && mkdir /opencv-${OPENCV_VERSION}/cmake_binary \ 41 | && cd /opencv-${OPENCV_VERSION}/cmake_binary \ 42 | && cmake -DBUILD_TIFF=ON \ 43 | -DBUILD_opencv_java=OFF \ 44 | -DWITH_CUDA=OFF \ 45 | -DWITH_OPENGL=ON \ 46 | -DWITH_OPENCL=ON \ 47 | -DWITH_IPP=ON \ 48 | -DWITH_TBB=ON \ 49 | -DWITH_EIGEN=ON \ 50 | -DWITH_V4L=ON \ 51 | -DBUILD_TESTS=OFF \ 52 | -DBUILD_PERF_TESTS=OFF \ 53 | -DCMAKE_BUILD_TYPE=RELEASE \ 54 | -DCMAKE_INSTALL_PREFIX=$(python3.7 -c "import sys; print(sys.prefix)") \ 55 | -DPYTHON_EXECUTABLE=$(which python3.7) \ 56 | -DPYTHON_INCLUDE_DIR=$(python3.7 -c "from distutils.sysconfig import get_python_inc; print(get_python_inc())") \ 57 | -DPYTHON_PACKAGES_PATH=$(python3.7 -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())") \ 58 | .. \ 59 | && make -j install \ 60 | && rm /${OPENCV_VERSION}.zip \ 61 | && rm -r /opencv-${OPENCV_VERSION} 62 | 63 | RUN mkdir -p /root/.cache/torch/checkpoints/ && \ 64 | wget https://download.pytorch.org/models/resnet34-333f7ec4.pth -P /root/.cache/torch/checkpoints/ -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: defeatcovi19 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - blas=1.0=mkl 9 | - bzip2=1.0.8=h516909a_2 10 | - ca-certificates=2019.11.28=hecc5488_0 11 | - cairo=1.16.0=hcf35c78_1003 12 | - certifi=2019.11.28=py38h32f6830_1 13 | - cudatoolkit=10.1.243=h6bb024c_0 14 | - dbus=1.13.6=he372182_0 15 | - expat=2.2.9=he1b5a44_2 16 | - ffmpeg=4.1.3=h167e202_0 17 | - fontconfig=2.13.1=h86ecdb6_1001 18 | - freetype=2.9.1=h8a8886c_1 19 | - gettext=0.19.8.1=hc5be6a0_1002 20 | - giflib=5.2.1=h516909a_2 21 | - glib=2.58.3=py38h73cb85d_1003 22 | - gmp=6.2.0=he1b5a44_2 23 | - gnutls=3.6.5=hd3a4fd2_1002 24 | - graphite2=1.3.13=he1b5a44_1001 25 | - gst-plugins-base=1.14.5=h0935bb2_2 26 | - gstreamer=1.14.5=h36ae1b5_2 27 | - harfbuzz=2.4.0=h9f30f68_3 28 | - hdf5=1.10.5=nompi_h3c11f04_1104 29 | - icu=64.2=he1b5a44_1 30 | - intel-openmp=2020.0=166 31 | - jasper=1.900.1=h07fcdf6_1006 32 | - joblib=0.14.1=py_0 33 | - jpeg=9c=h14c3975_1001 34 | - lame=3.100=h14c3975_1001 35 | - ld_impl_linux-64=2.33.1=h53a641e_7 36 | - libblas=3.8.0=15_mkl 37 | - libcblas=3.8.0=15_mkl 38 | - libclang=9.0.1=default_hde54327_0 39 | - libedit=3.1.20181209=hc058e9b_0 40 | - libffi=3.2.1=hd88cf55_4 41 | - libgcc-ng=9.1.0=hdf63c60_0 42 | - libgfortran-ng=7.3.0=hdf63c60_0 43 | - libiconv=1.15=h516909a_1006 44 | - liblapack=3.8.0=15_mkl 45 | - liblapacke=3.8.0=15_mkl 46 | - libllvm9=9.0.1=hc9558a2_0 47 | - libopencv=4.2.0=py38_2 48 | - libpng=1.6.37=hbc83047_0 49 | - libstdcxx-ng=9.1.0=hdf63c60_0 50 | - libtiff=4.1.0=h2733197_0 51 | - libuuid=2.32.1=h14c3975_1000 52 | - libwebp=1.0.2=h56121f0_5 53 | - libxcb=1.13=h14c3975_1002 54 | - libxkbcommon=0.10.0=he1b5a44_0 55 | - libxml2=2.9.10=hee79883_0 56 | - mkl=2020.0=166 57 | - mkl-service=2.3.0=py38he904b0f_0 58 | - mkl_fft=1.0.15=py38ha843d7b_0 59 | - mkl_random=1.1.0=py38h962f231_0 60 | - ncurses=6.2=he6710b0_0 61 | - nettle=3.4.1=h1bed415_1002 62 | - ninja=1.9.0=py38hfd86e86_0 63 | - nspr=4.25=he1b5a44_0 64 | - nss=3.47=he751ad9_0 65 | - numpy=1.18.1=py38h4f9e942_0 66 | - numpy-base=1.18.1=py38hde5b4d6_1 67 | - olefile=0.46=py_0 68 | - opencv=4.2.0=py38_2 69 | - openh264=1.8.0=hdbcaa40_1000 70 | - openssl=1.1.1e=h516909a_0 71 | - pandas=1.0.3=py38h0573a6f_0 72 | - pcre=8.44=he1b5a44_0 73 | - pillow=7.0.0=py38hb39fc2d_0 74 | - pip=20.0.2=py38_1 75 | - pixman=0.38.0=h516909a_1003 76 | - pthread-stubs=0.4=h14c3975_1001 77 | - py-opencv=4.2.0=py38h5ca1d4c_2 78 | - python=3.8.2=h191fe78_0 79 | - python-dateutil=2.8.1=py_0 80 | - python_abi=3.8=1_cp38 81 | - pytorch=1.4.0=py3.8_cuda10.1.243_cudnn7.6.3_0 82 | - pytz=2019.3=py_0 83 | - qt=5.12.5=hd8c4c69_1 84 | - readline=7.0=h7b6447c_5 85 | - scikit-learn=0.22.1=py38hd81dba3_0 86 | - scipy=1.4.1=py38h0b6359f_0 87 | - setuptools=46.1.1=py38_0 88 | - six=1.14.0=py38_0 89 | - sqlite=3.31.1=h7b6447c_0 90 | - tk=8.6.8=hbc83047_0 91 | - torchvision=0.5.0=py38_cu101 92 | - wheel=0.34.2=py38_0 93 | - x264=1!152.20180806=h14c3975_0 94 | - xorg-kbproto=1.0.7=h14c3975_1002 95 | - xorg-libice=1.0.10=h516909a_0 96 | - xorg-libsm=1.2.3=h84519dc_1000 97 | - xorg-libx11=1.6.9=h516909a_0 98 | - xorg-libxau=1.0.9=h14c3975_0 99 | - xorg-libxdmcp=1.1.3=h516909a_0 100 | - xorg-libxext=1.3.4=h516909a_0 101 | - xorg-libxrender=0.9.10=h516909a_1002 102 | - xorg-renderproto=0.11.1=h14c3975_1002 103 | - xorg-xextproto=7.3.0=h14c3975_1002 104 | - xorg-xproto=7.0.31=h14c3975_1007 105 | - xz=5.2.4=h14c3975_4 106 | - zlib=1.2.11=h7b6447c_3 107 | - zstd=1.3.7=h0b5b093_0 108 | 109 | -------------------------------------------------------------------------------- /datasets/nih_cx38_dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import pandas as pd 3 | import numpy as np 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | class NIHCX38Dataset(Dataset): 8 | ''' Integrates the National Institutes of Health Clinical Center chest x-ray dataset. 9 | 10 | Dataset description: https://www.nih.gov/news-events/news-releases/nih-clinical-center-provides-one-largest-publicly-available-chest-x-ray-datasets-scientific-community 11 | Download the dataset from https://nihcc.app.box.com/v/ChestXray-NIHCC to the folder input/nih-cx38 12 | Extract all image_??.tar.gz to the input/nih-cx38/images/ folder and ensure input/nih-cx38/Data_Entry_2017.csv is present. 13 | ''' 14 | 15 | def __init__(self, path, size=128, augment=None, balance=False): 16 | super(NIHCX38Dataset, self).__init__() 17 | print('{} initialized with size={}, augment={}'.format(self.__class__.__name__, size, augment)) 18 | print('Dataset is located in {}'.format(path)) 19 | self.size = size 20 | self.augment = augment 21 | 22 | image_dir = path / 'images' 23 | metadata_path = path / 'Data_Entry_2017.csv' 24 | 25 | df_metadata = pd.read_csv(metadata_path) 26 | df_metadata['labels'] = df_metadata['Finding Labels'].str.split('|') 27 | 28 | # Pneumonia = 1, no Pneumonia = 0 29 | finding_mask = lambda df, finding: df['labels'].apply(lambda l: finding in l) 30 | pneumonia_mask = finding_mask(df_metadata, 'Pneumonia') 31 | 32 | if balance: 33 | pneumonia_indices = np.arange(len(pneumonia_mask))[pneumonia_mask] 34 | normal_indices = np.arange(len(pneumonia_mask))[~pneumonia_mask] 35 | if len(pneumonia_indices) < len(normal_indices): 36 | normal_indices = np.random.choice(normal_indices, len(pneumonia_indices)) 37 | else: 38 | pneumonia_indices = np.random.choice(pneumonia_indices, len(normal_indices)) 39 | self.labels = np.concatenate(( 40 | np.zeros(len(normal_indices)), 41 | np.ones(len(pneumonia_indices)) 42 | )).reshape(-1, 1) 43 | images = df_metadata['Image Index'][np.concatenate((normal_indices, pneumonia_indices))] 44 | else: 45 | self.labels = pneumonia_mask.values.reshape(-1, 1) 46 | images = df_metadata['Image Index'] 47 | 48 | images = images.apply(lambda x: image_dir / x).values.reshape(-1, 1) 49 | self.df = pd.DataFrame(np.concatenate((images, self.labels), axis=1), columns=['image', 'label']) 50 | 51 | del images 52 | 53 | print("Dataset: {}".format(self.df)) 54 | print(" Number of positive cases: {}".format(sum(self.labels))) 55 | print(" Number of negative cases: {}".format(len(self.labels) - sum(self.labels))) 56 | 57 | 58 | @staticmethod 59 | def _load_image(path, size): 60 | img = Image.open(path) 61 | img = cv2.resize(np.array(img), (size, size), interpolation=cv2.INTER_AREA) 62 | if len(img.shape) == 2: 63 | img = np.expand_dims(img, axis=2) 64 | img = np.dstack([img, img, img]) 65 | else: 66 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 67 | 68 | # size, size, chan -> chan, size, size 69 | img = np.transpose(img, axes=[2, 0, 1]) 70 | 71 | return img 72 | 73 | def __getitem__(self, index): 74 | row = self.df.iloc[index] 75 | img = self._load_image(row['image'], self.size) 76 | label = row['label'] 77 | 78 | if self.augment is not None: 79 | img = self.augment(img) 80 | 81 | return img, label 82 | 83 | def __len__(self): 84 | return self.df.shape[0] 85 | -------------------------------------------------------------------------------- /trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.optim as optim 5 | import torch.nn as nn 6 | from timeit import default_timer as timer 7 | from torch.utils.data import Dataset 8 | from sklearn.model_selection import train_test_split 9 | from sklearn.metrics import roc_auc_score 10 | from dataloaders import SubsetRandomDataLoader 11 | from metrics import Accuracy 12 | 13 | class Trainer: 14 | def __init__(self, classifier, dataset, batch_size, train_idx, validation_idx): 15 | self.classifier = classifier 16 | self.batch_size = batch_size 17 | print('Trainer started with classifier: {} dataset: {} batch size: {}'.format(classifier.__class__.__name__, dataset, batch_size)) 18 | 19 | self.optimizer = None 20 | self.scheduler = None 21 | 22 | self.train_dataset = dataset 23 | self.validation_dataset = dataset 24 | 25 | 26 | # train_idx, validation_idx = train_test_split( 27 | # list(range(len(self.train_dataset))), 28 | # test_size=0.2, 29 | # stratify=self.train_dataset.labels 30 | # ) 31 | 32 | self.train_loader = SubsetRandomDataLoader(dataset, train_idx, batch_size) 33 | self.validation_loader = SubsetRandomDataLoader(dataset, validation_idx, batch_size) 34 | 35 | print('Train set: {}'.format(len(train_idx))) 36 | print('Validation set: {}'.format(len(validation_idx))) 37 | 38 | self.it_per_epoch = math.ceil(len(train_idx) / self.batch_size) 39 | print('Training with {} mini-batches per epoch'.format(self.it_per_epoch)) 40 | 41 | 42 | def run(self, max_epochs=10, lr=0.01): 43 | self.classifier = self.classifier.cuda() 44 | model = self.classifier 45 | 46 | it = 0 47 | epoch = 0 48 | it_save = self.it_per_epoch * 5 49 | it_log = math.ceil(self.it_per_epoch / 5) 50 | it_smooth = self.it_per_epoch 51 | print("Logging performance every {} iter, smoothing every: {} iter".format(it_log, it_smooth)) 52 | 53 | self.optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, momentum=0.9, weight_decay=0.0001) 54 | self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 2 * self.it_per_epoch, gamma=0.9) 55 | 56 | criterion = nn.BCELoss() 57 | criterion = criterion.cuda() 58 | metrics = [Accuracy(), roc_auc_score] 59 | 60 | print("{}'".format(self.optimizer)) 61 | print("{}'".format(self.scheduler)) 62 | print("{}'".format(criterion)) 63 | print("{}'".format(metrics)) 64 | 65 | train_loss = 0 66 | train_roc = 0 67 | train_acc = 0 68 | 69 | print(' | VALID | TRAIN | ') 70 | print(' lr iter epoch | loss roc acc | loss roc acc | time ') 71 | print('------------------------------------------------------------------------------') 72 | 73 | start = timer() 74 | while epoch < max_epochs: 75 | epoch_labels = [] 76 | epoch_preds = [] 77 | epoch_losses = [] 78 | epoch_loss_weights = [] 79 | 80 | for inputs, labels in self.train_loader: 81 | epoch = (it + 1) / self.it_per_epoch 82 | 83 | lr = self.scheduler.get_last_lr()[0] 84 | 85 | # checkpoint 86 | if it % it_save == 0 and it != 0: 87 | self.save(model, self.optimizer, it, epoch) 88 | 89 | # training 90 | 91 | model.train() 92 | inputs = inputs.cuda().float() 93 | labels = labels.cuda().float() 94 | 95 | preds = model(inputs) 96 | loss = criterion(preds, labels) 97 | 98 | epoch_labels.append(labels.cpu()) 99 | epoch_preds.append(preds.cpu()) 100 | epoch_losses.append(loss.data.cpu().numpy()) 101 | epoch_loss_weights.append(len(preds)) 102 | 103 | self.optimizer.zero_grad() 104 | loss.backward() 105 | 106 | nn.utils.clip_grad_norm_(model.parameters(), 1.0) 107 | self.optimizer.step() 108 | self.scheduler.step() 109 | 110 | if it % it_log == 0: 111 | print( 112 | "{:5f} {:4d} {:5.1f} | | | {:6.2f}".format( 113 | lr, it, epoch, timer() - start 114 | )) 115 | 116 | it += 1 117 | 118 | # loss and metrics 119 | with torch.no_grad(): 120 | epoch_labels = torch.cat(epoch_labels) 121 | epoch_preds = torch.cat(epoch_preds) 122 | train_acc, train_roc = [i(epoch_labels.cpu(), epoch_preds.cpu()).item() for i in metrics] 123 | 124 | train_loss = np.average(epoch_losses, weights=epoch_loss_weights) 125 | 126 | # validation 127 | valid_loss, valid_m = self.do_valid(model, criterion, metrics) 128 | valid_acc, valid_roc = valid_m 129 | 130 | print( 131 | "{:5f} {:4d} {:5.1f} | {:0.3f}* {:0.3f} {:0.3f} | {:0.3f}* {:0.3f} {:0.3f} | {:6.2f}".format( 132 | lr, it, epoch, valid_loss, valid_roc, valid_acc, train_loss, train_roc, train_acc, timer() - start 133 | )) 134 | 135 | # Data loader end 136 | # Training end 137 | 138 | self.save(model, self.optimizer, it, epoch) 139 | 140 | def do_valid(self, model, criterion, metrics): 141 | model.eval() 142 | valid_num = 0 143 | valid_labels = [] 144 | valid_preds = [] 145 | losses = [] 146 | loss_weights = [] 147 | 148 | for inputs, labels in self.validation_loader: 149 | inputs = inputs.cuda().float() 150 | labels = labels.cuda().float() 151 | 152 | with torch.no_grad(): 153 | preds = model(inputs) 154 | loss = criterion(preds, labels) 155 | 156 | valid_num += len(inputs) 157 | valid_labels.append(labels.cpu()) 158 | valid_preds.append(preds.cpu()) 159 | losses.append(loss.data.cpu().numpy()) 160 | loss_weights.append(len(inputs)) 161 | 162 | assert (valid_num == len(self.validation_loader.sampler)) 163 | 164 | with torch.no_grad(): 165 | valid_labels = torch.cat(valid_labels) 166 | valid_preds = torch.cat(valid_preds) 167 | m = [i(valid_labels.cpu(), valid_preds.cpu()).item() for i in metrics] 168 | 169 | loss = np.average(losses, weights=loss_weights) 170 | return loss, m 171 | 172 | def save(self, model, optimizer, iter, epoch): 173 | torch.save(model.state_dict(), "checkpoints/{}_model.pth".format(iter)) 174 | torch.save({ 175 | "optimizer": optimizer.state_dict(), 176 | "iter": iter, 177 | "epoch": epoch 178 | }, "checkpoints/{}_optimizer.pth".format(iter)) 179 | --------------------------------------------------------------------------------