├── .gitignore
├── LICENSE
├── README.md
├── attn_maps_dec.png
├── attn_maps_enc.png
├── gifs
├── breakdance-flare.gif
├── car-competition.gif
├── choreography.gif
└── city-ride.gif
├── overview.png
├── requirements.txt
├── segm
├── config.py
├── config.yml
├── data
│ ├── __init__.py
│ ├── ade20k.py
│ ├── base.py
│ ├── cityscapes.py
│ ├── config
│ │ ├── ade20k.py
│ │ ├── ade20k.yml
│ │ ├── cityscapes.py
│ │ ├── cityscapes.yml
│ │ ├── pascal_context.py
│ │ └── pascal_context.yml
│ ├── factory.py
│ ├── imagenet.py
│ ├── loader.py
│ ├── pascal_context.py
│ └── utils.py
├── engine.py
├── eval
│ ├── accuracy.py
│ └── miou.py
├── inference.py
├── metrics.py
├── model
│ ├── blocks.py
│ ├── decoder.py
│ ├── factory.py
│ ├── segmenter.py
│ ├── utils.py
│ └── vit.py
├── optim
│ ├── factory.py
│ └── scheduler.py
├── scripts
│ ├── prepare_ade20k.py
│ ├── prepare_cityscapes.py
│ ├── prepare_pcontext.py
│ └── show_attn_map.py
├── train.py
└── utils
│ ├── distributed.py
│ ├── download.py
│ ├── lines.py
│ ├── logger.py
│ ├── logs.py
│ └── torch.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Created by https://www.toptal.com/developers/gitignore/api/python
3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python
4 |
5 | ### Python ###
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | pip-wheel-metadata/
29 | share/python-wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 | MANIFEST
34 |
35 | # PyInstaller
36 | # Usually these files are written by a python script from a template
37 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
38 | *.manifest
39 | *.spec
40 |
41 | # Installer logs
42 | pip-log.txt
43 | pip-delete-this-directory.txt
44 |
45 | # Unit test / coverage reports
46 | htmlcov/
47 | .tox/
48 | .nox/
49 | .coverage
50 | .coverage.*
51 | .cache
52 | nosetests.xml
53 | coverage.xml
54 | *.cover
55 | *.py,cover
56 | .hypothesis/
57 | .pytest_cache/
58 | pytestdebug.log
59 |
60 | # Translations
61 | *.mo
62 | *.pot
63 |
64 | # Django stuff:
65 | *.log
66 | local_settings.py
67 | db.sqlite3
68 | db.sqlite3-journal
69 |
70 | # Flask stuff:
71 | instance/
72 | .webassets-cache
73 |
74 | # Scrapy stuff:
75 | .scrapy
76 |
77 | # Sphinx documentation
78 | docs/_build/
79 | doc/_build/
80 |
81 | # PyBuilder
82 | target/
83 |
84 | # Jupyter Notebook
85 | .ipynb_checkpoints
86 |
87 | # IPython
88 | profile_default/
89 | ipython_config.py
90 |
91 | # pyenv
92 | .python-version
93 |
94 | # pipenv
95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
98 | # install all needed dependencies.
99 | #Pipfile.lock
100 |
101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
102 | __pypackages__/
103 |
104 | # Celery stuff
105 | celerybeat-schedule
106 | celerybeat.pid
107 |
108 | # SageMath parsed files
109 | *.sage.py
110 |
111 | # Environments
112 | .env
113 | .venv
114 | env/
115 | venv/
116 | ENV/
117 | env.bak/
118 | venv.bak/
119 | pythonenv*
120 |
121 | # Spyder project settings
122 | .spyderproject
123 | .spyproject
124 |
125 | # Rope project settings
126 | .ropeproject
127 |
128 | # mkdocs documentation
129 | /site
130 |
131 | # mypy
132 | .mypy_cache/
133 | .dmypy.json
134 | dmypy.json
135 |
136 | # Pyre type checker
137 | .pyre/
138 |
139 | # pytype static type analyzer
140 | .pytype/
141 |
142 | # profiling data
143 | .prof
144 |
145 | # End of https://www.toptal.com/developers/gitignore/api/python
146 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Robin Strudel
4 | Copyright (c) INRIA
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Segmenter: Transformer for Semantic Segmentation
2 |
3 | 
4 |
5 | [Segmenter: Transformer for Semantic Segmentation](https://arxiv.org/abs/2105.05633)
6 | by Robin Strudel*, Ricardo Garcia*, Ivan Laptev and Cordelia Schmid, ICCV 2021.
7 |
8 | *Equal Contribution
9 |
10 | 🔥 **Segmenter is now available on [MMSegmentation](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/segmenter).**
11 |
12 | ## Installation
13 |
14 | Define os environment variables pointing to your checkpoint and dataset directory, put in your `.bashrc`:
15 | ```sh
16 | export DATASET=/path/to/dataset/dir
17 | ```
18 |
19 | Install [PyTorch 1.9](https://pytorch.org/) then `pip install .` at the root of this repository.
20 |
21 | To download ADE20K, use the following command:
22 | ```python
23 | python -m segm.scripts.prepare_ade20k $DATASET
24 | ```
25 |
26 | ## Model Zoo
27 | We release models with a Vision Transformer backbone initialized from the [improved ViT](https://arxiv.org/abs/2106.10270) models.
28 |
29 | ### ADE20K
30 |
31 | Segmenter models with ViT backbone:
32 |
33 |
34 | Name |
35 | mIoU (SS/MS) |
36 | # params |
37 | Resolution |
38 | FPS |
39 | Download |
40 |
41 |
42 | Seg-T-Mask/16 |
43 | 38.1 / 38.8 |
44 | 7M |
45 | 512x512 |
46 | 52.4 |
47 | model |
48 | config |
49 | log |
50 |
51 |
52 | Seg-S-Mask/16 |
53 | 45.3 / 46.9 |
54 | 27M |
55 | 512x512 |
56 | 34.8 |
57 | model |
58 | config |
59 | log |
60 |
61 |
62 | Seg-B-Mask/16 |
63 | 48.5 / 50.0 |
64 | 106M |
65 | 512x512 |
66 | 24.1 |
67 | model |
68 | config |
69 | log |
70 |
71 |
72 | Seg-B/8 |
73 | 49.5 / 50.5 |
74 | 89M |
75 | 512x512 |
76 | 4.2 |
77 | model |
78 | config |
79 | log |
80 |
81 |
82 | Seg-L-Mask/16 |
83 | 51.8 / 53.6 |
84 | 334M |
85 | 640x640 |
86 | - |
87 | model |
88 | config |
89 | log |
90 |
91 |
92 |
93 | Segmenter models with DeiT backbone:
94 |
95 |
96 | Name |
97 | mIoU (SS/MS) |
98 | # params |
99 | Resolution |
100 | FPS |
101 | Download |
102 |
103 |
104 | Seg-B†/16 |
105 | 47.1 / 48.1 |
106 | 87M |
107 | 512x512 |
108 | 27.3 |
109 | model |
110 | config |
111 | log |
112 |
113 |
114 | Seg-B†-Mask/16 |
115 | 48.7 / 50.1 |
116 | 106M |
117 | 512x512 |
118 | 24.1 |
119 | model |
120 | config |
121 | log |
122 |
123 |
124 |
125 |
126 | ### Pascal Context
127 |
128 |
129 | Name |
130 | mIoU (SS/MS) |
131 | # params |
132 | Resolution |
133 | FPS |
134 | Download |
135 |
136 |
137 | Seg-L-Mask/16 |
138 | 58.1 / 59.0 |
139 | 334M |
140 | 480x480 |
141 | - |
142 | model |
143 | config |
144 | log |
145 |
146 |
147 |
148 | ### Cityscapes
149 |
150 |
151 | Name |
152 | mIoU (SS/MS) |
153 | # params |
154 | Resolution |
155 | FPS |
156 | Download |
157 |
158 |
159 | Seg-L-Mask/16 |
160 | 79.1 / 81.3 |
161 | 322M |
162 | 768x768 |
163 | - |
164 | model |
165 | config |
166 | log |
167 |
168 |
169 |
170 | ## Inference
171 |
172 | Download one checkpoint with its configuration in a common folder, for example `seg_tiny_mask`.
173 |
174 | You can generate segmentation maps from your own data with:
175 | ```python
176 | python -m segm.inference --model-path seg_tiny_mask/checkpoint.pth -i images/ -o segmaps/
177 | ```
178 |
179 | To evaluate on ADE20K, run the command:
180 | ```python
181 | # single-scale evaluation:
182 | python -m segm.eval.miou seg_tiny_mask/checkpoint.pth ade20k --singlescale
183 | # multi-scale evaluation:
184 | python -m segm.eval.miou seg_tiny_mask/checkpoint.pth ade20k --multiscale
185 | ```
186 |
187 | ## Train
188 |
189 | Train `Seg-T-Mask/16` on ADE20K on a single GPU:
190 | ```python
191 | python -m segm.train --log-dir seg_tiny_mask --dataset ade20k \
192 | --backbone vit_tiny_patch16_384 --decoder mask_transformer
193 | ```
194 |
195 | To train `Seg-B-Mask/16`, simply set `vit_base_patch16_384` as backbone and launch the above command using a minimum of 4 V100 GPUs (~12 minutes per epoch) and up to 8 V100 GPUs (~7 minutes per epoch). The code uses [SLURM](https://slurm.schedmd.com/documentation.html) environment variables.
196 |
197 | ## Logs
198 |
199 | To plot the logs of your experiments, you can use
200 | ```python
201 | python -m segm.utils.logs logs.yml
202 | ```
203 |
204 | with `logs.yml` located in `utils/` with the path to your experiments logs:
205 | ```yaml
206 | root: /path/to/checkpoints/
207 | logs:
208 | seg-t: seg_tiny_mask/log.txt
209 | seg-b: seg_base_mask/log.txt
210 | ```
211 |
212 | ## Attention Maps
213 |
214 | To visualize the attention maps for `Seg-T-Mask/16` encoder layer 0 and patch `(0, 21)`, you can use:
215 |
216 | ```python
217 | python -m segm.scripts.show_attn_map seg_tiny_mask/checkpoint.pth \
218 | images/im0.jpg output_dir/ --layer-id 0 --x-patch 0 --y-patch 21 --enc
219 | ```
220 |
221 | Different options are provided to select the generated attention maps:
222 | * `--enc` or `--dec`: Select encoder or decoder attention maps respectively.
223 | * `--patch` or `--cls`: `--patch` generates attention maps for the patch with coordinates `(x_patch, y_patch)`. `--cls` combined with `--enc` generates attention maps for the CLS token of the encoder. `--cls` combined with `--dec` generates maps for each class embedding of the decoder.
224 | * `--x-patch` and `--y-patch`: Coordinates of the patch to draw attention maps from. This flag is ignored when `--cls` is used.
225 | * `--layer-id`: Select the layer for which the attention maps are generated.
226 |
227 | For example, to generate attention maps for the decoder class embeddings, you can use:
228 |
229 | ```python
230 | python -m segm.scripts.show_attn_map seg_tiny_mask/checkpoint.pth \
231 | images/im0.jpg output_dir/ --layer-id 0 --dec --cls
232 | ```
233 |
234 | Attention maps for patch `(0, 21)` in `Seg-L-Mask/16` encoder layers 1, 4, 8, 12 and 16:
235 |
236 | 
237 |
238 | Attention maps for the class embeddings in `Seg-L-Mask/16` decoder layer 0:
239 |
240 | 
241 |
242 | ## Video Segmentation
243 |
244 | Zero shot video segmentation on [DAVIS](https://davischallenge.org/) video dataset with Seg-B-Mask/16 model trained on [ADE20K](https://groups.csail.mit.edu/vision/datasets/ADE20K/).
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 | ## BibTex
256 |
257 | ```
258 | @article{strudel2021,
259 | title={Segmenter: Transformer for Semantic Segmentation},
260 | author={Strudel, Robin and Garcia, Ricardo and Laptev, Ivan and Schmid, Cordelia},
261 | journal={arXiv preprint arXiv:2105.05633},
262 | year={2021}
263 | }
264 | ```
265 |
266 |
267 | ## Acknowledgements
268 |
269 | The Vision Transformer code is based on [timm](https://github.com/rwightman/pytorch-image-models) library and the semantic segmentation training and evaluation pipeline
270 | is using [mmsegmentation](https://github.com/open-mmlab/mmsegmentation).
271 |
--------------------------------------------------------------------------------
/attn_maps_dec.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rstrudel/segmenter/20d1bfad354165ee45c3f65972a4d9c131f58d53/attn_maps_dec.png
--------------------------------------------------------------------------------
/attn_maps_enc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rstrudel/segmenter/20d1bfad354165ee45c3f65972a4d9c131f58d53/attn_maps_enc.png
--------------------------------------------------------------------------------
/gifs/breakdance-flare.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rstrudel/segmenter/20d1bfad354165ee45c3f65972a4d9c131f58d53/gifs/breakdance-flare.gif
--------------------------------------------------------------------------------
/gifs/car-competition.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rstrudel/segmenter/20d1bfad354165ee45c3f65972a4d9c131f58d53/gifs/car-competition.gif
--------------------------------------------------------------------------------
/gifs/choreography.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rstrudel/segmenter/20d1bfad354165ee45c3f65972a4d9c131f58d53/gifs/choreography.gif
--------------------------------------------------------------------------------
/gifs/city-ride.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rstrudel/segmenter/20d1bfad354165ee45c3f65972a4d9c131f58d53/gifs/city-ride.gif
--------------------------------------------------------------------------------
/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rstrudel/segmenter/20d1bfad354165ee45c3f65972a4d9c131f58d53/overview.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | click
3 | numpy
4 | einops
5 | python-hostlist
6 | tqdm
7 | requests
8 | pyyaml
9 | timm == 0.4.12
10 | mmcv==1.3.8
11 | mmsegmentation==0.14.1
12 |
--------------------------------------------------------------------------------
/segm/config.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | from pathlib import Path
3 |
4 | import os
5 |
6 |
7 | def load_config():
8 | return yaml.load(
9 | open(Path(__file__).parent / "config.yml", "r"), Loader=yaml.FullLoader
10 | )
11 |
12 |
13 | def check_os_environ(key, use):
14 | if key not in os.environ:
15 | raise ValueError(
16 | f"{key} is not defined in the os variables, it is required for {use}."
17 | )
18 |
19 |
20 | def dataset_dir():
21 | check_os_environ("DATASET", "data loading")
22 | return os.environ["DATASET"]
23 |
--------------------------------------------------------------------------------
/segm/config.yml:
--------------------------------------------------------------------------------
1 | model:
2 | # deit
3 | deit_tiny_distilled_patch16_224:
4 | image_size: 224
5 | patch_size: 16
6 | d_model: 192
7 | n_heads: 3
8 | n_layers: 12
9 | normalization: deit
10 | distilled: true
11 | deit_small_distilled_patch16_224:
12 | image_size: 224
13 | patch_size: 16
14 | d_model: 384
15 | n_heads: 6
16 | n_layers: 12
17 | normalization: deit
18 | distilled: true
19 | deit_base_distilled_patch16_224:
20 | image_size: 224
21 | patch_size: 16
22 | d_model: 768
23 | n_heads: 12
24 | n_layers: 12
25 | normalization: deit
26 | distilled: true
27 | deit_base_distilled_patch16_384:
28 | image_size: 384
29 | patch_size: 16
30 | d_model: 768
31 | n_heads: 12
32 | n_layers: 12
33 | normalization: deit
34 | distilled: true
35 | # vit
36 | vit_base_patch8_384:
37 | image_size: 384
38 | patch_size: 8
39 | d_model: 768
40 | n_heads: 12
41 | n_layers: 12
42 | normalization: vit
43 | distilled: false
44 | vit_tiny_patch16_384:
45 | image_size: 384
46 | patch_size: 16
47 | d_model: 192
48 | n_heads: 3
49 | n_layers: 12
50 | normalization: vit
51 | distilled: false
52 | vit_small_patch16_384:
53 | image_size: 384
54 | patch_size: 16
55 | d_model: 384
56 | n_heads: 6
57 | n_layers: 12
58 | normalization: vit
59 | distilled: false
60 | vit_base_patch16_384:
61 | image_size: 384
62 | patch_size: 16
63 | d_model: 768
64 | n_heads: 12
65 | n_layers: 12
66 | normalization: vit
67 | distilled: false
68 | vit_large_patch16_384:
69 | image_size: 384
70 | patch_size: 16
71 | d_model: 1024
72 | n_heads: 16
73 | n_layers: 24
74 | normalization: vit
75 | vit_small_patch32_384:
76 | image_size: 384
77 | patch_size: 32
78 | d_model: 384
79 | n_heads: 6
80 | n_layers: 12
81 | normalization: vit
82 | distilled: false
83 | vit_base_patch32_384:
84 | image_size: 384
85 | patch_size: 32
86 | d_model: 768
87 | n_heads: 12
88 | n_layers: 12
89 | normalization: vit
90 | vit_large_patch32_384:
91 | image_size: 384
92 | patch_size: 32
93 | d_model: 1024
94 | n_heads: 16
95 | n_layers: 24
96 | normalization: vit
97 | decoder:
98 | linear: {}
99 | deeplab_dec:
100 | encoder_layer: -1
101 | mask_transformer:
102 | drop_path_rate: 0.0
103 | dropout: 0.1
104 | n_layers: 2
105 | dataset:
106 | ade20k:
107 | epochs: 64
108 | eval_freq: 2
109 | batch_size: 8
110 | learning_rate: 0.001
111 | im_size: 512
112 | crop_size: 512
113 | window_size: 512
114 | window_stride: 512
115 | pascal_context:
116 | epochs: 256
117 | eval_freq: 8
118 | batch_size: 16
119 | learning_rate: 0.001
120 | im_size: 520
121 | crop_size: 480
122 | window_size: 480
123 | window_stride: 320
124 | cityscapes:
125 | epochs: 216
126 | eval_freq: 4
127 | batch_size: 8
128 | learning_rate: 0.01
129 | im_size: 1024
130 | crop_size: 768
131 | window_size: 768
132 | window_stride: 512
133 |
--------------------------------------------------------------------------------
/segm/data/__init__.py:
--------------------------------------------------------------------------------
1 | from segm.data.loader import Loader
2 |
3 | from segm.data.imagenet import ImagenetDataset
4 | from segm.data.ade20k import ADE20KSegmentation
5 | from segm.data.pascal_context import PascalContextDataset
6 | from segm.data.cityscapes import CityscapesDataset
7 |
--------------------------------------------------------------------------------
/segm/data/ade20k.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from segm.data.base import BaseMMSeg
4 | from segm.data import utils
5 | from segm.config import dataset_dir
6 |
7 |
8 | ADE20K_CONFIG_PATH = Path(__file__).parent / "config" / "ade20k.py"
9 | ADE20K_CATS_PATH = Path(__file__).parent / "config" / "ade20k.yml"
10 |
11 |
12 | class ADE20KSegmentation(BaseMMSeg):
13 | def __init__(self, image_size, crop_size, split, **kwargs):
14 | super().__init__(
15 | image_size,
16 | crop_size,
17 | split,
18 | ADE20K_CONFIG_PATH,
19 | **kwargs,
20 | )
21 | self.names, self.colors = utils.dataset_cat_description(ADE20K_CATS_PATH)
22 | self.n_cls = 150
23 | self.ignore_label = 0
24 | self.reduce_zero_label = True
25 |
26 | def update_default_config(self, config):
27 | root_dir = dataset_dir()
28 | path = Path(root_dir) / "ade20k"
29 | config.data_root = path
30 | if self.split == "train":
31 | config.data.train.data_root = path / "ADEChallengeData2016"
32 | elif self.split == "trainval":
33 | config.data.trainval.data_root = path / "ADEChallengeData2016"
34 | elif self.split == "val":
35 | config.data.val.data_root = path / "ADEChallengeData2016"
36 | elif self.split == "test":
37 | config.data.test.data_root = path / "release_test"
38 | config = super().update_default_config(config)
39 | return config
40 |
41 | def test_post_process(self, labels):
42 | return labels + 1
43 |
--------------------------------------------------------------------------------
/segm/data/base.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pathlib import Path
3 | from PIL import Image, ImageOps, ImageFilter
4 |
5 | import torch
6 | from torch.utils.data import Dataset
7 | import torchvision.transforms.functional as F
8 |
9 | from mmseg.datasets import build_dataset
10 | import mmcv
11 | from mmcv.utils import Config
12 |
13 |
14 | from segm.data.utils import STATS, IGNORE_LABEL
15 | from segm.data import utils
16 |
17 |
18 | class BaseMMSeg(Dataset):
19 | def __init__(
20 | self,
21 | image_size,
22 | crop_size,
23 | split,
24 | config_path,
25 | normalization,
26 | **kwargs,
27 | ):
28 | super().__init__()
29 | self.image_size = image_size
30 | self.crop_size = crop_size
31 | self.split = split
32 | self.normalization = STATS[normalization].copy()
33 | self.ignore_label = None
34 | for k, v in self.normalization.items():
35 | v = np.round(255 * np.array(v), 2)
36 | self.normalization[k] = tuple(v)
37 | print(f"Use normalization: {self.normalization}")
38 |
39 | config = Config.fromfile(config_path)
40 |
41 | self.ratio = config.max_ratio
42 | self.dataset = None
43 | self.config = self.update_default_config(config)
44 | self.dataset = build_dataset(getattr(self.config.data, f"{self.split}"))
45 |
46 | def update_default_config(self, config):
47 |
48 | train_splits = ["train", "trainval"]
49 | if self.split in train_splits:
50 | config_pipeline = getattr(config, f"train_pipeline")
51 | else:
52 | config_pipeline = getattr(config, f"{self.split}_pipeline")
53 |
54 | img_scale = (self.ratio * self.image_size, self.image_size)
55 | if self.split not in train_splits:
56 | assert config_pipeline[1]["type"] == "MultiScaleFlipAug"
57 | config_pipeline = config_pipeline[1]["transforms"]
58 | for i, op in enumerate(config_pipeline):
59 | op_type = op["type"]
60 | if op_type == "Resize":
61 | op["img_scale"] = img_scale
62 | elif op_type == "RandomCrop":
63 | op["crop_size"] = (
64 | self.crop_size,
65 | self.crop_size,
66 | )
67 | elif op_type == "Normalize":
68 | op["mean"] = self.normalization["mean"]
69 | op["std"] = self.normalization["std"]
70 | elif op_type == "Pad":
71 | op["size"] = (self.crop_size, self.crop_size)
72 | config_pipeline[i] = op
73 | if self.split == "train":
74 | config.data.train.pipeline = config_pipeline
75 | elif self.split == "trainval":
76 | config.data.trainval.pipeline = config_pipeline
77 | elif self.split == "val":
78 | config.data.val.pipeline[1]["img_scale"] = img_scale
79 | config.data.val.pipeline[1]["transforms"] = config_pipeline
80 | elif self.split == "test":
81 | config.data.test.pipeline[1]["img_scale"] = img_scale
82 | config.data.test.pipeline[1]["transforms"] = config_pipeline
83 | config.data.test.test_mode = True
84 | else:
85 | raise ValueError(f"Unknown split: {self.split}")
86 | return config
87 |
88 | def set_multiscale_mode(self):
89 | self.config.data.val.pipeline[1]["img_ratios"] = [
90 | 0.5,
91 | 0.75,
92 | 1.0,
93 | 1.25,
94 | 1.5,
95 | 1.75,
96 | ]
97 | self.config.data.val.pipeline[1]["flip"] = True
98 | self.config.data.test.pipeline[1]["img_ratios"] = [
99 | 0.5,
100 | 0.75,
101 | 1.0,
102 | 1.25,
103 | 1.5,
104 | 1.75,
105 | ]
106 | self.config.data.test.pipeline[1]["flip"] = True
107 | self.dataset = build_dataset(getattr(self.config.data, f"{self.split}"))
108 |
109 | def __getitem__(self, idx):
110 | data = self.dataset[idx]
111 |
112 | train_splits = ["train", "trainval"]
113 |
114 | if self.split in train_splits:
115 | im = data["img"].data
116 | seg = data["gt_semantic_seg"].data.squeeze(0)
117 | else:
118 | im = [im.data for im in data["img"]]
119 | seg = None
120 |
121 | out = dict(im=im)
122 | if self.split in train_splits:
123 | out["segmentation"] = seg
124 | else:
125 | im_metas = [meta.data for meta in data["img_metas"]]
126 | out["im_metas"] = im_metas
127 | out["colors"] = self.colors
128 |
129 | return out
130 |
131 | def get_gt_seg_maps(self):
132 | dataset = self.dataset
133 | gt_seg_maps = {}
134 | for img_info in dataset.img_infos:
135 | seg_map = Path(dataset.ann_dir) / img_info["ann"]["seg_map"]
136 | gt_seg_map = mmcv.imread(seg_map, flag="unchanged", backend="pillow")
137 | gt_seg_map[gt_seg_map == self.ignore_label] = IGNORE_LABEL
138 | if self.reduce_zero_label:
139 | gt_seg_map[gt_seg_map != IGNORE_LABEL] -= 1
140 | gt_seg_maps[img_info["filename"]] = gt_seg_map
141 | return gt_seg_maps
142 |
143 | def __len__(self):
144 | return len(self.dataset)
145 |
146 | @property
147 | def unwrapped(self):
148 | return self
149 |
150 | def set_epoch(self, epoch):
151 | pass
152 |
153 | def get_diagnostics(self, logger):
154 | pass
155 |
156 | def get_snapshot(self):
157 | return {}
158 |
159 | def end_epoch(self, epoch):
160 | return
161 |
--------------------------------------------------------------------------------
/segm/data/cityscapes.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | try:
4 | import cityscapesscripts.helpers.labels as CSLabels
5 | except:
6 | pass
7 |
8 | from pathlib import Path
9 | from segm.data.base import BaseMMSeg
10 | from segm.data import utils
11 | from segm.config import dataset_dir
12 |
13 | CITYSCAPES_CONFIG_PATH = Path(__file__).parent / "config" / "cityscapes.py"
14 | CITYSCAPES_CATS_PATH = Path(__file__).parent / "config" / "cityscapes.yml"
15 |
16 |
17 | class CityscapesDataset(BaseMMSeg):
18 | def __init__(self, image_size, crop_size, split, **kwargs):
19 | super().__init__(image_size, crop_size, split, CITYSCAPES_CONFIG_PATH, **kwargs)
20 | self.names, self.colors = utils.dataset_cat_description(CITYSCAPES_CATS_PATH)
21 | self.n_cls = 19
22 | self.ignore_label = 255
23 | self.reduce_zero_label = False
24 |
25 | def update_default_config(self, config):
26 |
27 | root_dir = dataset_dir()
28 | path = Path(root_dir) / "cityscapes"
29 | config.data_root = path
30 |
31 | config.data[self.split]["data_root"] = path
32 | config = super().update_default_config(config)
33 |
34 | return config
35 |
36 | def test_post_process(self, labels):
37 | labels_copy = np.copy(labels)
38 | cats = np.unique(labels_copy)
39 | for cat in cats:
40 | labels_copy[labels == cat] = CSLabels.trainId2label[cat].id
41 | return labels_copy
42 |
--------------------------------------------------------------------------------
/segm/data/config/ade20k.py:
--------------------------------------------------------------------------------
1 | # dataset settings
2 | dataset_type = "ADE20KDataset"
3 | data_root = "data/ade/ADEChallengeData2016"
4 | img_norm_cfg = dict(
5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
6 | )
7 | crop_size = (512, 512)
8 | max_ratio = 4
9 | train_pipeline = [
10 | dict(type="LoadImageFromFile"),
11 | dict(type="LoadAnnotations", reduce_zero_label=True),
12 | dict(type="Resize", img_scale=(512 * max_ratio, 512), ratio_range=(0.5, 2.0)),
13 | dict(type="RandomCrop", crop_size=crop_size, cat_max_ratio=0.75),
14 | dict(type="RandomFlip", prob=0.5),
15 | dict(type="PhotoMetricDistortion"),
16 | dict(type="Normalize", **img_norm_cfg),
17 | dict(type="Pad", size=crop_size, pad_val=0, seg_pad_val=255),
18 | dict(type="DefaultFormatBundle"),
19 | dict(type="Collect", keys=["img", "gt_semantic_seg"]),
20 | ]
21 | val_pipeline = [
22 | dict(type="LoadImageFromFile"),
23 | dict(
24 | type="MultiScaleFlipAug",
25 | img_scale=(512 * max_ratio, 512),
26 | flip=False,
27 | transforms=[
28 | dict(type="Resize", keep_ratio=True),
29 | dict(type="RandomFlip"),
30 | dict(type="Normalize", **img_norm_cfg),
31 | dict(type="ImageToTensor", keys=["img"]),
32 | dict(type="Collect", keys=["img"]),
33 | ],
34 | ),
35 | ]
36 | test_pipeline = [
37 | dict(type="LoadImageFromFile"),
38 | dict(
39 | type="MultiScaleFlipAug",
40 | img_scale=(512 * max_ratio, 512),
41 | flip=False,
42 | transforms=[
43 | dict(type="Resize", keep_ratio=True),
44 | dict(type="RandomFlip"),
45 | dict(type="Normalize", **img_norm_cfg),
46 | dict(type="ImageToTensor", keys=["img"]),
47 | dict(type="Collect", keys=["img"]),
48 | ],
49 | ),
50 | ]
51 | data = dict(
52 | samples_per_gpu=4,
53 | workers_per_gpu=4,
54 | train=dict(
55 | type=dataset_type,
56 | data_root=data_root,
57 | img_dir="images/training",
58 | ann_dir="annotations/training",
59 | pipeline=train_pipeline,
60 | ),
61 | trainval=dict(
62 | type=dataset_type,
63 | data_root=data_root,
64 | img_dir=["images/training", "images/validation"],
65 | ann_dir=["annotations/training", "annotations/validation"],
66 | pipeline=train_pipeline,
67 | ),
68 | val=dict(
69 | type=dataset_type,
70 | data_root=data_root,
71 | img_dir="images/validation",
72 | ann_dir="annotations/validation",
73 | pipeline=val_pipeline,
74 | ),
75 | test=dict(
76 | type=dataset_type,
77 | data_root=data_root,
78 | img_dir="testing",
79 | pipeline=test_pipeline,
80 | ),
81 | )
82 |
--------------------------------------------------------------------------------
/segm/data/config/ade20k.yml:
--------------------------------------------------------------------------------
1 | - color:
2 | - 120
3 | - 120
4 | - 120
5 | id: 0
6 | isthing: 0
7 | name: wall
8 | - color:
9 | - 180
10 | - 120
11 | - 120
12 | id: 1
13 | isthing: 0
14 | name: building, edifice
15 | - color:
16 | - 6
17 | - 230
18 | - 230
19 | id: 2
20 | isthing: 0
21 | name: sky
22 | - color:
23 | - 80
24 | - 50
25 | - 50
26 | id: 3
27 | isthing: 0
28 | name: floor, flooring
29 | - color:
30 | - 4
31 | - 200
32 | - 3
33 | id: 4
34 | isthing: 0
35 | name: tree
36 | - color:
37 | - 120
38 | - 120
39 | - 80
40 | id: 5
41 | isthing: 0
42 | name: ceiling
43 | - color:
44 | - 140
45 | - 140
46 | - 140
47 | id: 6
48 | isthing: 0
49 | name: road, route
50 | - color:
51 | - 204
52 | - 5
53 | - 255
54 | id: 7
55 | isthing: 0
56 | name: bed
57 | - color:
58 | - 230
59 | - 230
60 | - 230
61 | id: 8
62 | isthing: 0
63 | name: windowpane, window
64 | - color:
65 | - 4
66 | - 250
67 | - 7
68 | id: 9
69 | isthing: 0
70 | name: grass
71 | - color:
72 | - 224
73 | - 5
74 | - 255
75 | id: 10
76 | isthing: 0
77 | name: cabinet
78 | - color:
79 | - 235
80 | - 255
81 | - 7
82 | id: 11
83 | isthing: 0
84 | name: sidewalk, pavement
85 | - color:
86 | - 150
87 | - 5
88 | - 61
89 | id: 12
90 | isthing: 0
91 | name: person, individual, someone, somebody, mortal, soul
92 | - color:
93 | - 120
94 | - 120
95 | - 70
96 | id: 13
97 | isthing: 0
98 | name: earth, ground
99 | - color:
100 | - 8
101 | - 255
102 | - 51
103 | id: 14
104 | isthing: 0
105 | name: door, double door
106 | - color:
107 | - 255
108 | - 6
109 | - 82
110 | id: 15
111 | isthing: 0
112 | name: table
113 | - color:
114 | - 143
115 | - 255
116 | - 140
117 | id: 16
118 | isthing: 0
119 | name: mountain, mount
120 | - color:
121 | - 204
122 | - 255
123 | - 4
124 | id: 17
125 | isthing: 0
126 | name: plant, flora, plant life
127 | - color:
128 | - 255
129 | - 51
130 | - 7
131 | id: 18
132 | isthing: 0
133 | name: curtain, drape, drapery, mantle, pall
134 | - color:
135 | - 204
136 | - 70
137 | - 3
138 | id: 19
139 | isthing: 0
140 | name: chair
141 | - color:
142 | - 0
143 | - 102
144 | - 200
145 | id: 20
146 | isthing: 0
147 | name: car, auto, automobile, machine, motorcar
148 | - color:
149 | - 61
150 | - 230
151 | - 250
152 | id: 21
153 | isthing: 0
154 | name: water
155 | - color:
156 | - 255
157 | - 6
158 | - 51
159 | id: 22
160 | isthing: 0
161 | name: painting, picture
162 | - color:
163 | - 11
164 | - 102
165 | - 255
166 | id: 23
167 | isthing: 0
168 | name: sofa, couch, lounge
169 | - color:
170 | - 255
171 | - 7
172 | - 71
173 | id: 24
174 | isthing: 0
175 | name: shelf
176 | - color:
177 | - 255
178 | - 9
179 | - 224
180 | id: 25
181 | isthing: 0
182 | name: house
183 | - color:
184 | - 9
185 | - 7
186 | - 230
187 | id: 26
188 | isthing: 0
189 | name: sea
190 | - color:
191 | - 220
192 | - 220
193 | - 220
194 | id: 27
195 | isthing: 0
196 | name: mirror
197 | - color:
198 | - 255
199 | - 9
200 | - 92
201 | id: 28
202 | isthing: 0
203 | name: rug, carpet, carpeting
204 | - color:
205 | - 112
206 | - 9
207 | - 255
208 | id: 29
209 | isthing: 0
210 | name: field
211 | - color:
212 | - 8
213 | - 255
214 | - 214
215 | id: 30
216 | isthing: 0
217 | name: armchair
218 | - color:
219 | - 7
220 | - 255
221 | - 224
222 | id: 31
223 | isthing: 0
224 | name: seat
225 | - color:
226 | - 255
227 | - 184
228 | - 6
229 | id: 32
230 | isthing: 0
231 | name: fence, fencing
232 | - color:
233 | - 10
234 | - 255
235 | - 71
236 | id: 33
237 | isthing: 0
238 | name: desk
239 | - color:
240 | - 255
241 | - 41
242 | - 10
243 | id: 34
244 | isthing: 0
245 | name: rock, stone
246 | - color:
247 | - 7
248 | - 255
249 | - 255
250 | id: 35
251 | isthing: 0
252 | name: wardrobe, closet, press
253 | - color:
254 | - 224
255 | - 255
256 | - 8
257 | id: 36
258 | isthing: 0
259 | name: lamp
260 | - color:
261 | - 102
262 | - 8
263 | - 255
264 | id: 37
265 | isthing: 0
266 | name: bathtub, bathing tub, bath, tub
267 | - color:
268 | - 255
269 | - 61
270 | - 6
271 | id: 38
272 | isthing: 0
273 | name: railing, rail
274 | - color:
275 | - 255
276 | - 194
277 | - 7
278 | id: 39
279 | isthing: 0
280 | name: cushion
281 | - color:
282 | - 255
283 | - 122
284 | - 8
285 | id: 40
286 | isthing: 0
287 | name: base, pedestal, stand
288 | - color:
289 | - 0
290 | - 255
291 | - 20
292 | id: 41
293 | isthing: 0
294 | name: box
295 | - color:
296 | - 255
297 | - 8
298 | - 41
299 | id: 42
300 | isthing: 0
301 | name: column, pillar
302 | - color:
303 | - 255
304 | - 5
305 | - 153
306 | id: 43
307 | isthing: 0
308 | name: signboard, sign
309 | - color:
310 | - 6
311 | - 51
312 | - 255
313 | id: 44
314 | isthing: 0
315 | name: chest of drawers, chest, bureau, dresser
316 | - color:
317 | - 235
318 | - 12
319 | - 255
320 | id: 45
321 | isthing: 0
322 | name: counter
323 | - color:
324 | - 160
325 | - 150
326 | - 20
327 | id: 46
328 | isthing: 0
329 | name: sand
330 | - color:
331 | - 0
332 | - 163
333 | - 255
334 | id: 47
335 | isthing: 0
336 | name: sink
337 | - color:
338 | - 140
339 | - 140
340 | - 140
341 | id: 48
342 | isthing: 0
343 | name: skyscraper
344 | - color:
345 | - 250
346 | - 10
347 | - 15
348 | id: 49
349 | isthing: 0
350 | name: fireplace, hearth, open fireplace
351 | - color:
352 | - 20
353 | - 255
354 | - 0
355 | id: 50
356 | isthing: 0
357 | name: refrigerator, icebox
358 | - color:
359 | - 31
360 | - 255
361 | - 0
362 | id: 51
363 | isthing: 0
364 | name: grandstand, covered stand
365 | - color:
366 | - 255
367 | - 31
368 | - 0
369 | id: 52
370 | isthing: 0
371 | name: path
372 | - color:
373 | - 255
374 | - 224
375 | - 0
376 | id: 53
377 | isthing: 0
378 | name: stairs, steps
379 | - color:
380 | - 153
381 | - 255
382 | - 0
383 | id: 54
384 | isthing: 0
385 | name: runway
386 | - color:
387 | - 0
388 | - 0
389 | - 255
390 | id: 55
391 | isthing: 0
392 | name: case, display case, showcase, vitrine
393 | - color:
394 | - 255
395 | - 71
396 | - 0
397 | id: 56
398 | isthing: 0
399 | name: pool table, billiard table, snooker table
400 | - color:
401 | - 0
402 | - 235
403 | - 255
404 | id: 57
405 | isthing: 0
406 | name: pillow
407 | - color:
408 | - 0
409 | - 173
410 | - 255
411 | id: 58
412 | isthing: 0
413 | name: screen door, screen
414 | - color:
415 | - 31
416 | - 0
417 | - 255
418 | id: 59
419 | isthing: 0
420 | name: stairway, staircase
421 | - color:
422 | - 11
423 | - 200
424 | - 200
425 | id: 60
426 | isthing: 0
427 | name: river
428 | - color:
429 | - 255
430 | - 82
431 | - 0
432 | id: 61
433 | isthing: 0
434 | name: bridge, span
435 | - color:
436 | - 0
437 | - 255
438 | - 245
439 | id: 62
440 | isthing: 0
441 | name: bookcase
442 | - color:
443 | - 0
444 | - 61
445 | - 255
446 | id: 63
447 | isthing: 0
448 | name: blind, screen
449 | - color:
450 | - 0
451 | - 255
452 | - 112
453 | id: 64
454 | isthing: 0
455 | name: coffee table, cocktail table
456 | - color:
457 | - 0
458 | - 255
459 | - 133
460 | id: 65
461 | isthing: 0
462 | name: toilet, can, commode, crapper, pot, potty, stool, throne
463 | - color:
464 | - 255
465 | - 0
466 | - 0
467 | id: 66
468 | isthing: 0
469 | name: flower
470 | - color:
471 | - 255
472 | - 163
473 | - 0
474 | id: 67
475 | isthing: 0
476 | name: book
477 | - color:
478 | - 255
479 | - 102
480 | - 0
481 | id: 68
482 | isthing: 0
483 | name: hill
484 | - color:
485 | - 194
486 | - 255
487 | - 0
488 | id: 69
489 | isthing: 0
490 | name: bench
491 | - color:
492 | - 0
493 | - 143
494 | - 255
495 | id: 70
496 | isthing: 0
497 | name: countertop
498 | - color:
499 | - 51
500 | - 255
501 | - 0
502 | id: 71
503 | isthing: 0
504 | name: stove, kitchen stove, range, kitchen range, cooking stove
505 | - color:
506 | - 0
507 | - 82
508 | - 255
509 | id: 72
510 | isthing: 0
511 | name: palm, palm tree
512 | - color:
513 | - 0
514 | - 255
515 | - 41
516 | id: 73
517 | isthing: 0
518 | name: kitchen island
519 | - color:
520 | - 0
521 | - 255
522 | - 173
523 | id: 74
524 | isthing: 0
525 | name: computer, computing machine, computing device, data processor, electronic
526 | computer, information processing system
527 | - color:
528 | - 10
529 | - 0
530 | - 255
531 | id: 75
532 | isthing: 0
533 | name: swivel chair
534 | - color:
535 | - 173
536 | - 255
537 | - 0
538 | id: 76
539 | isthing: 0
540 | name: boat
541 | - color:
542 | - 0
543 | - 255
544 | - 153
545 | id: 77
546 | isthing: 0
547 | name: bar
548 | - color:
549 | - 255
550 | - 92
551 | - 0
552 | id: 78
553 | isthing: 0
554 | name: arcade machine
555 | - color:
556 | - 255
557 | - 0
558 | - 255
559 | id: 79
560 | isthing: 0
561 | name: hovel, hut, hutch, shack, shanty
562 | - color:
563 | - 255
564 | - 0
565 | - 245
566 | id: 80
567 | isthing: 0
568 | name: bus, autobus, coach, charabanc, double-decker, jitney, motorbus, motorcoach,
569 | omnibus, passenger vehicle
570 | - color:
571 | - 255
572 | - 0
573 | - 102
574 | id: 81
575 | isthing: 0
576 | name: towel
577 | - color:
578 | - 255
579 | - 173
580 | - 0
581 | id: 82
582 | isthing: 0
583 | name: light, light source
584 | - color:
585 | - 255
586 | - 0
587 | - 20
588 | id: 83
589 | isthing: 0
590 | name: truck, motortruck
591 | - color:
592 | - 255
593 | - 184
594 | - 184
595 | id: 84
596 | isthing: 0
597 | name: tower
598 | - color:
599 | - 0
600 | - 31
601 | - 255
602 | id: 85
603 | isthing: 0
604 | name: chandelier, pendant, pendent
605 | - color:
606 | - 0
607 | - 255
608 | - 61
609 | id: 86
610 | isthing: 0
611 | name: awning, sunshade, sunblind
612 | - color:
613 | - 0
614 | - 71
615 | - 255
616 | id: 87
617 | isthing: 0
618 | name: streetlight, street lamp
619 | - color:
620 | - 255
621 | - 0
622 | - 204
623 | id: 88
624 | isthing: 0
625 | name: booth, cubicle, stall, kiosk
626 | - color:
627 | - 0
628 | - 255
629 | - 194
630 | id: 89
631 | isthing: 0
632 | name: television, television receiver, television set, tv, tv set, idiot box, boob
633 | tube, telly, goggle box
634 | - color:
635 | - 0
636 | - 255
637 | - 82
638 | id: 90
639 | isthing: 0
640 | name: airplane, aeroplane, plane
641 | - color:
642 | - 0
643 | - 10
644 | - 255
645 | id: 91
646 | isthing: 0
647 | name: dirt track
648 | - color:
649 | - 0
650 | - 112
651 | - 255
652 | id: 92
653 | isthing: 0
654 | name: apparel, wearing apparel, dress, clothes
655 | - color:
656 | - 51
657 | - 0
658 | - 255
659 | id: 93
660 | isthing: 0
661 | name: pole
662 | - color:
663 | - 0
664 | - 194
665 | - 255
666 | id: 94
667 | isthing: 0
668 | name: land, ground, soil
669 | - color:
670 | - 0
671 | - 122
672 | - 255
673 | id: 95
674 | isthing: 0
675 | name: bannister, banister, balustrade, balusters, handrail
676 | - color:
677 | - 0
678 | - 255
679 | - 163
680 | id: 96
681 | isthing: 0
682 | name: escalator, moving staircase, moving stairway
683 | - color:
684 | - 255
685 | - 153
686 | - 0
687 | id: 97
688 | isthing: 0
689 | name: ottoman, pouf, pouffe, puff, hassock
690 | - color:
691 | - 0
692 | - 255
693 | - 10
694 | id: 98
695 | isthing: 0
696 | name: bottle
697 | - color:
698 | - 255
699 | - 112
700 | - 0
701 | id: 99
702 | isthing: 0
703 | name: buffet, counter, sideboard
704 | - color:
705 | - 143
706 | - 255
707 | - 0
708 | id: 100
709 | isthing: 0
710 | name: poster, posting, placard, notice, bill, card
711 | - color:
712 | - 82
713 | - 0
714 | - 255
715 | id: 101
716 | isthing: 0
717 | name: stage
718 | - color:
719 | - 163
720 | - 255
721 | - 0
722 | id: 102
723 | isthing: 0
724 | name: van
725 | - color:
726 | - 255
727 | - 235
728 | - 0
729 | id: 103
730 | isthing: 0
731 | name: ship
732 | - color:
733 | - 8
734 | - 184
735 | - 170
736 | id: 104
737 | isthing: 0
738 | name: fountain
739 | - color:
740 | - 133
741 | - 0
742 | - 255
743 | id: 105
744 | isthing: 0
745 | name: conveyer belt, conveyor belt, conveyer, conveyor, transporter
746 | - color:
747 | - 0
748 | - 255
749 | - 92
750 | id: 106
751 | isthing: 0
752 | name: canopy
753 | - color:
754 | - 184
755 | - 0
756 | - 255
757 | id: 107
758 | isthing: 0
759 | name: washer, automatic washer, washing machine
760 | - color:
761 | - 255
762 | - 0
763 | - 31
764 | id: 108
765 | isthing: 0
766 | name: plaything, toy
767 | - color:
768 | - 0
769 | - 184
770 | - 255
771 | id: 109
772 | isthing: 0
773 | name: swimming pool, swimming bath, natatorium
774 | - color:
775 | - 0
776 | - 214
777 | - 255
778 | id: 110
779 | isthing: 0
780 | name: stool
781 | - color:
782 | - 255
783 | - 0
784 | - 112
785 | id: 111
786 | isthing: 0
787 | name: barrel, cask
788 | - color:
789 | - 92
790 | - 255
791 | - 0
792 | id: 112
793 | isthing: 0
794 | name: basket, handbasket
795 | - color:
796 | - 0
797 | - 224
798 | - 255
799 | id: 113
800 | isthing: 0
801 | name: waterfall, falls
802 | - color:
803 | - 112
804 | - 224
805 | - 255
806 | id: 114
807 | isthing: 0
808 | name: tent, collapsible shelter
809 | - color:
810 | - 70
811 | - 184
812 | - 160
813 | id: 115
814 | isthing: 0
815 | name: bag
816 | - color:
817 | - 163
818 | - 0
819 | - 255
820 | id: 116
821 | isthing: 0
822 | name: minibike, motorbike
823 | - color:
824 | - 153
825 | - 0
826 | - 255
827 | id: 117
828 | isthing: 0
829 | name: cradle
830 | - color:
831 | - 71
832 | - 255
833 | - 0
834 | id: 118
835 | isthing: 0
836 | name: oven
837 | - color:
838 | - 255
839 | - 0
840 | - 163
841 | id: 119
842 | isthing: 0
843 | name: ball
844 | - color:
845 | - 255
846 | - 204
847 | - 0
848 | id: 120
849 | isthing: 0
850 | name: food, solid food
851 | - color:
852 | - 255
853 | - 0
854 | - 143
855 | id: 121
856 | isthing: 0
857 | name: step, stair
858 | - color:
859 | - 0
860 | - 255
861 | - 235
862 | id: 122
863 | isthing: 0
864 | name: tank, storage tank
865 | - color:
866 | - 133
867 | - 255
868 | - 0
869 | id: 123
870 | isthing: 0
871 | name: trade name, brand name, brand, marque
872 | - color:
873 | - 255
874 | - 0
875 | - 235
876 | id: 124
877 | isthing: 0
878 | name: microwave, microwave oven
879 | - color:
880 | - 245
881 | - 0
882 | - 255
883 | id: 125
884 | isthing: 0
885 | name: pot, flowerpot
886 | - color:
887 | - 255
888 | - 0
889 | - 122
890 | id: 126
891 | isthing: 0
892 | name: animal, animate being, beast, brute, creature, fauna
893 | - color:
894 | - 255
895 | - 245
896 | - 0
897 | id: 127
898 | isthing: 0
899 | name: bicycle, bike, wheel, cycle
900 | - color:
901 | - 10
902 | - 190
903 | - 212
904 | id: 128
905 | isthing: 0
906 | name: lake
907 | - color:
908 | - 214
909 | - 255
910 | - 0
911 | id: 129
912 | isthing: 0
913 | name: dishwasher, dish washer, dishwashing machine
914 | - color:
915 | - 0
916 | - 204
917 | - 255
918 | id: 130
919 | isthing: 0
920 | name: screen, silver screen, projection screen
921 | - color:
922 | - 20
923 | - 0
924 | - 255
925 | id: 131
926 | isthing: 0
927 | name: blanket, cover
928 | - color:
929 | - 255
930 | - 255
931 | - 0
932 | id: 132
933 | isthing: 0
934 | name: sculpture
935 | - color:
936 | - 0
937 | - 153
938 | - 255
939 | id: 133
940 | isthing: 0
941 | name: hood, exhaust hood
942 | - color:
943 | - 0
944 | - 41
945 | - 255
946 | id: 134
947 | isthing: 0
948 | name: sconce
949 | - color:
950 | - 0
951 | - 255
952 | - 204
953 | id: 135
954 | isthing: 0
955 | name: vase
956 | - color:
957 | - 41
958 | - 0
959 | - 255
960 | id: 136
961 | isthing: 0
962 | name: traffic light, traffic signal, stoplight
963 | - color:
964 | - 41
965 | - 255
966 | - 0
967 | id: 137
968 | isthing: 0
969 | name: tray
970 | - color:
971 | - 173
972 | - 0
973 | - 255
974 | id: 138
975 | isthing: 0
976 | name: ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin,
977 | trash barrel, trash bin
978 | - color:
979 | - 0
980 | - 245
981 | - 255
982 | id: 139
983 | isthing: 0
984 | name: fan
985 | - color:
986 | - 71
987 | - 0
988 | - 255
989 | id: 140
990 | isthing: 0
991 | name: pier, wharf, wharfage, dock
992 | - color:
993 | - 122
994 | - 0
995 | - 255
996 | id: 141
997 | isthing: 0
998 | name: crt screen
999 | - color:
1000 | - 0
1001 | - 255
1002 | - 184
1003 | id: 142
1004 | isthing: 0
1005 | name: plate
1006 | - color:
1007 | - 0
1008 | - 92
1009 | - 255
1010 | id: 143
1011 | isthing: 0
1012 | name: monitor, monitoring device
1013 | - color:
1014 | - 184
1015 | - 255
1016 | - 0
1017 | id: 144
1018 | isthing: 0
1019 | name: bulletin board, notice board
1020 | - color:
1021 | - 0
1022 | - 133
1023 | - 255
1024 | id: 145
1025 | isthing: 0
1026 | name: shower
1027 | - color:
1028 | - 255
1029 | - 214
1030 | - 0
1031 | id: 146
1032 | isthing: 0
1033 | name: radiator
1034 | - color:
1035 | - 25
1036 | - 194
1037 | - 194
1038 | id: 147
1039 | isthing: 0
1040 | name: glass, drinking glass
1041 | - color:
1042 | - 102
1043 | - 255
1044 | - 0
1045 | id: 148
1046 | isthing: 0
1047 | name: clock
1048 | - color:
1049 | - 92
1050 | - 0
1051 | - 255
1052 | id: 149
1053 | isthing: 0
1054 | name: flag
1055 |
--------------------------------------------------------------------------------
/segm/data/config/cityscapes.py:
--------------------------------------------------------------------------------
1 | # dataset settings
2 | dataset_type = "CityscapesDataset"
3 | data_root = "data/cityscapes/"
4 | img_norm_cfg = dict(
5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
6 | )
7 | crop_size = (768, 768)
8 | max_ratio = 2
9 | train_pipeline = [
10 | dict(type="LoadImageFromFile"),
11 | dict(type="LoadAnnotations"),
12 | dict(type="Resize", img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
13 | dict(type="RandomCrop", crop_size=crop_size, cat_max_ratio=0.75),
14 | dict(type="RandomFlip", prob=0.5),
15 | dict(type="PhotoMetricDistortion"),
16 | dict(type="Normalize", **img_norm_cfg),
17 | dict(type="Pad", size=crop_size, pad_val=0, seg_pad_val=255),
18 | dict(type="DefaultFormatBundle"),
19 | dict(type="Collect", keys=["img", "gt_semantic_seg"]),
20 | ]
21 | val_pipeline = [
22 | dict(type="LoadImageFromFile"),
23 | dict(
24 | type="MultiScaleFlipAug",
25 | img_scale=(1024 * max_ratio, 1024),
26 | flip=False,
27 | transforms=[
28 | dict(type="Resize", keep_ratio=True),
29 | dict(type="RandomFlip"),
30 | dict(type="Normalize", **img_norm_cfg),
31 | dict(type="ImageToTensor", keys=["img"]),
32 | dict(type="Collect", keys=["img"]),
33 | ],
34 | ),
35 | ]
36 | test_pipeline = [
37 | dict(type="LoadImageFromFile"),
38 | dict(
39 | type="MultiScaleFlipAug",
40 | img_scale=(1024 * max_ratio, 1024),
41 | flip=False,
42 | transforms=[
43 | dict(type="Resize", keep_ratio=True),
44 | dict(type="RandomFlip"),
45 | dict(type="Normalize", **img_norm_cfg),
46 | dict(type="ImageToTensor", keys=["img"]),
47 | dict(type="Collect", keys=["img"]),
48 | ],
49 | ),
50 | ]
51 | data = dict(
52 | samples_per_gpu=2,
53 | workers_per_gpu=2,
54 | train=dict(
55 | type=dataset_type,
56 | data_root=data_root,
57 | img_dir="leftImg8bit/train",
58 | ann_dir="gtFine/train",
59 | pipeline=train_pipeline,
60 | ),
61 | trainval=dict(
62 | type=dataset_type,
63 | data_root=data_root,
64 | img_dir=["leftImg8bit/train", "leftImg8bit/val"],
65 | ann_dir=["gtFine/train", "gtFine/val"],
66 | pipeline=train_pipeline,
67 | ),
68 | val=dict(
69 | type=dataset_type,
70 | data_root=data_root,
71 | img_dir="leftImg8bit/val",
72 | ann_dir="gtFine/val",
73 | pipeline=test_pipeline,
74 | ),
75 | test=dict(
76 | type=dataset_type,
77 | data_root=data_root,
78 | img_dir="leftImg8bit/test",
79 | ann_dir="gtFine/test",
80 | pipeline=test_pipeline,
81 | ),
82 | )
83 |
--------------------------------------------------------------------------------
/segm/data/config/cityscapes.yml:
--------------------------------------------------------------------------------
1 | - color:
2 | - 128
3 | - 64
4 | - 128
5 | id: 0
6 | isthing: false
7 | name: road
8 | - color:
9 | - 244
10 | - 35
11 | - 232
12 | id: 1
13 | isthing: false
14 | name: sidewalk
15 | - color:
16 | - 70
17 | - 70
18 | - 70
19 | id: 2
20 | isthing: false
21 | name: building
22 | - color:
23 | - 102
24 | - 102
25 | - 156
26 | id: 3
27 | isthing: false
28 | name: wall
29 | - color:
30 | - 190
31 | - 153
32 | - 153
33 | id: 4
34 | isthing: false
35 | name: fence
36 | - color:
37 | - 153
38 | - 153
39 | - 153
40 | id: 5
41 | isthing: false
42 | name: pole
43 | - color:
44 | - 250
45 | - 170
46 | - 30
47 | id: 6
48 | isthing: false
49 | name: traffic light
50 | - color:
51 | - 220
52 | - 220
53 | - 0
54 | id: 7
55 | isthing: false
56 | name: traffic sign
57 | - color:
58 | - 107
59 | - 142
60 | - 35
61 | id: 8
62 | isthing: false
63 | name: vegetation
64 | - color:
65 | - 152
66 | - 251
67 | - 152
68 | id: 9
69 | isthing: false
70 | name: terrain
71 | - color:
72 | - 70
73 | - 130
74 | - 180
75 | id: 10
76 | isthing: false
77 | name: sky
78 | - color:
79 | - 220
80 | - 20
81 | - 60
82 | id: 11
83 | isthing: true
84 | name: person
85 | - color:
86 | - 255
87 | - 0
88 | - 0
89 | id: 12
90 | isthing: true
91 | name: rider
92 | - color:
93 | - 0
94 | - 0
95 | - 142
96 | id: 13
97 | isthing: true
98 | name: car
99 | - color:
100 | - 0
101 | - 0
102 | - 70
103 | id: 14
104 | isthing: true
105 | name: truck
106 | - color:
107 | - 0
108 | - 60
109 | - 100
110 | id: 15
111 | isthing: true
112 | name: bus
113 | - color:
114 | - 0
115 | - 80
116 | - 100
117 | id: 16
118 | isthing: true
119 | name: train
120 | - color:
121 | - 0
122 | - 0
123 | - 230
124 | id: 17
125 | isthing: true
126 | name: motorcycle
127 | - color:
128 | - 119
129 | - 11
130 | - 32
131 | id: 18
132 | isthing: true
133 | name: bicycle
134 |
--------------------------------------------------------------------------------
/segm/data/config/pascal_context.py:
--------------------------------------------------------------------------------
1 | # dataset settings
2 | dataset_type = "PascalContextDataset"
3 | data_root = "data/VOCdevkit/VOC2010/"
4 | img_norm_cfg = dict(
5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
6 | )
7 |
8 | img_scale = (512, 512)
9 | crop_size = (512, 512)
10 | max_ratio = 8
11 | train_pipeline = [
12 | dict(type="LoadImageFromFile"),
13 | dict(type="LoadAnnotations"),
14 | dict(type="Resize", img_scale=img_scale, ratio_range=(0.5, 2.0)),
15 | dict(type="RandomCrop", crop_size=crop_size, cat_max_ratio=0.75),
16 | dict(type="RandomFlip", prob=0.5),
17 | dict(type="PhotoMetricDistortion"),
18 | dict(type="Normalize", **img_norm_cfg),
19 | dict(type="Pad", size=crop_size, pad_val=0, seg_pad_val=255),
20 | dict(type="DefaultFormatBundle"),
21 | dict(type="Collect", keys=["img", "gt_semantic_seg"]),
22 | ]
23 | val_pipeline = [
24 | dict(type="LoadImageFromFile"),
25 | dict(
26 | type="MultiScaleFlipAug",
27 | img_scale=(512 * max_ratio, 512),
28 | flip=False,
29 | transforms=[
30 | dict(type="Resize", keep_ratio=True),
31 | dict(type="RandomFlip"),
32 | dict(type="Normalize", **img_norm_cfg),
33 | dict(type="ImageToTensor", keys=["img"]),
34 | dict(type="Collect", keys=["img"]),
35 | ],
36 | ),
37 | ]
38 | test_pipeline = [
39 | dict(type="LoadImageFromFile"),
40 | dict(
41 | type="MultiScaleFlipAug",
42 | img_scale=(512 * max_ratio, 512),
43 | flip=False,
44 | transforms=[
45 | dict(type="Resize", keep_ratio=True),
46 | dict(type="RandomFlip"),
47 | dict(type="Normalize", **img_norm_cfg),
48 | dict(type="ImageToTensor", keys=["img"]),
49 | dict(type="Collect", keys=["img"]),
50 | ],
51 | ),
52 | ]
53 | data = dict(
54 | samples_per_gpu=4,
55 | workers_per_gpu=4,
56 | train=dict(
57 | type=dataset_type,
58 | data_root=data_root,
59 | img_dir="JPEGImages",
60 | ann_dir="SegmentationClassContext",
61 | split="ImageSets/SegmentationContext/train.txt",
62 | pipeline=train_pipeline,
63 | ),
64 | val=dict(
65 | type=dataset_type,
66 | data_root=data_root,
67 | img_dir="JPEGImages",
68 | ann_dir="SegmentationClassContext",
69 | split="ImageSets/SegmentationContext/val.txt",
70 | pipeline=val_pipeline,
71 | ),
72 | test=dict(
73 | type=dataset_type,
74 | data_root=data_root,
75 | img_dir="JPEGImages",
76 | ann_dir="SegmentationClassContext",
77 | split="ImageSets/SegmentationContext/val.txt",
78 | pipeline=test_pipeline,
79 | ),
80 | )
81 |
--------------------------------------------------------------------------------
/segm/data/config/pascal_context.yml:
--------------------------------------------------------------------------------
1 | - color:
2 | - 120
3 | - 120
4 | - 120
5 | id: 0
6 | name: background
7 | - color:
8 | - 180
9 | - 120
10 | - 120
11 | id: 1
12 | name: aeroplane
13 | - color:
14 | - 6
15 | - 230
16 | - 230
17 | id: 2
18 | name: bicycle
19 | - color:
20 | - 80
21 | - 50
22 | - 50
23 | id: 3
24 | name: bird
25 | - color:
26 | - 4
27 | - 200
28 | - 3
29 | id: 4
30 | name: boat
31 | - color:
32 | - 120
33 | - 120
34 | - 80
35 | id: 5
36 | name: bottle
37 | - color:
38 | - 140
39 | - 140
40 | - 140
41 | id: 6
42 | name: bus
43 | - color:
44 | - 204
45 | - 5
46 | - 255
47 | id: 7
48 | name: car
49 | - color:
50 | - 230
51 | - 230
52 | - 230
53 | id: 8
54 | name: cat
55 | - color:
56 | - 4
57 | - 250
58 | - 7
59 | id: 9
60 | name: chair
61 | - color:
62 | - 224
63 | - 5
64 | - 255
65 | id: 10
66 | name: cow
67 | - color:
68 | - 235
69 | - 255
70 | - 7
71 | id: 11
72 | name: table
73 | - color:
74 | - 150
75 | - 5
76 | - 61
77 | id: 12
78 | name: dog
79 | - color:
80 | - 120
81 | - 120
82 | - 70
83 | id: 13
84 | name: horse
85 | - color:
86 | - 8
87 | - 255
88 | - 51
89 | id: 14
90 | name: motorbike
91 | - color:
92 | - 255
93 | - 6
94 | - 82
95 | id: 15
96 | name: person
97 | - color:
98 | - 143
99 | - 255
100 | - 140
101 | id: 16
102 | name: pottedplant
103 | - color:
104 | - 204
105 | - 255
106 | - 4
107 | id: 17
108 | name: sheep
109 | - color:
110 | - 255
111 | - 51
112 | - 7
113 | id: 18
114 | name: sofa
115 | - color:
116 | - 204
117 | - 70
118 | - 3
119 | id: 19
120 | name: train
121 | - color:
122 | - 0
123 | - 102
124 | - 200
125 | id: 20
126 | name: tvmonitor
127 | - color:
128 | - 61
129 | - 230
130 | - 250
131 | id: 21
132 | name: bag
133 | - color:
134 | - 255
135 | - 6
136 | - 51
137 | id: 22
138 | name: bed
139 | - color:
140 | - 11
141 | - 102
142 | - 255
143 | id: 23
144 | name: bench
145 | - color:
146 | - 255
147 | - 7
148 | - 71
149 | id: 24
150 | name: book
151 | - color:
152 | - 255
153 | - 9
154 | - 224
155 | id: 25
156 | name: building
157 | - color:
158 | - 9
159 | - 7
160 | - 230
161 | id: 26
162 | name: cabinet
163 | - color:
164 | - 220
165 | - 220
166 | - 220
167 | id: 27
168 | name: ceiling
169 | - color:
170 | - 255
171 | - 9
172 | - 92
173 | id: 28
174 | name: cloth
175 | - color:
176 | - 112
177 | - 9
178 | - 255
179 | id: 29
180 | name: computer
181 | - color:
182 | - 8
183 | - 255
184 | - 214
185 | id: 30
186 | name: cup
187 | - color:
188 | - 7
189 | - 255
190 | - 224
191 | id: 31
192 | name: door
193 | - color:
194 | - 255
195 | - 184
196 | - 6
197 | id: 32
198 | name: fence
199 | - color:
200 | - 10
201 | - 255
202 | - 71
203 | id: 33
204 | name: floor
205 | - color:
206 | - 255
207 | - 41
208 | - 10
209 | id: 34
210 | name: flower
211 | - color:
212 | - 7
213 | - 255
214 | - 255
215 | id: 35
216 | name: food
217 | - color:
218 | - 224
219 | - 255
220 | - 8
221 | id: 36
222 | name: grass
223 | - color:
224 | - 102
225 | - 8
226 | - 255
227 | id: 37
228 | name: ground
229 | - color:
230 | - 255
231 | - 61
232 | - 6
233 | id: 38
234 | name: keyboard
235 | - color:
236 | - 255
237 | - 194
238 | - 7
239 | id: 39
240 | name: light
241 | - color:
242 | - 255
243 | - 122
244 | - 8
245 | id: 40
246 | name: mountain
247 | - color:
248 | - 0
249 | - 255
250 | - 20
251 | id: 41
252 | name: mouse
253 | - color:
254 | - 255
255 | - 8
256 | - 41
257 | id: 42
258 | name: curtain
259 | - color:
260 | - 255
261 | - 5
262 | - 153
263 | id: 43
264 | name: platform
265 | - color:
266 | - 6
267 | - 51
268 | - 255
269 | id: 44
270 | name: sign
271 | - color:
272 | - 235
273 | - 12
274 | - 255
275 | id: 45
276 | name: plate
277 | - color:
278 | - 160
279 | - 150
280 | - 20
281 | id: 46
282 | name: road
283 | - color:
284 | - 0
285 | - 163
286 | - 255
287 | id: 47
288 | name: rock
289 | - color:
290 | - 140
291 | - 140
292 | - 140
293 | id: 48
294 | name: shelves
295 | - color:
296 | - 250
297 | - 10
298 | - 15
299 | id: 49
300 | name: sidewalk
301 | - color:
302 | - 20
303 | - 255
304 | - 0
305 | id: 50
306 | name: sky
307 | - color:
308 | - 31
309 | - 255
310 | - 0
311 | id: 51
312 | name: snow
313 | - color:
314 | - 255
315 | - 31
316 | - 0
317 | id: 52
318 | name: bedclothes
319 | - color:
320 | - 255
321 | - 224
322 | - 0
323 | id: 53
324 | name: track
325 | - color:
326 | - 153
327 | - 255
328 | - 0
329 | id: 54
330 | name: tree
331 | - color:
332 | - 0
333 | - 0
334 | - 255
335 | id: 55
336 | name: truck
337 | - color:
338 | - 255
339 | - 71
340 | - 0
341 | id: 56
342 | name: wall
343 | - color:
344 | - 0
345 | - 235
346 | - 255
347 | id: 57
348 | name: water
349 | - color:
350 | - 0
351 | - 173
352 | - 255
353 | id: 58
354 | name: window
355 | - color:
356 | - 31
357 | - 0
358 | - 255
359 | id: 59
360 | name: wood
361 |
--------------------------------------------------------------------------------
/segm/data/factory.py:
--------------------------------------------------------------------------------
1 | import segm.utils.torch as ptu
2 |
3 | from segm.data import ImagenetDataset
4 | from segm.data import ADE20KSegmentation
5 | from segm.data import PascalContextDataset
6 | from segm.data import CityscapesDataset
7 | from segm.data import Loader
8 |
9 |
10 | def create_dataset(dataset_kwargs):
11 | dataset_kwargs = dataset_kwargs.copy()
12 | dataset_name = dataset_kwargs.pop("dataset")
13 | batch_size = dataset_kwargs.pop("batch_size")
14 | num_workers = dataset_kwargs.pop("num_workers")
15 | split = dataset_kwargs.pop("split")
16 |
17 | # load dataset_name
18 | if dataset_name == "imagenet":
19 | dataset_kwargs.pop("patch_size")
20 | dataset = ImagenetDataset(split=split, **dataset_kwargs)
21 | elif dataset_name == "ade20k":
22 | dataset = ADE20KSegmentation(split=split, **dataset_kwargs)
23 | elif dataset_name == "pascal_context":
24 | dataset = PascalContextDataset(split=split, **dataset_kwargs)
25 | elif dataset_name == "cityscapes":
26 | dataset = CityscapesDataset(split=split, **dataset_kwargs)
27 | else:
28 | raise ValueError(f"Dataset {dataset_name} is unknown.")
29 |
30 | dataset = Loader(
31 | dataset=dataset,
32 | batch_size=batch_size,
33 | num_workers=num_workers,
34 | distributed=ptu.distributed,
35 | split=split,
36 | )
37 | return dataset
38 |
--------------------------------------------------------------------------------
/segm/data/imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from pathlib import Path
4 |
5 | from torch.utils.data import Dataset
6 | from torchvision import datasets
7 | from torchvision import transforms
8 | from PIL import Image
9 |
10 | from segm.data import utils
11 | from segm.config import dataset_dir
12 |
13 |
14 | class ImagenetDataset(Dataset):
15 | def __init__(
16 | self,
17 | root_dir,
18 | image_size=224,
19 | crop_size=224,
20 | split="train",
21 | normalization="vit",
22 | ):
23 | super().__init__()
24 | assert image_size[0] == image_size[1]
25 |
26 | self.path = Path(root_dir) / split
27 | self.crop_size = crop_size
28 | self.image_size = image_size
29 | self.split = split
30 | self.normalization = normalization
31 |
32 | if split == "train":
33 | self.transform = transforms.Compose(
34 | [
35 | transforms.RandomResizedCrop(self.crop_size, interpolation=3),
36 | transforms.RandomHorizontalFlip(),
37 | transforms.ToTensor(),
38 | ]
39 | )
40 | else:
41 | self.transform = transforms.Compose(
42 | [
43 | transforms.Resize(image_size[0] + 32, interpolation=3),
44 | transforms.CenterCrop(self.crop_size),
45 | transforms.ToTensor(),
46 | ]
47 | )
48 |
49 | self.base_dataset = datasets.ImageFolder(self.path, self.transform)
50 | self.n_cls = 1000
51 |
52 | @property
53 | def unwrapped(self):
54 | return self
55 |
56 | def __len__(self):
57 | return len(self.base_dataset)
58 |
59 | def __getitem__(self, idx):
60 | im, target = self.base_dataset[idx]
61 | im = utils.rgb_normalize(im, self.normalization)
62 | return dict(im=im, target=target)
63 |
--------------------------------------------------------------------------------
/segm/data/loader.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader
2 | from torch.utils.data.distributed import DistributedSampler
3 |
4 | import segm.utils.torch as ptu
5 |
6 |
7 | class Loader(DataLoader):
8 | def __init__(self, dataset, batch_size, num_workers, distributed, split):
9 | if distributed:
10 | sampler = DistributedSampler(dataset, shuffle=True)
11 | super().__init__(
12 | dataset,
13 | batch_size=batch_size,
14 | shuffle=False,
15 | num_workers=num_workers,
16 | pin_memory=True,
17 | sampler=sampler,
18 | )
19 | else:
20 | super().__init__(
21 | dataset,
22 | batch_size=batch_size,
23 | shuffle=True,
24 | num_workers=num_workers,
25 | pin_memory=True,
26 | )
27 |
28 | self.base_dataset = self.dataset
29 |
30 | @property
31 | def unwrapped(self):
32 | return self.base_dataset.unwrapped
33 |
34 | def set_epoch(self, epoch):
35 | if isinstance(self.sampler, DistributedSampler):
36 | self.sampler.set_epoch(epoch)
37 |
38 | def get_diagnostics(self, logger):
39 | return self.base_dataset.get_diagnostics(logger)
40 |
41 | def get_snapshot(self):
42 | return self.base_dataset.get_snapshot()
43 |
44 | def end_epoch(self, epoch):
45 | return self.base_dataset.end_epoch(epoch)
46 |
--------------------------------------------------------------------------------
/segm/data/pascal_context.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from segm.data.base import BaseMMSeg
4 | from segm.data import utils
5 | from segm.config import dataset_dir
6 |
7 | PASCAL_CONTEXT_CONFIG_PATH = Path(__file__).parent / "config" / "pascal_context.py"
8 | PASCAL_CONTEXT_CATS_PATH = Path(__file__).parent / "config" / "pascal_context.yml"
9 |
10 |
11 | class PascalContextDataset(BaseMMSeg):
12 | def __init__(self, image_size, crop_size, split, **kwargs):
13 | super().__init__(
14 | image_size, crop_size, split, PASCAL_CONTEXT_CONFIG_PATH, **kwargs
15 | )
16 | self.names, self.colors = utils.dataset_cat_description(
17 | PASCAL_CONTEXT_CATS_PATH
18 | )
19 | self.n_cls = 60
20 | self.ignore_label = 255
21 | self.reduce_zero_label = False
22 |
23 | def update_default_config(self, config):
24 | root_dir = dataset_dir()
25 | path = Path(root_dir) / "pcontext"
26 | config.data_root = path
27 | if self.split == "train":
28 | config.data.train.data_root = path / "VOCdevkit/VOC2010/"
29 | elif self.split == "val":
30 | config.data.val.data_root = path / "VOCdevkit/VOC2010/"
31 | elif self.split == "test":
32 | raise ValueError("Test split is not valid for Pascal Context dataset")
33 | config = super().update_default_config(config)
34 | return config
35 |
36 | def test_post_process(self, labels):
37 | return labels
38 |
--------------------------------------------------------------------------------
/segm/data/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.transforms.functional as F
3 | import numpy as np
4 | import yaml
5 | from pathlib import Path
6 |
7 | IGNORE_LABEL = 255
8 | STATS = {
9 | "vit": {"mean": (0.5, 0.5, 0.5), "std": (0.5, 0.5, 0.5)},
10 | "deit": {"mean": (0.485, 0.456, 0.406), "std": (0.229, 0.224, 0.225)},
11 | }
12 |
13 |
14 | def seg_to_rgb(seg, colors):
15 | im = torch.zeros((seg.shape[0], seg.shape[1], seg.shape[2], 3)).float()
16 | cls = torch.unique(seg)
17 | for cl in cls:
18 | color = colors[int(cl)]
19 | if len(color.shape) > 1:
20 | color = color[0]
21 | im[seg == cl] = color
22 | return im
23 |
24 |
25 | def dataset_cat_description(path, cmap=None):
26 | desc = yaml.load(open(path, "r"), Loader=yaml.FullLoader)
27 | colors = {}
28 | names = []
29 | for i, cat in enumerate(desc):
30 | names.append(cat["name"])
31 | if "color" in cat:
32 | colors[cat["id"]] = torch.tensor(cat["color"]).float() / 255
33 | else:
34 | colors[cat["id"]] = torch.tensor(cmap[cat["id"]]).float()
35 | colors[IGNORE_LABEL] = torch.tensor([0.0, 0.0, 0.0]).float()
36 | return names, colors
37 |
38 |
39 | def rgb_normalize(x, stats):
40 | """
41 | x : C x *
42 | x \in [0, 1]
43 | """
44 | return F.normalize(x, stats["mean"], stats["std"])
45 |
46 |
47 | def rgb_denormalize(x, stats):
48 | """
49 | x : N x C x *
50 | x \in [-1, 1]
51 | """
52 | mean = torch.tensor(stats["mean"])
53 | std = torch.tensor(stats["std"])
54 | for i in range(3):
55 | x[:, i, :, :] = x[:, i, :, :] * std[i] + mean[i]
56 | return x
57 |
--------------------------------------------------------------------------------
/segm/engine.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 |
4 | from segm.utils.logger import MetricLogger
5 | from segm.metrics import gather_data, compute_metrics
6 | from segm.model import utils
7 | from segm.data.utils import IGNORE_LABEL
8 | import segm.utils.torch as ptu
9 |
10 |
11 | def train_one_epoch(
12 | model,
13 | data_loader,
14 | optimizer,
15 | lr_scheduler,
16 | epoch,
17 | amp_autocast,
18 | loss_scaler,
19 | ):
20 | criterion = torch.nn.CrossEntropyLoss(ignore_index=IGNORE_LABEL)
21 | logger = MetricLogger(delimiter=" ")
22 | header = f"Epoch: [{epoch}]"
23 | print_freq = 100
24 |
25 | model.train()
26 | data_loader.set_epoch(epoch)
27 | num_updates = epoch * len(data_loader)
28 | for batch in logger.log_every(data_loader, print_freq, header):
29 | im = batch["im"].to(ptu.device)
30 | seg_gt = batch["segmentation"].long().to(ptu.device)
31 |
32 | with amp_autocast():
33 | seg_pred = model.forward(im)
34 | loss = criterion(seg_pred, seg_gt)
35 |
36 | loss_value = loss.item()
37 | if not math.isfinite(loss_value):
38 | print("Loss is {}, stopping training".format(loss_value), force=True)
39 |
40 | optimizer.zero_grad()
41 | if loss_scaler is not None:
42 | loss_scaler(
43 | loss,
44 | optimizer,
45 | parameters=model.parameters(),
46 | )
47 | else:
48 | loss.backward()
49 | optimizer.step()
50 |
51 | num_updates += 1
52 | lr_scheduler.step_update(num_updates=num_updates)
53 |
54 | torch.cuda.synchronize()
55 |
56 | logger.update(
57 | loss=loss.item(),
58 | learning_rate=optimizer.param_groups[0]["lr"],
59 | )
60 |
61 | return logger
62 |
63 |
64 | @torch.no_grad()
65 | def evaluate(
66 | model,
67 | data_loader,
68 | val_seg_gt,
69 | window_size,
70 | window_stride,
71 | amp_autocast,
72 | ):
73 | model_without_ddp = model
74 | if hasattr(model, "module"):
75 | model_without_ddp = model.module
76 | logger = MetricLogger(delimiter=" ")
77 | header = "Eval:"
78 | print_freq = 50
79 |
80 | val_seg_pred = {}
81 | model.eval()
82 | for batch in logger.log_every(data_loader, print_freq, header):
83 | ims = [im.to(ptu.device) for im in batch["im"]]
84 | ims_metas = batch["im_metas"]
85 | ori_shape = ims_metas[0]["ori_shape"]
86 | ori_shape = (ori_shape[0].item(), ori_shape[1].item())
87 | filename = batch["im_metas"][0]["ori_filename"][0]
88 |
89 | with amp_autocast():
90 | seg_pred = utils.inference(
91 | model_without_ddp,
92 | ims,
93 | ims_metas,
94 | ori_shape,
95 | window_size,
96 | window_stride,
97 | batch_size=1,
98 | )
99 | seg_pred = seg_pred.argmax(0)
100 |
101 | seg_pred = seg_pred.cpu().numpy()
102 | val_seg_pred[filename] = seg_pred
103 |
104 | val_seg_pred = gather_data(val_seg_pred)
105 | scores = compute_metrics(
106 | val_seg_pred,
107 | val_seg_gt,
108 | data_loader.unwrapped.n_cls,
109 | ignore_index=IGNORE_LABEL,
110 | distributed=ptu.distributed,
111 | )
112 |
113 | for k, v in scores.items():
114 | logger.update(**{f"{k}": v, "n": 1})
115 |
116 | return logger
117 |
--------------------------------------------------------------------------------
/segm/eval/accuracy.py:
--------------------------------------------------------------------------------
1 | import click
2 | import torch
3 |
4 | import segm.utils.torch as ptu
5 |
6 | from segm.utils.logger import MetricLogger
7 |
8 | from segm.model.factory import create_vit
9 | from segm.data.factory import create_dataset
10 | from segm.data.utils import STATS
11 | from segm.metrics import accuracy
12 | from segm import config
13 |
14 |
15 | def compute_labels(model, batch):
16 | im = batch["im"]
17 | target = batch["target"]
18 |
19 | with torch.no_grad():
20 | with torch.cuda.amp.autocast():
21 | output = model.forward(im)
22 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
23 |
24 | return acc1.item(), acc5.item()
25 |
26 |
27 | def eval_dataset(model, dataset_kwargs):
28 | db = create_dataset(dataset_kwargs)
29 | print_freq = 20
30 | header = ""
31 | logger = MetricLogger(delimiter=" ")
32 |
33 | for batch in logger.log_every(db, print_freq, header):
34 | for k, v in batch.items():
35 | batch[k] = v.to(ptu.device)
36 | acc1, acc5 = compute_labels(model, batch)
37 | batch_size = batch["im"].size(0)
38 | logger.update(acc1=acc1, n=batch_size)
39 | logger.update(acc5=acc5, n=batch_size)
40 | print(f"Imagenet accuracy: {logger}")
41 |
42 |
43 | @click.command()
44 | @click.argument("backbone", type=str)
45 | @click.option("--imagenet-dir", type=str)
46 | @click.option("-bs", "--batch-size", default=32, type=int)
47 | @click.option("-nw", "--num-workers", default=10, type=int)
48 | @click.option("-gpu", "--gpu/--no-gpu", default=True, is_flag=True)
49 | def main(backbone, imagenet_dir, batch_size, num_workers, gpu):
50 | ptu.set_gpu_mode(gpu)
51 | cfg = config.load_config()
52 | cfg = cfg["model"][backbone]
53 | cfg["backbone"] = backbone
54 | cfg["image_size"] = (cfg["image_size"], cfg["image_size"])
55 |
56 | dataset_kwargs = dict(
57 | dataset="imagenet",
58 | root_dir=imagenet_dir,
59 | image_size=cfg["image_size"],
60 | crop_size=cfg["image_size"],
61 | patch_size=cfg["patch_size"],
62 | batch_size=batch_size,
63 | num_workers=num_workers,
64 | split="val",
65 | normalization=STATS[cfg["normalization"]],
66 | )
67 |
68 | model = create_vit(cfg)
69 | model.to(ptu.device)
70 | model.eval()
71 | eval_dataset(model, dataset_kwargs)
72 |
73 |
74 | if __name__ == "__main__":
75 | main()
76 |
--------------------------------------------------------------------------------
/segm/eval/miou.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import click
3 | from pathlib import Path
4 | import yaml
5 | import numpy as np
6 | from PIL import Image
7 | import shutil
8 |
9 | import torch
10 | import torch.nn.functional as F
11 | from torch.nn.parallel import DistributedDataParallel as DDP
12 |
13 | from segm.utils import distributed
14 | from segm.utils.logger import MetricLogger
15 | import segm.utils.torch as ptu
16 |
17 | from segm.model.factory import load_model
18 | from segm.data.factory import create_dataset
19 | from segm.metrics import gather_data, compute_metrics
20 |
21 | from segm.model.utils import inference
22 | from segm.data.utils import seg_to_rgb, rgb_denormalize, IGNORE_LABEL
23 | from segm import config
24 |
25 |
26 | def blend_im(im, seg, alpha=0.5):
27 | pil_im = Image.fromarray(im)
28 | pil_seg = Image.fromarray(seg)
29 | im_blend = Image.blend(pil_im, pil_seg, alpha).convert("RGB")
30 | return np.asarray(im_blend)
31 |
32 |
33 | def save_im(save_dir, save_name, im, seg_pred, seg_gt, colors, blend, normalization):
34 | seg_rgb = seg_to_rgb(seg_gt[None], colors)
35 | pred_rgb = seg_to_rgb(seg_pred[None], colors)
36 | im_unnorm = rgb_denormalize(im, normalization)
37 | save_dir = Path(save_dir)
38 |
39 | # save images
40 | im_uint = (im_unnorm.permute(0, 2, 3, 1).cpu().numpy()).astype(np.uint8)
41 | seg_rgb_uint = (255 * seg_rgb.cpu().numpy()).astype(np.uint8)
42 | seg_pred_uint = (255 * pred_rgb.cpu().numpy()).astype(np.uint8)
43 | for i in range(pred_rgb.shape[0]):
44 | if blend:
45 | blend_pred = blend_im(im_uint[i], seg_pred_uint[i])
46 | blend_gt = blend_im(im_uint[i], seg_rgb_uint[i])
47 | ims = (im_uint[i], blend_pred, blend_gt)
48 | else:
49 | ims = (im_uint[i], seg_pred_uint[i], seg_rgb_uint[i])
50 | for im, im_dir in zip(
51 | ims, (save_dir / "input", save_dir / "pred", save_dir / "gt"),
52 | ):
53 | pil_out = Image.fromarray(im)
54 | im_dir.mkdir(exist_ok=True)
55 | pil_out.save(im_dir / save_name)
56 |
57 |
58 | def process_batch(
59 | model, batch, window_size, window_stride, window_batch_size,
60 | ):
61 | ims = batch["im"]
62 | ims_metas = batch["im_metas"]
63 | ori_shape = ims_metas[0]["ori_shape"]
64 | ori_shape = (ori_shape[0].item(), ori_shape[1].item())
65 | filename = batch["im_metas"][0]["ori_filename"][0]
66 |
67 | model_without_ddp = model
68 | if ptu.distributed:
69 | model_without_ddp = model.module
70 | seg_pred = inference(
71 | model_without_ddp,
72 | ims,
73 | ims_metas,
74 | ori_shape,
75 | window_size,
76 | window_stride,
77 | window_batch_size,
78 | )
79 | seg_pred = seg_pred.argmax(0)
80 | im = F.interpolate(ims[-1], ori_shape, mode="bilinear")
81 |
82 | return filename, im.cpu(), seg_pred.cpu()
83 |
84 |
85 | def eval_dataset(
86 | model,
87 | multiscale,
88 | model_dir,
89 | blend,
90 | window_size,
91 | window_stride,
92 | window_batch_size,
93 | save_images,
94 | frac_dataset,
95 | dataset_kwargs,
96 | ):
97 | db = create_dataset(dataset_kwargs)
98 | normalization = db.dataset.normalization
99 | dataset_name = dataset_kwargs["dataset"]
100 | im_size = dataset_kwargs["image_size"]
101 | cat_names = db.base_dataset.names
102 | n_cls = db.unwrapped.n_cls
103 | if multiscale:
104 | db.dataset.set_multiscale_mode()
105 |
106 | logger = MetricLogger(delimiter=" ")
107 | header = ""
108 | print_freq = 50
109 |
110 | ims = {}
111 | seg_pred_maps = {}
112 | idx = 0
113 | for batch in logger.log_every(db, print_freq, header):
114 | colors = batch["colors"]
115 | filename, im, seg_pred = process_batch(
116 | model, batch, window_size, window_stride, window_batch_size,
117 | )
118 | ims[filename] = im
119 | seg_pred_maps[filename] = seg_pred
120 | idx += 1
121 | if idx > len(db) * frac_dataset:
122 | break
123 |
124 | seg_gt_maps = db.dataset.get_gt_seg_maps()
125 | if save_images:
126 | save_dir = model_dir / "images"
127 | if ptu.dist_rank == 0:
128 | if save_dir.exists():
129 | shutil.rmtree(save_dir)
130 | save_dir.mkdir()
131 | if ptu.distributed:
132 | torch.distributed.barrier()
133 |
134 | for name in sorted(ims):
135 | instance_dir = save_dir
136 | filename = name
137 |
138 | if dataset_name == "cityscapes":
139 | filename_list = name.split("/")
140 | instance_dir = instance_dir / filename_list[0]
141 | filename = filename_list[-1]
142 | if not instance_dir.exists():
143 | instance_dir.mkdir()
144 |
145 | save_im(
146 | instance_dir,
147 | filename,
148 | ims[name],
149 | seg_pred_maps[name],
150 | torch.tensor(seg_gt_maps[name]),
151 | colors,
152 | blend,
153 | normalization,
154 | )
155 | if ptu.dist_rank == 0:
156 | shutil.make_archive(save_dir, "zip", save_dir)
157 | # shutil.rmtree(save_dir)
158 | print(f"Saved eval images in {save_dir}.zip")
159 |
160 | if ptu.distributed:
161 | torch.distributed.barrier()
162 | seg_pred_maps = gather_data(seg_pred_maps)
163 |
164 | scores = compute_metrics(
165 | seg_pred_maps,
166 | seg_gt_maps,
167 | n_cls,
168 | ignore_index=IGNORE_LABEL,
169 | ret_cat_iou=True,
170 | distributed=ptu.distributed,
171 | )
172 |
173 | if ptu.dist_rank == 0:
174 | scores["inference"] = "single_scale" if not multiscale else "multi_scale"
175 | suffix = "ss" if not multiscale else "ms"
176 | scores["cat_iou"] = np.round(100 * scores["cat_iou"], 2).tolist()
177 | for k, v in scores.items():
178 | if k != "cat_iou" and k != "inference":
179 | scores[k] = v.item()
180 | if k != "cat_iou":
181 | print(f"{k}: {scores[k]}")
182 | scores_str = yaml.dump(scores)
183 | with open(model_dir / f"scores_{suffix}.yml", "w") as f:
184 | f.write(scores_str)
185 |
186 |
187 | @click.command()
188 | @click.argument("model_path", type=str)
189 | @click.argument("dataset_name", type=str)
190 | @click.option("--im-size", default=None, type=int)
191 | @click.option("--multiscale/--singlescale", default=False, is_flag=True)
192 | @click.option("--blend/--no-blend", default=True, is_flag=True)
193 | @click.option("--window-size", default=None, type=int)
194 | @click.option("--window-stride", default=None, type=int)
195 | @click.option("--window-batch-size", default=4, type=int)
196 | @click.option("--save-images/--no-save-images", default=False, is_flag=True)
197 | @click.option("-frac-dataset", "--frac-dataset", default=1.0, type=float)
198 | def main(
199 | model_path,
200 | dataset_name,
201 | im_size,
202 | multiscale,
203 | blend,
204 | window_size,
205 | window_stride,
206 | window_batch_size,
207 | save_images,
208 | frac_dataset,
209 | ):
210 |
211 | model_dir = Path(model_path).parent
212 |
213 | # start distributed mode
214 | ptu.set_gpu_mode(True)
215 | distributed.init_process()
216 |
217 | model, variant = load_model(model_path)
218 | patch_size = model.patch_size
219 | model.eval()
220 | model.to(ptu.device)
221 | if ptu.distributed:
222 | model = DDP(model, device_ids=[ptu.device], find_unused_parameters=True)
223 |
224 | cfg = config.load_config()
225 | dataset_cfg = cfg["dataset"][dataset_name]
226 | normalization = variant["dataset_kwargs"]["normalization"]
227 | if im_size is None:
228 | im_size = dataset_cfg.get("im_size", variant["dataset_kwargs"]["image_size"])
229 | if window_size is None:
230 | window_size = variant["dataset_kwargs"]["crop_size"]
231 | if window_stride is None:
232 | window_stride = variant["dataset_kwargs"]["crop_size"] - 32
233 |
234 | dataset_kwargs = dict(
235 | dataset=dataset_name,
236 | image_size=im_size,
237 | crop_size=im_size,
238 | patch_size=patch_size,
239 | batch_size=1,
240 | num_workers=10,
241 | split="val",
242 | normalization=normalization,
243 | crop=False,
244 | rep_aug=False,
245 | )
246 |
247 | eval_dataset(
248 | model,
249 | multiscale,
250 | model_dir,
251 | blend,
252 | window_size,
253 | window_stride,
254 | window_batch_size,
255 | save_images,
256 | frac_dataset,
257 | dataset_kwargs,
258 | )
259 |
260 | distributed.barrier()
261 | distributed.destroy_process()
262 | sys.exit(1)
263 |
264 |
265 | if __name__ == "__main__":
266 | main()
267 |
--------------------------------------------------------------------------------
/segm/inference.py:
--------------------------------------------------------------------------------
1 | import click
2 | from tqdm import tqdm
3 | from pathlib import Path
4 | from PIL import Image
5 | import numpy as np
6 | import torchvision.transforms.functional as F
7 |
8 | import segm.utils.torch as ptu
9 |
10 | from segm.data.utils import STATS
11 | from segm.data.ade20k import ADE20K_CATS_PATH
12 | from segm.data.utils import dataset_cat_description, seg_to_rgb
13 |
14 | from segm.model.factory import load_model
15 | from segm.model.utils import inference
16 |
17 |
18 | @click.command()
19 | @click.option("--model-path", type=str)
20 | @click.option("--input-dir", "-i", type=str, help="folder with input images")
21 | @click.option("--output-dir", "-o", type=str, help="folder with output images")
22 | @click.option("--gpu/--cpu", default=True, is_flag=True)
23 | def main(model_path, input_dir, output_dir, gpu):
24 | ptu.set_gpu_mode(gpu)
25 |
26 | model_dir = Path(model_path).parent
27 | model, variant = load_model(model_path)
28 | model.to(ptu.device)
29 |
30 | normalization_name = variant["dataset_kwargs"]["normalization"]
31 | normalization = STATS[normalization_name]
32 | cat_names, cat_colors = dataset_cat_description(ADE20K_CATS_PATH)
33 |
34 | input_dir = Path(input_dir)
35 | output_dir = Path(output_dir)
36 | output_dir.mkdir(exist_ok=True)
37 |
38 | list_dir = list(input_dir.iterdir())
39 | for filename in tqdm(list_dir, ncols=80):
40 | pil_im = Image.open(filename).copy()
41 | im = F.pil_to_tensor(pil_im).float() / 255
42 | im = F.normalize(im, normalization["mean"], normalization["std"])
43 | im = im.to(ptu.device).unsqueeze(0)
44 |
45 | im_meta = dict(flip=False)
46 | logits = inference(
47 | model,
48 | [im],
49 | [im_meta],
50 | ori_shape=im.shape[2:4],
51 | window_size=variant["inference_kwargs"]["window_size"],
52 | window_stride=variant["inference_kwargs"]["window_stride"],
53 | batch_size=2,
54 | )
55 | seg_map = logits.argmax(0, keepdim=True)
56 | seg_rgb = seg_to_rgb(seg_map, cat_colors)
57 | seg_rgb = (255 * seg_rgb.cpu().numpy()).astype(np.uint8)
58 | pil_seg = Image.fromarray(seg_rgb[0])
59 |
60 | pil_blend = Image.blend(pil_im, pil_seg, 0.5).convert("RGB")
61 | pil_blend.save(output_dir / filename.name)
62 |
63 |
64 | if __name__ == "__main__":
65 | main()
66 |
--------------------------------------------------------------------------------
/segm/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.distributed as dist
4 | import segm.utils.torch as ptu
5 |
6 | import os
7 | import pickle as pkl
8 | from pathlib import Path
9 | import tempfile
10 | import shutil
11 | from mmseg.core import mean_iou
12 |
13 | """
14 | ImageNet classifcation accuracy
15 | """
16 |
17 |
18 | def accuracy(output, target, topk=(1,)):
19 | """
20 | https://github.com/pytorch/examples/blob/master/imagenet/main.py
21 | Computes the accuracy over the k top predictions for the specified values of k
22 | """
23 | with torch.no_grad():
24 | maxk = max(topk)
25 | batch_size = target.size(0)
26 |
27 | _, pred = output.topk(maxk, 1, True, True)
28 | pred = pred.t()
29 | correct = pred.eq(target.view(1, -1).expand_as(pred))
30 |
31 | res = []
32 | for k in topk:
33 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
34 | correct_k /= batch_size
35 | res.append(correct_k)
36 | return res
37 |
38 |
39 | """
40 | Segmentation mean IoU
41 | based on collect_results_cpu
42 | https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/apis/test.py#L160-L200
43 | """
44 |
45 |
46 | def gather_data(seg_pred, tmp_dir=None):
47 | """
48 | distributed data gathering
49 | prediction and ground truth are stored in a common tmp directory
50 | and loaded on the master node to compute metrics
51 | """
52 | if tmp_dir is None:
53 | tmpprefix = os.path.expandvars("$DATASET/temp")
54 | else:
55 | tmpprefix = os.path.expandvars(tmp_dir)
56 | MAX_LEN = 512
57 | # 32 is whitespace
58 | dir_tensor = torch.full((MAX_LEN,), 32, dtype=torch.uint8, device=ptu.device)
59 | if ptu.dist_rank == 0:
60 | tmpdir = tempfile.mkdtemp(prefix=tmpprefix)
61 | tmpdir = torch.tensor(
62 | bytearray(tmpdir.encode()), dtype=torch.uint8, device=ptu.device
63 | )
64 | dir_tensor[: len(tmpdir)] = tmpdir
65 | # broadcast tmpdir from 0 to to the other nodes
66 | dist.broadcast(dir_tensor, 0)
67 | tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
68 | tmpdir = Path(tmpdir)
69 | """
70 | Save results in temp file and load them on main process
71 | """
72 | tmp_file = tmpdir / f"part_{ptu.dist_rank}.pkl"
73 | pkl.dump(seg_pred, open(tmp_file, "wb"))
74 | dist.barrier()
75 | seg_pred = {}
76 | if ptu.dist_rank == 0:
77 | for i in range(ptu.world_size):
78 | part_seg_pred = pkl.load(open(tmpdir / f"part_{i}.pkl", "rb"))
79 | seg_pred.update(part_seg_pred)
80 | shutil.rmtree(tmpdir)
81 | return seg_pred
82 |
83 |
84 | def compute_metrics(
85 | seg_pred,
86 | seg_gt,
87 | n_cls,
88 | ignore_index=None,
89 | ret_cat_iou=False,
90 | tmp_dir=None,
91 | distributed=False,
92 | ):
93 | ret_metrics_mean = torch.zeros(3, dtype=float, device=ptu.device)
94 | if ptu.dist_rank == 0:
95 | list_seg_pred = []
96 | list_seg_gt = []
97 | keys = sorted(seg_pred.keys())
98 | for k in keys:
99 | list_seg_pred.append(np.asarray(seg_pred[k]))
100 | list_seg_gt.append(np.asarray(seg_gt[k]))
101 | ret_metrics = mean_iou(
102 | results=list_seg_pred,
103 | gt_seg_maps=list_seg_gt,
104 | num_classes=n_cls,
105 | ignore_index=ignore_index,
106 | )
107 | ret_metrics = [ret_metrics["aAcc"], ret_metrics["Acc"], ret_metrics["IoU"]]
108 | ret_metrics_mean = torch.tensor(
109 | [
110 | np.round(np.nanmean(ret_metric.astype(np.float)) * 100, 2)
111 | for ret_metric in ret_metrics
112 | ],
113 | dtype=float,
114 | device=ptu.device,
115 | )
116 | cat_iou = ret_metrics[2]
117 | # broadcast metrics from 0 to all nodes
118 | if distributed:
119 | dist.broadcast(ret_metrics_mean, 0)
120 | pix_acc, mean_acc, miou = ret_metrics_mean
121 | ret = dict(pixel_accuracy=pix_acc, mean_accuracy=mean_acc, mean_iou=miou)
122 | if ret_cat_iou and ptu.dist_rank == 0:
123 | ret["cat_iou"] = cat_iou
124 | return ret
125 |
--------------------------------------------------------------------------------
/segm/model/blocks.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from 2020 Ross Wightman
3 | https://github.com/rwightman/pytorch-image-models
4 | """
5 |
6 | import torch
7 | import torch.nn as nn
8 | from einops import rearrange
9 | from pathlib import Path
10 |
11 | import torch.nn.functional as F
12 |
13 | from timm.models.layers import DropPath
14 |
15 |
16 | class FeedForward(nn.Module):
17 | def __init__(self, dim, hidden_dim, dropout, out_dim=None):
18 | super().__init__()
19 | self.fc1 = nn.Linear(dim, hidden_dim)
20 | self.act = nn.GELU()
21 | if out_dim is None:
22 | out_dim = dim
23 | self.fc2 = nn.Linear(hidden_dim, out_dim)
24 | self.drop = nn.Dropout(dropout)
25 |
26 | @property
27 | def unwrapped(self):
28 | return self
29 |
30 | def forward(self, x):
31 | x = self.fc1(x)
32 | x = self.act(x)
33 | x = self.drop(x)
34 | x = self.fc2(x)
35 | x = self.drop(x)
36 | return x
37 |
38 |
39 | class Attention(nn.Module):
40 | def __init__(self, dim, heads, dropout):
41 | super().__init__()
42 | self.heads = heads
43 | head_dim = dim // heads
44 | self.scale = head_dim ** -0.5
45 | self.attn = None
46 |
47 | self.qkv = nn.Linear(dim, dim * 3)
48 | self.attn_drop = nn.Dropout(dropout)
49 | self.proj = nn.Linear(dim, dim)
50 | self.proj_drop = nn.Dropout(dropout)
51 |
52 | @property
53 | def unwrapped(self):
54 | return self
55 |
56 | def forward(self, x, mask=None):
57 | B, N, C = x.shape
58 | qkv = (
59 | self.qkv(x)
60 | .reshape(B, N, 3, self.heads, C // self.heads)
61 | .permute(2, 0, 3, 1, 4)
62 | )
63 | q, k, v = (
64 | qkv[0],
65 | qkv[1],
66 | qkv[2],
67 | )
68 |
69 | attn = (q @ k.transpose(-2, -1)) * self.scale
70 | attn = attn.softmax(dim=-1)
71 | attn = self.attn_drop(attn)
72 |
73 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
74 | x = self.proj(x)
75 | x = self.proj_drop(x)
76 |
77 | return x, attn
78 |
79 |
80 | class Block(nn.Module):
81 | def __init__(self, dim, heads, mlp_dim, dropout, drop_path):
82 | super().__init__()
83 | self.norm1 = nn.LayerNorm(dim)
84 | self.norm2 = nn.LayerNorm(dim)
85 | self.attn = Attention(dim, heads, dropout)
86 | self.mlp = FeedForward(dim, mlp_dim, dropout)
87 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
88 |
89 | def forward(self, x, mask=None, return_attention=False):
90 | y, attn = self.attn(self.norm1(x), mask)
91 | if return_attention:
92 | return attn
93 | x = x + self.drop_path(y)
94 | x = x + self.drop_path(self.mlp(self.norm2(x)))
95 | return x
96 |
--------------------------------------------------------------------------------
/segm/model/decoder.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from einops import rearrange
6 |
7 | from timm.models.layers import trunc_normal_
8 |
9 | from segm.model.blocks import Block, FeedForward
10 | from segm.model.utils import init_weights
11 |
12 |
13 | class DecoderLinear(nn.Module):
14 | def __init__(self, n_cls, patch_size, d_encoder):
15 | super().__init__()
16 |
17 | self.d_encoder = d_encoder
18 | self.patch_size = patch_size
19 | self.n_cls = n_cls
20 |
21 | self.head = nn.Linear(self.d_encoder, n_cls)
22 | self.apply(init_weights)
23 |
24 | @torch.jit.ignore
25 | def no_weight_decay(self):
26 | return set()
27 |
28 | def forward(self, x, im_size):
29 | H, W = im_size
30 | GS = H // self.patch_size
31 | x = self.head(x)
32 | x = rearrange(x, "b (h w) c -> b c h w", h=GS)
33 |
34 | return x
35 |
36 |
37 | class MaskTransformer(nn.Module):
38 | def __init__(
39 | self,
40 | n_cls,
41 | patch_size,
42 | d_encoder,
43 | n_layers,
44 | n_heads,
45 | d_model,
46 | d_ff,
47 | drop_path_rate,
48 | dropout,
49 | ):
50 | super().__init__()
51 | self.d_encoder = d_encoder
52 | self.patch_size = patch_size
53 | self.n_layers = n_layers
54 | self.n_cls = n_cls
55 | self.d_model = d_model
56 | self.d_ff = d_ff
57 | self.scale = d_model ** -0.5
58 |
59 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)]
60 | self.blocks = nn.ModuleList(
61 | [Block(d_model, n_heads, d_ff, dropout, dpr[i]) for i in range(n_layers)]
62 | )
63 |
64 | self.cls_emb = nn.Parameter(torch.randn(1, n_cls, d_model))
65 | self.proj_dec = nn.Linear(d_encoder, d_model)
66 |
67 | self.proj_patch = nn.Parameter(self.scale * torch.randn(d_model, d_model))
68 | self.proj_classes = nn.Parameter(self.scale * torch.randn(d_model, d_model))
69 |
70 | self.decoder_norm = nn.LayerNorm(d_model)
71 | self.mask_norm = nn.LayerNorm(n_cls)
72 |
73 | self.apply(init_weights)
74 | trunc_normal_(self.cls_emb, std=0.02)
75 |
76 | @torch.jit.ignore
77 | def no_weight_decay(self):
78 | return {"cls_emb"}
79 |
80 | def forward(self, x, im_size):
81 | H, W = im_size
82 | GS = H // self.patch_size
83 |
84 | x = self.proj_dec(x)
85 | cls_emb = self.cls_emb.expand(x.size(0), -1, -1)
86 | x = torch.cat((x, cls_emb), 1)
87 | for blk in self.blocks:
88 | x = blk(x)
89 | x = self.decoder_norm(x)
90 |
91 | patches, cls_seg_feat = x[:, : -self.n_cls], x[:, -self.n_cls :]
92 | patches = patches @ self.proj_patch
93 | cls_seg_feat = cls_seg_feat @ self.proj_classes
94 |
95 | patches = patches / patches.norm(dim=-1, keepdim=True)
96 | cls_seg_feat = cls_seg_feat / cls_seg_feat.norm(dim=-1, keepdim=True)
97 |
98 | masks = patches @ cls_seg_feat.transpose(1, 2)
99 | masks = self.mask_norm(masks)
100 | masks = rearrange(masks, "b (h w) n -> b n h w", h=int(GS))
101 |
102 | return masks
103 |
104 | def get_attention_map(self, x, layer_id):
105 | if layer_id >= self.n_layers or layer_id < 0:
106 | raise ValueError(
107 | f"Provided layer_id: {layer_id} is not valid. 0 <= {layer_id} < {self.n_layers}."
108 | )
109 | x = self.proj_dec(x)
110 | cls_emb = self.cls_emb.expand(x.size(0), -1, -1)
111 | x = torch.cat((x, cls_emb), 1)
112 | for i, blk in enumerate(self.blocks):
113 | if i < layer_id:
114 | x = blk(x)
115 | else:
116 | return blk(x, return_attention=True)
117 |
--------------------------------------------------------------------------------
/segm/model/factory.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import yaml
3 | import torch
4 | import math
5 | import os
6 | import torch.nn as nn
7 |
8 | from timm.models.helpers import load_pretrained, load_custom_pretrained
9 | from timm.models.vision_transformer import default_cfgs
10 | from timm.models.registry import register_model
11 | from timm.models.vision_transformer import _create_vision_transformer
12 |
13 | from segm.model.vit import VisionTransformer
14 | from segm.model.utils import checkpoint_filter_fn
15 | from segm.model.decoder import DecoderLinear
16 | from segm.model.decoder import MaskTransformer
17 | from segm.model.segmenter import Segmenter
18 | import segm.utils.torch as ptu
19 |
20 |
21 | @register_model
22 | def vit_base_patch8_384(pretrained=False, **kwargs):
23 | """ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
24 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
25 | """
26 | model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs)
27 | model = _create_vision_transformer(
28 | "vit_base_patch8_384",
29 | pretrained=pretrained,
30 | default_cfg=dict(
31 | url="",
32 | input_size=(3, 384, 384),
33 | mean=(0.5, 0.5, 0.5),
34 | std=(0.5, 0.5, 0.5),
35 | num_classes=1000,
36 | ),
37 | **model_kwargs,
38 | )
39 | return model
40 |
41 |
42 | def create_vit(model_cfg):
43 | model_cfg = model_cfg.copy()
44 | backbone = model_cfg.pop("backbone")
45 |
46 | normalization = model_cfg.pop("normalization")
47 | model_cfg["n_cls"] = 1000
48 | mlp_expansion_ratio = 4
49 | model_cfg["d_ff"] = mlp_expansion_ratio * model_cfg["d_model"]
50 |
51 | if backbone in default_cfgs:
52 | default_cfg = default_cfgs[backbone]
53 | else:
54 | default_cfg = dict(
55 | pretrained=False,
56 | num_classes=1000,
57 | drop_rate=0.0,
58 | drop_path_rate=0.0,
59 | drop_block_rate=None,
60 | )
61 |
62 | default_cfg["input_size"] = (
63 | 3,
64 | model_cfg["image_size"][0],
65 | model_cfg["image_size"][1],
66 | )
67 | model = VisionTransformer(**model_cfg)
68 | if backbone == "vit_base_patch8_384":
69 | path = os.path.expandvars("$TORCH_HOME/hub/checkpoints/vit_base_patch8_384.pth")
70 | state_dict = torch.load(path, map_location="cpu")
71 | filtered_dict = checkpoint_filter_fn(state_dict, model)
72 | model.load_state_dict(filtered_dict, strict=True)
73 | elif "deit" in backbone:
74 | load_pretrained(model, default_cfg, filter_fn=checkpoint_filter_fn)
75 | else:
76 | load_custom_pretrained(model, default_cfg)
77 |
78 | return model
79 |
80 |
81 | def create_decoder(encoder, decoder_cfg):
82 | decoder_cfg = decoder_cfg.copy()
83 | name = decoder_cfg.pop("name")
84 | decoder_cfg["d_encoder"] = encoder.d_model
85 | decoder_cfg["patch_size"] = encoder.patch_size
86 |
87 | if "linear" in name:
88 | decoder = DecoderLinear(**decoder_cfg)
89 | elif name == "mask_transformer":
90 | dim = encoder.d_model
91 | n_heads = dim // 64
92 | decoder_cfg["n_heads"] = n_heads
93 | decoder_cfg["d_model"] = dim
94 | decoder_cfg["d_ff"] = 4 * dim
95 | decoder = MaskTransformer(**decoder_cfg)
96 | else:
97 | raise ValueError(f"Unknown decoder: {name}")
98 | return decoder
99 |
100 |
101 | def create_segmenter(model_cfg):
102 | model_cfg = model_cfg.copy()
103 | decoder_cfg = model_cfg.pop("decoder")
104 | decoder_cfg["n_cls"] = model_cfg["n_cls"]
105 |
106 | encoder = create_vit(model_cfg)
107 | decoder = create_decoder(encoder, decoder_cfg)
108 | model = Segmenter(encoder, decoder, n_cls=model_cfg["n_cls"])
109 |
110 | return model
111 |
112 |
113 | def load_model(model_path):
114 | variant_path = Path(model_path).parent / "variant.yml"
115 | with open(variant_path, "r") as f:
116 | variant = yaml.load(f, Loader=yaml.FullLoader)
117 | net_kwargs = variant["net_kwargs"]
118 |
119 | model = create_segmenter(net_kwargs)
120 | data = torch.load(model_path, map_location=ptu.device)
121 | checkpoint = data["model"]
122 |
123 | model.load_state_dict(checkpoint, strict=True)
124 |
125 | return model, variant
126 |
--------------------------------------------------------------------------------
/segm/model/segmenter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from segm.model.utils import padding, unpadding
6 | from timm.models.layers import trunc_normal_
7 |
8 |
9 | class Segmenter(nn.Module):
10 | def __init__(
11 | self,
12 | encoder,
13 | decoder,
14 | n_cls,
15 | ):
16 | super().__init__()
17 | self.n_cls = n_cls
18 | self.patch_size = encoder.patch_size
19 | self.encoder = encoder
20 | self.decoder = decoder
21 |
22 | @torch.jit.ignore
23 | def no_weight_decay(self):
24 | def append_prefix_no_weight_decay(prefix, module):
25 | return set(map(lambda x: prefix + x, module.no_weight_decay()))
26 |
27 | nwd_params = append_prefix_no_weight_decay("encoder.", self.encoder).union(
28 | append_prefix_no_weight_decay("decoder.", self.decoder)
29 | )
30 | return nwd_params
31 |
32 | def forward(self, im):
33 | H_ori, W_ori = im.size(2), im.size(3)
34 | im = padding(im, self.patch_size)
35 | H, W = im.size(2), im.size(3)
36 |
37 | x = self.encoder(im, return_features=True)
38 |
39 | # remove CLS/DIST tokens for decoding
40 | num_extra_tokens = 1 + self.encoder.distilled
41 | x = x[:, num_extra_tokens:]
42 |
43 | masks = self.decoder(x, (H, W))
44 |
45 | masks = F.interpolate(masks, size=(H, W), mode="bilinear")
46 | masks = unpadding(masks, (H_ori, W_ori))
47 |
48 | return masks
49 |
50 | def get_attention_map_enc(self, im, layer_id):
51 | return self.encoder.get_attention_map(im, layer_id)
52 |
53 | def get_attention_map_dec(self, im, layer_id):
54 | x = self.encoder(im, return_features=True)
55 |
56 | # remove CLS/DIST tokens for decoding
57 | num_extra_tokens = 1 + self.encoder.distilled
58 | x = x[:, num_extra_tokens:]
59 |
60 | return self.decoder.get_attention_map(x, layer_id)
61 |
--------------------------------------------------------------------------------
/segm/model/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 | from collections import defaultdict
6 |
7 | from timm.models.layers import trunc_normal_
8 |
9 | import segm.utils.torch as ptu
10 |
11 |
12 | def init_weights(m):
13 | if isinstance(m, nn.Linear):
14 | trunc_normal_(m.weight, std=0.02)
15 | if isinstance(m, nn.Linear) and m.bias is not None:
16 | nn.init.constant_(m.bias, 0)
17 | elif isinstance(m, nn.LayerNorm):
18 | nn.init.constant_(m.bias, 0)
19 | nn.init.constant_(m.weight, 1.0)
20 |
21 |
22 | def resize_pos_embed(posemb, grid_old_shape, grid_new_shape, num_extra_tokens):
23 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from
24 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
25 | posemb_tok, posemb_grid = (
26 | posemb[:, :num_extra_tokens],
27 | posemb[0, num_extra_tokens:],
28 | )
29 | if grid_old_shape is None:
30 | gs_old_h = int(math.sqrt(len(posemb_grid)))
31 | gs_old_w = gs_old_h
32 | else:
33 | gs_old_h, gs_old_w = grid_old_shape
34 |
35 | gs_h, gs_w = grid_new_shape
36 | posemb_grid = posemb_grid.reshape(1, gs_old_h, gs_old_w, -1).permute(0, 3, 1, 2)
37 | posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
38 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
39 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
40 | return posemb
41 |
42 |
43 | def checkpoint_filter_fn(state_dict, model):
44 | """ convert patch embedding weight from manual patchify + linear proj to conv"""
45 | out_dict = {}
46 | if "model" in state_dict:
47 | # For deit models
48 | state_dict = state_dict["model"]
49 | num_extra_tokens = 1 + ("dist_token" in state_dict.keys())
50 | patch_size = model.patch_size
51 | image_size = model.patch_embed.image_size
52 | for k, v in state_dict.items():
53 | if k == "pos_embed" and v.shape != model.pos_embed.shape:
54 | # To resize pos embedding when using model at different size from pretrained weights
55 | v = resize_pos_embed(
56 | v,
57 | None,
58 | (image_size[0] // patch_size, image_size[1] // patch_size),
59 | num_extra_tokens,
60 | )
61 | out_dict[k] = v
62 | return out_dict
63 |
64 |
65 | def padding(im, patch_size, fill_value=0):
66 | # make the image sizes divisible by patch_size
67 | H, W = im.size(2), im.size(3)
68 | pad_h, pad_w = 0, 0
69 | if H % patch_size > 0:
70 | pad_h = patch_size - (H % patch_size)
71 | if W % patch_size > 0:
72 | pad_w = patch_size - (W % patch_size)
73 | im_padded = im
74 | if pad_h > 0 or pad_w > 0:
75 | im_padded = F.pad(im, (0, pad_w, 0, pad_h), value=fill_value)
76 | return im_padded
77 |
78 |
79 | def unpadding(y, target_size):
80 | H, W = target_size
81 | H_pad, W_pad = y.size(2), y.size(3)
82 | # crop predictions on extra pixels coming from padding
83 | extra_h = H_pad - H
84 | extra_w = W_pad - W
85 | if extra_h > 0:
86 | y = y[:, :, :-extra_h]
87 | if extra_w > 0:
88 | y = y[:, :, :, :-extra_w]
89 | return y
90 |
91 |
92 | def resize(im, smaller_size):
93 | h, w = im.shape[2:]
94 | if h < w:
95 | ratio = w / h
96 | h_res, w_res = smaller_size, ratio * smaller_size
97 | else:
98 | ratio = h / w
99 | h_res, w_res = ratio * smaller_size, smaller_size
100 | if min(h, w) < smaller_size:
101 | im_res = F.interpolate(im, (int(h_res), int(w_res)), mode="bilinear")
102 | else:
103 | im_res = im
104 | return im_res
105 |
106 |
107 | def sliding_window(im, flip, window_size, window_stride):
108 | B, C, H, W = im.shape
109 | ws = window_size
110 |
111 | windows = {"crop": [], "anchors": []}
112 | h_anchors = torch.arange(0, H, window_stride)
113 | w_anchors = torch.arange(0, W, window_stride)
114 | h_anchors = [h.item() for h in h_anchors if h < H - ws] + [H - ws]
115 | w_anchors = [w.item() for w in w_anchors if w < W - ws] + [W - ws]
116 | for ha in h_anchors:
117 | for wa in w_anchors:
118 | window = im[:, :, ha : ha + ws, wa : wa + ws]
119 | windows["crop"].append(window)
120 | windows["anchors"].append((ha, wa))
121 | windows["flip"] = flip
122 | windows["shape"] = (H, W)
123 | return windows
124 |
125 |
126 | def merge_windows(windows, window_size, ori_shape):
127 | ws = window_size
128 | im_windows = windows["seg_maps"]
129 | anchors = windows["anchors"]
130 | C = im_windows[0].shape[0]
131 | H, W = windows["shape"]
132 | flip = windows["flip"]
133 |
134 | logit = torch.zeros((C, H, W), device=im_windows.device)
135 | count = torch.zeros((1, H, W), device=im_windows.device)
136 | for window, (ha, wa) in zip(im_windows, anchors):
137 | logit[:, ha : ha + ws, wa : wa + ws] += window
138 | count[:, ha : ha + ws, wa : wa + ws] += 1
139 | logit = logit / count
140 | logit = F.interpolate(
141 | logit.unsqueeze(0),
142 | ori_shape,
143 | mode="bilinear",
144 | )[0]
145 | if flip:
146 | logit = torch.flip(logit, (2,))
147 | result = F.softmax(logit, 0)
148 | return result
149 |
150 |
151 | def inference(
152 | model,
153 | ims,
154 | ims_metas,
155 | ori_shape,
156 | window_size,
157 | window_stride,
158 | batch_size,
159 | ):
160 | C = model.n_cls
161 | seg_map = torch.zeros((C, ori_shape[0], ori_shape[1]), device=ptu.device)
162 | for im, im_metas in zip(ims, ims_metas):
163 | im = im.to(ptu.device)
164 | im = resize(im, window_size)
165 | flip = im_metas["flip"]
166 | windows = sliding_window(im, flip, window_size, window_stride)
167 | crops = torch.stack(windows.pop("crop"))[:, 0]
168 | B = len(crops)
169 | WB = batch_size
170 | seg_maps = torch.zeros((B, C, window_size, window_size), device=im.device)
171 | with torch.no_grad():
172 | for i in range(0, B, WB):
173 | seg_maps[i : i + WB] = model.forward(crops[i : i + WB])
174 | windows["seg_maps"] = seg_maps
175 | im_seg_map = merge_windows(windows, window_size, ori_shape)
176 | seg_map += im_seg_map
177 | seg_map /= len(ims)
178 | return seg_map
179 |
180 |
181 | def num_params(model):
182 | model_parameters = filter(lambda p: p.requires_grad, model.parameters())
183 | n_params = sum([torch.prod(torch.tensor(p.size())) for p in model_parameters])
184 | return n_params.item()
185 |
--------------------------------------------------------------------------------
/segm/model/vit.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from 2020 Ross Wightman
3 | https://github.com/rwightman/pytorch-image-models
4 | """
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 | from segm.model.utils import init_weights, resize_pos_embed
10 | from segm.model.blocks import Block
11 |
12 | from timm.models.layers import DropPath
13 | from timm.models.layers import trunc_normal_
14 | from timm.models.vision_transformer import _load_weights
15 |
16 |
17 | class PatchEmbedding(nn.Module):
18 | def __init__(self, image_size, patch_size, embed_dim, channels):
19 | super().__init__()
20 |
21 | self.image_size = image_size
22 | if image_size[0] % patch_size != 0 or image_size[1] % patch_size != 0:
23 | raise ValueError("image dimensions must be divisible by the patch size")
24 | self.grid_size = image_size[0] // patch_size, image_size[1] // patch_size
25 | self.num_patches = self.grid_size[0] * self.grid_size[1]
26 | self.patch_size = patch_size
27 |
28 | self.proj = nn.Conv2d(
29 | channels, embed_dim, kernel_size=patch_size, stride=patch_size
30 | )
31 |
32 | def forward(self, im):
33 | B, C, H, W = im.shape
34 | x = self.proj(im).flatten(2).transpose(1, 2)
35 | return x
36 |
37 |
38 | class VisionTransformer(nn.Module):
39 | def __init__(
40 | self,
41 | image_size,
42 | patch_size,
43 | n_layers,
44 | d_model,
45 | d_ff,
46 | n_heads,
47 | n_cls,
48 | dropout=0.1,
49 | drop_path_rate=0.0,
50 | distilled=False,
51 | channels=3,
52 | ):
53 | super().__init__()
54 | self.patch_embed = PatchEmbedding(
55 | image_size,
56 | patch_size,
57 | d_model,
58 | channels,
59 | )
60 | self.patch_size = patch_size
61 | self.n_layers = n_layers
62 | self.d_model = d_model
63 | self.d_ff = d_ff
64 | self.n_heads = n_heads
65 | self.dropout = nn.Dropout(dropout)
66 | self.n_cls = n_cls
67 |
68 | # cls and pos tokens
69 | self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
70 | self.distilled = distilled
71 | if self.distilled:
72 | self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model))
73 | self.pos_embed = nn.Parameter(
74 | torch.randn(1, self.patch_embed.num_patches + 2, d_model)
75 | )
76 | self.head_dist = nn.Linear(d_model, n_cls)
77 | else:
78 | self.pos_embed = nn.Parameter(
79 | torch.randn(1, self.patch_embed.num_patches + 1, d_model)
80 | )
81 |
82 | # transformer blocks
83 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)]
84 | self.blocks = nn.ModuleList(
85 | [Block(d_model, n_heads, d_ff, dropout, dpr[i]) for i in range(n_layers)]
86 | )
87 |
88 | # output head
89 | self.norm = nn.LayerNorm(d_model)
90 | self.head = nn.Linear(d_model, n_cls)
91 |
92 | trunc_normal_(self.pos_embed, std=0.02)
93 | trunc_normal_(self.cls_token, std=0.02)
94 | if self.distilled:
95 | trunc_normal_(self.dist_token, std=0.02)
96 | self.pre_logits = nn.Identity()
97 |
98 | self.apply(init_weights)
99 |
100 | @torch.jit.ignore
101 | def no_weight_decay(self):
102 | return {"pos_embed", "cls_token", "dist_token"}
103 |
104 | @torch.jit.ignore()
105 | def load_pretrained(self, checkpoint_path, prefix=""):
106 | _load_weights(self, checkpoint_path, prefix)
107 |
108 | def forward(self, im, return_features=False):
109 | B, _, H, W = im.shape
110 | PS = self.patch_size
111 |
112 | x = self.patch_embed(im)
113 | cls_tokens = self.cls_token.expand(B, -1, -1)
114 | if self.distilled:
115 | dist_tokens = self.dist_token.expand(B, -1, -1)
116 | x = torch.cat((cls_tokens, dist_tokens, x), dim=1)
117 | else:
118 | x = torch.cat((cls_tokens, x), dim=1)
119 |
120 | pos_embed = self.pos_embed
121 | num_extra_tokens = 1 + self.distilled
122 | if x.shape[1] != pos_embed.shape[1]:
123 | pos_embed = resize_pos_embed(
124 | pos_embed,
125 | self.patch_embed.grid_size,
126 | (H // PS, W // PS),
127 | num_extra_tokens,
128 | )
129 | x = x + pos_embed
130 | x = self.dropout(x)
131 |
132 | for blk in self.blocks:
133 | x = blk(x)
134 | x = self.norm(x)
135 |
136 | if return_features:
137 | return x
138 |
139 | if self.distilled:
140 | x, x_dist = x[:, 0], x[:, 1]
141 | x = self.head(x)
142 | x_dist = self.head_dist(x_dist)
143 | x = (x + x_dist) / 2
144 | else:
145 | x = x[:, 0]
146 | x = self.head(x)
147 | return x
148 |
149 | def get_attention_map(self, im, layer_id):
150 | if layer_id >= self.n_layers or layer_id < 0:
151 | raise ValueError(
152 | f"Provided layer_id: {layer_id} is not valid. 0 <= {layer_id} < {self.n_layers}."
153 | )
154 | B, _, H, W = im.shape
155 | PS = self.patch_size
156 |
157 | x = self.patch_embed(im)
158 | cls_tokens = self.cls_token.expand(B, -1, -1)
159 | if self.distilled:
160 | dist_tokens = self.dist_token.expand(B, -1, -1)
161 | x = torch.cat((cls_tokens, dist_tokens, x), dim=1)
162 | else:
163 | x = torch.cat((cls_tokens, x), dim=1)
164 |
165 | pos_embed = self.pos_embed
166 | num_extra_tokens = 1 + self.distilled
167 | if x.shape[1] != pos_embed.shape[1]:
168 | pos_embed = resize_pos_embed(
169 | pos_embed,
170 | self.patch_embed.grid_size,
171 | (H // PS, W // PS),
172 | num_extra_tokens,
173 | )
174 | x = x + pos_embed
175 |
176 | for i, blk in enumerate(self.blocks):
177 | if i < layer_id:
178 | x = blk(x)
179 | else:
180 | return blk(x, return_attention=True)
181 |
--------------------------------------------------------------------------------
/segm/optim/factory.py:
--------------------------------------------------------------------------------
1 | from timm import scheduler
2 | from timm import optim
3 |
4 | from segm.optim.scheduler import PolynomialLR
5 |
6 |
7 | def create_scheduler(opt_args, optimizer):
8 | if opt_args.sched == "polynomial":
9 | lr_scheduler = PolynomialLR(
10 | optimizer,
11 | opt_args.poly_step_size,
12 | opt_args.iter_warmup,
13 | opt_args.iter_max,
14 | opt_args.poly_power,
15 | opt_args.min_lr,
16 | )
17 | else:
18 | lr_scheduler, _ = scheduler.create_scheduler(opt_args, optimizer)
19 | return lr_scheduler
20 |
21 |
22 | def create_optimizer(opt_args, model):
23 | return optim.create_optimizer(opt_args, model)
24 |
--------------------------------------------------------------------------------
/segm/optim/scheduler.py:
--------------------------------------------------------------------------------
1 | from torch import optim
2 | from torch.optim.lr_scheduler import _LRScheduler
3 | from timm.scheduler.scheduler import Scheduler
4 |
5 |
6 | class PolynomialLR(_LRScheduler):
7 | def __init__(
8 | self,
9 | optimizer,
10 | step_size,
11 | iter_warmup,
12 | iter_max,
13 | power,
14 | min_lr=0,
15 | last_epoch=-1,
16 | ):
17 | self.step_size = step_size
18 | self.iter_warmup = int(iter_warmup)
19 | self.iter_max = int(iter_max)
20 | self.power = power
21 | self.min_lr = min_lr
22 | super(PolynomialLR, self).__init__(optimizer, last_epoch)
23 |
24 | def polynomial_decay(self, lr):
25 | iter_cur = float(self.last_epoch)
26 | if iter_cur < self.iter_warmup:
27 | coef = iter_cur / self.iter_warmup
28 | coef *= (1 - self.iter_warmup / self.iter_max) ** self.power
29 | else:
30 | coef = (1 - iter_cur / self.iter_max) ** self.power
31 | return (lr - self.min_lr) * coef + self.min_lr
32 |
33 | def get_lr(self):
34 | if (
35 | (self.last_epoch == 0)
36 | or (self.last_epoch % self.step_size != 0)
37 | or (self.last_epoch > self.iter_max)
38 | ):
39 | return [group["lr"] for group in self.optimizer.param_groups]
40 | return [self.polynomial_decay(lr) for lr in self.base_lrs]
41 |
42 | def step_update(self, num_updates):
43 | self.step()
44 |
--------------------------------------------------------------------------------
/segm/scripts/prepare_ade20k.py:
--------------------------------------------------------------------------------
1 | """Prepare ADE20K dataset"""
2 | import click
3 | import zipfile
4 |
5 | from pathlib import Path
6 | from segm.utils.download import download
7 |
8 |
9 | def download_ade(path, overwrite=False):
10 | _AUG_DOWNLOAD_URLS = [
11 | (
12 | "http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip",
13 | "219e1696abb36c8ba3a3afe7fb2f4b4606a897c7",
14 | ),
15 | (
16 | "http://data.csail.mit.edu/places/ADEchallenge/release_test.zip",
17 | "e05747892219d10e9243933371a497e905a4860c",
18 | ),
19 | ]
20 | download_dir = path / "downloads"
21 | download_dir.mkdir(parents=True, exist_ok=True)
22 | for url, checksum in _AUG_DOWNLOAD_URLS:
23 | filename = download(
24 | url, path=str(download_dir), overwrite=overwrite, sha1_hash=checksum
25 | )
26 | # extract
27 | with zipfile.ZipFile(filename, "r") as zip_ref:
28 | zip_ref.extractall(path=str(path))
29 |
30 |
31 | @click.command(help="Initialize ADE20K dataset.")
32 | @click.argument("download_dir", type=str)
33 | def main(download_dir):
34 | dataset_dir = Path(download_dir) / "ade20k"
35 | download_ade(dataset_dir, overwrite=False)
36 |
37 |
38 | if __name__ == "__main__":
39 | main()
40 |
--------------------------------------------------------------------------------
/segm/scripts/prepare_cityscapes.py:
--------------------------------------------------------------------------------
1 | """Prepare Cityscapes dataset"""
2 | import click
3 | import os
4 | import shutil
5 | import mmcv
6 | import zipfile
7 |
8 | from pathlib import Path
9 | from segm.utils.download import download
10 |
11 | USERNAME = None
12 | PASSWORD = None
13 |
14 |
15 | def download_cityscapes(path, username, password, overwrite=False):
16 | _CITY_DOWNLOAD_URLS = [
17 | ("gtFine_trainvaltest.zip", "99f532cb1af174f5fcc4c5bc8feea8c66246ddbc"),
18 | ("leftImg8bit_trainvaltest.zip", "2c0b77ce9933cc635adda307fbba5566f5d9d404"),
19 | ]
20 | download_dir = path / "downloads"
21 | download_dir.mkdir(parents=True, exist_ok=True)
22 |
23 | os.system(
24 | f"wget --keep-session-cookies --save-cookies=cookies.txt --post-data 'username={username}&password={password}&submit=Login' https://www.cityscapes-dataset.com/login/ -P {download_dir}"
25 | )
26 |
27 | if not (download_dir / "gtFine_trainvaltest.zip").is_file():
28 | os.system(
29 | f"wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=1 -P {download_dir}"
30 | )
31 |
32 | if not (download_dir / "leftImg8bit_trainvaltest.zip").is_file():
33 | os.system(
34 | f"wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=3 -P {download_dir}"
35 | )
36 |
37 | for filename, checksum in _CITY_DOWNLOAD_URLS:
38 | # extract
39 | with zipfile.ZipFile(str(download_dir / filename), "r") as zip_ref:
40 | zip_ref.extractall(path=path)
41 | print("Extracted", filename)
42 |
43 |
44 | def install_cityscapes_api():
45 | os.system("pip install cityscapesscripts")
46 | try:
47 | import cityscapesscripts
48 | except Exception:
49 | print(
50 | "Installing Cityscapes API failed, please install it manually %s"
51 | % (repo_url)
52 | )
53 |
54 |
55 | def convert_json_to_label(json_file):
56 | from cityscapesscripts.preparation.json2labelImg import json2labelImg
57 |
58 | label_file = json_file.replace("_polygons.json", "_labelTrainIds.png")
59 | json2labelImg(json_file, label_file, "trainIds")
60 |
61 |
62 | @click.command(help="Initialize Cityscapes dataset.")
63 | @click.argument("download_dir", type=str)
64 | @click.option("--username", default=USERNAME, type=str)
65 | @click.option("--password", default=PASSWORD, type=str)
66 | @click.option("--nproc", default=10, type=int)
67 | def main(
68 | download_dir,
69 | username,
70 | password,
71 | nproc,
72 | ):
73 |
74 | dataset_dir = Path(download_dir) / "cityscapes"
75 |
76 | if username is None or password is None:
77 | raise ValueError(
78 | "You must indicate your username and password either in the script variables or by passing options --username and --pasword."
79 | )
80 |
81 | download_cityscapes(dataset_dir, username, password, overwrite=False)
82 |
83 | install_cityscapes_api()
84 |
85 | gt_dir = dataset_dir / "gtFine"
86 |
87 | poly_files = []
88 | for poly in mmcv.scandir(str(gt_dir), "_polygons.json", recursive=True):
89 | poly_file = str(gt_dir / poly)
90 | poly_files.append(poly_file)
91 | mmcv.track_parallel_progress(convert_json_to_label, poly_files, nproc)
92 |
93 | split_names = ["train", "val", "test"]
94 |
95 | for split in split_names:
96 | filenames = []
97 | for poly in mmcv.scandir(str(gt_dir / split), "_polygons.json", recursive=True):
98 | filenames.append(poly.replace("_gtFine_polygons.json", ""))
99 | with open(str(dataset_dir / f"{split}.txt"), "w") as f:
100 | f.writelines(f + "\n" for f in filenames)
101 |
102 |
103 | if __name__ == "__main__":
104 | main()
105 |
--------------------------------------------------------------------------------
/segm/scripts/prepare_pcontext.py:
--------------------------------------------------------------------------------
1 | """Prepare PASCAL Context dataset"""
2 | import click
3 | import shutil
4 | import tarfile
5 | import torch
6 |
7 | from tqdm import tqdm
8 | from pathlib import Path
9 |
10 | from segm.utils.download import download
11 |
12 |
13 | def download_pcontext(path, overwrite=False):
14 | _AUG_DOWNLOAD_URLS = [
15 | (
16 | "https://www.dropbox.com/s/wtdibo9lb2fur70/VOCtrainval_03-May-2010.tar?dl=1",
17 | "VOCtrainval_03-May-2010.tar",
18 | "bf9985e9f2b064752bf6bd654d89f017c76c395a",
19 | ),
20 | (
21 | "https://codalabuser.blob.core.windows.net/public/trainval_merged.json",
22 | "",
23 | "169325d9f7e9047537fedca7b04de4dddf10b881",
24 | ),
25 | (
26 | "https://hangzh.s3.amazonaws.com/encoding/data/pcontext/train.pth",
27 | "",
28 | "4bfb49e8c1cefe352df876c9b5434e655c9c1d07",
29 | ),
30 | (
31 | "https://hangzh.s3.amazonaws.com/encoding/data/pcontext/val.pth",
32 | "",
33 | "ebedc94247ec616c57b9a2df15091784826a7b0c",
34 | ),
35 | ]
36 | download_dir = path / "downloads"
37 |
38 | download_dir.mkdir(parents=True, exist_ok=True)
39 |
40 | for url, filename, checksum in _AUG_DOWNLOAD_URLS:
41 | filename = download(
42 | url,
43 | path=str(download_dir / filename),
44 | overwrite=overwrite,
45 | sha1_hash=checksum,
46 | )
47 | # extract
48 | if Path(filename).suffix == ".tar":
49 | with tarfile.open(filename) as tar:
50 | tar.extractall(path=str(path))
51 | else:
52 | shutil.move(
53 | filename,
54 | str(path / "VOCdevkit" / "VOC2010" / Path(filename).name),
55 | )
56 |
57 |
58 | @click.command(help="Initialize PASCAL Context dataset.")
59 | @click.argument("download_dir", type=str)
60 | def main(download_dir):
61 |
62 | dataset_dir = Path(download_dir) / "pcontext"
63 |
64 | download_pcontext(dataset_dir, overwrite=False)
65 |
66 | devkit_path = dataset_dir / "VOCdevkit"
67 | out_dir = devkit_path / "VOC2010" / "SegmentationClassContext"
68 | imageset_dir = devkit_path / "VOC2010" / "ImageSets" / "SegmentationContext"
69 |
70 | out_dir.mkdir(parents=True, exist_ok=True)
71 | imageset_dir.mkdir(parents=True, exist_ok=True)
72 |
73 | train_torch_path = devkit_path / "VOC2010" / "train.pth"
74 | val_torch_path = devkit_path / "VOC2010" / "val.pth"
75 |
76 | train_dict = torch.load(str(train_torch_path))
77 |
78 | train_list = []
79 | for idx, label in tqdm(train_dict.items()):
80 | idx = str(idx)
81 | new_idx = idx[:4] + "_" + idx[4:]
82 | train_list.append(new_idx)
83 | label_path = out_dir / f"{new_idx}.png"
84 | label.save(str(label_path))
85 |
86 | with open(str(imageset_dir / "train.txt"), "w") as f:
87 | f.writelines(line + "\n" for line in sorted(train_list))
88 |
89 | val_dict = torch.load(str(val_torch_path))
90 |
91 | val_list = []
92 | for idx, label in tqdm(val_dict.items()):
93 | idx = str(idx)
94 | new_idx = idx[:4] + "_" + idx[4:]
95 | val_list.append(new_idx)
96 | label_path = out_dir / f"{new_idx}.png"
97 | label.save(str(label_path))
98 |
99 | with open(str(imageset_dir / "val.txt"), "w") as f:
100 | f.writelines(line + "\n" for line in sorted(val_list))
101 |
102 |
103 | if __name__ == "__main__":
104 | main()
105 |
--------------------------------------------------------------------------------
/segm/scripts/show_attn_map.py:
--------------------------------------------------------------------------------
1 | import click
2 | import einops
3 | import torch
4 | import torchvision
5 |
6 | import matplotlib.pyplot as plt
7 | import segm.utils.torch as ptu
8 | import torch.nn.functional as F
9 |
10 | from pathlib import Path
11 | from PIL import Image
12 | from segm import config
13 | from segm.data.utils import STATS
14 | from segm.model.decoder import MaskTransformer
15 | from segm.model.factory import load_model
16 | from torchvision import transforms
17 |
18 |
19 | @click.command()
20 | @click.argument("model-path", type=str)
21 | @click.argument("image-path", type=str)
22 | @click.argument("output-dir", type=str)
23 | @click.option("--layer-id", default=0, type=int)
24 | @click.option("--x-patch", default=0, type=int)
25 | @click.option("--y-patch", default=0, type=int)
26 | @click.option("--cmap", default="viridis", type=str)
27 | @click.option("--enc/--dec", default=True, is_flag=True)
28 | @click.option("--cls/--patch", default=False, is_flag=True)
29 | def visualize(
30 | model_path,
31 | image_path,
32 | output_dir,
33 | layer_id,
34 | x_patch,
35 | y_patch,
36 | cmap,
37 | enc,
38 | cls,
39 | ):
40 |
41 | output_dir = Path(output_dir)
42 | model_dir = Path(model_path).parent
43 |
44 | ptu.set_gpu_mode(True)
45 |
46 | # Build model
47 | model, variant = load_model(model_path)
48 | for p in model.parameters():
49 | p.requires_grad = False
50 |
51 | model.eval()
52 | model.to(ptu.device)
53 |
54 | # Get model config
55 | patch_size = model.patch_size
56 | normalization = variant["dataset_kwargs"]["normalization"]
57 | image_size = variant["dataset_kwargs"]["image_size"]
58 | n_cls = variant["net_kwargs"]["n_cls"]
59 | stats = STATS[normalization]
60 |
61 | # Open image and process it
62 | try:
63 | with open(image_path, "rb") as f:
64 | img = Image.open(f)
65 | img = img.convert("RGB")
66 | except:
67 | raise ValueError(f"Provided image path {image_path} is not a valid image file.")
68 |
69 | # Normalize and resize
70 | transform = transforms.Compose(
71 | [
72 | transforms.Resize(image_size),
73 | transforms.ToTensor(),
74 | transforms.Normalize(stats["mean"], stats["std"]),
75 | ]
76 | )
77 |
78 | img = transform(img)
79 |
80 | # Make the image divisible by the patch size
81 | w, h = (
82 | image_size - image_size % patch_size,
83 | image_size - image_size % patch_size,
84 | )
85 |
86 | # Crop to image size
87 | img = img[:, :w, :h].unsqueeze(0)
88 |
89 | w_featmap = img.shape[-2] // patch_size
90 | h_featmap = img.shape[-1] // patch_size
91 |
92 | # Sanity checks
93 | if not enc and not isinstance(model.decoder, MaskTransformer):
94 | raise ValueError(
95 | f"Attention maps for decoder are only availabe for MaskTransformer. Provided model with decoder type: {model.decoder}."
96 | )
97 |
98 | if not cls:
99 | if x_patch > w_featmap or y_patch > h_featmap:
100 | raise ValueError(
101 | f"Provided patch x: {x_patch} y: {y_patch} is not valid. Patch should be in the range x: [0, {w_featmap}), y: [0, {h_featmap})"
102 | )
103 | num_patch = w_featmap * y_patch + x_patch
104 |
105 | if layer_id < 0:
106 | raise ValueError("Provided layer_id should be positive.")
107 |
108 | if enc and model.encoder.n_layers <= layer_id:
109 | raise ValueError(
110 | f"Provided layer_id: {layer_id} is not valid for encoder with {model.encoder.n_layers}."
111 | )
112 |
113 | if not enc and model.decoder.n_layers <= layer_id:
114 | raise ValueError(
115 | f"Provided layer_id: {layer_id} is not valid for decoder with {model.decoder.n_layers}."
116 | )
117 |
118 | Path.mkdir(output_dir, exist_ok=True)
119 |
120 | # Process input and extract attention maps
121 | if enc:
122 | print(f"Generating Attention Mapping for Encoder Layer Id {layer_id}")
123 | attentions = model.get_attention_map_enc(img.to(ptu.device), layer_id)
124 | num_extra_tokens = 1 + model.encoder.distilled
125 | if cls:
126 | attentions = attentions[0, :, 0, num_extra_tokens:]
127 | else:
128 | attentions = attentions[
129 | 0, :, num_patch + num_extra_tokens, num_extra_tokens:
130 | ]
131 | else:
132 | print(f"Generating Attention Mapping for Decoder Layer Id {layer_id}")
133 | attentions = model.get_attention_map_dec(img.to(ptu.device), layer_id)
134 | if cls:
135 | attentions = attentions[0, :, -n_cls:, :-n_cls]
136 | else:
137 | attentions = attentions[0, :, num_patch, :-n_cls]
138 |
139 | # Reshape into image shape
140 | nh = attentions.shape[0] # Number of heads
141 | attentions = attentions.reshape(nh, -1)
142 |
143 | if cls and not enc:
144 | attentions = attentions.reshape(nh, n_cls, w_featmap, h_featmap)
145 | else:
146 | attentions = attentions.reshape(nh, 1, w_featmap, h_featmap)
147 |
148 | # Resize attention maps to match input size
149 | attentions = (
150 | F.interpolate(attentions, scale_factor=patch_size, mode="nearest").cpu().numpy()
151 | )
152 |
153 | # Save Attention map for each head
154 | for i in range(nh):
155 | base_name = "enc" if enc else "dec"
156 | head_name = f"{base_name}_layer{layer_id}_attn-head{i}"
157 | attention_maps_list = attentions[i]
158 | for j in range(attention_maps_list.shape[0]):
159 | attention_map = attention_maps_list[j]
160 | file_name = head_name
161 | dir_path = output_dir / f"{base_name}_layer{layer_id}"
162 | Path.mkdir(dir_path, exist_ok=True)
163 | if cls:
164 | if enc:
165 | file_name = f"{file_name}_cls"
166 | dir_path /= "cls"
167 | else:
168 | file_name = f"{file_name}_{j}"
169 | dir_path /= f"cls_{j}"
170 | Path.mkdir(dir_path, exist_ok=True)
171 | else:
172 | dir_path /= f"patch_{x_patch}_{y_patch}"
173 | Path.mkdir(dir_path, exist_ok=True)
174 |
175 | file_path = dir_path / f"{file_name}.png"
176 | plt.imsave(fname=str(file_path), arr=attention_map, format="png", cmap=cmap)
177 | print(f"{file_path} saved.")
178 |
179 | # Save input image showing selected patch
180 | if not cls:
181 | im_n = torchvision.utils.make_grid(img, normalize=True, scale_each=True)
182 |
183 | # Compute corresponding X and Y px in the original image
184 | x_px = x_patch * patch_size
185 | y_px = y_patch * patch_size
186 | px_v = einops.repeat(
187 | torch.tensor([1, 0, 0]),
188 | "c -> 1 c h w",
189 | h=patch_size,
190 | w=patch_size,
191 | )
192 |
193 | # Draw pixels for selected patch
194 | im_n[:, y_px : y_px + patch_size, x_px : x_px + patch_size] = px_v
195 | torchvision.utils.save_image(
196 | im_n,
197 | str(dir_path / "input_img.png"),
198 | )
199 |
200 |
201 | if __name__ == "__main__":
202 | visualize()
203 |
--------------------------------------------------------------------------------
/segm/train.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from pathlib import Path
3 | import yaml
4 | import json
5 | import numpy as np
6 | import torch
7 | import click
8 | import argparse
9 | from torch.nn.parallel import DistributedDataParallel as DDP
10 |
11 | from segm.utils import distributed
12 | import segm.utils.torch as ptu
13 | from segm import config
14 |
15 | from segm.model.factory import create_segmenter
16 | from segm.optim.factory import create_optimizer, create_scheduler
17 | from segm.data.factory import create_dataset
18 | from segm.model.utils import num_params
19 |
20 | from timm.utils import NativeScaler
21 | from contextlib import suppress
22 |
23 | from segm.utils.distributed import sync_model
24 | from segm.engine import train_one_epoch, evaluate
25 |
26 |
27 | @click.command(help="")
28 | @click.option("--log-dir", type=str, help="logging directory")
29 | @click.option("--dataset", type=str)
30 | @click.option("--im-size", default=None, type=int, help="dataset resize size")
31 | @click.option("--crop-size", default=None, type=int)
32 | @click.option("--window-size", default=None, type=int)
33 | @click.option("--window-stride", default=None, type=int)
34 | @click.option("--backbone", default="", type=str)
35 | @click.option("--decoder", default="", type=str)
36 | @click.option("--optimizer", default="sgd", type=str)
37 | @click.option("--scheduler", default="polynomial", type=str)
38 | @click.option("--weight-decay", default=0.0, type=float)
39 | @click.option("--dropout", default=0.0, type=float)
40 | @click.option("--drop-path", default=0.1, type=float)
41 | @click.option("--batch-size", default=None, type=int)
42 | @click.option("--epochs", default=None, type=int)
43 | @click.option("-lr", "--learning-rate", default=None, type=float)
44 | @click.option("--normalization", default=None, type=str)
45 | @click.option("--eval-freq", default=None, type=int)
46 | @click.option("--amp/--no-amp", default=False, is_flag=True)
47 | @click.option("--resume/--no-resume", default=True, is_flag=True)
48 | def main(
49 | log_dir,
50 | dataset,
51 | im_size,
52 | crop_size,
53 | window_size,
54 | window_stride,
55 | backbone,
56 | decoder,
57 | optimizer,
58 | scheduler,
59 | weight_decay,
60 | dropout,
61 | drop_path,
62 | batch_size,
63 | epochs,
64 | learning_rate,
65 | normalization,
66 | eval_freq,
67 | amp,
68 | resume,
69 | ):
70 | # start distributed mode
71 | ptu.set_gpu_mode(True)
72 | distributed.init_process()
73 |
74 | # set up configuration
75 | cfg = config.load_config()
76 | model_cfg = cfg["model"][backbone]
77 | dataset_cfg = cfg["dataset"][dataset]
78 | if "mask_transformer" in decoder:
79 | decoder_cfg = cfg["decoder"]["mask_transformer"]
80 | else:
81 | decoder_cfg = cfg["decoder"][decoder]
82 |
83 | # model config
84 | if not im_size:
85 | im_size = dataset_cfg["im_size"]
86 | if not crop_size:
87 | crop_size = dataset_cfg.get("crop_size", im_size)
88 | if not window_size:
89 | window_size = dataset_cfg.get("window_size", im_size)
90 | if not window_stride:
91 | window_stride = dataset_cfg.get("window_stride", im_size)
92 |
93 | model_cfg["image_size"] = (crop_size, crop_size)
94 | model_cfg["backbone"] = backbone
95 | model_cfg["dropout"] = dropout
96 | model_cfg["drop_path_rate"] = drop_path
97 | decoder_cfg["name"] = decoder
98 | model_cfg["decoder"] = decoder_cfg
99 |
100 | # dataset config
101 | world_batch_size = dataset_cfg["batch_size"]
102 | num_epochs = dataset_cfg["epochs"]
103 | lr = dataset_cfg["learning_rate"]
104 | if batch_size:
105 | world_batch_size = batch_size
106 | if epochs:
107 | num_epochs = epochs
108 | if learning_rate:
109 | lr = learning_rate
110 | if eval_freq is None:
111 | eval_freq = dataset_cfg.get("eval_freq", 1)
112 |
113 | if normalization:
114 | model_cfg["normalization"] = normalization
115 |
116 | # experiment config
117 | batch_size = world_batch_size // ptu.world_size
118 | variant = dict(
119 | world_batch_size=world_batch_size,
120 | version="normal",
121 | resume=resume,
122 | dataset_kwargs=dict(
123 | dataset=dataset,
124 | image_size=im_size,
125 | crop_size=crop_size,
126 | batch_size=batch_size,
127 | normalization=model_cfg["normalization"],
128 | split="train",
129 | num_workers=10,
130 | ),
131 | algorithm_kwargs=dict(
132 | batch_size=batch_size,
133 | start_epoch=0,
134 | num_epochs=num_epochs,
135 | eval_freq=eval_freq,
136 | ),
137 | optimizer_kwargs=dict(
138 | opt=optimizer,
139 | lr=lr,
140 | weight_decay=weight_decay,
141 | momentum=0.9,
142 | clip_grad=None,
143 | sched=scheduler,
144 | epochs=num_epochs,
145 | min_lr=1e-5,
146 | poly_power=0.9,
147 | poly_step_size=1,
148 | ),
149 | net_kwargs=model_cfg,
150 | amp=amp,
151 | log_dir=log_dir,
152 | inference_kwargs=dict(
153 | im_size=im_size,
154 | window_size=window_size,
155 | window_stride=window_stride,
156 | ),
157 | )
158 |
159 | log_dir = Path(log_dir)
160 | log_dir.mkdir(parents=True, exist_ok=True)
161 | checkpoint_path = log_dir / "checkpoint.pth"
162 |
163 | # dataset
164 | dataset_kwargs = variant["dataset_kwargs"]
165 |
166 | train_loader = create_dataset(dataset_kwargs)
167 | val_kwargs = dataset_kwargs.copy()
168 | val_kwargs["split"] = "val"
169 | val_kwargs["batch_size"] = 1
170 | val_kwargs["crop"] = False
171 | val_loader = create_dataset(val_kwargs)
172 | n_cls = train_loader.unwrapped.n_cls
173 |
174 | # model
175 | net_kwargs = variant["net_kwargs"]
176 | net_kwargs["n_cls"] = n_cls
177 | model = create_segmenter(net_kwargs)
178 | model.to(ptu.device)
179 |
180 | # optimizer
181 | optimizer_kwargs = variant["optimizer_kwargs"]
182 | optimizer_kwargs["iter_max"] = len(train_loader) * optimizer_kwargs["epochs"]
183 | optimizer_kwargs["iter_warmup"] = 0.0
184 | opt_args = argparse.Namespace()
185 | opt_vars = vars(opt_args)
186 | for k, v in optimizer_kwargs.items():
187 | opt_vars[k] = v
188 | optimizer = create_optimizer(opt_args, model)
189 | lr_scheduler = create_scheduler(opt_args, optimizer)
190 | num_iterations = 0
191 | amp_autocast = suppress
192 | loss_scaler = None
193 | if amp:
194 | amp_autocast = torch.cuda.amp.autocast
195 | loss_scaler = NativeScaler()
196 |
197 | # resume
198 | if resume and checkpoint_path.exists():
199 | print(f"Resuming training from checkpoint: {checkpoint_path}")
200 | checkpoint = torch.load(checkpoint_path, map_location="cpu")
201 | model.load_state_dict(checkpoint["model"])
202 | optimizer.load_state_dict(checkpoint["optimizer"])
203 | if loss_scaler and "loss_scaler" in checkpoint:
204 | loss_scaler.load_state_dict(checkpoint["loss_scaler"])
205 | lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
206 | variant["algorithm_kwargs"]["start_epoch"] = checkpoint["epoch"] + 1
207 | else:
208 | sync_model(log_dir, model)
209 |
210 | if ptu.distributed:
211 | model = DDP(model, device_ids=[ptu.device], find_unused_parameters=True)
212 |
213 | # save config
214 | variant_str = yaml.dump(variant)
215 | print(f"Configuration:\n{variant_str}")
216 | variant["net_kwargs"] = net_kwargs
217 | variant["dataset_kwargs"] = dataset_kwargs
218 | log_dir.mkdir(parents=True, exist_ok=True)
219 | with open(log_dir / "variant.yml", "w") as f:
220 | f.write(variant_str)
221 |
222 | # train
223 | start_epoch = variant["algorithm_kwargs"]["start_epoch"]
224 | num_epochs = variant["algorithm_kwargs"]["num_epochs"]
225 | eval_freq = variant["algorithm_kwargs"]["eval_freq"]
226 |
227 | model_without_ddp = model
228 | if hasattr(model, "module"):
229 | model_without_ddp = model.module
230 |
231 | val_seg_gt = val_loader.dataset.get_gt_seg_maps()
232 |
233 | print(f"Train dataset length: {len(train_loader.dataset)}")
234 | print(f"Val dataset length: {len(val_loader.dataset)}")
235 | print(f"Encoder parameters: {num_params(model_without_ddp.encoder)}")
236 | print(f"Decoder parameters: {num_params(model_without_ddp.decoder)}")
237 |
238 | for epoch in range(start_epoch, num_epochs):
239 | # train for one epoch
240 | train_logger = train_one_epoch(
241 | model,
242 | train_loader,
243 | optimizer,
244 | lr_scheduler,
245 | epoch,
246 | amp_autocast,
247 | loss_scaler,
248 | )
249 |
250 | # save checkpoint
251 | if ptu.dist_rank == 0:
252 | snapshot = dict(
253 | model=model_without_ddp.state_dict(),
254 | optimizer=optimizer.state_dict(),
255 | n_cls=model_without_ddp.n_cls,
256 | lr_scheduler=lr_scheduler.state_dict(),
257 | )
258 | if loss_scaler is not None:
259 | snapshot["loss_scaler"] = loss_scaler.state_dict()
260 | snapshot["epoch"] = epoch
261 | torch.save(snapshot, checkpoint_path)
262 |
263 | # evaluate
264 | eval_epoch = epoch % eval_freq == 0 or epoch == num_epochs - 1
265 | if eval_epoch:
266 | eval_logger = evaluate(
267 | model,
268 | val_loader,
269 | val_seg_gt,
270 | window_size,
271 | window_stride,
272 | amp_autocast,
273 | )
274 | print(f"Stats [{epoch}]:", eval_logger, flush=True)
275 | print("")
276 |
277 | # log stats
278 | if ptu.dist_rank == 0:
279 | train_stats = {
280 | k: meter.global_avg for k, meter in train_logger.meters.items()
281 | }
282 | val_stats = {}
283 | if eval_epoch:
284 | val_stats = {
285 | k: meter.global_avg for k, meter in eval_logger.meters.items()
286 | }
287 |
288 | log_stats = {
289 | **{f"train_{k}": v for k, v in train_stats.items()},
290 | **{f"val_{k}": v for k, v in val_stats.items()},
291 | "epoch": epoch,
292 | "num_updates": (epoch + 1) * len(train_loader),
293 | }
294 |
295 | with open(log_dir / "log.txt", "a") as f:
296 | f.write(json.dumps(log_stats) + "\n")
297 |
298 | distributed.barrier()
299 | distributed.destroy_process()
300 | sys.exit(1)
301 |
302 |
303 | if __name__ == "__main__":
304 | main()
305 |
--------------------------------------------------------------------------------
/segm/utils/distributed.py:
--------------------------------------------------------------------------------
1 | import os
2 | import hostlist
3 | from pathlib import Path
4 | import torch
5 | import torch.distributed as dist
6 |
7 | import segm.utils.torch as ptu
8 |
9 |
10 | def init_process(backend="nccl"):
11 | print(f"Starting process with rank {ptu.dist_rank}...", flush=True)
12 |
13 | if "SLURM_STEPS_GPUS" in os.environ:
14 | gpu_ids = os.environ["SLURM_STEP_GPUS"].split(",")
15 | os.environ["MASTER_PORT"] = str(12345 + int(min(gpu_ids)))
16 | else:
17 | os.environ["MASTER_PORT"] = str(12345)
18 |
19 | if "SLURM_JOB_NODELIST" in os.environ:
20 | hostnames = hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"])
21 | os.environ["MASTER_ADDR"] = hostnames[0]
22 | else:
23 | os.environ["MASTER_ADDR"] = "127.0.0.1"
24 |
25 | dist.init_process_group(
26 | backend,
27 | rank=ptu.dist_rank,
28 | world_size=ptu.world_size,
29 | )
30 | print(f"Process {ptu.dist_rank} is connected.", flush=True)
31 | dist.barrier()
32 |
33 | silence_print(ptu.dist_rank == 0)
34 | if ptu.dist_rank == 0:
35 | print(f"All processes are connected.", flush=True)
36 |
37 |
38 | def silence_print(is_master):
39 | """
40 | This function disables printing when not in master process
41 | """
42 | import builtins as __builtin__
43 |
44 | builtin_print = __builtin__.print
45 |
46 | def print(*args, **kwargs):
47 | force = kwargs.pop("force", False)
48 | if is_master or force:
49 | builtin_print(*args, **kwargs)
50 |
51 | __builtin__.print = print
52 |
53 |
54 | def sync_model(sync_dir, model):
55 | # https://github.com/ylabbe/cosypose/blob/master/cosypose/utils/distributed.py
56 | sync_path = Path(sync_dir).resolve() / "sync_model.pkl"
57 | if ptu.dist_rank == 0 and ptu.world_size > 1:
58 | torch.save(model.state_dict(), sync_path)
59 | dist.barrier()
60 | if ptu.dist_rank > 0:
61 | model.load_state_dict(torch.load(sync_path))
62 | dist.barrier()
63 | if ptu.dist_rank == 0 and ptu.world_size > 1:
64 | sync_path.unlink()
65 | return model
66 |
67 |
68 | def barrier():
69 | dist.barrier()
70 |
71 |
72 | def destroy_process():
73 | dist.destroy_process_group()
74 |
--------------------------------------------------------------------------------
/segm/utils/download.py:
--------------------------------------------------------------------------------
1 | import os
2 | import requests
3 | import hashlib
4 | from tqdm import tqdm
5 |
6 |
7 | def check_sha1(filename, sha1_hash):
8 | """Check whether the sha1 hash of the file content matches the expected hash.
9 | Parameters
10 | ----------
11 | filename : str
12 | Path to the file.
13 | sha1_hash : str
14 | Expected sha1 hash in hexadecimal digits.
15 | Returns
16 | -------
17 | bool
18 | Whether the file content matches the expected hash.
19 | """
20 | sha1 = hashlib.sha1()
21 | with open(filename, "rb") as f:
22 | while True:
23 | data = f.read(1048576)
24 | if not data:
25 | break
26 | sha1.update(data)
27 |
28 | return sha1.hexdigest() == sha1_hash
29 |
30 |
31 | def download(url, path=None, overwrite=False, sha1_hash=None):
32 | """
33 | https://github.com/junfu1115/DANet/blob/master/encoding/utils/files.py
34 | Download a given URL
35 | Parameters
36 | ----------
37 | url : str
38 | URL to download
39 | path : str, optional
40 | Destination path to store downloaded file. By default stores to the
41 | current directory with same name as in url.
42 | overwrite : bool, optional
43 | Whether to overwrite destination file if already exists.
44 | sha1_hash : str, optional
45 | Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
46 | but doesn't match.
47 | Returns
48 | -------
49 | str
50 | The file path of the downloaded file.
51 | """
52 | if path is None:
53 | fname = url.split("/")[-1]
54 | else:
55 | path = os.path.expanduser(path)
56 | if os.path.isdir(path):
57 | fname = os.path.join(path, url.split("/")[-1])
58 | else:
59 | fname = path
60 |
61 | if (
62 | overwrite
63 | or not os.path.exists(fname)
64 | or (sha1_hash and not check_sha1(fname, sha1_hash))
65 | ):
66 | dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
67 | if not os.path.exists(dirname):
68 | os.makedirs(dirname)
69 |
70 | print("Downloading %s from %s..." % (fname, url))
71 | r = requests.get(url, stream=True)
72 | if r.status_code != 200:
73 | raise RuntimeError("Failed downloading url %s" % url)
74 | total_length = r.headers.get("content-length")
75 | with open(fname, "wb") as f:
76 | if total_length is None: # no content length header
77 | for chunk in r.iter_content(chunk_size=1024):
78 | if chunk: # filter out keep-alive new chunks
79 | f.write(chunk)
80 | else:
81 | total_length = int(total_length)
82 | for chunk in tqdm(
83 | r.iter_content(chunk_size=1024),
84 | total=int(total_length / 1024.0 + 0.5),
85 | unit="KB",
86 | unit_scale=False,
87 | dynamic_ncols=True,
88 | ):
89 | f.write(chunk)
90 |
91 | if sha1_hash and not check_sha1(fname, sha1_hash):
92 | raise UserWarning(
93 | "File {} is downloaded but the content hash does not match. "
94 | "The repo may be outdated or download may be incomplete. "
95 | 'If the "repo_url" is overridden, consider switching to '
96 | "the default repo.".format(fname)
97 | )
98 |
99 | return fname
100 |
--------------------------------------------------------------------------------
/segm/utils/lines.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from itertools import cycle
3 |
4 |
5 | class Lines:
6 | def __init__(self, resolution=20, smooth=None):
7 | self.COLORS = cycle(
8 | [
9 | "#377eb8",
10 | "#e41a1c",
11 | "#4daf4a",
12 | "#984ea3",
13 | "#ff7f00",
14 | "#ffff33",
15 | "#a65628",
16 | "#f781bf",
17 | ]
18 | )
19 | self.MARKERS = cycle("os^Dp>d<")
20 | self.LEGEND = dict(fontsize="medium", labelspacing=0, numpoints=1)
21 | self._resolution = resolution
22 | self._smooth_weight = smooth
23 |
24 | def __call__(self, ax, domains, lines, labels):
25 | assert len(domains) == len(lines) == len(labels)
26 | colors = []
27 | for index, (label, color, marker) in enumerate(
28 | zip(labels, self.COLORS, self.MARKERS)
29 | ):
30 | domain, line = domains[index], lines[index]
31 | line = self.smooth(line, self._smooth_weight)
32 | ax.plot(domain, line[:, 0], color=color, label=label)
33 |
34 | last_x, last_y = domain[-1], line[-1, 0]
35 | ax.scatter(last_x, last_y, color=color, marker="x")
36 | ax.annotate(
37 | f"{last_y:.2f}",
38 | xy=(last_x, last_y),
39 | xytext=(last_x, last_y + 0.1),
40 | )
41 | colors.append(color)
42 |
43 | self._plot_legend(ax, lines, labels)
44 | return colors
45 |
46 | def _plot_legend(self, ax, lines, labels):
47 | scores = {label: -np.nanmedian(line) for label, line in zip(labels, lines)}
48 | handles, labels = ax.get_legend_handles_labels()
49 | # handles, labels = zip(*sorted(
50 | # zip(handles, labels), key=lambda x: scores[x[1]]))
51 | legend = ax.legend(handles, labels, **self.LEGEND)
52 | legend.get_frame().set_edgecolor("white")
53 | for line in legend.get_lines():
54 | line.set_alpha(1)
55 |
56 | def smooth(self, scalars, weight):
57 | """
58 | weight in [0, 1]
59 | exponential moving average, same as tensorboard
60 | """
61 | assert weight >= 0 and weight <= 1
62 | last = scalars[0]
63 | smoothed = np.asarray(scalars)
64 | for i, point in enumerate(scalars):
65 | last = last * weight + (1 - weight) * point
66 | smoothed[i] = last
67 |
68 | return smoothed
69 |
--------------------------------------------------------------------------------
/segm/utils/logger.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/facebookresearch/deit/blob/main/utils.py
3 | """
4 |
5 | import io
6 | import os
7 | import time
8 | from collections import defaultdict, deque
9 | import datetime
10 |
11 | import torch
12 | import torch.distributed as dist
13 |
14 | import segm.utils.torch as ptu
15 |
16 |
17 | class SmoothedValue(object):
18 | """Track a series of values and provide access to smoothed values over a
19 | window or the global series average.
20 | """
21 |
22 | def __init__(self, window_size=20, fmt=None):
23 | if fmt is None:
24 | fmt = "{median:.4f} ({global_avg:.4f})"
25 | self.deque = deque(maxlen=window_size)
26 | self.total = 0.0
27 | self.count = 0
28 | self.fmt = fmt
29 |
30 | def update(self, value, n=1):
31 | self.deque.append(value)
32 | self.count += n
33 | self.total += value * n
34 |
35 | def synchronize_between_processes(self):
36 | """
37 | Warning: does not synchronize the deque!
38 | """
39 | if not is_dist_avail_and_initialized():
40 | return
41 | t = torch.tensor(
42 | [self.count, self.total], dtype=torch.float64, device=ptu.device
43 | )
44 | dist.barrier()
45 | dist.all_reduce(t)
46 | t = t.tolist()
47 | self.count = int(t[0])
48 | self.total = t[1]
49 |
50 | @property
51 | def median(self):
52 | d = torch.tensor(list(self.deque))
53 | return d.median().item()
54 |
55 | @property
56 | def avg(self):
57 | d = torch.tensor(list(self.deque), dtype=torch.float32)
58 | return d.mean().item()
59 |
60 | @property
61 | def global_avg(self):
62 | return self.total / self.count
63 |
64 | @property
65 | def max(self):
66 | return max(self.deque)
67 |
68 | @property
69 | def value(self):
70 | return self.deque[-1]
71 |
72 | def __str__(self):
73 | return self.fmt.format(
74 | median=self.median,
75 | avg=self.avg,
76 | global_avg=self.global_avg,
77 | max=self.max,
78 | value=self.value,
79 | )
80 |
81 |
82 | class MetricLogger(object):
83 | def __init__(self, delimiter="\t"):
84 | self.meters = defaultdict(SmoothedValue)
85 | self.delimiter = delimiter
86 |
87 | def update(self, n=1, **kwargs):
88 | for k, v in kwargs.items():
89 | if isinstance(v, torch.Tensor):
90 | v = v.item()
91 | assert isinstance(v, (float, int))
92 | self.meters[k].update(v, n)
93 |
94 | def __getattr__(self, attr):
95 | if attr in self.meters:
96 | return self.meters[attr]
97 | if attr in self.__dict__:
98 | return self.__dict__[attr]
99 | raise AttributeError(
100 | "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
101 | )
102 |
103 | def __str__(self):
104 | loss_str = []
105 | for name, meter in self.meters.items():
106 | loss_str.append("{}: {}".format(name, str(meter)))
107 | return self.delimiter.join(loss_str)
108 |
109 | def synchronize_between_processes(self):
110 | for meter in self.meters.values():
111 | meter.synchronize_between_processes()
112 |
113 | def add_meter(self, name, meter):
114 | self.meters[name] = meter
115 |
116 | def log_every(self, iterable, print_freq, header=None):
117 | i = 0
118 | if not header:
119 | header = ""
120 | start_time = time.time()
121 | end = time.time()
122 | iter_time = SmoothedValue(fmt="{avg:.4f}")
123 | data_time = SmoothedValue(fmt="{avg:.4f}")
124 | space_fmt = ":" + str(len(str(len(iterable)))) + "d"
125 | log_msg = [
126 | header,
127 | "[{0" + space_fmt + "}/{1}]",
128 | "eta: {eta}",
129 | "{meters}",
130 | "time: {time}",
131 | "data: {data}",
132 | ]
133 | if torch.cuda.is_available():
134 | log_msg.append("max mem: {memory:.0f}")
135 | log_msg = self.delimiter.join(log_msg)
136 | MB = 1024.0 * 1024.0
137 | for obj in iterable:
138 | data_time.update(time.time() - end)
139 | yield obj
140 | iter_time.update(time.time() - end)
141 | if i % print_freq == 0 or i == len(iterable) - 1:
142 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
143 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
144 | if torch.cuda.is_available():
145 | print(
146 | log_msg.format(
147 | i,
148 | len(iterable),
149 | eta=eta_string,
150 | meters=str(self),
151 | time=str(iter_time),
152 | data=str(data_time),
153 | memory=torch.cuda.max_memory_allocated() / MB,
154 | ),
155 | flush=True,
156 | )
157 | else:
158 | print(
159 | log_msg.format(
160 | i,
161 | len(iterable),
162 | eta=eta_string,
163 | meters=str(self),
164 | time=str(iter_time),
165 | data=str(data_time),
166 | ),
167 | flush=True,
168 | )
169 | i += 1
170 | end = time.time()
171 | total_time = time.time() - start_time
172 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
173 | print(
174 | "{} Total time: {} ({:.4f} s / it)".format(
175 | header, total_time_str, total_time / len(iterable)
176 | )
177 | )
178 |
179 |
180 | def is_dist_avail_and_initialized():
181 | if not dist.is_available():
182 | return False
183 | if not dist.is_initialized():
184 | return False
185 | return True
186 |
--------------------------------------------------------------------------------
/segm/utils/logs.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 | import numpy as np
4 | import yaml
5 | import matplotlib.pyplot as plt
6 | import click
7 | from collections import OrderedDict
8 |
9 | from segm.utils.lines import Lines
10 |
11 |
12 | def plot_logs(logs, x_key, y_key, size, vmin, vmax, epochs):
13 | m = np.inf
14 | M = -np.inf
15 | domains = []
16 | lines = []
17 | y_keys = y_key.split("/")
18 | for name, log in logs.items():
19 | logs[name] = log[:epochs]
20 | for name, log in logs.items():
21 | domain = [x[x_key] for x in log if y_keys[0] in x]
22 | if y_keys[0] not in log[0]:
23 | continue
24 | log_plot = [x[y_keys[0]] for x in log if y_keys[0] in x]
25 | for y_key in y_keys[1:]:
26 | if y_key in log_plot[0]:
27 | log_plot = [x[y_key] for x in log_plot if y_key in x]
28 | domains.append(domain)
29 | lines.append(np.array(log_plot)[:, None])
30 | m = np.min((m, min(log_plot)))
31 | M = np.max((M, max(log_plot)))
32 | if vmin is not None:
33 | m = vmin
34 | if vmax is not None:
35 | M = vmax
36 | delta = 0.1 * (M - m)
37 |
38 | ratio = 0.6
39 | figsizes = {"tight": (4, 3), "large": (16 * ratio, 10 * ratio)}
40 | figsize = figsizes[size]
41 |
42 | # plot parameters
43 | fig, ax = plt.subplots(figsize=figsize)
44 | ax.set_xlabel(x_key)
45 | ax.set_ylabel(y_key)
46 | plot_lines = Lines(resolution=50, smooth=0.0)
47 | plot_lines.LEGEND["loc"] = "upper left"
48 | # plot_lines.LEGEND["fontsize"] = "large"
49 | plot_lines.LEGEND["bbox_to_anchor"] = (0.75, 0.2)
50 | labels_logs = list(logs.keys())
51 | colors = plot_lines(ax, domains, lines, labels_logs)
52 | ax.grid(True, alpha=0.5)
53 | ax.set_ylim(m - delta, M + delta)
54 |
55 | plt.show()
56 | fig.savefig(
57 | "plot.png", bbox_inches="tight", pad_inches=0.1, transparent=False, dpi=300
58 | )
59 | plt.close(fig)
60 |
61 |
62 | def print_logs(logs, x_key, y_key, last_log_idx=None):
63 | delim = " "
64 | s = ""
65 | keys = []
66 | y_keys = y_key.split("/")
67 | for name, log in logs.items():
68 | log_idx = last_log_idx
69 | if log_idx is None:
70 | log_idx = len(log) - 1
71 | while y_keys[0] not in log[log_idx]:
72 | log_idx -= 1
73 | last_log = log[log_idx]
74 | log_x = last_log[x_key]
75 | log_y = last_log[y_keys[0]]
76 | for y_key in y_keys[1:]:
77 | log_y = log_y[y_key]
78 | s += f"{name}:\n"
79 | # s += f"{delim}{x_key}: {log_x}\n"
80 | s += f"{delim}{y_key}: {log_y:.4f}\n"
81 | keys += list(last_log.keys())
82 | keys = list(set(keys))
83 | keys = ", ".join(keys)
84 | s = f"keys: {keys}\n" + s
85 | print(s)
86 |
87 |
88 | def read_logs(root, logs_path):
89 | logs = {}
90 | for name, path in logs_path.items():
91 | path = root / path
92 | if not path.exists():
93 | print(f"Skipping {name} that has no log file")
94 | continue
95 | logs[name] = []
96 | with open(path, "r") as f:
97 | for line in f.readlines():
98 | d = json.loads(line)
99 | logs[name].append(d)
100 | return logs
101 |
102 |
103 | @click.command()
104 | @click.argument("log_path", type=str)
105 | @click.option("--x-key", default="epoch", type=str)
106 | @click.option("--y-key", default="val_mean_iou", type=str)
107 | @click.option("-s", "--size", default="large", type=str)
108 | @click.option("-ep", "--epoch", default=-1, type=int)
109 | @click.option("-plot", "--plot/--no-plot", default=True, is_flag=True)
110 | def main(log_path, x_key, y_key, size, epoch, plot):
111 | abs_path = Path(__file__).parent / log_path
112 | if abs_path.exists():
113 | log_path = abs_path
114 | config = yaml.load(open(log_path, "r"), Loader=yaml.FullLoader)
115 | root = Path(config["root"])
116 | logs_path = OrderedDict(config["logs"])
117 | vmin = config.get("vmin", None)
118 | vmax = config.get("vmax", None)
119 | epochs = config.get("epochs", None)
120 |
121 | logs = read_logs(root, logs_path)
122 | if not logs:
123 | return
124 | print_logs(logs, x_key, y_key, epoch)
125 | if plot:
126 | plot_logs(logs, x_key, y_key, size, vmin, vmax, epochs)
127 |
128 |
129 | if __name__ == "__main__":
130 | main()
131 |
--------------------------------------------------------------------------------
/segm/utils/torch.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 |
5 | """
6 | GPU wrappers
7 | """
8 |
9 | use_gpu = False
10 | gpu_id = 0
11 | device = None
12 |
13 | distributed = False
14 | dist_rank = 0
15 | world_size = 1
16 |
17 |
18 | def set_gpu_mode(mode):
19 | global use_gpu
20 | global device
21 | global gpu_id
22 | global distributed
23 | global dist_rank
24 | global world_size
25 | gpu_id = int(os.environ.get("SLURM_LOCALID", 0))
26 | dist_rank = int(os.environ.get("SLURM_PROCID", 0))
27 | world_size = int(os.environ.get("SLURM_NTASKS", 1))
28 |
29 | distributed = world_size > 1
30 | use_gpu = mode
31 | device = torch.device(f"cuda:{gpu_id}" if use_gpu else "cpu")
32 | torch.backends.cudnn.benchmark = True
33 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import shutil
2 |
3 | from setuptools import setup, find_packages
4 | from codecs import open
5 | from os import path
6 |
7 | here = path.abspath(path.dirname(__file__))
8 |
9 |
10 | def read_requirements_file(filename):
11 | req_file_path = path.join(path.dirname(path.realpath(__file__)), filename)
12 | with open(req_file_path) as f:
13 | return [line.strip() for line in f]
14 |
15 |
16 | setup(
17 | name="segm",
18 | version="0.0.1",
19 | description="Segmenter: Transformer for Semantic Segmentation",
20 | packages=find_packages(),
21 | install_requires=read_requirements_file("requirements.txt"),
22 | )
23 |
--------------------------------------------------------------------------------