├── .gitignore ├── LICENSE ├── README.md ├── dataset ├── __init__.py ├── ade.py ├── cityscapes_domain.py ├── consep.py ├── monusac.py ├── transform.py ├── utils.py └── voc.py ├── docs └── framework.png ├── models ├── __init__.py ├── resnet.py └── util.py └── modules ├── __init__.py ├── deeplab.py ├── fca_cid.py ├── misc.py ├── residual.py └── unet.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 why19991 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # InSeg 2 | ### Incremental Nuclei Segmentation from Histopathological Images 3 | ### via Future-class Awareness and Compatibility-inspired Distillation [CVPR 2024] 4 | ###### Huyong Wang1, Huisi Wu1, Jing Qin2 5 | ###### 1College of Computer Science and Software Engineering, Shenzhen University 6 | ###### 2Centre for Smart Health, School of Nursing, The Hong Kong Polytechnic University 7 | 8 | # Method 9 |
10 | 11 |
12 | 13 | # Notes 14 | Due to the confidentiality agreement in commercial cooperation, we only provide codes of core modules and the whole trainable models for the convenience of comparisons. 15 | 16 | # Environment 17 | * Python>=3.8 18 | * Pytorch>=1.8.1 19 | * Install [inplace_abn](https://github.com/mapillary/inplace_abn) 20 | # Thanks 21 | This code is heavily borrowed from [[EWF](https://github.com/schuy1er/EWF_official)] and [[PLOP](https://github.com/arthurdouillard/CVPR2021_PLOP)]. We appreciate their contributions to this community. 22 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .ade import AdeSegmentation, AdeSegmentationIncremental 2 | from .cityscapes_domain import (CityscapesSegmentationDomain, 3 | CityscapesSegmentationIncrementalDomain) 4 | from .voc import VOCSegmentation, VOCSegmentationIncremental 5 | -------------------------------------------------------------------------------- /dataset/ade.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import torch.utils.data as data 6 | import torchvision as tv 7 | from PIL import Image 8 | from torch import distributed 9 | 10 | from .utils import Subset, filter_images, group_images 11 | 12 | classes = [ 13 | "void", "wall", "building", "sky", "floor", "tree", "ceiling", "road", "bed ", "windowpane", 14 | "grass", "cabinet", "sidewalk", "person", "earth", "door", "table", "mountain", "plant", 15 | "curtain", "chair", "car", "water", "painting", "sofa", "shelf", "house", "sea", "mirror", 16 | "rug", "field", "armchair", "seat", "fence", "desk", "rock", "wardrobe", "lamp", "bathtub", 17 | "railing", "cushion", "base", "box", "column", "signboard", "chest of drawers", "counter", 18 | "sand", "sink", "skyscraper", "fireplace", "refrigerator", "grandstand", "path", "stairs", 19 | "runway", "case", "pool table", "pillow", "screen door", "stairway", "river", "bridge", 20 | "bookcase", "blind", "coffee table", "toilet", "flower", "book", "hill", "bench", "countertop", 21 | "stove", "palm", "kitchen island", "computer", "swivel chair", "boat", "bar", "arcade machine", 22 | "hovel", "bus", "towel", "light", "truck", "tower", "chandelier", "awning", "streetlight", 23 | "booth", "television receiver", "airplane", "dirt track", "apparel", "pole", "land", 24 | "bannister", "escalator", "ottoman", "bottle", "buffet", "poster", "stage", "van", "ship", 25 | "fountain", "conveyer belt", "canopy", "washer", "plaything", "swimming pool", "stool", 26 | "barrel", "basket", "waterfall", "tent", "bag", "minibike", "cradle", "oven", "ball", "food", 27 | "step", "tank", "trade name", "microwave", "pot", "animal", "bicycle", "lake", "dishwasher", 28 | "screen", "blanket", "sculpture", "hood", "sconce", "vase", "traffic light", "tray", "ashcan", 29 | "fan", "pier", "crt screen", "plate", "monitor", "bulletin board", "shower", "radiator", 30 | "glass", "clock", "flag" 31 | ] 32 | 33 | 34 | class AdeSegmentation(data.Dataset): 35 | 36 | def __init__(self, root, train=True, transform=None): 37 | 38 | root = os.path.expanduser(root) 39 | base_dir = "ADEChallengeData2016" 40 | ade_root = os.path.join(root, base_dir) 41 | if train: 42 | split = 'training' 43 | else: 44 | split = 'validation' 45 | annotation_folder = os.path.join(ade_root, 'annotations', split) 46 | image_folder = os.path.join(ade_root, 'images', split) 47 | 48 | self.images = [] 49 | fnames = sorted(os.listdir(image_folder)) 50 | self.images = [ 51 | (os.path.join(image_folder, x), os.path.join(annotation_folder, x[:-3] + "png")) 52 | for x in fnames 53 | ] 54 | 55 | self.transform = transform 56 | 57 | def __getitem__(self, index): 58 | """ 59 | Args: 60 | index (int): Index 61 | Returns: 62 | tuple: (image, target) where target is the image segmentation. 63 | """ 64 | img = Image.open(self.images[index][0]).convert('RGB') 65 | target = Image.open(self.images[index][1]) 66 | 67 | if self.transform is not None: 68 | img, target = self.transform(img, target) 69 | 70 | return img, target 71 | 72 | def __len__(self): 73 | return len(self.images) 74 | 75 | 76 | class AdeSegmentationIncremental(data.Dataset): 77 | 78 | def __init__( 79 | self, 80 | root, 81 | train=True, 82 | transform=None, 83 | labels=None, 84 | labels_old=None, 85 | idxs_path=None, 86 | masking=True, 87 | overlap=True, 88 | data_masking="current", 89 | ignore_test_bg=False, 90 | **kwargs 91 | ): 92 | 93 | full_data = AdeSegmentation(root, train) 94 | 95 | self.labels = [] 96 | self.labels_old = [] 97 | 98 | if labels is not None: 99 | # store the labels 100 | labels_old = labels_old if labels_old is not None else [] 101 | 102 | self.__strip_zero(labels) 103 | self.__strip_zero(labels_old) 104 | 105 | assert not any( 106 | l in labels_old for l in labels 107 | ), "labels and labels_old must be disjoint sets" 108 | 109 | self.labels = labels 110 | self.labels_old = labels_old 111 | 112 | self.order = [0] + labels_old + labels 113 | 114 | # take index of images with at least one class in labels and all classes in labels+labels_old+[255] 115 | if idxs_path is not None and os.path.exists(idxs_path): 116 | idxs = np.load(idxs_path).tolist() 117 | else: 118 | idxs = filter_images(full_data, labels, labels_old, overlap=overlap) 119 | if idxs_path is not None and distributed.get_rank() == 0: 120 | np.save(idxs_path, np.array(idxs, dtype=int)) 121 | 122 | #if train: 123 | # masking_value = 0 124 | #else: 125 | # masking_value = 255 126 | 127 | #self.inverted_order = {label: self.order.index(label) for label in self.order} 128 | #self.inverted_order[0] = masking_value 129 | 130 | self.inverted_order = {label: self.order.index(label) for label in self.order} 131 | if ignore_test_bg: 132 | masking_value = 255 133 | self.inverted_order[0] = masking_value 134 | else: 135 | masking_value = 0 # Future classes will be considered as background. 136 | self.inverted_order[255] = 255 137 | 138 | reorder_transform = tv.transforms.Lambda( 139 | lambda t: t.apply_( 140 | lambda x: self.inverted_order[x] if x in self.inverted_order else masking_value 141 | ) 142 | ) 143 | 144 | if masking: 145 | target_transform = tv.transforms.Lambda( 146 | lambda t: t. 147 | apply_(lambda x: self.inverted_order[x] if x in self.labels else masking_value) 148 | ) 149 | else: 150 | target_transform = reorder_transform 151 | 152 | # make the subset of the dataset 153 | self.dataset = Subset(full_data, idxs, transform, target_transform) 154 | else: 155 | self.dataset = full_data 156 | 157 | def __getitem__(self, index): 158 | """ 159 | Args: 160 | index (int): Index 161 | Returns: 162 | tuple: (image, target) where target is the image segmentation. 163 | """ 164 | 165 | return self.dataset[index] 166 | 167 | @staticmethod 168 | def __strip_zero(labels): 169 | while 0 in labels: 170 | labels.remove(0) 171 | 172 | def __len__(self): 173 | return len(self.dataset) 174 | -------------------------------------------------------------------------------- /dataset/cityscapes_domain.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import glob 3 | import os 4 | 5 | import numpy as np 6 | import torch.utils.data as data 7 | import torchvision as tv 8 | from PIL import Image 9 | from torch import distributed 10 | 11 | from .utils import Subset, group_images 12 | 13 | 14 | # Converting the id to the train_id. Many objects have a train id at 15 | # 255 (unknown / ignored). 16 | # See there for more information: 17 | # https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py 18 | id_to_trainid = { 19 | 0: 255, 20 | 1: 255, 21 | 2: 255, 22 | 3: 255, 23 | 4: 255, 24 | 5: 255, 25 | 6: 255, 26 | 7: 0, # road 27 | 8: 1, # sidewalk 28 | 9: 255, 29 | 10: 255, 30 | 11: 2, # building 31 | 12: 3, # wall 32 | 13: 4, # fence 33 | 14: 255, 34 | 15: 255, 35 | 16: 255, 36 | 17: 5, # pole 37 | 18: 255, 38 | 19: 6, # traffic light 39 | 20: 7, # traffic sign 40 | 21: 8, # vegetation 41 | 22: 9, # terrain 42 | 23: 10, # sky 43 | 24: 11, # person 44 | 25: 12, # rider 45 | 26: 13, # car 46 | 27: 14, # truck 47 | 28: 15, # bus 48 | 29: 255, 49 | 30: 255, 50 | 31: 16, # train 51 | 32: 17, # motorcycle 52 | 33: 18, # bicycle 53 | -1: 255 54 | } 55 | 56 | city_to_id = { 57 | "aachen": 0, "bremen": 1, "darmstadt": 2, "erfurt": 3, "hanover": 4, 58 | "krefeld": 5, "strasbourg": 6, "tubingen": 7, "weimar": 8, "bochum": 9, 59 | "cologne": 10, "dusseldorf": 11, "hamburg": 12, "jena": 13, 60 | "monchengladbach": 14, "stuttgart": 15, "ulm": 16, "zurich": 17, 61 | "frankfurt": 18, "lindau": 19, "munster": 20 62 | } 63 | 64 | 65 | def filter_images(dataset, labels): 66 | # Filter images without any label in LABELS (using labels not reordered) 67 | idxs = [] 68 | 69 | print(f"Filtering images...") 70 | for i in range(len(dataset)): 71 | domain_id = dataset.__getitem__(i, get_domain=True) # taking domain id 72 | if domain_id in labels: 73 | idxs.append(i) 74 | if i % 1000 == 0: 75 | print(f"\t{i}/{len(dataset)} ...") 76 | return idxs 77 | 78 | 79 | class CityscapesSegmentationDomain(data.Dataset): 80 | 81 | def __init__(self, root, train=True, transform=None, domain_transform=None): 82 | root = os.path.expanduser(root) 83 | annotation_folder = os.path.join(root, 'gtFine') 84 | image_folder = os.path.join(root, 'leftImg8bit') 85 | 86 | self.images = [ # Add train cities 87 | ( 88 | path, 89 | os.path.join( 90 | annotation_folder, 91 | "train", 92 | path.split("/")[-2], 93 | path.split("/")[-1][:-15] + "gtFine_labelIds.png" 94 | ), 95 | city_to_id[path.split("/")[-2]] 96 | ) for path in sorted(glob.glob(os.path.join(image_folder, "train/*/*.png"))) 97 | ] 98 | self.images += [ # Add validation cities 99 | ( 100 | path, 101 | os.path.join( 102 | annotation_folder, 103 | "val", 104 | path.split("/")[-2], 105 | path.split("/")[-1][:-15] + "gtFine_labelIds.png" 106 | ), 107 | city_to_id[path.split("/")[-2]] 108 | ) for path in sorted(glob.glob(os.path.join(image_folder, "val/*/*.png"))) 109 | ] 110 | 111 | self.transform = transform 112 | self.domain_transform = domain_transform 113 | 114 | def __getitem__(self, index, get_domain=False): 115 | """ 116 | Args: 117 | index (int): Index 118 | Returns: 119 | tuple: (image, target) where target is the image segmentation. 120 | """ 121 | if get_domain: 122 | domain = self.images[index][2] 123 | if self.domain_transform is not None: 124 | domain = self.domain_transform(domain) 125 | return domain 126 | 127 | try: 128 | img = Image.open(self.images[index][0]).convert('RGB') 129 | target = Image.open(self.images[index][1]) 130 | except Exception as e: 131 | raise Exception(f"Index: {index}, len: {len(self)}, message: {str(e)}") 132 | 133 | if self.transform is not None: 134 | img, target = self.transform(img, target) 135 | 136 | return img, target 137 | 138 | def __len__(self): 139 | return len(self.images) 140 | 141 | 142 | class CityscapesSegmentationIncrementalDomain(data.Dataset): 143 | """Labels correspond to domains not classes in this case.""" 144 | def __init__( 145 | self, 146 | root, 147 | train=True, 148 | transform=None, 149 | labels=None, 150 | idxs_path=None, 151 | masking=True, 152 | overlap=True, 153 | **kwargs 154 | ): 155 | full_data = CityscapesSegmentationDomain(root, train) 156 | 157 | # take index of images with at least one class in labels and all classes in labels+labels_old+[255] 158 | if idxs_path is not None and os.path.exists(idxs_path): 159 | idxs = np.load(idxs_path).tolist() 160 | else: 161 | idxs = filter_images(full_data, labels) 162 | if idxs_path is not None and distributed.get_rank() == 0: 163 | np.save(idxs_path, np.array(idxs, dtype=int)) 164 | 165 | rnd = np.random.RandomState(1) 166 | rnd.shuffle(idxs) 167 | train_len = int(0.8 * len(idxs)) 168 | if train: 169 | idxs = idxs[:train_len] 170 | print(f"{len(idxs)} images for train") 171 | else: 172 | idxs = idxs[train_len:] 173 | print(f"{len(idxs)} images for val") 174 | 175 | target_transform = tv.transforms.Lambda( 176 | lambda t: t. 177 | apply_(lambda x: id_to_trainid.get(x, 255)) 178 | ) 179 | # make the subset of the dataset 180 | self.dataset = Subset(full_data, idxs, transform, target_transform) 181 | 182 | def __getitem__(self, index): 183 | """ 184 | Args: 185 | index (int): Index 186 | Returns: 187 | tuple: (image, target) where target is the image segmentation. 188 | """ 189 | 190 | return self.dataset[index] 191 | 192 | def __len__(self): 193 | return len(self.dataset) 194 | -------------------------------------------------------------------------------- /dataset/consep.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import copy 4 | 5 | import numpy as np 6 | import torch 7 | import torch.utils.data as data 8 | import torchvision as tv 9 | from PIL import Image 10 | from torch import distributed 11 | 12 | from dataset import transform 13 | from .utils import Subset, filter_images, group_images 14 | 15 | classes = { 16 | 0: 'background', 17 | 1: 'class 1', 18 | 2: 'class 2', 19 | 3: 'class 3', 20 | 4: 'class 4' 21 | } 22 | 23 | class CoNSePSegmentation(data.Dataset): 24 | """ 25 | Args: 26 | root (string): Root directory of the VOC Dataset. 27 | image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val`` 28 | is_aug (bool, optional): If you want to use the augmented train set or not (default is True) 29 | transform (callable, optional): A function/transform that takes in an PIL image 30 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 31 | """ 32 | 33 | def __init__(self, root, image_set='train', is_aug=True, transform=None): 34 | 35 | self.root = os.path.expanduser(root) 36 | self.year = 2020 37 | 38 | self.transform = transform 39 | 40 | self.image_set = image_set 41 | monusac_root = self.root 42 | splits_dir = os.path.join(monusac_root,'splits') 43 | 44 | if not os.path.exists(monusac_root): 45 | raise RuntimeError( 46 | f'Dataset {monusac_root} not found or corrupted.' 47 | ) 48 | 49 | split_f = os.path.join(splits_dir,image_set.rstrip('\n') + '.txt') 50 | if not os.path.exists(split_f): 51 | raise ValueError( 52 | 'Wrong image_set entered! Please use image_set="train" ' 53 | 'or image_set="trainval" or image_set="val" ' 54 | f'{split_f}' 55 | ) 56 | 57 | # remove leading \n 58 | with open(os.path.join(split_f), "r") as f: 59 | file_names = [x[:-1].split(" ") for x in f.readlines()] 60 | 61 | 62 | self.images = [ 63 | ( 64 | os.path.join(monusac_root,x[0][1:]), os.path.join(monusac_root, x[1][1:]) 65 | ) for x in file_names 66 | ] 67 | 68 | def __getitem__(self, index): 69 | """ 70 | Args: 71 | index (int): Index 72 | Returns: 73 | tuple: (image, target) where target is the image segmentation. 74 | """ 75 | 76 | img = Image.open(self.images[index][0]).convert('RGB') 77 | target = Image.open(self.images[index][1]) 78 | # img = Image.open("/data/wzz/MoNuSAC/train/image_crop/TCGA-5P-A9K0-01Z-00-DX1_1_11.png").convert('RGB') 79 | # target = Image.open("/data/wzz/MoNuSAC/train/mask_crop/TCGA-5P-A9K0-01Z-00-DX1_1_11.png") 80 | if self.transform is not None: 81 | img, target = self.transform(img, target) 82 | 83 | # return img, target, "TCGA-5P-A9K0-01Z-00-DX1_1_11" 84 | return img, target, self.images[index][1][-31:-4] 85 | 86 | 87 | def __len__(self): 88 | return len(self.images) 89 | 90 | def viz_getter(self, index): 91 | image_path = self.images[index][0] 92 | raw_image = Image.open(self.images[index][0]).convert('RGB') 93 | target = Image.open(self.images[index][1]) 94 | if self.transform is not None: 95 | img, target = self.transform(raw_image, target) 96 | else: 97 | img = copy.deepcopy(raw_image) 98 | return image_path, raw_image, img, target 99 | 100 | 101 | class CoNSePSegmentationIncremental(data.Dataset): 102 | def __init__(self, 103 | root, 104 | train=True, 105 | transform=None, 106 | labels=None, 107 | labels_old=None, 108 | idxs_path=None, 109 | masking=True, 110 | overlap=True, 111 | data_masking="current", 112 | test_on_val=False, 113 | **kwargs 114 | ): 115 | 116 | full_consep = CoNSePSegmentation(root,'train' if train else 'val',is_aug= True,transform=None) 117 | self.labels = [] 118 | self.labels_old = [] 119 | 120 | if labels is not None: 121 | 122 | labels_old = labels_old if labels_old is not None else [] 123 | 124 | self.__strip_zero(labels) 125 | self.__strip_zero(labels_old) 126 | 127 | assert not any( 128 | l in labels_old for l in labels 129 | ), "labels and labels_old must be disjoint sets" 130 | 131 | self.labels = [0] + labels 132 | self.labels_old = [0] + labels_old 133 | 134 | self.order = [0] + labels_old + labels 135 | 136 | if idxs_path is not None and os.path.exists(idxs_path): 137 | idxs = np.load(idxs_path).tolist() 138 | 139 | else: 140 | idxs = filter_images(full_consep, labels, labels_old,overlap=overlap) 141 | if idxs_path is not None and distributed.get_rank() == 0: 142 | np.save(idxs_path,np.array(idxs,dtype=int)) 143 | 144 | if test_on_val: 145 | rnd = np.random.RandomState(1) 146 | rnd.shuffle(idxs) 147 | train_len = int(0.8 * len(idxs)) 148 | if train: 149 | idxs = idxs[:train_len] 150 | else: 151 | idxs = idxs[train_len:] 152 | 153 | masking_value = 0 # Future classes will be considered as background. 154 | self.inverted_order = {label: self.order.index(label) for label in self.order} 155 | self.inverted_order[255] = 255 156 | reorder_transform = tv.transforms.Lambda( 157 | lambda t: t.apply_( 158 | lambda x: self.inverted_order[x] if x in self.inverted_order else masking_value 159 | ) 160 | ) 161 | 162 | if masking: 163 | if data_masking == 'current': 164 | tmp_labels = self.labels + [255] 165 | elif data_masking == 'current+old': 166 | tmp_labels = labels_old + self.labels + [255] 167 | elif data_masking == 'all': 168 | raise NotImplementedError( 169 | f"data_masking={data_masking} not yet implemented sorry not sorry." 170 | ) 171 | 172 | elif data_masking == "new": 173 | tmp_labels = self.labels 174 | masking_value = 255 175 | 176 | target_transform = tv.transforms.Lambda( 177 | lambda t: t. 178 | apply_(lambda x: self.inverted_order[x] if x in tmp_labels else masking_value) 179 | ) 180 | else: 181 | assert False 182 | target_transform = reorder_transform 183 | 184 | # make the subset of the dataset 185 | self.dataset = Subset(full_consep, idxs, transform, target_transform) 186 | else: 187 | self.dataset = full_consep 188 | 189 | 190 | 191 | @staticmethod 192 | def __strip_zero(labels): 193 | while 0 in labels: 194 | labels.remove(0) 195 | 196 | 197 | def __len__(self): 198 | return len(self.dataset) 199 | 200 | def __getitem__(self, index): 201 | """ 202 | Args: 203 | index (int): Index 204 | Returns: 205 | tuple: (image, target) where target is the image segmentation. 206 | """ 207 | 208 | return self.dataset[index] -------------------------------------------------------------------------------- /dataset/monusac.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import copy 4 | 5 | import numpy as np 6 | import torch.utils.data as data 7 | import torchvision as tv 8 | from PIL import Image 9 | from torch import distributed 10 | 11 | from .utils import Subset, filter_images, group_images 12 | 13 | classes = { 14 | 0: 'background', 15 | 1: 'class 1', 16 | 2: 'class 2', 17 | 3: 'class 3' 18 | } 19 | 20 | class MoNuSACSegmentation(data.Dataset): 21 | """ 22 | Args: 23 | root (string): Root directory of the VOC Dataset. 24 | image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val`` 25 | is_aug (bool, optional): If you want to use the augmented train set or not (default is True) 26 | transform (callable, optional): A function/transform that takes in an PIL image 27 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 28 | """ 29 | 30 | def __init__(self, root, image_set='train', is_aug=True, transform=None): 31 | 32 | self.root = os.path.expanduser(root) 33 | self.year = 2020 34 | 35 | self.transform = transform 36 | 37 | self.image_set = image_set 38 | monusac_root = self.root 39 | splits_dir = os.path.join(monusac_root,'splits') 40 | 41 | if not os.path.exists(monusac_root): 42 | raise RuntimeError( 43 | f'Dataset {monusac_root} not found or corrupted.' 44 | ) 45 | 46 | split_f = os.path.join(splits_dir,image_set.rstrip('\n') + '.txt') 47 | if not os.path.exists(split_f): 48 | raise ValueError( 49 | 'Wrong image_set entered! Please use image_set="train" ' 50 | 'or image_set="trainval" or image_set="val" ' 51 | f'{split_f}' 52 | ) 53 | 54 | # remove leading \n 55 | with open(os.path.join(split_f), "r") as f: 56 | file_names = [x[:-1].split(" ") for x in f.readlines()] 57 | 58 | 59 | self.images = [ 60 | ( 61 | os.path.join(monusac_root,x[0][1:]), os.path.join(monusac_root, x[1][1:]) 62 | ) for x in file_names 63 | ] 64 | 65 | def __getitem__(self, index): 66 | """ 67 | Args: 68 | index (int): Index 69 | Returns: 70 | tuple: (image, target) where target is the image segmentation. 71 | """ 72 | 73 | 74 | img = Image.open(self.images[index][0]).convert('RGB') 75 | target = Image.open(self.images[index][1]) 76 | # img = Image.open("/data/why/MoNuSAC/train/image_crop/TCGA-IZ-A6M9-01Z-00-DX1_3_1.png").convert('RGB') 77 | # target = Image.open("/data/why/MoNuSAC/test/mask_crop/TCGA-IZ-A6M9-01Z-00-DX1_3_1.png") 78 | if self.transform is not None: 79 | img, target = self.transform(img, target) 80 | 81 | # return img, target, "TCGA-IZ-A6M9-01Z-00-DX1_3_1" 82 | # return img, target, self.images[index][1][-31:-4] 83 | return img, target, self.images[index][0][35:] 84 | 85 | 86 | def __len__(self): 87 | return len(self.images) 88 | 89 | def viz_getter(self, index): 90 | image_path = self.images[index][0] 91 | raw_image = Image.open(self.images[index][0]).convert('RGB') 92 | target = Image.open(self.images[index][1]) 93 | if self.transform is not None: 94 | img, target = self.transform(raw_image, target) 95 | else: 96 | img = copy.deepcopy(raw_image) 97 | return image_path, raw_image, img, target 98 | 99 | 100 | class MoNuSACSegmentationIncremental(data.Dataset): 101 | def __init__(self, 102 | root, 103 | train=True, 104 | transform=None, 105 | labels=None, 106 | labels_old=None, 107 | idxs_path=None, 108 | masking=True, 109 | overlap=True, 110 | data_masking="current", 111 | test_on_val=False, 112 | image_Set=None, 113 | **kwargs 114 | ): 115 | 116 | full_monusac = MoNuSACSegmentation(root,image_set=image_Set,is_aug= True,transform=None) 117 | self.labels = [] 118 | self.labels_old = [] 119 | 120 | if labels is not None: 121 | 122 | labels_old = labels_old if labels_old is not None else [] 123 | 124 | self.__strip_zero(labels) 125 | self.__strip_zero(labels_old) 126 | 127 | assert not any( 128 | l in labels_old for l in labels 129 | ), "labels and labels_old must be disjoint sets" 130 | 131 | self.labels = [0] + labels 132 | self.labels_old = [0] + labels_old 133 | 134 | self.order = [0] + labels_old + labels 135 | 136 | if idxs_path is not None and os.path.exists(idxs_path): 137 | idxs = np.load(idxs_path).tolist() 138 | 139 | else: 140 | idxs = filter_images(full_monusac, labels, labels_old,overlap=overlap) 141 | if idxs_path is not None and distributed.get_rank() == 0: 142 | np.save(idxs_path,np.array(idxs,dtype=int)) 143 | 144 | if test_on_val: 145 | rnd = np.random.RandomState(1) 146 | rnd.shuffle(idxs) 147 | train_len = int(0.8 * len(idxs)) 148 | if train: 149 | idxs = idxs[:train_len] 150 | else: 151 | idxs = idxs[train_len:] 152 | 153 | masking_value = 0 # Future classes will be considered as background. 154 | self.inverted_order = {label: self.order.index(label) for label in self.order} 155 | self.inverted_order[255] = 255 156 | reorder_transform = tv.transforms.Lambda( 157 | lambda t: t.apply_( 158 | lambda x: self.inverted_order[x] if x in self.inverted_order else masking_value 159 | ) 160 | ) 161 | 162 | if masking: 163 | if data_masking == 'current': 164 | tmp_labels = self.labels + [255] 165 | elif data_masking == 'current+old': 166 | tmp_labels = labels_old + self.labels + [255] 167 | elif data_masking == 'all': 168 | raise NotImplementedError( 169 | f"data_masking={data_masking} not yet implemented sorry not sorry." 170 | ) 171 | 172 | elif data_masking == "new": 173 | tmp_labels = self.labels 174 | masking_value = 255 175 | 176 | target_transform = tv.transforms.Lambda( 177 | lambda t: t. 178 | apply_(lambda x: self.inverted_order[x] if x in tmp_labels else masking_value) 179 | ) 180 | else: 181 | assert False 182 | target_transform = reorder_transform 183 | 184 | # make the subset of the dataset 185 | self.dataset = Subset(full_monusac, idxs, transform, target_transform) 186 | else: 187 | self.dataset = full_monusac 188 | 189 | 190 | 191 | @staticmethod 192 | def __strip_zero(labels): 193 | while 0 in labels: 194 | labels.remove(0) 195 | 196 | 197 | def __len__(self): 198 | return len(self.dataset) 199 | 200 | def __getitem__(self, index): 201 | """ 202 | Args: 203 | index (int): Index 204 | Returns: 205 | tuple: (image, target) where target is the image segmentation. 206 | """ 207 | 208 | return self.dataset[index] -------------------------------------------------------------------------------- /dataset/transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms.functional as F 3 | import random 4 | import numbers 5 | import numpy as np 6 | import collections 7 | from PIL import Image 8 | import warnings 9 | import math 10 | 11 | _pil_interpolation_to_str = { 12 | Image.NEAREST: 'PIL.Image.NEAREST', 13 | Image.BILINEAR: 'PIL.Image.BILINEAR', 14 | Image.BICUBIC: 'PIL.Image.BICUBIC', 15 | Image.LANCZOS: 'PIL.Image.LANCZOS', 16 | Image.HAMMING: 'PIL.Image.HAMMING', 17 | Image.BOX: 'PIL.Image.BOX', 18 | } 19 | 20 | 21 | class Compose(object): 22 | """Composes several transforms together. 23 | 24 | Args: 25 | transforms (list of ``Transform`` objects): list of transforms to compose. 26 | 27 | """ 28 | 29 | def __init__(self, transforms): 30 | self.transforms = transforms 31 | 32 | def __call__(self, img, lbl=None): 33 | if lbl is not None: 34 | for t in self.transforms: 35 | img, lbl = t(img, lbl) 36 | return img, lbl 37 | else: 38 | for t in self.transforms: 39 | img = t(img) 40 | return img 41 | 42 | def __repr__(self): 43 | format_string = self.__class__.__name__ + '(' 44 | for t in self.transforms: 45 | format_string += '\n' 46 | format_string += ' {0}'.format(t) 47 | format_string += '\n)' 48 | return format_string 49 | 50 | 51 | class Resize(object): 52 | """Resize the input PIL Image to the given size. 53 | 54 | Args: 55 | size (sequence or int): Desired output size. If size is a sequence like 56 | (h, w), output size will be matched to this. If size is an int, 57 | smaller edge of the image will be matched to this number. 58 | i.e, if height > width, then image will be rescaled to 59 | (size * height / width, size) 60 | interpolation (int, optional): Desired interpolation. Default is 61 | ``PIL.Image.BILINEAR`` 62 | """ 63 | 64 | def __init__(self, size, interpolation=Image.BILINEAR): 65 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 66 | self.size = size 67 | self.interpolation = interpolation 68 | 69 | def __call__(self, img, lbl=None): 70 | """ 71 | Args: 72 | img (PIL Image): Image to be scaled. 73 | 74 | Returns: 75 | PIL Image: Rescaled image. 76 | """ 77 | if lbl is not None: 78 | return F.resize(img, self.size, self.interpolation), F.resize(lbl, self.size, Image.NEAREST) 79 | else: 80 | return F.resize(img, self.size, self.interpolation) 81 | 82 | def __repr__(self): 83 | interpolate_str = _pil_interpolation_to_str[self.interpolation] 84 | return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str) 85 | 86 | 87 | class CenterCrop(object): 88 | """Crops the given PIL Image at the center. 89 | 90 | Args: 91 | size (sequence or int): Desired output size of the crop. If size is an 92 | int instead of sequence like (h, w), a square crop (size, size) is 93 | made. 94 | """ 95 | 96 | def __init__(self, size): 97 | if isinstance(size, numbers.Number): 98 | self.size = (int(size), int(size)) 99 | else: 100 | self.size = size 101 | 102 | def __call__(self, img, lbl=None): 103 | """ 104 | Args: 105 | img (PIL Image): Image to be cropped. 106 | 107 | Returns: 108 | PIL Image: Cropped image. 109 | """ 110 | if lbl is not None: 111 | return F.center_crop(img, self.size), F.center_crop(lbl, self.size) 112 | else: 113 | return F.center_crop(img, self.size) 114 | 115 | def __repr__(self): 116 | return self.__class__.__name__ + f'(size={self.size})' 117 | 118 | 119 | class Pad(object): 120 | """Pad the given PIL Image on all sides with the given "pad" value. 121 | Args: 122 | padding (int or tuple): Padding on each border. If a single int is provided this 123 | is used to pad all borders. If tuple of length 2 is provided this is the padding 124 | on left/right and top/bottom respectively. If a tuple of length 4 is provided 125 | this is the padding for the left, top, right and bottom borders 126 | respectively. 127 | fill (int): Pixel fill value for constant fill. Default is 0. 128 | This value is only used when the padding_mode is constant 129 | padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. 130 | Default is constant. 131 | - constant: pads with a constant value, this value is specified with fill 132 | - edge: pads with the last value at the edge of the image 133 | - reflect: pads with reflection of image without repeating the last value on the edge 134 | For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode 135 | will result in [3, 2, 1, 2, 3, 4, 3, 2] 136 | - symmetric: pads with reflection of image repeating the last value on the edge 137 | For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode 138 | will result in [2, 1, 1, 2, 3, 4, 4, 3] 139 | """ 140 | 141 | def __init__(self, padding, fill=0, padding_mode='constant'): 142 | assert isinstance(padding, (numbers.Number, tuple)) 143 | assert isinstance(fill, (numbers.Number, str)) 144 | assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'] 145 | if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]: 146 | raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + 147 | "{} element tuple".format(len(padding))) 148 | 149 | self.padding = padding 150 | self.fill = fill 151 | self.padding_mode = padding_mode 152 | 153 | def __call__(self, img, lbl=None): 154 | """ 155 | Args: 156 | img (PIL Image): Image to be padded. 157 | Returns: 158 | PIL Image: Padded image. 159 | """ 160 | if lbl is not None: 161 | return F.pad(img, self.padding, self.fill, self.padding_mode), F.pad(lbl, self.padding, self.fill, self.padding_mode) 162 | else: 163 | return F.pad(img, self.padding, self.fill, self.padding_mode) 164 | 165 | def __repr__(self): 166 | return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\ 167 | format(self.padding, self.fill, self.padding_mode) 168 | 169 | 170 | class Lambda(object): 171 | """Apply a user-defined lambda as a transform. 172 | Args: 173 | lambd (function): Lambda/function to be used for transform. 174 | """ 175 | 176 | def __init__(self, lambd): 177 | assert callable(lambd), repr(type(lambd).__name__) + " object is not callable" 178 | self.lambd = lambd 179 | 180 | def __call__(self, img, lbl=None): 181 | if lbl is not None: 182 | return self.lambd(img), self.lambd(lbl) 183 | else: 184 | return self.lambd(img) 185 | 186 | def __repr__(self): 187 | return self.__class__.__name__ + '()' 188 | 189 | 190 | class RandomRotation(object): 191 | """Rotate the image by angle. 192 | 193 | Args: 194 | degrees (sequence or float or int): Range of degrees to select from. 195 | If degrees is a number instead of sequence like (min, max), the range of degrees 196 | will be (-degrees, +degrees). 197 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): 198 | An optional resampling filter. 199 | See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters 200 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. 201 | expand (bool, optional): Optional expansion flag. 202 | If true, expands the output to make it large enough to hold the entire rotated image. 203 | If false or omitted, make the output image the same size as the input image. 204 | Note that the expand flag assumes rotation around the center and no translation. 205 | center (2-tuple, optional): Optional center of rotation. 206 | Origin is the upper left corner. 207 | Default is the center of the image. 208 | """ 209 | 210 | def __init__(self, degrees, resample=False, expand=False, center=None): 211 | if isinstance(degrees, numbers.Number): 212 | if degrees < 0: 213 | raise ValueError("If degrees is a single number, it must be positive.") 214 | self.degrees = (-degrees, degrees) 215 | else: 216 | if len(degrees) != 2: 217 | raise ValueError("If degrees is a sequence, it must be of len 2.") 218 | self.degrees = degrees 219 | 220 | self.resample = resample 221 | self.expand = expand 222 | self.center = center 223 | 224 | @staticmethod 225 | def get_params(degrees): 226 | """Get parameters for ``rotate`` for a random rotation. 227 | 228 | Returns: 229 | sequence: params to be passed to ``rotate`` for random rotation. 230 | """ 231 | angle = random.uniform(degrees[0], degrees[1]) 232 | 233 | return angle 234 | 235 | def __call__(self, img, lbl): 236 | """ 237 | img (PIL Image): Image to be rotated. 238 | lbl (PIL Image): Label to be rotated. 239 | 240 | Returns: 241 | PIL Image: Rotated image. 242 | PIL Image: Rotated label. 243 | """ 244 | 245 | angle = self.get_params(self.degrees) 246 | if lbl is not None: 247 | return F.rotate(img, angle, self.resample, self.expand, self.center), \ 248 | F.rotate(lbl, angle, self.resample, self.expand, self.center) 249 | else: 250 | return F.rotate(img, angle, self.resample, self.expand, self.center) 251 | 252 | def __repr__(self): 253 | format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees) 254 | format_string += ', resample={0}'.format(self.resample) 255 | format_string += ', expand={0}'.format(self.expand) 256 | if self.center is not None: 257 | format_string += ', center={0}'.format(self.center) 258 | format_string += ')' 259 | return format_string 260 | 261 | 262 | class RandomHorizontalFlip(object): 263 | """Horizontally flip the given PIL Image randomly with a given probability. 264 | 265 | Args: 266 | p (float): probability of the image being flipped. Default value is 0.5 267 | """ 268 | 269 | def __init__(self, p=0.5): 270 | self.p = p 271 | 272 | def __call__(self, img, lbl=None): 273 | """ 274 | Args: 275 | img (PIL Image): Image to be flipped. 276 | 277 | Returns: 278 | PIL Image: Randomly flipped image. 279 | """ 280 | if random.random() < self.p: 281 | if lbl is not None: 282 | return F.hflip(img), F.hflip(lbl) 283 | else: 284 | return F.hflip(img) 285 | if lbl is not None: 286 | return img, lbl 287 | else: 288 | return img 289 | 290 | def __repr__(self): 291 | return self.__class__.__name__ + '(p={})'.format(self.p) 292 | 293 | 294 | class RandomVerticalFlip(object): 295 | """Vertically flip the given PIL Image randomly with a given probability. 296 | 297 | Args: 298 | p (float): probability of the image being flipped. Default value is 0.5 299 | """ 300 | 301 | def __init__(self, p=0.5): 302 | self.p = p 303 | 304 | def __call__(self, img, lbl): 305 | """ 306 | Args: 307 | img (PIL Image): Image to be flipped. 308 | lbl (PIL Image): Label to be flipped. 309 | 310 | Returns: 311 | PIL Image: Randomly flipped image. 312 | PIL Image: Randomly flipped label. 313 | """ 314 | if random.random() < self.p: 315 | if lbl is not None: 316 | return F.vflip(img), F.vflip(lbl) 317 | else: 318 | return F.vflip(img) 319 | if lbl is not None: 320 | return img, lbl 321 | else: 322 | return img 323 | 324 | def __repr__(self): 325 | return self.__class__.__name__ + '(p={})'.format(self.p) 326 | 327 | 328 | class ToTensor(object): 329 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 330 | 331 | Converts a PIL Image or numpy.ndarray (H x W x C) in the range 332 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] 333 | if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) 334 | or if the numpy.ndarray has dtype = np.uint8 335 | In the other cases, tensors are returned without scaling. 336 | 337 | """ 338 | 339 | def __call__(self, pic, lbl=None): 340 | """ 341 | Note that labels will not be normalized to [0, 1]. 342 | 343 | Args: 344 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 345 | lbl (PIL Image or numpy.ndarray): Label to be converted to tensor. 346 | Returns: 347 | Tensor: Converted image and label 348 | """ 349 | if lbl is not None: 350 | return F.to_tensor(pic), torch.from_numpy(np.array(lbl, dtype=np.uint8)) 351 | else: 352 | return F.to_tensor(pic) 353 | 354 | def __repr__(self): 355 | return self.__class__.__name__ + '()' 356 | 357 | 358 | class Normalize(object): 359 | """Normalize a tensor image with mean and standard deviation. 360 | Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform 361 | will normalize each channel of the input ``torch.*Tensor`` i.e. 362 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` 363 | 364 | Args: 365 | mean (sequence): Sequence of means for each channel. 366 | std (sequence): Sequence of standard deviations for each channel. 367 | """ 368 | 369 | def __init__(self, mean, std): 370 | self.mean = mean 371 | self.std = std 372 | 373 | def __call__(self, tensor, lbl=None): 374 | """ 375 | Args: 376 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 377 | tensor (Tensor): Tensor of label. A dummy input for ExtCompose 378 | Returns: 379 | Tensor: Normalized Tensor image. 380 | Tensor: Unchanged Tensor label 381 | """ 382 | if lbl is not None: 383 | return F.normalize(tensor, self.mean, self.std), lbl 384 | else: 385 | return F.normalize(tensor, self.mean, self.std) 386 | 387 | def __repr__(self): 388 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) 389 | 390 | 391 | class RandomCrop(object): 392 | """Crop the given PIL Image at a random location. 393 | 394 | Args: 395 | size (sequence or int): Desired output size of the crop. If size is an 396 | int instead of sequence like (h, w), a square crop (size, size) is 397 | made. 398 | padding (int or sequence, optional): Optional padding on each border 399 | of the image. Default is 0, i.e no padding. If a sequence of length 400 | 4 is provided, it is used to pad left, top, right, bottom borders 401 | respectively. 402 | pad_if_needed (boolean): It will pad the image if smaller than the 403 | desired size to avoid raising an exception. 404 | """ 405 | 406 | def __init__(self, size, padding=0, pad_if_needed=False): 407 | if isinstance(size, numbers.Number): 408 | self.size = (int(size), int(size)) 409 | else: 410 | self.size = size 411 | self.padding = padding 412 | self.pad_if_needed = pad_if_needed 413 | 414 | @staticmethod 415 | def get_params(img, output_size): 416 | """Get parameters for ``crop`` for a random crop. 417 | 418 | Args: 419 | img (PIL Image): Image to be cropped. 420 | output_size (tuple): Expected output size of the crop. 421 | 422 | Returns: 423 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 424 | """ 425 | w, h = img.size 426 | th, tw = output_size 427 | if w == tw and h == th: 428 | return 0, 0, h, w 429 | 430 | i = random.randint(0, h - th) 431 | j = random.randint(0, w - tw) 432 | return i, j, th, tw 433 | 434 | def __call__(self, img, lbl=None): 435 | """ 436 | Args: 437 | img (PIL Image): Image to be cropped. 438 | lbl (PIL Image): Label to be cropped. 439 | Returns: 440 | PIL Image: Cropped image. 441 | PIL Image: Cropped label. 442 | """ 443 | if lbl is None: 444 | if self.padding > 0: 445 | img = F.pad(img, self.padding) 446 | # pad the width if needed 447 | if self.pad_if_needed and img.size[0] < self.size[1]: 448 | img = F.pad(img, padding=int((1 + self.size[1] - img.size[0]) / 2)) 449 | # pad the height if needed 450 | if self.pad_if_needed and img.size[1] < self.size[0]: 451 | img = F.pad(img, padding=int((1 + self.size[0] - img.size[1]) / 2)) 452 | 453 | i, j, h, w = self.get_params(img, self.size) 454 | 455 | return F.crop(img, i, j, h, w) 456 | 457 | else: 458 | assert img.size == lbl.size, 'size of img and lbl should be the same. %s, %s' % (img.size, lbl.size) 459 | if self.padding > 0: 460 | img = F.pad(img, self.padding) 461 | lbl = F.pad(lbl, self.padding) 462 | 463 | # pad the width if needed 464 | if self.pad_if_needed and img.size[0] < self.size[1]: 465 | img = F.pad(img, padding=int((1 + self.size[1] - img.size[0]) / 2)) 466 | lbl = F.pad(lbl, padding=int((1 + self.size[1] - lbl.size[0]) / 2)) 467 | 468 | # pad the height if needed 469 | if self.pad_if_needed and img.size[1] < self.size[0]: 470 | img = F.pad(img, padding=int((1 + self.size[0] - img.size[1]) / 2)) 471 | lbl = F.pad(lbl, padding=int((1 + self.size[0] - lbl.size[1]) / 2)) 472 | 473 | i, j, h, w = self.get_params(img, self.size) 474 | 475 | return F.crop(img, i, j, h, w), F.crop(lbl, i, j, h, w) 476 | 477 | def __repr__(self): 478 | return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding) 479 | 480 | 481 | class RandomResizedCrop(object): 482 | """Crop the given PIL Image to random size and aspect ratio. 483 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random 484 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop 485 | is finally resized to given size. 486 | This is popularly used to train the Inception networks. 487 | Args: 488 | size: expected output size of each edge 489 | scale: range of size of the origin size cropped 490 | ratio: range of aspect ratio of the origin aspect ratio cropped 491 | interpolation: Default: PIL.Image.BILINEAR 492 | """ 493 | 494 | def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR): 495 | if isinstance(size, tuple): 496 | self.size = size 497 | else: 498 | self.size = (size, size) 499 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): 500 | warnings.warn("range should be of kind (min, max)") 501 | 502 | self.interpolation = interpolation 503 | self.scale = scale 504 | self.ratio = ratio 505 | 506 | @staticmethod 507 | def get_params(img, scale, ratio): 508 | """Get parameters for ``crop`` for a random sized crop. 509 | Args: 510 | img (PIL Image): Image to be cropped. 511 | scale (tuple): range of size of the origin size cropped 512 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped 513 | Returns: 514 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random 515 | sized crop. 516 | """ 517 | area = img.size[0] * img.size[1] 518 | 519 | for attempt in range(10): 520 | target_area = random.uniform(*scale) * area 521 | log_ratio = (math.log(ratio[0]), math.log(ratio[1])) 522 | aspect_ratio = math.exp(random.uniform(*log_ratio)) 523 | 524 | w = int(round(math.sqrt(target_area * aspect_ratio))) 525 | h = int(round(math.sqrt(target_area / aspect_ratio))) 526 | 527 | if w <= img.size[0] and h <= img.size[1]: 528 | i = random.randint(0, img.size[1] - h) 529 | j = random.randint(0, img.size[0] - w) 530 | return i, j, h, w 531 | 532 | # Fallback to central crop 533 | in_ratio = img.size[0] / img.size[1] 534 | if (in_ratio < min(ratio)): 535 | w = img.size[0] 536 | h = int(round(w / min(ratio))) 537 | elif (in_ratio > max(ratio)): 538 | h = img.size[1] 539 | w = int(round(h * max(ratio))) 540 | else: # whole image 541 | w = img.size[0] 542 | h = img.size[1] 543 | i = (img.size[1] - h) // 2 544 | j = (img.size[0] - w) // 2 545 | return i, j, h, w 546 | 547 | def __call__(self, img, lbl=None): 548 | """ 549 | Args: 550 | img (PIL Image): Image to be cropped and resized. 551 | Returns: 552 | PIL Image: Randomly cropped and resized image. 553 | """ 554 | i, j, h, w = self.get_params(img, self.scale, self.ratio) 555 | if lbl is not None: 556 | return F.resized_crop(img, i, j, h, w, self.size, self.interpolation), \ 557 | F.resized_crop(lbl, i, j, h, w, self.size, Image.NEAREST) 558 | else: 559 | return F.resized_crop(img, i, j, h, w, self.size, self.interpolation) 560 | 561 | def __repr__(self): 562 | interpolate_str = _pil_interpolation_to_str[self.interpolation] 563 | format_string = self.__class__.__name__ + '(size={0}'.format(self.size) 564 | format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) 565 | format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) 566 | format_string += ', interpolation={0})'.format(interpolate_str) 567 | return format_string 568 | 569 | 570 | class ColorJitter(object): 571 | """Randomly change the brightness, contrast and saturation of an image. 572 | Args: 573 | brightness (float or tuple of float (min, max)): How much to jitter brightness. 574 | brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] 575 | or the given [min, max]. Should be non negative numbers. 576 | contrast (float or tuple of float (min, max)): How much to jitter contrast. 577 | contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] 578 | or the given [min, max]. Should be non negative numbers. 579 | saturation (float or tuple of float (min, max)): How much to jitter saturation. 580 | saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] 581 | or the given [min, max]. Should be non negative numbers. 582 | hue (float or tuple of float (min, max)): How much to jitter hue. 583 | hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. 584 | Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. 585 | """ 586 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 587 | self.brightness = self._check_input(brightness, 'brightness') 588 | self.contrast = self._check_input(contrast, 'contrast') 589 | self.saturation = self._check_input(saturation, 'saturation') 590 | self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), 591 | clip_first_on_zero=False) 592 | 593 | def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): 594 | if isinstance(value, numbers.Number): 595 | if value < 0: 596 | raise ValueError("If {} is a single number, it must be non negative.".format(name)) 597 | value = [center - value, center + value] 598 | if clip_first_on_zero: 599 | value[0] = max(value[0], 0) 600 | elif isinstance(value, (tuple, list)) and len(value) == 2: 601 | if not bound[0] <= value[0] <= value[1] <= bound[1]: 602 | raise ValueError("{} values should be between {}".format(name, bound)) 603 | else: 604 | raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) 605 | 606 | # if value is 0 or (1., 1.) for brightness/contrast/saturation 607 | # or (0., 0.) for hue, do nothing 608 | if value[0] == value[1] == center: 609 | value = None 610 | return value 611 | 612 | @staticmethod 613 | def get_params(brightness, contrast, saturation, hue): 614 | """Get a randomized transform to be applied on image. 615 | Arguments are same as that of __init__. 616 | Returns: 617 | Transform which randomly adjusts brightness, contrast and 618 | saturation in a random order. 619 | """ 620 | transforms = [] 621 | 622 | if brightness is not None: 623 | brightness_factor = random.uniform(brightness[0], brightness[1]) 624 | transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor))) 625 | 626 | if contrast is not None: 627 | contrast_factor = random.uniform(contrast[0], contrast[1]) 628 | transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) 629 | 630 | if saturation is not None: 631 | saturation_factor = random.uniform(saturation[0], saturation[1]) 632 | transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) 633 | 634 | if hue is not None: 635 | hue_factor = random.uniform(hue[0], hue[1]) 636 | transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor))) 637 | 638 | random.shuffle(transforms) 639 | transform = Compose(transforms) 640 | 641 | return transform 642 | 643 | def __call__(self, img, lbl=None): 644 | """ 645 | Args: 646 | img (PIL Image): Input image. 647 | Returns: 648 | PIL Image: Color jittered image. 649 | """ 650 | transform = self.get_params(self.brightness, self.contrast, 651 | self.saturation, self.hue) 652 | if lbl is not None: 653 | return transform(img), lbl 654 | else: 655 | return transform(img) 656 | 657 | def __repr__(self): 658 | format_string = self.__class__.__name__ + '(' 659 | format_string += 'brightness={0}'.format(self.brightness) 660 | format_string += ', contrast={0}'.format(self.contrast) 661 | format_string += ', saturation={0}'.format(self.saturation) 662 | format_string += ', hue={0})'.format(self.hue) 663 | return format_string -------------------------------------------------------------------------------- /dataset/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def group_images(dataset, labels): 6 | # Group images based on the label in LABELS (using labels not reordered) 7 | idxs = {lab: [] for lab in labels} 8 | 9 | labels_cum = labels + [0, 255] 10 | for i in range(len(dataset)): 11 | cls = np.unique(np.array(dataset[i][1])) 12 | if all(x in labels_cum for x in cls): 13 | for x in cls: 14 | if x in labels: 15 | idxs[x].append(i) 16 | return idxs 17 | 18 | 19 | def filter_images(dataset, labels, labels_old=None, overlap=True): 20 | # Filter images without any label in LABELS (using labels not reordered) 21 | idxs = [] 22 | 23 | if 0 in labels: 24 | labels.remove(0) 25 | 26 | print(f"Filtering images...") 27 | if labels_old is None: 28 | labels_old = [] 29 | labels_cum = labels + labels_old + [0, 255] 30 | 31 | if overlap: 32 | fil = lambda c: any(x in labels for x in cls) 33 | else: 34 | fil = lambda c: any(x in labels for x in cls) and all(x in labels_cum for x in c) 35 | 36 | for i in range(len(dataset)): 37 | cls = np.unique(np.array(dataset[i][1])) 38 | if fil(cls): 39 | idxs.append(i) 40 | if i % 1000 == 0: 41 | print(f"\t{i}/{len(dataset)} ...") 42 | return idxs 43 | 44 | 45 | class Subset(torch.utils.data.Dataset): 46 | """ 47 | Subset of a dataset at specified indices. 48 | Arguments: 49 | dataset (Dataset): The whole Dataset 50 | indices (sequence): Indices in the whole set selected for subset 51 | transform (callable): way to transform the images and the targets 52 | target_transform(callable): way to transform the target labels 53 | """ 54 | 55 | def __init__(self, dataset, indices, transform=None, target_transform=None): 56 | self.dataset = dataset 57 | self.indices = indices 58 | self.transform = transform 59 | self.target_transform = target_transform 60 | 61 | def __getitem__(self, idx): 62 | try: 63 | sample, target ,name = self.dataset[self.indices[idx]] 64 | except Exception as e: 65 | raise Exception( 66 | f"dataset = {len(self.dataset)}, indices = {len(self.indices)}, idx = {idx}, msg = {str(e)}" 67 | ) 68 | 69 | if self.transform is not None: 70 | sample, target = self.transform(sample, target) 71 | 72 | if self.target_transform is not None: 73 | target = self.target_transform(target) 74 | 75 | return sample, target ,name 76 | 77 | def viz_getter(self, idx): 78 | image_path, raw_image, sample, target = self.dataset.viz_getter(self.indices[idx]) 79 | if self.transform is not None: 80 | sample, target = self.transform(sample, target) 81 | if self.target_transform is not None: 82 | target = self.target_transform(target) 83 | 84 | return image_path, raw_image, sample, target 85 | 86 | def __len__(self): 87 | return len(self.indices) 88 | 89 | 90 | class MaskLabels: 91 | """ 92 | Use this class to mask labels that you don't want in your dataset. 93 | Arguments: 94 | labels_to_keep (list): The list of labels to keep in the target images 95 | mask_value (int): The value to replace ignored values (def: 0) 96 | """ 97 | 98 | def __init__(self, labels_to_keep, mask_value=0): 99 | self.labels = labels_to_keep 100 | self.value = torch.tensor(mask_value, dtype=torch.uint8) 101 | 102 | def __call__(self, sample): 103 | # sample must be a tensor 104 | assert isinstance(sample, torch.Tensor), "Sample must be a tensor" 105 | 106 | sample.apply_(lambda t: t.apply_(lambda x: x if x in self.labels else self.value)) 107 | 108 | return sample 109 | -------------------------------------------------------------------------------- /dataset/voc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import copy 4 | 5 | import numpy as np 6 | import torch.utils.data as data 7 | import torchvision as tv 8 | from PIL import Image 9 | from torch import distributed 10 | 11 | from .utils import Subset, filter_images, group_images 12 | 13 | classes = { 14 | 0: 'background', 15 | 1: 'aeroplane', 16 | 2: 'bicycle', 17 | 3: 'bird', 18 | 4: 'boat', 19 | 5: 'bottle', 20 | 6: 'bus', 21 | 7: 'car', 22 | 8: 'cat', 23 | 9: 'chair', 24 | 10: 'cow', 25 | 11: 'diningtable', 26 | 12: 'dog', 27 | 13: 'horse', 28 | 14: 'motorbike', 29 | 15: 'person', 30 | 16: 'pottedplant', 31 | 17: 'sheep', 32 | 18: 'sofa', 33 | 19: 'train', 34 | 20: 'tvmonitor' 35 | } 36 | 37 | 38 | class VOCSegmentation(data.Dataset): 39 | """`Pascal VOC `_ Segmentation Dataset. 40 | Args: 41 | root (string): Root directory of the VOC Dataset. 42 | image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val`` 43 | is_aug (bool, optional): If you want to use the augmented train set or not (default is True) 44 | transform (callable, optional): A function/transform that takes in an PIL image 45 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 46 | """ 47 | 48 | def __init__(self, root, image_set='train', is_aug=True, transform=None): 49 | 50 | self.root = os.path.expanduser(root) 51 | self.year = "2012" 52 | 53 | self.transform = transform 54 | 55 | self.image_set = image_set 56 | voc_root = self.root 57 | splits_dir = os.path.join(voc_root, 'list') 58 | 59 | if not os.path.isdir(voc_root): 60 | raise RuntimeError( 61 | 'Dataset not found or corrupted.' + ' You can use download=True to download it' 62 | f'at location = {voc_root}' 63 | ) 64 | 65 | if is_aug and image_set == 'train': 66 | mask_dir = os.path.join(voc_root, 'SegmentationClassAug') 67 | assert os.path.exists(mask_dir), "SegmentationClassAug not found" 68 | split_f = os.path.join(splits_dir, 'train_aug.txt') 69 | else: 70 | split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt') 71 | 72 | if not os.path.exists(split_f): 73 | raise ValueError( 74 | 'Wrong image_set entered! Please use image_set="train" ' 75 | 'or image_set="trainval" or image_set="val" ' 76 | f'{split_f}' 77 | ) 78 | 79 | # remove leading \n 80 | with open(os.path.join(split_f), "r") as f: 81 | file_names = [x[:-1].split(' ') for x in f.readlines()] 82 | 83 | # REMOVE FIRST SLASH OTHERWISE THE JOIN WILL start from root 84 | self.images = [ 85 | ( 86 | os.path.join(voc_root, "VOCdevkit/VOC2012", 87 | x[0][1:]), os.path.join(voc_root, x[1][1:]) 88 | ) for x in file_names 89 | ] 90 | # print(file_names) 91 | # self.sal_maps = [os.path.join(voc_root, "saliency_map", x + "")] 92 | 93 | def __getitem__(self, index): 94 | """ 95 | Args: 96 | index (int): Index 97 | Returns: 98 | tuple: (image, target) where target is the image segmentation. 99 | """ 100 | img = Image.open(self.images[index][0]).convert('RGB') 101 | target = Image.open(self.images[index][1]) 102 | if self.transform is not None: 103 | img, target = self.transform(img, target) 104 | 105 | return img, target 106 | 107 | def viz_getter(self, index): 108 | image_path = self.images[index][0] 109 | raw_image = Image.open(self.images[index][0]).convert('RGB') 110 | target = Image.open(self.images[index][1]) 111 | if self.transform is not None: 112 | img, target = self.transform(raw_image, target) 113 | else: 114 | img = copy.deepcopy(raw_image) 115 | return image_path, raw_image, img, target 116 | 117 | def __len__(self): 118 | return len(self.images) 119 | 120 | 121 | class VOCSegmentationIncremental(data.Dataset): 122 | 123 | def __init__( 124 | self, 125 | root, 126 | train=True, 127 | transform=None, 128 | labels=None, 129 | labels_old=None, 130 | idxs_path=None, 131 | masking=True, 132 | overlap=True, 133 | data_masking="current", 134 | test_on_val=False, 135 | **kwargs 136 | ): 137 | 138 | full_voc = VOCSegmentation(root, 'train' if train else 'val', is_aug=True, transform=None) 139 | 140 | self.labels = [] 141 | self.labels_old = [] 142 | 143 | if labels is not None: 144 | # store the labels 145 | labels_old = labels_old if labels_old is not None else [] 146 | 147 | self.__strip_zero(labels) 148 | self.__strip_zero(labels_old) 149 | 150 | assert not any( 151 | l in labels_old for l in labels 152 | ), "labels and labels_old must be disjoint sets" 153 | 154 | self.labels = [0] + labels 155 | self.labels_old = [0] + labels_old 156 | 157 | self.order = [0] + labels_old + labels 158 | 159 | # take index of images with at least one class in labels and all classes in labels+labels_old+[0,255] 160 | if idxs_path is not None and os.path.exists(idxs_path): 161 | idxs = np.load(idxs_path).tolist() 162 | else: 163 | idxs = filter_images(full_voc, labels, labels_old, overlap=overlap) 164 | if idxs_path is not None and distributed.get_rank() == 0: 165 | np.save(idxs_path, np.array(idxs, dtype=int)) 166 | 167 | if test_on_val: 168 | rnd = np.random.RandomState(1) 169 | rnd.shuffle(idxs) 170 | train_len = int(0.8 * len(idxs)) 171 | if train: 172 | idxs = idxs[:train_len] 173 | else: 174 | idxs = idxs[train_len:] 175 | 176 | #if train: 177 | # masking_value = 0 178 | #else: 179 | # masking_value = 255 180 | 181 | #self.inverted_order = {label: self.order.index(label) for label in self.order} 182 | #self.inverted_order[255] = masking_value 183 | 184 | masking_value = 0 # Future classes will be considered as background. 185 | self.inverted_order = {label: self.order.index(label) for label in self.order} 186 | self.inverted_order[255] = 255 187 | 188 | reorder_transform = tv.transforms.Lambda( 189 | lambda t: t.apply_( 190 | lambda x: self.inverted_order[x] if x in self.inverted_order else masking_value 191 | ) 192 | ) 193 | 194 | if masking: 195 | if data_masking == "current": 196 | tmp_labels = self.labels + [255] 197 | elif data_masking == "current+old": 198 | tmp_labels = labels_old + self.labels + [255] 199 | elif data_masking == "all": 200 | raise NotImplementedError( 201 | f"data_masking={data_masking} not yet implemented sorry not sorry." 202 | ) 203 | elif data_masking == "new": 204 | tmp_labels = self.labels 205 | masking_value = 255 206 | 207 | target_transform = tv.transforms.Lambda( 208 | lambda t: t. 209 | apply_(lambda x: self.inverted_order[x] if x in tmp_labels else masking_value) 210 | ) 211 | else: 212 | assert False 213 | target_transform = reorder_transform 214 | 215 | # make the subset of the dataset 216 | self.dataset = Subset(full_voc, idxs, transform, target_transform) 217 | else: 218 | self.dataset = full_voc 219 | 220 | def __getitem__(self, index): 221 | """ 222 | Args: 223 | index (int): Index 224 | Returns: 225 | tuple: (image, target) where target is the image segmentation. 226 | """ 227 | 228 | return self.dataset[index] 229 | 230 | def viz_getter(self, index): 231 | return self.dataset.viz_getter(index) 232 | 233 | def __len__(self): 234 | return len(self.dataset) 235 | 236 | @staticmethod 237 | def __strip_zero(labels): 238 | while 0 in labels: 239 | labels.remove(0) 240 | -------------------------------------------------------------------------------- /docs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/why19991/InSeg/d45ebeb795776d60373c4031b163e2efd693000b/docs/framework.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from collections import OrderedDict 3 | from functools import partial 4 | 5 | import torch.nn as nn 6 | 7 | from modules import GlobalAvgPool2d, ResidualBlock 8 | 9 | from .util import try_index 10 | 11 | 12 | class ResNet(nn.Module): 13 | """Standard residual network 14 | 15 | Parameters 16 | ---------- 17 | structure : list of int 18 | Number of residual blocks in each of the four modules of the network 19 | bottleneck : bool 20 | If `True` use "bottleneck" residual blocks with 3 convolutions, otherwise use standard blocks 21 | norm_act : callable 22 | Function to create normalization / activation Module 23 | classes : int 24 | If not `0` also include global average pooling and a fully-connected layer with `classes` outputs at the end 25 | of the network 26 | dilation : int or list of int 27 | List of dilation factors for the four modules of the network, or `1` to ignore dilation 28 | keep_outputs : bool 29 | If `True` output a list with the outputs of all modules 30 | """ 31 | 32 | def __init__( 33 | self, 34 | structure, 35 | bottleneck, 36 | norm_act=nn.BatchNorm2d, 37 | classes=0, 38 | output_stride=16, 39 | keep_outputs=False 40 | ): 41 | super(ResNet, self).__init__() 42 | self.structure = structure 43 | self.bottleneck = bottleneck 44 | self.keep_outputs = keep_outputs 45 | self.keep_outputs = True 46 | 47 | if len(structure) != 4: 48 | raise ValueError("Expected a structure with four values") 49 | if output_stride != 8 and output_stride != 16: 50 | raise ValueError("Output stride must be 8 or 16") 51 | 52 | if output_stride == 16: 53 | dilation = [1, 1, 1, 2] # dilated conv for last 3 blocks (9 layers) 54 | elif output_stride == 8: 55 | dilation = [1, 1, 2, 4] # 23+3 blocks (78 layers) 56 | else: 57 | raise NotImplementedError 58 | 59 | self.dilation = dilation 60 | 61 | # Initial layers 62 | layers = [ 63 | ("conv1", nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False)), ("bn1", norm_act(64)) 64 | ] 65 | if try_index(dilation, 0) == 1: 66 | layers.append(("pool1", nn.MaxPool2d(3, stride=2, padding=1))) 67 | self.mod1 = nn.Sequential(OrderedDict(layers)) 68 | 69 | # Groups of residual blocks 70 | in_channels = 64 71 | if self.bottleneck: 72 | channels = (64, 64, 256) 73 | else: 74 | channels = (64, 64) 75 | for mod_id, num in enumerate(structure): 76 | # Create blocks for module 77 | blocks = [] 78 | for block_id in range(num): 79 | stride, dil = self._stride_dilation(dilation, mod_id, block_id) 80 | blocks.append( 81 | ( 82 | "block%d" % (block_id + 1), 83 | ResidualBlock( 84 | in_channels, 85 | channels, 86 | norm_act=norm_act, 87 | stride=stride, 88 | dilation=dil, 89 | last=block_id == num - 1, 90 | block_id=block_id 91 | ) 92 | ) 93 | ) 94 | 95 | # Update channels and p_keep 96 | in_channels = channels[-1] 97 | 98 | # Create module 99 | self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks))) 100 | 101 | # Double the number of channels for the next module 102 | channels = [c * 2 for c in channels] 103 | 104 | self.out_channels = in_channels 105 | 106 | # Pooling and predictor 107 | if classes != 0: 108 | self.classifier = nn.Sequential( 109 | OrderedDict( 110 | [("avg_pool", GlobalAvgPool2d()), ("fc", nn.Linear(in_channels, classes))] 111 | ) 112 | ) 113 | 114 | @staticmethod 115 | def _stride_dilation(dilation, mod_id, block_id): 116 | d = try_index(dilation, mod_id) 117 | s = 2 if d == 1 and block_id == 0 and mod_id > 0 else 1 118 | return s, d 119 | 120 | def forward(self, x): 121 | outs = [] 122 | attentions = [] 123 | 124 | x = self.mod1(x) 125 | #attentions.append(x) 126 | outs.append(x) 127 | 128 | x, att = self.mod2(x) 129 | attentions.append(att) 130 | outs.append(x) 131 | 132 | x, att = self.mod3(x) 133 | attentions.append(att) 134 | outs.append(x) 135 | 136 | x, att = self.mod4(x) 137 | attentions.append(att) 138 | outs.append(x) 139 | 140 | x, att = self.mod5(x) 141 | attentions.append(att) 142 | outs.append(x) 143 | 144 | if hasattr(self, "classifier"): 145 | outs.append(self.classifier(outs[-1])) 146 | 147 | if self.keep_outputs: 148 | return outs, attentions 149 | else: 150 | return outs[-1], attentions 151 | 152 | 153 | _NETS = { 154 | "18": { 155 | "structure": [2, 2, 2, 2], 156 | "bottleneck": False 157 | }, 158 | "34": { 159 | "structure": [3, 4, 6, 3], 160 | "bottleneck": False 161 | }, 162 | "50": { 163 | "structure": [3, 4, 6, 3], 164 | "bottleneck": True 165 | }, 166 | "101": { 167 | "structure": [3, 4, 23, 3], 168 | "bottleneck": True 169 | }, 170 | "152": { 171 | "structure": [3, 8, 36, 3], 172 | "bottleneck": True 173 | }, 174 | } 175 | 176 | __all__ = [] 177 | for name, params in _NETS.items(): 178 | net_name = "net_resnet" + name 179 | setattr(sys.modules[__name__], net_name, partial(ResNet, **params)) 180 | __all__.append(net_name) 181 | -------------------------------------------------------------------------------- /models/util.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from functools import partial 7 | 8 | 9 | def try_index(scalar_or_list, i): 10 | try: 11 | return scalar_or_list[i] 12 | except TypeError: 13 | return scalar_or_list 14 | 15 | class Upsample(nn.Module): 16 | def __init__(self, size=None, scale_factor=None, mode='bilinear', align_corners=False): 17 | super(Upsample, self).__init__() 18 | self.align_corners = align_corners 19 | self.mode = mode 20 | self.scale_factor = scale_factor 21 | self.size = size 22 | 23 | def forward(self, x): 24 | return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, 25 | align_corners=self.align_corners) 26 | 27 | 28 | class DecoderBlock(nn.Module): 29 | def __init__(self, in_channels, n_filters, norm_act,use_transpose=True): 30 | super(DecoderBlock, self).__init__() 31 | if use_transpose: 32 | self.up_op = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, 3, stride=2, padding=1, output_padding=1) 33 | else: 34 | self.up_op = Upsample(scale_factor=2, align_corners=True) 35 | 36 | self.conv1 = nn.Conv2d(in_channels, in_channels // 2, 1) 37 | self.norm1 = norm_act(in_channels//2) 38 | # self.norm1 = nn.BatchNorm2d(in_channels // 2) 39 | # self.relu1 = nn.ReLU(inplace=True) 40 | 41 | # self.deconv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 3, stride=2, padding=1, output_padding=1) 42 | self.norm2 = norm_act(in_channels // 2) 43 | 44 | self.conv3 = nn.Conv2d(in_channels // 2, n_filters, 1) 45 | self.norm3 = norm_act(n_filters) 46 | 47 | def forward(self, x): 48 | x = self.conv1(x) 49 | x = self.norm1(x) 50 | x = self.up_op(x) 51 | x = self.norm2(x) 52 | x = self.conv3(x) 53 | x = self.norm3(x) 54 | return x -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .deeplab import DeeplabV3 2 | from .residual import IdentityResidualBlock, ResidualBlock 3 | from .misc import GlobalAvgPool2d 4 | -------------------------------------------------------------------------------- /modules/deeplab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as functional 4 | 5 | from models.util import try_index 6 | 7 | 8 | class DeeplabV3(nn.Module): 9 | def __init__(self, 10 | in_channels, 11 | out_channels, 12 | hidden_channels=256, 13 | out_stride=16, 14 | norm_act=nn.BatchNorm2d, 15 | pooling_size=None): 16 | super(DeeplabV3, self).__init__() 17 | self.pooling_size = pooling_size 18 | 19 | if out_stride == 16: 20 | dilations = [6, 12, 18] 21 | elif out_stride == 8: 22 | dilations = [12, 24, 32] 23 | 24 | self.map_convs = nn.ModuleList([ 25 | nn.Conv2d(in_channels, hidden_channels, 1, bias=False), 26 | nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[0], padding=dilations[0]), 27 | nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[1], padding=dilations[1]), 28 | nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[2], padding=dilations[2]) 29 | ]) 30 | self.map_bn = norm_act(hidden_channels * 4) 31 | 32 | self.global_pooling_conv = nn.Conv2d(in_channels, hidden_channels, 1, bias=False) 33 | self.global_pooling_bn = norm_act(hidden_channels) 34 | 35 | self.red_conv = nn.Conv2d(hidden_channels * 4, out_channels, 1, bias=False) 36 | self.pool_red_conv = nn.Conv2d(hidden_channels, out_channels, 1, bias=False) 37 | self.red_bn = norm_act(out_channels) 38 | 39 | self.reset_parameters(self.map_bn.activation, self.map_bn.activation_param) 40 | 41 | def reset_parameters(self, activation, slope): 42 | gain = nn.init.calculate_gain(activation, slope) 43 | for m in self.modules(): 44 | if isinstance(m, nn.Conv2d): 45 | nn.init.xavier_normal_(m.weight.data, gain) 46 | if hasattr(m, "bias") and m.bias is not None: 47 | nn.init.constant_(m.bias, 0) 48 | elif isinstance(m, nn.BatchNorm2d): 49 | if hasattr(m, "weight") and m.weight is not None: 50 | nn.init.constant_(m.weight, 1) 51 | if hasattr(m, "bias") and m.bias is not None: 52 | nn.init.constant_(m.bias, 0) 53 | 54 | def forward(self, x): 55 | # Map convolutions 56 | out = torch.cat([m(x) for m in self.map_convs], dim=1) 57 | out = self.map_bn(out) 58 | out = self.red_conv(out) 59 | 60 | # Global pooling 61 | pool = self._global_pooling(x) # if training is global avg pooling 1x1, else use larger pool size 62 | pool = self.global_pooling_conv(pool) 63 | pool = self.global_pooling_bn(pool) 64 | pool = self.pool_red_conv(pool) 65 | if self.training or self.pooling_size is None: 66 | pool = pool.repeat(1, 1, x.size(2), x.size(3)) 67 | 68 | out += pool 69 | out = self.red_bn(out) 70 | return out 71 | 72 | def _global_pooling(self, x): 73 | if self.training or self.pooling_size is None: 74 | # this is like Adaptive Average Pooling (1,1) 75 | pool = x.view(x.size(0), x.size(1), -1).mean(dim=-1) 76 | pool = pool.view(x.size(0), x.size(1), 1, 1) 77 | else: 78 | pooling_size = (min(try_index(self.pooling_size, 0), x.shape[2]), 79 | min(try_index(self.pooling_size, 1), x.shape[3])) 80 | padding = ( 81 | (pooling_size[1] - 1) // 2, 82 | (pooling_size[1] - 1) // 2 if pooling_size[1] % 2 == 1 else (pooling_size[1] - 1) // 2 + 1, 83 | (pooling_size[0] - 1) // 2, 84 | (pooling_size[0] - 1) // 2 if pooling_size[0] % 2 == 1 else (pooling_size[0] - 1) // 2 + 1 85 | ) 86 | 87 | pool = functional.avg_pool2d(x, pooling_size, stride=1) 88 | pool = functional.pad(pool, pad=padding, mode="replicate") 89 | return pool 90 | -------------------------------------------------------------------------------- /modules/fca_cid.py: -------------------------------------------------------------------------------- 1 | def FCA(labels=None, features=None, centroids=None, unknown_label=None, features_counter=None,strategy='running'): 2 | b, h, w = labels.shape 3 | labels_down = labels.unsqueeze(dim=1) 4 | cl_present = torch.unique(input=labels_down) 5 | for cl in cl_present: 6 | if cl > 0 and cl != 255: 7 | features_cl = features[(labels_down == cl).expand(-1, features.shape[1], -1, -1)].view(-1, features.shape[1]) 8 | if strategy == 'running': 9 | features_counter[cl] = features_counter[cl] + features_cl.shape[0] 10 | centroids[cl] = (centroids[cl] * features_counter[cl] + features_cl.detach().sum(dim=0)) / features_counter[cl] 11 | else: 12 | centroids[cl] = centroids[cl] * 0.999 + (1 - 0.999) * features_cl.detach().mean(dim=0) 13 | centroids = F.normalize(centroids, p=2, dim=1) 14 | features_bg = features[(labels_down == 0).expand(-1, features.shape[1], -1, -1)].view(-1, features.shape[1]) 15 | features_bg = F.normalize(features_bg, p=2, dim=1) 16 | re_labels = labels_down.view(-1) 17 | similarity = torch.matmul(features_bg.detach(), centroids.T) 18 | similarity = similarity.mean(dim=-1) 19 | value, index = torch.sort(similarity, descending=True) 20 | fill_mask = torch.zeros_like(index) 21 | value_index = value >= 0.8 22 | fill_index = index[value_index] 23 | fill_mask[fill_index] = unknown_label # we set the label of unknown class as C_t+1 in our experiment 24 | re_labels[re_labels==0] = fill_mask 25 | re_labels = re_labels.view(b, h , w ) 26 | return re_labels, centroids 27 | 28 | def CID(outputs=None, outputs_old=None, nb_old_classes=None, nb_current_classes=None, nb_future_classes=None, labels=None): 29 | 30 | outputs = outputs.permute(0, 2, 3, 1).contiguous() 31 | b, h, w, c = outputs.shape 32 | outputs_old = outputs_old.permute(0, 2, 3, 1).contiguous() 33 | out_old = torch.zeros_like(outputs) 34 | labels_unique = torch.unique(labels) 35 | out_old[..., :nb_old_classes + nb_future_classes] = outputs_old[..., :] 36 | for cl in range(nb_old_classes, nb_current_classes): 37 | out_old[..., cl] = outputs_old[..., 0] * (labels==cl).squeeze(dim=-1) 38 | for j in range(nb_future_classes): 39 | out_old[..., cl] = out_old[..., cl] + outputs_old[..., nb_old_classes + j] * (labels==cl).squeeze(dim=-1) 40 | # out_old[..., nb_old_classes+j] = out_old[..., cl] + outputs_old[..., nb_old_classes + j] * (labels == cl).squeeze(dim=-1) 41 | out_old[..., 0] = (labels != cl).squeeze(dim=-1) * out_old[..., 0] 42 | # out_old[..., :nb_old_classes+nb_future_classes] = outputs_old[..., :]* ((labels < nb_old_classes) + (labels == 255)).unsqueeze(dim=-1) 43 | out_old = torch.log_softmax(out_old, dim=-1) 44 | outputs = torch.softmax(outputs, dim=-1) 45 | # out = (out_old * outputs * ((labels= nb_old_classes)).unsqueeze(dim=-1).expand(-1, -1, -1,c)).sum(dim=-1) / c 49 | 50 | return - torch.mean(out) 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /modules/misc.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class GlobalAvgPool2d(nn.Module): 5 | def __init__(self): 6 | """Global average pooling over the input's spatial dimensions""" 7 | super(GlobalAvgPool2d, self).__init__() 8 | 9 | def forward(self, inputs): 10 | in_size = inputs.size() 11 | return inputs.view((in_size[0], in_size[1], -1)).mean(dim=2) 12 | 13 | -------------------------------------------------------------------------------- /modules/residual.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as functional 5 | import torch 6 | 7 | class GroupBatchnorm2d(nn.Module): 8 | def __init__(self, c_num: int, 9 | group_num: int = 16, 10 | eps: float = 1e-10 11 | ): 12 | super(GroupBatchnorm2d, self).__init__() 13 | assert c_num >= group_num 14 | self.group_num = group_num 15 | self.weight = nn.Parameter(torch.randn(c_num, 1, 1)) 16 | self.bias = nn.Parameter(torch.zeros(c_num, 1, 1)) 17 | self.eps = eps 18 | 19 | def forward(self, x): 20 | N, C, H, W = x.size() 21 | x = x.view(N, self.group_num, -1) 22 | mean = x.mean(dim=2, keepdim=True) 23 | std = x.std(dim=2, keepdim=True) 24 | x = (x - mean) / (std + self.eps) 25 | x = x.view(N, C, H, W) 26 | return x * self.weight + self.bias 27 | 28 | 29 | class SRU(nn.Module): 30 | def __init__(self, 31 | oup_channels: int, 32 | group_num: int = 16, 33 | gate_treshold: float = 0.5, 34 | torch_gn: bool = False 35 | ): 36 | super().__init__() 37 | 38 | self.gn = nn.GroupNorm(num_channels=oup_channels, num_groups=group_num) if torch_gn else GroupBatchnorm2d( 39 | c_num=oup_channels, group_num=group_num) 40 | self.gate_treshold = gate_treshold 41 | self.sigomid = nn.Sigmoid() 42 | 43 | def forward(self, x): 44 | gn_x = self.gn(x) 45 | w_gamma = self.gn.weight / torch.sum(self.gn.weight) 46 | w_gamma = w_gamma.view(1, -1, 1, 1) 47 | reweigts = self.sigomid(gn_x * w_gamma) 48 | # Gate 49 | info_mask = reweigts >= self.gate_treshold 50 | noninfo_mask = reweigts < self.gate_treshold 51 | x_1 = info_mask * gn_x 52 | x_2 = noninfo_mask * gn_x 53 | x = self.reconstruct(x_1, x_2) 54 | return x 55 | 56 | def reconstruct(self, x_1, x_2): 57 | x_11, x_12 = torch.split(x_1, x_1.size(1) // 2, dim=1) 58 | x_21, x_22 = torch.split(x_2, x_2.size(1) // 2, dim=1) 59 | return torch.cat([x_11 + x_22, x_12 + x_21], dim=1) 60 | 61 | class ResidualBlock(nn.Module): 62 | """Configurable residual block 63 | 64 | Parameters 65 | ---------- 66 | in_channels : int 67 | Number of input channels. 68 | channels : list of int 69 | Number of channels in the internal feature maps. Can either have two or three elements: if three construct 70 | a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then 71 | `3 x 3` then `1 x 1` convolutions. 72 | stride : int 73 | Stride of the first `3 x 3` convolution 74 | dilation : int 75 | Dilation to apply to the `3 x 3` convolutions. 76 | groups : int 77 | Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with 78 | bottleneck blocks. 79 | norm_act : callable 80 | Function to create normalization / activation Module. 81 | dropout: callable 82 | Function to create Dropout Module. 83 | """ 84 | 85 | def __init__( 86 | self, 87 | in_channels, 88 | channels, 89 | stride=1, 90 | dilation=1, 91 | groups=1, 92 | norm_act=nn.BatchNorm2d, 93 | dropout=None, 94 | last=False, 95 | block_id=None 96 | ): 97 | super(ResidualBlock, self).__init__() 98 | 99 | self.block_id = block_id 100 | 101 | # Check parameters for inconsistencies 102 | if len(channels) != 2 and len(channels) != 3: 103 | raise ValueError("channels must contain either two or three values") 104 | if len(channels) == 2 and groups != 1: 105 | raise ValueError("groups > 1 are only valid if len(channels) == 3") 106 | 107 | is_bottleneck = len(channels) == 3 108 | need_proj_conv = stride != 1 or in_channels != channels[-1] 109 | 110 | if not is_bottleneck: 111 | bn2 = norm_act(channels[1]) 112 | bn2.activation = "identity" 113 | layers = [ 114 | ( 115 | "conv1", 116 | nn.Conv2d( 117 | in_channels, 118 | channels[0], 119 | 3, 120 | stride=stride, 121 | padding=dilation, 122 | bias=False, 123 | dilation=dilation 124 | ) 125 | ), ("bn1", norm_act(channels[0])), 126 | ( 127 | "conv2", 128 | nn.Conv2d( 129 | channels[0], 130 | channels[1], 131 | 3, 132 | stride=1, 133 | padding=dilation, 134 | bias=False, 135 | dilation=dilation 136 | ) 137 | ), ("bn2", bn2) 138 | ] 139 | if dropout is not None: 140 | layers = layers[0:2] + [("dropout", dropout())] + layers[2:] 141 | else: 142 | bn3 = norm_act(channels[2]) 143 | bn3.activation = "identity" 144 | layers = [ 145 | ("conv1", nn.Conv2d(in_channels, channels[0], 1, stride=1, padding=0, bias=False)), 146 | ("bn1", norm_act(channels[0])), 147 | ( 148 | "conv2", 149 | nn.Conv2d(channels[0],channels[1],3,stride=stride,padding=dilation,bias=False,groups=groups,dilation=dilation) 150 | ), 151 | ("bn2", norm_act(channels[1])), 152 | ("conv3", nn.Conv2d(channels[1], channels[2], 1, stride=1, padding=0, bias=False)), 153 | ("bn3", bn3) 154 | ] 155 | 156 | if dropout is not None: 157 | layers = layers[0:4] + [("dropout", dropout())] + layers[4:] 158 | self.convs = nn.Sequential(OrderedDict(layers)) 159 | 160 | if need_proj_conv: 161 | self.proj_conv = nn.Conv2d( 162 | in_channels, channels[-1], 1, stride=stride, padding=0, bias=False 163 | ) 164 | self.proj_bn = norm_act(channels[-1]) 165 | self.proj_bn.activation = "identity" 166 | 167 | self._last = last 168 | 169 | def forward(self, x): 170 | if hasattr(self, "proj_conv"): 171 | residual = self.proj_conv(x) 172 | residual = self.proj_bn(residual) 173 | else: 174 | residual = x 175 | 176 | x = self.convs(x) + residual 177 | 178 | if self.convs.bn1.activation == "leaky_relu": 179 | act = functional.leaky_relu( 180 | x, negative_slope=self.convs.bn1.activation_param, inplace=not self._last 181 | ) 182 | elif self.convs.bn1.activation == "elu": 183 | act = functional.elu(x, alpha=self.convs.bn1.activation_param, inplace=not self._last) 184 | elif self.convs.bn1.activation == "identity": 185 | act = x 186 | 187 | if self._last: 188 | return act, x 189 | return act 190 | 191 | 192 | class IdentityResidualBlock(nn.Module): 193 | 194 | def __init__( 195 | self, 196 | in_channels, 197 | channels, 198 | stride=1, 199 | dilation=1, 200 | groups=1, 201 | norm_act=nn.BatchNorm2d, 202 | dropout=None 203 | ): 204 | """Configurable identity-mapping residual block 205 | 206 | Parameters 207 | ---------- 208 | in_channels : int 209 | Number of input channels. 210 | channels : list of int 211 | Number of channels in the internal feature maps. Can either have two or three elements: if three construct 212 | a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then 213 | `3 x 3` then `1 x 1` convolutions. 214 | stride : int 215 | Stride of the first `3 x 3` convolution 216 | dilation : int 217 | Dilation to apply to the `3 x 3` convolutions. 218 | groups : int 219 | Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with 220 | bottleneck blocks. 221 | norm_act : callable 222 | Function to create normalization / activation Module. 223 | dropout: callable 224 | Function to create Dropout Module. 225 | """ 226 | super(IdentityResidualBlock, self).__init__() 227 | 228 | # Check parameters for inconsistencies 229 | if len(channels) != 2 and len(channels) != 3: 230 | raise ValueError("channels must contain either two or three values") 231 | if len(channels) == 2 and groups != 1: 232 | raise ValueError("groups > 1 are only valid if len(channels) == 3") 233 | 234 | is_bottleneck = len(channels) == 3 235 | need_proj_conv = stride != 1 or in_channels != channels[-1] 236 | 237 | self.bn1 = norm_act(in_channels) 238 | if not is_bottleneck: 239 | layers = [ 240 | ( 241 | "conv1", 242 | nn.Conv2d( 243 | in_channels, 244 | channels[0], 245 | 3, 246 | stride=stride, 247 | padding=dilation, 248 | bias=False, 249 | dilation=dilation 250 | ) 251 | ), ("bn2", norm_act(channels[0])), 252 | ( 253 | "conv2", 254 | nn.Conv2d( 255 | channels[0], 256 | channels[1], 257 | 3, 258 | stride=1, 259 | padding=dilation, 260 | bias=False, 261 | dilation=dilation 262 | ) 263 | ) 264 | ] 265 | if dropout is not None: 266 | layers = layers[0:2] + [("dropout", dropout())] + layers[2:] 267 | else: 268 | layers = [ 269 | ( 270 | "conv1", 271 | nn.Conv2d(in_channels, channels[0], 1, stride=stride, padding=0, bias=False) 272 | ), ("bn2", norm_act(channels[0])), 273 | ( 274 | "conv2", 275 | nn.Conv2d( 276 | channels[0], 277 | channels[1], 278 | 3, 279 | stride=1, 280 | padding=dilation, 281 | bias=False, 282 | groups=groups, 283 | dilation=dilation 284 | ) 285 | ), ("bn3", norm_act(channels[1])), 286 | ("conv3", nn.Conv2d(channels[1], channels[2], 1, stride=1, padding=0, bias=False)) 287 | ] 288 | if dropout is not None: 289 | layers = layers[0:4] + [("dropout", dropout())] + layers[4:] 290 | self.convs = nn.Sequential(OrderedDict(layers)) 291 | 292 | if need_proj_conv: 293 | self.proj_conv = nn.Conv2d( 294 | in_channels, channels[-1], 1, stride=stride, padding=0, bias=False 295 | ) 296 | 297 | def forward(self, x): 298 | if hasattr(self, "proj_conv"): 299 | bn1 = self.bn1(x) 300 | shortcut = self.proj_conv(bn1) 301 | else: 302 | shortcut = x.clone() 303 | bn1 = self.bn1(x) 304 | 305 | out = self.convs(bn1) 306 | out.add_(shortcut) 307 | 308 | return out 309 | -------------------------------------------------------------------------------- /modules/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as functional 4 | 5 | from models.util import try_index, DecoderBlock 6 | 7 | 8 | class Unet(nn.Module): 9 | def __init__(self, 10 | in_channels, 11 | filters, 12 | norm_act=nn.BatchNorm2d, 13 | ): 14 | 15 | super(Unet, self).__init__() 16 | 17 | 18 | self.u_conv1 = nn.Conv2d(in_channels=in_channels, out_channels=filters[3], kernel_size=1) 19 | self.u_conv2 = nn.Conv2d(in_channels=filters[1], out_channels=filters[0], kernel_size=1) 20 | 21 | self.u_decoder4 = DecoderBlock(filters[3], filters[2],norm_act=norm_act) 22 | self.u_decoder3 = DecoderBlock(filters[2], filters[1],norm_act=norm_act) 23 | self.u_decoder2 = DecoderBlock(filters[0], filters[0],norm_act=norm_act) 24 | self.u_decoder1 = DecoderBlock(filters[0], filters[0],norm_act=norm_act) 25 | 26 | self.u_conv1_new = nn.Conv2d(in_channels=in_channels, out_channels=filters[3], kernel_size=1) 27 | self.u_conv2_new = nn.Conv2d(in_channels=filters[1], out_channels=filters[0], kernel_size=1) 28 | 29 | self.u_decoder4_new = DecoderBlock(filters[3], filters[2], norm_act=norm_act) 30 | self.u_decoder3_new = DecoderBlock(filters[2], filters[1], norm_act=norm_act) 31 | self.u_decoder2_new = DecoderBlock(filters[0], filters[0], norm_act=norm_act) 32 | self.u_decoder1_new = DecoderBlock(filters[0], filters[0], norm_act=norm_act) 33 | 34 | def forward(self, x): 35 | 36 | e = functional.leaky_relu_(self.u_conv1(x[4]),negative_slope=0.01) 37 | d4 = self.u_decoder4(e+x[3]) + x[2] 38 | d3 = self.u_decoder3(d4) + x[1] 39 | e = functional.leaky_relu_(self.u_conv2(d3),negative_slope=0.01) 40 | d2 = self.u_decoder2(e+x[0]) 41 | d1 = self.u_decoder1(d2) 42 | return functional.leaky_relu(d1,negative_slope=0.01) 43 | 44 | 45 | 46 | class Unet_2(nn.Module): 47 | def __init__(self, 48 | in_channels, 49 | filters, 50 | norm_act=nn.BatchNorm2d, 51 | ): 52 | 53 | super(Unet_2, self).__init__() 54 | 55 | 56 | self.u_conv1 = nn.Conv2d(in_channels=in_channels, out_channels=filters[3], kernel_size=1) 57 | self.u_conv2 = nn.Conv2d(in_channels=filters[1], out_channels=filters[0], kernel_size=1) 58 | 59 | self.u_decoder4 = DecoderBlock(filters[3], filters[2],norm_act=norm_act) 60 | self.u_decoder3 = DecoderBlock(filters[2], filters[1],norm_act=norm_act) 61 | self.u_decoder2 = DecoderBlock(filters[0], filters[0],norm_act=norm_act) 62 | self.u_decoder1 = DecoderBlock(filters[0], filters[0],norm_act=norm_act) 63 | 64 | self.u_conv1_new = nn.Conv2d(in_channels=in_channels, out_channels=filters[3], kernel_size=1) 65 | self.u_conv2_new = nn.Conv2d(in_channels=filters[1], out_channels=filters[0], kernel_size=1) 66 | 67 | self.u_decoder4_new = DecoderBlock(filters[3], filters[2], norm_act=norm_act) 68 | self.u_decoder3_new = DecoderBlock(filters[2], filters[1], norm_act=norm_act) 69 | self.u_decoder2_new = DecoderBlock(filters[0], filters[0], norm_act=norm_act) 70 | self.u_decoder1_new = DecoderBlock(filters[0], filters[0], norm_act=norm_act) 71 | 72 | def forward(self, x): 73 | 74 | # Decoder 75 | e = functional.leaky_relu_(self.u_conv1(x[4]),negative_slope=0.01) 76 | d4 = self.u_decoder4(e+x[3]) + x[2] 77 | d3 = self.u_decoder3(d4) + x[1] 78 | e = functional.leaky_relu_(self.u_conv2(d3),negative_slope=0.01) 79 | d2 = self.u_decoder2(e+x[0]) 80 | d1 = self.u_decoder1(d2) 81 | 82 | e_new = functional.leaky_relu_(self.u_conv1_new(x[4]), negative_slope=0.01) 83 | d4_new = self.u_decoder4_new(e_new + x[3]) + x[2] 84 | d3_new = self.u_decoder3_new(d4_new) + x[1] 85 | e_new = functional.leaky_relu_(self.u_conv2_new(d3_new), negative_slope=0.01) 86 | d2_new = self.u_decoder2_new(e_new + x[0]) 87 | d1_new = self.u_decoder1_new(d2_new) 88 | 89 | # r = torch.rand(1, d1.shape[1], 1, 1, dtype=torch.float32) 90 | # if self.training == False: 91 | # r[:, :, :, :] = 1.0 92 | # weight_out_branch = torch.zeros_like(r) 93 | # weight_out_new_branch = torch.zeros_like(r) 94 | # weight_out_branch[r < 0.33] = 2. 95 | # weight_out_new_branch[r < 0.33] = 0. 96 | # weight_out_branch[(r < 0.66) * (r >= 0.33)] = 0. 97 | # weight_out_new_branch[(r < 0.66) * (r >= 0.33)] = 2. 98 | # weight_out_branch[r >= 0.66] = 1. 99 | # weight_out_new_branch[r >= 0.66] = 1. 100 | # out = d1 * weight_out_branch.to(d1.device) * 0.5 + d1_new * weight_out_new_branch.to(d1_new.device) * 0.5 101 | 102 | out = d1 * 0.5 + d1_new * 0.5 103 | return functional.leaky_relu(out,negative_slope=0.01) 104 | 105 | class Unet_Intermediate(nn.Module): 106 | def __init__(self, 107 | in_channels, 108 | filters, 109 | norm_act=nn.BatchNorm2d, 110 | ): 111 | 112 | super(Unet_Intermediate, self).__init__() 113 | 114 | 115 | # self.u_conv1 = nn.Conv2d(in_channels=in_channels, out_channels=filters[3], kernel_size=1) 116 | self.u_conv2 = nn.Conv2d(in_channels=filters[1], out_channels=filters[0], kernel_size=1) 117 | 118 | self.u_decoder4 = DecoderBlock(filters[3], filters[2],norm_act=norm_act) 119 | self.u_decoder3 = DecoderBlock(filters[2], filters[1],norm_act=norm_act) 120 | self.u_decoder2 = DecoderBlock(filters[0], filters[0],norm_act=norm_act) 121 | self.u_decoder1 = DecoderBlock(filters[0], filters[0],norm_act=norm_act) 122 | 123 | # self.u_conv1_new = nn.Conv2d(in_channels=in_channels, out_channels=filters[3], kernel_size=1) 124 | # self.u_conv2_new = nn.Conv2d(in_channels=filters[1], out_channels=filters[0], kernel_size=1) 125 | 126 | # self.u_decoder4_new = DecoderBlock(filters[3], filters[2], norm_act=norm_act) 127 | # self.u_decoder3_new = DecoderBlock(filters[2], filters[1], norm_act=norm_act) 128 | # self.u_decoder2_new = DecoderBlock(filters[0], filters[0], norm_act=norm_act) 129 | # self.u_decoder1_new = DecoderBlock(filters[0], filters[0], norm_act=norm_act) 130 | 131 | def forward(self, x): 132 | # print(x[0].shape) 133 | # print(x[1].shape) 134 | # print(x[2].shape) 135 | # print(x[3].shape) 136 | # print(x[4].shape) 137 | # Decoder 138 | # e = functional.leaky_relu_(self.u_conv1(x[4]),negative_slope=0.01) 139 | d4 = self.u_decoder4(x[3]) + x[2] 140 | d3 = self.u_decoder3(d4) + x[1] 141 | e = functional.leaky_relu_(self.u_conv2(d3),negative_slope=0.01) 142 | d2 = self.u_decoder2(e+x[0]) 143 | d1 = self.u_decoder1(d2) 144 | 145 | # e_new = functional.leaky_relu_(self.u_conv1_new(x[4]), negative_slope=0.01) 146 | # d4_new = self.u_decoder4_new(e_new + x[3]) + x[2] 147 | # d3_new = self.u_decoder3_new(d4_new) + x[1] 148 | # e_new = functional.leaky_relu_(self.u_conv2_new(d3_new), negative_slope=0.01) 149 | # d2_new = self.u_decoder2_new(e_new + x[0]) 150 | # d1_new = self.u_decoder1_new(d2_new) 151 | # 152 | # r = torch.rand(1, d1.shape[1], 1, 1, dtype=torch.float32) 153 | # if self.training == False: 154 | # r[:, :, :, :] = 1.0 155 | # weight_out_branch = torch.zeros_like(r) 156 | # weight_out_new_branch = torch.zeros_like(r) 157 | # weight_out_branch[r < 0.33] = 2. 158 | # weight_out_new_branch[r < 0.33] = 0. 159 | # weight_out_branch[(r < 0.66) * (r >= 0.33)] = 0. 160 | # weight_out_new_branch[(r < 0.66) * (r >= 0.33)] = 2. 161 | # weight_out_branch[r >= 0.66] = 1. 162 | # weight_out_new_branch[r >= 0.66] = 1. 163 | # out = d1 * weight_out_branch.to(d1.device) * 0.5 + d1_new * weight_out_new_branch.to(d1_new.device) * 0.5 164 | 165 | 166 | return functional.leaky_relu(d1,negative_slope=0.01) --------------------------------------------------------------------------------