├── requirements.txt ├── train.sh ├── prepare_train_val.py ├── loss.py ├── LICENSE ├── validation.py ├── .gitignore ├── dataset.py ├── evaluate.py ├── utils.py ├── train.py ├── generate_masks.py ├── README.rst ├── models.py └── transforms.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==0.3.1.post2 2 | numpy==1.14.0 3 | opencv-python==3.3.0.10 4 | tqdm==4.19.4 5 | torchvision==0.1.9 6 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export CUDA_DEVICE_ORDER=PCI_BUS_ID 4 | export CUDA_VISIBLE_DEVICES=0,1,2,3 5 | 6 | for i in 0 1 2 3 4 7 | do 8 | python train.py --device-ids 0,1,2,3 --limit 10000 --batch-size 16 --n-epochs 10 --fold $i --model UNet16 9 | python train.py --device-ids 0,1,2,3 --limit 10000 --batch-size 16 --n-epochs 15 --fold $i --lr 0.00001 --model UNe16 10 | done -------------------------------------------------------------------------------- /prepare_train_val.py: -------------------------------------------------------------------------------- 1 | from dataset import data_path 2 | from sklearn.model_selection import KFold 3 | import numpy as np 4 | 5 | 6 | def get_split(fold, num_splits=5): 7 | train_path = data_path / 'train' / 'angyodysplasia' / 'images' 8 | 9 | train_file_names = np.array(sorted(list(train_path.glob('*')))) 10 | 11 | kf = KFold(n_splits=num_splits, random_state=2018) 12 | 13 | ids = list(kf.split(train_file_names)) 14 | 15 | train_ids, val_ids = ids[fold] 16 | 17 | if fold == -1: 18 | return train_file_names, train_file_names 19 | else: 20 | return train_file_names[train_ids], train_file_names[val_ids] 21 | 22 | 23 | if __name__ == '__main__': 24 | ids = get_split(0) 25 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import functional as F 3 | 4 | 5 | def soft_jaccard(outputs, targets): 6 | eps = 1e-15 7 | jaccard_target = (targets == 1).float() 8 | jaccard_output = F.sigmoid(outputs) 9 | 10 | intersection = (jaccard_output * jaccard_target).sum() 11 | union = jaccard_output.sum() + jaccard_target.sum() 12 | return intersection / (union - intersection + eps) 13 | 14 | 15 | class LossBinary: 16 | """ 17 | Loss defined as BCE - log(soft_jaccard) 18 | 19 | Vladimir Iglovikov, Sergey Mushinskiy, Vladimir Osin, 20 | Satellite Imagery Feature Detection using Deep Convolutional Neural Network: A Kaggle Competition 21 | arXiv:1706.06169 22 | """ 23 | 24 | def __init__(self, jaccard_weight=0): 25 | self.nll_loss = nn.BCEWithLogitsLoss() 26 | self.jaccard_weight = jaccard_weight 27 | 28 | def __call__(self, outputs, targets): 29 | loss = (1 - self.jaccard_weight) * self.nll_loss(outputs, targets) 30 | 31 | if self.jaccard_weight: 32 | loss += self.jaccard_weight * (1 - soft_jaccard(outputs, targets)) 33 | return loss 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Vladimir Iglovikov, Alexey Shvets, Alexandr A. Kalinin, Alexander Rakhlin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /validation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import utils 3 | from torch import nn 4 | 5 | 6 | def validation_binary(model: nn.Module, criterion, valid_loader): 7 | model.eval() 8 | losses = [] 9 | 10 | jaccard = [] 11 | dice = [] 12 | 13 | for inputs, targets in valid_loader: 14 | inputs = utils.variable(inputs, volatile=True) 15 | targets = utils.variable(targets) 16 | outputs = model(inputs) 17 | loss = criterion(outputs, targets) 18 | losses.append(loss.data[0]) 19 | outputs = (outputs > 0).float() 20 | jaccard += get_jaccard(targets, outputs) 21 | dice += get_dice(targets, outputs) 22 | 23 | valid_loss = np.mean(losses) # type: float 24 | 25 | valid_jaccard = np.mean(jaccard).astype(np.float64) 26 | valid_dice = np.mean(dice).astype(np.float64) 27 | 28 | print('Valid loss: {:.5f}, jaccard: {:.5f}, dice: {:.5f}'.format(valid_loss, valid_jaccard, valid_dice)) 29 | metrics = {'valid_loss': valid_loss, 'jaccard_loss': valid_jaccard, 'dice_loss': valid_dice} 30 | return metrics 31 | 32 | 33 | def get_jaccard(y_true, y_pred): 34 | epsilon = 1e-15 35 | intersection = (y_pred * y_true).sum(dim=-2).sum(dim=-1) 36 | union = y_true.sum(dim=-2).sum(dim=-1) + y_pred.sum(dim=-2).sum(dim=-1) 37 | 38 | return list((intersection / (union + epsilon - intersection)).data.cpu().numpy()) 39 | 40 | 41 | def get_dice(y_true, y_pred): 42 | epsilon = 1e-15 43 | intersection = (y_pred * y_true).sum(dim=-2).sum(dim=-1) 44 | union = y_true.sum(dim=-2).sum(dim=-1) + y_pred.sum(dim=-2).sum(dim=-1) 45 | return list((2 * intersection / (union + epsilon)).data.cpu().numpy()) 46 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import cv2 4 | from torch.utils.data import Dataset 5 | from pathlib import Path 6 | 7 | data_path = Path('data') 8 | 9 | 10 | class AngyodysplasiaDataset(Dataset): 11 | def __init__(self, img_paths: list, to_augment=False, transform=None, mode='train', limit=None): 12 | self.img_paths = img_paths 13 | self.to_augment = to_augment 14 | self.transform = transform 15 | self.mode = mode 16 | self.limit = limit 17 | 18 | def __len__(self): 19 | if self.limit is None: 20 | return len(self.img_paths) 21 | else: 22 | return self.limit 23 | 24 | def __getitem__(self, idx): 25 | if self.limit is None: 26 | img_file_name = self.img_paths[idx] 27 | else: 28 | img_file_name = np.random.choice(self.img_paths) 29 | 30 | img = load_image(img_file_name) 31 | 32 | if self.mode == 'train': 33 | mask = load_mask(img_file_name) 34 | 35 | img, mask = self.transform(img, mask) 36 | 37 | return to_float_tensor(img), torch.from_numpy(np.expand_dims(mask, 0)).float() 38 | else: 39 | mask = np.zeros(img.shape[:2]) 40 | img, mask = self.transform(img, mask) 41 | 42 | return to_float_tensor(img), str(img_file_name) 43 | 44 | 45 | def to_float_tensor(img): 46 | return torch.from_numpy(np.moveaxis(img, -1, 0)).float() 47 | 48 | 49 | def load_image(path): 50 | img = cv2.imread(str(path)) 51 | return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 52 | 53 | 54 | def load_mask(path): 55 | mask = cv2.imread(str(path).replace('images', 'masks').replace(r'.jpg', r'_a.jpg'), 0) 56 | return (mask > 0).astype(np.uint8) 57 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import argparse 3 | import cv2 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | 8 | def general_dice(y_true, y_pred): 9 | if y_true.sum() == 0: 10 | if y_pred.sum() == 0: 11 | return 1 12 | else: 13 | return 0 14 | 15 | return dice(y_true == 1, y_pred == 1) 16 | 17 | 18 | def general_jaccard(y_true, y_pred): 19 | if y_true.sum() == 0: 20 | if y_pred.sum() == 0: 21 | return 1 22 | else: 23 | return 0 24 | 25 | return jaccard(y_true == 1, y_pred == 1) 26 | 27 | 28 | def jaccard(y_true, y_pred): 29 | intersection = (y_true * y_pred).sum() 30 | union = y_true.sum() + y_pred.sum() - intersection 31 | return (intersection + 1e-15) / (union + 1e-15) 32 | 33 | 34 | def dice(y_true, y_pred): 35 | return (2 * (y_true * y_pred).sum() + 1e-15) / (y_true.sum() + y_pred.sum() + 1e-15) 36 | 37 | 38 | if __name__ == '__main__': 39 | parser = argparse.ArgumentParser() 40 | arg = parser.add_argument 41 | 42 | arg('--train_path', type=str, default='data/train/angyodysplasia/masks', help='path where train images with ground truth are located') 43 | arg('--target_path', type=str, default='predictions/UNet', help='path with predictions') 44 | args = parser.parse_args() 45 | 46 | result_dice = [] 47 | result_jaccard = [] 48 | 49 | for file_name in tqdm(list(Path(args.train_path).glob('*'))): 50 | y_true = (cv2.imread(str(file_name), 0) > 255 * 0.5).astype(np.uint8) 51 | 52 | pred_file_name = Path(args.target_path) / (file_name.stem.replace('_a', '') + '.png') 53 | 54 | y_pred = (cv2.imread(str(pred_file_name), 0) > 255 * 0.5).astype(np.uint8) 55 | 56 | result_dice += [dice(y_true, y_pred)] 57 | result_jaccard += [jaccard(y_true, y_pred)] 58 | 59 | print('Dice = ', np.mean(result_dice), np.std(result_dice)) 60 | print('Jaccard = ', np.mean(result_jaccard), np.std(result_jaccard)) 61 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from datetime import datetime 3 | from pathlib import Path 4 | 5 | import random 6 | import numpy as np 7 | 8 | import torch 9 | from torch.autograd import Variable 10 | import tqdm 11 | 12 | 13 | def variable(x, volatile=False): 14 | if isinstance(x, (list, tuple)): 15 | return [variable(y, volatile=volatile) for y in x] 16 | return cuda(Variable(x, volatile=volatile)) 17 | 18 | 19 | def cuda(x): 20 | return x.cuda(async=True) if torch.cuda.is_available() else x 21 | 22 | 23 | def write_event(log, step: int, **data): 24 | data['step'] = step 25 | data['dt'] = datetime.now().isoformat() 26 | log.write(json.dumps(data, sort_keys=True)) 27 | log.write('\n') 28 | log.flush() 29 | 30 | 31 | def train(args, model, criterion, train_loader, valid_loader, validation, init_optimizer, n_epochs=None, fold=None): 32 | lr = args.lr 33 | n_epochs = n_epochs or args.n_epochs 34 | optimizer = init_optimizer(lr) 35 | 36 | root = Path(args.root) 37 | model_path = root / 'model_{fold}.pt'.format(fold=fold) 38 | if model_path.exists(): 39 | state = torch.load(str(model_path)) 40 | epoch = state['epoch'] 41 | step = state['step'] 42 | model.load_state_dict(state['model']) 43 | print('Restored model, epoch {}, step {:,}'.format(epoch, step)) 44 | else: 45 | epoch = 1 46 | step = 0 47 | 48 | save = lambda ep: torch.save({ 49 | 'model': model.state_dict(), 50 | 'epoch': ep, 51 | 'step': step, 52 | }, str(model_path)) 53 | 54 | report_each = 10 55 | log = root.joinpath('train_{fold}.log'.format(fold=fold)).open('at', encoding='utf8') 56 | valid_losses = [] 57 | for epoch in range(epoch, n_epochs + 1): 58 | model.train() 59 | random.seed() 60 | tq = tqdm.tqdm(total=(len(train_loader) * args.batch_size)) 61 | tq.set_description('Epoch {}, lr {}'.format(epoch, lr)) 62 | losses = [] 63 | tl = train_loader 64 | try: 65 | mean_loss = 0 66 | for i, (inputs, targets) in enumerate(tl): 67 | inputs, targets = variable(inputs), variable(targets) 68 | outputs = model(inputs) 69 | loss = criterion(outputs, targets) 70 | optimizer.zero_grad() 71 | batch_size = inputs.size(0) 72 | loss.backward() 73 | optimizer.step() 74 | step += 1 75 | tq.update(batch_size) 76 | losses.append(loss.data[0]) 77 | mean_loss = np.mean(losses[-report_each:]) 78 | tq.set_postfix(loss='{:.5f}'.format(mean_loss)) 79 | if i and i % report_each == 0: 80 | write_event(log, step, loss=mean_loss) 81 | write_event(log, step, loss=mean_loss) 82 | tq.close() 83 | save(epoch + 1) 84 | valid_metrics = validation(model, criterion, valid_loader) 85 | write_event(log, step, **valid_metrics) 86 | valid_loss = valid_metrics['valid_loss'] 87 | valid_losses.append(valid_loss) 88 | except KeyboardInterrupt: 89 | tq.close() 90 | print('Ctrl+C, saving snapshot') 91 | save(epoch) 92 | print('done.') 93 | return 94 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pathlib import Path 4 | from validation import validation_binary 5 | import torch 6 | from torch import nn 7 | from torch.optim import Adam 8 | from torch.utils.data import DataLoader 9 | import torch.backends.cudnn as cudnn 10 | import torch.backends.cudnn 11 | 12 | from models import UNet, UNet11, UNet16, LinkNet34, AlbuNet34 13 | from loss import LossBinary 14 | from dataset import AngyodysplasiaDataset 15 | import utils 16 | 17 | from prepare_train_val import get_split 18 | 19 | from transforms import (DualCompose, 20 | ImageOnly, 21 | Normalize, 22 | HorizontalFlip, 23 | Rotate, 24 | CenterCrop, 25 | RandomHueSaturationValue, 26 | VerticalFlip) 27 | 28 | 29 | def main(): 30 | parser = argparse.ArgumentParser() 31 | arg = parser.add_argument 32 | arg('--jaccard-weight', default=0.3, type=float) 33 | arg('--device-ids', type=str, default='0', help='For example 0,1 to run on two GPUs') 34 | arg('--fold', type=int, help='fold', default=0) 35 | arg('--root', default='runs/debug', help='checkpoint root') 36 | arg('--batch-size', type=int, default=1) 37 | arg('--limit', type=int, default=10000, help='number of images in epoch') 38 | arg('--n-epochs', type=int, default=100) 39 | arg('--lr', type=float, default=0.0001) 40 | arg('--workers', type=int, default=12) 41 | arg('--model', type=str, default='UNet', choices=['UNet', 'UNet11', 'LinkNet34', 'UNet16', 'AlbuNet34']) 42 | 43 | args = parser.parse_args() 44 | 45 | root = Path(args.root) 46 | root.mkdir(exist_ok=True, parents=True) 47 | 48 | num_classes = 1 49 | if args.model == 'UNet': 50 | model = UNet(num_classes=num_classes) 51 | elif args.model == 'UNet11': 52 | model = UNet11(num_classes=num_classes, pretrained=True) 53 | elif args.model == 'UNet16': 54 | model = UNet16(num_classes=num_classes, pretrained=True) 55 | elif args.model == 'LinkNet34': 56 | model = LinkNet34(num_classes=num_classes, pretrained=True) 57 | elif args.model == 'AlbuNet': 58 | model = AlbuNet34(num_classes=num_classes, pretrained=True) 59 | else: 60 | model = UNet(num_classes=num_classes, input_channels=3) 61 | 62 | if torch.cuda.is_available(): 63 | if args.device_ids: 64 | device_ids = list(map(int, args.device_ids.split(','))) 65 | else: 66 | device_ids = None 67 | model = nn.DataParallel(model, device_ids=device_ids).cuda() 68 | 69 | loss = LossBinary(jaccard_weight=args.jaccard_weight) 70 | 71 | cudnn.benchmark = True 72 | 73 | def make_loader(file_names, shuffle=False, transform=None, limit=None): 74 | return DataLoader( 75 | dataset=AngyodysplasiaDataset(file_names, transform=transform, limit=limit), 76 | shuffle=shuffle, 77 | num_workers=args.workers, 78 | batch_size=args.batch_size, 79 | pin_memory=torch.cuda.is_available() 80 | ) 81 | 82 | train_file_names, val_file_names = get_split(args.fold) 83 | 84 | print('num train = {}, num_val = {}'.format(len(train_file_names), len(val_file_names))) 85 | 86 | train_transform = DualCompose([ 87 | CenterCrop(512), 88 | HorizontalFlip(), 89 | VerticalFlip(), 90 | Rotate(), 91 | ImageOnly(RandomHueSaturationValue()), 92 | ImageOnly(Normalize()) 93 | ]) 94 | 95 | val_transform = DualCompose([ 96 | CenterCrop(512), 97 | ImageOnly(Normalize()) 98 | ]) 99 | 100 | train_loader = make_loader(train_file_names, shuffle=True, transform=train_transform, limit=args.limit) 101 | valid_loader = make_loader(val_file_names, transform=val_transform) 102 | 103 | root.joinpath('params.json').write_text( 104 | json.dumps(vars(args), indent=True, sort_keys=True)) 105 | 106 | utils.train( 107 | init_optimizer=lambda lr: Adam(model.parameters(), lr=lr), 108 | args=args, 109 | model=model, 110 | criterion=loss, 111 | train_loader=train_loader, 112 | valid_loader=valid_loader, 113 | validation=validation_binary, 114 | fold=args.fold 115 | ) 116 | 117 | 118 | if __name__ == '__main__': 119 | main() 120 | -------------------------------------------------------------------------------- /generate_masks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script generates predictions, splitting original images into tiles, and assembling prediction back together 3 | """ 4 | import argparse 5 | from prepare_train_val import get_split 6 | from dataset import AngyodysplasiaDataset 7 | import cv2 8 | from models import UNet, UNet11, UNet16, AlbuNet34 9 | import torch 10 | from pathlib import Path 11 | from tqdm import tqdm 12 | import numpy as np 13 | import utils 14 | # import prepare_data 15 | from torch.utils.data import DataLoader 16 | from torch.nn import functional as F 17 | 18 | from transforms import (ImageOnly, 19 | Normalize, 20 | CenterCrop, 21 | DualCompose) 22 | 23 | img_transform = DualCompose([ 24 | CenterCrop(512), 25 | ImageOnly(Normalize()) 26 | ]) 27 | 28 | 29 | def get_model(model_path, model_type): 30 | """ 31 | 32 | :param model_path: 33 | :param model_type: 'UNet', 'UNet11', 'UNet16', 'AlbuNet34' 34 | :return: 35 | """ 36 | 37 | num_classes = 1 38 | 39 | if model_type == 'UNet11': 40 | model = UNet11(num_classes=num_classes) 41 | elif model_type == 'UNet16': 42 | model = UNet16(num_classes=num_classes) 43 | elif model_type == 'AlbuNet34': 44 | model = AlbuNet34(num_classes=num_classes) 45 | elif model_type == 'UNet': 46 | model = UNet(num_classes=num_classes) 47 | else: 48 | model = UNet(num_classes=num_classes) 49 | 50 | state = torch.load(str(model_path)) 51 | state = {key.replace('module.', ''): value for key, value in state['model'].items()} 52 | model.load_state_dict(state) 53 | 54 | if torch.cuda.is_available(): 55 | return model.cuda() 56 | 57 | model.eval() 58 | 59 | return model 60 | 61 | 62 | def predict(model, from_file_names, batch_size: int, to_path): 63 | loader = DataLoader( 64 | dataset=AngyodysplasiaDataset(from_file_names, transform=img_transform, mode='predict'), 65 | shuffle=False, 66 | batch_size=batch_size, 67 | num_workers=args.workers, 68 | pin_memory=torch.cuda.is_available() 69 | ) 70 | 71 | for batch_num, (inputs, paths) in enumerate(tqdm(loader, desc='Predict')): 72 | inputs = utils.variable(inputs, volatile=True) 73 | 74 | outputs = model(inputs) 75 | 76 | for i, image_name in enumerate(paths): 77 | mask = (F.sigmoid(outputs[i, 0]).data.cpu().numpy() * 255).astype(np.uint8) 78 | 79 | h, w = mask.shape 80 | 81 | full_mask = np.zeros((576, 576)) 82 | full_mask[32:32 + h, 32:32 + w] = mask 83 | 84 | (to_path / args.model_type).mkdir(exist_ok=True, parents=True) 85 | 86 | cv2.imwrite(str(to_path / args.model_type / (Path(paths[i]).stem + '.png')), full_mask) 87 | 88 | 89 | if __name__ == '__main__': 90 | parser = argparse.ArgumentParser() 91 | arg = parser.add_argument 92 | arg('--model_path', type=str, default='data/models/UNet', help='path to model folder') 93 | arg('--model_type', type=str, default='UNet', help='network architecture', 94 | choices=['UNet', 'UNet11', 'UNet16', 'AlbuNet34']) 95 | arg('--batch-size', type=int, default=4) 96 | arg('--fold', type=int, default=-1, choices=[0, 1, 2, 3, 4, -1], help='-1: all folds') 97 | arg('--workers', type=int, default=12) 98 | 99 | args = parser.parse_args() 100 | 101 | if args.fold == -1: 102 | for fold in [0, 1, 2, 3, 4]: 103 | _, file_names = get_split(fold) 104 | 105 | print(len(file_names)) 106 | 107 | model = get_model(str(Path(args.model_path).joinpath('model_{fold}.pt'.format(fold=fold))), 108 | model_type=args.model_type) 109 | 110 | print('num file_names = {}'.format(len(file_names))) 111 | 112 | output_path = Path(args.model_path) 113 | output_path.mkdir(exist_ok=True, parents=True) 114 | 115 | predict(model, file_names, args.batch_size, output_path) 116 | else: 117 | _, file_names = get_split(args.fold) 118 | model = get_model(str(Path(args.model_path).joinpath('model_{fold}.pt'.format(fold=args.fold))), 119 | model_type=args.model_type) 120 | 121 | print('num file_names = {}'.format(len(file_names))) 122 | 123 | output_path = Path(args.model_path) 124 | output_path.mkdir(exist_ok=True, parents=True) 125 | 126 | predict(model, file_names, args.batch_size, output_path) 127 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ================================================================================= 2 | MICCAI 2017 Endoscopic Vision Challenge Angiodysplasia Detection and Localization 3 | ================================================================================= 4 | 5 | Here we present our wining solution and its further development for `MICCAI 2017 Endoscopic Vision Challenge Angiodysplasia Detection and Localization`_. It addresses binary segmentation problem, where every pixel in image is labeled as an angiodysplasia lesions or background. Then, we analyze connected component of each predicted mask. Based on the analysis we developed a classifier that predict angiodysplasia lesions (binary variable) and a detector for their localization (center of a component). 6 | 7 | .. contents:: 8 | 9 | Team members 10 | ------------ 11 | `Alexey Shvets`_, `Vladimir Iglovikov`_, `Alexander Rakhlin`_, `Alexandr A. Kalinin`_ 12 | 13 | Citation 14 | ---------- 15 | 16 | If you find this work useful for your publications, please consider citing:: 17 | 18 | @inproceedings{shvets2018angiodysplasia, 19 | title={Angiodysplasia Detection and Localization using Deep Convolutional Neural Networks}, 20 | author={Shvets, Alexey A and Iglovikov, Vladimir I and Rakhlin, Alexander and Kalinin, Alexandr A}, 21 | booktitle={2018 17th IEEE International Conference on Machine Learning and Applications (ICMLA)}, 22 | pages={612--617}, 23 | year={2018} 24 | } 25 | 26 | Overview 27 | -------- 28 | Angiodysplasias are degenerative lesions of previously healthy blood vessels, in which the bowel wall have microvascular abnormalities. These lesions are the most common source of small bowel bleeding in patients older than 50 years, and cause approximately 8% of all gastrointestinal bleeding episodes. Gold-standard examination for angiodysplasia detection and localization in the small bowel is performed using Wireless Capsule Endoscopy (WCE). Last generation of this pill-like device is able to acquire more than 60 000 images with a resolution of approximately 520*520 pixels. According to the latest state-of-the art, only 69% of angiodysplasias are detected by gastroenterologist experts during the reading of WCE videos, and blood indicator software (provided by WCE provider like Given Imaging), in the presence of angiodysplasias, presents sensitivity and specificity values of only 41% and 67%, respectively. 29 | 30 | .. figure:: https://habrastorage.org/webt/if/5p/tj/if5ptjnbzeswfgqderpcww0sstm.jpeg 31 | 32 | Data 33 | ---- 34 | The dataset consists of 1200 color images obtained with WCE. The images are in 24-bit PNG format, with 576 |times| 576 pixel resolution. The dataset is split into two equal parts, 600 images for training and 600 for evaluation. Each subset is composed of 300 images with apparent AD and 300 without any pathology. The training subset is annotated by human expert and contains 300 binary masks in JPEG format of the same 576 |times| 576 pixel resolution. White pixels in the masks correspond to lesion localization. 35 | 36 | .. figure:: https://hsto.org/webt/nq/3v/wf/nq3vwfqtoutrzmnbzmrnyligwym.png 37 | :scale: 30 % 38 | 39 | First row corresponds to images without pathology, the second row to images with several AD lesions in every image, and the last row contains masks that correspond to the pathology images from the second row. 40 | 41 | | 42 | | 43 | | 44 | 45 | .. figure:: https://habrastorage.org/webt/t3/p6/yy/t3p6yykecrvr9mim7fqgevodgu4.png 46 | :scale: 45 % 47 | 48 | Most images contain 1 lesion. Distribution of AD lesion areas reaches maximum of 12,000 pixels and has median 1,648 pixels. 49 | 50 | 51 | Method 52 | ------ 53 | We evaluate 4 different deep architectures for segmentation: `U-Net`_ (Ronneberger et al., 2015; Iglovikov et al., 2017a), 2 modifications of `TernausNet`_ (Iglovikov and Shvets, 2018), and `AlbuNet34`_, a modifications of `LinkNet`_ (Chaurasia and Culurciello, 2017; Shvets et al., 2018). As an improvement over standard `U-Net`_, we use similar networks with pre-trained encoders. `TernausNet`_ (Iglovikov and Shvets, 2018) is a U-Net-like architecture that uses relatively simple pre-trained VGG11 or VGG16 (Simonyan and Zisserman, 2014) networks as an encoder. VGG11 consists of seven convolutional layers, each followed by a ReLU activation function, and ve max polling operations, each reducing feature map by 2. All convolutional layers have 3 |times| 3 kernels. TernausNet16 has a similar structure and uses VGG16 network as an encoder 54 | 55 | .. figure:: https://hsto.org/webt/vz/ok/wt/vzokwtntgqe6lb-g2oyhzj0qcyo.png 56 | :scale: 72 % 57 | 58 | .. figure:: https://hsto.org/webt/vs/by/8y/vsby8yt4bj_6n3pqdqlf2tb8r9a.png 59 | :scale: 72 % 60 | 61 | Training 62 | -------- 63 | 64 | We use Jaccard index (Intersection Over Union) as the evaluation metric. It can be interpreted as a similarity measure between a finite number of sets. For two sets A and B, it can be defined as following: 65 | 66 | .. raw:: html 67 | 68 |
69 | 70 |
71 | 72 | Since an image consists of pixels, the expression can be adapted for discrete objects in the following way: 73 | 74 | .. figure:: https://habrastorage.org/webt/_8/wc/j1/_8wcj1to6ahxfsmb8s3nrxumqjy.gif 75 | :align: center 76 | 77 | where |y| and |y_hat| are a binary value (label) and a predicted probability for the pixel |i|, respectively. 78 | 79 | Since image segmentation task can also be considered as a pixel classification problem, we additionally use common classification loss functions, denoted as H. For a binary segmentation problem H is a binary cross entropy, while for a multi-class segmentation problem H is a categorical cross entropy. 80 | 81 | .. figure:: https://habrastorage.org/webt/tf/d0/kn/tfd0kn2l612do_wmlc6zp5rdgdw.gif 82 | :align: center 83 | 84 | As an output of a model, we obtain an image, in which each pixel value corresponds to a probability of belonging to the area of interest or a class. The size of the output image matches the input image size. For binary segmentation, we use 0.3 as a threshold value (chosen using validation dataset) to binarize pixel probabilities. All pixel values below the speci ed threshold are set to 0, while all values above the threshold are set to 255 to produce final prediction mask. 85 | 86 | Following the segmentation step, we perform postprocessing in order to nd the coordinates of angiodysplasia lesions in the image. In the postprocessing step we use OpenCV implementation of connected component labeling function `connectedComponentsWithStats`. This function returns the number of connected components, their sizes (areas), and centroid coordinates of the corresponding connected component. In our detector we use another threshold to neglect all clusters with the size smaller than 300 pixels. Therefore, in order to establish the presence of the lesions, the number of found components should be higher than 0, otherwise the image corresponds to a normal condition. Then, for localization of angiodysplasia lesions we return centroid coordinates of all connected components. 87 | 88 | Results 89 | ------- 90 | 91 | The quantitative comparison of our models' performance is presented in the Table 1. For the segmentation task the best results is achieved by `AlbuNet34`_ providing IoU = 0.754 and Dice = 0.850. When compared by the inference time, `AlbuNet34`_ is also the fastest model due to the light encoder. In the segmentation task this network takes around 20ms 92 | 93 | .. figure:: https://hsto.org/webt/mw/yj/-l/mwyj-l6ddk6xz-ykydduixzhrdk.png 94 | :scale: 60 % 95 | 96 | Prediction of our detector on the validation image. The left picture is original image, the central is ground truth mask, and the right is predicted mask. Green dots correspond to centroid coordinates that define localization of the angiodysplasia. 97 | 98 | | 99 | | 100 | | 101 | 102 | .. table:: Table 1. Segmentation results per task. Intersection over Union, Dice coefficient and inference time, ms. 103 | 104 | ============= ========= ========= ================== 105 | Model IOU, % Dice, % Inference time, ms 106 | ============= ========= ========= ================== 107 | U-Net 73.18 83.06 21 108 | TernausNet-11 74.94 84.43 51 109 | TernausNet-16 73.83 83.05 60 110 | AlbuNet34 75.35 84.98 30 111 | ============= ========= ========= ================== 112 | 113 | Pre-trained weights for all model of all segmentation tasks can be found on `google drive`_ 114 | 115 | Dependencies 116 | ------------ 117 | 118 | * Python 3.6 119 | * PyTorch 0.3.1 120 | * TorchVision 0.1.9 121 | * numpy 1.14.0 122 | * opencv-python 3.3.0.10 123 | * tqdm 4.19.4 124 | 125 | These dependencies can be installed by running:: 126 | 127 | pip install -r requirements.txt 128 | 129 | 130 | How to run 131 | ---------- 132 | The dataset is organized in the folloing way:: 133 | :: 134 | 135 | ├── data 136 | │   ├── test 137 | │   └── train 138 | │   ├── angyodysplasia 139 | │   │   ├── images 140 | │   │   └── masks 141 | │   └── normal 142 | │   ├── images 143 | │   └── masks 144 | │   ....................... 145 | 146 | The training dataset contains 2 sets of images, one with angyodysplasia and second without it. For training we used only the images with angyodysplasia, which were split in 5 folds. 147 | 148 | 1. Training 149 | 150 | The main file that is used to train all models - ``train.py``. Running ``python train.py --help`` will return set of all possible input parameters. 151 | To train all models we used the folloing bash script (batch size was chosen depending on how many samples fit into the GPU RAM, limit was adjusted accordingly to keep the same number of updates for every network):: 152 | 153 | #!/bin/bash 154 | 155 | for i in 0 1 2 3 156 | do 157 | python train.py --device-ids 0,1,2,3 --limit 10000 --batch-size 12 --fold $i --workers 12 --lr 0.0001 --n-epochs 10 --jaccard-weight 0.3 --model UNet11 158 | python train.py --device-ids 0,1,2,3 --limit 10000 --batch-size 12 --fold $i --workers 12 --lr 0.00001 --n-epochs 15 --jaccard-weight 0.3 --model UNet11 159 | done 160 | 161 | 2. Mask generation. 162 | 163 | The main file to generate masks is ``generate_masks.py``. Running ``python generate_masks.py --help`` will return set of all possible input parameters. Example:: 164 | 165 | python generate_masks.py --output_path predictions/UNet16 --model_type UNet16 --model_path data/models/UNet16 --fold -1 --batch-size 4 166 | 167 | 3. Evaluation. 168 | 169 | The evaluation is different for a binary and multi-class segmentation: 170 | 171 | [a] In the case of binary segmentation it calculates jaccard (dice) per image / per video and then the predictions are avaraged. 172 | 173 | [b] In the case of multi-class segmentation it calculates jaccard (dice) for every class independently then avaraged them for each image and then for every video:: 174 | 175 | python evaluate.py --target_path predictions/UNet16 --train_path data/train/angyodysplasia/masks 176 | 177 | 4. Further Improvements. 178 | 179 | Our results can be improved further by few percentages using simple rules such as additional augmentation of train images and train the model for longer time. In addition, the cyclic learning rate or cosine annealing could be also applied. To do it one can use our pre-trained weights as initialization. To improve test prediction TTA technique could be used as well as averaging prediction from all folds. 180 | 181 | Demo Example 182 | ------------ 183 | You can start working with our models using the demonstration example: `Demo.ipynb`_ 184 | 185 | .. _`Demo.ipynb`: Demo.ipynb 186 | .. _`Alexander Rakhlin`: https://www.linkedin.com/in/alrakhlin/ 187 | .. _`Alexey Shvets`: https://www.linkedin.com/in/shvetsiya/ 188 | .. _`Vladimir Iglovikov`: https://www.linkedin.com/in/iglovikov/ 189 | .. _`Alexandr A. Kalinin`: https://alxndrkalinin.github.io/ 190 | .. _`MICCAI 2017 Endoscopic Vision SubChallenge Angiodysplasia Detection and Localization`: https://endovissub2017-giana.grand-challenge.org/angiodysplasia-etisdb/ 191 | .. _`TernausNet`: https://arxiv.org/abs/1801.05746 192 | .. _`U-Net`: https://arxiv.org/abs/1505.04597 193 | .. _`AlbuNet34`: https://arxiv.org/abs/1803.01207 194 | .. _`LinkNet`: https://arxiv.org/abs/1707.03718 195 | .. _`google drive`: https://drive.google.com/drive/folders/1V_bLBTzsl_Z8Ln9Iq8gjcFDxodfiHxul 196 | 197 | .. |br| raw:: html 198 | 199 |
200 | 201 | .. |plusmn| raw:: html 202 | 203 | ± 204 | 205 | .. |times| raw:: html 206 | 207 | × 208 | 209 | .. |micro| raw:: html 210 | 211 | µm 212 | 213 | .. |y| image:: https://hsto.org/webt/jm/sn/i0/jmsni0y8mao8vnaij8a4eyuoqmu.gif 214 | .. |y_hat| image:: https://hsto.org/webt/xf/j2/a4/xfj2a4obgqhdzneysar5_us5pgk.gif 215 | .. |i| image:: https://hsto.org/webt/87/cc/ca/87cccaz4gjp2lgyeip17utljvvi.gif 216 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torchvision import models 4 | import torchvision 5 | from torch.nn import functional as F 6 | 7 | 8 | def conv3x3(in_, out): 9 | return nn.Conv2d(in_, out, 3, padding=1) 10 | 11 | 12 | class ConvRelu(nn.Module): 13 | def __init__(self, in_: int, out: int): 14 | super().__init__() 15 | self.conv = conv3x3(in_, out) 16 | self.activation = nn.ReLU(inplace=True) 17 | 18 | def forward(self, x): 19 | x = self.conv(x) 20 | x = self.activation(x) 21 | return x 22 | 23 | 24 | class DecoderBlock(nn.Module): 25 | def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True): 26 | super(DecoderBlock, self).__init__() 27 | self.in_channels = in_channels 28 | 29 | if is_deconv: 30 | self.block = nn.Sequential( 31 | ConvRelu(in_channels, middle_channels), 32 | nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2, 33 | padding=1), 34 | nn.ReLU(inplace=True) 35 | ) 36 | else: 37 | self.block = nn.Sequential( 38 | nn.Upsample(scale_factor=2, mode='bilinear'), 39 | ConvRelu(in_channels, middle_channels), 40 | ConvRelu(middle_channels, out_channels), 41 | ) 42 | 43 | def forward(self, x): 44 | return self.block(x) 45 | 46 | 47 | class UNet11(nn.Module): 48 | def __init__(self, num_classes=1, num_filters=32, pretrained=False, is_deconv=False): 49 | """ 50 | :param num_classes: 51 | :param num_filters: 52 | :param pretrained: 53 | False - no pre-trained network used 54 | vgg - encoder pre-trained with VGG11 55 | """ 56 | super().__init__() 57 | self.pool = nn.MaxPool2d(2, 2) 58 | 59 | self.num_classes = num_classes 60 | 61 | self.encoder = models.vgg11(pretrained=pretrained).features 62 | 63 | self.relu = nn.ReLU(inplace=True) 64 | self.conv1 = nn.Sequential(self.encoder[0], 65 | self.relu) 66 | 67 | self.conv2 = nn.Sequential(self.encoder[3], 68 | self.relu) 69 | 70 | self.conv3 = nn.Sequential( 71 | self.encoder[6], 72 | self.relu, 73 | self.encoder[8], 74 | self.relu, 75 | ) 76 | self.conv4 = nn.Sequential( 77 | self.encoder[11], 78 | self.relu, 79 | self.encoder[13], 80 | self.relu, 81 | ) 82 | 83 | self.conv5 = nn.Sequential( 84 | self.encoder[16], 85 | self.relu, 86 | self.encoder[18], 87 | self.relu, 88 | ) 89 | 90 | self.center = DecoderBlock(256 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv=is_deconv) 91 | self.dec5 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv=is_deconv) 92 | self.dec4 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 4, is_deconv=is_deconv) 93 | self.dec3 = DecoderBlock(256 + num_filters * 4, num_filters * 4 * 2, num_filters * 2, is_deconv=is_deconv) 94 | self.dec2 = DecoderBlock(128 + num_filters * 2, num_filters * 2 * 2, num_filters, is_deconv=is_deconv) 95 | self.dec1 = ConvRelu(64 + num_filters, num_filters) 96 | 97 | self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1) 98 | 99 | def forward(self, x): 100 | conv1 = self.conv1(x) 101 | conv2 = self.conv2(self.pool(conv1)) 102 | conv3 = self.conv3(self.pool(conv2)) 103 | conv4 = self.conv4(self.pool(conv3)) 104 | conv5 = self.conv5(self.pool(conv4)) 105 | center = self.center(self.pool(conv5)) 106 | 107 | dec5 = self.dec5(torch.cat([center, conv5], 1)) 108 | dec4 = self.dec4(torch.cat([dec5, conv4], 1)) 109 | dec3 = self.dec3(torch.cat([dec4, conv3], 1)) 110 | dec2 = self.dec2(torch.cat([dec3, conv2], 1)) 111 | dec1 = self.dec1(torch.cat([dec2, conv1], 1)) 112 | 113 | if self.num_classes > 1: 114 | x_out = F.log_softmax(self.final(dec1), dim=1) 115 | else: 116 | x_out = self.final(dec1) 117 | 118 | return x_out 119 | 120 | 121 | class UNet16(nn.Module): 122 | def __init__(self, num_classes=1, num_filters=32, pretrained=False, is_deconv=False): 123 | """ 124 | :param num_classes: 125 | :param num_filters: 126 | :param pretrained: if encoder uses pre-trained weigths from VGG16 127 | """ 128 | super().__init__() 129 | self.num_classes = num_classes 130 | 131 | self.pool = nn.MaxPool2d(2, 2) 132 | 133 | self.encoder = torchvision.models.vgg16(pretrained=pretrained).features 134 | 135 | self.relu = nn.ReLU(inplace=True) 136 | 137 | self.conv1 = nn.Sequential(self.encoder[0], 138 | self.relu, 139 | self.encoder[2], 140 | self.relu) 141 | 142 | self.conv2 = nn.Sequential(self.encoder[5], 143 | self.relu, 144 | self.encoder[7], 145 | self.relu) 146 | 147 | self.conv3 = nn.Sequential(self.encoder[10], 148 | self.relu, 149 | self.encoder[12], 150 | self.relu, 151 | self.encoder[14], 152 | self.relu) 153 | 154 | self.conv4 = nn.Sequential(self.encoder[17], 155 | self.relu, 156 | self.encoder[19], 157 | self.relu, 158 | self.encoder[21], 159 | self.relu) 160 | 161 | self.conv5 = nn.Sequential(self.encoder[24], 162 | self.relu, 163 | self.encoder[26], 164 | self.relu, 165 | self.encoder[28], 166 | self.relu) 167 | 168 | self.center = DecoderBlock(512, num_filters * 8 * 2, num_filters * 8, is_deconv=is_deconv) 169 | 170 | self.dec5 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv=is_deconv) 171 | self.dec4 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv=is_deconv) 172 | self.dec3 = DecoderBlock(256 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv=is_deconv) 173 | self.dec2 = DecoderBlock(128 + num_filters * 2, num_filters * 2 * 2, num_filters, is_deconv=is_deconv) 174 | self.dec1 = ConvRelu(64 + num_filters, num_filters) 175 | self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1) 176 | 177 | def forward(self, x): 178 | conv1 = self.conv1(x) 179 | conv2 = self.conv2(self.pool(conv1)) 180 | conv3 = self.conv3(self.pool(conv2)) 181 | conv4 = self.conv4(self.pool(conv3)) 182 | conv5 = self.conv5(self.pool(conv4)) 183 | 184 | center = self.center(self.pool(conv5)) 185 | 186 | dec5 = self.dec5(torch.cat([center, conv5], 1)) 187 | 188 | dec4 = self.dec4(torch.cat([dec5, conv4], 1)) 189 | dec3 = self.dec3(torch.cat([dec4, conv3], 1)) 190 | dec2 = self.dec2(torch.cat([dec3, conv2], 1)) 191 | dec1 = self.dec1(torch.cat([dec2, conv1], 1)) 192 | 193 | if self.num_classes > 1: 194 | x_out = F.log_softmax(self.final(dec1), dim=1) 195 | else: 196 | x_out = self.final(dec1) 197 | 198 | return x_out 199 | 200 | 201 | class Conv3BN(nn.Module): 202 | def __init__(self, in_: int, out: int, bn=False): 203 | super().__init__() 204 | self.conv = conv3x3(in_, out) 205 | self.bn = nn.BatchNorm2d(out) if bn else None 206 | self.activation = nn.ReLU(inplace=True) 207 | 208 | def forward(self, x): 209 | x = self.conv(x) 210 | if self.bn is not None: 211 | x = self.bn(x) 212 | x = self.activation(x) 213 | return x 214 | 215 | 216 | class UNetModule(nn.Module): 217 | def __init__(self, in_: int, out: int): 218 | super().__init__() 219 | self.l1 = Conv3BN(in_, out) 220 | self.l2 = Conv3BN(out, out) 221 | 222 | def forward(self, x): 223 | x = self.l1(x) 224 | x = self.l2(x) 225 | return x 226 | 227 | 228 | class UNet(nn.Module): 229 | """ 230 | Vanilla UNet. 231 | 232 | Implementation from https://github.com/lopuhin/mapillary-vistas-2017/blob/master/unet_models.py 233 | """ 234 | output_downscaled = 1 235 | module = UNetModule 236 | 237 | def __init__(self, 238 | input_channels: int = 3, 239 | filters_base: int = 32, 240 | down_filter_factors=(1, 2, 4, 8, 16), 241 | up_filter_factors=(1, 2, 4, 8, 16), 242 | bottom_s=4, 243 | num_classes=1, 244 | add_output=True): 245 | super().__init__() 246 | self.num_classes = num_classes 247 | assert len(down_filter_factors) == len(up_filter_factors) 248 | assert down_filter_factors[-1] == up_filter_factors[-1] 249 | down_filter_sizes = [filters_base * s for s in down_filter_factors] 250 | up_filter_sizes = [filters_base * s for s in up_filter_factors] 251 | self.down, self.up = nn.ModuleList(), nn.ModuleList() 252 | self.down.append(self.module(input_channels, down_filter_sizes[0])) 253 | for prev_i, nf in enumerate(down_filter_sizes[1:]): 254 | self.down.append(self.module(down_filter_sizes[prev_i], nf)) 255 | for prev_i, nf in enumerate(up_filter_sizes[1:]): 256 | self.up.append(self.module( 257 | down_filter_sizes[prev_i] + nf, up_filter_sizes[prev_i])) 258 | pool = nn.MaxPool2d(2, 2) 259 | pool_bottom = nn.MaxPool2d(bottom_s, bottom_s) 260 | upsample = nn.Upsample(scale_factor=2) 261 | upsample_bottom = nn.Upsample(scale_factor=bottom_s) 262 | self.downsamplers = [None] + [pool] * (len(self.down) - 1) 263 | self.downsamplers[-1] = pool_bottom 264 | self.upsamplers = [upsample] * len(self.up) 265 | self.upsamplers[-1] = upsample_bottom 266 | self.add_output = add_output 267 | if add_output: 268 | self.conv_final = nn.Conv2d(up_filter_sizes[0], num_classes, 1) 269 | 270 | def forward(self, x): 271 | xs = [] 272 | for downsample, down in zip(self.downsamplers, self.down): 273 | x_in = x if downsample is None else downsample(xs[-1]) 274 | x_out = down(x_in) 275 | xs.append(x_out) 276 | 277 | x_out = xs[-1] 278 | for x_skip, upsample, up in reversed( 279 | list(zip(xs[:-1], self.upsamplers, self.up))): 280 | x_out = upsample(x_out) 281 | x_out = up(torch.cat([x_out, x_skip], 1)) 282 | 283 | if self.add_output: 284 | x_out = self.conv_final(x_out) 285 | if self.num_classes > 1: 286 | x_out = F.log_softmax(x_out, dim=1) 287 | return x_out 288 | 289 | 290 | class AlbuNet34(nn.Module): 291 | """ 292 | UNet (https://arxiv.org/abs/1505.04597) with Resnet34(https://arxiv.org/abs/1512.03385) encoder 293 | Proposed by Alexander Buslaev: https://www.linkedin.com/in/al-buslaev/ 294 | """ 295 | 296 | def __init__(self, num_classes=1, num_filters=32, pretrained=False, is_deconv=False): 297 | """ 298 | :param num_classes: 299 | :param num_filters: 300 | :param pretrained: 301 | False - no pre-trained network is used 302 | True - encoder is pre-trained with resnet34 303 | :is_deconv: 304 | False: bilinear interpolation is used in decoder 305 | True: deconvolution is used in decoder 306 | """ 307 | super().__init__() 308 | self.num_classes = num_classes 309 | 310 | self.pool = nn.MaxPool2d(2, 2) 311 | 312 | self.encoder = torchvision.models.resnet34(pretrained=pretrained) 313 | 314 | self.relu = nn.ReLU(inplace=True) 315 | 316 | self.conv1 = nn.Sequential(self.encoder.conv1, 317 | self.encoder.bn1, 318 | self.encoder.relu, 319 | self.pool) 320 | 321 | self.conv2 = self.encoder.layer1 322 | 323 | self.conv3 = self.encoder.layer2 324 | 325 | self.conv4 = self.encoder.layer3 326 | 327 | self.conv5 = self.encoder.layer4 328 | 329 | self.center = DecoderBlock(512, num_filters * 8 * 2, num_filters * 8, is_deconv) 330 | 331 | self.dec5 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv) 332 | self.dec4 = DecoderBlock(256 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv) 333 | self.dec3 = DecoderBlock(128 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv) 334 | self.dec2 = DecoderBlock(64 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2, is_deconv) 335 | self.dec1 = DecoderBlock(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv) 336 | self.dec0 = ConvRelu(num_filters, num_filters) 337 | self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1) 338 | 339 | def forward(self, x): 340 | conv1 = self.conv1(x) 341 | conv2 = self.conv2(conv1) 342 | conv3 = self.conv3(conv2) 343 | conv4 = self.conv4(conv3) 344 | conv5 = self.conv5(conv4) 345 | 346 | center = self.center(self.pool(conv5)) 347 | 348 | dec5 = self.dec5(torch.cat([center, conv5], 1)) 349 | 350 | dec4 = self.dec4(torch.cat([dec5, conv4], 1)) 351 | dec3 = self.dec3(torch.cat([dec4, conv3], 1)) 352 | dec2 = self.dec2(torch.cat([dec3, conv2], 1)) 353 | dec1 = self.dec1(dec2) 354 | dec0 = self.dec0(dec1) 355 | 356 | if self.num_classes > 1: 357 | x_out = F.log_softmax(self.final(dec0), dim=1) 358 | else: 359 | x_out = self.final(dec0) 360 | 361 | return x_out 362 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on a set of transformations developed by Alexander Buslaev as a part of the winning solution (1 out of 735) 3 | to the Kaggle: Carvana Image Masking Challenge. 4 | 5 | https://github.com/asanakoy/kaggle_carvana_segmentation/blob/master/albu/src/transforms.py 6 | """ 7 | 8 | import random 9 | import cv2 10 | import numpy as np 11 | import math 12 | 13 | 14 | class DualCompose: 15 | def __init__(self, transforms): 16 | self.transforms = transforms 17 | 18 | def __call__(self, x, mask=None): 19 | for t in self.transforms: 20 | x, mask = t(x, mask) 21 | return x, mask 22 | 23 | 24 | class OneOf: 25 | def __init__(self, transforms, prob=0.5): 26 | self.transforms = transforms 27 | self.prob = prob 28 | 29 | def __call__(self, x, mask=None): 30 | if random.random() < self.prob: 31 | t = random.choice(self.transforms) 32 | t.prob = 1. 33 | x, mask = t(x, mask) 34 | return x, mask 35 | 36 | 37 | class OneOrOther: 38 | def __init__(self, first, second, prob=0.5): 39 | self.first = first 40 | first.prob = 1. 41 | self.second = second 42 | second.prob = 1. 43 | self.prob = prob 44 | 45 | def __call__(self, x, mask=None): 46 | if random.random() < self.prob: 47 | x, mask = self.first(x, mask) 48 | else: 49 | x, mask = self.second(x, mask) 50 | return x, mask 51 | 52 | 53 | class ImageOnly: 54 | def __init__(self, trans): 55 | self.trans = trans 56 | 57 | def __call__(self, x, mask=None): 58 | return self.trans(x), mask 59 | 60 | 61 | class VerticalFlip: 62 | def __init__(self, prob=0.5): 63 | self.prob = prob 64 | 65 | def __call__(self, img, mask=None): 66 | if random.random() < self.prob: 67 | img = cv2.flip(img, 0) 68 | if mask is not None: 69 | mask = cv2.flip(mask, 0) 70 | return img, mask 71 | 72 | 73 | class HorizontalFlip: 74 | def __init__(self, prob=0.5): 75 | self.prob = prob 76 | 77 | def __call__(self, img, mask=None): 78 | if random.random() < self.prob: 79 | img = cv2.flip(img, 1) 80 | if mask is not None: 81 | mask = cv2.flip(mask, 1) 82 | return img, mask 83 | 84 | 85 | class RandomFlip: 86 | def __init__(self, prob=0.5): 87 | self.prob = prob 88 | 89 | def __call__(self, img, mask=None): 90 | if random.random() < self.prob: 91 | d = random.randint(-1, 1) 92 | img = cv2.flip(img, d) 93 | if mask is not None: 94 | mask = cv2.flip(mask, d) 95 | return img, mask 96 | 97 | 98 | class Transpose: 99 | def __init__(self, prob=0.5): 100 | self.prob = prob 101 | 102 | def __call__(self, img, mask=None): 103 | if random.random() < self.prob: 104 | img = img.transpose(1, 0, 2) 105 | if mask is not None: 106 | mask = mask.transpose(1, 0) 107 | return img, mask 108 | 109 | 110 | class RandomRotate90: 111 | def __init__(self, prob=0.5): 112 | self.prob = prob 113 | 114 | def __call__(self, img, mask=None): 115 | if random.random() < self.prob: 116 | factor = random.randint(0, 4) 117 | img = np.rot90(img, factor) 118 | if mask is not None: 119 | mask = np.rot90(mask, factor) 120 | return img.copy(), mask.copy() 121 | 122 | 123 | class Rotate: 124 | def __init__(self, limit=90, prob=0.5): 125 | self.prob = prob 126 | self.limit = limit 127 | 128 | def __call__(self, img, mask=None): 129 | if random.random() < self.prob: 130 | angle = random.uniform(-self.limit, self.limit) 131 | 132 | height, width = img.shape[0:2] 133 | mat = cv2.getRotationMatrix2D((width / 2, height / 2), angle, 1.0) 134 | img = cv2.warpAffine(img, mat, (height, width), 135 | flags=cv2.INTER_LINEAR, 136 | borderMode=cv2.BORDER_REFLECT_101) 137 | if mask is not None: 138 | mask = cv2.warpAffine(mask, mat, (height, width), 139 | flags=cv2.INTER_NEAREST, 140 | borderMode=cv2.BORDER_REFLECT_101) 141 | 142 | return img, mask 143 | 144 | 145 | class RandomCrop: 146 | def __init__(self, size): 147 | self.h = size[0] 148 | self.w = size[1] 149 | 150 | def __call__(self, img, mask=None): 151 | height, width, _ = img.shape 152 | 153 | h_start = np.random.randint(0, height - self.h) 154 | w_start = np.random.randint(0, width - self.w) 155 | 156 | img = img[h_start: h_start + self.h, w_start: w_start + self.w] 157 | 158 | assert img.shape[0] == self.h 159 | assert img.shape[1] == self.w 160 | 161 | if mask is not None: 162 | mask = mask[h_start: h_start + self.h, w_start: w_start + self.w] 163 | 164 | return img, mask 165 | 166 | 167 | class Shift: 168 | def __init__(self, limit=4, prob=.5): 169 | self.limit = limit 170 | self.prob = prob 171 | 172 | def __call__(self, img, mask=None): 173 | if random.random() < self.prob: 174 | limit = self.limit 175 | dx = round(random.uniform(-limit, limit)) 176 | dy = round(random.uniform(-limit, limit)) 177 | 178 | height, width, channel = img.shape 179 | y1 = limit + 1 + dy 180 | y2 = y1 + height 181 | x1 = limit + 1 + dx 182 | x2 = x1 + width 183 | 184 | img1 = cv2.copyMakeBorder(img, limit + 1, limit + 1, limit + 1, limit + 1, 185 | borderType=cv2.BORDER_REFLECT_101) 186 | img = img1[y1:y2, x1:x2, :] 187 | if mask is not None: 188 | msk1 = cv2.copyMakeBorder(mask, limit + 1, limit + 1, limit + 1, limit + 1, 189 | borderType=cv2.BORDER_REFLECT_101) 190 | mask = msk1[y1:y2, x1:x2, :] 191 | 192 | return img, mask 193 | 194 | 195 | class ShiftScale: 196 | def __init__(self, limit=4, prob=.25): 197 | self.limit = limit 198 | self.prob = prob 199 | 200 | def __call__(self, img, mask=None): 201 | limit = self.limit 202 | if random.random() < self.prob: 203 | height, width, channel = img.shape 204 | assert (width == height) 205 | size0 = width 206 | size1 = width + 2 * limit 207 | size = round(random.uniform(size0, size1)) 208 | 209 | dx = round(random.uniform(0, size1 - size)) 210 | dy = round(random.uniform(0, size1 - size)) 211 | 212 | y1 = dy 213 | y2 = y1 + size 214 | x1 = dx 215 | x2 = x1 + size 216 | 217 | img1 = cv2.copyMakeBorder(img, limit, limit, limit, limit, borderType=cv2.BORDER_REFLECT_101) 218 | img = (img1[y1:y2, x1:x2, :] if size == size0 219 | else cv2.resize(img1[y1:y2, x1:x2, :], (size0, size0), interpolation=cv2.INTER_LINEAR)) 220 | 221 | if mask is not None: 222 | msk1 = cv2.copyMakeBorder(mask, limit, limit, limit, limit, borderType=cv2.BORDER_REFLECT_101) 223 | mask = (msk1[y1:y2, x1:x2, :] if size == size0 224 | else cv2.resize(msk1[y1:y2, x1:x2, :], (size0, size0), interpolation=cv2.INTER_LINEAR)) 225 | 226 | return img, mask 227 | 228 | 229 | class ShiftScaleRotate: 230 | def __init__(self, shift_limit=0.0625, scale_limit=0.1, rotate_limit=45, prob=0.5): 231 | self.shift_limit = shift_limit 232 | self.scale_limit = scale_limit 233 | self.rotate_limit = rotate_limit 234 | self.prob = prob 235 | 236 | def __call__(self, img, mask=None): 237 | if random.random() < self.prob: 238 | height, width, channel = img.shape 239 | 240 | angle = random.uniform(-self.rotate_limit, self.rotate_limit) 241 | scale = random.uniform(1 - self.scale_limit, 1 + self.scale_limit) 242 | dx = round(random.uniform(-self.shift_limit, self.shift_limit)) * width 243 | dy = round(random.uniform(-self.shift_limit, self.shift_limit)) * height 244 | 245 | cc = math.cos(angle / 180 * math.pi) * scale 246 | ss = math.sin(angle / 180 * math.pi) * scale 247 | rotate_matrix = np.array([[cc, -ss], [ss, cc]]) 248 | 249 | box0 = np.array([[0, 0], [width, 0], [width, height], [0, height], ]) 250 | box1 = box0 - np.array([width / 2, height / 2]) 251 | box1 = np.dot(box1, rotate_matrix.T) + np.array([width / 2 + dx, height / 2 + dy]) 252 | 253 | box0 = box0.astype(np.float32) 254 | box1 = box1.astype(np.float32) 255 | mat = cv2.getPerspectiveTransform(box0, box1) 256 | img = cv2.warpPerspective(img, mat, (width, height), 257 | flags=cv2.INTER_LINEAR, 258 | borderMode=cv2.BORDER_REFLECT_101) 259 | if mask is not None: 260 | mask = cv2.warpPerspective(mask, mat, (width, height), 261 | flags=cv2.INTER_NEAREST, 262 | borderMode=cv2.BORDER_REFLECT_101) 263 | 264 | return img, mask 265 | 266 | 267 | class CenterCrop: 268 | def __init__(self, size): 269 | if isinstance(size, int): 270 | size = (size, size) 271 | 272 | self.height = size[0] 273 | self.width = size[1] 274 | 275 | def __call__(self, img, mask=None): 276 | h, w, c = img.shape 277 | dy = (h - self.height) // 2 278 | dx = (w - self.width) // 2 279 | 280 | y1 = dy 281 | y2 = y1 + self.height 282 | x1 = dx 283 | x2 = x1 + self.width 284 | img = img[y1:y2, x1:x2] 285 | 286 | if mask is not None: 287 | mask = mask[y1:y2, x1:x2] 288 | 289 | return img, mask 290 | 291 | 292 | class Normalize: 293 | def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): 294 | self.mean = mean 295 | self.std = std 296 | 297 | def __call__(self, img): 298 | max_pixel_value = 255.0 299 | 300 | img = img.astype(np.float32) / max_pixel_value 301 | 302 | img -= np.ones(img.shape) * self.mean 303 | img /= np.ones(img.shape) * self.std 304 | return img 305 | 306 | 307 | class Distort1: 308 | """" 309 | ## unconverntional augmnet ################################################################################3 310 | ## https://stackoverflow.com/questions/6199636/formulas-for-barrel-pincushion-distortion 311 | 312 | ## https://stackoverflow.com/questions/10364201/image-transformation-in-opencv 313 | ## https://stackoverflow.com/questions/2477774/correcting-fisheye-distortion-programmatically 314 | ## http://www.coldvision.io/2017/03/02/advanced-lane-finding-using-opencv/ 315 | 316 | ## barrel\pincushion distortion 317 | """ 318 | 319 | def __init__(self, distort_limit=0.35, shift_limit=0.25, prob=0.5): 320 | self.distort_limit = distort_limit 321 | self.shift_limit = shift_limit 322 | self.prob = prob 323 | 324 | def __call__(self, img, mask=None): 325 | if random.random() < self.prob: 326 | height, width, channel = img.shape 327 | 328 | if 0: 329 | img = img.copy() 330 | for x in range(0, width, 10): 331 | cv2.line(img, (x, 0), (x, height), (1, 1, 1), 1) 332 | for y in range(0, height, 10): 333 | cv2.line(img, (0, y), (width, y), (1, 1, 1), 1) 334 | 335 | k = random.uniform(-self.distort_limit, self.distort_limit) * 0.00001 336 | dx = random.uniform(-self.shift_limit, self.shift_limit) * width 337 | dy = random.uniform(-self.shift_limit, self.shift_limit) * height 338 | 339 | # map_x, map_y = 340 | # cv2.initUndistortRectifyMap(intrinsics, dist_coeffs, None, None, (width,height),cv2.CV_32FC1) 341 | # https://stackoverflow.com/questions/6199636/formulas-for-barrel-pincushion-distortion 342 | # https://stackoverflow.com/questions/10364201/image-transformation-in-opencv 343 | x, y = np.mgrid[0:width:1, 0:height:1] 344 | x = x.astype(np.float32) - width / 2 - dx 345 | y = y.astype(np.float32) - height / 2 - dy 346 | theta = np.arctan2(y, x) 347 | d = (x * x + y * y) ** 0.5 348 | r = d * (1 + k * d * d) 349 | map_x = r * np.cos(theta) + width / 2 + dx 350 | map_y = r * np.sin(theta) + height / 2 + dy 351 | 352 | img = cv2.remap(img, map_x, map_y, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101) 353 | if mask is not None: 354 | mask = cv2.remap(mask, map_x, map_y, interpolation=cv2.INTER_NEAREST, borderMode=cv2.BORDER_REFLECT_101) 355 | return img, mask 356 | 357 | 358 | class Distort2: 359 | """ 360 | #http://pythology.blogspot.sg/2014/03/interpolation-on-regular-distorted-grid.html 361 | ## grid distortion 362 | """ 363 | 364 | def __init__(self, num_steps=10, distort_limit=0.2, prob=0.5): 365 | self.num_steps = num_steps 366 | self.distort_limit = distort_limit 367 | self.prob = prob 368 | 369 | def __call__(self, img, mask=None): 370 | if random.random() < self.prob: 371 | height, width, channel = img.shape 372 | 373 | x_step = width // self.num_steps 374 | xx = np.zeros(width, np.float32) 375 | prev = 0 376 | for x in range(0, width, x_step): 377 | start = x 378 | end = x + x_step 379 | if end > width: 380 | end = width 381 | cur = width 382 | else: 383 | cur = prev + x_step * (1 + random.uniform(-self.distort_limit, self.distort_limit)) 384 | 385 | xx[start:end] = np.linspace(prev, cur, end - start) 386 | prev = cur 387 | 388 | y_step = height // self.num_steps 389 | yy = np.zeros(height, np.float32) 390 | prev = 0 391 | for y in range(0, height, y_step): 392 | start = y 393 | end = y + y_step 394 | if end > width: 395 | end = height 396 | cur = height 397 | else: 398 | cur = prev + y_step * (1 + random.uniform(-self.distort_limit, self.distort_limit)) 399 | 400 | yy[start:end] = np.linspace(prev, cur, end - start) 401 | prev = cur 402 | 403 | map_x, map_y = np.meshgrid(xx, yy) 404 | map_x = map_x.astype(np.float32) 405 | map_y = map_y.astype(np.float32) 406 | img = cv2.remap(img, map_x, map_y, 407 | interpolation=cv2.INTER_LINEAR, 408 | borderMode=cv2.BORDER_REFLECT_101) 409 | if mask is not None: 410 | mask = cv2.remap(mask, map_x, map_y, 411 | interpolation=cv2.INTER_LINEAR, 412 | borderMode=cv2.BORDER_REFLECT_101) 413 | 414 | return img, mask 415 | 416 | 417 | def clip(img, dtype, maxval): 418 | return np.clip(img, 0, maxval).astype(dtype) 419 | 420 | 421 | class RandomFilter: 422 | """ 423 | blur sharpen, etc 424 | """ 425 | 426 | def __init__(self, limit=.5, prob=.5): 427 | self.limit = limit 428 | self.prob = prob 429 | 430 | def __call__(self, img): 431 | if random.random() < self.prob: 432 | alpha = self.limit * random.uniform(0, 1) 433 | kernel = np.ones((3, 3), np.float32) / 9 * 0.2 434 | 435 | colored = img[..., :3] 436 | colored = alpha * cv2.filter2D(colored, -1, kernel) + (1 - alpha) * colored 437 | maxval = np.max(img[..., :3]) 438 | dtype = img.dtype 439 | img[..., :3] = clip(colored, dtype, maxval) 440 | 441 | return img 442 | 443 | 444 | # https://github.com/pytorch/vision/pull/27/commits/659c854c6971ecc5b94dca3f4459ef2b7e42fb70 445 | # color augmentation 446 | 447 | # brightness, contrast, saturation------------- 448 | # from mxnet code, see: https://github.com/dmlc/mxnet/blob/master/python/mxnet/image.py 449 | 450 | class RandomBrightness: 451 | def __init__(self, limit=0.1, prob=0.5): 452 | self.limit = limit 453 | self.prob = prob 454 | 455 | def __call__(self, img): 456 | if random.random() < self.prob: 457 | alpha = 1.0 + self.limit * random.uniform(-1, 1) 458 | 459 | maxval = np.max(img[..., :3]) 460 | dtype = img.dtype 461 | img[..., :3] = clip(alpha * img[..., :3], dtype, maxval) 462 | return img 463 | 464 | 465 | class RandomContrast: 466 | def __init__(self, limit=.1, prob=.5): 467 | self.limit = limit 468 | self.prob = prob 469 | 470 | def __call__(self, img): 471 | if random.random() < self.prob: 472 | alpha = 1.0 + self.limit * random.uniform(-1, 1) 473 | 474 | gray = cv2.cvtColor(img[:, :, :3], cv2.COLOR_BGR2GRAY) 475 | gray = (3.0 * (1.0 - alpha) / gray.size) * np.sum(gray) 476 | maxval = np.max(img[..., :3]) 477 | dtype = img.dtype 478 | img[:, :, :3] = clip(alpha * img[:, :, :3] + gray, dtype, maxval) 479 | return img 480 | 481 | 482 | class RandomSaturation: 483 | def __init__(self, limit=0.3, prob=0.5): 484 | self.limit = limit 485 | self.prob = prob 486 | 487 | def __call__(self, img): 488 | # dont work :( 489 | if random.random() < self.prob: 490 | maxval = np.max(img[..., :3]) 491 | dtype = img.dtype 492 | alpha = 1.0 + random.uniform(-self.limit, self.limit) 493 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 494 | gray = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR) 495 | img[..., :3] = alpha * img[..., :3] + (1.0 - alpha) * gray 496 | img[..., :3] = clip(img[..., :3], dtype, maxval) 497 | return img 498 | 499 | 500 | class RandomHueSaturationValue: 501 | def __init__(self, hue_shift_limit=(-20, 20), sat_shift_limit=(-35, 35), val_shift_limit=(-35, 35), prob=0.5): 502 | self.hue_shift_limit = hue_shift_limit 503 | self.sat_shift_limit = sat_shift_limit 504 | self.val_shift_limit = val_shift_limit 505 | self.prob = prob 506 | 507 | def __call__(self, image): 508 | if random.random() < self.prob: 509 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 510 | h, s, v = cv2.split(image) 511 | hue_shift = np.random.uniform(self.hue_shift_limit[0], self.hue_shift_limit[1]) 512 | h = cv2.add(h, hue_shift) 513 | sat_shift = np.random.uniform(self.sat_shift_limit[0], self.sat_shift_limit[1]) 514 | s = cv2.add(s, sat_shift) 515 | val_shift = np.random.uniform(self.val_shift_limit[0], self.val_shift_limit[1]) 516 | v = cv2.add(v, val_shift) 517 | image = cv2.merge((h, s, v)) 518 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) 519 | return image 520 | 521 | 522 | class CLAHE: 523 | def __init__(self, clipLimit=2.0, tileGridSize=(8, 8)): 524 | self.clipLimit = clipLimit 525 | self.tileGridSize = tileGridSize 526 | 527 | def __call__(self, im): 528 | img_yuv = cv2.cvtColor(im, cv2.COLOR_BGR2YUV) 529 | clahe = cv2.createCLAHE(clipLimit=self.clipLimit, tileGridSize=self.tileGridSize) 530 | img_yuv[:, :, 0] = clahe.apply(img_yuv[:, :, 0]) 531 | img_output = cv2.cvtColor(img_yuv, cv2.COLOR_YUV2BGR) 532 | return img_output 533 | 534 | 535 | def augment(x, mask=None, prob=0.5): 536 | return DualCompose([ 537 | OneOrOther( 538 | *(OneOf([ 539 | Distort1(distort_limit=0.05, shift_limit=0.05), 540 | Distort2(num_steps=2, distort_limit=0.05)]), 541 | ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.10, rotate_limit=45)), prob=prob), 542 | RandomFlip(prob=0.5), 543 | Transpose(prob=0.5), 544 | ImageOnly(RandomContrast(limit=0.2, prob=0.5)), 545 | ImageOnly(RandomFilter(limit=0.5, prob=0.2)), 546 | ])(x, mask) 547 | --------------------------------------------------------------------------------