├── .gitignore
├── LICENSE
├── README.md
├── checkpoints
├── weights_iw1.pth
├── weights_iw16.pth
├── weights_iw3.pth
├── weights_iw4.pth
├── weights_pw1.pth
├── weights_pw16.pth
├── weights_pw3.pth
└── weights_pw4.pth
├── dataset
├── test
│ └── .gitignore
├── train
│ ├── Benign
│ │ └── .gitignore
│ ├── InSitu
│ │ └── .gitignore
│ ├── Invasive
│ │ └── .gitignore
│ └── Normal
│ │ └── .gitignore
└── validation
│ ├── Benign
│ └── .gitignore
│ ├── InSitu
│ └── .gitignore
│ ├── Invasive
│ └── .gitignore
│ └── Normal
│ └── .gitignore
├── img
├── dataset.jpg
└── network.png
├── requirements.txt
├── src
├── __init__.py
├── datasets.py
├── models.py
├── networks.py
├── options.py
└── patch_extractor.py
├── test.py
├── train.py
└── validate.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
103 |
104 | .idea
105 | draft
106 | checkpoints
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Imaging Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ICIAR2018
2 | ### Two-Stage Convolutional Neural Network for Breast Cancer Histology Image Classification
3 | This repository is the part A of the ICIAR 2018 Grand Challenge on BreAst Cancer Histology (BACH) images for automatically classifying H&E stained breast histology microscopy images in four classes: normal, benign, in situ carcinoma and invasive carcinoma.
4 |
5 | We are presenting a CNN approach using two convolutional networks to classify histology images in a patchwise fashion. The first network, receives overlapping patches (35 patches) of the whole-slide image and learns to generate spatially smaller outputs. The second network is trained on the downsampled patches of the whole image using the output of the first network. The number of channels in the input to the second network is equal to the total number of patches extracted from the microscopy image in a non-overlapping fashion (12 patches) times the depth of the feature maps generted by the first network (C):
6 |
7 |
8 |
9 |
10 | ## Prerequisites
11 | - Linux
12 | - Python 3
13 | - NVIDIA GPU (12G or 24G memory) + CUDA cuDNN
14 |
15 | ## Getting Started
16 | ### Installation
17 | - Clone this repo:
18 | ```bash
19 | git clone https://github.com/ImagingLab/ICIAR2018
20 | cd ICIAR2018
21 | ```
22 | - Install PyTorch and dependencies from http://pytorch.org
23 | - Install python requirements:
24 | ```bash
25 | pip install -r requirements.txt
26 | ```
27 |
28 | ### Dataset
29 | - We use the ICIAR2018 dataset. To train a model on the full dataset, please download it from the [official website](https://iciar2018-challenge.grand-challenge.org/dataset/) (registration required). The dataset is composed of 400 high resolution Hematoxylin and Eosin (H&E) stained breast histology microscopy images labelled as normal, benign, in situ carcinoma, and invasive carcinoma (100 images for each category):
30 |
31 |
32 |
33 | After downloading, please put it under the `datasets` folder in the same way the sub-directories are provided.
34 |
35 |
36 | ### Testing
37 | - The pre-trained ICIAR2018 dataset model resides under `./checkpoints`.
38 | - To test the model, run `test.py` script
39 | - Use `--testset-path` command-line argument to provide the path to the `test` folder.
40 | ```bash
41 | python test.py --testset-path ./dataset/test
42 | ```
43 | - If you don't provide the test-set path, an open-file dialogbox will appear to select an image for test.
44 | The test results will be printed on the screen.
45 |
46 |
47 |
48 | ### Training
49 | - To train the model, run `train.py` script
50 | ```bash
51 | python train.py
52 | ```
53 | - To change the number of feature-maps generated by the patch-wise network use `--channels` argument:
54 | ```bash
55 | python train.py --channels 1
56 | ```
57 |
58 |
59 | ### Validation & ROC Curves
60 | - To validate the model on the validation set and plot the ROC curves, run `validate.py` script
61 | ```bash
62 | python validate.py
63 | ```
64 | - To change the number of feature-maps generated by the patch-wise network use `--channels` argument:
65 | ```bash
66 | python train.py --channels 1
67 | ```
68 |
69 | ## Citation
70 | If you use this code for your research, please cite our paper Two-Stage Convolutional Neural Network for Breast Cancer Histology Image Classification:
71 |
72 | ```
73 | @inproceedings{nazeri2018two,
74 | title={Two-Stage Convolutional Neural Network for Breast Cancer Histology Image Classification},
75 | author={Nazeri, Kamyar and Aminpour, Azad and Ebrahimi, Mehran},
76 | booktitle={International Conference Image Analysis and Recognition},
77 | pages={717--726},
78 | year={2018},
79 | organization={Springer}
80 | }
81 | ```
82 |
--------------------------------------------------------------------------------
/checkpoints/weights_iw1.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ImagingLab/ICIAR2018/0a98dd9c21e3069052a38bc75973832b3a5e775e/checkpoints/weights_iw1.pth
--------------------------------------------------------------------------------
/checkpoints/weights_iw16.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ImagingLab/ICIAR2018/0a98dd9c21e3069052a38bc75973832b3a5e775e/checkpoints/weights_iw16.pth
--------------------------------------------------------------------------------
/checkpoints/weights_iw3.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ImagingLab/ICIAR2018/0a98dd9c21e3069052a38bc75973832b3a5e775e/checkpoints/weights_iw3.pth
--------------------------------------------------------------------------------
/checkpoints/weights_iw4.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ImagingLab/ICIAR2018/0a98dd9c21e3069052a38bc75973832b3a5e775e/checkpoints/weights_iw4.pth
--------------------------------------------------------------------------------
/checkpoints/weights_pw1.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ImagingLab/ICIAR2018/0a98dd9c21e3069052a38bc75973832b3a5e775e/checkpoints/weights_pw1.pth
--------------------------------------------------------------------------------
/checkpoints/weights_pw16.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ImagingLab/ICIAR2018/0a98dd9c21e3069052a38bc75973832b3a5e775e/checkpoints/weights_pw16.pth
--------------------------------------------------------------------------------
/checkpoints/weights_pw3.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ImagingLab/ICIAR2018/0a98dd9c21e3069052a38bc75973832b3a5e775e/checkpoints/weights_pw3.pth
--------------------------------------------------------------------------------
/checkpoints/weights_pw4.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ImagingLab/ICIAR2018/0a98dd9c21e3069052a38bc75973832b3a5e775e/checkpoints/weights_pw4.pth
--------------------------------------------------------------------------------
/dataset/test/.gitignore:
--------------------------------------------------------------------------------
1 | *.tif
--------------------------------------------------------------------------------
/dataset/train/Benign/.gitignore:
--------------------------------------------------------------------------------
1 | *.tif
--------------------------------------------------------------------------------
/dataset/train/InSitu/.gitignore:
--------------------------------------------------------------------------------
1 | *.tif
--------------------------------------------------------------------------------
/dataset/train/Invasive/.gitignore:
--------------------------------------------------------------------------------
1 | *.tif
--------------------------------------------------------------------------------
/dataset/train/Normal/.gitignore:
--------------------------------------------------------------------------------
1 | *.tif
--------------------------------------------------------------------------------
/dataset/validation/Benign/.gitignore:
--------------------------------------------------------------------------------
1 | *.tif
--------------------------------------------------------------------------------
/dataset/validation/InSitu/.gitignore:
--------------------------------------------------------------------------------
1 | *.tif
--------------------------------------------------------------------------------
/dataset/validation/Invasive/.gitignore:
--------------------------------------------------------------------------------
1 | *.tif
--------------------------------------------------------------------------------
/dataset/validation/Normal/.gitignore:
--------------------------------------------------------------------------------
1 | *.tif
--------------------------------------------------------------------------------
/img/dataset.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ImagingLab/ICIAR2018/0a98dd9c21e3069052a38bc75973832b3a5e775e/img/dataset.jpg
--------------------------------------------------------------------------------
/img/network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ImagingLab/ICIAR2018/0a98dd9c21e3069052a38bc75973832b3a5e775e/img/network.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy ~=1.14.3
2 | scipy ~= 1.0.1
3 | future ~= 0.16.0
4 | matplotlib ~= 2.2.2
5 | pillow >= 6.2.0
6 | scikit-learn ~= 0.19.1
7 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import *
2 | from .networks import *
3 | from .options import *
4 | from .datasets import *
5 | from .patch_extractor import *
6 |
--------------------------------------------------------------------------------
/src/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import torch
4 | import numpy as np
5 | from PIL import Image, ImageEnhance
6 | from torch.utils.data import Dataset
7 | from torchvision.transforms import transforms
8 | from .patch_extractor import PatchExtractor
9 |
10 | LABELS = ['Normal', 'Benign', 'InSitu', 'Invasive']
11 | IMAGE_SIZE = (2048, 1536)
12 | PATCH_SIZE = 512
13 |
14 |
15 | class PatchWiseDataset(Dataset):
16 | def __init__(self, path, stride=PATCH_SIZE, rotate=False, flip=False, enhance=False):
17 | super().__init__()
18 |
19 | wp = int((IMAGE_SIZE[0] - PATCH_SIZE) / stride + 1)
20 | hp = int((IMAGE_SIZE[1] - PATCH_SIZE) / stride + 1)
21 | labels = {name: index for index in range(len(LABELS)) for name in glob.glob(path + '/' + LABELS[index] + '/*.tif')}
22 |
23 | self.path = path
24 | self.stride = stride
25 | self.labels = labels
26 | self.names = list(sorted(labels.keys()))
27 | self.shape = (len(labels), wp, hp, (4 if rotate else 1), (2 if flip else 1), (2 if enhance else 1)) # (files, x_patches, y_patches, rotations, flip, enhance)
28 | self.augment_size = np.prod(self.shape) / len(labels)
29 |
30 | def __getitem__(self, index):
31 | im, xpatch, ypatch, rotation, flip, enhance = np.unravel_index(index, self.shape)
32 |
33 | with Image.open(self.names[im]) as img:
34 | extractor = PatchExtractor(img=img, patch_size=PATCH_SIZE, stride=self.stride)
35 | patch = extractor.extract_patch((xpatch, ypatch))
36 |
37 | if rotation != 0:
38 | patch = patch.rotate(rotation * 90)
39 |
40 | if flip != 0:
41 | patch = patch.transpose(Image.FLIP_LEFT_RIGHT)
42 |
43 | if enhance != 0:
44 | factors = np.random.uniform(.5, 1.5, 3)
45 | patch = ImageEnhance.Color(patch).enhance(factors[0])
46 | patch = ImageEnhance.Contrast(patch).enhance(factors[1])
47 | patch = ImageEnhance.Brightness(patch).enhance(factors[2])
48 |
49 | label = self.labels[self.names[im]]
50 | return transforms.ToTensor()(patch), label
51 |
52 | def __len__(self):
53 | return np.prod(self.shape)
54 |
55 |
56 | class ImageWiseDataset(Dataset):
57 | def __init__(self, path, stride=PATCH_SIZE, rotate=False, flip=False, enhance=False):
58 | super().__init__()
59 |
60 | labels = {name: index for index in range(len(LABELS)) for name in glob.glob(path + '/' + LABELS[index] + '/*.tif')}
61 |
62 | self.path = path
63 | self.stride = stride
64 | self.labels = labels
65 | self.names = list(sorted(labels.keys()))
66 | self.shape = (len(labels), (4 if rotate else 1), (2 if flip else 1), (2 if enhance else 1)) # (files, x_patches, y_patches, rotations, flip, enhance)
67 | self.augment_size = np.prod(self.shape) / len(labels)
68 |
69 | def __getitem__(self, index):
70 | im, rotation, flip, enhance = np.unravel_index(index, self.shape)
71 |
72 | with Image.open(self.names[im]) as img:
73 |
74 | if flip != 0:
75 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
76 |
77 | if rotation != 0:
78 | img = img.rotate(rotation * 90)
79 |
80 | if enhance != 0:
81 | factors = np.random.uniform(.5, 1.5, 3)
82 | img = ImageEnhance.Color(img).enhance(factors[0])
83 | img = ImageEnhance.Contrast(img).enhance(factors[1])
84 | img = ImageEnhance.Brightness(img).enhance(factors[2])
85 |
86 | extractor = PatchExtractor(img=img, patch_size=PATCH_SIZE, stride=self.stride)
87 | patches = extractor.extract_patches()
88 |
89 | label = self.labels[self.names[im]]
90 |
91 | b = torch.zeros((len(patches), 3, PATCH_SIZE, PATCH_SIZE))
92 | for i in range(len(patches)):
93 | b[i] = transforms.ToTensor()(patches[i])
94 |
95 | return b, label
96 |
97 | def __len__(self):
98 | return np.prod(self.shape)
99 |
100 |
101 | class TestDataset(Dataset):
102 | def __init__(self, path, stride=PATCH_SIZE, augment=False):
103 | super().__init__()
104 |
105 | if os.path.isdir(path):
106 | names = [name for name in glob.glob(path + '/*.tif')]
107 | else:
108 | names = [path]
109 |
110 | self.path = path
111 | self.stride = stride
112 | self.augment = augment
113 | self.names = list(sorted(names))
114 |
115 | def __getitem__(self, index):
116 | file = self.names[index]
117 | with Image.open(file) as img:
118 |
119 | bins = 8 if self.augment else 1
120 | extractor = PatchExtractor(img=img, patch_size=PATCH_SIZE, stride=self.stride)
121 | b = torch.zeros((bins, extractor.shape()[0] * extractor.shape()[1], 3, PATCH_SIZE, PATCH_SIZE))
122 |
123 | for k in range(bins):
124 |
125 | if k % 4 != 0:
126 | img = img.rotate((k % 4) * 90)
127 |
128 | if k // 4 != 0:
129 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
130 |
131 | extractor = PatchExtractor(img=img, patch_size=PATCH_SIZE, stride=self.stride)
132 | patches = extractor.extract_patches()
133 |
134 | for i in range(len(patches)):
135 | b[k, i] = transforms.ToTensor()(patches[i])
136 |
137 | return b, file
138 |
139 | def __len__(self):
140 | return len(self.names)
141 |
--------------------------------------------------------------------------------
/src/models.py:
--------------------------------------------------------------------------------
1 | import time
2 | import time
3 | import ntpath
4 | import datetime
5 | import matplotlib.pyplot as plt
6 | import torch.optim as optim
7 | import torch.nn.functional as F
8 | import matplotlib.pyplot as ply
9 | from torch.autograd import Variable
10 | from torch.utils.data import DataLoader, TensorDataset
11 | from sklearn.metrics import roc_curve, auc
12 | from sklearn.preprocessing import label_binarize
13 |
14 | from .datasets import *
15 |
16 | TRAIN_PATH = '/train'
17 | VALIDATION_PATH = '/validation'
18 |
19 |
20 | class BaseModel:
21 | def __init__(self, args, network, weights_path):
22 | self.args = args
23 | self.weights = weights_path
24 | self.network = network.cuda() if args.cuda else network
25 | self.load()
26 |
27 | def load(self):
28 | try:
29 | if os.path.exists(self.weights):
30 | print('Loading "patch-wise" model...')
31 | self.network.load_state_dict(torch.load(self.weights))
32 | except:
33 | print('Failed to load pre-trained network')
34 |
35 | def save(self):
36 | print('Saving model to "{}"'.format(self.weights))
37 | torch.save(self.network.state_dict(), self.weights)
38 |
39 |
40 | class PatchWiseModel(BaseModel):
41 | def __init__(self, args, network):
42 | super(PatchWiseModel, self).__init__(args, network, args.checkpoints_path + '/weights_' + network.name() + '.pth')
43 |
44 | def train(self):
45 | self.network.train()
46 | print('Start training patch-wise network: {}\n'.format(time.strftime('%Y/%m/%d %H:%M')))
47 |
48 | train_loader = DataLoader(
49 | dataset=PatchWiseDataset(path=self.args.dataset_path + TRAIN_PATH, stride=self.args.patch_stride, rotate=True, flip=True, enhance=True),
50 | batch_size=self.args.batch_size,
51 | shuffle=True,
52 | num_workers=4
53 | )
54 | optimizer = optim.Adam(self.network.parameters(), lr=self.args.lr, betas=(self.args.beta1, self.args.beta2))
55 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
56 | best = self.validate(verbose=False)
57 | mean = 0
58 | epoch = 0
59 |
60 | for epoch in range(1, self.args.epochs + 1):
61 |
62 | self.network.train()
63 | scheduler.step()
64 | stime = datetime.datetime.now()
65 |
66 | correct = 0
67 | total = 0
68 |
69 | for index, (images, labels) in enumerate(train_loader):
70 |
71 | if self.args.cuda:
72 | images, labels = images.cuda(), labels.cuda()
73 |
74 | optimizer.zero_grad()
75 | output = self.network(Variable(images))
76 | loss = F.nll_loss(output, Variable(labels))
77 | loss.backward()
78 | optimizer.step()
79 |
80 | _, predicted = torch.max(output.data, 1)
81 | correct += torch.sum(predicted == labels)
82 | total += len(images)
83 |
84 | if index > 0 and index % self.args.log_interval == 0:
85 | print('Epoch: {}/{} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, Accuracy: {:.2f}%'.format(
86 | epoch,
87 | self.args.epochs,
88 | index * len(images),
89 | len(train_loader.dataset),
90 | 100. * index / len(train_loader),
91 | loss.data[0],
92 | 100 * correct / total
93 | ))
94 |
95 | print('\nEnd of epoch {}, time: {}'.format(epoch, datetime.datetime.now() - stime))
96 | acc = self.validate()
97 | mean += acc
98 | if acc > best:
99 | best = acc
100 |
101 | self.save()
102 |
103 | print('\nEnd of training, best accuracy: {}, mean accuracy: {}\n'.format(best, mean // epoch))
104 |
105 | def validate(self, verbose=True):
106 | self.network.eval()
107 |
108 | test_loss = 0
109 | correct = 0
110 | classes = len(LABELS)
111 |
112 | tp = [0] * classes
113 | tpfp = [0] * classes
114 | tpfn = [0] * classes
115 | precision = [0] * classes
116 | recall = [0] * classes
117 | f1 = [0] * classes
118 |
119 | test_loader = DataLoader(
120 | dataset=PatchWiseDataset(path=self.args.dataset_path + VALIDATION_PATH, stride=self.args.patch_stride),
121 | batch_size=self.args.batch_size,
122 | shuffle=False,
123 | num_workers=4
124 | )
125 | if verbose:
126 | print('\nEvaluating....')
127 |
128 | for images, labels in test_loader:
129 |
130 | if self.args.cuda:
131 | images, labels = images.cuda(), labels.cuda()
132 |
133 | output = self.network(Variable(images, volatile=True))
134 |
135 | test_loss += F.nll_loss(output, Variable(labels), size_average=False).data[0]
136 | _, predicted = torch.max(output.data, 1)
137 | correct += torch.sum(predicted == labels)
138 |
139 | for label in range(classes):
140 | t_labels = labels == label
141 | p_labels = predicted == label
142 | tp[label] += torch.sum(t_labels == (p_labels * 2 - 1))
143 | tpfp[label] += torch.sum(p_labels)
144 | tpfn[label] += torch.sum(t_labels)
145 |
146 | for label in range(classes):
147 | precision[label] += (tp[label] / (tpfp[label] + 1e-8))
148 | recall[label] += (tp[label] / (tpfn[label] + 1e-8))
149 | f1[label] = 2 * precision[label] * recall[label] / (precision[label] + recall[label] + 1e-8)
150 |
151 | test_loss /= len(test_loader.dataset)
152 | acc = 100. * correct / len(test_loader.dataset)
153 |
154 | if verbose:
155 | print('Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
156 | test_loss,
157 | correct,
158 | len(test_loader.dataset),
159 | 100. * correct / len(test_loader.dataset)
160 | ))
161 |
162 | for label in range(classes):
163 | print('{}: \t Precision: {:.2f}, Recall: {:.2f}, F1: {:.2f}'.format(
164 | LABELS[label],
165 | precision[label],
166 | recall[label],
167 | f1[label]
168 | ))
169 |
170 | print('')
171 | return acc
172 |
173 | def test(self, path, verbose=True):
174 | self.network.eval()
175 | dataset = TestDataset(path=path, stride=PATCH_SIZE, augment=False)
176 | data_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False)
177 | stime = datetime.datetime.now()
178 |
179 | if verbose:
180 | print('\t sum\t\t max\t\t maj')
181 |
182 | res = []
183 |
184 | for index, (image, file_name) in enumerate(data_loader):
185 | image = image.squeeze()
186 | if self.args.cuda:
187 | image = image.cuda()
188 |
189 | output = self.network(Variable(image))
190 | _, predicted = torch.max(output.data, 1)
191 |
192 | #
193 | # the following measures are prioritised based on [invasive, insitu, benign, normal]
194 | # the original labels are [normal, benign, insitu, invasive], so we reverse the order using [::-1]
195 | # output data shape is 12x4
196 | # sum_prop: sum of probabilities among y axis: (1, 4), reverse, and take the index of the largest value
197 | # max_prop: max of probabilities among y axis: (1, 4), reverse, and take the index of the largest value
198 | # maj_prop: majority voting: create a one-hot vector of predicted values: (12, 4), sum among y axis: (1, 4), reverse, and take the index of the largest value
199 |
200 | sum_prob = 3 - np.argmax(np.sum(np.exp(output.data.cpu().numpy()), axis=0)[::-1])
201 | max_prob = 3 - np.argmax(np.max(np.exp(output.data.cpu().numpy()), axis=0)[::-1])
202 | maj_prob = 3 - np.argmax(np.sum(np.eye(4)[np.array(predicted).reshape(-1)], axis=0)[::-1])
203 |
204 | res.append([sum_prob, max_prob, maj_prob, file_name[0]])
205 |
206 | if verbose:
207 | np.sum(output.data.cpu().numpy(), axis=0)
208 | print('{}) \t {} \t {} \t {} \t {}'.format(
209 | str(index + 1).rjust(2, '0'),
210 | LABELS[sum_prob].ljust(8),
211 | LABELS[max_prob].ljust(8),
212 | LABELS[maj_prob].ljust(8),
213 | ntpath.basename(file_name[0])))
214 |
215 | if verbose:
216 | print('\nInference time: {}\n'.format(datetime.datetime.now() - stime))
217 |
218 | return res
219 |
220 | def output(self, input_tensor):
221 | self.network.eval()
222 | res = self.network.features(Variable(input_tensor, volatile=True))
223 | return res.squeeze()
224 |
225 | def visualize(self, path, channel=0):
226 | self.network.eval()
227 | dataset = TestDataset(path=path, stride=PATCH_SIZE)
228 | data_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False)
229 |
230 | for index, (image, file_name) in enumerate(data_loader):
231 |
232 | if self.args.cuda:
233 | image = image[0].cuda()
234 |
235 | patches = self.output(image)
236 |
237 | output = patches.cpu().data.numpy()
238 |
239 | map = np.zeros((3 * 64, 4 * 64))
240 |
241 | for i in range(12):
242 | row = i // 4
243 | col = i % 4
244 | map[row * 64:(row + 1) * 64, col * 64:(col + 1) * 64] = output[i]
245 |
246 | if len(map.shape) > 2:
247 | map = map[channel]
248 |
249 | with Image.open(file_name[0]) as img:
250 | ply.subplot(121)
251 | ply.axis('off')
252 | ply.imshow(np.array(img))
253 |
254 | ply.subplot(122)
255 | ply.imshow(map, cmap='gray')
256 | ply.axis('off')
257 |
258 | ply.show()
259 |
260 |
261 | class ImageWiseModel(BaseModel):
262 | def __init__(self, args, image_wise_network, patch_wise_network):
263 | super(ImageWiseModel, self).__init__(args, image_wise_network, args.checkpoints_path + '/weights_' + image_wise_network.name() + '.pth')
264 |
265 | self.patch_wise_model = PatchWiseModel(args, patch_wise_network)
266 | self._test_loader = None
267 |
268 | def train(self):
269 | self.network.train()
270 | print('Evaluating patch-wise model...')
271 |
272 | train_loader = self._patch_loader(self.args.dataset_path + TRAIN_PATH, True)
273 |
274 | print('Start training image-wise network: {}\n'.format(time.strftime('%Y/%m/%d %H:%M')))
275 |
276 | optimizer = optim.Adam(self.network.parameters(), lr=self.args.lr, betas=(self.args.beta1, self.args.beta2))
277 | best = self.validate(verbose=False)
278 | mean = 0
279 | epoch = 0
280 |
281 | for epoch in range(1, self.args.epochs + 1):
282 |
283 | self.network.train()
284 | stime = datetime.datetime.now()
285 |
286 | correct = 0
287 | total = 0
288 |
289 | for index, (images, labels) in enumerate(train_loader):
290 |
291 | if self.args.cuda:
292 | images, labels = images.cuda(), labels.cuda()
293 |
294 | optimizer.zero_grad()
295 | output = self.network(Variable(images))
296 | loss = F.nll_loss(output, Variable(labels))
297 | loss.backward()
298 | optimizer.step()
299 |
300 | _, predicted = torch.max(output.data, 1)
301 | correct += torch.sum(predicted == labels)
302 | total += len(images)
303 |
304 | if index > 0 and index % 10 == 0:
305 | print('Epoch: {}/{} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, Accuracy: {:.2f}%'.format(
306 | epoch,
307 | self.args.epochs,
308 | index * len(images),
309 | len(train_loader.dataset),
310 | 100. * index / len(train_loader),
311 | loss.data[0],
312 | 100 * correct / total
313 | ))
314 |
315 | print('\nEnd of epoch {}, time: {}'.format(epoch, datetime.datetime.now() - stime))
316 | acc = self.validate()
317 | mean += acc
318 | if acc > best:
319 | best = acc
320 | self.save()
321 |
322 | print('\nEnd of training, best accuracy: {}, mean accuracy: {}\n'.format(best, mean // epoch))
323 |
324 | def validate(self, verbose=True, roc=False):
325 | self.network.eval()
326 |
327 | if self._test_loader is None:
328 | self._test_loader = self._patch_loader(self.args.dataset_path + VALIDATION_PATH, False)
329 |
330 | val_loss = 0
331 | correct = 0
332 | classes = len(LABELS)
333 |
334 | tp = [0] * classes
335 | tpfp = [0] * classes
336 | tpfn = [0] * classes
337 | precision = [0] * classes
338 | recall = [0] * classes
339 | f1 = [0] * classes
340 |
341 | if verbose:
342 | print('\nEvaluating....')
343 |
344 | labels_true = []
345 | labels_pred = np.empty((0, 4))
346 |
347 | for images, labels in self._test_loader:
348 |
349 | if self.args.cuda:
350 | images, labels = images.cuda(), labels.cuda()
351 |
352 | output = self.network(Variable(images, volatile=True))
353 |
354 | val_loss += F.nll_loss(output, Variable(labels), size_average=False).data[0]
355 | _, predicted = torch.max(output.data, 1)
356 | correct += torch.sum(predicted == labels)
357 |
358 | labels_true = np.append(labels_true, labels)
359 | labels_pred = np.append(labels_pred, torch.exp(output.data).cpu().numpy(), axis=0)
360 |
361 | for label in range(classes):
362 | t_labels = labels == label
363 | p_labels = predicted == label
364 | tp[label] += torch.sum(t_labels == (p_labels * 2 - 1))
365 | tpfp[label] += torch.sum(p_labels)
366 | tpfn[label] += torch.sum(t_labels)
367 |
368 | for label in range(classes):
369 | precision[label] += (tp[label] / (tpfp[label] + 1e-8))
370 | recall[label] += (tp[label] / (tpfn[label] + 1e-8))
371 | f1[label] = 2 * precision[label] * recall[label] / (precision[label] + recall[label] + 1e-8)
372 |
373 | val_loss /= len(self._test_loader.dataset)
374 | acc = 100. * correct / len(self._test_loader.dataset)
375 |
376 | if roc == 1:
377 | labels_true = label_binarize(labels_true, classes=range(classes))
378 | for lbl in range(classes):
379 | fpr, tpr, _ = roc_curve(labels_true[:, lbl], labels_pred[:, lbl])
380 | roc_auc = auc(fpr, tpr)
381 | plt.plot(fpr, tpr, lw=2, label='{} (AUC: {:.1f})'.format(LABELS[lbl], roc_auc * 100))
382 |
383 | plt.xlim([0, 1])
384 | plt.ylim([0, 1.05])
385 | plt.xlabel('False Positive Rate')
386 | plt.ylabel('True Positive Rate')
387 | plt.legend(loc="lower right")
388 | plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
389 | plt.title('Receiver Operating Characteristic')
390 | plt.show()
391 |
392 | if verbose:
393 | print('Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
394 | val_loss,
395 | correct,
396 | len(self._test_loader.dataset),
397 | acc
398 | ))
399 |
400 | for label in range(classes):
401 | print('{}: \t Precision: {:.2f}, Recall: {:.2f}, F1: {:.2f}'.format(
402 | LABELS[label],
403 | precision[label],
404 | recall[label],
405 | f1[label]
406 | ))
407 |
408 | print('')
409 |
410 | return acc
411 |
412 | def test(self, path, verbose=True, ensemble=True):
413 | self.network.eval()
414 | dataset = TestDataset(path=path, stride=PATCH_SIZE, augment=ensemble)
415 | data_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False)
416 | stime = datetime.datetime.now()
417 |
418 | if verbose:
419 | print('')
420 |
421 | res = []
422 |
423 | for index, (image, file_name) in enumerate(data_loader):
424 | n_bins, n_patches = image.shape[1], image.shape[2]
425 | image = image.view(-1, 3, PATCH_SIZE, PATCH_SIZE)
426 |
427 | if self.args.cuda:
428 | image = image.cuda()
429 |
430 | patches = self.patch_wise_model.output(image)
431 | patches = patches.view(n_bins, -1, 64, 64)
432 |
433 | if self.args.cuda:
434 | patches = patches.cuda()
435 |
436 | output = self.network(patches)
437 | _, predicted = torch.max(output.data, 1)
438 |
439 | # maj_prop: majority voting: create a one-hot vector of predicted values: (12, 4),
440 | # sum among y axis: (1, 4), reverse, and take the index of the largest value
441 |
442 | maj_prob = 3 - np.argmax(np.sum(np.eye(4)[np.array(predicted).reshape(-1)], axis=0)[::-1])
443 |
444 | confidence = np.sum(np.array(predicted) == maj_prob) / n_bins if ensemble else torch.max(torch.exp(output.data))
445 | confidence = np.round(confidence * 100, 2)
446 |
447 | res.append([maj_prob, confidence, file_name[0]])
448 |
449 | if verbose:
450 | print('{}) {} ({}%) \t {}'.format(
451 | str(index).rjust(2, '0'),
452 | LABELS[maj_prob],
453 | confidence,
454 | ntpath.basename(file_name[0])))
455 |
456 | if verbose:
457 | print('\nInference time: {}\n'.format(datetime.datetime.now() - stime))
458 |
459 | return res
460 |
461 | def _patch_loader(self, path, augment):
462 | images_path = '{}/{}_images.npy'.format(self.args.checkpoints_path, self.network.name())
463 | labels_path = '{}/{}_labels.npy'.format(self.args.checkpoints_path, self.network.name())
464 |
465 | if self.args.debug and augment and os.path.exists(images_path):
466 | np_images = np.load(images_path)
467 | np_labels = np.load(labels_path)
468 |
469 | else:
470 | dataset = ImageWiseDataset(
471 | path=path,
472 | stride=PATCH_SIZE,
473 | flip=augment,
474 | rotate=augment,
475 | enhance=augment)
476 |
477 | bsize = 8
478 | output_loader = DataLoader(dataset=dataset, batch_size=bsize, shuffle=True, num_workers=4)
479 | output_images = []
480 | output_labels = []
481 |
482 | for index, (images, labels) in enumerate(output_loader):
483 | if index > 0 and index % 10 == 0:
484 | print('{} images loaded'.format(int((index * bsize) / dataset.augment_size)))
485 |
486 | if self.args.cuda:
487 | images = images.cuda()
488 |
489 | bsize = images.shape[0]
490 |
491 | res = self.patch_wise_model.output(images.view((-1, 3, 512, 512)))
492 | res = res.view((bsize, -1, 64, 64)).data.cpu().numpy()
493 |
494 | for i in range(bsize):
495 | output_images.append(res[i])
496 | output_labels.append(labels.numpy()[i])
497 |
498 | np_images = np.array(output_images)
499 | np_labels = np.array(output_labels)
500 |
501 | if self.args.debug and augment:
502 | np.save(images_path, np_images)
503 | np.save(labels_path, np_labels)
504 |
505 | images, labels = torch.from_numpy(np_images), torch.from_numpy(np_labels).squeeze()
506 |
507 | return DataLoader(
508 | dataset=TensorDataset(images, labels),
509 | batch_size=self.args.batch_size,
510 | shuffle=True,
511 | num_workers=2
512 | )
513 |
--------------------------------------------------------------------------------
/src/networks.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class BaseNetwork(nn.Module):
7 | def __init__(self, name, channels=1):
8 | super(BaseNetwork, self).__init__()
9 | self._name = name
10 | self._channels = channels
11 |
12 | def name(self):
13 | return self._name
14 |
15 | def initialize_weights(self):
16 | for m in self.modules():
17 | if isinstance(m, nn.Conv2d):
18 | nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
19 | if m.bias is not None:
20 | m.bias.data.zero_()
21 |
22 | elif isinstance(m, nn.BatchNorm2d):
23 | m.weight.data.fill_(1)
24 | m.bias.data.zero_()
25 |
26 | elif isinstance(m, nn.Linear):
27 | m.weight.data.normal_(0, 0.01)
28 | m.bias.data.zero_()
29 |
30 |
31 | class PatchWiseNetwork(BaseNetwork):
32 | def __init__(self, channels=1):
33 | super(PatchWiseNetwork, self).__init__('pw' + str(channels), channels)
34 |
35 | self.features = nn.Sequential(
36 | # Block 1
37 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1),
38 | nn.BatchNorm2d(16),
39 | nn.ReLU(inplace=True),
40 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1),
41 | nn.BatchNorm2d(16),
42 | nn.ReLU(inplace=True),
43 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=2, stride=2),
44 | nn.BatchNorm2d(16),
45 | nn.ReLU(inplace=True),
46 |
47 | # Block 2
48 | nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),
49 | nn.BatchNorm2d(32),
50 | nn.ReLU(inplace=True),
51 | nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
52 | nn.BatchNorm2d(32),
53 | nn.ReLU(inplace=True),
54 | nn.Conv2d(in_channels=32, out_channels=32, kernel_size=2, stride=2),
55 | nn.BatchNorm2d(32),
56 | nn.ReLU(inplace=True),
57 |
58 | # Block 3
59 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
60 | nn.BatchNorm2d(64),
61 | nn.ReLU(inplace=True),
62 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
63 | nn.BatchNorm2d(64),
64 | nn.ReLU(inplace=True),
65 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=2, stride=2),
66 | nn.BatchNorm2d(64),
67 | nn.ReLU(inplace=True),
68 |
69 | # Block 4
70 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
71 | nn.BatchNorm2d(128),
72 | nn.ReLU(inplace=True),
73 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
74 | nn.BatchNorm2d(128),
75 | nn.ReLU(inplace=True),
76 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
77 | nn.BatchNorm2d(128),
78 | nn.ReLU(inplace=True),
79 |
80 | # Block 5
81 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
82 | nn.BatchNorm2d(256),
83 | nn.ReLU(inplace=True),
84 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
85 | nn.BatchNorm2d(256),
86 | nn.ReLU(inplace=True),
87 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
88 | nn.BatchNorm2d(256),
89 | nn.ReLU(inplace=True),
90 |
91 | nn.Conv2d(in_channels=256, out_channels=channels, kernel_size=1, stride=1),
92 | )
93 |
94 | self.classifier = nn.Sequential(
95 | nn.Linear(channels * 64 * 64, 4),
96 | )
97 |
98 | self.initialize_weights()
99 |
100 | def forward(self, x):
101 | x = self.features(x)
102 | x = x.view(x.size(0), -1)
103 | x = self.classifier(x)
104 | x = F.log_softmax(x, dim=1)
105 | return x
106 |
107 |
108 | class ImageWiseNetwork(BaseNetwork):
109 | def __init__(self, channels=1):
110 | super(ImageWiseNetwork, self).__init__('iw' + str(channels), channels)
111 |
112 | self.features = nn.Sequential(
113 | # Block 1
114 | nn.Conv2d(in_channels=12 * channels, out_channels=64, kernel_size=3, stride=1, padding=1),
115 | nn.BatchNorm2d(64),
116 | nn.ReLU(inplace=True),
117 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
118 | nn.BatchNorm2d(64),
119 | nn.ReLU(inplace=True),
120 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=2, stride=2),
121 | nn.BatchNorm2d(64),
122 | nn.ReLU(inplace=True),
123 |
124 | # Block 2
125 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
126 | nn.BatchNorm2d(128),
127 | nn.ReLU(inplace=True),
128 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
129 | nn.BatchNorm2d(128),
130 | nn.ReLU(inplace=True),
131 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=2, stride=2),
132 | nn.BatchNorm2d(128),
133 | nn.ReLU(inplace=True),
134 |
135 | nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1),
136 | )
137 |
138 | self.classifier = nn.Sequential(
139 | nn.Linear(1 * 16 * 16, 128),
140 | nn.ReLU(inplace=True),
141 | nn.Dropout(0.5, inplace=True),
142 |
143 | nn.Linear(128, 128),
144 | nn.ReLU(inplace=True),
145 | nn.Dropout(0.5, inplace=True),
146 |
147 | nn.Linear(128, 64),
148 | nn.ReLU(inplace=True),
149 | nn.Dropout(0.5, inplace=True),
150 |
151 | nn.Linear(64, 4),
152 | )
153 |
154 | self.initialize_weights()
155 |
156 | def forward(self, x):
157 | x = self.features(x)
158 | x = x.view(x.size(0), -1)
159 | x = self.classifier(x)
160 | x = F.log_softmax(x, dim=1)
161 | return x
162 |
--------------------------------------------------------------------------------
/src/options.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import torch
4 | import argparse
5 |
6 |
7 | class ModelOptions:
8 | def __init__(self):
9 | parser = argparse.ArgumentParser(description='Classification of breast cancer histology')
10 | parser.add_argument('--dataset-path', type=str, default='./dataset', help='dataset path (default: ./dataset)')
11 | parser.add_argument('--testset-path', type=str, default='', help='file or directory address to the test set')
12 | parser.add_argument('--checkpoints-path', type=str, default='./checkpoints', help='models are saved here')
13 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', help='input batch size for training (default: 64)')
14 | parser.add_argument('--test-batch-size', type=int, default=64, metavar='N', help='input batch size for testing (default: 64)')
15 | parser.add_argument('--patch-stride', type=int, default=256, metavar='N', help='How far the centers of two consecutive patches are in the image (default: 256)')
16 | parser.add_argument('--epochs', type=int, default=30, metavar='N', help='number of epochs to train (default: 30)')
17 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR', help='learning rate (default: 0.01)')
18 | parser.add_argument('--beta1', type=float, default=0.9, metavar='M', help='Adam beta1 (default: 0.9)')
19 | parser.add_argument('--beta2', type=float, default=0.999, metavar='M', help='Adam beta2 (default: 0.999)')
20 | parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training')
21 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
22 | parser.add_argument('--log-interval', type=int, default=50, metavar='N', help='how many batches to wait before logging training status')
23 | parser.add_argument('--gpu-ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
24 | parser.add_argument('--ensemble', type=int, default='1', help='whether to use model ensemble on test-set prediction (default: 1)')
25 | parser.add_argument('--network', type=str, default='0', help='train patch-wise network: 1, image-wise network: 2 or both: 0 (default: 0)')
26 | parser.add_argument('--channels', type=int, default=1, help='number of channels created by the patch-wise network that feeds into the image-wise network (default: 1)')
27 | parser.add_argument('--debug', type=int, default=0, help='debugging (default: 0)')
28 |
29 | self._parser = parser
30 |
31 | def parse(self):
32 | opt = self._parser.parse_args()
33 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids
34 | opt.cuda = not opt.no_cuda and torch.cuda.is_available()
35 | opt.debug = opt.debug != 0
36 |
37 | args = vars(opt)
38 | print('\n------------ Options -------------')
39 | for k, v in sorted(args.items()):
40 | print('%s: %s' % (str(k), str(v)))
41 | print('-------------- End ----------------\n')
42 |
43 | return opt
44 |
--------------------------------------------------------------------------------
/src/patch_extractor.py:
--------------------------------------------------------------------------------
1 | class PatchExtractor:
2 | def __init__(self, img, patch_size, stride):
3 | '''
4 | :param img: :py:class:`~PIL.Image.Image`
5 | :param patch_size: integer, size of the patch
6 | :param stride: integer, size of the stride
7 | '''
8 | self.img = img
9 | self.size = patch_size
10 | self.stride = stride
11 |
12 | def extract_patches(self):
13 | """
14 | extracts all patches from an image
15 | :returns: A list of :py:class:`~PIL.Image.Image` objects.
16 | """
17 | wp, hp = self.shape()
18 | return [self.extract_patch((w, h)) for h in range(hp) for w in range(wp)]
19 |
20 | def extract_patch(self, patch):
21 | """
22 | extracts a patch from an input image
23 | :param patch: a tuple
24 | :rtype: :py:class:`~PIL.Image.Image`
25 | :returns: An :py:class:`~PIL.Image.Image` object.
26 | """
27 | return self.img.crop((
28 | patch[0] * self.stride, # left
29 | patch[1] * self.stride, # up
30 | patch[0] * self.stride + self.size, # right
31 | patch[1] * self.stride + self.size # down
32 | ))
33 |
34 | def shape(self):
35 | wp = int((self.img.width - self.size) / self.stride + 1)
36 | hp = int((self.img.height - self.size) / self.stride + 1)
37 | return wp, hp
38 |
39 |
40 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from src import *
2 |
3 | args = ModelOptions().parse()
4 |
5 | torch.manual_seed(args.seed)
6 | if args.cuda:
7 | torch.cuda.manual_seed(args.seed)
8 |
9 | pw_network = PatchWiseNetwork(args.channels)
10 | iw_network = ImageWiseNetwork(args.channels)
11 |
12 | if args.testset_path is '':
13 | import tkinter.filedialog as fdialog
14 |
15 | args.testset_path = fdialog.askopenfilename(initialdir=r"./dataset/test", title="choose your file", filetypes=(("tiff files", "*.tif"), ("all files", "*.*")))
16 |
17 | if args.network == '1':
18 | pw_model = PatchWiseModel(args, pw_network)
19 | pw_model.test(args.testset_path)
20 |
21 | else:
22 | im_model = ImageWiseModel(args, iw_network, pw_network)
23 | im_model.test(args.testset_path, ensemble=args.ensemble == 1)
24 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from src import *
2 |
3 | args = ModelOptions().parse()
4 |
5 | torch.manual_seed(args.seed)
6 | if args.cuda:
7 | torch.cuda.manual_seed(args.seed)
8 |
9 | pw_network = PatchWiseNetwork(args.channels)
10 | iw_network = ImageWiseNetwork(args.channels)
11 |
12 | if args.network == '0' or args.network == '1':
13 | pw_model = PatchWiseModel(args, pw_network)
14 | pw_model.train()
15 |
16 | if args.network == '0' or args.network == '2':
17 | iw_model = ImageWiseModel(args, iw_network, pw_network)
18 | iw_model.train()
19 |
--------------------------------------------------------------------------------
/validate.py:
--------------------------------------------------------------------------------
1 | from src import *
2 |
3 | args = ModelOptions().parse()
4 |
5 | torch.manual_seed(args.seed)
6 | if args.cuda:
7 | torch.cuda.manual_seed(args.seed)
8 |
9 | pw_network = PatchWiseNetwork(args.channels)
10 | iw_network = ImageWiseNetwork(args.channels)
11 |
12 | if args.network == '0' or args.network == '1':
13 | pw_model = PatchWiseModel(args, pw_network)
14 | pw_model.validate()
15 |
16 | if args.network == '0' or args.network == '2':
17 | iw_model = ImageWiseModel(args, iw_network, pw_network)
18 | iw_model.validate(roc=True)
19 |
--------------------------------------------------------------------------------