├── README.md
├── Warmup_with_SimSiam
├── .gitignore
├── LICENSE
├── README.md
├── byol_pytorch
│ ├── __init__.py
│ ├── byol_pytorch.py
│ ├── dataset.py
│ └── model.py
├── diagram.png
├── examples
│ ├── dataset.py
│ └── lightning
│ │ ├── README.md
│ │ └── train.py
├── setup.py
└── warmup.py
├── args_parser.py
├── finaledgegcncopy.py
├── inference
├── 001.png
├── 002.png
├── 003.png
├── 004.png
├── 005.png
├── 006.png
└── 007.png
├── input_data
├── 001.png
├── 002.png
├── 003.png
├── 004.png
├── 005.png
├── 006.png
└── 007.png
├── model.py
├── state_dict
├── FCN1000.pt
├── FCNopt1000.pt
├── GCN1000.pt
└── GCNopt1000.pt
├── train_process_files
└── .keep
└── utilities.py
/README.md:
--------------------------------------------------------------------------------
1 | # Joint Fully Convolutional and Graph Convolutional Networks for Weakly-Supervised Segmentation of Pathology Images
2 |
3 | # Instructions
4 | A trained checkpoint numbered 1000 is provided with 9 HER2 pathology images for use in inference.
5 | This checkpoint is trained with 226 HER2 pathology images from a private dataset
6 |
7 | ## To run inference:
8 | `python3 finaledgegcncopy.py --inference-path full_path_to/inference --checkpoint xxxx`
9 | for example, with provided images and state dict, run like:
10 | `python3 finaledgegcncopy.py --inference-path full_path_to/inference --checkpoint 1000`
11 |
12 | ## To run Train:
13 | `python3 finaledgegcncopy.py`
14 |
15 | ## To resume Train:
16 | `python3 finaledgegcncopy.py --checkpoint xxxx`
17 |
18 | ## Flags and folders:
19 |
20 | `--train-path` or `./train_process_files`: a folder which the pipeline saves training visualization files to
21 |
22 | `--input-path` or `./input_data`: images used for inference or training. For our weakly supervised loss to work, the training images should be named as: AreaRatio_Uncertainty_*.png.
23 | For example, 0.4_0.05_*.png means the target region occupies (40+/-5)% of the image.
24 |
25 | `--inference-path` or `full_path_to/inference`: an argument that defines a folder for the pipeline to output inferenced mask to.
26 | Setting this argument will switch on inference mode. This argument must be used with `--checkpoint`.
27 |
--------------------------------------------------------------------------------
/Warmup_with_SimSiam/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/Warmup_with_SimSiam/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Warmup_with_SimSiam/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Bootstrap Your Own Latent (BYOL), in Pytorch
4 |
5 | [](https://badge.fury.io/py/byol-pytorch)
6 |
7 | Practical implementation of an astoundingly simple method for self-supervised learning that achieves a new state of the art (surpassing SimCLR) without contrastive learning and having to designate negative pairs.
8 |
9 | This repository offers a module that one can easily wrap any image-based neural network (residual network, discriminator, policy network) to immediately start benefitting from unlabelled image data.
10 |
11 | Update 1: There is now new evidence that batch normalization is key to making this technique work well
12 |
13 | Update 2: A new paper has successfully replaced batch norm with group norm + weight standardization, refuting that batch statistics are needed for BYOL to work
14 |
15 | Yannic Kilcher's excellent explanation
16 |
17 | ## Install
18 |
19 | ```bash
20 | $ pip install byol-pytorch
21 | ```
22 |
23 | ## Usage
24 |
25 | Simply plugin your neural network, specifying (1) the image dimensions as well as (2) the name (or index) of the hidden layer, whose output is used as the latent representation used for self-supervised training.
26 |
27 | ```python
28 | import torch
29 | from byol_pytorch import BYOL
30 | from torchvision import models
31 |
32 | resnet = models.resnet50(pretrained=True)
33 |
34 | learner = BYOL(
35 | resnet,
36 | image_size = 256,
37 | hidden_layer = 'avgpool'
38 | )
39 |
40 | opt = torch.optim.Adam(learner.parameters(), lr=3e-4)
41 |
42 | def sample_unlabelled_images():
43 | return torch.randn(20, 3, 256, 256)
44 |
45 | for _ in range(100):
46 | images = sample_unlabelled_images()
47 | loss = learner(images)
48 | opt.zero_grad()
49 | loss.backward()
50 | opt.step()
51 | learner.update_moving_average() # update moving average of target encoder
52 |
53 | # save your improved network
54 | torch.save(resnet.state_dict(), './improved-net.pt')
55 | ```
56 |
57 | That's pretty much it. After much training, the residual network should now perform better on its downstream supervised tasks.
58 |
59 | ## BYOL → SimSiam
60 |
61 | A new paper from Kaiming He suggests that BYOL does not even need the target encoder to be an exponential moving average of the online encoder. I've decided to build in this option so that you can easily use that variant for training, simply by setting the `use_momentum` flag to `False`. You will no longer need to invoke `update_moving_average` if you go this route as shown in the example below.
62 |
63 | ```python
64 | import torch
65 | from byol_pytorch import BYOL
66 | from torchvision import models
67 |
68 | resnet = models.resnet50(pretrained=True)
69 |
70 | learner = BYOL(
71 | resnet,
72 | image_size = 256,
73 | hidden_layer = 'avgpool',
74 | use_momentum = False # turn off momentum in the target encoder
75 | )
76 |
77 | opt = torch.optim.Adam(learner.parameters(), lr=3e-4)
78 |
79 | def sample_unlabelled_images():
80 | return torch.randn(20, 3, 256, 256)
81 |
82 | for _ in range(100):
83 | images = sample_unlabelled_images()
84 | loss = learner(images)
85 | opt.zero_grad()
86 | loss.backward()
87 | opt.step()
88 |
89 | # save your improved network
90 | torch.save(resnet.state_dict(), './improved-net.pt')
91 | ```
92 |
93 | ## Advanced
94 |
95 | While the hyperparameters have already been set to what the paper has found optimal, you can change them with extra keyword arguments to the base wrapper class.
96 |
97 | ```python
98 | learner = BYOL(
99 | resnet,
100 | image_size = 256,
101 | hidden_layer = 'avgpool',
102 | projection_size = 256, # the projection size
103 | projection_hidden_size = 4096, # the hidden dimension of the MLP for both the projection and prediction
104 | moving_average_decay = 0.99 # the moving average decay factor for the target encoder, already set at what paper recommends
105 | )
106 | ```
107 |
108 | By default, this library will use the augmentations from the SimCLR paper (which is also used in the BYOL paper). However, if you would like to specify your own augmentation pipeline, you can simply pass in your own custom augmentation function with the `augment_fn` keyword.
109 |
110 | Augmentations must work in the tensor space. `kornia` library is highly recommended for this. If you decide to use torchvision augmentations, make sure the tensor is first converted to PIL `.toPILImage()`, and then back to tensors `.ToTensor()`
111 |
112 | ```python
113 | augment_fn = nn.Sequential(
114 | kornia.augmentation.RandomHorizontalFlip()
115 | )
116 |
117 | learner = BYOL(
118 | resnet,
119 | image_size = 256,
120 | hidden_layer = -2,
121 | augment_fn = augment_fn
122 | )
123 | ```
124 |
125 | In the paper, they seem to assure that one of the augmentations have a higher gaussian blur probability than the other. You can also adjust this to your heart's delight.
126 |
127 | ```python
128 | augment_fn = nn.Sequential(
129 | kornia.augmentation.RandomHorizontalFlip()
130 | )
131 |
132 | augment_fn2 = nn.Sequential(
133 | kornia.augmentation.RandomHorizontalFlip(),
134 | kornia.filters.GaussianBlur2d((3, 3), (1.5, 1.5))
135 | )
136 |
137 | learner = BYOL(
138 | resnet,
139 | image_size = 256,
140 | hidden_layer = -2,
141 | augment_fn = augment_fn,
142 | augment_fn2 = augment_fn2,
143 | )
144 | ```
145 |
146 | ## Citation
147 |
148 | ```bibtex
149 | @misc{grill2020bootstrap,
150 | title = {Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning},
151 | author = {Jean-Bastien Grill and Florian Strub and Florent Altché and Corentin Tallec and Pierre H. Richemond and Elena Buchatskaya and Carl Doersch and Bernardo Avila Pires and Zhaohan Daniel Guo and Mohammad Gheshlaghi Azar and Bilal Piot and Koray Kavukcuoglu and Rémi Munos and Michal Valko},
152 | year = {2020},
153 | eprint = {2006.07733},
154 | archivePrefix = {arXiv},
155 | primaryClass = {cs.LG}
156 | }
157 | ```
158 |
159 | ```bibtex
160 | @misc{chen2020exploring,
161 | title={Exploring Simple Siamese Representation Learning},
162 | author={Xinlei Chen and Kaiming He},
163 | year={2020},
164 | eprint={2011.10566},
165 | archivePrefix={arXiv},
166 | primaryClass={cs.CV}
167 | }
168 | ```
169 |
--------------------------------------------------------------------------------
/Warmup_with_SimSiam/byol_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from byol_pytorch.byol_pytorch import BYOL
2 |
--------------------------------------------------------------------------------
/Warmup_with_SimSiam/byol_pytorch/byol_pytorch.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import random
3 | from functools import wraps
4 |
5 | import torch
6 | from torch import nn
7 | import torch.nn.functional as F
8 |
9 | from kornia import augmentation as augs
10 | from kornia import filters, color
11 |
12 | # helper functions
13 |
14 | def default(val, def_val):
15 | return def_val if val is None else val
16 |
17 | def flatten(t):
18 | return t.reshape(t.shape[0], -1)
19 |
20 | def singleton(cache_key):
21 | def inner_fn(fn):
22 | @wraps(fn)
23 | def wrapper(self, *args, **kwargs):
24 | instance = getattr(self, cache_key)
25 | if instance is not None:
26 | return instance
27 |
28 | instance = fn(self, *args, **kwargs)
29 | setattr(self, cache_key, instance)
30 | return instance
31 | return wrapper
32 | return inner_fn
33 |
34 | def get_module_device(module):
35 | return next(module.parameters()).device
36 |
37 | def set_requires_grad(model, val):
38 | for p in model.parameters():
39 | p.requires_grad = val
40 |
41 | # loss fn
42 |
43 | def loss_fn(x, y):
44 | x = F.normalize(x, dim=1, p=2)
45 | y = F.normalize(y, dim=1, p=2)
46 | return 2 - 2 * (x * y).sum(dim=1)
47 |
48 | # augmentation utils
49 |
50 | class RandomApply(nn.Module):
51 | def __init__(self, fn, p):
52 | super().__init__()
53 | self.fn = fn
54 | self.p = p
55 | def forward(self, x):
56 | if random.random() > self.p:
57 | return x
58 | return self.fn(x)
59 |
60 | # exponential moving average
61 |
62 | class EMA():
63 | def __init__(self, beta):
64 | super().__init__()
65 | self.beta = beta
66 |
67 | def update_average(self, old, new):
68 | if old is None:
69 | return new
70 | return old * self.beta + (1 - self.beta) * new
71 |
72 | def update_moving_average(ema_updater, ma_model, current_model):
73 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
74 | old_weight, up_weight = ma_params.data, current_params.data
75 | ma_params.data = ema_updater.update_average(old_weight, up_weight)
76 |
77 | # MLP class for projector and predictor
78 |
79 | class MLP(nn.Module):
80 | def __init__(self, dim, projection_size, hidden_size = 4096):
81 | super().__init__()
82 | self.net = nn.Sequential(
83 | nn.Linear(dim, hidden_size),
84 | nn.BatchNorm1d(hidden_size),
85 | nn.ReLU(inplace=True),
86 | nn.Linear(hidden_size, projection_size)
87 | )
88 |
89 | def forward(self, x):
90 | return self.net(x)
91 |
92 | # a wrapper class for the base neural network
93 | # will manage the interception of the hidden layer output
94 | # and pipe it into the projecter and predictor nets
95 |
96 | class NetWrapper(nn.Module):
97 | def __init__(self, net, projection_size, projection_hidden_size, layer = -2):
98 | super().__init__()
99 | self.net = net
100 | self.layer = layer
101 |
102 | self.projector = None
103 | self.projection_size = projection_size
104 | self.projection_hidden_size = projection_hidden_size
105 |
106 | self.hidden = None
107 | self.hook_registered = False
108 |
109 | def _find_layer(self):
110 | if type(self.layer) == str:
111 | modules = dict([*self.net.named_modules()])
112 | return modules.get(self.layer, None)
113 | elif type(self.layer) == int:
114 | children = [*self.net.children()]
115 | return children[self.layer]
116 | return None
117 |
118 | def _hook(self, _, __, output):
119 | self.hidden = flatten(output)
120 |
121 | def _register_hook(self):
122 | layer = self._find_layer()
123 | assert layer is not None, f'hidden layer ({self.layer}) not found'
124 | handle = layer.register_forward_hook(self._hook)
125 | self.hook_registered = True
126 |
127 | @singleton('projector')
128 | def _get_projector(self, hidden):
129 | _, dim = hidden.shape
130 | projector = MLP(dim, self.projection_size, self.projection_hidden_size)
131 | return projector.to(hidden)
132 |
133 | def get_representation(self, x):
134 | if self.layer == -1:
135 | return self.net(x)
136 |
137 | if not self.hook_registered:
138 | self._register_hook()
139 |
140 | _ = self.net(x)
141 | hidden = self.hidden
142 | self.hidden = None
143 | assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
144 | return hidden
145 |
146 | def forward(self, x):
147 | representation = self.get_representation(x)
148 | projector = self._get_projector(representation)
149 | projection = projector(representation)
150 | return projection
151 |
152 | # main class
153 |
154 | class BYOL(nn.Module):
155 | def __init__(
156 | self,
157 | encoder,
158 | predictor,
159 | image_size,
160 | hidden_layer = -2,
161 | projection_size = 256,
162 | projection_hidden_size = 4096,
163 | augment_fn = None,
164 | augment_fn2 = None,
165 | moving_average_decay = 0.99,
166 | use_momentum = True
167 | ):
168 | super().__init__()
169 |
170 | # default SimCLR augmentation
171 |
172 | DEFAULT_AUG = nn.Sequential(
173 | RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
174 | augs.RandomGrayscale(p=0.2),
175 | # augs.RandomHorizontalFlip(),
176 | RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
177 | # augs.RandomResizedCrop((image_size, image_size)),
178 | # augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]))
179 | )
180 |
181 | self.augment1 = default(augment_fn, DEFAULT_AUG)
182 | self.augment2 = default(augment_fn2, self.augment1)
183 |
184 | # self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer)
185 | self.online_encoder = encoder
186 |
187 | self.use_momentum = use_momentum
188 | self.target_encoder = None
189 | self.target_ema_updater = EMA(moving_average_decay)
190 |
191 | self.online_predictor = predictor
192 |
193 | # get device of network and make wrapper same device
194 | # device = get_module_device(net)
195 | device = torch.device(2)
196 | self.to(device)
197 |
198 | # send a mock image tensor to instantiate singleton parameters
199 | self.forward(torch.randn(2, 3, image_size, image_size, device=device))
200 |
201 | @singleton('target_encoder')
202 | def _get_target_encoder(self):
203 | target_encoder = copy.deepcopy(self.online_encoder)
204 | set_requires_grad(target_encoder, False)
205 | return target_encoder
206 |
207 | def reset_moving_average(self):
208 | del self.target_encoder
209 | self.target_encoder = None
210 |
211 | def update_moving_average(self):
212 | assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder'
213 | assert self.target_encoder is not None, 'target encoder has not been created yet'
214 | update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
215 |
216 | def forward(self, x):
217 | image_one, image_two = self.augment1(x), self.augment2(x)
218 |
219 | online_proj_one = self.online_encoder(image_one)
220 | online_proj_two = self.online_encoder(image_two)
221 |
222 | online_pred_one = self.online_predictor(online_proj_one)
223 | online_pred_two = self.online_predictor(online_proj_two)
224 |
225 | with torch.no_grad():
226 | target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder
227 | target_proj_one = target_encoder(image_one).detach()
228 | target_proj_two = target_encoder(image_two).detach()
229 |
230 | loss_one = loss_fn(online_pred_one, target_proj_two.detach())
231 | loss_two = loss_fn(online_pred_two, target_proj_one.detach())
232 |
233 | loss = loss_one + loss_two
234 | return loss.mean()
235 |
--------------------------------------------------------------------------------
/Warmup_with_SimSiam/byol_pytorch/dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import os
3 | from PIL import Image
4 | import torchvision.transforms as transforms
5 | from torch.utils.data import Dataset, DataLoader
6 |
7 | def pil_loader(path):
8 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
9 | with open(path, 'rb') as f:
10 | img = Image.open(f)
11 | return img.convert('RGB')
12 |
13 |
14 | class SegDataset(Dataset):
15 | """Face Landmarks dataset."""
16 |
17 | def __init__(self, root_dir):
18 | """
19 | Args:
20 | root_dir (string): Directory with all the images.
21 | """
22 | self.root_dir = root_dir
23 | self.image_paths = sorted(os.listdir(root_dir))
24 | self.transform = transforms.Compose([transforms.Resize((1024, 1024)), transforms.ToTensor()])
25 |
26 | def __len__(self):
27 | return len(self.image_paths)
28 |
29 | def __getitem__(self, idx):
30 | image_name = self.image_paths[idx]
31 | img_name = os.path.join(self.root_dir, image_name)
32 |
33 | image = pil_loader(img_name)
34 |
35 | return self.transform(image)
36 |
37 |
38 |
39 |
--------------------------------------------------------------------------------
/Warmup_with_SimSiam/byol_pytorch/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 |
5 | class FCN(nn.Module):
6 | """
7 | Proposed Fully Convolutional Network
8 | This function/module uses fully convolutional blocks to extract pixel-wise image features.
9 | Tested on 1024*1024, 512*512 resolution; RGB, Immunohistochemical color channels
10 |
11 | Keyword arguments:
12 | input_dim -- input channel, 3 for RGB images (default)
13 | """
14 | def __init__(self, input_dim, output_classes, p_mode = 'replicate'):
15 | super(FCN, self).__init__()
16 | #self.Dropout = nn.Dropout(p=0.05)
17 | self.conv1 = nn.Conv2d(input_dim, 32, kernel_size=3, stride=1, padding=1 ,padding_mode=p_mode)
18 | self.bn1 = nn.BatchNorm2d(32)
19 |
20 | self.conv2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, padding_mode=p_mode)
21 | self.bn2 = nn.BatchNorm2d(32)
22 |
23 | self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, padding_mode=p_mode)
24 | self.bn3 = nn.BatchNorm2d(64)
25 |
26 | self.conv4 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, padding_mode=p_mode)
27 | self.bn4 = nn.BatchNorm2d(64)
28 |
29 | self.conv5 = nn.Conv2d(64, output_classes, kernel_size=1, stride=1, padding=0)
30 |
31 | #self.Dropout = nn.Dropout(p=0.3)
32 |
33 | def forward(self, x):
34 | x = self.conv1(x)
35 | x = F.relu(x)
36 | x = self.bn1(x)
37 | #x = self.Dropout(x)
38 |
39 | x = self.conv2(x)
40 | x = F.relu(x)
41 | x = self.bn2(x)
42 | #x = self.Dropout(x)
43 |
44 | x = self.conv3(x)
45 | x = F.relu(x)
46 | x = self.bn3(x)
47 | #x = self.Dropout(x)
48 |
49 | x = self.conv4(x)
50 | x = F.relu(x)
51 | x = self.bn4(x)
52 |
53 | x = self.conv5(x)
54 | return x
55 |
56 | class Predictor(nn.Module):
57 | def __init__(self, dim, p_mode = 'replicate'):
58 | super(Predictor, self).__init__()
59 | self.conv4 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, padding_mode=p_mode)
60 | self.bn4 = nn.BatchNorm2d(dim)
61 |
62 | self.conv5 = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0)
63 |
64 | def forward(self, x):
65 | x = self.conv4(x)
66 | x = F.relu(x)
67 | x = self.bn4(x)
68 |
69 | x = self.conv5(x)
70 | return x
71 |
72 |
73 | class GCN(torch.nn.Module):
74 | """
75 | Proposed Graph Convolutional Network
76 | This function/module uses classic GCN layers to generate superpixels(nodes) classification.
77 | --"Semi-Supervised Classification with Graph Convolutional Networks",
78 | --Thomas N. Kipf, Max Welling, ICLR2017
79 |
80 | Keyword arguments:
81 | input_dim -- input channel, aligns with output channel from FCN
82 | output_classes --output channel, default 1 for our proposed loss function
83 | """
84 | def __init__(self, input_dim, output_classes):
85 | super(GCN, self).__init__()
86 | self.conv1 = GCNConv(input_dim, 64)
87 | self.conv2 = GCNConv(64, 128)
88 | self.conv3 = GCNConv(128, 256)
89 | self.conv4 = GCNConv(256, 64)
90 | self.conv5 = GCNConv(64, output_classes)
91 | #self.Dropout = nn.Dropout(p=0.5)
92 |
93 | # self.bn1 = nn.BatchNorm1d(64)
94 | # self.bn2 = nn.BatchNorm1d(128)
95 | # self.bn3 = nn.BatchNorm1d(256)
96 | # self.bn4 = nn.BatchNorm1d(64)
97 | #
98 | # self.lin1 = Linear(64, 256)
99 | # self.lin2 = Linear(256, 128)
100 | # self.lin3 = Linear(128, output_classes)
101 |
102 | def forward(self, data):
103 | x = self.conv1(data.x, edge_index = data.edge_index, edge_weight = data.edge_weight)
104 | x = F.relu(x)
105 | #x = self.Dropout(x)
106 | #x = self.bn1(x)
107 |
108 | x = self.conv2(x, edge_index = data.edge_index, edge_weight = data.edge_weight)
109 | x = F.relu(x)
110 | #x = self.Dropout(x)
111 | #x = self.bn2(x)
112 |
113 | x = self.conv3(x, edge_index = data.edge_index, edge_weight = data.edge_weight)
114 | x = F.relu(x)
115 | #x = self.Dropout(x)
116 | #x = self.bn3(x)
117 |
118 | x = self.conv4(x, edge_index = data.edge_index, edge_weight = data.edge_weight)
119 | x = F.relu(x)
120 | #x = self.bn4(x)
121 |
122 | x = self.conv5(x, edge_index = data.edge_index, edge_weight = data.edge_weight)
123 |
124 |
125 | return torch.tanh(x)
126 |
--------------------------------------------------------------------------------
/Warmup_with_SimSiam/diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/Warmup_with_SimSiam/diagram.png
--------------------------------------------------------------------------------
/Warmup_with_SimSiam/examples/dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import os
3 | from PIL import Image
4 | import torchvision.transforms as transforms
5 | from torch.utils.data import Dataset, DataLoader
6 |
7 | def pil_loader(path):
8 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
9 | with open(path, 'rb') as f:
10 | img = Image.open(f)
11 | return img.convert('RGB')
12 |
13 |
14 | class SegDataset(Dataset):
15 | """Face Landmarks dataset."""
16 |
17 | def __init__(self, root_dir):
18 | """
19 | Args:
20 | root_dir (string): Directory with all the images.
21 | """
22 | self.root_dir = root_dir
23 | self.image_paths = sorted(os.listdir(root_dir))
24 | self.transform = transforms.Compose([transforms.ToTensor()])
25 |
26 | def __len__(self):
27 | return len(self.landmarks_frame)
28 |
29 | def __getitem__(self, idx):
30 | image_name = self.image_paths[idx]
31 | img_name = os.path.join(self.root_dir, image_name)
32 |
33 | image = pil_loader(img_name)
34 |
35 | return self.transform(image)
36 |
37 |
38 |
39 |
--------------------------------------------------------------------------------
/Warmup_with_SimSiam/examples/lightning/README.md:
--------------------------------------------------------------------------------
1 | ## Pytorch-lightning example script
2 |
3 | ### Requirements
4 |
5 | ```bash
6 | $ pip install pytorch-lightning
7 | $ pip install pillow
8 | ```
9 |
10 | ### Run
11 |
12 | ```bash
13 | $ python train.py --image_folder /path/to/your/images
14 | ```
15 |
--------------------------------------------------------------------------------
/Warmup_with_SimSiam/examples/lightning/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import multiprocessing
4 | from pathlib import Path
5 | from PIL import Image
6 |
7 | import torch
8 | from torchvision import models, transforms
9 | from torch.utils.data import DataLoader, Dataset
10 |
11 | from byol_pytorch import BYOL
12 | import pytorch_lightning as pl
13 |
14 | # test model, a resnet 50
15 |
16 | resnet = models.resnet50(pretrained=True)
17 |
18 | # arguments
19 |
20 | parser = argparse.ArgumentParser(description='byol-lightning-test')
21 |
22 | parser.add_argument('--image_folder', type=str, required = True,
23 | help='path to your folder of images for self-supervised learning')
24 |
25 | args = parser.parse_args()
26 |
27 | # constants
28 |
29 | BATCH_SIZE = 32
30 | EPOCHS = 1000
31 | LR = 3e-4
32 | NUM_GPUS = 2
33 | IMAGE_SIZE = 256
34 | IMAGE_EXTS = ['.jpg', '.png', '.jpeg']
35 | NUM_WORKERS = multiprocessing.cpu_count()
36 |
37 | # pytorch lightning module
38 |
39 | class SelfSupervisedLearner(pl.LightningModule):
40 | def __init__(self, net, **kwargs):
41 | super().__init__()
42 | self.learner = BYOL(net, **kwargs)
43 |
44 | def forward(self, images):
45 | return self.learner(images)
46 |
47 | def training_step(self, images, _):
48 | loss = self.forward(images)
49 | return {'loss': loss}
50 |
51 | def configure_optimizers(self):
52 | return torch.optim.Adam(self.parameters(), lr=LR)
53 |
54 | def on_before_zero_grad(self, _):
55 | self.learner.update_moving_average()
56 |
57 | # images dataset
58 |
59 | def expand_greyscale(t):
60 | return t.expand(3, -1, -1)
61 |
62 | class ImagesDataset(Dataset):
63 | def __init__(self, folder, image_size):
64 | super().__init__()
65 | self.folder = folder
66 | self.paths = []
67 |
68 | for path in Path(f'{folder}').glob('**/*'):
69 | _, ext = os.path.splitext(path)
70 | if ext.lower() in IMAGE_EXTS:
71 | self.paths.append(path)
72 |
73 | print(f'{len(self.paths)} images found')
74 |
75 | self.transform = transforms.Compose([
76 | transforms.Resize(image_size),
77 | transforms.CenterCrop(image_size),
78 | transforms.ToTensor(),
79 | transforms.Lambda(expand_greyscale)
80 | ])
81 |
82 | def __len__(self):
83 | return len(self.paths)
84 |
85 | def __getitem__(self, index):
86 | path = self.paths[index]
87 | img = Image.open(path)
88 | img = img.convert('RGB')
89 | return self.transform(img)
90 |
91 | # main
92 |
93 | if __name__ == '__main__':
94 | ds = ImagesDataset(args.image_folder, IMAGE_SIZE)
95 | train_loader = DataLoader(ds, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True)
96 |
97 | model = SelfSupervisedLearner(
98 | resnet,
99 | image_size = IMAGE_SIZE,
100 | hidden_layer = 'avgpool',
101 | projection_size = 256,
102 | projection_hidden_size = 4096,
103 | moving_average_decay = 0.99
104 | )
105 |
106 | trainer = pl.Trainer(
107 | gpus = NUM_GPUS,
108 | max_epochs = EPOCHS,
109 | accumulate_grad_batches = 1
110 | )
111 |
112 | trainer.fit(model, train_loader)
113 |
--------------------------------------------------------------------------------
/Warmup_with_SimSiam/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'byol-pytorch',
5 | packages = find_packages(exclude=['examples']),
6 | version = '0.4.0',
7 | license='MIT',
8 | description = 'Self-supervised contrastive learning made simple',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | url = 'https://github.com/lucidrains/byol-pytorch',
12 | keywords = ['self-supervised learning', 'artificial intelligence'],
13 | install_requires=[
14 | 'torch>=1.6',
15 | 'kornia>=0.4.0'
16 | ],
17 | classifiers=[
18 | 'Development Status :: 4 - Beta',
19 | 'Intended Audience :: Developers',
20 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
21 | 'License :: OSI Approved :: MIT License',
22 | 'Programming Language :: Python :: 3.6',
23 | ],
24 | )
--------------------------------------------------------------------------------
/Warmup_with_SimSiam/warmup.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from byol_pytorch import BYOL
3 | from byol_pytorch.model import FCN, Predictor
4 | from byol_pytorch.dataset import SegDataset
5 |
6 | encoder = FCN(input_dim=3, output_classes=24)
7 | predictor = Predictor(dim=24)
8 |
9 | learner = BYOL(
10 | encoder=encoder,
11 | predictor=predictor,
12 | image_size = 1024,
13 | hidden_layer = 'avgpool',
14 | use_momentum = False
15 | )
16 |
17 | opt = torch.optim.Adam(learner.parameters(), lr=3e-4)
18 |
19 | imagenet = SegDataset(root_dir="/images/HER2/images")
20 | dataloader = torch.utils.data.DataLoader(imagenet, batch_size=2, shuffle=True, num_workers=0)
21 | device = torch.device(2)
22 |
23 | for _ in range(100):
24 | print("start epoch, ", _)
25 | for local_batch in dataloader:
26 | local_batch = local_batch.to(device)
27 | loss = learner(local_batch)
28 | opt.zero_grad()
29 | loss.backward()
30 | opt.step()
31 | # learner.update_moving_average() # update moving average of target encoder
32 | torch.cuda.empty_cache()
33 | torch.save(encoder.state_dict(), 'checkpoints/improved-net_{}.pt'.format(_))
34 |
35 | # save your improved network
36 |
--------------------------------------------------------------------------------
/args_parser.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | parser = argparse.ArgumentParser(description='PyTorch Weakly Supervised Segmentation')
4 | parser.add_argument('--nChannel', metavar='N', default=24, type=int,
5 | help='number of channels of histogram')
6 | parser.add_argument('--maxIter', metavar='T', default=10000, type=int,
7 | help='number of maximum iterations')
8 | parser.add_argument('--checkpoint', metavar='ckpt number', default=-1, type=int,
9 | help='continue training from checkpoint #')
10 | parser.add_argument('--num_superpixels', metavar='K', default=7000, type=int,
11 | help='number of initial superpixels')
12 | parser.add_argument('--compactness', metavar='C', default=3, type=float,
13 | help='compactness of superpixels')
14 | parser.add_argument('--visualize', metavar='1 or 0', default=0, type=int,
15 | help='0:display image using OPENCV | 1:save images to train-path')
16 | parser.add_argument('--color_channel_separation', metavar='True or False', default=False, type=bool,
17 | help='Immunohistochemical staining colors separation toggle')
18 | parser.add_argument('--half-precision', metavar='True or False', default=False, type=bool,
19 | help='Half precision training, requires torch 0.4.0 and apex from nVidia')
20 | parser.add_argument('--optimizer', default='Adam', help='optimizer(SGD|Adam)')
21 | parser.add_argument('--gcn-lr', default=0.0005, type=float, help='gcn learning rate')
22 | parser.add_argument('--fcn-lr', default=0.0001, type=float, help='fcn learning rate')
23 | parser.add_argument('-b','--batch-size', default=2, type=int, help='batch size')
24 | parser.add_argument('-t','--cpu-threads', default=8, type=int, help='number of threads for multiprocessing')
25 | parser.add_argument('--switch-iter', default=13, type=int, help='switch GCN into small \
26 | batch training after # of iterations')
27 | parser.add_argument('--adjust-iter', default=0.1, type=float, help='each iteration, \
28 | global_segments *= (1+/- slic_adjust_ratio)')
29 | parser.add_argument('--weight-ratio', default=1.5, type=float, help='edge weight complementing \
30 | (to 1) ratio, decreasing with training')
31 | parser.add_argument('--warmup-threshold', default=2, type=float, help='when FCN warmup loss \
32 | reaches #, terminate warmup and start training GCN')
33 | parser.add_argument('--output-size', default=1024, type=int, help='The resolution along one axis\
34 | of the image. input images will be scaled to the size defined.')
35 | parser.add_argument('--fuse-thresh', default=0.001, type=float, help='the cutoff to use when \
36 | outputting the final fused mask in inference')
37 | parser.add_argument('--train-path', type=str,
38 | default="./train_process_files",
39 | help='a folder to save train progress visualization',
40 | required=False)
41 | parser.add_argument('--input-path', type=str,
42 | default="./input_data",
43 | help='training set containing the labeled images',
44 | required=False)
45 | parser.add_argument('--checkpoint-path', type=str,
46 | default='./state_dict',
47 | help='path where the checkpoints are saved to',
48 | required=False)
49 | parser.add_argument('--inference-path', type=str,
50 | help='path where the inferenced masks are saved to',
51 | required=False)
52 | args = parser.parse_args()
53 |
--------------------------------------------------------------------------------
/finaledgegcncopy.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import torch
3 | from multiprocessing import Pool
4 | import math
5 | from torch.nn import Sequential, Linear, ReLU
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torch.optim as optim
9 | from torch_geometric.nn import DataParallel as GeoParallel
10 | from torch_geometric.data import Data, Batch
11 | from torchvision import datasets, transforms
12 | import torchvision
13 | import torch.utils.data as data
14 | from torch.autograd import Variable
15 | from skimage.color import rgb2hed
16 | from skimage.util import img_as_float
17 | from skimage.segmentation import slic, mark_boundaries
18 | from skimage.exposure import rescale_intensity
19 | from skimage import io
20 | import cv2
21 | import sys
22 | import numpy as np
23 | import random
24 | from skimage import segmentation
25 | import matplotlib.pyplot as plt
26 | import torch.nn.init
27 | from PIL import ImageFile, Image
28 | import os
29 | from collections import defaultdict
30 | from utilities import *
31 | from model import *
32 | from args_parser import *
33 |
34 |
35 | #amp_handle = amp.init(enabled=True)
36 |
37 | global_segments = 0
38 | if args.half_precision: from apex import amp
39 | ImageFile.LOAD_TRUNCATED_IMAGES = True
40 | ##CUDA_LAUNCH_BLOCKING=1 #if cuda fails to backpropagate use this toggle to debug
41 | OUTPUT_SIZE = args.output_size
42 | use_cuda = torch.cuda.is_available()
43 |
44 | SAVE_PATH = args.train_path
45 | SD_SAVE_PATH = args.checkpoint_path
46 | DATA_PATH = args.input_path
47 |
48 |
49 |
50 | class random_dataloader(torch.utils.data.Dataset):
51 | """
52 | Proposed Dataset with Random Matched Pairs
53 | This function is a dataset designed for small batch initialization.
54 | When initializing with small batches (< 8), less than ideal GCN convergence may occur.
55 | This dataset loads image patches from two subfolder: ./positives ./negatives
56 | -and combines positive and negative samples by adjustable ratio self.positive_negative_ratio
57 | -in one mini-batch.
58 |
59 | Keyword arguments:
60 | path --parent path which contains ./positives ./negatives subfolders
61 | """
62 | def __init__(self, path):
63 | super(random_dataloader, self).__init__()
64 | self.path = path
65 | self.positive_images = [names for names in os.listdir(os.path.join(path, 'positives'))]
66 | self.negative_images = [names for names in os.listdir(os.path.join(path, 'negatives'))]
67 | self.positive_image_files = []
68 | self.negative_image_files = []
69 | self.toTensor = transforms.ToTensor()
70 | self.positive_negative_ratio = 3 #3 positives and 1 negative in one mini-batch
71 | self.ratio_counter = 0
72 | self.negative_index = np.random.randint(low = 0, high = self.positive_negative_ratio + 1)
73 | for img in self.positive_images:
74 | if img[-4:] is not None and (img[-4:] == '.png' or img[-4:] == '.jpg'):
75 | img_file = os.path.join(os.path.join(path, 'positives'), "%s" % img)
76 | self.positive_image_files.append({
77 | "img": img_file
78 |
79 | })
80 | for img in self.negative_images:
81 | if img[-4:] is not None and (img[-4:] == '.png' or img[-4:] == '.jpg'):
82 | img_file = os.path.join(os.path.join(path, 'negatives'), "%s" % img)
83 | self.negative_image_files.append({
84 | "img": img_file
85 |
86 | })
87 |
88 | def __len__(self):
89 | return (len(self.positive_image_files) + len(self.negative_image_files))
90 |
91 | def __getitem__(self, index):
92 | if self.ratio_counter == self.negative_index:
93 | index = index % len(self.negative_images)
94 | data_file = self.negative_image_files[index]
95 | else:
96 | index = index % len(self.positive_images)
97 | data_file = self.positive_image_files[index]
98 |
99 | self.ratio_counter += 1
100 | if self.ratio_counter > self.positive_negative_ratio:
101 | self.negative_index = np.random.randint(low = 0, high = self.positive_negative_ratio + 1)
102 | self.ratio_counter = 0
103 | image = cv2.imread(data_file["img"])
104 | image = cv2.resize(image, (OUTPUT_SIZE, OUTPUT_SIZE))
105 | angle = np.random.randint(4)
106 | image = rotate(image, angle)
107 | image = cv2.flip(image, np.random.randint(2) - 1)
108 | if args.color_channel_separation:
109 | ihc_hed = rgb2hed(image)
110 | h = rescale_intensity(ihc_hed[:, :, 0], out_range=(0, 1))
111 | d = rescale_intensity(ihc_hed[:, :, 2], out_range=(0, 1))
112 | image = np.dstack((np.zeros_like(h), d, h))
113 | #image = zdh.transpose(2, 0, 1).astype('float32')/255
114 | name = data_file["img"]
115 | path, file = os.path.split(name)
116 | split_filename = file.split("_")
117 | gt_percent = float(split_filename[0])
118 | moe = float(split_filename[1])
119 | image = self.toTensor(image)
120 | return (image, name, gt_percent, moe)
121 |
122 |
123 | class normal_dataloader(torch.utils.data.Dataset):
124 | """
125 | Typical Dataset
126 | This function loads images from path and returns:
127 | image: image (resized, color channel separated(optional))
128 | name: full image path(contains image name)
129 | gt_percent: Ground-Truth percent, an image-level weak annotation
130 | moe: Margin of Error, an image-level weak annotation
131 |
132 | Keyword arguments:
133 | path --path which contains *.png or *.jpg image patches
134 | """
135 | def __init__(self, path):
136 | super(normal_dataloader, self).__init__()
137 | self.path = path
138 | self.images = [names for names in os.listdir(path)]
139 | self.image_files = []
140 | self.toTensor = transforms.ToTensor()
141 |
142 | for img in self.images:
143 | if img[-4:] is not None and (img[-4:] == '.png' or img[-4:] == '.jpg'):
144 | img_file = os.path.join(path, "%s" % img)
145 | self.image_files.append({
146 | "img": img_file,
147 | "label": "1"
148 | })
149 |
150 | def __len__(self):
151 | return len(self.image_files)
152 |
153 | def __getitem__(self, index):
154 | index = index % len(self.image_files)
155 | data_file = self.image_files[index]
156 | image = cv2.imread(data_file["img"])
157 | image = cv2.resize(image, (OUTPUT_SIZE, OUTPUT_SIZE))
158 | angle = np.random.randint(4)
159 | image = rotate(image, angle)
160 | image = cv2.flip(image, np.random.randint(2) - 1)
161 | if args.color_channel_separation:
162 | ihc_hed = rgb2hed(image)
163 | h = rescale_intensity(ihc_hed[:, :, 0], out_range=(0, 1))
164 | d = rescale_intensity(ihc_hed[:, :, 2], out_range=(0, 1))
165 | image = np.dstack((np.zeros_like(h), d, h))
166 | #image = zdh.transpose(2, 0, 1).astype('float32')/255
167 | name = data_file["img"]
168 | path, file = os.path.split(name)
169 | split_filename = file.split("_")
170 | gt_percent = float(split_filename[0])
171 | moe = float(split_filename[1])
172 | image = self.toTensor(image)
173 | return (image, name, gt_percent, moe)
174 |
175 |
176 | class inference_dataset(torch.utils.data.Dataset):
177 | """
178 | dataset loader used when inferencing
179 | """
180 | def __init__(self, path):
181 | super(inference_dataset, self).__init__()
182 | self.path = path
183 | self.images = [names for names in os.listdir(path)]
184 | self.image_files = []
185 | self.toTensor = transforms.ToTensor()
186 |
187 | for img in self.images:
188 | if img[-4:] is not None and (img[-4:] == '.jpg' or img[-4:] == '.png'):
189 | img_file = os.path.join(path, "%s" % img)
190 | #print(img)
191 | #print(img_file)
192 | self.image_files.append({
193 | "img": img_file,
194 | "label": "1"
195 | })
196 |
197 | def __len__(self):
198 | return len(self.image_files)
199 |
200 | def __getitem__(self, index):
201 | index = index % len(self.image_files)
202 | data_file = self.image_files[index]
203 | image = cv2.imread(data_file["img"])
204 | image = cv2.resize(image, (OUTPUT_SIZE, OUTPUT_SIZE))
205 | name = data_file["img"]
206 | image = self.toTensor(image)
207 | return (image, name, 0,0)
208 |
209 |
210 | def load_dataset(batch_size):
211 | """
212 | Typical Pytorch Dataloader
213 | This function loades image from Dataset
214 |
215 | Keyword arguments:
216 | path --please refer to class random_dataloader and class normal_dataloader
217 | """
218 | data_path = DATA_PATH
219 | train_dataset = normal_dataloader(path = data_path)
220 | train_loader = torch.utils.data.DataLoader(
221 | train_dataset,
222 | batch_size= batch_size, # <8*1024 res @ P40 cards, <25*512 res @ 1 P40 card
223 | num_workers= 8, #depends on RAM and CPU
224 | shuffle=True
225 | )
226 | return train_loader
227 |
228 | def inference_loader(batch_size = 1):
229 | """
230 | Typical Pytorch Dataloader
231 | This function loades image from Dataset
232 |
233 | Keyword arguments:
234 | path --please refer to class random_dataloader and class normal_dataloader
235 | """
236 | data_path = DATA_PATH
237 | train_dataset = inference_dataset(path = data_path)
238 | train_loader = torch.utils.data.DataLoader(
239 | train_dataset,
240 | batch_size= batch_size, # <8*1024 res @ P40 cards, <25*512 res @ 1 P40 card
241 | num_workers= 8, #depends on RAM and CPU
242 | shuffle=False
243 | )
244 | return train_loader
245 |
246 | def multithread_slic(multi_input):
247 | """
248 | Multi-Thread SLIC superpixel
249 | Since SLIC algorithm is implemented on CPU, we use Python Pool to accelerate
250 | -SLIC algorithm by taking full advantage of CPU's multi-threading capability.
251 | This function also calculate mutable edges' weight as one of GCN's input data.
252 |
253 |
254 | Keyword arguments:
255 | multi_input --tuple lists containing:
256 | (FCN output, Max Channel Response, Mutable Edge Weight Ratio)
257 | """
258 | (multi_output, max_channel_response, weight_ratio, i) = multi_input
259 | if i > 0:
260 | multi_output_slic = slic(multi_output, n_segments = i, compactness = args.compactness\
261 | , sigma = 0, multichannel = True)
262 | else:
263 | multi_output_slic = slic(multi_output, n_segments = global_segments, compactness = args.compactness\
264 | , sigma = 0, multichannel = True)
265 |
266 | num_segments = len(np.unique(multi_output_slic))
267 | multi_adj = adjacency3(multi_output_slic,num_segments)
268 | ## true_euclidean_distance = []
269 | ## f_norm = []
270 | ## mses = []
271 | chisq = []
272 | classes_raw = np.zeros((num_segments, args.nChannel),dtype="float32")
273 | for y in range(OUTPUT_SIZE):
274 | for x in range(OUTPUT_SIZE):
275 | curr_index = multi_output_slic[x][y]
276 | max_channel = max_channel_response[x][y]
277 | classes_raw[curr_index][max_channel] += 1.0
278 |
279 | for x in range(len(classes_raw)):
280 | max_in_row = np.amax(classes_raw[x])
281 | classes_raw[x] = classes_raw[x] / max_in_row
282 |
283 | for (p1, p2) in multi_adj:
284 | p1_class = np.asarray(classes_raw[p1])
285 | p2_class = np.asarray(classes_raw[p2])
286 | #true_euclidean_distance.append( euclidean_dist(p1center, p2center) )
287 | #f_norm.append( fnorm(p1_class,p2_class) )
288 | chisq.append( np.absolute(chisq_dist(p1_class, p2_class)) )
289 | #mses.append( mse(p1_class, p2_class) )
290 | chisq_max_value = np.amax(chisq)
291 | chisq = chisq / chisq_max_value
292 | #complementary_weight = np.ones_like(chisq) - chisq
293 | #chisq = chisq + weight_ratio * complementary_weight
294 | edge_weight = torch.from_numpy(chisq)
295 | return multi_output_slic, multi_adj, edge_weight, num_segments
296 |
297 |
298 | if __name__ == '__main__':
299 | # train
300 | model = FCN(3, args.nChannel)
301 | modelgcn = GCN(args.nChannel, 1)
302 | gcn_batch_iter = 1 #how many iteration on one GCN batch
303 | batch_counter = 0
304 | global_segments = args.num_superpixels #mutable superpixel quantity
305 | change_dataloader = args.switch_iter #switch GCN into small batch training after # of iterations
306 | model_loss = 99999 #variable keeping track of FCN loss during warmup phase
307 | in_GCN = False #True when FCN exiting warmup phase and feeding output into GCN
308 | inference_mode = False
309 | #slic_multiscale_descending = False #True when mutable superpixel quantity is decreasing per iteration
310 | slic_adjust_ratio = args.adjust_iter #each iteration, global_segments *= (1+/- slic_adjust_ratio)
311 | weight_ratio = args.weight_ratio #edge weight complementing (to 1) ratio, decreasing with training
312 | warmup_threshold = args.warmup_threshold# 0.5 #when FCN warmup loss reaches #, terminate warmup and start training GCN
313 | half_precision = args.half_precision
314 |
315 | if args.checkpoint > 0:
316 | #if given a checkpoint to resume training
317 | model.load_state_dict(torch.load(os.path.join(SD_SAVE_PATH, "FCN" + str(args.checkpoint) + ".pt")))
318 | modelgcn.load_state_dict(torch.load(os.path.join(SD_SAVE_PATH, "GCN" + str(args.checkpoint) + ".pt")))
319 | batch_counter = int(args.checkpoint)
320 | change_dataloader = batch_counter - 1
321 | in_GCN = True
322 |
323 | if args.inference_path is not None:
324 | dataset_loader = inference_loader()
325 | inference_mode = True
326 | args.maxIter = 1
327 | if args.checkpoint < 1:
328 | print("please define a inference checkpoint using --checkpoint")
329 | exit(-1)
330 | else:
331 | dataset_loader = load_dataset(args.batch_size)
332 |
333 | if not inference_mode and args.cpu_threads > args.batch_size:
334 | args.cpu_threads = args.batch_size
335 |
336 | if use_cuda:
337 | model = model.to('cuda')
338 | modelgcn = modelgcn.to('cuda')
339 | if half_precision:
340 | model, optimizer = amp.initialize(model, optimizer)
341 | modelgcn, optimizergcn = amp.initialize(modelgcn, optimizergcn)
342 | else:
343 | print("model using CPU, please check CUDA settings")
344 |
345 | if use_cuda:
346 | if torch.cuda.device_count() > 1:
347 | print(str(torch.cuda.device_count()) + " GPUs visible")
348 | model = nn.DataParallel(model)
349 | modelgcn = GeoParallel(modelgcn)
350 | if inference_mode:
351 | model.train()
352 | else:
353 | model.train()
354 | modelgcn = modelgcn.float() #necessary for edge_weight initialized training
355 | loss_fn = torch.nn.CrossEntropyLoss()
356 | if args.optimizer == 'SGD':
357 | optimizer = optim.SGD(model.parameters(), lr=args.fcn_lr, momentum=0)
358 | optimizergcn = optim.SGD(modelgcn.parameters(), lr=args.gcn_lr, momentum=0)
359 | elif args.optimizer == 'Adam':
360 | optimizer = optim.Adam(model.parameters(), lr=args.fcn_lr)
361 | optimizergcn = optim.Adam(modelgcn.parameters(), lr=args.gcn_lr)
362 | else:
363 | print("please reselect optimizer, curr value:", args.optimizer)
364 | exit(-1)
365 |
366 |
367 |
368 |
369 | if args.checkpoint > 0:
370 | #if given a checkpoint to resume training
371 | optimizer.load_state_dict(torch.load(os.path.join(SD_SAVE_PATH, "FCNopt" + str(args.checkpoint) + ".pt")))
372 | optimizergcn.load_state_dict(torch.load(os.path.join(SD_SAVE_PATH, "GCNopt" + str(args.checkpoint) + ".pt")))
373 | print("training successfully resumed!")
374 |
375 |
376 | #an RGB colors array used to visualize channel response
377 | label_colours = np.array([[0,0,0], [255,255,255],[0,0,255], [0,255,0],
378 | [255,0,0],[128,0,0],[0,128,0],[0,0,128],
379 | [255,255,0], [255,128,0], [128,255,0],
380 | [0,255,255],[255,0,255],[255,255,255],
381 | [128,128,128],[255,0,128],[0,128,255],
382 | [128,0,255],[0,255,128],[100,200,200],
383 | [200,100,100],[200,255,0],[100,255,0],
384 | [200,0,255],[30,99,212],[40,222,100],
385 | [100,200,25],[30,199,20],[0,211,200],
386 | [3,44,122],[23,44,100],[90,22,0],[233,111,222],
387 | [122,122,150],[0,233,149],[3,111,23]])
388 |
389 |
390 | for epoch in range(args.maxIter):
391 |
392 | """switch large batch into small batch"""
393 | # if batch_counter < change_dataloader:
394 | # dataset_loader = load_dataset(2)
395 | # else:
396 | # gcn_batch_iter = 1
397 | # dataset_loader = load_dataset(2)
398 | #
399 |
400 | for batch_idx, (data, name, gt_percent, moe) in \
401 | enumerate(dataset_loader):
402 |
403 | if not inference_mode:
404 | print("iteration: " + str(batch_counter) + " epoch: " + str(epoch))
405 | else:
406 | print("---------------------------------------------------")
407 | print("inferencing ", str(os.path.basename(name[0])))
408 | if args.visualize:
409 | """visualize using opencv"""
410 | originalimg = cv2.imread(name[0])
411 | originalimg = cv2.resize(originalimg, (OUTPUT_SIZE,OUTPUT_SIZE))
412 | cv2.imshow("original", originalimg)
413 | cv2.waitKey(1)
414 |
415 | batch_counter += 1 #records in-epoch progress
416 | if use_cuda:
417 | data = data.to('cuda')
418 | optimizer.zero_grad()
419 | output = model(data)
420 |
421 |
422 |
423 | nLabels = -1
424 |
425 | if not inference_mode and ((not in_GCN and batch_counter % 50 == 0) or (in_GCN and batch_counter % 10 == 0)):
426 | """FCN output visualization, either save to SAVE_PATH or display using opencv"""
427 | #model.eval()
428 | #output = model(data)
429 | ignore, target = torch.max( output, 1 )
430 | im_target = target.data.cpu().numpy() #label map original
431 | num_in_minibatch = 0
432 | for i in im_target:
433 | im_target = i.flatten()
434 | nLabels = len(np.unique(im_target))
435 | label_num = rank(im_target)
436 | label_rank = [i[0] for i in label_num]
437 | im_target_rgb = np.array([label_colours [label_rank.index(c)] for c in im_target])
438 | im_target_rgb = im_target_rgb.reshape( OUTPUT_SIZE, OUTPUT_SIZE, 3 ).astype( np.uint8 )
439 | curr_filename = name[num_in_minibatch][-16:-4]
440 | if args.visualize:
441 | cv2.imshow("pre", im_target_rgb)
442 | cv2.waitKey(1)
443 | else:
444 | cv2.imwrite(os.path.join(SAVE_PATH, 'PRE' + str(batch_counter) \
445 | + 'N' + str(num_in_minibatch) + "_" \
446 | + str(curr_filename) + '.png'), \
447 | cv2.cvtColor(im_target_rgb, cv2.COLOR_RGB2BGR))
448 | num_in_minibatch += 1
449 | torch.save(model.module.state_dict(), os.path.join(SD_SAVE_PATH,"FCN" + str(batch_counter) + ".pt"))
450 |
451 |
452 |
453 | if not inference_mode and (model_loss > warmup_threshold and not in_GCN): #stable 0.3 2000epoch
454 | """warning up FCN, loss is the cross entropy between pixel-wise
455 | -max channel responses and FCN model's output"""
456 | ignore, target = torch.max( output, 1 )
457 | loss = loss_fn(output, target)
458 | if half_precision:
459 | with amp.scale_loss(loss, optimizer) as scaled_loss:
460 | scaled_loss.backward()
461 | else:
462 | loss.backward()
463 | optimizer.step()
464 | if model_loss > loss.data:
465 | model_loss = loss.data
466 | print (epoch, '/', args.maxIter, ':', nLabels, loss.data)
467 | change_dataloader += 1
468 | continue
469 | else:
470 | """when FCN model_loss is below warmup_threshold, enter GCN traning"""
471 | #optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
472 | #change the learning rate of FCN to a less conservative number
473 | #optimizer = optim.Adam(model.parameters(), lr=0.00005)
474 | #dataset_loader = load_dataset(2)
475 | in_GCN = True
476 |
477 | if model_loss < warmup_threshold or in_GCN:
478 | """Proposed method bridging output of FCN and input to GCN"""
479 |
480 | """mutable superpixel count, changes every training iterations"""
481 | """count changes by a fixed pattern"""
482 | # if slic_multiscale_descending:
483 | # global_segments = int(global_segments * (1 - slic_adjust_ratio))
484 | # if global_segments < 2000:
485 | # slic_adjust_ratio = slic_adjust_ratio * 0.7
486 | # slic_multiscale_descending = False
487 | # else:
488 | # global_segments = int(global_segments * (1 + slic_adjust_ratio))
489 | # if global_segments > 8000:
490 | # global_segments = global_segments + 17
491 | # slic_multiscale_descending = True
492 | """count is randomized between [2000, 8000)"""
493 | global_segments = int(random.random() * 6000.0 + 2000.0)
494 |
495 | if not inference_mode: print("slic segments count:" + str(global_segments))
496 |
497 | gcn_batch_list = [] #a batch which later feed into GCN
498 | segments_list = [] #SLIC segmentations of each image in current GCN batch
499 | batch_node_num = [] #number of nodes(superpixel tiles) that each image have in\
500 | #current GCN batch
501 |
502 | print("computing multithread slic")
503 | """prepares FCN output for using in ,
504 | computes SLIC segmentations, adjacency edges, & edge weight using
505 | multi-threads"""
506 | multi_input = []
507 | for multi_one_graph in output:
508 | multi_one_graph = multi_one_graph.permute( 1 , 2 , 0 )
509 | _, max_channel_response = torch.max(multi_one_graph, 2)
510 | multi_one_graph = multi_one_graph.cpu().detach().numpy().astype(np.float64)
511 | max_channel_response = max_channel_response.cpu().detach().numpy()
512 | if not inference_mode:
513 | multi_input.append((multi_one_graph, max_channel_response, weight_ratio, -1))
514 | else:
515 | for i in (2000,3000,4000,5000,6000,7000,8000):
516 | multi_input.append((multi_one_graph, max_channel_response, weight_ratio, i))
517 | with Pool(args.cpu_threads) as p:
518 | multi_slic_adj_list = p.map(multithread_slic, multi_input)
519 | print("multithread slic finished")
520 | if weight_ratio > 0.2:
521 | weight_ratio = weight_ratio * 0.99 #reduce edge weight's complementary
522 |
523 | for one_graph, (segments, adj, edge_weight, num_segments) in zip(output, multi_slic_adj_list):
524 | """Bridging FCN's output into GCN's input, initialize GCN batch"""
525 | segments_list.append(segments)
526 | one_graph = one_graph.permute( 1 , 2 , 0 )
527 | #one_graph.shape: [img_size, img_size, channel_size]
528 | original_one_graph = torch.flatten(one_graph, start_dim = 0, end_dim = 1)
529 | #original_one_graph.shape: [channel_size, img_size*img_size]
530 | one_graph = None
531 |
532 | batch_node_num.append(num_segments)
533 |
534 | slic_pixels = [[] for _ in range (num_segments)]
535 | """slic_pixels stores x,y(flatten) corrdinates according to superpixel's index"""
536 | for y in range (OUTPUT_SIZE):
537 | for x in range (OUTPUT_SIZE):
538 | curr_label = segments[x,y]
539 | slic_pixels[curr_label].append(x * OUTPUT_SIZE + y)
540 | #each slic seg's x y axis
541 | classes = None
542 |
543 | """
544 | For each superpixel's tile, select the PyTorch Variable(FCN) inside.
545 | These Variables(FCN) combines into new Nodes Variable(GCN) while
546 | -carrying gradients.
547 | """
548 | for n in slic_pixels:
549 | index_tensor = torch.LongTensor(n)
550 | if use_cuda:
551 | index_tensor = index_tensor.to('cuda')
552 | one_class = torch.unsqueeze(torch.sum( \
553 | torch.index_select(original_one_graph, dim = 0, index = index_tensor), \
554 | 0), dim = 0)
555 | if classes is None:
556 | classes = one_class
557 | else:
558 | classes = torch.cat((classes, one_class), 0)
559 | one_class = None
560 | index_tensor = None
561 | original_one_graph = None
562 | temp_ind = 0
563 | adj = np.asarray(adj)
564 | adj = torch.from_numpy(adj)
565 | adj = adj.t().contiguous()
566 | adj = Variable(adj).type(torch.LongTensor)
567 |
568 | """datagcn: GCN-ready wrapped data for one image
569 | gcn_batch_list: GCN-ready minibatch containings same images from
570 | -previous FCN's minibatch."""
571 | datagcn = Data(x = classes, edge_index = adj, \
572 | edge_weight = edge_weight.type(torch.FloatTensor))
573 |
574 | gcn_batch_list.append(datagcn)
575 | #print(gcn_batch_list)
576 | classes = None #releases cached GPU memory immediately
577 | adj = None
578 | datagcn = None
579 | if torch.cuda.device_count() == 1:
580 | gcn_batch_list = Batch.from_data_list(gcn_batch_list)
581 | if use_cuda:
582 | gcn_batch_list = gcn_batch_list.to('cuda')
583 | """GCN training iterations"""
584 | if inference_mode:
585 | modelgcn.train()
586 | else:
587 | modelgcn.train()
588 | print(gt_percent.data.cpu().numpy())
589 | print(moe.data.cpu().numpy())
590 | for epochgcn in range(0, gcn_batch_iter):
591 | optimizergcn.zero_grad()
592 | outgcn = modelgcn(gcn_batch_list)
593 |
594 | """visualize GCN output"""
595 | if not inference_mode and epochgcn == gcn_batch_iter - 1 and batch_counter % 10 == 0:
596 | start_index = 0
597 | counter = 0
598 | #modelgcn.eval()
599 | #outgcn = modelgcn(gcn_batch_list)
600 | for curr_batch_idx in range(len(name)):
601 | outgcn_slice = torch.narrow(input = outgcn, dim = 0, \
602 | start = start_index, \
603 | length = batch_node_num[curr_batch_idx])
604 | start_index += batch_node_num[curr_batch_idx]
605 | outputgcn_np = outgcn_slice.detach().data.cpu().numpy()
606 | segments_copy = segments_list[curr_batch_idx].copy()
607 | segments_copy = segments_copy.astype(np.float64)
608 | for segInd in range(len(outputgcn_np)):
609 | segments_copy[segments_copy == segInd] = outputgcn_np[segInd]
610 | gcn_target_rgb = np.array([[255*(c + 1) / 2, 255*(c + 1) / 2, 255*(c + 1) / 2] \
611 | for c in segments_copy])
612 | gcn_target_rgb = np.moveaxis(gcn_target_rgb, 1, 2)
613 | gcn_target_rgb = gcn_target_rgb.reshape( (OUTPUT_SIZE,OUTPUT_SIZE,3) ).astype( np.uint8 )
614 | if args.visualize:
615 | cv2.imshow("gcn", gcn_target_rgb)
616 | cv2.waitKey(1)
617 | else:
618 | cv2.imwrite(os.path.join(SAVE_PATH, 'GCN' + str(batch_counter) + "N" + str(counter)\
619 | + "_" + str(name[curr_batch_idx][-16:-4]) + '.png'), \
620 | cv2.cvtColor(gcn_target_rgb, cv2.COLOR_RGB2BGR))
621 | counter += 1
622 | outgcn_slice = None
623 |
624 | if not inference_mode:
625 | """
626 | loss_top --cost, positive % of nodes responded correctly
627 | loss_bottom --cost, negative <1-gt_percent-moe> % of nodes responded correctly
628 | positive refers to desired region(cancer), negative refers to other regions(background)
629 | """
630 | loss_top, loss_bottom = one_label_loss(gt_percent = gt_percent.data.cpu().numpy(), \
631 | predict = outgcn, \
632 | moe = moe.data.cpu().numpy(), \
633 | batch_node_num = batch_node_num)
634 | if loss_top is None:
635 | total_gcn_loss = loss_bottom
636 | elif loss_bottom is None:
637 | total_gcn_loss = loss_top
638 | else:
639 | total_gcn_loss = loss_top + loss_bottom
640 | if half_precision:
641 | with amp.scale_loss(total_gcn_loss, optimizergcn) as scaled_loss2:
642 | scaled_loss2.backward(retain_graph=True)
643 | else:
644 | total_gcn_loss.backward(retain_graph=True)
645 | #backward calculating GCN gradients according to combined loss
646 | #print(total_gcn_loss)
647 | print("GCN+FCN loss: " + str(total_gcn_loss.data.cpu().numpy()))
648 | if not inference_mode:
649 | #backpropagate through GCN's layers
650 | optimizergcn.step()
651 | #backpropagate accumulated gradients through FCN's layers
652 | optimizer.step()
653 | #saving models & optimizers state_dict for later training and inference
654 | #cpu = torch.device("cpu")
655 | #model = model.to(cpu)
656 | #modelgcn = modelgcn.to(cpu)
657 | torch.save(model.module.state_dict(), os.path.join(SD_SAVE_PATH,"FCN" + str(batch_counter) + ".pt"))
658 | torch.save(modelgcn.module.state_dict(), os.path.join(SD_SAVE_PATH,"GCN" + str(batch_counter) + ".pt"))
659 | torch.save(optimizer.state_dict(), os.path.join(SD_SAVE_PATH,"FCNopt" + str(batch_counter) + ".pt"))
660 | torch.save(optimizergcn.state_dict(), os.path.join(SD_SAVE_PATH,"GCNopt" + str(batch_counter) + ".pt"))
661 | #model = model.to('cuda')
662 | #modelgcn = modelgcn.to('cuda')
663 | if inference_mode:
664 | start_index = 0
665 | counter = 0
666 | final_map = np.zeros((args.output_size, args.output_size))
667 | multi_input = []
668 | print("fusing")
669 | for input_index in range(len(segments_list)):
670 | outgcn_numpy = torch.narrow(input = outgcn, dim = 0, start = start_index, length = batch_node_num[input_index]).detach().data.cpu().numpy()
671 | segments_copy = segments_list[input_index].astype(np.float64)
672 | multi_input.append((outgcn_numpy, segments_copy))
673 | start_index += batch_node_num[input_index]
674 | with Pool(args.cpu_threads) as p:
675 | multi_graph = p.map(fuse_results, multi_input)
676 | for graph in multi_graph:
677 | final_map += graph
678 |
679 | final_map = final_map / float(len(multi_graph))
680 | final_map += 1.0
681 | final_map = final_map / 2.0
682 | final_map[final_map < args.fuse_thresh] = 0
683 |
684 | gcn_target_rgb = np.array([[255 * c , 255* c , 255* c] for c in final_map])
685 | gcn_target_rgb = np.moveaxis(gcn_target_rgb, 1, 2)
686 | gcn_target_rgb = gcn_target_rgb.reshape( (args.output_size,args.output_size,3) ).astype( np.uint8 )
687 | if args.visualize:
688 | cv2.imshow("gcn", gcn_target_rgb)
689 | cv2.waitKey(1)
690 | else:
691 | basename = os.path.basename(name[0])
692 | cv2.imwrite(os.path.join(args.inference_path, basename), cv2.cvtColor(gcn_target_rgb, cv2.COLOR_RGB2BGR))
693 | print("inference for", str(basename), "saved to", str(args.inference_path))
694 |
--------------------------------------------------------------------------------
/inference/001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/inference/001.png
--------------------------------------------------------------------------------
/inference/002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/inference/002.png
--------------------------------------------------------------------------------
/inference/003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/inference/003.png
--------------------------------------------------------------------------------
/inference/004.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/inference/004.png
--------------------------------------------------------------------------------
/inference/005.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/inference/005.png
--------------------------------------------------------------------------------
/inference/006.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/inference/006.png
--------------------------------------------------------------------------------
/inference/007.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/inference/007.png
--------------------------------------------------------------------------------
/input_data/001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/input_data/001.png
--------------------------------------------------------------------------------
/input_data/002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/input_data/002.png
--------------------------------------------------------------------------------
/input_data/003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/input_data/003.png
--------------------------------------------------------------------------------
/input_data/004.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/input_data/004.png
--------------------------------------------------------------------------------
/input_data/005.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/input_data/005.png
--------------------------------------------------------------------------------
/input_data/006.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/input_data/006.png
--------------------------------------------------------------------------------
/input_data/007.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/input_data/007.png
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.nn import GCNConv
3 | import torch.nn.functional as F
4 | from torch_geometric.data import Data, Batch
5 | import torch.nn as nn
6 |
7 | class FCN(nn.Module):
8 | """
9 | Proposed Fully Convolutional Network
10 | This function/module uses fully convolutional blocks to extract pixel-wise image features.
11 | Tested on 1024*1024, 512*512 resolution; RGB, Immunohistochemical color channels
12 |
13 | Keyword arguments:
14 | input_dim -- input channel, 3 for RGB images (default)
15 | """
16 | def __init__(self,input_dim, output_classes, p_mode = 'replicate'):
17 | super(FCN, self).__init__()
18 | #self.Dropout = nn.Dropout(p=0.05)
19 | self.conv1 = nn.Conv2d(input_dim, 32, kernel_size=3, stride=1, padding=1 ,padding_mode=p_mode)
20 | self.bn1 = nn.BatchNorm2d(32)
21 |
22 | self.conv2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, padding_mode=p_mode)
23 | self.bn2 = nn.BatchNorm2d(32)
24 |
25 | self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, padding_mode=p_mode)
26 | self.bn3 = nn.BatchNorm2d(64)
27 |
28 | self.conv4 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, padding_mode=p_mode)
29 | self.bn4 = nn.BatchNorm2d(64)
30 |
31 | self.conv5 = nn.Conv2d(64, output_classes, kernel_size=1, stride=1, padding=0)
32 |
33 | #self.Dropout = nn.Dropout(p=0.3)
34 |
35 | def forward(self, x):
36 | x = self.conv1(x)
37 | x = F.relu(x)
38 | x = self.bn1(x)
39 | #x = self.Dropout(x)
40 |
41 | x = self.conv2(x)
42 | x = F.relu(x)
43 | x = self.bn2(x)
44 | #x = self.Dropout(x)
45 |
46 | x = self.conv3(x)
47 | x = F.relu(x)
48 | x = self.bn3(x)
49 | #x = self.Dropout(x)
50 |
51 | x = self.conv4(x)
52 | x = F.relu(x)
53 | x = self.bn4(x)
54 |
55 | x = self.conv5(x)
56 | return x
57 |
58 |
59 | class GCN(torch.nn.Module):
60 | """
61 | Proposed Graph Convolutional Network
62 | This function/module uses classic GCN layers to generate superpixels(nodes) classification.
63 | --"Semi-Supervised Classification with Graph Convolutional Networks",
64 | --Thomas N. Kipf, Max Welling, ICLR2017
65 |
66 | Keyword arguments:
67 | input_dim -- input channel, aligns with output channel from FCN
68 | output_classes --output channel, default 1 for our proposed loss function
69 | """
70 | def __init__(self, input_dim, output_classes):
71 | super(GCN, self).__init__()
72 | self.conv1 = GCNConv(input_dim, 64)
73 | self.conv2 = GCNConv(64, 128)
74 | self.conv3 = GCNConv(128, 256)
75 | self.conv4 = GCNConv(256, 64)
76 | self.conv5 = GCNConv(64, output_classes)
77 | #self.Dropout = nn.Dropout(p=0.5)
78 |
79 | # self.bn1 = nn.BatchNorm1d(64)
80 | # self.bn2 = nn.BatchNorm1d(128)
81 | # self.bn3 = nn.BatchNorm1d(256)
82 | # self.bn4 = nn.BatchNorm1d(64)
83 | #
84 | # self.lin1 = Linear(64, 256)
85 | # self.lin2 = Linear(256, 128)
86 | # self.lin3 = Linear(128, output_classes)
87 |
88 | def forward(self, data):
89 | x = self.conv1(data.x, edge_index = data.edge_index, edge_weight = data.edge_weight)
90 | x = F.relu(x)
91 | #x = self.Dropout(x)
92 | #x = self.bn1(x)
93 |
94 | x = self.conv2(x, edge_index = data.edge_index, edge_weight = data.edge_weight)
95 | x = F.relu(x)
96 | #x = self.Dropout(x)
97 | #x = self.bn2(x)
98 |
99 | x = self.conv3(x, edge_index = data.edge_index, edge_weight = data.edge_weight)
100 | x = F.relu(x)
101 | #x = self.Dropout(x)
102 | #x = self.bn3(x)
103 |
104 | x = self.conv4(x, edge_index = data.edge_index, edge_weight = data.edge_weight)
105 | x = F.relu(x)
106 | #x = self.bn4(x)
107 |
108 | x = self.conv5(x, edge_index = data.edge_index, edge_weight = data.edge_weight)
109 |
110 |
111 | return torch.tanh(x)
112 |
--------------------------------------------------------------------------------
/state_dict/FCN1000.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/state_dict/FCN1000.pt
--------------------------------------------------------------------------------
/state_dict/FCNopt1000.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/state_dict/FCNopt1000.pt
--------------------------------------------------------------------------------
/state_dict/GCN1000.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/state_dict/GCN1000.pt
--------------------------------------------------------------------------------
/state_dict/GCNopt1000.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/state_dict/GCNopt1000.pt
--------------------------------------------------------------------------------
/train_process_files/.keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangjun001/FGNet/a601b84d55f0580061ce2bfdcd9dede024ece540/train_process_files/.keep
--------------------------------------------------------------------------------
/utilities.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch
4 | from torch_geometric.data import Data, Batch
5 | from torch.autograd import Variable
6 | from skimage.segmentation import slic, mark_boundaries
7 | from skimage import io
8 | import cv2
9 | import sys
10 | import numpy as np
11 | from skimage import segmentation
12 | import matplotlib.pyplot as plt
13 | import os
14 | from collections import defaultdict
15 | use_cuda = torch.cuda.is_available()
16 |
17 |
18 | def rotate(image,angle):
19 | return np.rot90(image,angle)
20 |
21 | def rank(l):
22 | """
23 | Rank/Frequency Counter
24 | This function efficiently calculates number rank/frequency from a ndarray
25 | e.g. rank([3,5,2,2,2,2,1,1]) -> ((4,2), (2,1), (1,3), (1,5))
26 |
27 |
28 | Keyword arguments:
29 | l --a 1-dimensional numpy array
30 | """
31 | d = defaultdict(int)
32 | for i in l:
33 | d[i] += 1
34 | return sorted(d.items(), key=lambda x: x[1], reverse=True)
35 |
36 |
37 | def adjacency(sMask):
38 | """
39 | Adjacency Matrix
40 | This function (not so efficiently) finds bidirectional adjacency edges
41 | -from a SLIC superpixel mask.
42 |
43 |
44 | Keyword arguments:
45 | sMask --a 2-dimensional numpy array generated by SLIC algorithm
46 | """
47 | curr = 0
48 | adj = []
49 | for y in range(sMask.shape[0]):
50 | #print(y)
51 | for x in range(sMask.shape[1]):
52 | if x >= sMask.shape[1] - 1: #reach end of row
53 | if y >= sMask.shape[0] - 1: #reach end of graph
54 | return adj
55 | else: #switch to new line, update curr
56 | curr = sMask[0, y+1]
57 | continue
58 | else: #still iterating row
59 | if sMask[x,y] != curr:
60 | if (curr, sMask[x,y]) not in adj:
61 | #new edge not in adjacency list
62 | adj.append((curr, sMask[x,y]))
63 | adj.append((sMask[x,y], curr)) #bidirectional
64 | curr = sMask[x,y]
65 | else: #already in adj
66 | curr = sMask[x,y]
67 | continue
68 | else:
69 | continue
70 | return adj
71 |
72 |
73 | def adjacency3(sMask, mask_length):
74 | """
75 | Adjacency Matrix
76 | This function efficiently finds bidirectional adjacency edges
77 | -from a SLIC superpixel mask at the cost of higher RAM/cache usage
78 |
79 |
80 | Keyword arguments:
81 | sMask --a 2-dimensional numpy array generated by SLIC algorithm
82 | mask_length --number of unique SLIC superpixels in sMask
83 | """
84 | curr = 0
85 | adj = []
86 | edge_visited = np.full((mask_length, mask_length), False)
87 | for y in range(sMask.shape[0]):
88 | for x in range(sMask.shape[1]):
89 | if x >= sMask.shape[1] - 1: #reach end of row
90 | if y >= sMask.shape[0] - 1: #reach end of graph
91 | return adj
92 | else: #switch to new line, update curr
93 | curr = sMask[0, y+1]
94 | continue
95 | else: #still iterating row
96 | target = sMask[x,y]
97 | if target != curr:
98 | if edge_visited[curr, target] == False or \
99 | edge_visited[target, curr] == False:
100 | #new edge not in adjacency list
101 | adj.append((curr, target))
102 | adj.append((target, curr)) #bidirectional
103 | edge_visited[curr, target] = True
104 | edge_visited[target, curr] = True
105 | curr = sMask[x,y]
106 | else: #already in adj
107 | curr = sMask[x,y]
108 | continue
109 | else:
110 | continue
111 | return adj
112 |
113 |
114 | def center(pixel):
115 | """
116 | Center of Superpixels
117 | This function efficiently finds weighted center of SLIC superpixel
118 | -tiles
119 |
120 |
121 | Keyword arguments:
122 | pixel --(x,y) tuple pixel corrdinates list from a SLIC superpixel tile.
123 | """
124 | x_sorted = sorted(pixel.copy())
125 | y_sorted = sorted(pixel, key=lambda tup: tup[1]) #sorted by y
126 | assert(len(x_sorted) == len(y_sorted))
127 | mid = int(len(x_sorted) / 2)
128 | mid_x = int(x_sorted[mid][0])
129 | mid_y = int(y_sorted[mid][1]) #x, y individual middle points
130 | return (mid_x, mid_y)
131 |
132 |
133 | def euclidean_dist(a, b):
134 | """
135 | Euclidean Distance
136 | Calculate Eucliden Distance between two matricies.
137 |
138 |
139 | Keyword arguments:
140 | a --first matrix's ndarray
141 | b --second matrix's ndarray, must match input a's shape
142 | """
143 | return np.linalg.norm(a-b)
144 |
145 |
146 | def mse(a,b):
147 | """
148 | Mean Squared Error
149 | Calculate Mean Squared Error between two matricies.
150 |
151 |
152 | Keyword arguments:
153 | a --first matrix's ndarray
154 | b --second matrix's ndarray, must match input a's shape
155 | """
156 | return ((a - b)**2).mean(axis=None)
157 |
158 |
159 | def fnorm(a, b):
160 | """
161 | Frobenius Norm Distance
162 | Calculate Frobenius Norm Distance between two matricies.
163 |
164 |
165 | Keyword arguments:
166 | a --first matrix's ndarray
167 | b --second matrix's ndarray, must match input a's shape
168 | """
169 | a_ss = np.sum(a = np.square(a)) #sum of squares
170 | b_ss = np.sum(a = np.square(b))
171 | return np.sqrt(np.absolute(a_ss - b_ss))
172 |
173 |
174 | def chisq_dist(a, b, gamma = 1):
175 | """
176 | Chi-Square Distance
177 | Calculate Chi-Square Distance between two matricies.
178 |
179 |
180 | Keyword arguments:
181 | a --first matrix's ndarray
182 | b --second matrix's ndarray, must match input a's shape
183 | """
184 | numerator = np.sum(np.square(a-b))
185 | denomenator = np.sum(a+b)
186 | dist = 0.5 * numerator / denomenator
187 | return np.exp(- gamma * dist)
188 |
189 |
190 | def one_label_loss(gt_percent, predict, moe, batch_node_num):
191 | """
192 | Proposed Loss Function
193 | Our proposed Loss Functions calculates cost of training batch using
194 | -GCN's output graphs and weak image level annotations.
195 | For more information, please refer to our paper.
196 |
197 |
198 | Keyword arguments:
199 | gt_percent --Ground-Trueth percent, a weak image-level annotation
200 | predict --GCN module output, gradient required
201 | moe --Margin of Error, a weak image-level annotation
202 | batch_node_num --integer list of node numbers per image in batch
203 | """
204 | curr_index = 0
205 | batch_top_k_loss = []
206 | batch_bottom_k_loss = []
207 | batch_pairwise_loss = []
208 | positive_num = 0.00000001
209 | negative_num = 0.00000001
210 | for i in range(len(gt_percent)):
211 | total_length = batch_node_num[i] #one graph length
212 | predict_slice = torch.narrow(input = predict, dim = 0, start = curr_index, length = total_length)
213 | curr_index += total_length
214 | one_gt_percent = gt_percent[i]
215 | one_moe = moe[i]
216 | select = torch.tensor([0])
217 | if use_cuda:
218 | select = select.to('cuda')
219 |
220 | threshold_ceil = int(total_length * (one_gt_percent - one_moe)) #100 * (0.8 - 0.1) = top 70 %
221 | if threshold_ceil < 0:
222 | threshold_ceil = 0
223 | threshold_floor = int(total_length * (1.0 - one_gt_percent - one_moe)) #100 * (1 - 0.8 - 0.1) = bottom 10 %
224 | if threshold_floor < 0:
225 | threshold_floor = 0
226 |
227 | top_k, _ = torch.topk(input = predict_slice, k = threshold_ceil, dim = 0, largest = True, sorted = False)
228 | bottom_k, _ = torch.topk(input = predict_slice, k = threshold_floor, dim = 0, largest = False, sorted = False)
229 |
230 | top_k_mean = torch.mean(top_k,dim=0)
231 | bottom_k_mean = torch.mean(bottom_k,dim=0)
232 |
233 | predict_slice = None
234 | top_k = None
235 | select = None
236 | bottom_k = None
237 | loss_fn = nn.SmoothL1Loss()
238 | if use_cuda:
239 | temp_ones = torch.ones(1, dtype = torch.float).to('cuda')
240 | temp_zeros = torch.tensor([-1], dtype = torch.float).to('cuda')
241 | temp_ground = torch.zeros(1, dtype = torch.float).to('cuda')
242 | if threshold_ceil > 0:
243 | #top_k_loss = F.l1_loss(top_k_mean, temp_ones)
244 | top_k_loss = loss_fn(top_k_mean, temp_ones)
245 | positive_num += top_k_loss.detach().cpu().numpy()
246 | else:
247 | top_k_loss = None
248 |
249 | if threshold_floor > 0:
250 | #bottom_k_loss = F.l1_loss(bottom_k_mean, temp_zeros)
251 | bottom_k_loss = loss_fn(bottom_k_mean, temp_zeros)
252 | negative_num += bottom_k_loss.detach().cpu().numpy()
253 | else:
254 | bottom_k_loss = None
255 | temp_ones = None
256 | temp_zeors = None
257 | else:
258 | if threshold_ceil > 0:
259 | #top_k_loss = F.l1_loss(top_k_mean, torch.ones(1, dtype = torch.float))
260 | top_k_loss = loss_fn(top_k_mean, torch.ones(1, dtype = torch.float))
261 | positive_num += 1.0
262 | else:
263 | top_k_loss = None
264 |
265 | if threshold_floor > 0:
266 | #bottom_k_loss = F.l1_loss(bottom_k_mean, torch.zeros(1, dtype = torch.float))
267 | bottom_k_loss = loss_fn(bottom_k_mean, torch.zeros(1, dtype = torch.float))
268 | negative_num += 1.0
269 | else:
270 | bottom_k_loss = None
271 | batch_top_k_loss.append(top_k_loss)
272 | batch_bottom_k_loss.append(bottom_k_loss)
273 | top_k_loss = None
274 | bottom_k_loss = None
275 | pairwise_loss = None
276 | print("-------------------------------------------------------------------------------")
277 | print("Targeted Regions Losses Per Image")
278 | print([round(float(x.data.cpu().detach().numpy()),2) if x is not None else -1.00 for x in batch_top_k_loss])
279 | print("Background Regions Losses Per Image")
280 | print([round(float(x.data.cpu().detach().numpy()),2) if x is not None else -1.00 for x in batch_bottom_k_loss])
281 | print("-------------------------------------------------------------------------------")
282 |
283 | for t, b, g, a in zip(batch_top_k_loss, batch_bottom_k_loss, gt_percent, moe):
284 | if top_k_loss is None and t is not None:
285 | top_k_loss = (g - a) * t
286 | elif t is not None:
287 | top_k_loss += (g - a) * t
288 | if bottom_k_loss is None and b is not None:
289 | bottom_k_loss = (1.0 - g - a) * b
290 | elif b is not None:
291 | bottom_k_loss += (1.0 - g - a) * b
292 | return top_k_loss, bottom_k_loss
293 |
294 |
295 |
296 |
297 | def plot_grad_flow(named_parameters):
298 | """
299 | Utility Function-Visualize Gradient Flow
300 | This utility function can assist in checking gradient flow between layers.
301 |
302 |
303 | Keyword arguments:
304 | named_parameters --module's parameter()
305 | """
306 | ave_grads = []
307 | layers = []
308 | for n, p in named_parameters:
309 | if(p.requires_grad) and ("bias" not in n):
310 | layers.append(n)
311 | ave_grads.append(p.grad.abs().mean())
312 | plt.plot(ave_grads, alpha=0.3, color="b")
313 | plt.hlines(0, 0, len(ave_grads)+1, linewidth=1, color="k" )
314 | plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical")
315 | plt.xlim(xmin=0, xmax=len(ave_grads))
316 | plt.xlabel("Layers")
317 | plt.ylabel("average gradient")
318 | plt.title("Gradient flow")
319 | plt.grid(True)
320 | plt.show()
321 |
322 | def fuse_results(multi_args):
323 | (outgcn_numpy, segments_copy) = multi_args
324 | for segInd in range(len(outgcn_numpy)):
325 | segments_copy[segments_copy == segInd] = outgcn_numpy[segInd]
326 | return segments_copy
327 |
--------------------------------------------------------------------------------