├── .gitignore
├── LICENSE
├── README.md
├── dataset
├── __init__.py
├── ade.py
├── cityscapes_domain.py
├── consep.py
├── monusac.py
├── transform.py
├── utils.py
└── voc.py
├── docs
└── framework.png
├── models
├── __init__.py
├── resnet.py
└── util.py
└── modules
├── __init__.py
├── deeplab.py
├── fca_cid.py
├── misc.py
├── residual.py
└── unet.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 why19991
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # InSeg
2 | ### Incremental Nuclei Segmentation from Histopathological Images
3 | ### via Future-class Awareness and Compatibility-inspired Distillation [CVPR 2024]
4 | ###### Huyong Wang1, Huisi Wu1, Jing Qin2
5 | ###### 1College of Computer Science and Software Engineering, Shenzhen University
6 | ###### 2Centre for Smart Health, School of Nursing, The Hong Kong Polytechnic University
7 |
8 | # Method
9 |
10 |

11 |
12 |
13 | # Notes
14 | Due to the confidentiality agreement in commercial cooperation, we only provide codes of core modules and the whole trainable models for the convenience of comparisons.
15 |
16 | # Environment
17 | * Python>=3.8
18 | * Pytorch>=1.8.1
19 | * Install [inplace_abn](https://github.com/mapillary/inplace_abn)
20 | # Thanks
21 | This code is heavily borrowed from [[EWF](https://github.com/schuy1er/EWF_official)] and [[PLOP](https://github.com/arthurdouillard/CVPR2021_PLOP)]. We appreciate their contributions to this community.
22 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .ade import AdeSegmentation, AdeSegmentationIncremental
2 | from .cityscapes_domain import (CityscapesSegmentationDomain,
3 | CityscapesSegmentationIncrementalDomain)
4 | from .voc import VOCSegmentation, VOCSegmentationIncremental
5 |
--------------------------------------------------------------------------------
/dataset/ade.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import numpy as np
5 | import torch.utils.data as data
6 | import torchvision as tv
7 | from PIL import Image
8 | from torch import distributed
9 |
10 | from .utils import Subset, filter_images, group_images
11 |
12 | classes = [
13 | "void", "wall", "building", "sky", "floor", "tree", "ceiling", "road", "bed ", "windowpane",
14 | "grass", "cabinet", "sidewalk", "person", "earth", "door", "table", "mountain", "plant",
15 | "curtain", "chair", "car", "water", "painting", "sofa", "shelf", "house", "sea", "mirror",
16 | "rug", "field", "armchair", "seat", "fence", "desk", "rock", "wardrobe", "lamp", "bathtub",
17 | "railing", "cushion", "base", "box", "column", "signboard", "chest of drawers", "counter",
18 | "sand", "sink", "skyscraper", "fireplace", "refrigerator", "grandstand", "path", "stairs",
19 | "runway", "case", "pool table", "pillow", "screen door", "stairway", "river", "bridge",
20 | "bookcase", "blind", "coffee table", "toilet", "flower", "book", "hill", "bench", "countertop",
21 | "stove", "palm", "kitchen island", "computer", "swivel chair", "boat", "bar", "arcade machine",
22 | "hovel", "bus", "towel", "light", "truck", "tower", "chandelier", "awning", "streetlight",
23 | "booth", "television receiver", "airplane", "dirt track", "apparel", "pole", "land",
24 | "bannister", "escalator", "ottoman", "bottle", "buffet", "poster", "stage", "van", "ship",
25 | "fountain", "conveyer belt", "canopy", "washer", "plaything", "swimming pool", "stool",
26 | "barrel", "basket", "waterfall", "tent", "bag", "minibike", "cradle", "oven", "ball", "food",
27 | "step", "tank", "trade name", "microwave", "pot", "animal", "bicycle", "lake", "dishwasher",
28 | "screen", "blanket", "sculpture", "hood", "sconce", "vase", "traffic light", "tray", "ashcan",
29 | "fan", "pier", "crt screen", "plate", "monitor", "bulletin board", "shower", "radiator",
30 | "glass", "clock", "flag"
31 | ]
32 |
33 |
34 | class AdeSegmentation(data.Dataset):
35 |
36 | def __init__(self, root, train=True, transform=None):
37 |
38 | root = os.path.expanduser(root)
39 | base_dir = "ADEChallengeData2016"
40 | ade_root = os.path.join(root, base_dir)
41 | if train:
42 | split = 'training'
43 | else:
44 | split = 'validation'
45 | annotation_folder = os.path.join(ade_root, 'annotations', split)
46 | image_folder = os.path.join(ade_root, 'images', split)
47 |
48 | self.images = []
49 | fnames = sorted(os.listdir(image_folder))
50 | self.images = [
51 | (os.path.join(image_folder, x), os.path.join(annotation_folder, x[:-3] + "png"))
52 | for x in fnames
53 | ]
54 |
55 | self.transform = transform
56 |
57 | def __getitem__(self, index):
58 | """
59 | Args:
60 | index (int): Index
61 | Returns:
62 | tuple: (image, target) where target is the image segmentation.
63 | """
64 | img = Image.open(self.images[index][0]).convert('RGB')
65 | target = Image.open(self.images[index][1])
66 |
67 | if self.transform is not None:
68 | img, target = self.transform(img, target)
69 |
70 | return img, target
71 |
72 | def __len__(self):
73 | return len(self.images)
74 |
75 |
76 | class AdeSegmentationIncremental(data.Dataset):
77 |
78 | def __init__(
79 | self,
80 | root,
81 | train=True,
82 | transform=None,
83 | labels=None,
84 | labels_old=None,
85 | idxs_path=None,
86 | masking=True,
87 | overlap=True,
88 | data_masking="current",
89 | ignore_test_bg=False,
90 | **kwargs
91 | ):
92 |
93 | full_data = AdeSegmentation(root, train)
94 |
95 | self.labels = []
96 | self.labels_old = []
97 |
98 | if labels is not None:
99 | # store the labels
100 | labels_old = labels_old if labels_old is not None else []
101 |
102 | self.__strip_zero(labels)
103 | self.__strip_zero(labels_old)
104 |
105 | assert not any(
106 | l in labels_old for l in labels
107 | ), "labels and labels_old must be disjoint sets"
108 |
109 | self.labels = labels
110 | self.labels_old = labels_old
111 |
112 | self.order = [0] + labels_old + labels
113 |
114 | # take index of images with at least one class in labels and all classes in labels+labels_old+[255]
115 | if idxs_path is not None and os.path.exists(idxs_path):
116 | idxs = np.load(idxs_path).tolist()
117 | else:
118 | idxs = filter_images(full_data, labels, labels_old, overlap=overlap)
119 | if idxs_path is not None and distributed.get_rank() == 0:
120 | np.save(idxs_path, np.array(idxs, dtype=int))
121 |
122 | #if train:
123 | # masking_value = 0
124 | #else:
125 | # masking_value = 255
126 |
127 | #self.inverted_order = {label: self.order.index(label) for label in self.order}
128 | #self.inverted_order[0] = masking_value
129 |
130 | self.inverted_order = {label: self.order.index(label) for label in self.order}
131 | if ignore_test_bg:
132 | masking_value = 255
133 | self.inverted_order[0] = masking_value
134 | else:
135 | masking_value = 0 # Future classes will be considered as background.
136 | self.inverted_order[255] = 255
137 |
138 | reorder_transform = tv.transforms.Lambda(
139 | lambda t: t.apply_(
140 | lambda x: self.inverted_order[x] if x in self.inverted_order else masking_value
141 | )
142 | )
143 |
144 | if masking:
145 | target_transform = tv.transforms.Lambda(
146 | lambda t: t.
147 | apply_(lambda x: self.inverted_order[x] if x in self.labels else masking_value)
148 | )
149 | else:
150 | target_transform = reorder_transform
151 |
152 | # make the subset of the dataset
153 | self.dataset = Subset(full_data, idxs, transform, target_transform)
154 | else:
155 | self.dataset = full_data
156 |
157 | def __getitem__(self, index):
158 | """
159 | Args:
160 | index (int): Index
161 | Returns:
162 | tuple: (image, target) where target is the image segmentation.
163 | """
164 |
165 | return self.dataset[index]
166 |
167 | @staticmethod
168 | def __strip_zero(labels):
169 | while 0 in labels:
170 | labels.remove(0)
171 |
172 | def __len__(self):
173 | return len(self.dataset)
174 |
--------------------------------------------------------------------------------
/dataset/cityscapes_domain.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import glob
3 | import os
4 |
5 | import numpy as np
6 | import torch.utils.data as data
7 | import torchvision as tv
8 | from PIL import Image
9 | from torch import distributed
10 |
11 | from .utils import Subset, group_images
12 |
13 |
14 | # Converting the id to the train_id. Many objects have a train id at
15 | # 255 (unknown / ignored).
16 | # See there for more information:
17 | # https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py
18 | id_to_trainid = {
19 | 0: 255,
20 | 1: 255,
21 | 2: 255,
22 | 3: 255,
23 | 4: 255,
24 | 5: 255,
25 | 6: 255,
26 | 7: 0, # road
27 | 8: 1, # sidewalk
28 | 9: 255,
29 | 10: 255,
30 | 11: 2, # building
31 | 12: 3, # wall
32 | 13: 4, # fence
33 | 14: 255,
34 | 15: 255,
35 | 16: 255,
36 | 17: 5, # pole
37 | 18: 255,
38 | 19: 6, # traffic light
39 | 20: 7, # traffic sign
40 | 21: 8, # vegetation
41 | 22: 9, # terrain
42 | 23: 10, # sky
43 | 24: 11, # person
44 | 25: 12, # rider
45 | 26: 13, # car
46 | 27: 14, # truck
47 | 28: 15, # bus
48 | 29: 255,
49 | 30: 255,
50 | 31: 16, # train
51 | 32: 17, # motorcycle
52 | 33: 18, # bicycle
53 | -1: 255
54 | }
55 |
56 | city_to_id = {
57 | "aachen": 0, "bremen": 1, "darmstadt": 2, "erfurt": 3, "hanover": 4,
58 | "krefeld": 5, "strasbourg": 6, "tubingen": 7, "weimar": 8, "bochum": 9,
59 | "cologne": 10, "dusseldorf": 11, "hamburg": 12, "jena": 13,
60 | "monchengladbach": 14, "stuttgart": 15, "ulm": 16, "zurich": 17,
61 | "frankfurt": 18, "lindau": 19, "munster": 20
62 | }
63 |
64 |
65 | def filter_images(dataset, labels):
66 | # Filter images without any label in LABELS (using labels not reordered)
67 | idxs = []
68 |
69 | print(f"Filtering images...")
70 | for i in range(len(dataset)):
71 | domain_id = dataset.__getitem__(i, get_domain=True) # taking domain id
72 | if domain_id in labels:
73 | idxs.append(i)
74 | if i % 1000 == 0:
75 | print(f"\t{i}/{len(dataset)} ...")
76 | return idxs
77 |
78 |
79 | class CityscapesSegmentationDomain(data.Dataset):
80 |
81 | def __init__(self, root, train=True, transform=None, domain_transform=None):
82 | root = os.path.expanduser(root)
83 | annotation_folder = os.path.join(root, 'gtFine')
84 | image_folder = os.path.join(root, 'leftImg8bit')
85 |
86 | self.images = [ # Add train cities
87 | (
88 | path,
89 | os.path.join(
90 | annotation_folder,
91 | "train",
92 | path.split("/")[-2],
93 | path.split("/")[-1][:-15] + "gtFine_labelIds.png"
94 | ),
95 | city_to_id[path.split("/")[-2]]
96 | ) for path in sorted(glob.glob(os.path.join(image_folder, "train/*/*.png")))
97 | ]
98 | self.images += [ # Add validation cities
99 | (
100 | path,
101 | os.path.join(
102 | annotation_folder,
103 | "val",
104 | path.split("/")[-2],
105 | path.split("/")[-1][:-15] + "gtFine_labelIds.png"
106 | ),
107 | city_to_id[path.split("/")[-2]]
108 | ) for path in sorted(glob.glob(os.path.join(image_folder, "val/*/*.png")))
109 | ]
110 |
111 | self.transform = transform
112 | self.domain_transform = domain_transform
113 |
114 | def __getitem__(self, index, get_domain=False):
115 | """
116 | Args:
117 | index (int): Index
118 | Returns:
119 | tuple: (image, target) where target is the image segmentation.
120 | """
121 | if get_domain:
122 | domain = self.images[index][2]
123 | if self.domain_transform is not None:
124 | domain = self.domain_transform(domain)
125 | return domain
126 |
127 | try:
128 | img = Image.open(self.images[index][0]).convert('RGB')
129 | target = Image.open(self.images[index][1])
130 | except Exception as e:
131 | raise Exception(f"Index: {index}, len: {len(self)}, message: {str(e)}")
132 |
133 | if self.transform is not None:
134 | img, target = self.transform(img, target)
135 |
136 | return img, target
137 |
138 | def __len__(self):
139 | return len(self.images)
140 |
141 |
142 | class CityscapesSegmentationIncrementalDomain(data.Dataset):
143 | """Labels correspond to domains not classes in this case."""
144 | def __init__(
145 | self,
146 | root,
147 | train=True,
148 | transform=None,
149 | labels=None,
150 | idxs_path=None,
151 | masking=True,
152 | overlap=True,
153 | **kwargs
154 | ):
155 | full_data = CityscapesSegmentationDomain(root, train)
156 |
157 | # take index of images with at least one class in labels and all classes in labels+labels_old+[255]
158 | if idxs_path is not None and os.path.exists(idxs_path):
159 | idxs = np.load(idxs_path).tolist()
160 | else:
161 | idxs = filter_images(full_data, labels)
162 | if idxs_path is not None and distributed.get_rank() == 0:
163 | np.save(idxs_path, np.array(idxs, dtype=int))
164 |
165 | rnd = np.random.RandomState(1)
166 | rnd.shuffle(idxs)
167 | train_len = int(0.8 * len(idxs))
168 | if train:
169 | idxs = idxs[:train_len]
170 | print(f"{len(idxs)} images for train")
171 | else:
172 | idxs = idxs[train_len:]
173 | print(f"{len(idxs)} images for val")
174 |
175 | target_transform = tv.transforms.Lambda(
176 | lambda t: t.
177 | apply_(lambda x: id_to_trainid.get(x, 255))
178 | )
179 | # make the subset of the dataset
180 | self.dataset = Subset(full_data, idxs, transform, target_transform)
181 |
182 | def __getitem__(self, index):
183 | """
184 | Args:
185 | index (int): Index
186 | Returns:
187 | tuple: (image, target) where target is the image segmentation.
188 | """
189 |
190 | return self.dataset[index]
191 |
192 | def __len__(self):
193 | return len(self.dataset)
194 |
--------------------------------------------------------------------------------
/dataset/consep.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import copy
4 |
5 | import numpy as np
6 | import torch
7 | import torch.utils.data as data
8 | import torchvision as tv
9 | from PIL import Image
10 | from torch import distributed
11 |
12 | from dataset import transform
13 | from .utils import Subset, filter_images, group_images
14 |
15 | classes = {
16 | 0: 'background',
17 | 1: 'class 1',
18 | 2: 'class 2',
19 | 3: 'class 3',
20 | 4: 'class 4'
21 | }
22 |
23 | class CoNSePSegmentation(data.Dataset):
24 | """
25 | Args:
26 | root (string): Root directory of the VOC Dataset.
27 | image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
28 | is_aug (bool, optional): If you want to use the augmented train set or not (default is True)
29 | transform (callable, optional): A function/transform that takes in an PIL image
30 | and returns a transformed version. E.g, ``transforms.RandomCrop``
31 | """
32 |
33 | def __init__(self, root, image_set='train', is_aug=True, transform=None):
34 |
35 | self.root = os.path.expanduser(root)
36 | self.year = 2020
37 |
38 | self.transform = transform
39 |
40 | self.image_set = image_set
41 | monusac_root = self.root
42 | splits_dir = os.path.join(monusac_root,'splits')
43 |
44 | if not os.path.exists(monusac_root):
45 | raise RuntimeError(
46 | f'Dataset {monusac_root} not found or corrupted.'
47 | )
48 |
49 | split_f = os.path.join(splits_dir,image_set.rstrip('\n') + '.txt')
50 | if not os.path.exists(split_f):
51 | raise ValueError(
52 | 'Wrong image_set entered! Please use image_set="train" '
53 | 'or image_set="trainval" or image_set="val" '
54 | f'{split_f}'
55 | )
56 |
57 | # remove leading \n
58 | with open(os.path.join(split_f), "r") as f:
59 | file_names = [x[:-1].split(" ") for x in f.readlines()]
60 |
61 |
62 | self.images = [
63 | (
64 | os.path.join(monusac_root,x[0][1:]), os.path.join(monusac_root, x[1][1:])
65 | ) for x in file_names
66 | ]
67 |
68 | def __getitem__(self, index):
69 | """
70 | Args:
71 | index (int): Index
72 | Returns:
73 | tuple: (image, target) where target is the image segmentation.
74 | """
75 |
76 | img = Image.open(self.images[index][0]).convert('RGB')
77 | target = Image.open(self.images[index][1])
78 | # img = Image.open("/data/wzz/MoNuSAC/train/image_crop/TCGA-5P-A9K0-01Z-00-DX1_1_11.png").convert('RGB')
79 | # target = Image.open("/data/wzz/MoNuSAC/train/mask_crop/TCGA-5P-A9K0-01Z-00-DX1_1_11.png")
80 | if self.transform is not None:
81 | img, target = self.transform(img, target)
82 |
83 | # return img, target, "TCGA-5P-A9K0-01Z-00-DX1_1_11"
84 | return img, target, self.images[index][1][-31:-4]
85 |
86 |
87 | def __len__(self):
88 | return len(self.images)
89 |
90 | def viz_getter(self, index):
91 | image_path = self.images[index][0]
92 | raw_image = Image.open(self.images[index][0]).convert('RGB')
93 | target = Image.open(self.images[index][1])
94 | if self.transform is not None:
95 | img, target = self.transform(raw_image, target)
96 | else:
97 | img = copy.deepcopy(raw_image)
98 | return image_path, raw_image, img, target
99 |
100 |
101 | class CoNSePSegmentationIncremental(data.Dataset):
102 | def __init__(self,
103 | root,
104 | train=True,
105 | transform=None,
106 | labels=None,
107 | labels_old=None,
108 | idxs_path=None,
109 | masking=True,
110 | overlap=True,
111 | data_masking="current",
112 | test_on_val=False,
113 | **kwargs
114 | ):
115 |
116 | full_consep = CoNSePSegmentation(root,'train' if train else 'val',is_aug= True,transform=None)
117 | self.labels = []
118 | self.labels_old = []
119 |
120 | if labels is not None:
121 |
122 | labels_old = labels_old if labels_old is not None else []
123 |
124 | self.__strip_zero(labels)
125 | self.__strip_zero(labels_old)
126 |
127 | assert not any(
128 | l in labels_old for l in labels
129 | ), "labels and labels_old must be disjoint sets"
130 |
131 | self.labels = [0] + labels
132 | self.labels_old = [0] + labels_old
133 |
134 | self.order = [0] + labels_old + labels
135 |
136 | if idxs_path is not None and os.path.exists(idxs_path):
137 | idxs = np.load(idxs_path).tolist()
138 |
139 | else:
140 | idxs = filter_images(full_consep, labels, labels_old,overlap=overlap)
141 | if idxs_path is not None and distributed.get_rank() == 0:
142 | np.save(idxs_path,np.array(idxs,dtype=int))
143 |
144 | if test_on_val:
145 | rnd = np.random.RandomState(1)
146 | rnd.shuffle(idxs)
147 | train_len = int(0.8 * len(idxs))
148 | if train:
149 | idxs = idxs[:train_len]
150 | else:
151 | idxs = idxs[train_len:]
152 |
153 | masking_value = 0 # Future classes will be considered as background.
154 | self.inverted_order = {label: self.order.index(label) for label in self.order}
155 | self.inverted_order[255] = 255
156 | reorder_transform = tv.transforms.Lambda(
157 | lambda t: t.apply_(
158 | lambda x: self.inverted_order[x] if x in self.inverted_order else masking_value
159 | )
160 | )
161 |
162 | if masking:
163 | if data_masking == 'current':
164 | tmp_labels = self.labels + [255]
165 | elif data_masking == 'current+old':
166 | tmp_labels = labels_old + self.labels + [255]
167 | elif data_masking == 'all':
168 | raise NotImplementedError(
169 | f"data_masking={data_masking} not yet implemented sorry not sorry."
170 | )
171 |
172 | elif data_masking == "new":
173 | tmp_labels = self.labels
174 | masking_value = 255
175 |
176 | target_transform = tv.transforms.Lambda(
177 | lambda t: t.
178 | apply_(lambda x: self.inverted_order[x] if x in tmp_labels else masking_value)
179 | )
180 | else:
181 | assert False
182 | target_transform = reorder_transform
183 |
184 | # make the subset of the dataset
185 | self.dataset = Subset(full_consep, idxs, transform, target_transform)
186 | else:
187 | self.dataset = full_consep
188 |
189 |
190 |
191 | @staticmethod
192 | def __strip_zero(labels):
193 | while 0 in labels:
194 | labels.remove(0)
195 |
196 |
197 | def __len__(self):
198 | return len(self.dataset)
199 |
200 | def __getitem__(self, index):
201 | """
202 | Args:
203 | index (int): Index
204 | Returns:
205 | tuple: (image, target) where target is the image segmentation.
206 | """
207 |
208 | return self.dataset[index]
--------------------------------------------------------------------------------
/dataset/monusac.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import copy
4 |
5 | import numpy as np
6 | import torch.utils.data as data
7 | import torchvision as tv
8 | from PIL import Image
9 | from torch import distributed
10 |
11 | from .utils import Subset, filter_images, group_images
12 |
13 | classes = {
14 | 0: 'background',
15 | 1: 'class 1',
16 | 2: 'class 2',
17 | 3: 'class 3'
18 | }
19 |
20 | class MoNuSACSegmentation(data.Dataset):
21 | """
22 | Args:
23 | root (string): Root directory of the VOC Dataset.
24 | image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
25 | is_aug (bool, optional): If you want to use the augmented train set or not (default is True)
26 | transform (callable, optional): A function/transform that takes in an PIL image
27 | and returns a transformed version. E.g, ``transforms.RandomCrop``
28 | """
29 |
30 | def __init__(self, root, image_set='train', is_aug=True, transform=None):
31 |
32 | self.root = os.path.expanduser(root)
33 | self.year = 2020
34 |
35 | self.transform = transform
36 |
37 | self.image_set = image_set
38 | monusac_root = self.root
39 | splits_dir = os.path.join(monusac_root,'splits')
40 |
41 | if not os.path.exists(monusac_root):
42 | raise RuntimeError(
43 | f'Dataset {monusac_root} not found or corrupted.'
44 | )
45 |
46 | split_f = os.path.join(splits_dir,image_set.rstrip('\n') + '.txt')
47 | if not os.path.exists(split_f):
48 | raise ValueError(
49 | 'Wrong image_set entered! Please use image_set="train" '
50 | 'or image_set="trainval" or image_set="val" '
51 | f'{split_f}'
52 | )
53 |
54 | # remove leading \n
55 | with open(os.path.join(split_f), "r") as f:
56 | file_names = [x[:-1].split(" ") for x in f.readlines()]
57 |
58 |
59 | self.images = [
60 | (
61 | os.path.join(monusac_root,x[0][1:]), os.path.join(monusac_root, x[1][1:])
62 | ) for x in file_names
63 | ]
64 |
65 | def __getitem__(self, index):
66 | """
67 | Args:
68 | index (int): Index
69 | Returns:
70 | tuple: (image, target) where target is the image segmentation.
71 | """
72 |
73 |
74 | img = Image.open(self.images[index][0]).convert('RGB')
75 | target = Image.open(self.images[index][1])
76 | # img = Image.open("/data/why/MoNuSAC/train/image_crop/TCGA-IZ-A6M9-01Z-00-DX1_3_1.png").convert('RGB')
77 | # target = Image.open("/data/why/MoNuSAC/test/mask_crop/TCGA-IZ-A6M9-01Z-00-DX1_3_1.png")
78 | if self.transform is not None:
79 | img, target = self.transform(img, target)
80 |
81 | # return img, target, "TCGA-IZ-A6M9-01Z-00-DX1_3_1"
82 | # return img, target, self.images[index][1][-31:-4]
83 | return img, target, self.images[index][0][35:]
84 |
85 |
86 | def __len__(self):
87 | return len(self.images)
88 |
89 | def viz_getter(self, index):
90 | image_path = self.images[index][0]
91 | raw_image = Image.open(self.images[index][0]).convert('RGB')
92 | target = Image.open(self.images[index][1])
93 | if self.transform is not None:
94 | img, target = self.transform(raw_image, target)
95 | else:
96 | img = copy.deepcopy(raw_image)
97 | return image_path, raw_image, img, target
98 |
99 |
100 | class MoNuSACSegmentationIncremental(data.Dataset):
101 | def __init__(self,
102 | root,
103 | train=True,
104 | transform=None,
105 | labels=None,
106 | labels_old=None,
107 | idxs_path=None,
108 | masking=True,
109 | overlap=True,
110 | data_masking="current",
111 | test_on_val=False,
112 | image_Set=None,
113 | **kwargs
114 | ):
115 |
116 | full_monusac = MoNuSACSegmentation(root,image_set=image_Set,is_aug= True,transform=None)
117 | self.labels = []
118 | self.labels_old = []
119 |
120 | if labels is not None:
121 |
122 | labels_old = labels_old if labels_old is not None else []
123 |
124 | self.__strip_zero(labels)
125 | self.__strip_zero(labels_old)
126 |
127 | assert not any(
128 | l in labels_old for l in labels
129 | ), "labels and labels_old must be disjoint sets"
130 |
131 | self.labels = [0] + labels
132 | self.labels_old = [0] + labels_old
133 |
134 | self.order = [0] + labels_old + labels
135 |
136 | if idxs_path is not None and os.path.exists(idxs_path):
137 | idxs = np.load(idxs_path).tolist()
138 |
139 | else:
140 | idxs = filter_images(full_monusac, labels, labels_old,overlap=overlap)
141 | if idxs_path is not None and distributed.get_rank() == 0:
142 | np.save(idxs_path,np.array(idxs,dtype=int))
143 |
144 | if test_on_val:
145 | rnd = np.random.RandomState(1)
146 | rnd.shuffle(idxs)
147 | train_len = int(0.8 * len(idxs))
148 | if train:
149 | idxs = idxs[:train_len]
150 | else:
151 | idxs = idxs[train_len:]
152 |
153 | masking_value = 0 # Future classes will be considered as background.
154 | self.inverted_order = {label: self.order.index(label) for label in self.order}
155 | self.inverted_order[255] = 255
156 | reorder_transform = tv.transforms.Lambda(
157 | lambda t: t.apply_(
158 | lambda x: self.inverted_order[x] if x in self.inverted_order else masking_value
159 | )
160 | )
161 |
162 | if masking:
163 | if data_masking == 'current':
164 | tmp_labels = self.labels + [255]
165 | elif data_masking == 'current+old':
166 | tmp_labels = labels_old + self.labels + [255]
167 | elif data_masking == 'all':
168 | raise NotImplementedError(
169 | f"data_masking={data_masking} not yet implemented sorry not sorry."
170 | )
171 |
172 | elif data_masking == "new":
173 | tmp_labels = self.labels
174 | masking_value = 255
175 |
176 | target_transform = tv.transforms.Lambda(
177 | lambda t: t.
178 | apply_(lambda x: self.inverted_order[x] if x in tmp_labels else masking_value)
179 | )
180 | else:
181 | assert False
182 | target_transform = reorder_transform
183 |
184 | # make the subset of the dataset
185 | self.dataset = Subset(full_monusac, idxs, transform, target_transform)
186 | else:
187 | self.dataset = full_monusac
188 |
189 |
190 |
191 | @staticmethod
192 | def __strip_zero(labels):
193 | while 0 in labels:
194 | labels.remove(0)
195 |
196 |
197 | def __len__(self):
198 | return len(self.dataset)
199 |
200 | def __getitem__(self, index):
201 | """
202 | Args:
203 | index (int): Index
204 | Returns:
205 | tuple: (image, target) where target is the image segmentation.
206 | """
207 |
208 | return self.dataset[index]
--------------------------------------------------------------------------------
/dataset/transform.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.transforms.functional as F
3 | import random
4 | import numbers
5 | import numpy as np
6 | import collections
7 | from PIL import Image
8 | import warnings
9 | import math
10 |
11 | _pil_interpolation_to_str = {
12 | Image.NEAREST: 'PIL.Image.NEAREST',
13 | Image.BILINEAR: 'PIL.Image.BILINEAR',
14 | Image.BICUBIC: 'PIL.Image.BICUBIC',
15 | Image.LANCZOS: 'PIL.Image.LANCZOS',
16 | Image.HAMMING: 'PIL.Image.HAMMING',
17 | Image.BOX: 'PIL.Image.BOX',
18 | }
19 |
20 |
21 | class Compose(object):
22 | """Composes several transforms together.
23 |
24 | Args:
25 | transforms (list of ``Transform`` objects): list of transforms to compose.
26 |
27 | """
28 |
29 | def __init__(self, transforms):
30 | self.transforms = transforms
31 |
32 | def __call__(self, img, lbl=None):
33 | if lbl is not None:
34 | for t in self.transforms:
35 | img, lbl = t(img, lbl)
36 | return img, lbl
37 | else:
38 | for t in self.transforms:
39 | img = t(img)
40 | return img
41 |
42 | def __repr__(self):
43 | format_string = self.__class__.__name__ + '('
44 | for t in self.transforms:
45 | format_string += '\n'
46 | format_string += ' {0}'.format(t)
47 | format_string += '\n)'
48 | return format_string
49 |
50 |
51 | class Resize(object):
52 | """Resize the input PIL Image to the given size.
53 |
54 | Args:
55 | size (sequence or int): Desired output size. If size is a sequence like
56 | (h, w), output size will be matched to this. If size is an int,
57 | smaller edge of the image will be matched to this number.
58 | i.e, if height > width, then image will be rescaled to
59 | (size * height / width, size)
60 | interpolation (int, optional): Desired interpolation. Default is
61 | ``PIL.Image.BILINEAR``
62 | """
63 |
64 | def __init__(self, size, interpolation=Image.BILINEAR):
65 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
66 | self.size = size
67 | self.interpolation = interpolation
68 |
69 | def __call__(self, img, lbl=None):
70 | """
71 | Args:
72 | img (PIL Image): Image to be scaled.
73 |
74 | Returns:
75 | PIL Image: Rescaled image.
76 | """
77 | if lbl is not None:
78 | return F.resize(img, self.size, self.interpolation), F.resize(lbl, self.size, Image.NEAREST)
79 | else:
80 | return F.resize(img, self.size, self.interpolation)
81 |
82 | def __repr__(self):
83 | interpolate_str = _pil_interpolation_to_str[self.interpolation]
84 | return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
85 |
86 |
87 | class CenterCrop(object):
88 | """Crops the given PIL Image at the center.
89 |
90 | Args:
91 | size (sequence or int): Desired output size of the crop. If size is an
92 | int instead of sequence like (h, w), a square crop (size, size) is
93 | made.
94 | """
95 |
96 | def __init__(self, size):
97 | if isinstance(size, numbers.Number):
98 | self.size = (int(size), int(size))
99 | else:
100 | self.size = size
101 |
102 | def __call__(self, img, lbl=None):
103 | """
104 | Args:
105 | img (PIL Image): Image to be cropped.
106 |
107 | Returns:
108 | PIL Image: Cropped image.
109 | """
110 | if lbl is not None:
111 | return F.center_crop(img, self.size), F.center_crop(lbl, self.size)
112 | else:
113 | return F.center_crop(img, self.size)
114 |
115 | def __repr__(self):
116 | return self.__class__.__name__ + f'(size={self.size})'
117 |
118 |
119 | class Pad(object):
120 | """Pad the given PIL Image on all sides with the given "pad" value.
121 | Args:
122 | padding (int or tuple): Padding on each border. If a single int is provided this
123 | is used to pad all borders. If tuple of length 2 is provided this is the padding
124 | on left/right and top/bottom respectively. If a tuple of length 4 is provided
125 | this is the padding for the left, top, right and bottom borders
126 | respectively.
127 | fill (int): Pixel fill value for constant fill. Default is 0.
128 | This value is only used when the padding_mode is constant
129 | padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
130 | Default is constant.
131 | - constant: pads with a constant value, this value is specified with fill
132 | - edge: pads with the last value at the edge of the image
133 | - reflect: pads with reflection of image without repeating the last value on the edge
134 | For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
135 | will result in [3, 2, 1, 2, 3, 4, 3, 2]
136 | - symmetric: pads with reflection of image repeating the last value on the edge
137 | For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
138 | will result in [2, 1, 1, 2, 3, 4, 4, 3]
139 | """
140 |
141 | def __init__(self, padding, fill=0, padding_mode='constant'):
142 | assert isinstance(padding, (numbers.Number, tuple))
143 | assert isinstance(fill, (numbers.Number, str))
144 | assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
145 | if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
146 | raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
147 | "{} element tuple".format(len(padding)))
148 |
149 | self.padding = padding
150 | self.fill = fill
151 | self.padding_mode = padding_mode
152 |
153 | def __call__(self, img, lbl=None):
154 | """
155 | Args:
156 | img (PIL Image): Image to be padded.
157 | Returns:
158 | PIL Image: Padded image.
159 | """
160 | if lbl is not None:
161 | return F.pad(img, self.padding, self.fill, self.padding_mode), F.pad(lbl, self.padding, self.fill, self.padding_mode)
162 | else:
163 | return F.pad(img, self.padding, self.fill, self.padding_mode)
164 |
165 | def __repr__(self):
166 | return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\
167 | format(self.padding, self.fill, self.padding_mode)
168 |
169 |
170 | class Lambda(object):
171 | """Apply a user-defined lambda as a transform.
172 | Args:
173 | lambd (function): Lambda/function to be used for transform.
174 | """
175 |
176 | def __init__(self, lambd):
177 | assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
178 | self.lambd = lambd
179 |
180 | def __call__(self, img, lbl=None):
181 | if lbl is not None:
182 | return self.lambd(img), self.lambd(lbl)
183 | else:
184 | return self.lambd(img)
185 |
186 | def __repr__(self):
187 | return self.__class__.__name__ + '()'
188 |
189 |
190 | class RandomRotation(object):
191 | """Rotate the image by angle.
192 |
193 | Args:
194 | degrees (sequence or float or int): Range of degrees to select from.
195 | If degrees is a number instead of sequence like (min, max), the range of degrees
196 | will be (-degrees, +degrees).
197 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
198 | An optional resampling filter.
199 | See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters
200 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
201 | expand (bool, optional): Optional expansion flag.
202 | If true, expands the output to make it large enough to hold the entire rotated image.
203 | If false or omitted, make the output image the same size as the input image.
204 | Note that the expand flag assumes rotation around the center and no translation.
205 | center (2-tuple, optional): Optional center of rotation.
206 | Origin is the upper left corner.
207 | Default is the center of the image.
208 | """
209 |
210 | def __init__(self, degrees, resample=False, expand=False, center=None):
211 | if isinstance(degrees, numbers.Number):
212 | if degrees < 0:
213 | raise ValueError("If degrees is a single number, it must be positive.")
214 | self.degrees = (-degrees, degrees)
215 | else:
216 | if len(degrees) != 2:
217 | raise ValueError("If degrees is a sequence, it must be of len 2.")
218 | self.degrees = degrees
219 |
220 | self.resample = resample
221 | self.expand = expand
222 | self.center = center
223 |
224 | @staticmethod
225 | def get_params(degrees):
226 | """Get parameters for ``rotate`` for a random rotation.
227 |
228 | Returns:
229 | sequence: params to be passed to ``rotate`` for random rotation.
230 | """
231 | angle = random.uniform(degrees[0], degrees[1])
232 |
233 | return angle
234 |
235 | def __call__(self, img, lbl):
236 | """
237 | img (PIL Image): Image to be rotated.
238 | lbl (PIL Image): Label to be rotated.
239 |
240 | Returns:
241 | PIL Image: Rotated image.
242 | PIL Image: Rotated label.
243 | """
244 |
245 | angle = self.get_params(self.degrees)
246 | if lbl is not None:
247 | return F.rotate(img, angle, self.resample, self.expand, self.center), \
248 | F.rotate(lbl, angle, self.resample, self.expand, self.center)
249 | else:
250 | return F.rotate(img, angle, self.resample, self.expand, self.center)
251 |
252 | def __repr__(self):
253 | format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees)
254 | format_string += ', resample={0}'.format(self.resample)
255 | format_string += ', expand={0}'.format(self.expand)
256 | if self.center is not None:
257 | format_string += ', center={0}'.format(self.center)
258 | format_string += ')'
259 | return format_string
260 |
261 |
262 | class RandomHorizontalFlip(object):
263 | """Horizontally flip the given PIL Image randomly with a given probability.
264 |
265 | Args:
266 | p (float): probability of the image being flipped. Default value is 0.5
267 | """
268 |
269 | def __init__(self, p=0.5):
270 | self.p = p
271 |
272 | def __call__(self, img, lbl=None):
273 | """
274 | Args:
275 | img (PIL Image): Image to be flipped.
276 |
277 | Returns:
278 | PIL Image: Randomly flipped image.
279 | """
280 | if random.random() < self.p:
281 | if lbl is not None:
282 | return F.hflip(img), F.hflip(lbl)
283 | else:
284 | return F.hflip(img)
285 | if lbl is not None:
286 | return img, lbl
287 | else:
288 | return img
289 |
290 | def __repr__(self):
291 | return self.__class__.__name__ + '(p={})'.format(self.p)
292 |
293 |
294 | class RandomVerticalFlip(object):
295 | """Vertically flip the given PIL Image randomly with a given probability.
296 |
297 | Args:
298 | p (float): probability of the image being flipped. Default value is 0.5
299 | """
300 |
301 | def __init__(self, p=0.5):
302 | self.p = p
303 |
304 | def __call__(self, img, lbl):
305 | """
306 | Args:
307 | img (PIL Image): Image to be flipped.
308 | lbl (PIL Image): Label to be flipped.
309 |
310 | Returns:
311 | PIL Image: Randomly flipped image.
312 | PIL Image: Randomly flipped label.
313 | """
314 | if random.random() < self.p:
315 | if lbl is not None:
316 | return F.vflip(img), F.vflip(lbl)
317 | else:
318 | return F.vflip(img)
319 | if lbl is not None:
320 | return img, lbl
321 | else:
322 | return img
323 |
324 | def __repr__(self):
325 | return self.__class__.__name__ + '(p={})'.format(self.p)
326 |
327 |
328 | class ToTensor(object):
329 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
330 |
331 | Converts a PIL Image or numpy.ndarray (H x W x C) in the range
332 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
333 | if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
334 | or if the numpy.ndarray has dtype = np.uint8
335 | In the other cases, tensors are returned without scaling.
336 |
337 | """
338 |
339 | def __call__(self, pic, lbl=None):
340 | """
341 | Note that labels will not be normalized to [0, 1].
342 |
343 | Args:
344 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
345 | lbl (PIL Image or numpy.ndarray): Label to be converted to tensor.
346 | Returns:
347 | Tensor: Converted image and label
348 | """
349 | if lbl is not None:
350 | return F.to_tensor(pic), torch.from_numpy(np.array(lbl, dtype=np.uint8))
351 | else:
352 | return F.to_tensor(pic)
353 |
354 | def __repr__(self):
355 | return self.__class__.__name__ + '()'
356 |
357 |
358 | class Normalize(object):
359 | """Normalize a tensor image with mean and standard deviation.
360 | Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
361 | will normalize each channel of the input ``torch.*Tensor`` i.e.
362 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
363 |
364 | Args:
365 | mean (sequence): Sequence of means for each channel.
366 | std (sequence): Sequence of standard deviations for each channel.
367 | """
368 |
369 | def __init__(self, mean, std):
370 | self.mean = mean
371 | self.std = std
372 |
373 | def __call__(self, tensor, lbl=None):
374 | """
375 | Args:
376 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
377 | tensor (Tensor): Tensor of label. A dummy input for ExtCompose
378 | Returns:
379 | Tensor: Normalized Tensor image.
380 | Tensor: Unchanged Tensor label
381 | """
382 | if lbl is not None:
383 | return F.normalize(tensor, self.mean, self.std), lbl
384 | else:
385 | return F.normalize(tensor, self.mean, self.std)
386 |
387 | def __repr__(self):
388 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
389 |
390 |
391 | class RandomCrop(object):
392 | """Crop the given PIL Image at a random location.
393 |
394 | Args:
395 | size (sequence or int): Desired output size of the crop. If size is an
396 | int instead of sequence like (h, w), a square crop (size, size) is
397 | made.
398 | padding (int or sequence, optional): Optional padding on each border
399 | of the image. Default is 0, i.e no padding. If a sequence of length
400 | 4 is provided, it is used to pad left, top, right, bottom borders
401 | respectively.
402 | pad_if_needed (boolean): It will pad the image if smaller than the
403 | desired size to avoid raising an exception.
404 | """
405 |
406 | def __init__(self, size, padding=0, pad_if_needed=False):
407 | if isinstance(size, numbers.Number):
408 | self.size = (int(size), int(size))
409 | else:
410 | self.size = size
411 | self.padding = padding
412 | self.pad_if_needed = pad_if_needed
413 |
414 | @staticmethod
415 | def get_params(img, output_size):
416 | """Get parameters for ``crop`` for a random crop.
417 |
418 | Args:
419 | img (PIL Image): Image to be cropped.
420 | output_size (tuple): Expected output size of the crop.
421 |
422 | Returns:
423 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
424 | """
425 | w, h = img.size
426 | th, tw = output_size
427 | if w == tw and h == th:
428 | return 0, 0, h, w
429 |
430 | i = random.randint(0, h - th)
431 | j = random.randint(0, w - tw)
432 | return i, j, th, tw
433 |
434 | def __call__(self, img, lbl=None):
435 | """
436 | Args:
437 | img (PIL Image): Image to be cropped.
438 | lbl (PIL Image): Label to be cropped.
439 | Returns:
440 | PIL Image: Cropped image.
441 | PIL Image: Cropped label.
442 | """
443 | if lbl is None:
444 | if self.padding > 0:
445 | img = F.pad(img, self.padding)
446 | # pad the width if needed
447 | if self.pad_if_needed and img.size[0] < self.size[1]:
448 | img = F.pad(img, padding=int((1 + self.size[1] - img.size[0]) / 2))
449 | # pad the height if needed
450 | if self.pad_if_needed and img.size[1] < self.size[0]:
451 | img = F.pad(img, padding=int((1 + self.size[0] - img.size[1]) / 2))
452 |
453 | i, j, h, w = self.get_params(img, self.size)
454 |
455 | return F.crop(img, i, j, h, w)
456 |
457 | else:
458 | assert img.size == lbl.size, 'size of img and lbl should be the same. %s, %s' % (img.size, lbl.size)
459 | if self.padding > 0:
460 | img = F.pad(img, self.padding)
461 | lbl = F.pad(lbl, self.padding)
462 |
463 | # pad the width if needed
464 | if self.pad_if_needed and img.size[0] < self.size[1]:
465 | img = F.pad(img, padding=int((1 + self.size[1] - img.size[0]) / 2))
466 | lbl = F.pad(lbl, padding=int((1 + self.size[1] - lbl.size[0]) / 2))
467 |
468 | # pad the height if needed
469 | if self.pad_if_needed and img.size[1] < self.size[0]:
470 | img = F.pad(img, padding=int((1 + self.size[0] - img.size[1]) / 2))
471 | lbl = F.pad(lbl, padding=int((1 + self.size[0] - lbl.size[1]) / 2))
472 |
473 | i, j, h, w = self.get_params(img, self.size)
474 |
475 | return F.crop(img, i, j, h, w), F.crop(lbl, i, j, h, w)
476 |
477 | def __repr__(self):
478 | return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)
479 |
480 |
481 | class RandomResizedCrop(object):
482 | """Crop the given PIL Image to random size and aspect ratio.
483 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random
484 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
485 | is finally resized to given size.
486 | This is popularly used to train the Inception networks.
487 | Args:
488 | size: expected output size of each edge
489 | scale: range of size of the origin size cropped
490 | ratio: range of aspect ratio of the origin aspect ratio cropped
491 | interpolation: Default: PIL.Image.BILINEAR
492 | """
493 |
494 | def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):
495 | if isinstance(size, tuple):
496 | self.size = size
497 | else:
498 | self.size = (size, size)
499 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
500 | warnings.warn("range should be of kind (min, max)")
501 |
502 | self.interpolation = interpolation
503 | self.scale = scale
504 | self.ratio = ratio
505 |
506 | @staticmethod
507 | def get_params(img, scale, ratio):
508 | """Get parameters for ``crop`` for a random sized crop.
509 | Args:
510 | img (PIL Image): Image to be cropped.
511 | scale (tuple): range of size of the origin size cropped
512 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
513 | Returns:
514 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random
515 | sized crop.
516 | """
517 | area = img.size[0] * img.size[1]
518 |
519 | for attempt in range(10):
520 | target_area = random.uniform(*scale) * area
521 | log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
522 | aspect_ratio = math.exp(random.uniform(*log_ratio))
523 |
524 | w = int(round(math.sqrt(target_area * aspect_ratio)))
525 | h = int(round(math.sqrt(target_area / aspect_ratio)))
526 |
527 | if w <= img.size[0] and h <= img.size[1]:
528 | i = random.randint(0, img.size[1] - h)
529 | j = random.randint(0, img.size[0] - w)
530 | return i, j, h, w
531 |
532 | # Fallback to central crop
533 | in_ratio = img.size[0] / img.size[1]
534 | if (in_ratio < min(ratio)):
535 | w = img.size[0]
536 | h = int(round(w / min(ratio)))
537 | elif (in_ratio > max(ratio)):
538 | h = img.size[1]
539 | w = int(round(h * max(ratio)))
540 | else: # whole image
541 | w = img.size[0]
542 | h = img.size[1]
543 | i = (img.size[1] - h) // 2
544 | j = (img.size[0] - w) // 2
545 | return i, j, h, w
546 |
547 | def __call__(self, img, lbl=None):
548 | """
549 | Args:
550 | img (PIL Image): Image to be cropped and resized.
551 | Returns:
552 | PIL Image: Randomly cropped and resized image.
553 | """
554 | i, j, h, w = self.get_params(img, self.scale, self.ratio)
555 | if lbl is not None:
556 | return F.resized_crop(img, i, j, h, w, self.size, self.interpolation), \
557 | F.resized_crop(lbl, i, j, h, w, self.size, Image.NEAREST)
558 | else:
559 | return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
560 |
561 | def __repr__(self):
562 | interpolate_str = _pil_interpolation_to_str[self.interpolation]
563 | format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
564 | format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
565 | format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
566 | format_string += ', interpolation={0})'.format(interpolate_str)
567 | return format_string
568 |
569 |
570 | class ColorJitter(object):
571 | """Randomly change the brightness, contrast and saturation of an image.
572 | Args:
573 | brightness (float or tuple of float (min, max)): How much to jitter brightness.
574 | brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
575 | or the given [min, max]. Should be non negative numbers.
576 | contrast (float or tuple of float (min, max)): How much to jitter contrast.
577 | contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
578 | or the given [min, max]. Should be non negative numbers.
579 | saturation (float or tuple of float (min, max)): How much to jitter saturation.
580 | saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
581 | or the given [min, max]. Should be non negative numbers.
582 | hue (float or tuple of float (min, max)): How much to jitter hue.
583 | hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
584 | Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
585 | """
586 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
587 | self.brightness = self._check_input(brightness, 'brightness')
588 | self.contrast = self._check_input(contrast, 'contrast')
589 | self.saturation = self._check_input(saturation, 'saturation')
590 | self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
591 | clip_first_on_zero=False)
592 |
593 | def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
594 | if isinstance(value, numbers.Number):
595 | if value < 0:
596 | raise ValueError("If {} is a single number, it must be non negative.".format(name))
597 | value = [center - value, center + value]
598 | if clip_first_on_zero:
599 | value[0] = max(value[0], 0)
600 | elif isinstance(value, (tuple, list)) and len(value) == 2:
601 | if not bound[0] <= value[0] <= value[1] <= bound[1]:
602 | raise ValueError("{} values should be between {}".format(name, bound))
603 | else:
604 | raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
605 |
606 | # if value is 0 or (1., 1.) for brightness/contrast/saturation
607 | # or (0., 0.) for hue, do nothing
608 | if value[0] == value[1] == center:
609 | value = None
610 | return value
611 |
612 | @staticmethod
613 | def get_params(brightness, contrast, saturation, hue):
614 | """Get a randomized transform to be applied on image.
615 | Arguments are same as that of __init__.
616 | Returns:
617 | Transform which randomly adjusts brightness, contrast and
618 | saturation in a random order.
619 | """
620 | transforms = []
621 |
622 | if brightness is not None:
623 | brightness_factor = random.uniform(brightness[0], brightness[1])
624 | transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor)))
625 |
626 | if contrast is not None:
627 | contrast_factor = random.uniform(contrast[0], contrast[1])
628 | transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))
629 |
630 | if saturation is not None:
631 | saturation_factor = random.uniform(saturation[0], saturation[1])
632 | transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor)))
633 |
634 | if hue is not None:
635 | hue_factor = random.uniform(hue[0], hue[1])
636 | transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))
637 |
638 | random.shuffle(transforms)
639 | transform = Compose(transforms)
640 |
641 | return transform
642 |
643 | def __call__(self, img, lbl=None):
644 | """
645 | Args:
646 | img (PIL Image): Input image.
647 | Returns:
648 | PIL Image: Color jittered image.
649 | """
650 | transform = self.get_params(self.brightness, self.contrast,
651 | self.saturation, self.hue)
652 | if lbl is not None:
653 | return transform(img), lbl
654 | else:
655 | return transform(img)
656 |
657 | def __repr__(self):
658 | format_string = self.__class__.__name__ + '('
659 | format_string += 'brightness={0}'.format(self.brightness)
660 | format_string += ', contrast={0}'.format(self.contrast)
661 | format_string += ', saturation={0}'.format(self.saturation)
662 | format_string += ', hue={0})'.format(self.hue)
663 | return format_string
--------------------------------------------------------------------------------
/dataset/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def group_images(dataset, labels):
6 | # Group images based on the label in LABELS (using labels not reordered)
7 | idxs = {lab: [] for lab in labels}
8 |
9 | labels_cum = labels + [0, 255]
10 | for i in range(len(dataset)):
11 | cls = np.unique(np.array(dataset[i][1]))
12 | if all(x in labels_cum for x in cls):
13 | for x in cls:
14 | if x in labels:
15 | idxs[x].append(i)
16 | return idxs
17 |
18 |
19 | def filter_images(dataset, labels, labels_old=None, overlap=True):
20 | # Filter images without any label in LABELS (using labels not reordered)
21 | idxs = []
22 |
23 | if 0 in labels:
24 | labels.remove(0)
25 |
26 | print(f"Filtering images...")
27 | if labels_old is None:
28 | labels_old = []
29 | labels_cum = labels + labels_old + [0, 255]
30 |
31 | if overlap:
32 | fil = lambda c: any(x in labels for x in cls)
33 | else:
34 | fil = lambda c: any(x in labels for x in cls) and all(x in labels_cum for x in c)
35 |
36 | for i in range(len(dataset)):
37 | cls = np.unique(np.array(dataset[i][1]))
38 | if fil(cls):
39 | idxs.append(i)
40 | if i % 1000 == 0:
41 | print(f"\t{i}/{len(dataset)} ...")
42 | return idxs
43 |
44 |
45 | class Subset(torch.utils.data.Dataset):
46 | """
47 | Subset of a dataset at specified indices.
48 | Arguments:
49 | dataset (Dataset): The whole Dataset
50 | indices (sequence): Indices in the whole set selected for subset
51 | transform (callable): way to transform the images and the targets
52 | target_transform(callable): way to transform the target labels
53 | """
54 |
55 | def __init__(self, dataset, indices, transform=None, target_transform=None):
56 | self.dataset = dataset
57 | self.indices = indices
58 | self.transform = transform
59 | self.target_transform = target_transform
60 |
61 | def __getitem__(self, idx):
62 | try:
63 | sample, target ,name = self.dataset[self.indices[idx]]
64 | except Exception as e:
65 | raise Exception(
66 | f"dataset = {len(self.dataset)}, indices = {len(self.indices)}, idx = {idx}, msg = {str(e)}"
67 | )
68 |
69 | if self.transform is not None:
70 | sample, target = self.transform(sample, target)
71 |
72 | if self.target_transform is not None:
73 | target = self.target_transform(target)
74 |
75 | return sample, target ,name
76 |
77 | def viz_getter(self, idx):
78 | image_path, raw_image, sample, target = self.dataset.viz_getter(self.indices[idx])
79 | if self.transform is not None:
80 | sample, target = self.transform(sample, target)
81 | if self.target_transform is not None:
82 | target = self.target_transform(target)
83 |
84 | return image_path, raw_image, sample, target
85 |
86 | def __len__(self):
87 | return len(self.indices)
88 |
89 |
90 | class MaskLabels:
91 | """
92 | Use this class to mask labels that you don't want in your dataset.
93 | Arguments:
94 | labels_to_keep (list): The list of labels to keep in the target images
95 | mask_value (int): The value to replace ignored values (def: 0)
96 | """
97 |
98 | def __init__(self, labels_to_keep, mask_value=0):
99 | self.labels = labels_to_keep
100 | self.value = torch.tensor(mask_value, dtype=torch.uint8)
101 |
102 | def __call__(self, sample):
103 | # sample must be a tensor
104 | assert isinstance(sample, torch.Tensor), "Sample must be a tensor"
105 |
106 | sample.apply_(lambda t: t.apply_(lambda x: x if x in self.labels else self.value))
107 |
108 | return sample
109 |
--------------------------------------------------------------------------------
/dataset/voc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import copy
4 |
5 | import numpy as np
6 | import torch.utils.data as data
7 | import torchvision as tv
8 | from PIL import Image
9 | from torch import distributed
10 |
11 | from .utils import Subset, filter_images, group_images
12 |
13 | classes = {
14 | 0: 'background',
15 | 1: 'aeroplane',
16 | 2: 'bicycle',
17 | 3: 'bird',
18 | 4: 'boat',
19 | 5: 'bottle',
20 | 6: 'bus',
21 | 7: 'car',
22 | 8: 'cat',
23 | 9: 'chair',
24 | 10: 'cow',
25 | 11: 'diningtable',
26 | 12: 'dog',
27 | 13: 'horse',
28 | 14: 'motorbike',
29 | 15: 'person',
30 | 16: 'pottedplant',
31 | 17: 'sheep',
32 | 18: 'sofa',
33 | 19: 'train',
34 | 20: 'tvmonitor'
35 | }
36 |
37 |
38 | class VOCSegmentation(data.Dataset):
39 | """`Pascal VOC `_ Segmentation Dataset.
40 | Args:
41 | root (string): Root directory of the VOC Dataset.
42 | image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
43 | is_aug (bool, optional): If you want to use the augmented train set or not (default is True)
44 | transform (callable, optional): A function/transform that takes in an PIL image
45 | and returns a transformed version. E.g, ``transforms.RandomCrop``
46 | """
47 |
48 | def __init__(self, root, image_set='train', is_aug=True, transform=None):
49 |
50 | self.root = os.path.expanduser(root)
51 | self.year = "2012"
52 |
53 | self.transform = transform
54 |
55 | self.image_set = image_set
56 | voc_root = self.root
57 | splits_dir = os.path.join(voc_root, 'list')
58 |
59 | if not os.path.isdir(voc_root):
60 | raise RuntimeError(
61 | 'Dataset not found or corrupted.' + ' You can use download=True to download it'
62 | f'at location = {voc_root}'
63 | )
64 |
65 | if is_aug and image_set == 'train':
66 | mask_dir = os.path.join(voc_root, 'SegmentationClassAug')
67 | assert os.path.exists(mask_dir), "SegmentationClassAug not found"
68 | split_f = os.path.join(splits_dir, 'train_aug.txt')
69 | else:
70 | split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
71 |
72 | if not os.path.exists(split_f):
73 | raise ValueError(
74 | 'Wrong image_set entered! Please use image_set="train" '
75 | 'or image_set="trainval" or image_set="val" '
76 | f'{split_f}'
77 | )
78 |
79 | # remove leading \n
80 | with open(os.path.join(split_f), "r") as f:
81 | file_names = [x[:-1].split(' ') for x in f.readlines()]
82 |
83 | # REMOVE FIRST SLASH OTHERWISE THE JOIN WILL start from root
84 | self.images = [
85 | (
86 | os.path.join(voc_root, "VOCdevkit/VOC2012",
87 | x[0][1:]), os.path.join(voc_root, x[1][1:])
88 | ) for x in file_names
89 | ]
90 | # print(file_names)
91 | # self.sal_maps = [os.path.join(voc_root, "saliency_map", x + "")]
92 |
93 | def __getitem__(self, index):
94 | """
95 | Args:
96 | index (int): Index
97 | Returns:
98 | tuple: (image, target) where target is the image segmentation.
99 | """
100 | img = Image.open(self.images[index][0]).convert('RGB')
101 | target = Image.open(self.images[index][1])
102 | if self.transform is not None:
103 | img, target = self.transform(img, target)
104 |
105 | return img, target
106 |
107 | def viz_getter(self, index):
108 | image_path = self.images[index][0]
109 | raw_image = Image.open(self.images[index][0]).convert('RGB')
110 | target = Image.open(self.images[index][1])
111 | if self.transform is not None:
112 | img, target = self.transform(raw_image, target)
113 | else:
114 | img = copy.deepcopy(raw_image)
115 | return image_path, raw_image, img, target
116 |
117 | def __len__(self):
118 | return len(self.images)
119 |
120 |
121 | class VOCSegmentationIncremental(data.Dataset):
122 |
123 | def __init__(
124 | self,
125 | root,
126 | train=True,
127 | transform=None,
128 | labels=None,
129 | labels_old=None,
130 | idxs_path=None,
131 | masking=True,
132 | overlap=True,
133 | data_masking="current",
134 | test_on_val=False,
135 | **kwargs
136 | ):
137 |
138 | full_voc = VOCSegmentation(root, 'train' if train else 'val', is_aug=True, transform=None)
139 |
140 | self.labels = []
141 | self.labels_old = []
142 |
143 | if labels is not None:
144 | # store the labels
145 | labels_old = labels_old if labels_old is not None else []
146 |
147 | self.__strip_zero(labels)
148 | self.__strip_zero(labels_old)
149 |
150 | assert not any(
151 | l in labels_old for l in labels
152 | ), "labels and labels_old must be disjoint sets"
153 |
154 | self.labels = [0] + labels
155 | self.labels_old = [0] + labels_old
156 |
157 | self.order = [0] + labels_old + labels
158 |
159 | # take index of images with at least one class in labels and all classes in labels+labels_old+[0,255]
160 | if idxs_path is not None and os.path.exists(idxs_path):
161 | idxs = np.load(idxs_path).tolist()
162 | else:
163 | idxs = filter_images(full_voc, labels, labels_old, overlap=overlap)
164 | if idxs_path is not None and distributed.get_rank() == 0:
165 | np.save(idxs_path, np.array(idxs, dtype=int))
166 |
167 | if test_on_val:
168 | rnd = np.random.RandomState(1)
169 | rnd.shuffle(idxs)
170 | train_len = int(0.8 * len(idxs))
171 | if train:
172 | idxs = idxs[:train_len]
173 | else:
174 | idxs = idxs[train_len:]
175 |
176 | #if train:
177 | # masking_value = 0
178 | #else:
179 | # masking_value = 255
180 |
181 | #self.inverted_order = {label: self.order.index(label) for label in self.order}
182 | #self.inverted_order[255] = masking_value
183 |
184 | masking_value = 0 # Future classes will be considered as background.
185 | self.inverted_order = {label: self.order.index(label) for label in self.order}
186 | self.inverted_order[255] = 255
187 |
188 | reorder_transform = tv.transforms.Lambda(
189 | lambda t: t.apply_(
190 | lambda x: self.inverted_order[x] if x in self.inverted_order else masking_value
191 | )
192 | )
193 |
194 | if masking:
195 | if data_masking == "current":
196 | tmp_labels = self.labels + [255]
197 | elif data_masking == "current+old":
198 | tmp_labels = labels_old + self.labels + [255]
199 | elif data_masking == "all":
200 | raise NotImplementedError(
201 | f"data_masking={data_masking} not yet implemented sorry not sorry."
202 | )
203 | elif data_masking == "new":
204 | tmp_labels = self.labels
205 | masking_value = 255
206 |
207 | target_transform = tv.transforms.Lambda(
208 | lambda t: t.
209 | apply_(lambda x: self.inverted_order[x] if x in tmp_labels else masking_value)
210 | )
211 | else:
212 | assert False
213 | target_transform = reorder_transform
214 |
215 | # make the subset of the dataset
216 | self.dataset = Subset(full_voc, idxs, transform, target_transform)
217 | else:
218 | self.dataset = full_voc
219 |
220 | def __getitem__(self, index):
221 | """
222 | Args:
223 | index (int): Index
224 | Returns:
225 | tuple: (image, target) where target is the image segmentation.
226 | """
227 |
228 | return self.dataset[index]
229 |
230 | def viz_getter(self, index):
231 | return self.dataset.viz_getter(index)
232 |
233 | def __len__(self):
234 | return len(self.dataset)
235 |
236 | @staticmethod
237 | def __strip_zero(labels):
238 | while 0 in labels:
239 | labels.remove(0)
240 |
--------------------------------------------------------------------------------
/docs/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/why19991/InSeg/d45ebeb795776d60373c4031b163e2efd693000b/docs/framework.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet import *
2 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from collections import OrderedDict
3 | from functools import partial
4 |
5 | import torch.nn as nn
6 |
7 | from modules import GlobalAvgPool2d, ResidualBlock
8 |
9 | from .util import try_index
10 |
11 |
12 | class ResNet(nn.Module):
13 | """Standard residual network
14 |
15 | Parameters
16 | ----------
17 | structure : list of int
18 | Number of residual blocks in each of the four modules of the network
19 | bottleneck : bool
20 | If `True` use "bottleneck" residual blocks with 3 convolutions, otherwise use standard blocks
21 | norm_act : callable
22 | Function to create normalization / activation Module
23 | classes : int
24 | If not `0` also include global average pooling and a fully-connected layer with `classes` outputs at the end
25 | of the network
26 | dilation : int or list of int
27 | List of dilation factors for the four modules of the network, or `1` to ignore dilation
28 | keep_outputs : bool
29 | If `True` output a list with the outputs of all modules
30 | """
31 |
32 | def __init__(
33 | self,
34 | structure,
35 | bottleneck,
36 | norm_act=nn.BatchNorm2d,
37 | classes=0,
38 | output_stride=16,
39 | keep_outputs=False
40 | ):
41 | super(ResNet, self).__init__()
42 | self.structure = structure
43 | self.bottleneck = bottleneck
44 | self.keep_outputs = keep_outputs
45 | self.keep_outputs = True
46 |
47 | if len(structure) != 4:
48 | raise ValueError("Expected a structure with four values")
49 | if output_stride != 8 and output_stride != 16:
50 | raise ValueError("Output stride must be 8 or 16")
51 |
52 | if output_stride == 16:
53 | dilation = [1, 1, 1, 2] # dilated conv for last 3 blocks (9 layers)
54 | elif output_stride == 8:
55 | dilation = [1, 1, 2, 4] # 23+3 blocks (78 layers)
56 | else:
57 | raise NotImplementedError
58 |
59 | self.dilation = dilation
60 |
61 | # Initial layers
62 | layers = [
63 | ("conv1", nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False)), ("bn1", norm_act(64))
64 | ]
65 | if try_index(dilation, 0) == 1:
66 | layers.append(("pool1", nn.MaxPool2d(3, stride=2, padding=1)))
67 | self.mod1 = nn.Sequential(OrderedDict(layers))
68 |
69 | # Groups of residual blocks
70 | in_channels = 64
71 | if self.bottleneck:
72 | channels = (64, 64, 256)
73 | else:
74 | channels = (64, 64)
75 | for mod_id, num in enumerate(structure):
76 | # Create blocks for module
77 | blocks = []
78 | for block_id in range(num):
79 | stride, dil = self._stride_dilation(dilation, mod_id, block_id)
80 | blocks.append(
81 | (
82 | "block%d" % (block_id + 1),
83 | ResidualBlock(
84 | in_channels,
85 | channels,
86 | norm_act=norm_act,
87 | stride=stride,
88 | dilation=dil,
89 | last=block_id == num - 1,
90 | block_id=block_id
91 | )
92 | )
93 | )
94 |
95 | # Update channels and p_keep
96 | in_channels = channels[-1]
97 |
98 | # Create module
99 | self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks)))
100 |
101 | # Double the number of channels for the next module
102 | channels = [c * 2 for c in channels]
103 |
104 | self.out_channels = in_channels
105 |
106 | # Pooling and predictor
107 | if classes != 0:
108 | self.classifier = nn.Sequential(
109 | OrderedDict(
110 | [("avg_pool", GlobalAvgPool2d()), ("fc", nn.Linear(in_channels, classes))]
111 | )
112 | )
113 |
114 | @staticmethod
115 | def _stride_dilation(dilation, mod_id, block_id):
116 | d = try_index(dilation, mod_id)
117 | s = 2 if d == 1 and block_id == 0 and mod_id > 0 else 1
118 | return s, d
119 |
120 | def forward(self, x):
121 | outs = []
122 | attentions = []
123 |
124 | x = self.mod1(x)
125 | #attentions.append(x)
126 | outs.append(x)
127 |
128 | x, att = self.mod2(x)
129 | attentions.append(att)
130 | outs.append(x)
131 |
132 | x, att = self.mod3(x)
133 | attentions.append(att)
134 | outs.append(x)
135 |
136 | x, att = self.mod4(x)
137 | attentions.append(att)
138 | outs.append(x)
139 |
140 | x, att = self.mod5(x)
141 | attentions.append(att)
142 | outs.append(x)
143 |
144 | if hasattr(self, "classifier"):
145 | outs.append(self.classifier(outs[-1]))
146 |
147 | if self.keep_outputs:
148 | return outs, attentions
149 | else:
150 | return outs[-1], attentions
151 |
152 |
153 | _NETS = {
154 | "18": {
155 | "structure": [2, 2, 2, 2],
156 | "bottleneck": False
157 | },
158 | "34": {
159 | "structure": [3, 4, 6, 3],
160 | "bottleneck": False
161 | },
162 | "50": {
163 | "structure": [3, 4, 6, 3],
164 | "bottleneck": True
165 | },
166 | "101": {
167 | "structure": [3, 4, 23, 3],
168 | "bottleneck": True
169 | },
170 | "152": {
171 | "structure": [3, 8, 36, 3],
172 | "bottleneck": True
173 | },
174 | }
175 |
176 | __all__ = []
177 | for name, params in _NETS.items():
178 | net_name = "net_resnet" + name
179 | setattr(sys.modules[__name__], net_name, partial(ResNet, **params))
180 | __all__.append(net_name)
181 |
--------------------------------------------------------------------------------
/models/util.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from functools import partial
7 |
8 |
9 | def try_index(scalar_or_list, i):
10 | try:
11 | return scalar_or_list[i]
12 | except TypeError:
13 | return scalar_or_list
14 |
15 | class Upsample(nn.Module):
16 | def __init__(self, size=None, scale_factor=None, mode='bilinear', align_corners=False):
17 | super(Upsample, self).__init__()
18 | self.align_corners = align_corners
19 | self.mode = mode
20 | self.scale_factor = scale_factor
21 | self.size = size
22 |
23 | def forward(self, x):
24 | return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode,
25 | align_corners=self.align_corners)
26 |
27 |
28 | class DecoderBlock(nn.Module):
29 | def __init__(self, in_channels, n_filters, norm_act,use_transpose=True):
30 | super(DecoderBlock, self).__init__()
31 | if use_transpose:
32 | self.up_op = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, 3, stride=2, padding=1, output_padding=1)
33 | else:
34 | self.up_op = Upsample(scale_factor=2, align_corners=True)
35 |
36 | self.conv1 = nn.Conv2d(in_channels, in_channels // 2, 1)
37 | self.norm1 = norm_act(in_channels//2)
38 | # self.norm1 = nn.BatchNorm2d(in_channels // 2)
39 | # self.relu1 = nn.ReLU(inplace=True)
40 |
41 | # self.deconv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 3, stride=2, padding=1, output_padding=1)
42 | self.norm2 = norm_act(in_channels // 2)
43 |
44 | self.conv3 = nn.Conv2d(in_channels // 2, n_filters, 1)
45 | self.norm3 = norm_act(n_filters)
46 |
47 | def forward(self, x):
48 | x = self.conv1(x)
49 | x = self.norm1(x)
50 | x = self.up_op(x)
51 | x = self.norm2(x)
52 | x = self.conv3(x)
53 | x = self.norm3(x)
54 | return x
--------------------------------------------------------------------------------
/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .deeplab import DeeplabV3
2 | from .residual import IdentityResidualBlock, ResidualBlock
3 | from .misc import GlobalAvgPool2d
4 |
--------------------------------------------------------------------------------
/modules/deeplab.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as functional
4 |
5 | from models.util import try_index
6 |
7 |
8 | class DeeplabV3(nn.Module):
9 | def __init__(self,
10 | in_channels,
11 | out_channels,
12 | hidden_channels=256,
13 | out_stride=16,
14 | norm_act=nn.BatchNorm2d,
15 | pooling_size=None):
16 | super(DeeplabV3, self).__init__()
17 | self.pooling_size = pooling_size
18 |
19 | if out_stride == 16:
20 | dilations = [6, 12, 18]
21 | elif out_stride == 8:
22 | dilations = [12, 24, 32]
23 |
24 | self.map_convs = nn.ModuleList([
25 | nn.Conv2d(in_channels, hidden_channels, 1, bias=False),
26 | nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[0], padding=dilations[0]),
27 | nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[1], padding=dilations[1]),
28 | nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[2], padding=dilations[2])
29 | ])
30 | self.map_bn = norm_act(hidden_channels * 4)
31 |
32 | self.global_pooling_conv = nn.Conv2d(in_channels, hidden_channels, 1, bias=False)
33 | self.global_pooling_bn = norm_act(hidden_channels)
34 |
35 | self.red_conv = nn.Conv2d(hidden_channels * 4, out_channels, 1, bias=False)
36 | self.pool_red_conv = nn.Conv2d(hidden_channels, out_channels, 1, bias=False)
37 | self.red_bn = norm_act(out_channels)
38 |
39 | self.reset_parameters(self.map_bn.activation, self.map_bn.activation_param)
40 |
41 | def reset_parameters(self, activation, slope):
42 | gain = nn.init.calculate_gain(activation, slope)
43 | for m in self.modules():
44 | if isinstance(m, nn.Conv2d):
45 | nn.init.xavier_normal_(m.weight.data, gain)
46 | if hasattr(m, "bias") and m.bias is not None:
47 | nn.init.constant_(m.bias, 0)
48 | elif isinstance(m, nn.BatchNorm2d):
49 | if hasattr(m, "weight") and m.weight is not None:
50 | nn.init.constant_(m.weight, 1)
51 | if hasattr(m, "bias") and m.bias is not None:
52 | nn.init.constant_(m.bias, 0)
53 |
54 | def forward(self, x):
55 | # Map convolutions
56 | out = torch.cat([m(x) for m in self.map_convs], dim=1)
57 | out = self.map_bn(out)
58 | out = self.red_conv(out)
59 |
60 | # Global pooling
61 | pool = self._global_pooling(x) # if training is global avg pooling 1x1, else use larger pool size
62 | pool = self.global_pooling_conv(pool)
63 | pool = self.global_pooling_bn(pool)
64 | pool = self.pool_red_conv(pool)
65 | if self.training or self.pooling_size is None:
66 | pool = pool.repeat(1, 1, x.size(2), x.size(3))
67 |
68 | out += pool
69 | out = self.red_bn(out)
70 | return out
71 |
72 | def _global_pooling(self, x):
73 | if self.training or self.pooling_size is None:
74 | # this is like Adaptive Average Pooling (1,1)
75 | pool = x.view(x.size(0), x.size(1), -1).mean(dim=-1)
76 | pool = pool.view(x.size(0), x.size(1), 1, 1)
77 | else:
78 | pooling_size = (min(try_index(self.pooling_size, 0), x.shape[2]),
79 | min(try_index(self.pooling_size, 1), x.shape[3]))
80 | padding = (
81 | (pooling_size[1] - 1) // 2,
82 | (pooling_size[1] - 1) // 2 if pooling_size[1] % 2 == 1 else (pooling_size[1] - 1) // 2 + 1,
83 | (pooling_size[0] - 1) // 2,
84 | (pooling_size[0] - 1) // 2 if pooling_size[0] % 2 == 1 else (pooling_size[0] - 1) // 2 + 1
85 | )
86 |
87 | pool = functional.avg_pool2d(x, pooling_size, stride=1)
88 | pool = functional.pad(pool, pad=padding, mode="replicate")
89 | return pool
90 |
--------------------------------------------------------------------------------
/modules/fca_cid.py:
--------------------------------------------------------------------------------
1 | def FCA(labels=None, features=None, centroids=None, unknown_label=None, features_counter=None,strategy='running'):
2 | b, h, w = labels.shape
3 | labels_down = labels.unsqueeze(dim=1)
4 | cl_present = torch.unique(input=labels_down)
5 | for cl in cl_present:
6 | if cl > 0 and cl != 255:
7 | features_cl = features[(labels_down == cl).expand(-1, features.shape[1], -1, -1)].view(-1, features.shape[1])
8 | if strategy == 'running':
9 | features_counter[cl] = features_counter[cl] + features_cl.shape[0]
10 | centroids[cl] = (centroids[cl] * features_counter[cl] + features_cl.detach().sum(dim=0)) / features_counter[cl]
11 | else:
12 | centroids[cl] = centroids[cl] * 0.999 + (1 - 0.999) * features_cl.detach().mean(dim=0)
13 | centroids = F.normalize(centroids, p=2, dim=1)
14 | features_bg = features[(labels_down == 0).expand(-1, features.shape[1], -1, -1)].view(-1, features.shape[1])
15 | features_bg = F.normalize(features_bg, p=2, dim=1)
16 | re_labels = labels_down.view(-1)
17 | similarity = torch.matmul(features_bg.detach(), centroids.T)
18 | similarity = similarity.mean(dim=-1)
19 | value, index = torch.sort(similarity, descending=True)
20 | fill_mask = torch.zeros_like(index)
21 | value_index = value >= 0.8
22 | fill_index = index[value_index]
23 | fill_mask[fill_index] = unknown_label # we set the label of unknown class as C_t+1 in our experiment
24 | re_labels[re_labels==0] = fill_mask
25 | re_labels = re_labels.view(b, h , w )
26 | return re_labels, centroids
27 |
28 | def CID(outputs=None, outputs_old=None, nb_old_classes=None, nb_current_classes=None, nb_future_classes=None, labels=None):
29 |
30 | outputs = outputs.permute(0, 2, 3, 1).contiguous()
31 | b, h, w, c = outputs.shape
32 | outputs_old = outputs_old.permute(0, 2, 3, 1).contiguous()
33 | out_old = torch.zeros_like(outputs)
34 | labels_unique = torch.unique(labels)
35 | out_old[..., :nb_old_classes + nb_future_classes] = outputs_old[..., :]
36 | for cl in range(nb_old_classes, nb_current_classes):
37 | out_old[..., cl] = outputs_old[..., 0] * (labels==cl).squeeze(dim=-1)
38 | for j in range(nb_future_classes):
39 | out_old[..., cl] = out_old[..., cl] + outputs_old[..., nb_old_classes + j] * (labels==cl).squeeze(dim=-1)
40 | # out_old[..., nb_old_classes+j] = out_old[..., cl] + outputs_old[..., nb_old_classes + j] * (labels == cl).squeeze(dim=-1)
41 | out_old[..., 0] = (labels != cl).squeeze(dim=-1) * out_old[..., 0]
42 | # out_old[..., :nb_old_classes+nb_future_classes] = outputs_old[..., :]* ((labels < nb_old_classes) + (labels == 255)).unsqueeze(dim=-1)
43 | out_old = torch.log_softmax(out_old, dim=-1)
44 | outputs = torch.softmax(outputs, dim=-1)
45 | # out = (out_old * outputs * ((labels= nb_old_classes)).unsqueeze(dim=-1).expand(-1, -1, -1,c)).sum(dim=-1) / c
49 |
50 | return - torch.mean(out)
51 |
52 |
53 |
54 |
--------------------------------------------------------------------------------
/modules/misc.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class GlobalAvgPool2d(nn.Module):
5 | def __init__(self):
6 | """Global average pooling over the input's spatial dimensions"""
7 | super(GlobalAvgPool2d, self).__init__()
8 |
9 | def forward(self, inputs):
10 | in_size = inputs.size()
11 | return inputs.view((in_size[0], in_size[1], -1)).mean(dim=2)
12 |
13 |
--------------------------------------------------------------------------------
/modules/residual.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as functional
5 | import torch
6 |
7 | class GroupBatchnorm2d(nn.Module):
8 | def __init__(self, c_num: int,
9 | group_num: int = 16,
10 | eps: float = 1e-10
11 | ):
12 | super(GroupBatchnorm2d, self).__init__()
13 | assert c_num >= group_num
14 | self.group_num = group_num
15 | self.weight = nn.Parameter(torch.randn(c_num, 1, 1))
16 | self.bias = nn.Parameter(torch.zeros(c_num, 1, 1))
17 | self.eps = eps
18 |
19 | def forward(self, x):
20 | N, C, H, W = x.size()
21 | x = x.view(N, self.group_num, -1)
22 | mean = x.mean(dim=2, keepdim=True)
23 | std = x.std(dim=2, keepdim=True)
24 | x = (x - mean) / (std + self.eps)
25 | x = x.view(N, C, H, W)
26 | return x * self.weight + self.bias
27 |
28 |
29 | class SRU(nn.Module):
30 | def __init__(self,
31 | oup_channels: int,
32 | group_num: int = 16,
33 | gate_treshold: float = 0.5,
34 | torch_gn: bool = False
35 | ):
36 | super().__init__()
37 |
38 | self.gn = nn.GroupNorm(num_channels=oup_channels, num_groups=group_num) if torch_gn else GroupBatchnorm2d(
39 | c_num=oup_channels, group_num=group_num)
40 | self.gate_treshold = gate_treshold
41 | self.sigomid = nn.Sigmoid()
42 |
43 | def forward(self, x):
44 | gn_x = self.gn(x)
45 | w_gamma = self.gn.weight / torch.sum(self.gn.weight)
46 | w_gamma = w_gamma.view(1, -1, 1, 1)
47 | reweigts = self.sigomid(gn_x * w_gamma)
48 | # Gate
49 | info_mask = reweigts >= self.gate_treshold
50 | noninfo_mask = reweigts < self.gate_treshold
51 | x_1 = info_mask * gn_x
52 | x_2 = noninfo_mask * gn_x
53 | x = self.reconstruct(x_1, x_2)
54 | return x
55 |
56 | def reconstruct(self, x_1, x_2):
57 | x_11, x_12 = torch.split(x_1, x_1.size(1) // 2, dim=1)
58 | x_21, x_22 = torch.split(x_2, x_2.size(1) // 2, dim=1)
59 | return torch.cat([x_11 + x_22, x_12 + x_21], dim=1)
60 |
61 | class ResidualBlock(nn.Module):
62 | """Configurable residual block
63 |
64 | Parameters
65 | ----------
66 | in_channels : int
67 | Number of input channels.
68 | channels : list of int
69 | Number of channels in the internal feature maps. Can either have two or three elements: if three construct
70 | a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then
71 | `3 x 3` then `1 x 1` convolutions.
72 | stride : int
73 | Stride of the first `3 x 3` convolution
74 | dilation : int
75 | Dilation to apply to the `3 x 3` convolutions.
76 | groups : int
77 | Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with
78 | bottleneck blocks.
79 | norm_act : callable
80 | Function to create normalization / activation Module.
81 | dropout: callable
82 | Function to create Dropout Module.
83 | """
84 |
85 | def __init__(
86 | self,
87 | in_channels,
88 | channels,
89 | stride=1,
90 | dilation=1,
91 | groups=1,
92 | norm_act=nn.BatchNorm2d,
93 | dropout=None,
94 | last=False,
95 | block_id=None
96 | ):
97 | super(ResidualBlock, self).__init__()
98 |
99 | self.block_id = block_id
100 |
101 | # Check parameters for inconsistencies
102 | if len(channels) != 2 and len(channels) != 3:
103 | raise ValueError("channels must contain either two or three values")
104 | if len(channels) == 2 and groups != 1:
105 | raise ValueError("groups > 1 are only valid if len(channels) == 3")
106 |
107 | is_bottleneck = len(channels) == 3
108 | need_proj_conv = stride != 1 or in_channels != channels[-1]
109 |
110 | if not is_bottleneck:
111 | bn2 = norm_act(channels[1])
112 | bn2.activation = "identity"
113 | layers = [
114 | (
115 | "conv1",
116 | nn.Conv2d(
117 | in_channels,
118 | channels[0],
119 | 3,
120 | stride=stride,
121 | padding=dilation,
122 | bias=False,
123 | dilation=dilation
124 | )
125 | ), ("bn1", norm_act(channels[0])),
126 | (
127 | "conv2",
128 | nn.Conv2d(
129 | channels[0],
130 | channels[1],
131 | 3,
132 | stride=1,
133 | padding=dilation,
134 | bias=False,
135 | dilation=dilation
136 | )
137 | ), ("bn2", bn2)
138 | ]
139 | if dropout is not None:
140 | layers = layers[0:2] + [("dropout", dropout())] + layers[2:]
141 | else:
142 | bn3 = norm_act(channels[2])
143 | bn3.activation = "identity"
144 | layers = [
145 | ("conv1", nn.Conv2d(in_channels, channels[0], 1, stride=1, padding=0, bias=False)),
146 | ("bn1", norm_act(channels[0])),
147 | (
148 | "conv2",
149 | nn.Conv2d(channels[0],channels[1],3,stride=stride,padding=dilation,bias=False,groups=groups,dilation=dilation)
150 | ),
151 | ("bn2", norm_act(channels[1])),
152 | ("conv3", nn.Conv2d(channels[1], channels[2], 1, stride=1, padding=0, bias=False)),
153 | ("bn3", bn3)
154 | ]
155 |
156 | if dropout is not None:
157 | layers = layers[0:4] + [("dropout", dropout())] + layers[4:]
158 | self.convs = nn.Sequential(OrderedDict(layers))
159 |
160 | if need_proj_conv:
161 | self.proj_conv = nn.Conv2d(
162 | in_channels, channels[-1], 1, stride=stride, padding=0, bias=False
163 | )
164 | self.proj_bn = norm_act(channels[-1])
165 | self.proj_bn.activation = "identity"
166 |
167 | self._last = last
168 |
169 | def forward(self, x):
170 | if hasattr(self, "proj_conv"):
171 | residual = self.proj_conv(x)
172 | residual = self.proj_bn(residual)
173 | else:
174 | residual = x
175 |
176 | x = self.convs(x) + residual
177 |
178 | if self.convs.bn1.activation == "leaky_relu":
179 | act = functional.leaky_relu(
180 | x, negative_slope=self.convs.bn1.activation_param, inplace=not self._last
181 | )
182 | elif self.convs.bn1.activation == "elu":
183 | act = functional.elu(x, alpha=self.convs.bn1.activation_param, inplace=not self._last)
184 | elif self.convs.bn1.activation == "identity":
185 | act = x
186 |
187 | if self._last:
188 | return act, x
189 | return act
190 |
191 |
192 | class IdentityResidualBlock(nn.Module):
193 |
194 | def __init__(
195 | self,
196 | in_channels,
197 | channels,
198 | stride=1,
199 | dilation=1,
200 | groups=1,
201 | norm_act=nn.BatchNorm2d,
202 | dropout=None
203 | ):
204 | """Configurable identity-mapping residual block
205 |
206 | Parameters
207 | ----------
208 | in_channels : int
209 | Number of input channels.
210 | channels : list of int
211 | Number of channels in the internal feature maps. Can either have two or three elements: if three construct
212 | a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then
213 | `3 x 3` then `1 x 1` convolutions.
214 | stride : int
215 | Stride of the first `3 x 3` convolution
216 | dilation : int
217 | Dilation to apply to the `3 x 3` convolutions.
218 | groups : int
219 | Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with
220 | bottleneck blocks.
221 | norm_act : callable
222 | Function to create normalization / activation Module.
223 | dropout: callable
224 | Function to create Dropout Module.
225 | """
226 | super(IdentityResidualBlock, self).__init__()
227 |
228 | # Check parameters for inconsistencies
229 | if len(channels) != 2 and len(channels) != 3:
230 | raise ValueError("channels must contain either two or three values")
231 | if len(channels) == 2 and groups != 1:
232 | raise ValueError("groups > 1 are only valid if len(channels) == 3")
233 |
234 | is_bottleneck = len(channels) == 3
235 | need_proj_conv = stride != 1 or in_channels != channels[-1]
236 |
237 | self.bn1 = norm_act(in_channels)
238 | if not is_bottleneck:
239 | layers = [
240 | (
241 | "conv1",
242 | nn.Conv2d(
243 | in_channels,
244 | channels[0],
245 | 3,
246 | stride=stride,
247 | padding=dilation,
248 | bias=False,
249 | dilation=dilation
250 | )
251 | ), ("bn2", norm_act(channels[0])),
252 | (
253 | "conv2",
254 | nn.Conv2d(
255 | channels[0],
256 | channels[1],
257 | 3,
258 | stride=1,
259 | padding=dilation,
260 | bias=False,
261 | dilation=dilation
262 | )
263 | )
264 | ]
265 | if dropout is not None:
266 | layers = layers[0:2] + [("dropout", dropout())] + layers[2:]
267 | else:
268 | layers = [
269 | (
270 | "conv1",
271 | nn.Conv2d(in_channels, channels[0], 1, stride=stride, padding=0, bias=False)
272 | ), ("bn2", norm_act(channels[0])),
273 | (
274 | "conv2",
275 | nn.Conv2d(
276 | channels[0],
277 | channels[1],
278 | 3,
279 | stride=1,
280 | padding=dilation,
281 | bias=False,
282 | groups=groups,
283 | dilation=dilation
284 | )
285 | ), ("bn3", norm_act(channels[1])),
286 | ("conv3", nn.Conv2d(channels[1], channels[2], 1, stride=1, padding=0, bias=False))
287 | ]
288 | if dropout is not None:
289 | layers = layers[0:4] + [("dropout", dropout())] + layers[4:]
290 | self.convs = nn.Sequential(OrderedDict(layers))
291 |
292 | if need_proj_conv:
293 | self.proj_conv = nn.Conv2d(
294 | in_channels, channels[-1], 1, stride=stride, padding=0, bias=False
295 | )
296 |
297 | def forward(self, x):
298 | if hasattr(self, "proj_conv"):
299 | bn1 = self.bn1(x)
300 | shortcut = self.proj_conv(bn1)
301 | else:
302 | shortcut = x.clone()
303 | bn1 = self.bn1(x)
304 |
305 | out = self.convs(bn1)
306 | out.add_(shortcut)
307 |
308 | return out
309 |
--------------------------------------------------------------------------------
/modules/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as functional
4 |
5 | from models.util import try_index, DecoderBlock
6 |
7 |
8 | class Unet(nn.Module):
9 | def __init__(self,
10 | in_channels,
11 | filters,
12 | norm_act=nn.BatchNorm2d,
13 | ):
14 |
15 | super(Unet, self).__init__()
16 |
17 |
18 | self.u_conv1 = nn.Conv2d(in_channels=in_channels, out_channels=filters[3], kernel_size=1)
19 | self.u_conv2 = nn.Conv2d(in_channels=filters[1], out_channels=filters[0], kernel_size=1)
20 |
21 | self.u_decoder4 = DecoderBlock(filters[3], filters[2],norm_act=norm_act)
22 | self.u_decoder3 = DecoderBlock(filters[2], filters[1],norm_act=norm_act)
23 | self.u_decoder2 = DecoderBlock(filters[0], filters[0],norm_act=norm_act)
24 | self.u_decoder1 = DecoderBlock(filters[0], filters[0],norm_act=norm_act)
25 |
26 | self.u_conv1_new = nn.Conv2d(in_channels=in_channels, out_channels=filters[3], kernel_size=1)
27 | self.u_conv2_new = nn.Conv2d(in_channels=filters[1], out_channels=filters[0], kernel_size=1)
28 |
29 | self.u_decoder4_new = DecoderBlock(filters[3], filters[2], norm_act=norm_act)
30 | self.u_decoder3_new = DecoderBlock(filters[2], filters[1], norm_act=norm_act)
31 | self.u_decoder2_new = DecoderBlock(filters[0], filters[0], norm_act=norm_act)
32 | self.u_decoder1_new = DecoderBlock(filters[0], filters[0], norm_act=norm_act)
33 |
34 | def forward(self, x):
35 |
36 | e = functional.leaky_relu_(self.u_conv1(x[4]),negative_slope=0.01)
37 | d4 = self.u_decoder4(e+x[3]) + x[2]
38 | d3 = self.u_decoder3(d4) + x[1]
39 | e = functional.leaky_relu_(self.u_conv2(d3),negative_slope=0.01)
40 | d2 = self.u_decoder2(e+x[0])
41 | d1 = self.u_decoder1(d2)
42 | return functional.leaky_relu(d1,negative_slope=0.01)
43 |
44 |
45 |
46 | class Unet_2(nn.Module):
47 | def __init__(self,
48 | in_channels,
49 | filters,
50 | norm_act=nn.BatchNorm2d,
51 | ):
52 |
53 | super(Unet_2, self).__init__()
54 |
55 |
56 | self.u_conv1 = nn.Conv2d(in_channels=in_channels, out_channels=filters[3], kernel_size=1)
57 | self.u_conv2 = nn.Conv2d(in_channels=filters[1], out_channels=filters[0], kernel_size=1)
58 |
59 | self.u_decoder4 = DecoderBlock(filters[3], filters[2],norm_act=norm_act)
60 | self.u_decoder3 = DecoderBlock(filters[2], filters[1],norm_act=norm_act)
61 | self.u_decoder2 = DecoderBlock(filters[0], filters[0],norm_act=norm_act)
62 | self.u_decoder1 = DecoderBlock(filters[0], filters[0],norm_act=norm_act)
63 |
64 | self.u_conv1_new = nn.Conv2d(in_channels=in_channels, out_channels=filters[3], kernel_size=1)
65 | self.u_conv2_new = nn.Conv2d(in_channels=filters[1], out_channels=filters[0], kernel_size=1)
66 |
67 | self.u_decoder4_new = DecoderBlock(filters[3], filters[2], norm_act=norm_act)
68 | self.u_decoder3_new = DecoderBlock(filters[2], filters[1], norm_act=norm_act)
69 | self.u_decoder2_new = DecoderBlock(filters[0], filters[0], norm_act=norm_act)
70 | self.u_decoder1_new = DecoderBlock(filters[0], filters[0], norm_act=norm_act)
71 |
72 | def forward(self, x):
73 |
74 | # Decoder
75 | e = functional.leaky_relu_(self.u_conv1(x[4]),negative_slope=0.01)
76 | d4 = self.u_decoder4(e+x[3]) + x[2]
77 | d3 = self.u_decoder3(d4) + x[1]
78 | e = functional.leaky_relu_(self.u_conv2(d3),negative_slope=0.01)
79 | d2 = self.u_decoder2(e+x[0])
80 | d1 = self.u_decoder1(d2)
81 |
82 | e_new = functional.leaky_relu_(self.u_conv1_new(x[4]), negative_slope=0.01)
83 | d4_new = self.u_decoder4_new(e_new + x[3]) + x[2]
84 | d3_new = self.u_decoder3_new(d4_new) + x[1]
85 | e_new = functional.leaky_relu_(self.u_conv2_new(d3_new), negative_slope=0.01)
86 | d2_new = self.u_decoder2_new(e_new + x[0])
87 | d1_new = self.u_decoder1_new(d2_new)
88 |
89 | # r = torch.rand(1, d1.shape[1], 1, 1, dtype=torch.float32)
90 | # if self.training == False:
91 | # r[:, :, :, :] = 1.0
92 | # weight_out_branch = torch.zeros_like(r)
93 | # weight_out_new_branch = torch.zeros_like(r)
94 | # weight_out_branch[r < 0.33] = 2.
95 | # weight_out_new_branch[r < 0.33] = 0.
96 | # weight_out_branch[(r < 0.66) * (r >= 0.33)] = 0.
97 | # weight_out_new_branch[(r < 0.66) * (r >= 0.33)] = 2.
98 | # weight_out_branch[r >= 0.66] = 1.
99 | # weight_out_new_branch[r >= 0.66] = 1.
100 | # out = d1 * weight_out_branch.to(d1.device) * 0.5 + d1_new * weight_out_new_branch.to(d1_new.device) * 0.5
101 |
102 | out = d1 * 0.5 + d1_new * 0.5
103 | return functional.leaky_relu(out,negative_slope=0.01)
104 |
105 | class Unet_Intermediate(nn.Module):
106 | def __init__(self,
107 | in_channels,
108 | filters,
109 | norm_act=nn.BatchNorm2d,
110 | ):
111 |
112 | super(Unet_Intermediate, self).__init__()
113 |
114 |
115 | # self.u_conv1 = nn.Conv2d(in_channels=in_channels, out_channels=filters[3], kernel_size=1)
116 | self.u_conv2 = nn.Conv2d(in_channels=filters[1], out_channels=filters[0], kernel_size=1)
117 |
118 | self.u_decoder4 = DecoderBlock(filters[3], filters[2],norm_act=norm_act)
119 | self.u_decoder3 = DecoderBlock(filters[2], filters[1],norm_act=norm_act)
120 | self.u_decoder2 = DecoderBlock(filters[0], filters[0],norm_act=norm_act)
121 | self.u_decoder1 = DecoderBlock(filters[0], filters[0],norm_act=norm_act)
122 |
123 | # self.u_conv1_new = nn.Conv2d(in_channels=in_channels, out_channels=filters[3], kernel_size=1)
124 | # self.u_conv2_new = nn.Conv2d(in_channels=filters[1], out_channels=filters[0], kernel_size=1)
125 |
126 | # self.u_decoder4_new = DecoderBlock(filters[3], filters[2], norm_act=norm_act)
127 | # self.u_decoder3_new = DecoderBlock(filters[2], filters[1], norm_act=norm_act)
128 | # self.u_decoder2_new = DecoderBlock(filters[0], filters[0], norm_act=norm_act)
129 | # self.u_decoder1_new = DecoderBlock(filters[0], filters[0], norm_act=norm_act)
130 |
131 | def forward(self, x):
132 | # print(x[0].shape)
133 | # print(x[1].shape)
134 | # print(x[2].shape)
135 | # print(x[3].shape)
136 | # print(x[4].shape)
137 | # Decoder
138 | # e = functional.leaky_relu_(self.u_conv1(x[4]),negative_slope=0.01)
139 | d4 = self.u_decoder4(x[3]) + x[2]
140 | d3 = self.u_decoder3(d4) + x[1]
141 | e = functional.leaky_relu_(self.u_conv2(d3),negative_slope=0.01)
142 | d2 = self.u_decoder2(e+x[0])
143 | d1 = self.u_decoder1(d2)
144 |
145 | # e_new = functional.leaky_relu_(self.u_conv1_new(x[4]), negative_slope=0.01)
146 | # d4_new = self.u_decoder4_new(e_new + x[3]) + x[2]
147 | # d3_new = self.u_decoder3_new(d4_new) + x[1]
148 | # e_new = functional.leaky_relu_(self.u_conv2_new(d3_new), negative_slope=0.01)
149 | # d2_new = self.u_decoder2_new(e_new + x[0])
150 | # d1_new = self.u_decoder1_new(d2_new)
151 | #
152 | # r = torch.rand(1, d1.shape[1], 1, 1, dtype=torch.float32)
153 | # if self.training == False:
154 | # r[:, :, :, :] = 1.0
155 | # weight_out_branch = torch.zeros_like(r)
156 | # weight_out_new_branch = torch.zeros_like(r)
157 | # weight_out_branch[r < 0.33] = 2.
158 | # weight_out_new_branch[r < 0.33] = 0.
159 | # weight_out_branch[(r < 0.66) * (r >= 0.33)] = 0.
160 | # weight_out_new_branch[(r < 0.66) * (r >= 0.33)] = 2.
161 | # weight_out_branch[r >= 0.66] = 1.
162 | # weight_out_new_branch[r >= 0.66] = 1.
163 | # out = d1 * weight_out_branch.to(d1.device) * 0.5 + d1_new * weight_out_new_branch.to(d1_new.device) * 0.5
164 |
165 |
166 | return functional.leaky_relu(d1,negative_slope=0.01)
--------------------------------------------------------------------------------