├── README.md ├── Warmup_with_SimSiam ├── .gitignore ├── LICENSE ├── README.md ├── byol_pytorch │ ├── __init__.py │ ├── byol_pytorch.py │ ├── dataset.py │ └── model.py ├── diagram.png ├── examples │ ├── dataset.py │ └── lightning │ │ ├── README.md │ │ └── train.py ├── setup.py └── warmup.py ├── args_parser.py ├── finaledgegcncopy.py ├── inference ├── 001.png ├── 002.png ├── 003.png ├── 004.png ├── 005.png ├── 006.png └── 007.png ├── input_data ├── 001.png ├── 002.png ├── 003.png ├── 004.png ├── 005.png ├── 006.png └── 007.png ├── model.py ├── state_dict ├── FCN1000.pt ├── FCNopt1000.pt ├── GCN1000.pt └── GCNopt1000.pt ├── train_process_files └── .keep └── utilities.py /README.md: -------------------------------------------------------------------------------- 1 | # Joint Fully Convolutional and Graph Convolutional Networks for Weakly-Supervised Segmentation of Pathology Images 2 | 3 | # Instructions 4 | A trained checkpoint numbered 1000 is provided with 9 HER2 pathology images for use in inference. 5 | This checkpoint is trained with 226 HER2 pathology images from a private dataset 6 | 7 | ## To run inference: 8 | `python3 finaledgegcncopy.py --inference-path full_path_to/inference --checkpoint xxxx` 9 | for example, with provided images and state dict, run like: 10 | `python3 finaledgegcncopy.py --inference-path full_path_to/inference --checkpoint 1000` 11 | 12 | ## To run Train: 13 | `python3 finaledgegcncopy.py` 14 | 15 | ## To resume Train: 16 | `python3 finaledgegcncopy.py --checkpoint xxxx` 17 | 18 | ## Flags and folders: 19 | 20 | `--train-path` or `./train_process_files`: a folder which the pipeline saves training visualization files to 21 | 22 | `--input-path` or `./input_data`: images used for inference or training. For our weakly supervised loss to work, the training images should be named as: AreaRatio_Uncertainty_*.png. 23 | For example, 0.4_0.05_*.png means the target region occupies (40+/-5)% of the image. 24 | 25 | `--inference-path` or `full_path_to/inference`: an argument that defines a folder for the pipeline to output inferenced mask to. 26 | Setting this argument will switch on inference mode. This argument must be used with `--checkpoint`. 27 | -------------------------------------------------------------------------------- /Warmup_with_SimSiam/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /Warmup_with_SimSiam/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Warmup_with_SimSiam/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Bootstrap Your Own Latent (BYOL), in Pytorch 4 | 5 | [![PyPI version](https://badge.fury.io/py/byol-pytorch.svg)](https://badge.fury.io/py/byol-pytorch) 6 | 7 | Practical implementation of an astoundingly simple method for self-supervised learning that achieves a new state of the art (surpassing SimCLR) without contrastive learning and having to designate negative pairs. 8 | 9 | This repository offers a module that one can easily wrap any image-based neural network (residual network, discriminator, policy network) to immediately start benefitting from unlabelled image data. 10 | 11 | Update 1: There is now new evidence that batch normalization is key to making this technique work well 12 | 13 | Update 2: A new paper has successfully replaced batch norm with group norm + weight standardization, refuting that batch statistics are needed for BYOL to work 14 | 15 | Yannic Kilcher's excellent explanation 16 | 17 | ## Install 18 | 19 | ```bash 20 | $ pip install byol-pytorch 21 | ``` 22 | 23 | ## Usage 24 | 25 | Simply plugin your neural network, specifying (1) the image dimensions as well as (2) the name (or index) of the hidden layer, whose output is used as the latent representation used for self-supervised training. 26 | 27 | ```python 28 | import torch 29 | from byol_pytorch import BYOL 30 | from torchvision import models 31 | 32 | resnet = models.resnet50(pretrained=True) 33 | 34 | learner = BYOL( 35 | resnet, 36 | image_size = 256, 37 | hidden_layer = 'avgpool' 38 | ) 39 | 40 | opt = torch.optim.Adam(learner.parameters(), lr=3e-4) 41 | 42 | def sample_unlabelled_images(): 43 | return torch.randn(20, 3, 256, 256) 44 | 45 | for _ in range(100): 46 | images = sample_unlabelled_images() 47 | loss = learner(images) 48 | opt.zero_grad() 49 | loss.backward() 50 | opt.step() 51 | learner.update_moving_average() # update moving average of target encoder 52 | 53 | # save your improved network 54 | torch.save(resnet.state_dict(), './improved-net.pt') 55 | ``` 56 | 57 | That's pretty much it. After much training, the residual network should now perform better on its downstream supervised tasks. 58 | 59 | ## BYOL → SimSiam 60 | 61 | A new paper from Kaiming He suggests that BYOL does not even need the target encoder to be an exponential moving average of the online encoder. I've decided to build in this option so that you can easily use that variant for training, simply by setting the `use_momentum` flag to `False`. You will no longer need to invoke `update_moving_average` if you go this route as shown in the example below. 62 | 63 | ```python 64 | import torch 65 | from byol_pytorch import BYOL 66 | from torchvision import models 67 | 68 | resnet = models.resnet50(pretrained=True) 69 | 70 | learner = BYOL( 71 | resnet, 72 | image_size = 256, 73 | hidden_layer = 'avgpool', 74 | use_momentum = False # turn off momentum in the target encoder 75 | ) 76 | 77 | opt = torch.optim.Adam(learner.parameters(), lr=3e-4) 78 | 79 | def sample_unlabelled_images(): 80 | return torch.randn(20, 3, 256, 256) 81 | 82 | for _ in range(100): 83 | images = sample_unlabelled_images() 84 | loss = learner(images) 85 | opt.zero_grad() 86 | loss.backward() 87 | opt.step() 88 | 89 | # save your improved network 90 | torch.save(resnet.state_dict(), './improved-net.pt') 91 | ``` 92 | 93 | ## Advanced 94 | 95 | While the hyperparameters have already been set to what the paper has found optimal, you can change them with extra keyword arguments to the base wrapper class. 96 | 97 | ```python 98 | learner = BYOL( 99 | resnet, 100 | image_size = 256, 101 | hidden_layer = 'avgpool', 102 | projection_size = 256, # the projection size 103 | projection_hidden_size = 4096, # the hidden dimension of the MLP for both the projection and prediction 104 | moving_average_decay = 0.99 # the moving average decay factor for the target encoder, already set at what paper recommends 105 | ) 106 | ``` 107 | 108 | By default, this library will use the augmentations from the SimCLR paper (which is also used in the BYOL paper). However, if you would like to specify your own augmentation pipeline, you can simply pass in your own custom augmentation function with the `augment_fn` keyword. 109 | 110 | Augmentations must work in the tensor space. `kornia` library is highly recommended for this. If you decide to use torchvision augmentations, make sure the tensor is first converted to PIL `.toPILImage()`, and then back to tensors `.ToTensor()` 111 | 112 | ```python 113 | augment_fn = nn.Sequential( 114 | kornia.augmentation.RandomHorizontalFlip() 115 | ) 116 | 117 | learner = BYOL( 118 | resnet, 119 | image_size = 256, 120 | hidden_layer = -2, 121 | augment_fn = augment_fn 122 | ) 123 | ``` 124 | 125 | In the paper, they seem to assure that one of the augmentations have a higher gaussian blur probability than the other. You can also adjust this to your heart's delight. 126 | 127 | ```python 128 | augment_fn = nn.Sequential( 129 | kornia.augmentation.RandomHorizontalFlip() 130 | ) 131 | 132 | augment_fn2 = nn.Sequential( 133 | kornia.augmentation.RandomHorizontalFlip(), 134 | kornia.filters.GaussianBlur2d((3, 3), (1.5, 1.5)) 135 | ) 136 | 137 | learner = BYOL( 138 | resnet, 139 | image_size = 256, 140 | hidden_layer = -2, 141 | augment_fn = augment_fn, 142 | augment_fn2 = augment_fn2, 143 | ) 144 | ``` 145 | 146 | ## Citation 147 | 148 | ```bibtex 149 | @misc{grill2020bootstrap, 150 | title = {Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning}, 151 | author = {Jean-Bastien Grill and Florian Strub and Florent Altché and Corentin Tallec and Pierre H. Richemond and Elena Buchatskaya and Carl Doersch and Bernardo Avila Pires and Zhaohan Daniel Guo and Mohammad Gheshlaghi Azar and Bilal Piot and Koray Kavukcuoglu and Rémi Munos and Michal Valko}, 152 | year = {2020}, 153 | eprint = {2006.07733}, 154 | archivePrefix = {arXiv}, 155 | primaryClass = {cs.LG} 156 | } 157 | ``` 158 | 159 | ```bibtex 160 | @misc{chen2020exploring, 161 | title={Exploring Simple Siamese Representation Learning}, 162 | author={Xinlei Chen and Kaiming He}, 163 | year={2020}, 164 | eprint={2011.10566}, 165 | archivePrefix={arXiv}, 166 | primaryClass={cs.CV} 167 | } 168 | ``` 169 | -------------------------------------------------------------------------------- /Warmup_with_SimSiam/byol_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from byol_pytorch.byol_pytorch import BYOL 2 | -------------------------------------------------------------------------------- /Warmup_with_SimSiam/byol_pytorch/byol_pytorch.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | from functools import wraps 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | from kornia import augmentation as augs 10 | from kornia import filters, color 11 | 12 | # helper functions 13 | 14 | def default(val, def_val): 15 | return def_val if val is None else val 16 | 17 | def flatten(t): 18 | return t.reshape(t.shape[0], -1) 19 | 20 | def singleton(cache_key): 21 | def inner_fn(fn): 22 | @wraps(fn) 23 | def wrapper(self, *args, **kwargs): 24 | instance = getattr(self, cache_key) 25 | if instance is not None: 26 | return instance 27 | 28 | instance = fn(self, *args, **kwargs) 29 | setattr(self, cache_key, instance) 30 | return instance 31 | return wrapper 32 | return inner_fn 33 | 34 | def get_module_device(module): 35 | return next(module.parameters()).device 36 | 37 | def set_requires_grad(model, val): 38 | for p in model.parameters(): 39 | p.requires_grad = val 40 | 41 | # loss fn 42 | 43 | def loss_fn(x, y): 44 | x = F.normalize(x, dim=1, p=2) 45 | y = F.normalize(y, dim=1, p=2) 46 | return 2 - 2 * (x * y).sum(dim=1) 47 | 48 | # augmentation utils 49 | 50 | class RandomApply(nn.Module): 51 | def __init__(self, fn, p): 52 | super().__init__() 53 | self.fn = fn 54 | self.p = p 55 | def forward(self, x): 56 | if random.random() > self.p: 57 | return x 58 | return self.fn(x) 59 | 60 | # exponential moving average 61 | 62 | class EMA(): 63 | def __init__(self, beta): 64 | super().__init__() 65 | self.beta = beta 66 | 67 | def update_average(self, old, new): 68 | if old is None: 69 | return new 70 | return old * self.beta + (1 - self.beta) * new 71 | 72 | def update_moving_average(ema_updater, ma_model, current_model): 73 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): 74 | old_weight, up_weight = ma_params.data, current_params.data 75 | ma_params.data = ema_updater.update_average(old_weight, up_weight) 76 | 77 | # MLP class for projector and predictor 78 | 79 | class MLP(nn.Module): 80 | def __init__(self, dim, projection_size, hidden_size = 4096): 81 | super().__init__() 82 | self.net = nn.Sequential( 83 | nn.Linear(dim, hidden_size), 84 | nn.BatchNorm1d(hidden_size), 85 | nn.ReLU(inplace=True), 86 | nn.Linear(hidden_size, projection_size) 87 | ) 88 | 89 | def forward(self, x): 90 | return self.net(x) 91 | 92 | # a wrapper class for the base neural network 93 | # will manage the interception of the hidden layer output 94 | # and pipe it into the projecter and predictor nets 95 | 96 | class NetWrapper(nn.Module): 97 | def __init__(self, net, projection_size, projection_hidden_size, layer = -2): 98 | super().__init__() 99 | self.net = net 100 | self.layer = layer 101 | 102 | self.projector = None 103 | self.projection_size = projection_size 104 | self.projection_hidden_size = projection_hidden_size 105 | 106 | self.hidden = None 107 | self.hook_registered = False 108 | 109 | def _find_layer(self): 110 | if type(self.layer) == str: 111 | modules = dict([*self.net.named_modules()]) 112 | return modules.get(self.layer, None) 113 | elif type(self.layer) == int: 114 | children = [*self.net.children()] 115 | return children[self.layer] 116 | return None 117 | 118 | def _hook(self, _, __, output): 119 | self.hidden = flatten(output) 120 | 121 | def _register_hook(self): 122 | layer = self._find_layer() 123 | assert layer is not None, f'hidden layer ({self.layer}) not found' 124 | handle = layer.register_forward_hook(self._hook) 125 | self.hook_registered = True 126 | 127 | @singleton('projector') 128 | def _get_projector(self, hidden): 129 | _, dim = hidden.shape 130 | projector = MLP(dim, self.projection_size, self.projection_hidden_size) 131 | return projector.to(hidden) 132 | 133 | def get_representation(self, x): 134 | if self.layer == -1: 135 | return self.net(x) 136 | 137 | if not self.hook_registered: 138 | self._register_hook() 139 | 140 | _ = self.net(x) 141 | hidden = self.hidden 142 | self.hidden = None 143 | assert hidden is not None, f'hidden layer {self.layer} never emitted an output' 144 | return hidden 145 | 146 | def forward(self, x): 147 | representation = self.get_representation(x) 148 | projector = self._get_projector(representation) 149 | projection = projector(representation) 150 | return projection 151 | 152 | # main class 153 | 154 | class BYOL(nn.Module): 155 | def __init__( 156 | self, 157 | encoder, 158 | predictor, 159 | image_size, 160 | hidden_layer = -2, 161 | projection_size = 256, 162 | projection_hidden_size = 4096, 163 | augment_fn = None, 164 | augment_fn2 = None, 165 | moving_average_decay = 0.99, 166 | use_momentum = True 167 | ): 168 | super().__init__() 169 | 170 | # default SimCLR augmentation 171 | 172 | DEFAULT_AUG = nn.Sequential( 173 | RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), 174 | augs.RandomGrayscale(p=0.2), 175 | # augs.RandomHorizontalFlip(), 176 | RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), 177 | # augs.RandomResizedCrop((image_size, image_size)), 178 | # augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])) 179 | ) 180 | 181 | self.augment1 = default(augment_fn, DEFAULT_AUG) 182 | self.augment2 = default(augment_fn2, self.augment1) 183 | 184 | # self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer) 185 | self.online_encoder = encoder 186 | 187 | self.use_momentum = use_momentum 188 | self.target_encoder = None 189 | self.target_ema_updater = EMA(moving_average_decay) 190 | 191 | self.online_predictor = predictor 192 | 193 | # get device of network and make wrapper same device 194 | # device = get_module_device(net) 195 | device = torch.device(2) 196 | self.to(device) 197 | 198 | # send a mock image tensor to instantiate singleton parameters 199 | self.forward(torch.randn(2, 3, image_size, image_size, device=device)) 200 | 201 | @singleton('target_encoder') 202 | def _get_target_encoder(self): 203 | target_encoder = copy.deepcopy(self.online_encoder) 204 | set_requires_grad(target_encoder, False) 205 | return target_encoder 206 | 207 | def reset_moving_average(self): 208 | del self.target_encoder 209 | self.target_encoder = None 210 | 211 | def update_moving_average(self): 212 | assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder' 213 | assert self.target_encoder is not None, 'target encoder has not been created yet' 214 | update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder) 215 | 216 | def forward(self, x): 217 | image_one, image_two = self.augment1(x), self.augment2(x) 218 | 219 | online_proj_one = self.online_encoder(image_one) 220 | online_proj_two = self.online_encoder(image_two) 221 | 222 | online_pred_one = self.online_predictor(online_proj_one) 223 | online_pred_two = self.online_predictor(online_proj_two) 224 | 225 | with torch.no_grad(): 226 | target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder 227 | target_proj_one = target_encoder(image_one).detach() 228 | target_proj_two = target_encoder(image_two).detach() 229 | 230 | loss_one = loss_fn(online_pred_one, target_proj_two.detach()) 231 | loss_two = loss_fn(online_pred_two, target_proj_one.detach()) 232 | 233 | loss = loss_one + loss_two 234 | return loss.mean() 235 | -------------------------------------------------------------------------------- /Warmup_with_SimSiam/byol_pytorch/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | from PIL import Image 4 | import torchvision.transforms as transforms 5 | from torch.utils.data import Dataset, DataLoader 6 | 7 | def pil_loader(path): 8 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 9 | with open(path, 'rb') as f: 10 | img = Image.open(f) 11 | return img.convert('RGB') 12 | 13 | 14 | class SegDataset(Dataset): 15 | """Face Landmarks dataset.""" 16 | 17 | def __init__(self, root_dir): 18 | """ 19 | Args: 20 | root_dir (string): Directory with all the images. 21 | """ 22 | self.root_dir = root_dir 23 | self.image_paths = sorted(os.listdir(root_dir)) 24 | self.transform = transforms.Compose([transforms.Resize((1024, 1024)), transforms.ToTensor()]) 25 | 26 | def __len__(self): 27 | return len(self.image_paths) 28 | 29 | def __getitem__(self, idx): 30 | image_name = self.image_paths[idx] 31 | img_name = os.path.join(self.root_dir, image_name) 32 | 33 | image = pil_loader(img_name) 34 | 35 | return self.transform(image) 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /Warmup_with_SimSiam/byol_pytorch/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | class FCN(nn.Module): 6 | """ 7 | Proposed Fully Convolutional Network 8 | This function/module uses fully convolutional blocks to extract pixel-wise image features. 9 | Tested on 1024*1024, 512*512 resolution; RGB, Immunohistochemical color channels 10 | 11 | Keyword arguments: 12 | input_dim -- input channel, 3 for RGB images (default) 13 | """ 14 | def __init__(self, input_dim, output_classes, p_mode = 'replicate'): 15 | super(FCN, self).__init__() 16 | #self.Dropout = nn.Dropout(p=0.05) 17 | self.conv1 = nn.Conv2d(input_dim, 32, kernel_size=3, stride=1, padding=1 ,padding_mode=p_mode) 18 | self.bn1 = nn.BatchNorm2d(32) 19 | 20 | self.conv2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, padding_mode=p_mode) 21 | self.bn2 = nn.BatchNorm2d(32) 22 | 23 | self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, padding_mode=p_mode) 24 | self.bn3 = nn.BatchNorm2d(64) 25 | 26 | self.conv4 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, padding_mode=p_mode) 27 | self.bn4 = nn.BatchNorm2d(64) 28 | 29 | self.conv5 = nn.Conv2d(64, output_classes, kernel_size=1, stride=1, padding=0) 30 | 31 | #self.Dropout = nn.Dropout(p=0.3) 32 | 33 | def forward(self, x): 34 | x = self.conv1(x) 35 | x = F.relu(x) 36 | x = self.bn1(x) 37 | #x = self.Dropout(x) 38 | 39 | x = self.conv2(x) 40 | x = F.relu(x) 41 | x = self.bn2(x) 42 | #x = self.Dropout(x) 43 | 44 | x = self.conv3(x) 45 | x = F.relu(x) 46 | x = self.bn3(x) 47 | #x = self.Dropout(x) 48 | 49 | x = self.conv4(x) 50 | x = F.relu(x) 51 | x = self.bn4(x) 52 | 53 | x = self.conv5(x) 54 | return x 55 | 56 | class Predictor(nn.Module): 57 | def __init__(self, dim, p_mode = 'replicate'): 58 | super(Predictor, self).__init__() 59 | self.conv4 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, padding_mode=p_mode) 60 | self.bn4 = nn.BatchNorm2d(dim) 61 | 62 | self.conv5 = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0) 63 | 64 | def forward(self, x): 65 | x = self.conv4(x) 66 | x = F.relu(x) 67 | x = self.bn4(x) 68 | 69 | x = self.conv5(x) 70 | return x 71 | 72 | 73 | class GCN(torch.nn.Module): 74 | """ 75 | Proposed Graph Convolutional Network 76 | This function/module uses classic GCN layers to generate superpixels(nodes) classification. 77 | --"Semi-Supervised Classification with Graph Convolutional Networks", 78 | --Thomas N. Kipf, Max Welling, ICLR2017 79 | 80 | Keyword arguments: 81 | input_dim -- input channel, aligns with output channel from FCN 82 | output_classes --output channel, default 1 for our proposed loss function 83 | """ 84 | def __init__(self, input_dim, output_classes): 85 | super(GCN, self).__init__() 86 | self.conv1 = GCNConv(input_dim, 64) 87 | self.conv2 = GCNConv(64, 128) 88 | self.conv3 = GCNConv(128, 256) 89 | self.conv4 = GCNConv(256, 64) 90 | self.conv5 = GCNConv(64, output_classes) 91 | #self.Dropout = nn.Dropout(p=0.5) 92 | 93 | # self.bn1 = nn.BatchNorm1d(64) 94 | # self.bn2 = nn.BatchNorm1d(128) 95 | # self.bn3 = nn.BatchNorm1d(256) 96 | # self.bn4 = nn.BatchNorm1d(64) 97 | # 98 | # self.lin1 = Linear(64, 256) 99 | # self.lin2 = Linear(256, 128) 100 | # self.lin3 = Linear(128, output_classes) 101 | 102 | def forward(self, data): 103 | x = self.conv1(data.x, edge_index = data.edge_index, edge_weight = data.edge_weight) 104 | x = F.relu(x) 105 | #x = self.Dropout(x) 106 | #x = self.bn1(x) 107 | 108 | x = self.conv2(x, edge_index = data.edge_index, edge_weight = data.edge_weight) 109 | x = F.relu(x) 110 | #x = self.Dropout(x) 111 | #x = self.bn2(x) 112 | 113 | x = self.conv3(x, edge_index = data.edge_index, edge_weight = data.edge_weight) 114 | x = F.relu(x) 115 | #x = self.Dropout(x) 116 | #x = self.bn3(x) 117 | 118 | x = self.conv4(x, edge_index = data.edge_index, edge_weight = data.edge_weight) 119 | x = F.relu(x) 120 | #x = self.bn4(x) 121 | 122 | x = self.conv5(x, edge_index = data.edge_index, edge_weight = data.edge_weight) 123 | 124 | 125 | return torch.tanh(x) 126 | -------------------------------------------------------------------------------- /Warmup_with_SimSiam/diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/Warmup_with_SimSiam/diagram.png -------------------------------------------------------------------------------- /Warmup_with_SimSiam/examples/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | from PIL import Image 4 | import torchvision.transforms as transforms 5 | from torch.utils.data import Dataset, DataLoader 6 | 7 | def pil_loader(path): 8 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 9 | with open(path, 'rb') as f: 10 | img = Image.open(f) 11 | return img.convert('RGB') 12 | 13 | 14 | class SegDataset(Dataset): 15 | """Face Landmarks dataset.""" 16 | 17 | def __init__(self, root_dir): 18 | """ 19 | Args: 20 | root_dir (string): Directory with all the images. 21 | """ 22 | self.root_dir = root_dir 23 | self.image_paths = sorted(os.listdir(root_dir)) 24 | self.transform = transforms.Compose([transforms.ToTensor()]) 25 | 26 | def __len__(self): 27 | return len(self.landmarks_frame) 28 | 29 | def __getitem__(self, idx): 30 | image_name = self.image_paths[idx] 31 | img_name = os.path.join(self.root_dir, image_name) 32 | 33 | image = pil_loader(img_name) 34 | 35 | return self.transform(image) 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /Warmup_with_SimSiam/examples/lightning/README.md: -------------------------------------------------------------------------------- 1 | ## Pytorch-lightning example script 2 | 3 | ### Requirements 4 | 5 | ```bash 6 | $ pip install pytorch-lightning 7 | $ pip install pillow 8 | ``` 9 | 10 | ### Run 11 | 12 | ```bash 13 | $ python train.py --image_folder /path/to/your/images 14 | ``` 15 | -------------------------------------------------------------------------------- /Warmup_with_SimSiam/examples/lightning/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import multiprocessing 4 | from pathlib import Path 5 | from PIL import Image 6 | 7 | import torch 8 | from torchvision import models, transforms 9 | from torch.utils.data import DataLoader, Dataset 10 | 11 | from byol_pytorch import BYOL 12 | import pytorch_lightning as pl 13 | 14 | # test model, a resnet 50 15 | 16 | resnet = models.resnet50(pretrained=True) 17 | 18 | # arguments 19 | 20 | parser = argparse.ArgumentParser(description='byol-lightning-test') 21 | 22 | parser.add_argument('--image_folder', type=str, required = True, 23 | help='path to your folder of images for self-supervised learning') 24 | 25 | args = parser.parse_args() 26 | 27 | # constants 28 | 29 | BATCH_SIZE = 32 30 | EPOCHS = 1000 31 | LR = 3e-4 32 | NUM_GPUS = 2 33 | IMAGE_SIZE = 256 34 | IMAGE_EXTS = ['.jpg', '.png', '.jpeg'] 35 | NUM_WORKERS = multiprocessing.cpu_count() 36 | 37 | # pytorch lightning module 38 | 39 | class SelfSupervisedLearner(pl.LightningModule): 40 | def __init__(self, net, **kwargs): 41 | super().__init__() 42 | self.learner = BYOL(net, **kwargs) 43 | 44 | def forward(self, images): 45 | return self.learner(images) 46 | 47 | def training_step(self, images, _): 48 | loss = self.forward(images) 49 | return {'loss': loss} 50 | 51 | def configure_optimizers(self): 52 | return torch.optim.Adam(self.parameters(), lr=LR) 53 | 54 | def on_before_zero_grad(self, _): 55 | self.learner.update_moving_average() 56 | 57 | # images dataset 58 | 59 | def expand_greyscale(t): 60 | return t.expand(3, -1, -1) 61 | 62 | class ImagesDataset(Dataset): 63 | def __init__(self, folder, image_size): 64 | super().__init__() 65 | self.folder = folder 66 | self.paths = [] 67 | 68 | for path in Path(f'{folder}').glob('**/*'): 69 | _, ext = os.path.splitext(path) 70 | if ext.lower() in IMAGE_EXTS: 71 | self.paths.append(path) 72 | 73 | print(f'{len(self.paths)} images found') 74 | 75 | self.transform = transforms.Compose([ 76 | transforms.Resize(image_size), 77 | transforms.CenterCrop(image_size), 78 | transforms.ToTensor(), 79 | transforms.Lambda(expand_greyscale) 80 | ]) 81 | 82 | def __len__(self): 83 | return len(self.paths) 84 | 85 | def __getitem__(self, index): 86 | path = self.paths[index] 87 | img = Image.open(path) 88 | img = img.convert('RGB') 89 | return self.transform(img) 90 | 91 | # main 92 | 93 | if __name__ == '__main__': 94 | ds = ImagesDataset(args.image_folder, IMAGE_SIZE) 95 | train_loader = DataLoader(ds, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True) 96 | 97 | model = SelfSupervisedLearner( 98 | resnet, 99 | image_size = IMAGE_SIZE, 100 | hidden_layer = 'avgpool', 101 | projection_size = 256, 102 | projection_hidden_size = 4096, 103 | moving_average_decay = 0.99 104 | ) 105 | 106 | trainer = pl.Trainer( 107 | gpus = NUM_GPUS, 108 | max_epochs = EPOCHS, 109 | accumulate_grad_batches = 1 110 | ) 111 | 112 | trainer.fit(model, train_loader) 113 | -------------------------------------------------------------------------------- /Warmup_with_SimSiam/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'byol-pytorch', 5 | packages = find_packages(exclude=['examples']), 6 | version = '0.4.0', 7 | license='MIT', 8 | description = 'Self-supervised contrastive learning made simple', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | url = 'https://github.com/lucidrains/byol-pytorch', 12 | keywords = ['self-supervised learning', 'artificial intelligence'], 13 | install_requires=[ 14 | 'torch>=1.6', 15 | 'kornia>=0.4.0' 16 | ], 17 | classifiers=[ 18 | 'Development Status :: 4 - Beta', 19 | 'Intended Audience :: Developers', 20 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 21 | 'License :: OSI Approved :: MIT License', 22 | 'Programming Language :: Python :: 3.6', 23 | ], 24 | ) -------------------------------------------------------------------------------- /Warmup_with_SimSiam/warmup.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from byol_pytorch import BYOL 3 | from byol_pytorch.model import FCN, Predictor 4 | from byol_pytorch.dataset import SegDataset 5 | 6 | encoder = FCN(input_dim=3, output_classes=24) 7 | predictor = Predictor(dim=24) 8 | 9 | learner = BYOL( 10 | encoder=encoder, 11 | predictor=predictor, 12 | image_size = 1024, 13 | hidden_layer = 'avgpool', 14 | use_momentum = False 15 | ) 16 | 17 | opt = torch.optim.Adam(learner.parameters(), lr=3e-4) 18 | 19 | imagenet = SegDataset(root_dir="/images/HER2/images") 20 | dataloader = torch.utils.data.DataLoader(imagenet, batch_size=2, shuffle=True, num_workers=0) 21 | device = torch.device(2) 22 | 23 | for _ in range(100): 24 | print("start epoch, ", _) 25 | for local_batch in dataloader: 26 | local_batch = local_batch.to(device) 27 | loss = learner(local_batch) 28 | opt.zero_grad() 29 | loss.backward() 30 | opt.step() 31 | # learner.update_moving_average() # update moving average of target encoder 32 | torch.cuda.empty_cache() 33 | torch.save(encoder.state_dict(), 'checkpoints/improved-net_{}.pt'.format(_)) 34 | 35 | # save your improved network 36 | -------------------------------------------------------------------------------- /args_parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser(description='PyTorch Weakly Supervised Segmentation') 4 | parser.add_argument('--nChannel', metavar='N', default=24, type=int, 5 | help='number of channels of histogram') 6 | parser.add_argument('--maxIter', metavar='T', default=10000, type=int, 7 | help='number of maximum iterations') 8 | parser.add_argument('--checkpoint', metavar='ckpt number', default=-1, type=int, 9 | help='continue training from checkpoint #') 10 | parser.add_argument('--num_superpixels', metavar='K', default=7000, type=int, 11 | help='number of initial superpixels') 12 | parser.add_argument('--compactness', metavar='C', default=3, type=float, 13 | help='compactness of superpixels') 14 | parser.add_argument('--visualize', metavar='1 or 0', default=0, type=int, 15 | help='0:display image using OPENCV | 1:save images to train-path') 16 | parser.add_argument('--color_channel_separation', metavar='True or False', default=False, type=bool, 17 | help='Immunohistochemical staining colors separation toggle') 18 | parser.add_argument('--half-precision', metavar='True or False', default=False, type=bool, 19 | help='Half precision training, requires torch 0.4.0 and apex from nVidia') 20 | parser.add_argument('--optimizer', default='Adam', help='optimizer(SGD|Adam)') 21 | parser.add_argument('--gcn-lr', default=0.0005, type=float, help='gcn learning rate') 22 | parser.add_argument('--fcn-lr', default=0.0001, type=float, help='fcn learning rate') 23 | parser.add_argument('-b','--batch-size', default=2, type=int, help='batch size') 24 | parser.add_argument('-t','--cpu-threads', default=8, type=int, help='number of threads for multiprocessing') 25 | parser.add_argument('--switch-iter', default=13, type=int, help='switch GCN into small \ 26 | batch training after # of iterations') 27 | parser.add_argument('--adjust-iter', default=0.1, type=float, help='each iteration, \ 28 | global_segments *= (1+/- slic_adjust_ratio)') 29 | parser.add_argument('--weight-ratio', default=1.5, type=float, help='edge weight complementing \ 30 | (to 1) ratio, decreasing with training') 31 | parser.add_argument('--warmup-threshold', default=2, type=float, help='when FCN warmup loss \ 32 | reaches #, terminate warmup and start training GCN') 33 | parser.add_argument('--output-size', default=1024, type=int, help='The resolution along one axis\ 34 | of the image. input images will be scaled to the size defined.') 35 | parser.add_argument('--fuse-thresh', default=0.001, type=float, help='the cutoff to use when \ 36 | outputting the final fused mask in inference') 37 | parser.add_argument('--train-path', type=str, 38 | default="./train_process_files", 39 | help='a folder to save train progress visualization', 40 | required=False) 41 | parser.add_argument('--input-path', type=str, 42 | default="./input_data", 43 | help='training set containing the labeled images', 44 | required=False) 45 | parser.add_argument('--checkpoint-path', type=str, 46 | default='./state_dict', 47 | help='path where the checkpoints are saved to', 48 | required=False) 49 | parser.add_argument('--inference-path', type=str, 50 | help='path where the inferenced masks are saved to', 51 | required=False) 52 | args = parser.parse_args() 53 | -------------------------------------------------------------------------------- /finaledgegcncopy.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import torch 3 | from multiprocessing import Pool 4 | import math 5 | from torch.nn import Sequential, Linear, ReLU 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | from torch_geometric.nn import DataParallel as GeoParallel 10 | from torch_geometric.data import Data, Batch 11 | from torchvision import datasets, transforms 12 | import torchvision 13 | import torch.utils.data as data 14 | from torch.autograd import Variable 15 | from skimage.color import rgb2hed 16 | from skimage.util import img_as_float 17 | from skimage.segmentation import slic, mark_boundaries 18 | from skimage.exposure import rescale_intensity 19 | from skimage import io 20 | import cv2 21 | import sys 22 | import numpy as np 23 | import random 24 | from skimage import segmentation 25 | import matplotlib.pyplot as plt 26 | import torch.nn.init 27 | from PIL import ImageFile, Image 28 | import os 29 | from collections import defaultdict 30 | from utilities import * 31 | from model import * 32 | from args_parser import * 33 | 34 | 35 | #amp_handle = amp.init(enabled=True) 36 | 37 | global_segments = 0 38 | if args.half_precision: from apex import amp 39 | ImageFile.LOAD_TRUNCATED_IMAGES = True 40 | ##CUDA_LAUNCH_BLOCKING=1 #if cuda fails to backpropagate use this toggle to debug 41 | OUTPUT_SIZE = args.output_size 42 | use_cuda = torch.cuda.is_available() 43 | 44 | SAVE_PATH = args.train_path 45 | SD_SAVE_PATH = args.checkpoint_path 46 | DATA_PATH = args.input_path 47 | 48 | 49 | 50 | class random_dataloader(torch.utils.data.Dataset): 51 | """ 52 | Proposed Dataset with Random Matched Pairs 53 | This function is a dataset designed for small batch initialization. 54 | When initializing with small batches (< 8), less than ideal GCN convergence may occur. 55 | This dataset loads image patches from two subfolder: ./positives ./negatives 56 | -and combines positive and negative samples by adjustable ratio self.positive_negative_ratio 57 | -in one mini-batch. 58 | 59 | Keyword arguments: 60 | path --parent path which contains ./positives ./negatives subfolders 61 | """ 62 | def __init__(self, path): 63 | super(random_dataloader, self).__init__() 64 | self.path = path 65 | self.positive_images = [names for names in os.listdir(os.path.join(path, 'positives'))] 66 | self.negative_images = [names for names in os.listdir(os.path.join(path, 'negatives'))] 67 | self.positive_image_files = [] 68 | self.negative_image_files = [] 69 | self.toTensor = transforms.ToTensor() 70 | self.positive_negative_ratio = 3 #3 positives and 1 negative in one mini-batch 71 | self.ratio_counter = 0 72 | self.negative_index = np.random.randint(low = 0, high = self.positive_negative_ratio + 1) 73 | for img in self.positive_images: 74 | if img[-4:] is not None and (img[-4:] == '.png' or img[-4:] == '.jpg'): 75 | img_file = os.path.join(os.path.join(path, 'positives'), "%s" % img) 76 | self.positive_image_files.append({ 77 | "img": img_file 78 | 79 | }) 80 | for img in self.negative_images: 81 | if img[-4:] is not None and (img[-4:] == '.png' or img[-4:] == '.jpg'): 82 | img_file = os.path.join(os.path.join(path, 'negatives'), "%s" % img) 83 | self.negative_image_files.append({ 84 | "img": img_file 85 | 86 | }) 87 | 88 | def __len__(self): 89 | return (len(self.positive_image_files) + len(self.negative_image_files)) 90 | 91 | def __getitem__(self, index): 92 | if self.ratio_counter == self.negative_index: 93 | index = index % len(self.negative_images) 94 | data_file = self.negative_image_files[index] 95 | else: 96 | index = index % len(self.positive_images) 97 | data_file = self.positive_image_files[index] 98 | 99 | self.ratio_counter += 1 100 | if self.ratio_counter > self.positive_negative_ratio: 101 | self.negative_index = np.random.randint(low = 0, high = self.positive_negative_ratio + 1) 102 | self.ratio_counter = 0 103 | image = cv2.imread(data_file["img"]) 104 | image = cv2.resize(image, (OUTPUT_SIZE, OUTPUT_SIZE)) 105 | angle = np.random.randint(4) 106 | image = rotate(image, angle) 107 | image = cv2.flip(image, np.random.randint(2) - 1) 108 | if args.color_channel_separation: 109 | ihc_hed = rgb2hed(image) 110 | h = rescale_intensity(ihc_hed[:, :, 0], out_range=(0, 1)) 111 | d = rescale_intensity(ihc_hed[:, :, 2], out_range=(0, 1)) 112 | image = np.dstack((np.zeros_like(h), d, h)) 113 | #image = zdh.transpose(2, 0, 1).astype('float32')/255 114 | name = data_file["img"] 115 | path, file = os.path.split(name) 116 | split_filename = file.split("_") 117 | gt_percent = float(split_filename[0]) 118 | moe = float(split_filename[1]) 119 | image = self.toTensor(image) 120 | return (image, name, gt_percent, moe) 121 | 122 | 123 | class normal_dataloader(torch.utils.data.Dataset): 124 | """ 125 | Typical Dataset 126 | This function loads images from path and returns: 127 | image: image (resized, color channel separated(optional)) 128 | name: full image path(contains image name) 129 | gt_percent: Ground-Truth percent, an image-level weak annotation 130 | moe: Margin of Error, an image-level weak annotation 131 | 132 | Keyword arguments: 133 | path --path which contains *.png or *.jpg image patches 134 | """ 135 | def __init__(self, path): 136 | super(normal_dataloader, self).__init__() 137 | self.path = path 138 | self.images = [names for names in os.listdir(path)] 139 | self.image_files = [] 140 | self.toTensor = transforms.ToTensor() 141 | 142 | for img in self.images: 143 | if img[-4:] is not None and (img[-4:] == '.png' or img[-4:] == '.jpg'): 144 | img_file = os.path.join(path, "%s" % img) 145 | self.image_files.append({ 146 | "img": img_file, 147 | "label": "1" 148 | }) 149 | 150 | def __len__(self): 151 | return len(self.image_files) 152 | 153 | def __getitem__(self, index): 154 | index = index % len(self.image_files) 155 | data_file = self.image_files[index] 156 | image = cv2.imread(data_file["img"]) 157 | image = cv2.resize(image, (OUTPUT_SIZE, OUTPUT_SIZE)) 158 | angle = np.random.randint(4) 159 | image = rotate(image, angle) 160 | image = cv2.flip(image, np.random.randint(2) - 1) 161 | if args.color_channel_separation: 162 | ihc_hed = rgb2hed(image) 163 | h = rescale_intensity(ihc_hed[:, :, 0], out_range=(0, 1)) 164 | d = rescale_intensity(ihc_hed[:, :, 2], out_range=(0, 1)) 165 | image = np.dstack((np.zeros_like(h), d, h)) 166 | #image = zdh.transpose(2, 0, 1).astype('float32')/255 167 | name = data_file["img"] 168 | path, file = os.path.split(name) 169 | split_filename = file.split("_") 170 | gt_percent = float(split_filename[0]) 171 | moe = float(split_filename[1]) 172 | image = self.toTensor(image) 173 | return (image, name, gt_percent, moe) 174 | 175 | 176 | class inference_dataset(torch.utils.data.Dataset): 177 | """ 178 | dataset loader used when inferencing 179 | """ 180 | def __init__(self, path): 181 | super(inference_dataset, self).__init__() 182 | self.path = path 183 | self.images = [names for names in os.listdir(path)] 184 | self.image_files = [] 185 | self.toTensor = transforms.ToTensor() 186 | 187 | for img in self.images: 188 | if img[-4:] is not None and (img[-4:] == '.jpg' or img[-4:] == '.png'): 189 | img_file = os.path.join(path, "%s" % img) 190 | #print(img) 191 | #print(img_file) 192 | self.image_files.append({ 193 | "img": img_file, 194 | "label": "1" 195 | }) 196 | 197 | def __len__(self): 198 | return len(self.image_files) 199 | 200 | def __getitem__(self, index): 201 | index = index % len(self.image_files) 202 | data_file = self.image_files[index] 203 | image = cv2.imread(data_file["img"]) 204 | image = cv2.resize(image, (OUTPUT_SIZE, OUTPUT_SIZE)) 205 | name = data_file["img"] 206 | image = self.toTensor(image) 207 | return (image, name, 0,0) 208 | 209 | 210 | def load_dataset(batch_size): 211 | """ 212 | Typical Pytorch Dataloader 213 | This function loades image from Dataset 214 | 215 | Keyword arguments: 216 | path --please refer to class random_dataloader and class normal_dataloader 217 | """ 218 | data_path = DATA_PATH 219 | train_dataset = normal_dataloader(path = data_path) 220 | train_loader = torch.utils.data.DataLoader( 221 | train_dataset, 222 | batch_size= batch_size, # <8*1024 res @ P40 cards, <25*512 res @ 1 P40 card 223 | num_workers= 8, #depends on RAM and CPU 224 | shuffle=True 225 | ) 226 | return train_loader 227 | 228 | def inference_loader(batch_size = 1): 229 | """ 230 | Typical Pytorch Dataloader 231 | This function loades image from Dataset 232 | 233 | Keyword arguments: 234 | path --please refer to class random_dataloader and class normal_dataloader 235 | """ 236 | data_path = DATA_PATH 237 | train_dataset = inference_dataset(path = data_path) 238 | train_loader = torch.utils.data.DataLoader( 239 | train_dataset, 240 | batch_size= batch_size, # <8*1024 res @ P40 cards, <25*512 res @ 1 P40 card 241 | num_workers= 8, #depends on RAM and CPU 242 | shuffle=False 243 | ) 244 | return train_loader 245 | 246 | def multithread_slic(multi_input): 247 | """ 248 | Multi-Thread SLIC superpixel 249 | Since SLIC algorithm is implemented on CPU, we use Python Pool to accelerate 250 | -SLIC algorithm by taking full advantage of CPU's multi-threading capability. 251 | This function also calculate mutable edges' weight as one of GCN's input data. 252 | 253 | 254 | Keyword arguments: 255 | multi_input --tuple lists containing: 256 | (FCN output, Max Channel Response, Mutable Edge Weight Ratio) 257 | """ 258 | (multi_output, max_channel_response, weight_ratio, i) = multi_input 259 | if i > 0: 260 | multi_output_slic = slic(multi_output, n_segments = i, compactness = args.compactness\ 261 | , sigma = 0, multichannel = True) 262 | else: 263 | multi_output_slic = slic(multi_output, n_segments = global_segments, compactness = args.compactness\ 264 | , sigma = 0, multichannel = True) 265 | 266 | num_segments = len(np.unique(multi_output_slic)) 267 | multi_adj = adjacency3(multi_output_slic,num_segments) 268 | ## true_euclidean_distance = [] 269 | ## f_norm = [] 270 | ## mses = [] 271 | chisq = [] 272 | classes_raw = np.zeros((num_segments, args.nChannel),dtype="float32") 273 | for y in range(OUTPUT_SIZE): 274 | for x in range(OUTPUT_SIZE): 275 | curr_index = multi_output_slic[x][y] 276 | max_channel = max_channel_response[x][y] 277 | classes_raw[curr_index][max_channel] += 1.0 278 | 279 | for x in range(len(classes_raw)): 280 | max_in_row = np.amax(classes_raw[x]) 281 | classes_raw[x] = classes_raw[x] / max_in_row 282 | 283 | for (p1, p2) in multi_adj: 284 | p1_class = np.asarray(classes_raw[p1]) 285 | p2_class = np.asarray(classes_raw[p2]) 286 | #true_euclidean_distance.append( euclidean_dist(p1center, p2center) ) 287 | #f_norm.append( fnorm(p1_class,p2_class) ) 288 | chisq.append( np.absolute(chisq_dist(p1_class, p2_class)) ) 289 | #mses.append( mse(p1_class, p2_class) ) 290 | chisq_max_value = np.amax(chisq) 291 | chisq = chisq / chisq_max_value 292 | #complementary_weight = np.ones_like(chisq) - chisq 293 | #chisq = chisq + weight_ratio * complementary_weight 294 | edge_weight = torch.from_numpy(chisq) 295 | return multi_output_slic, multi_adj, edge_weight, num_segments 296 | 297 | 298 | if __name__ == '__main__': 299 | # train 300 | model = FCN(3, args.nChannel) 301 | modelgcn = GCN(args.nChannel, 1) 302 | gcn_batch_iter = 1 #how many iteration on one GCN batch 303 | batch_counter = 0 304 | global_segments = args.num_superpixels #mutable superpixel quantity 305 | change_dataloader = args.switch_iter #switch GCN into small batch training after # of iterations 306 | model_loss = 99999 #variable keeping track of FCN loss during warmup phase 307 | in_GCN = False #True when FCN exiting warmup phase and feeding output into GCN 308 | inference_mode = False 309 | #slic_multiscale_descending = False #True when mutable superpixel quantity is decreasing per iteration 310 | slic_adjust_ratio = args.adjust_iter #each iteration, global_segments *= (1+/- slic_adjust_ratio) 311 | weight_ratio = args.weight_ratio #edge weight complementing (to 1) ratio, decreasing with training 312 | warmup_threshold = args.warmup_threshold# 0.5 #when FCN warmup loss reaches #, terminate warmup and start training GCN 313 | half_precision = args.half_precision 314 | 315 | if args.checkpoint > 0: 316 | #if given a checkpoint to resume training 317 | model.load_state_dict(torch.load(os.path.join(SD_SAVE_PATH, "FCN" + str(args.checkpoint) + ".pt"))) 318 | modelgcn.load_state_dict(torch.load(os.path.join(SD_SAVE_PATH, "GCN" + str(args.checkpoint) + ".pt"))) 319 | batch_counter = int(args.checkpoint) 320 | change_dataloader = batch_counter - 1 321 | in_GCN = True 322 | 323 | if args.inference_path is not None: 324 | dataset_loader = inference_loader() 325 | inference_mode = True 326 | args.maxIter = 1 327 | if args.checkpoint < 1: 328 | print("please define a inference checkpoint using --checkpoint") 329 | exit(-1) 330 | else: 331 | dataset_loader = load_dataset(args.batch_size) 332 | 333 | if not inference_mode and args.cpu_threads > args.batch_size: 334 | args.cpu_threads = args.batch_size 335 | 336 | if use_cuda: 337 | model = model.to('cuda') 338 | modelgcn = modelgcn.to('cuda') 339 | if half_precision: 340 | model, optimizer = amp.initialize(model, optimizer) 341 | modelgcn, optimizergcn = amp.initialize(modelgcn, optimizergcn) 342 | else: 343 | print("model using CPU, please check CUDA settings") 344 | 345 | if use_cuda: 346 | if torch.cuda.device_count() > 1: 347 | print(str(torch.cuda.device_count()) + " GPUs visible") 348 | model = nn.DataParallel(model) 349 | modelgcn = GeoParallel(modelgcn) 350 | if inference_mode: 351 | model.train() 352 | else: 353 | model.train() 354 | modelgcn = modelgcn.float() #necessary for edge_weight initialized training 355 | loss_fn = torch.nn.CrossEntropyLoss() 356 | if args.optimizer == 'SGD': 357 | optimizer = optim.SGD(model.parameters(), lr=args.fcn_lr, momentum=0) 358 | optimizergcn = optim.SGD(modelgcn.parameters(), lr=args.gcn_lr, momentum=0) 359 | elif args.optimizer == 'Adam': 360 | optimizer = optim.Adam(model.parameters(), lr=args.fcn_lr) 361 | optimizergcn = optim.Adam(modelgcn.parameters(), lr=args.gcn_lr) 362 | else: 363 | print("please reselect optimizer, curr value:", args.optimizer) 364 | exit(-1) 365 | 366 | 367 | 368 | 369 | if args.checkpoint > 0: 370 | #if given a checkpoint to resume training 371 | optimizer.load_state_dict(torch.load(os.path.join(SD_SAVE_PATH, "FCNopt" + str(args.checkpoint) + ".pt"))) 372 | optimizergcn.load_state_dict(torch.load(os.path.join(SD_SAVE_PATH, "GCNopt" + str(args.checkpoint) + ".pt"))) 373 | print("training successfully resumed!") 374 | 375 | 376 | #an RGB colors array used to visualize channel response 377 | label_colours = np.array([[0,0,0], [255,255,255],[0,0,255], [0,255,0], 378 | [255,0,0],[128,0,0],[0,128,0],[0,0,128], 379 | [255,255,0], [255,128,0], [128,255,0], 380 | [0,255,255],[255,0,255],[255,255,255], 381 | [128,128,128],[255,0,128],[0,128,255], 382 | [128,0,255],[0,255,128],[100,200,200], 383 | [200,100,100],[200,255,0],[100,255,0], 384 | [200,0,255],[30,99,212],[40,222,100], 385 | [100,200,25],[30,199,20],[0,211,200], 386 | [3,44,122],[23,44,100],[90,22,0],[233,111,222], 387 | [122,122,150],[0,233,149],[3,111,23]]) 388 | 389 | 390 | for epoch in range(args.maxIter): 391 | 392 | """switch large batch into small batch""" 393 | # if batch_counter < change_dataloader: 394 | # dataset_loader = load_dataset(2) 395 | # else: 396 | # gcn_batch_iter = 1 397 | # dataset_loader = load_dataset(2) 398 | # 399 | 400 | for batch_idx, (data, name, gt_percent, moe) in \ 401 | enumerate(dataset_loader): 402 | 403 | if not inference_mode: 404 | print("iteration: " + str(batch_counter) + " epoch: " + str(epoch)) 405 | else: 406 | print("---------------------------------------------------") 407 | print("inferencing ", str(os.path.basename(name[0]))) 408 | if args.visualize: 409 | """visualize using opencv""" 410 | originalimg = cv2.imread(name[0]) 411 | originalimg = cv2.resize(originalimg, (OUTPUT_SIZE,OUTPUT_SIZE)) 412 | cv2.imshow("original", originalimg) 413 | cv2.waitKey(1) 414 | 415 | batch_counter += 1 #records in-epoch progress 416 | if use_cuda: 417 | data = data.to('cuda') 418 | optimizer.zero_grad() 419 | output = model(data) 420 | 421 | 422 | 423 | nLabels = -1 424 | 425 | if not inference_mode and ((not in_GCN and batch_counter % 50 == 0) or (in_GCN and batch_counter % 10 == 0)): 426 | """FCN output visualization, either save to SAVE_PATH or display using opencv""" 427 | #model.eval() 428 | #output = model(data) 429 | ignore, target = torch.max( output, 1 ) 430 | im_target = target.data.cpu().numpy() #label map original 431 | num_in_minibatch = 0 432 | for i in im_target: 433 | im_target = i.flatten() 434 | nLabels = len(np.unique(im_target)) 435 | label_num = rank(im_target) 436 | label_rank = [i[0] for i in label_num] 437 | im_target_rgb = np.array([label_colours [label_rank.index(c)] for c in im_target]) 438 | im_target_rgb = im_target_rgb.reshape( OUTPUT_SIZE, OUTPUT_SIZE, 3 ).astype( np.uint8 ) 439 | curr_filename = name[num_in_minibatch][-16:-4] 440 | if args.visualize: 441 | cv2.imshow("pre", im_target_rgb) 442 | cv2.waitKey(1) 443 | else: 444 | cv2.imwrite(os.path.join(SAVE_PATH, 'PRE' + str(batch_counter) \ 445 | + 'N' + str(num_in_minibatch) + "_" \ 446 | + str(curr_filename) + '.png'), \ 447 | cv2.cvtColor(im_target_rgb, cv2.COLOR_RGB2BGR)) 448 | num_in_minibatch += 1 449 | torch.save(model.module.state_dict(), os.path.join(SD_SAVE_PATH,"FCN" + str(batch_counter) + ".pt")) 450 | 451 | 452 | 453 | if not inference_mode and (model_loss > warmup_threshold and not in_GCN): #stable 0.3 2000epoch 454 | """warning up FCN, loss is the cross entropy between pixel-wise 455 | -max channel responses and FCN model's output""" 456 | ignore, target = torch.max( output, 1 ) 457 | loss = loss_fn(output, target) 458 | if half_precision: 459 | with amp.scale_loss(loss, optimizer) as scaled_loss: 460 | scaled_loss.backward() 461 | else: 462 | loss.backward() 463 | optimizer.step() 464 | if model_loss > loss.data: 465 | model_loss = loss.data 466 | print (epoch, '/', args.maxIter, ':', nLabels, loss.data) 467 | change_dataloader += 1 468 | continue 469 | else: 470 | """when FCN model_loss is below warmup_threshold, enter GCN traning""" 471 | #optimizer = torch.optim.Adam(model.parameters(), lr=0.0005) 472 | #change the learning rate of FCN to a less conservative number 473 | #optimizer = optim.Adam(model.parameters(), lr=0.00005) 474 | #dataset_loader = load_dataset(2) 475 | in_GCN = True 476 | 477 | if model_loss < warmup_threshold or in_GCN: 478 | """Proposed method bridging output of FCN and input to GCN""" 479 | 480 | """mutable superpixel count, changes every training iterations""" 481 | """count changes by a fixed pattern""" 482 | # if slic_multiscale_descending: 483 | # global_segments = int(global_segments * (1 - slic_adjust_ratio)) 484 | # if global_segments < 2000: 485 | # slic_adjust_ratio = slic_adjust_ratio * 0.7 486 | # slic_multiscale_descending = False 487 | # else: 488 | # global_segments = int(global_segments * (1 + slic_adjust_ratio)) 489 | # if global_segments > 8000: 490 | # global_segments = global_segments + 17 491 | # slic_multiscale_descending = True 492 | """count is randomized between [2000, 8000)""" 493 | global_segments = int(random.random() * 6000.0 + 2000.0) 494 | 495 | if not inference_mode: print("slic segments count:" + str(global_segments)) 496 | 497 | gcn_batch_list = [] #a batch which later feed into GCN 498 | segments_list = [] #SLIC segmentations of each image in current GCN batch 499 | batch_node_num = [] #number of nodes(superpixel tiles) that each image have in\ 500 | #current GCN batch 501 | 502 | print("computing multithread slic") 503 | """prepares FCN output for using in , 504 | computes SLIC segmentations, adjacency edges, & edge weight using 505 | multi-threads""" 506 | multi_input = [] 507 | for multi_one_graph in output: 508 | multi_one_graph = multi_one_graph.permute( 1 , 2 , 0 ) 509 | _, max_channel_response = torch.max(multi_one_graph, 2) 510 | multi_one_graph = multi_one_graph.cpu().detach().numpy().astype(np.float64) 511 | max_channel_response = max_channel_response.cpu().detach().numpy() 512 | if not inference_mode: 513 | multi_input.append((multi_one_graph, max_channel_response, weight_ratio, -1)) 514 | else: 515 | for i in (2000,3000,4000,5000,6000,7000,8000): 516 | multi_input.append((multi_one_graph, max_channel_response, weight_ratio, i)) 517 | with Pool(args.cpu_threads) as p: 518 | multi_slic_adj_list = p.map(multithread_slic, multi_input) 519 | print("multithread slic finished") 520 | if weight_ratio > 0.2: 521 | weight_ratio = weight_ratio * 0.99 #reduce edge weight's complementary 522 | 523 | for one_graph, (segments, adj, edge_weight, num_segments) in zip(output, multi_slic_adj_list): 524 | """Bridging FCN's output into GCN's input, initialize GCN batch""" 525 | segments_list.append(segments) 526 | one_graph = one_graph.permute( 1 , 2 , 0 ) 527 | #one_graph.shape: [img_size, img_size, channel_size] 528 | original_one_graph = torch.flatten(one_graph, start_dim = 0, end_dim = 1) 529 | #original_one_graph.shape: [channel_size, img_size*img_size] 530 | one_graph = None 531 | 532 | batch_node_num.append(num_segments) 533 | 534 | slic_pixels = [[] for _ in range (num_segments)] 535 | """slic_pixels stores x,y(flatten) corrdinates according to superpixel's index""" 536 | for y in range (OUTPUT_SIZE): 537 | for x in range (OUTPUT_SIZE): 538 | curr_label = segments[x,y] 539 | slic_pixels[curr_label].append(x * OUTPUT_SIZE + y) 540 | #each slic seg's x y axis 541 | classes = None 542 | 543 | """ 544 | For each superpixel's tile, select the PyTorch Variable(FCN) inside. 545 | These Variables(FCN) combines into new Nodes Variable(GCN) while 546 | -carrying gradients. 547 | """ 548 | for n in slic_pixels: 549 | index_tensor = torch.LongTensor(n) 550 | if use_cuda: 551 | index_tensor = index_tensor.to('cuda') 552 | one_class = torch.unsqueeze(torch.sum( \ 553 | torch.index_select(original_one_graph, dim = 0, index = index_tensor), \ 554 | 0), dim = 0) 555 | if classes is None: 556 | classes = one_class 557 | else: 558 | classes = torch.cat((classes, one_class), 0) 559 | one_class = None 560 | index_tensor = None 561 | original_one_graph = None 562 | temp_ind = 0 563 | adj = np.asarray(adj) 564 | adj = torch.from_numpy(adj) 565 | adj = adj.t().contiguous() 566 | adj = Variable(adj).type(torch.LongTensor) 567 | 568 | """datagcn: GCN-ready wrapped data for one image 569 | gcn_batch_list: GCN-ready minibatch containings same images from 570 | -previous FCN's minibatch.""" 571 | datagcn = Data(x = classes, edge_index = adj, \ 572 | edge_weight = edge_weight.type(torch.FloatTensor)) 573 | 574 | gcn_batch_list.append(datagcn) 575 | #print(gcn_batch_list) 576 | classes = None #releases cached GPU memory immediately 577 | adj = None 578 | datagcn = None 579 | if torch.cuda.device_count() == 1: 580 | gcn_batch_list = Batch.from_data_list(gcn_batch_list) 581 | if use_cuda: 582 | gcn_batch_list = gcn_batch_list.to('cuda') 583 | """GCN training iterations""" 584 | if inference_mode: 585 | modelgcn.train() 586 | else: 587 | modelgcn.train() 588 | print(gt_percent.data.cpu().numpy()) 589 | print(moe.data.cpu().numpy()) 590 | for epochgcn in range(0, gcn_batch_iter): 591 | optimizergcn.zero_grad() 592 | outgcn = modelgcn(gcn_batch_list) 593 | 594 | """visualize GCN output""" 595 | if not inference_mode and epochgcn == gcn_batch_iter - 1 and batch_counter % 10 == 0: 596 | start_index = 0 597 | counter = 0 598 | #modelgcn.eval() 599 | #outgcn = modelgcn(gcn_batch_list) 600 | for curr_batch_idx in range(len(name)): 601 | outgcn_slice = torch.narrow(input = outgcn, dim = 0, \ 602 | start = start_index, \ 603 | length = batch_node_num[curr_batch_idx]) 604 | start_index += batch_node_num[curr_batch_idx] 605 | outputgcn_np = outgcn_slice.detach().data.cpu().numpy() 606 | segments_copy = segments_list[curr_batch_idx].copy() 607 | segments_copy = segments_copy.astype(np.float64) 608 | for segInd in range(len(outputgcn_np)): 609 | segments_copy[segments_copy == segInd] = outputgcn_np[segInd] 610 | gcn_target_rgb = np.array([[255*(c + 1) / 2, 255*(c + 1) / 2, 255*(c + 1) / 2] \ 611 | for c in segments_copy]) 612 | gcn_target_rgb = np.moveaxis(gcn_target_rgb, 1, 2) 613 | gcn_target_rgb = gcn_target_rgb.reshape( (OUTPUT_SIZE,OUTPUT_SIZE,3) ).astype( np.uint8 ) 614 | if args.visualize: 615 | cv2.imshow("gcn", gcn_target_rgb) 616 | cv2.waitKey(1) 617 | else: 618 | cv2.imwrite(os.path.join(SAVE_PATH, 'GCN' + str(batch_counter) + "N" + str(counter)\ 619 | + "_" + str(name[curr_batch_idx][-16:-4]) + '.png'), \ 620 | cv2.cvtColor(gcn_target_rgb, cv2.COLOR_RGB2BGR)) 621 | counter += 1 622 | outgcn_slice = None 623 | 624 | if not inference_mode: 625 | """ 626 | loss_top --cost, positive % of nodes responded correctly 627 | loss_bottom --cost, negative <1-gt_percent-moe> % of nodes responded correctly 628 | positive refers to desired region(cancer), negative refers to other regions(background) 629 | """ 630 | loss_top, loss_bottom = one_label_loss(gt_percent = gt_percent.data.cpu().numpy(), \ 631 | predict = outgcn, \ 632 | moe = moe.data.cpu().numpy(), \ 633 | batch_node_num = batch_node_num) 634 | if loss_top is None: 635 | total_gcn_loss = loss_bottom 636 | elif loss_bottom is None: 637 | total_gcn_loss = loss_top 638 | else: 639 | total_gcn_loss = loss_top + loss_bottom 640 | if half_precision: 641 | with amp.scale_loss(total_gcn_loss, optimizergcn) as scaled_loss2: 642 | scaled_loss2.backward(retain_graph=True) 643 | else: 644 | total_gcn_loss.backward(retain_graph=True) 645 | #backward calculating GCN gradients according to combined loss 646 | #print(total_gcn_loss) 647 | print("GCN+FCN loss: " + str(total_gcn_loss.data.cpu().numpy())) 648 | if not inference_mode: 649 | #backpropagate through GCN's layers 650 | optimizergcn.step() 651 | #backpropagate accumulated gradients through FCN's layers 652 | optimizer.step() 653 | #saving models & optimizers state_dict for later training and inference 654 | #cpu = torch.device("cpu") 655 | #model = model.to(cpu) 656 | #modelgcn = modelgcn.to(cpu) 657 | torch.save(model.module.state_dict(), os.path.join(SD_SAVE_PATH,"FCN" + str(batch_counter) + ".pt")) 658 | torch.save(modelgcn.module.state_dict(), os.path.join(SD_SAVE_PATH,"GCN" + str(batch_counter) + ".pt")) 659 | torch.save(optimizer.state_dict(), os.path.join(SD_SAVE_PATH,"FCNopt" + str(batch_counter) + ".pt")) 660 | torch.save(optimizergcn.state_dict(), os.path.join(SD_SAVE_PATH,"GCNopt" + str(batch_counter) + ".pt")) 661 | #model = model.to('cuda') 662 | #modelgcn = modelgcn.to('cuda') 663 | if inference_mode: 664 | start_index = 0 665 | counter = 0 666 | final_map = np.zeros((args.output_size, args.output_size)) 667 | multi_input = [] 668 | print("fusing") 669 | for input_index in range(len(segments_list)): 670 | outgcn_numpy = torch.narrow(input = outgcn, dim = 0, start = start_index, length = batch_node_num[input_index]).detach().data.cpu().numpy() 671 | segments_copy = segments_list[input_index].astype(np.float64) 672 | multi_input.append((outgcn_numpy, segments_copy)) 673 | start_index += batch_node_num[input_index] 674 | with Pool(args.cpu_threads) as p: 675 | multi_graph = p.map(fuse_results, multi_input) 676 | for graph in multi_graph: 677 | final_map += graph 678 | 679 | final_map = final_map / float(len(multi_graph)) 680 | final_map += 1.0 681 | final_map = final_map / 2.0 682 | final_map[final_map < args.fuse_thresh] = 0 683 | 684 | gcn_target_rgb = np.array([[255 * c , 255* c , 255* c] for c in final_map]) 685 | gcn_target_rgb = np.moveaxis(gcn_target_rgb, 1, 2) 686 | gcn_target_rgb = gcn_target_rgb.reshape( (args.output_size,args.output_size,3) ).astype( np.uint8 ) 687 | if args.visualize: 688 | cv2.imshow("gcn", gcn_target_rgb) 689 | cv2.waitKey(1) 690 | else: 691 | basename = os.path.basename(name[0]) 692 | cv2.imwrite(os.path.join(args.inference_path, basename), cv2.cvtColor(gcn_target_rgb, cv2.COLOR_RGB2BGR)) 693 | print("inference for", str(basename), "saved to", str(args.inference_path)) 694 | -------------------------------------------------------------------------------- /inference/001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/inference/001.png -------------------------------------------------------------------------------- /inference/002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/inference/002.png -------------------------------------------------------------------------------- /inference/003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/inference/003.png -------------------------------------------------------------------------------- /inference/004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/inference/004.png -------------------------------------------------------------------------------- /inference/005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/inference/005.png -------------------------------------------------------------------------------- /inference/006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/inference/006.png -------------------------------------------------------------------------------- /inference/007.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/inference/007.png -------------------------------------------------------------------------------- /input_data/001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/input_data/001.png -------------------------------------------------------------------------------- /input_data/002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/input_data/002.png -------------------------------------------------------------------------------- /input_data/003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/input_data/003.png -------------------------------------------------------------------------------- /input_data/004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/input_data/004.png -------------------------------------------------------------------------------- /input_data/005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/input_data/005.png -------------------------------------------------------------------------------- /input_data/006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/input_data/006.png -------------------------------------------------------------------------------- /input_data/007.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/input_data/007.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn import GCNConv 3 | import torch.nn.functional as F 4 | from torch_geometric.data import Data, Batch 5 | import torch.nn as nn 6 | 7 | class FCN(nn.Module): 8 | """ 9 | Proposed Fully Convolutional Network 10 | This function/module uses fully convolutional blocks to extract pixel-wise image features. 11 | Tested on 1024*1024, 512*512 resolution; RGB, Immunohistochemical color channels 12 | 13 | Keyword arguments: 14 | input_dim -- input channel, 3 for RGB images (default) 15 | """ 16 | def __init__(self,input_dim, output_classes, p_mode = 'replicate'): 17 | super(FCN, self).__init__() 18 | #self.Dropout = nn.Dropout(p=0.05) 19 | self.conv1 = nn.Conv2d(input_dim, 32, kernel_size=3, stride=1, padding=1 ,padding_mode=p_mode) 20 | self.bn1 = nn.BatchNorm2d(32) 21 | 22 | self.conv2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, padding_mode=p_mode) 23 | self.bn2 = nn.BatchNorm2d(32) 24 | 25 | self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, padding_mode=p_mode) 26 | self.bn3 = nn.BatchNorm2d(64) 27 | 28 | self.conv4 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, padding_mode=p_mode) 29 | self.bn4 = nn.BatchNorm2d(64) 30 | 31 | self.conv5 = nn.Conv2d(64, output_classes, kernel_size=1, stride=1, padding=0) 32 | 33 | #self.Dropout = nn.Dropout(p=0.3) 34 | 35 | def forward(self, x): 36 | x = self.conv1(x) 37 | x = F.relu(x) 38 | x = self.bn1(x) 39 | #x = self.Dropout(x) 40 | 41 | x = self.conv2(x) 42 | x = F.relu(x) 43 | x = self.bn2(x) 44 | #x = self.Dropout(x) 45 | 46 | x = self.conv3(x) 47 | x = F.relu(x) 48 | x = self.bn3(x) 49 | #x = self.Dropout(x) 50 | 51 | x = self.conv4(x) 52 | x = F.relu(x) 53 | x = self.bn4(x) 54 | 55 | x = self.conv5(x) 56 | return x 57 | 58 | 59 | class GCN(torch.nn.Module): 60 | """ 61 | Proposed Graph Convolutional Network 62 | This function/module uses classic GCN layers to generate superpixels(nodes) classification. 63 | --"Semi-Supervised Classification with Graph Convolutional Networks", 64 | --Thomas N. Kipf, Max Welling, ICLR2017 65 | 66 | Keyword arguments: 67 | input_dim -- input channel, aligns with output channel from FCN 68 | output_classes --output channel, default 1 for our proposed loss function 69 | """ 70 | def __init__(self, input_dim, output_classes): 71 | super(GCN, self).__init__() 72 | self.conv1 = GCNConv(input_dim, 64) 73 | self.conv2 = GCNConv(64, 128) 74 | self.conv3 = GCNConv(128, 256) 75 | self.conv4 = GCNConv(256, 64) 76 | self.conv5 = GCNConv(64, output_classes) 77 | #self.Dropout = nn.Dropout(p=0.5) 78 | 79 | # self.bn1 = nn.BatchNorm1d(64) 80 | # self.bn2 = nn.BatchNorm1d(128) 81 | # self.bn3 = nn.BatchNorm1d(256) 82 | # self.bn4 = nn.BatchNorm1d(64) 83 | # 84 | # self.lin1 = Linear(64, 256) 85 | # self.lin2 = Linear(256, 128) 86 | # self.lin3 = Linear(128, output_classes) 87 | 88 | def forward(self, data): 89 | x = self.conv1(data.x, edge_index = data.edge_index, edge_weight = data.edge_weight) 90 | x = F.relu(x) 91 | #x = self.Dropout(x) 92 | #x = self.bn1(x) 93 | 94 | x = self.conv2(x, edge_index = data.edge_index, edge_weight = data.edge_weight) 95 | x = F.relu(x) 96 | #x = self.Dropout(x) 97 | #x = self.bn2(x) 98 | 99 | x = self.conv3(x, edge_index = data.edge_index, edge_weight = data.edge_weight) 100 | x = F.relu(x) 101 | #x = self.Dropout(x) 102 | #x = self.bn3(x) 103 | 104 | x = self.conv4(x, edge_index = data.edge_index, edge_weight = data.edge_weight) 105 | x = F.relu(x) 106 | #x = self.bn4(x) 107 | 108 | x = self.conv5(x, edge_index = data.edge_index, edge_weight = data.edge_weight) 109 | 110 | 111 | return torch.tanh(x) 112 | -------------------------------------------------------------------------------- /state_dict/FCN1000.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/state_dict/FCN1000.pt -------------------------------------------------------------------------------- /state_dict/FCNopt1000.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/state_dict/FCNopt1000.pt -------------------------------------------------------------------------------- /state_dict/GCN1000.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/state_dict/GCN1000.pt -------------------------------------------------------------------------------- /state_dict/GCNopt1000.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/state_dict/GCNopt1000.pt -------------------------------------------------------------------------------- /train_process_files/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/train_process_files/.keep -------------------------------------------------------------------------------- /utilities.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | from torch_geometric.data import Data, Batch 5 | from torch.autograd import Variable 6 | from skimage.segmentation import slic, mark_boundaries 7 | from skimage import io 8 | import cv2 9 | import sys 10 | import numpy as np 11 | from skimage import segmentation 12 | import matplotlib.pyplot as plt 13 | import os 14 | from collections import defaultdict 15 | use_cuda = torch.cuda.is_available() 16 | 17 | 18 | def rotate(image,angle): 19 | return np.rot90(image,angle) 20 | 21 | def rank(l): 22 | """ 23 | Rank/Frequency Counter 24 | This function efficiently calculates number rank/frequency from a ndarray 25 | e.g. rank([3,5,2,2,2,2,1,1]) -> ((4,2), (2,1), (1,3), (1,5)) 26 | 27 | 28 | Keyword arguments: 29 | l --a 1-dimensional numpy array 30 | """ 31 | d = defaultdict(int) 32 | for i in l: 33 | d[i] += 1 34 | return sorted(d.items(), key=lambda x: x[1], reverse=True) 35 | 36 | 37 | def adjacency(sMask): 38 | """ 39 | Adjacency Matrix 40 | This function (not so efficiently) finds bidirectional adjacency edges 41 | -from a SLIC superpixel mask. 42 | 43 | 44 | Keyword arguments: 45 | sMask --a 2-dimensional numpy array generated by SLIC algorithm 46 | """ 47 | curr = 0 48 | adj = [] 49 | for y in range(sMask.shape[0]): 50 | #print(y) 51 | for x in range(sMask.shape[1]): 52 | if x >= sMask.shape[1] - 1: #reach end of row 53 | if y >= sMask.shape[0] - 1: #reach end of graph 54 | return adj 55 | else: #switch to new line, update curr 56 | curr = sMask[0, y+1] 57 | continue 58 | else: #still iterating row 59 | if sMask[x,y] != curr: 60 | if (curr, sMask[x,y]) not in adj: 61 | #new edge not in adjacency list 62 | adj.append((curr, sMask[x,y])) 63 | adj.append((sMask[x,y], curr)) #bidirectional 64 | curr = sMask[x,y] 65 | else: #already in adj 66 | curr = sMask[x,y] 67 | continue 68 | else: 69 | continue 70 | return adj 71 | 72 | 73 | def adjacency3(sMask, mask_length): 74 | """ 75 | Adjacency Matrix 76 | This function efficiently finds bidirectional adjacency edges 77 | -from a SLIC superpixel mask at the cost of higher RAM/cache usage 78 | 79 | 80 | Keyword arguments: 81 | sMask --a 2-dimensional numpy array generated by SLIC algorithm 82 | mask_length --number of unique SLIC superpixels in sMask 83 | """ 84 | curr = 0 85 | adj = [] 86 | edge_visited = np.full((mask_length, mask_length), False) 87 | for y in range(sMask.shape[0]): 88 | for x in range(sMask.shape[1]): 89 | if x >= sMask.shape[1] - 1: #reach end of row 90 | if y >= sMask.shape[0] - 1: #reach end of graph 91 | return adj 92 | else: #switch to new line, update curr 93 | curr = sMask[0, y+1] 94 | continue 95 | else: #still iterating row 96 | target = sMask[x,y] 97 | if target != curr: 98 | if edge_visited[curr, target] == False or \ 99 | edge_visited[target, curr] == False: 100 | #new edge not in adjacency list 101 | adj.append((curr, target)) 102 | adj.append((target, curr)) #bidirectional 103 | edge_visited[curr, target] = True 104 | edge_visited[target, curr] = True 105 | curr = sMask[x,y] 106 | else: #already in adj 107 | curr = sMask[x,y] 108 | continue 109 | else: 110 | continue 111 | return adj 112 | 113 | 114 | def center(pixel): 115 | """ 116 | Center of Superpixels 117 | This function efficiently finds weighted center of SLIC superpixel 118 | -tiles 119 | 120 | 121 | Keyword arguments: 122 | pixel --(x,y) tuple pixel corrdinates list from a SLIC superpixel tile. 123 | """ 124 | x_sorted = sorted(pixel.copy()) 125 | y_sorted = sorted(pixel, key=lambda tup: tup[1]) #sorted by y 126 | assert(len(x_sorted) == len(y_sorted)) 127 | mid = int(len(x_sorted) / 2) 128 | mid_x = int(x_sorted[mid][0]) 129 | mid_y = int(y_sorted[mid][1]) #x, y individual middle points 130 | return (mid_x, mid_y) 131 | 132 | 133 | def euclidean_dist(a, b): 134 | """ 135 | Euclidean Distance 136 | Calculate Eucliden Distance between two matricies. 137 | 138 | 139 | Keyword arguments: 140 | a --first matrix's ndarray 141 | b --second matrix's ndarray, must match input a's shape 142 | """ 143 | return np.linalg.norm(a-b) 144 | 145 | 146 | def mse(a,b): 147 | """ 148 | Mean Squared Error 149 | Calculate Mean Squared Error between two matricies. 150 | 151 | 152 | Keyword arguments: 153 | a --first matrix's ndarray 154 | b --second matrix's ndarray, must match input a's shape 155 | """ 156 | return ((a - b)**2).mean(axis=None) 157 | 158 | 159 | def fnorm(a, b): 160 | """ 161 | Frobenius Norm Distance 162 | Calculate Frobenius Norm Distance between two matricies. 163 | 164 | 165 | Keyword arguments: 166 | a --first matrix's ndarray 167 | b --second matrix's ndarray, must match input a's shape 168 | """ 169 | a_ss = np.sum(a = np.square(a)) #sum of squares 170 | b_ss = np.sum(a = np.square(b)) 171 | return np.sqrt(np.absolute(a_ss - b_ss)) 172 | 173 | 174 | def chisq_dist(a, b, gamma = 1): 175 | """ 176 | Chi-Square Distance 177 | Calculate Chi-Square Distance between two matricies. 178 | 179 | 180 | Keyword arguments: 181 | a --first matrix's ndarray 182 | b --second matrix's ndarray, must match input a's shape 183 | """ 184 | numerator = np.sum(np.square(a-b)) 185 | denomenator = np.sum(a+b) 186 | dist = 0.5 * numerator / denomenator 187 | return np.exp(- gamma * dist) 188 | 189 | 190 | def one_label_loss(gt_percent, predict, moe, batch_node_num): 191 | """ 192 | Proposed Loss Function 193 | Our proposed Loss Functions calculates cost of training batch using 194 | -GCN's output graphs and weak image level annotations. 195 | For more information, please refer to our paper. 196 | 197 | 198 | Keyword arguments: 199 | gt_percent --Ground-Trueth percent, a weak image-level annotation 200 | predict --GCN module output, gradient required 201 | moe --Margin of Error, a weak image-level annotation 202 | batch_node_num --integer list of node numbers per image in batch 203 | """ 204 | curr_index = 0 205 | batch_top_k_loss = [] 206 | batch_bottom_k_loss = [] 207 | batch_pairwise_loss = [] 208 | positive_num = 0.00000001 209 | negative_num = 0.00000001 210 | for i in range(len(gt_percent)): 211 | total_length = batch_node_num[i] #one graph length 212 | predict_slice = torch.narrow(input = predict, dim = 0, start = curr_index, length = total_length) 213 | curr_index += total_length 214 | one_gt_percent = gt_percent[i] 215 | one_moe = moe[i] 216 | select = torch.tensor([0]) 217 | if use_cuda: 218 | select = select.to('cuda') 219 | 220 | threshold_ceil = int(total_length * (one_gt_percent - one_moe)) #100 * (0.8 - 0.1) = top 70 % 221 | if threshold_ceil < 0: 222 | threshold_ceil = 0 223 | threshold_floor = int(total_length * (1.0 - one_gt_percent - one_moe)) #100 * (1 - 0.8 - 0.1) = bottom 10 % 224 | if threshold_floor < 0: 225 | threshold_floor = 0 226 | 227 | top_k, _ = torch.topk(input = predict_slice, k = threshold_ceil, dim = 0, largest = True, sorted = False) 228 | bottom_k, _ = torch.topk(input = predict_slice, k = threshold_floor, dim = 0, largest = False, sorted = False) 229 | 230 | top_k_mean = torch.mean(top_k,dim=0) 231 | bottom_k_mean = torch.mean(bottom_k,dim=0) 232 | 233 | predict_slice = None 234 | top_k = None 235 | select = None 236 | bottom_k = None 237 | loss_fn = nn.SmoothL1Loss() 238 | if use_cuda: 239 | temp_ones = torch.ones(1, dtype = torch.float).to('cuda') 240 | temp_zeros = torch.tensor([-1], dtype = torch.float).to('cuda') 241 | temp_ground = torch.zeros(1, dtype = torch.float).to('cuda') 242 | if threshold_ceil > 0: 243 | #top_k_loss = F.l1_loss(top_k_mean, temp_ones) 244 | top_k_loss = loss_fn(top_k_mean, temp_ones) 245 | positive_num += top_k_loss.detach().cpu().numpy() 246 | else: 247 | top_k_loss = None 248 | 249 | if threshold_floor > 0: 250 | #bottom_k_loss = F.l1_loss(bottom_k_mean, temp_zeros) 251 | bottom_k_loss = loss_fn(bottom_k_mean, temp_zeros) 252 | negative_num += bottom_k_loss.detach().cpu().numpy() 253 | else: 254 | bottom_k_loss = None 255 | temp_ones = None 256 | temp_zeors = None 257 | else: 258 | if threshold_ceil > 0: 259 | #top_k_loss = F.l1_loss(top_k_mean, torch.ones(1, dtype = torch.float)) 260 | top_k_loss = loss_fn(top_k_mean, torch.ones(1, dtype = torch.float)) 261 | positive_num += 1.0 262 | else: 263 | top_k_loss = None 264 | 265 | if threshold_floor > 0: 266 | #bottom_k_loss = F.l1_loss(bottom_k_mean, torch.zeros(1, dtype = torch.float)) 267 | bottom_k_loss = loss_fn(bottom_k_mean, torch.zeros(1, dtype = torch.float)) 268 | negative_num += 1.0 269 | else: 270 | bottom_k_loss = None 271 | batch_top_k_loss.append(top_k_loss) 272 | batch_bottom_k_loss.append(bottom_k_loss) 273 | top_k_loss = None 274 | bottom_k_loss = None 275 | pairwise_loss = None 276 | print("-------------------------------------------------------------------------------") 277 | print("Targeted Regions Losses Per Image") 278 | print([round(float(x.data.cpu().detach().numpy()),2) if x is not None else -1.00 for x in batch_top_k_loss]) 279 | print("Background Regions Losses Per Image") 280 | print([round(float(x.data.cpu().detach().numpy()),2) if x is not None else -1.00 for x in batch_bottom_k_loss]) 281 | print("-------------------------------------------------------------------------------") 282 | 283 | for t, b, g, a in zip(batch_top_k_loss, batch_bottom_k_loss, gt_percent, moe): 284 | if top_k_loss is None and t is not None: 285 | top_k_loss = (g - a) * t 286 | elif t is not None: 287 | top_k_loss += (g - a) * t 288 | if bottom_k_loss is None and b is not None: 289 | bottom_k_loss = (1.0 - g - a) * b 290 | elif b is not None: 291 | bottom_k_loss += (1.0 - g - a) * b 292 | return top_k_loss, bottom_k_loss 293 | 294 | 295 | 296 | 297 | def plot_grad_flow(named_parameters): 298 | """ 299 | Utility Function-Visualize Gradient Flow 300 | This utility function can assist in checking gradient flow between layers. 301 | 302 | 303 | Keyword arguments: 304 | named_parameters --module's parameter() 305 | """ 306 | ave_grads = [] 307 | layers = [] 308 | for n, p in named_parameters: 309 | if(p.requires_grad) and ("bias" not in n): 310 | layers.append(n) 311 | ave_grads.append(p.grad.abs().mean()) 312 | plt.plot(ave_grads, alpha=0.3, color="b") 313 | plt.hlines(0, 0, len(ave_grads)+1, linewidth=1, color="k" ) 314 | plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical") 315 | plt.xlim(xmin=0, xmax=len(ave_grads)) 316 | plt.xlabel("Layers") 317 | plt.ylabel("average gradient") 318 | plt.title("Gradient flow") 319 | plt.grid(True) 320 | plt.show() 321 | 322 | def fuse_results(multi_args): 323 | (outgcn_numpy, segments_copy) = multi_args 324 | for segInd in range(len(outgcn_numpy)): 325 | segments_copy[segments_copy == segInd] = outgcn_numpy[segInd] 326 | return segments_copy 327 | --------------------------------------------------------------------------------