├── .gitignore ├── LICENSE ├── README.md ├── attn_maps_dec.png ├── attn_maps_enc.png ├── gifs ├── breakdance-flare.gif ├── car-competition.gif ├── choreography.gif └── city-ride.gif ├── overview.png ├── requirements.txt ├── segm ├── config.py ├── config.yml ├── data │ ├── __init__.py │ ├── ade20k.py │ ├── base.py │ ├── cityscapes.py │ ├── config │ │ ├── ade20k.py │ │ ├── ade20k.yml │ │ ├── cityscapes.py │ │ ├── cityscapes.yml │ │ ├── pascal_context.py │ │ └── pascal_context.yml │ ├── factory.py │ ├── imagenet.py │ ├── loader.py │ ├── pascal_context.py │ └── utils.py ├── engine.py ├── eval │ ├── accuracy.py │ └── miou.py ├── inference.py ├── metrics.py ├── model │ ├── blocks.py │ ├── decoder.py │ ├── factory.py │ ├── segmenter.py │ ├── utils.py │ └── vit.py ├── optim │ ├── factory.py │ └── scheduler.py ├── scripts │ ├── prepare_ade20k.py │ ├── prepare_cityscapes.py │ ├── prepare_pcontext.py │ └── show_attn_map.py ├── train.py └── utils │ ├── distributed.py │ ├── download.py │ ├── lines.py │ ├── logger.py │ ├── logs.py │ └── torch.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/python 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 4 | 5 | ### Python ### 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | pytestdebug.log 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | doc/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | pythonenv* 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | # Pyre type checker 137 | .pyre/ 138 | 139 | # pytype static type analyzer 140 | .pytype/ 141 | 142 | # profiling data 143 | .prof 144 | 145 | # End of https://www.toptal.com/developers/gitignore/api/python 146 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Robin Strudel 4 | Copyright (c) INRIA 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Segmenter: Transformer for Semantic Segmentation 2 | 3 | ![Figure 1 from paper](./overview.png) 4 | 5 | [Segmenter: Transformer for Semantic Segmentation](https://arxiv.org/abs/2105.05633) 6 | by Robin Strudel*, Ricardo Garcia*, Ivan Laptev and Cordelia Schmid, ICCV 2021. 7 | 8 | *Equal Contribution 9 | 10 | 🔥 **Segmenter is now available on [MMSegmentation](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/segmenter).** 11 | 12 | ## Installation 13 | 14 | Define os environment variables pointing to your checkpoint and dataset directory, put in your `.bashrc`: 15 | ```sh 16 | export DATASET=/path/to/dataset/dir 17 | ``` 18 | 19 | Install [PyTorch 1.9](https://pytorch.org/) then `pip install .` at the root of this repository. 20 | 21 | To download ADE20K, use the following command: 22 | ```python 23 | python -m segm.scripts.prepare_ade20k $DATASET 24 | ``` 25 | 26 | ## Model Zoo 27 | We release models with a Vision Transformer backbone initialized from the [improved ViT](https://arxiv.org/abs/2106.10270) models. 28 | 29 | ### ADE20K 30 | 31 | Segmenter models with ViT backbone: 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 |
NamemIoU (SS/MS)# paramsResolutionFPSDownload
Seg-T-Mask/1638.1 / 38.87M512x51252.4modelconfiglog
Seg-S-Mask/1645.3 / 46.927M512x51234.8modelconfiglog
Seg-B-Mask/1648.5 / 50.0106M512x51224.1modelconfiglog
Seg-B/849.5 / 50.589M512x5124.2modelconfiglog
Seg-L-Mask/1651.8 / 53.6334M640x640-modelconfiglog
92 | 93 | Segmenter models with DeiT backbone: 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 |
NamemIoU (SS/MS)# paramsResolutionFPSDownload
Seg-B/1647.1 / 48.187M512x51227.3modelconfiglog
Seg-B-Mask/1648.7 / 50.1106M512x51224.1modelconfiglog
125 | 126 | ### Pascal Context 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 |
NamemIoU (SS/MS)# paramsResolutionFPSDownload
Seg-L-Mask/1658.1 / 59.0334M480x480-modelconfiglog
147 | 148 | ### Cityscapes 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 |
NamemIoU (SS/MS)# paramsResolutionFPSDownload
Seg-L-Mask/1679.1 / 81.3322M768x768-modelconfiglog
169 | 170 | ## Inference 171 | 172 | Download one checkpoint with its configuration in a common folder, for example `seg_tiny_mask`. 173 | 174 | You can generate segmentation maps from your own data with: 175 | ```python 176 | python -m segm.inference --model-path seg_tiny_mask/checkpoint.pth -i images/ -o segmaps/ 177 | ``` 178 | 179 | To evaluate on ADE20K, run the command: 180 | ```python 181 | # single-scale evaluation: 182 | python -m segm.eval.miou seg_tiny_mask/checkpoint.pth ade20k --singlescale 183 | # multi-scale evaluation: 184 | python -m segm.eval.miou seg_tiny_mask/checkpoint.pth ade20k --multiscale 185 | ``` 186 | 187 | ## Train 188 | 189 | Train `Seg-T-Mask/16` on ADE20K on a single GPU: 190 | ```python 191 | python -m segm.train --log-dir seg_tiny_mask --dataset ade20k \ 192 | --backbone vit_tiny_patch16_384 --decoder mask_transformer 193 | ``` 194 | 195 | To train `Seg-B-Mask/16`, simply set `vit_base_patch16_384` as backbone and launch the above command using a minimum of 4 V100 GPUs (~12 minutes per epoch) and up to 8 V100 GPUs (~7 minutes per epoch). The code uses [SLURM](https://slurm.schedmd.com/documentation.html) environment variables. 196 | 197 | ## Logs 198 | 199 | To plot the logs of your experiments, you can use 200 | ```python 201 | python -m segm.utils.logs logs.yml 202 | ``` 203 | 204 | with `logs.yml` located in `utils/` with the path to your experiments logs: 205 | ```yaml 206 | root: /path/to/checkpoints/ 207 | logs: 208 | seg-t: seg_tiny_mask/log.txt 209 | seg-b: seg_base_mask/log.txt 210 | ``` 211 | 212 | ## Attention Maps 213 | 214 | To visualize the attention maps for `Seg-T-Mask/16` encoder layer 0 and patch `(0, 21)`, you can use: 215 | 216 | ```python 217 | python -m segm.scripts.show_attn_map seg_tiny_mask/checkpoint.pth \ 218 | images/im0.jpg output_dir/ --layer-id 0 --x-patch 0 --y-patch 21 --enc 219 | ``` 220 | 221 | Different options are provided to select the generated attention maps: 222 | * `--enc` or `--dec`: Select encoder or decoder attention maps respectively. 223 | * `--patch` or `--cls`: `--patch` generates attention maps for the patch with coordinates `(x_patch, y_patch)`. `--cls` combined with `--enc` generates attention maps for the CLS token of the encoder. `--cls` combined with `--dec` generates maps for each class embedding of the decoder. 224 | * `--x-patch` and `--y-patch`: Coordinates of the patch to draw attention maps from. This flag is ignored when `--cls` is used. 225 | * `--layer-id`: Select the layer for which the attention maps are generated. 226 | 227 | For example, to generate attention maps for the decoder class embeddings, you can use: 228 | 229 | ```python 230 | python -m segm.scripts.show_attn_map seg_tiny_mask/checkpoint.pth \ 231 | images/im0.jpg output_dir/ --layer-id 0 --dec --cls 232 | ``` 233 | 234 | Attention maps for patch `(0, 21)` in `Seg-L-Mask/16` encoder layers 1, 4, 8, 12 and 16: 235 | 236 | ![Attention maps of patch x=8 and y=21 and encoder layers 1, 4, 8, 12 and 16](./attn_maps_enc.png) 237 | 238 | Attention maps for the class embeddings in `Seg-L-Mask/16` decoder layer 0: 239 | 240 | ![Attention maps of cls tokens 7, 15, 18, 22, 36 and 57 and Mask decoder layer 0](./attn_maps_dec.png) 241 | 242 | ## Video Segmentation 243 | 244 | Zero shot video segmentation on [DAVIS](https://davischallenge.org/) video dataset with Seg-B-Mask/16 model trained on [ADE20K](https://groups.csail.mit.edu/vision/datasets/ADE20K/). 245 | 246 |

247 | 248 | 249 |

250 |

251 | 252 | 253 |

