├── .gitignore ├── LICENSE ├── README.md ├── checkpoints ├── weights_iw1.pth ├── weights_iw16.pth ├── weights_iw3.pth ├── weights_iw4.pth ├── weights_pw1.pth ├── weights_pw16.pth ├── weights_pw3.pth └── weights_pw4.pth ├── dataset ├── test │ └── .gitignore ├── train │ ├── Benign │ │ └── .gitignore │ ├── InSitu │ │ └── .gitignore │ ├── Invasive │ │ └── .gitignore │ └── Normal │ │ └── .gitignore └── validation │ ├── Benign │ └── .gitignore │ ├── InSitu │ └── .gitignore │ ├── Invasive │ └── .gitignore │ └── Normal │ └── .gitignore ├── img ├── dataset.jpg └── network.png ├── requirements.txt ├── src ├── __init__.py ├── datasets.py ├── models.py ├── networks.py ├── options.py └── patch_extractor.py ├── test.py ├── train.py └── validate.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | 104 | .idea 105 | draft 106 | checkpoints -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Imaging Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ICIAR2018 2 | ### Two-Stage Convolutional Neural Network for Breast Cancer Histology Image Classification 3 | This repository is the part A of the ICIAR 2018 Grand Challenge on BreAst Cancer Histology (BACH) images for automatically classifying H&E stained breast histology microscopy images in four classes: normal, benign, in situ carcinoma and invasive carcinoma. 4 | 5 | We are presenting a CNN approach using two convolutional networks to classify histology images in a patchwise fashion. The first network, receives overlapping patches (35 patches) of the whole-slide image and learns to generate spatially smaller outputs. The second network is trained on the downsampled patches of the whole image using the output of the first network. The number of channels in the input to the second network is equal to the total number of patches extracted from the microscopy image in a non-overlapping fashion (12 patches) times the depth of the feature maps generted by the first network (C): 6 |

7 | 8 |

9 | 10 | ## Prerequisites 11 | - Linux 12 | - Python 3 13 | - NVIDIA GPU (12G or 24G memory) + CUDA cuDNN 14 | 15 | ## Getting Started 16 | ### Installation 17 | - Clone this repo: 18 | ```bash 19 | git clone https://github.com/ImagingLab/ICIAR2018 20 | cd ICIAR2018 21 | ``` 22 | - Install PyTorch and dependencies from http://pytorch.org 23 | - Install python requirements: 24 | ```bash 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | ### Dataset 29 | - We use the ICIAR2018 dataset. To train a model on the full dataset, please download it from the [official website](https://iciar2018-challenge.grand-challenge.org/dataset/) (registration required). The dataset is composed of 400 high resolution Hematoxylin and Eosin (H&E) stained breast histology microscopy images labelled as normal, benign, in situ carcinoma, and invasive carcinoma (100 images for each category): 30 |

31 | 32 |

