├── requirements.txt
├── train.sh
├── prepare_train_val.py
├── loss.py
├── LICENSE
├── validation.py
├── .gitignore
├── dataset.py
├── evaluate.py
├── utils.py
├── train.py
├── generate_masks.py
├── README.rst
├── models.py
└── transforms.py
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==0.3.1.post2
2 | numpy==1.14.0
3 | opencv-python==3.3.0.10
4 | tqdm==4.19.4
5 | torchvision==0.1.9
6 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | export CUDA_DEVICE_ORDER=PCI_BUS_ID
4 | export CUDA_VISIBLE_DEVICES=0,1,2,3
5 |
6 | for i in 0 1 2 3 4
7 | do
8 | python train.py --device-ids 0,1,2,3 --limit 10000 --batch-size 16 --n-epochs 10 --fold $i --model UNet16
9 | python train.py --device-ids 0,1,2,3 --limit 10000 --batch-size 16 --n-epochs 15 --fold $i --lr 0.00001 --model UNe16
10 | done
--------------------------------------------------------------------------------
/prepare_train_val.py:
--------------------------------------------------------------------------------
1 | from dataset import data_path
2 | from sklearn.model_selection import KFold
3 | import numpy as np
4 |
5 |
6 | def get_split(fold, num_splits=5):
7 | train_path = data_path / 'train' / 'angyodysplasia' / 'images'
8 |
9 | train_file_names = np.array(sorted(list(train_path.glob('*'))))
10 |
11 | kf = KFold(n_splits=num_splits, random_state=2018)
12 |
13 | ids = list(kf.split(train_file_names))
14 |
15 | train_ids, val_ids = ids[fold]
16 |
17 | if fold == -1:
18 | return train_file_names, train_file_names
19 | else:
20 | return train_file_names[train_ids], train_file_names[val_ids]
21 |
22 |
23 | if __name__ == '__main__':
24 | ids = get_split(0)
25 |
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch.nn import functional as F
3 |
4 |
5 | def soft_jaccard(outputs, targets):
6 | eps = 1e-15
7 | jaccard_target = (targets == 1).float()
8 | jaccard_output = F.sigmoid(outputs)
9 |
10 | intersection = (jaccard_output * jaccard_target).sum()
11 | union = jaccard_output.sum() + jaccard_target.sum()
12 | return intersection / (union - intersection + eps)
13 |
14 |
15 | class LossBinary:
16 | """
17 | Loss defined as BCE - log(soft_jaccard)
18 |
19 | Vladimir Iglovikov, Sergey Mushinskiy, Vladimir Osin,
20 | Satellite Imagery Feature Detection using Deep Convolutional Neural Network: A Kaggle Competition
21 | arXiv:1706.06169
22 | """
23 |
24 | def __init__(self, jaccard_weight=0):
25 | self.nll_loss = nn.BCEWithLogitsLoss()
26 | self.jaccard_weight = jaccard_weight
27 |
28 | def __call__(self, outputs, targets):
29 | loss = (1 - self.jaccard_weight) * self.nll_loss(outputs, targets)
30 |
31 | if self.jaccard_weight:
32 | loss += self.jaccard_weight * (1 - soft_jaccard(outputs, targets))
33 | return loss
34 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Vladimir Iglovikov, Alexey Shvets, Alexandr A. Kalinin, Alexander Rakhlin
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/validation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import utils
3 | from torch import nn
4 |
5 |
6 | def validation_binary(model: nn.Module, criterion, valid_loader):
7 | model.eval()
8 | losses = []
9 |
10 | jaccard = []
11 | dice = []
12 |
13 | for inputs, targets in valid_loader:
14 | inputs = utils.variable(inputs, volatile=True)
15 | targets = utils.variable(targets)
16 | outputs = model(inputs)
17 | loss = criterion(outputs, targets)
18 | losses.append(loss.data[0])
19 | outputs = (outputs > 0).float()
20 | jaccard += get_jaccard(targets, outputs)
21 | dice += get_dice(targets, outputs)
22 |
23 | valid_loss = np.mean(losses) # type: float
24 |
25 | valid_jaccard = np.mean(jaccard).astype(np.float64)
26 | valid_dice = np.mean(dice).astype(np.float64)
27 |
28 | print('Valid loss: {:.5f}, jaccard: {:.5f}, dice: {:.5f}'.format(valid_loss, valid_jaccard, valid_dice))
29 | metrics = {'valid_loss': valid_loss, 'jaccard_loss': valid_jaccard, 'dice_loss': valid_dice}
30 | return metrics
31 |
32 |
33 | def get_jaccard(y_true, y_pred):
34 | epsilon = 1e-15
35 | intersection = (y_pred * y_true).sum(dim=-2).sum(dim=-1)
36 | union = y_true.sum(dim=-2).sum(dim=-1) + y_pred.sum(dim=-2).sum(dim=-1)
37 |
38 | return list((intersection / (union + epsilon - intersection)).data.cpu().numpy())
39 |
40 |
41 | def get_dice(y_true, y_pred):
42 | epsilon = 1e-15
43 | intersection = (y_pred * y_true).sum(dim=-2).sum(dim=-1)
44 | union = y_true.sum(dim=-2).sum(dim=-1) + y_pred.sum(dim=-2).sum(dim=-1)
45 | return list((2 * intersection / (union + epsilon)).data.cpu().numpy())
46 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import cv2
4 | from torch.utils.data import Dataset
5 | from pathlib import Path
6 |
7 | data_path = Path('data')
8 |
9 |
10 | class AngyodysplasiaDataset(Dataset):
11 | def __init__(self, img_paths: list, to_augment=False, transform=None, mode='train', limit=None):
12 | self.img_paths = img_paths
13 | self.to_augment = to_augment
14 | self.transform = transform
15 | self.mode = mode
16 | self.limit = limit
17 |
18 | def __len__(self):
19 | if self.limit is None:
20 | return len(self.img_paths)
21 | else:
22 | return self.limit
23 |
24 | def __getitem__(self, idx):
25 | if self.limit is None:
26 | img_file_name = self.img_paths[idx]
27 | else:
28 | img_file_name = np.random.choice(self.img_paths)
29 |
30 | img = load_image(img_file_name)
31 |
32 | if self.mode == 'train':
33 | mask = load_mask(img_file_name)
34 |
35 | img, mask = self.transform(img, mask)
36 |
37 | return to_float_tensor(img), torch.from_numpy(np.expand_dims(mask, 0)).float()
38 | else:
39 | mask = np.zeros(img.shape[:2])
40 | img, mask = self.transform(img, mask)
41 |
42 | return to_float_tensor(img), str(img_file_name)
43 |
44 |
45 | def to_float_tensor(img):
46 | return torch.from_numpy(np.moveaxis(img, -1, 0)).float()
47 |
48 |
49 | def load_image(path):
50 | img = cv2.imread(str(path))
51 | return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
52 |
53 |
54 | def load_mask(path):
55 | mask = cv2.imread(str(path).replace('images', 'masks').replace(r'.jpg', r'_a.jpg'), 0)
56 | return (mask > 0).astype(np.uint8)
57 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import argparse
3 | import cv2
4 | import numpy as np
5 | from tqdm import tqdm
6 |
7 |
8 | def general_dice(y_true, y_pred):
9 | if y_true.sum() == 0:
10 | if y_pred.sum() == 0:
11 | return 1
12 | else:
13 | return 0
14 |
15 | return dice(y_true == 1, y_pred == 1)
16 |
17 |
18 | def general_jaccard(y_true, y_pred):
19 | if y_true.sum() == 0:
20 | if y_pred.sum() == 0:
21 | return 1
22 | else:
23 | return 0
24 |
25 | return jaccard(y_true == 1, y_pred == 1)
26 |
27 |
28 | def jaccard(y_true, y_pred):
29 | intersection = (y_true * y_pred).sum()
30 | union = y_true.sum() + y_pred.sum() - intersection
31 | return (intersection + 1e-15) / (union + 1e-15)
32 |
33 |
34 | def dice(y_true, y_pred):
35 | return (2 * (y_true * y_pred).sum() + 1e-15) / (y_true.sum() + y_pred.sum() + 1e-15)
36 |
37 |
38 | if __name__ == '__main__':
39 | parser = argparse.ArgumentParser()
40 | arg = parser.add_argument
41 |
42 | arg('--train_path', type=str, default='data/train/angyodysplasia/masks', help='path where train images with ground truth are located')
43 | arg('--target_path', type=str, default='predictions/UNet', help='path with predictions')
44 | args = parser.parse_args()
45 |
46 | result_dice = []
47 | result_jaccard = []
48 |
49 | for file_name in tqdm(list(Path(args.train_path).glob('*'))):
50 | y_true = (cv2.imread(str(file_name), 0) > 255 * 0.5).astype(np.uint8)
51 |
52 | pred_file_name = Path(args.target_path) / (file_name.stem.replace('_a', '') + '.png')
53 |
54 | y_pred = (cv2.imread(str(pred_file_name), 0) > 255 * 0.5).astype(np.uint8)
55 |
56 | result_dice += [dice(y_true, y_pred)]
57 | result_jaccard += [jaccard(y_true, y_pred)]
58 |
59 | print('Dice = ', np.mean(result_dice), np.std(result_dice))
60 | print('Jaccard = ', np.mean(result_jaccard), np.std(result_jaccard))
61 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | from datetime import datetime
3 | from pathlib import Path
4 |
5 | import random
6 | import numpy as np
7 |
8 | import torch
9 | from torch.autograd import Variable
10 | import tqdm
11 |
12 |
13 | def variable(x, volatile=False):
14 | if isinstance(x, (list, tuple)):
15 | return [variable(y, volatile=volatile) for y in x]
16 | return cuda(Variable(x, volatile=volatile))
17 |
18 |
19 | def cuda(x):
20 | return x.cuda(async=True) if torch.cuda.is_available() else x
21 |
22 |
23 | def write_event(log, step: int, **data):
24 | data['step'] = step
25 | data['dt'] = datetime.now().isoformat()
26 | log.write(json.dumps(data, sort_keys=True))
27 | log.write('\n')
28 | log.flush()
29 |
30 |
31 | def train(args, model, criterion, train_loader, valid_loader, validation, init_optimizer, n_epochs=None, fold=None):
32 | lr = args.lr
33 | n_epochs = n_epochs or args.n_epochs
34 | optimizer = init_optimizer(lr)
35 |
36 | root = Path(args.root)
37 | model_path = root / 'model_{fold}.pt'.format(fold=fold)
38 | if model_path.exists():
39 | state = torch.load(str(model_path))
40 | epoch = state['epoch']
41 | step = state['step']
42 | model.load_state_dict(state['model'])
43 | print('Restored model, epoch {}, step {:,}'.format(epoch, step))
44 | else:
45 | epoch = 1
46 | step = 0
47 |
48 | save = lambda ep: torch.save({
49 | 'model': model.state_dict(),
50 | 'epoch': ep,
51 | 'step': step,
52 | }, str(model_path))
53 |
54 | report_each = 10
55 | log = root.joinpath('train_{fold}.log'.format(fold=fold)).open('at', encoding='utf8')
56 | valid_losses = []
57 | for epoch in range(epoch, n_epochs + 1):
58 | model.train()
59 | random.seed()
60 | tq = tqdm.tqdm(total=(len(train_loader) * args.batch_size))
61 | tq.set_description('Epoch {}, lr {}'.format(epoch, lr))
62 | losses = []
63 | tl = train_loader
64 | try:
65 | mean_loss = 0
66 | for i, (inputs, targets) in enumerate(tl):
67 | inputs, targets = variable(inputs), variable(targets)
68 | outputs = model(inputs)
69 | loss = criterion(outputs, targets)
70 | optimizer.zero_grad()
71 | batch_size = inputs.size(0)
72 | loss.backward()
73 | optimizer.step()
74 | step += 1
75 | tq.update(batch_size)
76 | losses.append(loss.data[0])
77 | mean_loss = np.mean(losses[-report_each:])
78 | tq.set_postfix(loss='{:.5f}'.format(mean_loss))
79 | if i and i % report_each == 0:
80 | write_event(log, step, loss=mean_loss)
81 | write_event(log, step, loss=mean_loss)
82 | tq.close()
83 | save(epoch + 1)
84 | valid_metrics = validation(model, criterion, valid_loader)
85 | write_event(log, step, **valid_metrics)
86 | valid_loss = valid_metrics['valid_loss']
87 | valid_losses.append(valid_loss)
88 | except KeyboardInterrupt:
89 | tq.close()
90 | print('Ctrl+C, saving snapshot')
91 | save(epoch)
92 | print('done.')
93 | return
94 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | from pathlib import Path
4 | from validation import validation_binary
5 | import torch
6 | from torch import nn
7 | from torch.optim import Adam
8 | from torch.utils.data import DataLoader
9 | import torch.backends.cudnn as cudnn
10 | import torch.backends.cudnn
11 |
12 | from models import UNet, UNet11, UNet16, LinkNet34, AlbuNet34
13 | from loss import LossBinary
14 | from dataset import AngyodysplasiaDataset
15 | import utils
16 |
17 | from prepare_train_val import get_split
18 |
19 | from transforms import (DualCompose,
20 | ImageOnly,
21 | Normalize,
22 | HorizontalFlip,
23 | Rotate,
24 | CenterCrop,
25 | RandomHueSaturationValue,
26 | VerticalFlip)
27 |
28 |
29 | def main():
30 | parser = argparse.ArgumentParser()
31 | arg = parser.add_argument
32 | arg('--jaccard-weight', default=0.3, type=float)
33 | arg('--device-ids', type=str, default='0', help='For example 0,1 to run on two GPUs')
34 | arg('--fold', type=int, help='fold', default=0)
35 | arg('--root', default='runs/debug', help='checkpoint root')
36 | arg('--batch-size', type=int, default=1)
37 | arg('--limit', type=int, default=10000, help='number of images in epoch')
38 | arg('--n-epochs', type=int, default=100)
39 | arg('--lr', type=float, default=0.0001)
40 | arg('--workers', type=int, default=12)
41 | arg('--model', type=str, default='UNet', choices=['UNet', 'UNet11', 'LinkNet34', 'UNet16', 'AlbuNet34'])
42 |
43 | args = parser.parse_args()
44 |
45 | root = Path(args.root)
46 | root.mkdir(exist_ok=True, parents=True)
47 |
48 | num_classes = 1
49 | if args.model == 'UNet':
50 | model = UNet(num_classes=num_classes)
51 | elif args.model == 'UNet11':
52 | model = UNet11(num_classes=num_classes, pretrained=True)
53 | elif args.model == 'UNet16':
54 | model = UNet16(num_classes=num_classes, pretrained=True)
55 | elif args.model == 'LinkNet34':
56 | model = LinkNet34(num_classes=num_classes, pretrained=True)
57 | elif args.model == 'AlbuNet':
58 | model = AlbuNet34(num_classes=num_classes, pretrained=True)
59 | else:
60 | model = UNet(num_classes=num_classes, input_channels=3)
61 |
62 | if torch.cuda.is_available():
63 | if args.device_ids:
64 | device_ids = list(map(int, args.device_ids.split(',')))
65 | else:
66 | device_ids = None
67 | model = nn.DataParallel(model, device_ids=device_ids).cuda()
68 |
69 | loss = LossBinary(jaccard_weight=args.jaccard_weight)
70 |
71 | cudnn.benchmark = True
72 |
73 | def make_loader(file_names, shuffle=False, transform=None, limit=None):
74 | return DataLoader(
75 | dataset=AngyodysplasiaDataset(file_names, transform=transform, limit=limit),
76 | shuffle=shuffle,
77 | num_workers=args.workers,
78 | batch_size=args.batch_size,
79 | pin_memory=torch.cuda.is_available()
80 | )
81 |
82 | train_file_names, val_file_names = get_split(args.fold)
83 |
84 | print('num train = {}, num_val = {}'.format(len(train_file_names), len(val_file_names)))
85 |
86 | train_transform = DualCompose([
87 | CenterCrop(512),
88 | HorizontalFlip(),
89 | VerticalFlip(),
90 | Rotate(),
91 | ImageOnly(RandomHueSaturationValue()),
92 | ImageOnly(Normalize())
93 | ])
94 |
95 | val_transform = DualCompose([
96 | CenterCrop(512),
97 | ImageOnly(Normalize())
98 | ])
99 |
100 | train_loader = make_loader(train_file_names, shuffle=True, transform=train_transform, limit=args.limit)
101 | valid_loader = make_loader(val_file_names, transform=val_transform)
102 |
103 | root.joinpath('params.json').write_text(
104 | json.dumps(vars(args), indent=True, sort_keys=True))
105 |
106 | utils.train(
107 | init_optimizer=lambda lr: Adam(model.parameters(), lr=lr),
108 | args=args,
109 | model=model,
110 | criterion=loss,
111 | train_loader=train_loader,
112 | valid_loader=valid_loader,
113 | validation=validation_binary,
114 | fold=args.fold
115 | )
116 |
117 |
118 | if __name__ == '__main__':
119 | main()
120 |
--------------------------------------------------------------------------------
/generate_masks.py:
--------------------------------------------------------------------------------
1 | """
2 | Script generates predictions, splitting original images into tiles, and assembling prediction back together
3 | """
4 | import argparse
5 | from prepare_train_val import get_split
6 | from dataset import AngyodysplasiaDataset
7 | import cv2
8 | from models import UNet, UNet11, UNet16, AlbuNet34
9 | import torch
10 | from pathlib import Path
11 | from tqdm import tqdm
12 | import numpy as np
13 | import utils
14 | # import prepare_data
15 | from torch.utils.data import DataLoader
16 | from torch.nn import functional as F
17 |
18 | from transforms import (ImageOnly,
19 | Normalize,
20 | CenterCrop,
21 | DualCompose)
22 |
23 | img_transform = DualCompose([
24 | CenterCrop(512),
25 | ImageOnly(Normalize())
26 | ])
27 |
28 |
29 | def get_model(model_path, model_type):
30 | """
31 |
32 | :param model_path:
33 | :param model_type: 'UNet', 'UNet11', 'UNet16', 'AlbuNet34'
34 | :return:
35 | """
36 |
37 | num_classes = 1
38 |
39 | if model_type == 'UNet11':
40 | model = UNet11(num_classes=num_classes)
41 | elif model_type == 'UNet16':
42 | model = UNet16(num_classes=num_classes)
43 | elif model_type == 'AlbuNet34':
44 | model = AlbuNet34(num_classes=num_classes)
45 | elif model_type == 'UNet':
46 | model = UNet(num_classes=num_classes)
47 | else:
48 | model = UNet(num_classes=num_classes)
49 |
50 | state = torch.load(str(model_path))
51 | state = {key.replace('module.', ''): value for key, value in state['model'].items()}
52 | model.load_state_dict(state)
53 |
54 | if torch.cuda.is_available():
55 | return model.cuda()
56 |
57 | model.eval()
58 |
59 | return model
60 |
61 |
62 | def predict(model, from_file_names, batch_size: int, to_path):
63 | loader = DataLoader(
64 | dataset=AngyodysplasiaDataset(from_file_names, transform=img_transform, mode='predict'),
65 | shuffle=False,
66 | batch_size=batch_size,
67 | num_workers=args.workers,
68 | pin_memory=torch.cuda.is_available()
69 | )
70 |
71 | for batch_num, (inputs, paths) in enumerate(tqdm(loader, desc='Predict')):
72 | inputs = utils.variable(inputs, volatile=True)
73 |
74 | outputs = model(inputs)
75 |
76 | for i, image_name in enumerate(paths):
77 | mask = (F.sigmoid(outputs[i, 0]).data.cpu().numpy() * 255).astype(np.uint8)
78 |
79 | h, w = mask.shape
80 |
81 | full_mask = np.zeros((576, 576))
82 | full_mask[32:32 + h, 32:32 + w] = mask
83 |
84 | (to_path / args.model_type).mkdir(exist_ok=True, parents=True)
85 |
86 | cv2.imwrite(str(to_path / args.model_type / (Path(paths[i]).stem + '.png')), full_mask)
87 |
88 |
89 | if __name__ == '__main__':
90 | parser = argparse.ArgumentParser()
91 | arg = parser.add_argument
92 | arg('--model_path', type=str, default='data/models/UNet', help='path to model folder')
93 | arg('--model_type', type=str, default='UNet', help='network architecture',
94 | choices=['UNet', 'UNet11', 'UNet16', 'AlbuNet34'])
95 | arg('--batch-size', type=int, default=4)
96 | arg('--fold', type=int, default=-1, choices=[0, 1, 2, 3, 4, -1], help='-1: all folds')
97 | arg('--workers', type=int, default=12)
98 |
99 | args = parser.parse_args()
100 |
101 | if args.fold == -1:
102 | for fold in [0, 1, 2, 3, 4]:
103 | _, file_names = get_split(fold)
104 |
105 | print(len(file_names))
106 |
107 | model = get_model(str(Path(args.model_path).joinpath('model_{fold}.pt'.format(fold=fold))),
108 | model_type=args.model_type)
109 |
110 | print('num file_names = {}'.format(len(file_names)))
111 |
112 | output_path = Path(args.model_path)
113 | output_path.mkdir(exist_ok=True, parents=True)
114 |
115 | predict(model, file_names, args.batch_size, output_path)
116 | else:
117 | _, file_names = get_split(args.fold)
118 | model = get_model(str(Path(args.model_path).joinpath('model_{fold}.pt'.format(fold=args.fold))),
119 | model_type=args.model_type)
120 |
121 | print('num file_names = {}'.format(len(file_names)))
122 |
123 | output_path = Path(args.model_path)
124 | output_path.mkdir(exist_ok=True, parents=True)
125 |
126 | predict(model, file_names, args.batch_size, output_path)
127 |
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 | =================================================================================
2 | MICCAI 2017 Endoscopic Vision Challenge Angiodysplasia Detection and Localization
3 | =================================================================================
4 |
5 | Here we present our wining solution and its further development for `MICCAI 2017 Endoscopic Vision Challenge Angiodysplasia Detection and Localization`_. It addresses binary segmentation problem, where every pixel in image is labeled as an angiodysplasia lesions or background. Then, we analyze connected component of each predicted mask. Based on the analysis we developed a classifier that predict angiodysplasia lesions (binary variable) and a detector for their localization (center of a component).
6 |
7 | .. contents::
8 |
9 | Team members
10 | ------------
11 | `Alexey Shvets`_, `Vladimir Iglovikov`_, `Alexander Rakhlin`_, `Alexandr A. Kalinin`_
12 |
13 | Citation
14 | ----------
15 |
16 | If you find this work useful for your publications, please consider citing::
17 |
18 | @inproceedings{shvets2018angiodysplasia,
19 | title={Angiodysplasia Detection and Localization using Deep Convolutional Neural Networks},
20 | author={Shvets, Alexey A and Iglovikov, Vladimir I and Rakhlin, Alexander and Kalinin, Alexandr A},
21 | booktitle={2018 17th IEEE International Conference on Machine Learning and Applications (ICMLA)},
22 | pages={612--617},
23 | year={2018}
24 | }
25 |
26 | Overview
27 | --------
28 | Angiodysplasias are degenerative lesions of previously healthy blood vessels, in which the bowel wall have microvascular abnormalities. These lesions are the most common source of small bowel bleeding in patients older than 50 years, and cause approximately 8% of all gastrointestinal bleeding episodes. Gold-standard examination for angiodysplasia detection and localization in the small bowel is performed using Wireless Capsule Endoscopy (WCE). Last generation of this pill-like device is able to acquire more than 60 000 images with a resolution of approximately 520*520 pixels. According to the latest state-of-the art, only 69% of angiodysplasias are detected by gastroenterologist experts during the reading of WCE videos, and blood indicator software (provided by WCE provider like Given Imaging), in the presence of angiodysplasias, presents sensitivity and specificity values of only 41% and 67%, respectively.
29 |
30 | .. figure:: https://habrastorage.org/webt/if/5p/tj/if5ptjnbzeswfgqderpcww0sstm.jpeg
31 |
32 | Data
33 | ----
34 | The dataset consists of 1200 color images obtained with WCE. The images are in 24-bit PNG format, with 576 |times| 576 pixel resolution. The dataset is split into two equal parts, 600 images for training and 600 for evaluation. Each subset is composed of 300 images with apparent AD and 300 without any pathology. The training subset is annotated by human expert and contains 300 binary masks in JPEG format of the same 576 |times| 576 pixel resolution. White pixels in the masks correspond to lesion localization.
35 |
36 | .. figure:: https://hsto.org/webt/nq/3v/wf/nq3vwfqtoutrzmnbzmrnyligwym.png
37 | :scale: 30 %
38 |
39 | First row corresponds to images without pathology, the second row to images with several AD lesions in every image, and the last row contains masks that correspond to the pathology images from the second row.
40 |
41 | |
42 | |
43 | |
44 |
45 | .. figure:: https://habrastorage.org/webt/t3/p6/yy/t3p6yykecrvr9mim7fqgevodgu4.png
46 | :scale: 45 %
47 |
48 | Most images contain 1 lesion. Distribution of AD lesion areas reaches maximum of 12,000 pixels and has median 1,648 pixels.
49 |
50 |
51 | Method
52 | ------
53 | We evaluate 4 different deep architectures for segmentation: `U-Net`_ (Ronneberger et al., 2015; Iglovikov et al., 2017a), 2 modifications of `TernausNet`_ (Iglovikov and Shvets, 2018), and `AlbuNet34`_, a modifications of `LinkNet`_ (Chaurasia and Culurciello, 2017; Shvets et al., 2018). As an improvement over standard `U-Net`_, we use similar networks with pre-trained encoders. `TernausNet`_ (Iglovikov and Shvets, 2018) is a U-Net-like architecture that uses relatively simple pre-trained VGG11 or VGG16 (Simonyan and Zisserman, 2014) networks as an encoder. VGG11 consists of seven convolutional layers, each followed by a ReLU activation function, and ve max polling operations, each reducing feature map by 2. All convolutional layers have 3 |times| 3 kernels. TernausNet16 has a similar structure and uses VGG16 network as an encoder
54 |
55 | .. figure:: https://hsto.org/webt/vz/ok/wt/vzokwtntgqe6lb-g2oyhzj0qcyo.png
56 | :scale: 72 %
57 |
58 | .. figure:: https://hsto.org/webt/vs/by/8y/vsby8yt4bj_6n3pqdqlf2tb8r9a.png
59 | :scale: 72 %
60 |
61 | Training
62 | --------
63 |
64 | We use Jaccard index (Intersection Over Union) as the evaluation metric. It can be interpreted as a similarity measure between a finite number of sets. For two sets A and B, it can be defined as following:
65 |
66 | .. raw:: html
67 |
68 |
69 |
70 |
71 |
72 | Since an image consists of pixels, the expression can be adapted for discrete objects in the following way:
73 |
74 | .. figure:: https://habrastorage.org/webt/_8/wc/j1/_8wcj1to6ahxfsmb8s3nrxumqjy.gif
75 | :align: center
76 |
77 | where |y| and |y_hat| are a binary value (label) and a predicted probability for the pixel |i|, respectively.
78 |
79 | Since image segmentation task can also be considered as a pixel classification problem, we additionally use common classification loss functions, denoted as H. For a binary segmentation problem H is a binary cross entropy, while for a multi-class segmentation problem H is a categorical cross entropy.
80 |
81 | .. figure:: https://habrastorage.org/webt/tf/d0/kn/tfd0kn2l612do_wmlc6zp5rdgdw.gif
82 | :align: center
83 |
84 | As an output of a model, we obtain an image, in which each pixel value corresponds to a probability of belonging to the area of interest or a class. The size of the output image matches the input image size. For binary segmentation, we use 0.3 as a threshold value (chosen using validation dataset) to binarize pixel probabilities. All pixel values below the specied threshold are set to 0, while all values above the threshold are set to 255 to produce final prediction mask.
85 |
86 | Following the segmentation step, we perform postprocessing in order to nd the coordinates of angiodysplasia lesions in the image. In the postprocessing step we use OpenCV implementation of connected component labeling function `connectedComponentsWithStats`. This function returns the number of connected components, their sizes (areas), and centroid coordinates of the corresponding connected component. In our detector we use another threshold to neglect all clusters with the size smaller than 300 pixels. Therefore, in order to establish the presence of the lesions, the number of found components should be higher than 0, otherwise the image corresponds to a normal condition. Then, for localization of angiodysplasia lesions we return centroid coordinates of all connected components.
87 |
88 | Results
89 | -------
90 |
91 | The quantitative comparison of our models' performance is presented in the Table 1. For the segmentation task the best results is achieved by `AlbuNet34`_ providing IoU = 0.754 and Dice = 0.850. When compared by the inference time, `AlbuNet34`_ is also the fastest model due to the light encoder. In the segmentation task this network takes around 20ms
92 |
93 | .. figure:: https://hsto.org/webt/mw/yj/-l/mwyj-l6ddk6xz-ykydduixzhrdk.png
94 | :scale: 60 %
95 |
96 | Prediction of our detector on the validation image. The left picture is original image, the central is ground truth mask, and the right is predicted mask. Green dots correspond to centroid coordinates that define localization of the angiodysplasia.
97 |
98 | |
99 | |
100 | |
101 |
102 | .. table:: Table 1. Segmentation results per task. Intersection over Union, Dice coefficient and inference time, ms.
103 |
104 | ============= ========= ========= ==================
105 | Model IOU, % Dice, % Inference time, ms
106 | ============= ========= ========= ==================
107 | U-Net 73.18 83.06 21
108 | TernausNet-11 74.94 84.43 51
109 | TernausNet-16 73.83 83.05 60
110 | AlbuNet34 75.35 84.98 30
111 | ============= ========= ========= ==================
112 |
113 | Pre-trained weights for all model of all segmentation tasks can be found on `google drive`_
114 |
115 | Dependencies
116 | ------------
117 |
118 | * Python 3.6
119 | * PyTorch 0.3.1
120 | * TorchVision 0.1.9
121 | * numpy 1.14.0
122 | * opencv-python 3.3.0.10
123 | * tqdm 4.19.4
124 |
125 | These dependencies can be installed by running::
126 |
127 | pip install -r requirements.txt
128 |
129 |
130 | How to run
131 | ----------
132 | The dataset is organized in the folloing way::
133 | ::
134 |
135 | ├── data
136 | │ ├── test
137 | │ └── train
138 | │ ├── angyodysplasia
139 | │ │ ├── images
140 | │ │ └── masks
141 | │ └── normal
142 | │ ├── images
143 | │ └── masks
144 | │ .......................
145 |
146 | The training dataset contains 2 sets of images, one with angyodysplasia and second without it. For training we used only the images with angyodysplasia, which were split in 5 folds.
147 |
148 | 1. Training
149 |
150 | The main file that is used to train all models - ``train.py``. Running ``python train.py --help`` will return set of all possible input parameters.
151 | To train all models we used the folloing bash script (batch size was chosen depending on how many samples fit into the GPU RAM, limit was adjusted accordingly to keep the same number of updates for every network)::
152 |
153 | #!/bin/bash
154 |
155 | for i in 0 1 2 3
156 | do
157 | python train.py --device-ids 0,1,2,3 --limit 10000 --batch-size 12 --fold $i --workers 12 --lr 0.0001 --n-epochs 10 --jaccard-weight 0.3 --model UNet11
158 | python train.py --device-ids 0,1,2,3 --limit 10000 --batch-size 12 --fold $i --workers 12 --lr 0.00001 --n-epochs 15 --jaccard-weight 0.3 --model UNet11
159 | done
160 |
161 | 2. Mask generation.
162 |
163 | The main file to generate masks is ``generate_masks.py``. Running ``python generate_masks.py --help`` will return set of all possible input parameters. Example::
164 |
165 | python generate_masks.py --output_path predictions/UNet16 --model_type UNet16 --model_path data/models/UNet16 --fold -1 --batch-size 4
166 |
167 | 3. Evaluation.
168 |
169 | The evaluation is different for a binary and multi-class segmentation:
170 |
171 | [a] In the case of binary segmentation it calculates jaccard (dice) per image / per video and then the predictions are avaraged.
172 |
173 | [b] In the case of multi-class segmentation it calculates jaccard (dice) for every class independently then avaraged them for each image and then for every video::
174 |
175 | python evaluate.py --target_path predictions/UNet16 --train_path data/train/angyodysplasia/masks
176 |
177 | 4. Further Improvements.
178 |
179 | Our results can be improved further by few percentages using simple rules such as additional augmentation of train images and train the model for longer time. In addition, the cyclic learning rate or cosine annealing could be also applied. To do it one can use our pre-trained weights as initialization. To improve test prediction TTA technique could be used as well as averaging prediction from all folds.
180 |
181 | Demo Example
182 | ------------
183 | You can start working with our models using the demonstration example: `Demo.ipynb`_
184 |
185 | .. _`Demo.ipynb`: Demo.ipynb
186 | .. _`Alexander Rakhlin`: https://www.linkedin.com/in/alrakhlin/
187 | .. _`Alexey Shvets`: https://www.linkedin.com/in/shvetsiya/
188 | .. _`Vladimir Iglovikov`: https://www.linkedin.com/in/iglovikov/
189 | .. _`Alexandr A. Kalinin`: https://alxndrkalinin.github.io/
190 | .. _`MICCAI 2017 Endoscopic Vision SubChallenge Angiodysplasia Detection and Localization`: https://endovissub2017-giana.grand-challenge.org/angiodysplasia-etisdb/
191 | .. _`TernausNet`: https://arxiv.org/abs/1801.05746
192 | .. _`U-Net`: https://arxiv.org/abs/1505.04597
193 | .. _`AlbuNet34`: https://arxiv.org/abs/1803.01207
194 | .. _`LinkNet`: https://arxiv.org/abs/1707.03718
195 | .. _`google drive`: https://drive.google.com/drive/folders/1V_bLBTzsl_Z8Ln9Iq8gjcFDxodfiHxul
196 |
197 | .. |br| raw:: html
198 |
199 |
200 |
201 | .. |plusmn| raw:: html
202 |
203 | ±
204 |
205 | .. |times| raw:: html
206 |
207 | ×
208 |
209 | .. |micro| raw:: html
210 |
211 | µm
212 |
213 | .. |y| image:: https://hsto.org/webt/jm/sn/i0/jmsni0y8mao8vnaij8a4eyuoqmu.gif
214 | .. |y_hat| image:: https://hsto.org/webt/xf/j2/a4/xfj2a4obgqhdzneysar5_us5pgk.gif
215 | .. |i| image:: https://hsto.org/webt/87/cc/ca/87cccaz4gjp2lgyeip17utljvvi.gif
216 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 | from torchvision import models
4 | import torchvision
5 | from torch.nn import functional as F
6 |
7 |
8 | def conv3x3(in_, out):
9 | return nn.Conv2d(in_, out, 3, padding=1)
10 |
11 |
12 | class ConvRelu(nn.Module):
13 | def __init__(self, in_: int, out: int):
14 | super().__init__()
15 | self.conv = conv3x3(in_, out)
16 | self.activation = nn.ReLU(inplace=True)
17 |
18 | def forward(self, x):
19 | x = self.conv(x)
20 | x = self.activation(x)
21 | return x
22 |
23 |
24 | class DecoderBlock(nn.Module):
25 | def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
26 | super(DecoderBlock, self).__init__()
27 | self.in_channels = in_channels
28 |
29 | if is_deconv:
30 | self.block = nn.Sequential(
31 | ConvRelu(in_channels, middle_channels),
32 | nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
33 | padding=1),
34 | nn.ReLU(inplace=True)
35 | )
36 | else:
37 | self.block = nn.Sequential(
38 | nn.Upsample(scale_factor=2, mode='bilinear'),
39 | ConvRelu(in_channels, middle_channels),
40 | ConvRelu(middle_channels, out_channels),
41 | )
42 |
43 | def forward(self, x):
44 | return self.block(x)
45 |
46 |
47 | class UNet11(nn.Module):
48 | def __init__(self, num_classes=1, num_filters=32, pretrained=False, is_deconv=False):
49 | """
50 | :param num_classes:
51 | :param num_filters:
52 | :param pretrained:
53 | False - no pre-trained network used
54 | vgg - encoder pre-trained with VGG11
55 | """
56 | super().__init__()
57 | self.pool = nn.MaxPool2d(2, 2)
58 |
59 | self.num_classes = num_classes
60 |
61 | self.encoder = models.vgg11(pretrained=pretrained).features
62 |
63 | self.relu = nn.ReLU(inplace=True)
64 | self.conv1 = nn.Sequential(self.encoder[0],
65 | self.relu)
66 |
67 | self.conv2 = nn.Sequential(self.encoder[3],
68 | self.relu)
69 |
70 | self.conv3 = nn.Sequential(
71 | self.encoder[6],
72 | self.relu,
73 | self.encoder[8],
74 | self.relu,
75 | )
76 | self.conv4 = nn.Sequential(
77 | self.encoder[11],
78 | self.relu,
79 | self.encoder[13],
80 | self.relu,
81 | )
82 |
83 | self.conv5 = nn.Sequential(
84 | self.encoder[16],
85 | self.relu,
86 | self.encoder[18],
87 | self.relu,
88 | )
89 |
90 | self.center = DecoderBlock(256 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv=is_deconv)
91 | self.dec5 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv=is_deconv)
92 | self.dec4 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 4, is_deconv=is_deconv)
93 | self.dec3 = DecoderBlock(256 + num_filters * 4, num_filters * 4 * 2, num_filters * 2, is_deconv=is_deconv)
94 | self.dec2 = DecoderBlock(128 + num_filters * 2, num_filters * 2 * 2, num_filters, is_deconv=is_deconv)
95 | self.dec1 = ConvRelu(64 + num_filters, num_filters)
96 |
97 | self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
98 |
99 | def forward(self, x):
100 | conv1 = self.conv1(x)
101 | conv2 = self.conv2(self.pool(conv1))
102 | conv3 = self.conv3(self.pool(conv2))
103 | conv4 = self.conv4(self.pool(conv3))
104 | conv5 = self.conv5(self.pool(conv4))
105 | center = self.center(self.pool(conv5))
106 |
107 | dec5 = self.dec5(torch.cat([center, conv5], 1))
108 | dec4 = self.dec4(torch.cat([dec5, conv4], 1))
109 | dec3 = self.dec3(torch.cat([dec4, conv3], 1))
110 | dec2 = self.dec2(torch.cat([dec3, conv2], 1))
111 | dec1 = self.dec1(torch.cat([dec2, conv1], 1))
112 |
113 | if self.num_classes > 1:
114 | x_out = F.log_softmax(self.final(dec1), dim=1)
115 | else:
116 | x_out = self.final(dec1)
117 |
118 | return x_out
119 |
120 |
121 | class UNet16(nn.Module):
122 | def __init__(self, num_classes=1, num_filters=32, pretrained=False, is_deconv=False):
123 | """
124 | :param num_classes:
125 | :param num_filters:
126 | :param pretrained: if encoder uses pre-trained weigths from VGG16
127 | """
128 | super().__init__()
129 | self.num_classes = num_classes
130 |
131 | self.pool = nn.MaxPool2d(2, 2)
132 |
133 | self.encoder = torchvision.models.vgg16(pretrained=pretrained).features
134 |
135 | self.relu = nn.ReLU(inplace=True)
136 |
137 | self.conv1 = nn.Sequential(self.encoder[0],
138 | self.relu,
139 | self.encoder[2],
140 | self.relu)
141 |
142 | self.conv2 = nn.Sequential(self.encoder[5],
143 | self.relu,
144 | self.encoder[7],
145 | self.relu)
146 |
147 | self.conv3 = nn.Sequential(self.encoder[10],
148 | self.relu,
149 | self.encoder[12],
150 | self.relu,
151 | self.encoder[14],
152 | self.relu)
153 |
154 | self.conv4 = nn.Sequential(self.encoder[17],
155 | self.relu,
156 | self.encoder[19],
157 | self.relu,
158 | self.encoder[21],
159 | self.relu)
160 |
161 | self.conv5 = nn.Sequential(self.encoder[24],
162 | self.relu,
163 | self.encoder[26],
164 | self.relu,
165 | self.encoder[28],
166 | self.relu)
167 |
168 | self.center = DecoderBlock(512, num_filters * 8 * 2, num_filters * 8, is_deconv=is_deconv)
169 |
170 | self.dec5 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv=is_deconv)
171 | self.dec4 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv=is_deconv)
172 | self.dec3 = DecoderBlock(256 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv=is_deconv)
173 | self.dec2 = DecoderBlock(128 + num_filters * 2, num_filters * 2 * 2, num_filters, is_deconv=is_deconv)
174 | self.dec1 = ConvRelu(64 + num_filters, num_filters)
175 | self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
176 |
177 | def forward(self, x):
178 | conv1 = self.conv1(x)
179 | conv2 = self.conv2(self.pool(conv1))
180 | conv3 = self.conv3(self.pool(conv2))
181 | conv4 = self.conv4(self.pool(conv3))
182 | conv5 = self.conv5(self.pool(conv4))
183 |
184 | center = self.center(self.pool(conv5))
185 |
186 | dec5 = self.dec5(torch.cat([center, conv5], 1))
187 |
188 | dec4 = self.dec4(torch.cat([dec5, conv4], 1))
189 | dec3 = self.dec3(torch.cat([dec4, conv3], 1))
190 | dec2 = self.dec2(torch.cat([dec3, conv2], 1))
191 | dec1 = self.dec1(torch.cat([dec2, conv1], 1))
192 |
193 | if self.num_classes > 1:
194 | x_out = F.log_softmax(self.final(dec1), dim=1)
195 | else:
196 | x_out = self.final(dec1)
197 |
198 | return x_out
199 |
200 |
201 | class Conv3BN(nn.Module):
202 | def __init__(self, in_: int, out: int, bn=False):
203 | super().__init__()
204 | self.conv = conv3x3(in_, out)
205 | self.bn = nn.BatchNorm2d(out) if bn else None
206 | self.activation = nn.ReLU(inplace=True)
207 |
208 | def forward(self, x):
209 | x = self.conv(x)
210 | if self.bn is not None:
211 | x = self.bn(x)
212 | x = self.activation(x)
213 | return x
214 |
215 |
216 | class UNetModule(nn.Module):
217 | def __init__(self, in_: int, out: int):
218 | super().__init__()
219 | self.l1 = Conv3BN(in_, out)
220 | self.l2 = Conv3BN(out, out)
221 |
222 | def forward(self, x):
223 | x = self.l1(x)
224 | x = self.l2(x)
225 | return x
226 |
227 |
228 | class UNet(nn.Module):
229 | """
230 | Vanilla UNet.
231 |
232 | Implementation from https://github.com/lopuhin/mapillary-vistas-2017/blob/master/unet_models.py
233 | """
234 | output_downscaled = 1
235 | module = UNetModule
236 |
237 | def __init__(self,
238 | input_channels: int = 3,
239 | filters_base: int = 32,
240 | down_filter_factors=(1, 2, 4, 8, 16),
241 | up_filter_factors=(1, 2, 4, 8, 16),
242 | bottom_s=4,
243 | num_classes=1,
244 | add_output=True):
245 | super().__init__()
246 | self.num_classes = num_classes
247 | assert len(down_filter_factors) == len(up_filter_factors)
248 | assert down_filter_factors[-1] == up_filter_factors[-1]
249 | down_filter_sizes = [filters_base * s for s in down_filter_factors]
250 | up_filter_sizes = [filters_base * s for s in up_filter_factors]
251 | self.down, self.up = nn.ModuleList(), nn.ModuleList()
252 | self.down.append(self.module(input_channels, down_filter_sizes[0]))
253 | for prev_i, nf in enumerate(down_filter_sizes[1:]):
254 | self.down.append(self.module(down_filter_sizes[prev_i], nf))
255 | for prev_i, nf in enumerate(up_filter_sizes[1:]):
256 | self.up.append(self.module(
257 | down_filter_sizes[prev_i] + nf, up_filter_sizes[prev_i]))
258 | pool = nn.MaxPool2d(2, 2)
259 | pool_bottom = nn.MaxPool2d(bottom_s, bottom_s)
260 | upsample = nn.Upsample(scale_factor=2)
261 | upsample_bottom = nn.Upsample(scale_factor=bottom_s)
262 | self.downsamplers = [None] + [pool] * (len(self.down) - 1)
263 | self.downsamplers[-1] = pool_bottom
264 | self.upsamplers = [upsample] * len(self.up)
265 | self.upsamplers[-1] = upsample_bottom
266 | self.add_output = add_output
267 | if add_output:
268 | self.conv_final = nn.Conv2d(up_filter_sizes[0], num_classes, 1)
269 |
270 | def forward(self, x):
271 | xs = []
272 | for downsample, down in zip(self.downsamplers, self.down):
273 | x_in = x if downsample is None else downsample(xs[-1])
274 | x_out = down(x_in)
275 | xs.append(x_out)
276 |
277 | x_out = xs[-1]
278 | for x_skip, upsample, up in reversed(
279 | list(zip(xs[:-1], self.upsamplers, self.up))):
280 | x_out = upsample(x_out)
281 | x_out = up(torch.cat([x_out, x_skip], 1))
282 |
283 | if self.add_output:
284 | x_out = self.conv_final(x_out)
285 | if self.num_classes > 1:
286 | x_out = F.log_softmax(x_out, dim=1)
287 | return x_out
288 |
289 |
290 | class AlbuNet34(nn.Module):
291 | """
292 | UNet (https://arxiv.org/abs/1505.04597) with Resnet34(https://arxiv.org/abs/1512.03385) encoder
293 | Proposed by Alexander Buslaev: https://www.linkedin.com/in/al-buslaev/
294 | """
295 |
296 | def __init__(self, num_classes=1, num_filters=32, pretrained=False, is_deconv=False):
297 | """
298 | :param num_classes:
299 | :param num_filters:
300 | :param pretrained:
301 | False - no pre-trained network is used
302 | True - encoder is pre-trained with resnet34
303 | :is_deconv:
304 | False: bilinear interpolation is used in decoder
305 | True: deconvolution is used in decoder
306 | """
307 | super().__init__()
308 | self.num_classes = num_classes
309 |
310 | self.pool = nn.MaxPool2d(2, 2)
311 |
312 | self.encoder = torchvision.models.resnet34(pretrained=pretrained)
313 |
314 | self.relu = nn.ReLU(inplace=True)
315 |
316 | self.conv1 = nn.Sequential(self.encoder.conv1,
317 | self.encoder.bn1,
318 | self.encoder.relu,
319 | self.pool)
320 |
321 | self.conv2 = self.encoder.layer1
322 |
323 | self.conv3 = self.encoder.layer2
324 |
325 | self.conv4 = self.encoder.layer3
326 |
327 | self.conv5 = self.encoder.layer4
328 |
329 | self.center = DecoderBlock(512, num_filters * 8 * 2, num_filters * 8, is_deconv)
330 |
331 | self.dec5 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
332 | self.dec4 = DecoderBlock(256 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
333 | self.dec3 = DecoderBlock(128 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
334 | self.dec2 = DecoderBlock(64 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2, is_deconv)
335 | self.dec1 = DecoderBlock(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
336 | self.dec0 = ConvRelu(num_filters, num_filters)
337 | self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
338 |
339 | def forward(self, x):
340 | conv1 = self.conv1(x)
341 | conv2 = self.conv2(conv1)
342 | conv3 = self.conv3(conv2)
343 | conv4 = self.conv4(conv3)
344 | conv5 = self.conv5(conv4)
345 |
346 | center = self.center(self.pool(conv5))
347 |
348 | dec5 = self.dec5(torch.cat([center, conv5], 1))
349 |
350 | dec4 = self.dec4(torch.cat([dec5, conv4], 1))
351 | dec3 = self.dec3(torch.cat([dec4, conv3], 1))
352 | dec2 = self.dec2(torch.cat([dec3, conv2], 1))
353 | dec1 = self.dec1(dec2)
354 | dec0 = self.dec0(dec1)
355 |
356 | if self.num_classes > 1:
357 | x_out = F.log_softmax(self.final(dec0), dim=1)
358 | else:
359 | x_out = self.final(dec0)
360 |
361 | return x_out
362 |
--------------------------------------------------------------------------------
/transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | Based on a set of transformations developed by Alexander Buslaev as a part of the winning solution (1 out of 735)
3 | to the Kaggle: Carvana Image Masking Challenge.
4 |
5 | https://github.com/asanakoy/kaggle_carvana_segmentation/blob/master/albu/src/transforms.py
6 | """
7 |
8 | import random
9 | import cv2
10 | import numpy as np
11 | import math
12 |
13 |
14 | class DualCompose:
15 | def __init__(self, transforms):
16 | self.transforms = transforms
17 |
18 | def __call__(self, x, mask=None):
19 | for t in self.transforms:
20 | x, mask = t(x, mask)
21 | return x, mask
22 |
23 |
24 | class OneOf:
25 | def __init__(self, transforms, prob=0.5):
26 | self.transforms = transforms
27 | self.prob = prob
28 |
29 | def __call__(self, x, mask=None):
30 | if random.random() < self.prob:
31 | t = random.choice(self.transforms)
32 | t.prob = 1.
33 | x, mask = t(x, mask)
34 | return x, mask
35 |
36 |
37 | class OneOrOther:
38 | def __init__(self, first, second, prob=0.5):
39 | self.first = first
40 | first.prob = 1.
41 | self.second = second
42 | second.prob = 1.
43 | self.prob = prob
44 |
45 | def __call__(self, x, mask=None):
46 | if random.random() < self.prob:
47 | x, mask = self.first(x, mask)
48 | else:
49 | x, mask = self.second(x, mask)
50 | return x, mask
51 |
52 |
53 | class ImageOnly:
54 | def __init__(self, trans):
55 | self.trans = trans
56 |
57 | def __call__(self, x, mask=None):
58 | return self.trans(x), mask
59 |
60 |
61 | class VerticalFlip:
62 | def __init__(self, prob=0.5):
63 | self.prob = prob
64 |
65 | def __call__(self, img, mask=None):
66 | if random.random() < self.prob:
67 | img = cv2.flip(img, 0)
68 | if mask is not None:
69 | mask = cv2.flip(mask, 0)
70 | return img, mask
71 |
72 |
73 | class HorizontalFlip:
74 | def __init__(self, prob=0.5):
75 | self.prob = prob
76 |
77 | def __call__(self, img, mask=None):
78 | if random.random() < self.prob:
79 | img = cv2.flip(img, 1)
80 | if mask is not None:
81 | mask = cv2.flip(mask, 1)
82 | return img, mask
83 |
84 |
85 | class RandomFlip:
86 | def __init__(self, prob=0.5):
87 | self.prob = prob
88 |
89 | def __call__(self, img, mask=None):
90 | if random.random() < self.prob:
91 | d = random.randint(-1, 1)
92 | img = cv2.flip(img, d)
93 | if mask is not None:
94 | mask = cv2.flip(mask, d)
95 | return img, mask
96 |
97 |
98 | class Transpose:
99 | def __init__(self, prob=0.5):
100 | self.prob = prob
101 |
102 | def __call__(self, img, mask=None):
103 | if random.random() < self.prob:
104 | img = img.transpose(1, 0, 2)
105 | if mask is not None:
106 | mask = mask.transpose(1, 0)
107 | return img, mask
108 |
109 |
110 | class RandomRotate90:
111 | def __init__(self, prob=0.5):
112 | self.prob = prob
113 |
114 | def __call__(self, img, mask=None):
115 | if random.random() < self.prob:
116 | factor = random.randint(0, 4)
117 | img = np.rot90(img, factor)
118 | if mask is not None:
119 | mask = np.rot90(mask, factor)
120 | return img.copy(), mask.copy()
121 |
122 |
123 | class Rotate:
124 | def __init__(self, limit=90, prob=0.5):
125 | self.prob = prob
126 | self.limit = limit
127 |
128 | def __call__(self, img, mask=None):
129 | if random.random() < self.prob:
130 | angle = random.uniform(-self.limit, self.limit)
131 |
132 | height, width = img.shape[0:2]
133 | mat = cv2.getRotationMatrix2D((width / 2, height / 2), angle, 1.0)
134 | img = cv2.warpAffine(img, mat, (height, width),
135 | flags=cv2.INTER_LINEAR,
136 | borderMode=cv2.BORDER_REFLECT_101)
137 | if mask is not None:
138 | mask = cv2.warpAffine(mask, mat, (height, width),
139 | flags=cv2.INTER_NEAREST,
140 | borderMode=cv2.BORDER_REFLECT_101)
141 |
142 | return img, mask
143 |
144 |
145 | class RandomCrop:
146 | def __init__(self, size):
147 | self.h = size[0]
148 | self.w = size[1]
149 |
150 | def __call__(self, img, mask=None):
151 | height, width, _ = img.shape
152 |
153 | h_start = np.random.randint(0, height - self.h)
154 | w_start = np.random.randint(0, width - self.w)
155 |
156 | img = img[h_start: h_start + self.h, w_start: w_start + self.w]
157 |
158 | assert img.shape[0] == self.h
159 | assert img.shape[1] == self.w
160 |
161 | if mask is not None:
162 | mask = mask[h_start: h_start + self.h, w_start: w_start + self.w]
163 |
164 | return img, mask
165 |
166 |
167 | class Shift:
168 | def __init__(self, limit=4, prob=.5):
169 | self.limit = limit
170 | self.prob = prob
171 |
172 | def __call__(self, img, mask=None):
173 | if random.random() < self.prob:
174 | limit = self.limit
175 | dx = round(random.uniform(-limit, limit))
176 | dy = round(random.uniform(-limit, limit))
177 |
178 | height, width, channel = img.shape
179 | y1 = limit + 1 + dy
180 | y2 = y1 + height
181 | x1 = limit + 1 + dx
182 | x2 = x1 + width
183 |
184 | img1 = cv2.copyMakeBorder(img, limit + 1, limit + 1, limit + 1, limit + 1,
185 | borderType=cv2.BORDER_REFLECT_101)
186 | img = img1[y1:y2, x1:x2, :]
187 | if mask is not None:
188 | msk1 = cv2.copyMakeBorder(mask, limit + 1, limit + 1, limit + 1, limit + 1,
189 | borderType=cv2.BORDER_REFLECT_101)
190 | mask = msk1[y1:y2, x1:x2, :]
191 |
192 | return img, mask
193 |
194 |
195 | class ShiftScale:
196 | def __init__(self, limit=4, prob=.25):
197 | self.limit = limit
198 | self.prob = prob
199 |
200 | def __call__(self, img, mask=None):
201 | limit = self.limit
202 | if random.random() < self.prob:
203 | height, width, channel = img.shape
204 | assert (width == height)
205 | size0 = width
206 | size1 = width + 2 * limit
207 | size = round(random.uniform(size0, size1))
208 |
209 | dx = round(random.uniform(0, size1 - size))
210 | dy = round(random.uniform(0, size1 - size))
211 |
212 | y1 = dy
213 | y2 = y1 + size
214 | x1 = dx
215 | x2 = x1 + size
216 |
217 | img1 = cv2.copyMakeBorder(img, limit, limit, limit, limit, borderType=cv2.BORDER_REFLECT_101)
218 | img = (img1[y1:y2, x1:x2, :] if size == size0
219 | else cv2.resize(img1[y1:y2, x1:x2, :], (size0, size0), interpolation=cv2.INTER_LINEAR))
220 |
221 | if mask is not None:
222 | msk1 = cv2.copyMakeBorder(mask, limit, limit, limit, limit, borderType=cv2.BORDER_REFLECT_101)
223 | mask = (msk1[y1:y2, x1:x2, :] if size == size0
224 | else cv2.resize(msk1[y1:y2, x1:x2, :], (size0, size0), interpolation=cv2.INTER_LINEAR))
225 |
226 | return img, mask
227 |
228 |
229 | class ShiftScaleRotate:
230 | def __init__(self, shift_limit=0.0625, scale_limit=0.1, rotate_limit=45, prob=0.5):
231 | self.shift_limit = shift_limit
232 | self.scale_limit = scale_limit
233 | self.rotate_limit = rotate_limit
234 | self.prob = prob
235 |
236 | def __call__(self, img, mask=None):
237 | if random.random() < self.prob:
238 | height, width, channel = img.shape
239 |
240 | angle = random.uniform(-self.rotate_limit, self.rotate_limit)
241 | scale = random.uniform(1 - self.scale_limit, 1 + self.scale_limit)
242 | dx = round(random.uniform(-self.shift_limit, self.shift_limit)) * width
243 | dy = round(random.uniform(-self.shift_limit, self.shift_limit)) * height
244 |
245 | cc = math.cos(angle / 180 * math.pi) * scale
246 | ss = math.sin(angle / 180 * math.pi) * scale
247 | rotate_matrix = np.array([[cc, -ss], [ss, cc]])
248 |
249 | box0 = np.array([[0, 0], [width, 0], [width, height], [0, height], ])
250 | box1 = box0 - np.array([width / 2, height / 2])
251 | box1 = np.dot(box1, rotate_matrix.T) + np.array([width / 2 + dx, height / 2 + dy])
252 |
253 | box0 = box0.astype(np.float32)
254 | box1 = box1.astype(np.float32)
255 | mat = cv2.getPerspectiveTransform(box0, box1)
256 | img = cv2.warpPerspective(img, mat, (width, height),
257 | flags=cv2.INTER_LINEAR,
258 | borderMode=cv2.BORDER_REFLECT_101)
259 | if mask is not None:
260 | mask = cv2.warpPerspective(mask, mat, (width, height),
261 | flags=cv2.INTER_NEAREST,
262 | borderMode=cv2.BORDER_REFLECT_101)
263 |
264 | return img, mask
265 |
266 |
267 | class CenterCrop:
268 | def __init__(self, size):
269 | if isinstance(size, int):
270 | size = (size, size)
271 |
272 | self.height = size[0]
273 | self.width = size[1]
274 |
275 | def __call__(self, img, mask=None):
276 | h, w, c = img.shape
277 | dy = (h - self.height) // 2
278 | dx = (w - self.width) // 2
279 |
280 | y1 = dy
281 | y2 = y1 + self.height
282 | x1 = dx
283 | x2 = x1 + self.width
284 | img = img[y1:y2, x1:x2]
285 |
286 | if mask is not None:
287 | mask = mask[y1:y2, x1:x2]
288 |
289 | return img, mask
290 |
291 |
292 | class Normalize:
293 | def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
294 | self.mean = mean
295 | self.std = std
296 |
297 | def __call__(self, img):
298 | max_pixel_value = 255.0
299 |
300 | img = img.astype(np.float32) / max_pixel_value
301 |
302 | img -= np.ones(img.shape) * self.mean
303 | img /= np.ones(img.shape) * self.std
304 | return img
305 |
306 |
307 | class Distort1:
308 | """"
309 | ## unconverntional augmnet ################################################################################3
310 | ## https://stackoverflow.com/questions/6199636/formulas-for-barrel-pincushion-distortion
311 |
312 | ## https://stackoverflow.com/questions/10364201/image-transformation-in-opencv
313 | ## https://stackoverflow.com/questions/2477774/correcting-fisheye-distortion-programmatically
314 | ## http://www.coldvision.io/2017/03/02/advanced-lane-finding-using-opencv/
315 |
316 | ## barrel\pincushion distortion
317 | """
318 |
319 | def __init__(self, distort_limit=0.35, shift_limit=0.25, prob=0.5):
320 | self.distort_limit = distort_limit
321 | self.shift_limit = shift_limit
322 | self.prob = prob
323 |
324 | def __call__(self, img, mask=None):
325 | if random.random() < self.prob:
326 | height, width, channel = img.shape
327 |
328 | if 0:
329 | img = img.copy()
330 | for x in range(0, width, 10):
331 | cv2.line(img, (x, 0), (x, height), (1, 1, 1), 1)
332 | for y in range(0, height, 10):
333 | cv2.line(img, (0, y), (width, y), (1, 1, 1), 1)
334 |
335 | k = random.uniform(-self.distort_limit, self.distort_limit) * 0.00001
336 | dx = random.uniform(-self.shift_limit, self.shift_limit) * width
337 | dy = random.uniform(-self.shift_limit, self.shift_limit) * height
338 |
339 | # map_x, map_y =
340 | # cv2.initUndistortRectifyMap(intrinsics, dist_coeffs, None, None, (width,height),cv2.CV_32FC1)
341 | # https://stackoverflow.com/questions/6199636/formulas-for-barrel-pincushion-distortion
342 | # https://stackoverflow.com/questions/10364201/image-transformation-in-opencv
343 | x, y = np.mgrid[0:width:1, 0:height:1]
344 | x = x.astype(np.float32) - width / 2 - dx
345 | y = y.astype(np.float32) - height / 2 - dy
346 | theta = np.arctan2(y, x)
347 | d = (x * x + y * y) ** 0.5
348 | r = d * (1 + k * d * d)
349 | map_x = r * np.cos(theta) + width / 2 + dx
350 | map_y = r * np.sin(theta) + height / 2 + dy
351 |
352 | img = cv2.remap(img, map_x, map_y, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101)
353 | if mask is not None:
354 | mask = cv2.remap(mask, map_x, map_y, interpolation=cv2.INTER_NEAREST, borderMode=cv2.BORDER_REFLECT_101)
355 | return img, mask
356 |
357 |
358 | class Distort2:
359 | """
360 | #http://pythology.blogspot.sg/2014/03/interpolation-on-regular-distorted-grid.html
361 | ## grid distortion
362 | """
363 |
364 | def __init__(self, num_steps=10, distort_limit=0.2, prob=0.5):
365 | self.num_steps = num_steps
366 | self.distort_limit = distort_limit
367 | self.prob = prob
368 |
369 | def __call__(self, img, mask=None):
370 | if random.random() < self.prob:
371 | height, width, channel = img.shape
372 |
373 | x_step = width // self.num_steps
374 | xx = np.zeros(width, np.float32)
375 | prev = 0
376 | for x in range(0, width, x_step):
377 | start = x
378 | end = x + x_step
379 | if end > width:
380 | end = width
381 | cur = width
382 | else:
383 | cur = prev + x_step * (1 + random.uniform(-self.distort_limit, self.distort_limit))
384 |
385 | xx[start:end] = np.linspace(prev, cur, end - start)
386 | prev = cur
387 |
388 | y_step = height // self.num_steps
389 | yy = np.zeros(height, np.float32)
390 | prev = 0
391 | for y in range(0, height, y_step):
392 | start = y
393 | end = y + y_step
394 | if end > width:
395 | end = height
396 | cur = height
397 | else:
398 | cur = prev + y_step * (1 + random.uniform(-self.distort_limit, self.distort_limit))
399 |
400 | yy[start:end] = np.linspace(prev, cur, end - start)
401 | prev = cur
402 |
403 | map_x, map_y = np.meshgrid(xx, yy)
404 | map_x = map_x.astype(np.float32)
405 | map_y = map_y.astype(np.float32)
406 | img = cv2.remap(img, map_x, map_y,
407 | interpolation=cv2.INTER_LINEAR,
408 | borderMode=cv2.BORDER_REFLECT_101)
409 | if mask is not None:
410 | mask = cv2.remap(mask, map_x, map_y,
411 | interpolation=cv2.INTER_LINEAR,
412 | borderMode=cv2.BORDER_REFLECT_101)
413 |
414 | return img, mask
415 |
416 |
417 | def clip(img, dtype, maxval):
418 | return np.clip(img, 0, maxval).astype(dtype)
419 |
420 |
421 | class RandomFilter:
422 | """
423 | blur sharpen, etc
424 | """
425 |
426 | def __init__(self, limit=.5, prob=.5):
427 | self.limit = limit
428 | self.prob = prob
429 |
430 | def __call__(self, img):
431 | if random.random() < self.prob:
432 | alpha = self.limit * random.uniform(0, 1)
433 | kernel = np.ones((3, 3), np.float32) / 9 * 0.2
434 |
435 | colored = img[..., :3]
436 | colored = alpha * cv2.filter2D(colored, -1, kernel) + (1 - alpha) * colored
437 | maxval = np.max(img[..., :3])
438 | dtype = img.dtype
439 | img[..., :3] = clip(colored, dtype, maxval)
440 |
441 | return img
442 |
443 |
444 | # https://github.com/pytorch/vision/pull/27/commits/659c854c6971ecc5b94dca3f4459ef2b7e42fb70
445 | # color augmentation
446 |
447 | # brightness, contrast, saturation-------------
448 | # from mxnet code, see: https://github.com/dmlc/mxnet/blob/master/python/mxnet/image.py
449 |
450 | class RandomBrightness:
451 | def __init__(self, limit=0.1, prob=0.5):
452 | self.limit = limit
453 | self.prob = prob
454 |
455 | def __call__(self, img):
456 | if random.random() < self.prob:
457 | alpha = 1.0 + self.limit * random.uniform(-1, 1)
458 |
459 | maxval = np.max(img[..., :3])
460 | dtype = img.dtype
461 | img[..., :3] = clip(alpha * img[..., :3], dtype, maxval)
462 | return img
463 |
464 |
465 | class RandomContrast:
466 | def __init__(self, limit=.1, prob=.5):
467 | self.limit = limit
468 | self.prob = prob
469 |
470 | def __call__(self, img):
471 | if random.random() < self.prob:
472 | alpha = 1.0 + self.limit * random.uniform(-1, 1)
473 |
474 | gray = cv2.cvtColor(img[:, :, :3], cv2.COLOR_BGR2GRAY)
475 | gray = (3.0 * (1.0 - alpha) / gray.size) * np.sum(gray)
476 | maxval = np.max(img[..., :3])
477 | dtype = img.dtype
478 | img[:, :, :3] = clip(alpha * img[:, :, :3] + gray, dtype, maxval)
479 | return img
480 |
481 |
482 | class RandomSaturation:
483 | def __init__(self, limit=0.3, prob=0.5):
484 | self.limit = limit
485 | self.prob = prob
486 |
487 | def __call__(self, img):
488 | # dont work :(
489 | if random.random() < self.prob:
490 | maxval = np.max(img[..., :3])
491 | dtype = img.dtype
492 | alpha = 1.0 + random.uniform(-self.limit, self.limit)
493 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
494 | gray = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
495 | img[..., :3] = alpha * img[..., :3] + (1.0 - alpha) * gray
496 | img[..., :3] = clip(img[..., :3], dtype, maxval)
497 | return img
498 |
499 |
500 | class RandomHueSaturationValue:
501 | def __init__(self, hue_shift_limit=(-20, 20), sat_shift_limit=(-35, 35), val_shift_limit=(-35, 35), prob=0.5):
502 | self.hue_shift_limit = hue_shift_limit
503 | self.sat_shift_limit = sat_shift_limit
504 | self.val_shift_limit = val_shift_limit
505 | self.prob = prob
506 |
507 | def __call__(self, image):
508 | if random.random() < self.prob:
509 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
510 | h, s, v = cv2.split(image)
511 | hue_shift = np.random.uniform(self.hue_shift_limit[0], self.hue_shift_limit[1])
512 | h = cv2.add(h, hue_shift)
513 | sat_shift = np.random.uniform(self.sat_shift_limit[0], self.sat_shift_limit[1])
514 | s = cv2.add(s, sat_shift)
515 | val_shift = np.random.uniform(self.val_shift_limit[0], self.val_shift_limit[1])
516 | v = cv2.add(v, val_shift)
517 | image = cv2.merge((h, s, v))
518 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
519 | return image
520 |
521 |
522 | class CLAHE:
523 | def __init__(self, clipLimit=2.0, tileGridSize=(8, 8)):
524 | self.clipLimit = clipLimit
525 | self.tileGridSize = tileGridSize
526 |
527 | def __call__(self, im):
528 | img_yuv = cv2.cvtColor(im, cv2.COLOR_BGR2YUV)
529 | clahe = cv2.createCLAHE(clipLimit=self.clipLimit, tileGridSize=self.tileGridSize)
530 | img_yuv[:, :, 0] = clahe.apply(img_yuv[:, :, 0])
531 | img_output = cv2.cvtColor(img_yuv, cv2.COLOR_YUV2BGR)
532 | return img_output
533 |
534 |
535 | def augment(x, mask=None, prob=0.5):
536 | return DualCompose([
537 | OneOrOther(
538 | *(OneOf([
539 | Distort1(distort_limit=0.05, shift_limit=0.05),
540 | Distort2(num_steps=2, distort_limit=0.05)]),
541 | ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.10, rotate_limit=45)), prob=prob),
542 | RandomFlip(prob=0.5),
543 | Transpose(prob=0.5),
544 | ImageOnly(RandomContrast(limit=0.2, prob=0.5)),
545 | ImageOnly(RandomFilter(limit=0.5, prob=0.2)),
546 | ])(x, mask)
547 |
--------------------------------------------------------------------------------