254 | 255 | ## BibTex 256 | 257 | ``` 258 | @article{strudel2021, 259 | title={Segmenter: Transformer for Semantic Segmentation}, 260 | author={Strudel, Robin and Garcia, Ricardo and Laptev, Ivan and Schmid, Cordelia}, 261 | journal={arXiv preprint arXiv:2105.05633}, 262 | year={2021} 263 | } 264 | ``` 265 | 266 | 267 | ## Acknowledgements 268 | 269 | The Vision Transformer code is based on [timm](https://github.com/rwightman/pytorch-image-models) library and the semantic segmentation training and evaluation pipeline 270 | is using [mmsegmentation](https://github.com/open-mmlab/mmsegmentation). 271 | -------------------------------------------------------------------------------- /attn_maps_dec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rstrudel/segmenter/20d1bfad354165ee45c3f65972a4d9c131f58d53/attn_maps_dec.png -------------------------------------------------------------------------------- /attn_maps_enc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rstrudel/segmenter/20d1bfad354165ee45c3f65972a4d9c131f58d53/attn_maps_enc.png -------------------------------------------------------------------------------- /gifs/breakdance-flare.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rstrudel/segmenter/20d1bfad354165ee45c3f65972a4d9c131f58d53/gifs/breakdance-flare.gif -------------------------------------------------------------------------------- /gifs/car-competition.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rstrudel/segmenter/20d1bfad354165ee45c3f65972a4d9c131f58d53/gifs/car-competition.gif -------------------------------------------------------------------------------- /gifs/choreography.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rstrudel/segmenter/20d1bfad354165ee45c3f65972a4d9c131f58d53/gifs/choreography.gif -------------------------------------------------------------------------------- /gifs/city-ride.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rstrudel/segmenter/20d1bfad354165ee45c3f65972a4d9c131f58d53/gifs/city-ride.gif -------------------------------------------------------------------------------- /overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rstrudel/segmenter/20d1bfad354165ee45c3f65972a4d9c131f58d53/overview.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | click 3 | numpy 4 | einops 5 | python-hostlist 6 | tqdm 7 | requests 8 | pyyaml 9 | timm == 0.4.12 10 | mmcv==1.3.8 11 | mmsegmentation==0.14.1 12 | -------------------------------------------------------------------------------- /segm/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from pathlib import Path 3 | 4 | import os 5 | 6 | 7 | def load_config(): 8 | return yaml.load( 9 | open(Path(__file__).parent / "config.yml", "r"), Loader=yaml.FullLoader 10 | ) 11 | 12 | 13 | def check_os_environ(key, use): 14 | if key not in os.environ: 15 | raise ValueError( 16 | f"{key} is not defined in the os variables, it is required for {use}." 17 | ) 18 | 19 | 20 | def dataset_dir(): 21 | check_os_environ("DATASET", "data loading") 22 | return os.environ["DATASET"] 23 | -------------------------------------------------------------------------------- /segm/config.yml: -------------------------------------------------------------------------------- 1 | model: 2 | # deit 3 | deit_tiny_distilled_patch16_224: 4 | image_size: 224 5 | patch_size: 16 6 | d_model: 192 7 | n_heads: 3 8 | n_layers: 12 9 | normalization: deit 10 | distilled: true 11 | deit_small_distilled_patch16_224: 12 | image_size: 224 13 | patch_size: 16 14 | d_model: 384 15 | n_heads: 6 16 | n_layers: 12 17 | normalization: deit 18 | distilled: true 19 | deit_base_distilled_patch16_224: 20 | image_size: 224 21 | patch_size: 16 22 | d_model: 768 23 | n_heads: 12 24 | n_layers: 12 25 | normalization: deit 26 | distilled: true 27 | deit_base_distilled_patch16_384: 28 | image_size: 384 29 | patch_size: 16 30 | d_model: 768 31 | n_heads: 12 32 | n_layers: 12 33 | normalization: deit 34 | distilled: true 35 | # vit 36 | vit_base_patch8_384: 37 | image_size: 384 38 | patch_size: 8 39 | d_model: 768 40 | n_heads: 12 41 | n_layers: 12 42 | normalization: vit 43 | distilled: false 44 | vit_tiny_patch16_384: 45 | image_size: 384 46 | patch_size: 16 47 | d_model: 192 48 | n_heads: 3 49 | n_layers: 12 50 | normalization: vit 51 | distilled: false 52 | vit_small_patch16_384: 53 | image_size: 384 54 | patch_size: 16 55 | d_model: 384 56 | n_heads: 6 57 | n_layers: 12 58 | normalization: vit 59 | distilled: false 60 | vit_base_patch16_384: 61 | image_size: 384 62 | patch_size: 16 63 | d_model: 768 64 | n_heads: 12 65 | n_layers: 12 66 | normalization: vit 67 | distilled: false 68 | vit_large_patch16_384: 69 | image_size: 384 70 | patch_size: 16 71 | d_model: 1024 72 | n_heads: 16 73 | n_layers: 24 74 | normalization: vit 75 | vit_small_patch32_384: 76 | image_size: 384 77 | patch_size: 32 78 | d_model: 384 79 | n_heads: 6 80 | n_layers: 12 81 | normalization: vit 82 | distilled: false 83 | vit_base_patch32_384: 84 | image_size: 384 85 | patch_size: 32 86 | d_model: 768 87 | n_heads: 12 88 | n_layers: 12 89 | normalization: vit 90 | vit_large_patch32_384: 91 | image_size: 384 92 | patch_size: 32 93 | d_model: 1024 94 | n_heads: 16 95 | n_layers: 24 96 | normalization: vit 97 | decoder: 98 | linear: {} 99 | deeplab_dec: 100 | encoder_layer: -1 101 | mask_transformer: 102 | drop_path_rate: 0.0 103 | dropout: 0.1 104 | n_layers: 2 105 | dataset: 106 | ade20k: 107 | epochs: 64 108 | eval_freq: 2 109 | batch_size: 8 110 | learning_rate: 0.001 111 | im_size: 512 112 | crop_size: 512 113 | window_size: 512 114 | window_stride: 512 115 | pascal_context: 116 | epochs: 256 117 | eval_freq: 8 118 | batch_size: 16 119 | learning_rate: 0.001 120 | im_size: 520 121 | crop_size: 480 122 | window_size: 480 123 | window_stride: 320 124 | cityscapes: 125 | epochs: 216 126 | eval_freq: 4 127 | batch_size: 8 128 | learning_rate: 0.01 129 | im_size: 1024 130 | crop_size: 768 131 | window_size: 768 132 | window_stride: 512 133 | -------------------------------------------------------------------------------- /segm/data/__init__.py: -------------------------------------------------------------------------------- 1 | from segm.data.loader import Loader 2 | 3 | from segm.data.imagenet import ImagenetDataset 4 | from segm.data.ade20k import ADE20KSegmentation 5 | from segm.data.pascal_context import PascalContextDataset 6 | from segm.data.cityscapes import CityscapesDataset 7 | -------------------------------------------------------------------------------- /segm/data/ade20k.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from segm.data.base import BaseMMSeg 4 | from segm.data import utils 5 | from segm.config import dataset_dir 6 | 7 | 8 | ADE20K_CONFIG_PATH = Path(__file__).parent / "config" / "ade20k.py" 9 | ADE20K_CATS_PATH = Path(__file__).parent / "config" / "ade20k.yml" 10 | 11 | 12 | class ADE20KSegmentation(BaseMMSeg): 13 | def __init__(self, image_size, crop_size, split, **kwargs): 14 | super().__init__( 15 | image_size, 16 | crop_size, 17 | split, 18 | ADE20K_CONFIG_PATH, 19 | **kwargs, 20 | ) 21 | self.names, self.colors = utils.dataset_cat_description(ADE20K_CATS_PATH) 22 | self.n_cls = 150 23 | self.ignore_label = 0 24 | self.reduce_zero_label = True 25 | 26 | def update_default_config(self, config): 27 | root_dir = dataset_dir() 28 | path = Path(root_dir) / "ade20k" 29 | config.data_root = path 30 | if self.split == "train": 31 | config.data.train.data_root = path / "ADEChallengeData2016" 32 | elif self.split == "trainval": 33 | config.data.trainval.data_root = path / "ADEChallengeData2016" 34 | elif self.split == "val": 35 | config.data.val.data_root = path / "ADEChallengeData2016" 36 | elif self.split == "test": 37 | config.data.test.data_root = path / "release_test" 38 | config = super().update_default_config(config) 39 | return config 40 | 41 | def test_post_process(self, labels): 42 | return labels + 1 43 | -------------------------------------------------------------------------------- /segm/data/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pathlib import Path 3 | from PIL import Image, ImageOps, ImageFilter 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | import torchvision.transforms.functional as F 8 | 9 | from mmseg.datasets import build_dataset 10 | import mmcv 11 | from mmcv.utils import Config 12 | 13 | 14 | from segm.data.utils import STATS, IGNORE_LABEL 15 | from segm.data import utils 16 | 17 | 18 | class BaseMMSeg(Dataset): 19 | def __init__( 20 | self, 21 | image_size, 22 | crop_size, 23 | split, 24 | config_path, 25 | normalization, 26 | **kwargs, 27 | ): 28 | super().__init__() 29 | self.image_size = image_size 30 | self.crop_size = crop_size 31 | self.split = split 32 | self.normalization = STATS[normalization].copy() 33 | self.ignore_label = None 34 | for k, v in self.normalization.items(): 35 | v = np.round(255 * np.array(v), 2) 36 | self.normalization[k] = tuple(v) 37 | print(f"Use normalization: {self.normalization}") 38 | 39 | config = Config.fromfile(config_path) 40 | 41 | self.ratio = config.max_ratio 42 | self.dataset = None 43 | self.config = self.update_default_config(config) 44 | self.dataset = build_dataset(getattr(self.config.data, f"{self.split}")) 45 | 46 | def update_default_config(self, config): 47 | 48 | train_splits = ["train", "trainval"] 49 | if self.split in train_splits: 50 | config_pipeline = getattr(config, f"train_pipeline") 51 | else: 52 | config_pipeline = getattr(config, f"{self.split}_pipeline") 53 | 54 | img_scale = (self.ratio * self.image_size, self.image_size) 55 | if self.split not in train_splits: 56 | assert config_pipeline[1]["type"] == "MultiScaleFlipAug" 57 | config_pipeline = config_pipeline[1]["transforms"] 58 | for i, op in enumerate(config_pipeline): 59 | op_type = op["type"] 60 | if op_type == "Resize": 61 | op["img_scale"] = img_scale 62 | elif op_type == "RandomCrop": 63 | op["crop_size"] = ( 64 | self.crop_size, 65 | self.crop_size, 66 | ) 67 | elif op_type == "Normalize": 68 | op["mean"] = self.normalization["mean"] 69 | op["std"] = self.normalization["std"] 70 | elif op_type == "Pad": 71 | op["size"] = (self.crop_size, self.crop_size) 72 | config_pipeline[i] = op 73 | if self.split == "train": 74 | config.data.train.pipeline = config_pipeline 75 | elif self.split == "trainval": 76 | config.data.trainval.pipeline = config_pipeline 77 | elif self.split == "val": 78 | config.data.val.pipeline[1]["img_scale"] = img_scale 79 | config.data.val.pipeline[1]["transforms"] = config_pipeline 80 | elif self.split == "test": 81 | config.data.test.pipeline[1]["img_scale"] = img_scale 82 | config.data.test.pipeline[1]["transforms"] = config_pipeline 83 | config.data.test.test_mode = True 84 | else: 85 | raise ValueError(f"Unknown split: {self.split}") 86 | return config 87 | 88 | def set_multiscale_mode(self): 89 | self.config.data.val.pipeline[1]["img_ratios"] = [ 90 | 0.5, 91 | 0.75, 92 | 1.0, 93 | 1.25, 94 | 1.5, 95 | 1.75, 96 | ] 97 | self.config.data.val.pipeline[1]["flip"] = True 98 | self.config.data.test.pipeline[1]["img_ratios"] = [ 99 | 0.5, 100 | 0.75, 101 | 1.0, 102 | 1.25, 103 | 1.5, 104 | 1.75, 105 | ] 106 | self.config.data.test.pipeline[1]["flip"] = True 107 | self.dataset = build_dataset(getattr(self.config.data, f"{self.split}")) 108 | 109 | def __getitem__(self, idx): 110 | data = self.dataset[idx] 111 | 112 | train_splits = ["train", "trainval"] 113 | 114 | if self.split in train_splits: 115 | im = data["img"].data 116 | seg = data["gt_semantic_seg"].data.squeeze(0) 117 | else: 118 | im = [im.data for im in data["img"]] 119 | seg = None 120 | 121 | out = dict(im=im) 122 | if self.split in train_splits: 123 | out["segmentation"] = seg 124 | else: 125 | im_metas = [meta.data for meta in data["img_metas"]] 126 | out["im_metas"] = im_metas 127 | out["colors"] = self.colors 128 | 129 | return out 130 | 131 | def get_gt_seg_maps(self): 132 | dataset = self.dataset 133 | gt_seg_maps = {} 134 | for img_info in dataset.img_infos: 135 | seg_map = Path(dataset.ann_dir) / img_info["ann"]["seg_map"] 136 | gt_seg_map = mmcv.imread(seg_map, flag="unchanged", backend="pillow") 137 | gt_seg_map[gt_seg_map == self.ignore_label] = IGNORE_LABEL 138 | if self.reduce_zero_label: 139 | gt_seg_map[gt_seg_map != IGNORE_LABEL] -= 1 140 | gt_seg_maps[img_info["filename"]] = gt_seg_map 141 | return gt_seg_maps 142 | 143 | def __len__(self): 144 | return len(self.dataset) 145 | 146 | @property 147 | def unwrapped(self): 148 | return self 149 | 150 | def set_epoch(self, epoch): 151 | pass 152 | 153 | def get_diagnostics(self, logger): 154 | pass 155 | 156 | def get_snapshot(self): 157 | return {} 158 | 159 | def end_epoch(self, epoch): 160 | return 161 | -------------------------------------------------------------------------------- /segm/data/cityscapes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | try: 4 | import cityscapesscripts.helpers.labels as CSLabels 5 | except: 6 | pass 7 | 8 | from pathlib import Path 9 | from segm.data.base import BaseMMSeg 10 | from segm.data import utils 11 | from segm.config import dataset_dir 12 | 13 | CITYSCAPES_CONFIG_PATH = Path(__file__).parent / "config" / "cityscapes.py" 14 | CITYSCAPES_CATS_PATH = Path(__file__).parent / "config" / "cityscapes.yml" 15 | 16 | 17 | class CityscapesDataset(BaseMMSeg): 18 | def __init__(self, image_size, crop_size, split, **kwargs): 19 | super().__init__(image_size, crop_size, split, CITYSCAPES_CONFIG_PATH, **kwargs) 20 | self.names, self.colors = utils.dataset_cat_description(CITYSCAPES_CATS_PATH) 21 | self.n_cls = 19 22 | self.ignore_label = 255 23 | self.reduce_zero_label = False 24 | 25 | def update_default_config(self, config): 26 | 27 | root_dir = dataset_dir() 28 | path = Path(root_dir) / "cityscapes" 29 | config.data_root = path 30 | 31 | config.data[self.split]["data_root"] = path 32 | config = super().update_default_config(config) 33 | 34 | return config 35 | 36 | def test_post_process(self, labels): 37 | labels_copy = np.copy(labels) 38 | cats = np.unique(labels_copy) 39 | for cat in cats: 40 | labels_copy[labels == cat] = CSLabels.trainId2label[cat].id 41 | return labels_copy 42 | -------------------------------------------------------------------------------- /segm/data/config/ade20k.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = "ADE20KDataset" 3 | data_root = "data/ade/ADEChallengeData2016" 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True 6 | ) 7 | crop_size = (512, 512) 8 | max_ratio = 4 9 | train_pipeline = [ 10 | dict(type="LoadImageFromFile"), 11 | dict(type="LoadAnnotations", reduce_zero_label=True), 12 | dict(type="Resize", img_scale=(512 * max_ratio, 512), ratio_range=(0.5, 2.0)), 13 | dict(type="RandomCrop", crop_size=crop_size, cat_max_ratio=0.75), 14 | dict(type="RandomFlip", prob=0.5), 15 | dict(type="PhotoMetricDistortion"), 16 | dict(type="Normalize", **img_norm_cfg), 17 | dict(type="Pad", size=crop_size, pad_val=0, seg_pad_val=255), 18 | dict(type="DefaultFormatBundle"), 19 | dict(type="Collect", keys=["img", "gt_semantic_seg"]), 20 | ] 21 | val_pipeline = [ 22 | dict(type="LoadImageFromFile"), 23 | dict( 24 | type="MultiScaleFlipAug", 25 | img_scale=(512 * max_ratio, 512), 26 | flip=False, 27 | transforms=[ 28 | dict(type="Resize", keep_ratio=True), 29 | dict(type="RandomFlip"), 30 | dict(type="Normalize", **img_norm_cfg), 31 | dict(type="ImageToTensor", keys=["img"]), 32 | dict(type="Collect", keys=["img"]), 33 | ], 34 | ), 35 | ] 36 | test_pipeline = [ 37 | dict(type="LoadImageFromFile"), 38 | dict( 39 | type="MultiScaleFlipAug", 40 | img_scale=(512 * max_ratio, 512), 41 | flip=False, 42 | transforms=[ 43 | dict(type="Resize", keep_ratio=True), 44 | dict(type="RandomFlip"), 45 | dict(type="Normalize", **img_norm_cfg), 46 | dict(type="ImageToTensor", keys=["img"]), 47 | dict(type="Collect", keys=["img"]), 48 | ], 49 | ), 50 | ] 51 | data = dict( 52 | samples_per_gpu=4, 53 | workers_per_gpu=4, 54 | train=dict( 55 | type=dataset_type, 56 | data_root=data_root, 57 | img_dir="images/training", 58 | ann_dir="annotations/training", 59 | pipeline=train_pipeline, 60 | ), 61 | trainval=dict( 62 | type=dataset_type, 63 | data_root=data_root, 64 | img_dir=["images/training", "images/validation"], 65 | ann_dir=["annotations/training", "annotations/validation"], 66 | pipeline=train_pipeline, 67 | ), 68 | val=dict( 69 | type=dataset_type, 70 | data_root=data_root, 71 | img_dir="images/validation", 72 | ann_dir="annotations/validation", 73 | pipeline=val_pipeline, 74 | ), 75 | test=dict( 76 | type=dataset_type, 77 | data_root=data_root, 78 | img_dir="testing", 79 | pipeline=test_pipeline, 80 | ), 81 | ) 82 | -------------------------------------------------------------------------------- /segm/data/config/ade20k.yml: -------------------------------------------------------------------------------- 1 | - color: 2 | - 120 3 | - 120 4 | - 120 5 | id: 0 6 | isthing: 0 7 | name: wall 8 | - color: 9 | - 180 10 | - 120 11 | - 120 12 | id: 1 13 | isthing: 0 14 | name: building, edifice 15 | - color: 16 | - 6 17 | - 230 18 | - 230 19 | id: 2 20 | isthing: 0 21 | name: sky 22 | - color: 23 | - 80 24 | - 50 25 | - 50 26 | id: 3 27 | isthing: 0 28 | name: floor, flooring 29 | - color: 30 | - 4 31 | - 200 32 | - 3 33 | id: 4 34 | isthing: 0 35 | name: tree 36 | - color: 37 | - 120 38 | - 120 39 | - 80 40 | id: 5 41 | isthing: 0 42 | name: ceiling 43 | - color: 44 | - 140 45 | - 140 46 | - 140 47 | id: 6 48 | isthing: 0 49 | name: road, route 50 | - color: 51 | - 204 52 | - 5 53 | - 255 54 | id: 7 55 | isthing: 0 56 | name: bed 57 | - color: 58 | - 230 59 | - 230 60 | - 230 61 | id: 8 62 | isthing: 0 63 | name: windowpane, window 64 | - color: 65 | - 4 66 | - 250 67 | - 7 68 | id: 9 69 | isthing: 0 70 | name: grass 71 | - color: 72 | - 224 73 | - 5 74 | - 255 75 | id: 10 76 | isthing: 0 77 | name: cabinet 78 | - color: 79 | - 235 80 | - 255 81 | - 7 82 | id: 11 83 | isthing: 0 84 | name: sidewalk, pavement 85 | - color: 86 | - 150 87 | - 5 88 | - 61 89 | id: 12 90 | isthing: 0 91 | name: person, individual, someone, somebody, mortal, soul 92 | - color: 93 | - 120 94 | - 120 95 | - 70 96 | id: 13 97 | isthing: 0 98 | name: earth, ground 99 | - color: 100 | - 8 101 | - 255 102 | - 51 103 | id: 14 104 | isthing: 0 105 | name: door, double door 106 | - color: 107 | - 255 108 | - 6 109 | - 82 110 | id: 15 111 | isthing: 0 112 | name: table 113 | - color: 114 | - 143 115 | - 255 116 | - 140 117 | id: 16 118 | isthing: 0 119 | name: mountain, mount 120 | - color: 121 | - 204 122 | - 255 123 | - 4 124 | id: 17 125 | isthing: 0 126 | name: plant, flora, plant life 127 | - color: 128 | - 255 129 | - 51 130 | - 7 131 | id: 18 132 | isthing: 0 133 | name: curtain, drape, drapery, mantle, pall 134 | - color: 135 | - 204 136 | - 70 137 | - 3 138 | id: 19 139 | isthing: 0 140 | name: chair 141 | - color: 142 | - 0 143 | - 102 144 | - 200 145 | id: 20 146 | isthing: 0 147 | name: car, auto, automobile, machine, motorcar 148 | - color: 149 | - 61 150 | - 230 151 | - 250 152 | id: 21 153 | isthing: 0 154 | name: water 155 | - color: 156 | - 255 157 | - 6 158 | - 51 159 | id: 22 160 | isthing: 0 161 | name: painting, picture 162 | - color: 163 | - 11 164 | - 102 165 | - 255 166 | id: 23 167 | isthing: 0 168 | name: sofa, couch, lounge 169 | - color: 170 | - 255 171 | - 7 172 | - 71 173 | id: 24 174 | isthing: 0 175 | name: shelf 176 | - color: 177 | - 255 178 | - 9 179 | - 224 180 | id: 25 181 | isthing: 0 182 | name: house 183 | - color: 184 | - 9 185 | - 7 186 | - 230 187 | id: 26 188 | isthing: 0 189 | name: sea 190 | - color: 191 | - 220 192 | - 220 193 | - 220 194 | id: 27 195 | isthing: 0 196 | name: mirror 197 | - color: 198 | - 255 199 | - 9 200 | - 92 201 | id: 28 202 | isthing: 0 203 | name: rug, carpet, carpeting 204 | - color: 205 | - 112 206 | - 9 207 | - 255 208 | id: 29 209 | isthing: 0 210 | name: field 211 | - color: 212 | - 8 213 | - 255 214 | - 214 215 | id: 30 216 | isthing: 0 217 | name: armchair 218 | - color: 219 | - 7 220 | - 255 221 | - 224 222 | id: 31 223 | isthing: 0 224 | name: seat 225 | - color: 226 | - 255 227 | - 184 228 | - 6 229 | id: 32 230 | isthing: 0 231 | name: fence, fencing 232 | - color: 233 | - 10 234 | - 255 235 | - 71 236 | id: 33 237 | isthing: 0 238 | name: desk 239 | - color: 240 | - 255 241 | - 41 242 | - 10 243 | id: 34 244 | isthing: 0 245 | name: rock, stone 246 | - color: 247 | - 7 248 | - 255 249 | - 255 250 | id: 35 251 | isthing: 0 252 | name: wardrobe, closet, press 253 | - color: 254 | - 224 255 | - 255 256 | - 8 257 | id: 36 258 | isthing: 0 259 | name: lamp 260 | - color: 261 | - 102 262 | - 8 263 | - 255 264 | id: 37 265 | isthing: 0 266 | name: bathtub, bathing tub, bath, tub 267 | - color: 268 | - 255 269 | - 61 270 | - 6 271 | id: 38 272 | isthing: 0 273 | name: railing, rail 274 | - color: 275 | - 255 276 | - 194 277 | - 7 278 | id: 39 279 | isthing: 0 280 | name: cushion 281 | - color: 282 | - 255 283 | - 122 284 | - 8 285 | id: 40 286 | isthing: 0 287 | name: base, pedestal, stand 288 | - color: 289 | - 0 290 | - 255 291 | - 20 292 | id: 41 293 | isthing: 0 294 | name: box 295 | - color: 296 | - 255 297 | - 8 298 | - 41 299 | id: 42 300 | isthing: 0 301 | name: column, pillar 302 | - color: 303 | - 255 304 | - 5 305 | - 153 306 | id: 43 307 | isthing: 0 308 | name: signboard, sign 309 | - color: 310 | - 6 311 | - 51 312 | - 255 313 | id: 44 314 | isthing: 0 315 | name: chest of drawers, chest, bureau, dresser 316 | - color: 317 | - 235 318 | - 12 319 | - 255 320 | id: 45 321 | isthing: 0 322 | name: counter 323 | - color: 324 | - 160 325 | - 150 326 | - 20 327 | id: 46 328 | isthing: 0 329 | name: sand 330 | - color: 331 | - 0 332 | - 163 333 | - 255 334 | id: 47 335 | isthing: 0 336 | name: sink 337 | - color: 338 | - 140 339 | - 140 340 | - 140 341 | id: 48 342 | isthing: 0 343 | name: skyscraper 344 | - color: 345 | - 250 346 | - 10 347 | - 15 348 | id: 49 349 | isthing: 0 350 | name: fireplace, hearth, open fireplace 351 | - color: 352 | - 20 353 | - 255 354 | - 0 355 | id: 50 356 | isthing: 0 357 | name: refrigerator, icebox 358 | - color: 359 | - 31 360 | - 255 361 | - 0 362 | id: 51 363 | isthing: 0 364 | name: grandstand, covered stand 365 | - color: 366 | - 255 367 | - 31 368 | - 0 369 | id: 52 370 | isthing: 0 371 | name: path 372 | - color: 373 | - 255 374 | - 224 375 | - 0 376 | id: 53 377 | isthing: 0 378 | name: stairs, steps 379 | - color: 380 | - 153 381 | - 255 382 | - 0 383 | id: 54 384 | isthing: 0 385 | name: runway 386 | - color: 387 | - 0 388 | - 0 389 | - 255 390 | id: 55 391 | isthing: 0 392 | name: case, display case, showcase, vitrine 393 | - color: 394 | - 255 395 | - 71 396 | - 0 397 | id: 56 398 | isthing: 0 399 | name: pool table, billiard table, snooker table 400 | - color: 401 | - 0 402 | - 235 403 | - 255 404 | id: 57 405 | isthing: 0 406 | name: pillow 407 | - color: 408 | - 0 409 | - 173 410 | - 255 411 | id: 58 412 | isthing: 0 413 | name: screen door, screen 414 | - color: 415 | - 31 416 | - 0 417 | - 255 418 | id: 59 419 | isthing: 0 420 | name: stairway, staircase 421 | - color: 422 | - 11 423 | - 200 424 | - 200 425 | id: 60 426 | isthing: 0 427 | name: river 428 | - color: 429 | - 255 430 | - 82 431 | - 0 432 | id: 61 433 | isthing: 0 434 | name: bridge, span 435 | - color: 436 | - 0 437 | - 255 438 | - 245 439 | id: 62 440 | isthing: 0 441 | name: bookcase 442 | - color: 443 | - 0 444 | - 61 445 | - 255 446 | id: 63 447 | isthing: 0 448 | name: blind, screen 449 | - color: 450 | - 0 451 | - 255 452 | - 112 453 | id: 64 454 | isthing: 0 455 | name: coffee table, cocktail table 456 | - color: 457 | - 0 458 | - 255 459 | - 133 460 | id: 65 461 | isthing: 0 462 | name: toilet, can, commode, crapper, pot, potty, stool, throne 463 | - color: 464 | - 255 465 | - 0 466 | - 0 467 | id: 66 468 | isthing: 0 469 | name: flower 470 | - color: 471 | - 255 472 | - 163 473 | - 0 474 | id: 67 475 | isthing: 0 476 | name: book 477 | - color: 478 | - 255 479 | - 102 480 | - 0 481 | id: 68 482 | isthing: 0 483 | name: hill 484 | - color: 485 | - 194 486 | - 255 487 | - 0 488 | id: 69 489 | isthing: 0 490 | name: bench 491 | - color: 492 | - 0 493 | - 143 494 | - 255 495 | id: 70 496 | isthing: 0 497 | name: countertop 498 | - color: 499 | - 51 500 | - 255 501 | - 0 502 | id: 71 503 | isthing: 0 504 | name: stove, kitchen stove, range, kitchen range, cooking stove 505 | - color: 506 | - 0 507 | - 82 508 | - 255 509 | id: 72 510 | isthing: 0 511 | name: palm, palm tree 512 | - color: 513 | - 0 514 | - 255 515 | - 41 516 | id: 73 517 | isthing: 0 518 | name: kitchen island 519 | - color: 520 | - 0 521 | - 255 522 | - 173 523 | id: 74 524 | isthing: 0 525 | name: computer, computing machine, computing device, data processor, electronic 526 | computer, information processing system 527 | - color: 528 | - 10 529 | - 0 530 | - 255 531 | id: 75 532 | isthing: 0 533 | name: swivel chair 534 | - color: 535 | - 173 536 | - 255 537 | - 0 538 | id: 76 539 | isthing: 0 540 | name: boat 541 | - color: 542 | - 0 543 | - 255 544 | - 153 545 | id: 77 546 | isthing: 0 547 | name: bar 548 | - color: 549 | - 255 550 | - 92 551 | - 0 552 | id: 78 553 | isthing: 0 554 | name: arcade machine 555 | - color: 556 | - 255 557 | - 0 558 | - 255 559 | id: 79 560 | isthing: 0 561 | name: hovel, hut, hutch, shack, shanty 562 | - color: 563 | - 255 564 | - 0 565 | - 245 566 | id: 80 567 | isthing: 0 568 | name: bus, autobus, coach, charabanc, double-decker, jitney, motorbus, motorcoach, 569 | omnibus, passenger vehicle 570 | - color: 571 | - 255 572 | - 0 573 | - 102 574 | id: 81 575 | isthing: 0 576 | name: towel 577 | - color: 578 | - 255 579 | - 173 580 | - 0 581 | id: 82 582 | isthing: 0 583 | name: light, light source 584 | - color: 585 | - 255 586 | - 0 587 | - 20 588 | id: 83 589 | isthing: 0 590 | name: truck, motortruck 591 | - color: 592 | - 255 593 | - 184 594 | - 184 595 | id: 84 596 | isthing: 0 597 | name: tower 598 | - color: 599 | - 0 600 | - 31 601 | - 255 602 | id: 85 603 | isthing: 0 604 | name: chandelier, pendant, pendent 605 | - color: 606 | - 0 607 | - 255 608 | - 61 609 | id: 86 610 | isthing: 0 611 | name: awning, sunshade, sunblind 612 | - color: 613 | - 0 614 | - 71 615 | - 255 616 | id: 87 617 | isthing: 0 618 | name: streetlight, street lamp 619 | - color: 620 | - 255 621 | - 0 622 | - 204 623 | id: 88 624 | isthing: 0 625 | name: booth, cubicle, stall, kiosk 626 | - color: 627 | - 0 628 | - 255 629 | - 194 630 | id: 89 631 | isthing: 0 632 | name: television, television receiver, television set, tv, tv set, idiot box, boob 633 | tube, telly, goggle box 634 | - color: 635 | - 0 636 | - 255 637 | - 82 638 | id: 90 639 | isthing: 0 640 | name: airplane, aeroplane, plane 641 | - color: 642 | - 0 643 | - 10 644 | - 255 645 | id: 91 646 | isthing: 0 647 | name: dirt track 648 | - color: 649 | - 0 650 | - 112 651 | - 255 652 | id: 92 653 | isthing: 0 654 | name: apparel, wearing apparel, dress, clothes 655 | - color: 656 | - 51 657 | - 0 658 | - 255 659 | id: 93 660 | isthing: 0 661 | name: pole 662 | - color: 663 | - 0 664 | - 194 665 | - 255 666 | id: 94 667 | isthing: 0 668 | name: land, ground, soil 669 | - color: 670 | - 0 671 | - 122 672 | - 255 673 | id: 95 674 | isthing: 0 675 | name: bannister, banister, balustrade, balusters, handrail 676 | - color: 677 | - 0 678 | - 255 679 | - 163 680 | id: 96 681 | isthing: 0 682 | name: escalator, moving staircase, moving stairway 683 | - color: 684 | - 255 685 | - 153 686 | - 0 687 | id: 97 688 | isthing: 0 689 | name: ottoman, pouf, pouffe, puff, hassock 690 | - color: 691 | - 0 692 | - 255 693 | - 10 694 | id: 98 695 | isthing: 0 696 | name: bottle 697 | - color: 698 | - 255 699 | - 112 700 | - 0 701 | id: 99 702 | isthing: 0 703 | name: buffet, counter, sideboard 704 | - color: 705 | - 143 706 | - 255 707 | - 0 708 | id: 100 709 | isthing: 0 710 | name: poster, posting, placard, notice, bill, card 711 | - color: 712 | - 82 713 | - 0 714 | - 255 715 | id: 101 716 | isthing: 0 717 | name: stage 718 | - color: 719 | - 163 720 | - 255 721 | - 0 722 | id: 102 723 | isthing: 0 724 | name: van 725 | - color: 726 | - 255 727 | - 235 728 | - 0 729 | id: 103 730 | isthing: 0 731 | name: ship 732 | - color: 733 | - 8 734 | - 184 735 | - 170 736 | id: 104 737 | isthing: 0 738 | name: fountain 739 | - color: 740 | - 133 741 | - 0 742 | - 255 743 | id: 105 744 | isthing: 0 745 | name: conveyer belt, conveyor belt, conveyer, conveyor, transporter 746 | - color: 747 | - 0 748 | - 255 749 | - 92 750 | id: 106 751 | isthing: 0 752 | name: canopy 753 | - color: 754 | - 184 755 | - 0 756 | - 255 757 | id: 107 758 | isthing: 0 759 | name: washer, automatic washer, washing machine 760 | - color: 761 | - 255 762 | - 0 763 | - 31 764 | id: 108 765 | isthing: 0 766 | name: plaything, toy 767 | - color: 768 | - 0 769 | - 184 770 | - 255 771 | id: 109 772 | isthing: 0 773 | name: swimming pool, swimming bath, natatorium 774 | - color: 775 | - 0 776 | - 214 777 | - 255 778 | id: 110 779 | isthing: 0 780 | name: stool 781 | - color: 782 | - 255 783 | - 0 784 | - 112 785 | id: 111 786 | isthing: 0 787 | name: barrel, cask 788 | - color: 789 | - 92 790 | - 255 791 | - 0 792 | id: 112 793 | isthing: 0 794 | name: basket, handbasket 795 | - color: 796 | - 0 797 | - 224 798 | - 255 799 | id: 113 800 | isthing: 0 801 | name: waterfall, falls 802 | - color: 803 | - 112 804 | - 224 805 | - 255 806 | id: 114 807 | isthing: 0 808 | name: tent, collapsible shelter 809 | - color: 810 | - 70 811 | - 184 812 | - 160 813 | id: 115 814 | isthing: 0 815 | name: bag 816 | - color: 817 | - 163 818 | - 0 819 | - 255 820 | id: 116 821 | isthing: 0 822 | name: minibike, motorbike 823 | - color: 824 | - 153 825 | - 0 826 | - 255 827 | id: 117 828 | isthing: 0 829 | name: cradle 830 | - color: 831 | - 71 832 | - 255 833 | - 0 834 | id: 118 835 | isthing: 0 836 | name: oven 837 | - color: 838 | - 255 839 | - 0 840 | - 163 841 | id: 119 842 | isthing: 0 843 | name: ball 844 | - color: 845 | - 255 846 | - 204 847 | - 0 848 | id: 120 849 | isthing: 0 850 | name: food, solid food 851 | - color: 852 | - 255 853 | - 0 854 | - 143 855 | id: 121 856 | isthing: 0 857 | name: step, stair 858 | - color: 859 | - 0 860 | - 255 861 | - 235 862 | id: 122 863 | isthing: 0 864 | name: tank, storage tank 865 | - color: 866 | - 133 867 | - 255 868 | - 0 869 | id: 123 870 | isthing: 0 871 | name: trade name, brand name, brand, marque 872 | - color: 873 | - 255 874 | - 0 875 | - 235 876 | id: 124 877 | isthing: 0 878 | name: microwave, microwave oven 879 | - color: 880 | - 245 881 | - 0 882 | - 255 883 | id: 125 884 | isthing: 0 885 | name: pot, flowerpot 886 | - color: 887 | - 255 888 | - 0 889 | - 122 890 | id: 126 891 | isthing: 0 892 | name: animal, animate being, beast, brute, creature, fauna 893 | - color: 894 | - 255 895 | - 245 896 | - 0 897 | id: 127 898 | isthing: 0 899 | name: bicycle, bike, wheel, cycle 900 | - color: 901 | - 10 902 | - 190 903 | - 212 904 | id: 128 905 | isthing: 0 906 | name: lake 907 | - color: 908 | - 214 909 | - 255 910 | - 0 911 | id: 129 912 | isthing: 0 913 | name: dishwasher, dish washer, dishwashing machine 914 | - color: 915 | - 0 916 | - 204 917 | - 255 918 | id: 130 919 | isthing: 0 920 | name: screen, silver screen, projection screen 921 | - color: 922 | - 20 923 | - 0 924 | - 255 925 | id: 131 926 | isthing: 0 927 | name: blanket, cover 928 | - color: 929 | - 255 930 | - 255 931 | - 0 932 | id: 132 933 | isthing: 0 934 | name: sculpture 935 | - color: 936 | - 0 937 | - 153 938 | - 255 939 | id: 133 940 | isthing: 0 941 | name: hood, exhaust hood 942 | - color: 943 | - 0 944 | - 41 945 | - 255 946 | id: 134 947 | isthing: 0 948 | name: sconce 949 | - color: 950 | - 0 951 | - 255 952 | - 204 953 | id: 135 954 | isthing: 0 955 | name: vase 956 | - color: 957 | - 41 958 | - 0 959 | - 255 960 | id: 136 961 | isthing: 0 962 | name: traffic light, traffic signal, stoplight 963 | - color: 964 | - 41 965 | - 255 966 | - 0 967 | id: 137 968 | isthing: 0 969 | name: tray 970 | - color: 971 | - 173 972 | - 0 973 | - 255 974 | id: 138 975 | isthing: 0 976 | name: ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, 977 | trash barrel, trash bin 978 | - color: 979 | - 0 980 | - 245 981 | - 255 982 | id: 139 983 | isthing: 0 984 | name: fan 985 | - color: 986 | - 71 987 | - 0 988 | - 255 989 | id: 140 990 | isthing: 0 991 | name: pier, wharf, wharfage, dock 992 | - color: 993 | - 122 994 | - 0 995 | - 255 996 | id: 141 997 | isthing: 0 998 | name: crt screen 999 | - color: 1000 | - 0 1001 | - 255 1002 | - 184 1003 | id: 142 1004 | isthing: 0 1005 | name: plate 1006 | - color: 1007 | - 0 1008 | - 92 1009 | - 255 1010 | id: 143 1011 | isthing: 0 1012 | name: monitor, monitoring device 1013 | - color: 1014 | - 184 1015 | - 255 1016 | - 0 1017 | id: 144 1018 | isthing: 0 1019 | name: bulletin board, notice board 1020 | - color: 1021 | - 0 1022 | - 133 1023 | - 255 1024 | id: 145 1025 | isthing: 0 1026 | name: shower 1027 | - color: 1028 | - 255 1029 | - 214 1030 | - 0 1031 | id: 146 1032 | isthing: 0 1033 | name: radiator 1034 | - color: 1035 | - 25 1036 | - 194 1037 | - 194 1038 | id: 147 1039 | isthing: 0 1040 | name: glass, drinking glass 1041 | - color: 1042 | - 102 1043 | - 255 1044 | - 0 1045 | id: 148 1046 | isthing: 0 1047 | name: clock 1048 | - color: 1049 | - 92 1050 | - 0 1051 | - 255 1052 | id: 149 1053 | isthing: 0 1054 | name: flag 1055 | -------------------------------------------------------------------------------- /segm/data/config/cityscapes.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = "CityscapesDataset" 3 | data_root = "data/cityscapes/" 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True 6 | ) 7 | crop_size = (768, 768) 8 | max_ratio = 2 9 | train_pipeline = [ 10 | dict(type="LoadImageFromFile"), 11 | dict(type="LoadAnnotations"), 12 | dict(type="Resize", img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), 13 | dict(type="RandomCrop", crop_size=crop_size, cat_max_ratio=0.75), 14 | dict(type="RandomFlip", prob=0.5), 15 | dict(type="PhotoMetricDistortion"), 16 | dict(type="Normalize", **img_norm_cfg), 17 | dict(type="Pad", size=crop_size, pad_val=0, seg_pad_val=255), 18 | dict(type="DefaultFormatBundle"), 19 | dict(type="Collect", keys=["img", "gt_semantic_seg"]), 20 | ] 21 | val_pipeline = [ 22 | dict(type="LoadImageFromFile"), 23 | dict( 24 | type="MultiScaleFlipAug", 25 | img_scale=(1024 * max_ratio, 1024), 26 | flip=False, 27 | transforms=[ 28 | dict(type="Resize", keep_ratio=True), 29 | dict(type="RandomFlip"), 30 | dict(type="Normalize", **img_norm_cfg), 31 | dict(type="ImageToTensor", keys=["img"]), 32 | dict(type="Collect", keys=["img"]), 33 | ], 34 | ), 35 | ] 36 | test_pipeline = [ 37 | dict(type="LoadImageFromFile"), 38 | dict( 39 | type="MultiScaleFlipAug", 40 | img_scale=(1024 * max_ratio, 1024), 41 | flip=False, 42 | transforms=[ 43 | dict(type="Resize", keep_ratio=True), 44 | dict(type="RandomFlip"), 45 | dict(type="Normalize", **img_norm_cfg), 46 | dict(type="ImageToTensor", keys=["img"]), 47 | dict(type="Collect", keys=["img"]), 48 | ], 49 | ), 50 | ] 51 | data = dict( 52 | samples_per_gpu=2, 53 | workers_per_gpu=2, 54 | train=dict( 55 | type=dataset_type, 56 | data_root=data_root, 57 | img_dir="leftImg8bit/train", 58 | ann_dir="gtFine/train", 59 | pipeline=train_pipeline, 60 | ), 61 | trainval=dict( 62 | type=dataset_type, 63 | data_root=data_root, 64 | img_dir=["leftImg8bit/train", "leftImg8bit/val"], 65 | ann_dir=["gtFine/train", "gtFine/val"], 66 | pipeline=train_pipeline, 67 | ), 68 | val=dict( 69 | type=dataset_type, 70 | data_root=data_root, 71 | img_dir="leftImg8bit/val", 72 | ann_dir="gtFine/val", 73 | pipeline=test_pipeline, 74 | ), 75 | test=dict( 76 | type=dataset_type, 77 | data_root=data_root, 78 | img_dir="leftImg8bit/test", 79 | ann_dir="gtFine/test", 80 | pipeline=test_pipeline, 81 | ), 82 | ) 83 | -------------------------------------------------------------------------------- /segm/data/config/cityscapes.yml: -------------------------------------------------------------------------------- 1 | - color: 2 | - 128 3 | - 64 4 | - 128 5 | id: 0 6 | isthing: false 7 | name: road 8 | - color: 9 | - 244 10 | - 35 11 | - 232 12 | id: 1 13 | isthing: false 14 | name: sidewalk 15 | - color: 16 | - 70 17 | - 70 18 | - 70 19 | id: 2 20 | isthing: false 21 | name: building 22 | - color: 23 | - 102 24 | - 102 25 | - 156 26 | id: 3 27 | isthing: false 28 | name: wall 29 | - color: 30 | - 190 31 | - 153 32 | - 153 33 | id: 4 34 | isthing: false 35 | name: fence 36 | - color: 37 | - 153 38 | - 153 39 | - 153 40 | id: 5 41 | isthing: false 42 | name: pole 43 | - color: 44 | - 250 45 | - 170 46 | - 30 47 | id: 6 48 | isthing: false 49 | name: traffic light 50 | - color: 51 | - 220 52 | - 220 53 | - 0 54 | id: 7 55 | isthing: false 56 | name: traffic sign 57 | - color: 58 | - 107 59 | - 142 60 | - 35 61 | id: 8 62 | isthing: false 63 | name: vegetation 64 | - color: 65 | - 152 66 | - 251 67 | - 152 68 | id: 9 69 | isthing: false 70 | name: terrain 71 | - color: 72 | - 70 73 | - 130 74 | - 180 75 | id: 10 76 | isthing: false 77 | name: sky 78 | - color: 79 | - 220 80 | - 20 81 | - 60 82 | id: 11 83 | isthing: true 84 | name: person 85 | - color: 86 | - 255 87 | - 0 88 | - 0 89 | id: 12 90 | isthing: true 91 | name: rider 92 | - color: 93 | - 0 94 | - 0 95 | - 142 96 | id: 13 97 | isthing: true 98 | name: car 99 | - color: 100 | - 0 101 | - 0 102 | - 70 103 | id: 14 104 | isthing: true 105 | name: truck 106 | - color: 107 | - 0 108 | - 60 109 | - 100 110 | id: 15 111 | isthing: true 112 | name: bus 113 | - color: 114 | - 0 115 | - 80 116 | - 100 117 | id: 16 118 | isthing: true 119 | name: train 120 | - color: 121 | - 0 122 | - 0 123 | - 230 124 | id: 17 125 | isthing: true 126 | name: motorcycle 127 | - color: 128 | - 119 129 | - 11 130 | - 32 131 | id: 18 132 | isthing: true 133 | name: bicycle 134 | -------------------------------------------------------------------------------- /segm/data/config/pascal_context.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = "PascalContextDataset" 3 | data_root = "data/VOCdevkit/VOC2010/" 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True 6 | ) 7 | 8 | img_scale = (512, 512) 9 | crop_size = (512, 512) 10 | max_ratio = 8 11 | train_pipeline = [ 12 | dict(type="LoadImageFromFile"), 13 | dict(type="LoadAnnotations"), 14 | dict(type="Resize", img_scale=img_scale, ratio_range=(0.5, 2.0)), 15 | dict(type="RandomCrop", crop_size=crop_size, cat_max_ratio=0.75), 16 | dict(type="RandomFlip", prob=0.5), 17 | dict(type="PhotoMetricDistortion"), 18 | dict(type="Normalize", **img_norm_cfg), 19 | dict(type="Pad", size=crop_size, pad_val=0, seg_pad_val=255), 20 | dict(type="DefaultFormatBundle"), 21 | dict(type="Collect", keys=["img", "gt_semantic_seg"]), 22 | ] 23 | val_pipeline = [ 24 | dict(type="LoadImageFromFile"), 25 | dict( 26 | type="MultiScaleFlipAug", 27 | img_scale=(512 * max_ratio, 512), 28 | flip=False, 29 | transforms=[ 30 | dict(type="Resize", keep_ratio=True), 31 | dict(type="RandomFlip"), 32 | dict(type="Normalize", **img_norm_cfg), 33 | dict(type="ImageToTensor", keys=["img"]), 34 | dict(type="Collect", keys=["img"]), 35 | ], 36 | ), 37 | ] 38 | test_pipeline = [ 39 | dict(type="LoadImageFromFile"), 40 | dict( 41 | type="MultiScaleFlipAug", 42 | img_scale=(512 * max_ratio, 512), 43 | flip=False, 44 | transforms=[ 45 | dict(type="Resize", keep_ratio=True), 46 | dict(type="RandomFlip"), 47 | dict(type="Normalize", **img_norm_cfg), 48 | dict(type="ImageToTensor", keys=["img"]), 49 | dict(type="Collect", keys=["img"]), 50 | ], 51 | ), 52 | ] 53 | data = dict( 54 | samples_per_gpu=4, 55 | workers_per_gpu=4, 56 | train=dict( 57 | type=dataset_type, 58 | data_root=data_root, 59 | img_dir="JPEGImages", 60 | ann_dir="SegmentationClassContext", 61 | split="ImageSets/SegmentationContext/train.txt", 62 | pipeline=train_pipeline, 63 | ), 64 | val=dict( 65 | type=dataset_type, 66 | data_root=data_root, 67 | img_dir="JPEGImages", 68 | ann_dir="SegmentationClassContext", 69 | split="ImageSets/SegmentationContext/val.txt", 70 | pipeline=val_pipeline, 71 | ), 72 | test=dict( 73 | type=dataset_type, 74 | data_root=data_root, 75 | img_dir="JPEGImages", 76 | ann_dir="SegmentationClassContext", 77 | split="ImageSets/SegmentationContext/val.txt", 78 | pipeline=test_pipeline, 79 | ), 80 | ) 81 | -------------------------------------------------------------------------------- /segm/data/config/pascal_context.yml: -------------------------------------------------------------------------------- 1 | - color: 2 | - 120 3 | - 120 4 | - 120 5 | id: 0 6 | name: background 7 | - color: 8 | - 180 9 | - 120 10 | - 120 11 | id: 1 12 | name: aeroplane 13 | - color: 14 | - 6 15 | - 230 16 | - 230 17 | id: 2 18 | name: bicycle 19 | - color: 20 | - 80 21 | - 50 22 | - 50 23 | id: 3 24 | name: bird 25 | - color: 26 | - 4 27 | - 200 28 | - 3 29 | id: 4 30 | name: boat 31 | - color: 32 | - 120 33 | - 120 34 | - 80 35 | id: 5 36 | name: bottle 37 | - color: 38 | - 140 39 | - 140 40 | - 140 41 | id: 6 42 | name: bus 43 | - color: 44 | - 204 45 | - 5 46 | - 255 47 | id: 7 48 | name: car 49 | - color: 50 | - 230 51 | - 230 52 | - 230 53 | id: 8 54 | name: cat 55 | - color: 56 | - 4 57 | - 250 58 | - 7 59 | id: 9 60 | name: chair 61 | - color: 62 | - 224 63 | - 5 64 | - 255 65 | id: 10 66 | name: cow 67 | - color: 68 | - 235 69 | - 255 70 | - 7 71 | id: 11 72 | name: table 73 | - color: 74 | - 150 75 | - 5 76 | - 61 77 | id: 12 78 | name: dog 79 | - color: 80 | - 120 81 | - 120 82 | - 70 83 | id: 13 84 | name: horse 85 | - color: 86 | - 8 87 | - 255 88 | - 51 89 | id: 14 90 | name: motorbike 91 | - color: 92 | - 255 93 | - 6 94 | - 82 95 | id: 15 96 | name: person 97 | - color: 98 | - 143 99 | - 255 100 | - 140 101 | id: 16 102 | name: pottedplant 103 | - color: 104 | - 204 105 | - 255 106 | - 4 107 | id: 17 108 | name: sheep 109 | - color: 110 | - 255 111 | - 51 112 | - 7 113 | id: 18 114 | name: sofa 115 | - color: 116 | - 204 117 | - 70 118 | - 3 119 | id: 19 120 | name: train 121 | - color: 122 | - 0 123 | - 102 124 | - 200 125 | id: 20 126 | name: tvmonitor 127 | - color: 128 | - 61 129 | - 230 130 | - 250 131 | id: 21 132 | name: bag 133 | - color: 134 | - 255 135 | - 6 136 | - 51 137 | id: 22 138 | name: bed 139 | - color: 140 | - 11 141 | - 102 142 | - 255 143 | id: 23 144 | name: bench 145 | - color: 146 | - 255 147 | - 7 148 | - 71 149 | id: 24 150 | name: book 151 | - color: 152 | - 255 153 | - 9 154 | - 224 155 | id: 25 156 | name: building 157 | - color: 158 | - 9 159 | - 7 160 | - 230 161 | id: 26 162 | name: cabinet 163 | - color: 164 | - 220 165 | - 220 166 | - 220 167 | id: 27 168 | name: ceiling 169 | - color: 170 | - 255 171 | - 9 172 | - 92 173 | id: 28 174 | name: cloth 175 | - color: 176 | - 112 177 | - 9 178 | - 255 179 | id: 29 180 | name: computer 181 | - color: 182 | - 8 183 | - 255 184 | - 214 185 | id: 30 186 | name: cup 187 | - color: 188 | - 7 189 | - 255 190 | - 224 191 | id: 31 192 | name: door 193 | - color: 194 | - 255 195 | - 184 196 | - 6 197 | id: 32 198 | name: fence 199 | - color: 200 | - 10 201 | - 255 202 | - 71 203 | id: 33 204 | name: floor 205 | - color: 206 | - 255 207 | - 41 208 | - 10 209 | id: 34 210 | name: flower 211 | - color: 212 | - 7 213 | - 255 214 | - 255 215 | id: 35 216 | name: food 217 | - color: 218 | - 224 219 | - 255 220 | - 8 221 | id: 36 222 | name: grass 223 | - color: 224 | - 102 225 | - 8 226 | - 255 227 | id: 37 228 | name: ground 229 | - color: 230 | - 255 231 | - 61 232 | - 6 233 | id: 38 234 | name: keyboard 235 | - color: 236 | - 255 237 | - 194 238 | - 7 239 | id: 39 240 | name: light 241 | - color: 242 | - 255 243 | - 122 244 | - 8 245 | id: 40 246 | name: mountain 247 | - color: 248 | - 0 249 | - 255 250 | - 20 251 | id: 41 252 | name: mouse 253 | - color: 254 | - 255 255 | - 8 256 | - 41 257 | id: 42 258 | name: curtain 259 | - color: 260 | - 255 261 | - 5 262 | - 153 263 | id: 43 264 | name: platform 265 | - color: 266 | - 6 267 | - 51 268 | - 255 269 | id: 44 270 | name: sign 271 | - color: 272 | - 235 273 | - 12 274 | - 255 275 | id: 45 276 | name: plate 277 | - color: 278 | - 160 279 | - 150 280 | - 20 281 | id: 46 282 | name: road 283 | - color: 284 | - 0 285 | - 163 286 | - 255 287 | id: 47 288 | name: rock 289 | - color: 290 | - 140 291 | - 140 292 | - 140 293 | id: 48 294 | name: shelves 295 | - color: 296 | - 250 297 | - 10 298 | - 15 299 | id: 49 300 | name: sidewalk 301 | - color: 302 | - 20 303 | - 255 304 | - 0 305 | id: 50 306 | name: sky 307 | - color: 308 | - 31 309 | - 255 310 | - 0 311 | id: 51 312 | name: snow 313 | - color: 314 | - 255 315 | - 31 316 | - 0 317 | id: 52 318 | name: bedclothes 319 | - color: 320 | - 255 321 | - 224 322 | - 0 323 | id: 53 324 | name: track 325 | - color: 326 | - 153 327 | - 255 328 | - 0 329 | id: 54 330 | name: tree 331 | - color: 332 | - 0 333 | - 0 334 | - 255 335 | id: 55 336 | name: truck 337 | - color: 338 | - 255 339 | - 71 340 | - 0 341 | id: 56 342 | name: wall 343 | - color: 344 | - 0 345 | - 235 346 | - 255 347 | id: 57 348 | name: water 349 | - color: 350 | - 0 351 | - 173 352 | - 255 353 | id: 58 354 | name: window 355 | - color: 356 | - 31 357 | - 0 358 | - 255 359 | id: 59 360 | name: wood 361 | -------------------------------------------------------------------------------- /segm/data/factory.py: -------------------------------------------------------------------------------- 1 | import segm.utils.torch as ptu 2 | 3 | from segm.data import ImagenetDataset 4 | from segm.data import ADE20KSegmentation 5 | from segm.data import PascalContextDataset 6 | from segm.data import CityscapesDataset 7 | from segm.data import Loader 8 | 9 | 10 | def create_dataset(dataset_kwargs): 11 | dataset_kwargs = dataset_kwargs.copy() 12 | dataset_name = dataset_kwargs.pop("dataset") 13 | batch_size = dataset_kwargs.pop("batch_size") 14 | num_workers = dataset_kwargs.pop("num_workers") 15 | split = dataset_kwargs.pop("split") 16 | 17 | # load dataset_name 18 | if dataset_name == "imagenet": 19 | dataset_kwargs.pop("patch_size") 20 | dataset = ImagenetDataset(split=split, **dataset_kwargs) 21 | elif dataset_name == "ade20k": 22 | dataset = ADE20KSegmentation(split=split, **dataset_kwargs) 23 | elif dataset_name == "pascal_context": 24 | dataset = PascalContextDataset(split=split, **dataset_kwargs) 25 | elif dataset_name == "cityscapes": 26 | dataset = CityscapesDataset(split=split, **dataset_kwargs) 27 | else: 28 | raise ValueError(f"Dataset {dataset_name} is unknown.") 29 | 30 | dataset = Loader( 31 | dataset=dataset, 32 | batch_size=batch_size, 33 | num_workers=num_workers, 34 | distributed=ptu.distributed, 35 | split=split, 36 | ) 37 | return dataset 38 | -------------------------------------------------------------------------------- /segm/data/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from pathlib import Path 4 | 5 | from torch.utils.data import Dataset 6 | from torchvision import datasets 7 | from torchvision import transforms 8 | from PIL import Image 9 | 10 | from segm.data import utils 11 | from segm.config import dataset_dir 12 | 13 | 14 | class ImagenetDataset(Dataset): 15 | def __init__( 16 | self, 17 | root_dir, 18 | image_size=224, 19 | crop_size=224, 20 | split="train", 21 | normalization="vit", 22 | ): 23 | super().__init__() 24 | assert image_size[0] == image_size[1] 25 | 26 | self.path = Path(root_dir) / split 27 | self.crop_size = crop_size 28 | self.image_size = image_size 29 | self.split = split 30 | self.normalization = normalization 31 | 32 | if split == "train": 33 | self.transform = transforms.Compose( 34 | [ 35 | transforms.RandomResizedCrop(self.crop_size, interpolation=3), 36 | transforms.RandomHorizontalFlip(), 37 | transforms.ToTensor(), 38 | ] 39 | ) 40 | else: 41 | self.transform = transforms.Compose( 42 | [ 43 | transforms.Resize(image_size[0] + 32, interpolation=3), 44 | transforms.CenterCrop(self.crop_size), 45 | transforms.ToTensor(), 46 | ] 47 | ) 48 | 49 | self.base_dataset = datasets.ImageFolder(self.path, self.transform) 50 | self.n_cls = 1000 51 | 52 | @property 53 | def unwrapped(self): 54 | return self 55 | 56 | def __len__(self): 57 | return len(self.base_dataset) 58 | 59 | def __getitem__(self, idx): 60 | im, target = self.base_dataset[idx] 61 | im = utils.rgb_normalize(im, self.normalization) 62 | return dict(im=im, target=target) 63 | -------------------------------------------------------------------------------- /segm/data/loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from torch.utils.data.distributed import DistributedSampler 3 | 4 | import segm.utils.torch as ptu 5 | 6 | 7 | class Loader(DataLoader): 8 | def __init__(self, dataset, batch_size, num_workers, distributed, split): 9 | if distributed: 10 | sampler = DistributedSampler(dataset, shuffle=True) 11 | super().__init__( 12 | dataset, 13 | batch_size=batch_size, 14 | shuffle=False, 15 | num_workers=num_workers, 16 | pin_memory=True, 17 | sampler=sampler, 18 | ) 19 | else: 20 | super().__init__( 21 | dataset, 22 | batch_size=batch_size, 23 | shuffle=True, 24 | num_workers=num_workers, 25 | pin_memory=True, 26 | ) 27 | 28 | self.base_dataset = self.dataset 29 | 30 | @property 31 | def unwrapped(self): 32 | return self.base_dataset.unwrapped 33 | 34 | def set_epoch(self, epoch): 35 | if isinstance(self.sampler, DistributedSampler): 36 | self.sampler.set_epoch(epoch) 37 | 38 | def get_diagnostics(self, logger): 39 | return self.base_dataset.get_diagnostics(logger) 40 | 41 | def get_snapshot(self): 42 | return self.base_dataset.get_snapshot() 43 | 44 | def end_epoch(self, epoch): 45 | return self.base_dataset.end_epoch(epoch) 46 | -------------------------------------------------------------------------------- /segm/data/pascal_context.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from segm.data.base import BaseMMSeg 4 | from segm.data import utils 5 | from segm.config import dataset_dir 6 | 7 | PASCAL_CONTEXT_CONFIG_PATH = Path(__file__).parent / "config" / "pascal_context.py" 8 | PASCAL_CONTEXT_CATS_PATH = Path(__file__).parent / "config" / "pascal_context.yml" 9 | 10 | 11 | class PascalContextDataset(BaseMMSeg): 12 | def __init__(self, image_size, crop_size, split, **kwargs): 13 | super().__init__( 14 | image_size, crop_size, split, PASCAL_CONTEXT_CONFIG_PATH, **kwargs 15 | ) 16 | self.names, self.colors = utils.dataset_cat_description( 17 | PASCAL_CONTEXT_CATS_PATH 18 | ) 19 | self.n_cls = 60 20 | self.ignore_label = 255 21 | self.reduce_zero_label = False 22 | 23 | def update_default_config(self, config): 24 | root_dir = dataset_dir() 25 | path = Path(root_dir) / "pcontext" 26 | config.data_root = path 27 | if self.split == "train": 28 | config.data.train.data_root = path / "VOCdevkit/VOC2010/" 29 | elif self.split == "val": 30 | config.data.val.data_root = path / "VOCdevkit/VOC2010/" 31 | elif self.split == "test": 32 | raise ValueError("Test split is not valid for Pascal Context dataset") 33 | config = super().update_default_config(config) 34 | return config 35 | 36 | def test_post_process(self, labels): 37 | return labels 38 | -------------------------------------------------------------------------------- /segm/data/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms.functional as F 3 | import numpy as np 4 | import yaml 5 | from pathlib import Path 6 | 7 | IGNORE_LABEL = 255 8 | STATS = { 9 | "vit": {"mean": (0.5, 0.5, 0.5), "std": (0.5, 0.5, 0.5)}, 10 | "deit": {"mean": (0.485, 0.456, 0.406), "std": (0.229, 0.224, 0.225)}, 11 | } 12 | 13 | 14 | def seg_to_rgb(seg, colors): 15 | im = torch.zeros((seg.shape[0], seg.shape[1], seg.shape[2], 3)).float() 16 | cls = torch.unique(seg) 17 | for cl in cls: 18 | color = colors[int(cl)] 19 | if len(color.shape) > 1: 20 | color = color[0] 21 | im[seg == cl] = color 22 | return im 23 | 24 | 25 | def dataset_cat_description(path, cmap=None): 26 | desc = yaml.load(open(path, "r"), Loader=yaml.FullLoader) 27 | colors = {} 28 | names = [] 29 | for i, cat in enumerate(desc): 30 | names.append(cat["name"]) 31 | if "color" in cat: 32 | colors[cat["id"]] = torch.tensor(cat["color"]).float() / 255 33 | else: 34 | colors[cat["id"]] = torch.tensor(cmap[cat["id"]]).float() 35 | colors[IGNORE_LABEL] = torch.tensor([0.0, 0.0, 0.0]).float() 36 | return names, colors 37 | 38 | 39 | def rgb_normalize(x, stats): 40 | """ 41 | x : C x * 42 | x \in [0, 1] 43 | """ 44 | return F.normalize(x, stats["mean"], stats["std"]) 45 | 46 | 47 | def rgb_denormalize(x, stats): 48 | """ 49 | x : N x C x * 50 | x \in [-1, 1] 51 | """ 52 | mean = torch.tensor(stats["mean"]) 53 | std = torch.tensor(stats["std"]) 54 | for i in range(3): 55 | x[:, i, :, :] = x[:, i, :, :] * std[i] + mean[i] 56 | return x 57 | -------------------------------------------------------------------------------- /segm/engine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | from segm.utils.logger import MetricLogger 5 | from segm.metrics import gather_data, compute_metrics 6 | from segm.model import utils 7 | from segm.data.utils import IGNORE_LABEL 8 | import segm.utils.torch as ptu 9 | 10 | 11 | def train_one_epoch( 12 | model, 13 | data_loader, 14 | optimizer, 15 | lr_scheduler, 16 | epoch, 17 | amp_autocast, 18 | loss_scaler, 19 | ): 20 | criterion = torch.nn.CrossEntropyLoss(ignore_index=IGNORE_LABEL) 21 | logger = MetricLogger(delimiter=" ") 22 | header = f"Epoch: [{epoch}]" 23 | print_freq = 100 24 | 25 | model.train() 26 | data_loader.set_epoch(epoch) 27 | num_updates = epoch * len(data_loader) 28 | for batch in logger.log_every(data_loader, print_freq, header): 29 | im = batch["im"].to(ptu.device) 30 | seg_gt = batch["segmentation"].long().to(ptu.device) 31 | 32 | with amp_autocast(): 33 | seg_pred = model.forward(im) 34 | loss = criterion(seg_pred, seg_gt) 35 | 36 | loss_value = loss.item() 37 | if not math.isfinite(loss_value): 38 | print("Loss is {}, stopping training".format(loss_value), force=True) 39 | 40 | optimizer.zero_grad() 41 | if loss_scaler is not None: 42 | loss_scaler( 43 | loss, 44 | optimizer, 45 | parameters=model.parameters(), 46 | ) 47 | else: 48 | loss.backward() 49 | optimizer.step() 50 | 51 | num_updates += 1 52 | lr_scheduler.step_update(num_updates=num_updates) 53 | 54 | torch.cuda.synchronize() 55 | 56 | logger.update( 57 | loss=loss.item(), 58 | learning_rate=optimizer.param_groups[0]["lr"], 59 | ) 60 | 61 | return logger 62 | 63 | 64 | @torch.no_grad() 65 | def evaluate( 66 | model, 67 | data_loader, 68 | val_seg_gt, 69 | window_size, 70 | window_stride, 71 | amp_autocast, 72 | ): 73 | model_without_ddp = model 74 | if hasattr(model, "module"): 75 | model_without_ddp = model.module 76 | logger = MetricLogger(delimiter=" ") 77 | header = "Eval:" 78 | print_freq = 50 79 | 80 | val_seg_pred = {} 81 | model.eval() 82 | for batch in logger.log_every(data_loader, print_freq, header): 83 | ims = [im.to(ptu.device) for im in batch["im"]] 84 | ims_metas = batch["im_metas"] 85 | ori_shape = ims_metas[0]["ori_shape"] 86 | ori_shape = (ori_shape[0].item(), ori_shape[1].item()) 87 | filename = batch["im_metas"][0]["ori_filename"][0] 88 | 89 | with amp_autocast(): 90 | seg_pred = utils.inference( 91 | model_without_ddp, 92 | ims, 93 | ims_metas, 94 | ori_shape, 95 | window_size, 96 | window_stride, 97 | batch_size=1, 98 | ) 99 | seg_pred = seg_pred.argmax(0) 100 | 101 | seg_pred = seg_pred.cpu().numpy() 102 | val_seg_pred[filename] = seg_pred 103 | 104 | val_seg_pred = gather_data(val_seg_pred) 105 | scores = compute_metrics( 106 | val_seg_pred, 107 | val_seg_gt, 108 | data_loader.unwrapped.n_cls, 109 | ignore_index=IGNORE_LABEL, 110 | distributed=ptu.distributed, 111 | ) 112 | 113 | for k, v in scores.items(): 114 | logger.update(**{f"{k}": v, "n": 1}) 115 | 116 | return logger 117 | -------------------------------------------------------------------------------- /segm/eval/accuracy.py: -------------------------------------------------------------------------------- 1 | import click 2 | import torch 3 | 4 | import segm.utils.torch as ptu 5 | 6 | from segm.utils.logger import MetricLogger 7 | 8 | from segm.model.factory import create_vit 9 | from segm.data.factory import create_dataset 10 | from segm.data.utils import STATS 11 | from segm.metrics import accuracy 12 | from segm import config 13 | 14 | 15 | def compute_labels(model, batch): 16 | im = batch["im"] 17 | target = batch["target"] 18 | 19 | with torch.no_grad(): 20 | with torch.cuda.amp.autocast(): 21 | output = model.forward(im) 22 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 23 | 24 | return acc1.item(), acc5.item() 25 | 26 | 27 | def eval_dataset(model, dataset_kwargs): 28 | db = create_dataset(dataset_kwargs) 29 | print_freq = 20 30 | header = "" 31 | logger = MetricLogger(delimiter=" ") 32 | 33 | for batch in logger.log_every(db, print_freq, header): 34 | for k, v in batch.items(): 35 | batch[k] = v.to(ptu.device) 36 | acc1, acc5 = compute_labels(model, batch) 37 | batch_size = batch["im"].size(0) 38 | logger.update(acc1=acc1, n=batch_size) 39 | logger.update(acc5=acc5, n=batch_size) 40 | print(f"Imagenet accuracy: {logger}") 41 | 42 | 43 | @click.command() 44 | @click.argument("backbone", type=str) 45 | @click.option("--imagenet-dir", type=str) 46 | @click.option("-bs", "--batch-size", default=32, type=int) 47 | @click.option("-nw", "--num-workers", default=10, type=int) 48 | @click.option("-gpu", "--gpu/--no-gpu", default=True, is_flag=True) 49 | def main(backbone, imagenet_dir, batch_size, num_workers, gpu): 50 | ptu.set_gpu_mode(gpu) 51 | cfg = config.load_config() 52 | cfg = cfg["model"][backbone] 53 | cfg["backbone"] = backbone 54 | cfg["image_size"] = (cfg["image_size"], cfg["image_size"]) 55 | 56 | dataset_kwargs = dict( 57 | dataset="imagenet", 58 | root_dir=imagenet_dir, 59 | image_size=cfg["image_size"], 60 | crop_size=cfg["image_size"], 61 | patch_size=cfg["patch_size"], 62 | batch_size=batch_size, 63 | num_workers=num_workers, 64 | split="val", 65 | normalization=STATS[cfg["normalization"]], 66 | ) 67 | 68 | model = create_vit(cfg) 69 | model.to(ptu.device) 70 | model.eval() 71 | eval_dataset(model, dataset_kwargs) 72 | 73 | 74 | if __name__ == "__main__": 75 | main() 76 | -------------------------------------------------------------------------------- /segm/eval/miou.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import click 3 | from pathlib import Path 4 | import yaml 5 | import numpy as np 6 | from PIL import Image 7 | import shutil 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.nn.parallel import DistributedDataParallel as DDP 12 | 13 | from segm.utils import distributed 14 | from segm.utils.logger import MetricLogger 15 | import segm.utils.torch as ptu 16 | 17 | from segm.model.factory import load_model 18 | from segm.data.factory import create_dataset 19 | from segm.metrics import gather_data, compute_metrics 20 | 21 | from segm.model.utils import inference 22 | from segm.data.utils import seg_to_rgb, rgb_denormalize, IGNORE_LABEL 23 | from segm import config 24 | 25 | 26 | def blend_im(im, seg, alpha=0.5): 27 | pil_im = Image.fromarray(im) 28 | pil_seg = Image.fromarray(seg) 29 | im_blend = Image.blend(pil_im, pil_seg, alpha).convert("RGB") 30 | return np.asarray(im_blend) 31 | 32 | 33 | def save_im(save_dir, save_name, im, seg_pred, seg_gt, colors, blend, normalization): 34 | seg_rgb = seg_to_rgb(seg_gt[None], colors) 35 | pred_rgb = seg_to_rgb(seg_pred[None], colors) 36 | im_unnorm = rgb_denormalize(im, normalization) 37 | save_dir = Path(save_dir) 38 | 39 | # save images 40 | im_uint = (im_unnorm.permute(0, 2, 3, 1).cpu().numpy()).astype(np.uint8) 41 | seg_rgb_uint = (255 * seg_rgb.cpu().numpy()).astype(np.uint8) 42 | seg_pred_uint = (255 * pred_rgb.cpu().numpy()).astype(np.uint8) 43 | for i in range(pred_rgb.shape[0]): 44 | if blend: 45 | blend_pred = blend_im(im_uint[i], seg_pred_uint[i]) 46 | blend_gt = blend_im(im_uint[i], seg_rgb_uint[i]) 47 | ims = (im_uint[i], blend_pred, blend_gt) 48 | else: 49 | ims = (im_uint[i], seg_pred_uint[i], seg_rgb_uint[i]) 50 | for im, im_dir in zip( 51 | ims, (save_dir / "input", save_dir / "pred", save_dir / "gt"), 52 | ): 53 | pil_out = Image.fromarray(im) 54 | im_dir.mkdir(exist_ok=True) 55 | pil_out.save(im_dir / save_name) 56 | 57 | 58 | def process_batch( 59 | model, batch, window_size, window_stride, window_batch_size, 60 | ): 61 | ims = batch["im"] 62 | ims_metas = batch["im_metas"] 63 | ori_shape = ims_metas[0]["ori_shape"] 64 | ori_shape = (ori_shape[0].item(), ori_shape[1].item()) 65 | filename = batch["im_metas"][0]["ori_filename"][0] 66 | 67 | model_without_ddp = model 68 | if ptu.distributed: 69 | model_without_ddp = model.module 70 | seg_pred = inference( 71 | model_without_ddp, 72 | ims, 73 | ims_metas, 74 | ori_shape, 75 | window_size, 76 | window_stride, 77 | window_batch_size, 78 | ) 79 | seg_pred = seg_pred.argmax(0) 80 | im = F.interpolate(ims[-1], ori_shape, mode="bilinear") 81 | 82 | return filename, im.cpu(), seg_pred.cpu() 83 | 84 | 85 | def eval_dataset( 86 | model, 87 | multiscale, 88 | model_dir, 89 | blend, 90 | window_size, 91 | window_stride, 92 | window_batch_size, 93 | save_images, 94 | frac_dataset, 95 | dataset_kwargs, 96 | ): 97 | db = create_dataset(dataset_kwargs) 98 | normalization = db.dataset.normalization 99 | dataset_name = dataset_kwargs["dataset"] 100 | im_size = dataset_kwargs["image_size"] 101 | cat_names = db.base_dataset.names 102 | n_cls = db.unwrapped.n_cls 103 | if multiscale: 104 | db.dataset.set_multiscale_mode() 105 | 106 | logger = MetricLogger(delimiter=" ") 107 | header = "" 108 | print_freq = 50 109 | 110 | ims = {} 111 | seg_pred_maps = {} 112 | idx = 0 113 | for batch in logger.log_every(db, print_freq, header): 114 | colors = batch["colors"] 115 | filename, im, seg_pred = process_batch( 116 | model, batch, window_size, window_stride, window_batch_size, 117 | ) 118 | ims[filename] = im 119 | seg_pred_maps[filename] = seg_pred 120 | idx += 1 121 | if idx > len(db) * frac_dataset: 122 | break 123 | 124 | seg_gt_maps = db.dataset.get_gt_seg_maps() 125 | if save_images: 126 | save_dir = model_dir / "images" 127 | if ptu.dist_rank == 0: 128 | if save_dir.exists(): 129 | shutil.rmtree(save_dir) 130 | save_dir.mkdir() 131 | if ptu.distributed: 132 | torch.distributed.barrier() 133 | 134 | for name in sorted(ims): 135 | instance_dir = save_dir 136 | filename = name 137 | 138 | if dataset_name == "cityscapes": 139 | filename_list = name.split("/") 140 | instance_dir = instance_dir / filename_list[0] 141 | filename = filename_list[-1] 142 | if not instance_dir.exists(): 143 | instance_dir.mkdir() 144 | 145 | save_im( 146 | instance_dir, 147 | filename, 148 | ims[name], 149 | seg_pred_maps[name], 150 | torch.tensor(seg_gt_maps[name]), 151 | colors, 152 | blend, 153 | normalization, 154 | ) 155 | if ptu.dist_rank == 0: 156 | shutil.make_archive(save_dir, "zip", save_dir) 157 | # shutil.rmtree(save_dir) 158 | print(f"Saved eval images in {save_dir}.zip") 159 | 160 | if ptu.distributed: 161 | torch.distributed.barrier() 162 | seg_pred_maps = gather_data(seg_pred_maps) 163 | 164 | scores = compute_metrics( 165 | seg_pred_maps, 166 | seg_gt_maps, 167 | n_cls, 168 | ignore_index=IGNORE_LABEL, 169 | ret_cat_iou=True, 170 | distributed=ptu.distributed, 171 | ) 172 | 173 | if ptu.dist_rank == 0: 174 | scores["inference"] = "single_scale" if not multiscale else "multi_scale" 175 | suffix = "ss" if not multiscale else "ms" 176 | scores["cat_iou"] = np.round(100 * scores["cat_iou"], 2).tolist() 177 | for k, v in scores.items(): 178 | if k != "cat_iou" and k != "inference": 179 | scores[k] = v.item() 180 | if k != "cat_iou": 181 | print(f"{k}: {scores[k]}") 182 | scores_str = yaml.dump(scores) 183 | with open(model_dir / f"scores_{suffix}.yml", "w") as f: 184 | f.write(scores_str) 185 | 186 | 187 | @click.command() 188 | @click.argument("model_path", type=str) 189 | @click.argument("dataset_name", type=str) 190 | @click.option("--im-size", default=None, type=int) 191 | @click.option("--multiscale/--singlescale", default=False, is_flag=True) 192 | @click.option("--blend/--no-blend", default=True, is_flag=True) 193 | @click.option("--window-size", default=None, type=int) 194 | @click.option("--window-stride", default=None, type=int) 195 | @click.option("--window-batch-size", default=4, type=int) 196 | @click.option("--save-images/--no-save-images", default=False, is_flag=True) 197 | @click.option("-frac-dataset", "--frac-dataset", default=1.0, type=float) 198 | def main( 199 | model_path, 200 | dataset_name, 201 | im_size, 202 | multiscale, 203 | blend, 204 | window_size, 205 | window_stride, 206 | window_batch_size, 207 | save_images, 208 | frac_dataset, 209 | ): 210 | 211 | model_dir = Path(model_path).parent 212 | 213 | # start distributed mode 214 | ptu.set_gpu_mode(True) 215 | distributed.init_process() 216 | 217 | model, variant = load_model(model_path) 218 | patch_size = model.patch_size 219 | model.eval() 220 | model.to(ptu.device) 221 | if ptu.distributed: 222 | model = DDP(model, device_ids=[ptu.device], find_unused_parameters=True) 223 | 224 | cfg = config.load_config() 225 | dataset_cfg = cfg["dataset"][dataset_name] 226 | normalization = variant["dataset_kwargs"]["normalization"] 227 | if im_size is None: 228 | im_size = dataset_cfg.get("im_size", variant["dataset_kwargs"]["image_size"]) 229 | if window_size is None: 230 | window_size = variant["dataset_kwargs"]["crop_size"] 231 | if window_stride is None: 232 | window_stride = variant["dataset_kwargs"]["crop_size"] - 32 233 | 234 | dataset_kwargs = dict( 235 | dataset=dataset_name, 236 | image_size=im_size, 237 | crop_size=im_size, 238 | patch_size=patch_size, 239 | batch_size=1, 240 | num_workers=10, 241 | split="val", 242 | normalization=normalization, 243 | crop=False, 244 | rep_aug=False, 245 | ) 246 | 247 | eval_dataset( 248 | model, 249 | multiscale, 250 | model_dir, 251 | blend, 252 | window_size, 253 | window_stride, 254 | window_batch_size, 255 | save_images, 256 | frac_dataset, 257 | dataset_kwargs, 258 | ) 259 | 260 | distributed.barrier() 261 | distributed.destroy_process() 262 | sys.exit(1) 263 | 264 | 265 | if __name__ == "__main__": 266 | main() 267 | -------------------------------------------------------------------------------- /segm/inference.py: -------------------------------------------------------------------------------- 1 | import click 2 | from tqdm import tqdm 3 | from pathlib import Path 4 | from PIL import Image 5 | import numpy as np 6 | import torchvision.transforms.functional as F 7 | 8 | import segm.utils.torch as ptu 9 | 10 | from segm.data.utils import STATS 11 | from segm.data.ade20k import ADE20K_CATS_PATH 12 | from segm.data.utils import dataset_cat_description, seg_to_rgb 13 | 14 | from segm.model.factory import load_model 15 | from segm.model.utils import inference 16 | 17 | 18 | @click.command() 19 | @click.option("--model-path", type=str) 20 | @click.option("--input-dir", "-i", type=str, help="folder with input images") 21 | @click.option("--output-dir", "-o", type=str, help="folder with output images") 22 | @click.option("--gpu/--cpu", default=True, is_flag=True) 23 | def main(model_path, input_dir, output_dir, gpu): 24 | ptu.set_gpu_mode(gpu) 25 | 26 | model_dir = Path(model_path).parent 27 | model, variant = load_model(model_path) 28 | model.to(ptu.device) 29 | 30 | normalization_name = variant["dataset_kwargs"]["normalization"] 31 | normalization = STATS[normalization_name] 32 | cat_names, cat_colors = dataset_cat_description(ADE20K_CATS_PATH) 33 | 34 | input_dir = Path(input_dir) 35 | output_dir = Path(output_dir) 36 | output_dir.mkdir(exist_ok=True) 37 | 38 | list_dir = list(input_dir.iterdir()) 39 | for filename in tqdm(list_dir, ncols=80): 40 | pil_im = Image.open(filename).copy() 41 | im = F.pil_to_tensor(pil_im).float() / 255 42 | im = F.normalize(im, normalization["mean"], normalization["std"]) 43 | im = im.to(ptu.device).unsqueeze(0) 44 | 45 | im_meta = dict(flip=False) 46 | logits = inference( 47 | model, 48 | [im], 49 | [im_meta], 50 | ori_shape=im.shape[2:4], 51 | window_size=variant["inference_kwargs"]["window_size"], 52 | window_stride=variant["inference_kwargs"]["window_stride"], 53 | batch_size=2, 54 | ) 55 | seg_map = logits.argmax(0, keepdim=True) 56 | seg_rgb = seg_to_rgb(seg_map, cat_colors) 57 | seg_rgb = (255 * seg_rgb.cpu().numpy()).astype(np.uint8) 58 | pil_seg = Image.fromarray(seg_rgb[0]) 59 | 60 | pil_blend = Image.blend(pil_im, pil_seg, 0.5).convert("RGB") 61 | pil_blend.save(output_dir / filename.name) 62 | 63 | 64 | if __name__ == "__main__": 65 | main() 66 | -------------------------------------------------------------------------------- /segm/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.distributed as dist 4 | import segm.utils.torch as ptu 5 | 6 | import os 7 | import pickle as pkl 8 | from pathlib import Path 9 | import tempfile 10 | import shutil 11 | from mmseg.core import mean_iou 12 | 13 | """ 14 | ImageNet classifcation accuracy 15 | """ 16 | 17 | 18 | def accuracy(output, target, topk=(1,)): 19 | """ 20 | https://github.com/pytorch/examples/blob/master/imagenet/main.py 21 | Computes the accuracy over the k top predictions for the specified values of k 22 | """ 23 | with torch.no_grad(): 24 | maxk = max(topk) 25 | batch_size = target.size(0) 26 | 27 | _, pred = output.topk(maxk, 1, True, True) 28 | pred = pred.t() 29 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 30 | 31 | res = [] 32 | for k in topk: 33 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 34 | correct_k /= batch_size 35 | res.append(correct_k) 36 | return res 37 | 38 | 39 | """ 40 | Segmentation mean IoU 41 | based on collect_results_cpu 42 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/apis/test.py#L160-L200 43 | """ 44 | 45 | 46 | def gather_data(seg_pred, tmp_dir=None): 47 | """ 48 | distributed data gathering 49 | prediction and ground truth are stored in a common tmp directory 50 | and loaded on the master node to compute metrics 51 | """ 52 | if tmp_dir is None: 53 | tmpprefix = os.path.expandvars("$DATASET/temp") 54 | else: 55 | tmpprefix = os.path.expandvars(tmp_dir) 56 | MAX_LEN = 512 57 | # 32 is whitespace 58 | dir_tensor = torch.full((MAX_LEN,), 32, dtype=torch.uint8, device=ptu.device) 59 | if ptu.dist_rank == 0: 60 | tmpdir = tempfile.mkdtemp(prefix=tmpprefix) 61 | tmpdir = torch.tensor( 62 | bytearray(tmpdir.encode()), dtype=torch.uint8, device=ptu.device 63 | ) 64 | dir_tensor[: len(tmpdir)] = tmpdir 65 | # broadcast tmpdir from 0 to to the other nodes 66 | dist.broadcast(dir_tensor, 0) 67 | tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() 68 | tmpdir = Path(tmpdir) 69 | """ 70 | Save results in temp file and load them on main process 71 | """ 72 | tmp_file = tmpdir / f"part_{ptu.dist_rank}.pkl" 73 | pkl.dump(seg_pred, open(tmp_file, "wb")) 74 | dist.barrier() 75 | seg_pred = {} 76 | if ptu.dist_rank == 0: 77 | for i in range(ptu.world_size): 78 | part_seg_pred = pkl.load(open(tmpdir / f"part_{i}.pkl", "rb")) 79 | seg_pred.update(part_seg_pred) 80 | shutil.rmtree(tmpdir) 81 | return seg_pred 82 | 83 | 84 | def compute_metrics( 85 | seg_pred, 86 | seg_gt, 87 | n_cls, 88 | ignore_index=None, 89 | ret_cat_iou=False, 90 | tmp_dir=None, 91 | distributed=False, 92 | ): 93 | ret_metrics_mean = torch.zeros(3, dtype=float, device=ptu.device) 94 | if ptu.dist_rank == 0: 95 | list_seg_pred = [] 96 | list_seg_gt = [] 97 | keys = sorted(seg_pred.keys()) 98 | for k in keys: 99 | list_seg_pred.append(np.asarray(seg_pred[k])) 100 | list_seg_gt.append(np.asarray(seg_gt[k])) 101 | ret_metrics = mean_iou( 102 | results=list_seg_pred, 103 | gt_seg_maps=list_seg_gt, 104 | num_classes=n_cls, 105 | ignore_index=ignore_index, 106 | ) 107 | ret_metrics = [ret_metrics["aAcc"], ret_metrics["Acc"], ret_metrics["IoU"]] 108 | ret_metrics_mean = torch.tensor( 109 | [ 110 | np.round(np.nanmean(ret_metric.astype(np.float)) * 100, 2) 111 | for ret_metric in ret_metrics 112 | ], 113 | dtype=float, 114 | device=ptu.device, 115 | ) 116 | cat_iou = ret_metrics[2] 117 | # broadcast metrics from 0 to all nodes 118 | if distributed: 119 | dist.broadcast(ret_metrics_mean, 0) 120 | pix_acc, mean_acc, miou = ret_metrics_mean 121 | ret = dict(pixel_accuracy=pix_acc, mean_accuracy=mean_acc, mean_iou=miou) 122 | if ret_cat_iou and ptu.dist_rank == 0: 123 | ret["cat_iou"] = cat_iou 124 | return ret 125 | -------------------------------------------------------------------------------- /segm/model/blocks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from 2020 Ross Wightman 3 | https://github.com/rwightman/pytorch-image-models 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | from einops import rearrange 9 | from pathlib import Path 10 | 11 | import torch.nn.functional as F 12 | 13 | from timm.models.layers import DropPath 14 | 15 | 16 | class FeedForward(nn.Module): 17 | def __init__(self, dim, hidden_dim, dropout, out_dim=None): 18 | super().__init__() 19 | self.fc1 = nn.Linear(dim, hidden_dim) 20 | self.act = nn.GELU() 21 | if out_dim is None: 22 | out_dim = dim 23 | self.fc2 = nn.Linear(hidden_dim, out_dim) 24 | self.drop = nn.Dropout(dropout) 25 | 26 | @property 27 | def unwrapped(self): 28 | return self 29 | 30 | def forward(self, x): 31 | x = self.fc1(x) 32 | x = self.act(x) 33 | x = self.drop(x) 34 | x = self.fc2(x) 35 | x = self.drop(x) 36 | return x 37 | 38 | 39 | class Attention(nn.Module): 40 | def __init__(self, dim, heads, dropout): 41 | super().__init__() 42 | self.heads = heads 43 | head_dim = dim // heads 44 | self.scale = head_dim ** -0.5 45 | self.attn = None 46 | 47 | self.qkv = nn.Linear(dim, dim * 3) 48 | self.attn_drop = nn.Dropout(dropout) 49 | self.proj = nn.Linear(dim, dim) 50 | self.proj_drop = nn.Dropout(dropout) 51 | 52 | @property 53 | def unwrapped(self): 54 | return self 55 | 56 | def forward(self, x, mask=None): 57 | B, N, C = x.shape 58 | qkv = ( 59 | self.qkv(x) 60 | .reshape(B, N, 3, self.heads, C // self.heads) 61 | .permute(2, 0, 3, 1, 4) 62 | ) 63 | q, k, v = ( 64 | qkv[0], 65 | qkv[1], 66 | qkv[2], 67 | ) 68 | 69 | attn = (q @ k.transpose(-2, -1)) * self.scale 70 | attn = attn.softmax(dim=-1) 71 | attn = self.attn_drop(attn) 72 | 73 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 74 | x = self.proj(x) 75 | x = self.proj_drop(x) 76 | 77 | return x, attn 78 | 79 | 80 | class Block(nn.Module): 81 | def __init__(self, dim, heads, mlp_dim, dropout, drop_path): 82 | super().__init__() 83 | self.norm1 = nn.LayerNorm(dim) 84 | self.norm2 = nn.LayerNorm(dim) 85 | self.attn = Attention(dim, heads, dropout) 86 | self.mlp = FeedForward(dim, mlp_dim, dropout) 87 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 88 | 89 | def forward(self, x, mask=None, return_attention=False): 90 | y, attn = self.attn(self.norm1(x), mask) 91 | if return_attention: 92 | return attn 93 | x = x + self.drop_path(y) 94 | x = x + self.drop_path(self.mlp(self.norm2(x))) 95 | return x 96 | -------------------------------------------------------------------------------- /segm/model/decoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | 7 | from timm.models.layers import trunc_normal_ 8 | 9 | from segm.model.blocks import Block, FeedForward 10 | from segm.model.utils import init_weights 11 | 12 | 13 | class DecoderLinear(nn.Module): 14 | def __init__(self, n_cls, patch_size, d_encoder): 15 | super().__init__() 16 | 17 | self.d_encoder = d_encoder 18 | self.patch_size = patch_size 19 | self.n_cls = n_cls 20 | 21 | self.head = nn.Linear(self.d_encoder, n_cls) 22 | self.apply(init_weights) 23 | 24 | @torch.jit.ignore 25 | def no_weight_decay(self): 26 | return set() 27 | 28 | def forward(self, x, im_size): 29 | H, W = im_size 30 | GS = H // self.patch_size 31 | x = self.head(x) 32 | x = rearrange(x, "b (h w) c -> b c h w", h=GS) 33 | 34 | return x 35 | 36 | 37 | class MaskTransformer(nn.Module): 38 | def __init__( 39 | self, 40 | n_cls, 41 | patch_size, 42 | d_encoder, 43 | n_layers, 44 | n_heads, 45 | d_model, 46 | d_ff, 47 | drop_path_rate, 48 | dropout, 49 | ): 50 | super().__init__() 51 | self.d_encoder = d_encoder 52 | self.patch_size = patch_size 53 | self.n_layers = n_layers 54 | self.n_cls = n_cls 55 | self.d_model = d_model 56 | self.d_ff = d_ff 57 | self.scale = d_model ** -0.5 58 | 59 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)] 60 | self.blocks = nn.ModuleList( 61 | [Block(d_model, n_heads, d_ff, dropout, dpr[i]) for i in range(n_layers)] 62 | ) 63 | 64 | self.cls_emb = nn.Parameter(torch.randn(1, n_cls, d_model)) 65 | self.proj_dec = nn.Linear(d_encoder, d_model) 66 | 67 | self.proj_patch = nn.Parameter(self.scale * torch.randn(d_model, d_model)) 68 | self.proj_classes = nn.Parameter(self.scale * torch.randn(d_model, d_model)) 69 | 70 | self.decoder_norm = nn.LayerNorm(d_model) 71 | self.mask_norm = nn.LayerNorm(n_cls) 72 | 73 | self.apply(init_weights) 74 | trunc_normal_(self.cls_emb, std=0.02) 75 | 76 | @torch.jit.ignore 77 | def no_weight_decay(self): 78 | return {"cls_emb"} 79 | 80 | def forward(self, x, im_size): 81 | H, W = im_size 82 | GS = H // self.patch_size 83 | 84 | x = self.proj_dec(x) 85 | cls_emb = self.cls_emb.expand(x.size(0), -1, -1) 86 | x = torch.cat((x, cls_emb), 1) 87 | for blk in self.blocks: 88 | x = blk(x) 89 | x = self.decoder_norm(x) 90 | 91 | patches, cls_seg_feat = x[:, : -self.n_cls], x[:, -self.n_cls :] 92 | patches = patches @ self.proj_patch 93 | cls_seg_feat = cls_seg_feat @ self.proj_classes 94 | 95 | patches = patches / patches.norm(dim=-1, keepdim=True) 96 | cls_seg_feat = cls_seg_feat / cls_seg_feat.norm(dim=-1, keepdim=True) 97 | 98 | masks = patches @ cls_seg_feat.transpose(1, 2) 99 | masks = self.mask_norm(masks) 100 | masks = rearrange(masks, "b (h w) n -> b n h w", h=int(GS)) 101 | 102 | return masks 103 | 104 | def get_attention_map(self, x, layer_id): 105 | if layer_id >= self.n_layers or layer_id < 0: 106 | raise ValueError( 107 | f"Provided layer_id: {layer_id} is not valid. 0 <= {layer_id} < {self.n_layers}." 108 | ) 109 | x = self.proj_dec(x) 110 | cls_emb = self.cls_emb.expand(x.size(0), -1, -1) 111 | x = torch.cat((x, cls_emb), 1) 112 | for i, blk in enumerate(self.blocks): 113 | if i < layer_id: 114 | x = blk(x) 115 | else: 116 | return blk(x, return_attention=True) 117 | -------------------------------------------------------------------------------- /segm/model/factory.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import yaml 3 | import torch 4 | import math 5 | import os 6 | import torch.nn as nn 7 | 8 | from timm.models.helpers import load_pretrained, load_custom_pretrained 9 | from timm.models.vision_transformer import default_cfgs 10 | from timm.models.registry import register_model 11 | from timm.models.vision_transformer import _create_vision_transformer 12 | 13 | from segm.model.vit import VisionTransformer 14 | from segm.model.utils import checkpoint_filter_fn 15 | from segm.model.decoder import DecoderLinear 16 | from segm.model.decoder import MaskTransformer 17 | from segm.model.segmenter import Segmenter 18 | import segm.utils.torch as ptu 19 | 20 | 21 | @register_model 22 | def vit_base_patch8_384(pretrained=False, **kwargs): 23 | """ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 24 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 25 | """ 26 | model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) 27 | model = _create_vision_transformer( 28 | "vit_base_patch8_384", 29 | pretrained=pretrained, 30 | default_cfg=dict( 31 | url="", 32 | input_size=(3, 384, 384), 33 | mean=(0.5, 0.5, 0.5), 34 | std=(0.5, 0.5, 0.5), 35 | num_classes=1000, 36 | ), 37 | **model_kwargs, 38 | ) 39 | return model 40 | 41 | 42 | def create_vit(model_cfg): 43 | model_cfg = model_cfg.copy() 44 | backbone = model_cfg.pop("backbone") 45 | 46 | normalization = model_cfg.pop("normalization") 47 | model_cfg["n_cls"] = 1000 48 | mlp_expansion_ratio = 4 49 | model_cfg["d_ff"] = mlp_expansion_ratio * model_cfg["d_model"] 50 | 51 | if backbone in default_cfgs: 52 | default_cfg = default_cfgs[backbone] 53 | else: 54 | default_cfg = dict( 55 | pretrained=False, 56 | num_classes=1000, 57 | drop_rate=0.0, 58 | drop_path_rate=0.0, 59 | drop_block_rate=None, 60 | ) 61 | 62 | default_cfg["input_size"] = ( 63 | 3, 64 | model_cfg["image_size"][0], 65 | model_cfg["image_size"][1], 66 | ) 67 | model = VisionTransformer(**model_cfg) 68 | if backbone == "vit_base_patch8_384": 69 | path = os.path.expandvars("$TORCH_HOME/hub/checkpoints/vit_base_patch8_384.pth") 70 | state_dict = torch.load(path, map_location="cpu") 71 | filtered_dict = checkpoint_filter_fn(state_dict, model) 72 | model.load_state_dict(filtered_dict, strict=True) 73 | elif "deit" in backbone: 74 | load_pretrained(model, default_cfg, filter_fn=checkpoint_filter_fn) 75 | else: 76 | load_custom_pretrained(model, default_cfg) 77 | 78 | return model 79 | 80 | 81 | def create_decoder(encoder, decoder_cfg): 82 | decoder_cfg = decoder_cfg.copy() 83 | name = decoder_cfg.pop("name") 84 | decoder_cfg["d_encoder"] = encoder.d_model 85 | decoder_cfg["patch_size"] = encoder.patch_size 86 | 87 | if "linear" in name: 88 | decoder = DecoderLinear(**decoder_cfg) 89 | elif name == "mask_transformer": 90 | dim = encoder.d_model 91 | n_heads = dim // 64 92 | decoder_cfg["n_heads"] = n_heads 93 | decoder_cfg["d_model"] = dim 94 | decoder_cfg["d_ff"] = 4 * dim 95 | decoder = MaskTransformer(**decoder_cfg) 96 | else: 97 | raise ValueError(f"Unknown decoder: {name}") 98 | return decoder 99 | 100 | 101 | def create_segmenter(model_cfg): 102 | model_cfg = model_cfg.copy() 103 | decoder_cfg = model_cfg.pop("decoder") 104 | decoder_cfg["n_cls"] = model_cfg["n_cls"] 105 | 106 | encoder = create_vit(model_cfg) 107 | decoder = create_decoder(encoder, decoder_cfg) 108 | model = Segmenter(encoder, decoder, n_cls=model_cfg["n_cls"]) 109 | 110 | return model 111 | 112 | 113 | def load_model(model_path): 114 | variant_path = Path(model_path).parent / "variant.yml" 115 | with open(variant_path, "r") as f: 116 | variant = yaml.load(f, Loader=yaml.FullLoader) 117 | net_kwargs = variant["net_kwargs"] 118 | 119 | model = create_segmenter(net_kwargs) 120 | data = torch.load(model_path, map_location=ptu.device) 121 | checkpoint = data["model"] 122 | 123 | model.load_state_dict(checkpoint, strict=True) 124 | 125 | return model, variant 126 | -------------------------------------------------------------------------------- /segm/model/segmenter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from segm.model.utils import padding, unpadding 6 | from timm.models.layers import trunc_normal_ 7 | 8 | 9 | class Segmenter(nn.Module): 10 | def __init__( 11 | self, 12 | encoder, 13 | decoder, 14 | n_cls, 15 | ): 16 | super().__init__() 17 | self.n_cls = n_cls 18 | self.patch_size = encoder.patch_size 19 | self.encoder = encoder 20 | self.decoder = decoder 21 | 22 | @torch.jit.ignore 23 | def no_weight_decay(self): 24 | def append_prefix_no_weight_decay(prefix, module): 25 | return set(map(lambda x: prefix + x, module.no_weight_decay())) 26 | 27 | nwd_params = append_prefix_no_weight_decay("encoder.", self.encoder).union( 28 | append_prefix_no_weight_decay("decoder.", self.decoder) 29 | ) 30 | return nwd_params 31 | 32 | def forward(self, im): 33 | H_ori, W_ori = im.size(2), im.size(3) 34 | im = padding(im, self.patch_size) 35 | H, W = im.size(2), im.size(3) 36 | 37 | x = self.encoder(im, return_features=True) 38 | 39 | # remove CLS/DIST tokens for decoding 40 | num_extra_tokens = 1 + self.encoder.distilled 41 | x = x[:, num_extra_tokens:] 42 | 43 | masks = self.decoder(x, (H, W)) 44 | 45 | masks = F.interpolate(masks, size=(H, W), mode="bilinear") 46 | masks = unpadding(masks, (H_ori, W_ori)) 47 | 48 | return masks 49 | 50 | def get_attention_map_enc(self, im, layer_id): 51 | return self.encoder.get_attention_map(im, layer_id) 52 | 53 | def get_attention_map_dec(self, im, layer_id): 54 | x = self.encoder(im, return_features=True) 55 | 56 | # remove CLS/DIST tokens for decoding 57 | num_extra_tokens = 1 + self.encoder.distilled 58 | x = x[:, num_extra_tokens:] 59 | 60 | return self.decoder.get_attention_map(x, layer_id) 61 | -------------------------------------------------------------------------------- /segm/model/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from collections import defaultdict 6 | 7 | from timm.models.layers import trunc_normal_ 8 | 9 | import segm.utils.torch as ptu 10 | 11 | 12 | def init_weights(m): 13 | if isinstance(m, nn.Linear): 14 | trunc_normal_(m.weight, std=0.02) 15 | if isinstance(m, nn.Linear) and m.bias is not None: 16 | nn.init.constant_(m.bias, 0) 17 | elif isinstance(m, nn.LayerNorm): 18 | nn.init.constant_(m.bias, 0) 19 | nn.init.constant_(m.weight, 1.0) 20 | 21 | 22 | def resize_pos_embed(posemb, grid_old_shape, grid_new_shape, num_extra_tokens): 23 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 24 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 25 | posemb_tok, posemb_grid = ( 26 | posemb[:, :num_extra_tokens], 27 | posemb[0, num_extra_tokens:], 28 | ) 29 | if grid_old_shape is None: 30 | gs_old_h = int(math.sqrt(len(posemb_grid))) 31 | gs_old_w = gs_old_h 32 | else: 33 | gs_old_h, gs_old_w = grid_old_shape 34 | 35 | gs_h, gs_w = grid_new_shape 36 | posemb_grid = posemb_grid.reshape(1, gs_old_h, gs_old_w, -1).permute(0, 3, 1, 2) 37 | posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") 38 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) 39 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 40 | return posemb 41 | 42 | 43 | def checkpoint_filter_fn(state_dict, model): 44 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 45 | out_dict = {} 46 | if "model" in state_dict: 47 | # For deit models 48 | state_dict = state_dict["model"] 49 | num_extra_tokens = 1 + ("dist_token" in state_dict.keys()) 50 | patch_size = model.patch_size 51 | image_size = model.patch_embed.image_size 52 | for k, v in state_dict.items(): 53 | if k == "pos_embed" and v.shape != model.pos_embed.shape: 54 | # To resize pos embedding when using model at different size from pretrained weights 55 | v = resize_pos_embed( 56 | v, 57 | None, 58 | (image_size[0] // patch_size, image_size[1] // patch_size), 59 | num_extra_tokens, 60 | ) 61 | out_dict[k] = v 62 | return out_dict 63 | 64 | 65 | def padding(im, patch_size, fill_value=0): 66 | # make the image sizes divisible by patch_size 67 | H, W = im.size(2), im.size(3) 68 | pad_h, pad_w = 0, 0 69 | if H % patch_size > 0: 70 | pad_h = patch_size - (H % patch_size) 71 | if W % patch_size > 0: 72 | pad_w = patch_size - (W % patch_size) 73 | im_padded = im 74 | if pad_h > 0 or pad_w > 0: 75 | im_padded = F.pad(im, (0, pad_w, 0, pad_h), value=fill_value) 76 | return im_padded 77 | 78 | 79 | def unpadding(y, target_size): 80 | H, W = target_size 81 | H_pad, W_pad = y.size(2), y.size(3) 82 | # crop predictions on extra pixels coming from padding 83 | extra_h = H_pad - H 84 | extra_w = W_pad - W 85 | if extra_h > 0: 86 | y = y[:, :, :-extra_h] 87 | if extra_w > 0: 88 | y = y[:, :, :, :-extra_w] 89 | return y 90 | 91 | 92 | def resize(im, smaller_size): 93 | h, w = im.shape[2:] 94 | if h < w: 95 | ratio = w / h 96 | h_res, w_res = smaller_size, ratio * smaller_size 97 | else: 98 | ratio = h / w 99 | h_res, w_res = ratio * smaller_size, smaller_size 100 | if min(h, w) < smaller_size: 101 | im_res = F.interpolate(im, (int(h_res), int(w_res)), mode="bilinear") 102 | else: 103 | im_res = im 104 | return im_res 105 | 106 | 107 | def sliding_window(im, flip, window_size, window_stride): 108 | B, C, H, W = im.shape 109 | ws = window_size 110 | 111 | windows = {"crop": [], "anchors": []} 112 | h_anchors = torch.arange(0, H, window_stride) 113 | w_anchors = torch.arange(0, W, window_stride) 114 | h_anchors = [h.item() for h in h_anchors if h < H - ws] + [H - ws] 115 | w_anchors = [w.item() for w in w_anchors if w < W - ws] + [W - ws] 116 | for ha in h_anchors: 117 | for wa in w_anchors: 118 | window = im[:, :, ha : ha + ws, wa : wa + ws] 119 | windows["crop"].append(window) 120 | windows["anchors"].append((ha, wa)) 121 | windows["flip"] = flip 122 | windows["shape"] = (H, W) 123 | return windows 124 | 125 | 126 | def merge_windows(windows, window_size, ori_shape): 127 | ws = window_size 128 | im_windows = windows["seg_maps"] 129 | anchors = windows["anchors"] 130 | C = im_windows[0].shape[0] 131 | H, W = windows["shape"] 132 | flip = windows["flip"] 133 | 134 | logit = torch.zeros((C, H, W), device=im_windows.device) 135 | count = torch.zeros((1, H, W), device=im_windows.device) 136 | for window, (ha, wa) in zip(im_windows, anchors): 137 | logit[:, ha : ha + ws, wa : wa + ws] += window 138 | count[:, ha : ha + ws, wa : wa + ws] += 1 139 | logit = logit / count 140 | logit = F.interpolate( 141 | logit.unsqueeze(0), 142 | ori_shape, 143 | mode="bilinear", 144 | )[0] 145 | if flip: 146 | logit = torch.flip(logit, (2,)) 147 | result = F.softmax(logit, 0) 148 | return result 149 | 150 | 151 | def inference( 152 | model, 153 | ims, 154 | ims_metas, 155 | ori_shape, 156 | window_size, 157 | window_stride, 158 | batch_size, 159 | ): 160 | C = model.n_cls 161 | seg_map = torch.zeros((C, ori_shape[0], ori_shape[1]), device=ptu.device) 162 | for im, im_metas in zip(ims, ims_metas): 163 | im = im.to(ptu.device) 164 | im = resize(im, window_size) 165 | flip = im_metas["flip"] 166 | windows = sliding_window(im, flip, window_size, window_stride) 167 | crops = torch.stack(windows.pop("crop"))[:, 0] 168 | B = len(crops) 169 | WB = batch_size 170 | seg_maps = torch.zeros((B, C, window_size, window_size), device=im.device) 171 | with torch.no_grad(): 172 | for i in range(0, B, WB): 173 | seg_maps[i : i + WB] = model.forward(crops[i : i + WB]) 174 | windows["seg_maps"] = seg_maps 175 | im_seg_map = merge_windows(windows, window_size, ori_shape) 176 | seg_map += im_seg_map 177 | seg_map /= len(ims) 178 | return seg_map 179 | 180 | 181 | def num_params(model): 182 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 183 | n_params = sum([torch.prod(torch.tensor(p.size())) for p in model_parameters]) 184 | return n_params.item() 185 | -------------------------------------------------------------------------------- /segm/model/vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from 2020 Ross Wightman 3 | https://github.com/rwightman/pytorch-image-models 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from segm.model.utils import init_weights, resize_pos_embed 10 | from segm.model.blocks import Block 11 | 12 | from timm.models.layers import DropPath 13 | from timm.models.layers import trunc_normal_ 14 | from timm.models.vision_transformer import _load_weights 15 | 16 | 17 | class PatchEmbedding(nn.Module): 18 | def __init__(self, image_size, patch_size, embed_dim, channels): 19 | super().__init__() 20 | 21 | self.image_size = image_size 22 | if image_size[0] % patch_size != 0 or image_size[1] % patch_size != 0: 23 | raise ValueError("image dimensions must be divisible by the patch size") 24 | self.grid_size = image_size[0] // patch_size, image_size[1] // patch_size 25 | self.num_patches = self.grid_size[0] * self.grid_size[1] 26 | self.patch_size = patch_size 27 | 28 | self.proj = nn.Conv2d( 29 | channels, embed_dim, kernel_size=patch_size, stride=patch_size 30 | ) 31 | 32 | def forward(self, im): 33 | B, C, H, W = im.shape 34 | x = self.proj(im).flatten(2).transpose(1, 2) 35 | return x 36 | 37 | 38 | class VisionTransformer(nn.Module): 39 | def __init__( 40 | self, 41 | image_size, 42 | patch_size, 43 | n_layers, 44 | d_model, 45 | d_ff, 46 | n_heads, 47 | n_cls, 48 | dropout=0.1, 49 | drop_path_rate=0.0, 50 | distilled=False, 51 | channels=3, 52 | ): 53 | super().__init__() 54 | self.patch_embed = PatchEmbedding( 55 | image_size, 56 | patch_size, 57 | d_model, 58 | channels, 59 | ) 60 | self.patch_size = patch_size 61 | self.n_layers = n_layers 62 | self.d_model = d_model 63 | self.d_ff = d_ff 64 | self.n_heads = n_heads 65 | self.dropout = nn.Dropout(dropout) 66 | self.n_cls = n_cls 67 | 68 | # cls and pos tokens 69 | self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) 70 | self.distilled = distilled 71 | if self.distilled: 72 | self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model)) 73 | self.pos_embed = nn.Parameter( 74 | torch.randn(1, self.patch_embed.num_patches + 2, d_model) 75 | ) 76 | self.head_dist = nn.Linear(d_model, n_cls) 77 | else: 78 | self.pos_embed = nn.Parameter( 79 | torch.randn(1, self.patch_embed.num_patches + 1, d_model) 80 | ) 81 | 82 | # transformer blocks 83 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)] 84 | self.blocks = nn.ModuleList( 85 | [Block(d_model, n_heads, d_ff, dropout, dpr[i]) for i in range(n_layers)] 86 | ) 87 | 88 | # output head 89 | self.norm = nn.LayerNorm(d_model) 90 | self.head = nn.Linear(d_model, n_cls) 91 | 92 | trunc_normal_(self.pos_embed, std=0.02) 93 | trunc_normal_(self.cls_token, std=0.02) 94 | if self.distilled: 95 | trunc_normal_(self.dist_token, std=0.02) 96 | self.pre_logits = nn.Identity() 97 | 98 | self.apply(init_weights) 99 | 100 | @torch.jit.ignore 101 | def no_weight_decay(self): 102 | return {"pos_embed", "cls_token", "dist_token"} 103 | 104 | @torch.jit.ignore() 105 | def load_pretrained(self, checkpoint_path, prefix=""): 106 | _load_weights(self, checkpoint_path, prefix) 107 | 108 | def forward(self, im, return_features=False): 109 | B, _, H, W = im.shape 110 | PS = self.patch_size 111 | 112 | x = self.patch_embed(im) 113 | cls_tokens = self.cls_token.expand(B, -1, -1) 114 | if self.distilled: 115 | dist_tokens = self.dist_token.expand(B, -1, -1) 116 | x = torch.cat((cls_tokens, dist_tokens, x), dim=1) 117 | else: 118 | x = torch.cat((cls_tokens, x), dim=1) 119 | 120 | pos_embed = self.pos_embed 121 | num_extra_tokens = 1 + self.distilled 122 | if x.shape[1] != pos_embed.shape[1]: 123 | pos_embed = resize_pos_embed( 124 | pos_embed, 125 | self.patch_embed.grid_size, 126 | (H // PS, W // PS), 127 | num_extra_tokens, 128 | ) 129 | x = x + pos_embed 130 | x = self.dropout(x) 131 | 132 | for blk in self.blocks: 133 | x = blk(x) 134 | x = self.norm(x) 135 | 136 | if return_features: 137 | return x 138 | 139 | if self.distilled: 140 | x, x_dist = x[:, 0], x[:, 1] 141 | x = self.head(x) 142 | x_dist = self.head_dist(x_dist) 143 | x = (x + x_dist) / 2 144 | else: 145 | x = x[:, 0] 146 | x = self.head(x) 147 | return x 148 | 149 | def get_attention_map(self, im, layer_id): 150 | if layer_id >= self.n_layers or layer_id < 0: 151 | raise ValueError( 152 | f"Provided layer_id: {layer_id} is not valid. 0 <= {layer_id} < {self.n_layers}." 153 | ) 154 | B, _, H, W = im.shape 155 | PS = self.patch_size 156 | 157 | x = self.patch_embed(im) 158 | cls_tokens = self.cls_token.expand(B, -1, -1) 159 | if self.distilled: 160 | dist_tokens = self.dist_token.expand(B, -1, -1) 161 | x = torch.cat((cls_tokens, dist_tokens, x), dim=1) 162 | else: 163 | x = torch.cat((cls_tokens, x), dim=1) 164 | 165 | pos_embed = self.pos_embed 166 | num_extra_tokens = 1 + self.distilled 167 | if x.shape[1] != pos_embed.shape[1]: 168 | pos_embed = resize_pos_embed( 169 | pos_embed, 170 | self.patch_embed.grid_size, 171 | (H // PS, W // PS), 172 | num_extra_tokens, 173 | ) 174 | x = x + pos_embed 175 | 176 | for i, blk in enumerate(self.blocks): 177 | if i < layer_id: 178 | x = blk(x) 179 | else: 180 | return blk(x, return_attention=True) 181 | -------------------------------------------------------------------------------- /segm/optim/factory.py: -------------------------------------------------------------------------------- 1 | from timm import scheduler 2 | from timm import optim 3 | 4 | from segm.optim.scheduler import PolynomialLR 5 | 6 | 7 | def create_scheduler(opt_args, optimizer): 8 | if opt_args.sched == "polynomial": 9 | lr_scheduler = PolynomialLR( 10 | optimizer, 11 | opt_args.poly_step_size, 12 | opt_args.iter_warmup, 13 | opt_args.iter_max, 14 | opt_args.poly_power, 15 | opt_args.min_lr, 16 | ) 17 | else: 18 | lr_scheduler, _ = scheduler.create_scheduler(opt_args, optimizer) 19 | return lr_scheduler 20 | 21 | 22 | def create_optimizer(opt_args, model): 23 | return optim.create_optimizer(opt_args, model) 24 | -------------------------------------------------------------------------------- /segm/optim/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch import optim 2 | from torch.optim.lr_scheduler import _LRScheduler 3 | from timm.scheduler.scheduler import Scheduler 4 | 5 | 6 | class PolynomialLR(_LRScheduler): 7 | def __init__( 8 | self, 9 | optimizer, 10 | step_size, 11 | iter_warmup, 12 | iter_max, 13 | power, 14 | min_lr=0, 15 | last_epoch=-1, 16 | ): 17 | self.step_size = step_size 18 | self.iter_warmup = int(iter_warmup) 19 | self.iter_max = int(iter_max) 20 | self.power = power 21 | self.min_lr = min_lr 22 | super(PolynomialLR, self).__init__(optimizer, last_epoch) 23 | 24 | def polynomial_decay(self, lr): 25 | iter_cur = float(self.last_epoch) 26 | if iter_cur < self.iter_warmup: 27 | coef = iter_cur / self.iter_warmup 28 | coef *= (1 - self.iter_warmup / self.iter_max) ** self.power 29 | else: 30 | coef = (1 - iter_cur / self.iter_max) ** self.power 31 | return (lr - self.min_lr) * coef + self.min_lr 32 | 33 | def get_lr(self): 34 | if ( 35 | (self.last_epoch == 0) 36 | or (self.last_epoch % self.step_size != 0) 37 | or (self.last_epoch > self.iter_max) 38 | ): 39 | return [group["lr"] for group in self.optimizer.param_groups] 40 | return [self.polynomial_decay(lr) for lr in self.base_lrs] 41 | 42 | def step_update(self, num_updates): 43 | self.step() 44 | -------------------------------------------------------------------------------- /segm/scripts/prepare_ade20k.py: -------------------------------------------------------------------------------- 1 | """Prepare ADE20K dataset""" 2 | import click 3 | import zipfile 4 | 5 | from pathlib import Path 6 | from segm.utils.download import download 7 | 8 | 9 | def download_ade(path, overwrite=False): 10 | _AUG_DOWNLOAD_URLS = [ 11 | ( 12 | "http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip", 13 | "219e1696abb36c8ba3a3afe7fb2f4b4606a897c7", 14 | ), 15 | ( 16 | "http://data.csail.mit.edu/places/ADEchallenge/release_test.zip", 17 | "e05747892219d10e9243933371a497e905a4860c", 18 | ), 19 | ] 20 | download_dir = path / "downloads" 21 | download_dir.mkdir(parents=True, exist_ok=True) 22 | for url, checksum in _AUG_DOWNLOAD_URLS: 23 | filename = download( 24 | url, path=str(download_dir), overwrite=overwrite, sha1_hash=checksum 25 | ) 26 | # extract 27 | with zipfile.ZipFile(filename, "r") as zip_ref: 28 | zip_ref.extractall(path=str(path)) 29 | 30 | 31 | @click.command(help="Initialize ADE20K dataset.") 32 | @click.argument("download_dir", type=str) 33 | def main(download_dir): 34 | dataset_dir = Path(download_dir) / "ade20k" 35 | download_ade(dataset_dir, overwrite=False) 36 | 37 | 38 | if __name__ == "__main__": 39 | main() 40 | -------------------------------------------------------------------------------- /segm/scripts/prepare_cityscapes.py: -------------------------------------------------------------------------------- 1 | """Prepare Cityscapes dataset""" 2 | import click 3 | import os 4 | import shutil 5 | import mmcv 6 | import zipfile 7 | 8 | from pathlib import Path 9 | from segm.utils.download import download 10 | 11 | USERNAME = None 12 | PASSWORD = None 13 | 14 | 15 | def download_cityscapes(path, username, password, overwrite=False): 16 | _CITY_DOWNLOAD_URLS = [ 17 | ("gtFine_trainvaltest.zip", "99f532cb1af174f5fcc4c5bc8feea8c66246ddbc"), 18 | ("leftImg8bit_trainvaltest.zip", "2c0b77ce9933cc635adda307fbba5566f5d9d404"), 19 | ] 20 | download_dir = path / "downloads" 21 | download_dir.mkdir(parents=True, exist_ok=True) 22 | 23 | os.system( 24 | f"wget --keep-session-cookies --save-cookies=cookies.txt --post-data 'username={username}&password={password}&submit=Login' https://www.cityscapes-dataset.com/login/ -P {download_dir}" 25 | ) 26 | 27 | if not (download_dir / "gtFine_trainvaltest.zip").is_file(): 28 | os.system( 29 | f"wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=1 -P {download_dir}" 30 | ) 31 | 32 | if not (download_dir / "leftImg8bit_trainvaltest.zip").is_file(): 33 | os.system( 34 | f"wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=3 -P {download_dir}" 35 | ) 36 | 37 | for filename, checksum in _CITY_DOWNLOAD_URLS: 38 | # extract 39 | with zipfile.ZipFile(str(download_dir / filename), "r") as zip_ref: 40 | zip_ref.extractall(path=path) 41 | print("Extracted", filename) 42 | 43 | 44 | def install_cityscapes_api(): 45 | os.system("pip install cityscapesscripts") 46 | try: 47 | import cityscapesscripts 48 | except Exception: 49 | print( 50 | "Installing Cityscapes API failed, please install it manually %s" 51 | % (repo_url) 52 | ) 53 | 54 | 55 | def convert_json_to_label(json_file): 56 | from cityscapesscripts.preparation.json2labelImg import json2labelImg 57 | 58 | label_file = json_file.replace("_polygons.json", "_labelTrainIds.png") 59 | json2labelImg(json_file, label_file, "trainIds") 60 | 61 | 62 | @click.command(help="Initialize Cityscapes dataset.") 63 | @click.argument("download_dir", type=str) 64 | @click.option("--username", default=USERNAME, type=str) 65 | @click.option("--password", default=PASSWORD, type=str) 66 | @click.option("--nproc", default=10, type=int) 67 | def main( 68 | download_dir, 69 | username, 70 | password, 71 | nproc, 72 | ): 73 | 74 | dataset_dir = Path(download_dir) / "cityscapes" 75 | 76 | if username is None or password is None: 77 | raise ValueError( 78 | "You must indicate your username and password either in the script variables or by passing options --username and --pasword." 79 | ) 80 | 81 | download_cityscapes(dataset_dir, username, password, overwrite=False) 82 | 83 | install_cityscapes_api() 84 | 85 | gt_dir = dataset_dir / "gtFine" 86 | 87 | poly_files = [] 88 | for poly in mmcv.scandir(str(gt_dir), "_polygons.json", recursive=True): 89 | poly_file = str(gt_dir / poly) 90 | poly_files.append(poly_file) 91 | mmcv.track_parallel_progress(convert_json_to_label, poly_files, nproc) 92 | 93 | split_names = ["train", "val", "test"] 94 | 95 | for split in split_names: 96 | filenames = [] 97 | for poly in mmcv.scandir(str(gt_dir / split), "_polygons.json", recursive=True): 98 | filenames.append(poly.replace("_gtFine_polygons.json", "")) 99 | with open(str(dataset_dir / f"{split}.txt"), "w") as f: 100 | f.writelines(f + "\n" for f in filenames) 101 | 102 | 103 | if __name__ == "__main__": 104 | main() 105 | -------------------------------------------------------------------------------- /segm/scripts/prepare_pcontext.py: -------------------------------------------------------------------------------- 1 | """Prepare PASCAL Context dataset""" 2 | import click 3 | import shutil 4 | import tarfile 5 | import torch 6 | 7 | from tqdm import tqdm 8 | from pathlib import Path 9 | 10 | from segm.utils.download import download 11 | 12 | 13 | def download_pcontext(path, overwrite=False): 14 | _AUG_DOWNLOAD_URLS = [ 15 | ( 16 | "https://www.dropbox.com/s/wtdibo9lb2fur70/VOCtrainval_03-May-2010.tar?dl=1", 17 | "VOCtrainval_03-May-2010.tar", 18 | "bf9985e9f2b064752bf6bd654d89f017c76c395a", 19 | ), 20 | ( 21 | "https://codalabuser.blob.core.windows.net/public/trainval_merged.json", 22 | "", 23 | "169325d9f7e9047537fedca7b04de4dddf10b881", 24 | ), 25 | ( 26 | "https://hangzh.s3.amazonaws.com/encoding/data/pcontext/train.pth", 27 | "", 28 | "4bfb49e8c1cefe352df876c9b5434e655c9c1d07", 29 | ), 30 | ( 31 | "https://hangzh.s3.amazonaws.com/encoding/data/pcontext/val.pth", 32 | "", 33 | "ebedc94247ec616c57b9a2df15091784826a7b0c", 34 | ), 35 | ] 36 | download_dir = path / "downloads" 37 | 38 | download_dir.mkdir(parents=True, exist_ok=True) 39 | 40 | for url, filename, checksum in _AUG_DOWNLOAD_URLS: 41 | filename = download( 42 | url, 43 | path=str(download_dir / filename), 44 | overwrite=overwrite, 45 | sha1_hash=checksum, 46 | ) 47 | # extract 48 | if Path(filename).suffix == ".tar": 49 | with tarfile.open(filename) as tar: 50 | tar.extractall(path=str(path)) 51 | else: 52 | shutil.move( 53 | filename, 54 | str(path / "VOCdevkit" / "VOC2010" / Path(filename).name), 55 | ) 56 | 57 | 58 | @click.command(help="Initialize PASCAL Context dataset.") 59 | @click.argument("download_dir", type=str) 60 | def main(download_dir): 61 | 62 | dataset_dir = Path(download_dir) / "pcontext" 63 | 64 | download_pcontext(dataset_dir, overwrite=False) 65 | 66 | devkit_path = dataset_dir / "VOCdevkit" 67 | out_dir = devkit_path / "VOC2010" / "SegmentationClassContext" 68 | imageset_dir = devkit_path / "VOC2010" / "ImageSets" / "SegmentationContext" 69 | 70 | out_dir.mkdir(parents=True, exist_ok=True) 71 | imageset_dir.mkdir(parents=True, exist_ok=True) 72 | 73 | train_torch_path = devkit_path / "VOC2010" / "train.pth" 74 | val_torch_path = devkit_path / "VOC2010" / "val.pth" 75 | 76 | train_dict = torch.load(str(train_torch_path)) 77 | 78 | train_list = [] 79 | for idx, label in tqdm(train_dict.items()): 80 | idx = str(idx) 81 | new_idx = idx[:4] + "_" + idx[4:] 82 | train_list.append(new_idx) 83 | label_path = out_dir / f"{new_idx}.png" 84 | label.save(str(label_path)) 85 | 86 | with open(str(imageset_dir / "train.txt"), "w") as f: 87 | f.writelines(line + "\n" for line in sorted(train_list)) 88 | 89 | val_dict = torch.load(str(val_torch_path)) 90 | 91 | val_list = [] 92 | for idx, label in tqdm(val_dict.items()): 93 | idx = str(idx) 94 | new_idx = idx[:4] + "_" + idx[4:] 95 | val_list.append(new_idx) 96 | label_path = out_dir / f"{new_idx}.png" 97 | label.save(str(label_path)) 98 | 99 | with open(str(imageset_dir / "val.txt"), "w") as f: 100 | f.writelines(line + "\n" for line in sorted(val_list)) 101 | 102 | 103 | if __name__ == "__main__": 104 | main() 105 | -------------------------------------------------------------------------------- /segm/scripts/show_attn_map.py: -------------------------------------------------------------------------------- 1 | import click 2 | import einops 3 | import torch 4 | import torchvision 5 | 6 | import matplotlib.pyplot as plt 7 | import segm.utils.torch as ptu 8 | import torch.nn.functional as F 9 | 10 | from pathlib import Path 11 | from PIL import Image 12 | from segm import config 13 | from segm.data.utils import STATS 14 | from segm.model.decoder import MaskTransformer 15 | from segm.model.factory import load_model 16 | from torchvision import transforms 17 | 18 | 19 | @click.command() 20 | @click.argument("model-path", type=str) 21 | @click.argument("image-path", type=str) 22 | @click.argument("output-dir", type=str) 23 | @click.option("--layer-id", default=0, type=int) 24 | @click.option("--x-patch", default=0, type=int) 25 | @click.option("--y-patch", default=0, type=int) 26 | @click.option("--cmap", default="viridis", type=str) 27 | @click.option("--enc/--dec", default=True, is_flag=True) 28 | @click.option("--cls/--patch", default=False, is_flag=True) 29 | def visualize( 30 | model_path, 31 | image_path, 32 | output_dir, 33 | layer_id, 34 | x_patch, 35 | y_patch, 36 | cmap, 37 | enc, 38 | cls, 39 | ): 40 | 41 | output_dir = Path(output_dir) 42 | model_dir = Path(model_path).parent 43 | 44 | ptu.set_gpu_mode(True) 45 | 46 | # Build model 47 | model, variant = load_model(model_path) 48 | for p in model.parameters(): 49 | p.requires_grad = False 50 | 51 | model.eval() 52 | model.to(ptu.device) 53 | 54 | # Get model config 55 | patch_size = model.patch_size 56 | normalization = variant["dataset_kwargs"]["normalization"] 57 | image_size = variant["dataset_kwargs"]["image_size"] 58 | n_cls = variant["net_kwargs"]["n_cls"] 59 | stats = STATS[normalization] 60 | 61 | # Open image and process it 62 | try: 63 | with open(image_path, "rb") as f: 64 | img = Image.open(f) 65 | img = img.convert("RGB") 66 | except: 67 | raise ValueError(f"Provided image path {image_path} is not a valid image file.") 68 | 69 | # Normalize and resize 70 | transform = transforms.Compose( 71 | [ 72 | transforms.Resize(image_size), 73 | transforms.ToTensor(), 74 | transforms.Normalize(stats["mean"], stats["std"]), 75 | ] 76 | ) 77 | 78 | img = transform(img) 79 | 80 | # Make the image divisible by the patch size 81 | w, h = ( 82 | image_size - image_size % patch_size, 83 | image_size - image_size % patch_size, 84 | ) 85 | 86 | # Crop to image size 87 | img = img[:, :w, :h].unsqueeze(0) 88 | 89 | w_featmap = img.shape[-2] // patch_size 90 | h_featmap = img.shape[-1] // patch_size 91 | 92 | # Sanity checks 93 | if not enc and not isinstance(model.decoder, MaskTransformer): 94 | raise ValueError( 95 | f"Attention maps for decoder are only availabe for MaskTransformer. Provided model with decoder type: {model.decoder}." 96 | ) 97 | 98 | if not cls: 99 | if x_patch > w_featmap or y_patch > h_featmap: 100 | raise ValueError( 101 | f"Provided patch x: {x_patch} y: {y_patch} is not valid. Patch should be in the range x: [0, {w_featmap}), y: [0, {h_featmap})" 102 | ) 103 | num_patch = w_featmap * y_patch + x_patch 104 | 105 | if layer_id < 0: 106 | raise ValueError("Provided layer_id should be positive.") 107 | 108 | if enc and model.encoder.n_layers <= layer_id: 109 | raise ValueError( 110 | f"Provided layer_id: {layer_id} is not valid for encoder with {model.encoder.n_layers}." 111 | ) 112 | 113 | if not enc and model.decoder.n_layers <= layer_id: 114 | raise ValueError( 115 | f"Provided layer_id: {layer_id} is not valid for decoder with {model.decoder.n_layers}." 116 | ) 117 | 118 | Path.mkdir(output_dir, exist_ok=True) 119 | 120 | # Process input and extract attention maps 121 | if enc: 122 | print(f"Generating Attention Mapping for Encoder Layer Id {layer_id}") 123 | attentions = model.get_attention_map_enc(img.to(ptu.device), layer_id) 124 | num_extra_tokens = 1 + model.encoder.distilled 125 | if cls: 126 | attentions = attentions[0, :, 0, num_extra_tokens:] 127 | else: 128 | attentions = attentions[ 129 | 0, :, num_patch + num_extra_tokens, num_extra_tokens: 130 | ] 131 | else: 132 | print(f"Generating Attention Mapping for Decoder Layer Id {layer_id}") 133 | attentions = model.get_attention_map_dec(img.to(ptu.device), layer_id) 134 | if cls: 135 | attentions = attentions[0, :, -n_cls:, :-n_cls] 136 | else: 137 | attentions = attentions[0, :, num_patch, :-n_cls] 138 | 139 | # Reshape into image shape 140 | nh = attentions.shape[0] # Number of heads 141 | attentions = attentions.reshape(nh, -1) 142 | 143 | if cls and not enc: 144 | attentions = attentions.reshape(nh, n_cls, w_featmap, h_featmap) 145 | else: 146 | attentions = attentions.reshape(nh, 1, w_featmap, h_featmap) 147 | 148 | # Resize attention maps to match input size 149 | attentions = ( 150 | F.interpolate(attentions, scale_factor=patch_size, mode="nearest").cpu().numpy() 151 | ) 152 | 153 | # Save Attention map for each head 154 | for i in range(nh): 155 | base_name = "enc" if enc else "dec" 156 | head_name = f"{base_name}_layer{layer_id}_attn-head{i}" 157 | attention_maps_list = attentions[i] 158 | for j in range(attention_maps_list.shape[0]): 159 | attention_map = attention_maps_list[j] 160 | file_name = head_name 161 | dir_path = output_dir / f"{base_name}_layer{layer_id}" 162 | Path.mkdir(dir_path, exist_ok=True) 163 | if cls: 164 | if enc: 165 | file_name = f"{file_name}_cls" 166 | dir_path /= "cls" 167 | else: 168 | file_name = f"{file_name}_{j}" 169 | dir_path /= f"cls_{j}" 170 | Path.mkdir(dir_path, exist_ok=True) 171 | else: 172 | dir_path /= f"patch_{x_patch}_{y_patch}" 173 | Path.mkdir(dir_path, exist_ok=True) 174 | 175 | file_path = dir_path / f"{file_name}.png" 176 | plt.imsave(fname=str(file_path), arr=attention_map, format="png", cmap=cmap) 177 | print(f"{file_path} saved.") 178 | 179 | # Save input image showing selected patch 180 | if not cls: 181 | im_n = torchvision.utils.make_grid(img, normalize=True, scale_each=True) 182 | 183 | # Compute corresponding X and Y px in the original image 184 | x_px = x_patch * patch_size 185 | y_px = y_patch * patch_size 186 | px_v = einops.repeat( 187 | torch.tensor([1, 0, 0]), 188 | "c -> 1 c h w", 189 | h=patch_size, 190 | w=patch_size, 191 | ) 192 | 193 | # Draw pixels for selected patch 194 | im_n[:, y_px : y_px + patch_size, x_px : x_px + patch_size] = px_v 195 | torchvision.utils.save_image( 196 | im_n, 197 | str(dir_path / "input_img.png"), 198 | ) 199 | 200 | 201 | if __name__ == "__main__": 202 | visualize() 203 | -------------------------------------------------------------------------------- /segm/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | import yaml 4 | import json 5 | import numpy as np 6 | import torch 7 | import click 8 | import argparse 9 | from torch.nn.parallel import DistributedDataParallel as DDP 10 | 11 | from segm.utils import distributed 12 | import segm.utils.torch as ptu 13 | from segm import config 14 | 15 | from segm.model.factory import create_segmenter 16 | from segm.optim.factory import create_optimizer, create_scheduler 17 | from segm.data.factory import create_dataset 18 | from segm.model.utils import num_params 19 | 20 | from timm.utils import NativeScaler 21 | from contextlib import suppress 22 | 23 | from segm.utils.distributed import sync_model 24 | from segm.engine import train_one_epoch, evaluate 25 | 26 | 27 | @click.command(help="") 28 | @click.option("--log-dir", type=str, help="logging directory") 29 | @click.option("--dataset", type=str) 30 | @click.option("--im-size", default=None, type=int, help="dataset resize size") 31 | @click.option("--crop-size", default=None, type=int) 32 | @click.option("--window-size", default=None, type=int) 33 | @click.option("--window-stride", default=None, type=int) 34 | @click.option("--backbone", default="", type=str) 35 | @click.option("--decoder", default="", type=str) 36 | @click.option("--optimizer", default="sgd", type=str) 37 | @click.option("--scheduler", default="polynomial", type=str) 38 | @click.option("--weight-decay", default=0.0, type=float) 39 | @click.option("--dropout", default=0.0, type=float) 40 | @click.option("--drop-path", default=0.1, type=float) 41 | @click.option("--batch-size", default=None, type=int) 42 | @click.option("--epochs", default=None, type=int) 43 | @click.option("-lr", "--learning-rate", default=None, type=float) 44 | @click.option("--normalization", default=None, type=str) 45 | @click.option("--eval-freq", default=None, type=int) 46 | @click.option("--amp/--no-amp", default=False, is_flag=True) 47 | @click.option("--resume/--no-resume", default=True, is_flag=True) 48 | def main( 49 | log_dir, 50 | dataset, 51 | im_size, 52 | crop_size, 53 | window_size, 54 | window_stride, 55 | backbone, 56 | decoder, 57 | optimizer, 58 | scheduler, 59 | weight_decay, 60 | dropout, 61 | drop_path, 62 | batch_size, 63 | epochs, 64 | learning_rate, 65 | normalization, 66 | eval_freq, 67 | amp, 68 | resume, 69 | ): 70 | # start distributed mode 71 | ptu.set_gpu_mode(True) 72 | distributed.init_process() 73 | 74 | # set up configuration 75 | cfg = config.load_config() 76 | model_cfg = cfg["model"][backbone] 77 | dataset_cfg = cfg["dataset"][dataset] 78 | if "mask_transformer" in decoder: 79 | decoder_cfg = cfg["decoder"]["mask_transformer"] 80 | else: 81 | decoder_cfg = cfg["decoder"][decoder] 82 | 83 | # model config 84 | if not im_size: 85 | im_size = dataset_cfg["im_size"] 86 | if not crop_size: 87 | crop_size = dataset_cfg.get("crop_size", im_size) 88 | if not window_size: 89 | window_size = dataset_cfg.get("window_size", im_size) 90 | if not window_stride: 91 | window_stride = dataset_cfg.get("window_stride", im_size) 92 | 93 | model_cfg["image_size"] = (crop_size, crop_size) 94 | model_cfg["backbone"] = backbone 95 | model_cfg["dropout"] = dropout 96 | model_cfg["drop_path_rate"] = drop_path 97 | decoder_cfg["name"] = decoder 98 | model_cfg["decoder"] = decoder_cfg 99 | 100 | # dataset config 101 | world_batch_size = dataset_cfg["batch_size"] 102 | num_epochs = dataset_cfg["epochs"] 103 | lr = dataset_cfg["learning_rate"] 104 | if batch_size: 105 | world_batch_size = batch_size 106 | if epochs: 107 | num_epochs = epochs 108 | if learning_rate: 109 | lr = learning_rate 110 | if eval_freq is None: 111 | eval_freq = dataset_cfg.get("eval_freq", 1) 112 | 113 | if normalization: 114 | model_cfg["normalization"] = normalization 115 | 116 | # experiment config 117 | batch_size = world_batch_size // ptu.world_size 118 | variant = dict( 119 | world_batch_size=world_batch_size, 120 | version="normal", 121 | resume=resume, 122 | dataset_kwargs=dict( 123 | dataset=dataset, 124 | image_size=im_size, 125 | crop_size=crop_size, 126 | batch_size=batch_size, 127 | normalization=model_cfg["normalization"], 128 | split="train", 129 | num_workers=10, 130 | ), 131 | algorithm_kwargs=dict( 132 | batch_size=batch_size, 133 | start_epoch=0, 134 | num_epochs=num_epochs, 135 | eval_freq=eval_freq, 136 | ), 137 | optimizer_kwargs=dict( 138 | opt=optimizer, 139 | lr=lr, 140 | weight_decay=weight_decay, 141 | momentum=0.9, 142 | clip_grad=None, 143 | sched=scheduler, 144 | epochs=num_epochs, 145 | min_lr=1e-5, 146 | poly_power=0.9, 147 | poly_step_size=1, 148 | ), 149 | net_kwargs=model_cfg, 150 | amp=amp, 151 | log_dir=log_dir, 152 | inference_kwargs=dict( 153 | im_size=im_size, 154 | window_size=window_size, 155 | window_stride=window_stride, 156 | ), 157 | ) 158 | 159 | log_dir = Path(log_dir) 160 | log_dir.mkdir(parents=True, exist_ok=True) 161 | checkpoint_path = log_dir / "checkpoint.pth" 162 | 163 | # dataset 164 | dataset_kwargs = variant["dataset_kwargs"] 165 | 166 | train_loader = create_dataset(dataset_kwargs) 167 | val_kwargs = dataset_kwargs.copy() 168 | val_kwargs["split"] = "val" 169 | val_kwargs["batch_size"] = 1 170 | val_kwargs["crop"] = False 171 | val_loader = create_dataset(val_kwargs) 172 | n_cls = train_loader.unwrapped.n_cls 173 | 174 | # model 175 | net_kwargs = variant["net_kwargs"] 176 | net_kwargs["n_cls"] = n_cls 177 | model = create_segmenter(net_kwargs) 178 | model.to(ptu.device) 179 | 180 | # optimizer 181 | optimizer_kwargs = variant["optimizer_kwargs"] 182 | optimizer_kwargs["iter_max"] = len(train_loader) * optimizer_kwargs["epochs"] 183 | optimizer_kwargs["iter_warmup"] = 0.0 184 | opt_args = argparse.Namespace() 185 | opt_vars = vars(opt_args) 186 | for k, v in optimizer_kwargs.items(): 187 | opt_vars[k] = v 188 | optimizer = create_optimizer(opt_args, model) 189 | lr_scheduler = create_scheduler(opt_args, optimizer) 190 | num_iterations = 0 191 | amp_autocast = suppress 192 | loss_scaler = None 193 | if amp: 194 | amp_autocast = torch.cuda.amp.autocast 195 | loss_scaler = NativeScaler() 196 | 197 | # resume 198 | if resume and checkpoint_path.exists(): 199 | print(f"Resuming training from checkpoint: {checkpoint_path}") 200 | checkpoint = torch.load(checkpoint_path, map_location="cpu") 201 | model.load_state_dict(checkpoint["model"]) 202 | optimizer.load_state_dict(checkpoint["optimizer"]) 203 | if loss_scaler and "loss_scaler" in checkpoint: 204 | loss_scaler.load_state_dict(checkpoint["loss_scaler"]) 205 | lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) 206 | variant["algorithm_kwargs"]["start_epoch"] = checkpoint["epoch"] + 1 207 | else: 208 | sync_model(log_dir, model) 209 | 210 | if ptu.distributed: 211 | model = DDP(model, device_ids=[ptu.device], find_unused_parameters=True) 212 | 213 | # save config 214 | variant_str = yaml.dump(variant) 215 | print(f"Configuration:\n{variant_str}") 216 | variant["net_kwargs"] = net_kwargs 217 | variant["dataset_kwargs"] = dataset_kwargs 218 | log_dir.mkdir(parents=True, exist_ok=True) 219 | with open(log_dir / "variant.yml", "w") as f: 220 | f.write(variant_str) 221 | 222 | # train 223 | start_epoch = variant["algorithm_kwargs"]["start_epoch"] 224 | num_epochs = variant["algorithm_kwargs"]["num_epochs"] 225 | eval_freq = variant["algorithm_kwargs"]["eval_freq"] 226 | 227 | model_without_ddp = model 228 | if hasattr(model, "module"): 229 | model_without_ddp = model.module 230 | 231 | val_seg_gt = val_loader.dataset.get_gt_seg_maps() 232 | 233 | print(f"Train dataset length: {len(train_loader.dataset)}") 234 | print(f"Val dataset length: {len(val_loader.dataset)}") 235 | print(f"Encoder parameters: {num_params(model_without_ddp.encoder)}") 236 | print(f"Decoder parameters: {num_params(model_without_ddp.decoder)}") 237 | 238 | for epoch in range(start_epoch, num_epochs): 239 | # train for one epoch 240 | train_logger = train_one_epoch( 241 | model, 242 | train_loader, 243 | optimizer, 244 | lr_scheduler, 245 | epoch, 246 | amp_autocast, 247 | loss_scaler, 248 | ) 249 | 250 | # save checkpoint 251 | if ptu.dist_rank == 0: 252 | snapshot = dict( 253 | model=model_without_ddp.state_dict(), 254 | optimizer=optimizer.state_dict(), 255 | n_cls=model_without_ddp.n_cls, 256 | lr_scheduler=lr_scheduler.state_dict(), 257 | ) 258 | if loss_scaler is not None: 259 | snapshot["loss_scaler"] = loss_scaler.state_dict() 260 | snapshot["epoch"] = epoch 261 | torch.save(snapshot, checkpoint_path) 262 | 263 | # evaluate 264 | eval_epoch = epoch % eval_freq == 0 or epoch == num_epochs - 1 265 | if eval_epoch: 266 | eval_logger = evaluate( 267 | model, 268 | val_loader, 269 | val_seg_gt, 270 | window_size, 271 | window_stride, 272 | amp_autocast, 273 | ) 274 | print(f"Stats [{epoch}]:", eval_logger, flush=True) 275 | print("") 276 | 277 | # log stats 278 | if ptu.dist_rank == 0: 279 | train_stats = { 280 | k: meter.global_avg for k, meter in train_logger.meters.items() 281 | } 282 | val_stats = {} 283 | if eval_epoch: 284 | val_stats = { 285 | k: meter.global_avg for k, meter in eval_logger.meters.items() 286 | } 287 | 288 | log_stats = { 289 | **{f"train_{k}": v for k, v in train_stats.items()}, 290 | **{f"val_{k}": v for k, v in val_stats.items()}, 291 | "epoch": epoch, 292 | "num_updates": (epoch + 1) * len(train_loader), 293 | } 294 | 295 | with open(log_dir / "log.txt", "a") as f: 296 | f.write(json.dumps(log_stats) + "\n") 297 | 298 | distributed.barrier() 299 | distributed.destroy_process() 300 | sys.exit(1) 301 | 302 | 303 | if __name__ == "__main__": 304 | main() 305 | -------------------------------------------------------------------------------- /segm/utils/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hostlist 3 | from pathlib import Path 4 | import torch 5 | import torch.distributed as dist 6 | 7 | import segm.utils.torch as ptu 8 | 9 | 10 | def init_process(backend="nccl"): 11 | print(f"Starting process with rank {ptu.dist_rank}...", flush=True) 12 | 13 | if "SLURM_STEPS_GPUS" in os.environ: 14 | gpu_ids = os.environ["SLURM_STEP_GPUS"].split(",") 15 | os.environ["MASTER_PORT"] = str(12345 + int(min(gpu_ids))) 16 | else: 17 | os.environ["MASTER_PORT"] = str(12345) 18 | 19 | if "SLURM_JOB_NODELIST" in os.environ: 20 | hostnames = hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"]) 21 | os.environ["MASTER_ADDR"] = hostnames[0] 22 | else: 23 | os.environ["MASTER_ADDR"] = "127.0.0.1" 24 | 25 | dist.init_process_group( 26 | backend, 27 | rank=ptu.dist_rank, 28 | world_size=ptu.world_size, 29 | ) 30 | print(f"Process {ptu.dist_rank} is connected.", flush=True) 31 | dist.barrier() 32 | 33 | silence_print(ptu.dist_rank == 0) 34 | if ptu.dist_rank == 0: 35 | print(f"All processes are connected.", flush=True) 36 | 37 | 38 | def silence_print(is_master): 39 | """ 40 | This function disables printing when not in master process 41 | """ 42 | import builtins as __builtin__ 43 | 44 | builtin_print = __builtin__.print 45 | 46 | def print(*args, **kwargs): 47 | force = kwargs.pop("force", False) 48 | if is_master or force: 49 | builtin_print(*args, **kwargs) 50 | 51 | __builtin__.print = print 52 | 53 | 54 | def sync_model(sync_dir, model): 55 | # https://github.com/ylabbe/cosypose/blob/master/cosypose/utils/distributed.py 56 | sync_path = Path(sync_dir).resolve() / "sync_model.pkl" 57 | if ptu.dist_rank == 0 and ptu.world_size > 1: 58 | torch.save(model.state_dict(), sync_path) 59 | dist.barrier() 60 | if ptu.dist_rank > 0: 61 | model.load_state_dict(torch.load(sync_path)) 62 | dist.barrier() 63 | if ptu.dist_rank == 0 and ptu.world_size > 1: 64 | sync_path.unlink() 65 | return model 66 | 67 | 68 | def barrier(): 69 | dist.barrier() 70 | 71 | 72 | def destroy_process(): 73 | dist.destroy_process_group() 74 | -------------------------------------------------------------------------------- /segm/utils/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import hashlib 4 | from tqdm import tqdm 5 | 6 | 7 | def check_sha1(filename, sha1_hash): 8 | """Check whether the sha1 hash of the file content matches the expected hash. 9 | Parameters 10 | ---------- 11 | filename : str 12 | Path to the file. 13 | sha1_hash : str 14 | Expected sha1 hash in hexadecimal digits. 15 | Returns 16 | ------- 17 | bool 18 | Whether the file content matches the expected hash. 19 | """ 20 | sha1 = hashlib.sha1() 21 | with open(filename, "rb") as f: 22 | while True: 23 | data = f.read(1048576) 24 | if not data: 25 | break 26 | sha1.update(data) 27 | 28 | return sha1.hexdigest() == sha1_hash 29 | 30 | 31 | def download(url, path=None, overwrite=False, sha1_hash=None): 32 | """ 33 | https://github.com/junfu1115/DANet/blob/master/encoding/utils/files.py 34 | Download a given URL 35 | Parameters 36 | ---------- 37 | url : str 38 | URL to download 39 | path : str, optional 40 | Destination path to store downloaded file. By default stores to the 41 | current directory with same name as in url. 42 | overwrite : bool, optional 43 | Whether to overwrite destination file if already exists. 44 | sha1_hash : str, optional 45 | Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified 46 | but doesn't match. 47 | Returns 48 | ------- 49 | str 50 | The file path of the downloaded file. 51 | """ 52 | if path is None: 53 | fname = url.split("/")[-1] 54 | else: 55 | path = os.path.expanduser(path) 56 | if os.path.isdir(path): 57 | fname = os.path.join(path, url.split("/")[-1]) 58 | else: 59 | fname = path 60 | 61 | if ( 62 | overwrite 63 | or not os.path.exists(fname) 64 | or (sha1_hash and not check_sha1(fname, sha1_hash)) 65 | ): 66 | dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) 67 | if not os.path.exists(dirname): 68 | os.makedirs(dirname) 69 | 70 | print("Downloading %s from %s..." % (fname, url)) 71 | r = requests.get(url, stream=True) 72 | if r.status_code != 200: 73 | raise RuntimeError("Failed downloading url %s" % url) 74 | total_length = r.headers.get("content-length") 75 | with open(fname, "wb") as f: 76 | if total_length is None: # no content length header 77 | for chunk in r.iter_content(chunk_size=1024): 78 | if chunk: # filter out keep-alive new chunks 79 | f.write(chunk) 80 | else: 81 | total_length = int(total_length) 82 | for chunk in tqdm( 83 | r.iter_content(chunk_size=1024), 84 | total=int(total_length / 1024.0 + 0.5), 85 | unit="KB", 86 | unit_scale=False, 87 | dynamic_ncols=True, 88 | ): 89 | f.write(chunk) 90 | 91 | if sha1_hash and not check_sha1(fname, sha1_hash): 92 | raise UserWarning( 93 | "File {} is downloaded but the content hash does not match. " 94 | "The repo may be outdated or download may be incomplete. " 95 | 'If the "repo_url" is overridden, consider switching to ' 96 | "the default repo.".format(fname) 97 | ) 98 | 99 | return fname 100 | -------------------------------------------------------------------------------- /segm/utils/lines.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from itertools import cycle 3 | 4 | 5 | class Lines: 6 | def __init__(self, resolution=20, smooth=None): 7 | self.COLORS = cycle( 8 | [ 9 | "#377eb8", 10 | "#e41a1c", 11 | "#4daf4a", 12 | "#984ea3", 13 | "#ff7f00", 14 | "#ffff33", 15 | "#a65628", 16 | "#f781bf", 17 | ] 18 | ) 19 | self.MARKERS = cycle("os^Dp>d<") 20 | self.LEGEND = dict(fontsize="medium", labelspacing=0, numpoints=1) 21 | self._resolution = resolution 22 | self._smooth_weight = smooth 23 | 24 | def __call__(self, ax, domains, lines, labels): 25 | assert len(domains) == len(lines) == len(labels) 26 | colors = [] 27 | for index, (label, color, marker) in enumerate( 28 | zip(labels, self.COLORS, self.MARKERS) 29 | ): 30 | domain, line = domains[index], lines[index] 31 | line = self.smooth(line, self._smooth_weight) 32 | ax.plot(domain, line[:, 0], color=color, label=label) 33 | 34 | last_x, last_y = domain[-1], line[-1, 0] 35 | ax.scatter(last_x, last_y, color=color, marker="x") 36 | ax.annotate( 37 | f"{last_y:.2f}", 38 | xy=(last_x, last_y), 39 | xytext=(last_x, last_y + 0.1), 40 | ) 41 | colors.append(color) 42 | 43 | self._plot_legend(ax, lines, labels) 44 | return colors 45 | 46 | def _plot_legend(self, ax, lines, labels): 47 | scores = {label: -np.nanmedian(line) for label, line in zip(labels, lines)} 48 | handles, labels = ax.get_legend_handles_labels() 49 | # handles, labels = zip(*sorted( 50 | # zip(handles, labels), key=lambda x: scores[x[1]])) 51 | legend = ax.legend(handles, labels, **self.LEGEND) 52 | legend.get_frame().set_edgecolor("white") 53 | for line in legend.get_lines(): 54 | line.set_alpha(1) 55 | 56 | def smooth(self, scalars, weight): 57 | """ 58 | weight in [0, 1] 59 | exponential moving average, same as tensorboard 60 | """ 61 | assert weight >= 0 and weight <= 1 62 | last = scalars[0] 63 | smoothed = np.asarray(scalars) 64 | for i, point in enumerate(scalars): 65 | last = last * weight + (1 - weight) * point 66 | smoothed[i] = last 67 | 68 | return smoothed 69 | -------------------------------------------------------------------------------- /segm/utils/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/facebookresearch/deit/blob/main/utils.py 3 | """ 4 | 5 | import io 6 | import os 7 | import time 8 | from collections import defaultdict, deque 9 | import datetime 10 | 11 | import torch 12 | import torch.distributed as dist 13 | 14 | import segm.utils.torch as ptu 15 | 16 | 17 | class SmoothedValue(object): 18 | """Track a series of values and provide access to smoothed values over a 19 | window or the global series average. 20 | """ 21 | 22 | def __init__(self, window_size=20, fmt=None): 23 | if fmt is None: 24 | fmt = "{median:.4f} ({global_avg:.4f})" 25 | self.deque = deque(maxlen=window_size) 26 | self.total = 0.0 27 | self.count = 0 28 | self.fmt = fmt 29 | 30 | def update(self, value, n=1): 31 | self.deque.append(value) 32 | self.count += n 33 | self.total += value * n 34 | 35 | def synchronize_between_processes(self): 36 | """ 37 | Warning: does not synchronize the deque! 38 | """ 39 | if not is_dist_avail_and_initialized(): 40 | return 41 | t = torch.tensor( 42 | [self.count, self.total], dtype=torch.float64, device=ptu.device 43 | ) 44 | dist.barrier() 45 | dist.all_reduce(t) 46 | t = t.tolist() 47 | self.count = int(t[0]) 48 | self.total = t[1] 49 | 50 | @property 51 | def median(self): 52 | d = torch.tensor(list(self.deque)) 53 | return d.median().item() 54 | 55 | @property 56 | def avg(self): 57 | d = torch.tensor(list(self.deque), dtype=torch.float32) 58 | return d.mean().item() 59 | 60 | @property 61 | def global_avg(self): 62 | return self.total / self.count 63 | 64 | @property 65 | def max(self): 66 | return max(self.deque) 67 | 68 | @property 69 | def value(self): 70 | return self.deque[-1] 71 | 72 | def __str__(self): 73 | return self.fmt.format( 74 | median=self.median, 75 | avg=self.avg, 76 | global_avg=self.global_avg, 77 | max=self.max, 78 | value=self.value, 79 | ) 80 | 81 | 82 | class MetricLogger(object): 83 | def __init__(self, delimiter="\t"): 84 | self.meters = defaultdict(SmoothedValue) 85 | self.delimiter = delimiter 86 | 87 | def update(self, n=1, **kwargs): 88 | for k, v in kwargs.items(): 89 | if isinstance(v, torch.Tensor): 90 | v = v.item() 91 | assert isinstance(v, (float, int)) 92 | self.meters[k].update(v, n) 93 | 94 | def __getattr__(self, attr): 95 | if attr in self.meters: 96 | return self.meters[attr] 97 | if attr in self.__dict__: 98 | return self.__dict__[attr] 99 | raise AttributeError( 100 | "'{}' object has no attribute '{}'".format(type(self).__name__, attr) 101 | ) 102 | 103 | def __str__(self): 104 | loss_str = [] 105 | for name, meter in self.meters.items(): 106 | loss_str.append("{}: {}".format(name, str(meter))) 107 | return self.delimiter.join(loss_str) 108 | 109 | def synchronize_between_processes(self): 110 | for meter in self.meters.values(): 111 | meter.synchronize_between_processes() 112 | 113 | def add_meter(self, name, meter): 114 | self.meters[name] = meter 115 | 116 | def log_every(self, iterable, print_freq, header=None): 117 | i = 0 118 | if not header: 119 | header = "" 120 | start_time = time.time() 121 | end = time.time() 122 | iter_time = SmoothedValue(fmt="{avg:.4f}") 123 | data_time = SmoothedValue(fmt="{avg:.4f}") 124 | space_fmt = ":" + str(len(str(len(iterable)))) + "d" 125 | log_msg = [ 126 | header, 127 | "[{0" + space_fmt + "}/{1}]", 128 | "eta: {eta}", 129 | "{meters}", 130 | "time: {time}", 131 | "data: {data}", 132 | ] 133 | if torch.cuda.is_available(): 134 | log_msg.append("max mem: {memory:.0f}") 135 | log_msg = self.delimiter.join(log_msg) 136 | MB = 1024.0 * 1024.0 137 | for obj in iterable: 138 | data_time.update(time.time() - end) 139 | yield obj 140 | iter_time.update(time.time() - end) 141 | if i % print_freq == 0 or i == len(iterable) - 1: 142 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 143 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 144 | if torch.cuda.is_available(): 145 | print( 146 | log_msg.format( 147 | i, 148 | len(iterable), 149 | eta=eta_string, 150 | meters=str(self), 151 | time=str(iter_time), 152 | data=str(data_time), 153 | memory=torch.cuda.max_memory_allocated() / MB, 154 | ), 155 | flush=True, 156 | ) 157 | else: 158 | print( 159 | log_msg.format( 160 | i, 161 | len(iterable), 162 | eta=eta_string, 163 | meters=str(self), 164 | time=str(iter_time), 165 | data=str(data_time), 166 | ), 167 | flush=True, 168 | ) 169 | i += 1 170 | end = time.time() 171 | total_time = time.time() - start_time 172 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 173 | print( 174 | "{} Total time: {} ({:.4f} s / it)".format( 175 | header, total_time_str, total_time / len(iterable) 176 | ) 177 | ) 178 | 179 | 180 | def is_dist_avail_and_initialized(): 181 | if not dist.is_available(): 182 | return False 183 | if not dist.is_initialized(): 184 | return False 185 | return True 186 | -------------------------------------------------------------------------------- /segm/utils/logs.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | import numpy as np 4 | import yaml 5 | import matplotlib.pyplot as plt 6 | import click 7 | from collections import OrderedDict 8 | 9 | from segm.utils.lines import Lines 10 | 11 | 12 | def plot_logs(logs, x_key, y_key, size, vmin, vmax, epochs): 13 | m = np.inf 14 | M = -np.inf 15 | domains = [] 16 | lines = [] 17 | y_keys = y_key.split("/") 18 | for name, log in logs.items(): 19 | logs[name] = log[:epochs] 20 | for name, log in logs.items(): 21 | domain = [x[x_key] for x in log if y_keys[0] in x] 22 | if y_keys[0] not in log[0]: 23 | continue 24 | log_plot = [x[y_keys[0]] for x in log if y_keys[0] in x] 25 | for y_key in y_keys[1:]: 26 | if y_key in log_plot[0]: 27 | log_plot = [x[y_key] for x in log_plot if y_key in x] 28 | domains.append(domain) 29 | lines.append(np.array(log_plot)[:, None]) 30 | m = np.min((m, min(log_plot))) 31 | M = np.max((M, max(log_plot))) 32 | if vmin is not None: 33 | m = vmin 34 | if vmax is not None: 35 | M = vmax 36 | delta = 0.1 * (M - m) 37 | 38 | ratio = 0.6 39 | figsizes = {"tight": (4, 3), "large": (16 * ratio, 10 * ratio)} 40 | figsize = figsizes[size] 41 | 42 | # plot parameters 43 | fig, ax = plt.subplots(figsize=figsize) 44 | ax.set_xlabel(x_key) 45 | ax.set_ylabel(y_key) 46 | plot_lines = Lines(resolution=50, smooth=0.0) 47 | plot_lines.LEGEND["loc"] = "upper left" 48 | # plot_lines.LEGEND["fontsize"] = "large" 49 | plot_lines.LEGEND["bbox_to_anchor"] = (0.75, 0.2) 50 | labels_logs = list(logs.keys()) 51 | colors = plot_lines(ax, domains, lines, labels_logs) 52 | ax.grid(True, alpha=0.5) 53 | ax.set_ylim(m - delta, M + delta) 54 | 55 | plt.show() 56 | fig.savefig( 57 | "plot.png", bbox_inches="tight", pad_inches=0.1, transparent=False, dpi=300 58 | ) 59 | plt.close(fig) 60 | 61 | 62 | def print_logs(logs, x_key, y_key, last_log_idx=None): 63 | delim = " " 64 | s = "" 65 | keys = [] 66 | y_keys = y_key.split("/") 67 | for name, log in logs.items(): 68 | log_idx = last_log_idx 69 | if log_idx is None: 70 | log_idx = len(log) - 1 71 | while y_keys[0] not in log[log_idx]: 72 | log_idx -= 1 73 | last_log = log[log_idx] 74 | log_x = last_log[x_key] 75 | log_y = last_log[y_keys[0]] 76 | for y_key in y_keys[1:]: 77 | log_y = log_y[y_key] 78 | s += f"{name}:\n" 79 | # s += f"{delim}{x_key}: {log_x}\n" 80 | s += f"{delim}{y_key}: {log_y:.4f}\n" 81 | keys += list(last_log.keys()) 82 | keys = list(set(keys)) 83 | keys = ", ".join(keys) 84 | s = f"keys: {keys}\n" + s 85 | print(s) 86 | 87 | 88 | def read_logs(root, logs_path): 89 | logs = {} 90 | for name, path in logs_path.items(): 91 | path = root / path 92 | if not path.exists(): 93 | print(f"Skipping {name} that has no log file") 94 | continue 95 | logs[name] = [] 96 | with open(path, "r") as f: 97 | for line in f.readlines(): 98 | d = json.loads(line) 99 | logs[name].append(d) 100 | return logs 101 | 102 | 103 | @click.command() 104 | @click.argument("log_path", type=str) 105 | @click.option("--x-key", default="epoch", type=str) 106 | @click.option("--y-key", default="val_mean_iou", type=str) 107 | @click.option("-s", "--size", default="large", type=str) 108 | @click.option("-ep", "--epoch", default=-1, type=int) 109 | @click.option("-plot", "--plot/--no-plot", default=True, is_flag=True) 110 | def main(log_path, x_key, y_key, size, epoch, plot): 111 | abs_path = Path(__file__).parent / log_path 112 | if abs_path.exists(): 113 | log_path = abs_path 114 | config = yaml.load(open(log_path, "r"), Loader=yaml.FullLoader) 115 | root = Path(config["root"]) 116 | logs_path = OrderedDict(config["logs"]) 117 | vmin = config.get("vmin", None) 118 | vmax = config.get("vmax", None) 119 | epochs = config.get("epochs", None) 120 | 121 | logs = read_logs(root, logs_path) 122 | if not logs: 123 | return 124 | print_logs(logs, x_key, y_key, epoch) 125 | if plot: 126 | plot_logs(logs, x_key, y_key, size, vmin, vmax, epochs) 127 | 128 | 129 | if __name__ == "__main__": 130 | main() 131 | -------------------------------------------------------------------------------- /segm/utils/torch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | """ 6 | GPU wrappers 7 | """ 8 | 9 | use_gpu = False 10 | gpu_id = 0 11 | device = None 12 | 13 | distributed = False 14 | dist_rank = 0 15 | world_size = 1 16 | 17 | 18 | def set_gpu_mode(mode): 19 | global use_gpu 20 | global device 21 | global gpu_id 22 | global distributed 23 | global dist_rank 24 | global world_size 25 | gpu_id = int(os.environ.get("SLURM_LOCALID", 0)) 26 | dist_rank = int(os.environ.get("SLURM_PROCID", 0)) 27 | world_size = int(os.environ.get("SLURM_NTASKS", 1)) 28 | 29 | distributed = world_size > 1 30 | use_gpu = mode 31 | device = torch.device(f"cuda:{gpu_id}" if use_gpu else "cpu") 32 | torch.backends.cudnn.benchmark = True 33 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | 3 | from setuptools import setup, find_packages 4 | from codecs import open 5 | from os import path 6 | 7 | here = path.abspath(path.dirname(__file__)) 8 | 9 | 10 | def read_requirements_file(filename): 11 | req_file_path = path.join(path.dirname(path.realpath(__file__)), filename) 12 | with open(req_file_path) as f: 13 | return [line.strip() for line in f] 14 | 15 | 16 | setup( 17 | name="segm", 18 | version="0.0.1", 19 | description="Segmenter: Transformer for Semantic Segmentation", 20 | packages=find_packages(), 21 | install_requires=read_requirements_file("requirements.txt"), 22 | ) 23 | --------------------------------------------------------------------------------