├── .gitignore ├── AAF.py ├── LICENSE ├── OnlineRetraining ├── LICENSE └── segm │ ├── __init__.py │ ├── config.py │ ├── config.yml │ ├── data │ ├── __init__.py │ ├── ade20k.py │ ├── base.py │ ├── cityscapes.py │ ├── coco.py │ ├── config │ │ ├── ade20k.py │ │ ├── ade20k.yml │ │ ├── cityscapes.py │ │ ├── cityscapes.yml │ │ ├── coco.py │ │ ├── coco.yml │ │ ├── pascal_context.py │ │ ├── pascal_context.yml │ │ ├── pascal_voc.py │ │ └── pascal_voc.yml │ ├── factory.py │ ├── imagenet.py │ ├── loader.py │ ├── pascal_context.py │ ├── pascal_voc.py │ └── utils.py │ ├── datasets │ └── coco.py │ ├── dist_test.sh │ ├── dist_train.sh │ ├── engine.py │ ├── eval │ ├── accuracy.py │ ├── densecrf.py │ ├── make_crf.py │ └── miou.py │ ├── inference.py │ ├── metrics.py │ ├── model │ ├── blocks.py │ ├── decoder.py │ ├── eva02.py │ ├── factory.py │ ├── rope.py │ ├── segmenter.py │ ├── utils.py │ └── vit.py │ ├── optim │ ├── factory.py │ ├── optim_factory.py │ └── scheduler.py │ ├── scripts │ ├── prepare_ade20k.py │ ├── prepare_cityscapes.py │ ├── prepare_pcontext.py │ └── show_attn_map.py │ ├── train.py │ └── utils │ ├── download.py │ ├── lines.py │ ├── logger.py │ ├── logs.py │ └── torch.py ├── README.md ├── coco ├── cls_labels.npy ├── train_1250_id.txt ├── train_id.txt ├── val_5000.txt └── val_id.txt ├── data ├── download_and_convert_coco.sh └── download_and_convert_voc12.sh ├── datasets.py ├── docs ├── Evaluate.md ├── Install.md ├── Training.md └── prepare_dataset.md ├── engine.py ├── evaluation.py ├── img ├── WeakTr.png ├── clip_grad_decoder.png └── miou_compare.png ├── main.py ├── models.py ├── requirements.txt ├── tool └── imutils.py ├── utils.py ├── vision_transformer.py └── voc12 ├── cls_labels.npy ├── train_aug_id.txt ├── train_id.txt └── val_id.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | .DS_Store 162 | 163 | # run scripts 164 | run_coco.sh 165 | run.sh 166 | run_affinitynet.sh 167 | CenterPoints/exp_run.sh 168 | colorful.py 169 | *.json 170 | data/coco 171 | data/voc12 172 | mlruns/* 173 | *results* 174 | OnlineRetraining/mlruns 175 | OnlineRetraining/start* 176 | OnlineRetraining/seg* 177 | !OnlineRetraining/segm -------------------------------------------------------------------------------- /AAF.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | 4 | 5 | # 1. using attention feature to generate dynamic weight 6 | class AAF(nn.Module): 7 | def __init__(self, channel, reduction=16, feats_channel=64, feat_reduction=8, pool="avg"): 8 | super().__init__() 9 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 10 | if pool == "max": 11 | self.avg_pool = nn.AdaptiveMaxPool2d(1) 12 | self.attn_head_ffn = nn.Sequential( 13 | nn.Linear(channel, int(channel / reduction), bias=False), 14 | nn.ReLU(inplace=True), # inplace=True sometimes slightly decrease the memory usage 15 | # nn.Sigmoid(), 16 | nn.Linear(int(channel / reduction), channel, bias=False), 17 | nn.Sigmoid() 18 | ) 19 | self.attn_feat_ffn = nn.Sequential( 20 | nn.Linear(feats_channel, int(feats_channel / feat_reduction)), 21 | nn.Linear(int(feats_channel / feat_reduction), 1), 22 | ) 23 | 24 | def forward_weight(self, x): 25 | b, c, n, m = x.size() # batchsize, attn heads num=72, class tokens + patch tokens, embedding_dim=64 26 | 27 | # 1. pooling for tokens 28 | x = x.permute(0, 1, 3, 2).contiguous().view(b, c*m, n, 1) 29 | attn_feat_pool = self.avg_pool(x) 30 | 31 | # 2. FFN for channels, generate dynamic weight 32 | attn_feat_pool = attn_feat_pool.view(b*c, m) 33 | attn_weight = self.attn_feat_ffn(attn_feat_pool) 34 | 35 | # 3. FFN for attn heads generate last weight 36 | attn_weight = attn_weight.view(b, c) 37 | attn_weight = self.attn_head_ffn(attn_weight).view(b, c, -1, 1) 38 | 39 | return attn_weight 40 | 41 | def forward(self, attn_feat, x): 42 | weight = self.forward_weight(attn_feat) 43 | return x * weight.expand_as(x), x * weight.expand_as(x) 44 | 45 | 46 | # 2. using randomly initialized weight to generate dynamic weight 47 | class AAF_RandWeight(AAF): 48 | def __init__(self, channel, *args, **kwargs): 49 | super().__init__(*args, **kwargs) 50 | self.query = torch.randn(1, channel, requires_grad=False).cuda() 51 | 52 | def forward_weight(self, x): 53 | b, c, n, m = x.size() # batchsize, attn heads num=72, class tokens + patch tokens, embedding_dim=64 54 | 55 | attn_weight = self.attn_head_ffn(self.query.expand(b, -1)).unsqueeze(2).unsqueeze(3) 56 | 57 | return attn_weight -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 BAAI-Vision 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /OnlineRetraining/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Robin Strudel 4 | Copyright (c) INRIA 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/__init__.py: -------------------------------------------------------------------------------- 1 | from segm.datasets.coco import COCODataset -------------------------------------------------------------------------------- /OnlineRetraining/segm/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from pathlib import Path 3 | 4 | import os 5 | 6 | 7 | def load_config(): 8 | return yaml.load( 9 | open(Path(__file__).parent / "config.yml", "r"), Loader=yaml.FullLoader 10 | ) 11 | 12 | 13 | def check_os_environ(key, use): 14 | if key not in os.environ: 15 | raise ValueError( 16 | f"{key} is not defined in the os variables, it is required for {use}." 17 | ) 18 | 19 | 20 | def dataset_dir(): 21 | check_os_environ("DATASET", "data loading") 22 | return os.environ["DATASET"] 23 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/config.yml: -------------------------------------------------------------------------------- 1 | model: 2 | # deit 3 | deit_tiny_distilled_patch16_224: 4 | image_size: 224 5 | patch_size: 16 6 | d_model: 192 7 | n_heads: 3 8 | n_layers: 12 9 | normalization: deit 10 | distilled: true 11 | deit_small_distilled_patch16_224: 12 | image_size: 224 13 | patch_size: 16 14 | d_model: 384 15 | n_heads: 6 16 | n_layers: 12 17 | normalization: deit 18 | distilled: true 19 | deit_base_distilled_patch16_224: 20 | image_size: 224 21 | patch_size: 16 22 | d_model: 768 23 | n_heads: 12 24 | n_layers: 12 25 | normalization: deit 26 | distilled: true 27 | deit_base_distilled_patch16_384: 28 | image_size: 384 29 | patch_size: 16 30 | d_model: 768 31 | n_heads: 12 32 | n_layers: 12 33 | normalization: deit 34 | distilled: true 35 | deit_small_patch16_224: 36 | image_size: 224 37 | patch_size: 16 38 | d_model: 384 39 | n_heads: 6 40 | n_layers: 12 41 | normalization: deit 42 | distilled: false 43 | # vit 44 | vit_base_patch8_384: 45 | image_size: 384 46 | patch_size: 8 47 | d_model: 768 48 | n_heads: 12 49 | n_layers: 12 50 | normalization: vit 51 | distilled: false 52 | vit_tiny_patch16_384: 53 | image_size: 384 54 | patch_size: 16 55 | d_model: 192 56 | n_heads: 3 57 | n_layers: 12 58 | normalization: vit 59 | distilled: false 60 | vit_small_patch16_384: 61 | image_size: 384 62 | patch_size: 16 63 | d_model: 384 64 | n_heads: 6 65 | n_layers: 12 66 | normalization: vit 67 | distilled: false 68 | vit_base_patch16_384: 69 | image_size: 384 70 | patch_size: 16 71 | d_model: 768 72 | n_heads: 12 73 | n_layers: 12 74 | normalization: vit 75 | distilled: false 76 | vit_large_patch16_384: 77 | image_size: 384 78 | patch_size: 16 79 | d_model: 1024 80 | n_heads: 16 81 | n_layers: 24 82 | normalization: vit 83 | vit_small_patch32_384: 84 | image_size: 384 85 | patch_size: 32 86 | d_model: 384 87 | n_heads: 6 88 | n_layers: 12 89 | normalization: vit 90 | distilled: false 91 | vit_base_patch32_384: 92 | image_size: 384 93 | patch_size: 32 94 | d_model: 768 95 | n_heads: 12 96 | n_layers: 12 97 | normalization: vit 98 | vit_large_patch32_384: 99 | image_size: 384 100 | patch_size: 32 101 | d_model: 1024 102 | n_heads: 16 103 | n_layers: 24 104 | normalization: vit 105 | 106 | # dino 107 | dino_small_patch16_224: 108 | image_size: 224 109 | patch_size: 16 110 | d_model: 384 111 | n_heads: 6 112 | n_layers: 12 113 | normalization: deit 114 | distilled: false 115 | dinov2_small_patch16_224: 116 | image_size: 224 117 | patch_size: 16 118 | d_model: 384 119 | n_heads: 6 120 | n_layers: 12 121 | normalization: deit 122 | distilled: false 123 | dinov2_small_patch14_224: 124 | image_size: 224 125 | patch_size: 14 126 | d_model: 384 127 | n_heads: 6 128 | n_layers: 12 129 | normalization: deit 130 | distilled: false 131 | # eva 132 | eva02_tiny_patch16_224: 133 | image_size: 224 134 | patch_size: 16 135 | d_model: 192 136 | n_heads: 3 137 | n_layers: 12 138 | normalization: eva02 139 | distilled: false 140 | # eva 141 | eva02_small_patch16_224: 142 | image_size: 224 143 | patch_size: 16 144 | d_model: 384 145 | n_heads: 6 146 | n_layers: 12 147 | normalization: eva02 148 | distilled: false 149 | decoder: 150 | linear: {} 151 | deeplab_dec: 152 | encoder_layer: -1 153 | mask_transformer: 154 | drop_path_rate: 0.0 155 | dropout: 0.1 156 | n_layers: 2 157 | dataset: 158 | ade20k: 159 | epochs: 64 160 | eval_freq: 2 161 | batch_size: 8 162 | learning_rate: 0.001 163 | im_size: 512 164 | crop_size: 512 165 | window_size: 512 166 | window_stride: 512 167 | pascal_context: 168 | epochs: 256 169 | eval_freq: 8 170 | batch_size: 16 171 | learning_rate: 0.001 172 | im_size: 520 173 | crop_size: 480 174 | window_size: 480 175 | window_stride: 320 176 | pascal_voc: 177 | epochs: 100 178 | eval_freq: 8 179 | batch_size: 4 180 | learning_rate: 0.0001 181 | im_size: 520 182 | crop_size: 480 183 | window_size: 480 184 | window_stride: 320 185 | coco: 186 | epochs: 256 187 | eval_freq: 1 188 | batch_size: 16 189 | learning_rate: 0.001 190 | im_size: 520 191 | crop_size: 480 192 | window_size: 480 193 | window_stride: 320 194 | cityscapes: 195 | epochs: 216 196 | eval_freq: 4 197 | batch_size: 8 198 | learning_rate: 0.01 199 | im_size: 1024 200 | crop_size: 768 201 | window_size: 768 202 | window_stride: 512 203 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/data/__init__.py: -------------------------------------------------------------------------------- 1 | from segm.data.loader import Loader 2 | 3 | from segm.data.imagenet import ImagenetDataset 4 | from segm.data.ade20k import ADE20KSegmentation 5 | from segm.data.pascal_context import PascalContextDataset 6 | from segm.data.cityscapes import CityscapesDataset 7 | from segm.data.coco import COCODataset 8 | from segm.data.pascal_voc import PascalVOCDataset 9 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/data/ade20k.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from segm.data.base import BaseMMSeg 4 | from segm.data import utils 5 | from segm.config import dataset_dir 6 | 7 | 8 | ADE20K_CONFIG_PATH = Path(__file__).parent / "config" / "ade20k.py" 9 | ADE20K_CATS_PATH = Path(__file__).parent / "config" / "ade20k.yml" 10 | 11 | 12 | class ADE20KSegmentation(BaseMMSeg): 13 | def __init__(self, image_size, crop_size, split, **kwargs): 14 | super().__init__( 15 | image_size, 16 | crop_size, 17 | split, 18 | ADE20K_CONFIG_PATH, 19 | **kwargs, 20 | ) 21 | self.names, self.colors = utils.dataset_cat_description(ADE20K_CATS_PATH) 22 | self.n_cls = 150 23 | self.ignore_label = 0 24 | self.reduce_zero_label = True 25 | 26 | def update_default_config(self, config): 27 | root_dir = dataset_dir() 28 | path = Path(root_dir) / "ade20k" 29 | config.data_root = path 30 | if self.split == "train": 31 | config.data.train.data_root = path / "ADEChallengeData2016" 32 | elif self.split == "trainval": 33 | config.data.trainval.data_root = path / "ADEChallengeData2016" 34 | elif self.split == "val": 35 | config.data.val.data_root = path / "ADEChallengeData2016" 36 | elif self.split == "test": 37 | config.data.test.data_root = path / "release_test" 38 | config = super().update_default_config(config) 39 | return config 40 | 41 | def test_post_process(self, labels): 42 | return labels + 1 43 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/data/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pathlib import Path 3 | from PIL import Image, ImageOps, ImageFilter 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | import torchvision.transforms.functional as F 8 | 9 | from mmseg.datasets import build_dataset 10 | import mmcv 11 | from mmcv.utils import Config 12 | 13 | 14 | from segm.data.utils import STATS, IGNORE_LABEL 15 | 16 | from segm.utils.logger import printd 17 | 18 | class BaseMMSeg(Dataset): 19 | def __init__( 20 | self, 21 | image_size, 22 | crop_size, 23 | split, 24 | config_path, 25 | normalization, 26 | max_ratio=None, 27 | **kwargs, 28 | ): 29 | super().__init__() 30 | self.image_size = image_size 31 | self.crop_size = crop_size 32 | self.split = split 33 | self.normalization = STATS[normalization].copy() 34 | self.ignore_label = None 35 | for k, v in self.normalization.items(): 36 | v = np.round(255 * np.array(v), 2) 37 | self.normalization[k] = tuple(v) 38 | printd(f"Use normalization: {self.normalization}") 39 | 40 | config = Config.fromfile(config_path) 41 | 42 | if max_ratio is not None: 43 | self.ratio = max_ratio 44 | else: 45 | self.ratio = config.max_ratio 46 | self.dataset = None 47 | self.config = self.update_default_config(config) 48 | self.dataset = build_dataset(getattr(self.config.data, f"{self.split}")) 49 | 50 | def update_default_config(self, config): 51 | 52 | train_splits = ["train", "trainval"] 53 | if self.split in train_splits: 54 | config_pipeline = getattr(config, f"train_pipeline") 55 | else: 56 | config_pipeline = getattr(config, f"{self.split}_pipeline") 57 | 58 | img_scale = (self.ratio * self.image_size, self.image_size) 59 | if self.split not in train_splits: 60 | assert config_pipeline[1]["type"] == "MultiScaleFlipAug" 61 | config_pipeline = config_pipeline[1]["transforms"] 62 | for i, op in enumerate(config_pipeline): 63 | op_type = op["type"] 64 | if op_type == "Resize": 65 | op["img_scale"] = img_scale 66 | elif op_type == "RandomCrop": 67 | op["crop_size"] = ( 68 | self.crop_size, 69 | self.crop_size, 70 | ) 71 | elif op_type == "Normalize": 72 | op["mean"] = self.normalization["mean"] 73 | op["std"] = self.normalization["std"] 74 | elif op_type == "Pad": 75 | op["size"] = (self.crop_size, self.crop_size) 76 | config_pipeline[i] = op 77 | if self.split == "train": 78 | config.data.train.pipeline = config_pipeline 79 | elif self.split == "trainval": 80 | config.data.trainval.pipeline = config_pipeline 81 | elif self.split == "val": 82 | config.data.val.pipeline[1]["img_scale"] = img_scale 83 | config.data.val.pipeline[1]["transforms"] = config_pipeline 84 | elif self.split == "test": 85 | config.data.test.pipeline[1]["img_scale"] = img_scale 86 | config.data.test.pipeline[1]["transforms"] = config_pipeline 87 | config.data.test.test_mode = True 88 | else: 89 | raise ValueError(f"Unknown split: {self.split}") 90 | return config 91 | 92 | def set_multiscale_mode(self): 93 | self.config.data.val.pipeline[1]["img_ratios"] = [ 94 | 0.5, 95 | 0.75, 96 | 1.0, 97 | 1.25, 98 | 1.5, 99 | 1.75, 100 | ] 101 | self.config.data.val.pipeline[1]["flip"] = True 102 | self.config.data.test.pipeline[1]["img_ratios"] = [ 103 | 0.5, 104 | 0.75, 105 | 1.0, 106 | 1.25, 107 | 1.5, 108 | 1.75, 109 | ] 110 | self.config.data.test.pipeline[1]["flip"] = True 111 | self.dataset = build_dataset(getattr(self.config.data, f"{self.split}")) 112 | 113 | def __getitem__(self, idx): 114 | data = self.dataset[idx] 115 | 116 | train_splits = ["train", "trainval"] 117 | 118 | if self.split in train_splits: 119 | im = data["img"].data 120 | seg = data["gt_semantic_seg"].data.squeeze(0) 121 | else: 122 | im = [im.data for im in data["img"]] 123 | seg = None 124 | 125 | out = dict(im=im) 126 | if self.split in train_splits: 127 | out["segmentation"] = seg 128 | else: 129 | im_metas = [meta.data for meta in data["img_metas"]] 130 | out["im_metas"] = im_metas 131 | out["colors"] = self.colors 132 | 133 | return out 134 | 135 | def get_gt_seg_maps(self): 136 | dataset = self.dataset 137 | gt_seg_maps = {} 138 | for img_info in dataset.img_infos: 139 | seg_map = Path(dataset.ann_dir) / img_info["ann"]["seg_map"] 140 | gt_seg_map = mmcv.imread(seg_map, flag="unchanged", backend="pillow") 141 | gt_seg_map[gt_seg_map == self.ignore_label] = IGNORE_LABEL 142 | if self.reduce_zero_label: 143 | gt_seg_map[gt_seg_map != IGNORE_LABEL] -= 1 144 | gt_seg_maps[img_info["filename"]] = gt_seg_map 145 | return gt_seg_maps 146 | 147 | def __len__(self): 148 | return len(self.dataset) 149 | 150 | @property 151 | def unwrapped(self): 152 | return self 153 | 154 | def set_epoch(self, epoch): 155 | pass 156 | 157 | def get_diagnostics(self, logger): 158 | pass 159 | 160 | def get_snapshot(self): 161 | return {} 162 | 163 | def end_epoch(self, epoch): 164 | return 165 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/data/cityscapes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | try: 4 | import cityscapesscripts.helpers.labels as CSLabels 5 | except: 6 | pass 7 | 8 | from pathlib import Path 9 | from segm.data.base import BaseMMSeg 10 | from segm.data import utils 11 | from segm.config import dataset_dir 12 | 13 | CITYSCAPES_CONFIG_PATH = Path(__file__).parent / "config" / "cityscapes.py" 14 | CITYSCAPES_CATS_PATH = Path(__file__).parent / "config" / "cityscapes.yml" 15 | 16 | 17 | class CityscapesDataset(BaseMMSeg): 18 | def __init__(self, image_size, crop_size, split, **kwargs): 19 | super().__init__(image_size, crop_size, split, CITYSCAPES_CONFIG_PATH, **kwargs) 20 | self.names, self.colors = utils.dataset_cat_description(CITYSCAPES_CATS_PATH) 21 | self.n_cls = 19 22 | self.ignore_label = 255 23 | self.reduce_zero_label = False 24 | 25 | def update_default_config(self, config): 26 | 27 | root_dir = dataset_dir() 28 | path = Path(root_dir) / "cityscapes" 29 | config.data_root = path 30 | 31 | config.data[self.split]["data_root"] = path 32 | config = super().update_default_config(config) 33 | 34 | return config 35 | 36 | def test_post_process(self, labels): 37 | labels_copy = np.copy(labels) 38 | cats = np.unique(labels_copy) 39 | for cat in cats: 40 | labels_copy[labels == cat] = CSLabels.trainId2label[cat].id 41 | return labels_copy 42 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/data/coco.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from segm.data.base import BaseMMSeg 4 | from segm.data import utils 5 | from segm.config import dataset_dir 6 | from mmseg.datasets import DATASETS 7 | 8 | COCO_CONTEXT_CONFIG_PATH = Path(__file__).parent / "config" / "coco.py" 9 | COCO_CONTEXT_CATS_PATH = Path(__file__).parent / "config" / "coco.yml" 10 | 11 | 12 | class COCODataset(BaseMMSeg): 13 | def __init__(self, image_size, crop_size, split, ann_dir=None, eval_split=None, **kwargs): 14 | self.names, self.colors = utils.dataset_cat_description( 15 | COCO_CONTEXT_CATS_PATH 16 | ) 17 | self.n_cls = 91 18 | self.ignore_label = 255 19 | self.reduce_zero_label = False 20 | self.ann_dir = ann_dir 21 | self.eval_split = eval_split 22 | super().__init__( 23 | image_size, crop_size, split, COCO_CONTEXT_CONFIG_PATH, **kwargs 24 | ) 25 | 26 | def update_default_config(self, config): 27 | root_dir = dataset_dir() 28 | path = Path(root_dir) / "coco" 29 | config.data_root = path 30 | if self.split == "train": 31 | config.data.train.data_root = path 32 | if self.ann_dir is not None: 33 | config.data.train.ann_dir = self.ann_dir 34 | elif self.split == "val": 35 | config.data.val.data_root = path 36 | if self.eval_split is not None: 37 | config.data.val.split = self.eval_split 38 | elif self.split == "test": 39 | raise ValueError("Test split is not valid for Pascal Context dataset") 40 | config = super().update_default_config(config) 41 | return config 42 | 43 | def test_post_process(self, labels): 44 | return labels 45 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/data/config/ade20k.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = "ADE20KDataset" 3 | data_root = "data/ade/ADEChallengeData2016" 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True 6 | ) 7 | crop_size = (512, 512) 8 | max_ratio = 4 9 | train_pipeline = [ 10 | dict(type="LoadImageFromFile"), 11 | dict(type="LoadAnnotations", reduce_zero_label=True), 12 | dict(type="Resize", img_scale=(512 * max_ratio, 512), ratio_range=(0.5, 2.0)), 13 | dict(type="RandomCrop", crop_size=crop_size, cat_max_ratio=0.75), 14 | dict(type="RandomFlip", prob=0.5), 15 | dict(type="PhotoMetricDistortion"), 16 | dict(type="Normalize", **img_norm_cfg), 17 | dict(type="Pad", size=crop_size, pad_val=0, seg_pad_val=255), 18 | dict(type="DefaultFormatBundle"), 19 | dict(type="Collect", keys=["img", "gt_semantic_seg"]), 20 | ] 21 | val_pipeline = [ 22 | dict(type="LoadImageFromFile"), 23 | dict( 24 | type="MultiScaleFlipAug", 25 | img_scale=(512 * max_ratio, 512), 26 | flip=False, 27 | transforms=[ 28 | dict(type="Resize", keep_ratio=True), 29 | dict(type="RandomFlip"), 30 | dict(type="Normalize", **img_norm_cfg), 31 | dict(type="ImageToTensor", keys=["img"]), 32 | dict(type="Collect", keys=["img"]), 33 | ], 34 | ), 35 | ] 36 | test_pipeline = [ 37 | dict(type="LoadImageFromFile"), 38 | dict( 39 | type="MultiScaleFlipAug", 40 | img_scale=(512 * max_ratio, 512), 41 | flip=False, 42 | transforms=[ 43 | dict(type="Resize", keep_ratio=True), 44 | dict(type="RandomFlip"), 45 | dict(type="Normalize", **img_norm_cfg), 46 | dict(type="ImageToTensor", keys=["img"]), 47 | dict(type="Collect", keys=["img"]), 48 | ], 49 | ), 50 | ] 51 | data = dict( 52 | samples_per_gpu=4, 53 | workers_per_gpu=4, 54 | train=dict( 55 | type=dataset_type, 56 | data_root=data_root, 57 | img_dir="images/training", 58 | ann_dir="annotations/training", 59 | pipeline=train_pipeline, 60 | ), 61 | trainval=dict( 62 | type=dataset_type, 63 | data_root=data_root, 64 | img_dir=["images/training", "images/validation"], 65 | ann_dir=["annotations/training", "annotations/validation"], 66 | pipeline=train_pipeline, 67 | ), 68 | val=dict( 69 | type=dataset_type, 70 | data_root=data_root, 71 | img_dir="images/validation", 72 | ann_dir="annotations/validation", 73 | pipeline=val_pipeline, 74 | ), 75 | test=dict( 76 | type=dataset_type, 77 | data_root=data_root, 78 | img_dir="testing", 79 | pipeline=test_pipeline, 80 | ), 81 | ) 82 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/data/config/cityscapes.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = "CityscapesDataset" 3 | data_root = "data/cityscapes/" 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True 6 | ) 7 | crop_size = (768, 768) 8 | max_ratio = 2 9 | train_pipeline = [ 10 | dict(type="LoadImageFromFile"), 11 | dict(type="LoadAnnotations"), 12 | dict(type="Resize", img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), 13 | dict(type="RandomCrop", crop_size=crop_size, cat_max_ratio=0.75), 14 | dict(type="RandomFlip", prob=0.5), 15 | dict(type="PhotoMetricDistortion"), 16 | dict(type="Normalize", **img_norm_cfg), 17 | dict(type="Pad", size=crop_size, pad_val=0, seg_pad_val=255), 18 | dict(type="DefaultFormatBundle"), 19 | dict(type="Collect", keys=["img", "gt_semantic_seg"]), 20 | ] 21 | val_pipeline = [ 22 | dict(type="LoadImageFromFile"), 23 | dict( 24 | type="MultiScaleFlipAug", 25 | img_scale=(1024 * max_ratio, 1024), 26 | flip=False, 27 | transforms=[ 28 | dict(type="Resize", keep_ratio=True), 29 | dict(type="RandomFlip"), 30 | dict(type="Normalize", **img_norm_cfg), 31 | dict(type="ImageToTensor", keys=["img"]), 32 | dict(type="Collect", keys=["img"]), 33 | ], 34 | ), 35 | ] 36 | test_pipeline = [ 37 | dict(type="LoadImageFromFile"), 38 | dict( 39 | type="MultiScaleFlipAug", 40 | img_scale=(1024 * max_ratio, 1024), 41 | flip=False, 42 | transforms=[ 43 | dict(type="Resize", keep_ratio=True), 44 | dict(type="RandomFlip"), 45 | dict(type="Normalize", **img_norm_cfg), 46 | dict(type="ImageToTensor", keys=["img"]), 47 | dict(type="Collect", keys=["img"]), 48 | ], 49 | ), 50 | ] 51 | data = dict( 52 | samples_per_gpu=2, 53 | workers_per_gpu=2, 54 | train=dict( 55 | type=dataset_type, 56 | data_root=data_root, 57 | img_dir="leftImg8bit/train", 58 | ann_dir="gtFine/train", 59 | pipeline=train_pipeline, 60 | ), 61 | trainval=dict( 62 | type=dataset_type, 63 | data_root=data_root, 64 | img_dir=["leftImg8bit/train", "leftImg8bit/val"], 65 | ann_dir=["gtFine/train", "gtFine/val"], 66 | pipeline=train_pipeline, 67 | ), 68 | val=dict( 69 | type=dataset_type, 70 | data_root=data_root, 71 | img_dir="leftImg8bit/val", 72 | ann_dir="gtFine/val", 73 | pipeline=test_pipeline, 74 | ), 75 | test=dict( 76 | type=dataset_type, 77 | data_root=data_root, 78 | img_dir="leftImg8bit/test", 79 | ann_dir="gtFine/test", 80 | pipeline=test_pipeline, 81 | ), 82 | ) 83 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/data/config/cityscapes.yml: -------------------------------------------------------------------------------- 1 | - color: 2 | - 128 3 | - 64 4 | - 128 5 | id: 0 6 | isthing: false 7 | name: road 8 | - color: 9 | - 244 10 | - 35 11 | - 232 12 | id: 1 13 | isthing: false 14 | name: sidewalk 15 | - color: 16 | - 70 17 | - 70 18 | - 70 19 | id: 2 20 | isthing: false 21 | name: building 22 | - color: 23 | - 102 24 | - 102 25 | - 156 26 | id: 3 27 | isthing: false 28 | name: wall 29 | - color: 30 | - 190 31 | - 153 32 | - 153 33 | id: 4 34 | isthing: false 35 | name: fence 36 | - color: 37 | - 153 38 | - 153 39 | - 153 40 | id: 5 41 | isthing: false 42 | name: pole 43 | - color: 44 | - 250 45 | - 170 46 | - 30 47 | id: 6 48 | isthing: false 49 | name: traffic light 50 | - color: 51 | - 220 52 | - 220 53 | - 0 54 | id: 7 55 | isthing: false 56 | name: traffic sign 57 | - color: 58 | - 107 59 | - 142 60 | - 35 61 | id: 8 62 | isthing: false 63 | name: vegetation 64 | - color: 65 | - 152 66 | - 251 67 | - 152 68 | id: 9 69 | isthing: false 70 | name: terrain 71 | - color: 72 | - 70 73 | - 130 74 | - 180 75 | id: 10 76 | isthing: false 77 | name: sky 78 | - color: 79 | - 220 80 | - 20 81 | - 60 82 | id: 11 83 | isthing: true 84 | name: person 85 | - color: 86 | - 255 87 | - 0 88 | - 0 89 | id: 12 90 | isthing: true 91 | name: rider 92 | - color: 93 | - 0 94 | - 0 95 | - 142 96 | id: 13 97 | isthing: true 98 | name: car 99 | - color: 100 | - 0 101 | - 0 102 | - 70 103 | id: 14 104 | isthing: true 105 | name: truck 106 | - color: 107 | - 0 108 | - 60 109 | - 100 110 | id: 15 111 | isthing: true 112 | name: bus 113 | - color: 114 | - 0 115 | - 80 116 | - 100 117 | id: 16 118 | isthing: true 119 | name: train 120 | - color: 121 | - 0 122 | - 0 123 | - 230 124 | id: 17 125 | isthing: true 126 | name: motorcycle 127 | - color: 128 | - 119 129 | - 11 130 | - 32 131 | id: 18 132 | isthing: true 133 | name: bicycle 134 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/data/config/coco.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = "COCODataset" 3 | data_root = "data/COCO14" 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True 6 | ) 7 | 8 | img_scale = (512, 512) 9 | crop_size = (512, 512) 10 | max_ratio = 8 11 | train_pipeline = [ 12 | dict(type="LoadImageFromFile"), 13 | dict(type="LoadAnnotations"), 14 | dict(type="Resize", img_scale=img_scale, ratio_range=(0.5, 2.0)), 15 | dict(type="RandomCrop", crop_size=crop_size, cat_max_ratio=0.75), 16 | dict(type="RandomFlip", prob=0.5), 17 | dict(type="PhotoMetricDistortion"), 18 | dict(type="Normalize", **img_norm_cfg), 19 | dict(type="Pad", size=crop_size, pad_val=0, seg_pad_val=255), 20 | dict(type="DefaultFormatBundle"), 21 | dict(type="Collect", keys=["img", "gt_semantic_seg"]), 22 | ] 23 | val_pipeline = [ 24 | dict(type="LoadImageFromFile"), 25 | dict( 26 | type="MultiScaleFlipAug", 27 | img_scale=(512 * max_ratio, 512), 28 | flip=False, 29 | transforms=[ 30 | dict(type="Resize", keep_ratio=True), 31 | dict(type="RandomFlip"), 32 | dict(type="Normalize", **img_norm_cfg), 33 | dict(type="ImageToTensor", keys=["img"]), 34 | dict(type="Collect", keys=["img"]), 35 | ], 36 | ), 37 | ] 38 | test_pipeline = [ 39 | dict(type="LoadImageFromFile"), 40 | dict( 41 | type="MultiScaleFlipAug", 42 | img_scale=(512 * max_ratio, 512), 43 | flip=False, 44 | transforms=[ 45 | dict(type="Resize", keep_ratio=True), 46 | dict(type="RandomFlip"), 47 | dict(type="Normalize", **img_norm_cfg), 48 | dict(type="ImageToTensor", keys=["img"]), 49 | dict(type="Collect", keys=["img"]), 50 | ], 51 | ), 52 | ] 53 | data = dict( 54 | samples_per_gpu=4, 55 | workers_per_gpu=2, 56 | train=dict( 57 | type=dataset_type, 58 | data_root=data_root, 59 | img_dir="images", 60 | ann_dir="voc_format/class_labels", 61 | split="voc_format/train.txt", 62 | pipeline=train_pipeline, 63 | ), 64 | val=dict( 65 | type=dataset_type, 66 | data_root=data_root, 67 | img_dir="images", 68 | ann_dir="voc_format/class_labels", 69 | split="voc_format/val_5000.txt", 70 | pipeline=val_pipeline, 71 | ), 72 | test=dict( 73 | type=dataset_type, 74 | data_root=data_root, 75 | img_dir="images", 76 | ann_dir="voc_format/class_labels", 77 | split="voc_format/val_5000.txt", 78 | pipeline=test_pipeline, 79 | ), 80 | ) 81 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/data/config/coco.yml: -------------------------------------------------------------------------------- 1 | - color: 2 | - 120 3 | - 120 4 | - 120 5 | id: 0 6 | name: background 7 | - color: 8 | - 180 9 | - 120 10 | - 120 11 | id: 1 12 | name: person 13 | - color: 14 | - 6 15 | - 230 16 | - 230 17 | id: 2 18 | name: bicycle 19 | - color: 20 | - 80 21 | - 50 22 | - 50 23 | id: 3 24 | name: car 25 | - color: 26 | - 4 27 | - 200 28 | - 3 29 | id: 4 30 | name: motorcycle 31 | - color: 32 | - 120 33 | - 120 34 | - 80 35 | id: 5 36 | name: airplane 37 | - color: 38 | - 140 39 | - 140 40 | - 140 41 | id: 6 42 | name: bus 43 | - color: 44 | - 204 45 | - 5 46 | - 255 47 | id: 7 48 | name: train 49 | - color: 50 | - 230 51 | - 230 52 | - 230 53 | id: 8 54 | name: truck 55 | - color: 56 | - 4 57 | - 250 58 | - 7 59 | id: 9 60 | name: boat 61 | - color: 62 | - 224 63 | - 5 64 | - 255 65 | id: 10 66 | name: traffic light 67 | - color: 68 | - 235 69 | - 255 70 | - 7 71 | id: 11 72 | name: fire hydrant 73 | - color: 74 | - 150 75 | - 5 76 | - 61 77 | id: 12 78 | name: street sign 79 | - color: 80 | - 120 81 | - 120 82 | - 70 83 | id: 13 84 | name: stop sign 85 | - color: 86 | - 8 87 | - 255 88 | - 51 89 | id: 14 90 | name: parking meter 91 | - color: 92 | - 255 93 | - 6 94 | - 82 95 | id: 15 96 | name: bench 97 | - color: 98 | - 143 99 | - 255 100 | - 140 101 | id: 16 102 | name: bird 103 | - color: 104 | - 204 105 | - 255 106 | - 4 107 | id: 17 108 | name: cat 109 | - color: 110 | - 255 111 | - 51 112 | - 7 113 | id: 18 114 | name: dog 115 | - color: 116 | - 204 117 | - 70 118 | - 3 119 | id: 19 120 | name: horse 121 | - color: 122 | - 0 123 | - 102 124 | - 200 125 | id: 20 126 | name: sheep 127 | - color: 128 | - 61 129 | - 230 130 | - 250 131 | id: 21 132 | name: cow 133 | - color: 134 | - 255 135 | - 6 136 | - 51 137 | id: 22 138 | name: elephant 139 | - color: 140 | - 11 141 | - 102 142 | - 255 143 | id: 23 144 | name: bear 145 | - color: 146 | - 255 147 | - 7 148 | - 71 149 | id: 24 150 | name: zebra 151 | - color: 152 | - 255 153 | - 9 154 | - 224 155 | id: 25 156 | name: giraffe 157 | - color: 158 | - 9 159 | - 7 160 | - 230 161 | id: 26 162 | name: hat 163 | - color: 164 | - 220 165 | - 220 166 | - 220 167 | id: 27 168 | name: backpack 169 | - color: 170 | - 255 171 | - 9 172 | - 92 173 | id: 28 174 | name: umbrella 175 | - color: 176 | - 112 177 | - 9 178 | - 255 179 | id: 29 180 | name: shoe 181 | - color: 182 | - 8 183 | - 255 184 | - 214 185 | id: 30 186 | name: eye glasses 187 | - color: 188 | - 7 189 | - 255 190 | - 224 191 | id: 31 192 | name: handbag 193 | - color: 194 | - 255 195 | - 184 196 | - 6 197 | id: 32 198 | name: tie 199 | - color: 200 | - 10 201 | - 255 202 | - 71 203 | id: 33 204 | name: suitcase 205 | - color: 206 | - 255 207 | - 41 208 | - 10 209 | id: 34 210 | name: frisbee 211 | - color: 212 | - 7 213 | - 255 214 | - 255 215 | id: 35 216 | name: skis 217 | - color: 218 | - 224 219 | - 255 220 | - 8 221 | id: 36 222 | name: snowboard 223 | - color: 224 | - 102 225 | - 8 226 | - 255 227 | id: 37 228 | name: sports ball 229 | - color: 230 | - 255 231 | - 61 232 | - 6 233 | id: 38 234 | name: kite 235 | - color: 236 | - 255 237 | - 194 238 | - 7 239 | id: 39 240 | name: baseball bat 241 | - color: 242 | - 255 243 | - 122 244 | - 8 245 | id: 40 246 | name: baseball glove 247 | - color: 248 | - 0 249 | - 255 250 | - 20 251 | id: 41 252 | name: skateboard 253 | - color: 254 | - 255 255 | - 8 256 | - 41 257 | id: 42 258 | name: surfboard 259 | - color: 260 | - 255 261 | - 5 262 | - 153 263 | id: 43 264 | name: tennis racket 265 | - color: 266 | - 6 267 | - 51 268 | - 255 269 | id: 44 270 | name: bottle 271 | - color: 272 | - 235 273 | - 12 274 | - 255 275 | id: 45 276 | name: plate 277 | - color: 278 | - 160 279 | - 150 280 | - 20 281 | id: 46 282 | name: wine glass 283 | - color: 284 | - 0 285 | - 163 286 | - 255 287 | id: 47 288 | name: cup 289 | - color: 290 | - 140 291 | - 140 292 | - 140 293 | id: 48 294 | name: fork 295 | - color: 296 | - 250 297 | - 10 298 | - 15 299 | id: 49 300 | name: knife 301 | - color: 302 | - 20 303 | - 255 304 | - 0 305 | id: 50 306 | name: spoon 307 | - color: 308 | - 31 309 | - 255 310 | - 0 311 | id: 51 312 | name: bowl 313 | - color: 314 | - 255 315 | - 31 316 | - 0 317 | id: 52 318 | name: banana 319 | - color: 320 | - 255 321 | - 224 322 | - 0 323 | id: 53 324 | name: apple 325 | - color: 326 | - 153 327 | - 255 328 | - 0 329 | id: 54 330 | name: sandwich 331 | - color: 332 | - 0 333 | - 0 334 | - 255 335 | id: 55 336 | name: orange 337 | - color: 338 | - 255 339 | - 71 340 | - 0 341 | id: 56 342 | name: broccoli 343 | - color: 344 | - 0 345 | - 235 346 | - 255 347 | id: 57 348 | name: carrot 349 | - color: 350 | - 0 351 | - 173 352 | - 255 353 | id: 58 354 | name: hot dog 355 | - color: 356 | - 31 357 | - 0 358 | - 255 359 | id: 59 360 | name: pizza 361 | - color: 362 | - 120 363 | - 120 364 | - 120 365 | id: 60 366 | name: donut 367 | - color: 368 | - 180 369 | - 120 370 | - 120 371 | id: 61 372 | name: cake 373 | - color: 374 | - 6 375 | - 230 376 | - 230 377 | id: 62 378 | name: chair 379 | - color: 380 | - 80 381 | - 50 382 | - 50 383 | id: 63 384 | name: couch 385 | - color: 386 | - 4 387 | - 200 388 | - 3 389 | id: 64 390 | name: potted plant 391 | - color: 392 | - 120 393 | - 120 394 | - 80 395 | id: 65 396 | name: bed 397 | - color: 398 | - 140 399 | - 140 400 | - 140 401 | id: 66 402 | name: mirror 403 | - color: 404 | - 204 405 | - 5 406 | - 255 407 | id: 67 408 | name: dining table 409 | - color: 410 | - 230 411 | - 230 412 | - 230 413 | id: 68 414 | name: window 415 | - color: 416 | - 4 417 | - 250 418 | - 7 419 | id: 69 420 | name: desk 421 | - color: 422 | - 224 423 | - 5 424 | - 255 425 | id: 70 426 | name: toilet 427 | - color: 428 | - 235 429 | - 255 430 | - 7 431 | id: 71 432 | name: door 433 | - color: 434 | - 150 435 | - 5 436 | - 61 437 | id: 72 438 | name: tv 439 | - color: 440 | - 120 441 | - 120 442 | - 70 443 | id: 73 444 | name: laptop 445 | - color: 446 | - 8 447 | - 255 448 | - 51 449 | id: 74 450 | name: mouse 451 | - color: 452 | - 255 453 | - 6 454 | - 82 455 | id: 75 456 | name: remote 457 | - color: 458 | - 143 459 | - 255 460 | - 140 461 | id: 76 462 | name: keyboard 463 | - color: 464 | - 204 465 | - 255 466 | - 4 467 | id: 77 468 | name: cell phone 469 | - color: 470 | - 255 471 | - 51 472 | - 7 473 | id: 78 474 | name: microwave 475 | - color: 476 | - 204 477 | - 70 478 | - 3 479 | id: 79 480 | name: oven 481 | - color: 482 | - 0 483 | - 102 484 | - 200 485 | id: 80 486 | name: toaster 487 | - color: 488 | - 61 489 | - 230 490 | - 250 491 | id: 81 492 | name: sink 493 | - color: 494 | - 255 495 | - 6 496 | - 51 497 | id: 82 498 | name: refrigerator 499 | - color: 500 | - 11 501 | - 102 502 | - 255 503 | id: 83 504 | name: blender 505 | - color: 506 | - 255 507 | - 7 508 | - 71 509 | id: 84 510 | name: book 511 | - color: 512 | - 255 513 | - 9 514 | - 224 515 | id: 85 516 | name: clock 517 | - color: 518 | - 9 519 | - 7 520 | - 230 521 | id: 86 522 | name: vase 523 | - color: 524 | - 220 525 | - 220 526 | - 220 527 | id: 87 528 | name: scissors 529 | - color: 530 | - 255 531 | - 9 532 | - 92 533 | id: 88 534 | name: teddy bear 535 | - color: 536 | - 112 537 | - 9 538 | - 255 539 | id: 89 540 | name: hair drier 541 | - color: 542 | - 8 543 | - 255 544 | - 214 545 | id: 90 546 | name: toothbrush 547 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/data/config/pascal_context.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = "PascalContextDataset" 3 | data_root = "data/VOCdevkit/VOC2010/" 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True 6 | ) 7 | 8 | img_scale = (512, 512) 9 | crop_size = (512, 512) 10 | max_ratio = 8 11 | train_pipeline = [ 12 | dict(type="LoadImageFromFile"), 13 | dict(type="LoadAnnotations"), 14 | dict(type="Resize", img_scale=img_scale, ratio_range=(0.5, 2.0)), 15 | dict(type="RandomCrop", crop_size=crop_size, cat_max_ratio=0.75), 16 | dict(type="RandomFlip", prob=0.5), 17 | dict(type="PhotoMetricDistortion"), 18 | dict(type="Normalize", **img_norm_cfg), 19 | dict(type="Pad", size=crop_size, pad_val=0, seg_pad_val=255), 20 | dict(type="DefaultFormatBundle"), 21 | dict(type="Collect", keys=["img", "gt_semantic_seg"]), 22 | ] 23 | val_pipeline = [ 24 | dict(type="LoadImageFromFile"), 25 | dict( 26 | type="MultiScaleFlipAug", 27 | img_scale=(512 * max_ratio, 512), 28 | flip=False, 29 | transforms=[ 30 | dict(type="Resize", keep_ratio=True), 31 | dict(type="RandomFlip"), 32 | dict(type="Normalize", **img_norm_cfg), 33 | dict(type="ImageToTensor", keys=["img"]), 34 | dict(type="Collect", keys=["img"]), 35 | ], 36 | ), 37 | ] 38 | test_pipeline = [ 39 | dict(type="LoadImageFromFile"), 40 | dict( 41 | type="MultiScaleFlipAug", 42 | img_scale=(512 * max_ratio, 512), 43 | flip=False, 44 | transforms=[ 45 | dict(type="Resize", keep_ratio=True), 46 | dict(type="RandomFlip"), 47 | dict(type="Normalize", **img_norm_cfg), 48 | dict(type="ImageToTensor", keys=["img"]), 49 | dict(type="Collect", keys=["img"]), 50 | ], 51 | ), 52 | ] 53 | data = dict( 54 | samples_per_gpu=4, 55 | workers_per_gpu=4, 56 | train=dict( 57 | type=dataset_type, 58 | data_root=data_root, 59 | img_dir="JPEGImages", 60 | ann_dir="SegmentationClassContext", 61 | split="ImageSets/SegmentationContext/train.txt", 62 | pipeline=train_pipeline, 63 | ), 64 | val=dict( 65 | type=dataset_type, 66 | data_root=data_root, 67 | img_dir="JPEGImages", 68 | ann_dir="SegmentationClassContext", 69 | split="ImageSets/SegmentationContext/val.txt", 70 | pipeline=val_pipeline, 71 | ), 72 | test=dict( 73 | type=dataset_type, 74 | data_root=data_root, 75 | img_dir="JPEGImages", 76 | ann_dir="SegmentationClassContext", 77 | split="ImageSets/SegmentationContext/val.txt", 78 | pipeline=test_pipeline, 79 | ), 80 | ) 81 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/data/config/pascal_context.yml: -------------------------------------------------------------------------------- 1 | - color: 2 | - 120 3 | - 120 4 | - 120 5 | id: 0 6 | name: background 7 | - color: 8 | - 180 9 | - 120 10 | - 120 11 | id: 1 12 | name: aeroplane 13 | - color: 14 | - 6 15 | - 230 16 | - 230 17 | id: 2 18 | name: bicycle 19 | - color: 20 | - 80 21 | - 50 22 | - 50 23 | id: 3 24 | name: bird 25 | - color: 26 | - 4 27 | - 200 28 | - 3 29 | id: 4 30 | name: boat 31 | - color: 32 | - 120 33 | - 120 34 | - 80 35 | id: 5 36 | name: bottle 37 | - color: 38 | - 140 39 | - 140 40 | - 140 41 | id: 6 42 | name: bus 43 | - color: 44 | - 204 45 | - 5 46 | - 255 47 | id: 7 48 | name: car 49 | - color: 50 | - 230 51 | - 230 52 | - 230 53 | id: 8 54 | name: cat 55 | - color: 56 | - 4 57 | - 250 58 | - 7 59 | id: 9 60 | name: chair 61 | - color: 62 | - 224 63 | - 5 64 | - 255 65 | id: 10 66 | name: cow 67 | - color: 68 | - 235 69 | - 255 70 | - 7 71 | id: 11 72 | name: table 73 | - color: 74 | - 150 75 | - 5 76 | - 61 77 | id: 12 78 | name: dog 79 | - color: 80 | - 120 81 | - 120 82 | - 70 83 | id: 13 84 | name: horse 85 | - color: 86 | - 8 87 | - 255 88 | - 51 89 | id: 14 90 | name: motorbike 91 | - color: 92 | - 255 93 | - 6 94 | - 82 95 | id: 15 96 | name: person 97 | - color: 98 | - 143 99 | - 255 100 | - 140 101 | id: 16 102 | name: pottedplant 103 | - color: 104 | - 204 105 | - 255 106 | - 4 107 | id: 17 108 | name: sheep 109 | - color: 110 | - 255 111 | - 51 112 | - 7 113 | id: 18 114 | name: sofa 115 | - color: 116 | - 204 117 | - 70 118 | - 3 119 | id: 19 120 | name: train 121 | - color: 122 | - 0 123 | - 102 124 | - 200 125 | id: 20 126 | name: tvmonitor 127 | - color: 128 | - 61 129 | - 230 130 | - 250 131 | id: 21 132 | name: bag 133 | - color: 134 | - 255 135 | - 6 136 | - 51 137 | id: 22 138 | name: bed 139 | - color: 140 | - 11 141 | - 102 142 | - 255 143 | id: 23 144 | name: bench 145 | - color: 146 | - 255 147 | - 7 148 | - 71 149 | id: 24 150 | name: book 151 | - color: 152 | - 255 153 | - 9 154 | - 224 155 | id: 25 156 | name: building 157 | - color: 158 | - 9 159 | - 7 160 | - 230 161 | id: 26 162 | name: cabinet 163 | - color: 164 | - 220 165 | - 220 166 | - 220 167 | id: 27 168 | name: ceiling 169 | - color: 170 | - 255 171 | - 9 172 | - 92 173 | id: 28 174 | name: cloth 175 | - color: 176 | - 112 177 | - 9 178 | - 255 179 | id: 29 180 | name: computer 181 | - color: 182 | - 8 183 | - 255 184 | - 214 185 | id: 30 186 | name: cup 187 | - color: 188 | - 7 189 | - 255 190 | - 224 191 | id: 31 192 | name: door 193 | - color: 194 | - 255 195 | - 184 196 | - 6 197 | id: 32 198 | name: fence 199 | - color: 200 | - 10 201 | - 255 202 | - 71 203 | id: 33 204 | name: floor 205 | - color: 206 | - 255 207 | - 41 208 | - 10 209 | id: 34 210 | name: flower 211 | - color: 212 | - 7 213 | - 255 214 | - 255 215 | id: 35 216 | name: food 217 | - color: 218 | - 224 219 | - 255 220 | - 8 221 | id: 36 222 | name: grass 223 | - color: 224 | - 102 225 | - 8 226 | - 255 227 | id: 37 228 | name: ground 229 | - color: 230 | - 255 231 | - 61 232 | - 6 233 | id: 38 234 | name: keyboard 235 | - color: 236 | - 255 237 | - 194 238 | - 7 239 | id: 39 240 | name: light 241 | - color: 242 | - 255 243 | - 122 244 | - 8 245 | id: 40 246 | name: mountain 247 | - color: 248 | - 0 249 | - 255 250 | - 20 251 | id: 41 252 | name: mouse 253 | - color: 254 | - 255 255 | - 8 256 | - 41 257 | id: 42 258 | name: curtain 259 | - color: 260 | - 255 261 | - 5 262 | - 153 263 | id: 43 264 | name: platform 265 | - color: 266 | - 6 267 | - 51 268 | - 255 269 | id: 44 270 | name: sign 271 | - color: 272 | - 235 273 | - 12 274 | - 255 275 | id: 45 276 | name: plate 277 | - color: 278 | - 160 279 | - 150 280 | - 20 281 | id: 46 282 | name: road 283 | - color: 284 | - 0 285 | - 163 286 | - 255 287 | id: 47 288 | name: rock 289 | - color: 290 | - 140 291 | - 140 292 | - 140 293 | id: 48 294 | name: shelves 295 | - color: 296 | - 250 297 | - 10 298 | - 15 299 | id: 49 300 | name: sidewalk 301 | - color: 302 | - 20 303 | - 255 304 | - 0 305 | id: 50 306 | name: sky 307 | - color: 308 | - 31 309 | - 255 310 | - 0 311 | id: 51 312 | name: snow 313 | - color: 314 | - 255 315 | - 31 316 | - 0 317 | id: 52 318 | name: bedclothes 319 | - color: 320 | - 255 321 | - 224 322 | - 0 323 | id: 53 324 | name: track 325 | - color: 326 | - 153 327 | - 255 328 | - 0 329 | id: 54 330 | name: tree 331 | - color: 332 | - 0 333 | - 0 334 | - 255 335 | id: 55 336 | name: truck 337 | - color: 338 | - 255 339 | - 71 340 | - 0 341 | id: 56 342 | name: wall 343 | - color: 344 | - 0 345 | - 235 346 | - 255 347 | id: 57 348 | name: water 349 | - color: 350 | - 0 351 | - 173 352 | - 255 353 | id: 58 354 | name: window 355 | - color: 356 | - 31 357 | - 0 358 | - 255 359 | id: 59 360 | name: wood 361 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/data/config/pascal_voc.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = "PascalVOCDataset" 3 | data_root = "data/VOCdevkit/VOC2012/" 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True 6 | ) 7 | 8 | img_scale = (512, 512) 9 | crop_size = (512, 512) 10 | max_ratio = 8 11 | train_pipeline = [ 12 | dict(type="LoadImageFromFile"), 13 | dict(type="LoadAnnotations"), 14 | dict(type="Resize", img_scale=img_scale, ratio_range=(0.5, 2.0)), 15 | dict(type="RandomCrop", crop_size=crop_size, cat_max_ratio=0.75), 16 | dict(type="RandomFlip", prob=0.5), 17 | dict(type="PhotoMetricDistortion"), 18 | dict(type="Normalize", **img_norm_cfg), 19 | dict(type="Pad", size=crop_size, pad_val=0, seg_pad_val=255), 20 | dict(type="DefaultFormatBundle"), 21 | dict(type="Collect", keys=["img", "gt_semantic_seg"]), 22 | ] 23 | val_pipeline = [ 24 | dict(type="LoadImageFromFile"), 25 | dict( 26 | type="MultiScaleFlipAug", 27 | img_scale=(512 * max_ratio, 512), 28 | flip=False, 29 | transforms=[ 30 | dict(type="Resize", keep_ratio=True), 31 | dict(type="RandomFlip"), 32 | dict(type="Normalize", **img_norm_cfg), 33 | dict(type="ImageToTensor", keys=["img"]), 34 | dict(type="Collect", keys=["img"]), 35 | ], 36 | ), 37 | ] 38 | test_pipeline = [ 39 | dict(type="LoadImageFromFile"), 40 | dict( 41 | type="MultiScaleFlipAug", 42 | img_scale=(512 * max_ratio, 512), 43 | flip=False, 44 | transforms=[ 45 | dict(type="Resize", keep_ratio=True), 46 | dict(type="RandomFlip"), 47 | dict(type="Normalize", **img_norm_cfg), 48 | dict(type="ImageToTensor", keys=["img"]), 49 | dict(type="Collect", keys=["img"]), 50 | ], 51 | ), 52 | ] 53 | data = dict( 54 | samples_per_gpu=4, 55 | workers_per_gpu=4, 56 | train=dict( 57 | type=dataset_type, 58 | data_root=data_root, 59 | img_dir="JPEGImages", 60 | ann_dir="SegmentationClassAug", 61 | split="ImageSets/Segmentation/trainaug.txt", 62 | pipeline=train_pipeline, 63 | ), 64 | val=dict( 65 | type=dataset_type, 66 | data_root=data_root, 67 | img_dir="JPEGImages", 68 | ann_dir="SegmentationClassAug", 69 | split="ImageSets/Segmentation/val.txt", 70 | pipeline=val_pipeline, 71 | ), 72 | test=dict( 73 | type=dataset_type, 74 | data_root=data_root, 75 | img_dir="JPEGImages", 76 | ann_dir="SegmentationClassAug", 77 | split="ImageSets/Segmentation/test.txt", 78 | pipeline=test_pipeline, 79 | ), 80 | ) 81 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/data/config/pascal_voc.yml: -------------------------------------------------------------------------------- 1 | - color: 2 | - 120 3 | - 120 4 | - 120 5 | id: 0 6 | name: background 7 | - color: 8 | - 180 9 | - 120 10 | - 120 11 | id: 1 12 | name: aeroplane 13 | - color: 14 | - 6 15 | - 230 16 | - 230 17 | id: 2 18 | name: bicycle 19 | - color: 20 | - 80 21 | - 50 22 | - 50 23 | id: 3 24 | name: bird 25 | - color: 26 | - 4 27 | - 200 28 | - 3 29 | id: 4 30 | name: boat 31 | - color: 32 | - 120 33 | - 120 34 | - 80 35 | id: 5 36 | name: bottle 37 | - color: 38 | - 140 39 | - 140 40 | - 140 41 | id: 6 42 | name: bus 43 | - color: 44 | - 204 45 | - 5 46 | - 255 47 | id: 7 48 | name: car 49 | - color: 50 | - 230 51 | - 230 52 | - 230 53 | id: 8 54 | name: cat 55 | - color: 56 | - 4 57 | - 250 58 | - 7 59 | id: 9 60 | name: chair 61 | - color: 62 | - 224 63 | - 5 64 | - 255 65 | id: 10 66 | name: cow 67 | - color: 68 | - 235 69 | - 255 70 | - 7 71 | id: 11 72 | name: table 73 | - color: 74 | - 150 75 | - 5 76 | - 61 77 | id: 12 78 | name: dog 79 | - color: 80 | - 120 81 | - 120 82 | - 70 83 | id: 13 84 | name: horse 85 | - color: 86 | - 8 87 | - 255 88 | - 51 89 | id: 14 90 | name: motorbike 91 | - color: 92 | - 255 93 | - 6 94 | - 82 95 | id: 15 96 | name: person 97 | - color: 98 | - 143 99 | - 255 100 | - 140 101 | id: 16 102 | name: pottedplant 103 | - color: 104 | - 204 105 | - 255 106 | - 4 107 | id: 17 108 | name: sheep 109 | - color: 110 | - 255 111 | - 51 112 | - 7 113 | id: 18 114 | name: sofa 115 | - color: 116 | - 204 117 | - 70 118 | - 3 119 | id: 19 120 | name: train 121 | - color: 122 | - 0 123 | - 102 124 | - 200 125 | id: 20 126 | name: tvmonitor 127 | 128 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/data/factory.py: -------------------------------------------------------------------------------- 1 | import segm.utils.torch as ptu 2 | 3 | from segm.data import ImagenetDataset 4 | from segm.data import ADE20KSegmentation 5 | from segm.data import PascalContextDataset 6 | from segm.data import PascalVOCDataset 7 | from segm.data import CityscapesDataset 8 | from segm.data import Loader 9 | from segm.data import COCODataset 10 | 11 | 12 | def create_dataset(dataset_kwargs): 13 | dataset_kwargs = dataset_kwargs.copy() 14 | dataset_name = dataset_kwargs.pop("dataset") 15 | batch_size = dataset_kwargs.pop("batch_size") 16 | num_workers = dataset_kwargs.pop("num_workers") 17 | split = dataset_kwargs.pop("split") 18 | 19 | # load dataset_name 20 | if dataset_name == "imagenet": 21 | dataset_kwargs.pop("patch_size") 22 | dataset = ImagenetDataset(split=split, **dataset_kwargs) 23 | elif dataset_name == "ade20k": 24 | dataset = ADE20KSegmentation(split=split, **dataset_kwargs) 25 | elif dataset_name == "pascal_context": 26 | dataset = PascalContextDataset(split=split, **dataset_kwargs) 27 | elif dataset_name == "pascal_voc": 28 | dataset = PascalVOCDataset(split=split, **dataset_kwargs) 29 | elif dataset_name == "coco": 30 | # print('dataset_kwargs: ', **dataset_kwargs) 31 | dataset = COCODataset(split=split, **dataset_kwargs) 32 | elif dataset_name == "cityscapes": 33 | dataset = CityscapesDataset(split=split, **dataset_kwargs) 34 | else: 35 | raise ValueError(f"Dataset {dataset_name} is unknown.") 36 | 37 | dataset = Loader( 38 | dataset=dataset, 39 | batch_size=batch_size, 40 | num_workers=num_workers, 41 | distributed=ptu.distributed, 42 | split=split, 43 | ) 44 | return dataset 45 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/data/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from pathlib import Path 4 | 5 | from torch.utils.data import Dataset 6 | from torchvision import datasets 7 | from torchvision import transforms 8 | from PIL import Image 9 | 10 | from segm.data import utils 11 | from segm.config import dataset_dir 12 | 13 | 14 | class ImagenetDataset(Dataset): 15 | def __init__( 16 | self, 17 | root_dir, 18 | image_size=224, 19 | crop_size=224, 20 | split="train", 21 | normalization="vit", 22 | ): 23 | super().__init__() 24 | assert image_size[0] == image_size[1] 25 | 26 | self.path = Path(root_dir) / split 27 | self.crop_size = crop_size 28 | self.image_size = image_size 29 | self.split = split 30 | self.normalization = normalization 31 | 32 | if split == "train": 33 | self.transform = transforms.Compose( 34 | [ 35 | transforms.RandomResizedCrop(self.crop_size, interpolation=3), 36 | transforms.RandomHorizontalFlip(), 37 | transforms.ToTensor(), 38 | ] 39 | ) 40 | else: 41 | self.transform = transforms.Compose( 42 | [ 43 | transforms.Resize(image_size[0] + 32, interpolation=3), 44 | transforms.CenterCrop(self.crop_size), 45 | transforms.ToTensor(), 46 | ] 47 | ) 48 | 49 | self.base_dataset = datasets.ImageFolder(self.path, self.transform) 50 | self.n_cls = 1000 51 | 52 | @property 53 | def unwrapped(self): 54 | return self 55 | 56 | def __len__(self): 57 | return len(self.base_dataset) 58 | 59 | def __getitem__(self, idx): 60 | im, target = self.base_dataset[idx] 61 | im = utils.rgb_normalize(im, self.normalization) 62 | return dict(im=im, target=target) 63 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/data/loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from torch.utils.data.distributed import DistributedSampler 3 | 4 | import segm.utils.torch as ptu 5 | 6 | 7 | class Loader(DataLoader): 8 | def __init__(self, dataset, batch_size, num_workers, distributed, split): 9 | if distributed: 10 | sampler = DistributedSampler(dataset, shuffle=True) 11 | super().__init__( 12 | dataset, 13 | batch_size=batch_size, 14 | shuffle=False, 15 | num_workers=num_workers, 16 | pin_memory=True, 17 | sampler=sampler, 18 | ) 19 | else: 20 | super().__init__( 21 | dataset, 22 | batch_size=batch_size, 23 | shuffle=True, 24 | num_workers=num_workers, 25 | pin_memory=True, 26 | ) 27 | 28 | self.base_dataset = self.dataset 29 | 30 | @property 31 | def unwrapped(self): 32 | return self.base_dataset.unwrapped 33 | 34 | def set_epoch(self, epoch): 35 | if isinstance(self.sampler, DistributedSampler): 36 | self.sampler.set_epoch(epoch) 37 | 38 | def get_diagnostics(self, logger): 39 | return self.base_dataset.get_diagnostics(logger) 40 | 41 | def get_snapshot(self): 42 | return self.base_dataset.get_snapshot() 43 | 44 | def end_epoch(self, epoch): 45 | return self.base_dataset.end_epoch(epoch) 46 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/data/pascal_context.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from segm.data.base import BaseMMSeg 4 | from segm.data import utils 5 | from segm.config import dataset_dir 6 | 7 | PASCAL_CONTEXT_CONFIG_PATH = Path(__file__).parent / "config" / "pascal_context.py" 8 | PASCAL_CONTEXT_CATS_PATH = Path(__file__).parent / "config" / "pascal_context.yml" 9 | 10 | 11 | class PascalContextDataset(BaseMMSeg): 12 | def __init__(self, image_size, crop_size, split, **kwargs): 13 | super().__init__( 14 | image_size, crop_size, split, PASCAL_CONTEXT_CONFIG_PATH, **kwargs 15 | ) 16 | self.names, self.colors = utils.dataset_cat_description( 17 | PASCAL_CONTEXT_CATS_PATH 18 | ) 19 | self.n_cls = 60 20 | self.ignore_label = 255 21 | self.reduce_zero_label = False 22 | 23 | def update_default_config(self, config): 24 | root_dir = dataset_dir() 25 | path = Path(root_dir) / "pcontext" 26 | config.data_root = path 27 | if self.split == "train": 28 | config.data.train.data_root = path / "VOCdevkit/VOC2010/" 29 | elif self.split == "val": 30 | config.data.val.data_root = path / "VOCdevkit/VOC2010/" 31 | elif self.split == "test": 32 | raise ValueError("Test split is not valid for Pascal Context dataset") 33 | config = super().update_default_config(config) 34 | return config 35 | 36 | def test_post_process(self, labels): 37 | return labels 38 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/data/pascal_voc.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from segm.data.base import BaseMMSeg 4 | from segm.data import utils 5 | from segm.config import dataset_dir 6 | 7 | PASCAL_VOC_CONFIG_PATH = Path(__file__).parent / "config" / "pascal_voc.py" 8 | PASCAL_VOC_CATS_PATH = Path(__file__).parent / "config" / "pascal_voc.yml" 9 | 10 | 11 | class PascalVOCDataset(BaseMMSeg): 12 | def __init__(self, image_size, crop_size, split, ann_dir=None, eval_split=None, **kwargs): 13 | self.ann_dir = ann_dir 14 | self.eval_split = eval_split 15 | super().__init__( 16 | image_size, crop_size, split, PASCAL_VOC_CONFIG_PATH, **kwargs 17 | ) 18 | self.names, self.colors = utils.dataset_cat_description( 19 | PASCAL_VOC_CATS_PATH 20 | ) 21 | self.n_cls = 21 22 | self.ignore_label = 255 23 | self.reduce_zero_label = False 24 | 25 | def update_default_config(self, config): 26 | root_dir = dataset_dir() 27 | path = Path(root_dir) / "voc12" 28 | config.data_root = path 29 | if self.split == "train": 30 | config.data.train.data_root = path / "VOCdevkit/VOC2012/" 31 | if self.ann_dir: 32 | config.data.train.ann_dir = self.ann_dir 33 | elif self.split == "val": 34 | config.data.val.data_root = path / "VOCdevkit/VOC2012/" 35 | if self.eval_split is not None: 36 | config.data.val.split = self.eval_split 37 | elif self.split == "test": 38 | raise ValueError("Test split is not valid for Pascal Context dataset") 39 | config = super().update_default_config(config) 40 | return config 41 | 42 | def test_post_process(self, labels): 43 | return labels 44 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/data/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms.functional as F 3 | import numpy as np 4 | import yaml 5 | from pathlib import Path 6 | 7 | IGNORE_LABEL = 255 8 | STATS = { 9 | "vit": {"mean": (0.5, 0.5, 0.5), "std": (0.5, 0.5, 0.5)}, 10 | "deit": {"mean": (0.485, 0.456, 0.406), "std": (0.229, 0.224, 0.225)}, 11 | "eva02": {"mean": (0.48145466, 0.4578275, 0.40821073), "std": (0.26862954, 0.26130258, 0.27577711)}, 12 | } 13 | 14 | 15 | def seg_to_rgb(seg, colors): 16 | im = torch.zeros((seg.shape[0], seg.shape[1], seg.shape[2], 3)).float() 17 | cls = torch.unique(seg) 18 | for cl in cls: 19 | color = colors[int(cl)] 20 | if len(color.shape) > 1: 21 | color = color[0] 22 | im[seg == cl] = color 23 | return im 24 | 25 | 26 | def dataset_cat_description(path, cmap=None): 27 | desc = yaml.load(open(path, "r"), Loader=yaml.FullLoader) 28 | colors = {} 29 | names = [] 30 | for i, cat in enumerate(desc): 31 | names.append(cat["name"]) 32 | if "color" in cat: 33 | colors[cat["id"]] = torch.tensor(cat["color"]).float() / 255 34 | else: 35 | colors[cat["id"]] = torch.tensor(cmap[cat["id"]]).float() 36 | colors[IGNORE_LABEL] = torch.tensor([0.0, 0.0, 0.0]).float() 37 | return names, colors 38 | 39 | 40 | def rgb_normalize(x, stats): 41 | """ 42 | x : C x * 43 | x \in [0, 1] 44 | """ 45 | return F.normalize(x, stats["mean"], stats["std"]) 46 | 47 | 48 | def rgb_denormalize(x, stats): 49 | """ 50 | x : N x C x * 51 | x \in [-1, 1] 52 | """ 53 | mean = torch.tensor(stats["mean"]) 54 | std = torch.tensor(stats["std"]) 55 | for i in range(3): 56 | x[:, i, :, :] = x[:, i, :, :] * std[i] + mean[i] 57 | return x 58 | 59 | 60 | def reduce_loss(loss, reduction): 61 | """Reduce loss as specified. 62 | 63 | Args: 64 | loss (Tensor): Elementwise loss tensor. 65 | reduction (str): Options are "none", "mean" and "sum". 66 | 67 | Return: 68 | Tensor: Reduced loss tensor. 69 | """ 70 | reduction_enum = torch.nn.functional._Reduction.get_enum(reduction) 71 | # none: 0, elementwise_mean:1, sum: 2 72 | if reduction_enum == 0: 73 | return loss 74 | elif reduction_enum == 1: 75 | return loss.mean() 76 | elif reduction_enum == 2: 77 | return loss.sum() 78 | 79 | 80 | def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): 81 | """Apply element-wise weight and reduce loss. 82 | 83 | Args: 84 | loss (Tensor): Element-wise loss. 85 | weight (Tensor): Element-wise weights. 86 | reduction (str): Same as built-in losses of PyTorch. 87 | avg_factor (float): Avarage factor when computing the mean of losses. 88 | 89 | Returns: 90 | Tensor: Processed loss values. 91 | """ 92 | # if weight is specified, apply element-wise weight 93 | if weight is not None: 94 | assert weight.dim() == loss.dim() 95 | if weight.dim() > 1: 96 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 97 | loss = loss * weight 98 | 99 | # if avg_factor is not specified, just reduce the loss 100 | if avg_factor is None: 101 | loss = reduce_loss(loss, reduction) 102 | else: 103 | # if reduction is mean, then average the loss by avg_factor 104 | if reduction == 'mean': 105 | loss = loss.sum() / avg_factor 106 | # if reduction is 'none', then do nothing, otherwise raise an error 107 | elif reduction != 'none': 108 | raise ValueError('avg_factor can not be used with reduction="sum"') 109 | return loss -------------------------------------------------------------------------------- /OnlineRetraining/segm/datasets/coco.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from mmseg.datasets.builder import DATASETS 4 | # from .builder import DATASETS 5 | from mmseg.datasets.custom import CustomDataset 6 | 7 | 8 | # from .custom import CustomDataset 9 | 10 | 11 | @DATASETS.register_module() 12 | class COCODataset(CustomDataset): 13 | """PascalContext dataset. 14 | 15 | In segmentation map annotation for PascalContext, 0 stands for background, 16 | which is included in 60 categories. ``reduce_zero_label`` is fixed to 17 | False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is 18 | fixed to '.png'. 19 | 20 | Args: 21 | split (str): Split txt file for PascalContext. 22 | """ 23 | 24 | # CLASSES = tuple([str(i) for i in range(91)]) 25 | 26 | CLASSES = ('background', 27 | 'person', 28 | 'bicycle', 29 | 'car', 30 | 'motorcycle', 31 | 'airplane', 32 | 'bus', 33 | 'train', 34 | 'truck', 35 | 'boat', 36 | 'traffic light', 37 | 'fire hydrant', 38 | 'street sign', 39 | 'stop sign', 40 | 'parking meter', 41 | 'bench', 42 | 'bird', 43 | 'cat', 44 | 'dog', 45 | 'horse', 46 | 'sheep', 47 | 'cow', 48 | 'elephant', 49 | 'bear', 50 | 'zebra', 51 | 'giraffe', 52 | 'hat', 53 | 'backpack', 54 | 'umbrella', 55 | 'shoe', 56 | 'eye glasses', 57 | 'handbag', 58 | 'tie', 59 | 'suitcase', 60 | 'frisbee', 61 | 'skis', 62 | 'snowboard', 63 | 'sports ball', 64 | 'kite', 65 | 'baseball bat', 66 | 'baseball glove', 67 | 'skateboard', 68 | 'surfboard', 69 | 'tennis racket', 70 | 'bottle', 71 | 'plate', 72 | 'wine glass', 73 | 'cup', 74 | 'fork', 75 | 'knife', 76 | 'spoon', 77 | 'bowl', 78 | 'banana', 79 | 'apple', 80 | 'sandwich', 81 | 'orange', 82 | 'broccoli', 83 | 'carrot', 84 | 'hot dog', 85 | 'pizza', 86 | 'donut', 87 | 'cake', 88 | 'chair', 89 | 'couch', 90 | 'potted plant', 91 | 'bed', 92 | 'mirror', 93 | 'dining table', 94 | 'window', 95 | 'desk', 96 | 'toilet', 97 | 'door', 98 | 'tv', 99 | 'laptop', 100 | 'mouse', 101 | 'remote', 102 | 'keyboard', 103 | 'cell phone', 104 | 'microwave', 105 | 'oven', 106 | 'toaster', 107 | 'sink', 108 | 'refrigerator', 109 | 'blender', 110 | 'book', 111 | 'clock', 112 | 'vase', 113 | 'scissors', 114 | 'teddy bear', 115 | 'hair drier', 116 | 'toothbrush') 117 | 118 | PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], 119 | [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], 120 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], 121 | [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], 122 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], 123 | [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], 124 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], 125 | [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], 126 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], 127 | [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], 128 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], 129 | [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], 130 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], 131 | [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], 132 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]] * 5 133 | 134 | def __init__(self, split, **kwargs): 135 | super(COCODataset, self).__init__( 136 | img_suffix='.jpg', 137 | seg_map_suffix='.png', 138 | split=split, 139 | reduce_zero_label=False, 140 | **kwargs) 141 | assert osp.exists(self.img_dir) and self.split is not None 142 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/dist_test.sh: -------------------------------------------------------------------------------- 1 | GPUS=$1 2 | NNODES=${NNODES:-1} 3 | NODE_RANK=${NODE_RANK:-0} 4 | PORT=${PORT:-29500} 5 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 6 | 7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 8 | DATASET="$(dirname $0)/../../data" \ 9 | python -m torch.distributed.launch \ 10 | --nnodes=$NNODES \ 11 | --node_rank=$NODE_RANK \ 12 | --master_addr=$MASTER_ADDR \ 13 | --nproc_per_node=$GPUS \ 14 | --master_port=$PORT \ 15 | $(dirname "$0")/eval/miou.py \ 16 | ${@:2} -------------------------------------------------------------------------------- /OnlineRetraining/segm/dist_train.sh: -------------------------------------------------------------------------------- 1 | GPUS=$1 2 | NNODES=${NNODES:-1} 3 | NODE_RANK=${NODE_RANK:-0} 4 | PORT=${PORT:-29500} 5 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 6 | 7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 8 | DATASET="$(dirname $0)/../../data" \ 9 | python -m torch.distributed.launch \ 10 | --nnodes=$NNODES \ 11 | --node_rank=$NODE_RANK \ 12 | --master_addr=$MASTER_ADDR \ 13 | --nproc_per_node=$GPUS \ 14 | --master_port=$PORT \ 15 | $(dirname "$0")/train.py \ 16 | ${@:2} -------------------------------------------------------------------------------- /OnlineRetraining/segm/engine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | from segm.utils.logger import MetricLogger 5 | from segm.metrics import gather_data, compute_metrics 6 | from segm.model import utils 7 | from segm.data.utils import IGNORE_LABEL, weight_reduce_loss 8 | import segm.utils.torch as ptu 9 | 10 | 11 | def train_one_epoch( 12 | model, 13 | data_loader, 14 | optimizer, 15 | lr_scheduler, 16 | epoch, 17 | amp_autocast, 18 | loss_scaler, 19 | GradientClipping=None, 20 | ): 21 | criterion = torch.nn.CrossEntropyLoss(ignore_index=IGNORE_LABEL, reduction="none") 22 | logger = MetricLogger(delimiter=" ") 23 | header = f"Epoch: [{epoch}]" 24 | print_freq = 100 25 | 26 | model.train() 27 | data_loader.set_epoch(epoch) 28 | num_updates = epoch * len(data_loader) 29 | for batch in logger.log_every(data_loader, print_freq, header): 30 | for param_group in optimizer.param_groups: 31 | if "lr_scale" in param_group: 32 | param_group["lr"] = param_group["lr"] * param_group["lr_scale"] 33 | 34 | im = batch["im"].to(ptu.device) 35 | seg_gt = batch["segmentation"].long().to(ptu.device) 36 | 37 | with amp_autocast(): 38 | seg_pred = model.forward(im) 39 | if GradientClipping is not None: 40 | ori_loss, loss = GradientClipping(seg_pred, seg_gt, criterion) 41 | logger.update( 42 | mean_loss=ori_loss.mean().item(), 43 | ) 44 | else: 45 | loss = criterion(seg_pred, seg_gt) 46 | 47 | loss = weight_reduce_loss( 48 | loss, weight=None, reduction="mean", avg_factor=None) 49 | 50 | loss_value = loss.item() 51 | if not math.isfinite(loss_value): 52 | print("Loss is {}, stopping training".format(loss_value), force=True) 53 | 54 | optimizer.zero_grad() 55 | if loss_scaler is not None: 56 | loss_scaler( 57 | loss, 58 | optimizer, 59 | parameters=model.parameters(), 60 | ) 61 | else: 62 | loss.backward() 63 | optimizer.step() 64 | 65 | num_updates += 1 66 | lr_scheduler.step_update(num_updates=num_updates) 67 | 68 | torch.cuda.synchronize() 69 | 70 | logger.update( 71 | loss=loss.item(), 72 | ) 73 | 74 | return logger 75 | 76 | 77 | @torch.no_grad() 78 | def evaluate( 79 | model, 80 | data_loader, 81 | val_seg_gt, 82 | window_size, 83 | window_stride, 84 | amp_autocast, 85 | ): 86 | model_without_ddp = model 87 | if hasattr(model, "module"): 88 | model_without_ddp = model.module 89 | logger = MetricLogger(delimiter=" ") 90 | header = "Eval:" 91 | print_freq = 50 92 | 93 | val_seg_pred = {} 94 | model.eval() 95 | for batch in logger.log_every(data_loader, print_freq, header): 96 | ims = [im.to(ptu.device) for im in batch["im"]] 97 | ims_metas = batch["im_metas"] 98 | ori_shape = ims_metas[0]["ori_shape"] 99 | ori_shape = (ori_shape[0].item(), ori_shape[1].item()) 100 | filename = batch["im_metas"][0]["ori_filename"][0] 101 | 102 | with amp_autocast(): 103 | seg_pred = utils.inference( 104 | model_without_ddp, 105 | ims, 106 | ims_metas, 107 | ori_shape, 108 | window_size, 109 | window_stride, 110 | batch_size=1, 111 | ) 112 | seg_pred = seg_pred.argmax(0) 113 | 114 | seg_pred = seg_pred.cpu().numpy() 115 | val_seg_pred[filename] = seg_pred 116 | 117 | val_seg_pred = gather_data(val_seg_pred) 118 | scores = compute_metrics( 119 | val_seg_pred, 120 | val_seg_gt, 121 | data_loader.unwrapped.n_cls, 122 | ignore_index=IGNORE_LABEL, 123 | distributed=ptu.distributed, 124 | ) 125 | 126 | for k, v in scores.items(): 127 | logger.update(**{f"{k}": v, "n": 1}) 128 | 129 | return logger, scores['mean_iou'] 130 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/eval/accuracy.py: -------------------------------------------------------------------------------- 1 | import click 2 | import torch 3 | 4 | import segm.utils.torch as ptu 5 | 6 | from segm.utils.logger import MetricLogger 7 | 8 | from segm.model.factory import create_vit 9 | from segm.data.factory import create_dataset 10 | from segm.data.utils import STATS 11 | from segm.metrics import accuracy 12 | from segm import config 13 | 14 | 15 | def compute_labels(model, batch): 16 | im = batch["im"] 17 | target = batch["target"] 18 | 19 | with torch.no_grad(): 20 | with torch.cuda.amp.autocast(): 21 | output = model.forward(im) 22 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 23 | 24 | return acc1.item(), acc5.item() 25 | 26 | 27 | def eval_dataset(model, dataset_kwargs): 28 | db = create_dataset(dataset_kwargs) 29 | print_freq = 20 30 | header = "" 31 | logger = MetricLogger(delimiter=" ") 32 | 33 | for batch in logger.log_every(db, print_freq, header): 34 | for k, v in batch.items(): 35 | batch[k] = v.to(ptu.device) 36 | acc1, acc5 = compute_labels(model, batch) 37 | batch_size = batch["im"].size(0) 38 | logger.update(acc1=acc1, n=batch_size) 39 | logger.update(acc5=acc5, n=batch_size) 40 | print(f"Imagenet accuracy: {logger}") 41 | 42 | 43 | @click.command() 44 | @click.argument("backbone", type=str) 45 | @click.option("--imagenet-dir", type=str) 46 | @click.option("-bs", "--batch-size", default=32, type=int) 47 | @click.option("-nw", "--num-workers", default=10, type=int) 48 | @click.option("-gpu", "--gpu/--no-gpu", default=True, is_flag=True) 49 | def main(backbone, imagenet_dir, batch_size, num_workers, gpu): 50 | ptu.set_gpu_mode(gpu) 51 | cfg = config.load_config() 52 | cfg = cfg["model"][backbone] 53 | cfg["backbone"] = backbone 54 | cfg["image_size"] = (cfg["image_size"], cfg["image_size"]) 55 | 56 | dataset_kwargs = dict( 57 | dataset="imagenet", 58 | root_dir=imagenet_dir, 59 | image_size=cfg["image_size"], 60 | crop_size=cfg["image_size"], 61 | patch_size=cfg["patch_size"], 62 | batch_size=batch_size, 63 | num_workers=num_workers, 64 | split="val", 65 | normalization=STATS[cfg["normalization"]], 66 | ) 67 | 68 | model = create_vit(cfg) 69 | model.to(ptu.device) 70 | model.eval() 71 | eval_dataset(model, dataset_kwargs) 72 | 73 | 74 | if __name__ == "__main__": 75 | main() 76 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/eval/densecrf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def crf_inference_voc12(img, probs, t=10, scale_factor=1, labels=21): 5 | import pydensecrf.densecrf as dcrf 6 | from pydensecrf.utils import unary_from_softmax 7 | 8 | h, w = img.shape[:2] 9 | n_labels = labels 10 | 11 | d = dcrf.DenseCRF2D(w, h, n_labels) 12 | 13 | unary = unary_from_softmax(probs) 14 | unary = np.ascontiguousarray(unary) 15 | 16 | img_c = np.ascontiguousarray(img) 17 | 18 | d.setUnaryEnergy(unary) 19 | ## voc12 20 | d.addPairwiseGaussian(sxy=3 / scale_factor, compat=3) 21 | d.addPairwiseBilateral(sxy=83 / scale_factor, srgb=5, rgbim=np.ascontiguousarray(np.copy(img_c)), compat=3) 22 | # d.addPairwiseGaussian(sxy=3 / scale_factor, compat=3) 23 | # d.addPairwiseBilateral(sxy=80 / scale_factor, srgb=13, rgbim=np.copy(img_c), compat=10) 24 | # d.addPairwiseGaussian(sxy=3/scale_factor, compat=3) 25 | # d.addPairwiseBilateral(sxy=80/scale_factor, srgb=13, rgbim=np.ascontiguousarray(np.copy(img_c)), compat=10) 26 | # d.addPairwiseGaussian(sxy=3/scale_factor, compat=3) 27 | # d.addPairwiseBilateral(sxy=32/scale_factor, srgb=13, rgbim=np.copy(img_c), compat=10) 28 | # d.addPairwiseGaussian(sxy=1 / scale_factor, compat=3) 29 | # d.addPairwiseBilateral(sxy=67 / scale_factor, srgb=3, rgbim=np.copy(img_c), compat=4) 30 | 31 | ## coco 32 | # d.addPairwiseGaussian(sxy=3, compat=3) 33 | # d.addPairwiseBilateral(sxy=50, srgb=5, rgbim=np.ascontiguousarray(np.copy(img_c)), compat=10) 34 | 35 | Q = d.inference(t) 36 | 37 | return np.array(Q).reshape((n_labels, h, w)) 38 | 39 | 40 | def crf_inference_coco(img, probs, t=10, scale_factor=1, labels=21): 41 | import pydensecrf.densecrf as dcrf 42 | from pydensecrf.utils import unary_from_softmax 43 | 44 | h, w = img.shape[:2] 45 | n_labels = labels 46 | 47 | d = dcrf.DenseCRF2D(w, h, n_labels) 48 | 49 | unary = unary_from_softmax(probs) 50 | unary = np.ascontiguousarray(unary) 51 | 52 | img_c = np.ascontiguousarray(img) 53 | 54 | d.setUnaryEnergy(unary) 55 | ## voc12 56 | # d.addPairwiseGaussian(sxy=4 / scale_factor, compat=3) 57 | # d.addPairwiseBilateral(sxy=83 / scale_factor, srgb=5, rgbim=np.copy(img_c), compat=3) 58 | # d.addPairwiseGaussian(sxy=3/scale_factor, compat=3) 59 | # d.addPairwiseBilateral(sxy=80/scale_factor, srgb=13, rgbim=np.copy(img_c), compat=10) 60 | # d.addPairwiseGaussian(sxy=3/scale_factor, compat=3) 61 | # d.addPairwiseBilateral(sxy=32/scale_factor, srgb=13, rgbim=np.copy(img_c), compat=10) 62 | ## coco 63 | d.addPairwiseGaussian(sxy=1 / scale_factor, compat=3) 64 | d.addPairwiseBilateral(sxy=67 / scale_factor, srgb=3, rgbim=np.copy(img_c), compat=4) 65 | # d.addPairwiseGaussian(sxy=3, compat=3) 66 | # d.addPairwiseBilateral(sxy=50, srgb=5, rgbim=np.copy(img_c), compat=10) 67 | 68 | Q = d.inference(t) 69 | 70 | return np.array(Q).reshape((n_labels, h, w)) 71 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/inference.py: -------------------------------------------------------------------------------- 1 | import click 2 | from tqdm import tqdm 3 | from pathlib import Path 4 | from PIL import Image 5 | import numpy as np 6 | import torchvision.transforms.functional as F 7 | 8 | import segm.utils.torch as ptu 9 | 10 | from segm.data.utils import STATS 11 | from segm.data.ade20k import ADE20K_CATS_PATH 12 | from segm.data.utils import dataset_cat_description, seg_to_rgb 13 | 14 | from segm.model.factory import load_model 15 | from segm.model.utils import inference 16 | 17 | 18 | @click.command() 19 | @click.option("--model-path", type=str) 20 | @click.option("--input-dir", "-i", type=str, help="folder with input images") 21 | @click.option("--output-dir", "-o", type=str, help="folder with output images") 22 | @click.option("--gpu/--cpu", default=True, is_flag=True) 23 | def main(model_path, input_dir, output_dir, gpu): 24 | ptu.set_gpu_mode(gpu) 25 | 26 | model_dir = Path(model_path).parent 27 | model, variant = load_model(model_path) 28 | model.to(ptu.device) 29 | 30 | normalization_name = variant["dataset_kwargs"]["normalization"] 31 | normalization = STATS[normalization_name] 32 | cat_names, cat_colors = dataset_cat_description(ADE20K_CATS_PATH) 33 | 34 | input_dir = Path(input_dir) 35 | output_dir = Path(output_dir) 36 | output_dir.mkdir(exist_ok=True) 37 | 38 | list_dir = list(input_dir.iterdir()) 39 | for filename in tqdm(list_dir, ncols=80): 40 | pil_im = Image.open(filename).copy() 41 | im = F.pil_to_tensor(pil_im).float() / 255 42 | im = F.normalize(im, normalization["mean"], normalization["std"]) 43 | im = im.to(ptu.device).unsqueeze(0) 44 | 45 | im_meta = dict(flip=False) 46 | logits = inference( 47 | model, 48 | [im], 49 | [im_meta], 50 | ori_shape=im.shape[2:4], 51 | window_size=variant["inference_kwargs"]["window_size"], 52 | window_stride=variant["inference_kwargs"]["window_stride"], 53 | batch_size=2, 54 | ) 55 | seg_map = logits.argmax(0, keepdim=True) 56 | seg_rgb = seg_to_rgb(seg_map, cat_colors) 57 | seg_rgb = (255 * seg_rgb.cpu().numpy()).astype(np.uint8) 58 | pil_seg = Image.fromarray(seg_rgb[0]) 59 | 60 | pil_blend = Image.blend(pil_im, pil_seg, 0.5).convert("RGB") 61 | pil_blend.save(output_dir / filename.name) 62 | 63 | 64 | if __name__ == "__main__": 65 | main() 66 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.distributed as dist 4 | import segm.utils.torch as ptu 5 | 6 | import os 7 | import pickle as pkl 8 | from pathlib import Path 9 | import tempfile 10 | import shutil 11 | from mmseg.core import mean_iou 12 | 13 | """ 14 | ImageNet classifcation accuracy 15 | """ 16 | 17 | 18 | def accuracy(output, target, topk=(1,)): 19 | """ 20 | https://github.com/pytorch/examples/blob/master/imagenet/main.py 21 | Computes the accuracy over the k top predictions for the specified values of k 22 | """ 23 | with torch.no_grad(): 24 | maxk = max(topk) 25 | batch_size = target.size(0) 26 | 27 | _, pred = output.topk(maxk, 1, True, True) 28 | pred = pred.t() 29 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 30 | 31 | res = [] 32 | for k in topk: 33 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 34 | correct_k /= batch_size 35 | res.append(correct_k) 36 | return res 37 | 38 | 39 | """ 40 | Segmentation mean IoU 41 | based on collect_results_cpu 42 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/apis/test.py#L160-L200 43 | """ 44 | 45 | 46 | def gather_data(seg_pred, tmp_dir=None): 47 | """ 48 | distributed data gathering 49 | prediction and ground truth are stored in a common tmp directory 50 | and loaded on the master node to compute metrics 51 | """ 52 | if tmp_dir is None: 53 | tmpprefix = os.path.expandvars("temp") 54 | else: 55 | tmpprefix = os.path.expandvars(tmp_dir) 56 | MAX_LEN = 512 57 | # 32 is whitespace 58 | dir_tensor = torch.full((MAX_LEN,), 32, dtype=torch.uint8, device=ptu.device) 59 | if ptu.dist_rank == 0: 60 | tmpdir = tempfile.mkdtemp(prefix=tmpprefix) 61 | tmpdir = torch.tensor( 62 | bytearray(tmpdir.encode()), dtype=torch.uint8, device=ptu.device 63 | ) 64 | dir_tensor[: len(tmpdir)] = tmpdir 65 | # broadcast tmpdir from 0 to to the other nodes 66 | dist.broadcast(dir_tensor, 0) 67 | tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() 68 | tmpdir = Path(tmpdir) 69 | """ 70 | Save results in temp file and load them on main process 71 | """ 72 | tmp_file = tmpdir / f"part_{ptu.dist_rank}.pkl" 73 | pkl.dump(seg_pred, open(tmp_file, "wb")) 74 | dist.barrier() 75 | seg_pred = {} 76 | if ptu.dist_rank == 0: 77 | for i in range(ptu.world_size): 78 | part_seg_pred = pkl.load(open(tmpdir / f"part_{i}.pkl", "rb")) 79 | seg_pred.update(part_seg_pred) 80 | shutil.rmtree(tmpdir) 81 | return seg_pred 82 | 83 | 84 | def compute_metrics( 85 | seg_pred, 86 | seg_gt, 87 | n_cls, 88 | ignore_index=None, 89 | ret_cat_iou=False, 90 | tmp_dir=None, 91 | distributed=False, 92 | ): 93 | ret_metrics_mean = torch.zeros(3, dtype=float, device=ptu.device) 94 | if ptu.dist_rank == 0: 95 | list_seg_pred = [] 96 | list_seg_gt = [] 97 | keys = sorted(seg_pred.keys()) 98 | for k in keys: 99 | list_seg_pred.append(np.asarray(seg_pred[k])) 100 | list_seg_gt.append(np.asarray(seg_gt[k])) 101 | ret_metrics = mean_iou( 102 | results=list_seg_pred, 103 | gt_seg_maps=list_seg_gt, 104 | num_classes=n_cls, 105 | ignore_index=ignore_index, 106 | ) 107 | ret_metrics = [ret_metrics["aAcc"], ret_metrics["Acc"], ret_metrics["IoU"]] 108 | ret_metrics_mean = torch.tensor( 109 | [ 110 | np.round(np.nanmean(ret_metric.astype(np.float)) * 100, 2) 111 | for ret_metric in ret_metrics 112 | ], 113 | dtype=float, 114 | device=ptu.device, 115 | ) 116 | cat_iou = ret_metrics[2] 117 | # broadcast metrics from 0 to all nodes 118 | if distributed: 119 | dist.broadcast(ret_metrics_mean, 0) 120 | pix_acc, mean_acc, miou = ret_metrics_mean 121 | ret = dict(pixel_accuracy=pix_acc, mean_accuracy=mean_acc, mean_iou=miou) 122 | if ret_cat_iou and ptu.dist_rank == 0: 123 | ret["cat_iou"] = cat_iou 124 | return ret 125 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/model/blocks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from 2020 Ross Wightman 3 | https://github.com/rwightman/pytorch-image-models 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | from einops import rearrange 9 | from pathlib import Path 10 | 11 | import torch.nn.functional as F 12 | 13 | from timm.models.layers import DropPath 14 | from torch import Tensor 15 | from typing import Union 16 | 17 | import xformers.ops as xops 18 | 19 | class FeedForward(nn.Module): 20 | def __init__(self, dim, hidden_dim, dropout, out_dim=None): 21 | super().__init__() 22 | self.fc1 = nn.Linear(dim, hidden_dim) 23 | self.act = nn.GELU() 24 | if out_dim is None: 25 | out_dim = dim 26 | self.fc2 = nn.Linear(hidden_dim, out_dim) 27 | self.drop = nn.Dropout(dropout) 28 | 29 | @property 30 | def unwrapped(self): 31 | return self 32 | 33 | def forward(self, x): 34 | x = self.fc1(x) 35 | x = self.act(x) 36 | x = self.drop(x) 37 | x = self.fc2(x) 38 | x = self.drop(x) 39 | return x 40 | 41 | 42 | class Attention(nn.Module): 43 | def __init__(self, dim, heads, dropout): 44 | super().__init__() 45 | self.heads = heads 46 | head_dim = dim // heads 47 | self.scale = head_dim ** -0.5 48 | self.attn = None 49 | 50 | self.qkv = nn.Linear(dim, dim * 3) 51 | self.attn_drop = nn.Dropout(dropout) 52 | self.proj = nn.Linear(dim, dim) 53 | self.proj_drop = nn.Dropout(dropout) 54 | 55 | @property 56 | def unwrapped(self): 57 | return self 58 | 59 | def forward(self, x, mask=None): 60 | B, N, C = x.shape 61 | qkv = ( 62 | self.qkv(x) 63 | .reshape(B, N, 3, self.heads, C // self.heads) 64 | .permute(2, 0, 3, 1, 4) 65 | ) 66 | q, k, v = ( 67 | qkv[0], 68 | qkv[1], 69 | qkv[2], 70 | ) 71 | 72 | attn = (q @ k.transpose(-2, -1)) * self.scale 73 | attn = attn.softmax(dim=-1) 74 | attn = self.attn_drop(attn) 75 | 76 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 77 | x = self.proj(x) 78 | x = self.proj_drop(x) 79 | 80 | return x, attn 81 | 82 | 83 | class Block(nn.Module): 84 | def __init__(self, dim, heads, mlp_dim, dropout, drop_path, norm_layer=nn.LayerNorm): 85 | super().__init__() 86 | self.norm1 = norm_layer(dim) 87 | self.norm2 = norm_layer(dim) 88 | self.attn = Attention(dim, heads, dropout) 89 | self.mlp = FeedForward(dim, mlp_dim, dropout) 90 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 91 | 92 | def forward(self, x, mask=None, return_attention=False): 93 | y, attn = self.attn(self.norm1(x), mask) 94 | if return_attention: 95 | return attn 96 | x = x + self.drop_path(y) 97 | x = x + self.drop_path(self.mlp(self.norm2(x))) 98 | return x 99 | 100 | 101 | 102 | class LayerScale(nn.Module): 103 | def __init__( 104 | self, 105 | dim: int, 106 | init_values: Union[float, Tensor] = 1e-5, 107 | inplace: bool = False, 108 | ) -> None: 109 | super().__init__() 110 | self.inplace = inplace 111 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 112 | 113 | def forward(self, x: Tensor) -> Tensor: 114 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 115 | 116 | 117 | class DINOBlock(Block): 118 | def __init__(self, dim, heads, mlp_dim, dropout, drop_path, 119 | init_values=1e-5, use_bn=False): 120 | super().__init__(dim, heads, mlp_dim, dropout, drop_path) 121 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 122 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 123 | 124 | def forward(self, x, mask=None, return_attention=False): 125 | y, attn = self.attn(self.norm1(x), mask) 126 | y = self.ls1(y) 127 | if return_attention: 128 | return attn 129 | x = x + self.drop_path(y) 130 | x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x)))) 131 | return x 132 | 133 | 134 | class EVA02Attention(Attention): 135 | def __init__(self, dim, *args, rope=None): 136 | super().__init__(dim, *args) 137 | self.qkv = nn.Linear(dim, dim * 3, bias=False) 138 | self.q_bias = nn.Parameter(torch.zeros(dim)) 139 | self.v_bias = nn.Parameter(torch.zeros(dim)) 140 | 141 | self.rope = rope 142 | 143 | def forward(self, x, mask=None): 144 | B, N, C = x.shape 145 | qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) 146 | qkv = ( 147 | (self.qkv(x) + qkv_bias) 148 | .reshape(B, N, 3, self.heads, C // self.heads) 149 | .permute(2, 0, 3, 1, 4) 150 | ) 151 | q, k, v = ( 152 | qkv[0], 153 | qkv[1], 154 | qkv[2], 155 | ) 156 | 157 | if self.rope: 158 | q_t = q[:, :, 1:, :] 159 | ro_q_t = self.rope(q_t) 160 | q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v) 161 | 162 | k_t = k[:, :, 1:, :] 163 | ro_k_t = self.rope(k_t) 164 | k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v) 165 | 166 | attn = (q @ k.transpose(-2, -1)) * self.scale 167 | attn = attn.softmax(dim=-1) 168 | attn = self.attn_drop(attn) 169 | 170 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 171 | x = self.proj(x) 172 | x = self.proj_drop(x) 173 | 174 | return x, attn 175 | 176 | class EVA02Block(Block): 177 | def __init__(self, dim, heads, mlp_dim, dropout, drop_path, 178 | init_values=1e-5, use_bn=False, rope=None, norm_layer=nn.LayerNorm): 179 | super().__init__(dim, heads, mlp_dim, dropout, drop_path, norm_layer=norm_layer) 180 | self.attn = EVA02Attention(dim, heads, dropout, rope=rope) 181 | self.mlp = xops.SwiGLU( 182 | in_features=dim, 183 | hidden_features=mlp_dim 184 | ) # hidden_features: 2/3 -------------------------------------------------------------------------------- /OnlineRetraining/segm/model/decoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | 7 | from timm.models.layers import trunc_normal_ 8 | 9 | from segm.model.blocks import Block, FeedForward 10 | from segm.model.utils import init_weights 11 | 12 | 13 | class DecoderLinear(nn.Module): 14 | def __init__(self, n_cls, patch_size, d_encoder): 15 | super().__init__() 16 | 17 | self.d_encoder = d_encoder 18 | self.patch_size = patch_size 19 | self.n_cls = n_cls 20 | 21 | self.head = nn.Linear(self.d_encoder, n_cls) 22 | self.apply(init_weights) 23 | 24 | @torch.jit.ignore 25 | def no_weight_decay(self): 26 | return set() 27 | 28 | def forward(self, x, im_size): 29 | H, W = im_size 30 | GS = H // self.patch_size 31 | x = self.head(x) 32 | x = rearrange(x, "b (h w) c -> b c h w", h=GS) 33 | 34 | return x 35 | 36 | 37 | class MaskTransformer(nn.Module): 38 | def __init__( 39 | self, 40 | n_cls, 41 | patch_size, 42 | d_encoder, 43 | n_layers, 44 | n_heads, 45 | d_model, 46 | d_ff, 47 | drop_path_rate, 48 | dropout, 49 | ): 50 | super().__init__() 51 | self.d_encoder = d_encoder 52 | self.patch_size = patch_size 53 | self.n_layers = n_layers 54 | self.n_cls = n_cls 55 | self.d_model = d_model 56 | self.d_ff = d_ff 57 | self.scale = d_model ** -0.5 58 | 59 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)] 60 | self.blocks = nn.ModuleList( 61 | [Block(d_model, n_heads, d_ff, dropout, dpr[i]) for i in range(n_layers)] 62 | ) 63 | 64 | self.cls_emb = nn.Parameter(torch.randn(1, n_cls, d_model)) 65 | self.proj_dec = nn.Linear(d_encoder, d_model) 66 | 67 | self.proj_patch = nn.Parameter(self.scale * torch.randn(d_model, d_model)) 68 | self.proj_classes = nn.Parameter(self.scale * torch.randn(d_model, d_model)) 69 | 70 | self.decoder_norm = nn.LayerNorm(d_model) 71 | self.mask_norm = nn.LayerNorm(n_cls) 72 | 73 | self.apply(init_weights) 74 | trunc_normal_(self.cls_emb, std=0.02) 75 | 76 | @torch.jit.ignore 77 | def no_weight_decay(self): 78 | return {"cls_emb"} 79 | 80 | def forward(self, x, im_size): 81 | H, W = im_size 82 | GS = H // self.patch_size 83 | 84 | x = self.proj_dec(x) 85 | cls_emb = self.cls_emb.expand(x.size(0), -1, -1) 86 | x = torch.cat((x, cls_emb), 1) 87 | for blk in self.blocks: 88 | x = blk(x) 89 | x = self.decoder_norm(x) 90 | 91 | patches, cls_seg_feat = x[:, : -self.n_cls], x[:, -self.n_cls:] 92 | patches = patches @ self.proj_patch 93 | cls_seg_feat = cls_seg_feat @ self.proj_classes 94 | 95 | patches = patches / patches.norm(dim=-1, keepdim=True) 96 | cls_seg_feat = cls_seg_feat / cls_seg_feat.norm(dim=-1, keepdim=True) 97 | 98 | masks = patches @ cls_seg_feat.transpose(1, 2) 99 | masks = self.mask_norm(masks) 100 | masks = rearrange(masks, "b (h w) n -> b n h w", h=int(GS)) 101 | 102 | return masks 103 | 104 | def get_attention_map(self, x, layer_id): 105 | if layer_id >= self.n_layers or layer_id < 0: 106 | raise ValueError( 107 | f"Provided layer_id: {layer_id} is not valid. 0 <= {layer_id} < {self.n_layers}." 108 | ) 109 | x = self.proj_dec(x) 110 | cls_emb = self.cls_emb.expand(x.size(0), -1, -1) 111 | x = torch.cat((x, cls_emb), 1) 112 | for i, blk in enumerate(self.blocks): 113 | if i < layer_id: 114 | x = blk(x) 115 | else: 116 | return blk(x, return_attention=True) 117 | 118 | 119 | class GradientClipping(nn.Module): 120 | def __init__(self, start_value, patch_size): 121 | super().__init__() 122 | self.start_value = start_value 123 | self.patch_size = patch_size 124 | 125 | def forward(self, seg_pred, seg_gt, criterion): 126 | ori_loss = criterion(seg_pred, seg_gt) 127 | detach_loss = ori_loss.detach().clone() 128 | 129 | mean_loss = detach_loss.mean() 130 | 131 | # set start loss clamp threshold 132 | if mean_loss > self.start_value: 133 | return ori_loss, ori_loss 134 | 135 | b, h, w = detach_loss.shape 136 | 137 | # all batch average 138 | detach_loss = detach_loss.mean(dim=0).unsqueeze(0) 139 | local_mean = F.avg_pool2d(detach_loss.unsqueeze(1), kernel_size=self.patch_size, 140 | stride=self.patch_size, padding=h % self.patch_size, 141 | count_include_pad=False).squeeze(1) 142 | local_mean = torch.maximum(local_mean, mean_loss) 143 | local_mean = torch.repeat_interleave(local_mean, b, dim=0) 144 | local_mean = torch.repeat_interleave(local_mean, self.patch_size, dim=1) 145 | local_mean = torch.repeat_interleave(local_mean, self.patch_size, dim=2) 146 | 147 | clamp_loss = ori_loss - local_mean 148 | clamp_loss = torch.clamp(clamp_loss, None, 0) 149 | loss = clamp_loss + local_mean 150 | 151 | return ori_loss, loss 152 | 153 | 154 | class MultiMaskTransformer(MaskTransformer): 155 | def __init__(self, 156 | n_cls, 157 | patch_size, 158 | d_encoder, 159 | n_layers, 160 | n_heads, 161 | d_model, 162 | d_ff, 163 | drop_path_rate, 164 | dropout): 165 | super(MultiMaskTransformer, self).__init__(n_cls, 166 | patch_size, 167 | d_encoder, 168 | n_layers, 169 | n_heads, 170 | d_model, 171 | d_ff, 172 | drop_path_rate, 173 | dropout, ) 174 | 175 | def forward(self, x, im_size): 176 | H, W = im_size 177 | GS = H // self.patch_size 178 | 179 | x = self.proj_dec(x) 180 | cls_emb = self.cls_emb.expand(x.size(0), -1, -1) 181 | x = torch.cat((x, cls_emb), 1) 182 | for blk in self.blocks[:-1]: 183 | x = blk(x) 184 | x2 = x 185 | for blk in self.blocks[-1:]: 186 | x2 = blk(x2) 187 | 188 | masks1 = self.cls_forward(x, GS) 189 | masks2 = self.cls_forward(x2, GS) 190 | 191 | return masks1, masks2 192 | 193 | def cls_forward(self, x, GS): 194 | x = self.decoder_norm(x) 195 | 196 | patches, cls_seg_feat = x[:, : -self.n_cls], x[:, -self.n_cls:] 197 | patches = patches @ self.proj_patch 198 | cls_seg_feat = cls_seg_feat @ self.proj_classes 199 | 200 | patches = patches / patches.norm(dim=-1, keepdim=True) 201 | cls_seg_feat = cls_seg_feat / cls_seg_feat.norm(dim=-1, keepdim=True) 202 | 203 | masks = patches @ cls_seg_feat.transpose(1, 2) 204 | masks = self.mask_norm(masks) 205 | masks = rearrange(masks, "b (h w) n -> b n h w", h=int(GS)) 206 | 207 | return masks 208 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/model/factory.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import yaml 3 | import torch 4 | import math 5 | import os 6 | import torch.nn as nn 7 | 8 | from timm.models.helpers import load_pretrained, load_custom_pretrained 9 | from timm.models.vision_transformer import default_cfgs 10 | from timm.models.registry import register_model 11 | from timm.models.vision_transformer import _create_vision_transformer 12 | 13 | from segm.model.vit import VisionTransformer, DINOV2VisionTransformer, EVA02VisionTransformer 14 | from segm.model.utils import checkpoint_filter_fn 15 | from segm.model.decoder import DecoderLinear 16 | from segm.model.decoder import MaskTransformer, MultiMaskTransformer 17 | from segm.model.segmenter import Segmenter, MultiSegmenter 18 | import segm.utils.torch as ptu 19 | 20 | from apex.normalization import FusedLayerNorm 21 | 22 | # 添加多个键值对 23 | 24 | default_cfgs.update({ 25 | "dino_small_patch16_224": {"url": "https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"}, 26 | "dinov2_small_patch16_224": {"url": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth"}, 27 | "eva02_small_patch16_224": {"url": "eva02_S_pt_in21k_p14.pt"}, 28 | "eva02_tiny_patch16_224": {"url": "eva02_T_pt_in21k_p14.pt"}, 29 | }) 30 | 31 | @register_model 32 | def vit_base_patch8_384(pretrained=False, **kwargs): 33 | """ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 34 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 35 | """ 36 | model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) 37 | model = _create_vision_transformer( 38 | "vit_base_patch8_384", 39 | pretrained=pretrained, 40 | default_cfg=dict( 41 | url="", 42 | input_size=(3, 384, 384), 43 | mean=(0.5, 0.5, 0.5), 44 | std=(0.5, 0.5, 0.5), 45 | num_classes=1000, 46 | ), 47 | **model_kwargs, 48 | ) 49 | return model 50 | 51 | 52 | def create_vit(model_cfg): 53 | model_cfg = model_cfg.copy() 54 | backbone = model_cfg.pop("backbone") 55 | 56 | normalization = model_cfg.pop("normalization") 57 | model_cfg["n_cls"] = 1000 58 | mlp_expansion_ratio = 4 59 | model_cfg["d_ff"] = mlp_expansion_ratio * model_cfg["d_model"] 60 | 61 | if backbone in default_cfgs: 62 | default_cfg = default_cfgs[backbone] 63 | else: 64 | default_cfg = dict( 65 | pretrained=False, 66 | num_classes=1000, 67 | drop_rate=0.0, 68 | drop_path_rate=0.0, 69 | drop_block_rate=None, 70 | ) 71 | 72 | default_cfg["input_size"] = ( 73 | 3, 74 | model_cfg["image_size"][0], 75 | model_cfg["image_size"][1], 76 | ) 77 | if "dinov2" in backbone: 78 | model = DINOV2VisionTransformer(**model_cfg) 79 | elif "eva02" in backbone: 80 | mlp_expansion_ratio = 4*2/3 81 | model_cfg["d_ff"] = int(mlp_expansion_ratio * model_cfg["d_model"]) 82 | model_cfg["norm_layer"] = FusedLayerNorm 83 | model = EVA02VisionTransformer(**model_cfg) 84 | else: 85 | model = VisionTransformer(**model_cfg) 86 | 87 | from torch.hub import get_dir 88 | hub_dir = get_dir() 89 | model_dir = os.path.join(hub_dir, 'checkpoints') 90 | if backbone == "vit_base_patch8_384": 91 | path = os.path.join(model_dir, "vit_base_patch8_384.pth") 92 | state_dict = torch.load(path, map_location="cpu") 93 | filtered_dict = checkpoint_filter_fn(state_dict, model) 94 | model.load_state_dict(filtered_dict, strict=True) 95 | elif "deit" in backbone: 96 | load_pretrained(model, default_cfg, filter_fn=checkpoint_filter_fn) 97 | elif "dino" in backbone: 98 | # without head 99 | load_pretrained(model, default_cfg, filter_fn=checkpoint_filter_fn, strict=False) 100 | elif "eva02" in backbone: 101 | path = os.path.join(model_dir, default_cfg["url"]) 102 | state_dict = torch.load(path, map_location="cpu") 103 | filtered_dict = checkpoint_filter_fn(state_dict, model) 104 | model.load_state_dict(filtered_dict, strict=False) 105 | else: 106 | load_custom_pretrained(model, default_cfg) 107 | 108 | return model 109 | 110 | 111 | def create_decoder(encoder, decoder_cfg): 112 | decoder_cfg = decoder_cfg.copy() 113 | name = decoder_cfg.pop("name") 114 | decoder_cfg["d_encoder"] = encoder.d_model 115 | decoder_cfg["patch_size"] = encoder.patch_size 116 | 117 | if "linear" in name: 118 | decoder = DecoderLinear(**decoder_cfg) 119 | elif name == "mask_transformer": 120 | dim = encoder.d_model 121 | n_heads = dim // 64 122 | decoder_cfg["n_heads"] = n_heads 123 | decoder_cfg["d_model"] = dim 124 | decoder_cfg["d_ff"] = 4 * dim 125 | decoder = MaskTransformer(**decoder_cfg) 126 | elif name == "multi_mask_transformer": 127 | dim = encoder.d_model 128 | n_heads = dim // 64 129 | decoder_cfg["n_heads"] = n_heads 130 | decoder_cfg["d_model"] = dim 131 | decoder_cfg["d_ff"] = 4 * dim 132 | decoder = MultiMaskTransformer(**decoder_cfg) 133 | else: 134 | raise ValueError(f"Unknown decoder: {name}") 135 | return decoder 136 | 137 | 138 | def create_segmenter(model_cfg): 139 | model_cfg = model_cfg.copy() 140 | decoder_cfg = model_cfg.pop("decoder") 141 | decoder_cfg["n_cls"] = model_cfg["n_cls"] 142 | 143 | encoder = create_vit(model_cfg) 144 | decoder = create_decoder(encoder, decoder_cfg) 145 | model = Segmenter(encoder, decoder, n_cls=model_cfg["n_cls"]) 146 | 147 | return model 148 | 149 | 150 | def create_multi_segmenter(model_cfg): 151 | model_cfg = model_cfg.copy() 152 | decoder_cfg = model_cfg.pop("decoder") 153 | decoder_cfg["n_cls"] = model_cfg["n_cls"] 154 | 155 | encoder = create_vit(model_cfg) 156 | decoder = create_decoder(encoder, decoder_cfg) 157 | model = MultiSegmenter(encoder, decoder, n_cls=model_cfg["n_cls"]) 158 | return model 159 | 160 | 161 | def load_model(model_path, backbone=None): 162 | variant_path = Path(model_path).parent / "variant.yml" 163 | with open(variant_path, "r") as f: 164 | variant = yaml.load(f, Loader=yaml.FullLoader) 165 | net_kwargs = variant["net_kwargs"] 166 | if backbone is None: 167 | backbone = net_kwargs["backbone"] 168 | if "multi" in backbone: 169 | net_kwargs["decoder"]["name"] = "multi_mask_transformer" 170 | model = create_multi_segmenter(net_kwargs) 171 | else: 172 | model = create_segmenter(net_kwargs) 173 | data = torch.load(model_path, map_location="cpu") 174 | checkpoint = data["model"] 175 | 176 | model.load_state_dict(checkpoint, strict=False) 177 | 178 | return model, variant 179 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/model/rope.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # EVA-02: A Visual Representation for Neon Genesis 3 | # Github source: https://github.com/baaivision/EVA/EVA02 4 | # Copyright (c) 2023 Beijing Academy of Artificial Intelligence (BAAI) 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # By Yuxin Fang 7 | # 8 | # Based on https://github.com/lucidrains/rotary-embedding-torch 9 | # --------------------------------------------------------' 10 | 11 | from math import pi 12 | 13 | import torch 14 | from torch import nn 15 | 16 | from einops import rearrange, repeat 17 | 18 | 19 | 20 | def broadcat(tensors, dim = -1): 21 | num_tensors = len(tensors) 22 | shape_lens = set(list(map(lambda t: len(t.shape), tensors))) 23 | assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions' 24 | shape_len = list(shape_lens)[0] 25 | dim = (dim + shape_len) if dim < 0 else dim 26 | dims = list(zip(*map(lambda t: list(t.shape), tensors))) 27 | expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] 28 | assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation' 29 | max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) 30 | expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) 31 | expanded_dims.insert(dim, (dim, dims[dim])) 32 | expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) 33 | tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) 34 | return torch.cat(tensors, dim = dim) 35 | 36 | 37 | 38 | def rotate_half(x): 39 | x = rearrange(x, '... (d r) -> ... d r', r = 2) 40 | x1, x2 = x.unbind(dim = -1) 41 | x = torch.stack((-x2, x1), dim = -1) 42 | return rearrange(x, '... d r -> ... (d r)') 43 | 44 | 45 | 46 | class VisionRotaryEmbedding(nn.Module): 47 | def __init__( 48 | self, 49 | dim, 50 | pt_seq_len, 51 | ft_seq_len=None, 52 | custom_freqs = None, 53 | freqs_for = 'lang', 54 | theta = 10000, 55 | max_freq = 10, 56 | num_freqs = 1, 57 | ): 58 | super().__init__() 59 | if custom_freqs: 60 | freqs = custom_freqs 61 | elif freqs_for == 'lang': 62 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) 63 | elif freqs_for == 'pixel': 64 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi 65 | elif freqs_for == 'constant': 66 | freqs = torch.ones(num_freqs).float() 67 | else: 68 | raise ValueError(f'unknown modality {freqs_for}') 69 | 70 | if ft_seq_len is None: ft_seq_len = pt_seq_len 71 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 72 | 73 | freqs_h = torch.einsum('..., f -> ... f', t, freqs) 74 | freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2) 75 | 76 | freqs_w = torch.einsum('..., f -> ... f', t, freqs) 77 | freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2) 78 | 79 | freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1) 80 | 81 | self.register_buffer("freqs_cos", freqs.cos()) 82 | self.register_buffer("freqs_sin", freqs.sin()) 83 | 84 | print('======== shape of rope freq', self.freqs_cos.shape, '========') 85 | 86 | def forward(self, t, start_index = 0): 87 | rot_dim = self.freqs_cos.shape[-1] 88 | end_index = start_index + rot_dim 89 | assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' 90 | t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] 91 | t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) 92 | return torch.cat((t_left, t, t_right), dim = -1) 93 | 94 | 95 | 96 | class VisionRotaryEmbeddingFast(nn.Module): 97 | def __init__( 98 | self, 99 | dim, 100 | pt_seq_len=16, 101 | ft_seq_len=None, 102 | custom_freqs = None, 103 | freqs_for = 'lang', 104 | theta = 10000, 105 | max_freq = 10, 106 | num_freqs = 1, 107 | ): 108 | super().__init__() 109 | if custom_freqs: 110 | freqs = custom_freqs 111 | elif freqs_for == 'lang': 112 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) 113 | elif freqs_for == 'pixel': 114 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi 115 | elif freqs_for == 'constant': 116 | freqs = torch.ones(num_freqs).float() 117 | else: 118 | raise ValueError(f'unknown modality {freqs_for}') 119 | 120 | if ft_seq_len is None: ft_seq_len = pt_seq_len 121 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 122 | 123 | freqs = torch.einsum('..., f -> ... f', t, freqs) 124 | freqs = repeat(freqs, '... n -> ... (n r)', r = 2) 125 | freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1) 126 | 127 | freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) 128 | freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) 129 | 130 | self.register_buffer("freqs_cos", freqs_cos) 131 | self.register_buffer("freqs_sin", freqs_sin) 132 | 133 | print('======== shape of rope freq', self.freqs_cos.shape, '========') 134 | 135 | def forward(self, t): return t * self.freqs_cos + rotate_half(t) * self.freqs_sin 136 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/model/segmenter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from segm.model.utils import padding, unpadding 6 | from timm.models.layers import trunc_normal_ 7 | 8 | 9 | class Segmenter(nn.Module): 10 | def __init__( 11 | self, 12 | encoder, 13 | decoder, 14 | n_cls, 15 | ): 16 | super().__init__() 17 | self.n_cls = n_cls 18 | self.patch_size = encoder.patch_size 19 | self.encoder = encoder 20 | self.decoder = decoder 21 | 22 | @torch.jit.ignore 23 | def no_weight_decay(self): 24 | def append_prefix_no_weight_decay(prefix, module): 25 | return set(map(lambda x: prefix + x, module.no_weight_decay())) 26 | 27 | nwd_params = append_prefix_no_weight_decay("encoder.", self.encoder).union( 28 | append_prefix_no_weight_decay("decoder.", self.decoder) 29 | ) 30 | return nwd_params 31 | 32 | def forward(self, im): 33 | H_ori, W_ori = im.size(2), im.size(3) 34 | im = padding(im, self.patch_size) 35 | H, W = im.size(2), im.size(3) 36 | 37 | x = self.encoder(im, return_features=True) 38 | 39 | # remove CLS/DIST tokens for decoding 40 | num_extra_tokens = 1 + self.encoder.distilled 41 | x = x[:, num_extra_tokens:] 42 | 43 | masks = self.decoder(x, (H, W)) 44 | 45 | masks = F.interpolate(masks, size=(H, W), mode="bilinear") 46 | masks = unpadding(masks, (H_ori, W_ori)) 47 | 48 | return masks 49 | 50 | def get_attention_map_enc(self, im, layer_id): 51 | return self.encoder.get_attention_map(im, layer_id) 52 | 53 | def get_attention_map_dec(self, im, layer_id): 54 | x = self.encoder(im, return_features=True) 55 | 56 | # remove CLS/DIST tokens for decoding 57 | num_extra_tokens = 1 + self.encoder.distilled 58 | x = x[:, num_extra_tokens:] 59 | 60 | return self.decoder.get_attention_map(x, layer_id) 61 | 62 | 63 | class MultiSegmenter(Segmenter): 64 | def __init__(self, encoder, 65 | decoder, 66 | n_cls): 67 | super(MultiSegmenter, self).__init__(encoder, 68 | decoder, 69 | n_cls,) 70 | 71 | def forward(self, im): 72 | H_ori, W_ori = im.size(2), im.size(3) 73 | im = padding(im, self.patch_size) 74 | H, W = im.size(2), im.size(3) 75 | 76 | x = self.encoder(im, return_features=True) 77 | 78 | # remove CLS/DIST tokens for decoding 79 | num_extra_tokens = 1 + self.encoder.distilled 80 | # cls_token = x[:, 0] 81 | # cls_pred = self.cls_pred(cls_token) 82 | 83 | x = x[:, num_extra_tokens:] 84 | 85 | masks1, masks2 = self.decoder(x, (H, W)) 86 | 87 | masks1 = F.interpolate(masks1, size=(H, W), mode="bilinear") 88 | masks1 = unpadding(masks1, (H_ori, W_ori)) 89 | masks2 = F.interpolate(masks2, size=(H, W), mode="bilinear") 90 | masks2 = unpadding(masks2, (H_ori, W_ori)) 91 | 92 | return masks1, masks2 93 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/model/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from collections import defaultdict 6 | 7 | from timm.models.layers import trunc_normal_ 8 | 9 | import segm.utils.torch as ptu 10 | 11 | 12 | def init_weights(m): 13 | if isinstance(m, nn.Linear): 14 | trunc_normal_(m.weight, std=0.02) 15 | if isinstance(m, nn.Linear) and m.bias is not None: 16 | nn.init.constant_(m.bias, 0) 17 | elif isinstance(m, nn.LayerNorm): 18 | nn.init.constant_(m.bias, 0) 19 | nn.init.constant_(m.weight, 1.0) 20 | 21 | 22 | def resize_pos_embed(posemb, grid_old_shape, grid_new_shape, num_extra_tokens): 23 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 24 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 25 | posemb_tok, posemb_grid = ( 26 | posemb[:, :num_extra_tokens], 27 | posemb[0, num_extra_tokens:], 28 | ) 29 | if grid_old_shape is None: 30 | gs_old_h = int(math.sqrt(len(posemb_grid))) 31 | gs_old_w = gs_old_h 32 | else: 33 | gs_old_h, gs_old_w = grid_old_shape 34 | 35 | gs_h, gs_w = grid_new_shape 36 | posemb_grid = posemb_grid.reshape(1, gs_old_h, gs_old_w, -1).permute(0, 3, 1, 2) 37 | posemb_grid = F.interpolate(posemb_grid.cuda(), size=(gs_h, gs_w), mode="bilinear").cpu() 38 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) 39 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 40 | return posemb 41 | 42 | 43 | def checkpoint_filter_fn(state_dict, model): 44 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 45 | out_dict = {} 46 | if "model" in state_dict: 47 | # For deit models 48 | state_dict = state_dict["model"] 49 | if "module" in state_dict: 50 | state_dict = state_dict["module"] 51 | num_extra_tokens = 1 + ("dist_token" in state_dict.keys()) 52 | patch_size = model.patch_size 53 | image_size = model.patch_embed.image_size 54 | for k, v in state_dict.items(): 55 | if k == "patch_embed.proj.weight": 56 | v = F.interpolate( 57 | v.float(), 58 | size=(patch_size, patch_size), 59 | mode="bicubic", align_corners=False, 60 | ) 61 | if k == "pos_embed" and v.shape != model.pos_embed.shape: 62 | # To resize pos embedding when using model at different size from pretrained weights 63 | v = resize_pos_embed( 64 | v, 65 | None, 66 | (image_size[0] // patch_size, image_size[1] // patch_size), 67 | num_extra_tokens, 68 | ) 69 | if "rope.freqs_cos" in k and v.shape != model.rope.freqs_cos.shape: 70 | print(f"Skipping loading {k} because of shape mismatch of {v.shape} vs {model.rope.freqs_cos.shape}") 71 | continue 72 | if "rope.freqs_sin" in k and v.shape != model.rope.freqs_sin.shape: 73 | print(f"Skipping loading {k} because of shape mismatch of {v.shape} vs {model.rope.freqs_sin.shape}") 74 | continue 75 | out_dict[k] = v 76 | return out_dict 77 | 78 | 79 | def padding(im, patch_size, fill_value=0): 80 | # make the image sizes divisible by patch_size 81 | H, W = im.size(2), im.size(3) 82 | pad_h, pad_w = 0, 0 83 | if H % patch_size > 0: 84 | pad_h = patch_size - (H % patch_size) 85 | if W % patch_size > 0: 86 | pad_w = patch_size - (W % patch_size) 87 | im_padded = im 88 | if pad_h > 0 or pad_w > 0: 89 | im_padded = F.pad(im, (0, pad_w, 0, pad_h), value=fill_value) 90 | return im_padded 91 | 92 | 93 | def unpadding(y, target_size): 94 | H, W = target_size 95 | H_pad, W_pad = y.size(2), y.size(3) 96 | # crop predictions on extra pixels coming from padding 97 | extra_h = H_pad - H 98 | extra_w = W_pad - W 99 | if extra_h > 0: 100 | y = y[:, :, :-extra_h] 101 | if extra_w > 0: 102 | y = y[:, :, :, :-extra_w] 103 | return y 104 | 105 | 106 | def resize(im, smaller_size): 107 | h, w = im.shape[2:] 108 | if h < w: 109 | ratio = w / h 110 | h_res, w_res = smaller_size, ratio * smaller_size 111 | else: 112 | ratio = h / w 113 | h_res, w_res = ratio * smaller_size, smaller_size 114 | if min(h, w) < smaller_size: 115 | im_res = F.interpolate(im, (int(h_res), int(w_res)), mode="bilinear") 116 | else: 117 | im_res = im 118 | return im_res 119 | 120 | 121 | def sliding_window(im, flip, window_size, window_stride): 122 | B, C, H, W = im.shape 123 | ws = window_size 124 | 125 | windows = {"crop": [], "anchors": []} 126 | h_anchors = torch.arange(0, H, window_stride) 127 | w_anchors = torch.arange(0, W, window_stride) 128 | h_anchors = [h.item() for h in h_anchors if h < H - ws] + [H - ws] 129 | w_anchors = [w.item() for w in w_anchors if w < W - ws] + [W - ws] 130 | for ha in h_anchors: 131 | for wa in w_anchors: 132 | window = im[:, :, ha : ha + ws, wa : wa + ws] 133 | windows["crop"].append(window) 134 | windows["anchors"].append((ha, wa)) 135 | windows["flip"] = flip 136 | windows["shape"] = (H, W) 137 | return windows 138 | 139 | 140 | def merge_windows(windows, window_size, ori_shape): 141 | ws = window_size 142 | im_windows = windows["seg_maps"] 143 | anchors = windows["anchors"] 144 | C = im_windows[0].shape[0] 145 | H, W = windows["shape"] 146 | flip = windows["flip"] 147 | 148 | logit = torch.zeros((C, H, W), device=im_windows.device) 149 | count = torch.zeros((1, H, W), device=im_windows.device) 150 | for window, (ha, wa) in zip(im_windows, anchors): 151 | logit[:, ha : ha + ws, wa : wa + ws] += window 152 | count[:, ha : ha + ws, wa : wa + ws] += 1 153 | logit = logit / count 154 | logit = F.interpolate( 155 | logit.unsqueeze(0), 156 | ori_shape, 157 | mode="bilinear", 158 | )[0] 159 | if flip: 160 | logit = torch.flip(logit, (2,)) 161 | result = F.softmax(logit, 0) 162 | return result 163 | 164 | 165 | def inference( 166 | model, 167 | ims, 168 | ims_metas, 169 | ori_shape, 170 | window_size, 171 | window_stride, 172 | batch_size, 173 | ): 174 | C = model.n_cls 175 | seg_map = torch.zeros((C, ori_shape[0], ori_shape[1]), device=ptu.device) 176 | seg_map1 = torch.zeros((C, ori_shape[0], ori_shape[1]), device=ptu.device) 177 | seg_map2 = torch.zeros((C, ori_shape[0], ori_shape[1]), device=ptu.device) 178 | flag = 0 179 | for im, im_metas in zip(ims, ims_metas): 180 | im = im.to(ptu.device) 181 | im = resize(im, window_size) 182 | flip = im_metas["flip"] 183 | windows = sliding_window(im, flip, window_size, window_stride) 184 | windows1 = sliding_window(im, flip, window_size, window_stride) 185 | windows2 = sliding_window(im, flip, window_size, window_stride) 186 | crops = torch.stack(windows.pop("crop"))[:, 0] 187 | B = len(crops) 188 | WB = batch_size 189 | seg_maps = torch.zeros((B, C, window_size, window_size), device=im.device) 190 | seg_maps1 = torch.zeros((B, C, window_size, window_size), device=im.device) 191 | seg_maps2 = torch.zeros((B, C, window_size, window_size), device=im.device) 192 | with torch.no_grad(): 193 | for i in range(0, B, WB): 194 | map = model.forward(crops[i: i + WB]) 195 | if len(map) == 2: 196 | flag = 1 197 | seg_maps1[i: i + WB], seg_maps2[i: i + WB] = map[0], map[1] 198 | else: 199 | seg_maps[i: i + WB] = map 200 | 201 | windows["seg_maps"] = seg_maps 202 | im_seg_map = merge_windows(windows, window_size, ori_shape) 203 | seg_map += im_seg_map 204 | 205 | windows1["seg_maps"] = seg_maps1 206 | im_seg_map = merge_windows(windows1, window_size, ori_shape) 207 | seg_map1 += im_seg_map 208 | 209 | windows2["seg_maps"] = seg_maps2 210 | im_seg_map = merge_windows(windows2, window_size, ori_shape) 211 | seg_map2 += im_seg_map 212 | if flag == 0: 213 | seg_map /= len(ims) 214 | return seg_map 215 | seg_map1 /= len(ims) 216 | seg_map2 /= len(ims) 217 | return seg_map1, seg_map2 218 | 219 | 220 | def num_params(model): 221 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 222 | n_params = sum([torch.prod(torch.tensor(p.size())) for p in model_parameters]) 223 | return n_params.item() 224 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/model/vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from 2020 Ross Wightman 3 | https://github.com/rwightman/pytorch-image-models 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from segm.model.utils import init_weights, resize_pos_embed 10 | from segm.model.blocks import Block, DINOBlock, EVA02Block 11 | 12 | from timm.models.layers import DropPath 13 | from timm.models.layers import trunc_normal_ 14 | from timm.models.vision_transformer import _load_weights 15 | 16 | from segm.model.rope import * 17 | class PatchEmbedding(nn.Module): 18 | def __init__(self, image_size, patch_size, embed_dim, channels): 19 | super().__init__() 20 | 21 | self.image_size = image_size 22 | if image_size[0] % patch_size != 0 or image_size[1] % patch_size != 0: 23 | raise ValueError("image dimensions must be divisible by the patch size") 24 | self.grid_size = image_size[0] // patch_size, image_size[1] // patch_size 25 | self.num_patches = self.grid_size[0] * self.grid_size[1] 26 | self.patch_size = patch_size 27 | 28 | self.proj = nn.Conv2d( 29 | channels, embed_dim, kernel_size=patch_size, stride=patch_size 30 | ) 31 | 32 | def forward(self, im): 33 | B, C, H, W = im.shape 34 | x = self.proj(im).flatten(2).transpose(1, 2) 35 | return x 36 | 37 | 38 | class VisionTransformer(nn.Module): 39 | def __init__( 40 | self, 41 | image_size, 42 | patch_size, 43 | n_layers, 44 | d_model, 45 | d_ff, 46 | n_heads, 47 | n_cls, 48 | dropout=0.1, 49 | drop_path_rate=0.0, 50 | distilled=False, 51 | channels=3, 52 | ): 53 | super().__init__() 54 | self.patch_embed = PatchEmbedding( 55 | image_size, 56 | patch_size, 57 | d_model, 58 | channels, 59 | ) 60 | self.patch_size = patch_size 61 | self.n_layers = n_layers 62 | self.d_model = d_model 63 | self.d_ff = d_ff 64 | self.n_heads = n_heads 65 | self.dropout = nn.Dropout(dropout) 66 | self.n_cls = n_cls 67 | 68 | # cls and pos tokens 69 | self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) 70 | self.distilled = distilled 71 | if self.distilled: 72 | self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model)) 73 | self.pos_embed = nn.Parameter( 74 | torch.randn(1, self.patch_embed.num_patches + 2, d_model) 75 | ) 76 | self.head_dist = nn.Linear(d_model, n_cls) 77 | else: 78 | self.pos_embed = nn.Parameter( 79 | torch.randn(1, self.patch_embed.num_patches + 1, d_model) 80 | ) 81 | 82 | # transformer blocks 83 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)] 84 | self.blocks = nn.ModuleList( 85 | [Block(d_model, n_heads, d_ff, dropout, dpr[i]) for i in range(n_layers)] 86 | ) 87 | 88 | # output head 89 | self.norm = nn.LayerNorm(d_model) 90 | self.head = nn.Linear(d_model, n_cls) 91 | 92 | trunc_normal_(self.pos_embed, std=0.02) 93 | trunc_normal_(self.cls_token, std=0.02) 94 | if self.distilled: 95 | trunc_normal_(self.dist_token, std=0.02) 96 | self.pre_logits = nn.Identity() 97 | 98 | self.apply(init_weights) 99 | 100 | @torch.jit.ignore 101 | def no_weight_decay(self): 102 | return {"pos_embed", "cls_token", "dist_token"} 103 | 104 | @torch.jit.ignore() 105 | def load_pretrained(self, checkpoint_path, prefix=""): 106 | _load_weights(self, checkpoint_path, prefix) 107 | 108 | def get_num_layers(self): 109 | return len(self.blocks) 110 | 111 | def forward(self, im, return_features=False): 112 | B, _, H, W = im.shape 113 | PS = self.patch_size 114 | 115 | x = self.patch_embed(im) 116 | cls_tokens = self.cls_token.expand(B, -1, -1) 117 | if self.distilled: 118 | dist_tokens = self.dist_token.expand(B, -1, -1) 119 | x = torch.cat((cls_tokens, dist_tokens, x), dim=1) 120 | else: 121 | x = torch.cat((cls_tokens, x), dim=1) 122 | 123 | pos_embed = self.pos_embed 124 | num_extra_tokens = 1 + self.distilled 125 | if x.shape[1] != pos_embed.shape[1]: 126 | pos_embed = resize_pos_embed( 127 | pos_embed, 128 | self.patch_embed.grid_size, 129 | (H // PS, W // PS), 130 | num_extra_tokens, 131 | ) 132 | x = x + pos_embed 133 | x = self.dropout(x) 134 | 135 | for blk in self.blocks: 136 | x = blk(x) 137 | x = self.norm(x) 138 | 139 | if return_features: 140 | return x 141 | 142 | if self.distilled: 143 | x, x_dist = x[:, 0], x[:, 1] 144 | x = self.head(x) 145 | x_dist = self.head_dist(x_dist) 146 | x = (x + x_dist) / 2 147 | else: 148 | x = x[:, 0] 149 | x = self.head(x) 150 | return x 151 | 152 | def get_attention_map(self, im, layer_id): 153 | if layer_id >= self.n_layers or layer_id < 0: 154 | raise ValueError( 155 | f"Provided layer_id: {layer_id} is not valid. 0 <= {layer_id} < {self.n_layers}." 156 | ) 157 | B, _, H, W = im.shape 158 | PS = self.patch_size 159 | 160 | x = self.patch_embed(im) 161 | cls_tokens = self.cls_token.expand(B, -1, -1) 162 | if self.distilled: 163 | dist_tokens = self.dist_token.expand(B, -1, -1) 164 | x = torch.cat((cls_tokens, dist_tokens, x), dim=1) 165 | else: 166 | x = torch.cat((cls_tokens, x), dim=1) 167 | 168 | pos_embed = self.pos_embed 169 | num_extra_tokens = 1 + self.distilled 170 | if x.shape[1] != pos_embed.shape[1]: 171 | pos_embed = resize_pos_embed( 172 | pos_embed, 173 | self.patch_embed.grid_size, 174 | (H // PS, W // PS), 175 | num_extra_tokens, 176 | ) 177 | x = x + pos_embed 178 | 179 | for i, blk in enumerate(self.blocks): 180 | if i < layer_id: 181 | x = blk(x) 182 | else: 183 | return blk(x, return_attention=True) 184 | 185 | 186 | class DINOV2VisionTransformer(VisionTransformer): 187 | def __init__( 188 | self, 189 | image_size, 190 | patch_size, 191 | n_layers, 192 | d_model, 193 | d_ff, 194 | n_heads, 195 | n_cls, 196 | dropout=0.1, 197 | drop_path_rate=0.0, 198 | distilled=False, 199 | channels=3, 200 | init_values: float = 1.0, 201 | ): 202 | super().__init__( 203 | image_size, 204 | patch_size, 205 | n_layers, 206 | d_model, 207 | d_ff, 208 | n_heads, 209 | n_cls, 210 | dropout, 211 | drop_path_rate, 212 | distilled, 213 | channels, 214 | ) 215 | 216 | # transformer blocks 217 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)] 218 | self.blocks = nn.ModuleList( 219 | [DINOBlock(d_model, n_heads, d_ff, dropout, dpr[i], init_values) for i in range(n_layers)] 220 | ) 221 | 222 | class EVA02VisionTransformer(VisionTransformer): 223 | def __init__( 224 | self, 225 | image_size, 226 | patch_size, 227 | n_layers, 228 | d_model, 229 | d_ff, 230 | n_heads, 231 | n_cls, 232 | dropout=0.1, 233 | drop_path_rate=0.0, 234 | distilled=False, 235 | channels=3, 236 | init_values: float = 1.0, 237 | intp_freq = True, 238 | xattn=True, 239 | norm_layer=nn.LayerNorm, 240 | ): 241 | super().__init__( 242 | image_size, 243 | patch_size, 244 | n_layers, 245 | d_model, 246 | d_ff, 247 | n_heads, 248 | n_cls, 249 | dropout, 250 | drop_path_rate, 251 | distilled, 252 | channels, 253 | ) 254 | 255 | # transformer blocks 256 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)] 257 | half_head_dim = d_model // n_heads // 2 258 | hw_seq_len = image_size[0] // patch_size 259 | self.rope = VisionRotaryEmbeddingFast( 260 | dim=half_head_dim, 261 | pt_seq_len=16, 262 | ft_seq_len=hw_seq_len if intp_freq else None, 263 | ) 264 | self.blocks = nn.ModuleList( 265 | [EVA02Block(d_model, n_heads, d_ff, dropout, dpr[i], init_values, 266 | rope=self.rope, norm_layer=norm_layer) for i in range(n_layers)] 267 | ) 268 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/optim/factory.py: -------------------------------------------------------------------------------- 1 | from timm import scheduler 2 | from timm import optim 3 | 4 | from segm.optim.scheduler import PolynomialLR 5 | 6 | 7 | def create_scheduler(opt_args, optimizer): 8 | if opt_args.sched == "polynomial": 9 | lr_scheduler = PolynomialLR( 10 | optimizer, 11 | opt_args.poly_step_size, 12 | opt_args.iter_warmup, 13 | opt_args.iter_max, 14 | opt_args.poly_power, 15 | opt_args.min_lr, 16 | ) 17 | else: 18 | lr_scheduler, _ = scheduler.create_scheduler(opt_args, optimizer) 19 | return lr_scheduler 20 | 21 | 22 | def create_optimizer(opt_args, model): 23 | return optim.create_optimizer(opt_args, model) 24 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/optim/optim_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Copy-paste from BEiT library: 9 | https://github.com/microsoft/unilm/tree/master/beit 10 | """ 11 | 12 | import torch 13 | import json 14 | 15 | try: 16 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 17 | 18 | has_apex = True 19 | except ImportError: 20 | has_apex = False 21 | 22 | from torch import optim as optim 23 | from timm.optim.adafactor import Adafactor 24 | from timm.optim.adahessian import Adahessian 25 | from timm.optim.adamp import AdamP 26 | from timm.optim.lookahead import Lookahead 27 | from timm.optim.nadam import Nadam 28 | # from timm.optim.novograd import NovoGrad 29 | from timm.optim.nvnovograd import NvNovoGrad 30 | from timm.optim.radam import RAdam 31 | from timm.optim.rmsprop_tf import RMSpropTF 32 | from timm.optim.sgdp import SGDP 33 | 34 | from segm.utils.logger import printd 35 | 36 | 37 | def get_num_layer_for_vit(var_name, num_max_layer): 38 | if var_name.startswith("encoder"): 39 | name_list = ["cls_token", "pos_embed", "dist_token"] 40 | for list in name_list: 41 | if list in var_name: 42 | return 0 43 | if "patch_embed" in var_name: 44 | return 0 45 | elif "blocks" in var_name: 46 | layer_id = int(var_name.split('.')[2]) 47 | return layer_id + 1 48 | else: 49 | return num_max_layer - 2 50 | return num_max_layer - 1 51 | 52 | 53 | class LayerDecayValueAssigner(object): 54 | def __init__(self, values, is_swin=False, is_se=False, depths=None): 55 | self.values = values 56 | self.is_swin = is_swin 57 | self.depths = depths 58 | self.is_se = is_se 59 | 60 | def get_scale(self, layer_id): 61 | return self.values[layer_id] 62 | 63 | def get_layer_id(self, var_name): 64 | return get_num_layer_for_vit(var_name, len(self.values)) 65 | 66 | 67 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None): 68 | parameter_group_names = {} 69 | parameter_group_vars = {} 70 | 71 | for name, param in model.named_parameters(): 72 | if not param.requires_grad: 73 | continue # frozen weights 74 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 75 | group_name = "no_decay" 76 | this_weight_decay = 0. 77 | else: 78 | group_name = "decay" 79 | this_weight_decay = weight_decay 80 | if get_num_layer is not None: 81 | layer_id = get_num_layer(name) 82 | group_name = "layer_%d_%s" % (layer_id, group_name) 83 | else: 84 | layer_id = None 85 | 86 | if group_name not in parameter_group_names: 87 | if get_layer_scale is not None: 88 | scale = get_layer_scale(layer_id) 89 | else: 90 | scale = 1. 91 | 92 | parameter_group_names[group_name] = { 93 | "weight_decay": this_weight_decay, 94 | "params": [], 95 | "lr_scale": scale 96 | } 97 | parameter_group_vars[group_name] = { 98 | "weight_decay": this_weight_decay, 99 | "params": [], 100 | "lr_scale": scale 101 | } 102 | 103 | parameter_group_vars[group_name]["params"].append(param) 104 | parameter_group_names[group_name]["params"].append(name) 105 | printd("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) 106 | return list(parameter_group_vars.values()) 107 | 108 | 109 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None): 110 | opt_lower = args.opt.lower() 111 | weight_decay = args.weight_decay 112 | if filter_bias_and_bn: 113 | skip = {} 114 | if skip_list is not None: 115 | skip = skip_list 116 | elif hasattr(model, 'no_weight_decay'): 117 | skip = model.no_weight_decay() 118 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale) 119 | weight_decay = 0. 120 | else: 121 | parameters = model.parameters() 122 | 123 | if 'fused' in opt_lower: 124 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 125 | 126 | opt_args = dict(lr=args.lr, weight_decay=weight_decay) 127 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 128 | opt_args['eps'] = args.opt_eps 129 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 130 | opt_args['betas'] = args.opt_betas 131 | 132 | opt_split = opt_lower.split('_') 133 | opt_lower = opt_split[-1] 134 | if opt_lower == 'sgd' or opt_lower == 'nesterov': 135 | opt_args.pop('eps', None) 136 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 137 | elif opt_lower == 'momentum': 138 | opt_args.pop('eps', None) 139 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 140 | elif opt_lower == 'adam': 141 | optimizer = optim.Adam(parameters, **opt_args) 142 | elif opt_lower == 'adamw': 143 | optimizer = optim.AdamW(parameters, **opt_args) 144 | elif opt_lower == 'nadam': 145 | optimizer = Nadam(parameters, **opt_args) 146 | elif opt_lower == 'radam': 147 | optimizer = RAdam(parameters, **opt_args) 148 | elif opt_lower == 'adamp': 149 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) 150 | elif opt_lower == 'sgdp': 151 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) 152 | elif opt_lower == 'adadelta': 153 | optimizer = optim.Adadelta(parameters, **opt_args) 154 | elif opt_lower == 'adafactor': 155 | if not args.lr: 156 | opt_args['lr'] = None 157 | optimizer = Adafactor(parameters, **opt_args) 158 | elif opt_lower == 'adahessian': 159 | optimizer = Adahessian(parameters, **opt_args) 160 | elif opt_lower == 'rmsprop': 161 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 162 | elif opt_lower == 'rmsproptf': 163 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 164 | elif opt_lower == 'novograd': 165 | optimizer = NovoGrad(parameters, **opt_args) 166 | elif opt_lower == 'nvnovograd': 167 | optimizer = NvNovoGrad(parameters, **opt_args) 168 | elif opt_lower == 'fusedsgd': 169 | opt_args.pop('eps', None) 170 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 171 | elif opt_lower == 'fusedmomentum': 172 | opt_args.pop('eps', None) 173 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 174 | elif opt_lower == 'fusedadam': 175 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) 176 | elif opt_lower == 'fusedadamw': 177 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) 178 | elif opt_lower == 'fusedlamb': 179 | optimizer = FusedLAMB(parameters, **opt_args) 180 | elif opt_lower == 'fusednovograd': 181 | opt_args.setdefault('betas', (0.95, 0.98)) 182 | optimizer = FusedNovoGrad(parameters, **opt_args) 183 | else: 184 | assert False and "Invalid optimizer" 185 | raise ValueError 186 | 187 | if len(opt_split) > 1: 188 | if opt_split[0] == 'lookahead': 189 | optimizer = Lookahead(optimizer) 190 | 191 | return optimizer 192 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/optim/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch import optim 2 | from torch.optim.lr_scheduler import _LRScheduler 3 | from timm.scheduler.scheduler import Scheduler 4 | 5 | 6 | class PolynomialLR(_LRScheduler): 7 | def __init__( 8 | self, 9 | optimizer, 10 | step_size, 11 | iter_warmup, 12 | iter_max, 13 | power, 14 | min_lr=0, 15 | last_epoch=-1, 16 | ): 17 | self.step_size = step_size 18 | self.iter_warmup = int(iter_warmup) 19 | self.iter_max = int(iter_max) 20 | self.power = power 21 | self.min_lr = min_lr 22 | super(PolynomialLR, self).__init__(optimizer, last_epoch) 23 | 24 | def polynomial_decay(self, lr): 25 | iter_cur = float(self.last_epoch) 26 | if iter_cur < self.iter_warmup: 27 | coef = iter_cur / self.iter_warmup 28 | coef *= (1 - self.iter_warmup / self.iter_max) ** self.power 29 | else: 30 | coef = (1 - iter_cur / self.iter_max) ** self.power 31 | return (lr - self.min_lr) * coef + self.min_lr 32 | 33 | def get_lr(self): 34 | if ( 35 | (self.last_epoch == 0) 36 | or (self.last_epoch % self.step_size != 0) 37 | or (self.last_epoch > self.iter_max) 38 | ): 39 | return [group["lr"] for group in self.optimizer.param_groups] 40 | return [self.polynomial_decay(lr) for lr in self.base_lrs] 41 | 42 | def step_update(self, num_updates): 43 | self.step() 44 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/scripts/prepare_ade20k.py: -------------------------------------------------------------------------------- 1 | """Prepare ADE20K dataset""" 2 | import click 3 | import zipfile 4 | 5 | from pathlib import Path 6 | from segm.utils.download import download 7 | 8 | 9 | def download_ade(path, overwrite=False): 10 | _AUG_DOWNLOAD_URLS = [ 11 | ( 12 | "http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip", 13 | "219e1696abb36c8ba3a3afe7fb2f4b4606a897c7", 14 | ), 15 | ( 16 | "http://data.csail.mit.edu/places/ADEchallenge/release_test.zip", 17 | "e05747892219d10e9243933371a497e905a4860c", 18 | ), 19 | ] 20 | download_dir = path / "downloads" 21 | download_dir.mkdir(parents=True, exist_ok=True) 22 | for url, checksum in _AUG_DOWNLOAD_URLS: 23 | filename = download( 24 | url, path=str(download_dir), overwrite=overwrite, sha1_hash=checksum 25 | ) 26 | # extract 27 | with zipfile.ZipFile(filename, "r") as zip_ref: 28 | zip_ref.extractall(path=str(path)) 29 | 30 | 31 | @click.command(help="Initialize ADE20K dataset.") 32 | @click.argument("download_dir", type=str) 33 | def main(download_dir): 34 | dataset_dir = Path(download_dir) / "ade20k" 35 | download_ade(dataset_dir, overwrite=False) 36 | 37 | 38 | if __name__ == "__main__": 39 | main() 40 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/scripts/prepare_cityscapes.py: -------------------------------------------------------------------------------- 1 | """Prepare Cityscapes dataset""" 2 | import click 3 | import os 4 | import shutil 5 | import mmcv 6 | import zipfile 7 | 8 | from pathlib import Path 9 | from segm.utils.download import download 10 | 11 | USERNAME = None 12 | PASSWORD = None 13 | 14 | 15 | def download_cityscapes(path, username, password, overwrite=False): 16 | _CITY_DOWNLOAD_URLS = [ 17 | ("gtFine_trainvaltest.zip", "99f532cb1af174f5fcc4c5bc8feea8c66246ddbc"), 18 | ("leftImg8bit_trainvaltest.zip", "2c0b77ce9933cc635adda307fbba5566f5d9d404"), 19 | ] 20 | download_dir = path / "downloads" 21 | download_dir.mkdir(parents=True, exist_ok=True) 22 | 23 | os.system( 24 | f"wget --keep-session-cookies --save-cookies=cookies.txt --post-data 'username={username}&password={password}&submit=Login' https://www.cityscapes-dataset.com/login/ -P {download_dir}" 25 | ) 26 | 27 | if not (download_dir / "gtFine_trainvaltest.zip").is_file(): 28 | os.system( 29 | f"wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=1 -P {download_dir}" 30 | ) 31 | 32 | if not (download_dir / "leftImg8bit_trainvaltest.zip").is_file(): 33 | os.system( 34 | f"wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=3 -P {download_dir}" 35 | ) 36 | 37 | for filename, checksum in _CITY_DOWNLOAD_URLS: 38 | # extract 39 | with zipfile.ZipFile(str(download_dir / filename), "r") as zip_ref: 40 | zip_ref.extractall(path=path) 41 | print("Extracted", filename) 42 | 43 | 44 | def install_cityscapes_api(): 45 | os.system("pip install cityscapesscripts") 46 | try: 47 | import cityscapesscripts 48 | except Exception: 49 | print( 50 | "Installing Cityscapes API failed, please install it manually %s" 51 | % (repo_url) 52 | ) 53 | 54 | 55 | def convert_json_to_label(json_file): 56 | from cityscapesscripts.preparation.json2labelImg import json2labelImg 57 | 58 | label_file = json_file.replace("_polygons.json", "_labelTrainIds.png") 59 | json2labelImg(json_file, label_file, "trainIds") 60 | 61 | 62 | @click.command(help="Initialize Cityscapes dataset.") 63 | @click.argument("download_dir", type=str) 64 | @click.option("--username", default=USERNAME, type=str) 65 | @click.option("--password", default=PASSWORD, type=str) 66 | @click.option("--nproc", default=10, type=int) 67 | def main( 68 | download_dir, 69 | username, 70 | password, 71 | nproc, 72 | ): 73 | 74 | dataset_dir = Path(download_dir) / "cityscapes" 75 | 76 | if username is None or password is None: 77 | raise ValueError( 78 | "You must indicate your username and password either in the script variables or by passing options --username and --pasword." 79 | ) 80 | 81 | download_cityscapes(dataset_dir, username, password, overwrite=False) 82 | 83 | install_cityscapes_api() 84 | 85 | gt_dir = dataset_dir / "gtFine" 86 | 87 | poly_files = [] 88 | for poly in mmcv.scandir(str(gt_dir), "_polygons.json", recursive=True): 89 | poly_file = str(gt_dir / poly) 90 | poly_files.append(poly_file) 91 | mmcv.track_parallel_progress(convert_json_to_label, poly_files, nproc) 92 | 93 | split_names = ["train", "val", "test"] 94 | 95 | for split in split_names: 96 | filenames = [] 97 | for poly in mmcv.scandir(str(gt_dir / split), "_polygons.json", recursive=True): 98 | filenames.append(poly.replace("_gtFine_polygons.json", "")) 99 | with open(str(dataset_dir / f"{split}.txt"), "w") as f: 100 | f.writelines(f + "\n" for f in filenames) 101 | 102 | 103 | if __name__ == "__main__": 104 | main() 105 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/scripts/prepare_pcontext.py: -------------------------------------------------------------------------------- 1 | """Prepare PASCAL Context dataset""" 2 | import click 3 | import shutil 4 | import tarfile 5 | import torch 6 | 7 | from tqdm import tqdm 8 | from pathlib import Path 9 | 10 | from segm.utils.download import download 11 | 12 | 13 | def download_pcontext(path, overwrite=False): 14 | _AUG_DOWNLOAD_URLS = [ 15 | ( 16 | "https://www.dropbox.com/s/wtdibo9lb2fur70/VOCtrainval_03-May-2010.tar?dl=1", 17 | "VOCtrainval_03-May-2010.tar", 18 | "bf9985e9f2b064752bf6bd654d89f017c76c395a", 19 | ), 20 | ( 21 | "https://codalabuser.blob.core.windows.net/public/trainval_merged.json", 22 | "", 23 | "169325d9f7e9047537fedca7b04de4dddf10b881", 24 | ), 25 | ( 26 | "https://hangzh.s3.amazonaws.com/encoding/data/pcontext/train.pth", 27 | "", 28 | "4bfb49e8c1cefe352df876c9b5434e655c9c1d07", 29 | ), 30 | ( 31 | "https://hangzh.s3.amazonaws.com/encoding/data/pcontext/val.pth", 32 | "", 33 | "ebedc94247ec616c57b9a2df15091784826a7b0c", 34 | ), 35 | ] 36 | download_dir = path / "downloads" 37 | 38 | download_dir.mkdir(parents=True, exist_ok=True) 39 | 40 | for url, filename, checksum in _AUG_DOWNLOAD_URLS: 41 | filename = download( 42 | url, 43 | path=str(download_dir / filename), 44 | overwrite=overwrite, 45 | sha1_hash=checksum, 46 | ) 47 | # extract 48 | if Path(filename).suffix == ".tar": 49 | with tarfile.open(filename) as tar: 50 | tar.extractall(path=str(path)) 51 | else: 52 | shutil.move( 53 | filename, 54 | str(path / "VOCdevkit" / "VOC2010" / Path(filename).name), 55 | ) 56 | 57 | 58 | @click.command(help="Initialize PASCAL Context dataset.") 59 | @click.argument("download_dir", type=str) 60 | def main(download_dir): 61 | 62 | dataset_dir = Path(download_dir) / "pcontext" 63 | 64 | download_pcontext(dataset_dir, overwrite=False) 65 | 66 | devkit_path = dataset_dir / "VOCdevkit" 67 | out_dir = devkit_path / "VOC2010" / "SegmentationClassContext" 68 | imageset_dir = devkit_path / "VOC2010" / "ImageSets" / "SegmentationContext" 69 | 70 | out_dir.mkdir(parents=True, exist_ok=True) 71 | imageset_dir.mkdir(parents=True, exist_ok=True) 72 | 73 | train_torch_path = devkit_path / "VOC2010" / "train.pth" 74 | val_torch_path = devkit_path / "VOC2010" / "val.pth" 75 | 76 | train_dict = torch.load(str(train_torch_path)) 77 | 78 | train_list = [] 79 | for idx, label in tqdm(train_dict.items()): 80 | idx = str(idx) 81 | new_idx = idx[:4] + "_" + idx[4:] 82 | train_list.append(new_idx) 83 | label_path = out_dir / f"{new_idx}.png" 84 | label.save(str(label_path)) 85 | 86 | with open(str(imageset_dir / "train.txt"), "w") as f: 87 | f.writelines(line + "\n" for line in sorted(train_list)) 88 | 89 | val_dict = torch.load(str(val_torch_path)) 90 | 91 | val_list = [] 92 | for idx, label in tqdm(val_dict.items()): 93 | idx = str(idx) 94 | new_idx = idx[:4] + "_" + idx[4:] 95 | val_list.append(new_idx) 96 | label_path = out_dir / f"{new_idx}.png" 97 | label.save(str(label_path)) 98 | 99 | with open(str(imageset_dir / "val.txt"), "w") as f: 100 | f.writelines(line + "\n" for line in sorted(val_list)) 101 | 102 | 103 | if __name__ == "__main__": 104 | main() 105 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/scripts/show_attn_map.py: -------------------------------------------------------------------------------- 1 | import click 2 | import einops 3 | import torch 4 | import torchvision 5 | 6 | import matplotlib.pyplot as plt 7 | import segm.utils.torch as ptu 8 | import torch.nn.functional as F 9 | 10 | from pathlib import Path 11 | from PIL import Image 12 | from segm import config 13 | from segm.data.utils import STATS 14 | from segm.model.decoder import MaskTransformer 15 | from segm.model.factory import load_model 16 | from torchvision import transforms 17 | 18 | 19 | @click.command() 20 | @click.argument("model-path", type=str) 21 | @click.argument("image-path", type=str) 22 | @click.argument("output-dir", type=str) 23 | @click.option("--layer-id", default=0, type=int) 24 | @click.option("--x-patch", default=0, type=int) 25 | @click.option("--y-patch", default=0, type=int) 26 | @click.option("--cmap", default="viridis", type=str) 27 | @click.option("--enc/--dec", default=True, is_flag=True) 28 | @click.option("--cls/--patch", default=False, is_flag=True) 29 | def visualize( 30 | model_path, 31 | image_path, 32 | output_dir, 33 | layer_id, 34 | x_patch, 35 | y_patch, 36 | cmap, 37 | enc, 38 | cls, 39 | ): 40 | 41 | output_dir = Path(output_dir) 42 | model_dir = Path(model_path).parent 43 | 44 | ptu.set_gpu_mode(True) 45 | 46 | # Build model 47 | model, variant = load_model(model_path) 48 | for p in model.parameters(): 49 | p.requires_grad = False 50 | 51 | model.eval() 52 | model.to(ptu.device) 53 | 54 | # Get model config 55 | patch_size = model.patch_size 56 | normalization = variant["dataset_kwargs"]["normalization"] 57 | image_size = variant["dataset_kwargs"]["image_size"] 58 | n_cls = variant["net_kwargs"]["n_cls"] 59 | stats = STATS[normalization] 60 | 61 | # Open image and process it 62 | try: 63 | with open(image_path, "rb") as f: 64 | img = Image.open(f) 65 | img = img.convert("RGB") 66 | except: 67 | raise ValueError(f"Provided image path {image_path} is not a valid image file.") 68 | 69 | # Normalize and resize 70 | transform = transforms.Compose( 71 | [ 72 | transforms.Resize(image_size), 73 | transforms.ToTensor(), 74 | transforms.Normalize(stats["mean"], stats["std"]), 75 | ] 76 | ) 77 | 78 | img = transform(img) 79 | 80 | # Make the image divisible by the patch size 81 | w, h = ( 82 | image_size - image_size % patch_size, 83 | image_size - image_size % patch_size, 84 | ) 85 | 86 | # Crop to image size 87 | img = img[:, :w, :h].unsqueeze(0) 88 | 89 | w_featmap = img.shape[-2] // patch_size 90 | h_featmap = img.shape[-1] // patch_size 91 | 92 | # Sanity checks 93 | if not enc and not isinstance(model.decoder, MaskTransformer): 94 | raise ValueError( 95 | f"Attention maps for decoder are only availabe for MaskTransformer. Provided model with decoder type: {model.decoder}." 96 | ) 97 | 98 | if not cls: 99 | if x_patch > w_featmap or y_patch > h_featmap: 100 | raise ValueError( 101 | f"Provided patch x: {x_patch} y: {y_patch} is not valid. Patch should be in the range x: [0, {w_featmap}), y: [0, {h_featmap})" 102 | ) 103 | num_patch = w_featmap * y_patch + x_patch 104 | 105 | if layer_id < 0: 106 | raise ValueError("Provided layer_id should be positive.") 107 | 108 | if enc and model.encoder.n_layers <= layer_id: 109 | raise ValueError( 110 | f"Provided layer_id: {layer_id} is not valid for encoder with {model.encoder.n_layers}." 111 | ) 112 | 113 | if not enc and model.decoder.n_layers <= layer_id: 114 | raise ValueError( 115 | f"Provided layer_id: {layer_id} is not valid for decoder with {model.decoder.n_layers}." 116 | ) 117 | 118 | Path.mkdir(output_dir, exist_ok=True) 119 | 120 | # Process input and extract attention maps 121 | if enc: 122 | print(f"Generating Attention Mapping for Encoder Layer Id {layer_id}") 123 | attentions = model.get_attention_map_enc(img.to(ptu.device), layer_id) 124 | num_extra_tokens = 1 + model.encoder.distilled 125 | if cls: 126 | attentions = attentions[0, :, 0, num_extra_tokens:] 127 | else: 128 | attentions = attentions[ 129 | 0, :, num_patch + num_extra_tokens, num_extra_tokens: 130 | ] 131 | else: 132 | print(f"Generating Attention Mapping for Decoder Layer Id {layer_id}") 133 | attentions = model.get_attention_map_dec(img.to(ptu.device), layer_id) 134 | if cls: 135 | attentions = attentions[0, :, -n_cls:, :-n_cls] 136 | else: 137 | attentions = attentions[0, :, num_patch, :-n_cls] 138 | 139 | # Reshape into image shape 140 | nh = attentions.shape[0] # Number of heads 141 | attentions = attentions.reshape(nh, -1) 142 | 143 | if cls and not enc: 144 | attentions = attentions.reshape(nh, n_cls, w_featmap, h_featmap) 145 | else: 146 | attentions = attentions.reshape(nh, 1, w_featmap, h_featmap) 147 | 148 | # Resize attention maps to match input size 149 | attentions = ( 150 | F.interpolate(attentions, scale_factor=patch_size, mode="nearest").cpu().numpy() 151 | ) 152 | 153 | # Save Attention map for each head 154 | for i in range(nh): 155 | base_name = "enc" if enc else "dec" 156 | head_name = f"{base_name}_layer{layer_id}_attn-head{i}" 157 | attention_maps_list = attentions[i] 158 | for j in range(attention_maps_list.shape[0]): 159 | attention_map = attention_maps_list[j] 160 | file_name = head_name 161 | dir_path = output_dir / f"{base_name}_layer{layer_id}" 162 | Path.mkdir(dir_path, exist_ok=True) 163 | if cls: 164 | if enc: 165 | file_name = f"{file_name}_cls" 166 | dir_path /= "cls" 167 | else: 168 | file_name = f"{file_name}_{j}" 169 | dir_path /= f"cls_{j}" 170 | Path.mkdir(dir_path, exist_ok=True) 171 | else: 172 | dir_path /= f"patch_{x_patch}_{y_patch}" 173 | Path.mkdir(dir_path, exist_ok=True) 174 | 175 | file_path = dir_path / f"{file_name}.png" 176 | plt.imsave(fname=str(file_path), arr=attention_map, format="png", cmap=cmap) 177 | print(f"{file_path} saved.") 178 | 179 | # Save input image showing selected patch 180 | if not cls: 181 | im_n = torchvision.utils.make_grid(img, normalize=True, scale_each=True) 182 | 183 | # Compute corresponding X and Y px in the original image 184 | x_px = x_patch * patch_size 185 | y_px = y_patch * patch_size 186 | px_v = einops.repeat( 187 | torch.tensor([1, 0, 0]), 188 | "c -> 1 c h w", 189 | h=patch_size, 190 | w=patch_size, 191 | ) 192 | 193 | # Draw pixels for selected patch 194 | im_n[:, y_px : y_px + patch_size, x_px : x_px + patch_size] = px_v 195 | torchvision.utils.save_image( 196 | im_n, 197 | str(dir_path / "input_img.png"), 198 | ) 199 | 200 | 201 | if __name__ == "__main__": 202 | visualize() 203 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/utils/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import hashlib 4 | from tqdm import tqdm 5 | 6 | 7 | def check_sha1(filename, sha1_hash): 8 | """Check whether the sha1 hash of the file content matches the expected hash. 9 | Parameters 10 | ---------- 11 | filename : str 12 | Path to the file. 13 | sha1_hash : str 14 | Expected sha1 hash in hexadecimal digits. 15 | Returns 16 | ------- 17 | bool 18 | Whether the file content matches the expected hash. 19 | """ 20 | sha1 = hashlib.sha1() 21 | with open(filename, "rb") as f: 22 | while True: 23 | data = f.read(1048576) 24 | if not data: 25 | break 26 | sha1.update(data) 27 | 28 | return sha1.hexdigest() == sha1_hash 29 | 30 | 31 | def download(url, path=None, overwrite=False, sha1_hash=None): 32 | """ 33 | https://github.com/junfu1115/DANet/blob/master/encoding/utils/files.py 34 | Download a given URL 35 | Parameters 36 | ---------- 37 | url : str 38 | URL to download 39 | path : str, optional 40 | Destination path to store downloaded file. By default stores to the 41 | current directory with same name as in url. 42 | overwrite : bool, optional 43 | Whether to overwrite destination file if already exists. 44 | sha1_hash : str, optional 45 | Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified 46 | but doesn't match. 47 | Returns 48 | ------- 49 | str 50 | The file path of the downloaded file. 51 | """ 52 | if path is None: 53 | fname = url.split("/")[-1] 54 | else: 55 | path = os.path.expanduser(path) 56 | if os.path.isdir(path): 57 | fname = os.path.join(path, url.split("/")[-1]) 58 | else: 59 | fname = path 60 | 61 | if ( 62 | overwrite 63 | or not os.path.exists(fname) 64 | or (sha1_hash and not check_sha1(fname, sha1_hash)) 65 | ): 66 | dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) 67 | if not os.path.exists(dirname): 68 | os.makedirs(dirname) 69 | 70 | print("Downloading %s from %s..." % (fname, url)) 71 | r = requests.get(url, stream=True) 72 | if r.status_code != 200: 73 | raise RuntimeError("Failed downloading url %s" % url) 74 | total_length = r.headers.get("content-length") 75 | with open(fname, "wb") as f: 76 | if total_length is None: # no content length header 77 | for chunk in r.iter_content(chunk_size=1024): 78 | if chunk: # filter out keep-alive new chunks 79 | f.write(chunk) 80 | else: 81 | total_length = int(total_length) 82 | for chunk in tqdm( 83 | r.iter_content(chunk_size=1024), 84 | total=int(total_length / 1024.0 + 0.5), 85 | unit="KB", 86 | unit_scale=False, 87 | dynamic_ncols=True, 88 | ): 89 | f.write(chunk) 90 | 91 | if sha1_hash and not check_sha1(fname, sha1_hash): 92 | raise UserWarning( 93 | "File {} is downloaded but the content hash does not match. " 94 | "The repo may be outdated or download may be incomplete. " 95 | 'If the "repo_url" is overridden, consider switching to ' 96 | "the default repo.".format(fname) 97 | ) 98 | 99 | return fname 100 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/utils/lines.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from itertools import cycle 3 | 4 | 5 | class Lines: 6 | def __init__(self, resolution=20, smooth=None): 7 | self.COLORS = cycle( 8 | [ 9 | "#377eb8", 10 | "#e41a1c", 11 | "#4daf4a", 12 | "#984ea3", 13 | "#ff7f00", 14 | "#ffff33", 15 | "#a65628", 16 | "#f781bf", 17 | ] 18 | ) 19 | self.MARKERS = cycle("os^Dp>d<") 20 | self.LEGEND = dict(fontsize="medium", labelspacing=0, numpoints=1) 21 | self._resolution = resolution 22 | self._smooth_weight = smooth 23 | 24 | def __call__(self, ax, domains, lines, labels): 25 | assert len(domains) == len(lines) == len(labels) 26 | colors = [] 27 | for index, (label, color, marker) in enumerate( 28 | zip(labels, self.COLORS, self.MARKERS) 29 | ): 30 | domain, line = domains[index], lines[index] 31 | line = self.smooth(line, self._smooth_weight) 32 | ax.plot(domain, line[:, 0], color=color, label=label) 33 | 34 | last_x, last_y = domain[-1], line[-1, 0] 35 | ax.scatter(last_x, last_y, color=color, marker="x") 36 | ax.annotate( 37 | f"{last_y:.2f}", 38 | xy=(last_x, last_y), 39 | xytext=(last_x, last_y + 0.1), 40 | ) 41 | colors.append(color) 42 | 43 | self._plot_legend(ax, lines, labels) 44 | return colors 45 | 46 | def _plot_legend(self, ax, lines, labels): 47 | scores = {label: -np.nanmedian(line) for label, line in zip(labels, lines)} 48 | handles, labels = ax.get_legend_handles_labels() 49 | # handles, labels = zip(*sorted( 50 | # zip(handles, labels), key=lambda x: scores[x[1]])) 51 | legend = ax.legend(handles, labels, **self.LEGEND) 52 | legend.get_frame().set_edgecolor("white") 53 | for line in legend.get_lines(): 54 | line.set_alpha(1) 55 | 56 | def smooth(self, scalars, weight): 57 | """ 58 | weight in [0, 1] 59 | exponential moving average, same as tensorboard 60 | """ 61 | assert weight >= 0 and weight <= 1 62 | last = scalars[0] 63 | smoothed = np.asarray(scalars) 64 | for i, point in enumerate(scalars): 65 | last = last * weight + (1 - weight) * point 66 | smoothed[i] = last 67 | 68 | return smoothed 69 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/utils/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/facebookresearch/deit/blob/main/utils.py 3 | """ 4 | 5 | import io 6 | import os 7 | import time 8 | from collections import defaultdict, deque 9 | import datetime 10 | 11 | import torch 12 | import torch.distributed as dist 13 | 14 | import segm.utils.torch as ptu 15 | 16 | 17 | class SmoothedValue(object): 18 | """Track a series of values and provide access to smoothed values over a 19 | window or the global series average. 20 | """ 21 | 22 | def __init__(self, window_size=20, fmt=None): 23 | if fmt is None: 24 | fmt = "{median:.4f} ({global_avg:.4f})" 25 | self.deque = deque(maxlen=window_size) 26 | self.total = 0.0 27 | self.count = 0 28 | self.fmt = fmt 29 | 30 | def update(self, value, n=1): 31 | self.deque.append(value) 32 | self.count += n 33 | self.total += value * n 34 | 35 | def synchronize_between_processes(self): 36 | """ 37 | Warning: does not synchronize the deque! 38 | """ 39 | if not is_dist_avail_and_initialized(): 40 | return 41 | t = torch.tensor( 42 | [self.count, self.total], dtype=torch.float64, device=ptu.device 43 | ) 44 | dist.barrier() 45 | dist.all_reduce(t) 46 | t = t.tolist() 47 | self.count = int(t[0]) 48 | self.total = t[1] 49 | 50 | @property 51 | def median(self): 52 | d = torch.tensor(list(self.deque)) 53 | return d.median().item() 54 | 55 | @property 56 | def avg(self): 57 | d = torch.tensor(list(self.deque), dtype=torch.float32) 58 | return d.mean().item() 59 | 60 | @property 61 | def global_avg(self): 62 | return self.total / self.count 63 | 64 | @property 65 | def max(self): 66 | return max(self.deque) 67 | 68 | @property 69 | def value(self): 70 | return self.deque[-1] 71 | 72 | def __str__(self): 73 | return self.fmt.format( 74 | median=self.median, 75 | avg=self.avg, 76 | global_avg=self.global_avg, 77 | max=self.max, 78 | value=self.value, 79 | ) 80 | 81 | 82 | class MetricLogger(object): 83 | def __init__(self, delimiter="\t"): 84 | self.meters = defaultdict(SmoothedValue) 85 | self.delimiter = delimiter 86 | 87 | def update(self, n=1, **kwargs): 88 | for k, v in kwargs.items(): 89 | if isinstance(v, torch.Tensor): 90 | v = v.item() 91 | assert isinstance(v, (float, int)) 92 | self.meters[k].update(v, n) 93 | 94 | def __getattr__(self, attr): 95 | if attr in self.meters: 96 | return self.meters[attr] 97 | if attr in self.__dict__: 98 | return self.__dict__[attr] 99 | raise AttributeError( 100 | "'{}' object has no attribute '{}'".format(type(self).__name__, attr) 101 | ) 102 | 103 | def __str__(self): 104 | loss_str = [] 105 | for name, meter in self.meters.items(): 106 | loss_str.append("{}: {}".format(name, str(meter))) 107 | return self.delimiter.join(loss_str) 108 | 109 | def synchronize_between_processes(self): 110 | for meter in self.meters.values(): 111 | meter.synchronize_between_processes() 112 | 113 | def add_meter(self, name, meter): 114 | self.meters[name] = meter 115 | 116 | def log_every(self, iterable, print_freq, header=None): 117 | i = 0 118 | if not header: 119 | header = "" 120 | start_time = time.time() 121 | end = time.time() 122 | iter_time = SmoothedValue(fmt="{avg:.4f}") 123 | data_time = SmoothedValue(fmt="{avg:.4f}") 124 | space_fmt = ":" + str(len(str(len(iterable)))) + "d" 125 | log_msg = [ 126 | header, 127 | "[{0" + space_fmt + "}/{1}]", 128 | "eta: {eta}", 129 | "{meters}", 130 | "time: {time}", 131 | "data: {data}", 132 | ] 133 | if torch.cuda.is_available(): 134 | log_msg.append("max mem: {memory:.0f}") 135 | log_msg = self.delimiter.join(log_msg) 136 | MB = 1024.0 * 1024.0 137 | for obj in iterable: 138 | data_time.update(time.time() - end) 139 | yield obj 140 | iter_time.update(time.time() - end) 141 | if i % print_freq == 0 or i == len(iterable) - 1: 142 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 143 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 144 | if torch.cuda.is_available(): 145 | printd( 146 | log_msg.format( 147 | i, 148 | len(iterable), 149 | eta=eta_string, 150 | meters=str(self), 151 | time=str(iter_time), 152 | data=str(data_time), 153 | memory=torch.cuda.max_memory_allocated() / MB, 154 | ), 155 | flush=True, 156 | ) 157 | else: 158 | printd( 159 | log_msg.format( 160 | i, 161 | len(iterable), 162 | eta=eta_string, 163 | meters=str(self), 164 | time=str(iter_time), 165 | data=str(data_time), 166 | ), 167 | flush=True, 168 | ) 169 | i += 1 170 | end = time.time() 171 | total_time = time.time() - start_time 172 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 173 | printd( 174 | "{} Total time: {} ({:.4f} s / it)".format( 175 | header, total_time_str, total_time / len(iterable) 176 | ) 177 | ) 178 | 179 | 180 | def is_dist_avail_and_initialized(): 181 | if not dist.is_available(): 182 | return False 183 | if not dist.is_initialized(): 184 | return False 185 | return True 186 | 187 | 188 | def printd(x, *args, **kwargs): 189 | if ptu.dist_rank == 0: 190 | print(x, *args, **kwargs) 191 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/utils/logs.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | import numpy as np 4 | import yaml 5 | import matplotlib.pyplot as plt 6 | import click 7 | from collections import OrderedDict 8 | 9 | from segm.utils.lines import Lines 10 | 11 | 12 | def plot_logs(logs, x_key, y_key, size, vmin, vmax, epochs): 13 | m = np.inf 14 | M = -np.inf 15 | domains = [] 16 | lines = [] 17 | y_keys = y_key.split("/") 18 | for name, log in logs.items(): 19 | logs[name] = log[:epochs] 20 | for name, log in logs.items(): 21 | domain = [x[x_key] for x in log if y_keys[0] in x] 22 | if y_keys[0] not in log[0]: 23 | continue 24 | log_plot = [x[y_keys[0]] for x in log if y_keys[0] in x] 25 | for y_key in y_keys[1:]: 26 | if y_key in log_plot[0]: 27 | log_plot = [x[y_key] for x in log_plot if y_key in x] 28 | domains.append(domain) 29 | lines.append(np.array(log_plot)[:, None]) 30 | m = np.min((m, min(log_plot))) 31 | M = np.max((M, max(log_plot))) 32 | if vmin is not None: 33 | m = vmin 34 | if vmax is not None: 35 | M = vmax 36 | delta = 0.1 * (M - m) 37 | 38 | ratio = 0.6 39 | figsizes = {"tight": (4, 3), "large": (16 * ratio, 10 * ratio)} 40 | figsize = figsizes[size] 41 | 42 | # plot parameters 43 | fig, ax = plt.subplots(figsize=figsize) 44 | ax.set_xlabel(x_key) 45 | ax.set_ylabel(y_key) 46 | plot_lines = Lines(resolution=50, smooth=0.0) 47 | plot_lines.LEGEND["loc"] = "upper left" 48 | # plot_lines.LEGEND["fontsize"] = "large" 49 | plot_lines.LEGEND["bbox_to_anchor"] = (0.75, 0.2) 50 | labels_logs = list(logs.keys()) 51 | colors = plot_lines(ax, domains, lines, labels_logs) 52 | ax.grid(True, alpha=0.5) 53 | ax.set_ylim(m - delta, M + delta) 54 | 55 | plt.show() 56 | fig.savefig( 57 | "plot.png", bbox_inches="tight", pad_inches=0.1, transparent=False, dpi=300 58 | ) 59 | plt.close(fig) 60 | 61 | 62 | def print_logs(logs, x_key, y_key, last_log_idx=None): 63 | delim = " " 64 | s = "" 65 | keys = [] 66 | y_keys = y_key.split("/") 67 | for name, log in logs.items(): 68 | log_idx = last_log_idx 69 | if log_idx is None: 70 | log_idx = len(log) - 1 71 | while y_keys[0] not in log[log_idx]: 72 | log_idx -= 1 73 | last_log = log[log_idx] 74 | log_x = last_log[x_key] 75 | log_y = last_log[y_keys[0]] 76 | for y_key in y_keys[1:]: 77 | log_y = log_y[y_key] 78 | s += f"{name}:\n" 79 | # s += f"{delim}{x_key}: {log_x}\n" 80 | s += f"{delim}{y_key}: {log_y:.4f}\n" 81 | keys += list(last_log.keys()) 82 | keys = list(set(keys)) 83 | keys = ", ".join(keys) 84 | s = f"keys: {keys}\n" + s 85 | print(s) 86 | 87 | 88 | def read_logs(root, logs_path): 89 | logs = {} 90 | for name, path in logs_path.items(): 91 | path = root / path 92 | if not path.exists(): 93 | print(f"Skipping {name} that has no log file") 94 | continue 95 | logs[name] = [] 96 | with open(path, "r") as f: 97 | for line in f.readlines(): 98 | d = json.loads(line) 99 | logs[name].append(d) 100 | return logs 101 | 102 | 103 | @click.command() 104 | @click.argument("log_path", type=str) 105 | @click.option("--x-key", default="epoch", type=str) 106 | @click.option("--y-key", default="val_mean_iou", type=str) 107 | @click.option("-s", "--size", default="large", type=str) 108 | @click.option("-ep", "--epoch", default=-1, type=int) 109 | @click.option("-plot", "--plot/--no-plot", default=True, is_flag=True) 110 | def main(log_path, x_key, y_key, size, epoch, plot): 111 | abs_path = Path(__file__).parent / log_path 112 | if abs_path.exists(): 113 | log_path = abs_path 114 | config = yaml.load(open(log_path, "r"), Loader=yaml.FullLoader) 115 | root = Path(config["root"]) 116 | logs_path = OrderedDict(config["logs"]) 117 | vmin = config.get("vmin", None) 118 | vmax = config.get("vmax", None) 119 | epochs = config.get("epochs", None) 120 | 121 | logs = read_logs(root, logs_path) 122 | if not logs: 123 | return 124 | print_logs(logs, x_key, y_key, epoch) 125 | if plot: 126 | plot_logs(logs, x_key, y_key, size, vmin, vmax, epochs) 127 | 128 | 129 | if __name__ == "__main__": 130 | main() 131 | -------------------------------------------------------------------------------- /OnlineRetraining/segm/utils/torch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import numpy as np 5 | from torch import distributed as dist 6 | 7 | """ 8 | GPU wrappers 9 | """ 10 | 11 | use_gpu = False 12 | gpu_id = 0 13 | device = None 14 | 15 | distributed = False 16 | dist_rank = 0 17 | world_size = 1 18 | 19 | def set_gpu_dist_mode(mode): 20 | global use_gpu 21 | global device 22 | global gpu_id 23 | global distributed 24 | global dist_rank 25 | global world_size 26 | 27 | if dist.is_available() and dist.is_initialized(): 28 | dist_rank = dist.get_rank() 29 | world_size = dist.get_world_size() 30 | else: 31 | dist_rank = 0 32 | world_size = 1 33 | 34 | distributed = world_size > 1 35 | use_gpu = mode 36 | 37 | device = dist_rank % torch.cuda.device_count() 38 | torch.backends.cudnn.benchmark = True 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |