├── .gitignore ├── LICENSE ├── README.md ├── assets └── covid_example.jpg ├── config.py ├── data ├── dataset.py └── transforms.py ├── experiments └── README.md ├── generate_dataset.py ├── minimal_prediction.py ├── model ├── __init__.py ├── architecture.py └── layers.py ├── requirements.txt ├── train.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | experiments/* 2 | !experiments/README.md 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | 138 | # pytype static type analyzer 139 | .pytype/ 140 | 141 | # Cython debug symbols 142 | cython_debug/ 143 | 144 | # static files generated from Django application using `collectstatic` 145 | media 146 | static 147 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Velebit AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # COVID-Next → Pytorch upgrade of the COVID-Net 2 | 3 | Inspired by the recent paper [COVID-Net: A Tailored Deep Convolutional Neural Network Design for Detection of COVID-19 Cases from Chest Radiography Images](https://arxiv.org/pdf/2003.09871.pdf) and its Tensorflow [implementation](https://github.com/lindawangg/COVID-Net), we are now open sourcing the upgraded Pytorch version called COVID-Next. 4 | 5 | COVID-Next features an architecture that builds upon the famous ResNext50 architecture, which has around **5x** less parameters than the original COVID-Net, and achieves comparable performance. 6 | 7 | Tensorflow and Pytorch are two major deep learning frameworks and our motivation was to give the Pytorch research community the same starting ground Tensorflow already has when it comes to AI COVID-19 research. As the authors from the paper have already mentioned, this **model still doesn't offer production ready performance**. The key issue that needs to be resolved is the number of COVID-19 images as the number of such images is currently **not diverse and large enough** to provide representative prediction results end-users could expect in the production system. 8 | 9 | ## Requirements 10 | 11 | As always, we recommend [virtual environments](https://docs.python.org/3/tutorial/venv.html) where you install all requirements separately from your system ones. This step is optional :) 12 | 13 | To install all requirements, simply run `pip3 install -r requirements.txt`. 14 | Code was tested with Python 3.6.9. 15 | 16 | ## Pretrained model 17 | 18 | Download the pretrained COVID-Next model from [here](https://drive.google.com/open?id=1G8vQKBObt52b4qe5cQdoQkdPxjZK3ucI). 19 | 20 | ## Training 21 | 22 | Training configuration is currently modified through the `config.py` module. Check it out before starting training. 23 | 24 | `python3 train.py` command will run model training. 25 | 26 | ### Dataset 27 | 28 | We have created a script that automates the dataset generation from the two sources referenced in the original repo. To generate the dataset, follow these steps: 29 | 30 | 1. Download the datasets listed below: 31 | * COVID ChestXray [dataset](https://github.com/ieee8023/covid-chestxray-dataset.git). Be aware this repository is constantly adding new images. 32 | * Kaggle RSNA pneumonia [dataset](https://www.kaggle.com/c/rsna-pneumonia-detection-challenge/data) 33 | 2. Run the `generate_dataset.py` script. Run `python3 generate_dataset.py -h` to see supported arguments. 34 | 35 | The script will create a new folder with `train` and `test` subfolders where images are located, along with the two metadata files for both train and test subsets. 36 | 37 | 38 | ### Note 39 | 40 | IO will probably be a bottleneck during training because most of the images are large and a lot of time is wasted on loading images into memory. To avoid this issue, we suggest downscaling images beforehand to input size used by the model. 41 | 42 | You can also try to increase the `config.n_threads` to alleviate this issue but beware that increasing the number of threads will result in increased memory usage. 43 | 44 | ## Results 45 | 46 | The following results were obtained on the dataset used in the original repo as of March 20 2020. 47 | | | Accuracy | F1 Macro | Precision Macro | Recall Macro | 48 | |:-----------------:|:--------:|:--------:|:---------------:|:--------------:| 49 | | COVID-Net (Large) | 91.90% | 91.39% | 91.4% | 91.33% | 50 | | **COVID-Next** | 94.76% | 92.98% | 96.40% | 90.33% | 51 | 52 | ### Minimal prediction example 53 | 54 | You can find the minimal prediction example in `minimal_prediction.py`. 55 | The example demonstrates how to load the model and use it to predict the disease type on the image. 56 | 57 | ## Upgrades 58 | 59 | * [x] Training image augmentations 60 | * [x] Pretrained model 61 | * [x] Minimal prediction example 62 | * [x] Loss weights 63 | * [x] Automated dataset generation 64 | * [ ] Define train, validation, and test data splits for more proper model evaluation. 65 | * [ ] Tensorboard Logging 66 | * [ ] Smart sampling 67 | -------------------------------------------------------------------------------- /assets/covid_example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/velebit-ai/COVID-Next-Pytorch/73419cc698aabc3ed9ef0942aae6781f9a146d74/assets/covid_example.jpg -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # General 2 | name = "COVIDNext50_NewData" 3 | gpu = True 4 | batch_size = 64 5 | n_threads = 20 6 | random_seed = 1337 7 | 8 | # Model 9 | # Model weights path 10 | weights = "./experiments/ckpts/" 11 | 12 | # Optimizer 13 | lr = 1e-4 14 | weight_decay = 1e-3 15 | lr_reduce_factor = 0.7 16 | lr_reduce_patience = 5 17 | 18 | # Data 19 | train_imgs = "/data/ssd/datasets/covid/COVIDxV2/data/train" 20 | train_labels = "/data/ssd/datasets/covid/COVIDxV2/data/train_COVIDx.txt" 21 | 22 | val_imgs = "/data/ssd/datasets/covid/COVIDxV2/data/test" 23 | val_labels = "/data/ssd/datasets/covid/COVIDxV2/data/test_COVIDx.txt" 24 | 25 | # Categories mapping 26 | mapping = { 27 | 'normal': 0, 28 | 'pneumonia': 1, 29 | 'COVID-19': 2 30 | } 31 | # Loss weigths order follows the order in the category mapping dict 32 | loss_weights = [0.05, 0.05, 1.0] 33 | 34 | width = 256 35 | height = 256 36 | n_classes = len(mapping) 37 | 38 | # Training 39 | epochs = 300 40 | log_steps = 200 41 | eval_steps = 400 42 | ckpts_dir = "./experiments/ckpts" 43 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torch.utils.data import Dataset 4 | import torch 5 | from PIL import Image 6 | 7 | from config import mapping 8 | 9 | 10 | class COVIDxFolder(Dataset): 11 | def __init__(self, img_dir, labels_file, transforms): 12 | self.img_pths, self.labels = self._prepare_data(img_dir, labels_file) 13 | self.transforms = transforms 14 | 15 | def _prepare_data(self, img_dir, labels_file): 16 | with open(labels_file, 'r') as f: 17 | labels_raw = f.readlines() 18 | 19 | labels, img_pths = [], [] 20 | for i in range(len(labels_raw)): 21 | data = labels_raw[i].split() 22 | img_name = data[1] 23 | img_pth = os.path.join(img_dir, img_name) 24 | img_pths.append(img_pth) 25 | labels.append(mapping[data[2]]) 26 | 27 | return img_pths, labels 28 | 29 | def __len__(self): 30 | return len(self.labels) 31 | 32 | def __getitem__(self, idx): 33 | img = Image.open(self.img_pths[idx]).convert("RGB") 34 | img_tensor = self.transforms(img) 35 | 36 | label = self.labels[idx] 37 | label_tensor = torch.tensor(label, dtype=torch.long) 38 | 39 | return img_tensor, label_tensor 40 | -------------------------------------------------------------------------------- /data/transforms.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | 3 | 4 | def train_transforms(width, height): 5 | trans_list = [ 6 | transforms.Resize((height, width)), 7 | transforms.RandomVerticalFlip(p=0.5), 8 | transforms.RandomHorizontalFlip(p=0.5), 9 | transforms.RandomApply([ 10 | transforms.RandomAffine(degrees=20, 11 | translate=(0.15, 0.15), 12 | scale=(0.8, 1.2), 13 | shear=5)], p=0.5), 14 | transforms.RandomApply([ 15 | transforms.ColorJitter(brightness=0.3, contrast=0.3)], p=0.5), 16 | transforms.ToTensor() 17 | ] 18 | return transforms.Compose(trans_list) 19 | 20 | 21 | def val_transforms(width, height): 22 | trans_list = [ 23 | transforms.Resize((height, width)), 24 | transforms.ToTensor() 25 | ] 26 | return transforms.Compose(trans_list) 27 | -------------------------------------------------------------------------------- /experiments/README.md: -------------------------------------------------------------------------------- 1 | # Experiment files 2 | 3 | All experiment related files such as checkpoints, logs, and configs go here. 4 | -------------------------------------------------------------------------------- /generate_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generates the COVIDx dataset from the following sources: 3 | * https://github.com/ieee8023/covid-chestxray-dataset.git 4 | * https://www.kaggle.com/c/rsna-pneumonia-detection-challenge/data 5 | 6 | Code inspired by: 7 | https://github.com/lindawangg/COVID-Net/blob/master/create_COVIDx_v2.ipynb 8 | """ 9 | 10 | import logging 11 | import os 12 | from shutil import copyfile 13 | import argparse 14 | 15 | import numpy as np 16 | import pandas as pd 17 | import pydicom as dicom 18 | from PIL import Image 19 | 20 | log = logging.getLogger(__name__) 21 | logging.basicConfig(level=logging.INFO) 22 | 23 | 24 | def write_metadata(pth, data): 25 | with open(pth, "w") as file: 26 | for patient_id, filename, category in data: 27 | info = "{} {} {}\n".format(patient_id, filename, category) 28 | file.write(info) 29 | 30 | 31 | def main(args): 32 | train = [] 33 | test = [] 34 | test_count = {'normal': 0, 'pneumonia': 0, 'COVID-19': 0} 35 | train_count = {'normal': 0, 'pneumonia': 0, 'COVID-19': 0} 36 | 37 | # Create export test and train dirs 38 | TEST_EXPORT = os.path.join(args.save_path, 'test') 39 | os.makedirs(TEST_EXPORT, exist_ok=True) 40 | TRAIN_EXPORT = os.path.join(args.save_path, 'train') 41 | os.makedirs(TRAIN_EXPORT, exist_ok=True) 42 | 43 | mapping = dict() 44 | mapping['COVID-19'] = 'COVID-19' 45 | mapping['SARS'] = 'pneumonia' 46 | mapping['MERS'] = 'pneumonia' 47 | mapping['Streptococcus'] = 'pneumonia' 48 | mapping['No Finding'] = 'normal' 49 | mapping['Lung Opacity'] = 'pneumonia' 50 | mapping['1'] = 'pneumonia' 51 | 52 | covid_imgs = os.path.join(args.covid_dir, "images") 53 | covid_csv = os.path.join(args.covid_dir, "metadata.csv") 54 | 55 | csv = pd.read_csv(covid_csv, nrows=None) 56 | csv = csv[csv["view"] == "PA"] 57 | log.info("Metadata contains {} items with PA".format(len(csv))) 58 | 59 | pneumonias = ["COVID-19", "SARS", "MERS", "ARDS", "Streptococcus"] 60 | pathologies = ["Pneumonia", "Viral Pneumonia", "Bacterial Pneumonia", 61 | "No Finding"] + pneumonias 62 | pathologies = sorted(pathologies) 63 | 64 | filename_label = {'normal': [], 'pneumonia': [], 'COVID-19': []} 65 | count = {'normal': 0, 'pneumonia': 0, 'COVID-19': 0} 66 | for row in csv.itertuples(): 67 | f = row.finding.split('/')[-1] 68 | if f in mapping: 69 | count[mapping[f]] += 1 70 | entry = [row.patientid, row.filename, mapping[f]] 71 | filename_label[mapping[f]].append(entry) 72 | 73 | log.info('Data distribution from covid-chestxray-dataset:') 74 | log.info(count) 75 | 76 | # add covid-chestxray-dataset into COVIDx dataset 77 | for key in filename_label.keys(): 78 | arr = np.array(filename_label[key]) 79 | if arr.size == 0: 80 | continue 81 | 82 | # Randomly sample test set patients 83 | patient_ids = np.unique(arr[:, 0]) 84 | test_size = int(len(patient_ids) * args.test_size) 85 | test_patients = np.random.choice(patient_ids, test_size, replace=False) 86 | log.info('Category: {}, N test patients {}'.format(key, test_size)) 87 | 88 | # go through all the patients 89 | for patient in arr: 90 | src_img_pth = os.path.join(covid_imgs, patient[1]) 91 | if patient[0] in test_patients: 92 | dst_img_pth = os.path.join(TEST_EXPORT, patient[1]) 93 | copyfile(src_img_pth, dst_img_pth) 94 | test.append(patient) 95 | test_count[patient[2]] += 1 96 | else: 97 | dst_img_pth = os.path.join(TRAIN_EXPORT, patient[1]) 98 | copyfile(src_img_pth, dst_img_pth) 99 | train.append(patient) 100 | train_count[patient[2]] += 1 101 | 102 | log.info('test count: {}'.format(test_count)) 103 | log.info('train count: {}'.format(train_count)) 104 | 105 | # add normal and rest of pneumonia cases from 106 | # https://www.kaggle.com/c/rsna-pneumonia-detection-challenge 107 | kaggle_csv_normal = os.path.join(args.kaggle_data, 108 | "stage_2_detailed_class_info.csv") 109 | kaggle_csv_pneu = os.path.join(args.kaggle_data, 110 | "stage_2_train_labels.csv") 111 | csv_normal = pd.read_csv(kaggle_csv_normal, nrows=None) 112 | csv_pneu = pd.read_csv(kaggle_csv_pneu, nrows=None) 113 | patients = {'normal': [], 'pneumonia': []} 114 | 115 | for row in csv_normal.itertuples(): 116 | if row[2] == 'Normal': 117 | patients['normal'].append(row.patientId) 118 | 119 | for row in csv_pneu.itertuples(): 120 | if row.Target == 1: 121 | patients['pneumonia'].append(row.patientId) 122 | 123 | log.info("Preparing Kaggle dataset...") 124 | counter = 0 125 | for key in patients.keys(): 126 | arr = np.array(patients[key]) 127 | if arr.size == 0: 128 | continue 129 | 130 | # Choose random test patients 131 | patient_ids = np.unique(arr) 132 | test_size = int(len(patient_ids) * args.test_size) 133 | test_patients = np.random.choice(patient_ids, test_size, replace=False) 134 | log.info('Category: {}, N Test examples: {}'.format(key, test_size)) 135 | 136 | for patient in arr: 137 | ds = dicom.dcmread(os.path.join(args.kaggle_data, 138 | "stage_2_train_images", 139 | patient + '.dcm')) 140 | pixel_array_numpy = ds.pixel_array 141 | imgname = patient + '.png' 142 | pil_img = Image.fromarray(pixel_array_numpy) 143 | 144 | if patient in test_patients: 145 | pil_img.save(os.path.join(TEST_EXPORT, imgname)) 146 | test.append([patient, imgname, key]) 147 | test_count[key] += 1 148 | else: 149 | pil_img.save(os.path.join(TRAIN_EXPORT, imgname)) 150 | train.append([patient, imgname, key]) 151 | train_count[key] += 1 152 | counter += 1 153 | 154 | if counter % 500 == 0 and counter > 0: 155 | log.info("Converted {} Kaggle dataset images".format(counter)) 156 | 157 | log.info('test count: {}'.format(test_count)) 158 | log.info('train count: {}'.format(train_count)) 159 | 160 | write_metadata(os.path.join(args.save_path, 'train_metadata.txt'), train) 161 | write_metadata(os.path.join(args.save_path, 'test_metadata.txt'), test) 162 | 163 | 164 | if __name__ == "__main__": 165 | np.random.seed(1337) 166 | parser = argparse.ArgumentParser() 167 | 168 | parser.add_argument( 169 | 'covid_dir', 170 | help="Path to the cloned `covid-chestxray-dataset` repo dir", 171 | type=str) 172 | parser.add_argument( 173 | 'kaggle_data', 174 | help="Path to the downloaded Kaggle dataset dir", 175 | type=str) 176 | parser.add_argument( 177 | 'save_path', 178 | help="Directory where to save the new COVIDx dataset", 179 | type=str) 180 | parser.add_argument( 181 | '--test-size', 182 | help="Test set size fraction. Defaults to 10%.", 183 | default=0.1, type=float) 184 | 185 | args = parser.parse_args() 186 | 187 | if args.test_size < 0 or args.test_size > 1: 188 | raise ValueError("Test fraction value must be in range [0, 1]") 189 | 190 | main(args) 191 | -------------------------------------------------------------------------------- /minimal_prediction.py: -------------------------------------------------------------------------------- 1 | """ 2 | Minimal prediction example 3 | """ 4 | 5 | import torch 6 | from PIL import Image 7 | 8 | from model.architecture import COVIDNext50 9 | from data.transforms import val_transforms 10 | 11 | import config 12 | 13 | rev_mapping = {idx: name for name, idx in config.mapping.items()} 14 | 15 | model = COVIDNext50(n_classes=len(rev_mapping)) 16 | 17 | ckpt_pth = './experiments/ckpts/best/' 18 | weights = torch.load(ckpt_pth)['state_dict'] 19 | model.load_state_dict(weights) 20 | model.eval() 21 | 22 | transforms = val_transforms(width=config.width, height=config.height) 23 | 24 | img_pth = 'assets/covid_example.jpg' 25 | img = Image.open(img_pth).convert("RGB") 26 | img_tensor = transforms(img).unsqueeze(0) 27 | 28 | with torch.no_grad(): 29 | logits = model(img_tensor) 30 | cat_id = int(torch.argmax(logits)) 31 | print("Prediction for {} is: {}".format(img_pth, rev_mapping[cat_id])) 32 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/velebit-ai/COVID-Next-Pytorch/73419cc698aabc3ed9ef0942aae6781f9a146d74/model/__init__.py -------------------------------------------------------------------------------- /model/architecture.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchvision import models 4 | 5 | from .layers import Trainable, ConvBn2d 6 | 7 | 8 | class COVIDNext50(nn.Module): 9 | def __init__(self, n_classes): 10 | super(COVIDNext50, self).__init__() 11 | self.n_classes = n_classes 12 | trainable = True 13 | 14 | # Layers 15 | backbone = models.resnext50_32x4d(pretrained=True) 16 | self.block0 = Trainable(nn.Sequential( 17 | backbone.conv1, 18 | backbone.bn1, 19 | backbone.relu, 20 | backbone.maxpool), 21 | trainable=trainable, 22 | name="conv1") 23 | self.block1 = Trainable(backbone.layer1, 24 | trainable=trainable, 25 | name="block1") 26 | self.block2 = Trainable(backbone.layer2, 27 | trainable=trainable, 28 | name="block2") 29 | self.block3 = Trainable(backbone.layer3, 30 | trainable=trainable, 31 | name="block3") 32 | self.block4 = Trainable(backbone.layer4, 33 | trainable=trainable, 34 | name="block4") 35 | self.backbone_end = Trainable(nn.Sequential( 36 | ConvBn2d(2048, 512, 3), 37 | ConvBn2d(512, 1024, 1), 38 | ConvBn2d(1024, 512, 3)), 39 | name="back", 40 | trainable=True) 41 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 42 | self.logits = Trainable(nn.Linear(512, n_classes), 43 | name="logits", 44 | trainable=True) 45 | 46 | def forward(self, input): 47 | net = input 48 | for layer in [self.block0, self.block1, self.block2, self.block3, 49 | self.block4]: 50 | net = layer(net) 51 | net = self.backbone_end(net) 52 | net = self.avg_pool(net) 53 | net = torch.squeeze(net) 54 | return self.logits(net) 55 | 56 | def probability(self, logits): 57 | return nn.functional.softmax(logits, dim=-1) 58 | -------------------------------------------------------------------------------- /model/layers.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class Trainable(nn.Module): 5 | """ 6 | Wraps an arbitrary module with a Trainable module. The Trainable module 7 | is used as a wrapper for freezing and thawing module layers. 8 | """ 9 | def __init__(self, module, name, trainable=True): 10 | super().__init__() 11 | self.module = module 12 | self.name = name 13 | self.trainable_switch(trainable) 14 | 15 | def __call__(self, *args, **kwargs): 16 | return self.module(*args, **kwargs) 17 | 18 | def trainable_switch(self, trainable): 19 | """ 20 | Makes module layers trainable or not. 21 | 22 | :param trainable: bool, False to freeze the layers, True to unfreeze 23 | them. 24 | """ 25 | for p in self.parameters(): 26 | p.requires_grad = trainable 27 | 28 | 29 | def ConvBn2d(in_dim, out_dim, kernel_size, 30 | activation=nn.LeakyReLU(0.1, inplace=True)): 31 | """ 32 | Wraps Conv2D, Batch Normalization 2D, and an arbitrary activation layers 33 | with a nn.Sequential layer. 34 | 35 | :param in_dim: int, Input feature map dimension 36 | :param out_dim: int, Output feature map dimension 37 | :param kernel_size: int or tuple, Convolution kernel size 38 | :return: nn.Sequential structure containing above listed network layers 39 | """ 40 | padding = kernel_size // 2 41 | net = nn.Sequential( 42 | nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, 43 | padding=padding, bias=False), 44 | nn.BatchNorm2d(out_dim), 45 | activation) 46 | return net 47 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.18.2 2 | torch==1.4.0 3 | torchvision==0.5.0 4 | scikit-learn==0.22.2.post1 5 | Pillow==8.3.2 6 | pydicom==1.4.2 7 | pandas==1.0.3 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import numpy as np 5 | from sklearn.metrics import classification_report 6 | import torch 7 | from torch.optim import Adam 8 | from torch.optim.lr_scheduler import ReduceLROnPlateau 9 | from torch.nn import CrossEntropyLoss 10 | 11 | from data.dataset import COVIDxFolder 12 | from data import transforms 13 | from torch.utils.data import DataLoader 14 | from model import architecture 15 | import util 16 | import config 17 | 18 | 19 | log = logging.getLogger(__name__) 20 | logging.basicConfig(level=logging.INFO) 21 | 22 | 23 | def save_model(model, config): 24 | if isinstance(model, torch.nn.DataParallel): 25 | # Save without the DataParallel module 26 | model_dict = model.module.state_dict() 27 | else: 28 | model_dict = model.state_dict() 29 | 30 | state = { 31 | "state_dict": model_dict, 32 | "global_step": config['global_step'], 33 | "clf_report": config['clf_report'] 34 | } 35 | f1_macro = config['clf_report']['macro avg']['f1-score'] * 100 36 | name = "{}_F1_{:.2f}_step_{}.pth".format(config['name'], 37 | f1_macro, 38 | config['global_step']) 39 | model_path = os.path.join(config['save_dir'], name) 40 | torch.save(state, model_path) 41 | log.info("Saved model to {}".format(model_path)) 42 | 43 | 44 | def validate(data_loader, model, best_score, global_step, cfg): 45 | model.eval() 46 | gts, predictions = [], [] 47 | 48 | log.info("Validation started...") 49 | for data in data_loader: 50 | imgs, labels = data 51 | imgs = util.to_device(imgs, gpu=cfg.gpu) 52 | 53 | with torch.no_grad(): 54 | logits = model(imgs) 55 | probs = model.module.probability(logits) 56 | preds = torch.argmax(probs, dim=1).cpu().numpy() 57 | 58 | labels = labels.cpu().detach().numpy() 59 | 60 | predictions.extend(preds) 61 | gts.extend(labels) 62 | 63 | predictions = np.array(predictions, dtype=np.int32) 64 | gts = np.array(gts, dtype=np.int32) 65 | acc, f1, prec, rec = util.clf_metrics(predictions=predictions, 66 | targets=gts, 67 | average="macro") 68 | report = classification_report(gts, predictions, output_dict=True) 69 | 70 | log.info("VALIDATION | Accuracy {:.4f} | F1 {:.4f} | Precision {:.4f} | " 71 | "Recall {:.4f}".format(acc, f1, prec, rec)) 72 | 73 | if f1 > best_score: 74 | save_config = { 75 | 'name': config.name, 76 | 'save_dir': config.ckpts_dir, 77 | 'global_step': global_step, 78 | 'clf_report': report 79 | } 80 | save_model(model=model, config=save_config) 81 | best_score = f1 82 | log.info("Validation end") 83 | 84 | model.train() 85 | return best_score 86 | 87 | 88 | def main(): 89 | if config.gpu and not torch.cuda.is_available(): 90 | raise ValueError("GPU not supported or enabled on this system.") 91 | use_gpu = config.gpu 92 | 93 | log.info("Loading train dataset") 94 | train_dataset = COVIDxFolder(config.train_imgs, config.train_labels, 95 | transforms.train_transforms(config.width, 96 | config.height)) 97 | train_loader = DataLoader(train_dataset, 98 | batch_size=config.batch_size, 99 | shuffle=True, 100 | drop_last=True, 101 | num_workers=config.n_threads, 102 | pin_memory=use_gpu) 103 | log.info("Number of training examples {}".format(len(train_dataset))) 104 | 105 | log.info("Loading val dataset") 106 | val_dataset = COVIDxFolder(config.val_imgs, config.val_labels, 107 | transforms.val_transforms(config.width, 108 | config.height)) 109 | val_loader = DataLoader(val_dataset, 110 | batch_size=config.batch_size, 111 | shuffle=False, 112 | num_workers=config.n_threads, 113 | pin_memory=use_gpu) 114 | log.info("Number of validation examples {}".format(len(val_dataset))) 115 | 116 | if config.weights: 117 | state = torch.load(config.weights) 118 | log.info("Loaded model weights from: {}".format(config.weights)) 119 | else: 120 | state = None 121 | 122 | state_dict = state["state_dict"] if state else None 123 | model = architecture.COVIDNext50(n_classes=config.n_classes) 124 | if state_dict: 125 | model = util.load_model_weights(model=model, state_dict=state_dict) 126 | 127 | if use_gpu: 128 | model.cuda() 129 | model = torch.nn.DataParallel(model) 130 | optim_layers = filter(lambda p: p.requires_grad, model.parameters()) 131 | 132 | # optimizer and lr scheduler 133 | optimizer = Adam(optim_layers, 134 | lr=config.lr, 135 | weight_decay=config.weight_decay) 136 | scheduler = ReduceLROnPlateau(optimizer=optimizer, 137 | factor=config.lr_reduce_factor, 138 | patience=config.lr_reduce_patience, 139 | mode='max', 140 | min_lr=1e-7) 141 | 142 | # Load the last global_step from the checkpoint if existing 143 | global_step = 0 if state is None else state['global_step'] + 1 144 | 145 | class_weights = util.to_device(torch.FloatTensor(config.loss_weights), 146 | gpu=use_gpu) 147 | loss_fn = CrossEntropyLoss(reduction='mean', weight=class_weights) 148 | 149 | # Reset the best metric score 150 | best_score = -1 151 | for epoch in range(config.epochs): 152 | log.info("Started epoch {}/{}".format(epoch + 1, 153 | config.epochs)) 154 | for data in train_loader: 155 | imgs, labels = data 156 | imgs = util.to_device(imgs, gpu=use_gpu) 157 | labels = util.to_device(labels, gpu=use_gpu) 158 | 159 | logits = model(imgs) 160 | loss = loss_fn(logits, labels) 161 | optimizer.zero_grad() 162 | loss.backward() 163 | optimizer.step() 164 | 165 | if global_step % config.log_steps == 0 and global_step > 0: 166 | probs = model.module.probability(logits) 167 | preds = torch.argmax(probs, dim=1).detach().cpu().numpy() 168 | labels = labels.cpu().detach().numpy() 169 | acc, f1, _, _ = util.clf_metrics(preds, labels) 170 | lr = util.get_learning_rate(optimizer) 171 | 172 | log.info("Step {} | TRAINING batch: Loss {:.4f} | F1 {:.4f} | " 173 | "Accuracy {:.4f} | LR {:.2e}".format(global_step, 174 | loss.item(), 175 | f1, acc, 176 | lr)) 177 | 178 | if global_step % config.eval_steps == 0 and global_step > 0: 179 | best_score = validate(val_loader, 180 | model, 181 | best_score=best_score, 182 | global_step=global_step, 183 | cfg=config) 184 | scheduler.step(best_score) 185 | global_step += 1 186 | 187 | 188 | if __name__ == '__main__': 189 | seed = config.random_seed 190 | if seed: 191 | np.random.seed(seed) 192 | torch.manual_seed(seed) 193 | if torch.cuda.is_available(): 194 | torch.cuda.manual_seed_all(seed) 195 | main() 196 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from sklearn.metrics import f1_score, precision_score, recall_score, \ 3 | accuracy_score 4 | 5 | log = logging.getLogger(__name__) 6 | logging.basicConfig(level=logging.INFO) 7 | 8 | 9 | def load_model_weights(model, state_dict, verbose=True): 10 | """ 11 | Loads the model weights from the state dictionary. Function will only load 12 | the weights which have matching key names and dimensions in the state 13 | dictionary. 14 | 15 | :param state_dict: Pytorch model state dictionary 16 | :param verbose: bool, If True, the function will print the 17 | weight keys of parametares that can and cannot be loaded from the 18 | checkpoint state dictionary. 19 | :return: The model with loaded weights 20 | """ 21 | new_state_dict = model.state_dict() 22 | non_loadable, loadable = set(), set() 23 | 24 | for k, v in state_dict.items(): 25 | if k not in new_state_dict: 26 | non_loadable.add(k) 27 | continue 28 | 29 | if v.shape != new_state_dict[k].shape: 30 | non_loadable.add(k) 31 | continue 32 | 33 | new_state_dict[k] = v 34 | loadable.add(k) 35 | 36 | if verbose: 37 | log.info("### Checkpoint weights that WILL be loaded: ###") 38 | {log.info(k) for k in loadable} 39 | 40 | log.info("### Checkpoint weights that CANNOT be loaded: ###") 41 | {log.info(k) for k in non_loadable} 42 | 43 | model.load_state_dict(new_state_dict) 44 | return model 45 | 46 | 47 | def to_device(tensor, gpu=False): 48 | """ 49 | Places a Pytorch Tensor object on a GPU or CPU device. 50 | 51 | :param tensor: Pytorch Tensor object 52 | :param gpu: bool, Flag which specifies GPU placement 53 | :return: Tensor object 54 | """ 55 | return tensor.cuda() if gpu else tensor.cpu() 56 | 57 | 58 | def clf_metrics(predictions, targets, average='macro'): 59 | f1 = f1_score(targets, predictions, average=average) 60 | precision = precision_score(targets, predictions, average=average) 61 | recall = recall_score(targets, predictions, average=average) 62 | acc = accuracy_score(targets, predictions) 63 | 64 | return acc, f1, precision, recall 65 | 66 | 67 | def get_learning_rate(optimizer): 68 | """ 69 | Retrieves the current learning rate. If the optimizer doesn't have 70 | trainable variables, it will raise an error. 71 | :param optimizer: Optimizer object 72 | :return: float, Current learning rate 73 | """ 74 | if len(optimizer.param_groups) > 0: 75 | return optimizer.param_groups[0]['lr'] 76 | else: 77 | raise ValueError('No trainable parameters.') --------------------------------------------------------------------------------