├── .gitignore ├── README.md ├── assert ├── confusion_matrix.png ├── model_best.pth.tar └── roc_auc_curve.png ├── data └── data_augu │ ├── train │ ├── LESION │ │ └── .gitignore │ └── NORMAL │ │ └── .gitignore │ └── val │ ├── LESION │ └── .gitignore │ └── NORMAL │ └── .gitignore ├── main.py ├── preprocessing.py └── utils ├── FocalLoss.py ├── Logger.py ├── data_augument.py ├── sampler.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CheXNet-Pytorch 2 | 3 | This is a binary classification(Pneumonia vs Normal) in Xray14 with Pytorch.Densenet121 is adopted directly to train a classifier,which is accessible easily in current mainstream deep learning framework,e.g. Keras,TensorFlow,PyTorch.After 160 epochs of training,I finally achieved a best accuray of `94.98%`. 4 | 5 | ## Dataset 6 | 7 | The [ChestX-ray14 dataset](http://openaccess.thecvf.com/content_cvpr_2017/papers/Wang_ChestX-ray8_Hospital-Scale_Chest_CVPR_2017_paper.pdf) comprises 112,120 frontal-view chest X-ray images of 30,805 unique patients with 14 disease labels.I firstly extracted all normal images and imags with pneumonia,whose numbers are 1353 and 604,12 respectively.Then these images(original size is `1024*1024`) are resized into `256*256 ` and finally I randomly split the dataset into training(`80%`),validataion(`20%`) sets.Obviously,the baises will seriously inclined to to the class with a larger number if neural network is trained roughly with raw dataset because a severely class-imbalance exists.To go further,neural network just guess all inputs as the normal and can perform a "high" accuracy(`~97.81%`).That's the result we don't want to see that because a terrible overfitting occures,which means such a classifier is meaningless.Data augumentation is an effective method to tackle such problems. 8 | 9 | ## Data augumentation 10 | 11 | In this project,data augumentation that is conducted for images with pneumonia is neccessary and makes great sense.The used transforms are as follows: 12 | 13 | 1. `gaussain_blur:add random gaussain blur with mean=0 and variance=0.1` 14 | 2. `gaussain_noise:add random gaussin noise with mean=0 and variance=0.1` 15 | 3. `shift:randomly draft image with specified “distance”` 16 | 4. `rotation:randomly rotate image with specified angle ` 17 | 5. `brightness:randomly adjust image's brightness` 18 | 6. `contrast:randomly adjust image's contrast` 19 | 20 | By data augumentation,the number of images increases 12 times(i.e. `13*1353=17589`) compared with raw dataset.For more details,please check the script [`preprocessing.py`](https://github.com/estelle1722/CheXNet-Pytorch/blob/master/preprocessing.py). 21 | 22 | In fact,the class-imbalance problem still exists after data augumentation.For completely eliminating such issise,I remove same normal images util the number of two classes are approximately equivalent.Hence I just randomly select 180,00 images and the left images are left out instead.Ultimately,the directory tree of processed dataset is as follows: 23 | 24 | ```bash 25 | ├── data_augu 26 | │   ├── train 27 | │   │   ├── LESION 28 | │   │   └── NORMAL 29 | │   └── val 30 | │   ├── LESION 31 | │   └── NORMAL 32 | ``` 33 | 34 | ## Requirement 35 | 36 | - Python 3.6 37 | - [PyTorch 0.3](https://pytorch.org/) 38 | 39 | ## Usage 40 | 41 | 1. Clone this repository. 42 | 2. Download images of ChestX-ray14 from this [`released page`](https://nihcc.app.box.com/v/ChestXray-NIHCC) then decompress them and finally extract all normal images and images with pneumonia into the directory `NORMAL_ORIGINAL` and `LESION_ORIGINAL` respectively. 43 | 3. Run the script [`preprocessing.py`](https://github.com/estelle1722/CheXNet-Pytorch/blob/master/preprocessing.py) to accomplish data augumentation. 44 | 4. Split the entire dataset into training(`80%`) and validataion(`20%`) sets. 45 | 5. Run the script[`main.py`](https://github.com/estelle1722/CheXNet-Pytorch/blob/master/main.py) and train `Densenet121`. 46 | 47 | ## Evaluation 48 | 49 | The runtime environment is shown in the following table: 50 | 51 | | Property | Values | Note | 52 | | :-------------------: | :-------------------: | :-----------------: | 53 | | Model | Densenet121 | - | 54 | | Optimizer | Adam | - | 55 | | Initial learning rate | 0.001 | decay 0.1/40 epochs | 56 | | GPUs | 2*GeForce GTX 1080 Ti | - | 57 | | Epochs | 160 | - | 58 | | Mini Batch Size | 50 | - | 59 | 60 | ### Confusion matrix 61 | 62 | ![confusion_matrix](./assert/confusion_matrix.png) 63 | 64 | ### Receiver Operating Characteristic(ROC) & Area Under Curve(AUC) 65 | 66 | ![roc_auc_curve](./assert/roc_auc_curve.png) 67 | 68 | ## More 69 | 70 | In addition,[FocalLoss](https://arxiv.org/abs/1708.02002 ) also with default setting is operated before data augumentation to fix class imbalance.It is a pity that there is no distinct improvement.Another training trick called [Cycle Learning Rate](https://arxiv.org/pdf/1506.01186.pdf) is a kind of adjusting learning rate .Maybe it works for this project. 71 | 72 | Before using densenet121,I try to train resnet18 but without improvement.One reason I think out may be that resnet18 has more trainable parameters than densenet121,which results in larger diffculty to train resnet18.Rather than data augumentation,fine-tune pretrained model is also a common way to do classification(Pneumonia vs Normal).If you are interesed in this idea,I recommend you to refer to this [repository](https://github.com/arnoweng/CheXNet) that obtained most stars in GitHub about this issise.I will feel great appreciated if you realize outstanding performance and make furture disscussion at your convenience because I get stuck in this aspect. 73 | 74 | This experiment result origins from initial parameters setting without much tricks.Furture improvement are probably achieved. 75 | 76 | Any question,please contact with me.Email(zr.estelle@gmail.com) and WeChat(zhangrong1728) are available. 77 | 78 | 79 | 80 | ------- 81 | 82 | 1.For PyTorch,beside data augumentation there is another useful method called `oversampling` to keep class balance.However,it doesn't work in this project.You can look up the [python script](https://github.com/estelle1722/CheXNet-Pytorch/blob/master/utils/sampler.py). 83 | 84 | 85 | 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /assert/confusion_matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangrong1722/CheXNet-Pytorch/a55e4fec0d4732971978c9aedef774d1cf6fc38a/assert/confusion_matrix.png -------------------------------------------------------------------------------- /assert/model_best.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangrong1722/CheXNet-Pytorch/a55e4fec0d4732971978c9aedef774d1cf6fc38a/assert/model_best.pth.tar -------------------------------------------------------------------------------- /assert/roc_auc_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangrong1722/CheXNet-Pytorch/a55e4fec0d4732971978c9aedef774d1cf6fc38a/assert/roc_auc_curve.png -------------------------------------------------------------------------------- /data/data_augu/train/LESION/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /data/data_augu/train/NORMAL/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /data/data_augu/val/LESION/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /data/data_augu/val/NORMAL/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import sys 4 | import argparse 5 | import time 6 | import itertools 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import warnings 12 | import matplotlib.pyplot as plt 13 | import torch.optim as optim 14 | import torch.nn.functional as F 15 | from sklearn.metrics import confusion_matrix 16 | import scikitplot as skplt 17 | from torch.autograd import Variable 18 | from torch.backends import cudnn 19 | from torch.nn import DataParallel 20 | import torchvision.transforms as transforms 21 | import torchvision.models as models 22 | from torch.optim import lr_scheduler 23 | from torch.utils.data import DataLoader 24 | from torchvision.datasets import ImageFolder 25 | 26 | sys.path.append('./') 27 | from utils.util import set_prefix, write, add_prefix 28 | from utils.FocalLoss import FocalLoss 29 | 30 | plt.switch_backend('agg') 31 | 32 | parser = argparse.ArgumentParser(description='Training on Diabetic Retinopathy Dataset') 33 | parser.add_argument('--batch_size', '-b', default=90, type=int, help='batch size') 34 | parser.add_argument('--epochs', '-e', default=90, type=int, help='training epochs') 35 | parser.add_argument('--lr', default=1e-3, type=float, help='learning rate') 36 | parser.add_argument('--cuda', default=torch.cuda.is_available(), type=bool, help='use gpu or not') 37 | parser.add_argument('--step_size', default=30, type=int, help='learning rate decay interval') 38 | parser.add_argument('--gamma', default=0.1, type=float, help='learning rate decay scope') 39 | parser.add_argument('--interval_freq', '-i', default=12, type=int, help='printing log frequence') 40 | parser.add_argument('--data', '-d', default='./data/data_augu', help='path to dataset') 41 | parser.add_argument('--prefix', '-p', default='classifier', type=str, help='folder prefix') 42 | parser.add_argument('--best_model_path', default='model_best.pth.tar', help='best model saved path') 43 | parser.add_argument('--is_focal_loss', '-f', action='store_false', 44 | help='use focal loss or common loss(i.e. cross ectropy loss)(default: true)') 45 | 46 | best_acc = 0.0 47 | 48 | 49 | def main(): 50 | global args, best_acc 51 | args = parser.parse_args() 52 | # save source script 53 | set_prefix(args.prefix, __file__) 54 | model = models.densenet121(pretrained=False, num_classes=2) 55 | if args.cuda: 56 | model = DataParallel(model).cuda() 57 | else: 58 | warnings.warn('there is no gpu') 59 | 60 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 61 | # accelerate the speed of training 62 | cudnn.benchmark = True 63 | 64 | train_loader, val_loader = load_dataset() 65 | # class_names=['LESION', 'NORMAL'] 66 | class_names = train_loader.dataset.classes 67 | print(class_names) 68 | if args.is_focal_loss: 69 | print('try focal loss!!') 70 | criterion = FocalLoss().cuda() 71 | else: 72 | criterion = nn.CrossEntropyLoss().cuda() 73 | 74 | # learning rate decay per epochs 75 | exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) 76 | since = time.time() 77 | print('-' * 10) 78 | for epoch in range(args.epochs): 79 | exp_lr_scheduler.step() 80 | train(train_loader, model, optimizer, criterion, epoch) 81 | cur_accuracy = validate(model, val_loader, criterion) 82 | is_best = cur_accuracy > best_acc 83 | best_acc = max(cur_accuracy, best_acc) 84 | save_checkpoint({ 85 | 'epoch': epoch + 1, 86 | 'arch': 'resnet18', 87 | 'state_dict': model.state_dict(), 88 | 'best_accuracy': best_acc, 89 | 'optimizer': optimizer.state_dict(), 90 | }, is_best) 91 | time_elapsed = time.time() - since 92 | print('Training complete in {:.0f}m {:.0f}s'.format( 93 | time_elapsed // 60, time_elapsed % 60)) 94 | # compute validate meter such as confusion matrix 95 | compute_validate_meter(model, add_prefix(args.prefix, args.best_model_path), val_loader) 96 | # save running parameter setting to json 97 | write(vars(args), add_prefix(args.prefix, 'paras.txt')) 98 | 99 | 100 | def compute_validate_meter(model, best_model_path, val_loader): 101 | checkpoint = torch.load(best_model_path) 102 | model.load_state_dict(checkpoint['state_dict']) 103 | best_acc = checkpoint['best_accuracy'] 104 | print('best accuracy={:.4f}'.format(best_acc)) 105 | pred_y = list() 106 | test_y = list() 107 | probas_y = list() 108 | for data, target in val_loader: 109 | if args.cuda: 110 | data, target = data.cuda(), target.cuda() 111 | data, target = Variable(data, volatile=True), Variable(target) 112 | output = model(data) 113 | probas_y.extend(output.data.cpu().numpy().tolist()) 114 | pred_y.extend(output.data.cpu().max(1, keepdim=True)[1].numpy().flatten().tolist()) 115 | test_y.extend(target.data.cpu().numpy().flatten().tolist()) 116 | 117 | confusion = confusion_matrix(pred_y, test_y) 118 | plot_confusion_matrix(confusion, 119 | classes=val_loader.dataset.classes, 120 | title='Confusion matrix') 121 | plt_roc(test_y, probas_y) 122 | 123 | 124 | def plt_roc(test_y, probas_y, plot_micro=False, plot_macro=False): 125 | assert isinstance(test_y, list) and isinstance(probas_y, list), 'the type of input must be list' 126 | skplt.metrics.plot_roc(test_y, probas_y, plot_micro=plot_micro, plot_macro=plot_macro) 127 | plt.savefig(add_prefix(args.prefix, 'roc_auc_curve.png')) 128 | plt.close() 129 | 130 | 131 | def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues): 132 | """ 133 | This function prints and plots the confusion matrix. 134 | Normalization can be applied by setting `normalize=True`. 135 | refence: 136 | http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html 137 | """ 138 | if normalize: 139 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 140 | print("Normalized confusion matrix") 141 | else: 142 | print('Confusion matrix, without normalization') 143 | 144 | print(cm) 145 | 146 | plt.imshow(cm, interpolation='nearest', cmap=cmap) 147 | plt.title(title) 148 | plt.colorbar() 149 | tick_marks = np.arange(len(classes)) 150 | plt.xticks(tick_marks, classes, rotation=45) 151 | plt.yticks(tick_marks, classes) 152 | 153 | fmt = '.2f' if normalize else 'd' 154 | thresh = cm.max() / 2. 155 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 156 | plt.text(j, i, format(cm[i, j], fmt), 157 | horizontalalignment="center", 158 | color="white" if cm[i, j] > thresh else "black") 159 | 160 | plt.tight_layout() 161 | plt.ylabel('True label') 162 | plt.xlabel('Predicted label') 163 | plt.savefig(add_prefix(args.prefix, 'confusion_matrix.png')) 164 | plt.close() 165 | 166 | 167 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 168 | # save training state after each epoch 169 | torch.save(state, add_prefix(args.prefix, filename)) 170 | if is_best: 171 | shutil.copyfile(add_prefix(args.prefix, filename), 172 | add_prefix(args.prefix, args.best_model_path)) 173 | 174 | 175 | def load_dataset(): 176 | if args.data == './data/data_augu': 177 | traindir = os.path.join(args.data, 'train') 178 | valdir = os.path.join(args.data, 'val') 179 | mean = [0.5186, 0.5186, 0.5186] 180 | std = [0.1968, 0.1968, 0.1968] 181 | normalize = transforms.Normalize(mean, std) 182 | train_transforms = transforms.Compose([ 183 | transforms.CenterCrop(224), 184 | transforms.RandomHorizontalFlip(), 185 | transforms.ToTensor(), 186 | normalize, 187 | ]) 188 | val_transforms = transforms.Compose([ 189 | transforms.CenterCrop(224), 190 | transforms.ToTensor(), 191 | normalize, 192 | ]) 193 | train_dataset = ImageFolder(traindir, train_transforms) 194 | val_dataset = ImageFolder(valdir, val_transforms) 195 | print('load data-augumentation dataset successfully!!!') 196 | else: 197 | raise ValueError("parameter 'data' that means path to dataset must be in " 198 | "['./data/data_augu']") 199 | 200 | train_loader = DataLoader(train_dataset, 201 | batch_size=args.batch_size, 202 | shuffle=True, 203 | num_workers=4, 204 | pin_memory=True if args.cuda else False) 205 | val_loader = DataLoader(val_dataset, 206 | batch_size=args.batch_size, 207 | shuffle=False, 208 | num_workers=1, 209 | pin_memory=True if args.cuda else False) 210 | return train_loader, val_loader 211 | 212 | 213 | def train(train_loader, model, optimizer, criterion, epoch): 214 | model.train(True) 215 | print('Epoch {}/{}'.format(epoch + 1, args.epochs)) 216 | print('-' * 10) 217 | running_loss = 0.0 218 | running_corrects = 0 219 | 220 | # Iterate over data. 221 | for idx, (inputs, labels) in enumerate(train_loader): 222 | # wrap them in Variable 223 | if args.cuda: 224 | inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda()) 225 | else: 226 | inputs, labels = Variable(inputs), Variable(labels) 227 | 228 | # zero the parameter gradients 229 | optimizer.zero_grad() 230 | 231 | # forward 232 | outputs = model(inputs) 233 | 234 | _, preds = torch.max(outputs.data, 1) 235 | 236 | loss = criterion(outputs, labels) 237 | loss.backward() 238 | optimizer.step() 239 | if idx % args.interval_freq == 0: 240 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 241 | epoch + 1, idx * len(inputs), len(train_loader.dataset), 242 | 100. * idx / len(train_loader), loss.data[0])) 243 | 244 | # statistics 245 | running_loss += loss.data[0] * inputs.size(0) 246 | running_corrects += torch.sum(preds == labels.data) 247 | 248 | epoch_loss = running_loss / len(train_loader.dataset) 249 | epoch_acc = running_corrects / len(train_loader.dataset) 250 | 251 | print('Training Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc)) 252 | 253 | 254 | def validate(model, val_loader, criterion): 255 | model.eval() 256 | test_loss = 0 257 | correct = 0 258 | for data, target in val_loader: 259 | if args.cuda: 260 | data, target = data.cuda(), target.cuda() 261 | data, target = Variable(data, volatile=True), Variable(target) 262 | output = model(data) 263 | test_loss += criterion(output, target).data[0] 264 | # get the index of the max log-probability 265 | pred = output.data.max(1, keepdim=True)[1] 266 | correct += pred.eq(target.data.view_as(pred)).long().cpu().sum() 267 | 268 | test_loss /= len(val_loader.dataset) 269 | test_acc = 100. * correct / len(val_loader.dataset) 270 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.3f}%)\n'.format( 271 | test_loss, correct, len(val_loader.dataset), test_acc)) 272 | return test_acc 273 | 274 | 275 | if __name__ == '__main__': 276 | main() 277 | -------------------------------------------------------------------------------- /preprocessing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | from PIL import Image, ImageEnhance 4 | 5 | from utils.data_augument import img_contrast, img_shift, img_rotation, gaussain_blur, gaussain_noise, avg_blur 6 | 7 | 8 | _max_filiter_size = 7 # for avg_blur and gaussain_blur 9 | _sigma = 0 # for gaussain_blur 10 | 11 | _mean = 0 # for gaussain_noise 12 | _var = 0.1 # for gaussain_noise 13 | 14 | _x_min_shift_piexl = -20 # for img_shift 15 | _x_max_shift_piexl = 20 # for img_shift 16 | _y_min_shift_piexl = -20 # for img_shift 17 | _y_max_shift_piexl = 20 # for img_shift 18 | _fill_pixel = 0 # for img_shift and img_rotation: black 19 | 20 | _min_angel = -20 # for img_rotation 21 | _max_angel = 20 # for img_rotation 22 | _min_scale = 0.9 # for img_rotation 23 | _max_scale = 1.1 # for img_rotation 24 | 25 | _min_s = -10 # for img_contrast 26 | _max_s = 10 # for img_contrast 27 | _min_v = -10 # for img_contrast 28 | _max_v = 10 # for img_contrast 29 | 30 | _min_h = -30 # for img_color 31 | _max_h = 30 # for img_color 32 | 33 | _generate_quantity = 10 34 | 35 | data_dir = 'data/LESION_DATA' 36 | img_lst = os.listdir(data_dir) 37 | for name in img_lst: 38 | abs_path = os.path.join(data_dir, name) 39 | img = cv2.imread(abs_path) 40 | prefix, suffix = abs_path.split('.') 41 | cv2.imwrite('%s_%s.%s' % (prefix, 'blur1', suffix), gaussain_blur(img, _max_filiter_size, _sigma)) 42 | cv2.imwrite('%s_%s.%s' % (prefix, 'blur2', suffix), gaussain_blur(img, _max_filiter_size, _sigma)) 43 | cv2.imwrite('%s_%s.%s' % (prefix, 'noise1', suffix), gaussain_noise(img, _mean, _var)) 44 | cv2.imwrite('%s_%s.%s' % (prefix, 'noise2', suffix), gaussain_noise(img, _mean, _var)) 45 | cv2.imwrite('%s_%s.%s' % (prefix, 'shift1', suffix), 46 | img_shift(img, _x_min_shift_piexl, _x_max_shift_piexl, _y_min_shift_piexl, _y_max_shift_piexl, 47 | _fill_pixel)) 48 | cv2.imwrite('%s_%s.%s' % (prefix, 'shift2', suffix), 49 | img_shift(img, _x_min_shift_piexl, _x_max_shift_piexl, _y_min_shift_piexl, _y_max_shift_piexl, 50 | _fill_pixel)) 51 | cv2.imwrite('%s_%s.%s' % (prefix, 'rotation1', suffix), 52 | img_rotation(img, _min_angel, _max_angel, _min_scale, _max_scale, _fill_pixel)) 53 | cv2.imwrite('%s_%s.%s' % (prefix, 'rotation2', suffix), 54 | img_rotation(img, _min_angel, _max_angel, _min_scale, _max_scale, _fill_pixel)) 55 | 56 | img02 = Image.open(abs_path) 57 | ImageEnhance.Brightness(img02).enhance(0.5).save('%s_%s.%s' % (prefix, 'brightness1', suffix)) 58 | ImageEnhance.Brightness(img02).enhance(1.5).save('%s_%s.%s' % (prefix, 'brightness2', suffix)) 59 | ImageEnhance.Contrast(img02).enhance(0.6).save('%s_%s.%s' % (prefix, 'contrast1', suffix)) 60 | ImageEnhance.Contrast(img02).enhance(1.5).save('%s_%s.%s' % (prefix, 'contrast2', suffix)) 61 | -------------------------------------------------------------------------------- /utils/FocalLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | class FocalLoss(nn.Module): 7 | """ 8 | reference: https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py 9 | """ 10 | def __init__(self, gamma=2, alpha=0.25, size_average=True): 11 | super(FocalLoss, self).__init__() 12 | self.gamma = gamma 13 | self.alpha = alpha 14 | if isinstance(alpha,(float, int)): self.alpha = torch.Tensor([alpha,1-alpha]) 15 | if isinstance(alpha,list): self.alpha = torch.Tensor(alpha) 16 | self.size_average = size_average 17 | 18 | def forward(self, input, target): 19 | if input.dim()>2: 20 | input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W 21 | input = input.transpose(1,2) # N,C,H*W => N,H*W,C 22 | input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C 23 | target = target.view(-1,1) 24 | logpt = F.log_softmax(input, dim=1) 25 | logpt = logpt.gather(1,target) 26 | logpt = logpt.view(-1) 27 | pt = Variable(logpt.data.exp()) 28 | 29 | if self.alpha is not None: 30 | if self.alpha.type()!=input.data.type(): 31 | self.alpha = self.alpha.type_as(input.data) 32 | at = self.alpha.gather(0,target.data.view(-1)) 33 | logpt = logpt * Variable(at) 34 | 35 | loss = -1 * (1-pt)**self.gamma * logpt 36 | if self.size_average: return loss.mean() 37 | else: return loss.sum() 38 | 39 | if __name__ == '__main__': 40 | torch.manual_seed(1) 41 | inputs = Variable(torch.randn((10, 2))) 42 | targets = Variable(torch.LongTensor(10).random_(2)) 43 | loss = FocalLoss()(inputs, targets) 44 | print(loss) -------------------------------------------------------------------------------- /utils/Logger.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import scipy.misc 4 | try: 5 | from StringIO import StringIO # Python 2.7 6 | except ImportError: 7 | from io import BytesIO # Python 3.x 8 | 9 | 10 | class Logger(object): 11 | 12 | def __init__(self, log_dir): 13 | """Create a summary writer logging to log_dir.""" 14 | self.writer = tf.summary.FileWriter(log_dir) 15 | 16 | def scalar_summary(self, tag, value, step): 17 | """Log a scalar variable.""" 18 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, 19 | simple_value=value)]) 20 | self.writer.add_summary(summary, step) 21 | 22 | def image_summary(self, tag, images, step): 23 | """Log a list of images.""" 24 | 25 | img_summaries = [] 26 | for i, img in enumerate(images): 27 | # Write the image to a string 28 | try: 29 | s = StringIO() 30 | except: 31 | s = BytesIO() 32 | scipy.misc.toimage(img).save(s, format="png") 33 | 34 | # Create an Image object 35 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 36 | height=img.shape[0], 37 | width=img.shape[1]) 38 | # Create a Summary value 39 | img_summaries.append( 40 | tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 41 | 42 | # Create and write Summary 43 | summary = tf.Summary(value=img_summaries) 44 | self.writer.add_summary(summary, step) 45 | 46 | def histo_summary(self, tag, values, step, bins=1000): 47 | """Log a histogram of the tensor of values.""" 48 | 49 | # Create a histogram using numpy 50 | counts, bin_edges = np.histogram(values, bins=bins) 51 | 52 | # Fill the fields of the histogram proto 53 | hist = tf.HistogramProto() 54 | hist.min = float(np.min(values)) 55 | hist.max = float(np.max(values)) 56 | hist.num = int(np.prod(values.shape)) 57 | hist.sum = float(np.sum(values)) 58 | hist.sum_squares = float(np.sum(values**2)) 59 | 60 | # Drop the start of the first bin 61 | bin_edges = bin_edges[1:] 62 | 63 | # Add bin edges and counts 64 | for edge in bin_edges: 65 | hist.bucket_limit.append(edge) 66 | for c in counts: 67 | hist.bucket.append(c) 68 | 69 | # Create and write Summary 70 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 71 | self.writer.add_summary(summary, step) 72 | self.writer.flush() 73 | -------------------------------------------------------------------------------- /utils/data_augument.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | import random 5 | from matplotlib import pyplot as plt 6 | 7 | # create by Feng, edit by Feng 2017 / 08 / 14 8 | # this project is doing image data augmentation for any DL/ML algorithm 9 | # 10 | 11 | # avg blur minimum filter size is 3 12 | 13 | def avg_blur(img, max_filiter_size = 3) : 14 | img = img.astype(np.uint8) 15 | if max_filiter_size >= 3 : 16 | filter_size = random.randint(3, max_filiter_size) 17 | if filter_size % 2 == 0 : 18 | filter_size += 1 19 | out = cv2.blur(img, (filter_size, filter_size)) 20 | return out 21 | 22 | # gaussain blur minimum filter size is 3 23 | # when sigma = 0 gaussain blur weight will compute by program 24 | # when the sigma is more large the blur effect more obvious 25 | 26 | def gaussain_blur(img, max_filiter_size = 3, sigma = 0) : 27 | img = img.astype(np.uint8) 28 | if max_filiter_size >= 3 : 29 | filter_size = random.randint(3, max_filiter_size) 30 | if filter_size % 2 == 0 : 31 | filter_size += 1 32 | #print ('size = %d'% filter_size) 33 | out = cv2.GaussianBlur(img, (filter_size, filter_size), sigma) 34 | return out 35 | 36 | def gaussain_noise(img, mean = 0, var = 0.1) : 37 | img = img.astype(np.uint8) 38 | h, w, c = img.shape 39 | sigma = var ** 0.5 40 | gauss = np.random.normal(mean, sigma, (h, w, c)) 41 | gauss = gauss.reshape(h, w, c).astype(np.uint8) 42 | noisy = img + gauss 43 | return noisy 44 | 45 | # fill_pixel is 0(black) or 255(white) 46 | 47 | def img_shift(img, x_min_shift_piexl = -1, x_max_shift_piexl = 1, y_min_shift_piexl = -1, y_max_shift_piexl = 1, fill_pixel = 0) : 48 | img = img.astype(np.uint8) 49 | h, w, c = img.shape 50 | out = np.zeros(img.shape) 51 | if fill_pixel == 255 : 52 | out[:, :] = 255 53 | out = out.astype(np.uint8) 54 | move_x = random.randint(x_min_shift_piexl, x_max_shift_piexl) 55 | move_y = random.randint(y_min_shift_piexl, y_max_shift_piexl) 56 | #print (('move_x = %d')% (move_x)) 57 | #print (('move_y = %d')% (move_y)) 58 | if move_x >= 0 and move_y >= 0 : 59 | out[move_y:, move_x: ] = img[0: (h - move_y), 0: (w - move_x)] 60 | elif move_x < 0 and move_y < 0 : 61 | out[0: (h + move_y), 0: (w + move_x)] = img[ - move_y:, - move_x:] 62 | elif move_x >= 0 and move_y < 0 : 63 | out[0: (h + move_y), move_x:] = img[ - move_y:, 0: (w - move_x)] 64 | elif move_x < 0 and move_y >= 0 : 65 | out[move_y:, 0: (w + move_x)] = img[0 : (h - move_y), - move_x:] 66 | return out 67 | 68 | # In img_rotation func. rotation center is image center 69 | 70 | def img_rotation(img, min_angel = 0, max_angel = 0, min_scale = 1, max_scale = 1, fill_pixel = 0) : 71 | img = img.astype(np.uint8) 72 | h, w, c = img.shape 73 | _angel = random.randint(min_angel, max_angel) 74 | _scale = random.uniform(min_scale, max_scale) 75 | rotation_matrix = cv2.getRotationMatrix2D((w / 2, h / 2), _angel, _scale) 76 | out = cv2.warpAffine(img, rotation_matrix, (w, h)) 77 | if fill_pixel == 255 : 78 | mask = np.zeros(img.shape) 79 | mask[:, :, :] = 255 80 | mask = mask.astype(np.uint8) 81 | mask = cv2.warpAffine(mask, rotation_matrix, (w, h)) 82 | for i in range (h) : 83 | for j in range(w) : 84 | if mask[i, j, 0] == 0 and mask[i, j, 1] == 0 and mask[i, j, 2] == 0 : 85 | out[i, j, :] = 255 86 | return out 87 | 88 | # In img_flip func. it will random filp image 89 | # when flip factor is 1 it will do hor. flip (Horizontal) 90 | # 0 ver. flip (Vertical) 91 | # -1 hor. + ver flip 92 | 93 | def img_flip(img) : 94 | img = img.astype(np.uint8) 95 | flip_factor = random.randint(-1, 1) 96 | out = cv2.flip(img, flip_factor) 97 | return out 98 | 99 | # Zoom image by scale 100 | 101 | def img_zoom(img, min_scale = 1, max_scale = 1) : 102 | img = img.astype(np.uint8) 103 | h, w, c = img.shape 104 | scale = random.uniform(min_scale, max_scale) 105 | h = int(h * scale) 106 | w = int(w * scale) 107 | out = cv2.resize(img, (w, h), interpolation=cv2.INTER_CUBIC) 108 | return out 109 | 110 | # change image contrast by hsv 111 | 112 | def img_contrast(img, min_s, max_s, min_v, max_v) : 113 | img = img.astype(np.uint8) 114 | hsv_img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) 115 | _s = random.randint(min_s, max_s) 116 | _v = random.randint(min_v, max_v) 117 | if _s >= 0 : 118 | hsv_img[:, :, 1] += _s 119 | else : 120 | _s = - _s 121 | hsv_img[:, :, 1] -= _s 122 | if _v >= 0 : 123 | hsv_img[:, :, 2] += _v 124 | else : 125 | _v = - _v 126 | hsv_img[:, :, 2] += _v 127 | out = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR) 128 | return out 129 | 130 | # change image color by hsv 131 | 132 | def img_color(img, min_h, max_h) : 133 | img = img.astype(np.uint8) 134 | hsv_img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) 135 | _h = random.randint(min_h, max_h) 136 | hsv_img[:, :, 0] += _h 137 | out = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR) 138 | return out 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | -------------------------------------------------------------------------------- /utils/sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | a widely adopted technique called resampling is to solve data imbalance: 3 | 1.rebalance the class distributions when sampling from the imbalanced dataset 4 | 2.estimate the sampling weights automatically 5 | 3.avoid creating a new balanced dataset 6 | 4.mitigate overfitting when it is used in conjunction with data augmentation techniques 7 | reference: https://github.com/ufoym/imbalanced-dataset-sampler 8 | """ 9 | import torch 10 | import torch.utils.data 11 | import torchvision 12 | 13 | 14 | class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler): 15 | """Samples elements randomly from a given list of indices for imbalanced dataset 16 | Arguments: 17 | indices (list, optional): a list of indices 18 | num_samples (int, optional): number of samples to draw 19 | """ 20 | 21 | def __init__(self, dataset, indices=None, num_samples=None): 22 | 23 | # if indices is not provided, 24 | # all elements in the dataset will be considered 25 | self.indices = list(range(len(dataset))) \ 26 | if indices is None else indices 27 | 28 | # if num_samples is not provided, 29 | # draw `len(indices)` samples in each iteration 30 | self.num_samples = len(self.indices) \ 31 | if num_samples is None else num_samples 32 | 33 | # distribution of classes in the dataset 34 | label_to_count = {} 35 | for idx in self.indices: 36 | label = self._get_label(dataset, idx) 37 | if label in label_to_count: 38 | label_to_count[label] += 1 39 | else: 40 | label_to_count[label] = 1 41 | 42 | # weight for each sample 43 | weights = [1.0 / label_to_count[self._get_label(dataset, idx)] 44 | for idx in self.indices] 45 | self.weights = torch.DoubleTensor(weights) 46 | 47 | def _get_label(self, dataset, idx): 48 | dataset_type = type(dataset) 49 | if dataset_type is torchvision.datasets.MNIST: 50 | return dataset.train_labels[idx].item() 51 | elif dataset_type is torchvision.datasets.ImageFolder: 52 | return dataset.imgs[idx][1] 53 | else: 54 | raise NotImplementedError 55 | 56 | def __iter__(self): 57 | return (self.indices[i] for i in torch.multinomial( 58 | self.weights, self.num_samples, replacement=True)) 59 | 60 | def __len__(self): 61 | return self.num_samples 62 | 63 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import os 3 | import json 4 | from collections import OrderedDict 5 | 6 | import cv2 7 | import numpy as np 8 | import sys 9 | import platform 10 | import torch 11 | import torchvision.transforms as transforms 12 | from torch.autograd import Variable 13 | from torch.utils.data import DataLoader 14 | from torchvision import datasets 15 | 16 | sys.path.append('../') 17 | 18 | from PIL import Image 19 | 20 | 21 | # source: source file path 22 | # target:target file path 23 | def copy(source, target): 24 | if not os.path.exists(source): 25 | raise RuntimeError('source file does not exists!') 26 | if os.path.exists(target): 27 | raise RuntimeError('target file has existed!') 28 | shutil.copyfile(source, target) 29 | 30 | 31 | def move(source, target): 32 | if not os.path.exists(source): 33 | raise RuntimeError('source file does not exists!') 34 | if os.path.exists(target): 35 | raise RuntimeError('target file has existed!') 36 | shutil.move(source, target) 37 | 38 | 39 | # center_crop image 40 | def center_crop(path, new_width, new_height): 41 | image = Image 42 | width, height = image.size 43 | 44 | # resize to (224,224) directly if the new height or new width is larger(i.e. enlarge not crop) 45 | if width < new_width or height < new_height: 46 | print(path) 47 | return image.resize((new_width, new_height)) 48 | 49 | left = (width - new_width) / 2 50 | top = (height - new_height) / 2 51 | right = (width + new_width) / 2 52 | bottom = (height + new_height) / 2 53 | 54 | return image.crop((left, top, right, bottom)) 55 | 56 | 57 | # del all file 58 | def clear(path): 59 | if os.path.exists(path): 60 | shutil.rmtree(path) 61 | os.mkdir(path) 62 | 63 | 64 | # write json to txt file 65 | def write(dic, path): 66 | with open(path, 'w+') as f: 67 | f.write(json.dumps(dic)) 68 | 69 | 70 | # read from txt file and transfer to json 71 | def read(path): 72 | with open(path, 'r') as f: 73 | result = json.loads(f.read()) 74 | return result 75 | 76 | 77 | def save_list(lst, path): 78 | f = open(path, 'w') 79 | for i in lst: 80 | f.write((str)(i)) 81 | f.write('\n') 82 | f.close() 83 | 84 | 85 | def set_prefix(prefix, name): 86 | if not os.path.isdir(prefix): 87 | os.mkdir(prefix) 88 | if platform.system() == 'Windows': 89 | name = name.split('\\')[-1] 90 | else: 91 | name = name.split('/')[-1] 92 | shutil.copy(name, os.path.join(prefix, name)) 93 | 94 | 95 | def to_variable(x, has_gpu, requires_grad=False): 96 | if has_gpu: 97 | x = Variable(x.cuda(), requires_grad=requires_grad) 98 | else: 99 | x = Variable(x, requires_grad=requires_grad) 100 | return x 101 | 102 | 103 | def get_parent_diectory(name, num): 104 | """ 105 | return the parent directory 106 | :param name: __file__ 107 | :param num: parent num 108 | :return: path 109 | """ 110 | root = os.path.dirname(name) 111 | for i in range(num): 112 | root = os.path.dirname(root) 113 | return root 114 | 115 | 116 | def read_single_image(path, mean=None, std=None): 117 | # image.shape=(h, w, c) 118 | image = cv2.imread(path) 119 | # BGR -> RGB hwc [0,255]=>[0, 1] 120 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 121 | image = transforms.ToTensor()(image) 122 | if mean is not None and std is not None: 123 | for t, m, s in zip(image, mean, std): 124 | t.sub_(m).div_(s) 125 | 126 | return Variable(image.unsqueeze(0)) 127 | 128 | 129 | def write_list(lst, path): 130 | if not isinstance(lst, list): 131 | raise TypeError('parameter lst must be list.') 132 | with open(path, 'w') as file: 133 | file.write(str(lst)) 134 | 135 | 136 | def read_list(path): 137 | with open(path, 'r') as file: 138 | lst = eval(file.readline()) 139 | return lst 140 | 141 | 142 | def to_np(x): 143 | return x.data.cpu().numpy() 144 | 145 | 146 | def add_prefix(prefix, path): 147 | return os.path.join(prefix, path) 148 | 149 | 150 | def weight_to_cpu(path, is_load_on_cpu=True): 151 | if is_load_on_cpu: 152 | weights = torch.load(path, map_location=lambda storage, loc: storage) 153 | else: 154 | weights = torch.load(path) 155 | new_state_dict = OrderedDict() 156 | for k, v in weights.items(): 157 | name = k[7:] # remove `module.` 158 | new_state_dict[name] = v 159 | 160 | return new_state_dict 161 | 162 | 163 | def merge_dict(dic1, dic2): 164 | merge = dic1.copy() 165 | merge.update(dic2) 166 | return merge 167 | 168 | 169 | def to_image_type(x): 170 | x = torch.squeeze(x) 171 | x = to_np(x) 172 | 173 | x = np.transpose(x, (1, 2, 0)) 174 | x = np.clip(x * 255, 0, 255).astype(np.uint8) 175 | return x 176 | 177 | 178 | def rgb2gray(rgb): 179 | 180 | r, g, b = rgb[:, :, 0], rgb[:, :, 1], rgb[:, :, 2] 181 | gray = 0.2989 * r + 0.5870 * g + 0.1140 * b 182 | 183 | return gray 184 | 185 | 186 | def get_mean_and_std(path, transform, channels=3): 187 | dataset = datasets.ImageFolder(root=path, transform=transform) 188 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=2) 189 | mean = torch.zeros(channels) 190 | std = torch.zeros(channels) 191 | print('==> Computing mean and std..') 192 | for inputs, targets in dataloader: 193 | for i in range(channels): 194 | mean[i] += inputs[:, i, :, :].mean() 195 | std[i] += inputs[:, i, :, :].std() 196 | mean.div_(len(dataset)) 197 | std.div_(len(dataset)) 198 | mean, std = mean.numpy().tolist(), std.numpy().tolist() 199 | return [round(x, 4) for x in mean], [round(y, 4) for y in std] 200 | 201 | 202 | if __name__ == '__main__': 203 | # tensor = Variable(torch.randn((1, 3, 224, 224))) 204 | # to_image_type(tensor) 205 | # data_dir = '../data/diabetic_without_boundry/train' 206 | # data_dir = '../data/mnist/train' 207 | # data_dir = '../data/xray_all/train' 208 | data_dir = '../data/data_augu/train' 209 | print(get_mean_and_std(path=data_dir, 210 | channels=3, 211 | transform=transforms.Compose([transforms.CenterCrop(224), transforms.ToTensor()]))) 212 | --------------------------------------------------------------------------------