33 | After downloading, please put it under the `datasets` folder in the same way the sub-directories are provided. 34 | 35 | 36 | ### Testing 37 | - The pre-trained ICIAR2018 dataset model resides under `./checkpoints`. 38 | - To test the model, run `test.py` script 39 | - Use `--testset-path` command-line argument to provide the path to the `test` folder. 40 | ```bash 41 | python test.py --testset-path ./dataset/test 42 | ``` 43 | - If you don't provide the test-set path, an open-file dialogbox will appear to select an image for test. 44 | The test results will be printed on the screen. 45 | 46 | 47 | 48 | ### Training 49 | - To train the model, run `train.py` script 50 | ```bash 51 | python train.py 52 | ``` 53 | - To change the number of feature-maps generated by the patch-wise network use `--channels` argument: 54 | ```bash 55 | python train.py --channels 1 56 | ``` 57 | 58 | 59 | ### Validation & ROC Curves 60 | - To validate the model on the validation set and plot the ROC curves, run `validate.py` script 61 | ```bash 62 | python validate.py 63 | ``` 64 | - To change the number of feature-maps generated by the patch-wise network use `--channels` argument: 65 | ```bash 66 | python train.py --channels 1 67 | ``` 68 | 69 | ## Citation 70 | If you use this code for your research, please cite our paper Two-Stage Convolutional Neural Network for Breast Cancer Histology Image Classification: 71 | 72 | ``` 73 | @inproceedings{nazeri2018two, 74 | title={Two-Stage Convolutional Neural Network for Breast Cancer Histology Image Classification}, 75 | author={Nazeri, Kamyar and Aminpour, Azad and Ebrahimi, Mehran}, 76 | booktitle={International Conference Image Analysis and Recognition}, 77 | pages={717--726}, 78 | year={2018}, 79 | organization={Springer} 80 | } 81 | ``` 82 | -------------------------------------------------------------------------------- /checkpoints/weights_iw1.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ImagingLab/ICIAR2018/0a98dd9c21e3069052a38bc75973832b3a5e775e/checkpoints/weights_iw1.pth -------------------------------------------------------------------------------- /checkpoints/weights_iw16.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ImagingLab/ICIAR2018/0a98dd9c21e3069052a38bc75973832b3a5e775e/checkpoints/weights_iw16.pth -------------------------------------------------------------------------------- /checkpoints/weights_iw3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ImagingLab/ICIAR2018/0a98dd9c21e3069052a38bc75973832b3a5e775e/checkpoints/weights_iw3.pth -------------------------------------------------------------------------------- /checkpoints/weights_iw4.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ImagingLab/ICIAR2018/0a98dd9c21e3069052a38bc75973832b3a5e775e/checkpoints/weights_iw4.pth -------------------------------------------------------------------------------- /checkpoints/weights_pw1.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ImagingLab/ICIAR2018/0a98dd9c21e3069052a38bc75973832b3a5e775e/checkpoints/weights_pw1.pth -------------------------------------------------------------------------------- /checkpoints/weights_pw16.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ImagingLab/ICIAR2018/0a98dd9c21e3069052a38bc75973832b3a5e775e/checkpoints/weights_pw16.pth -------------------------------------------------------------------------------- /checkpoints/weights_pw3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ImagingLab/ICIAR2018/0a98dd9c21e3069052a38bc75973832b3a5e775e/checkpoints/weights_pw3.pth -------------------------------------------------------------------------------- /checkpoints/weights_pw4.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ImagingLab/ICIAR2018/0a98dd9c21e3069052a38bc75973832b3a5e775e/checkpoints/weights_pw4.pth -------------------------------------------------------------------------------- /dataset/test/.gitignore: -------------------------------------------------------------------------------- 1 | *.tif -------------------------------------------------------------------------------- /dataset/train/Benign/.gitignore: -------------------------------------------------------------------------------- 1 | *.tif -------------------------------------------------------------------------------- /dataset/train/InSitu/.gitignore: -------------------------------------------------------------------------------- 1 | *.tif -------------------------------------------------------------------------------- /dataset/train/Invasive/.gitignore: -------------------------------------------------------------------------------- 1 | *.tif -------------------------------------------------------------------------------- /dataset/train/Normal/.gitignore: -------------------------------------------------------------------------------- 1 | *.tif -------------------------------------------------------------------------------- /dataset/validation/Benign/.gitignore: -------------------------------------------------------------------------------- 1 | *.tif -------------------------------------------------------------------------------- /dataset/validation/InSitu/.gitignore: -------------------------------------------------------------------------------- 1 | *.tif -------------------------------------------------------------------------------- /dataset/validation/Invasive/.gitignore: -------------------------------------------------------------------------------- 1 | *.tif -------------------------------------------------------------------------------- /dataset/validation/Normal/.gitignore: -------------------------------------------------------------------------------- 1 | *.tif -------------------------------------------------------------------------------- /img/dataset.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ImagingLab/ICIAR2018/0a98dd9c21e3069052a38bc75973832b3a5e775e/img/dataset.jpg -------------------------------------------------------------------------------- /img/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ImagingLab/ICIAR2018/0a98dd9c21e3069052a38bc75973832b3a5e775e/img/network.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy ~=1.14.3 2 | scipy ~= 1.0.1 3 | future ~= 0.16.0 4 | matplotlib ~= 2.2.2 5 | pillow >= 6.2.0 6 | scikit-learn ~= 0.19.1 7 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import * 2 | from .networks import * 3 | from .options import * 4 | from .datasets import * 5 | from .patch_extractor import * 6 | -------------------------------------------------------------------------------- /src/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | import numpy as np 5 | from PIL import Image, ImageEnhance 6 | from torch.utils.data import Dataset 7 | from torchvision.transforms import transforms 8 | from .patch_extractor import PatchExtractor 9 | 10 | LABELS = ['Normal', 'Benign', 'InSitu', 'Invasive'] 11 | IMAGE_SIZE = (2048, 1536) 12 | PATCH_SIZE = 512 13 | 14 | 15 | class PatchWiseDataset(Dataset): 16 | def __init__(self, path, stride=PATCH_SIZE, rotate=False, flip=False, enhance=False): 17 | super().__init__() 18 | 19 | wp = int((IMAGE_SIZE[0] - PATCH_SIZE) / stride + 1) 20 | hp = int((IMAGE_SIZE[1] - PATCH_SIZE) / stride + 1) 21 | labels = {name: index for index in range(len(LABELS)) for name in glob.glob(path + '/' + LABELS[index] + '/*.tif')} 22 | 23 | self.path = path 24 | self.stride = stride 25 | self.labels = labels 26 | self.names = list(sorted(labels.keys())) 27 | self.shape = (len(labels), wp, hp, (4 if rotate else 1), (2 if flip else 1), (2 if enhance else 1)) # (files, x_patches, y_patches, rotations, flip, enhance) 28 | self.augment_size = np.prod(self.shape) / len(labels) 29 | 30 | def __getitem__(self, index): 31 | im, xpatch, ypatch, rotation, flip, enhance = np.unravel_index(index, self.shape) 32 | 33 | with Image.open(self.names[im]) as img: 34 | extractor = PatchExtractor(img=img, patch_size=PATCH_SIZE, stride=self.stride) 35 | patch = extractor.extract_patch((xpatch, ypatch)) 36 | 37 | if rotation != 0: 38 | patch = patch.rotate(rotation * 90) 39 | 40 | if flip != 0: 41 | patch = patch.transpose(Image.FLIP_LEFT_RIGHT) 42 | 43 | if enhance != 0: 44 | factors = np.random.uniform(.5, 1.5, 3) 45 | patch = ImageEnhance.Color(patch).enhance(factors[0]) 46 | patch = ImageEnhance.Contrast(patch).enhance(factors[1]) 47 | patch = ImageEnhance.Brightness(patch).enhance(factors[2]) 48 | 49 | label = self.labels[self.names[im]] 50 | return transforms.ToTensor()(patch), label 51 | 52 | def __len__(self): 53 | return np.prod(self.shape) 54 | 55 | 56 | class ImageWiseDataset(Dataset): 57 | def __init__(self, path, stride=PATCH_SIZE, rotate=False, flip=False, enhance=False): 58 | super().__init__() 59 | 60 | labels = {name: index for index in range(len(LABELS)) for name in glob.glob(path + '/' + LABELS[index] + '/*.tif')} 61 | 62 | self.path = path 63 | self.stride = stride 64 | self.labels = labels 65 | self.names = list(sorted(labels.keys())) 66 | self.shape = (len(labels), (4 if rotate else 1), (2 if flip else 1), (2 if enhance else 1)) # (files, x_patches, y_patches, rotations, flip, enhance) 67 | self.augment_size = np.prod(self.shape) / len(labels) 68 | 69 | def __getitem__(self, index): 70 | im, rotation, flip, enhance = np.unravel_index(index, self.shape) 71 | 72 | with Image.open(self.names[im]) as img: 73 | 74 | if flip != 0: 75 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 76 | 77 | if rotation != 0: 78 | img = img.rotate(rotation * 90) 79 | 80 | if enhance != 0: 81 | factors = np.random.uniform(.5, 1.5, 3) 82 | img = ImageEnhance.Color(img).enhance(factors[0]) 83 | img = ImageEnhance.Contrast(img).enhance(factors[1]) 84 | img = ImageEnhance.Brightness(img).enhance(factors[2]) 85 | 86 | extractor = PatchExtractor(img=img, patch_size=PATCH_SIZE, stride=self.stride) 87 | patches = extractor.extract_patches() 88 | 89 | label = self.labels[self.names[im]] 90 | 91 | b = torch.zeros((len(patches), 3, PATCH_SIZE, PATCH_SIZE)) 92 | for i in range(len(patches)): 93 | b[i] = transforms.ToTensor()(patches[i]) 94 | 95 | return b, label 96 | 97 | def __len__(self): 98 | return np.prod(self.shape) 99 | 100 | 101 | class TestDataset(Dataset): 102 | def __init__(self, path, stride=PATCH_SIZE, augment=False): 103 | super().__init__() 104 | 105 | if os.path.isdir(path): 106 | names = [name for name in glob.glob(path + '/*.tif')] 107 | else: 108 | names = [path] 109 | 110 | self.path = path 111 | self.stride = stride 112 | self.augment = augment 113 | self.names = list(sorted(names)) 114 | 115 | def __getitem__(self, index): 116 | file = self.names[index] 117 | with Image.open(file) as img: 118 | 119 | bins = 8 if self.augment else 1 120 | extractor = PatchExtractor(img=img, patch_size=PATCH_SIZE, stride=self.stride) 121 | b = torch.zeros((bins, extractor.shape()[0] * extractor.shape()[1], 3, PATCH_SIZE, PATCH_SIZE)) 122 | 123 | for k in range(bins): 124 | 125 | if k % 4 != 0: 126 | img = img.rotate((k % 4) * 90) 127 | 128 | if k // 4 != 0: 129 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 130 | 131 | extractor = PatchExtractor(img=img, patch_size=PATCH_SIZE, stride=self.stride) 132 | patches = extractor.extract_patches() 133 | 134 | for i in range(len(patches)): 135 | b[k, i] = transforms.ToTensor()(patches[i]) 136 | 137 | return b, file 138 | 139 | def __len__(self): 140 | return len(self.names) 141 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import time 2 | import time 3 | import ntpath 4 | import datetime 5 | import matplotlib.pyplot as plt 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | import matplotlib.pyplot as ply 9 | from torch.autograd import Variable 10 | from torch.utils.data import DataLoader, TensorDataset 11 | from sklearn.metrics import roc_curve, auc 12 | from sklearn.preprocessing import label_binarize 13 | 14 | from .datasets import * 15 | 16 | TRAIN_PATH = '/train' 17 | VALIDATION_PATH = '/validation' 18 | 19 | 20 | class BaseModel: 21 | def __init__(self, args, network, weights_path): 22 | self.args = args 23 | self.weights = weights_path 24 | self.network = network.cuda() if args.cuda else network 25 | self.load() 26 | 27 | def load(self): 28 | try: 29 | if os.path.exists(self.weights): 30 | print('Loading "patch-wise" model...') 31 | self.network.load_state_dict(torch.load(self.weights)) 32 | except: 33 | print('Failed to load pre-trained network') 34 | 35 | def save(self): 36 | print('Saving model to "{}"'.format(self.weights)) 37 | torch.save(self.network.state_dict(), self.weights) 38 | 39 | 40 | class PatchWiseModel(BaseModel): 41 | def __init__(self, args, network): 42 | super(PatchWiseModel, self).__init__(args, network, args.checkpoints_path + '/weights_' + network.name() + '.pth') 43 | 44 | def train(self): 45 | self.network.train() 46 | print('Start training patch-wise network: {}\n'.format(time.strftime('%Y/%m/%d %H:%M'))) 47 | 48 | train_loader = DataLoader( 49 | dataset=PatchWiseDataset(path=self.args.dataset_path + TRAIN_PATH, stride=self.args.patch_stride, rotate=True, flip=True, enhance=True), 50 | batch_size=self.args.batch_size, 51 | shuffle=True, 52 | num_workers=4 53 | ) 54 | optimizer = optim.Adam(self.network.parameters(), lr=self.args.lr, betas=(self.args.beta1, self.args.beta2)) 55 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1) 56 | best = self.validate(verbose=False) 57 | mean = 0 58 | epoch = 0 59 | 60 | for epoch in range(1, self.args.epochs + 1): 61 | 62 | self.network.train() 63 | scheduler.step() 64 | stime = datetime.datetime.now() 65 | 66 | correct = 0 67 | total = 0 68 | 69 | for index, (images, labels) in enumerate(train_loader): 70 | 71 | if self.args.cuda: 72 | images, labels = images.cuda(), labels.cuda() 73 | 74 | optimizer.zero_grad() 75 | output = self.network(Variable(images)) 76 | loss = F.nll_loss(output, Variable(labels)) 77 | loss.backward() 78 | optimizer.step() 79 | 80 | _, predicted = torch.max(output.data, 1) 81 | correct += torch.sum(predicted == labels) 82 | total += len(images) 83 | 84 | if index > 0 and index % self.args.log_interval == 0: 85 | print('Epoch: {}/{} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, Accuracy: {:.2f}%'.format( 86 | epoch, 87 | self.args.epochs, 88 | index * len(images), 89 | len(train_loader.dataset), 90 | 100. * index / len(train_loader), 91 | loss.data[0], 92 | 100 * correct / total 93 | )) 94 | 95 | print('\nEnd of epoch {}, time: {}'.format(epoch, datetime.datetime.now() - stime)) 96 | acc = self.validate() 97 | mean += acc 98 | if acc > best: 99 | best = acc 100 | 101 | self.save() 102 | 103 | print('\nEnd of training, best accuracy: {}, mean accuracy: {}\n'.format(best, mean // epoch)) 104 | 105 | def validate(self, verbose=True): 106 | self.network.eval() 107 | 108 | test_loss = 0 109 | correct = 0 110 | classes = len(LABELS) 111 | 112 | tp = [0] * classes 113 | tpfp = [0] * classes 114 | tpfn = [0] * classes 115 | precision = [0] * classes 116 | recall = [0] * classes 117 | f1 = [0] * classes 118 | 119 | test_loader = DataLoader( 120 | dataset=PatchWiseDataset(path=self.args.dataset_path + VALIDATION_PATH, stride=self.args.patch_stride), 121 | batch_size=self.args.batch_size, 122 | shuffle=False, 123 | num_workers=4 124 | ) 125 | if verbose: 126 | print('\nEvaluating....') 127 | 128 | for images, labels in test_loader: 129 | 130 | if self.args.cuda: 131 | images, labels = images.cuda(), labels.cuda() 132 | 133 | output = self.network(Variable(images, volatile=True)) 134 | 135 | test_loss += F.nll_loss(output, Variable(labels), size_average=False).data[0] 136 | _, predicted = torch.max(output.data, 1) 137 | correct += torch.sum(predicted == labels) 138 | 139 | for label in range(classes): 140 | t_labels = labels == label 141 | p_labels = predicted == label 142 | tp[label] += torch.sum(t_labels == (p_labels * 2 - 1)) 143 | tpfp[label] += torch.sum(p_labels) 144 | tpfn[label] += torch.sum(t_labels) 145 | 146 | for label in range(classes): 147 | precision[label] += (tp[label] / (tpfp[label] + 1e-8)) 148 | recall[label] += (tp[label] / (tpfn[label] + 1e-8)) 149 | f1[label] = 2 * precision[label] * recall[label] / (precision[label] + recall[label] + 1e-8) 150 | 151 | test_loss /= len(test_loader.dataset) 152 | acc = 100. * correct / len(test_loader.dataset) 153 | 154 | if verbose: 155 | print('Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format( 156 | test_loss, 157 | correct, 158 | len(test_loader.dataset), 159 | 100. * correct / len(test_loader.dataset) 160 | )) 161 | 162 | for label in range(classes): 163 | print('{}: \t Precision: {:.2f}, Recall: {:.2f}, F1: {:.2f}'.format( 164 | LABELS[label], 165 | precision[label], 166 | recall[label], 167 | f1[label] 168 | )) 169 | 170 | print('') 171 | return acc 172 | 173 | def test(self, path, verbose=True): 174 | self.network.eval() 175 | dataset = TestDataset(path=path, stride=PATCH_SIZE, augment=False) 176 | data_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False) 177 | stime = datetime.datetime.now() 178 | 179 | if verbose: 180 | print('\t sum\t\t max\t\t maj') 181 | 182 | res = [] 183 | 184 | for index, (image, file_name) in enumerate(data_loader): 185 | image = image.squeeze() 186 | if self.args.cuda: 187 | image = image.cuda() 188 | 189 | output = self.network(Variable(image)) 190 | _, predicted = torch.max(output.data, 1) 191 | 192 | # 193 | # the following measures are prioritised based on [invasive, insitu, benign, normal] 194 | # the original labels are [normal, benign, insitu, invasive], so we reverse the order using [::-1] 195 | # output data shape is 12x4 196 | # sum_prop: sum of probabilities among y axis: (1, 4), reverse, and take the index of the largest value 197 | # max_prop: max of probabilities among y axis: (1, 4), reverse, and take the index of the largest value 198 | # maj_prop: majority voting: create a one-hot vector of predicted values: (12, 4), sum among y axis: (1, 4), reverse, and take the index of the largest value 199 | 200 | sum_prob = 3 - np.argmax(np.sum(np.exp(output.data.cpu().numpy()), axis=0)[::-1]) 201 | max_prob = 3 - np.argmax(np.max(np.exp(output.data.cpu().numpy()), axis=0)[::-1]) 202 | maj_prob = 3 - np.argmax(np.sum(np.eye(4)[np.array(predicted).reshape(-1)], axis=0)[::-1]) 203 | 204 | res.append([sum_prob, max_prob, maj_prob, file_name[0]]) 205 | 206 | if verbose: 207 | np.sum(output.data.cpu().numpy(), axis=0) 208 | print('{}) \t {} \t {} \t {} \t {}'.format( 209 | str(index + 1).rjust(2, '0'), 210 | LABELS[sum_prob].ljust(8), 211 | LABELS[max_prob].ljust(8), 212 | LABELS[maj_prob].ljust(8), 213 | ntpath.basename(file_name[0]))) 214 | 215 | if verbose: 216 | print('\nInference time: {}\n'.format(datetime.datetime.now() - stime)) 217 | 218 | return res 219 | 220 | def output(self, input_tensor): 221 | self.network.eval() 222 | res = self.network.features(Variable(input_tensor, volatile=True)) 223 | return res.squeeze() 224 | 225 | def visualize(self, path, channel=0): 226 | self.network.eval() 227 | dataset = TestDataset(path=path, stride=PATCH_SIZE) 228 | data_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False) 229 | 230 | for index, (image, file_name) in enumerate(data_loader): 231 | 232 | if self.args.cuda: 233 | image = image[0].cuda() 234 | 235 | patches = self.output(image) 236 | 237 | output = patches.cpu().data.numpy() 238 | 239 | map = np.zeros((3 * 64, 4 * 64)) 240 | 241 | for i in range(12): 242 | row = i // 4 243 | col = i % 4 244 | map[row * 64:(row + 1) * 64, col * 64:(col + 1) * 64] = output[i] 245 | 246 | if len(map.shape) > 2: 247 | map = map[channel] 248 | 249 | with Image.open(file_name[0]) as img: 250 | ply.subplot(121) 251 | ply.axis('off') 252 | ply.imshow(np.array(img)) 253 | 254 | ply.subplot(122) 255 | ply.imshow(map, cmap='gray') 256 | ply.axis('off') 257 | 258 | ply.show() 259 | 260 | 261 | class ImageWiseModel(BaseModel): 262 | def __init__(self, args, image_wise_network, patch_wise_network): 263 | super(ImageWiseModel, self).__init__(args, image_wise_network, args.checkpoints_path + '/weights_' + image_wise_network.name() + '.pth') 264 | 265 | self.patch_wise_model = PatchWiseModel(args, patch_wise_network) 266 | self._test_loader = None 267 | 268 | def train(self): 269 | self.network.train() 270 | print('Evaluating patch-wise model...') 271 | 272 | train_loader = self._patch_loader(self.args.dataset_path + TRAIN_PATH, True) 273 | 274 | print('Start training image-wise network: {}\n'.format(time.strftime('%Y/%m/%d %H:%M'))) 275 | 276 | optimizer = optim.Adam(self.network.parameters(), lr=self.args.lr, betas=(self.args.beta1, self.args.beta2)) 277 | best = self.validate(verbose=False) 278 | mean = 0 279 | epoch = 0 280 | 281 | for epoch in range(1, self.args.epochs + 1): 282 | 283 | self.network.train() 284 | stime = datetime.datetime.now() 285 | 286 | correct = 0 287 | total = 0 288 | 289 | for index, (images, labels) in enumerate(train_loader): 290 | 291 | if self.args.cuda: 292 | images, labels = images.cuda(), labels.cuda() 293 | 294 | optimizer.zero_grad() 295 | output = self.network(Variable(images)) 296 | loss = F.nll_loss(output, Variable(labels)) 297 | loss.backward() 298 | optimizer.step() 299 | 300 | _, predicted = torch.max(output.data, 1) 301 | correct += torch.sum(predicted == labels) 302 | total += len(images) 303 | 304 | if index > 0 and index % 10 == 0: 305 | print('Epoch: {}/{} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, Accuracy: {:.2f}%'.format( 306 | epoch, 307 | self.args.epochs, 308 | index * len(images), 309 | len(train_loader.dataset), 310 | 100. * index / len(train_loader), 311 | loss.data[0], 312 | 100 * correct / total 313 | )) 314 | 315 | print('\nEnd of epoch {}, time: {}'.format(epoch, datetime.datetime.now() - stime)) 316 | acc = self.validate() 317 | mean += acc 318 | if acc > best: 319 | best = acc 320 | self.save() 321 | 322 | print('\nEnd of training, best accuracy: {}, mean accuracy: {}\n'.format(best, mean // epoch)) 323 | 324 | def validate(self, verbose=True, roc=False): 325 | self.network.eval() 326 | 327 | if self._test_loader is None: 328 | self._test_loader = self._patch_loader(self.args.dataset_path + VALIDATION_PATH, False) 329 | 330 | val_loss = 0 331 | correct = 0 332 | classes = len(LABELS) 333 | 334 | tp = [0] * classes 335 | tpfp = [0] * classes 336 | tpfn = [0] * classes 337 | precision = [0] * classes 338 | recall = [0] * classes 339 | f1 = [0] * classes 340 | 341 | if verbose: 342 | print('\nEvaluating....') 343 | 344 | labels_true = [] 345 | labels_pred = np.empty((0, 4)) 346 | 347 | for images, labels in self._test_loader: 348 | 349 | if self.args.cuda: 350 | images, labels = images.cuda(), labels.cuda() 351 | 352 | output = self.network(Variable(images, volatile=True)) 353 | 354 | val_loss += F.nll_loss(output, Variable(labels), size_average=False).data[0] 355 | _, predicted = torch.max(output.data, 1) 356 | correct += torch.sum(predicted == labels) 357 | 358 | labels_true = np.append(labels_true, labels) 359 | labels_pred = np.append(labels_pred, torch.exp(output.data).cpu().numpy(), axis=0) 360 | 361 | for label in range(classes): 362 | t_labels = labels == label 363 | p_labels = predicted == label 364 | tp[label] += torch.sum(t_labels == (p_labels * 2 - 1)) 365 | tpfp[label] += torch.sum(p_labels) 366 | tpfn[label] += torch.sum(t_labels) 367 | 368 | for label in range(classes): 369 | precision[label] += (tp[label] / (tpfp[label] + 1e-8)) 370 | recall[label] += (tp[label] / (tpfn[label] + 1e-8)) 371 | f1[label] = 2 * precision[label] * recall[label] / (precision[label] + recall[label] + 1e-8) 372 | 373 | val_loss /= len(self._test_loader.dataset) 374 | acc = 100. * correct / len(self._test_loader.dataset) 375 | 376 | if roc == 1: 377 | labels_true = label_binarize(labels_true, classes=range(classes)) 378 | for lbl in range(classes): 379 | fpr, tpr, _ = roc_curve(labels_true[:, lbl], labels_pred[:, lbl]) 380 | roc_auc = auc(fpr, tpr) 381 | plt.plot(fpr, tpr, lw=2, label='{} (AUC: {:.1f})'.format(LABELS[lbl], roc_auc * 100)) 382 | 383 | plt.xlim([0, 1]) 384 | plt.ylim([0, 1.05]) 385 | plt.xlabel('False Positive Rate') 386 | plt.ylabel('True Positive Rate') 387 | plt.legend(loc="lower right") 388 | plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') 389 | plt.title('Receiver Operating Characteristic') 390 | plt.show() 391 | 392 | if verbose: 393 | print('Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format( 394 | val_loss, 395 | correct, 396 | len(self._test_loader.dataset), 397 | acc 398 | )) 399 | 400 | for label in range(classes): 401 | print('{}: \t Precision: {:.2f}, Recall: {:.2f}, F1: {:.2f}'.format( 402 | LABELS[label], 403 | precision[label], 404 | recall[label], 405 | f1[label] 406 | )) 407 | 408 | print('') 409 | 410 | return acc 411 | 412 | def test(self, path, verbose=True, ensemble=True): 413 | self.network.eval() 414 | dataset = TestDataset(path=path, stride=PATCH_SIZE, augment=ensemble) 415 | data_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False) 416 | stime = datetime.datetime.now() 417 | 418 | if verbose: 419 | print('') 420 | 421 | res = [] 422 | 423 | for index, (image, file_name) in enumerate(data_loader): 424 | n_bins, n_patches = image.shape[1], image.shape[2] 425 | image = image.view(-1, 3, PATCH_SIZE, PATCH_SIZE) 426 | 427 | if self.args.cuda: 428 | image = image.cuda() 429 | 430 | patches = self.patch_wise_model.output(image) 431 | patches = patches.view(n_bins, -1, 64, 64) 432 | 433 | if self.args.cuda: 434 | patches = patches.cuda() 435 | 436 | output = self.network(patches) 437 | _, predicted = torch.max(output.data, 1) 438 | 439 | # maj_prop: majority voting: create a one-hot vector of predicted values: (12, 4), 440 | # sum among y axis: (1, 4), reverse, and take the index of the largest value 441 | 442 | maj_prob = 3 - np.argmax(np.sum(np.eye(4)[np.array(predicted).reshape(-1)], axis=0)[::-1]) 443 | 444 | confidence = np.sum(np.array(predicted) == maj_prob) / n_bins if ensemble else torch.max(torch.exp(output.data)) 445 | confidence = np.round(confidence * 100, 2) 446 | 447 | res.append([maj_prob, confidence, file_name[0]]) 448 | 449 | if verbose: 450 | print('{}) {} ({}%) \t {}'.format( 451 | str(index).rjust(2, '0'), 452 | LABELS[maj_prob], 453 | confidence, 454 | ntpath.basename(file_name[0]))) 455 | 456 | if verbose: 457 | print('\nInference time: {}\n'.format(datetime.datetime.now() - stime)) 458 | 459 | return res 460 | 461 | def _patch_loader(self, path, augment): 462 | images_path = '{}/{}_images.npy'.format(self.args.checkpoints_path, self.network.name()) 463 | labels_path = '{}/{}_labels.npy'.format(self.args.checkpoints_path, self.network.name()) 464 | 465 | if self.args.debug and augment and os.path.exists(images_path): 466 | np_images = np.load(images_path) 467 | np_labels = np.load(labels_path) 468 | 469 | else: 470 | dataset = ImageWiseDataset( 471 | path=path, 472 | stride=PATCH_SIZE, 473 | flip=augment, 474 | rotate=augment, 475 | enhance=augment) 476 | 477 | bsize = 8 478 | output_loader = DataLoader(dataset=dataset, batch_size=bsize, shuffle=True, num_workers=4) 479 | output_images = [] 480 | output_labels = [] 481 | 482 | for index, (images, labels) in enumerate(output_loader): 483 | if index > 0 and index % 10 == 0: 484 | print('{} images loaded'.format(int((index * bsize) / dataset.augment_size))) 485 | 486 | if self.args.cuda: 487 | images = images.cuda() 488 | 489 | bsize = images.shape[0] 490 | 491 | res = self.patch_wise_model.output(images.view((-1, 3, 512, 512))) 492 | res = res.view((bsize, -1, 64, 64)).data.cpu().numpy() 493 | 494 | for i in range(bsize): 495 | output_images.append(res[i]) 496 | output_labels.append(labels.numpy()[i]) 497 | 498 | np_images = np.array(output_images) 499 | np_labels = np.array(output_labels) 500 | 501 | if self.args.debug and augment: 502 | np.save(images_path, np_images) 503 | np.save(labels_path, np_labels) 504 | 505 | images, labels = torch.from_numpy(np_images), torch.from_numpy(np_labels).squeeze() 506 | 507 | return DataLoader( 508 | dataset=TensorDataset(images, labels), 509 | batch_size=self.args.batch_size, 510 | shuffle=True, 511 | num_workers=2 512 | ) 513 | -------------------------------------------------------------------------------- /src/networks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class BaseNetwork(nn.Module): 7 | def __init__(self, name, channels=1): 8 | super(BaseNetwork, self).__init__() 9 | self._name = name 10 | self._channels = channels 11 | 12 | def name(self): 13 | return self._name 14 | 15 | def initialize_weights(self): 16 | for m in self.modules(): 17 | if isinstance(m, nn.Conv2d): 18 | nn.init.kaiming_normal_(m.weight, nonlinearity='relu') 19 | if m.bias is not None: 20 | m.bias.data.zero_() 21 | 22 | elif isinstance(m, nn.BatchNorm2d): 23 | m.weight.data.fill_(1) 24 | m.bias.data.zero_() 25 | 26 | elif isinstance(m, nn.Linear): 27 | m.weight.data.normal_(0, 0.01) 28 | m.bias.data.zero_() 29 | 30 | 31 | class PatchWiseNetwork(BaseNetwork): 32 | def __init__(self, channels=1): 33 | super(PatchWiseNetwork, self).__init__('pw' + str(channels), channels) 34 | 35 | self.features = nn.Sequential( 36 | # Block 1 37 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1), 38 | nn.BatchNorm2d(16), 39 | nn.ReLU(inplace=True), 40 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1), 41 | nn.BatchNorm2d(16), 42 | nn.ReLU(inplace=True), 43 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=2, stride=2), 44 | nn.BatchNorm2d(16), 45 | nn.ReLU(inplace=True), 46 | 47 | # Block 2 48 | nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1), 49 | nn.BatchNorm2d(32), 50 | nn.ReLU(inplace=True), 51 | nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), 52 | nn.BatchNorm2d(32), 53 | nn.ReLU(inplace=True), 54 | nn.Conv2d(in_channels=32, out_channels=32, kernel_size=2, stride=2), 55 | nn.BatchNorm2d(32), 56 | nn.ReLU(inplace=True), 57 | 58 | # Block 3 59 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1), 60 | nn.BatchNorm2d(64), 61 | nn.ReLU(inplace=True), 62 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), 63 | nn.BatchNorm2d(64), 64 | nn.ReLU(inplace=True), 65 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=2, stride=2), 66 | nn.BatchNorm2d(64), 67 | nn.ReLU(inplace=True), 68 | 69 | # Block 4 70 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), 71 | nn.BatchNorm2d(128), 72 | nn.ReLU(inplace=True), 73 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), 74 | nn.BatchNorm2d(128), 75 | nn.ReLU(inplace=True), 76 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), 77 | nn.BatchNorm2d(128), 78 | nn.ReLU(inplace=True), 79 | 80 | # Block 5 81 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1), 82 | nn.BatchNorm2d(256), 83 | nn.ReLU(inplace=True), 84 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), 85 | nn.BatchNorm2d(256), 86 | nn.ReLU(inplace=True), 87 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), 88 | nn.BatchNorm2d(256), 89 | nn.ReLU(inplace=True), 90 | 91 | nn.Conv2d(in_channels=256, out_channels=channels, kernel_size=1, stride=1), 92 | ) 93 | 94 | self.classifier = nn.Sequential( 95 | nn.Linear(channels * 64 * 64, 4), 96 | ) 97 | 98 | self.initialize_weights() 99 | 100 | def forward(self, x): 101 | x = self.features(x) 102 | x = x.view(x.size(0), -1) 103 | x = self.classifier(x) 104 | x = F.log_softmax(x, dim=1) 105 | return x 106 | 107 | 108 | class ImageWiseNetwork(BaseNetwork): 109 | def __init__(self, channels=1): 110 | super(ImageWiseNetwork, self).__init__('iw' + str(channels), channels) 111 | 112 | self.features = nn.Sequential( 113 | # Block 1 114 | nn.Conv2d(in_channels=12 * channels, out_channels=64, kernel_size=3, stride=1, padding=1), 115 | nn.BatchNorm2d(64), 116 | nn.ReLU(inplace=True), 117 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), 118 | nn.BatchNorm2d(64), 119 | nn.ReLU(inplace=True), 120 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=2, stride=2), 121 | nn.BatchNorm2d(64), 122 | nn.ReLU(inplace=True), 123 | 124 | # Block 2 125 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), 126 | nn.BatchNorm2d(128), 127 | nn.ReLU(inplace=True), 128 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), 129 | nn.BatchNorm2d(128), 130 | nn.ReLU(inplace=True), 131 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=2, stride=2), 132 | nn.BatchNorm2d(128), 133 | nn.ReLU(inplace=True), 134 | 135 | nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1), 136 | ) 137 | 138 | self.classifier = nn.Sequential( 139 | nn.Linear(1 * 16 * 16, 128), 140 | nn.ReLU(inplace=True), 141 | nn.Dropout(0.5, inplace=True), 142 | 143 | nn.Linear(128, 128), 144 | nn.ReLU(inplace=True), 145 | nn.Dropout(0.5, inplace=True), 146 | 147 | nn.Linear(128, 64), 148 | nn.ReLU(inplace=True), 149 | nn.Dropout(0.5, inplace=True), 150 | 151 | nn.Linear(64, 4), 152 | ) 153 | 154 | self.initialize_weights() 155 | 156 | def forward(self, x): 157 | x = self.features(x) 158 | x = x.view(x.size(0), -1) 159 | x = self.classifier(x) 160 | x = F.log_softmax(x, dim=1) 161 | return x 162 | -------------------------------------------------------------------------------- /src/options.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import torch 4 | import argparse 5 | 6 | 7 | class ModelOptions: 8 | def __init__(self): 9 | parser = argparse.ArgumentParser(description='Classification of breast cancer histology') 10 | parser.add_argument('--dataset-path', type=str, default='./dataset', help='dataset path (default: ./dataset)') 11 | parser.add_argument('--testset-path', type=str, default='', help='file or directory address to the test set') 12 | parser.add_argument('--checkpoints-path', type=str, default='./checkpoints', help='models are saved here') 13 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', help='input batch size for training (default: 64)') 14 | parser.add_argument('--test-batch-size', type=int, default=64, metavar='N', help='input batch size for testing (default: 64)') 15 | parser.add_argument('--patch-stride', type=int, default=256, metavar='N', help='How far the centers of two consecutive patches are in the image (default: 256)') 16 | parser.add_argument('--epochs', type=int, default=30, metavar='N', help='number of epochs to train (default: 30)') 17 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR', help='learning rate (default: 0.01)') 18 | parser.add_argument('--beta1', type=float, default=0.9, metavar='M', help='Adam beta1 (default: 0.9)') 19 | parser.add_argument('--beta2', type=float, default=0.999, metavar='M', help='Adam beta2 (default: 0.999)') 20 | parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') 21 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') 22 | parser.add_argument('--log-interval', type=int, default=50, metavar='N', help='how many batches to wait before logging training status') 23 | parser.add_argument('--gpu-ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 24 | parser.add_argument('--ensemble', type=int, default='1', help='whether to use model ensemble on test-set prediction (default: 1)') 25 | parser.add_argument('--network', type=str, default='0', help='train patch-wise network: 1, image-wise network: 2 or both: 0 (default: 0)') 26 | parser.add_argument('--channels', type=int, default=1, help='number of channels created by the patch-wise network that feeds into the image-wise network (default: 1)') 27 | parser.add_argument('--debug', type=int, default=0, help='debugging (default: 0)') 28 | 29 | self._parser = parser 30 | 31 | def parse(self): 32 | opt = self._parser.parse_args() 33 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids 34 | opt.cuda = not opt.no_cuda and torch.cuda.is_available() 35 | opt.debug = opt.debug != 0 36 | 37 | args = vars(opt) 38 | print('\n------------ Options -------------') 39 | for k, v in sorted(args.items()): 40 | print('%s: %s' % (str(k), str(v))) 41 | print('-------------- End ----------------\n') 42 | 43 | return opt 44 | -------------------------------------------------------------------------------- /src/patch_extractor.py: -------------------------------------------------------------------------------- 1 | class PatchExtractor: 2 | def __init__(self, img, patch_size, stride): 3 | ''' 4 | :param img: :py:class:`~PIL.Image.Image` 5 | :param patch_size: integer, size of the patch 6 | :param stride: integer, size of the stride 7 | ''' 8 | self.img = img 9 | self.size = patch_size 10 | self.stride = stride 11 | 12 | def extract_patches(self): 13 | """ 14 | extracts all patches from an image 15 | :returns: A list of :py:class:`~PIL.Image.Image` objects. 16 | """ 17 | wp, hp = self.shape() 18 | return [self.extract_patch((w, h)) for h in range(hp) for w in range(wp)] 19 | 20 | def extract_patch(self, patch): 21 | """ 22 | extracts a patch from an input image 23 | :param patch: a tuple 24 | :rtype: :py:class:`~PIL.Image.Image` 25 | :returns: An :py:class:`~PIL.Image.Image` object. 26 | """ 27 | return self.img.crop(( 28 | patch[0] * self.stride, # left 29 | patch[1] * self.stride, # up 30 | patch[0] * self.stride + self.size, # right 31 | patch[1] * self.stride + self.size # down 32 | )) 33 | 34 | def shape(self): 35 | wp = int((self.img.width - self.size) / self.stride + 1) 36 | hp = int((self.img.height - self.size) / self.stride + 1) 37 | return wp, hp 38 | 39 | 40 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from src import * 2 | 3 | args = ModelOptions().parse() 4 | 5 | torch.manual_seed(args.seed) 6 | if args.cuda: 7 | torch.cuda.manual_seed(args.seed) 8 | 9 | pw_network = PatchWiseNetwork(args.channels) 10 | iw_network = ImageWiseNetwork(args.channels) 11 | 12 | if args.testset_path is '': 13 | import tkinter.filedialog as fdialog 14 | 15 | args.testset_path = fdialog.askopenfilename(initialdir=r"./dataset/test", title="choose your file", filetypes=(("tiff files", "*.tif"), ("all files", "*.*"))) 16 | 17 | if args.network == '1': 18 | pw_model = PatchWiseModel(args, pw_network) 19 | pw_model.test(args.testset_path) 20 | 21 | else: 22 | im_model = ImageWiseModel(args, iw_network, pw_network) 23 | im_model.test(args.testset_path, ensemble=args.ensemble == 1) 24 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from src import * 2 | 3 | args = ModelOptions().parse() 4 | 5 | torch.manual_seed(args.seed) 6 | if args.cuda: 7 | torch.cuda.manual_seed(args.seed) 8 | 9 | pw_network = PatchWiseNetwork(args.channels) 10 | iw_network = ImageWiseNetwork(args.channels) 11 | 12 | if args.network == '0' or args.network == '1': 13 | pw_model = PatchWiseModel(args, pw_network) 14 | pw_model.train() 15 | 16 | if args.network == '0' or args.network == '2': 17 | iw_model = ImageWiseModel(args, iw_network, pw_network) 18 | iw_model.train() 19 | -------------------------------------------------------------------------------- /validate.py: -------------------------------------------------------------------------------- 1 | from src import * 2 | 3 | args = ModelOptions().parse() 4 | 5 | torch.manual_seed(args.seed) 6 | if args.cuda: 7 | torch.cuda.manual_seed(args.seed) 8 | 9 | pw_network = PatchWiseNetwork(args.channels) 10 | iw_network = ImageWiseNetwork(args.channels) 11 | 12 | if args.network == '0' or args.network == '1': 13 | pw_model = PatchWiseModel(args, pw_network) 14 | pw_model.validate() 15 | 16 | if args.network == '0' or args.network == '2': 17 | iw_model = ImageWiseModel(args, iw_network, pw_network) 18 | iw_model.validate(roc=True) 19 | --------------------------------------------------------------------------------