├── .gitignore
├── LICENSE
├── README.md
├── assests
├── Helvetica.ttf
├── ade
│ ├── ADE_val_00000001.jpg
│ └── ADE_val_00000049.jpg
├── banner.jpg
├── image_labels
│ ├── Seq05VD_f05100.png
│ └── Seq05VD_f05100_L.png
└── infer_result.png
├── configs
├── ade20k.yaml
├── cityscapes.yaml
├── custom.yaml
└── helen.yaml
├── docs
├── BACKBONES.md
├── DATASETS.md
├── MODELS.md
└── OTHER_DATASETS.md
├── notebooks
├── aug_test.ipynb
└── tutorial.ipynb
├── scripts
├── calc_class_weights.py
├── export_data.py
├── onnx_infer.py
├── openvino_infer.py
├── preprocess_celebamaskhq.py
└── tflite_infer.py
├── semseg
├── __init__.py
├── augmentations.py
├── datasets
│ ├── __init__.py
│ ├── ade20k.py
│ ├── atr.py
│ ├── camvid.py
│ ├── celebamaskhq.py
│ ├── cihp.py
│ ├── cityscapes.py
│ ├── cocostuff.py
│ ├── facesynthetics.py
│ ├── helen.py
│ ├── ibugmask.py
│ ├── lapa.py
│ ├── lip.py
│ ├── mapillary.py
│ ├── mhpv1.py
│ ├── mhpv2.py
│ ├── pascalcontext.py
│ ├── suim.py
│ └── sunrgbd.py
├── losses.py
├── metrics.py
├── models
│ ├── __init__.py
│ ├── backbones
│ │ ├── __init__.py
│ │ ├── convnext.py
│ │ ├── micronet.py
│ │ ├── mit.py
│ │ ├── mobilenetv2.py
│ │ ├── mobilenetv3.py
│ │ ├── poolformer.py
│ │ ├── pvt.py
│ │ ├── resnet.py
│ │ ├── resnetd.py
│ │ ├── rest.py
│ │ └── uniformer.py
│ ├── base.py
│ ├── bisenetv1.py
│ ├── bisenetv2.py
│ ├── custom_cnn.py
│ ├── custom_vit.py
│ ├── ddrnet.py
│ ├── fchardnet.py
│ ├── heads
│ │ ├── __init__.py
│ │ ├── condnet.py
│ │ ├── fapn.py
│ │ ├── fcn.py
│ │ ├── fpn.py
│ │ ├── lawin.py
│ │ ├── segformer.py
│ │ ├── sfnet.py
│ │ └── upernet.py
│ ├── lawin.py
│ ├── layers
│ │ ├── __init__.py
│ │ ├── common.py
│ │ └── initialize.py
│ ├── modules
│ │ ├── __init__.py
│ │ ├── ppm.py
│ │ └── psa.py
│ ├── segformer.py
│ └── sfnet.py
├── optimizers.py
├── schedulers.py
└── utils
│ ├── __init__.py
│ ├── utils.py
│ └── visualize.py
├── setup.py
└── tools
├── benchmark.py
├── export.py
├── infer.py
├── train.py
└── val.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Repo-specific GitIgnore ----------------------------------------------------------------------------------------------
2 | *.jpg
3 | *.jpeg
4 | *.png
5 | *.bmp
6 | *.tif
7 | *.tiff
8 | *.heic
9 | *.JPG
10 | *.JPEG
11 | *.PNG
12 | *.BMP
13 | *.TIF
14 | *.TIFF
15 | *.HEIC
16 | *.mp4
17 | *.mov
18 | *.MOV
19 | *.avi
20 | *.data
21 | *.json
22 |
23 | *.cfg
24 | !cfg/yolov3*.cfg
25 |
26 | storage.googleapis.com
27 | runs/*
28 | data/*
29 | !data/images/zidane.jpg
30 | !data/images/bus.jpg
31 | !data/coco.names
32 | !data/coco_paper.names
33 | !data/coco.data
34 | !data/coco_*.data
35 | !data/coco_*.txt
36 | !data/trainvalno5k.shapes
37 | !data/*.sh
38 |
39 | test.py
40 | test_imgs/
41 |
42 | pycocotools/*
43 | results*.txt
44 | gcp_test*.sh
45 |
46 | checkpoints/
47 | output/
48 | assests/*/
49 |
50 | # Datasets -------------------------------------------------------------------------------------------------------------
51 | coco/
52 | coco128/
53 | VOC/
54 |
55 | # MATLAB GitIgnore -----------------------------------------------------------------------------------------------------
56 | *.m~
57 | *.mat
58 | !targets*.mat
59 |
60 | # Neural Network weights -----------------------------------------------------------------------------------------------
61 | *.weights
62 | *.pt
63 | *.onnx
64 | *.mlmodel
65 | *.torchscript
66 | darknet53.conv.74
67 | yolov3-tiny.conv.15
68 |
69 | # GitHub Python GitIgnore ----------------------------------------------------------------------------------------------
70 | # Byte-compiled / optimized / DLL files
71 | __pycache__/
72 | *.py[cod]
73 | *$py.class
74 |
75 | # C extensions
76 | *.so
77 |
78 | # Distribution / packaging
79 | .Python
80 | env/
81 | build/
82 | develop-eggs/
83 | dist/
84 | downloads/
85 | eggs/
86 | .eggs/
87 | lib/
88 | lib64/
89 | parts/
90 | sdist/
91 | var/
92 | wheels/
93 | *.egg-info/
94 | wandb/
95 | .installed.cfg
96 | *.egg
97 |
98 |
99 | # PyInstaller
100 | # Usually these files are written by a python script from a template
101 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
102 | *.manifest
103 | *.spec
104 |
105 | # Installer logs
106 | pip-log.txt
107 | pip-delete-this-directory.txt
108 |
109 | # Unit test / coverage reports
110 | htmlcov/
111 | .tox/
112 | .coverage
113 | .coverage.*
114 | .cache
115 | nosetests.xml
116 | coverage.xml
117 | *.cover
118 | .hypothesis/
119 |
120 | # Translations
121 | *.mo
122 | *.pot
123 |
124 | # Django stuff:
125 | *.log
126 | local_settings.py
127 |
128 | # Flask stuff:
129 | instance/
130 | .webassets-cache
131 |
132 | # Scrapy stuff:
133 | .scrapy
134 |
135 | # Sphinx documentation
136 | docs/_build/
137 |
138 | # PyBuilder
139 | target/
140 |
141 | # Jupyter Notebook
142 | .ipynb_checkpoints
143 |
144 | # pyenv
145 | .python-version
146 |
147 | # celery beat schedule file
148 | celerybeat-schedule
149 |
150 | # SageMath parsed files
151 | *.sage.py
152 |
153 | # dotenv
154 | .env
155 |
156 | # virtualenv
157 | .venv*
158 | venv*/
159 | ENV*/
160 |
161 | # Spyder project settings
162 | .spyderproject
163 | .spyproject
164 |
165 | # Rope project settings
166 | .ropeproject
167 |
168 | # mkdocs documentation
169 | /site
170 |
171 | # mypy
172 | .mypy_cache/
173 |
174 |
175 | # https://github.com/github/gitignore/blob/master/Global/macOS.gitignore -----------------------------------------------
176 |
177 | # General
178 | .DS_Store
179 | .AppleDouble
180 | .LSOverride
181 |
182 | # Icon must end with two \r
183 | Icon
184 | Icon?
185 |
186 | # Thumbnails
187 | ._*
188 |
189 | # Files that might appear in the root of a volume
190 | .DocumentRevisions-V100
191 | .fseventsd
192 | .Spotlight-V100
193 | .TemporaryItems
194 | .Trashes
195 | .VolumeIcon.icns
196 | .com.apple.timemachine.donotpresent
197 |
198 | # Directories potentially created on remote AFP share
199 | .AppleDB
200 | .AppleDesktop
201 | Network Trash Folder
202 | Temporary Items
203 | .apdisk
204 |
205 |
206 | # https://github.com/github/gitignore/blob/master/Global/JetBrains.gitignore
207 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
208 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
209 |
210 | # User-specific stuff:
211 | .idea/*
212 | .idea/**/workspace.xml
213 | .idea/**/tasks.xml
214 | .idea/dictionaries
215 | .html # Bokeh Plots
216 | .pg # TensorFlow Frozen Graphs
217 | .avi # videos
218 |
219 | # Sensitive or high-churn files:
220 | .idea/**/dataSources/
221 | .idea/**/dataSources.ids
222 | .idea/**/dataSources.local.xml
223 | .idea/**/sqlDataSources.xml
224 | .idea/**/dynamic.xml
225 | .idea/**/uiDesigner.xml
226 |
227 | # Gradle:
228 | .idea/**/gradle.xml
229 | .idea/**/libraries
230 |
231 | # CMake
232 | cmake-build-debug/
233 | cmake-build-release/
234 |
235 | # Mongo Explorer plugin:
236 | .idea/**/mongoSettings.xml
237 |
238 | ## File-based project format:
239 | *.iws
240 |
241 | ## Plugin-specific files:
242 |
243 | # IntelliJ
244 | out/
245 |
246 | # mpeltonen/sbt-idea plugin
247 | .idea_modules/
248 |
249 | # JIRA plugin
250 | atlassian-ide-plugin.xml
251 |
252 | # Cursive Clojure plugin
253 | .idea/replstate.xml
254 |
255 | # Crashlytics plugin (for Android Studio and IntelliJ)
256 | com_crashlytics_export_strings.xml
257 | crashlytics.properties
258 | crashlytics-build.properties
259 | fabric.properties
260 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 sithu3
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/assests/Helvetica.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sithu31296/semantic-segmentation/c2843c9cf42a66a96a4173a0ef0059975a6e387c/assests/Helvetica.ttf
--------------------------------------------------------------------------------
/assests/ade/ADE_val_00000001.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sithu31296/semantic-segmentation/c2843c9cf42a66a96a4173a0ef0059975a6e387c/assests/ade/ADE_val_00000001.jpg
--------------------------------------------------------------------------------
/assests/ade/ADE_val_00000049.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sithu31296/semantic-segmentation/c2843c9cf42a66a96a4173a0ef0059975a6e387c/assests/ade/ADE_val_00000049.jpg
--------------------------------------------------------------------------------
/assests/banner.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sithu31296/semantic-segmentation/c2843c9cf42a66a96a4173a0ef0059975a6e387c/assests/banner.jpg
--------------------------------------------------------------------------------
/assests/image_labels/Seq05VD_f05100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sithu31296/semantic-segmentation/c2843c9cf42a66a96a4173a0ef0059975a6e387c/assests/image_labels/Seq05VD_f05100.png
--------------------------------------------------------------------------------
/assests/image_labels/Seq05VD_f05100_L.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sithu31296/semantic-segmentation/c2843c9cf42a66a96a4173a0ef0059975a6e387c/assests/image_labels/Seq05VD_f05100_L.png
--------------------------------------------------------------------------------
/assests/infer_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sithu31296/semantic-segmentation/c2843c9cf42a66a96a4173a0ef0059975a6e387c/assests/infer_result.png
--------------------------------------------------------------------------------
/configs/ade20k.yaml:
--------------------------------------------------------------------------------
1 | DEVICE : cuda # device used for training and evaluation (cpu, cuda, cuda0, cuda1, ...)
2 | SAVE_DIR : 'output' # output folder name used for saving the model, logs and inference results
3 |
4 | MODEL:
5 | NAME : SegFormer # name of the model you are using
6 | BACKBONE : MiT-B2 # model variant
7 | PRETRAINED : 'checkpoints/backbones/mit/mit_b2.pth' # backbone model's weight
8 |
9 | DATASET:
10 | NAME : ADE20K # dataset name to be trained with (camvid, cityscapes, ade20k)
11 | ROOT : 'data/ADEChallengeData2016' # dataset root path
12 | IGNORE_LABEL : -1
13 |
14 | TRAIN:
15 | IMAGE_SIZE : [512, 512] # training image size in (h, w)
16 | BATCH_SIZE : 8 # batch size used to train
17 | EPOCHS : 500 # number of epochs to train
18 | EVAL_INTERVAL : 50 # evaluation interval during training
19 | AMP : false # use AMP in training
20 | DDP : false # use DDP training
21 |
22 | LOSS:
23 | NAME : OhemCrossEntropy # loss function name (ohemce, ce, dice)
24 | CLS_WEIGHTS : false # use class weights in loss calculation
25 |
26 | OPTIMIZER:
27 | NAME : adamw # optimizer name
28 | LR : 0.001 # initial learning rate used in optimizer
29 | WEIGHT_DECAY : 0.01 # decay rate used in optimizer
30 |
31 | SCHEDULER:
32 | NAME : warmuppolylr # scheduler name
33 | POWER : 0.9 # scheduler power
34 | WARMUP : 10 # warmup epochs used in scheduler
35 | WARMUP_RATIO : 0.1 # warmup ratio
36 |
37 |
38 | EVAL:
39 | MODEL_PATH : 'checkpoints/pretrained/segformer/segformer.b2.ade.pth' # trained model file path
40 | IMAGE_SIZE : [512, 512] # evaluation image size in (h, w)
41 | MSF:
42 | ENABLE : false # multi-scale and flip evaluation
43 | FLIP : true # use flip in evaluation
44 | SCALES : [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] # scales used in MSF evaluation
45 |
46 |
47 | TEST:
48 | MODEL_PATH : 'checkpoints/pretrained/segformer/segformer.b2.ade.pth' # trained model file path
49 | FILE : 'assests/ade' # filename or foldername
50 | IMAGE_SIZE : [512, 512] # inference image size in (h, w)
51 | OVERLAY : true # save the overlay result (image_alpha+label_alpha)
--------------------------------------------------------------------------------
/configs/cityscapes.yaml:
--------------------------------------------------------------------------------
1 | DEVICE : cuda # device used for training and evaluation (cpu, cuda, cuda0, cuda1, ...)
2 | SAVE_DIR : 'output' # output folder name used for saving the model, logs and inference results
3 |
4 | MODEL:
5 | NAME : DDRNet # name of the model you are using
6 | BACKBONE : DDRNet-23slim # model variant
7 | PRETRAINED : 'checkpoints/backbones/ddrnet/ddrnet_23slim.pth' # backbone model's weight
8 |
9 | DATASET:
10 | NAME : CityScapes # dataset name to be trained with (camvid, cityscapes, ade20k)
11 | ROOT : 'data/CityScapes' # dataset root path
12 | IGNORE_LABEL : 255
13 |
14 | TRAIN:
15 | IMAGE_SIZE : [1024, 1024] # training image size in (h, w)
16 | BATCH_SIZE : 8 # batch size used to train
17 | EPOCHS : 500 # number of epochs to train
18 | EVAL_INTERVAL : 20 # evaluation interval during training
19 | AMP : false # use AMP in training
20 | DDP : false # use DDP training
21 |
22 | LOSS:
23 | NAME : OhemCrossEntropy # loss function name (ohemce, ce, dice)
24 | CLS_WEIGHTS : false # use class weights in loss calculation
25 |
26 | OPTIMIZER:
27 | NAME : adamw # optimizer name
28 | LR : 0.001 # initial learning rate used in optimizer
29 | WEIGHT_DECAY : 0.01 # decay rate used in optimizer
30 |
31 | SCHEDULER:
32 | NAME : warmuppolylr # scheduler name
33 | POWER : 0.9 # scheduler power
34 | WARMUP : 10 # warmup epochs used in scheduler
35 | WARMUP_RATIO : 0.1 # warmup ratio
36 |
37 |
38 | EVAL:
39 | MODEL_PATH : 'checkpoints/pretrained/ddrnet/ddrnet_23slim_city.pth' # trained model file path
40 | IMAGE_SIZE : [1024, 1024] # evaluation image size in (h, w)
41 | MSF:
42 | ENABLE : false # multi-scale and flip evaluation
43 | FLIP : true # use flip in evaluation
44 | SCALES : [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] # scales used in MSF evaluation
45 |
46 |
47 | TEST:
48 | MODEL_PATH : 'checkpoints/pretrained/ddrnet/ddrnet_23slim_city.pth' # trained model file path
49 | FILE : 'assests/cityscapes' # filename or foldername
50 | IMAGE_SIZE : [1024, 1024] # inference image size in (h, w)
51 | OVERLAY : true # save the overlay result (image_alpha+label_alpha)
--------------------------------------------------------------------------------
/configs/custom.yaml:
--------------------------------------------------------------------------------
1 | DEVICE : cuda # device used for training and evaluation (cpu, cuda, cuda0, cuda1, ...)
2 | SAVE_DIR : 'output' # output folder name used for saving the model, logs and inference results
3 |
4 | MODEL:
5 | NAME : DDRNet # name of the model you are using
6 | BACKBONE : DDRNet-23slim # model variant
7 | PRETRAINED : 'checkpoints/backbones/ddrnet/ddrnet_23slim.pth' # backbone model's weight
8 |
9 | DATASET:
10 | NAME : CityScapes # dataset name to be trained with (camvid, cityscapes, ade20k)
11 | ROOT : 'data/CityScapes' # dataset root path
12 | IGNORE_LABEL : 255
13 |
14 | TRAIN:
15 | IMAGE_SIZE : [512, 512] # training image size in (h, w)
16 | BATCH_SIZE : 2 # batch size used to train
17 | EPOCHS : 100 # number of epochs to train
18 | EVAL_INTERVAL : 20 # evaluation interval during training
19 | AMP : false # use AMP in training
20 | DDP : false # use DDP training
21 |
22 | LOSS:
23 | NAME : OhemCrossEntropy # loss function name (ohemce, ce, dice)
24 | CLS_WEIGHTS : false # use class weights in loss calculation
25 |
26 | OPTIMIZER:
27 | NAME : adamw # optimizer name
28 | LR : 0.001 # initial learning rate used in optimizer
29 | WEIGHT_DECAY : 0.01 # decay rate used in optimizer
30 |
31 | SCHEDULER:
32 | NAME : warmuppolylr # scheduler name
33 | POWER : 0.9 # scheduler power
34 | WARMUP : 10 # warmup epochs used in scheduler
35 | WARMUP_RATIO : 0.1 # warmup ratio
36 |
37 |
38 | EVAL:
39 | MODEL_PATH : 'checkpoints/pretrained/ddrnet/ddrnet_23slim_city.pth' # trained model file path
40 | IMAGE_SIZE : [1024, 1024] # evaluation image size in (h, w)
41 | MSF:
42 | ENABLE : false # multi-scale and flip evaluation
43 | FLIP : true # use flip in evaluation
44 | SCALES : [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] # scales used in MSF evaluation
45 |
46 |
47 | TEST:
48 | MODEL_PATH : 'checkpoints/pretrained/ddrnet/ddrnet_23slim_city.pth' # trained model file path
49 | FILE : 'assests/cityscapes' # filename or foldername
50 | IMAGE_SIZE : [1024, 1024] # inference image size in (h, w)
51 | OVERLAY : true # save the overlay result (image_alpha+label_alpha)
--------------------------------------------------------------------------------
/configs/helen.yaml:
--------------------------------------------------------------------------------
1 | DEVICE : cuda # device used for training and evaluation (cpu, cuda, cuda0, cuda1, ...)
2 | SAVE_DIR : 'output' # output folder name used for saving the model, logs and inference results
3 |
4 | MODEL:
5 | NAME : DDRNet # name of the model you are using
6 | BACKBONE : DDRNet-23slim # model variant
7 | PRETRAINED : 'checkpoints/backbones/ddrnet/ddrnet_23slim.pth' # backbone model's weight
8 |
9 | DATASET:
10 | NAME : HELEN # dataset name to be trained with (camvid, cityscapes, ade20k)
11 | ROOT : '/home/sithu/datasets/SmithCVPR2013_dataset_resized' # dataset root path
12 | IGNORE_LABEL : 255
13 |
14 | TRAIN:
15 | IMAGE_SIZE : [512, 512] # training image size in (h, w)
16 | BATCH_SIZE : 16 # batch size used to train
17 | EPOCHS : 200 # number of epochs to train
18 | EVAL_INTERVAL : 10 # evaluation interval during training
19 | AMP : false # use AMP in training
20 | DDP : false # use DDP training
21 |
22 | LOSS:
23 | NAME : OhemCrossEntropy # loss function name (OhemCrossEntropy, CrossEntropy, Dice)
24 | CLS_WEIGHTS : false # use class weights in loss calculation
25 |
26 | OPTIMIZER:
27 | NAME : adamw # optimizer name
28 | LR : 0.001 # initial learning rate used in optimizer
29 | WEIGHT_DECAY : 0.01 # decay rate used in optimizer
30 |
31 | SCHEDULER:
32 | NAME : warmuppolylr # scheduler name
33 | POWER : 0.9 # scheduler power
34 | WARMUP : 5 # warmup epochs used in scheduler
35 | WARMUP_RATIO : 0.1 # warmup ratio
36 |
37 |
38 | EVAL:
39 | MODEL_PATH : 'output/DDRNet_DDRNet-23slim_HELEN_61_11.pth' # trained model file path
40 | IMAGE_SIZE : [512, 512] # evaluation image size in (h, w)
41 | MSF:
42 | ENABLE : false # multi-scale and flip evaluation
43 | FLIP : true # use flip in evaluation
44 | SCALES : [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] # scales used in MSF evaluation
45 |
46 |
47 | TEST:
48 | MODEL_PATH : 'output/DDRNet_DDRNet-23slim_HELEN_61_11.pth' # trained model file path
49 | FILE : 'assests/faces' # filename or foldername
50 | IMAGE_SIZE : [512, 512] # inference image size in (h, w)
51 | OVERLAY : true # save the overlay result (image_alpha+label_alpha)
52 |
--------------------------------------------------------------------------------
/docs/BACKBONES.md:
--------------------------------------------------------------------------------
1 | ## Supported Backbones
2 |
3 | Backbone | Variants | ImageNet-1k Top-1 Acc (%) | Params (M) | GFLOPs | Weights
4 | --- | --- | --- | --- | --- | ---
5 | MicroNet | M1\|M2\|M3 | 51.4`\|`59.4`\|`62.5 | 1`\|`2`\|`3 | 7M`\|`14M`\|`23M | [download][micronetw]
6 | MobileNetV2 | 1.0 | 71.9 | 3 | 300M | [download][mobilenetv2w]
7 | MobileNetV3 | S\|L | 67.7`\|`74.0 | 3`\|`5 | 56M`\|`219M | [S][mobilenetv3s]\|[L][mobilenetv3l]
8 | DDRNet | 23slim | 73.7 | 5 | 860M | [download][ddrnet23slim]
9 | ||
10 | ResNet | 18\|50\|101 | 71.5`\|`80.4`\|`81.5 | 12`\|`26`\|`45 | 2`\|`4`\|`8 | [download][resnetw]
11 | ResNetD | 18\|50\|101 | - | 12`\|`25`\|`44 | 2`\|`4`\|`8 | [download][resnetdw]
12 | MiT | B1\|B2\|B3 | - | 14`\|`25`\|`45 | 2`\|`4`\|`8 | [download][mitw]
13 | PVTv2 | B1\|B2\|B4 | 78.7`\|`82.0`\|`83.6 | 14`\|`25`\|`63 | 2`\|`4`\|`10 | [download][pvtv2w]
14 | ResT | S\|B\|L | 79.6`\|`81.6`\|`83.6 | 14`\|`30`\|`52 | 2`\|`4`\|`8 | [download][restw]
15 | PoolFormer | S24\|S36\|M36 | 80.3`\|`81.4`\|`82.1 | 21`\|`31`\|`56 | 4`\|`5`\|`9 | [download][poolformerw]
16 | ConvNeXt | T\|S\|B | 82.1`\|`83.1`\|`83.8 | 28`\|`50`\|`89 | 5`\|`9`\|`15 | [download][convnextw]
17 | UniFormer | S\|B | 82.9`\|`83.8 | 22`\|`50 | 4`\|`8 | [download][uniformerw]
18 | VAN | S\|B\|L | 81.1`\|`82.8`\|`83.9 | 14`\|`27`\|`45 | 3`\|`5`\|`9 | -
19 | DaViT | T\|S\|B | 82.8`\|`84.2`\|`84.6 | 28`\|`50`\|`88 | 5`\|`9`\|`16 | -
20 |
21 |
22 | [micronetw]: https://drive.google.com/drive/folders/1j4JSTcAh94U2k-7jCl_3nwbNi0eduM2P?usp=sharing
23 | [mobilenetv2w]: https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
24 | [mobilenetv3s]: https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth
25 | [mobilenetv3l]: https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth
26 | [resnetw]: https://drive.google.com/drive/folders/1MXP3Qx51c91PL9P52Tv89t90SaiTYuaC?usp=sharing
27 | [resnetdw]: https://drive.google.com/drive/folders/1sVyewBDkePlw3kbvhUD4PvUxjro4iKFy?usp=sharing
28 | [mitw]: https://drive.google.com/drive/folders/1b7bwrInTW4VLEm27YawHOAMSMikga2Ia
29 | [pvtv2w]: https://drive.google.com/drive/folders/10Dd9BEe4wv71dC5BXhsL_C6KeI_Rcxm3?usp=sharing
30 | [restw]: https://drive.google.com/drive/folders/1R2cewgHo6sYcQnRGBBIndjNomumBwekr?usp=sharing
31 | [ddrnet23slim]: https://drive.google.com/file/d/1tUcUCCsEZ7qKaF_bHHHECTonp4vbh-a9/view?usp=sharing
32 | [poolformerw]: https://drive.google.com/drive/folders/18OyxHHpVq-9pMMG2eu1jot7n-po4dUpD?usp=sharing
33 | [convnextw]: https://drive.google.com/drive/folders/1Oe50_zY4QKFZ0_22mSHKuNav0GiRcgWA?usp=sharing
34 | [uniformerw]: https://drive.google.com/drive/folders/175C4Je4kZoBb5x8HkwH4-VhtG_a5zQnX?usp=sharing
--------------------------------------------------------------------------------
/docs/DATASETS.md:
--------------------------------------------------------------------------------
1 | ##
Supported Datasets
2 |
3 | [ade20k]: http://sceneparsing.csail.mit.edu/
4 | [cityscapes]: https://www.cityscapes-dataset.com/
5 | [camvid]: http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/
6 | [cocostuff]: https://github.com/nightrome/cocostuff
7 | [mhp]: https://lv-mhp.github.io/
8 | [lip]: http://sysu-hcp.net/lip/index.php
9 | [atr]: https://github.com/lemondan/HumanParsing-Dataset
10 | [pascalcontext]: https://cs.stanford.edu/~roozbeh/pascal-context/
11 | [pcannos]: https://drive.google.com/file/d/1hOQnuTVYE9s7iRdo-6iARWkN2-qCAoVz/view?usp=sharing
12 | [suim]: http://irvlab.cs.umn.edu/resources/suim-dataset
13 | [mv]: https://www.mapillary.com/dataset/vistas
14 | [sunrgbd]: https://rgbd.cs.princeton.edu/
15 | [helen]: https://www.sifeiliu.net/face-parsing
16 | [celeba]: https://github.com/switchablenorms/CelebAMask-HQ
17 | [lapa]: https://github.com/JDAI-CV/lapa-dataset
18 | [ibugmask]: https://github.com/hhj1897/face_parsing
19 | [facesynthetics]: https://github.com/microsoft/FaceSynthetics
20 | [ccihp]: https://kalisteo.cea.fr/wp-content/uploads/2021/09/README.html
21 |
22 | Dataset | Type | Categories | Train
Images | Val
Images | Test
Images | Image Size
(HxW)
23 | --- | --- | --- | --- | --- | --- | ---
24 | [COCO-Stuff][cocostuff] | General Scene Parsing | 171 | 118,000 | 5,000 | 20,000 | -
25 | [ADE20K][ade20k] | General Scene Parsing | 150 | 20,210 | 2,000 | 3,352 | -
26 | [PASCALContext][pascalcontext] | General Scene Parsing | 59 | 4,996 | 5,104 | 9,637 | -
27 | ||
28 | [SUN RGB-D][sunrgbd] | Indoor Scene Parsing | 37 | 2,666 | 2,619 | 5,050+labels | -
29 | ||
30 | [Mapillary Vistas][mv] | Street Scene Parsing | 65 | 18,000 | 2,000 | 5,000 | 1080x1920
31 | [CityScapes][cityscapes] | Street Scene Parsing | 19 | 2,975 | 500 | 1,525+labels | 1024x2048
32 | [CamVid][camvid] | Street Scene Parsing | 11 | 367 | 101 | 233+labels | 720x960
33 | ||
34 | [MHPv2][mhp] | Multi-Human Parsing | 59 | 15,403 | 5,000 | 5,000 | -
35 | [MHPv1][mhp] | Multi-Human Parsing | 19 | 3,000 | 1,000 | 980+labels | -
36 | [LIP][lip] | Multi-Human Parsing | 20 | 30,462 | 10,000 | - | -
37 | [CCIHP][ccihp] | Multi-Human Parsing | 22 | 28,280 | 5,000 | 5,000 | -
38 | [CIHP][lip] | Multi-Human Parsing | 20 | 28,280 | 5,000 | 5,000 | -
39 | [ATR][atr] | Single-Human Parsing | 18 | 16,000 | 700 | 1,000+labels | -
40 | ||
41 | [HELEN][helen] | Face Parsing | 11 | 2,000 | 230 | 100+labels | -
42 | [LaPa][lapa] | Face Parsing | 11 | 18,176 | 2,000 | 2,000+labels | -
43 | [iBugMask][ibugmask] | Face Parsing | 11 | 21,866 | - | 1,000+labels | -
44 | [CelebAMaskHQ][celeba] | Face Parsing | 19 | 24,183 | 2,993 | 2,824+labels | 512x512
45 | [FaceSynthetics][facesynthetics] | Face Parsing (Synthetic) | 19 | 100,000 | 1,000 | 100+labels | 512x512
46 | ||
47 | [SUIM][suim] | Underwater Imagery | 8 | 1,525 | - | 110+labels | -
48 |
49 | Check [DATASETS](./DATASETS.md) to find more segmentation datasets.
50 |
51 |
52 | Datasets Structure (click to expand)
53 |
54 | Datasets should have the following structure:
55 |
56 | ```
57 | data
58 | |__ ADEChallenge
59 | |__ ADEChallengeData2016
60 | |__ images
61 | |__ training
62 | |__ validation
63 | |__ annotations
64 | |__ training
65 | |__ validation
66 |
67 | |__ CityScapes
68 | |__ leftImg8bit
69 | |__ train
70 | |__ val
71 | |__ test
72 | |__ gtFine
73 | |__ train
74 | |__ val
75 | |__ test
76 |
77 | |__ CamVid
78 | |__ train
79 | |__ val
80 | |__ test
81 | |__ train_labels
82 | |__ val_labels
83 | |__ test_labels
84 |
85 | |__ VOCdevkit
86 | |__ VOC2010
87 | |__ JPEGImages
88 | |__ SegmentationClassContext
89 | |__ ImageSets
90 | |__ SegmentationContext
91 | |__ train.txt
92 | |__ val.txt
93 |
94 | |__ COCO
95 | |__ images
96 | |__ train2017
97 | |__ val2017
98 | |__ labels
99 | |__ train2017
100 | |__ val2017
101 |
102 | |__ MHPv1
103 | |__ images
104 | |__ annotations
105 | |__ train_list.txt
106 | |__ test_list.txt
107 |
108 | |__ MHPv2
109 | |__ train
110 | |__ images
111 | |__ parsing_annos
112 | |__ val
113 | |__ images
114 | |__ parsing_annos
115 |
116 | |__ LIP
117 | |__ LIP
118 | |__ TrainVal_images
119 | |__ train_images
120 | |__ val_images
121 | |__ TrainVal_parsing_annotations
122 | |__ train_segmentations
123 | |__ val_segmentations
124 |
125 | |__ CIHP/CCIHP
126 | |__ instance-leve_human_parsing
127 | |__ Training
128 | |__ Images
129 | |__ Category_ids
130 | |__ Validation
131 | |__ Images
132 | |__ Category_ids
133 |
134 | |__ ATR
135 | |__ humanparsing
136 | |__ JPEGImages
137 | |__ SegmentationClassAug
138 |
139 | |__ SUIM
140 | |__ train_val
141 | |__ images
142 | |__ masks
143 | |__ TEST
144 | |__ images
145 | |__ masks
146 |
147 | |__ SunRGBD
148 | |__ SUNRGBD
149 | |__ kv1/kv2/realsense/xtion
150 | |__ SUNRGBDtoolbox
151 | |__ traintestSUNRGBD
152 | |__ allsplit.mat
153 |
154 | |__ Mapillary
155 | |__ training
156 | |__ images
157 | |__ labels
158 | |__ validation
159 | |__ images
160 | |__ labels
161 |
162 | |__ SmithCVPR2013_dataset_resized (HELEN)
163 | |__ images
164 | |__ labels
165 | |__ exemplars.txt
166 | |__ testing.txt
167 | |__ tuning.txt
168 |
169 | |__ CelebAMask-HQ
170 | |__ CelebA-HQ-img
171 | |__ CelebAMask-HQ-mask-anno
172 | |__ CelebA-HQ-to-CelebA-mapping.txt
173 |
174 | |__ LaPa
175 | |__ train
176 | |__ images
177 | |__ labels
178 | |__ val
179 | |__ images
180 | |__ labels
181 | |__ test
182 | |__ images
183 | |__ labels
184 |
185 | |__ ibugmask_release
186 | |__ train
187 | |__ test
188 |
189 | |__ FaceSynthetics
190 | |__ dataset_100000
191 | |__ dataset_1000
192 | |__ dataset_100
193 | ```
194 |
195 | > Note: For PASCALContext, download the annotations from [here](pcannos) and put it in VOC2010.
196 |
197 | > Note: For CelebAMask-HQ, run the preprocess script. `python3 scripts/preprocess_celebamaskhq.py --root `.
198 |
199 |
200 |
--------------------------------------------------------------------------------
/docs/MODELS.md:
--------------------------------------------------------------------------------
1 | ## Scene Parsing
2 |
3 | Accurate Models
4 |
5 | Method | Backbone | ADE20K
(mIoU) | Cityscapes
(mIoU) | COCO-Stuff
(mIoU) |Params
(M) | GFLOPs
(512x512) | GFLOPs
(1024x1024) | Weights
6 | --- | --- | --- | --- | --- | --- | --- | --- | ---
7 | SegFormer | MiT-B1 | 42.2 | 78.5 | 40.2 | 14 | 16 | 244 | [ade][segformerb1]
8 | || MiT-B2 | 46.5 | 81.0 | 44.6 | 28 | 62 | 717 | [ade][segformerb2]
9 | || MiT-B3 | 49.4 | 81.7 | 45.5 | 47 | 79 | 963 | [ade][segformerb3]
10 | ||
11 | Light-Ham | VAN-S | 45.7 | - | - | 15 | 21 | - | -
12 | || VAN-B | 49.6 | - | - | 27 | 34 | - | -
13 | || VAN-L | 51.0 | - | - | 46 | 55 | - | -
14 | ||
15 | Lawin | MiT-B1 | 42.1 | 79.0 | 40.5 | 14 | 13 | 218 | -
16 | || MiT-B2 | 47.8 | 81.7 | 45.2 | 30 | 45 | 563 | -
17 | || MiT-B3 | 50.3 | 82.5 | 46.6 | 50 | 62 | 809 | -
18 | ||
19 | TopFormer | TopFormer-T | 34.6 | - | - | 1.4 | 0.6 | - | -
20 | || TopFormer-S | 37.0 | - | - | 3.1 | 1.2 | - | -
21 | || TopFormer-B | 39.2 | - | - | 5.1 | 1.8 | - | -
22 |
23 | * mIoU results are with a single scale from official papers.
24 | * ADE20K image size = 512x512
25 | * Cityscapes image size = 1024x1024
26 | * COCO-Stuff image size = 512x512
27 |
28 | Real-time Models
29 |
30 | Method | Backbone | CityScapes-val
(mIoU) | CamVid
(mIoU) | Params (M) | GFLOPs
(1024x2048) | Weights
31 | --- | --- | --- | --- | --- | --- | ---
32 | BiSeNetv1 | ResNet-18 | 74.8 | 68.7 | 14 | 49 | -
33 | BiSeNetv2 | - | 73.4 | 72.4 | 18 | 21 | -
34 | SFNet | ResNetD-18 | 79.0 | - | 13 | - | -
35 | DDRNet | DDRNet-23slim | 77.8 | 74.7 | 6 | 36 | [city][ddrnet]
36 |
37 | * mIoU results are with a single scale from official papers.
38 | * Cityscapes image size = 1024x2048 (except BiSeNetv1 & 2 which uses 512x1024)
39 | * CamVid image size = 960x720
40 |
41 |
42 | ## Face Parsing
43 |
44 | Method | Backbone | HELEN-val
(mIoU) | Params
(M) | GFLOPs
(512x512) | FPS
(GTX1660ti) | Weights
45 | --- | --- | --- | --- | --- | --- | ---
46 | BiSeNetv1 | ResNet-18 | 58.50 | 14 | 13 | 263 | [HELEN](https://drive.google.com/file/d/1HMC6OiFPc-aYwhlHlPYoXa-VCR3r2WPQ/view?usp=sharing)
47 | BiSeNetv2 | - | 58.58 | 18 | 15 | 195 | [HELEN](https://drive.google.com/file/d/1cf-W_2m-vfxMRZ0mFQjEwhOglURpH7m6/view?usp=sharing)
48 | DDRNet | DDRNet-23slim | 61.11 | 6 | 5 | 180 | [HELEN](https://drive.google.com/file/d/1SdOgVvgYrp8UFztHWN6dHH0MhP8zqnyh/view?usp=sharing)
49 | SFNet | ResNetD-18 | 61.00 | 14 | 31 | 56 | [HELEN](https://drive.google.com/file/d/13w42DgI4PJ05bkWY9XCK_skSGMsmXroj/view?usp=sharing)
50 |
51 |
52 | [ddrnet]: https://drive.google.com/file/d/1VdE3OkrIlIzLRPuT-2So-Xq_5gPaxm0t/view?usp=sharing
53 | [segformerb3]: https://drive.google.com/file/d/1-OmW3xRD3WAbJTzktPC-VMOF5WMsN8XT/view?usp=sharing
54 | [segformerb2]: https://drive.google.com/file/d/1AcgEK5aWMJzpe8tsfauqhragR0nBHyPh/view?usp=sharing
55 | [segformerb1]: https://drive.google.com/file/d/18PN_P3ajcJi_5Q2v8b4BP9O4VdNCpt6m/view?usp=sharing
56 | [topformert]: https://drive.google.com/file/d/1OnS3_PwjJuNMWCKisreNxw_Lma8uR8bV/view?usp=sharing
57 | [topformers]: https://drive.google.com/file/d/19041fMb4HuDyNhIYdW1r5612FyzpexP0/view?usp=sharing
58 | [topformerb]: https://drive.google.com/file/d/1m7CxYKWAyJzl5W3cj1vwsW4DfqAb_rqz/view?usp=sharing
--------------------------------------------------------------------------------
/docs/OTHER_DATASETS.md:
--------------------------------------------------------------------------------
1 | # Semantic Segmentation Datasets
2 |
3 | ## General
4 |
5 | * [COCO-Stuff](https://github.com/nightrome/cocostuff)
6 | * [PASCAL-Context](https://cs.stanford.edu/~roozbeh/pascal-context/)
7 | * [PASCAL-VOC](http://host.robots.ox.ac.uk/pascal/VOC/)
8 | * [MSeg](https://github.com/mseg-dataset/mseg-api)
9 | * [ADE20K](http://groups.csail.mit.edu/vision/datasets/ADE20K/)
10 | * [Places365](http://places2.csail.mit.edu/)
11 |
12 | ## Outdoor
13 |
14 | * [CityScapes](https://www.cityscapes-dataset.com/)
15 | * [KITTI](http://www.cvlibs.net/datasets/kitti/)
16 | * [Mapillary Vistas](https://www.mapillary.com/dataset/vistas?lat=20&lng=0&z=1.5&pKey=xyW6a0ZmrJtjLw2iJ71Oqg)
17 | * [CamVid](http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/)
18 | * [Standford Background](http://dags.stanford.edu/projects/scenedataset.html)
19 | * [ApolloScape](http://apolloscape.auto/)
20 | * [BDD100K](https://bdd-data.berkeley.edu/)
21 | * [WoodScape](https://github.com/valeoai/WoodScape)
22 | * [IDD](http://idd.insaan.iiit.ac.in/)
23 | * [DADA-2000](https://github.com/JWFangit/LOTVS-DADA)
24 | * [Street Hazards](https://github.com/hendrycks/anomaly-seg)
25 | * [UNDD](https://github.com/sauradip/night_image_semantic_segmentation)
26 | * [WildDash](https://wilddash.cc/)
27 | * [A2D2](https://www.a2d2.audi/a2d2/en/dataset.html)
28 |
29 | ## Indoor
30 |
31 | * [ScanNet](http://www.scan-net.org/)
32 | * [Sun-RGBD](https://rgbd.cs.princeton.edu/)
33 | * [SceneNet](https://robotvault.bitbucket.io/)
34 | * [2D-3D-Semantics](https://github.com/alexsax/2D-3D-Semantics)
35 | * [NYUDepthv2](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html)
36 | * [SUN3D](http://sun3d.cs.princeton.edu/)
37 |
38 | ## Human Parts
39 |
40 | * [LIP/CIHP](http://sysu-hcp.net/lip/index.php)
41 | * [MHP](https://github.com/ZhaoJ9014/Multi-Human-Parsing)
42 | * [DeepFashion2](https://github.com/switchablenorms/DeepFashion2)
43 | * [PASCAL-Person-Part](http://roozbehm.info/pascal-parts/pascal-parts.html)
44 | * [PIC](http://picdataset.com/challenge/task/download/)
45 | * [iMat](https://github.com/visipedia/imat_comp)
46 |
47 | ## Food
48 |
49 | * [FoodSeg103](https://xiongweiwu.github.io/foodseg103.html)
50 |
51 | ## Binary
52 |
53 | * [SBCoseg](http://www.mlmrlab.com/cosegmentation_dataset_downloadC.html)
54 | * [DeepFish](https://github.com/alzayats/DeepFish)
55 | * [MVTecAD](https://www.mvtec.com/company/research/datasets/mvtec-ad/)
56 | * [LLAMAS](https://unsupervised-llamas.com/llamas/)
57 |
58 | ## Boundary Segmentation
59 |
60 | * [SBD](http://home.bharathh.info/pubs/codes/SBD/download.html)
61 | * [SketchyScene](https://github.com/SketchyScene/SketchyScene)
62 | * [TextSeg](https://github.com/SHI-Labs/Rethinking-Text-Segmentation)
63 |
64 | ## Synthetic
65 |
66 | * [EDEN](https://lhoangan.github.io/eden/)
67 | * [Synscapes](https://7dlabs.com/synscapes-overview)
68 | * [SYNTHIA](https://synthia-dataset.net/)
69 | * [GTA5](https://download.visinf.tu-darmstadt.de/data/from_games/)
70 |
71 | ## Robot-view
72 |
73 | * [Robot Home](http://mapir.isa.uma.es/mapirwebsite/index.php/mapir-downloads/203-robot-at-home-dataset.html)
74 | * [RobotriX](https://github.com/3dperceptionlab/therobotrix)
75 | * [Gibson Env](http://gibsonenv.stanford.edu/)
76 |
77 | ## Medical
78 |
79 | * [BraTS2015](https://www.smir.ch/BRATS/Start2015)
80 | * [Medical-Decathlon](http://medicaldecathlon.com/)
81 | * [PROMISE12](https://promise12.grand-challenge.org/)
82 | * [REFUGE](https://bitbucket.org/woalsdnd/refuge/src/master/)
83 | * [BIMCV-COVID-19](https://github.com/BIMCV-CSUSP/BIMCV-COVID-19)
84 | * [OpenEDS](https://research.fb.com/programs/openeds-challenge)
85 | * [Retinal-Microsurgery](https://sites.google.com/site/sznitr/home)
86 | * [CoNSeP](https://warwick.ac.uk/fac/sci/dcs/research/tia/data/hovernet/)
87 | * [ISIC-2018-Task1](https://challenge2018.isic-archive.com/task1/)
88 | * [Cata7](https://github.com/nizhenliang/RAUNet)
89 | * [ROSE](https://imed.nimte.ac.cn/dataofrose.html)
90 | * [SegTHOR](https://competitions.codalab.org/competitions/21145)
91 | * [CAMEL](https://github.com/ThoroughImages/CAMEL)
92 | * [CryoNuSeg](https://github.com/masih4/CryoNuSeg)
93 | * [OpenEDS2020](https://research.fb.com/programs/openeds-2020-challenge/)
94 | * [VocalFolds](https://github.com/imesluh/vocalfolds)
95 | * [Medico](https://multimediaeval.github.io/editions/2020/tasks/medico/)
96 | * [20MioEyeDS](https://unitc-my.sharepoint.com/personal/iitfu01_cloud_uni-tuebingen_de/_layouts/15/onedrive.aspx?id=%2Fpersonal%2Fiitfu01%5Fcloud%5Funi%2Dtuebingen%5Fde%2FDocuments%2F20MioEyeDS&originalPath=aHR0cHM6Ly91bml0Yy1teS5zaGFyZXBvaW50LmNvbS86ZjovZy9wZXJzb25hbC9paXRmdTAxX2Nsb3VkX3VuaS10dWViaW5nZW5fZGUvRXZyTlBkdGlnRlZIdENNZUZLU3lMbFVCZXBPY2JYMG5Fa2Ftd2VlWmEwczlTUT9ydGltZT1zcWtvTV9CYzJVZw)
97 | * [BrainMRI](https://www.kaggle.com/mateuszbuda/lgg-mri-segmentation)
98 | * [Liver Tumor](https://www.kaggle.com/andrewmvd/liver-tumor-segmentation)
99 | * [MRI Hippocampus](https://www.kaggle.com/sabermalek/mrihs)
100 |
101 | ## Aerial
102 |
103 | * [RIT-18](https://github.com/rmkemker/RIT-18)
104 | * [PolSF](https://github.com/liuxuvip/PolSF)
105 | * [AIRS](https://www.airs-dataset.com/)
106 | * [UOPNOA](https://zenodo.org/record/4648002)
107 | * [LandCover](https://landcover.ai/)
108 | * [ICG](https://www.kaggle.com/bulentsiyah/semantic-drone-dataset)
109 |
110 | ## Video
111 |
112 | * [DAVIS](https://davischallenge.org/)
113 | * [SESIV](https://sites.google.com/view/ltnghia/research/sesiv)
114 | * [YouTube-VOS](https://youtube-vos.org/)
115 |
116 | ## Others
117 |
118 | * [SUIM](http://irvlab.cs.umn.edu/resources/suim-dataset)
119 | * [Cam2BEV](https://github.com/ika-rwth-aachen/Cam2BEV)
120 | * [LabPics](https://www.kaggle.com/sagieppel/labpics-chemistry-labpics-medical)
121 | * [CreativeFlow+](https://www.cs.toronto.edu/creativeflow/)
122 | * [RoadAnomaly21](https://segmentmeifyoucan.com/datasets)
123 | * [RoadObstacle21](https://segmentmeifyoucan.com/datasets)
124 | * [HouseExpo](https://github.com/teaganli/houseexpo/)
125 | * [D2S](https://www.mvtec.com/company/research/datasets/mvtec-d2s/)
--------------------------------------------------------------------------------
/scripts/calc_class_weights.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import io
3 |
4 |
5 | def calc_class_weights(files, n_classes):
6 | pixels = {}
7 | for file in files:
8 | lbl_path = str(file).split('.')[0].replace('images', 'labels')
9 | label = io.read_image(lbl_path)
10 | for i in range(n_classes):
11 | if pixels.get(i) is not None:
12 | pixels[i] += [(label == i).sum()]
13 | else:
14 | pixels[i] = [(label == i).sum()]
15 |
16 | class_freq = torch.tensor([sum(v).item() for v in pixels.values()])
17 | weights = 1 / torch.log1p(class_freq)
18 | weights *= n_classes
19 | weights /= weights.sum()
20 | return weights
--------------------------------------------------------------------------------
/scripts/export_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | from PIL import Image
4 | from pathlib import Path
5 | from tqdm import tqdm
6 |
7 |
8 | def create_calibrate_data(image_folder, save_path):
9 | dataset = []
10 | mean = np.array([0.485, 0.456, 0.406])[None, None, :]
11 | std = np.array([0.229, 0.224, 0.225])[None, None, :]
12 | files = list(Path(image_folder).glob('*.jpg'))[:100]
13 | for file in tqdm(files):
14 | image = Image.open(file).convert('RGB')
15 | image = image.resize((512, 512))
16 | image = np.array(image, dtype=np.float32)
17 | image /= 255
18 | image -= mean
19 | image /= std
20 | dataset.append(image)
21 | dataset = np.stack(dataset, axis=0)
22 | np.save(save_path, dataset)
23 |
24 |
25 | if __name__ == '__main__':
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument('--dataset-path', type=str, default='/home/sithu/datasets/SmithCVPR2013_dataset_resized/images')
28 | parser.add_argument('--save-path', type=str, default='output/calibrate_data')
29 | args = parser.parse_args()
30 |
31 | create_calibrate_data(args.dataset_path, args.save_path)
32 |
33 |
--------------------------------------------------------------------------------
/scripts/onnx_infer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import onnxruntime
4 | from PIL import Image
5 | from semseg.utils.visualize import generate_palette
6 | from semseg.utils.utils import timer
7 |
8 |
9 | class Inference:
10 | def __init__(self, model: str) -> None:
11 | self.session = onnxruntime.InferenceSession(model)
12 | self.input_details = self.session.get_inputs()[0]
13 | self.palette = generate_palette(self.session.get_outputs()[0].shape[1], background=True)
14 | self.img_size = self.input_details.shape[-2:]
15 | self.mean = np.array([0.485, 0.456, 0.406]).reshape(-1, 1, 1)
16 | self.std = np.array([0.229, 0.224, 0.225]).reshape(-1, 1, 1)
17 |
18 | def preprocess(self, image: Image.Image) -> np.ndarray:
19 | image = image.resize(self.img_size)
20 | image = np.array(image, dtype=np.float32).transpose(2, 0, 1)
21 | image /= 255
22 | image -= self.mean
23 | image /= self.std
24 | image = image[np.newaxis, ...]
25 | return image
26 |
27 | def postprocess(self, seg_map: np.ndarray) -> np.ndarray:
28 | seg_map = np.argmax(seg_map, axis=1).astype(int)
29 | seg_map = self.palette[seg_map]
30 | return seg_map.squeeze()
31 |
32 | @timer
33 | def model_forward(self, img: np.ndarray) -> np.ndarray:
34 | return self.session.run(None, {self.input_details.name: img})[0]
35 |
36 | def predict(self, img_path: str) -> Image.Image:
37 | image = Image.open(img_path).convert('RGB')
38 | image = self.preprocess(image)
39 | seg_map = self.model_forward(image)
40 | seg_map = self.postprocess(seg_map)
41 | return seg_map.astype(np.uint8)
42 |
43 |
44 | if __name__ == '__main__':
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument('--model', type=str, default='output/DDRNet_23slim_HELEN_59_75.onnx')
47 | parser.add_argument('--img-path', type=str, default='assests/faces/27409477_1.jpg')
48 | args = parser.parse_args()
49 |
50 | session = Inference(args.model)
51 | seg_map = session.predict(args.img_path)
52 | seg_map = Image.fromarray(seg_map)
53 | seg_map.save(f"{args.img_path.split('.')[0]}_out.png")
--------------------------------------------------------------------------------
/scripts/openvino_infer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | from PIL import Image
4 | from pathlib import Path
5 | from openvino.inference_engine import IECore
6 | from semseg.utils.visualize import generate_palette
7 | from semseg.utils.utils import timer
8 |
9 |
10 | class Inference:
11 | def __init__(self, model: str) -> None:
12 | files = Path(model).iterdir()
13 |
14 | for file in files:
15 | if file.suffix == '.xml':
16 | model = str(file)
17 | elif file.suffix == '.bin':
18 | weights = str(file)
19 | ie = IECore()
20 | model = ie.read_network(model=model, weights=weights)
21 | self.input_info = next(iter(model.input_info))
22 | self.output_info = next(iter(model.outputs))
23 | self.img_size = model.input_info['input'].input_data.shape[-2:]
24 | self.palette = generate_palette(11, background=True)
25 | self.engine = ie.load_network(network=model, device_name='CPU')
26 |
27 | self.mean = np.array([0.485, 0.456, 0.406]).reshape(-1, 1, 1)
28 | self.std = np.array([0.229, 0.224, 0.225]).reshape(-1, 1, 1)
29 |
30 | def preprocess(self, image: Image.Image) -> np.ndarray:
31 | image = image.resize(self.img_size)
32 | image = np.array(image, dtype=np.float32).transpose(2, 0, 1)
33 | image /= 255
34 | image -= self.mean
35 | image /= self.std
36 | image = image[np.newaxis, ...]
37 | return image
38 |
39 | def postprocess(self, seg_map: np.ndarray) -> np.ndarray:
40 | seg_map = np.argmax(seg_map, axis=1).astype(int)
41 | seg_map = self.palette[seg_map]
42 | return seg_map.squeeze()
43 |
44 | @timer
45 | def model_forward(self, img: np.ndarray) -> np.ndarray:
46 | return self.engine.infer(inputs={self.input_info: img})[self.output_info]
47 |
48 | def predict(self, img_path: str) -> Image.Image:
49 | image = Image.open(img_path).convert('RGB')
50 | image = self.preprocess(image)
51 | seg_map = self.model_forward(image)
52 | seg_map = self.postprocess(seg_map)
53 | return seg_map.astype(np.uint8)
54 |
55 |
56 | if __name__ == '__main__':
57 | parser = argparse.ArgumentParser()
58 | parser.add_argument('--model', type=str, default='output/ddrnet_openvino')
59 | parser.add_argument('--img-path', type=str, default='assests/faces/27409477_1.jpg')
60 | args = parser.parse_args()
61 |
62 | session = Inference(args.model)
63 | seg_map = session.predict(args.img_path)
64 | seg_map = Image.fromarray(seg_map)
65 | seg_map.save(f"{args.img_path.split('.')[0]}_out.png")
66 |
--------------------------------------------------------------------------------
/scripts/preprocess_celebamaskhq.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | from tqdm import tqdm
4 | from pathlib import Path
5 | from PIL import Image
6 |
7 |
8 | def main(root):
9 | root = Path(root)
10 | annot_dir = root / 'CelebAMask-HQ-label'
11 | annot_dir.mkdir(exist_ok=True)
12 |
13 | train_lists = []
14 | test_lists = []
15 | val_lists = []
16 |
17 | names = [
18 | 'skin', 'nose', 'eye_g', 'l_eye', 'r_eye', 'l_brow', 'r_brow', 'l_ear', 'r_ear',
19 | 'mouth', 'u_lip', 'l_lip', 'hair', 'hat', 'ear_r', 'neck_l', 'neck', 'cloth'
20 | ]
21 | num_images = 30000
22 |
23 | for folder in root.iterdir():
24 | if folder.is_dir():
25 | if folder.name == 'CelebAMask-HQ-mask-anno':
26 | print("Transforming separate masks into one-hot mask...")
27 | for i in tqdm(range(num_images)):
28 | folder_num = i // 2000
29 | label = np.zeros((512, 512))
30 | for idx, name in enumerate(names):
31 | fname = folder / f"{folder_num}" / f"{str(i).rjust(5, '0')}_{name}.png"
32 | if fname.exists():
33 | img = Image.open(fname).convert('P')
34 | img = np.array(img)
35 | label[img != 0] = idx + 1
36 |
37 | label = Image.fromarray(label.astype(np.uint8))
38 | label.save(annot_dir / f"{i}.png")
39 |
40 | print("Splitting into train/val/test...")
41 |
42 | with open(root / "CelebA-HQ-to-CelebA-mapping.txt") as f:
43 | lines = f.read().splitlines()[1:]
44 | image_list = [int(line.split()[1]) for line in lines]
45 |
46 |
47 | for idx, fname in enumerate(image_list):
48 | if fname >= 162771 and fname < 182638:
49 | val_lists.append(f"{idx}\n")
50 |
51 | elif fname >= 182638:
52 | test_lists.append(f"{idx}\n")
53 |
54 | else:
55 | train_lists.append(f"{idx}\n")
56 |
57 | print(f"Train Size: {len(train_lists)}")
58 | print(f"Val Size: {len(val_lists)}")
59 | print(f"Test Size: {len(test_lists)}")
60 |
61 | with open(root / 'train_list.txt', 'w') as f:
62 | f.writelines(train_lists)
63 |
64 | with open(root / 'val_list.txt', 'w') as f:
65 | f.writelines(val_lists)
66 |
67 | with open(root / 'test_list.txt', 'w') as f:
68 | f.writelines(test_lists)
69 |
70 |
71 | if __name__ == '__main__':
72 | parser = argparse.ArgumentParser()
73 | parser.add_argument('--root', type=str, default='/home/sithu/datasets/CelebAMask-HQ')
74 | args = parser.parse_args()
75 | main(args.root)
--------------------------------------------------------------------------------
/scripts/tflite_infer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import tflite_runtime.interpreter as tflite
4 | from PIL import Image
5 | from semseg.utils.visualize import generate_palette
6 | from semseg.utils.utils import timer
7 |
8 |
9 | class Inference:
10 | def __init__(self, model: str) -> None:
11 | self.interpreter = tflite.Interpreter(model)
12 | self.interpreter.allocate_tensors()
13 |
14 | self.input_details = self.interpreter.get_input_details()[0]
15 | self.output_details = self.interpreter.get_output_details()[0]
16 | self.palette = generate_palette(self.output_details['shape'][-1], background=True)
17 | self.img_size = self.input_details['shape'][1:3]
18 | self.mean = np.array([0.485, 0.456, 0.406])[None, None, :]
19 | self.std = np.array([0.229, 0.224, 0.225])[None, None, :]
20 |
21 | def preprocess(self, image: Image.Image) -> np.ndarray:
22 | image = image.resize(self.img_size)
23 | image = np.array(image, dtype=np.float32)
24 | image /= 255
25 | image -= self.mean
26 | image /= self.std
27 | if self.input_details['dtype'] == np.int8 or self.input_details['dtype'] == np.uint8:
28 | scale, zero_point = self.input_details['quantization']
29 | image /= scale
30 | image += zero_point
31 | image = image.astype(self.input_details['dtype'])
32 | return image[np.newaxis, ...]
33 |
34 | def postprocess(self, seg_map: np.ndarray) -> np.ndarray:
35 | if self.output_details['dtype'] == np.int8 or self.output_details['dtype'] == np.uint8:
36 | scale, zero_point = self.output_details['quantization']
37 | seg_map = scale * (seg_map - zero_point)
38 | seg_map = np.argmax(seg_map, axis=-1).astype(int)
39 | seg_map = self.palette[seg_map]
40 | return seg_map.squeeze()
41 |
42 | @timer
43 | def model_forward(self, img: np.ndarray) -> np.ndarray:
44 | self.interpreter.set_tensor(self.input_details['index'], img)
45 | self.interpreter.invoke()
46 | return self.interpreter.get_tensor(self.output_details['index'])
47 |
48 | def predict(self, img_path: str) -> Image.Image:
49 | image = Image.open(img_path).convert('RGB')
50 | image = self.preprocess(image)
51 | seg_map = self.model_forward(image)
52 | seg_map = self.postprocess(seg_map)
53 | return seg_map.astype(np.uint8)
54 |
55 |
56 | if __name__ == '__main__':
57 | parser = argparse.ArgumentParser()
58 | parser.add_argument('--model', type=str, default='output/ddrnet_tflite2/ddrnet_float16.tflite')
59 | parser.add_argument('--img-path', type=str, default='assests/faces/27409477_1.jpg')
60 | args = parser.parse_args()
61 |
62 | session = Inference(args.model)
63 | seg_map = session.predict(args.img_path)
64 | seg_map = Image.fromarray(seg_map)
65 | seg_map.save(f"{args.img_path.split('.')[0]}_out.png")
--------------------------------------------------------------------------------
/semseg/__init__.py:
--------------------------------------------------------------------------------
1 | from tabulate import tabulate
2 | from semseg import models
3 | from semseg import datasets
4 | from semseg.models import backbones, heads
5 |
6 |
7 | def show_models():
8 | model_names = models.__all__
9 | numbers = list(range(1, len(model_names)+1))
10 | print(tabulate({'No.': numbers, 'Model Names': model_names}, headers='keys'))
11 |
12 |
13 | def show_backbones():
14 | backbone_names = backbones.__all__
15 | variants = []
16 | for name in backbone_names:
17 | try:
18 | variants.append(list(eval(f"backbones.{name.lower()}_settings").keys()))
19 | except:
20 | variants.append('-')
21 | print(tabulate({'Backbone Names': backbone_names, 'Variants': variants}, headers='keys'))
22 |
23 |
24 | def show_heads():
25 | head_names = heads.__all__
26 | numbers = list(range(1, len(head_names)+1))
27 | print(tabulate({'No.': numbers, 'Heads': head_names}, headers='keys'))
28 |
29 |
30 | def show_datasets():
31 | dataset_names = datasets.__all__
32 | numbers = list(range(1, len(dataset_names)+1))
33 | print(tabulate({'No.': numbers, 'Datasets': dataset_names}, headers='keys'))
34 |
--------------------------------------------------------------------------------
/semseg/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .ade20k import ADE20K
2 | from .camvid import CamVid
3 | from .cityscapes import CityScapes
4 | from .pascalcontext import PASCALContext
5 | from .cocostuff import COCOStuff
6 | from .sunrgbd import SunRGBD
7 | from .mapillary import MapillaryVistas
8 | from .mhpv1 import MHPv1
9 | from .mhpv2 import MHPv2
10 | from .lip import LIP
11 | from .cihp import CIHP, CCIHP
12 | from .atr import ATR
13 | from .suim import SUIM
14 | from .helen import HELEN
15 | from .lapa import LaPa
16 | from .ibugmask import iBugMask
17 | from .celebamaskhq import CelebAMaskHQ
18 | from .facesynthetics import FaceSynthetics
19 |
20 |
21 | __all__ = [
22 | 'CamVid',
23 | 'CityScapes',
24 | 'ADE20K',
25 | 'MHPv1',
26 | 'MHPv2',
27 | 'LIP',
28 | 'CIHP',
29 | 'CCIHP',
30 | 'ATR',
31 | 'PASCALContext',
32 | 'COCOStuff',
33 | 'SUIM',
34 | 'SunRGBD',
35 | 'MapillaryVistas',
36 | 'HELEN',
37 | 'LaPa',
38 | 'iBugMask',
39 | 'CelebAMaskHQ',
40 | 'FaceSynthetics',
41 | ]
--------------------------------------------------------------------------------
/semseg/datasets/ade20k.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch.utils.data import Dataset
4 | from torchvision import io
5 | from pathlib import Path
6 | from typing import Tuple
7 |
8 |
9 | class ADE20K(Dataset):
10 | CLASSES = [
11 | 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ', 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
12 | 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
13 | 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
14 | 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
15 | 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
16 | 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
17 | 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
18 | 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
19 | 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
20 | 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
21 | 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
22 | 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', 'clock', 'flag'
23 | ]
24 |
25 | PALETTE = torch.tensor([
26 | [120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
27 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
28 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
29 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
30 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
31 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
32 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
33 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
34 | [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
35 | [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
36 | [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
37 | [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
38 | [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
39 | [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
40 | [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
41 | [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
42 | [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
43 | [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
44 | [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], [102, 255, 0], [92, 0, 255]
45 | ])
46 |
47 | def __init__(self, root: str, split: str = 'train', transform = None) -> None:
48 | super().__init__()
49 | assert split in ['train', 'val']
50 | split = 'training' if split == 'train' else 'validation'
51 | self.transform = transform
52 | self.n_classes = len(self.CLASSES)
53 | self.ignore_label = -1
54 |
55 | img_path = Path(root) / 'images' / split
56 | self.files = list(img_path.glob('*.jpg'))
57 |
58 | if not self.files:
59 | raise Exception(f"No images found in {img_path}")
60 | print(f"Found {len(self.files)} {split} images.")
61 |
62 | def __len__(self) -> int:
63 | return len(self.files)
64 |
65 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
66 | img_path = str(self.files[index])
67 | lbl_path = str(self.files[index]).replace('images', 'annotations').replace('.jpg', '.png')
68 |
69 | image = io.read_image(img_path)
70 | label = io.read_image(lbl_path)
71 |
72 | if self.transform:
73 | image, label = self.transform(image, label)
74 | return image, label.squeeze().long() - 1
75 |
76 |
77 | if __name__ == '__main__':
78 | from semseg.utils.visualize import visualize_dataset_sample
79 | visualize_dataset_sample(ADE20K, '/home/sithu/datasets/ADEChallenge/ADEChallengeData2016')
--------------------------------------------------------------------------------
/semseg/datasets/atr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch.utils.data import Dataset
4 | from torchvision import io
5 | from pathlib import Path
6 | from typing import Tuple
7 |
8 |
9 | class ATR(Dataset):
10 | """Single Person Fashion Dataset
11 | https://openaccess.thecvf.com/content_iccv_2015/papers/Liang_Human_Parsing_With_ICCV_2015_paper.pdf
12 |
13 | https://github.com/lemondan/HumanParsing-Dataset
14 | num_classes: 17+background
15 | 16000 train images
16 | 700 val images
17 | 1000 test images with labels
18 | """
19 | CLASSES = ['background', 'hat', 'hair', 'sunglass', 'upper-clothes', 'skirt', 'pants', 'dress', 'belt', 'left-shoe', 'right-shoe', 'face', 'left-leg', 'right-leg', 'left-arm', 'right-arm', 'bag', 'scarf']
20 | PALETTE = torch.tensor([[0, 0, 0], [127, 0, 0], [254, 0, 0], [0, 84, 0], [169, 0, 50], [254, 84, 0], [255, 0, 84], [0, 118, 220], [84, 84, 0], [0, 84, 84], [84, 50, 0], [51, 85, 127], [0, 127, 0], [0, 0, 254], [50, 169, 220], [0, 254, 254], [84, 254, 169], [169, 254, 84]])
21 |
22 | def __init__(self, root: str, split: str = 'train', transform = None) -> None:
23 | super().__init__()
24 | assert split in ['train', 'val', 'test']
25 | self.transform = transform
26 | self.n_classes = len(self.CLASSES)
27 | self.ignore_label = 255
28 |
29 | img_path = Path(root) / 'humanparsing' / 'JPEGImages'
30 | self.files = list(img_path.glob('*.jpg'))
31 | if split == 'train':
32 | self.files = self.files[:16000]
33 | elif split == 'val':
34 | self.files = self.files[16000:16700]
35 | else:
36 | self.files = self.files[16700:17700]
37 |
38 | if not self.files:
39 | raise Exception(f"No images found in {img_path}")
40 | print(f"Found {len(self.files)} {split} images.")
41 |
42 | def __len__(self) -> int:
43 | return len(self.files)
44 |
45 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
46 | img_path = str(self.files[index])
47 | lbl_path = str(self.files[index]).replace('JPEGImages', 'SegmentationClassAug').replace('.jpg', '.png')
48 |
49 | image = io.read_image(img_path)
50 | label = io.read_image(lbl_path)
51 |
52 | if self.transform:
53 | image, label = self.transform(image, label)
54 | return image, label.squeeze().long()
55 |
56 |
57 | if __name__ == '__main__':
58 | from semseg.utils.visualize import visualize_dataset_sample
59 | visualize_dataset_sample(ATR, '/home/sithu/datasets/LIP/ATR')
--------------------------------------------------------------------------------
/semseg/datasets/camvid.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch.utils.data import Dataset
4 | from torchvision import io
5 | from pathlib import Path
6 | from typing import Tuple
7 |
8 |
9 | class CamVid(Dataset):
10 | """
11 | num_classes: 11
12 | all_num_classes: 31
13 | """
14 | CLASSES = ['Sky', 'Building', 'Pole', 'Road', 'Pavement', 'Tree', 'SignSymbol', 'Fence', 'Car', 'Pedestrian', 'Bicyclist']
15 | CLASSES_ALL = ['Wall', 'Animal', 'Archway', 'Bicyclist', 'Bridge', 'Building', 'Car', 'CarLuggage', 'Child', 'Pole', 'Fence', 'LaneDrive', 'LaneNonDrive', 'MiscText', 'Motorcycle/Scooter', 'OtherMoving', 'ParkingBlock', 'Pedestrian', 'Road', 'RoadShoulder', 'Sidewalk', 'SignSymbol', 'Sky', 'SUV/PickupTruck', 'TrafficCone', 'TrafficLight', 'Train', 'Tree', 'Truck/Bus', 'Tunnel', 'VegetationMisc']
16 | PALETTE = torch.tensor([[128, 128, 128], [128, 0, 0], [192, 192, 128], [128, 64, 128], [0, 0, 192], [128, 128, 0], [192, 128, 128], [64, 64, 128], [64, 0, 128], [64, 64, 0], [0, 128, 192]])
17 | PALETTE_ALL = torch.tensor([[64, 192, 0], [64, 128, 64], [192, 0, 128], [0, 128, 192], [0, 128, 64], [128, 0, 0], [64, 0, 128], [64, 0, 192], [192, 128, 64], [192, 192, 128], [64, 64, 128], [128, 0, 192], [192, 0, 64], [128, 128, 64], [192, 0, 192], [128, 64, 64], [64, 192, 128], [64, 64, 0], [128, 64, 128], [128, 128, 192], [0, 0, 192], [192, 128, 128], [128, 128, 128], [64, 128, 192], [0, 0, 64], [0, 64, 64], [192, 64, 128], [128, 128, 0], [192, 128, 192], [64, 0, 64], [192, 192, 0]])
18 |
19 | def __init__(self, root: str, split: str = 'train', transform = None) -> None:
20 | super().__init__()
21 | assert split in ['train', 'val', 'test']
22 | self.split = split
23 | self.transform = transform
24 | self.n_classes = len(self.CLASSES)
25 | self.ignore_label = -1
26 |
27 | img_path = Path(root) / split
28 | self.files = list(img_path.glob("*.png"))
29 |
30 | if not self.files:
31 | raise Exception(f"No images found in {img_path}")
32 | print(f"Found {len(self.files)} {split} images.")
33 |
34 | def __len__(self) -> int:
35 | return len(self.files)
36 |
37 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
38 | img_path = str(self.files[index])
39 | lbl_path = str(self.files[index]).replace(self.split, self.split + '_labels').replace('.png', '_L.png')
40 |
41 | image = io.read_image(img_path)
42 | label = io.read_image(lbl_path)
43 |
44 | if self.transform:
45 | image, label = self.transform(image, label)
46 | return image, self.encode(label).long() - 1
47 |
48 | def encode(self, label: Tensor) -> Tensor:
49 | label = label.permute(1, 2, 0)
50 | mask = torch.zeros(label.shape[:-1])
51 |
52 | for index, color in enumerate(self.PALETTE):
53 | bool_mask = torch.eq(label, color)
54 | class_map = torch.all(bool_mask, dim=-1)
55 | mask[class_map] = index + 1
56 | return mask
57 |
58 |
59 | if __name__ == '__main__':
60 | from semseg.utils.visualize import visualize_dataset_sample
61 | visualize_dataset_sample(CamVid, '/home/sithu/datasets/CamVid')
--------------------------------------------------------------------------------
/semseg/datasets/celebamaskhq.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch.utils.data import Dataset
4 | from torchvision import io
5 | from pathlib import Path
6 | from typing import Tuple
7 | from torchvision import transforms as T
8 |
9 |
10 | class CelebAMaskHQ(Dataset):
11 | CLASSES = [
12 | 'background', 'skin', 'nose', 'eye_g', 'l_eye', 'r_eye', 'l_brow', 'r_brow', 'l_ear',
13 | 'r_ear', 'mouth', 'u_lip', 'l_lip', 'hair', 'hat', 'ear_r', 'neck_l', 'neck', 'cloth'
14 | ]
15 | PALETTE = torch.tensor([
16 | [0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], [255, 204, 204], [102, 51, 0], [255, 0, 0],
17 | [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]
18 | ])
19 |
20 | def __init__(self, root: str, split: str = 'train', transform = None) -> None:
21 | super().__init__()
22 | assert split in ['train', 'val', 'test']
23 | self.root = Path(root)
24 | self.transform = transform
25 | self.n_classes = len(self.CLASSES)
26 | self.ignore_label = 255
27 | self.resize = T.Resize((512, 512))
28 |
29 | with open(self.root / f'{split}_list.txt') as f:
30 | self.files = f.read().splitlines()
31 |
32 | if not self.files:
33 | raise Exception(f"No images found in {root}")
34 | print(f"Found {len(self.files)} {split} images.")
35 |
36 | def __len__(self) -> int:
37 | return len(self.files)
38 |
39 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
40 | img_path = self.root / 'CelebA-HQ-img' / f"{self.files[index]}.jpg"
41 | lbl_path = self.root / 'CelebAMask-HQ-label' / f"{self.files[index]}.png"
42 | image = io.read_image(str(img_path))
43 | image = self.resize(image)
44 | label = io.read_image(str(lbl_path))
45 |
46 | if self.transform:
47 | image, label = self.transform(image, label)
48 | return image, label.squeeze().long()
49 |
50 |
51 | if __name__ == '__main__':
52 | from semseg.utils.visualize import visualize_dataset_sample
53 | visualize_dataset_sample(CelebAMaskHQ, '/home/sithu/datasets/CelebAMask-HQ')
--------------------------------------------------------------------------------
/semseg/datasets/cihp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch.utils.data import Dataset
4 | from torchvision import io
5 | from pathlib import Path
6 | from typing import Tuple
7 |
8 |
9 | class CIHP(Dataset):
10 | """This has Best Human Parsing Labels
11 | num_classes: 19+background
12 | 28280 train images
13 | 5000 val images
14 | """
15 | CLASSES = ['background', 'hat', 'hair', 'glove', 'sunglasses', 'upperclothes', 'dress', 'coat', 'socks', 'pants', 'jumpsuits', 'scarf', 'skirt', 'face', 'left-arm', 'right-arm', 'left-leg', 'right-leg', 'left-shoe', 'right-shoe']
16 | PALETTE = torch.tensor([[120, 120, 120], [127, 0, 0], [254, 0, 0], [0, 84, 0], [169, 0, 50], [254, 84, 0], [255, 0, 84], [0, 118, 220], [84, 84, 0], [0, 84, 84], [84, 50, 0], [51, 85, 127], [0, 127, 0], [0, 0, 254], [50, 169, 220], [0, 254, 254], [84, 254, 169], [169, 254, 84], [254, 254, 0], [254, 169, 0]])
17 |
18 | def __init__(self, root: str, split: str = 'train', transform = None) -> None:
19 | super().__init__()
20 | assert split in ['train', 'val']
21 | split = 'Training' if split == 'train' else 'Validation'
22 | self.transform = transform
23 | self.n_classes = len(self.CLASSES)
24 | self.ignore_label = 255
25 |
26 | img_path = Path(root) / 'instance-level_human_parsing' / split / 'Images'
27 | self.files = list(img_path.glob('*.jpg'))
28 |
29 | if not self.files:
30 | raise Exception(f"No images found in {img_path}")
31 | print(f"Found {len(self.files)} {split} images.")
32 |
33 | def __len__(self) -> int:
34 | return len(self.files)
35 |
36 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
37 | img_path = str(self.files[index])
38 | lbl_path = str(self.files[index]).replace('Images', 'Category_ids').replace('.jpg', '.png')
39 |
40 | image = io.read_image(img_path)
41 | label = io.read_image(lbl_path)
42 |
43 | if self.transform:
44 | image, label = self.transform(image, label)
45 | return image, label.squeeze().long()
46 |
47 |
48 | class CCIHP(CIHP):
49 | CLASSES = ['background', 'hat', 'hair', 'glove', 'sunglasses', 'upperclothes', 'facemask', 'coat', 'socks', 'pants', 'torso-skin', 'scarf', 'skirt', 'face', 'left-arm', 'right-arm', 'left-leg', 'right-leg', 'left-shoe', 'right-shoe', 'bag', 'others']
50 | PALETTE = torch.tensor([[120, 120, 120], [127, 0, 0], [254, 0, 0], [0, 84, 0], [169, 0, 50], [254, 84, 0], [255, 0, 84], [0, 118, 220], [84, 84, 0], [0, 84, 84], [84, 50, 0], [51, 85, 127], [0, 127, 0], [0, 0, 254], [50, 169, 220], [0, 254, 254], [84, 254, 169], [169, 254, 84], [254, 254, 0], [254, 169, 0], [102, 254, 0], [182, 255, 0]])
51 |
52 |
53 | if __name__ == '__main__':
54 | import sys
55 | sys.path.insert(0, '.')
56 | from semseg.utils.visualize import visualize_dataset_sample
57 | visualize_dataset_sample(CCIHP, 'C:\\Users\\sithu\\Documents\\Datasets\\LIP\\CIHP')
--------------------------------------------------------------------------------
/semseg/datasets/cityscapes.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch import Tensor
4 | from torch.utils.data import Dataset
5 | from torchvision import io
6 | from pathlib import Path
7 | from typing import Tuple
8 |
9 |
10 | class CityScapes(Dataset):
11 | """
12 | num_classes: 19
13 | """
14 | CLASSES = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign', 'vegetation',
15 | 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 'bicycle']
16 |
17 | PALETTE = torch.tensor([[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], [107, 142, 35],
18 | [152, 251, 152], [70, 130, 180], [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]])
19 |
20 | ID2TRAINID = {0: 255, 1: 255, 2: 255, 3: 255, 4: 255, 5: 255, 6: 255, 7: 0, 8: 1, 9: 255, 10: 255, 11: 2, 12: 3, 13: 4, 14: 255, 15: 255, 16: 255,
21 | 17: 5, 18: 255, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 26: 13, 27: 14, 28: 15, 29: 255, 30: 255, 31: 16, 32: 17, 33: 18, -1: -1}
22 |
23 | def __init__(self, root: str, split: str = 'train', transform = None) -> None:
24 | super().__init__()
25 | assert split in ['train', 'val', 'test']
26 | self.transform = transform
27 | self.n_classes = len(self.CLASSES)
28 | self.ignore_label = 255
29 |
30 | self.label_map = np.arange(256)
31 | for id, trainid in self.ID2TRAINID.items():
32 | self.label_map[id] = trainid
33 |
34 | img_path = Path(root) / 'leftImg8bit' / split
35 | self.files = list(img_path.rglob('*.png'))
36 |
37 | if not self.files:
38 | raise Exception(f"No images found in {img_path}")
39 | print(f"Found {len(self.files)} {split} images.")
40 |
41 | def __len__(self) -> int:
42 | return len(self.files)
43 |
44 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
45 | img_path = str(self.files[index])
46 | lbl_path = str(self.files[index]).replace('leftImg8bit', 'gtFine').replace('.png', '_labelIds.png')
47 |
48 | image = io.read_image(img_path)
49 | label = io.read_image(lbl_path)
50 |
51 | if self.transform:
52 | image, label = self.transform(image, label)
53 | return image, self.encode(label.squeeze().numpy()).long()
54 |
55 | def encode(self, label: Tensor) -> Tensor:
56 | label = self.label_map[label]
57 | return torch.from_numpy(label)
58 |
59 |
60 | if __name__ == '__main__':
61 | from semseg.utils.visualize import visualize_dataset_sample
62 | visualize_dataset_sample(CityScapes, '/home/sithu/datasets/CityScapes')
63 |
--------------------------------------------------------------------------------
/semseg/datasets/facesynthetics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch.utils.data import Dataset
4 | from torchvision import io
5 | from pathlib import Path
6 | from typing import Tuple
7 |
8 |
9 | class FaceSynthetics(Dataset):
10 | CLASSES = ['background', 'skin', 'nose', 'r-eye', 'l-eye', 'r-brow', 'l-brow', 'r-ear', 'l-ear', 'i-mouth', 't-lip', 'b-lip', 'neck', 'hair', 'beard', 'clothing', 'glasses', 'headwear', 'facewear']
11 | PALETTE = torch.tensor([
12 | [0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], [255, 204, 204], [102, 51, 0], [255, 0, 0],
13 | [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]
14 | ])
15 |
16 | def __init__(self, root: str, split: str = 'train', transform = None) -> None:
17 | super().__init__()
18 | assert split in ['train', 'val', 'test']
19 | if split == 'train':
20 | split = 'dataset_100000'
21 | elif split == 'val':
22 | split = 'dataset_1000'
23 | else:
24 | split = 'dataset_100'
25 |
26 | self.transform = transform
27 | self.n_classes = len(self.CLASSES)
28 | self.ignore_label = 255
29 |
30 | img_path = Path(root) / split
31 | images = img_path.glob('*.png')
32 | self.files = [path for path in images if '_seg' not in path.name]
33 |
34 | if not self.files: raise Exception(f"No images found in {root}")
35 | print(f"Found {len(self.files)} {split} images.")
36 |
37 | def __len__(self) -> int:
38 | return len(self.files)
39 |
40 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
41 | img_path = str(self.files[index])
42 | lbl_path = str(self.files[index]).replace('.png', '_seg.png')
43 | image = io.read_image(str(img_path))
44 | label = io.read_image(str(lbl_path))
45 |
46 | if self.transform:
47 | image, label = self.transform(image, label)
48 | return image, label.squeeze().long()
49 |
50 |
51 | if __name__ == '__main__':
52 | import sys
53 | sys.path.insert(0, '.')
54 | from semseg.utils.visualize import visualize_dataset_sample
55 | visualize_dataset_sample(FaceSynthetics, 'C:\\Users\\sithu\\Documents\\Datasets')
--------------------------------------------------------------------------------
/semseg/datasets/helen.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch.utils.data import Dataset
4 | from torchvision import io
5 | from pathlib import Path
6 | from typing import Tuple
7 |
8 |
9 | class HELEN(Dataset):
10 | CLASSES = ['background', 'skin', 'l-brow', 'r-brow', 'l-eye', 'r-eye', 'nose', 'u-lip', 'i-mouth', 'l-lip', 'hair']
11 | PALETTE = torch.tensor([[0, 0 ,0], [127, 0, 0], [254, 0, 0], [0, 84, 0], [169, 0, 50], [254, 84, 0], [255, 0, 84], [0, 118, 220], [84, 84, 0], [0, 84, 84], [84, 50, 0]])
12 |
13 | def __init__(self, root: str, split: str = 'train', transform = None) -> None:
14 | super().__init__()
15 | assert split in ['train', 'val', 'test']
16 | self.transform = transform
17 | self.n_classes = len(self.CLASSES)
18 | self.ignore_label = 255
19 |
20 | self.files = self.get_files(root, split)
21 | if not self.files: raise Exception(f"No images found in {root}")
22 | print(f"Found {len(self.files)} {split} images.")
23 |
24 | def get_files(self, root: str, split: str):
25 | root = Path(root)
26 | if split == 'train':
27 | split = 'exemplars'
28 | elif split == 'val':
29 | split = 'tuning'
30 | else:
31 | split = 'testing'
32 | with open(root / f'{split}.txt') as f:
33 | lines = f.read().splitlines()
34 |
35 | split_names = [line.split(',')[-1].strip() for line in lines if line != '']
36 | files = (root / 'images').glob("*.jpg")
37 | files = list(filter(lambda x: x.stem in split_names, files))
38 | return files
39 |
40 | def __len__(self) -> int:
41 | return len(self.files)
42 |
43 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
44 | img_path = str(self.files[index])
45 | lbl_path = str(self.files[index]).split('.')[0].replace('images', 'labels')
46 | image = io.read_image(img_path)
47 | label = self.encode(lbl_path)
48 |
49 | if self.transform:
50 | image, label = self.transform(image, label)
51 | return image, label.squeeze().long()
52 |
53 | def encode(self, label_path: str) -> Tensor:
54 | mask_paths = sorted(list(Path(label_path).glob('*.png')))
55 | for i, mask_path in enumerate(mask_paths):
56 | mask = io.read_image(str(mask_path)).squeeze()
57 | if i == 0:
58 | label = torch.zeros(self.n_classes, *mask.shape)
59 | label[i, ...] = mask
60 | label = label.argmax(dim=0).unsqueeze(0)
61 | return label
62 |
63 |
64 | if __name__ == '__main__':
65 | from semseg.utils.visualize import visualize_dataset_sample
66 | visualize_dataset_sample(HELEN, '/home/sithu/datasets/SmithCVPR2013_dataset_resized')
--------------------------------------------------------------------------------
/semseg/datasets/ibugmask.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch.utils.data import Dataset
4 | from torchvision import io
5 | from pathlib import Path
6 | from typing import Tuple
7 |
8 |
9 | class iBugMask(Dataset):
10 | CLASSES = ['background', 'skin', 'l-brow', 'r-brow', 'l-eye', 'r-eye', 'nose', 'u-lip', 'i-mouth', 'l-lip', 'hair']
11 | PALETTE = torch.tensor([[0, 0, 0], [255, 255, 0], [139, 76, 57], [139, 54, 38], [0, 205, 0], [0, 138, 0], [154, 50, 205], [72, 118, 255], [255, 165, 0], [0, 0, 139], [255, 0, 0]])
12 |
13 | def __init__(self, root: str, split: str = 'train', transform = None) -> None:
14 | super().__init__()
15 | assert split in ['train', 'val', 'test']
16 | split = 'train' if split == 'train' else 'test'
17 | self.transform = transform
18 | self.n_classes = len(self.CLASSES)
19 | self.ignore_label = 255
20 |
21 | img_path = Path(root) / split
22 | self.files = list(img_path.glob('*.jpg'))
23 |
24 | if not self.files: raise Exception(f"No images found in {root}")
25 | print(f"Found {len(self.files)} {split} images.")
26 |
27 | def __len__(self) -> int:
28 | return len(self.files)
29 |
30 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
31 | img_path = str(self.files[index])
32 | lbl_path = str(self.files[index]).replace('.jpg', '.png')
33 | image = io.read_image(str(img_path))
34 | label = io.read_image(str(lbl_path))
35 |
36 | if self.transform:
37 | image, label = self.transform(image, label)
38 | return image, label.squeeze().long()
39 |
40 |
41 | if __name__ == '__main__':
42 | from semseg.utils.visualize import visualize_dataset_sample
43 | visualize_dataset_sample(iBugMask, '/home/sithu/datasets/ibugmask_release')
--------------------------------------------------------------------------------
/semseg/datasets/lapa.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch.utils.data import Dataset
4 | from torchvision import io
5 | from pathlib import Path
6 | from typing import Tuple
7 |
8 |
9 | class LaPa(Dataset):
10 | CLASSES = ['background', 'skin', 'l-brow', 'r-brow', 'l-eye', 'r-eye', 'nose', 'u-lip', 'i-mouth', 'l-lip', 'hair']
11 | PALETTE = torch.tensor([[0, 0, 0], [0, 153, 255], [102, 255, 153], [0, 204, 153], [255, 255, 102], [255, 255, 204], [255, 153, 0], [255, 102, 255], [102, 0, 51], [255, 204, 255], [255, 0, 102]])
12 |
13 | def __init__(self, root: str, split: str = 'train', transform = None) -> None:
14 | super().__init__()
15 | assert split in ['train', 'val', 'test']
16 | self.transform = transform
17 | self.n_classes = len(self.CLASSES)
18 | self.ignore_label = 255
19 |
20 | img_path = Path(root) / split / 'images'
21 | self.files = list(img_path.glob('*.jpg'))
22 |
23 | if not self.files: raise Exception(f"No images found in {root}")
24 | print(f"Found {len(self.files)} {split} images.")
25 |
26 | def __len__(self) -> int:
27 | return len(self.files)
28 |
29 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
30 | img_path = str(self.files[index])
31 | lbl_path = str(self.files[index]).replace('images', 'labels').replace('.jpg', '.png')
32 | image = io.read_image(str(img_path))
33 | label = io.read_image(str(lbl_path))
34 |
35 | if self.transform:
36 | image, label = self.transform(image, label)
37 | return image, label.squeeze().long()
38 |
39 |
40 | if __name__ == '__main__':
41 | from semseg.utils.visualize import visualize_dataset_sample
42 | visualize_dataset_sample(LaPa, '/home/sithu/datasets/LaPa')
--------------------------------------------------------------------------------
/semseg/datasets/lip.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch.utils.data import Dataset
4 | from torchvision import io
5 | from pathlib import Path
6 | from typing import Tuple
7 |
8 |
9 | class LIP(Dataset):
10 | """
11 | num_classes: 19+background
12 | 30462 train images
13 | 10000 val images
14 | """
15 | CLASSES = ['background', 'hat', 'hair', 'glove', 'sunglasses', 'upperclothes', 'dress', 'coat', 'socks', 'pants', 'jumpsuits', 'scarf', 'skirt', 'face', 'left-arm', 'right-arm', 'left-leg', 'right-leg', 'left-shoe', 'right-shoe']
16 | PALETTE = torch.tensor([[0, 0, 0], [127, 0, 0], [254, 0, 0], [0, 84, 0], [169, 0, 50], [254, 84, 0], [255, 0, 84], [0, 118, 220], [84, 84, 0], [0, 84, 84], [84, 50, 0], [51, 85, 127], [0, 127, 0], [0, 0, 254], [50, 169, 220], [0, 254, 254], [84, 254, 169], [169, 254, 84], [254, 254, 0], [254, 169, 0]])
17 |
18 | def __init__(self, root: str, split: str = 'train', transform = None) -> None:
19 | super().__init__()
20 | assert split in ['train', 'val']
21 | self.split = split
22 | self.transform = transform
23 | self.n_classes = len(self.CLASSES)
24 | self.ignore_label = 255
25 |
26 | img_path = Path(root) / 'TrainVal_images' / f'{split}_images'
27 | self.files = list(img_path.glob('*.jpg'))
28 |
29 | if not self.files:
30 | raise Exception(f"No images found in {img_path}")
31 | print(f"Found {len(self.files)} {split} images.")
32 |
33 | def __len__(self) -> int:
34 | return len(self.files)
35 |
36 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
37 | img_path = str(self.files[index])
38 | lbl_path = str(self.files[index]).replace('TrainVal_images', 'TrainVal_parsing_annotations').replace(f'{self.split}_images', f'{self.split}_segmentations').replace('.jpg', '.png')
39 |
40 | image = io.read_image(img_path)
41 | label = io.read_image(lbl_path)
42 |
43 | if self.transform:
44 | image, label = self.transform(image, label)
45 | return image, label.squeeze().long()
46 |
47 |
48 | if __name__ == '__main__':
49 | from semseg.utils.visualize import visualize_dataset_sample
50 | visualize_dataset_sample(LIP, '/home/sithu/datasets/LIP/LIP')
--------------------------------------------------------------------------------
/semseg/datasets/mapillary.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch.utils.data import Dataset
4 | from torchvision import io
5 | from pathlib import Path
6 | from typing import Tuple
7 |
8 |
9 | class MapillaryVistas(Dataset):
10 | CLASSES = [
11 | 'Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail', 'Barrier', 'Wall', 'Bike Lane', 'Crosswalk - Plain', 'Curb Cut', 'Parking', 'Pedestrian Area', 'Rail Track', 'Road', 'Service Lane', 'Sidewalk', 'Bridge', 'Building', 'Tunnel', 'Person', 'Bicyclist', 'Motorcyclist', 'Other Rider', 'Lane Marking - Crosswalk', 'Lane Marking - General', 'Mountain', 'Sand', 'Sky', 'Snow', 'Terrain', 'Vegetation', 'Water', 'Banner',
12 | 'Bench', 'Bike Rack', 'Billboard', 'Catch Basin', 'CCTV Camera', 'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole', 'Phone Booth', 'Pothole', 'Street Light', 'Pole', 'Traffic Sign Frame', 'Utility Pole', 'Traffic Light', 'Traffic Sign (Back)', 'Traffic Sign (Front)', 'Trash Can', 'Bicycle', 'Boat', 'Bus', 'Car', 'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle', 'Trailer', 'Truck', 'Wheeled Slow', 'Car Mount', 'Ego Vehicle'
13 | ]
14 | PALETTE = torch.tensor([
15 | [165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153], [180, 165, 180], [90, 120, 150], [102, 102, 156], [128, 64, 255], [140, 140, 200], [170, 170, 170], [250, 170, 160], [96, 96, 96], [230, 150, 140], [128, 64, 128], [110, 110, 110], [244, 35, 232], [150, 100, 100], [70, 70, 70], [150, 120, 90], [220, 20, 60], [255, 0, 0], [255, 0, 100], [255, 0, 200], [200, 128, 128], [255, 255, 255], [64, 170, 64], [230, 160, 50], [70, 130, 180], [190, 255, 255], [152, 251, 152], [107, 142, 35],
16 | [0, 170, 30], [255, 255, 128], [250, 0, 30], [100, 140, 180], [220, 220, 220], [220, 128, 128], [222, 40, 40], [100, 170, 30], [40, 40, 40], [33, 33, 33], [100, 128, 160], [142, 0, 0], [70, 100, 150], [210, 170, 100], [153, 153, 153], [128, 128, 128], [0, 0, 80], [250, 170, 30], [192, 192, 192], [220, 220, 0], [140, 140, 20], [119, 11, 32], [150, 0, 255], [0, 60, 100], [0, 0, 142], [0, 0, 90], [0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110], [0, 0, 70], [0, 0, 192], [32, 32, 32], [120, 10, 10]
17 | ])
18 |
19 | def __init__(self, root: str, split: str = 'train', transform = None) -> None:
20 | super().__init__()
21 | assert split in ['train', 'val']
22 | split = 'training' if split == 'train' else 'validation'
23 | self.transform = transform
24 | self.n_classes = len(self.CLASSES)
25 | self.ignore_label = 65
26 |
27 | img_path = Path(root) / split / 'images'
28 | self.files = list(img_path.glob("*.jpg"))
29 |
30 | if not self.files:
31 | raise Exception(f"No images found in {img_path}")
32 | print(f"Found {len(self.files)} {split} images.")
33 |
34 | def __len__(self) -> int:
35 | return len(self.files)
36 |
37 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
38 | img_path = str(self.files[index])
39 | lbl_path = str(self.files[index]).replace('images', 'labels').replace('.jpg', '.png')
40 |
41 | image = io.read_image(img_path, io.ImageReadMode.RGB)
42 | label = io.read_image(lbl_path)
43 |
44 | if self.transform:
45 | image, label = self.transform(image, label)
46 | return image, label.squeeze().long()
47 |
48 |
49 | if __name__ == '__main__':
50 | from semseg.utils.visualize import visualize_dataset_sample
51 | visualize_dataset_sample(MapillaryVistas, '/home/sithu/datasets/Mapillary')
--------------------------------------------------------------------------------
/semseg/datasets/mhpv1.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch import Tensor
4 | from torch.utils.data import Dataset
5 | from torchvision import io
6 | from pathlib import Path
7 | from typing import Tuple
8 |
9 |
10 | class MHPv1(Dataset):
11 | """
12 | 4980 images each with at least 2 persons (average 3)
13 | 3000 images for training
14 | 1000 images for validation
15 | 980 images for testing
16 | num_classes: 18+background
17 | """
18 | CLASSES = ['background', 'hat', 'hair', 'sunglass', 'upper-clothes', 'skirt', 'pants', 'dress', 'belt', 'left-shoe', 'right-shoe', 'face', 'left-leg', 'right-leg', 'left-arm', 'right-arm', 'bag', 'sacrf', 'torso-skin']
19 | PALETTE = torch.tensor([[0, 0, 0], [128, 0, 0], [254, 0, 0], [0, 85, 0], [169, 0, 51], [254, 85, 0], [255, 0, 85], [0, 119, 220], [85, 85, 0], [190, 153, 153], [85, 51, 0], [52, 86, 128], [0, 128, 0], [0, 0, 254], [51, 169, 220], [0, 254, 254], [85, 254, 169], [169, 254, 85], [254, 254, 0]])
20 |
21 | def __init__(self, root: str, split: str = 'train', transform = None) -> None:
22 | super().__init__()
23 | assert split in ['train', 'val', 'test']
24 | self.transform = transform
25 | self.n_classes = len(self.CLASSES)
26 | self.ignore_label = 255
27 |
28 | self.images, self.labels = self.get_files(root, split)
29 | print(f"Found {len(self.images)} {split} images.")
30 |
31 | def get_files(self, root: str, split: str):
32 | root = Path(root)
33 | all_labels = list((root / 'annotations').rglob('*.png'))
34 | images, labels = [], []
35 |
36 | flist = 'test_list.txt' if split == 'test' else 'train_list.txt'
37 | with open(root / flist) as f:
38 | all_files = f.read().splitlines()
39 |
40 | if split == 'train':
41 | files = all_files[:3000]
42 | elif split == 'val':
43 | files = all_files[3000:]
44 | else:
45 | files = all_files
46 |
47 | for f in files:
48 | images.append(root / 'images' / f)
49 | img_name = f.split('.')[0]
50 | labels_per_images = list(filter(lambda x: x.stem.startswith(img_name), all_labels))
51 | assert labels_per_images != []
52 | labels.append(labels_per_images)
53 |
54 | assert len(images) == len(labels)
55 | return images, labels
56 |
57 | def __len__(self) -> int:
58 | return len(self.images)
59 |
60 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
61 | img_path = str(self.images[index])
62 | lbl_paths = self.labels[index]
63 |
64 | image = io.read_image(img_path)
65 | label = self.read_label(lbl_paths)
66 |
67 | if self.transform:
68 | image, label = self.transform(image, label)
69 | return image, label.squeeze().long()
70 |
71 | def read_label(self, lbl_paths: list) -> Tensor:
72 | labels = None
73 | label_idx = None
74 |
75 | for lbl_path in lbl_paths:
76 | label = io.read_image(str(lbl_path)).squeeze().numpy()
77 |
78 | if label_idx is None:
79 | label_idx = np.zeros(label.shape, dtype=np.uint8)
80 | label = np.ma.masked_array(label, mask=label_idx)
81 | label_idx += np.minimum(label, 1)
82 | if labels is None:
83 | labels = label
84 | else:
85 | labels += label
86 | return torch.from_numpy(labels.data).unsqueeze(0).to(torch.uint8)
87 |
88 |
89 | if __name__ == '__main__':
90 | from semseg.utils.visualize import visualize_dataset_sample
91 | visualize_dataset_sample(MHPv1, '/home/sithu/datasets/LV-MHP-v1')
92 |
--------------------------------------------------------------------------------
/semseg/datasets/mhpv2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch import Tensor
4 | from torch.utils.data import Dataset
5 | from torchvision import io
6 | from pathlib import Path
7 | from typing import Tuple
8 |
9 |
10 | class MHPv2(Dataset):
11 | """
12 | 25,403 images each with at least 2 persons (average 3)
13 | 15,403 images for training
14 | 5000 images for validation
15 | 5000 images for testing
16 | num_classes: 58+background
17 | """
18 | CLASSES = ['background', 'cap/hat', 'helmet', 'face', 'hair', 'left-arm', 'right-arm', 'left-hand', 'right-hand', 'protector', 'bikini/bra', 'jacket/windbreaker/hoodie', 't-shirt', 'polo-shirt', 'sweater', 'singlet', 'torso-skin', 'pants', 'shorts/swim-shorts', 'skirt', 'stockings', 'socks', 'left-boot', 'right-boot', 'left-shoe', 'right-shoe', 'left-highheel', 'right-highheel', 'left-sandal', 'right-sandal', 'left-leg', 'right-leg', 'left-foot', 'right-foot', 'coat', 'dress', 'robe', 'jumpsuits', 'other-full-body-clothes', 'headware', 'backpack', 'ball', 'bats', 'belt', 'bottle', 'carrybag', 'cases', 'sunglasses', 'eyeware', 'gloves', 'scarf', 'umbrella', 'wallet/purse', 'watch', 'wristband', 'tie', 'other-accessories', 'other-upper-body-clothes', 'other-lower-body-clothes']
19 | PALETTE = torch.tensor([[0, 0, 0], [255, 114, 196], [63, 31, 34], [253, 1, 0], [254, 26, 1], [253, 54, 0], [253, 82, 0], [252, 110, 0], [253, 137, 0], [253, 166, 1], [254, 191, 0], [253, 219, 0], [252, 248, 0], [238, 255, 1], [209, 255, 0], [182, 255, 0], [155, 255, 0], [133, 254, 0], [102, 254, 0], [78, 255, 0], [55, 254, 1], [38, 255, 0], [30, 255, 13], [34, 255, 35], [35, 254, 64], [36, 254, 87], [37, 252, 122], [37, 255, 143], [35, 255, 172], [35, 255, 200], [40, 253, 228], [40, 255, 255], [37, 228, 255], [33, 198, 254], [31, 170, 254], [22, 145, 255], [26, 112, 255], [20, 86, 253], [22, 53, 255], [19, 12, 253], [19, 1, 246], [30, 1, 252], [52, 0, 254], [72, 0, 255], [102, 0, 255], [121, 1, 252], [157, 1, 245], [182, 0, 253], [210, 0, 254], [235, 0, 255], [253, 1, 246], [254, 0, 220], [255, 0, 191], [254, 0, 165], [252, 0, 137], [248, 2, 111], [253, 0, 81], [255, 0, 54], [253, 1, 26]])
20 |
21 | def __init__(self, root: str, split: str = 'train', transform = None) -> None:
22 | super().__init__()
23 | assert split in ['train', 'val']
24 | self.transform = transform
25 | self.n_classes = len(self.CLASSES)
26 | self.ignore_label = 255
27 |
28 | self.images, self.labels = self.get_files(root, split)
29 | print(f"Found {len(self.images)} {split} images.")
30 |
31 | def get_files(self, root: str, split: str):
32 | root = Path(root)
33 | all_labels = list((root / split / 'parsing_annos').rglob('*.png'))
34 | images = list((root / split / 'images').rglob('*.jpg'))
35 | labels = []
36 |
37 | for f in images:
38 | labels_per_images = list(filter(lambda x: x.stem.split('_', maxsplit=1)[0] == f.stem, all_labels))
39 | assert labels_per_images != []
40 | labels.append(labels_per_images)
41 |
42 | assert len(images) == len(labels)
43 | return images, labels
44 |
45 | def __len__(self) -> int:
46 | return len(self.images)
47 |
48 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
49 | img_path = str(self.images[index])
50 | lbl_paths = self.labels[index]
51 |
52 | image = io.read_image(img_path)
53 | label = self.read_label(lbl_paths)
54 |
55 | if self.transform:
56 | image, label = self.transform(image, label)
57 | return image, label.squeeze().long()
58 |
59 | def read_label(self, lbl_paths: list) -> Tensor:
60 | labels = None
61 | label_idx = None
62 |
63 | for lbl_path in lbl_paths:
64 | label = io.read_image(str(lbl_path)).squeeze().numpy()
65 | if label.ndim != 2:
66 | label = label[0]
67 | if label_idx is None:
68 | label_idx = np.zeros(label.shape, dtype=np.uint8)
69 | label = np.ma.masked_array(label, mask=label_idx)
70 | label_idx += np.minimum(label, 1)
71 | if labels is None:
72 | labels = label
73 | else:
74 | labels += label
75 | return torch.from_numpy(labels.data).unsqueeze(0).to(torch.uint8)
76 |
77 |
78 | if __name__ == '__main__':
79 | from semseg.utils.visualize import visualize_dataset_sample
80 | visualize_dataset_sample(MHPv2, '/home/sithu/datasets/LV-MHP-v2')
--------------------------------------------------------------------------------
/semseg/datasets/pascalcontext.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch.utils.data import Dataset
4 | from torchvision import io
5 | from pathlib import Path
6 | from typing import Tuple
7 |
8 |
9 | class PASCALContext(Dataset):
10 | """
11 | https://cs.stanford.edu/~roozbeh/pascal-context/
12 | based on PASCAL VOC 2010
13 | num_classes: 59
14 | 10,100 train+val
15 | 9,637 test
16 | """
17 | CLASSES = [
18 | 'aeroplane', 'bag', 'bed', 'bedclothes', 'bench',
19 | 'bicycle', 'bird', 'boat', 'book', 'bottle', 'building', 'bus',
20 | 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth',
21 | 'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence',
22 | 'floor', 'flower', 'food', 'grass', 'ground', 'horse',
23 | 'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 'person',
24 | 'plate', 'platform', 'pottedplant', 'road', 'rock', 'sheep',
25 | 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table',
26 | 'track', 'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water',
27 | 'window', 'wood'
28 | ]
29 |
30 | PALETTE = torch.tensor([
31 | [180, 120, 120], [6, 230, 230], [80, 50, 50],
32 | [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
33 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
34 | [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
35 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
36 | [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
37 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
38 | [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
39 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
40 | [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
41 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
42 | [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
43 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
44 | [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
45 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]
46 | ])
47 |
48 | def __init__(self, root: str, split: str = 'train', transform = None) -> None:
49 | super().__init__()
50 | assert split in ['train', 'val']
51 | self.transform = transform
52 | self.n_classes = len(self.CLASSES)
53 | self.ignore_label = -1
54 |
55 | self.images, self.labels = self.get_files(root, split)
56 | print(f"Found {len(self.images)} {split} images.")
57 |
58 | def get_files(self, root: str, split: str):
59 | root = Path(root)
60 | flist = root / 'ImageSets' / 'SegmentationContext' / f'{split}.txt'
61 | with open(flist) as f:
62 | files = f.read().splitlines()
63 | images, labels = [], []
64 |
65 | for fi in files:
66 | images.append(str(root / 'JPEGImages' / f'{fi}.jpg'))
67 | labels.append(str(root / 'SegmentationClassContext' / f'{fi}.png'))
68 | return images, labels
69 |
70 | def __len__(self) -> int:
71 | return len(self.images)
72 |
73 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
74 | img_path = self.images[index]
75 | lbl_path = self.labels[index]
76 |
77 | image = io.read_image(img_path)
78 | label = io.read_image(lbl_path)
79 |
80 | if self.transform:
81 | image, label = self.transform(image, label)
82 | return image, label.squeeze().long() - 1 # remove background class
83 |
84 |
85 | if __name__ == '__main__':
86 | from semseg.utils.visualize import visualize_dataset_sample
87 | visualize_dataset_sample(PASCALContext, '/home/sithu/datasets/VOCdevkit/VOC2010')
--------------------------------------------------------------------------------
/semseg/datasets/suim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch.utils.data import Dataset
4 | from torchvision import io
5 | from pathlib import Path
6 | from typing import Tuple
7 | from PIL import Image
8 | from torchvision.transforms import functional as TF
9 |
10 |
11 | class SUIM(Dataset):
12 | CLASSES = ['water', 'human divers', 'aquatic plants and sea-grass', 'wrecks and ruins', 'robots (AUVs/ROVs/instruments)', 'reefs and invertebrates', 'fish and vertebrates', 'sea-floor and rocks']
13 | PALETTE = torch.tensor([[0, 0, 0], [0, 0, 255], [0, 255, 0], [0, 255, 255], [255, 0, 0], [255, 0, 255], [255, 255, 0], [255, 255, 255]])
14 |
15 | def __init__(self, root: str, split: str = 'train', transform = None) -> None:
16 | super().__init__()
17 | assert split in ['train', 'val']
18 | self.split = 'train_val' if split == 'train' else 'TEST'
19 | self.transform = transform
20 | self.n_classes = len(self.CLASSES)
21 | self.ignore_label = 255
22 |
23 | img_path = Path(root) / self.split / 'images'
24 | self.files = list(img_path.glob("*.jpg"))
25 |
26 | if not self.files:
27 | raise Exception(f"No images found in {img_path}")
28 | print(f"Found {len(self.files)} {split} images.")
29 |
30 | def __len__(self) -> int:
31 | return len(self.files)
32 |
33 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
34 | img_path = str(self.files[index])
35 | lbl_path = str(self.files[index]).replace('images', 'masks').replace('.jpg', '.bmp')
36 |
37 | image = io.read_image(img_path)
38 | label = TF.pil_to_tensor(Image.open(lbl_path).convert('RGB'))
39 |
40 | if self.transform:
41 | image, label = self.transform(image, label)
42 | return image, self.encode(label).long()
43 |
44 | def encode(self, label: Tensor) -> Tensor:
45 | label = label.permute(1, 2, 0)
46 | mask = torch.zeros(label.shape[:-1])
47 |
48 | for index, color in enumerate(self.PALETTE):
49 | bool_mask = torch.eq(label, color)
50 | class_map = torch.all(bool_mask, dim=-1)
51 | mask[class_map] = index
52 | return mask
53 |
54 |
55 | if __name__ == '__main__':
56 | from semseg.utils.visualize import visualize_dataset_sample
57 | visualize_dataset_sample(SUIM, '/home/sithu/datasets/SUIM')
--------------------------------------------------------------------------------
/semseg/datasets/sunrgbd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch import Tensor
4 | from torch.utils.data import Dataset
5 | from torchvision import io
6 | from scipy import io as sio
7 | from pathlib import Path
8 | from typing import Tuple
9 |
10 |
11 | class SunRGBD(Dataset):
12 | """
13 | num_classes: 37
14 | """
15 | CLASSES = [
16 | 'wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 'counter', 'blinds', 'desk', 'shelves', 'curtain', 'dresser', 'pillow', 'mirror',
17 | 'floor mat', 'clothes', 'ceiling', 'books', 'fridge', 'tv', 'paper', 'towel', 'shower curtain', 'box', 'whiteboard', 'person', 'night stand', 'toilet', 'sink', 'lamp', 'bathtub', 'bag'
18 | ]
19 |
20 | PALETTE = torch.tensor([
21 | (119, 119, 119), (244, 243, 131), (137, 28, 157), (150, 255, 255), (54, 114, 113), (0, 0, 176), (255, 69, 0), (87, 112, 255), (0, 163, 33),
22 | (255, 150, 255), (255, 180, 10), (101, 70, 86), (38, 230, 0), (255, 120, 70), (117, 41, 121), (150, 255, 0), (132, 0, 255), (24, 209, 255),
23 | (191, 130, 35), (219, 200, 109), (154, 62, 86), (255, 190, 190), (255, 0, 255), (152, 163, 55), (192, 79, 212), (230, 230, 230), (53, 130, 64),
24 | (155, 249, 152), (87, 64, 34), (214, 209, 175), (170, 0, 59), (255, 0, 0), (193, 195, 234), (70, 72, 115), (255, 255, 0), (52, 57, 131), (12, 83, 45)
25 | ])
26 |
27 | def __init__(self, root: str, split: str = 'train', transform = None) -> None:
28 | super().__init__()
29 | assert split in ['alltrain', 'train', 'val', 'test']
30 | self.transform = transform
31 | self.n_classes = len(self.CLASSES)
32 | self.ignore_label = -1
33 | self.files, self.labels = self.get_data(root, split)
34 | print(f"Found {len(self.files)} {split} images.")
35 |
36 | def get_data(self, root: str, split: str):
37 | root = Path(root)
38 | files, labels = [], []
39 | split_path = root / 'SUNRGBDtoolbox' / 'traintestSUNRGBD' / 'allsplit.mat'
40 | split_mat = sio.loadmat(split_path, squeeze_me=True, struct_as_record=False)
41 | if split == 'train':
42 | file_lists = split_mat['trainvalsplit'].train
43 | elif split == 'val':
44 | file_lists = split_mat['trainvalsplit'].val
45 | elif split == 'test':
46 | file_lists = split_mat['alltest']
47 | else:
48 | file_lists = split_mat['alltrain']
49 |
50 | for fl in file_lists:
51 | real_fl = root / fl.split('/n/fs/sun3d/data/')[-1]
52 | files.append(str(list((real_fl / 'image').glob('*.jpg'))[0]))
53 | labels.append(real_fl / 'seg.mat')
54 |
55 | assert len(files) == len(labels)
56 | return files, labels
57 |
58 | def __len__(self) -> int:
59 | return len(self.files)
60 |
61 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
62 | image = io.read_image(self.files[index], io.ImageReadMode.RGB)
63 | label = sio.loadmat(self.labels[index], squeeze_me=True, struct_as_record=False)['seglabel']
64 | label = torch.from_numpy(label.astype(np.uint8)).unsqueeze(0)
65 |
66 | if self.transform:
67 | image, label = self.transform(image, label)
68 | return image, self.encode(label.squeeze()).long() - 1 # subtract -1 to remove void class
69 |
70 | def encode(self, label: Tensor) -> Tensor:
71 | label[label > self.n_classes] = 0
72 | return label
73 |
74 |
75 | if __name__ == '__main__':
76 | from semseg.utils.visualize import visualize_dataset_sample
77 | visualize_dataset_sample(SunRGBD, '/home/sithu/datasets/sunrgbd')
--------------------------------------------------------------------------------
/semseg/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | from torch.nn import functional as F
4 |
5 |
6 | class CrossEntropy(nn.Module):
7 | def __init__(self, ignore_label: int = 255, weight: Tensor = None, aux_weights: list = [1, 0.4, 0.4]) -> None:
8 | super().__init__()
9 | self.aux_weights = aux_weights
10 | self.criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_label)
11 |
12 | def _forward(self, preds: Tensor, labels: Tensor) -> Tensor:
13 | # preds in shape [B, C, H, W] and labels in shape [B, H, W]
14 | return self.criterion(preds, labels)
15 |
16 | def forward(self, preds, labels: Tensor) -> Tensor:
17 | if isinstance(preds, tuple):
18 | return sum([w * self._forward(pred, labels) for (pred, w) in zip(preds, self.aux_weights)])
19 | return self._forward(preds, labels)
20 |
21 |
22 | class OhemCrossEntropy(nn.Module):
23 | def __init__(self, ignore_label: int = 255, weight: Tensor = None, thresh: float = 0.7, aux_weights: list = [1, 1]) -> None:
24 | super().__init__()
25 | self.ignore_label = ignore_label
26 | self.aux_weights = aux_weights
27 | self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float))
28 | self.criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_label, reduction='none')
29 |
30 | def _forward(self, preds: Tensor, labels: Tensor) -> Tensor:
31 | # preds in shape [B, C, H, W] and labels in shape [B, H, W]
32 | n_min = labels[labels != self.ignore_label].numel() // 16
33 | loss = self.criterion(preds, labels).view(-1)
34 | loss_hard = loss[loss > self.thresh]
35 |
36 | if loss_hard.numel() < n_min:
37 | loss_hard, _ = loss.topk(n_min)
38 |
39 | return torch.mean(loss_hard)
40 |
41 | def forward(self, preds, labels: Tensor) -> Tensor:
42 | if isinstance(preds, tuple):
43 | return sum([w * self._forward(pred, labels) for (pred, w) in zip(preds, self.aux_weights)])
44 | return self._forward(preds, labels)
45 |
46 |
47 | class Dice(nn.Module):
48 | def __init__(self, delta: float = 0.5, aux_weights: list = [1, 0.4, 0.4]):
49 | """
50 | delta: Controls weight given to FP and FN. This equals to dice score when delta=0.5
51 | """
52 | super().__init__()
53 | self.delta = delta
54 | self.aux_weights = aux_weights
55 |
56 | def _forward(self, preds: Tensor, labels: Tensor) -> Tensor:
57 | # preds in shape [B, C, H, W] and labels in shape [B, H, W]
58 | num_classes = preds.shape[1]
59 | labels = F.one_hot(labels, num_classes).permute(0, 3, 1, 2)
60 | tp = torch.sum(labels*preds, dim=(2, 3))
61 | fn = torch.sum(labels*(1-preds), dim=(2, 3))
62 | fp = torch.sum((1-labels)*preds, dim=(2, 3))
63 |
64 | dice_score = (tp + 1e-6) / (tp + self.delta * fn + (1 - self.delta) * fp + 1e-6)
65 | dice_score = torch.sum(1 - dice_score, dim=-1)
66 |
67 | dice_score = dice_score / num_classes
68 | return dice_score.mean()
69 |
70 | def forward(self, preds, targets: Tensor) -> Tensor:
71 | if isinstance(preds, tuple):
72 | return sum([w * self._forward(pred, targets) for (pred, w) in zip(preds, self.aux_weights)])
73 | return self._forward(preds, targets)
74 |
75 |
76 | __all__ = ['CrossEntropy', 'OhemCrossEntropy', 'Dice']
77 |
78 |
79 | def get_loss(loss_fn_name: str = 'CrossEntropy', ignore_label: int = 255, cls_weights: Tensor = None):
80 | assert loss_fn_name in __all__, f"Unavailable loss function name >> {loss_fn_name}.\nAvailable loss functions: {__all__}"
81 | if loss_fn_name == 'Dice':
82 | return Dice()
83 | return eval(loss_fn_name)(ignore_label, cls_weights)
84 |
85 |
86 | if __name__ == '__main__':
87 | pred = torch.randint(0, 19, (2, 19, 480, 640), dtype=torch.float)
88 | label = torch.randint(0, 19, (2, 480, 640), dtype=torch.long)
89 | loss_fn = Dice()
90 | y = loss_fn(pred, label)
91 | print(y)
--------------------------------------------------------------------------------
/semseg/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from typing import Tuple
4 |
5 |
6 | class Metrics:
7 | def __init__(self, num_classes: int, ignore_label: int, device) -> None:
8 | self.ignore_label = ignore_label
9 | self.num_classes = num_classes
10 | self.hist = torch.zeros(num_classes, num_classes).to(device)
11 |
12 | def update(self, pred: Tensor, target: Tensor) -> None:
13 | pred = pred.argmax(dim=1)
14 | keep = target != self.ignore_label
15 | self.hist += torch.bincount(target[keep] * self.num_classes + pred[keep], minlength=self.num_classes**2).view(self.num_classes, self.num_classes)
16 |
17 | def compute_iou(self) -> Tuple[Tensor, Tensor]:
18 | ious = self.hist.diag() / (self.hist.sum(0) + self.hist.sum(1) - self.hist.diag())
19 | miou = ious[~ious.isnan()].mean().item()
20 | ious *= 100
21 | miou *= 100
22 | return ious.cpu().numpy().round(2).tolist(), round(miou, 2)
23 |
24 | def compute_f1(self) -> Tuple[Tensor, Tensor]:
25 | f1 = 2 * self.hist.diag() / (self.hist.sum(0) + self.hist.sum(1))
26 | mf1 = f1[~f1.isnan()].mean().item()
27 | f1 *= 100
28 | mf1 *= 100
29 | return f1.cpu().numpy().round(2).tolist(), round(mf1, 2)
30 |
31 | def compute_pixel_acc(self) -> Tuple[Tensor, Tensor]:
32 | acc = self.hist.diag() / self.hist.sum(1)
33 | macc = acc[~acc.isnan()].mean().item()
34 | acc *= 100
35 | macc *= 100
36 | return acc.cpu().numpy().round(2).tolist(), round(macc, 2)
37 |
38 |
--------------------------------------------------------------------------------
/semseg/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .segformer import SegFormer
2 | from .ddrnet import DDRNet
3 | from .fchardnet import FCHarDNet
4 | from .sfnet import SFNet
5 | from .bisenetv1 import BiSeNetv1
6 | from .bisenetv2 import BiSeNetv2
7 | from .lawin import Lawin
8 |
9 |
10 | __all__ = [
11 | 'SegFormer',
12 | 'Lawin',
13 | 'SFNet',
14 | 'BiSeNetv1',
15 |
16 | # Standalone Models
17 | 'DDRNet',
18 | 'FCHarDNet',
19 | 'BiSeNetv2'
20 | ]
--------------------------------------------------------------------------------
/semseg/models/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet import ResNet, resnet_settings
2 | from .resnetd import ResNetD, resnetd_settings
3 | from .micronet import MicroNet, micronet_settings
4 | from .mobilenetv2 import MobileNetV2, mobilenetv2_settings
5 | from .mobilenetv3 import MobileNetV3, mobilenetv3_settings
6 |
7 | from .mit import MiT, mit_settings
8 | from .pvt import PVTv2, pvtv2_settings
9 | from .rest import ResT, rest_settings
10 | from .poolformer import PoolFormer, poolformer_settings
11 | from .convnext import ConvNeXt, convnext_settings
12 |
13 |
14 | __all__ = [
15 | 'ResNet',
16 | 'ResNetD',
17 | 'MicroNet',
18 | 'MobileNetV2',
19 | 'MobileNetV3',
20 |
21 | 'MiT',
22 | 'PVTv2',
23 | 'ResT',
24 | 'PoolFormer',
25 | 'ConvNeXt',
26 | ]
--------------------------------------------------------------------------------
/semseg/models/backbones/convnext.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | from semseg.models.layers import DropPath
4 |
5 |
6 | class LayerNorm(nn.Module):
7 | """Channel first layer norm
8 | """
9 | def __init__(self, normalized_shape, eps=1e-6) -> None:
10 | super().__init__()
11 | self.weight = nn.Parameter(torch.ones(normalized_shape))
12 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
13 | self.eps = eps
14 |
15 | def forward(self, x: Tensor) -> Tensor:
16 | u = x.mean(1, keepdim=True)
17 | s = (x - u).pow(2).mean(1, keepdim=True)
18 | x = (x - u) / torch.sqrt(s + self.eps)
19 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
20 | return x
21 |
22 |
23 | class Block(nn.Module):
24 | def __init__(self, dim, dpr=0., init_value=1e-6):
25 | super().__init__()
26 | self.dwconv = nn.Conv2d(dim, dim, 7, 1, 3, groups=dim)
27 | self.norm = nn.LayerNorm(dim, eps=1e-6)
28 | self.pwconv1 = nn.Linear(dim, 4*dim)
29 | self.act = nn.GELU()
30 | self.pwconv2 = nn.Linear(4*dim, dim)
31 | self.gamma = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True) if init_value > 0 else None
32 | self.drop_path = DropPath(dpr) if dpr > 0. else nn.Identity()
33 |
34 | def forward(self, x: Tensor) -> Tensor:
35 | input = x
36 | x = self.dwconv(x)
37 | x = x.permute(0, 2, 3, 1) # NCHW to NHWC
38 | x = self.norm(x)
39 | x = self.pwconv1(x)
40 | x = self.act(x)
41 | x = self.pwconv2(x)
42 |
43 | if self.gamma is not None:
44 | x = self.gamma * x
45 |
46 | x = x.permute(0, 3, 1, 2)
47 | x = input + self.drop_path(x)
48 | return x
49 |
50 |
51 | class Stem(nn.Sequential):
52 | def __init__(self, c1, c2, k, s):
53 | super().__init__(
54 | nn.Conv2d(c1, c2, k, s),
55 | LayerNorm(c2)
56 | )
57 |
58 |
59 | class Downsample(nn.Sequential):
60 | def __init__(self, c1, c2, k, s):
61 | super().__init__(
62 | LayerNorm(c1),
63 | nn.Conv2d(c1, c2, k, s)
64 | )
65 |
66 |
67 | convnext_settings = {
68 | 'T': [[3, 3, 9, 3], [96, 192, 384, 768], 0.0], # [depths, dims, dpr]
69 | 'S': [[3, 3, 27, 3], [96, 192, 384, 768], 0.0],
70 | 'B': [[3, 3, 27, 3], [128, 256, 512, 1024], 0.0]
71 | }
72 |
73 |
74 | class ConvNeXt(nn.Module):
75 | def __init__(self, model_name: str = 'T') -> None:
76 | super().__init__()
77 | assert model_name in convnext_settings.keys(), f"ConvNeXt model name should be in {list(convnext_settings.keys())}"
78 | depths, embed_dims, drop_path_rate = convnext_settings[model_name]
79 | self.channels = embed_dims
80 |
81 | self.downsample_layers = nn.ModuleList([
82 | Stem(3, embed_dims[0], 4, 4),
83 | *[Downsample(embed_dims[i], embed_dims[i+1], 2, 2) for i in range(3)]
84 | ])
85 |
86 | self.stages = nn.ModuleList()
87 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
88 | cur = 0
89 |
90 | for i in range(4):
91 | stage = nn.Sequential(*[
92 | Block(embed_dims[i], dpr[cur+j])
93 | for j in range(depths[i])])
94 | self.stages.append(stage)
95 | cur += depths[i]
96 |
97 | for i in range(4):
98 | self.add_module(f"norm{i}", LayerNorm(embed_dims[i]))
99 |
100 | def forward(self, x: Tensor):
101 | outs = []
102 |
103 | for i in range(4):
104 | x = self.downsample_layers[i](x)
105 | x = self.stages[i](x)
106 | norm_layer = getattr(self, f"norm{i}")
107 | outs.append(norm_layer(x))
108 | return outs
109 |
110 |
111 | if __name__ == '__main__':
112 | model = ConvNeXt('T')
113 | # model.load_state_dict(torch.load('C:\\Users\\sithu\\Documents\\weights\\backbones\\convnext\\convnext_tiny_1k_224_ema.pth', map_location='cpu')['model'], strict=False)
114 | x = torch.randn(1, 3, 224, 224)
115 | feats = model(x)
116 | for y in feats:
117 | print(y.shape)
--------------------------------------------------------------------------------
/semseg/models/backbones/mit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | from torch.nn import functional as F
4 | from semseg.models.layers import DropPath
5 |
6 |
7 | class Attention(nn.Module):
8 | def __init__(self, dim, head, sr_ratio):
9 | super().__init__()
10 | self.head = head
11 | self.sr_ratio = sr_ratio
12 | self.scale = (dim // head) ** -0.5
13 | self.q = nn.Linear(dim, dim)
14 | self.kv = nn.Linear(dim, dim*2)
15 | self.proj = nn.Linear(dim, dim)
16 |
17 | if sr_ratio > 1:
18 | self.sr = nn.Conv2d(dim, dim, sr_ratio, sr_ratio)
19 | self.norm = nn.LayerNorm(dim)
20 |
21 | def forward(self, x: Tensor, H, W) -> Tensor:
22 | B, N, C = x.shape
23 | q = self.q(x).reshape(B, N, self.head, C // self.head).permute(0, 2, 1, 3)
24 |
25 | if self.sr_ratio > 1:
26 | x = x.permute(0, 2, 1).reshape(B, C, H, W)
27 | x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1)
28 | x = self.norm(x)
29 |
30 | k, v = self.kv(x).reshape(B, -1, 2, self.head, C // self.head).permute(2, 0, 3, 1, 4)
31 |
32 | attn = (q @ k.transpose(-2, -1)) * self.scale
33 | attn = attn.softmax(dim=-1)
34 |
35 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
36 | x = self.proj(x)
37 | return x
38 |
39 |
40 | class DWConv(nn.Module):
41 | def __init__(self, dim):
42 | super().__init__()
43 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)
44 |
45 | def forward(self, x: Tensor, H, W) -> Tensor:
46 | B, _, C = x.shape
47 | x = x.transpose(1, 2).view(B, C, H, W)
48 | x = self.dwconv(x)
49 | return x.flatten(2).transpose(1, 2)
50 |
51 |
52 | class MLP(nn.Module):
53 | def __init__(self, c1, c2):
54 | super().__init__()
55 | self.fc1 = nn.Linear(c1, c2)
56 | self.dwconv = DWConv(c2)
57 | self.fc2 = nn.Linear(c2, c1)
58 |
59 | def forward(self, x: Tensor, H, W) -> Tensor:
60 | return self.fc2(F.gelu(self.dwconv(self.fc1(x), H, W)))
61 |
62 |
63 | class PatchEmbed(nn.Module):
64 | def __init__(self, c1=3, c2=32, patch_size=7, stride=4):
65 | super().__init__()
66 | self.proj = nn.Conv2d(c1, c2, patch_size, stride, patch_size//2) # padding=(ps[0]//2, ps[1]//2)
67 | self.norm = nn.LayerNorm(c2)
68 |
69 | def forward(self, x: Tensor) -> Tensor:
70 | x = self.proj(x)
71 | _, _, H, W = x.shape
72 | x = x.flatten(2).transpose(1, 2)
73 | x = self.norm(x)
74 | return x, H, W
75 |
76 |
77 | class Block(nn.Module):
78 | def __init__(self, dim, head, sr_ratio=1, dpr=0.):
79 | super().__init__()
80 | self.norm1 = nn.LayerNorm(dim)
81 | self.attn = Attention(dim, head, sr_ratio)
82 | self.drop_path = DropPath(dpr) if dpr > 0. else nn.Identity()
83 | self.norm2 = nn.LayerNorm(dim)
84 | self.mlp = MLP(dim, int(dim*4))
85 |
86 | def forward(self, x: Tensor, H, W) -> Tensor:
87 | x = x + self.drop_path(self.attn(self.norm1(x), H, W))
88 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
89 | return x
90 |
91 |
92 | mit_settings = {
93 | 'B0': [[32, 64, 160, 256], [2, 2, 2, 2]], # [embed_dims, depths]
94 | 'B1': [[64, 128, 320, 512], [2, 2, 2, 2]],
95 | 'B2': [[64, 128, 320, 512], [3, 4, 6, 3]],
96 | 'B3': [[64, 128, 320, 512], [3, 4, 18, 3]],
97 | 'B4': [[64, 128, 320, 512], [3, 8, 27, 3]],
98 | 'B5': [[64, 128, 320, 512], [3, 6, 40, 3]]
99 | }
100 |
101 |
102 | class MiT(nn.Module):
103 | def __init__(self, model_name: str = 'B0'):
104 | super().__init__()
105 | assert model_name in mit_settings.keys(), f"MiT model name should be in {list(mit_settings.keys())}"
106 | embed_dims, depths = mit_settings[model_name]
107 | drop_path_rate = 0.1
108 | self.channels = embed_dims
109 |
110 | # patch_embed
111 | self.patch_embed1 = PatchEmbed(3, embed_dims[0], 7, 4)
112 | self.patch_embed2 = PatchEmbed(embed_dims[0], embed_dims[1], 3, 2)
113 | self.patch_embed3 = PatchEmbed(embed_dims[1], embed_dims[2], 3, 2)
114 | self.patch_embed4 = PatchEmbed(embed_dims[2], embed_dims[3], 3, 2)
115 |
116 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
117 |
118 | cur = 0
119 | self.block1 = nn.ModuleList([Block(embed_dims[0], 1, 8, dpr[cur+i]) for i in range(depths[0])])
120 | self.norm1 = nn.LayerNorm(embed_dims[0])
121 |
122 | cur += depths[0]
123 | self.block2 = nn.ModuleList([Block(embed_dims[1], 2, 4, dpr[cur+i]) for i in range(depths[1])])
124 | self.norm2 = nn.LayerNorm(embed_dims[1])
125 |
126 | cur += depths[1]
127 | self.block3 = nn.ModuleList([Block(embed_dims[2], 5, 2, dpr[cur+i]) for i in range(depths[2])])
128 | self.norm3 = nn.LayerNorm(embed_dims[2])
129 |
130 | cur += depths[2]
131 | self.block4 = nn.ModuleList([Block(embed_dims[3], 8, 1, dpr[cur+i]) for i in range(depths[3])])
132 | self.norm4 = nn.LayerNorm(embed_dims[3])
133 |
134 |
135 | def forward(self, x: Tensor) -> Tensor:
136 | B = x.shape[0]
137 | # stage 1
138 | x, H, W = self.patch_embed1(x)
139 | for blk in self.block1:
140 | x = blk(x, H, W)
141 | x1 = self.norm1(x).reshape(B, H, W, -1).permute(0, 3, 1, 2)
142 |
143 | # stage 2
144 | x, H, W = self.patch_embed2(x1)
145 | for blk in self.block2:
146 | x = blk(x, H, W)
147 | x2 = self.norm2(x).reshape(B, H, W, -1).permute(0, 3, 1, 2)
148 |
149 | # stage 3
150 | x, H, W = self.patch_embed3(x2)
151 | for blk in self.block3:
152 | x = blk(x, H, W)
153 | x3 = self.norm3(x).reshape(B, H, W, -1).permute(0, 3, 1, 2)
154 |
155 | # stage 4
156 | x, H, W = self.patch_embed4(x3)
157 | for blk in self.block4:
158 | x = blk(x, H, W)
159 | x4 = self.norm4(x).reshape(B, H, W, -1).permute(0, 3, 1, 2)
160 |
161 | return x1, x2, x3, x4
162 |
163 |
164 | if __name__ == '__main__':
165 | model = MiT('B0')
166 | x = torch.zeros(1, 3, 224, 224)
167 | outs = model(x)
168 | for y in outs:
169 | print(y.shape)
170 |
171 |
172 |
--------------------------------------------------------------------------------
/semseg/models/backbones/mobilenetv2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 |
4 |
5 | class ConvModule(nn.Sequential):
6 | def __init__(self, c1, c2, k, s=1, p=0, d=1, g=1):
7 | super().__init__(
8 | nn.Conv2d(c1, c2, k, s, p, d, g, bias=False),
9 | nn.BatchNorm2d(c2),
10 | nn.ReLU6(True)
11 | )
12 |
13 |
14 | class InvertedResidual(nn.Module):
15 | def __init__(self, c1, c2, s, expand_ratio):
16 | super().__init__()
17 | ch = int(round(c1 * expand_ratio))
18 | self.use_res_connect = s == 1 and c1 == c2
19 |
20 | layers = []
21 |
22 | if expand_ratio != 1:
23 | layers.append(ConvModule(c1, ch, 1))
24 |
25 | layers.extend([
26 | ConvModule(ch, ch, 3, s, 1, g=ch),
27 | nn.Conv2d(ch, c2, 1, bias=False),
28 | nn.BatchNorm2d(c2)
29 | ])
30 |
31 | self.conv = nn.Sequential(*layers)
32 |
33 | def forward(self, x: Tensor) -> Tensor:
34 | if self.use_res_connect:
35 | return x + self.conv(x)
36 | else:
37 | return self.conv(x)
38 |
39 |
40 | mobilenetv2_settings = {
41 | '1.0': []
42 | }
43 |
44 |
45 | class MobileNetV2(nn.Module):
46 | def __init__(self, variant: str = None):
47 | super().__init__()
48 | self.out_indices = [3, 6, 13, 17]
49 | self.channels = [24, 32, 96, 320]
50 | input_channel = 32
51 |
52 | inverted_residual_setting = [
53 | # t, c, n, s
54 | [1, 16, 1, 1],
55 | [6, 24, 2, 2],
56 | [6, 32, 3, 2],
57 | [6, 64, 4, 2],
58 | [6, 96, 3, 1],
59 | [6, 160, 3, 2],
60 | [6, 320, 1, 1],
61 | ]
62 |
63 | self.features = nn.ModuleList([ConvModule(3, input_channel, 3, 2, 1)])
64 |
65 | for t, c, n, s in inverted_residual_setting:
66 | output_channel = c
67 | for i in range(n):
68 | stride = s if i == 0 else 1
69 | self.features.append(InvertedResidual(input_channel, output_channel, stride, t))
70 | input_channel = output_channel
71 |
72 | def forward(self, x: Tensor) -> Tensor:
73 | outs = []
74 | for i, m in enumerate(self.features):
75 | x = m(x)
76 | if i in self.out_indices:
77 | outs.append(x)
78 | return outs
79 |
80 |
81 | if __name__ == '__main__':
82 | model = MobileNetV2()
83 | # model.load_state_dict(torch.load('checkpoints/backbones/mobilenet_v2.pth', map_location='cpu'), strict=False)
84 | model.eval()
85 | x = torch.randn(1, 3, 224, 224)
86 | # outs = model(x)
87 | # for y in outs:
88 | # print(y.shape)
89 |
90 | from fvcore.nn import flop_count_table, FlopCountAnalysis
91 | print(flop_count_table(FlopCountAnalysis(model, x)))
--------------------------------------------------------------------------------
/semseg/models/backbones/mobilenetv3.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | from torch.nn import functional as F
4 | from typing import Optional
5 |
6 |
7 | def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
8 | """
9 | This function is taken from the original tf repo.
10 | It ensures that all layers have a channel number that is divisible by 8
11 | It can be seen here:
12 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
13 | """
14 | if min_value is None:
15 | min_value = divisor
16 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
17 | # Make sure that round down does not go down by more than 10%.
18 | if new_v < 0.9 * v:
19 | new_v += divisor
20 | return new_v
21 |
22 |
23 | class ConvModule(nn.Sequential):
24 | def __init__(self, c1, c2, k, s=1, p=0, d=1, g=1):
25 | super().__init__(
26 | nn.Conv2d(c1, c2, k, s, p, d, g, bias=False),
27 | nn.BatchNorm2d(c2),
28 | nn.ReLU6(True)
29 | )
30 |
31 |
32 | class SqueezeExcitation(nn.Module):
33 | def __init__(self, ch, squeeze_factor=4):
34 | super().__init__()
35 | squeeze_ch = _make_divisible(ch // squeeze_factor, 8)
36 | self.fc1 = nn.Conv2d(ch, squeeze_ch, 1)
37 | self.relu = nn.ReLU(True)
38 | self.fc2 = nn.Conv2d(squeeze_ch, ch, 1)
39 |
40 | def _scale(self, x: Tensor) -> Tensor:
41 | scale = F.adaptive_avg_pool2d(x, 1)
42 | scale = self.fc2(self.relu(self.fc1(scale)))
43 | return F.hardsigmoid(scale, True)
44 |
45 | def forward(self, x: Tensor) -> Tensor:
46 | scale = self._scale(x)
47 | return scale * x
48 |
49 |
50 | class InvertedResidualConfig:
51 | def __init__(self, c1, c2, k, expanded_ch, use_se) -> None:
52 | pass
53 |
54 |
55 | class InvertedResidual(nn.Module):
56 | def __init__(self, c1, c2, s, expand_ratio):
57 | super().__init__()
58 | ch = int(round(c1 * expand_ratio))
59 | self.use_res_connect = s == 1 and c1 == c2
60 |
61 | layers = []
62 |
63 | if expand_ratio != 1:
64 | layers.append(ConvModule(c1, ch, 1))
65 |
66 | layers.extend([
67 | ConvModule(ch, ch, 3, s, 1, g=ch),
68 | nn.Conv2d(ch, c2, 1, bias=False),
69 | nn.BatchNorm2d(c2)
70 | ])
71 |
72 | self.conv = nn.Sequential(*layers)
73 |
74 | def forward(self, x: Tensor) -> Tensor:
75 | if self.use_res_connect:
76 | return x + self.conv(x)
77 | else:
78 | return self.conv(x)
79 |
80 |
81 | mobilenetv3_settings = {
82 | 'S': [],
83 | 'L': []
84 | }
85 |
86 |
87 | class MobileNetV3(nn.Module):
88 | def __init__(self, variant: str = None):
89 | super().__init__()
90 | self.out_indices = [3, 6, 13, 17]
91 | self.channels = [24, 32, 96, 320]
92 | input_channel = 32
93 |
94 | inverted_residual_setting = [
95 | # t, c, n, s
96 | [1, 16, 1, 1],
97 | [6, 24, 2, 2],
98 | [6, 32, 3, 2],
99 | [6, 64, 4, 2],
100 | [6, 96, 3, 1],
101 | [6, 160, 3, 2],
102 | [6, 320, 1, 1],
103 | ]
104 |
105 | self.features = nn.ModuleList([ConvModule(3, input_channel, 3, 2, 1)])
106 |
107 | for t, c, n, s in inverted_residual_setting:
108 | output_channel = c
109 | for i in range(n):
110 | stride = s if i == 0 else 1
111 | self.features.append(InvertedResidual(input_channel, output_channel, stride, t))
112 | input_channel = output_channel
113 |
114 | def forward(self, x: Tensor) -> Tensor:
115 | outs = []
116 | for i, m in enumerate(self.features):
117 | x = m(x)
118 | if i in self.out_indices:
119 | outs.append(x)
120 | return outs
121 |
122 |
123 | if __name__ == '__main__':
124 | model = MobileNetV3()
125 | # model.load_state_dict(torch.load('checkpoints/backbones/mobilenet_v2.pth', map_location='cpu'), strict=False)
126 | model.eval()
127 | x = torch.randn(1, 3, 224, 224)
128 | # outs = model(x)
129 | # for y in outs:
130 | # print(y.shape)
131 |
132 | from fvcore.nn import flop_count_table, FlopCountAnalysis
133 | print(flop_count_table(FlopCountAnalysis(model, x)))
--------------------------------------------------------------------------------
/semseg/models/backbones/poolformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | from semseg.models.layers import DropPath
4 |
5 |
6 | class PatchEmbed(nn.Module):
7 | """Image to Patch Embedding with overlapping
8 | """
9 | def __init__(self, patch_size=16, stride=16, padding=0, in_ch=3, embed_dim=768):
10 | super().__init__()
11 | self.proj = nn.Conv2d(in_ch, embed_dim, patch_size, stride, padding)
12 |
13 | def forward(self, x: torch.Tensor) -> Tensor:
14 | x = self.proj(x) # b x hidden_dim x 14 x 14
15 | return x
16 |
17 |
18 | class Pooling(nn.Module):
19 | def __init__(self, pool_size=3) -> None:
20 | super().__init__()
21 | self.pool = nn.AvgPool2d(pool_size, 1, pool_size//2, count_include_pad=False)
22 |
23 | def forward(self, x: Tensor) -> Tensor:
24 | return self.pool(x) - x
25 |
26 |
27 | class MLP(nn.Module):
28 | def __init__(self, dim, hidden_dim, out_dim=None) -> None:
29 | super().__init__()
30 | out_dim = out_dim or dim
31 | self.fc1 = nn.Conv2d(dim, hidden_dim, 1)
32 | self.act = nn.GELU()
33 | self.fc2 = nn.Conv2d(hidden_dim, out_dim, 1)
34 |
35 | def forward(self, x: Tensor) -> Tensor:
36 | return self.fc2(self.act(self.fc1(x)))
37 |
38 |
39 | class PoolFormerBlock(nn.Module):
40 | def __init__(self, dim, pool_size=3, dpr=0., layer_scale_init_value=1e-5):
41 | super().__init__()
42 | self.norm1 = nn.GroupNorm(1, dim)
43 | self.token_mixer = Pooling(pool_size)
44 | self.norm2 = nn.GroupNorm(1, dim)
45 | self.drop_path = DropPath(dpr) if dpr > 0. else nn.Identity()
46 | self.mlp = MLP(dim, int(dim*4))
47 |
48 | self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
49 | self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
50 |
51 | def forward(self, x: Tensor) -> Tensor:
52 | x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.token_mixer(self.norm1(x)))
53 | x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)))
54 | return x
55 |
56 | poolformer_settings = {
57 | 'S24': [[4, 4, 12, 4], [64, 128, 320, 512], 0.1], # [layers, embed_dims, drop_path_rate]
58 | 'S36': [[6, 6, 18, 6], [64, 128, 320, 512], 0.2],
59 | 'M36': [[6, 6, 18, 6], [96, 192, 384, 768], 0.3]
60 | }
61 |
62 |
63 | class PoolFormer(nn.Module):
64 | def __init__(self, model_name: str = 'S24') -> None:
65 | super().__init__()
66 | assert model_name in poolformer_settings.keys(), f"PoolFormer model name should be in {list(poolformer_settings.keys())}"
67 | layers, embed_dims, drop_path_rate = poolformer_settings[model_name]
68 | self.channels = embed_dims
69 |
70 | self.patch_embed = PatchEmbed(7, 4, 2, 3, embed_dims[0])
71 |
72 | network = []
73 |
74 | for i in range(len(layers)):
75 | blocks = []
76 | for j in range(layers[i]):
77 | dpr = drop_path_rate * (j + sum(layers[:i])) / (sum(layers) - 1)
78 | blocks.append(PoolFormerBlock(embed_dims[i], 3, dpr))
79 |
80 | network.append(nn.Sequential(*blocks))
81 | if i >= len(layers) - 1: break
82 | network.append(PatchEmbed(3, 2, 1, embed_dims[i], embed_dims[i+1]))
83 |
84 | self.network = nn.ModuleList(network)
85 |
86 | self.out_indices = [0, 2, 4, 6]
87 | for i, index in enumerate(self.out_indices):
88 | self.add_module(f"norm{index}", nn.GroupNorm(1, embed_dims[i]))
89 |
90 | def forward(self, x: Tensor):
91 | x = self.patch_embed(x)
92 | outs = []
93 |
94 | for i, blk in enumerate(self.network):
95 | x = blk(x)
96 |
97 | if i in self.out_indices:
98 | out = getattr(self, f"norm{i}")(x)
99 | outs.append(out)
100 | return outs
101 |
102 |
103 | if __name__ == '__main__':
104 | model = PoolFormer('S24')
105 | model.load_state_dict(torch.load('C:\\Users\\sithu\\Documents\\weights\\backbones\\poolformer\\poolformer_s24.pth.tar', map_location='cpu'), strict=False)
106 | x = torch.randn(1, 3, 224, 224)
107 | feats = model(x)
108 | for y in feats:
109 | print(y.shape)
110 |
--------------------------------------------------------------------------------
/semseg/models/backbones/pvt.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | from torch.nn import functional as F
4 | from semseg.models.layers import DropPath
5 |
6 |
7 | class DWConv(nn.Module):
8 | def __init__(self, dim):
9 | super().__init__()
10 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)
11 |
12 | def forward(self, x: Tensor, H: int, W: int) -> Tensor:
13 | B, _, C = x.shape
14 | x = x.transpose(1, 2).view(B, C, H, W)
15 | x = self.dwconv(x)
16 | return x.flatten(2).transpose(1, 2)
17 |
18 |
19 | class MLP(nn.Module):
20 | def __init__(self, dim, hidden_dim, out_dim=None) -> None:
21 | super().__init__()
22 | out_dim = out_dim or dim
23 | self.fc1 = nn.Linear(dim, hidden_dim)
24 | self.fc2 = nn.Linear(hidden_dim, out_dim)
25 | self.dwconv = DWConv(hidden_dim)
26 |
27 | def forward(self, x: Tensor, H: int, W: int) -> Tensor:
28 | return self.fc2(F.gelu(self.dwconv(self.fc1(x), H, W)))
29 |
30 |
31 | class Attention(nn.Module):
32 | def __init__(self, dim, head, sr_ratio):
33 | super().__init__()
34 | self.head = head
35 | self.sr_ratio = sr_ratio
36 | self.scale = (dim // head) ** -0.5
37 | self.q = nn.Linear(dim, dim, bias=True)
38 | self.kv = nn.Linear(dim, dim*2, bias=True)
39 | self.proj = nn.Linear(dim, dim)
40 |
41 | if sr_ratio > 1:
42 | self.sr = nn.Conv2d(dim, dim, sr_ratio, sr_ratio)
43 | self.norm = nn.LayerNorm(dim)
44 |
45 | def forward(self, x: Tensor, H, W) -> Tensor:
46 | B, N, C = x.shape
47 | q = self.q(x).reshape(B, N, self.head, C // self.head).permute(0, 2, 1, 3)
48 |
49 | if self.sr_ratio > 1:
50 | x = x.permute(0, 2, 1).reshape(B, C, H, W)
51 | x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1)
52 | x = self.norm(x)
53 | k, v = self.kv(x).reshape(B, -1, 2, self.head, C // self.head).permute(2, 0, 3, 1, 4)
54 |
55 | attn = (q @ k.transpose(-2, -1)) * self.scale
56 | attn = attn.softmax(dim=-1)
57 |
58 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
59 | x = self.proj(x)
60 | return x
61 |
62 |
63 | class Block(nn.Module):
64 | def __init__(self, dim, head, sr_ratio=1, mlp_ratio=4, dpr=0.):
65 | super().__init__()
66 | self.norm1 = nn.LayerNorm(dim)
67 | self.attn = Attention(dim, head, sr_ratio)
68 | self.drop_path = DropPath(dpr) if dpr > 0. else nn.Identity()
69 | self.norm2 = nn.LayerNorm(dim)
70 | self.mlp = MLP(dim, int(dim*mlp_ratio))
71 |
72 | def forward(self, x: Tensor, H, W) -> Tensor:
73 | x = x + self.drop_path(self.attn(self.norm1(x), H, W))
74 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
75 | return x
76 |
77 |
78 | class PatchEmbed(nn.Module):
79 | def __init__(self, c1=3, c2=64, patch_size=7, stride=4):
80 | super().__init__()
81 | self.proj = nn.Conv2d(c1, c2, patch_size, stride, patch_size//2)
82 | self.norm = nn.LayerNorm(c2)
83 |
84 | def forward(self, x: Tensor) -> Tensor:
85 | x = self.proj(x)
86 | _, _, H, W = x.shape
87 | x = x.flatten(2).transpose(1, 2)
88 | x = self.norm(x)
89 | return x, H, W
90 |
91 |
92 | pvtv2_settings = {
93 | 'B1': [2, 2, 2, 2], # depths
94 | 'B2': [3, 4, 6, 3],
95 | 'B3': [3, 4, 18, 3],
96 | 'B4': [3, 8, 27, 3],
97 | 'B5': [3, 6, 40, 3]
98 | }
99 |
100 |
101 | class PVTv2(nn.Module):
102 | def __init__(self, model_name: str = 'B1') -> None:
103 | super().__init__()
104 | assert model_name in pvtv2_settings.keys(), f"PVTv2 model name should be in {list(pvtv2_settings.keys())}"
105 | depths = pvtv2_settings[model_name]
106 | embed_dims = [64, 128, 320, 512]
107 | drop_path_rate = 0.1
108 | self.channels = embed_dims
109 | # patch_embed
110 | self.patch_embed1 = PatchEmbed(3, embed_dims[0], 7, 4)
111 | self.patch_embed2 = PatchEmbed(embed_dims[0], embed_dims[1], 3, 2)
112 | self.patch_embed3 = PatchEmbed(embed_dims[1], embed_dims[2], 3, 2)
113 | self.patch_embed4 = PatchEmbed(embed_dims[2], embed_dims[3], 3, 2)
114 |
115 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
116 | # transformer encoder
117 | cur = 0
118 | self.block1 = nn.ModuleList([Block(embed_dims[0], 1, 8, 8, dpr[cur+i]) for i in range(depths[0])])
119 | self.norm1 = nn.LayerNorm(embed_dims[0])
120 |
121 | cur += depths[0]
122 | self.block2 = nn.ModuleList([Block(embed_dims[1], 2, 4, 8, dpr[cur+i]) for i in range(depths[1])])
123 | self.norm2 = nn.LayerNorm(embed_dims[1])
124 |
125 | cur += depths[1]
126 | self.block3 = nn.ModuleList([Block(embed_dims[2], 5, 2, 4, dpr[cur+i]) for i in range(depths[2])])
127 | self.norm3 = nn.LayerNorm(embed_dims[2])
128 |
129 | cur += depths[2]
130 | self.block4 = nn.ModuleList([Block(embed_dims[3], 8, 1, 4, dpr[cur+i]) for i in range(depths[3])])
131 | self.norm4 = nn.LayerNorm(embed_dims[3])
132 |
133 | def forward(self, x: Tensor) -> Tensor:
134 | B = x.shape[0]
135 | # stage 1
136 | x, H, W = self.patch_embed1(x)
137 | for blk in self.block1:
138 | x = blk(x, H, W)
139 | x1 = self.norm1(x).reshape(B, H, W, -1).permute(0, 3, 1, 2)
140 |
141 | # stage 2
142 | x, H, W = self.patch_embed2(x1)
143 | for blk in self.block2:
144 | x = blk(x, H, W)
145 | x2 = self.norm2(x).reshape(B, H, W, -1).permute(0, 3, 1, 2)
146 |
147 | # stage 3
148 | x, H, W = self.patch_embed3(x2)
149 | for blk in self.block3:
150 | x = blk(x, H, W)
151 | x3 = self.norm3(x).reshape(B, H, W, -1).permute(0, 3, 1, 2)
152 |
153 | # stage 4
154 | x, H, W = self.patch_embed4(x3)
155 | for blk in self.block4:
156 | x = blk(x, H, W)
157 | x4 = self.norm4(x).reshape(B, H, W, -1).permute(0, 3, 1, 2)
158 |
159 | return x1, x2, x3, x4
160 |
161 |
162 | if __name__ == '__main__':
163 | model = PVTv2('B1')
164 | model.load_state_dict(torch.load('checkpoints/backbones/pvtv2/pvt_v2_b1.pth', map_location='cpu'), strict=False)
165 | x = torch.zeros(1, 3, 224, 224)
166 | outs = model(x)
167 | for y in outs:
168 | print(y.shape)
--------------------------------------------------------------------------------
/semseg/models/backbones/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | from torch.nn import functional as F
4 |
5 |
6 | class BasicBlock(nn.Module):
7 | """2 Layer No Expansion Block
8 | """
9 | expansion: int = 1
10 | def __init__(self, c1, c2, s=1, downsample= None) -> None:
11 | super().__init__()
12 | self.conv1 = nn.Conv2d(c1, c2, 3, s, 1, bias=False)
13 | self.bn1 = nn.BatchNorm2d(c2)
14 | self.conv2 = nn.Conv2d(c2, c2, 3, 1, 1, bias=False)
15 | self.bn2 = nn.BatchNorm2d(c2)
16 | self.downsample = downsample
17 |
18 | def forward(self, x: Tensor) -> Tensor:
19 | identity = x
20 | out = F.relu(self.bn1(self.conv1(x)))
21 | out = self.bn2(self.conv2(out))
22 | if self.downsample is not None: identity = self.downsample(x)
23 | out += identity
24 | return F.relu(out)
25 |
26 |
27 | class Bottleneck(nn.Module):
28 | """3 Layer 4x Expansion Block
29 | """
30 | expansion: int = 4
31 | def __init__(self, c1, c2, s=1, downsample=None) -> None:
32 | super().__init__()
33 | self.conv1 = nn.Conv2d(c1, c2, 1, 1, 0, bias=False)
34 | self.bn1 = nn.BatchNorm2d(c2)
35 | self.conv2 = nn.Conv2d(c2, c2, 3, s, 1, bias=False)
36 | self.bn2 = nn.BatchNorm2d(c2)
37 | self.conv3 = nn.Conv2d(c2, c2 * self.expansion, 1, 1, 0, bias=False)
38 | self.bn3 = nn.BatchNorm2d(c2 * self.expansion)
39 | self.downsample = downsample
40 |
41 | def forward(self, x: Tensor) -> Tensor:
42 | identity = x
43 | out = F.relu(self.bn1(self.conv1(x)))
44 | out = F.relu(self.bn2(self.conv2(out)))
45 | out = self.bn3(self.conv3(out))
46 | if self.downsample is not None: identity = self.downsample(x)
47 | out += identity
48 | return F.relu(out)
49 |
50 |
51 | resnet_settings = {
52 | '18': [BasicBlock, [2, 2, 2, 2], [64, 128, 256, 512]],
53 | '34': [BasicBlock, [3, 4, 6, 3], [64, 128, 256, 512]],
54 | '50': [Bottleneck, [3, 4, 6, 3], [256, 512, 1024, 2048]],
55 | '101': [Bottleneck, [3, 4, 23, 3], [256, 512, 1024, 2048]],
56 | '152': [Bottleneck, [3, 8, 36, 3], [256, 512, 1024, 2048]]
57 | }
58 |
59 |
60 | class ResNet(nn.Module):
61 | def __init__(self, model_name: str = '50') -> None:
62 | super().__init__()
63 | assert model_name in resnet_settings.keys(), f"ResNet model name should be in {list(resnet_settings.keys())}"
64 | block, depths, channels = resnet_settings[model_name]
65 |
66 | self.inplanes = 64
67 | self.channels = channels
68 | self.conv1 = nn.Conv2d(3, self.inplanes, 7, 2, 3, bias=False)
69 | self.bn1 = nn.BatchNorm2d(self.inplanes)
70 | self.maxpool = nn.MaxPool2d(3, 2, 1)
71 |
72 | self.layer1 = self._make_layer(block, 64, depths[0], s=1)
73 | self.layer2 = self._make_layer(block, 128, depths[1], s=2)
74 | self.layer3 = self._make_layer(block, 256, depths[2], s=2)
75 | self.layer4 = self._make_layer(block, 512, depths[3], s=2)
76 |
77 |
78 | def _make_layer(self, block, planes, depth, s=1) -> nn.Sequential:
79 | downsample = None
80 | if s != 1 or self.inplanes != planes * block.expansion:
81 | downsample = nn.Sequential(
82 | nn.Conv2d(self.inplanes, planes * block.expansion, 1, s, bias=False),
83 | nn.BatchNorm2d(planes * block.expansion)
84 | )
85 | layers = nn.Sequential(
86 | block(self.inplanes, planes, s, downsample),
87 | *[block(planes * block.expansion, planes) for _ in range(1, depth)]
88 | )
89 | self.inplanes = planes * block.expansion
90 | return layers
91 |
92 |
93 | def forward(self, x: Tensor) -> Tensor:
94 | x = self.maxpool(F.relu(self.bn1(self.conv1(x)))) # [1, 64, H/4, W/4]
95 | x1 = self.layer1(x) # [1, 64/256, H/4, W/4]
96 | x2 = self.layer2(x1) # [1, 128/512, H/8, W/8]
97 | x3 = self.layer3(x2) # [1, 256/1024, H/16, W/16]
98 | x4 = self.layer4(x3) # [1, 512/2048, H/32, W/32]
99 | return x1, x2, x3, x4
100 |
101 |
102 | if __name__ == '__main__':
103 | model = ResNet('18')
104 | # model.load_state_dict(torch.load('C:\\Users\\sithu\\Documents\\weights\\backbones\\resnet\\resnet18_a1.pth', map_location='cpu'), strict=False)
105 | x = torch.zeros(1, 3, 224, 224)
106 | outs = model(x)
107 | for y in outs:
108 | print(y.shape)
--------------------------------------------------------------------------------
/semseg/models/backbones/resnetd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | from torch.nn import functional as F
4 |
5 |
6 | class BasicBlock(nn.Module):
7 | """2 Layer No Expansion Block
8 | """
9 | expansion: int = 1
10 | def __init__(self, c1, c2, s=1, d=1, downsample= None) -> None:
11 | super().__init__()
12 | self.conv1 = nn.Conv2d(c1, c2, 3, s, 1, bias=False)
13 | self.bn1 = nn.BatchNorm2d(c2)
14 | self.conv2 = nn.Conv2d(c2, c2, 3, 1, d if d != 1 else 1, d, bias=False)
15 | self.bn2 = nn.BatchNorm2d(c2)
16 | self.downsample = downsample
17 |
18 | def forward(self, x: Tensor) -> Tensor:
19 | identity = x
20 | out = F.relu(self.bn1(self.conv1(x)))
21 | out = self.bn2(self.conv2(out))
22 | if self.downsample is not None: identity = self.downsample(x)
23 | out += identity
24 | return F.relu(out)
25 |
26 |
27 | class Bottleneck(nn.Module):
28 | """3 Layer 4x Expansion Block
29 | """
30 | expansion: int = 4
31 | def __init__(self, c1, c2, s=1, d=1, downsample=None) -> None:
32 | super().__init__()
33 | self.conv1 = nn.Conv2d(c1, c2, 1, 1, 0, bias=False)
34 | self.bn1 = nn.BatchNorm2d(c2)
35 | self.conv2 = nn.Conv2d(c2, c2, 3, s, d if d != 1 else 1, d, bias=False)
36 | self.bn2 = nn.BatchNorm2d(c2)
37 | self.conv3 = nn.Conv2d(c2, c2 * self.expansion, 1, 1, 0, bias=False)
38 | self.bn3 = nn.BatchNorm2d(c2 * self.expansion)
39 | self.downsample = downsample
40 |
41 | def forward(self, x: Tensor) -> Tensor:
42 | identity = x
43 | out = F.relu(self.bn1(self.conv1(x)))
44 | out = F.relu(self.bn2(self.conv2(out)))
45 | out = self.bn3(self.conv3(out))
46 | if self.downsample is not None: identity = self.downsample(x)
47 | out += identity
48 | return F.relu(out)
49 |
50 |
51 | class Stem(nn.Sequential):
52 | def __init__(self, c1, ch, c2):
53 | super().__init__(
54 | nn.Conv2d(c1, ch, 3, 2, 1, bias=False),
55 | nn.BatchNorm2d(ch),
56 | nn.ReLU(True),
57 | nn.Conv2d(ch, ch, 3, 1, 1, bias=False),
58 | nn.BatchNorm2d(ch),
59 | nn.ReLU(True),
60 | nn.Conv2d(ch, c2, 3, 1, 1, bias=False),
61 | nn.BatchNorm2d(c2),
62 | nn.ReLU(True),
63 | nn.MaxPool2d(3, 2, 1)
64 | )
65 |
66 |
67 | resnetd_settings = {
68 | '18': [BasicBlock, [2, 2, 2, 2], [64, 128, 256, 512]],
69 | '50': [Bottleneck, [3, 4, 6, 3], [256, 512, 1024, 2048]],
70 | '101': [Bottleneck, [3, 4, 23, 3], [256, 512, 1024, 2048]]
71 | }
72 |
73 |
74 | class ResNetD(nn.Module):
75 | def __init__(self, model_name: str = '50') -> None:
76 | super().__init__()
77 | assert model_name in resnetd_settings.keys(), f"ResNetD model name should be in {list(resnetd_settings.keys())}"
78 | block, depths, channels = resnetd_settings[model_name]
79 |
80 | self.inplanes = 128
81 | self.channels = channels
82 | self.stem = Stem(3, 64, self.inplanes)
83 | self.layer1 = self._make_layer(block, 64, depths[0], s=1)
84 | self.layer2 = self._make_layer(block, 128, depths[1], s=2)
85 | self.layer3 = self._make_layer(block, 256, depths[2], s=2, d=2)
86 | self.layer4 = self._make_layer(block, 512, depths[3], s=2, d=4)
87 |
88 |
89 | def _make_layer(self, block, planes, depth, s=1, d=1) -> nn.Sequential:
90 | downsample = None
91 |
92 | if s != 1 or self.inplanes != planes * block.expansion:
93 | downsample = nn.Sequential(
94 | nn.Conv2d(self.inplanes, planes * block.expansion, 1, s, bias=False),
95 | nn.BatchNorm2d(planes * block.expansion)
96 | )
97 | layers = nn.Sequential(
98 | block(self.inplanes, planes, s, d, downsample=downsample),
99 | *[block(planes * block.expansion, planes, d=d) for _ in range(1, depth)]
100 | )
101 | self.inplanes = planes * block.expansion
102 | return layers
103 |
104 |
105 | def forward(self, x: Tensor) -> Tensor:
106 | x = self.stem(x) # [1, 128, H/4, W/4]
107 | x1 = self.layer1(x) # [1, 64/256, H/4, W/4]
108 | x2 = self.layer2(x1) # [1, 128/512, H/8, W/8]
109 | x3 = self.layer3(x2) # [1, 256/1024, H/16, W/16]
110 | x4 = self.layer4(x3) # [1, 512/2048, H/32, W/32]
111 | return x1, x2, x3, x4
112 |
113 |
114 | if __name__ == '__main__':
115 | model = ResNetD('18')
116 | model.load_state_dict(torch.load('checkpoints/backbones/resnetd/resnetd18.pth', map_location='cpu'), strict=False)
117 | x = torch.zeros(1, 3, 224, 224)
118 | outs = model(x)
119 | for y in outs:
120 | print(y.shape)
--------------------------------------------------------------------------------
/semseg/models/backbones/uniformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | from semseg.models.layers import DropPath
4 |
5 |
6 | class MLP(nn.Module):
7 | def __init__(self, dim, hidden_dim, out_dim=None) -> None:
8 | super().__init__()
9 | out_dim = out_dim or dim
10 | self.fc1 = nn.Linear(dim, hidden_dim)
11 | self.act = nn.GELU()
12 | self.fc2 = nn.Linear(hidden_dim, out_dim)
13 |
14 | def forward(self, x: Tensor) -> Tensor:
15 | return self.fc2(self.act(self.fc1(x)))
16 |
17 |
18 | class CMLP(nn.Module):
19 | def __init__(self, dim, hidden_dim, out_dim=None) -> None:
20 | super().__init__()
21 | out_dim = out_dim or dim
22 | self.fc1 = nn.Conv2d(dim, hidden_dim, 1)
23 | self.act = nn.GELU()
24 | self.fc2 = nn.Conv2d(hidden_dim, out_dim, 1)
25 |
26 | def forward(self, x: Tensor) -> Tensor:
27 | return self.fc2(self.act(self.fc1(x)))
28 |
29 |
30 | class Attention(nn.Module):
31 | def __init__(self, dim, num_heads=8) -> None:
32 | super().__init__()
33 | self.num_heads = num_heads
34 | self.scale = (dim // num_heads) ** -0.5
35 | self.qkv = nn.Linear(dim, dim*3)
36 | self.proj = nn.Linear(dim, dim)
37 |
38 | def forward(self, x: Tensor) -> Tensor:
39 | B, N, C = x.shape
40 | q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads, C//self.num_heads).permute(2, 0, 3, 1, 4)
41 | attn = (q @ k.transpose(-2, -1)) * self.scale
42 | attn = attn.softmax(dim=-1)
43 |
44 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
45 | x = self.proj(x)
46 | return x
47 |
48 |
49 | class CBlock(nn.Module):
50 | def __init__(self, dim, dpr=0.):
51 | super().__init__()
52 | self.pos_embed = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)
53 | self.norm1 = nn.BatchNorm2d(dim)
54 | self.conv1 = nn.Conv2d(dim, dim, 1)
55 | self.conv2 = nn.Conv2d(dim, dim, 1)
56 | self.attn = nn.Conv2d(dim, dim, 5, 1, 2, groups=dim)
57 | self.drop_path = DropPath(dpr) if dpr > 0. else nn.Identity()
58 | self.norm2 = nn.BatchNorm2d(dim)
59 | self.mlp = CMLP(dim, int(dim*4))
60 |
61 | def forward(self, x: Tensor) -> Tensor:
62 | x = x + self.pos_embed(x)
63 | x = x + self.drop_path(self.conv2(self.attn(self.conv1(self.norm1(x)))))
64 | x = x + self.drop_path(self.mlp(self.norm2(x)))
65 | return x
66 |
67 |
68 | class SABlock(nn.Module):
69 | def __init__(self, dim, num_heads, dpr=0.) -> None:
70 | super().__init__()
71 | self.pos_embed = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)
72 | self.norm1 = nn.LayerNorm(dim)
73 | self.attn = Attention(dim, num_heads)
74 | self.drop_path = DropPath(dpr) if dpr > 0. else nn.Identity()
75 | self.norm2 = nn.LayerNorm(dim)
76 | self.mlp = MLP(dim, int(dim*4))
77 |
78 | def forward(self, x: Tensor) -> Tensor:
79 | x = x + self.pos_embed(x)
80 | B, N, H, W = x.shape
81 | x = x.flatten(2).transpose(1, 2)
82 | x = x + self.drop_path(self.attn(self.norm1(x)))
83 | x = x + self.drop_path(self.mlp(self.norm2(x)))
84 | x = x.transpose(1, 2).reshape(B, N, H, W)
85 | return x
86 |
87 |
88 | class PatchEmbed(nn.Module):
89 | def __init__(self, patch_size=16, in_ch=3, embed_dim=768) -> None:
90 | super().__init__()
91 | self.norm = nn.LayerNorm(embed_dim)
92 | self.proj = nn.Conv2d(in_ch, embed_dim, patch_size, patch_size)
93 |
94 | def forward(self, x: Tensor) -> Tensor:
95 | x = self.proj(x)
96 | B, C, H, W = x.shape
97 | x = x.flatten(2).transpose(1, 2)
98 | x = self.norm(x)
99 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
100 | return x
101 |
102 |
103 | uniformer_settings = {
104 | 'S': [3, 4, 8, 3], # [depth]
105 | 'B': [5, 8, 20, 7]
106 | }
107 |
108 |
109 | class UniFormer(nn.Module):
110 | def __init__(self, model_name: str = 'S') -> None:
111 | super().__init__()
112 | assert model_name in uniformer_settings.keys(), f"UniFormer model name should be in {list(uniformer_settings.keys())}"
113 | depth = uniformer_settings[model_name]
114 |
115 | head_dim = 64
116 | drop_path_rate = 0.
117 | embed_dims = [64, 128, 320, 512]
118 |
119 | for i in range(4):
120 | self.add_module(f"patch_embed{i+1}", PatchEmbed(4 if i == 0 else 2, 3 if i == 0 else embed_dims[i-1], embed_dims[i]))
121 | self.add_module(f"norm{i+1}", nn.LayerNorm(embed_dims[i]))
122 |
123 | self.pos_drop = nn.Dropout(0.)
124 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depth))]
125 | num_heads = [dim // head_dim for dim in embed_dims]
126 |
127 | self.blocks1 = nn.ModuleList([
128 | CBlock(embed_dims[0], dpr[i])
129 | for i in range(depth[0])])
130 |
131 | self.blocks2 = nn.ModuleList([
132 | CBlock(embed_dims[1], dpr[i+depth[0]])
133 | for i in range(depth[1])])
134 |
135 | self.blocks3 = nn.ModuleList([
136 | SABlock(embed_dims[2], num_heads[2], dpr[i+depth[0]+depth[1]])
137 | for i in range(depth[2])])
138 |
139 | self.blocks4 = nn.ModuleList([
140 | SABlock(embed_dims[3], num_heads[3], dpr[i+depth[0]+depth[1]+depth[2]])
141 | for i in range(depth[3])])
142 |
143 |
144 | def forward(self, x: torch.Tensor):
145 | outs = []
146 |
147 | x = self.patch_embed1(x)
148 | x = self.pos_drop(x)
149 | for blk in self.blocks1:
150 | x = blk(x)
151 | x_out = self.norm1(x.permute(0, 2, 3, 1))
152 | outs.append(x_out.permute(0, 3, 1, 2))
153 |
154 | x = self.patch_embed2(x)
155 | for blk in self.blocks2:
156 | x = blk(x)
157 | x_out = self.norm2(x.permute(0, 2, 3, 1))
158 | outs.append(x_out.permute(0, 3, 1, 2))
159 |
160 | x = self.patch_embed3(x)
161 | for blk in self.blocks3:
162 | x = blk(x)
163 | x_out = self.norm3(x.permute(0, 2, 3, 1))
164 | outs.append(x_out.permute(0, 3, 1, 2))
165 |
166 | x = self.patch_embed4(x)
167 | for blk in self.blocks4:
168 | x = blk(x)
169 | x_out = self.norm4(x.permute(0, 2, 3, 1))
170 | outs.append(x_out.permute(0, 3, 1, 2))
171 |
172 | return outs
173 |
174 | if __name__ == '__main__':
175 | model = UniFormer('S')
176 | model.load_state_dict(torch.load('C:\\Users\\sithu\\Documents\\weights\\backbones\\uniformer\\uniformer_small_in1k.pth', map_location='cpu')['model'], strict=False)
177 | x = torch.randn(1, 3, 224, 224)
178 | feats = model(x)
179 | for y in feats:
180 | print(y.shape)
181 |
--------------------------------------------------------------------------------
/semseg/models/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | from torch import nn
4 | from semseg.models.backbones import *
5 | from semseg.models.layers import trunc_normal_
6 |
7 |
8 | class BaseModel(nn.Module):
9 | def __init__(self, backbone: str = 'MiT-B0', num_classes: int = 19) -> None:
10 | super().__init__()
11 | backbone, variant = backbone.split('-')
12 | self.backbone = eval(backbone)(variant)
13 |
14 | def _init_weights(self, m: nn.Module) -> None:
15 | if isinstance(m, nn.Linear):
16 | trunc_normal_(m.weight, std=.02)
17 | if m.bias is not None:
18 | nn.init.zeros_(m.bias)
19 | elif isinstance(m, nn.Conv2d):
20 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
21 | fan_out // m.groups
22 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
23 | if m.bias is not None:
24 | nn.init.zeros_(m.bias)
25 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
26 | nn.init.ones_(m.weight)
27 | nn.init.zeros_(m.bias)
28 |
29 | def init_pretrained(self, pretrained: str = None) -> None:
30 | if pretrained:
31 | self.backbone.load_state_dict(torch.load(pretrained, map_location='cpu'), strict=False)
--------------------------------------------------------------------------------
/semseg/models/custom_cnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch.nn import functional as F
4 | from semseg.models.base import BaseModel
5 | from semseg.models.heads import UPerHead
6 |
7 |
8 | class CustomCNN(BaseModel):
9 | def __init__(self, backbone: str = 'ResNet-50', num_classes: int = 19):
10 | super().__init__(backbone, num_classes)
11 | self.decode_head = UPerHead(self.backbone.channels, 256, num_classes)
12 | self.apply(self._init_weights)
13 |
14 | def forward(self, x: Tensor) -> Tensor:
15 | y = self.backbone(x)
16 | y = self.decode_head(y) # 4x reduction in image size
17 | y = F.interpolate(y, size=x.shape[2:], mode='bilinear', align_corners=False) # to original image shape
18 | return y
19 |
20 |
21 | if __name__ == '__main__':
22 | model = CustomCNN('ResNet-18', 19)
23 | model.init_pretrained('checkpoints/backbones/resnet/resnet18.pth')
24 | x = torch.randn(2, 3, 224, 224)
25 | y = model(x)
26 | print(y.shape)
--------------------------------------------------------------------------------
/semseg/models/custom_vit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch.nn import functional as F
4 | from semseg.models.base import BaseModel
5 | from semseg.models.heads import UPerHead
6 |
7 |
8 | class CustomVIT(BaseModel):
9 | def __init__(self, backbone: str = 'ResT-S', num_classes: int = 19) -> None:
10 | super().__init__(backbone, num_classes)
11 | self.decode_head = UPerHead(self.backbone.channels, 128, num_classes)
12 | self.apply(self._init_weights)
13 |
14 | def forward(self, x: Tensor) -> Tensor:
15 | y = self.backbone(x)
16 | y = self.decode_head(y) # 4x reduction in image size
17 | y = F.interpolate(y, size=x.shape[2:], mode='bilinear', align_corners=False) # to original image shape
18 | return y
19 |
20 |
21 | if __name__ == '__main__':
22 | model = CustomVIT('ResT-S', 19)
23 | model.init_pretrained('checkpoints/backbones/rest/rest_small.pth')
24 | x = torch.zeros(2, 3, 512, 512)
25 | y = model(x)
26 | print(y.shape)
27 |
28 |
29 |
--------------------------------------------------------------------------------
/semseg/models/fchardnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | from torch.nn import functional as F
4 |
5 |
6 | class ConvModule(nn.Module):
7 | def __init__(self, c1, c2, k=3, s=1):
8 | super().__init__()
9 | self.conv = nn.Conv2d(c1, c2, k, s, k//2, bias=False)
10 | self.norm = nn.BatchNorm2d(c2)
11 | self.relu = nn.ReLU6(True)
12 |
13 | def forward(self, x: Tensor) -> Tensor:
14 | return self.relu(self.norm(self.conv(x)))
15 |
16 |
17 | def get_link(layer, base_ch, growth_rate):
18 | if layer == 0:
19 | return base_ch, 0, []
20 |
21 | link = []
22 | out_channels = growth_rate
23 |
24 | for i in range(10):
25 | dv = 2 ** i
26 | if layer % dv == 0:
27 | link.append(layer - dv)
28 |
29 | if i > 0: out_channels *= 1.7
30 |
31 | out_channels = int((out_channels + 1) / 2) * 2
32 | in_channels = 0
33 |
34 | for i in link:
35 | ch, _, _ = get_link(i, base_ch, growth_rate)
36 | in_channels += ch
37 |
38 | return out_channels, in_channels, link
39 |
40 |
41 | class HarDBlock(nn.Module):
42 | def __init__(self, c1, growth_rate, n_layers):
43 | super().__init__()
44 | self.links = []
45 | layers = []
46 | self.out_channels = 0
47 |
48 | for i in range(n_layers):
49 | out_ch, in_ch, link = get_link(i+1, c1, growth_rate)
50 | self.links.append(link)
51 |
52 | layers.append(ConvModule(in_ch, out_ch))
53 |
54 | if (i % 2 == 0) or (i == n_layers - 1):
55 | self.out_channels += out_ch
56 |
57 | self.layers = nn.ModuleList(layers)
58 |
59 | def forward(self, x: Tensor) -> Tensor:
60 | layers = [x]
61 |
62 | for layer in range(len(self.layers)):
63 | link = self.links[layer]
64 | tin = []
65 |
66 | for i in link:
67 | tin.append(layers[i])
68 |
69 | if len(tin) > 1:
70 | x = torch.cat(tin, dim=1)
71 | else:
72 | x = tin[0]
73 |
74 | out = self.layers[layer](x)
75 | layers.append(out)
76 |
77 | t = len(layers)
78 | outs = []
79 | for i in range(t):
80 | if (i == t - 1) or (i % 2 == 1):
81 | outs.append(layers[i])
82 |
83 | out = torch.cat(outs, dim=1)
84 | return out
85 |
86 |
87 | class FCHarDNet(nn.Module):
88 | def __init__(self, backbone: str = None, num_classes: int = 19) -> None:
89 | super().__init__()
90 | first_ch, ch_list, gr, n_layers = [16, 24, 32, 48], [64, 96, 160, 224, 320], [10, 16, 18, 24, 32], [4, 4, 8, 8, 8]
91 |
92 | self.base = nn.ModuleList([])
93 |
94 | # stem
95 | self.base.append(ConvModule(3, first_ch[0], 3, 2))
96 | self.base.append(ConvModule(first_ch[0], first_ch[1], 3))
97 | self.base.append(ConvModule(first_ch[1], first_ch[2], 3, 2))
98 | self.base.append(ConvModule(first_ch[2], first_ch[3], 3))
99 |
100 | self.shortcut_layers = []
101 | skip_connection_channel_counts = []
102 | ch = first_ch[-1]
103 |
104 | for i in range(len(n_layers)):
105 | blk = HarDBlock(ch, gr[i], n_layers[i])
106 | ch = blk.out_channels
107 |
108 | skip_connection_channel_counts.append(ch)
109 | self.base.append(blk)
110 |
111 | if i < len(n_layers) - 1:
112 | self.shortcut_layers.append(len(self.base) - 1)
113 |
114 | self.base.append(ConvModule(ch, ch_list[i], k=1))
115 | ch = ch_list[i]
116 |
117 | if i < len(n_layers) - 1:
118 | self.base.append(nn.AvgPool2d(2, 2))
119 |
120 | prev_block_channels = ch
121 | self.n_blocks = len(n_layers) - 1
122 |
123 | self.denseBlocksUp = nn.ModuleList([])
124 | self.conv1x1_up = nn.ModuleList([])
125 |
126 | for i in range(self.n_blocks-1, -1, -1):
127 | cur_channels_count = prev_block_channels + skip_connection_channel_counts[i]
128 | blk = HarDBlock(cur_channels_count // 2, gr[i], n_layers[i])
129 | prev_block_channels = blk.out_channels
130 |
131 | self.conv1x1_up.append(ConvModule(cur_channels_count, cur_channels_count//2, 1))
132 | self.denseBlocksUp.append(blk)
133 |
134 | self.finalConv = nn.Conv2d(prev_block_channels, num_classes, 1, 1, 0)
135 |
136 | self.apply(self._init_weights)
137 |
138 | def _init_weights(self, m: nn.Module) -> None:
139 | if isinstance(m, nn.Conv2d):
140 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
141 | elif isinstance(m, nn.BatchNorm2d):
142 | nn.init.constant_(m.weight, 1)
143 | nn.init.constant_(m.bias, 0)
144 |
145 | def init_pretrained(self, pretrained: str = None) -> None:
146 | if pretrained:
147 | self.load_state_dict(torch.load(pretrained, map_location='cpu'), strict=False)
148 |
149 | def forward(self, x: Tensor) -> Tensor:
150 | H, W = x.shape[-2:]
151 | skip_connections = []
152 | for i, layer in enumerate(self.base):
153 | x = layer(x)
154 | if i in self.shortcut_layers:
155 | skip_connections.append(x)
156 |
157 | out = x
158 |
159 | for i in range(self.n_blocks):
160 | skip = skip_connections.pop()
161 | out = F.interpolate(out, size=skip.shape[-2:], mode='bilinear', align_corners=True)
162 | out = torch.cat([out, skip], dim=1)
163 | out = self.conv1x1_up[i](out)
164 | out = self.denseBlocksUp[i](out)
165 |
166 | out = self.finalConv(out)
167 | out = F.interpolate(out, size=(H, W), mode='bilinear', align_corners=True)
168 | return out
169 |
170 |
171 | if __name__ == '__main__':
172 | model = FCHarDNet()
173 | # model.init_pretrained('checkpoints/backbones/hardnet/hardnet_70.pth')
174 | # model.load_state_dict(torch.load('checkpoints/pretrained/hardnet/hardnet70_cityscapes.pth', map_location='cpu'))
175 | x = torch.zeros(1, 3, 224, 224)
176 | outs = model(x)
177 | print(outs.shape)
178 |
--------------------------------------------------------------------------------
/semseg/models/heads/__init__.py:
--------------------------------------------------------------------------------
1 | from .upernet import UPerHead
2 | from .segformer import SegFormerHead
3 | from .sfnet import SFHead
4 | from .fpn import FPNHead
5 | from .fapn import FaPNHead
6 | from .fcn import FCNHead
7 | from .condnet import CondHead
8 | from .lawin import LawinHead
9 |
10 | __all__ = ['UPerHead', 'SegFormerHead', 'SFHead', 'FPNHead', 'FaPNHead', 'FCNHead', 'CondHead', 'LawinHead']
--------------------------------------------------------------------------------
/semseg/models/heads/condnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | from torch.nn import functional as F
4 | from semseg.models.layers import ConvModule
5 |
6 |
7 | class CondHead(nn.Module):
8 | def __init__(self, in_channel: int = 2048, channel: int = 512, num_classes: int = 19):
9 | super().__init__()
10 | self.num_classes = num_classes
11 | self.weight_num = channel * num_classes
12 | self.bias_num = num_classes
13 |
14 | self.conv = ConvModule(in_channel, channel, 1)
15 | self.dropout = nn.Dropout2d(0.1)
16 |
17 | self.guidance_project = nn.Conv2d(channel, num_classes, 1)
18 | self.filter_project = nn.Conv2d(channel*num_classes, self.weight_num + self.bias_num, 1, groups=num_classes)
19 |
20 | def forward(self, features) -> Tensor:
21 | x = self.dropout(self.conv(features[-1]))
22 | B, C, H, W = x.shape
23 | guidance_mask = self.guidance_project(x)
24 | cond_logit = guidance_mask
25 |
26 | key = x
27 | value = x
28 | guidance_mask = guidance_mask.softmax(dim=1).view(*guidance_mask.shape[:2], -1)
29 | key = key.view(B, C, -1).permute(0, 2, 1)
30 |
31 | cond_filters = torch.matmul(guidance_mask, key)
32 | cond_filters /= H * W
33 | cond_filters = cond_filters.view(B, -1, 1, 1)
34 | cond_filters = self.filter_project(cond_filters)
35 | cond_filters = cond_filters.view(B, -1)
36 |
37 | weight, bias = torch.split(cond_filters, [self.weight_num, self.bias_num], dim=1)
38 | weight = weight.reshape(B * self.num_classes, -1, 1, 1)
39 | bias = bias.reshape(B * self.num_classes)
40 |
41 | value = value.view(-1, H, W).unsqueeze(0)
42 | seg_logit = F.conv2d(value, weight, bias, 1, 0, groups=B).view(B, self.num_classes, H, W)
43 |
44 | if self.training:
45 | return cond_logit, seg_logit
46 | return seg_logit
47 |
48 |
49 | if __name__ == '__main__':
50 | from semseg.models.backbones import ResNetD
51 | backbone = ResNetD('50')
52 | head = CondHead()
53 | x = torch.randn(2, 3, 224, 224)
54 | features = backbone(x)
55 | outs = head(features)
56 | for out in outs:
57 | out = F.interpolate(out, size=x.shape[-2:], mode='bilinear', align_corners=False)
58 | print(out.shape)
--------------------------------------------------------------------------------
/semseg/models/heads/fapn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | from torch.nn import functional as F
4 | from torchvision.ops import DeformConv2d
5 | from semseg.models.layers import ConvModule
6 |
7 |
8 | class DCNv2(nn.Module):
9 | def __init__(self, c1, c2, k, s, p, g=1):
10 | super().__init__()
11 | self.dcn = DeformConv2d(c1, c2, k, s, p, groups=g)
12 | self.offset_mask = nn.Conv2d(c2, g* 3 * k * k, k, s, p)
13 | self._init_offset()
14 |
15 | def _init_offset(self):
16 | self.offset_mask.weight.data.zero_()
17 | self.offset_mask.bias.data.zero_()
18 |
19 | def forward(self, x, offset):
20 | out = self.offset_mask(offset)
21 | o1, o2, mask = torch.chunk(out, 3, dim=1)
22 | offset = torch.cat([o1, o2], dim=1)
23 | mask = mask.sigmoid()
24 | return self.dcn(x, offset, mask)
25 |
26 |
27 | class FSM(nn.Module):
28 | def __init__(self, c1, c2):
29 | super().__init__()
30 | self.conv_atten = nn.Conv2d(c1, c1, 1, bias=False)
31 | self.conv = nn.Conv2d(c1, c2, 1, bias=False)
32 |
33 | def forward(self, x: Tensor) -> Tensor:
34 | atten = self.conv_atten(F.avg_pool2d(x, x.shape[2:])).sigmoid()
35 | feat = torch.mul(x, atten)
36 | x = x + feat
37 | return self.conv(x)
38 |
39 |
40 | class FAM(nn.Module):
41 | def __init__(self, c1, c2):
42 | super().__init__()
43 | self.lateral_conv = FSM(c1, c2)
44 | self.offset = nn.Conv2d(c2*2, c2, 1, bias=False)
45 | self.dcpack_l2 = DCNv2(c2, c2, 3, 1, 1, 8)
46 |
47 | def forward(self, feat_l, feat_s):
48 | feat_up = feat_s
49 | if feat_l.shape[2:] != feat_s.shape[2:]:
50 | feat_up = F.interpolate(feat_s, size=feat_l.shape[2:], mode='bilinear', align_corners=False)
51 |
52 | feat_arm = self.lateral_conv(feat_l)
53 | offset = self.offset(torch.cat([feat_arm, feat_up*2], dim=1))
54 |
55 | feat_align = F.relu(self.dcpack_l2(feat_up, offset))
56 | return feat_align + feat_arm
57 |
58 |
59 | class FaPNHead(nn.Module):
60 | def __init__(self, in_channels, channel=128, num_classes=19):
61 | super().__init__()
62 | in_channels = in_channels[::-1]
63 | self.align_modules = nn.ModuleList([ConvModule(in_channels[0], channel, 1)])
64 | self.output_convs = nn.ModuleList([])
65 |
66 | for ch in in_channels[1:]:
67 | self.align_modules.append(FAM(ch, channel))
68 | self.output_convs.append(ConvModule(channel, channel, 3, 1, 1))
69 |
70 | self.conv_seg = nn.Conv2d(channel, num_classes, 1)
71 | self.dropout = nn.Dropout2d(0.1)
72 |
73 | def forward(self, features) -> Tensor:
74 | features = features[::-1]
75 | out = self.align_modules[0](features[0])
76 |
77 | for feat, align_module, output_conv in zip(features[1:], self.align_modules[1:], self.output_convs):
78 | out = align_module(feat, out)
79 | out = output_conv(out)
80 | out = self.conv_seg(self.dropout(out))
81 | return out
82 |
83 |
84 | if __name__ == '__main__':
85 | from semseg.models.backbones import ResNet
86 | backbone = ResNet('50')
87 | head = FaPNHead([256, 512, 1024, 2048], 128, 19)
88 | x = torch.randn(2, 3, 224, 224)
89 | features = backbone(x)
90 | out = head(features)
91 | out = F.interpolate(out, size=x.shape[-2:], mode='bilinear', align_corners=False)
92 | print(out.shape)
--------------------------------------------------------------------------------
/semseg/models/heads/fcn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | from torch.nn import functional as F
4 | from semseg.models.layers import ConvModule
5 |
6 |
7 | class FCNHead(nn.Module):
8 | def __init__(self, c1, c2, num_classes: int = 19):
9 | super().__init__()
10 | self.conv = ConvModule(c1, c2, 1)
11 | self.cls = nn.Conv2d(c2, num_classes, 1)
12 |
13 | def forward(self, features) -> Tensor:
14 | x = self.conv(features[-1])
15 | x = self.cls(x)
16 | return x
17 |
18 |
19 | if __name__ == '__main__':
20 | from semseg.models.backbones import ResNet
21 | backbone = ResNet('50')
22 | head = FCNHead(2048, 256, 19)
23 | x = torch.randn(2, 3, 224, 224)
24 | features = backbone(x)
25 | out = head(features)
26 | out = F.interpolate(out, size=x.shape[-2:], mode='bilinear', align_corners=False)
27 | print(out.shape)
28 |
--------------------------------------------------------------------------------
/semseg/models/heads/fpn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | from torch.nn import functional as F
4 | from semseg.models.layers import ConvModule
5 |
6 |
7 | class FPNHead(nn.Module):
8 | """Panoptic Feature Pyramid Networks
9 | https://arxiv.org/abs/1901.02446
10 | """
11 | def __init__(self, in_channels, channel=128, num_classes=19):
12 | super().__init__()
13 | self.lateral_convs = nn.ModuleList([])
14 | self.output_convs = nn.ModuleList([])
15 |
16 | for ch in in_channels[::-1]:
17 | self.lateral_convs.append(ConvModule(ch, channel, 1))
18 | self.output_convs.append(ConvModule(channel, channel, 3, 1, 1))
19 |
20 | self.conv_seg = nn.Conv2d(channel, num_classes, 1)
21 | self.dropout = nn.Dropout2d(0.1)
22 |
23 | def forward(self, features) -> Tensor:
24 | features = features[::-1]
25 | out = self.lateral_convs[0](features[0])
26 |
27 | for i in range(1, len(features)):
28 | out = F.interpolate(out, scale_factor=2.0, mode='nearest')
29 | out = out + self.lateral_convs[i](features[i])
30 | out = self.output_convs[i](out)
31 | out = self.conv_seg(self.dropout(out))
32 | return out
33 |
34 |
35 | if __name__ == '__main__':
36 | from semseg.models.backbones import ResNet
37 | backbone = ResNet('50')
38 | head = FPNHead([256, 512, 1024, 2048], 128, 19)
39 | x = torch.randn(2, 3, 224, 224)
40 | features = backbone(x)
41 | out = head(features)
42 | out = F.interpolate(out, size=x.shape[-2:], mode='bilinear', align_corners=False)
43 | print(out.shape)
--------------------------------------------------------------------------------
/semseg/models/heads/segformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | from typing import Tuple
4 | from torch.nn import functional as F
5 |
6 |
7 | class MLP(nn.Module):
8 | def __init__(self, dim, embed_dim):
9 | super().__init__()
10 | self.proj = nn.Linear(dim, embed_dim)
11 |
12 | def forward(self, x: Tensor) -> Tensor:
13 | x = x.flatten(2).transpose(1, 2)
14 | x = self.proj(x)
15 | return x
16 |
17 |
18 | class ConvModule(nn.Module):
19 | def __init__(self, c1, c2):
20 | super().__init__()
21 | self.conv = nn.Conv2d(c1, c2, 1, bias=False)
22 | self.bn = nn.BatchNorm2d(c2) # use SyncBN in original
23 | self.activate = nn.ReLU(True)
24 |
25 | def forward(self, x: Tensor) -> Tensor:
26 | return self.activate(self.bn(self.conv(x)))
27 |
28 |
29 | class SegFormerHead(nn.Module):
30 | def __init__(self, dims: list, embed_dim: int = 256, num_classes: int = 19):
31 | super().__init__()
32 | for i, dim in enumerate(dims):
33 | self.add_module(f"linear_c{i+1}", MLP(dim, embed_dim))
34 |
35 | self.linear_fuse = ConvModule(embed_dim*4, embed_dim)
36 | self.linear_pred = nn.Conv2d(embed_dim, num_classes, 1)
37 | self.dropout = nn.Dropout2d(0.1)
38 |
39 | def forward(self, features: Tuple[Tensor, Tensor, Tensor, Tensor]) -> Tensor:
40 | B, _, H, W = features[0].shape
41 | outs = [self.linear_c1(features[0]).permute(0, 2, 1).reshape(B, -1, *features[0].shape[-2:])]
42 |
43 | for i, feature in enumerate(features[1:]):
44 | cf = eval(f"self.linear_c{i+2}")(feature).permute(0, 2, 1).reshape(B, -1, *feature.shape[-2:])
45 | outs.append(F.interpolate(cf, size=(H, W), mode='bilinear', align_corners=False))
46 |
47 | seg = self.linear_fuse(torch.cat(outs[::-1], dim=1))
48 | seg = self.linear_pred(self.dropout(seg))
49 | return seg
--------------------------------------------------------------------------------
/semseg/models/heads/sfnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | from torch.nn import functional as F
4 | from semseg.models.layers import ConvModule
5 | from semseg.models.modules import PPM
6 |
7 |
8 | class AlignedModule(nn.Module):
9 | def __init__(self, c1, c2, k=3):
10 | super().__init__()
11 | self.down_h = nn.Conv2d(c1, c2, 1, bias=False)
12 | self.down_l = nn.Conv2d(c1, c2, 1, bias=False)
13 | self.flow_make = nn.Conv2d(c2 * 2, 2, k, 1, 1, bias=False)
14 |
15 | def forward(self, low_feature: Tensor, high_feature: Tensor) -> Tensor:
16 | high_feature_origin = high_feature
17 | H, W = low_feature.shape[-2:]
18 | low_feature = self.down_l(low_feature)
19 | high_feature = self.down_h(high_feature)
20 | high_feature = F.interpolate(high_feature, size=(H, W), mode='bilinear', align_corners=True)
21 | flow = self.flow_make(torch.cat([high_feature, low_feature], dim=1))
22 | high_feature = self.flow_warp(high_feature_origin, flow, (H, W))
23 | return high_feature
24 |
25 | def flow_warp(self, x: Tensor, flow: Tensor, size: tuple) -> Tensor:
26 | # norm = torch.tensor(size).reshape(1, 1, 1, -1)
27 | norm = torch.tensor([[[[*size]]]]).type_as(x).to(x.device)
28 | H = torch.linspace(-1.0, 1.0, size[0]).view(-1, 1).repeat(1, size[1])
29 | W = torch.linspace(-1.0, 1.0, size[1]).repeat(size[0], 1)
30 | grid = torch.cat((W.unsqueeze(2), H.unsqueeze(2)), dim=2)
31 | grid = grid.repeat(x.shape[0], 1, 1, 1).type_as(x).to(x.device)
32 | grid = grid + flow.permute(0, 2, 3, 1) / norm
33 | output = F.grid_sample(x, grid, align_corners=False)
34 | return output
35 |
36 |
37 | class SFHead(nn.Module):
38 | def __init__(self, in_channels, channel=256, num_classes=19, scales=(1, 2, 3, 6)):
39 | super().__init__()
40 | self.ppm = PPM(in_channels[-1], channel, scales)
41 |
42 | self.fpn_in = nn.ModuleList([])
43 | self.fpn_out = nn.ModuleList([])
44 | self.fpn_out_align = nn.ModuleList([])
45 |
46 | for in_ch in in_channels[:-1]:
47 | self.fpn_in.append(ConvModule(in_ch, channel, 1))
48 | self.fpn_out.append(ConvModule(channel, channel, 3, 1, 1))
49 | self.fpn_out_align.append(AlignedModule(channel, channel//2))
50 |
51 | self.bottleneck = ConvModule(len(in_channels) * channel, channel, 3, 1, 1)
52 | self.dropout = nn.Dropout2d(0.1)
53 | self.conv_seg = nn.Conv2d(channel, num_classes, 1)
54 |
55 | def forward(self, features: list) -> Tensor:
56 | f = self.ppm(features[-1])
57 | fpn_features = [f]
58 |
59 | for i in reversed(range(len(features) - 1)):
60 | feature = self.fpn_in[i](features[i])
61 | f = feature + self.fpn_out_align[i](feature, f)
62 | fpn_features.append(self.fpn_out[i](f))
63 |
64 | fpn_features.reverse()
65 |
66 | for i in range(1, len(fpn_features)):
67 | fpn_features[i] = F.interpolate(fpn_features[i], size=fpn_features[0].shape[-2:], mode='bilinear', align_corners=True)
68 |
69 | output = self.bottleneck(torch.cat(fpn_features, dim=1))
70 | output = self.conv_seg(self.dropout(output))
71 | return output
72 |
73 |
--------------------------------------------------------------------------------
/semseg/models/heads/upernet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | from torch.nn import functional as F
4 | from typing import Tuple
5 | from semseg.models.layers import ConvModule
6 | from semseg.models.modules import PPM
7 |
8 |
9 | class UPerHead(nn.Module):
10 | """Unified Perceptual Parsing for Scene Understanding
11 | https://arxiv.org/abs/1807.10221
12 | scales: Pooling scales used in PPM module applied on the last feature
13 | """
14 | def __init__(self, in_channels, channel=128, num_classes: int = 19, scales=(1, 2, 3, 6)):
15 | super().__init__()
16 | # PPM Module
17 | self.ppm = PPM(in_channels[-1], channel, scales)
18 |
19 | # FPN Module
20 | self.fpn_in = nn.ModuleList()
21 | self.fpn_out = nn.ModuleList()
22 |
23 | for in_ch in in_channels[:-1]: # skip the top layer
24 | self.fpn_in.append(ConvModule(in_ch, channel, 1))
25 | self.fpn_out.append(ConvModule(channel, channel, 3, 1, 1))
26 |
27 | self.bottleneck = ConvModule(len(in_channels)*channel, channel, 3, 1, 1)
28 | self.dropout = nn.Dropout2d(0.1)
29 | self.conv_seg = nn.Conv2d(channel, num_classes, 1)
30 |
31 |
32 | def forward(self, features: Tuple[Tensor, Tensor, Tensor, Tensor]) -> Tensor:
33 | f = self.ppm(features[-1])
34 | fpn_features = [f]
35 |
36 | for i in reversed(range(len(features)-1)):
37 | feature = self.fpn_in[i](features[i])
38 | f = feature + F.interpolate(f, size=feature.shape[-2:], mode='bilinear', align_corners=False)
39 | fpn_features.append(self.fpn_out[i](f))
40 |
41 | fpn_features.reverse()
42 | for i in range(1, len(features)):
43 | fpn_features[i] = F.interpolate(fpn_features[i], size=fpn_features[0].shape[-2:], mode='bilinear', align_corners=False)
44 |
45 | output = self.bottleneck(torch.cat(fpn_features, dim=1))
46 | output = self.conv_seg(self.dropout(output))
47 | return output
48 |
49 |
50 | if __name__ == '__main__':
51 | model = UPerHead([64, 128, 256, 512], 128)
52 | x1 = torch.randn(2, 64, 56, 56)
53 | x2 = torch.randn(2, 128, 28, 28)
54 | x3 = torch.randn(2, 256, 14, 14)
55 | x4 = torch.randn(2, 512, 7, 7)
56 | y = model([x1, x2, x3, x4])
57 | print(y.shape)
--------------------------------------------------------------------------------
/semseg/models/lawin.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch.nn import functional as F
4 | from semseg.models.base import BaseModel
5 | from semseg.models.heads import LawinHead
6 |
7 |
8 | class Lawin(BaseModel):
9 | """
10 | Notes::::: This implementation has larger params and FLOPs than the results reported in the paper.
11 | Will update the code and weights if the original author releases the full code.
12 | """
13 | def __init__(self, backbone: str = 'MiT-B0', num_classes: int = 19) -> None:
14 | super().__init__(backbone, num_classes)
15 | self.decode_head = LawinHead(self.backbone.channels, 256 if 'B0' in backbone else 512, num_classes)
16 | self.apply(self._init_weights)
17 |
18 | def forward(self, x: Tensor) -> Tensor:
19 | y = self.backbone(x)
20 | y = self.decode_head(y) # 4x reduction in image size
21 | y = F.interpolate(y, size=x.shape[2:], mode='bilinear', align_corners=False) # to original image shape
22 | return y
23 |
24 |
25 | if __name__ == '__main__':
26 | model = Lawin('MiT-B1')
27 | model.eval()
28 | x = torch.zeros(1, 3, 512, 512)
29 | y = model(x)
30 | print(y.shape)
31 | from fvcore.nn import flop_count_table, FlopCountAnalysis
32 | print(flop_count_table(FlopCountAnalysis(model, x)))
--------------------------------------------------------------------------------
/semseg/models/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .common import *
2 | from .initialize import *
--------------------------------------------------------------------------------
/semseg/models/layers/common.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 |
4 |
5 | class ConvModule(nn.Sequential):
6 | def __init__(self, c1, c2, k, s=1, p=0, d=1, g=1):
7 | super().__init__(
8 | nn.Conv2d(c1, c2, k, s, p, d, g, bias=False),
9 | nn.BatchNorm2d(c2),
10 | nn.ReLU(True)
11 | )
12 |
13 |
14 | class DropPath(nn.Module):
15 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
16 | Copied from timm
17 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
18 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
19 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
20 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
21 | 'survival rate' as the argument.
22 | """
23 | def __init__(self, p: float = None):
24 | super().__init__()
25 | self.p = p
26 |
27 | def forward(self, x: Tensor) -> Tensor:
28 | if self.p == 0. or not self.training:
29 | return x
30 | kp = 1 - self.p
31 | shape = (x.shape[0],) + (1,) * (x.ndim - 1)
32 | random_tensor = kp + torch.rand(shape, dtype=x.dtype, device=x.device)
33 | random_tensor.floor_() # binarize
34 | return x.div(kp) * random_tensor
--------------------------------------------------------------------------------
/semseg/models/layers/initialize.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import warnings
4 | from torch import nn, Tensor
5 |
6 |
7 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
8 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
9 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
10 | def norm_cdf(x):
11 | # Computes standard normal cumulative distribution function
12 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
13 |
14 | if (mean < a - 2 * std) or (mean > b + 2 * std):
15 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
16 | "The distribution of values may be incorrect.",
17 | stacklevel=2)
18 |
19 | with torch.no_grad():
20 | # Values are generated by using a truncated uniform distribution and
21 | # then using the inverse CDF for the normal distribution.
22 | # Get upper and lower cdf values
23 | l = norm_cdf((a - mean) / std)
24 | u = norm_cdf((b - mean) / std)
25 |
26 | # Uniformly fill tensor with values from [l, u], then translate to
27 | # [2l-1, 2u-1].
28 | tensor.uniform_(2 * l - 1, 2 * u - 1)
29 |
30 | # Use inverse cdf transform for normal distribution to get truncated
31 | # standard normal
32 | tensor.erfinv_()
33 |
34 | # Transform to proper mean, std
35 | tensor.mul_(std * math.sqrt(2.))
36 | tensor.add_(mean)
37 |
38 | # Clamp to ensure it's in the proper range
39 | tensor.clamp_(min=a, max=b)
40 | return tensor
41 |
42 |
43 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
44 | # type: (Tensor, float, float, float, float) -> Tensor
45 | r"""Fills the input Tensor with values drawn from a truncated
46 | normal distribution. The values are effectively drawn from the
47 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
48 | with values outside :math:`[a, b]` redrawn until they are within
49 | the bounds. The method used for generating the random values works
50 | best when :math:`a \leq \text{mean} \leq b`.
51 | Args:
52 | tensor: an n-dimensional `torch.Tensor`
53 | mean: the mean of the normal distribution
54 | std: the standard deviation of the normal distribution
55 | a: the minimum cutoff value
56 | b: the maximum cutoff value
57 | Examples:
58 | >>> w = torch.empty(3, 5)
59 | >>> nn.init.trunc_normal_(w)
60 | """
61 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
62 |
--------------------------------------------------------------------------------
/semseg/models/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .ppm import PPM
2 | from .psa import PSAP, PSAS
3 |
4 | __all__ = ['PPM', 'PSAP', 'PSAS']
--------------------------------------------------------------------------------
/semseg/models/modules/ppm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | from torch.nn import functional as F
4 | from semseg.models.layers import ConvModule
5 |
6 |
7 | class PPM(nn.Module):
8 | """Pyramid Pooling Module in PSPNet
9 | """
10 | def __init__(self, c1, c2=128, scales=(1, 2, 3, 6)):
11 | super().__init__()
12 | self.stages = nn.ModuleList([
13 | nn.Sequential(
14 | nn.AdaptiveAvgPool2d(scale),
15 | ConvModule(c1, c2, 1)
16 | )
17 | for scale in scales])
18 |
19 | self.bottleneck = ConvModule(c1 + c2 * len(scales), c2, 3, 1, 1)
20 |
21 | def forward(self, x: Tensor) -> Tensor:
22 | outs = []
23 | for stage in self.stages:
24 | outs.append(F.interpolate(stage(x), size=x.shape[-2:], mode='bilinear', align_corners=True))
25 |
26 | outs = [x] + outs[::-1]
27 | out = self.bottleneck(torch.cat(outs, dim=1))
28 | return out
29 |
30 |
31 | if __name__ == '__main__':
32 | model = PPM(512, 128)
33 | x = torch.randn(2, 512, 7, 7)
34 | y = model(x)
35 | print(y.shape) # [2, 128, 7, 7]
--------------------------------------------------------------------------------
/semseg/models/segformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch.nn import functional as F
4 | from semseg.models.base import BaseModel
5 | from semseg.models.heads import SegFormerHead
6 |
7 |
8 | class SegFormer(BaseModel):
9 | def __init__(self, backbone: str = 'MiT-B0', num_classes: int = 19) -> None:
10 | super().__init__(backbone, num_classes)
11 | self.decode_head = SegFormerHead(self.backbone.channels, 256 if 'B0' in backbone or 'B1' in backbone else 768, num_classes)
12 | self.apply(self._init_weights)
13 |
14 | def forward(self, x: Tensor) -> Tensor:
15 | y = self.backbone(x)
16 | y = self.decode_head(y) # 4x reduction in image size
17 | y = F.interpolate(y, size=x.shape[2:], mode='bilinear', align_corners=False) # to original image shape
18 | return y
19 |
20 |
21 | if __name__ == '__main__':
22 | model = SegFormer('MiT-B0')
23 | # model.load_state_dict(torch.load('checkpoints/pretrained/segformer/segformer.b0.ade.pth', map_location='cpu'))
24 | x = torch.zeros(1, 3, 512, 512)
25 | y = model(x)
26 | print(y.shape)
--------------------------------------------------------------------------------
/semseg/models/sfnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch.nn import functional as F
4 | from semseg.models.base import BaseModel
5 | from semseg.models.heads import SFHead
6 |
7 |
8 | class SFNet(BaseModel):
9 | def __init__(self, backbone: str = 'ResNetD-18', num_classes: int = 19):
10 | assert 'ResNet' in backbone
11 | super().__init__(backbone, num_classes)
12 | self.head = SFHead(self.backbone.channels, 128 if '18' in backbone else 256, num_classes)
13 | self.apply(self._init_weights)
14 |
15 | def forward(self, x: Tensor) -> Tensor:
16 | outs = self.backbone(x)
17 | out = self.head(outs)
18 | out = F.interpolate(out, size=x.shape[-2:], mode='bilinear', align_corners=True)
19 | return out
20 |
21 |
22 | if __name__ == '__main__':
23 | model = SFNet('ResNetD-18')
24 | model.init_pretrained('checkpoints/backbones/resnetd/resnetd18.pth')
25 | x = torch.randn(2, 3, 224, 224)
26 | y = model(x)
27 | print(y.shape)
--------------------------------------------------------------------------------
/semseg/optimizers.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch.optim import AdamW, SGD
3 |
4 |
5 | def get_optimizer(model: nn.Module, optimizer: str, lr: float, weight_decay: float = 0.01):
6 | wd_params, nwd_params = [], []
7 | for p in model.parameters():
8 | if p.dim() == 1:
9 | nwd_params.append(p)
10 | else:
11 | wd_params.append(p)
12 |
13 | params = [
14 | {"params": wd_params},
15 | {"params": nwd_params, "weight_decay": 0}
16 | ]
17 |
18 | if optimizer == 'adamw':
19 | return AdamW(params, lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=weight_decay)
20 | else:
21 | return SGD(params, lr, momentum=0.9, weight_decay=weight_decay)
--------------------------------------------------------------------------------
/semseg/schedulers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | from torch.optim.lr_scheduler import _LRScheduler
4 |
5 |
6 | class PolyLR(_LRScheduler):
7 | def __init__(self, optimizer, max_iter, decay_iter=1, power=0.9, last_epoch=-1) -> None:
8 | self.decay_iter = decay_iter
9 | self.max_iter = max_iter
10 | self.power = power
11 | super().__init__(optimizer, last_epoch=last_epoch)
12 |
13 | def get_lr(self):
14 | if self.last_epoch % self.decay_iter or self.last_epoch % self.max_iter:
15 | return self.base_lrs
16 | else:
17 | factor = (1 - self.last_epoch / float(self.max_iter)) ** self.power
18 | return [factor*lr for lr in self.base_lrs]
19 |
20 |
21 | class WarmupLR(_LRScheduler):
22 | def __init__(self, optimizer, warmup_iter=500, warmup_ratio=5e-4, warmup='exp', last_epoch=-1) -> None:
23 | self.warmup_iter = warmup_iter
24 | self.warmup_ratio = warmup_ratio
25 | self.warmup = warmup
26 | super().__init__(optimizer, last_epoch)
27 |
28 | def get_lr(self):
29 | ratio = self.get_lr_ratio()
30 | return [ratio * lr for lr in self.base_lrs]
31 |
32 | def get_lr_ratio(self):
33 | return self.get_warmup_ratio() if self.last_epoch < self.warmup_iter else self.get_main_ratio()
34 |
35 | def get_main_ratio(self):
36 | raise NotImplementedError
37 |
38 | def get_warmup_ratio(self):
39 | assert self.warmup in ['linear', 'exp']
40 | alpha = self.last_epoch / self.warmup_iter
41 |
42 | return self.warmup_ratio + (1. - self.warmup_ratio) * alpha if self.warmup == 'linear' else self.warmup_ratio ** (1. - alpha)
43 |
44 |
45 | class WarmupPolyLR(WarmupLR):
46 | def __init__(self, optimizer, power, max_iter, warmup_iter=500, warmup_ratio=5e-4, warmup='exp', last_epoch=-1) -> None:
47 | self.power = power
48 | self.max_iter = max_iter
49 | super().__init__(optimizer, warmup_iter, warmup_ratio, warmup, last_epoch)
50 |
51 | def get_main_ratio(self):
52 | real_iter = self.last_epoch - self.warmup_iter
53 | real_max_iter = self.max_iter - self.warmup_iter
54 | alpha = real_iter / real_max_iter
55 |
56 | return (1 - alpha) ** self.power
57 |
58 |
59 | class WarmupExpLR(WarmupLR):
60 | def __init__(self, optimizer, gamma, interval=1, warmup_iter=500, warmup_ratio=5e-4, warmup='exp', last_epoch=-1) -> None:
61 | self.gamma = gamma
62 | self.interval = interval
63 | super().__init__(optimizer, warmup_iter, warmup_ratio, warmup, last_epoch)
64 |
65 | def get_main_ratio(self):
66 | real_iter = self.last_epoch - self.warmup_iter
67 | return self.gamma ** (real_iter // self.interval)
68 |
69 |
70 | class WarmupCosineLR(WarmupLR):
71 | def __init__(self, optimizer, max_iter, eta_ratio=0, warmup_iter=500, warmup_ratio=5e-4, warmup='exp', last_epoch=-1) -> None:
72 | self.eta_ratio = eta_ratio
73 | self.max_iter = max_iter
74 | super().__init__(optimizer, warmup_iter, warmup_ratio, warmup, last_epoch)
75 |
76 | def get_main_ratio(self):
77 | real_iter = self.last_epoch - self.warmup_iter
78 | real_max_iter = self.max_iter - self.warmup_iter
79 |
80 | return self.eta_ratio + (1 - self.eta_ratio) * (1 + math.cos(math.pi * self.last_epoch / real_max_iter)) / 2
81 |
82 |
83 |
84 | __all__ = ['polylr', 'warmuppolylr', 'warmupcosinelr', 'warmupsteplr']
85 |
86 |
87 | def get_scheduler(scheduler_name: str, optimizer, max_iter: int, power: int, warmup_iter: int, warmup_ratio: float):
88 | assert scheduler_name in __all__, f"Unavailable scheduler name >> {scheduler_name}.\nAvailable schedulers: {__all__}"
89 | if scheduler_name == 'warmuppolylr':
90 | return WarmupPolyLR(optimizer, power, max_iter, warmup_iter, warmup_ratio, warmup='linear')
91 | elif scheduler_name == 'warmupcosinelr':
92 | return WarmupCosineLR(optimizer, max_iter, warmup_iter=warmup_iter, warmup_ratio=warmup_ratio)
93 | return PolyLR(optimizer, max_iter)
94 |
95 |
96 | if __name__ == '__main__':
97 | model = torch.nn.Conv2d(3, 16, 3, 1, 1)
98 | optim = torch.optim.SGD(model.parameters(), lr=1e-3)
99 |
100 | max_iter = 20000
101 | sched = WarmupPolyLR(optim, power=0.9, max_iter=max_iter, warmup_iter=200, warmup_ratio=0.1, warmup='exp', last_epoch=-1)
102 |
103 | lrs = []
104 |
105 | for _ in range(max_iter):
106 | lr = sched.get_lr()[0]
107 | lrs.append(lr)
108 | optim.step()
109 | sched.step()
110 |
111 | import matplotlib.pyplot as plt
112 | import numpy as np
113 |
114 | plt.plot(np.arange(len(lrs)), np.array(lrs))
115 | plt.grid()
116 | plt.show()
--------------------------------------------------------------------------------
/semseg/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sithu31296/semantic-segmentation/c2843c9cf42a66a96a4173a0ef0059975a6e387c/semseg/utils/__init__.py
--------------------------------------------------------------------------------
/semseg/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import random
4 | import time
5 | import os
6 | import functools
7 | from pathlib import Path
8 | from torch.backends import cudnn
9 | from torch import nn, Tensor
10 | from torch.autograd import profiler
11 | from typing import Union
12 | from torch import distributed as dist
13 | from tabulate import tabulate
14 | from semseg import models
15 |
16 |
17 | def fix_seeds(seed: int = 3407) -> None:
18 | torch.manual_seed(seed)
19 | torch.cuda.manual_seed(seed)
20 | np.random.seed(seed)
21 | random.seed(seed)
22 |
23 | def setup_cudnn() -> None:
24 | # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
25 | cudnn.benchmark = True
26 | cudnn.deterministic = False
27 |
28 | def time_sync() -> float:
29 | if torch.cuda.is_available():
30 | torch.cuda.synchronize()
31 | return time.time()
32 |
33 | def get_model_size(model: Union[nn.Module, torch.jit.ScriptModule]):
34 | tmp_model_path = Path('temp.p')
35 | if isinstance(model, torch.jit.ScriptModule):
36 | torch.jit.save(model, tmp_model_path)
37 | else:
38 | torch.save(model.state_dict(), tmp_model_path)
39 | size = tmp_model_path.stat().st_size
40 | os.remove(tmp_model_path)
41 | return size / 1e6 # in MB
42 |
43 | @torch.no_grad()
44 | def test_model_latency(model: nn.Module, inputs: torch.Tensor, use_cuda: bool = False) -> float:
45 | with profiler.profile(use_cuda=use_cuda) as prof:
46 | _ = model(inputs)
47 | return prof.self_cpu_time_total / 1000 # ms
48 |
49 | def count_parameters(model: nn.Module) -> float:
50 | return sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6 # in M
51 |
52 | def setup_ddp() -> int:
53 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
54 | rank = int(os.environ['RANK'])
55 | world_size = int(os.environ['WORLD_SIZE'])
56 | gpu = int(os.environ(['LOCAL_RANK']))
57 | torch.cuda.set_device(gpu)
58 | dist.init_process_group('nccl', init_method="env://",world_size=world_size, rank=rank)
59 | dist.barrier()
60 | else:
61 | gpu = 0
62 | return gpu
63 |
64 | def cleanup_ddp():
65 | if dist.is_initialized():
66 | dist.destroy_process_group()
67 |
68 | def reduce_tensor(tensor: Tensor) -> Tensor:
69 | rt = tensor.clone()
70 | dist.all_reduce(rt, op=dist.ReduceOp.SUM)
71 | rt /= dist.get_world_size()
72 | return rt
73 |
74 | @torch.no_grad()
75 | def throughput(dataloader, model: nn.Module, times: int = 30):
76 | model.eval()
77 | images, _ = next(iter(dataloader))
78 | images = images.cuda(non_blocking=True)
79 | B = images.shape[0]
80 | print(f"Throughput averaged with {times} times")
81 | start = time_sync()
82 | for _ in range(times):
83 | model(images)
84 | end = time_sync()
85 |
86 | print(f"Batch Size {B} throughput {times * B / (end - start)} images/s")
87 |
88 |
89 | def show_models():
90 | model_names = models.__all__
91 | model_variants = [list(eval(f'models.{name.lower()}_settings').keys()) for name in model_names]
92 |
93 | print(tabulate({'Model Names': model_names, 'Model Variants': model_variants}, headers='keys'))
94 |
95 |
96 | def timer(func):
97 | @functools.wraps(func)
98 | def wrapper_timer(*args, **kwargs):
99 | tic = time.perf_counter()
100 | value = func(*args, **kwargs)
101 | toc = time.perf_counter()
102 | elapsed_time = toc - tic
103 | print(f"Elapsed time: {elapsed_time * 1000:.2f}ms")
104 | return value
105 | return wrapper_timer
--------------------------------------------------------------------------------
/semseg/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import numpy as np
4 | import torch
5 | import matplotlib.pyplot as plt
6 | from torch.utils.data import DataLoader
7 | from torchvision import transforms as T
8 | from torchvision.utils import make_grid
9 | from semseg.augmentations import Compose, Normalize, RandomResizedCrop
10 | from PIL import Image, ImageDraw, ImageFont
11 |
12 |
13 | def visualize_dataset_sample(dataset, root, split='val', batch_size=4):
14 | transform = Compose([
15 | RandomResizedCrop((512, 512), scale=(1.0, 1.0)),
16 | Normalize()
17 | ])
18 |
19 | dataset = dataset(root, split=split, transform=transform)
20 | dataloader = DataLoader(dataset, shuffle=True, batch_size=batch_size)
21 | image, label = next(iter(dataloader))
22 |
23 | print(f"Image Shape\t: {image.shape}")
24 | print(f"Label Shape\t: {label.shape}")
25 | print(f"Classes\t\t: {label.unique().tolist()}")
26 |
27 | label[label == -1] = 0
28 | label[label == 255] = 0
29 | labels = [dataset.PALETTE[lbl.to(int)].permute(2, 0, 1) for lbl in label]
30 | labels = torch.stack(labels)
31 |
32 | inv_normalize = T.Normalize(
33 | mean=(-0.485/0.229, -0.456/0.224, -0.406/0.225),
34 | std=(1/0.229, 1/0.224, 1/0.225)
35 | )
36 | image = inv_normalize(image)
37 | image *= 255
38 | images = torch.vstack([image, labels])
39 |
40 | plt.imshow(make_grid(images, nrow=4).to(torch.uint8).numpy().transpose((1, 2, 0)))
41 | plt.show()
42 |
43 |
44 | colors = [
45 | [120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
46 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
47 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
48 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
49 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
50 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
51 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
52 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
53 | [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
54 | [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
55 | [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
56 | [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
57 | [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
58 | [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
59 | [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
60 | [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
61 | [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
62 | [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
63 | [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], [102, 255, 0], [92, 0, 255]
64 | ]
65 |
66 |
67 | def generate_palette(num_classes, background: bool = False):
68 | random.shuffle(colors)
69 | if background:
70 | palette = [[0, 0, 0]]
71 | palette += colors[:num_classes-1]
72 | else:
73 | palette = colors[:num_classes]
74 | return np.array(palette)
75 |
76 |
77 | def draw_text(image: torch.Tensor, seg_map: torch.Tensor, labels: list, fontsize: int = 15):
78 | image = image.to(torch.uint8)
79 | font = ImageFont.truetype("assests/Helvetica.ttf", fontsize)
80 | pil_image = Image.fromarray(image.numpy())
81 | draw = ImageDraw.Draw(pil_image)
82 |
83 | indices = seg_map.unique().tolist()
84 | classes = [labels[index] for index in indices]
85 |
86 | for idx, cls in zip(indices, classes):
87 | mask = seg_map == idx
88 | mask = mask.squeeze().numpy()
89 | center = np.median((mask == 1).nonzero(), axis=1)[::-1]
90 | bbox = draw.textbbox(center, cls, font=font)
91 | bbox = (bbox[0]-3, bbox[1]-3, bbox[2]+3, bbox[3]+3)
92 | draw.rectangle(bbox, fill=(255, 255, 255), width=1)
93 | draw.text(center, cls, fill=(0, 0, 0), font=font)
94 | return pil_image
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='semseg',
5 | version='0.4.1',
6 | description='SOTA Semantic Segmentation Models',
7 | url='https://github.com/sithu31296/semantic-segmentation',
8 | author='Sithu Aung',
9 | author_email='sithu31296@gmail.com',
10 | license='MIT',
11 | packages=find_packages(include=['semseg']),
12 | install_requires=[
13 | 'tqdm',
14 | 'tabulate',
15 | 'numpy',
16 | 'scipy',
17 | 'matplotlib',
18 | 'tensorboard',
19 | 'fvcore',
20 | 'einops',
21 | 'rich',
22 | ]
23 | )
--------------------------------------------------------------------------------
/tools/benchmark.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import time
4 | from fvcore.nn import flop_count_table, FlopCountAnalysis
5 | from semseg.models import *
6 |
7 |
8 | def main(
9 | model_name: str,
10 | backbone_name: str,
11 | image_size: list,
12 | num_classes: int,
13 | device: str,
14 | ):
15 | device = torch.device('cuda' if torch.cuda.is_available() and device == 'cuda' else 'cpu')
16 | inputs = torch.randn(1, 3, *image_size).to(device)
17 | model = eval(model_name)(backbone_name, num_classes)
18 | model = model.to(device)
19 | model.eval()
20 |
21 | print(flop_count_table(FlopCountAnalysis(model, inputs)))
22 |
23 | total_time = 0.0
24 | for _ in range(10):
25 | tic = time.perf_counter()
26 | model(inputs)
27 | toc = time.perf_counter()
28 | total_time += toc - tic
29 | total_time /= 10
30 | print(f"Inference time: {total_time*1000:.2f}ms")
31 | print(f"FPS: {1/total_time}")
32 |
33 |
34 | if __name__ == '__main__':
35 | parser = argparse.ArgumentParser()
36 | parser.add_argument('--model-name', type=str, default='SegFormer')
37 | parser.add_argument('--backbone-name', type=str, default='MiT-B0')
38 | parser.add_argument('--image-size', type=list, default=[512, 512])
39 | parser.add_argument('--num-classes', type=int, default=11)
40 | parser.add_argument('--device', type=str, default='cuda')
41 | args = parser.parse_args()
42 |
43 | main(args.model_name, args.backbone_name, args.image_size, args.num_classes, args.device)
--------------------------------------------------------------------------------
/tools/export.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import yaml
4 | import onnx
5 | from pathlib import Path
6 | from onnxsim import simplify
7 | from semseg.models import *
8 | from semseg.datasets import *
9 |
10 |
11 | def export_onnx(model, inputs, file):
12 | torch.onnx.export(
13 | model,
14 | inputs,
15 | f"{cfg['TEST']['MODEL_PATH'].split('.')[0]}.onnx",
16 | input_names=['input'],
17 | output_names=['output'],
18 | opset_version=13
19 | )
20 | onnx_model = onnx.load(f"{file}.onnx")
21 | onnx.checker.check_model(onnx_model)
22 |
23 | onnx_model, check = simplify(onnx_model)
24 | onnx.save(onnx_model, f"{file}.onnx")
25 | assert check, "Simplified ONNX model could not be validated"
26 | print(f"ONNX model saved to {file}.onnx")
27 |
28 |
29 | def export_coreml(model, inputs, file):
30 | try:
31 | import coremltools as ct
32 | ts_model = torch.jit.trace(model, inputs, strict=True)
33 | ct_model = ct.convert(
34 | ts_model,
35 | inputs=[ct.ImageType('image', shape=inputs.shape, scale=1/255.0, bias=[0, 0, 0])]
36 | )
37 | ct_model.save(f"{file}.mlmodel")
38 | print(f"CoreML model saved to {file}.mlmodel")
39 | except:
40 | print("Please install coremltools to export to CoreML.\n`pip install coremltools`")
41 |
42 |
43 | def main(cfg):
44 | model = eval(cfg['MODEL']['NAME'])(cfg['MODEL']['BACKBONE'], len(eval(cfg['DATASET']['NAME']).PALETTE))
45 | model.load_state_dict(torch.load(cfg['TEST']['MODEL_PATH'], map_location='cpu'))
46 | model.eval()
47 |
48 | inputs = torch.randn(1, 3, *cfg['TEST']['IMAGE_SIZE'])
49 | file = cfg['TEST']['MODEL_PATH'].split('.')[0]
50 |
51 | export_onnx(model, inputs, file)
52 | export_coreml(model, inputs, file)
53 | print(f"Finished converting.")
54 |
55 |
56 | if __name__ == '__main__':
57 | parser = argparse.ArgumentParser()
58 | parser.add_argument('--cfg', type=str, default='configs/custom.yaml')
59 | args = parser.parse_args()
60 |
61 | with open(args.cfg) as f:
62 | cfg = yaml.load(f, Loader=yaml.SafeLoader)
63 |
64 | save_dir = Path(cfg['SAVE_DIR'])
65 | save_dir.mkdir(exist_ok=True)
66 |
67 | main(cfg)
--------------------------------------------------------------------------------
/tools/infer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import yaml
4 | import math
5 | from torch import Tensor
6 | from torch.nn import functional as F
7 | from pathlib import Path
8 | from torchvision import io
9 | from torchvision import transforms as T
10 | from semseg.models import *
11 | from semseg.datasets import *
12 | from semseg.utils.utils import timer
13 | from semseg.utils.visualize import draw_text
14 |
15 | from rich.console import Console
16 | console = Console()
17 |
18 |
19 | class SemSeg:
20 | def __init__(self, cfg) -> None:
21 | # inference device cuda or cpu
22 | self.device = torch.device(cfg['DEVICE'])
23 |
24 | # get dataset classes' colors and labels
25 | self.palette = eval(cfg['DATASET']['NAME']).PALETTE
26 | self.labels = eval(cfg['DATASET']['NAME']).CLASSES
27 |
28 | # initialize the model and load weights and send to device
29 | self.model = eval(cfg['MODEL']['NAME'])(cfg['MODEL']['BACKBONE'], len(self.palette))
30 | self.model.load_state_dict(torch.load(cfg['TEST']['MODEL_PATH'], map_location='cpu'))
31 | self.model = self.model.to(self.device)
32 | self.model.eval()
33 |
34 | # preprocess parameters and transformation pipeline
35 | self.size = cfg['TEST']['IMAGE_SIZE']
36 | self.tf_pipeline = T.Compose([
37 | T.Lambda(lambda x: x / 255),
38 | T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
39 | T.Lambda(lambda x: x.unsqueeze(0))
40 | ])
41 |
42 | def preprocess(self, image: Tensor) -> Tensor:
43 | H, W = image.shape[1:]
44 | console.print(f"Original Image Size > [red]{H}x{W}[/red]")
45 | # scale the short side of image to target size
46 | scale_factor = self.size[0] / min(H, W)
47 | nH, nW = round(H*scale_factor), round(W*scale_factor)
48 | # make it divisible by model stride
49 | nH, nW = int(math.ceil(nH / 32)) * 32, int(math.ceil(nW / 32)) * 32
50 | console.print(f"Inference Image Size > [red]{nH}x{nW}[/red]")
51 | # resize the image
52 | image = T.Resize((nH, nW))(image)
53 | # divide by 255, norm and add batch dim
54 | image = self.tf_pipeline(image).to(self.device)
55 | return image
56 |
57 | def postprocess(self, orig_img: Tensor, seg_map: Tensor, overlay: bool) -> Tensor:
58 | # resize to original image size
59 | seg_map = F.interpolate(seg_map, size=orig_img.shape[-2:], mode='bilinear', align_corners=True)
60 | # get segmentation map (value being 0 to num_classes)
61 | seg_map = seg_map.softmax(dim=1).argmax(dim=1).cpu().to(int)
62 |
63 | # convert segmentation map to color map
64 | seg_image = self.palette[seg_map].squeeze()
65 | if overlay:
66 | seg_image = (orig_img.permute(1, 2, 0) * 0.4) + (seg_image * 0.6)
67 |
68 | image = draw_text(seg_image, seg_map, self.labels)
69 | return image
70 |
71 | @torch.inference_mode()
72 | @timer
73 | def model_forward(self, img: Tensor) -> Tensor:
74 | return self.model(img)
75 |
76 | def predict(self, img_fname: str, overlay: bool) -> Tensor:
77 | image = io.read_image(img_fname)
78 | img = self.preprocess(image)
79 | seg_map = self.model_forward(img)
80 | seg_map = self.postprocess(image, seg_map, overlay)
81 | return seg_map
82 |
83 |
84 | if __name__ == '__main__':
85 | parser = argparse.ArgumentParser()
86 | parser.add_argument('--cfg', type=str, default='configs/ade20k.yaml')
87 | args = parser.parse_args()
88 |
89 | with open(args.cfg) as f:
90 | cfg = yaml.load(f, Loader=yaml.SafeLoader)
91 |
92 | test_file = Path(cfg['TEST']['FILE'])
93 | if not test_file.exists():
94 | raise FileNotFoundError(test_file)
95 |
96 | console.print(f"Model > [red]{cfg['MODEL']['NAME']} {cfg['MODEL']['BACKBONE']}[/red]")
97 | console.print(f"Model > [red]{cfg['DATASET']['NAME']}[/red]")
98 |
99 | save_dir = Path(cfg['SAVE_DIR']) / 'test_results'
100 | save_dir.mkdir(exist_ok=True)
101 |
102 | semseg = SemSeg(cfg)
103 |
104 | with console.status("[bright_green]Processing..."):
105 | if test_file.is_file():
106 | console.rule(f'[green]{test_file}')
107 | segmap = semseg.predict(str(test_file), cfg['TEST']['OVERLAY'])
108 | segmap.save(save_dir / f"{str(test_file.stem)}.png")
109 | else:
110 | files = test_file.glob('*.*')
111 | for file in files:
112 | console.rule(f'[green]{file}')
113 | segmap = semseg.predict(str(file), cfg['TEST']['OVERLAY'])
114 | segmap.save(save_dir / f"{str(file.stem)}.png")
115 |
116 | console.rule(f"[cyan]Segmentation results are saved in `{save_dir}`")
--------------------------------------------------------------------------------
/tools/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import yaml
4 | import time
5 | import multiprocessing as mp
6 | from tabulate import tabulate
7 | from tqdm import tqdm
8 | from torch.utils.data import DataLoader
9 | from pathlib import Path
10 | from torch.utils.tensorboard import SummaryWriter
11 | from torch.cuda.amp import GradScaler, autocast
12 | from torch.nn.parallel import DistributedDataParallel as DDP
13 | from torch.utils.data import DistributedSampler, RandomSampler
14 | from torch import distributed as dist
15 | from semseg.models import *
16 | from semseg.datasets import *
17 | from semseg.augmentations import get_train_augmentation, get_val_augmentation
18 | from semseg.losses import get_loss
19 | from semseg.schedulers import get_scheduler
20 | from semseg.optimizers import get_optimizer
21 | from semseg.utils.utils import fix_seeds, setup_cudnn, cleanup_ddp, setup_ddp
22 | from val import evaluate
23 |
24 |
25 | def main(cfg, gpu, save_dir):
26 | start = time.time()
27 | best_mIoU = 0.0
28 | num_workers = mp.cpu_count()
29 | device = torch.device(cfg['DEVICE'])
30 | train_cfg, eval_cfg = cfg['TRAIN'], cfg['EVAL']
31 | dataset_cfg, model_cfg = cfg['DATASET'], cfg['MODEL']
32 | loss_cfg, optim_cfg, sched_cfg = cfg['LOSS'], cfg['OPTIMIZER'], cfg['SCHEDULER']
33 | epochs, lr = train_cfg['EPOCHS'], optim_cfg['LR']
34 |
35 | traintransform = get_train_augmentation(train_cfg['IMAGE_SIZE'], seg_fill=dataset_cfg['IGNORE_LABEL'])
36 | valtransform = get_val_augmentation(eval_cfg['IMAGE_SIZE'])
37 |
38 | trainset = eval(dataset_cfg['NAME'])(dataset_cfg['ROOT'], 'train', traintransform)
39 | valset = eval(dataset_cfg['NAME'])(dataset_cfg['ROOT'], 'val', valtransform)
40 |
41 | model = eval(model_cfg['NAME'])(model_cfg['BACKBONE'], trainset.n_classes)
42 | model.init_pretrained(model_cfg['PRETRAINED'])
43 | model = model.to(device)
44 |
45 | if train_cfg['DDP']:
46 | sampler = DistributedSampler(trainset, dist.get_world_size(), dist.get_rank(), shuffle=True)
47 | model = DDP(model, device_ids=[gpu])
48 | else:
49 | sampler = RandomSampler(trainset)
50 |
51 | trainloader = DataLoader(trainset, batch_size=train_cfg['BATCH_SIZE'], num_workers=num_workers, drop_last=True, pin_memory=True, sampler=sampler)
52 | valloader = DataLoader(valset, batch_size=1, num_workers=1, pin_memory=True)
53 |
54 | iters_per_epoch = len(trainset) // train_cfg['BATCH_SIZE']
55 | # class_weights = trainset.class_weights.to(device)
56 | loss_fn = get_loss(loss_cfg['NAME'], trainset.ignore_label, None)
57 | optimizer = get_optimizer(model, optim_cfg['NAME'], lr, optim_cfg['WEIGHT_DECAY'])
58 | scheduler = get_scheduler(sched_cfg['NAME'], optimizer, epochs * iters_per_epoch, sched_cfg['POWER'], iters_per_epoch * sched_cfg['WARMUP'], sched_cfg['WARMUP_RATIO'])
59 | scaler = GradScaler(enabled=train_cfg['AMP'])
60 | writer = SummaryWriter(str(save_dir / 'logs'))
61 |
62 | for epoch in range(epochs):
63 | model.train()
64 | if train_cfg['DDP']: sampler.set_epoch(epoch)
65 |
66 | train_loss = 0.0
67 | pbar = tqdm(enumerate(trainloader), total=iters_per_epoch, desc=f"Epoch: [{epoch+1}/{epochs}] Iter: [{0}/{iters_per_epoch}] LR: {lr:.8f} Loss: {train_loss:.8f}")
68 |
69 | for iter, (img, lbl) in pbar:
70 | optimizer.zero_grad(set_to_none=True)
71 |
72 | img = img.to(device)
73 | lbl = lbl.to(device)
74 |
75 | with autocast(enabled=train_cfg['AMP']):
76 | logits = model(img)
77 | loss = loss_fn(logits, lbl)
78 |
79 | scaler.scale(loss).backward()
80 | scaler.step(optimizer)
81 | scaler.update()
82 | scheduler.step()
83 | torch.cuda.synchronize()
84 |
85 | lr = scheduler.get_lr()
86 | lr = sum(lr) / len(lr)
87 | train_loss += loss.item()
88 |
89 | pbar.set_description(f"Epoch: [{epoch+1}/{epochs}] Iter: [{iter+1}/{iters_per_epoch}] LR: {lr:.8f} Loss: {train_loss / (iter+1):.8f}")
90 |
91 | train_loss /= iter+1
92 | writer.add_scalar('train/loss', train_loss, epoch)
93 | torch.cuda.empty_cache()
94 |
95 | if (epoch+1) % train_cfg['EVAL_INTERVAL'] == 0 or (epoch+1) == epochs:
96 | miou = evaluate(model, valloader, device)[-1]
97 | writer.add_scalar('val/mIoU', miou, epoch)
98 |
99 | if miou > best_mIoU:
100 | best_mIoU = miou
101 | torch.save(model.module.state_dict() if train_cfg['DDP'] else model.state_dict(), save_dir / f"{model_cfg['NAME']}_{model_cfg['BACKBONE']}_{dataset_cfg['NAME']}.pth")
102 | print(f"Current mIoU: {miou} Best mIoU: {best_mIoU}")
103 |
104 | writer.close()
105 | pbar.close()
106 | end = time.gmtime(time.time() - start)
107 |
108 | table = [
109 | ['Best mIoU', f"{best_mIoU:.2f}"],
110 | ['Total Training Time', time.strftime("%H:%M:%S", end)]
111 | ]
112 | print(tabulate(table, numalign='right'))
113 |
114 |
115 | if __name__ == '__main__':
116 | parser = argparse.ArgumentParser()
117 | parser.add_argument('--cfg', type=str, default='configs/custom.yaml', help='Configuration file to use')
118 | args = parser.parse_args()
119 |
120 | with open(args.cfg) as f:
121 | cfg = yaml.load(f, Loader=yaml.SafeLoader)
122 |
123 | fix_seeds(3407)
124 | setup_cudnn()
125 | gpu = setup_ddp()
126 | save_dir = Path(cfg['SAVE_DIR'])
127 | save_dir.mkdir(exist_ok=True)
128 | main(cfg, gpu, save_dir)
129 | cleanup_ddp()
--------------------------------------------------------------------------------
/tools/val.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import yaml
4 | import math
5 | from pathlib import Path
6 | from tqdm import tqdm
7 | from tabulate import tabulate
8 | from torch.utils.data import DataLoader
9 | from torch.nn import functional as F
10 | from semseg.models import *
11 | from semseg.datasets import *
12 | from semseg.augmentations import get_val_augmentation
13 | from semseg.metrics import Metrics
14 | from semseg.utils.utils import setup_cudnn
15 |
16 |
17 | @torch.no_grad()
18 | def evaluate(model, dataloader, device):
19 | print('Evaluating...')
20 | model.eval()
21 | metrics = Metrics(dataloader.dataset.n_classes, dataloader.dataset.ignore_label, device)
22 |
23 | for images, labels in tqdm(dataloader):
24 | images = images.to(device)
25 | labels = labels.to(device)
26 | preds = model(images).softmax(dim=1)
27 | metrics.update(preds, labels)
28 |
29 | ious, miou = metrics.compute_iou()
30 | acc, macc = metrics.compute_pixel_acc()
31 | f1, mf1 = metrics.compute_f1()
32 |
33 | return acc, macc, f1, mf1, ious, miou
34 |
35 |
36 | @torch.no_grad()
37 | def evaluate_msf(model, dataloader, device, scales, flip):
38 | model.eval()
39 |
40 | n_classes = dataloader.dataset.n_classes
41 | metrics = Metrics(n_classes, dataloader.dataset.ignore_label, device)
42 |
43 | for images, labels in tqdm(dataloader):
44 | labels = labels.to(device)
45 | B, H, W = labels.shape
46 | scaled_logits = torch.zeros(B, n_classes, H, W).to(device)
47 |
48 | for scale in scales:
49 | new_H, new_W = int(scale * H), int(scale * W)
50 | new_H, new_W = int(math.ceil(new_H / 32)) * 32, int(math.ceil(new_W / 32)) * 32
51 | scaled_images = F.interpolate(images, size=(new_H, new_W), mode='bilinear', align_corners=True)
52 | scaled_images = scaled_images.to(device)
53 | logits = model(scaled_images)
54 | logits = F.interpolate(logits, size=(H, W), mode='bilinear', align_corners=True)
55 | scaled_logits += logits.softmax(dim=1)
56 |
57 | if flip:
58 | scaled_images = torch.flip(scaled_images, dims=(3,))
59 | logits = model(scaled_images)
60 | logits = torch.flip(logits, dims=(3,))
61 | logits = F.interpolate(logits, size=(H, W), mode='bilinear', align_corners=True)
62 | scaled_logits += logits.softmax(dim=1)
63 |
64 | metrics.update(scaled_logits, labels)
65 |
66 | acc, macc = metrics.compute_pixel_acc()
67 | f1, mf1 = metrics.compute_f1()
68 | ious, miou = metrics.compute_iou()
69 | return acc, macc, f1, mf1, ious, miou
70 |
71 |
72 | def main(cfg):
73 | device = torch.device(cfg['DEVICE'])
74 |
75 | eval_cfg = cfg['EVAL']
76 | transform = get_val_augmentation(eval_cfg['IMAGE_SIZE'])
77 | dataset = eval(cfg['DATASET']['NAME'])(cfg['DATASET']['ROOT'], 'val', transform)
78 | dataloader = DataLoader(dataset, 1, num_workers=1, pin_memory=True)
79 |
80 | model_path = Path(eval_cfg['MODEL_PATH'])
81 | if not model_path.exists(): model_path = Path(cfg['SAVE_DIR']) / f"{cfg['MODEL']['NAME']}_{cfg['MODEL']['BACKBONE']}_{cfg['DATASET']['NAME']}.pth"
82 | print(f"Evaluating {model_path}...")
83 |
84 | model = eval(cfg['MODEL']['NAME'])(cfg['MODEL']['BACKBONE'], dataset.n_classes)
85 | model.load_state_dict(torch.load(str(model_path), map_location='cpu'))
86 | model = model.to(device)
87 |
88 | if eval_cfg['MSF']['ENABLE']:
89 | acc, macc, f1, mf1, ious, miou = evaluate_msf(model, dataloader, device, eval_cfg['MSF']['SCALES'], eval_cfg['MSF']['FLIP'])
90 | else:
91 | acc, macc, f1, mf1, ious, miou = evaluate(model, dataloader, device)
92 |
93 | table = {
94 | 'Class': list(dataset.CLASSES) + ['Mean'],
95 | 'IoU': ious + [miou],
96 | 'F1': f1 + [mf1],
97 | 'Acc': acc + [macc]
98 | }
99 |
100 | print(tabulate(table, headers='keys'))
101 |
102 |
103 | if __name__ == '__main__':
104 | parser = argparse.ArgumentParser()
105 | parser.add_argument('--cfg', type=str, default='configs/custom.yaml')
106 | args = parser.parse_args()
107 |
108 | with open(args.cfg) as f:
109 | cfg = yaml.load(f, Loader=yaml.SafeLoader)
110 |
111 | setup_cudnn()
112 | main(cfg)
--------------------------------------------------------------------------------