├── .gitignore ├── LICENSE ├── README.md ├── assests ├── Helvetica.ttf ├── ade │ ├── ADE_val_00000001.jpg │ └── ADE_val_00000049.jpg ├── banner.jpg ├── image_labels │ ├── Seq05VD_f05100.png │ └── Seq05VD_f05100_L.png └── infer_result.png ├── configs ├── ade20k.yaml ├── cityscapes.yaml ├── custom.yaml └── helen.yaml ├── docs ├── BACKBONES.md ├── DATASETS.md ├── MODELS.md └── OTHER_DATASETS.md ├── notebooks ├── aug_test.ipynb └── tutorial.ipynb ├── scripts ├── calc_class_weights.py ├── export_data.py ├── onnx_infer.py ├── openvino_infer.py ├── preprocess_celebamaskhq.py └── tflite_infer.py ├── semseg ├── __init__.py ├── augmentations.py ├── datasets │ ├── __init__.py │ ├── ade20k.py │ ├── atr.py │ ├── camvid.py │ ├── celebamaskhq.py │ ├── cihp.py │ ├── cityscapes.py │ ├── cocostuff.py │ ├── facesynthetics.py │ ├── helen.py │ ├── ibugmask.py │ ├── lapa.py │ ├── lip.py │ ├── mapillary.py │ ├── mhpv1.py │ ├── mhpv2.py │ ├── pascalcontext.py │ ├── suim.py │ └── sunrgbd.py ├── losses.py ├── metrics.py ├── models │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── convnext.py │ │ ├── micronet.py │ │ ├── mit.py │ │ ├── mobilenetv2.py │ │ ├── mobilenetv3.py │ │ ├── poolformer.py │ │ ├── pvt.py │ │ ├── resnet.py │ │ ├── resnetd.py │ │ ├── rest.py │ │ └── uniformer.py │ ├── base.py │ ├── bisenetv1.py │ ├── bisenetv2.py │ ├── custom_cnn.py │ ├── custom_vit.py │ ├── ddrnet.py │ ├── fchardnet.py │ ├── heads │ │ ├── __init__.py │ │ ├── condnet.py │ │ ├── fapn.py │ │ ├── fcn.py │ │ ├── fpn.py │ │ ├── lawin.py │ │ ├── segformer.py │ │ ├── sfnet.py │ │ └── upernet.py │ ├── lawin.py │ ├── layers │ │ ├── __init__.py │ │ ├── common.py │ │ └── initialize.py │ ├── modules │ │ ├── __init__.py │ │ ├── ppm.py │ │ └── psa.py │ ├── segformer.py │ └── sfnet.py ├── optimizers.py ├── schedulers.py └── utils │ ├── __init__.py │ ├── utils.py │ └── visualize.py ├── setup.py └── tools ├── benchmark.py ├── export.py ├── infer.py ├── train.py └── val.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Repo-specific GitIgnore ---------------------------------------------------------------------------------------------- 2 | *.jpg 3 | *.jpeg 4 | *.png 5 | *.bmp 6 | *.tif 7 | *.tiff 8 | *.heic 9 | *.JPG 10 | *.JPEG 11 | *.PNG 12 | *.BMP 13 | *.TIF 14 | *.TIFF 15 | *.HEIC 16 | *.mp4 17 | *.mov 18 | *.MOV 19 | *.avi 20 | *.data 21 | *.json 22 | 23 | *.cfg 24 | !cfg/yolov3*.cfg 25 | 26 | storage.googleapis.com 27 | runs/* 28 | data/* 29 | !data/images/zidane.jpg 30 | !data/images/bus.jpg 31 | !data/coco.names 32 | !data/coco_paper.names 33 | !data/coco.data 34 | !data/coco_*.data 35 | !data/coco_*.txt 36 | !data/trainvalno5k.shapes 37 | !data/*.sh 38 | 39 | test.py 40 | test_imgs/ 41 | 42 | pycocotools/* 43 | results*.txt 44 | gcp_test*.sh 45 | 46 | checkpoints/ 47 | output/ 48 | assests/*/ 49 | 50 | # Datasets ------------------------------------------------------------------------------------------------------------- 51 | coco/ 52 | coco128/ 53 | VOC/ 54 | 55 | # MATLAB GitIgnore ----------------------------------------------------------------------------------------------------- 56 | *.m~ 57 | *.mat 58 | !targets*.mat 59 | 60 | # Neural Network weights ----------------------------------------------------------------------------------------------- 61 | *.weights 62 | *.pt 63 | *.onnx 64 | *.mlmodel 65 | *.torchscript 66 | darknet53.conv.74 67 | yolov3-tiny.conv.15 68 | 69 | # GitHub Python GitIgnore ---------------------------------------------------------------------------------------------- 70 | # Byte-compiled / optimized / DLL files 71 | __pycache__/ 72 | *.py[cod] 73 | *$py.class 74 | 75 | # C extensions 76 | *.so 77 | 78 | # Distribution / packaging 79 | .Python 80 | env/ 81 | build/ 82 | develop-eggs/ 83 | dist/ 84 | downloads/ 85 | eggs/ 86 | .eggs/ 87 | lib/ 88 | lib64/ 89 | parts/ 90 | sdist/ 91 | var/ 92 | wheels/ 93 | *.egg-info/ 94 | wandb/ 95 | .installed.cfg 96 | *.egg 97 | 98 | 99 | # PyInstaller 100 | # Usually these files are written by a python script from a template 101 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 102 | *.manifest 103 | *.spec 104 | 105 | # Installer logs 106 | pip-log.txt 107 | pip-delete-this-directory.txt 108 | 109 | # Unit test / coverage reports 110 | htmlcov/ 111 | .tox/ 112 | .coverage 113 | .coverage.* 114 | .cache 115 | nosetests.xml 116 | coverage.xml 117 | *.cover 118 | .hypothesis/ 119 | 120 | # Translations 121 | *.mo 122 | *.pot 123 | 124 | # Django stuff: 125 | *.log 126 | local_settings.py 127 | 128 | # Flask stuff: 129 | instance/ 130 | .webassets-cache 131 | 132 | # Scrapy stuff: 133 | .scrapy 134 | 135 | # Sphinx documentation 136 | docs/_build/ 137 | 138 | # PyBuilder 139 | target/ 140 | 141 | # Jupyter Notebook 142 | .ipynb_checkpoints 143 | 144 | # pyenv 145 | .python-version 146 | 147 | # celery beat schedule file 148 | celerybeat-schedule 149 | 150 | # SageMath parsed files 151 | *.sage.py 152 | 153 | # dotenv 154 | .env 155 | 156 | # virtualenv 157 | .venv* 158 | venv*/ 159 | ENV*/ 160 | 161 | # Spyder project settings 162 | .spyderproject 163 | .spyproject 164 | 165 | # Rope project settings 166 | .ropeproject 167 | 168 | # mkdocs documentation 169 | /site 170 | 171 | # mypy 172 | .mypy_cache/ 173 | 174 | 175 | # https://github.com/github/gitignore/blob/master/Global/macOS.gitignore ----------------------------------------------- 176 | 177 | # General 178 | .DS_Store 179 | .AppleDouble 180 | .LSOverride 181 | 182 | # Icon must end with two \r 183 | Icon 184 | Icon? 185 | 186 | # Thumbnails 187 | ._* 188 | 189 | # Files that might appear in the root of a volume 190 | .DocumentRevisions-V100 191 | .fseventsd 192 | .Spotlight-V100 193 | .TemporaryItems 194 | .Trashes 195 | .VolumeIcon.icns 196 | .com.apple.timemachine.donotpresent 197 | 198 | # Directories potentially created on remote AFP share 199 | .AppleDB 200 | .AppleDesktop 201 | Network Trash Folder 202 | Temporary Items 203 | .apdisk 204 | 205 | 206 | # https://github.com/github/gitignore/blob/master/Global/JetBrains.gitignore 207 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 208 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 209 | 210 | # User-specific stuff: 211 | .idea/* 212 | .idea/**/workspace.xml 213 | .idea/**/tasks.xml 214 | .idea/dictionaries 215 | .html # Bokeh Plots 216 | .pg # TensorFlow Frozen Graphs 217 | .avi # videos 218 | 219 | # Sensitive or high-churn files: 220 | .idea/**/dataSources/ 221 | .idea/**/dataSources.ids 222 | .idea/**/dataSources.local.xml 223 | .idea/**/sqlDataSources.xml 224 | .idea/**/dynamic.xml 225 | .idea/**/uiDesigner.xml 226 | 227 | # Gradle: 228 | .idea/**/gradle.xml 229 | .idea/**/libraries 230 | 231 | # CMake 232 | cmake-build-debug/ 233 | cmake-build-release/ 234 | 235 | # Mongo Explorer plugin: 236 | .idea/**/mongoSettings.xml 237 | 238 | ## File-based project format: 239 | *.iws 240 | 241 | ## Plugin-specific files: 242 | 243 | # IntelliJ 244 | out/ 245 | 246 | # mpeltonen/sbt-idea plugin 247 | .idea_modules/ 248 | 249 | # JIRA plugin 250 | atlassian-ide-plugin.xml 251 | 252 | # Cursive Clojure plugin 253 | .idea/replstate.xml 254 | 255 | # Crashlytics plugin (for Android Studio and IntelliJ) 256 | com_crashlytics_export_strings.xml 257 | crashlytics.properties 258 | crashlytics-build.properties 259 | fabric.properties 260 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 sithu3 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assests/Helvetica.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sithu31296/semantic-segmentation/c2843c9cf42a66a96a4173a0ef0059975a6e387c/assests/Helvetica.ttf -------------------------------------------------------------------------------- /assests/ade/ADE_val_00000001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sithu31296/semantic-segmentation/c2843c9cf42a66a96a4173a0ef0059975a6e387c/assests/ade/ADE_val_00000001.jpg -------------------------------------------------------------------------------- /assests/ade/ADE_val_00000049.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sithu31296/semantic-segmentation/c2843c9cf42a66a96a4173a0ef0059975a6e387c/assests/ade/ADE_val_00000049.jpg -------------------------------------------------------------------------------- /assests/banner.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sithu31296/semantic-segmentation/c2843c9cf42a66a96a4173a0ef0059975a6e387c/assests/banner.jpg -------------------------------------------------------------------------------- /assests/image_labels/Seq05VD_f05100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sithu31296/semantic-segmentation/c2843c9cf42a66a96a4173a0ef0059975a6e387c/assests/image_labels/Seq05VD_f05100.png -------------------------------------------------------------------------------- /assests/image_labels/Seq05VD_f05100_L.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sithu31296/semantic-segmentation/c2843c9cf42a66a96a4173a0ef0059975a6e387c/assests/image_labels/Seq05VD_f05100_L.png -------------------------------------------------------------------------------- /assests/infer_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sithu31296/semantic-segmentation/c2843c9cf42a66a96a4173a0ef0059975a6e387c/assests/infer_result.png -------------------------------------------------------------------------------- /configs/ade20k.yaml: -------------------------------------------------------------------------------- 1 | DEVICE : cuda # device used for training and evaluation (cpu, cuda, cuda0, cuda1, ...) 2 | SAVE_DIR : 'output' # output folder name used for saving the model, logs and inference results 3 | 4 | MODEL: 5 | NAME : SegFormer # name of the model you are using 6 | BACKBONE : MiT-B2 # model variant 7 | PRETRAINED : 'checkpoints/backbones/mit/mit_b2.pth' # backbone model's weight 8 | 9 | DATASET: 10 | NAME : ADE20K # dataset name to be trained with (camvid, cityscapes, ade20k) 11 | ROOT : 'data/ADEChallengeData2016' # dataset root path 12 | IGNORE_LABEL : -1 13 | 14 | TRAIN: 15 | IMAGE_SIZE : [512, 512] # training image size in (h, w) 16 | BATCH_SIZE : 8 # batch size used to train 17 | EPOCHS : 500 # number of epochs to train 18 | EVAL_INTERVAL : 50 # evaluation interval during training 19 | AMP : false # use AMP in training 20 | DDP : false # use DDP training 21 | 22 | LOSS: 23 | NAME : OhemCrossEntropy # loss function name (ohemce, ce, dice) 24 | CLS_WEIGHTS : false # use class weights in loss calculation 25 | 26 | OPTIMIZER: 27 | NAME : adamw # optimizer name 28 | LR : 0.001 # initial learning rate used in optimizer 29 | WEIGHT_DECAY : 0.01 # decay rate used in optimizer 30 | 31 | SCHEDULER: 32 | NAME : warmuppolylr # scheduler name 33 | POWER : 0.9 # scheduler power 34 | WARMUP : 10 # warmup epochs used in scheduler 35 | WARMUP_RATIO : 0.1 # warmup ratio 36 | 37 | 38 | EVAL: 39 | MODEL_PATH : 'checkpoints/pretrained/segformer/segformer.b2.ade.pth' # trained model file path 40 | IMAGE_SIZE : [512, 512] # evaluation image size in (h, w) 41 | MSF: 42 | ENABLE : false # multi-scale and flip evaluation 43 | FLIP : true # use flip in evaluation 44 | SCALES : [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] # scales used in MSF evaluation 45 | 46 | 47 | TEST: 48 | MODEL_PATH : 'checkpoints/pretrained/segformer/segformer.b2.ade.pth' # trained model file path 49 | FILE : 'assests/ade' # filename or foldername 50 | IMAGE_SIZE : [512, 512] # inference image size in (h, w) 51 | OVERLAY : true # save the overlay result (image_alpha+label_alpha) -------------------------------------------------------------------------------- /configs/cityscapes.yaml: -------------------------------------------------------------------------------- 1 | DEVICE : cuda # device used for training and evaluation (cpu, cuda, cuda0, cuda1, ...) 2 | SAVE_DIR : 'output' # output folder name used for saving the model, logs and inference results 3 | 4 | MODEL: 5 | NAME : DDRNet # name of the model you are using 6 | BACKBONE : DDRNet-23slim # model variant 7 | PRETRAINED : 'checkpoints/backbones/ddrnet/ddrnet_23slim.pth' # backbone model's weight 8 | 9 | DATASET: 10 | NAME : CityScapes # dataset name to be trained with (camvid, cityscapes, ade20k) 11 | ROOT : 'data/CityScapes' # dataset root path 12 | IGNORE_LABEL : 255 13 | 14 | TRAIN: 15 | IMAGE_SIZE : [1024, 1024] # training image size in (h, w) 16 | BATCH_SIZE : 8 # batch size used to train 17 | EPOCHS : 500 # number of epochs to train 18 | EVAL_INTERVAL : 20 # evaluation interval during training 19 | AMP : false # use AMP in training 20 | DDP : false # use DDP training 21 | 22 | LOSS: 23 | NAME : OhemCrossEntropy # loss function name (ohemce, ce, dice) 24 | CLS_WEIGHTS : false # use class weights in loss calculation 25 | 26 | OPTIMIZER: 27 | NAME : adamw # optimizer name 28 | LR : 0.001 # initial learning rate used in optimizer 29 | WEIGHT_DECAY : 0.01 # decay rate used in optimizer 30 | 31 | SCHEDULER: 32 | NAME : warmuppolylr # scheduler name 33 | POWER : 0.9 # scheduler power 34 | WARMUP : 10 # warmup epochs used in scheduler 35 | WARMUP_RATIO : 0.1 # warmup ratio 36 | 37 | 38 | EVAL: 39 | MODEL_PATH : 'checkpoints/pretrained/ddrnet/ddrnet_23slim_city.pth' # trained model file path 40 | IMAGE_SIZE : [1024, 1024] # evaluation image size in (h, w) 41 | MSF: 42 | ENABLE : false # multi-scale and flip evaluation 43 | FLIP : true # use flip in evaluation 44 | SCALES : [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] # scales used in MSF evaluation 45 | 46 | 47 | TEST: 48 | MODEL_PATH : 'checkpoints/pretrained/ddrnet/ddrnet_23slim_city.pth' # trained model file path 49 | FILE : 'assests/cityscapes' # filename or foldername 50 | IMAGE_SIZE : [1024, 1024] # inference image size in (h, w) 51 | OVERLAY : true # save the overlay result (image_alpha+label_alpha) -------------------------------------------------------------------------------- /configs/custom.yaml: -------------------------------------------------------------------------------- 1 | DEVICE : cuda # device used for training and evaluation (cpu, cuda, cuda0, cuda1, ...) 2 | SAVE_DIR : 'output' # output folder name used for saving the model, logs and inference results 3 | 4 | MODEL: 5 | NAME : DDRNet # name of the model you are using 6 | BACKBONE : DDRNet-23slim # model variant 7 | PRETRAINED : 'checkpoints/backbones/ddrnet/ddrnet_23slim.pth' # backbone model's weight 8 | 9 | DATASET: 10 | NAME : CityScapes # dataset name to be trained with (camvid, cityscapes, ade20k) 11 | ROOT : 'data/CityScapes' # dataset root path 12 | IGNORE_LABEL : 255 13 | 14 | TRAIN: 15 | IMAGE_SIZE : [512, 512] # training image size in (h, w) 16 | BATCH_SIZE : 2 # batch size used to train 17 | EPOCHS : 100 # number of epochs to train 18 | EVAL_INTERVAL : 20 # evaluation interval during training 19 | AMP : false # use AMP in training 20 | DDP : false # use DDP training 21 | 22 | LOSS: 23 | NAME : OhemCrossEntropy # loss function name (ohemce, ce, dice) 24 | CLS_WEIGHTS : false # use class weights in loss calculation 25 | 26 | OPTIMIZER: 27 | NAME : adamw # optimizer name 28 | LR : 0.001 # initial learning rate used in optimizer 29 | WEIGHT_DECAY : 0.01 # decay rate used in optimizer 30 | 31 | SCHEDULER: 32 | NAME : warmuppolylr # scheduler name 33 | POWER : 0.9 # scheduler power 34 | WARMUP : 10 # warmup epochs used in scheduler 35 | WARMUP_RATIO : 0.1 # warmup ratio 36 | 37 | 38 | EVAL: 39 | MODEL_PATH : 'checkpoints/pretrained/ddrnet/ddrnet_23slim_city.pth' # trained model file path 40 | IMAGE_SIZE : [1024, 1024] # evaluation image size in (h, w) 41 | MSF: 42 | ENABLE : false # multi-scale and flip evaluation 43 | FLIP : true # use flip in evaluation 44 | SCALES : [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] # scales used in MSF evaluation 45 | 46 | 47 | TEST: 48 | MODEL_PATH : 'checkpoints/pretrained/ddrnet/ddrnet_23slim_city.pth' # trained model file path 49 | FILE : 'assests/cityscapes' # filename or foldername 50 | IMAGE_SIZE : [1024, 1024] # inference image size in (h, w) 51 | OVERLAY : true # save the overlay result (image_alpha+label_alpha) -------------------------------------------------------------------------------- /configs/helen.yaml: -------------------------------------------------------------------------------- 1 | DEVICE : cuda # device used for training and evaluation (cpu, cuda, cuda0, cuda1, ...) 2 | SAVE_DIR : 'output' # output folder name used for saving the model, logs and inference results 3 | 4 | MODEL: 5 | NAME : DDRNet # name of the model you are using 6 | BACKBONE : DDRNet-23slim # model variant 7 | PRETRAINED : 'checkpoints/backbones/ddrnet/ddrnet_23slim.pth' # backbone model's weight 8 | 9 | DATASET: 10 | NAME : HELEN # dataset name to be trained with (camvid, cityscapes, ade20k) 11 | ROOT : '/home/sithu/datasets/SmithCVPR2013_dataset_resized' # dataset root path 12 | IGNORE_LABEL : 255 13 | 14 | TRAIN: 15 | IMAGE_SIZE : [512, 512] # training image size in (h, w) 16 | BATCH_SIZE : 16 # batch size used to train 17 | EPOCHS : 200 # number of epochs to train 18 | EVAL_INTERVAL : 10 # evaluation interval during training 19 | AMP : false # use AMP in training 20 | DDP : false # use DDP training 21 | 22 | LOSS: 23 | NAME : OhemCrossEntropy # loss function name (OhemCrossEntropy, CrossEntropy, Dice) 24 | CLS_WEIGHTS : false # use class weights in loss calculation 25 | 26 | OPTIMIZER: 27 | NAME : adamw # optimizer name 28 | LR : 0.001 # initial learning rate used in optimizer 29 | WEIGHT_DECAY : 0.01 # decay rate used in optimizer 30 | 31 | SCHEDULER: 32 | NAME : warmuppolylr # scheduler name 33 | POWER : 0.9 # scheduler power 34 | WARMUP : 5 # warmup epochs used in scheduler 35 | WARMUP_RATIO : 0.1 # warmup ratio 36 | 37 | 38 | EVAL: 39 | MODEL_PATH : 'output/DDRNet_DDRNet-23slim_HELEN_61_11.pth' # trained model file path 40 | IMAGE_SIZE : [512, 512] # evaluation image size in (h, w) 41 | MSF: 42 | ENABLE : false # multi-scale and flip evaluation 43 | FLIP : true # use flip in evaluation 44 | SCALES : [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] # scales used in MSF evaluation 45 | 46 | 47 | TEST: 48 | MODEL_PATH : 'output/DDRNet_DDRNet-23slim_HELEN_61_11.pth' # trained model file path 49 | FILE : 'assests/faces' # filename or foldername 50 | IMAGE_SIZE : [512, 512] # inference image size in (h, w) 51 | OVERLAY : true # save the overlay result (image_alpha+label_alpha) 52 | -------------------------------------------------------------------------------- /docs/BACKBONES.md: -------------------------------------------------------------------------------- 1 | ## Supported Backbones 2 | 3 | Backbone | Variants | ImageNet-1k Top-1 Acc (%) | Params (M) | GFLOPs | Weights 4 | --- | --- | --- | --- | --- | --- 5 | MicroNet | M1\|M2\|M3 | 51.4`\|`59.4`\|`62.5 | 1`\|`2`\|`3 | 7M`\|`14M`\|`23M | [download][micronetw] 6 | MobileNetV2 | 1.0 | 71.9 | 3 | 300M | [download][mobilenetv2w] 7 | MobileNetV3 | S\|L | 67.7`\|`74.0 | 3`\|`5 | 56M`\|`219M | [S][mobilenetv3s]\|[L][mobilenetv3l] 8 | DDRNet | 23slim | 73.7 | 5 | 860M | [download][ddrnet23slim] 9 | || 10 | ResNet | 18\|50\|101 | 71.5`\|`80.4`\|`81.5 | 12`\|`26`\|`45 | 2`\|`4`\|`8 | [download][resnetw] 11 | ResNetD | 18\|50\|101 | - | 12`\|`25`\|`44 | 2`\|`4`\|`8 | [download][resnetdw] 12 | MiT | B1\|B2\|B3 | - | 14`\|`25`\|`45 | 2`\|`4`\|`8 | [download][mitw] 13 | PVTv2 | B1\|B2\|B4 | 78.7`\|`82.0`\|`83.6 | 14`\|`25`\|`63 | 2`\|`4`\|`10 | [download][pvtv2w] 14 | ResT | S\|B\|L | 79.6`\|`81.6`\|`83.6 | 14`\|`30`\|`52 | 2`\|`4`\|`8 | [download][restw] 15 | PoolFormer | S24\|S36\|M36 | 80.3`\|`81.4`\|`82.1 | 21`\|`31`\|`56 | 4`\|`5`\|`9 | [download][poolformerw] 16 | ConvNeXt | T\|S\|B | 82.1`\|`83.1`\|`83.8 | 28`\|`50`\|`89 | 5`\|`9`\|`15 | [download][convnextw] 17 | UniFormer | S\|B | 82.9`\|`83.8 | 22`\|`50 | 4`\|`8 | [download][uniformerw] 18 | VAN | S\|B\|L | 81.1`\|`82.8`\|`83.9 | 14`\|`27`\|`45 | 3`\|`5`\|`9 | - 19 | DaViT | T\|S\|B | 82.8`\|`84.2`\|`84.6 | 28`\|`50`\|`88 | 5`\|`9`\|`16 | - 20 | 21 | 22 | [micronetw]: https://drive.google.com/drive/folders/1j4JSTcAh94U2k-7jCl_3nwbNi0eduM2P?usp=sharing 23 | [mobilenetv2w]: https://download.pytorch.org/models/mobilenet_v2-b0353104.pth 24 | [mobilenetv3s]: https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth 25 | [mobilenetv3l]: https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth 26 | [resnetw]: https://drive.google.com/drive/folders/1MXP3Qx51c91PL9P52Tv89t90SaiTYuaC?usp=sharing 27 | [resnetdw]: https://drive.google.com/drive/folders/1sVyewBDkePlw3kbvhUD4PvUxjro4iKFy?usp=sharing 28 | [mitw]: https://drive.google.com/drive/folders/1b7bwrInTW4VLEm27YawHOAMSMikga2Ia 29 | [pvtv2w]: https://drive.google.com/drive/folders/10Dd9BEe4wv71dC5BXhsL_C6KeI_Rcxm3?usp=sharing 30 | [restw]: https://drive.google.com/drive/folders/1R2cewgHo6sYcQnRGBBIndjNomumBwekr?usp=sharing 31 | [ddrnet23slim]: https://drive.google.com/file/d/1tUcUCCsEZ7qKaF_bHHHECTonp4vbh-a9/view?usp=sharing 32 | [poolformerw]: https://drive.google.com/drive/folders/18OyxHHpVq-9pMMG2eu1jot7n-po4dUpD?usp=sharing 33 | [convnextw]: https://drive.google.com/drive/folders/1Oe50_zY4QKFZ0_22mSHKuNav0GiRcgWA?usp=sharing 34 | [uniformerw]: https://drive.google.com/drive/folders/175C4Je4kZoBb5x8HkwH4-VhtG_a5zQnX?usp=sharing -------------------------------------------------------------------------------- /docs/DATASETS.md: -------------------------------------------------------------------------------- 1 | ##
Supported Datasets
2 | 3 | [ade20k]: http://sceneparsing.csail.mit.edu/ 4 | [cityscapes]: https://www.cityscapes-dataset.com/ 5 | [camvid]: http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/ 6 | [cocostuff]: https://github.com/nightrome/cocostuff 7 | [mhp]: https://lv-mhp.github.io/ 8 | [lip]: http://sysu-hcp.net/lip/index.php 9 | [atr]: https://github.com/lemondan/HumanParsing-Dataset 10 | [pascalcontext]: https://cs.stanford.edu/~roozbeh/pascal-context/ 11 | [pcannos]: https://drive.google.com/file/d/1hOQnuTVYE9s7iRdo-6iARWkN2-qCAoVz/view?usp=sharing 12 | [suim]: http://irvlab.cs.umn.edu/resources/suim-dataset 13 | [mv]: https://www.mapillary.com/dataset/vistas 14 | [sunrgbd]: https://rgbd.cs.princeton.edu/ 15 | [helen]: https://www.sifeiliu.net/face-parsing 16 | [celeba]: https://github.com/switchablenorms/CelebAMask-HQ 17 | [lapa]: https://github.com/JDAI-CV/lapa-dataset 18 | [ibugmask]: https://github.com/hhj1897/face_parsing 19 | [facesynthetics]: https://github.com/microsoft/FaceSynthetics 20 | [ccihp]: https://kalisteo.cea.fr/wp-content/uploads/2021/09/README.html 21 | 22 | Dataset | Type | Categories | Train
Images | Val
Images | Test
Images | Image Size
(HxW) 23 | --- | --- | --- | --- | --- | --- | --- 24 | [COCO-Stuff][cocostuff] | General Scene Parsing | 171 | 118,000 | 5,000 | 20,000 | - 25 | [ADE20K][ade20k] | General Scene Parsing | 150 | 20,210 | 2,000 | 3,352 | - 26 | [PASCALContext][pascalcontext] | General Scene Parsing | 59 | 4,996 | 5,104 | 9,637 | - 27 | || 28 | [SUN RGB-D][sunrgbd] | Indoor Scene Parsing | 37 | 2,666 | 2,619 | 5,050+labels | - 29 | || 30 | [Mapillary Vistas][mv] | Street Scene Parsing | 65 | 18,000 | 2,000 | 5,000 | 1080x1920 31 | [CityScapes][cityscapes] | Street Scene Parsing | 19 | 2,975 | 500 | 1,525+labels | 1024x2048 32 | [CamVid][camvid] | Street Scene Parsing | 11 | 367 | 101 | 233+labels | 720x960 33 | || 34 | [MHPv2][mhp] | Multi-Human Parsing | 59 | 15,403 | 5,000 | 5,000 | - 35 | [MHPv1][mhp] | Multi-Human Parsing | 19 | 3,000 | 1,000 | 980+labels | - 36 | [LIP][lip] | Multi-Human Parsing | 20 | 30,462 | 10,000 | - | - 37 | [CCIHP][ccihp] | Multi-Human Parsing | 22 | 28,280 | 5,000 | 5,000 | - 38 | [CIHP][lip] | Multi-Human Parsing | 20 | 28,280 | 5,000 | 5,000 | - 39 | [ATR][atr] | Single-Human Parsing | 18 | 16,000 | 700 | 1,000+labels | - 40 | || 41 | [HELEN][helen] | Face Parsing | 11 | 2,000 | 230 | 100+labels | - 42 | [LaPa][lapa] | Face Parsing | 11 | 18,176 | 2,000 | 2,000+labels | - 43 | [iBugMask][ibugmask] | Face Parsing | 11 | 21,866 | - | 1,000+labels | - 44 | [CelebAMaskHQ][celeba] | Face Parsing | 19 | 24,183 | 2,993 | 2,824+labels | 512x512 45 | [FaceSynthetics][facesynthetics] | Face Parsing (Synthetic) | 19 | 100,000 | 1,000 | 100+labels | 512x512 46 | || 47 | [SUIM][suim] | Underwater Imagery | 8 | 1,525 | - | 110+labels | - 48 | 49 | Check [DATASETS](./DATASETS.md) to find more segmentation datasets. 50 | 51 |
52 | Datasets Structure (click to expand) 53 | 54 | Datasets should have the following structure: 55 | 56 | ``` 57 | data 58 | |__ ADEChallenge 59 | |__ ADEChallengeData2016 60 | |__ images 61 | |__ training 62 | |__ validation 63 | |__ annotations 64 | |__ training 65 | |__ validation 66 | 67 | |__ CityScapes 68 | |__ leftImg8bit 69 | |__ train 70 | |__ val 71 | |__ test 72 | |__ gtFine 73 | |__ train 74 | |__ val 75 | |__ test 76 | 77 | |__ CamVid 78 | |__ train 79 | |__ val 80 | |__ test 81 | |__ train_labels 82 | |__ val_labels 83 | |__ test_labels 84 | 85 | |__ VOCdevkit 86 | |__ VOC2010 87 | |__ JPEGImages 88 | |__ SegmentationClassContext 89 | |__ ImageSets 90 | |__ SegmentationContext 91 | |__ train.txt 92 | |__ val.txt 93 | 94 | |__ COCO 95 | |__ images 96 | |__ train2017 97 | |__ val2017 98 | |__ labels 99 | |__ train2017 100 | |__ val2017 101 | 102 | |__ MHPv1 103 | |__ images 104 | |__ annotations 105 | |__ train_list.txt 106 | |__ test_list.txt 107 | 108 | |__ MHPv2 109 | |__ train 110 | |__ images 111 | |__ parsing_annos 112 | |__ val 113 | |__ images 114 | |__ parsing_annos 115 | 116 | |__ LIP 117 | |__ LIP 118 | |__ TrainVal_images 119 | |__ train_images 120 | |__ val_images 121 | |__ TrainVal_parsing_annotations 122 | |__ train_segmentations 123 | |__ val_segmentations 124 | 125 | |__ CIHP/CCIHP 126 | |__ instance-leve_human_parsing 127 | |__ Training 128 | |__ Images 129 | |__ Category_ids 130 | |__ Validation 131 | |__ Images 132 | |__ Category_ids 133 | 134 | |__ ATR 135 | |__ humanparsing 136 | |__ JPEGImages 137 | |__ SegmentationClassAug 138 | 139 | |__ SUIM 140 | |__ train_val 141 | |__ images 142 | |__ masks 143 | |__ TEST 144 | |__ images 145 | |__ masks 146 | 147 | |__ SunRGBD 148 | |__ SUNRGBD 149 | |__ kv1/kv2/realsense/xtion 150 | |__ SUNRGBDtoolbox 151 | |__ traintestSUNRGBD 152 | |__ allsplit.mat 153 | 154 | |__ Mapillary 155 | |__ training 156 | |__ images 157 | |__ labels 158 | |__ validation 159 | |__ images 160 | |__ labels 161 | 162 | |__ SmithCVPR2013_dataset_resized (HELEN) 163 | |__ images 164 | |__ labels 165 | |__ exemplars.txt 166 | |__ testing.txt 167 | |__ tuning.txt 168 | 169 | |__ CelebAMask-HQ 170 | |__ CelebA-HQ-img 171 | |__ CelebAMask-HQ-mask-anno 172 | |__ CelebA-HQ-to-CelebA-mapping.txt 173 | 174 | |__ LaPa 175 | |__ train 176 | |__ images 177 | |__ labels 178 | |__ val 179 | |__ images 180 | |__ labels 181 | |__ test 182 | |__ images 183 | |__ labels 184 | 185 | |__ ibugmask_release 186 | |__ train 187 | |__ test 188 | 189 | |__ FaceSynthetics 190 | |__ dataset_100000 191 | |__ dataset_1000 192 | |__ dataset_100 193 | ``` 194 | 195 | > Note: For PASCALContext, download the annotations from [here](pcannos) and put it in VOC2010. 196 | 197 | > Note: For CelebAMask-HQ, run the preprocess script. `python3 scripts/preprocess_celebamaskhq.py --root `. 198 | 199 |
200 | -------------------------------------------------------------------------------- /docs/MODELS.md: -------------------------------------------------------------------------------- 1 | ## Scene Parsing 2 | 3 | Accurate Models 4 | 5 | Method | Backbone | ADE20K
(mIoU) | Cityscapes
(mIoU) | COCO-Stuff
(mIoU) |Params
(M) | GFLOPs
(512x512) | GFLOPs
(1024x1024) | Weights 6 | --- | --- | --- | --- | --- | --- | --- | --- | --- 7 | SegFormer | MiT-B1 | 42.2 | 78.5 | 40.2 | 14 | 16 | 244 | [ade][segformerb1] 8 | || MiT-B2 | 46.5 | 81.0 | 44.6 | 28 | 62 | 717 | [ade][segformerb2] 9 | || MiT-B3 | 49.4 | 81.7 | 45.5 | 47 | 79 | 963 | [ade][segformerb3] 10 | || 11 | Light-Ham | VAN-S | 45.7 | - | - | 15 | 21 | - | - 12 | || VAN-B | 49.6 | - | - | 27 | 34 | - | - 13 | || VAN-L | 51.0 | - | - | 46 | 55 | - | - 14 | || 15 | Lawin | MiT-B1 | 42.1 | 79.0 | 40.5 | 14 | 13 | 218 | - 16 | || MiT-B2 | 47.8 | 81.7 | 45.2 | 30 | 45 | 563 | - 17 | || MiT-B3 | 50.3 | 82.5 | 46.6 | 50 | 62 | 809 | - 18 | || 19 | TopFormer | TopFormer-T | 34.6 | - | - | 1.4 | 0.6 | - | - 20 | || TopFormer-S | 37.0 | - | - | 3.1 | 1.2 | - | - 21 | || TopFormer-B | 39.2 | - | - | 5.1 | 1.8 | - | - 22 | 23 | * mIoU results are with a single scale from official papers. 24 | * ADE20K image size = 512x512 25 | * Cityscapes image size = 1024x1024 26 | * COCO-Stuff image size = 512x512 27 | 28 | Real-time Models 29 | 30 | Method | Backbone | CityScapes-val
(mIoU) | CamVid
(mIoU) | Params (M) | GFLOPs
(1024x2048) | Weights 31 | --- | --- | --- | --- | --- | --- | --- 32 | BiSeNetv1 | ResNet-18 | 74.8 | 68.7 | 14 | 49 | - 33 | BiSeNetv2 | - | 73.4 | 72.4 | 18 | 21 | - 34 | SFNet | ResNetD-18 | 79.0 | - | 13 | - | - 35 | DDRNet | DDRNet-23slim | 77.8 | 74.7 | 6 | 36 | [city][ddrnet] 36 | 37 | * mIoU results are with a single scale from official papers. 38 | * Cityscapes image size = 1024x2048 (except BiSeNetv1 & 2 which uses 512x1024) 39 | * CamVid image size = 960x720 40 | 41 | 42 | ## Face Parsing 43 | 44 | Method | Backbone | HELEN-val
(mIoU) | Params
(M) | GFLOPs
(512x512) | FPS
(GTX1660ti) | Weights 45 | --- | --- | --- | --- | --- | --- | --- 46 | BiSeNetv1 | ResNet-18 | 58.50 | 14 | 13 | 263 | [HELEN](https://drive.google.com/file/d/1HMC6OiFPc-aYwhlHlPYoXa-VCR3r2WPQ/view?usp=sharing) 47 | BiSeNetv2 | - | 58.58 | 18 | 15 | 195 | [HELEN](https://drive.google.com/file/d/1cf-W_2m-vfxMRZ0mFQjEwhOglURpH7m6/view?usp=sharing) 48 | DDRNet | DDRNet-23slim | 61.11 | 6 | 5 | 180 | [HELEN](https://drive.google.com/file/d/1SdOgVvgYrp8UFztHWN6dHH0MhP8zqnyh/view?usp=sharing) 49 | SFNet | ResNetD-18 | 61.00 | 14 | 31 | 56 | [HELEN](https://drive.google.com/file/d/13w42DgI4PJ05bkWY9XCK_skSGMsmXroj/view?usp=sharing) 50 | 51 | 52 | [ddrnet]: https://drive.google.com/file/d/1VdE3OkrIlIzLRPuT-2So-Xq_5gPaxm0t/view?usp=sharing 53 | [segformerb3]: https://drive.google.com/file/d/1-OmW3xRD3WAbJTzktPC-VMOF5WMsN8XT/view?usp=sharing 54 | [segformerb2]: https://drive.google.com/file/d/1AcgEK5aWMJzpe8tsfauqhragR0nBHyPh/view?usp=sharing 55 | [segformerb1]: https://drive.google.com/file/d/18PN_P3ajcJi_5Q2v8b4BP9O4VdNCpt6m/view?usp=sharing 56 | [topformert]: https://drive.google.com/file/d/1OnS3_PwjJuNMWCKisreNxw_Lma8uR8bV/view?usp=sharing 57 | [topformers]: https://drive.google.com/file/d/19041fMb4HuDyNhIYdW1r5612FyzpexP0/view?usp=sharing 58 | [topformerb]: https://drive.google.com/file/d/1m7CxYKWAyJzl5W3cj1vwsW4DfqAb_rqz/view?usp=sharing -------------------------------------------------------------------------------- /docs/OTHER_DATASETS.md: -------------------------------------------------------------------------------- 1 | # Semantic Segmentation Datasets 2 | 3 | ## General 4 | 5 | * [COCO-Stuff](https://github.com/nightrome/cocostuff) 6 | * [PASCAL-Context](https://cs.stanford.edu/~roozbeh/pascal-context/) 7 | * [PASCAL-VOC](http://host.robots.ox.ac.uk/pascal/VOC/) 8 | * [MSeg](https://github.com/mseg-dataset/mseg-api) 9 | * [ADE20K](http://groups.csail.mit.edu/vision/datasets/ADE20K/) 10 | * [Places365](http://places2.csail.mit.edu/) 11 | 12 | ## Outdoor 13 | 14 | * [CityScapes](https://www.cityscapes-dataset.com/) 15 | * [KITTI](http://www.cvlibs.net/datasets/kitti/) 16 | * [Mapillary Vistas](https://www.mapillary.com/dataset/vistas?lat=20&lng=0&z=1.5&pKey=xyW6a0ZmrJtjLw2iJ71Oqg) 17 | * [CamVid](http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/) 18 | * [Standford Background](http://dags.stanford.edu/projects/scenedataset.html) 19 | * [ApolloScape](http://apolloscape.auto/) 20 | * [BDD100K](https://bdd-data.berkeley.edu/) 21 | * [WoodScape](https://github.com/valeoai/WoodScape) 22 | * [IDD](http://idd.insaan.iiit.ac.in/) 23 | * [DADA-2000](https://github.com/JWFangit/LOTVS-DADA) 24 | * [Street Hazards](https://github.com/hendrycks/anomaly-seg) 25 | * [UNDD](https://github.com/sauradip/night_image_semantic_segmentation) 26 | * [WildDash](https://wilddash.cc/) 27 | * [A2D2](https://www.a2d2.audi/a2d2/en/dataset.html) 28 | 29 | ## Indoor 30 | 31 | * [ScanNet](http://www.scan-net.org/) 32 | * [Sun-RGBD](https://rgbd.cs.princeton.edu/) 33 | * [SceneNet](https://robotvault.bitbucket.io/) 34 | * [2D-3D-Semantics](https://github.com/alexsax/2D-3D-Semantics) 35 | * [NYUDepthv2](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html) 36 | * [SUN3D](http://sun3d.cs.princeton.edu/) 37 | 38 | ## Human Parts 39 | 40 | * [LIP/CIHP](http://sysu-hcp.net/lip/index.php) 41 | * [MHP](https://github.com/ZhaoJ9014/Multi-Human-Parsing) 42 | * [DeepFashion2](https://github.com/switchablenorms/DeepFashion2) 43 | * [PASCAL-Person-Part](http://roozbehm.info/pascal-parts/pascal-parts.html) 44 | * [PIC](http://picdataset.com/challenge/task/download/) 45 | * [iMat](https://github.com/visipedia/imat_comp) 46 | 47 | ## Food 48 | 49 | * [FoodSeg103](https://xiongweiwu.github.io/foodseg103.html) 50 | 51 | ## Binary 52 | 53 | * [SBCoseg](http://www.mlmrlab.com/cosegmentation_dataset_downloadC.html) 54 | * [DeepFish](https://github.com/alzayats/DeepFish) 55 | * [MVTecAD](https://www.mvtec.com/company/research/datasets/mvtec-ad/) 56 | * [LLAMAS](https://unsupervised-llamas.com/llamas/) 57 | 58 | ## Boundary Segmentation 59 | 60 | * [SBD](http://home.bharathh.info/pubs/codes/SBD/download.html) 61 | * [SketchyScene](https://github.com/SketchyScene/SketchyScene) 62 | * [TextSeg](https://github.com/SHI-Labs/Rethinking-Text-Segmentation) 63 | 64 | ## Synthetic 65 | 66 | * [EDEN](https://lhoangan.github.io/eden/) 67 | * [Synscapes](https://7dlabs.com/synscapes-overview) 68 | * [SYNTHIA](https://synthia-dataset.net/) 69 | * [GTA5](https://download.visinf.tu-darmstadt.de/data/from_games/) 70 | 71 | ## Robot-view 72 | 73 | * [Robot Home](http://mapir.isa.uma.es/mapirwebsite/index.php/mapir-downloads/203-robot-at-home-dataset.html) 74 | * [RobotriX](https://github.com/3dperceptionlab/therobotrix) 75 | * [Gibson Env](http://gibsonenv.stanford.edu/) 76 | 77 | ## Medical 78 | 79 | * [BraTS2015](https://www.smir.ch/BRATS/Start2015) 80 | * [Medical-Decathlon](http://medicaldecathlon.com/) 81 | * [PROMISE12](https://promise12.grand-challenge.org/) 82 | * [REFUGE](https://bitbucket.org/woalsdnd/refuge/src/master/) 83 | * [BIMCV-COVID-19](https://github.com/BIMCV-CSUSP/BIMCV-COVID-19) 84 | * [OpenEDS](https://research.fb.com/programs/openeds-challenge) 85 | * [Retinal-Microsurgery](https://sites.google.com/site/sznitr/home) 86 | * [CoNSeP](https://warwick.ac.uk/fac/sci/dcs/research/tia/data/hovernet/) 87 | * [ISIC-2018-Task1](https://challenge2018.isic-archive.com/task1/) 88 | * [Cata7](https://github.com/nizhenliang/RAUNet) 89 | * [ROSE](https://imed.nimte.ac.cn/dataofrose.html) 90 | * [SegTHOR](https://competitions.codalab.org/competitions/21145) 91 | * [CAMEL](https://github.com/ThoroughImages/CAMEL) 92 | * [CryoNuSeg](https://github.com/masih4/CryoNuSeg) 93 | * [OpenEDS2020](https://research.fb.com/programs/openeds-2020-challenge/) 94 | * [VocalFolds](https://github.com/imesluh/vocalfolds) 95 | * [Medico](https://multimediaeval.github.io/editions/2020/tasks/medico/) 96 | * [20MioEyeDS](https://unitc-my.sharepoint.com/personal/iitfu01_cloud_uni-tuebingen_de/_layouts/15/onedrive.aspx?id=%2Fpersonal%2Fiitfu01%5Fcloud%5Funi%2Dtuebingen%5Fde%2FDocuments%2F20MioEyeDS&originalPath=aHR0cHM6Ly91bml0Yy1teS5zaGFyZXBvaW50LmNvbS86ZjovZy9wZXJzb25hbC9paXRmdTAxX2Nsb3VkX3VuaS10dWViaW5nZW5fZGUvRXZyTlBkdGlnRlZIdENNZUZLU3lMbFVCZXBPY2JYMG5Fa2Ftd2VlWmEwczlTUT9ydGltZT1zcWtvTV9CYzJVZw) 97 | * [BrainMRI](https://www.kaggle.com/mateuszbuda/lgg-mri-segmentation) 98 | * [Liver Tumor](https://www.kaggle.com/andrewmvd/liver-tumor-segmentation) 99 | * [MRI Hippocampus](https://www.kaggle.com/sabermalek/mrihs) 100 | 101 | ## Aerial 102 | 103 | * [RIT-18](https://github.com/rmkemker/RIT-18) 104 | * [PolSF](https://github.com/liuxuvip/PolSF) 105 | * [AIRS](https://www.airs-dataset.com/) 106 | * [UOPNOA](https://zenodo.org/record/4648002) 107 | * [LandCover](https://landcover.ai/) 108 | * [ICG](https://www.kaggle.com/bulentsiyah/semantic-drone-dataset) 109 | 110 | ## Video 111 | 112 | * [DAVIS](https://davischallenge.org/) 113 | * [SESIV](https://sites.google.com/view/ltnghia/research/sesiv) 114 | * [YouTube-VOS](https://youtube-vos.org/) 115 | 116 | ## Others 117 | 118 | * [SUIM](http://irvlab.cs.umn.edu/resources/suim-dataset) 119 | * [Cam2BEV](https://github.com/ika-rwth-aachen/Cam2BEV) 120 | * [LabPics](https://www.kaggle.com/sagieppel/labpics-chemistry-labpics-medical) 121 | * [CreativeFlow+](https://www.cs.toronto.edu/creativeflow/) 122 | * [RoadAnomaly21](https://segmentmeifyoucan.com/datasets) 123 | * [RoadObstacle21](https://segmentmeifyoucan.com/datasets) 124 | * [HouseExpo](https://github.com/teaganli/houseexpo/) 125 | * [D2S](https://www.mvtec.com/company/research/datasets/mvtec-d2s/) -------------------------------------------------------------------------------- /scripts/calc_class_weights.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import io 3 | 4 | 5 | def calc_class_weights(files, n_classes): 6 | pixels = {} 7 | for file in files: 8 | lbl_path = str(file).split('.')[0].replace('images', 'labels') 9 | label = io.read_image(lbl_path) 10 | for i in range(n_classes): 11 | if pixels.get(i) is not None: 12 | pixels[i] += [(label == i).sum()] 13 | else: 14 | pixels[i] = [(label == i).sum()] 15 | 16 | class_freq = torch.tensor([sum(v).item() for v in pixels.values()]) 17 | weights = 1 / torch.log1p(class_freq) 18 | weights *= n_classes 19 | weights /= weights.sum() 20 | return weights -------------------------------------------------------------------------------- /scripts/export_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from PIL import Image 4 | from pathlib import Path 5 | from tqdm import tqdm 6 | 7 | 8 | def create_calibrate_data(image_folder, save_path): 9 | dataset = [] 10 | mean = np.array([0.485, 0.456, 0.406])[None, None, :] 11 | std = np.array([0.229, 0.224, 0.225])[None, None, :] 12 | files = list(Path(image_folder).glob('*.jpg'))[:100] 13 | for file in tqdm(files): 14 | image = Image.open(file).convert('RGB') 15 | image = image.resize((512, 512)) 16 | image = np.array(image, dtype=np.float32) 17 | image /= 255 18 | image -= mean 19 | image /= std 20 | dataset.append(image) 21 | dataset = np.stack(dataset, axis=0) 22 | np.save(save_path, dataset) 23 | 24 | 25 | if __name__ == '__main__': 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--dataset-path', type=str, default='/home/sithu/datasets/SmithCVPR2013_dataset_resized/images') 28 | parser.add_argument('--save-path', type=str, default='output/calibrate_data') 29 | args = parser.parse_args() 30 | 31 | create_calibrate_data(args.dataset_path, args.save_path) 32 | 33 | -------------------------------------------------------------------------------- /scripts/onnx_infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import onnxruntime 4 | from PIL import Image 5 | from semseg.utils.visualize import generate_palette 6 | from semseg.utils.utils import timer 7 | 8 | 9 | class Inference: 10 | def __init__(self, model: str) -> None: 11 | self.session = onnxruntime.InferenceSession(model) 12 | self.input_details = self.session.get_inputs()[0] 13 | self.palette = generate_palette(self.session.get_outputs()[0].shape[1], background=True) 14 | self.img_size = self.input_details.shape[-2:] 15 | self.mean = np.array([0.485, 0.456, 0.406]).reshape(-1, 1, 1) 16 | self.std = np.array([0.229, 0.224, 0.225]).reshape(-1, 1, 1) 17 | 18 | def preprocess(self, image: Image.Image) -> np.ndarray: 19 | image = image.resize(self.img_size) 20 | image = np.array(image, dtype=np.float32).transpose(2, 0, 1) 21 | image /= 255 22 | image -= self.mean 23 | image /= self.std 24 | image = image[np.newaxis, ...] 25 | return image 26 | 27 | def postprocess(self, seg_map: np.ndarray) -> np.ndarray: 28 | seg_map = np.argmax(seg_map, axis=1).astype(int) 29 | seg_map = self.palette[seg_map] 30 | return seg_map.squeeze() 31 | 32 | @timer 33 | def model_forward(self, img: np.ndarray) -> np.ndarray: 34 | return self.session.run(None, {self.input_details.name: img})[0] 35 | 36 | def predict(self, img_path: str) -> Image.Image: 37 | image = Image.open(img_path).convert('RGB') 38 | image = self.preprocess(image) 39 | seg_map = self.model_forward(image) 40 | seg_map = self.postprocess(seg_map) 41 | return seg_map.astype(np.uint8) 42 | 43 | 44 | if __name__ == '__main__': 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('--model', type=str, default='output/DDRNet_23slim_HELEN_59_75.onnx') 47 | parser.add_argument('--img-path', type=str, default='assests/faces/27409477_1.jpg') 48 | args = parser.parse_args() 49 | 50 | session = Inference(args.model) 51 | seg_map = session.predict(args.img_path) 52 | seg_map = Image.fromarray(seg_map) 53 | seg_map.save(f"{args.img_path.split('.')[0]}_out.png") -------------------------------------------------------------------------------- /scripts/openvino_infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from PIL import Image 4 | from pathlib import Path 5 | from openvino.inference_engine import IECore 6 | from semseg.utils.visualize import generate_palette 7 | from semseg.utils.utils import timer 8 | 9 | 10 | class Inference: 11 | def __init__(self, model: str) -> None: 12 | files = Path(model).iterdir() 13 | 14 | for file in files: 15 | if file.suffix == '.xml': 16 | model = str(file) 17 | elif file.suffix == '.bin': 18 | weights = str(file) 19 | ie = IECore() 20 | model = ie.read_network(model=model, weights=weights) 21 | self.input_info = next(iter(model.input_info)) 22 | self.output_info = next(iter(model.outputs)) 23 | self.img_size = model.input_info['input'].input_data.shape[-2:] 24 | self.palette = generate_palette(11, background=True) 25 | self.engine = ie.load_network(network=model, device_name='CPU') 26 | 27 | self.mean = np.array([0.485, 0.456, 0.406]).reshape(-1, 1, 1) 28 | self.std = np.array([0.229, 0.224, 0.225]).reshape(-1, 1, 1) 29 | 30 | def preprocess(self, image: Image.Image) -> np.ndarray: 31 | image = image.resize(self.img_size) 32 | image = np.array(image, dtype=np.float32).transpose(2, 0, 1) 33 | image /= 255 34 | image -= self.mean 35 | image /= self.std 36 | image = image[np.newaxis, ...] 37 | return image 38 | 39 | def postprocess(self, seg_map: np.ndarray) -> np.ndarray: 40 | seg_map = np.argmax(seg_map, axis=1).astype(int) 41 | seg_map = self.palette[seg_map] 42 | return seg_map.squeeze() 43 | 44 | @timer 45 | def model_forward(self, img: np.ndarray) -> np.ndarray: 46 | return self.engine.infer(inputs={self.input_info: img})[self.output_info] 47 | 48 | def predict(self, img_path: str) -> Image.Image: 49 | image = Image.open(img_path).convert('RGB') 50 | image = self.preprocess(image) 51 | seg_map = self.model_forward(image) 52 | seg_map = self.postprocess(seg_map) 53 | return seg_map.astype(np.uint8) 54 | 55 | 56 | if __name__ == '__main__': 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument('--model', type=str, default='output/ddrnet_openvino') 59 | parser.add_argument('--img-path', type=str, default='assests/faces/27409477_1.jpg') 60 | args = parser.parse_args() 61 | 62 | session = Inference(args.model) 63 | seg_map = session.predict(args.img_path) 64 | seg_map = Image.fromarray(seg_map) 65 | seg_map.save(f"{args.img_path.split('.')[0]}_out.png") 66 | -------------------------------------------------------------------------------- /scripts/preprocess_celebamaskhq.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from tqdm import tqdm 4 | from pathlib import Path 5 | from PIL import Image 6 | 7 | 8 | def main(root): 9 | root = Path(root) 10 | annot_dir = root / 'CelebAMask-HQ-label' 11 | annot_dir.mkdir(exist_ok=True) 12 | 13 | train_lists = [] 14 | test_lists = [] 15 | val_lists = [] 16 | 17 | names = [ 18 | 'skin', 'nose', 'eye_g', 'l_eye', 'r_eye', 'l_brow', 'r_brow', 'l_ear', 'r_ear', 19 | 'mouth', 'u_lip', 'l_lip', 'hair', 'hat', 'ear_r', 'neck_l', 'neck', 'cloth' 20 | ] 21 | num_images = 30000 22 | 23 | for folder in root.iterdir(): 24 | if folder.is_dir(): 25 | if folder.name == 'CelebAMask-HQ-mask-anno': 26 | print("Transforming separate masks into one-hot mask...") 27 | for i in tqdm(range(num_images)): 28 | folder_num = i // 2000 29 | label = np.zeros((512, 512)) 30 | for idx, name in enumerate(names): 31 | fname = folder / f"{folder_num}" / f"{str(i).rjust(5, '0')}_{name}.png" 32 | if fname.exists(): 33 | img = Image.open(fname).convert('P') 34 | img = np.array(img) 35 | label[img != 0] = idx + 1 36 | 37 | label = Image.fromarray(label.astype(np.uint8)) 38 | label.save(annot_dir / f"{i}.png") 39 | 40 | print("Splitting into train/val/test...") 41 | 42 | with open(root / "CelebA-HQ-to-CelebA-mapping.txt") as f: 43 | lines = f.read().splitlines()[1:] 44 | image_list = [int(line.split()[1]) for line in lines] 45 | 46 | 47 | for idx, fname in enumerate(image_list): 48 | if fname >= 162771 and fname < 182638: 49 | val_lists.append(f"{idx}\n") 50 | 51 | elif fname >= 182638: 52 | test_lists.append(f"{idx}\n") 53 | 54 | else: 55 | train_lists.append(f"{idx}\n") 56 | 57 | print(f"Train Size: {len(train_lists)}") 58 | print(f"Val Size: {len(val_lists)}") 59 | print(f"Test Size: {len(test_lists)}") 60 | 61 | with open(root / 'train_list.txt', 'w') as f: 62 | f.writelines(train_lists) 63 | 64 | with open(root / 'val_list.txt', 'w') as f: 65 | f.writelines(val_lists) 66 | 67 | with open(root / 'test_list.txt', 'w') as f: 68 | f.writelines(test_lists) 69 | 70 | 71 | if __name__ == '__main__': 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument('--root', type=str, default='/home/sithu/datasets/CelebAMask-HQ') 74 | args = parser.parse_args() 75 | main(args.root) -------------------------------------------------------------------------------- /scripts/tflite_infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import tflite_runtime.interpreter as tflite 4 | from PIL import Image 5 | from semseg.utils.visualize import generate_palette 6 | from semseg.utils.utils import timer 7 | 8 | 9 | class Inference: 10 | def __init__(self, model: str) -> None: 11 | self.interpreter = tflite.Interpreter(model) 12 | self.interpreter.allocate_tensors() 13 | 14 | self.input_details = self.interpreter.get_input_details()[0] 15 | self.output_details = self.interpreter.get_output_details()[0] 16 | self.palette = generate_palette(self.output_details['shape'][-1], background=True) 17 | self.img_size = self.input_details['shape'][1:3] 18 | self.mean = np.array([0.485, 0.456, 0.406])[None, None, :] 19 | self.std = np.array([0.229, 0.224, 0.225])[None, None, :] 20 | 21 | def preprocess(self, image: Image.Image) -> np.ndarray: 22 | image = image.resize(self.img_size) 23 | image = np.array(image, dtype=np.float32) 24 | image /= 255 25 | image -= self.mean 26 | image /= self.std 27 | if self.input_details['dtype'] == np.int8 or self.input_details['dtype'] == np.uint8: 28 | scale, zero_point = self.input_details['quantization'] 29 | image /= scale 30 | image += zero_point 31 | image = image.astype(self.input_details['dtype']) 32 | return image[np.newaxis, ...] 33 | 34 | def postprocess(self, seg_map: np.ndarray) -> np.ndarray: 35 | if self.output_details['dtype'] == np.int8 or self.output_details['dtype'] == np.uint8: 36 | scale, zero_point = self.output_details['quantization'] 37 | seg_map = scale * (seg_map - zero_point) 38 | seg_map = np.argmax(seg_map, axis=-1).astype(int) 39 | seg_map = self.palette[seg_map] 40 | return seg_map.squeeze() 41 | 42 | @timer 43 | def model_forward(self, img: np.ndarray) -> np.ndarray: 44 | self.interpreter.set_tensor(self.input_details['index'], img) 45 | self.interpreter.invoke() 46 | return self.interpreter.get_tensor(self.output_details['index']) 47 | 48 | def predict(self, img_path: str) -> Image.Image: 49 | image = Image.open(img_path).convert('RGB') 50 | image = self.preprocess(image) 51 | seg_map = self.model_forward(image) 52 | seg_map = self.postprocess(seg_map) 53 | return seg_map.astype(np.uint8) 54 | 55 | 56 | if __name__ == '__main__': 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument('--model', type=str, default='output/ddrnet_tflite2/ddrnet_float16.tflite') 59 | parser.add_argument('--img-path', type=str, default='assests/faces/27409477_1.jpg') 60 | args = parser.parse_args() 61 | 62 | session = Inference(args.model) 63 | seg_map = session.predict(args.img_path) 64 | seg_map = Image.fromarray(seg_map) 65 | seg_map.save(f"{args.img_path.split('.')[0]}_out.png") -------------------------------------------------------------------------------- /semseg/__init__.py: -------------------------------------------------------------------------------- 1 | from tabulate import tabulate 2 | from semseg import models 3 | from semseg import datasets 4 | from semseg.models import backbones, heads 5 | 6 | 7 | def show_models(): 8 | model_names = models.__all__ 9 | numbers = list(range(1, len(model_names)+1)) 10 | print(tabulate({'No.': numbers, 'Model Names': model_names}, headers='keys')) 11 | 12 | 13 | def show_backbones(): 14 | backbone_names = backbones.__all__ 15 | variants = [] 16 | for name in backbone_names: 17 | try: 18 | variants.append(list(eval(f"backbones.{name.lower()}_settings").keys())) 19 | except: 20 | variants.append('-') 21 | print(tabulate({'Backbone Names': backbone_names, 'Variants': variants}, headers='keys')) 22 | 23 | 24 | def show_heads(): 25 | head_names = heads.__all__ 26 | numbers = list(range(1, len(head_names)+1)) 27 | print(tabulate({'No.': numbers, 'Heads': head_names}, headers='keys')) 28 | 29 | 30 | def show_datasets(): 31 | dataset_names = datasets.__all__ 32 | numbers = list(range(1, len(dataset_names)+1)) 33 | print(tabulate({'No.': numbers, 'Datasets': dataset_names}, headers='keys')) 34 | -------------------------------------------------------------------------------- /semseg/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .ade20k import ADE20K 2 | from .camvid import CamVid 3 | from .cityscapes import CityScapes 4 | from .pascalcontext import PASCALContext 5 | from .cocostuff import COCOStuff 6 | from .sunrgbd import SunRGBD 7 | from .mapillary import MapillaryVistas 8 | from .mhpv1 import MHPv1 9 | from .mhpv2 import MHPv2 10 | from .lip import LIP 11 | from .cihp import CIHP, CCIHP 12 | from .atr import ATR 13 | from .suim import SUIM 14 | from .helen import HELEN 15 | from .lapa import LaPa 16 | from .ibugmask import iBugMask 17 | from .celebamaskhq import CelebAMaskHQ 18 | from .facesynthetics import FaceSynthetics 19 | 20 | 21 | __all__ = [ 22 | 'CamVid', 23 | 'CityScapes', 24 | 'ADE20K', 25 | 'MHPv1', 26 | 'MHPv2', 27 | 'LIP', 28 | 'CIHP', 29 | 'CCIHP', 30 | 'ATR', 31 | 'PASCALContext', 32 | 'COCOStuff', 33 | 'SUIM', 34 | 'SunRGBD', 35 | 'MapillaryVistas', 36 | 'HELEN', 37 | 'LaPa', 38 | 'iBugMask', 39 | 'CelebAMaskHQ', 40 | 'FaceSynthetics', 41 | ] -------------------------------------------------------------------------------- /semseg/datasets/ade20k.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.utils.data import Dataset 4 | from torchvision import io 5 | from pathlib import Path 6 | from typing import Tuple 7 | 8 | 9 | class ADE20K(Dataset): 10 | CLASSES = [ 11 | 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ', 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', 12 | 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', 13 | 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', 14 | 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', 15 | 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', 16 | 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', 17 | 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver', 18 | 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', 19 | 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent', 20 | 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', 21 | 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen', 22 | 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', 'clock', 'flag' 23 | ] 24 | 25 | PALETTE = torch.tensor([ 26 | [120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], 27 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], 28 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], 29 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], 30 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], 31 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], 32 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], 33 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], 34 | [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], 35 | [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], 36 | [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], 37 | [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], 38 | [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], 39 | [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], 40 | [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], 41 | [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], 42 | [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], 43 | [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], 44 | [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], [102, 255, 0], [92, 0, 255] 45 | ]) 46 | 47 | def __init__(self, root: str, split: str = 'train', transform = None) -> None: 48 | super().__init__() 49 | assert split in ['train', 'val'] 50 | split = 'training' if split == 'train' else 'validation' 51 | self.transform = transform 52 | self.n_classes = len(self.CLASSES) 53 | self.ignore_label = -1 54 | 55 | img_path = Path(root) / 'images' / split 56 | self.files = list(img_path.glob('*.jpg')) 57 | 58 | if not self.files: 59 | raise Exception(f"No images found in {img_path}") 60 | print(f"Found {len(self.files)} {split} images.") 61 | 62 | def __len__(self) -> int: 63 | return len(self.files) 64 | 65 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: 66 | img_path = str(self.files[index]) 67 | lbl_path = str(self.files[index]).replace('images', 'annotations').replace('.jpg', '.png') 68 | 69 | image = io.read_image(img_path) 70 | label = io.read_image(lbl_path) 71 | 72 | if self.transform: 73 | image, label = self.transform(image, label) 74 | return image, label.squeeze().long() - 1 75 | 76 | 77 | if __name__ == '__main__': 78 | from semseg.utils.visualize import visualize_dataset_sample 79 | visualize_dataset_sample(ADE20K, '/home/sithu/datasets/ADEChallenge/ADEChallengeData2016') -------------------------------------------------------------------------------- /semseg/datasets/atr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.utils.data import Dataset 4 | from torchvision import io 5 | from pathlib import Path 6 | from typing import Tuple 7 | 8 | 9 | class ATR(Dataset): 10 | """Single Person Fashion Dataset 11 | https://openaccess.thecvf.com/content_iccv_2015/papers/Liang_Human_Parsing_With_ICCV_2015_paper.pdf 12 | 13 | https://github.com/lemondan/HumanParsing-Dataset 14 | num_classes: 17+background 15 | 16000 train images 16 | 700 val images 17 | 1000 test images with labels 18 | """ 19 | CLASSES = ['background', 'hat', 'hair', 'sunglass', 'upper-clothes', 'skirt', 'pants', 'dress', 'belt', 'left-shoe', 'right-shoe', 'face', 'left-leg', 'right-leg', 'left-arm', 'right-arm', 'bag', 'scarf'] 20 | PALETTE = torch.tensor([[0, 0, 0], [127, 0, 0], [254, 0, 0], [0, 84, 0], [169, 0, 50], [254, 84, 0], [255, 0, 84], [0, 118, 220], [84, 84, 0], [0, 84, 84], [84, 50, 0], [51, 85, 127], [0, 127, 0], [0, 0, 254], [50, 169, 220], [0, 254, 254], [84, 254, 169], [169, 254, 84]]) 21 | 22 | def __init__(self, root: str, split: str = 'train', transform = None) -> None: 23 | super().__init__() 24 | assert split in ['train', 'val', 'test'] 25 | self.transform = transform 26 | self.n_classes = len(self.CLASSES) 27 | self.ignore_label = 255 28 | 29 | img_path = Path(root) / 'humanparsing' / 'JPEGImages' 30 | self.files = list(img_path.glob('*.jpg')) 31 | if split == 'train': 32 | self.files = self.files[:16000] 33 | elif split == 'val': 34 | self.files = self.files[16000:16700] 35 | else: 36 | self.files = self.files[16700:17700] 37 | 38 | if not self.files: 39 | raise Exception(f"No images found in {img_path}") 40 | print(f"Found {len(self.files)} {split} images.") 41 | 42 | def __len__(self) -> int: 43 | return len(self.files) 44 | 45 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: 46 | img_path = str(self.files[index]) 47 | lbl_path = str(self.files[index]).replace('JPEGImages', 'SegmentationClassAug').replace('.jpg', '.png') 48 | 49 | image = io.read_image(img_path) 50 | label = io.read_image(lbl_path) 51 | 52 | if self.transform: 53 | image, label = self.transform(image, label) 54 | return image, label.squeeze().long() 55 | 56 | 57 | if __name__ == '__main__': 58 | from semseg.utils.visualize import visualize_dataset_sample 59 | visualize_dataset_sample(ATR, '/home/sithu/datasets/LIP/ATR') -------------------------------------------------------------------------------- /semseg/datasets/camvid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.utils.data import Dataset 4 | from torchvision import io 5 | from pathlib import Path 6 | from typing import Tuple 7 | 8 | 9 | class CamVid(Dataset): 10 | """ 11 | num_classes: 11 12 | all_num_classes: 31 13 | """ 14 | CLASSES = ['Sky', 'Building', 'Pole', 'Road', 'Pavement', 'Tree', 'SignSymbol', 'Fence', 'Car', 'Pedestrian', 'Bicyclist'] 15 | CLASSES_ALL = ['Wall', 'Animal', 'Archway', 'Bicyclist', 'Bridge', 'Building', 'Car', 'CarLuggage', 'Child', 'Pole', 'Fence', 'LaneDrive', 'LaneNonDrive', 'MiscText', 'Motorcycle/Scooter', 'OtherMoving', 'ParkingBlock', 'Pedestrian', 'Road', 'RoadShoulder', 'Sidewalk', 'SignSymbol', 'Sky', 'SUV/PickupTruck', 'TrafficCone', 'TrafficLight', 'Train', 'Tree', 'Truck/Bus', 'Tunnel', 'VegetationMisc'] 16 | PALETTE = torch.tensor([[128, 128, 128], [128, 0, 0], [192, 192, 128], [128, 64, 128], [0, 0, 192], [128, 128, 0], [192, 128, 128], [64, 64, 128], [64, 0, 128], [64, 64, 0], [0, 128, 192]]) 17 | PALETTE_ALL = torch.tensor([[64, 192, 0], [64, 128, 64], [192, 0, 128], [0, 128, 192], [0, 128, 64], [128, 0, 0], [64, 0, 128], [64, 0, 192], [192, 128, 64], [192, 192, 128], [64, 64, 128], [128, 0, 192], [192, 0, 64], [128, 128, 64], [192, 0, 192], [128, 64, 64], [64, 192, 128], [64, 64, 0], [128, 64, 128], [128, 128, 192], [0, 0, 192], [192, 128, 128], [128, 128, 128], [64, 128, 192], [0, 0, 64], [0, 64, 64], [192, 64, 128], [128, 128, 0], [192, 128, 192], [64, 0, 64], [192, 192, 0]]) 18 | 19 | def __init__(self, root: str, split: str = 'train', transform = None) -> None: 20 | super().__init__() 21 | assert split in ['train', 'val', 'test'] 22 | self.split = split 23 | self.transform = transform 24 | self.n_classes = len(self.CLASSES) 25 | self.ignore_label = -1 26 | 27 | img_path = Path(root) / split 28 | self.files = list(img_path.glob("*.png")) 29 | 30 | if not self.files: 31 | raise Exception(f"No images found in {img_path}") 32 | print(f"Found {len(self.files)} {split} images.") 33 | 34 | def __len__(self) -> int: 35 | return len(self.files) 36 | 37 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: 38 | img_path = str(self.files[index]) 39 | lbl_path = str(self.files[index]).replace(self.split, self.split + '_labels').replace('.png', '_L.png') 40 | 41 | image = io.read_image(img_path) 42 | label = io.read_image(lbl_path) 43 | 44 | if self.transform: 45 | image, label = self.transform(image, label) 46 | return image, self.encode(label).long() - 1 47 | 48 | def encode(self, label: Tensor) -> Tensor: 49 | label = label.permute(1, 2, 0) 50 | mask = torch.zeros(label.shape[:-1]) 51 | 52 | for index, color in enumerate(self.PALETTE): 53 | bool_mask = torch.eq(label, color) 54 | class_map = torch.all(bool_mask, dim=-1) 55 | mask[class_map] = index + 1 56 | return mask 57 | 58 | 59 | if __name__ == '__main__': 60 | from semseg.utils.visualize import visualize_dataset_sample 61 | visualize_dataset_sample(CamVid, '/home/sithu/datasets/CamVid') -------------------------------------------------------------------------------- /semseg/datasets/celebamaskhq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.utils.data import Dataset 4 | from torchvision import io 5 | from pathlib import Path 6 | from typing import Tuple 7 | from torchvision import transforms as T 8 | 9 | 10 | class CelebAMaskHQ(Dataset): 11 | CLASSES = [ 12 | 'background', 'skin', 'nose', 'eye_g', 'l_eye', 'r_eye', 'l_brow', 'r_brow', 'l_ear', 13 | 'r_ear', 'mouth', 'u_lip', 'l_lip', 'hair', 'hat', 'ear_r', 'neck_l', 'neck', 'cloth' 14 | ] 15 | PALETTE = torch.tensor([ 16 | [0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], [255, 204, 204], [102, 51, 0], [255, 0, 0], 17 | [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0] 18 | ]) 19 | 20 | def __init__(self, root: str, split: str = 'train', transform = None) -> None: 21 | super().__init__() 22 | assert split in ['train', 'val', 'test'] 23 | self.root = Path(root) 24 | self.transform = transform 25 | self.n_classes = len(self.CLASSES) 26 | self.ignore_label = 255 27 | self.resize = T.Resize((512, 512)) 28 | 29 | with open(self.root / f'{split}_list.txt') as f: 30 | self.files = f.read().splitlines() 31 | 32 | if not self.files: 33 | raise Exception(f"No images found in {root}") 34 | print(f"Found {len(self.files)} {split} images.") 35 | 36 | def __len__(self) -> int: 37 | return len(self.files) 38 | 39 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: 40 | img_path = self.root / 'CelebA-HQ-img' / f"{self.files[index]}.jpg" 41 | lbl_path = self.root / 'CelebAMask-HQ-label' / f"{self.files[index]}.png" 42 | image = io.read_image(str(img_path)) 43 | image = self.resize(image) 44 | label = io.read_image(str(lbl_path)) 45 | 46 | if self.transform: 47 | image, label = self.transform(image, label) 48 | return image, label.squeeze().long() 49 | 50 | 51 | if __name__ == '__main__': 52 | from semseg.utils.visualize import visualize_dataset_sample 53 | visualize_dataset_sample(CelebAMaskHQ, '/home/sithu/datasets/CelebAMask-HQ') -------------------------------------------------------------------------------- /semseg/datasets/cihp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.utils.data import Dataset 4 | from torchvision import io 5 | from pathlib import Path 6 | from typing import Tuple 7 | 8 | 9 | class CIHP(Dataset): 10 | """This has Best Human Parsing Labels 11 | num_classes: 19+background 12 | 28280 train images 13 | 5000 val images 14 | """ 15 | CLASSES = ['background', 'hat', 'hair', 'glove', 'sunglasses', 'upperclothes', 'dress', 'coat', 'socks', 'pants', 'jumpsuits', 'scarf', 'skirt', 'face', 'left-arm', 'right-arm', 'left-leg', 'right-leg', 'left-shoe', 'right-shoe'] 16 | PALETTE = torch.tensor([[120, 120, 120], [127, 0, 0], [254, 0, 0], [0, 84, 0], [169, 0, 50], [254, 84, 0], [255, 0, 84], [0, 118, 220], [84, 84, 0], [0, 84, 84], [84, 50, 0], [51, 85, 127], [0, 127, 0], [0, 0, 254], [50, 169, 220], [0, 254, 254], [84, 254, 169], [169, 254, 84], [254, 254, 0], [254, 169, 0]]) 17 | 18 | def __init__(self, root: str, split: str = 'train', transform = None) -> None: 19 | super().__init__() 20 | assert split in ['train', 'val'] 21 | split = 'Training' if split == 'train' else 'Validation' 22 | self.transform = transform 23 | self.n_classes = len(self.CLASSES) 24 | self.ignore_label = 255 25 | 26 | img_path = Path(root) / 'instance-level_human_parsing' / split / 'Images' 27 | self.files = list(img_path.glob('*.jpg')) 28 | 29 | if not self.files: 30 | raise Exception(f"No images found in {img_path}") 31 | print(f"Found {len(self.files)} {split} images.") 32 | 33 | def __len__(self) -> int: 34 | return len(self.files) 35 | 36 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: 37 | img_path = str(self.files[index]) 38 | lbl_path = str(self.files[index]).replace('Images', 'Category_ids').replace('.jpg', '.png') 39 | 40 | image = io.read_image(img_path) 41 | label = io.read_image(lbl_path) 42 | 43 | if self.transform: 44 | image, label = self.transform(image, label) 45 | return image, label.squeeze().long() 46 | 47 | 48 | class CCIHP(CIHP): 49 | CLASSES = ['background', 'hat', 'hair', 'glove', 'sunglasses', 'upperclothes', 'facemask', 'coat', 'socks', 'pants', 'torso-skin', 'scarf', 'skirt', 'face', 'left-arm', 'right-arm', 'left-leg', 'right-leg', 'left-shoe', 'right-shoe', 'bag', 'others'] 50 | PALETTE = torch.tensor([[120, 120, 120], [127, 0, 0], [254, 0, 0], [0, 84, 0], [169, 0, 50], [254, 84, 0], [255, 0, 84], [0, 118, 220], [84, 84, 0], [0, 84, 84], [84, 50, 0], [51, 85, 127], [0, 127, 0], [0, 0, 254], [50, 169, 220], [0, 254, 254], [84, 254, 169], [169, 254, 84], [254, 254, 0], [254, 169, 0], [102, 254, 0], [182, 255, 0]]) 51 | 52 | 53 | if __name__ == '__main__': 54 | import sys 55 | sys.path.insert(0, '.') 56 | from semseg.utils.visualize import visualize_dataset_sample 57 | visualize_dataset_sample(CCIHP, 'C:\\Users\\sithu\\Documents\\Datasets\\LIP\\CIHP') -------------------------------------------------------------------------------- /semseg/datasets/cityscapes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import Tensor 4 | from torch.utils.data import Dataset 5 | from torchvision import io 6 | from pathlib import Path 7 | from typing import Tuple 8 | 9 | 10 | class CityScapes(Dataset): 11 | """ 12 | num_classes: 19 13 | """ 14 | CLASSES = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign', 'vegetation', 15 | 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 'bicycle'] 16 | 17 | PALETTE = torch.tensor([[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], [107, 142, 35], 18 | [152, 251, 152], [70, 130, 180], [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]]) 19 | 20 | ID2TRAINID = {0: 255, 1: 255, 2: 255, 3: 255, 4: 255, 5: 255, 6: 255, 7: 0, 8: 1, 9: 255, 10: 255, 11: 2, 12: 3, 13: 4, 14: 255, 15: 255, 16: 255, 21 | 17: 5, 18: 255, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 26: 13, 27: 14, 28: 15, 29: 255, 30: 255, 31: 16, 32: 17, 33: 18, -1: -1} 22 | 23 | def __init__(self, root: str, split: str = 'train', transform = None) -> None: 24 | super().__init__() 25 | assert split in ['train', 'val', 'test'] 26 | self.transform = transform 27 | self.n_classes = len(self.CLASSES) 28 | self.ignore_label = 255 29 | 30 | self.label_map = np.arange(256) 31 | for id, trainid in self.ID2TRAINID.items(): 32 | self.label_map[id] = trainid 33 | 34 | img_path = Path(root) / 'leftImg8bit' / split 35 | self.files = list(img_path.rglob('*.png')) 36 | 37 | if not self.files: 38 | raise Exception(f"No images found in {img_path}") 39 | print(f"Found {len(self.files)} {split} images.") 40 | 41 | def __len__(self) -> int: 42 | return len(self.files) 43 | 44 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: 45 | img_path = str(self.files[index]) 46 | lbl_path = str(self.files[index]).replace('leftImg8bit', 'gtFine').replace('.png', '_labelIds.png') 47 | 48 | image = io.read_image(img_path) 49 | label = io.read_image(lbl_path) 50 | 51 | if self.transform: 52 | image, label = self.transform(image, label) 53 | return image, self.encode(label.squeeze().numpy()).long() 54 | 55 | def encode(self, label: Tensor) -> Tensor: 56 | label = self.label_map[label] 57 | return torch.from_numpy(label) 58 | 59 | 60 | if __name__ == '__main__': 61 | from semseg.utils.visualize import visualize_dataset_sample 62 | visualize_dataset_sample(CityScapes, '/home/sithu/datasets/CityScapes') 63 | -------------------------------------------------------------------------------- /semseg/datasets/facesynthetics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.utils.data import Dataset 4 | from torchvision import io 5 | from pathlib import Path 6 | from typing import Tuple 7 | 8 | 9 | class FaceSynthetics(Dataset): 10 | CLASSES = ['background', 'skin', 'nose', 'r-eye', 'l-eye', 'r-brow', 'l-brow', 'r-ear', 'l-ear', 'i-mouth', 't-lip', 'b-lip', 'neck', 'hair', 'beard', 'clothing', 'glasses', 'headwear', 'facewear'] 11 | PALETTE = torch.tensor([ 12 | [0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], [255, 204, 204], [102, 51, 0], [255, 0, 0], 13 | [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0] 14 | ]) 15 | 16 | def __init__(self, root: str, split: str = 'train', transform = None) -> None: 17 | super().__init__() 18 | assert split in ['train', 'val', 'test'] 19 | if split == 'train': 20 | split = 'dataset_100000' 21 | elif split == 'val': 22 | split = 'dataset_1000' 23 | else: 24 | split = 'dataset_100' 25 | 26 | self.transform = transform 27 | self.n_classes = len(self.CLASSES) 28 | self.ignore_label = 255 29 | 30 | img_path = Path(root) / split 31 | images = img_path.glob('*.png') 32 | self.files = [path for path in images if '_seg' not in path.name] 33 | 34 | if not self.files: raise Exception(f"No images found in {root}") 35 | print(f"Found {len(self.files)} {split} images.") 36 | 37 | def __len__(self) -> int: 38 | return len(self.files) 39 | 40 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: 41 | img_path = str(self.files[index]) 42 | lbl_path = str(self.files[index]).replace('.png', '_seg.png') 43 | image = io.read_image(str(img_path)) 44 | label = io.read_image(str(lbl_path)) 45 | 46 | if self.transform: 47 | image, label = self.transform(image, label) 48 | return image, label.squeeze().long() 49 | 50 | 51 | if __name__ == '__main__': 52 | import sys 53 | sys.path.insert(0, '.') 54 | from semseg.utils.visualize import visualize_dataset_sample 55 | visualize_dataset_sample(FaceSynthetics, 'C:\\Users\\sithu\\Documents\\Datasets') -------------------------------------------------------------------------------- /semseg/datasets/helen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.utils.data import Dataset 4 | from torchvision import io 5 | from pathlib import Path 6 | from typing import Tuple 7 | 8 | 9 | class HELEN(Dataset): 10 | CLASSES = ['background', 'skin', 'l-brow', 'r-brow', 'l-eye', 'r-eye', 'nose', 'u-lip', 'i-mouth', 'l-lip', 'hair'] 11 | PALETTE = torch.tensor([[0, 0 ,0], [127, 0, 0], [254, 0, 0], [0, 84, 0], [169, 0, 50], [254, 84, 0], [255, 0, 84], [0, 118, 220], [84, 84, 0], [0, 84, 84], [84, 50, 0]]) 12 | 13 | def __init__(self, root: str, split: str = 'train', transform = None) -> None: 14 | super().__init__() 15 | assert split in ['train', 'val', 'test'] 16 | self.transform = transform 17 | self.n_classes = len(self.CLASSES) 18 | self.ignore_label = 255 19 | 20 | self.files = self.get_files(root, split) 21 | if not self.files: raise Exception(f"No images found in {root}") 22 | print(f"Found {len(self.files)} {split} images.") 23 | 24 | def get_files(self, root: str, split: str): 25 | root = Path(root) 26 | if split == 'train': 27 | split = 'exemplars' 28 | elif split == 'val': 29 | split = 'tuning' 30 | else: 31 | split = 'testing' 32 | with open(root / f'{split}.txt') as f: 33 | lines = f.read().splitlines() 34 | 35 | split_names = [line.split(',')[-1].strip() for line in lines if line != ''] 36 | files = (root / 'images').glob("*.jpg") 37 | files = list(filter(lambda x: x.stem in split_names, files)) 38 | return files 39 | 40 | def __len__(self) -> int: 41 | return len(self.files) 42 | 43 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: 44 | img_path = str(self.files[index]) 45 | lbl_path = str(self.files[index]).split('.')[0].replace('images', 'labels') 46 | image = io.read_image(img_path) 47 | label = self.encode(lbl_path) 48 | 49 | if self.transform: 50 | image, label = self.transform(image, label) 51 | return image, label.squeeze().long() 52 | 53 | def encode(self, label_path: str) -> Tensor: 54 | mask_paths = sorted(list(Path(label_path).glob('*.png'))) 55 | for i, mask_path in enumerate(mask_paths): 56 | mask = io.read_image(str(mask_path)).squeeze() 57 | if i == 0: 58 | label = torch.zeros(self.n_classes, *mask.shape) 59 | label[i, ...] = mask 60 | label = label.argmax(dim=0).unsqueeze(0) 61 | return label 62 | 63 | 64 | if __name__ == '__main__': 65 | from semseg.utils.visualize import visualize_dataset_sample 66 | visualize_dataset_sample(HELEN, '/home/sithu/datasets/SmithCVPR2013_dataset_resized') -------------------------------------------------------------------------------- /semseg/datasets/ibugmask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.utils.data import Dataset 4 | from torchvision import io 5 | from pathlib import Path 6 | from typing import Tuple 7 | 8 | 9 | class iBugMask(Dataset): 10 | CLASSES = ['background', 'skin', 'l-brow', 'r-brow', 'l-eye', 'r-eye', 'nose', 'u-lip', 'i-mouth', 'l-lip', 'hair'] 11 | PALETTE = torch.tensor([[0, 0, 0], [255, 255, 0], [139, 76, 57], [139, 54, 38], [0, 205, 0], [0, 138, 0], [154, 50, 205], [72, 118, 255], [255, 165, 0], [0, 0, 139], [255, 0, 0]]) 12 | 13 | def __init__(self, root: str, split: str = 'train', transform = None) -> None: 14 | super().__init__() 15 | assert split in ['train', 'val', 'test'] 16 | split = 'train' if split == 'train' else 'test' 17 | self.transform = transform 18 | self.n_classes = len(self.CLASSES) 19 | self.ignore_label = 255 20 | 21 | img_path = Path(root) / split 22 | self.files = list(img_path.glob('*.jpg')) 23 | 24 | if not self.files: raise Exception(f"No images found in {root}") 25 | print(f"Found {len(self.files)} {split} images.") 26 | 27 | def __len__(self) -> int: 28 | return len(self.files) 29 | 30 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: 31 | img_path = str(self.files[index]) 32 | lbl_path = str(self.files[index]).replace('.jpg', '.png') 33 | image = io.read_image(str(img_path)) 34 | label = io.read_image(str(lbl_path)) 35 | 36 | if self.transform: 37 | image, label = self.transform(image, label) 38 | return image, label.squeeze().long() 39 | 40 | 41 | if __name__ == '__main__': 42 | from semseg.utils.visualize import visualize_dataset_sample 43 | visualize_dataset_sample(iBugMask, '/home/sithu/datasets/ibugmask_release') -------------------------------------------------------------------------------- /semseg/datasets/lapa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.utils.data import Dataset 4 | from torchvision import io 5 | from pathlib import Path 6 | from typing import Tuple 7 | 8 | 9 | class LaPa(Dataset): 10 | CLASSES = ['background', 'skin', 'l-brow', 'r-brow', 'l-eye', 'r-eye', 'nose', 'u-lip', 'i-mouth', 'l-lip', 'hair'] 11 | PALETTE = torch.tensor([[0, 0, 0], [0, 153, 255], [102, 255, 153], [0, 204, 153], [255, 255, 102], [255, 255, 204], [255, 153, 0], [255, 102, 255], [102, 0, 51], [255, 204, 255], [255, 0, 102]]) 12 | 13 | def __init__(self, root: str, split: str = 'train', transform = None) -> None: 14 | super().__init__() 15 | assert split in ['train', 'val', 'test'] 16 | self.transform = transform 17 | self.n_classes = len(self.CLASSES) 18 | self.ignore_label = 255 19 | 20 | img_path = Path(root) / split / 'images' 21 | self.files = list(img_path.glob('*.jpg')) 22 | 23 | if not self.files: raise Exception(f"No images found in {root}") 24 | print(f"Found {len(self.files)} {split} images.") 25 | 26 | def __len__(self) -> int: 27 | return len(self.files) 28 | 29 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: 30 | img_path = str(self.files[index]) 31 | lbl_path = str(self.files[index]).replace('images', 'labels').replace('.jpg', '.png') 32 | image = io.read_image(str(img_path)) 33 | label = io.read_image(str(lbl_path)) 34 | 35 | if self.transform: 36 | image, label = self.transform(image, label) 37 | return image, label.squeeze().long() 38 | 39 | 40 | if __name__ == '__main__': 41 | from semseg.utils.visualize import visualize_dataset_sample 42 | visualize_dataset_sample(LaPa, '/home/sithu/datasets/LaPa') -------------------------------------------------------------------------------- /semseg/datasets/lip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.utils.data import Dataset 4 | from torchvision import io 5 | from pathlib import Path 6 | from typing import Tuple 7 | 8 | 9 | class LIP(Dataset): 10 | """ 11 | num_classes: 19+background 12 | 30462 train images 13 | 10000 val images 14 | """ 15 | CLASSES = ['background', 'hat', 'hair', 'glove', 'sunglasses', 'upperclothes', 'dress', 'coat', 'socks', 'pants', 'jumpsuits', 'scarf', 'skirt', 'face', 'left-arm', 'right-arm', 'left-leg', 'right-leg', 'left-shoe', 'right-shoe'] 16 | PALETTE = torch.tensor([[0, 0, 0], [127, 0, 0], [254, 0, 0], [0, 84, 0], [169, 0, 50], [254, 84, 0], [255, 0, 84], [0, 118, 220], [84, 84, 0], [0, 84, 84], [84, 50, 0], [51, 85, 127], [0, 127, 0], [0, 0, 254], [50, 169, 220], [0, 254, 254], [84, 254, 169], [169, 254, 84], [254, 254, 0], [254, 169, 0]]) 17 | 18 | def __init__(self, root: str, split: str = 'train', transform = None) -> None: 19 | super().__init__() 20 | assert split in ['train', 'val'] 21 | self.split = split 22 | self.transform = transform 23 | self.n_classes = len(self.CLASSES) 24 | self.ignore_label = 255 25 | 26 | img_path = Path(root) / 'TrainVal_images' / f'{split}_images' 27 | self.files = list(img_path.glob('*.jpg')) 28 | 29 | if not self.files: 30 | raise Exception(f"No images found in {img_path}") 31 | print(f"Found {len(self.files)} {split} images.") 32 | 33 | def __len__(self) -> int: 34 | return len(self.files) 35 | 36 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: 37 | img_path = str(self.files[index]) 38 | lbl_path = str(self.files[index]).replace('TrainVal_images', 'TrainVal_parsing_annotations').replace(f'{self.split}_images', f'{self.split}_segmentations').replace('.jpg', '.png') 39 | 40 | image = io.read_image(img_path) 41 | label = io.read_image(lbl_path) 42 | 43 | if self.transform: 44 | image, label = self.transform(image, label) 45 | return image, label.squeeze().long() 46 | 47 | 48 | if __name__ == '__main__': 49 | from semseg.utils.visualize import visualize_dataset_sample 50 | visualize_dataset_sample(LIP, '/home/sithu/datasets/LIP/LIP') -------------------------------------------------------------------------------- /semseg/datasets/mapillary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.utils.data import Dataset 4 | from torchvision import io 5 | from pathlib import Path 6 | from typing import Tuple 7 | 8 | 9 | class MapillaryVistas(Dataset): 10 | CLASSES = [ 11 | 'Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail', 'Barrier', 'Wall', 'Bike Lane', 'Crosswalk - Plain', 'Curb Cut', 'Parking', 'Pedestrian Area', 'Rail Track', 'Road', 'Service Lane', 'Sidewalk', 'Bridge', 'Building', 'Tunnel', 'Person', 'Bicyclist', 'Motorcyclist', 'Other Rider', 'Lane Marking - Crosswalk', 'Lane Marking - General', 'Mountain', 'Sand', 'Sky', 'Snow', 'Terrain', 'Vegetation', 'Water', 'Banner', 12 | 'Bench', 'Bike Rack', 'Billboard', 'Catch Basin', 'CCTV Camera', 'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole', 'Phone Booth', 'Pothole', 'Street Light', 'Pole', 'Traffic Sign Frame', 'Utility Pole', 'Traffic Light', 'Traffic Sign (Back)', 'Traffic Sign (Front)', 'Trash Can', 'Bicycle', 'Boat', 'Bus', 'Car', 'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle', 'Trailer', 'Truck', 'Wheeled Slow', 'Car Mount', 'Ego Vehicle' 13 | ] 14 | PALETTE = torch.tensor([ 15 | [165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153], [180, 165, 180], [90, 120, 150], [102, 102, 156], [128, 64, 255], [140, 140, 200], [170, 170, 170], [250, 170, 160], [96, 96, 96], [230, 150, 140], [128, 64, 128], [110, 110, 110], [244, 35, 232], [150, 100, 100], [70, 70, 70], [150, 120, 90], [220, 20, 60], [255, 0, 0], [255, 0, 100], [255, 0, 200], [200, 128, 128], [255, 255, 255], [64, 170, 64], [230, 160, 50], [70, 130, 180], [190, 255, 255], [152, 251, 152], [107, 142, 35], 16 | [0, 170, 30], [255, 255, 128], [250, 0, 30], [100, 140, 180], [220, 220, 220], [220, 128, 128], [222, 40, 40], [100, 170, 30], [40, 40, 40], [33, 33, 33], [100, 128, 160], [142, 0, 0], [70, 100, 150], [210, 170, 100], [153, 153, 153], [128, 128, 128], [0, 0, 80], [250, 170, 30], [192, 192, 192], [220, 220, 0], [140, 140, 20], [119, 11, 32], [150, 0, 255], [0, 60, 100], [0, 0, 142], [0, 0, 90], [0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110], [0, 0, 70], [0, 0, 192], [32, 32, 32], [120, 10, 10] 17 | ]) 18 | 19 | def __init__(self, root: str, split: str = 'train', transform = None) -> None: 20 | super().__init__() 21 | assert split in ['train', 'val'] 22 | split = 'training' if split == 'train' else 'validation' 23 | self.transform = transform 24 | self.n_classes = len(self.CLASSES) 25 | self.ignore_label = 65 26 | 27 | img_path = Path(root) / split / 'images' 28 | self.files = list(img_path.glob("*.jpg")) 29 | 30 | if not self.files: 31 | raise Exception(f"No images found in {img_path}") 32 | print(f"Found {len(self.files)} {split} images.") 33 | 34 | def __len__(self) -> int: 35 | return len(self.files) 36 | 37 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: 38 | img_path = str(self.files[index]) 39 | lbl_path = str(self.files[index]).replace('images', 'labels').replace('.jpg', '.png') 40 | 41 | image = io.read_image(img_path, io.ImageReadMode.RGB) 42 | label = io.read_image(lbl_path) 43 | 44 | if self.transform: 45 | image, label = self.transform(image, label) 46 | return image, label.squeeze().long() 47 | 48 | 49 | if __name__ == '__main__': 50 | from semseg.utils.visualize import visualize_dataset_sample 51 | visualize_dataset_sample(MapillaryVistas, '/home/sithu/datasets/Mapillary') -------------------------------------------------------------------------------- /semseg/datasets/mhpv1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import Tensor 4 | from torch.utils.data import Dataset 5 | from torchvision import io 6 | from pathlib import Path 7 | from typing import Tuple 8 | 9 | 10 | class MHPv1(Dataset): 11 | """ 12 | 4980 images each with at least 2 persons (average 3) 13 | 3000 images for training 14 | 1000 images for validation 15 | 980 images for testing 16 | num_classes: 18+background 17 | """ 18 | CLASSES = ['background', 'hat', 'hair', 'sunglass', 'upper-clothes', 'skirt', 'pants', 'dress', 'belt', 'left-shoe', 'right-shoe', 'face', 'left-leg', 'right-leg', 'left-arm', 'right-arm', 'bag', 'sacrf', 'torso-skin'] 19 | PALETTE = torch.tensor([[0, 0, 0], [128, 0, 0], [254, 0, 0], [0, 85, 0], [169, 0, 51], [254, 85, 0], [255, 0, 85], [0, 119, 220], [85, 85, 0], [190, 153, 153], [85, 51, 0], [52, 86, 128], [0, 128, 0], [0, 0, 254], [51, 169, 220], [0, 254, 254], [85, 254, 169], [169, 254, 85], [254, 254, 0]]) 20 | 21 | def __init__(self, root: str, split: str = 'train', transform = None) -> None: 22 | super().__init__() 23 | assert split in ['train', 'val', 'test'] 24 | self.transform = transform 25 | self.n_classes = len(self.CLASSES) 26 | self.ignore_label = 255 27 | 28 | self.images, self.labels = self.get_files(root, split) 29 | print(f"Found {len(self.images)} {split} images.") 30 | 31 | def get_files(self, root: str, split: str): 32 | root = Path(root) 33 | all_labels = list((root / 'annotations').rglob('*.png')) 34 | images, labels = [], [] 35 | 36 | flist = 'test_list.txt' if split == 'test' else 'train_list.txt' 37 | with open(root / flist) as f: 38 | all_files = f.read().splitlines() 39 | 40 | if split == 'train': 41 | files = all_files[:3000] 42 | elif split == 'val': 43 | files = all_files[3000:] 44 | else: 45 | files = all_files 46 | 47 | for f in files: 48 | images.append(root / 'images' / f) 49 | img_name = f.split('.')[0] 50 | labels_per_images = list(filter(lambda x: x.stem.startswith(img_name), all_labels)) 51 | assert labels_per_images != [] 52 | labels.append(labels_per_images) 53 | 54 | assert len(images) == len(labels) 55 | return images, labels 56 | 57 | def __len__(self) -> int: 58 | return len(self.images) 59 | 60 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: 61 | img_path = str(self.images[index]) 62 | lbl_paths = self.labels[index] 63 | 64 | image = io.read_image(img_path) 65 | label = self.read_label(lbl_paths) 66 | 67 | if self.transform: 68 | image, label = self.transform(image, label) 69 | return image, label.squeeze().long() 70 | 71 | def read_label(self, lbl_paths: list) -> Tensor: 72 | labels = None 73 | label_idx = None 74 | 75 | for lbl_path in lbl_paths: 76 | label = io.read_image(str(lbl_path)).squeeze().numpy() 77 | 78 | if label_idx is None: 79 | label_idx = np.zeros(label.shape, dtype=np.uint8) 80 | label = np.ma.masked_array(label, mask=label_idx) 81 | label_idx += np.minimum(label, 1) 82 | if labels is None: 83 | labels = label 84 | else: 85 | labels += label 86 | return torch.from_numpy(labels.data).unsqueeze(0).to(torch.uint8) 87 | 88 | 89 | if __name__ == '__main__': 90 | from semseg.utils.visualize import visualize_dataset_sample 91 | visualize_dataset_sample(MHPv1, '/home/sithu/datasets/LV-MHP-v1') 92 | -------------------------------------------------------------------------------- /semseg/datasets/mhpv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import Tensor 4 | from torch.utils.data import Dataset 5 | from torchvision import io 6 | from pathlib import Path 7 | from typing import Tuple 8 | 9 | 10 | class MHPv2(Dataset): 11 | """ 12 | 25,403 images each with at least 2 persons (average 3) 13 | 15,403 images for training 14 | 5000 images for validation 15 | 5000 images for testing 16 | num_classes: 58+background 17 | """ 18 | CLASSES = ['background', 'cap/hat', 'helmet', 'face', 'hair', 'left-arm', 'right-arm', 'left-hand', 'right-hand', 'protector', 'bikini/bra', 'jacket/windbreaker/hoodie', 't-shirt', 'polo-shirt', 'sweater', 'singlet', 'torso-skin', 'pants', 'shorts/swim-shorts', 'skirt', 'stockings', 'socks', 'left-boot', 'right-boot', 'left-shoe', 'right-shoe', 'left-highheel', 'right-highheel', 'left-sandal', 'right-sandal', 'left-leg', 'right-leg', 'left-foot', 'right-foot', 'coat', 'dress', 'robe', 'jumpsuits', 'other-full-body-clothes', 'headware', 'backpack', 'ball', 'bats', 'belt', 'bottle', 'carrybag', 'cases', 'sunglasses', 'eyeware', 'gloves', 'scarf', 'umbrella', 'wallet/purse', 'watch', 'wristband', 'tie', 'other-accessories', 'other-upper-body-clothes', 'other-lower-body-clothes'] 19 | PALETTE = torch.tensor([[0, 0, 0], [255, 114, 196], [63, 31, 34], [253, 1, 0], [254, 26, 1], [253, 54, 0], [253, 82, 0], [252, 110, 0], [253, 137, 0], [253, 166, 1], [254, 191, 0], [253, 219, 0], [252, 248, 0], [238, 255, 1], [209, 255, 0], [182, 255, 0], [155, 255, 0], [133, 254, 0], [102, 254, 0], [78, 255, 0], [55, 254, 1], [38, 255, 0], [30, 255, 13], [34, 255, 35], [35, 254, 64], [36, 254, 87], [37, 252, 122], [37, 255, 143], [35, 255, 172], [35, 255, 200], [40, 253, 228], [40, 255, 255], [37, 228, 255], [33, 198, 254], [31, 170, 254], [22, 145, 255], [26, 112, 255], [20, 86, 253], [22, 53, 255], [19, 12, 253], [19, 1, 246], [30, 1, 252], [52, 0, 254], [72, 0, 255], [102, 0, 255], [121, 1, 252], [157, 1, 245], [182, 0, 253], [210, 0, 254], [235, 0, 255], [253, 1, 246], [254, 0, 220], [255, 0, 191], [254, 0, 165], [252, 0, 137], [248, 2, 111], [253, 0, 81], [255, 0, 54], [253, 1, 26]]) 20 | 21 | def __init__(self, root: str, split: str = 'train', transform = None) -> None: 22 | super().__init__() 23 | assert split in ['train', 'val'] 24 | self.transform = transform 25 | self.n_classes = len(self.CLASSES) 26 | self.ignore_label = 255 27 | 28 | self.images, self.labels = self.get_files(root, split) 29 | print(f"Found {len(self.images)} {split} images.") 30 | 31 | def get_files(self, root: str, split: str): 32 | root = Path(root) 33 | all_labels = list((root / split / 'parsing_annos').rglob('*.png')) 34 | images = list((root / split / 'images').rglob('*.jpg')) 35 | labels = [] 36 | 37 | for f in images: 38 | labels_per_images = list(filter(lambda x: x.stem.split('_', maxsplit=1)[0] == f.stem, all_labels)) 39 | assert labels_per_images != [] 40 | labels.append(labels_per_images) 41 | 42 | assert len(images) == len(labels) 43 | return images, labels 44 | 45 | def __len__(self) -> int: 46 | return len(self.images) 47 | 48 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: 49 | img_path = str(self.images[index]) 50 | lbl_paths = self.labels[index] 51 | 52 | image = io.read_image(img_path) 53 | label = self.read_label(lbl_paths) 54 | 55 | if self.transform: 56 | image, label = self.transform(image, label) 57 | return image, label.squeeze().long() 58 | 59 | def read_label(self, lbl_paths: list) -> Tensor: 60 | labels = None 61 | label_idx = None 62 | 63 | for lbl_path in lbl_paths: 64 | label = io.read_image(str(lbl_path)).squeeze().numpy() 65 | if label.ndim != 2: 66 | label = label[0] 67 | if label_idx is None: 68 | label_idx = np.zeros(label.shape, dtype=np.uint8) 69 | label = np.ma.masked_array(label, mask=label_idx) 70 | label_idx += np.minimum(label, 1) 71 | if labels is None: 72 | labels = label 73 | else: 74 | labels += label 75 | return torch.from_numpy(labels.data).unsqueeze(0).to(torch.uint8) 76 | 77 | 78 | if __name__ == '__main__': 79 | from semseg.utils.visualize import visualize_dataset_sample 80 | visualize_dataset_sample(MHPv2, '/home/sithu/datasets/LV-MHP-v2') -------------------------------------------------------------------------------- /semseg/datasets/pascalcontext.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.utils.data import Dataset 4 | from torchvision import io 5 | from pathlib import Path 6 | from typing import Tuple 7 | 8 | 9 | class PASCALContext(Dataset): 10 | """ 11 | https://cs.stanford.edu/~roozbeh/pascal-context/ 12 | based on PASCAL VOC 2010 13 | num_classes: 59 14 | 10,100 train+val 15 | 9,637 test 16 | """ 17 | CLASSES = [ 18 | 'aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 19 | 'bicycle', 'bird', 'boat', 'book', 'bottle', 'building', 'bus', 20 | 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth', 21 | 'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence', 22 | 'floor', 'flower', 'food', 'grass', 'ground', 'horse', 23 | 'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 'person', 24 | 'plate', 'platform', 'pottedplant', 'road', 'rock', 'sheep', 25 | 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table', 26 | 'track', 'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water', 27 | 'window', 'wood' 28 | ] 29 | 30 | PALETTE = torch.tensor([ 31 | [180, 120, 120], [6, 230, 230], [80, 50, 50], 32 | [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], 33 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], 34 | [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], 35 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], 36 | [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], 37 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], 38 | [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], 39 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], 40 | [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], 41 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], 42 | [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], 43 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], 44 | [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], 45 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255] 46 | ]) 47 | 48 | def __init__(self, root: str, split: str = 'train', transform = None) -> None: 49 | super().__init__() 50 | assert split in ['train', 'val'] 51 | self.transform = transform 52 | self.n_classes = len(self.CLASSES) 53 | self.ignore_label = -1 54 | 55 | self.images, self.labels = self.get_files(root, split) 56 | print(f"Found {len(self.images)} {split} images.") 57 | 58 | def get_files(self, root: str, split: str): 59 | root = Path(root) 60 | flist = root / 'ImageSets' / 'SegmentationContext' / f'{split}.txt' 61 | with open(flist) as f: 62 | files = f.read().splitlines() 63 | images, labels = [], [] 64 | 65 | for fi in files: 66 | images.append(str(root / 'JPEGImages' / f'{fi}.jpg')) 67 | labels.append(str(root / 'SegmentationClassContext' / f'{fi}.png')) 68 | return images, labels 69 | 70 | def __len__(self) -> int: 71 | return len(self.images) 72 | 73 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: 74 | img_path = self.images[index] 75 | lbl_path = self.labels[index] 76 | 77 | image = io.read_image(img_path) 78 | label = io.read_image(lbl_path) 79 | 80 | if self.transform: 81 | image, label = self.transform(image, label) 82 | return image, label.squeeze().long() - 1 # remove background class 83 | 84 | 85 | if __name__ == '__main__': 86 | from semseg.utils.visualize import visualize_dataset_sample 87 | visualize_dataset_sample(PASCALContext, '/home/sithu/datasets/VOCdevkit/VOC2010') -------------------------------------------------------------------------------- /semseg/datasets/suim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.utils.data import Dataset 4 | from torchvision import io 5 | from pathlib import Path 6 | from typing import Tuple 7 | from PIL import Image 8 | from torchvision.transforms import functional as TF 9 | 10 | 11 | class SUIM(Dataset): 12 | CLASSES = ['water', 'human divers', 'aquatic plants and sea-grass', 'wrecks and ruins', 'robots (AUVs/ROVs/instruments)', 'reefs and invertebrates', 'fish and vertebrates', 'sea-floor and rocks'] 13 | PALETTE = torch.tensor([[0, 0, 0], [0, 0, 255], [0, 255, 0], [0, 255, 255], [255, 0, 0], [255, 0, 255], [255, 255, 0], [255, 255, 255]]) 14 | 15 | def __init__(self, root: str, split: str = 'train', transform = None) -> None: 16 | super().__init__() 17 | assert split in ['train', 'val'] 18 | self.split = 'train_val' if split == 'train' else 'TEST' 19 | self.transform = transform 20 | self.n_classes = len(self.CLASSES) 21 | self.ignore_label = 255 22 | 23 | img_path = Path(root) / self.split / 'images' 24 | self.files = list(img_path.glob("*.jpg")) 25 | 26 | if not self.files: 27 | raise Exception(f"No images found in {img_path}") 28 | print(f"Found {len(self.files)} {split} images.") 29 | 30 | def __len__(self) -> int: 31 | return len(self.files) 32 | 33 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: 34 | img_path = str(self.files[index]) 35 | lbl_path = str(self.files[index]).replace('images', 'masks').replace('.jpg', '.bmp') 36 | 37 | image = io.read_image(img_path) 38 | label = TF.pil_to_tensor(Image.open(lbl_path).convert('RGB')) 39 | 40 | if self.transform: 41 | image, label = self.transform(image, label) 42 | return image, self.encode(label).long() 43 | 44 | def encode(self, label: Tensor) -> Tensor: 45 | label = label.permute(1, 2, 0) 46 | mask = torch.zeros(label.shape[:-1]) 47 | 48 | for index, color in enumerate(self.PALETTE): 49 | bool_mask = torch.eq(label, color) 50 | class_map = torch.all(bool_mask, dim=-1) 51 | mask[class_map] = index 52 | return mask 53 | 54 | 55 | if __name__ == '__main__': 56 | from semseg.utils.visualize import visualize_dataset_sample 57 | visualize_dataset_sample(SUIM, '/home/sithu/datasets/SUIM') -------------------------------------------------------------------------------- /semseg/datasets/sunrgbd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import Tensor 4 | from torch.utils.data import Dataset 5 | from torchvision import io 6 | from scipy import io as sio 7 | from pathlib import Path 8 | from typing import Tuple 9 | 10 | 11 | class SunRGBD(Dataset): 12 | """ 13 | num_classes: 37 14 | """ 15 | CLASSES = [ 16 | 'wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 'counter', 'blinds', 'desk', 'shelves', 'curtain', 'dresser', 'pillow', 'mirror', 17 | 'floor mat', 'clothes', 'ceiling', 'books', 'fridge', 'tv', 'paper', 'towel', 'shower curtain', 'box', 'whiteboard', 'person', 'night stand', 'toilet', 'sink', 'lamp', 'bathtub', 'bag' 18 | ] 19 | 20 | PALETTE = torch.tensor([ 21 | (119, 119, 119), (244, 243, 131), (137, 28, 157), (150, 255, 255), (54, 114, 113), (0, 0, 176), (255, 69, 0), (87, 112, 255), (0, 163, 33), 22 | (255, 150, 255), (255, 180, 10), (101, 70, 86), (38, 230, 0), (255, 120, 70), (117, 41, 121), (150, 255, 0), (132, 0, 255), (24, 209, 255), 23 | (191, 130, 35), (219, 200, 109), (154, 62, 86), (255, 190, 190), (255, 0, 255), (152, 163, 55), (192, 79, 212), (230, 230, 230), (53, 130, 64), 24 | (155, 249, 152), (87, 64, 34), (214, 209, 175), (170, 0, 59), (255, 0, 0), (193, 195, 234), (70, 72, 115), (255, 255, 0), (52, 57, 131), (12, 83, 45) 25 | ]) 26 | 27 | def __init__(self, root: str, split: str = 'train', transform = None) -> None: 28 | super().__init__() 29 | assert split in ['alltrain', 'train', 'val', 'test'] 30 | self.transform = transform 31 | self.n_classes = len(self.CLASSES) 32 | self.ignore_label = -1 33 | self.files, self.labels = self.get_data(root, split) 34 | print(f"Found {len(self.files)} {split} images.") 35 | 36 | def get_data(self, root: str, split: str): 37 | root = Path(root) 38 | files, labels = [], [] 39 | split_path = root / 'SUNRGBDtoolbox' / 'traintestSUNRGBD' / 'allsplit.mat' 40 | split_mat = sio.loadmat(split_path, squeeze_me=True, struct_as_record=False) 41 | if split == 'train': 42 | file_lists = split_mat['trainvalsplit'].train 43 | elif split == 'val': 44 | file_lists = split_mat['trainvalsplit'].val 45 | elif split == 'test': 46 | file_lists = split_mat['alltest'] 47 | else: 48 | file_lists = split_mat['alltrain'] 49 | 50 | for fl in file_lists: 51 | real_fl = root / fl.split('/n/fs/sun3d/data/')[-1] 52 | files.append(str(list((real_fl / 'image').glob('*.jpg'))[0])) 53 | labels.append(real_fl / 'seg.mat') 54 | 55 | assert len(files) == len(labels) 56 | return files, labels 57 | 58 | def __len__(self) -> int: 59 | return len(self.files) 60 | 61 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: 62 | image = io.read_image(self.files[index], io.ImageReadMode.RGB) 63 | label = sio.loadmat(self.labels[index], squeeze_me=True, struct_as_record=False)['seglabel'] 64 | label = torch.from_numpy(label.astype(np.uint8)).unsqueeze(0) 65 | 66 | if self.transform: 67 | image, label = self.transform(image, label) 68 | return image, self.encode(label.squeeze()).long() - 1 # subtract -1 to remove void class 69 | 70 | def encode(self, label: Tensor) -> Tensor: 71 | label[label > self.n_classes] = 0 72 | return label 73 | 74 | 75 | if __name__ == '__main__': 76 | from semseg.utils.visualize import visualize_dataset_sample 77 | visualize_dataset_sample(SunRGBD, '/home/sithu/datasets/sunrgbd') -------------------------------------------------------------------------------- /semseg/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torch.nn import functional as F 4 | 5 | 6 | class CrossEntropy(nn.Module): 7 | def __init__(self, ignore_label: int = 255, weight: Tensor = None, aux_weights: list = [1, 0.4, 0.4]) -> None: 8 | super().__init__() 9 | self.aux_weights = aux_weights 10 | self.criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_label) 11 | 12 | def _forward(self, preds: Tensor, labels: Tensor) -> Tensor: 13 | # preds in shape [B, C, H, W] and labels in shape [B, H, W] 14 | return self.criterion(preds, labels) 15 | 16 | def forward(self, preds, labels: Tensor) -> Tensor: 17 | if isinstance(preds, tuple): 18 | return sum([w * self._forward(pred, labels) for (pred, w) in zip(preds, self.aux_weights)]) 19 | return self._forward(preds, labels) 20 | 21 | 22 | class OhemCrossEntropy(nn.Module): 23 | def __init__(self, ignore_label: int = 255, weight: Tensor = None, thresh: float = 0.7, aux_weights: list = [1, 1]) -> None: 24 | super().__init__() 25 | self.ignore_label = ignore_label 26 | self.aux_weights = aux_weights 27 | self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)) 28 | self.criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_label, reduction='none') 29 | 30 | def _forward(self, preds: Tensor, labels: Tensor) -> Tensor: 31 | # preds in shape [B, C, H, W] and labels in shape [B, H, W] 32 | n_min = labels[labels != self.ignore_label].numel() // 16 33 | loss = self.criterion(preds, labels).view(-1) 34 | loss_hard = loss[loss > self.thresh] 35 | 36 | if loss_hard.numel() < n_min: 37 | loss_hard, _ = loss.topk(n_min) 38 | 39 | return torch.mean(loss_hard) 40 | 41 | def forward(self, preds, labels: Tensor) -> Tensor: 42 | if isinstance(preds, tuple): 43 | return sum([w * self._forward(pred, labels) for (pred, w) in zip(preds, self.aux_weights)]) 44 | return self._forward(preds, labels) 45 | 46 | 47 | class Dice(nn.Module): 48 | def __init__(self, delta: float = 0.5, aux_weights: list = [1, 0.4, 0.4]): 49 | """ 50 | delta: Controls weight given to FP and FN. This equals to dice score when delta=0.5 51 | """ 52 | super().__init__() 53 | self.delta = delta 54 | self.aux_weights = aux_weights 55 | 56 | def _forward(self, preds: Tensor, labels: Tensor) -> Tensor: 57 | # preds in shape [B, C, H, W] and labels in shape [B, H, W] 58 | num_classes = preds.shape[1] 59 | labels = F.one_hot(labels, num_classes).permute(0, 3, 1, 2) 60 | tp = torch.sum(labels*preds, dim=(2, 3)) 61 | fn = torch.sum(labels*(1-preds), dim=(2, 3)) 62 | fp = torch.sum((1-labels)*preds, dim=(2, 3)) 63 | 64 | dice_score = (tp + 1e-6) / (tp + self.delta * fn + (1 - self.delta) * fp + 1e-6) 65 | dice_score = torch.sum(1 - dice_score, dim=-1) 66 | 67 | dice_score = dice_score / num_classes 68 | return dice_score.mean() 69 | 70 | def forward(self, preds, targets: Tensor) -> Tensor: 71 | if isinstance(preds, tuple): 72 | return sum([w * self._forward(pred, targets) for (pred, w) in zip(preds, self.aux_weights)]) 73 | return self._forward(preds, targets) 74 | 75 | 76 | __all__ = ['CrossEntropy', 'OhemCrossEntropy', 'Dice'] 77 | 78 | 79 | def get_loss(loss_fn_name: str = 'CrossEntropy', ignore_label: int = 255, cls_weights: Tensor = None): 80 | assert loss_fn_name in __all__, f"Unavailable loss function name >> {loss_fn_name}.\nAvailable loss functions: {__all__}" 81 | if loss_fn_name == 'Dice': 82 | return Dice() 83 | return eval(loss_fn_name)(ignore_label, cls_weights) 84 | 85 | 86 | if __name__ == '__main__': 87 | pred = torch.randint(0, 19, (2, 19, 480, 640), dtype=torch.float) 88 | label = torch.randint(0, 19, (2, 480, 640), dtype=torch.long) 89 | loss_fn = Dice() 90 | y = loss_fn(pred, label) 91 | print(y) -------------------------------------------------------------------------------- /semseg/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from typing import Tuple 4 | 5 | 6 | class Metrics: 7 | def __init__(self, num_classes: int, ignore_label: int, device) -> None: 8 | self.ignore_label = ignore_label 9 | self.num_classes = num_classes 10 | self.hist = torch.zeros(num_classes, num_classes).to(device) 11 | 12 | def update(self, pred: Tensor, target: Tensor) -> None: 13 | pred = pred.argmax(dim=1) 14 | keep = target != self.ignore_label 15 | self.hist += torch.bincount(target[keep] * self.num_classes + pred[keep], minlength=self.num_classes**2).view(self.num_classes, self.num_classes) 16 | 17 | def compute_iou(self) -> Tuple[Tensor, Tensor]: 18 | ious = self.hist.diag() / (self.hist.sum(0) + self.hist.sum(1) - self.hist.diag()) 19 | miou = ious[~ious.isnan()].mean().item() 20 | ious *= 100 21 | miou *= 100 22 | return ious.cpu().numpy().round(2).tolist(), round(miou, 2) 23 | 24 | def compute_f1(self) -> Tuple[Tensor, Tensor]: 25 | f1 = 2 * self.hist.diag() / (self.hist.sum(0) + self.hist.sum(1)) 26 | mf1 = f1[~f1.isnan()].mean().item() 27 | f1 *= 100 28 | mf1 *= 100 29 | return f1.cpu().numpy().round(2).tolist(), round(mf1, 2) 30 | 31 | def compute_pixel_acc(self) -> Tuple[Tensor, Tensor]: 32 | acc = self.hist.diag() / self.hist.sum(1) 33 | macc = acc[~acc.isnan()].mean().item() 34 | acc *= 100 35 | macc *= 100 36 | return acc.cpu().numpy().round(2).tolist(), round(macc, 2) 37 | 38 | -------------------------------------------------------------------------------- /semseg/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .segformer import SegFormer 2 | from .ddrnet import DDRNet 3 | from .fchardnet import FCHarDNet 4 | from .sfnet import SFNet 5 | from .bisenetv1 import BiSeNetv1 6 | from .bisenetv2 import BiSeNetv2 7 | from .lawin import Lawin 8 | 9 | 10 | __all__ = [ 11 | 'SegFormer', 12 | 'Lawin', 13 | 'SFNet', 14 | 'BiSeNetv1', 15 | 16 | # Standalone Models 17 | 'DDRNet', 18 | 'FCHarDNet', 19 | 'BiSeNetv2' 20 | ] -------------------------------------------------------------------------------- /semseg/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import ResNet, resnet_settings 2 | from .resnetd import ResNetD, resnetd_settings 3 | from .micronet import MicroNet, micronet_settings 4 | from .mobilenetv2 import MobileNetV2, mobilenetv2_settings 5 | from .mobilenetv3 import MobileNetV3, mobilenetv3_settings 6 | 7 | from .mit import MiT, mit_settings 8 | from .pvt import PVTv2, pvtv2_settings 9 | from .rest import ResT, rest_settings 10 | from .poolformer import PoolFormer, poolformer_settings 11 | from .convnext import ConvNeXt, convnext_settings 12 | 13 | 14 | __all__ = [ 15 | 'ResNet', 16 | 'ResNetD', 17 | 'MicroNet', 18 | 'MobileNetV2', 19 | 'MobileNetV3', 20 | 21 | 'MiT', 22 | 'PVTv2', 23 | 'ResT', 24 | 'PoolFormer', 25 | 'ConvNeXt', 26 | ] -------------------------------------------------------------------------------- /semseg/models/backbones/convnext.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from semseg.models.layers import DropPath 4 | 5 | 6 | class LayerNorm(nn.Module): 7 | """Channel first layer norm 8 | """ 9 | def __init__(self, normalized_shape, eps=1e-6) -> None: 10 | super().__init__() 11 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 12 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 13 | self.eps = eps 14 | 15 | def forward(self, x: Tensor) -> Tensor: 16 | u = x.mean(1, keepdim=True) 17 | s = (x - u).pow(2).mean(1, keepdim=True) 18 | x = (x - u) / torch.sqrt(s + self.eps) 19 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 20 | return x 21 | 22 | 23 | class Block(nn.Module): 24 | def __init__(self, dim, dpr=0., init_value=1e-6): 25 | super().__init__() 26 | self.dwconv = nn.Conv2d(dim, dim, 7, 1, 3, groups=dim) 27 | self.norm = nn.LayerNorm(dim, eps=1e-6) 28 | self.pwconv1 = nn.Linear(dim, 4*dim) 29 | self.act = nn.GELU() 30 | self.pwconv2 = nn.Linear(4*dim, dim) 31 | self.gamma = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True) if init_value > 0 else None 32 | self.drop_path = DropPath(dpr) if dpr > 0. else nn.Identity() 33 | 34 | def forward(self, x: Tensor) -> Tensor: 35 | input = x 36 | x = self.dwconv(x) 37 | x = x.permute(0, 2, 3, 1) # NCHW to NHWC 38 | x = self.norm(x) 39 | x = self.pwconv1(x) 40 | x = self.act(x) 41 | x = self.pwconv2(x) 42 | 43 | if self.gamma is not None: 44 | x = self.gamma * x 45 | 46 | x = x.permute(0, 3, 1, 2) 47 | x = input + self.drop_path(x) 48 | return x 49 | 50 | 51 | class Stem(nn.Sequential): 52 | def __init__(self, c1, c2, k, s): 53 | super().__init__( 54 | nn.Conv2d(c1, c2, k, s), 55 | LayerNorm(c2) 56 | ) 57 | 58 | 59 | class Downsample(nn.Sequential): 60 | def __init__(self, c1, c2, k, s): 61 | super().__init__( 62 | LayerNorm(c1), 63 | nn.Conv2d(c1, c2, k, s) 64 | ) 65 | 66 | 67 | convnext_settings = { 68 | 'T': [[3, 3, 9, 3], [96, 192, 384, 768], 0.0], # [depths, dims, dpr] 69 | 'S': [[3, 3, 27, 3], [96, 192, 384, 768], 0.0], 70 | 'B': [[3, 3, 27, 3], [128, 256, 512, 1024], 0.0] 71 | } 72 | 73 | 74 | class ConvNeXt(nn.Module): 75 | def __init__(self, model_name: str = 'T') -> None: 76 | super().__init__() 77 | assert model_name in convnext_settings.keys(), f"ConvNeXt model name should be in {list(convnext_settings.keys())}" 78 | depths, embed_dims, drop_path_rate = convnext_settings[model_name] 79 | self.channels = embed_dims 80 | 81 | self.downsample_layers = nn.ModuleList([ 82 | Stem(3, embed_dims[0], 4, 4), 83 | *[Downsample(embed_dims[i], embed_dims[i+1], 2, 2) for i in range(3)] 84 | ]) 85 | 86 | self.stages = nn.ModuleList() 87 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 88 | cur = 0 89 | 90 | for i in range(4): 91 | stage = nn.Sequential(*[ 92 | Block(embed_dims[i], dpr[cur+j]) 93 | for j in range(depths[i])]) 94 | self.stages.append(stage) 95 | cur += depths[i] 96 | 97 | for i in range(4): 98 | self.add_module(f"norm{i}", LayerNorm(embed_dims[i])) 99 | 100 | def forward(self, x: Tensor): 101 | outs = [] 102 | 103 | for i in range(4): 104 | x = self.downsample_layers[i](x) 105 | x = self.stages[i](x) 106 | norm_layer = getattr(self, f"norm{i}") 107 | outs.append(norm_layer(x)) 108 | return outs 109 | 110 | 111 | if __name__ == '__main__': 112 | model = ConvNeXt('T') 113 | # model.load_state_dict(torch.load('C:\\Users\\sithu\\Documents\\weights\\backbones\\convnext\\convnext_tiny_1k_224_ema.pth', map_location='cpu')['model'], strict=False) 114 | x = torch.randn(1, 3, 224, 224) 115 | feats = model(x) 116 | for y in feats: 117 | print(y.shape) -------------------------------------------------------------------------------- /semseg/models/backbones/mit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torch.nn import functional as F 4 | from semseg.models.layers import DropPath 5 | 6 | 7 | class Attention(nn.Module): 8 | def __init__(self, dim, head, sr_ratio): 9 | super().__init__() 10 | self.head = head 11 | self.sr_ratio = sr_ratio 12 | self.scale = (dim // head) ** -0.5 13 | self.q = nn.Linear(dim, dim) 14 | self.kv = nn.Linear(dim, dim*2) 15 | self.proj = nn.Linear(dim, dim) 16 | 17 | if sr_ratio > 1: 18 | self.sr = nn.Conv2d(dim, dim, sr_ratio, sr_ratio) 19 | self.norm = nn.LayerNorm(dim) 20 | 21 | def forward(self, x: Tensor, H, W) -> Tensor: 22 | B, N, C = x.shape 23 | q = self.q(x).reshape(B, N, self.head, C // self.head).permute(0, 2, 1, 3) 24 | 25 | if self.sr_ratio > 1: 26 | x = x.permute(0, 2, 1).reshape(B, C, H, W) 27 | x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1) 28 | x = self.norm(x) 29 | 30 | k, v = self.kv(x).reshape(B, -1, 2, self.head, C // self.head).permute(2, 0, 3, 1, 4) 31 | 32 | attn = (q @ k.transpose(-2, -1)) * self.scale 33 | attn = attn.softmax(dim=-1) 34 | 35 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 36 | x = self.proj(x) 37 | return x 38 | 39 | 40 | class DWConv(nn.Module): 41 | def __init__(self, dim): 42 | super().__init__() 43 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim) 44 | 45 | def forward(self, x: Tensor, H, W) -> Tensor: 46 | B, _, C = x.shape 47 | x = x.transpose(1, 2).view(B, C, H, W) 48 | x = self.dwconv(x) 49 | return x.flatten(2).transpose(1, 2) 50 | 51 | 52 | class MLP(nn.Module): 53 | def __init__(self, c1, c2): 54 | super().__init__() 55 | self.fc1 = nn.Linear(c1, c2) 56 | self.dwconv = DWConv(c2) 57 | self.fc2 = nn.Linear(c2, c1) 58 | 59 | def forward(self, x: Tensor, H, W) -> Tensor: 60 | return self.fc2(F.gelu(self.dwconv(self.fc1(x), H, W))) 61 | 62 | 63 | class PatchEmbed(nn.Module): 64 | def __init__(self, c1=3, c2=32, patch_size=7, stride=4): 65 | super().__init__() 66 | self.proj = nn.Conv2d(c1, c2, patch_size, stride, patch_size//2) # padding=(ps[0]//2, ps[1]//2) 67 | self.norm = nn.LayerNorm(c2) 68 | 69 | def forward(self, x: Tensor) -> Tensor: 70 | x = self.proj(x) 71 | _, _, H, W = x.shape 72 | x = x.flatten(2).transpose(1, 2) 73 | x = self.norm(x) 74 | return x, H, W 75 | 76 | 77 | class Block(nn.Module): 78 | def __init__(self, dim, head, sr_ratio=1, dpr=0.): 79 | super().__init__() 80 | self.norm1 = nn.LayerNorm(dim) 81 | self.attn = Attention(dim, head, sr_ratio) 82 | self.drop_path = DropPath(dpr) if dpr > 0. else nn.Identity() 83 | self.norm2 = nn.LayerNorm(dim) 84 | self.mlp = MLP(dim, int(dim*4)) 85 | 86 | def forward(self, x: Tensor, H, W) -> Tensor: 87 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 88 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 89 | return x 90 | 91 | 92 | mit_settings = { 93 | 'B0': [[32, 64, 160, 256], [2, 2, 2, 2]], # [embed_dims, depths] 94 | 'B1': [[64, 128, 320, 512], [2, 2, 2, 2]], 95 | 'B2': [[64, 128, 320, 512], [3, 4, 6, 3]], 96 | 'B3': [[64, 128, 320, 512], [3, 4, 18, 3]], 97 | 'B4': [[64, 128, 320, 512], [3, 8, 27, 3]], 98 | 'B5': [[64, 128, 320, 512], [3, 6, 40, 3]] 99 | } 100 | 101 | 102 | class MiT(nn.Module): 103 | def __init__(self, model_name: str = 'B0'): 104 | super().__init__() 105 | assert model_name in mit_settings.keys(), f"MiT model name should be in {list(mit_settings.keys())}" 106 | embed_dims, depths = mit_settings[model_name] 107 | drop_path_rate = 0.1 108 | self.channels = embed_dims 109 | 110 | # patch_embed 111 | self.patch_embed1 = PatchEmbed(3, embed_dims[0], 7, 4) 112 | self.patch_embed2 = PatchEmbed(embed_dims[0], embed_dims[1], 3, 2) 113 | self.patch_embed3 = PatchEmbed(embed_dims[1], embed_dims[2], 3, 2) 114 | self.patch_embed4 = PatchEmbed(embed_dims[2], embed_dims[3], 3, 2) 115 | 116 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 117 | 118 | cur = 0 119 | self.block1 = nn.ModuleList([Block(embed_dims[0], 1, 8, dpr[cur+i]) for i in range(depths[0])]) 120 | self.norm1 = nn.LayerNorm(embed_dims[0]) 121 | 122 | cur += depths[0] 123 | self.block2 = nn.ModuleList([Block(embed_dims[1], 2, 4, dpr[cur+i]) for i in range(depths[1])]) 124 | self.norm2 = nn.LayerNorm(embed_dims[1]) 125 | 126 | cur += depths[1] 127 | self.block3 = nn.ModuleList([Block(embed_dims[2], 5, 2, dpr[cur+i]) for i in range(depths[2])]) 128 | self.norm3 = nn.LayerNorm(embed_dims[2]) 129 | 130 | cur += depths[2] 131 | self.block4 = nn.ModuleList([Block(embed_dims[3], 8, 1, dpr[cur+i]) for i in range(depths[3])]) 132 | self.norm4 = nn.LayerNorm(embed_dims[3]) 133 | 134 | 135 | def forward(self, x: Tensor) -> Tensor: 136 | B = x.shape[0] 137 | # stage 1 138 | x, H, W = self.patch_embed1(x) 139 | for blk in self.block1: 140 | x = blk(x, H, W) 141 | x1 = self.norm1(x).reshape(B, H, W, -1).permute(0, 3, 1, 2) 142 | 143 | # stage 2 144 | x, H, W = self.patch_embed2(x1) 145 | for blk in self.block2: 146 | x = blk(x, H, W) 147 | x2 = self.norm2(x).reshape(B, H, W, -1).permute(0, 3, 1, 2) 148 | 149 | # stage 3 150 | x, H, W = self.patch_embed3(x2) 151 | for blk in self.block3: 152 | x = blk(x, H, W) 153 | x3 = self.norm3(x).reshape(B, H, W, -1).permute(0, 3, 1, 2) 154 | 155 | # stage 4 156 | x, H, W = self.patch_embed4(x3) 157 | for blk in self.block4: 158 | x = blk(x, H, W) 159 | x4 = self.norm4(x).reshape(B, H, W, -1).permute(0, 3, 1, 2) 160 | 161 | return x1, x2, x3, x4 162 | 163 | 164 | if __name__ == '__main__': 165 | model = MiT('B0') 166 | x = torch.zeros(1, 3, 224, 224) 167 | outs = model(x) 168 | for y in outs: 169 | print(y.shape) 170 | 171 | 172 | -------------------------------------------------------------------------------- /semseg/models/backbones/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | 4 | 5 | class ConvModule(nn.Sequential): 6 | def __init__(self, c1, c2, k, s=1, p=0, d=1, g=1): 7 | super().__init__( 8 | nn.Conv2d(c1, c2, k, s, p, d, g, bias=False), 9 | nn.BatchNorm2d(c2), 10 | nn.ReLU6(True) 11 | ) 12 | 13 | 14 | class InvertedResidual(nn.Module): 15 | def __init__(self, c1, c2, s, expand_ratio): 16 | super().__init__() 17 | ch = int(round(c1 * expand_ratio)) 18 | self.use_res_connect = s == 1 and c1 == c2 19 | 20 | layers = [] 21 | 22 | if expand_ratio != 1: 23 | layers.append(ConvModule(c1, ch, 1)) 24 | 25 | layers.extend([ 26 | ConvModule(ch, ch, 3, s, 1, g=ch), 27 | nn.Conv2d(ch, c2, 1, bias=False), 28 | nn.BatchNorm2d(c2) 29 | ]) 30 | 31 | self.conv = nn.Sequential(*layers) 32 | 33 | def forward(self, x: Tensor) -> Tensor: 34 | if self.use_res_connect: 35 | return x + self.conv(x) 36 | else: 37 | return self.conv(x) 38 | 39 | 40 | mobilenetv2_settings = { 41 | '1.0': [] 42 | } 43 | 44 | 45 | class MobileNetV2(nn.Module): 46 | def __init__(self, variant: str = None): 47 | super().__init__() 48 | self.out_indices = [3, 6, 13, 17] 49 | self.channels = [24, 32, 96, 320] 50 | input_channel = 32 51 | 52 | inverted_residual_setting = [ 53 | # t, c, n, s 54 | [1, 16, 1, 1], 55 | [6, 24, 2, 2], 56 | [6, 32, 3, 2], 57 | [6, 64, 4, 2], 58 | [6, 96, 3, 1], 59 | [6, 160, 3, 2], 60 | [6, 320, 1, 1], 61 | ] 62 | 63 | self.features = nn.ModuleList([ConvModule(3, input_channel, 3, 2, 1)]) 64 | 65 | for t, c, n, s in inverted_residual_setting: 66 | output_channel = c 67 | for i in range(n): 68 | stride = s if i == 0 else 1 69 | self.features.append(InvertedResidual(input_channel, output_channel, stride, t)) 70 | input_channel = output_channel 71 | 72 | def forward(self, x: Tensor) -> Tensor: 73 | outs = [] 74 | for i, m in enumerate(self.features): 75 | x = m(x) 76 | if i in self.out_indices: 77 | outs.append(x) 78 | return outs 79 | 80 | 81 | if __name__ == '__main__': 82 | model = MobileNetV2() 83 | # model.load_state_dict(torch.load('checkpoints/backbones/mobilenet_v2.pth', map_location='cpu'), strict=False) 84 | model.eval() 85 | x = torch.randn(1, 3, 224, 224) 86 | # outs = model(x) 87 | # for y in outs: 88 | # print(y.shape) 89 | 90 | from fvcore.nn import flop_count_table, FlopCountAnalysis 91 | print(flop_count_table(FlopCountAnalysis(model, x))) -------------------------------------------------------------------------------- /semseg/models/backbones/mobilenetv3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torch.nn import functional as F 4 | from typing import Optional 5 | 6 | 7 | def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: 8 | """ 9 | This function is taken from the original tf repo. 10 | It ensures that all layers have a channel number that is divisible by 8 11 | It can be seen here: 12 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 13 | """ 14 | if min_value is None: 15 | min_value = divisor 16 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 17 | # Make sure that round down does not go down by more than 10%. 18 | if new_v < 0.9 * v: 19 | new_v += divisor 20 | return new_v 21 | 22 | 23 | class ConvModule(nn.Sequential): 24 | def __init__(self, c1, c2, k, s=1, p=0, d=1, g=1): 25 | super().__init__( 26 | nn.Conv2d(c1, c2, k, s, p, d, g, bias=False), 27 | nn.BatchNorm2d(c2), 28 | nn.ReLU6(True) 29 | ) 30 | 31 | 32 | class SqueezeExcitation(nn.Module): 33 | def __init__(self, ch, squeeze_factor=4): 34 | super().__init__() 35 | squeeze_ch = _make_divisible(ch // squeeze_factor, 8) 36 | self.fc1 = nn.Conv2d(ch, squeeze_ch, 1) 37 | self.relu = nn.ReLU(True) 38 | self.fc2 = nn.Conv2d(squeeze_ch, ch, 1) 39 | 40 | def _scale(self, x: Tensor) -> Tensor: 41 | scale = F.adaptive_avg_pool2d(x, 1) 42 | scale = self.fc2(self.relu(self.fc1(scale))) 43 | return F.hardsigmoid(scale, True) 44 | 45 | def forward(self, x: Tensor) -> Tensor: 46 | scale = self._scale(x) 47 | return scale * x 48 | 49 | 50 | class InvertedResidualConfig: 51 | def __init__(self, c1, c2, k, expanded_ch, use_se) -> None: 52 | pass 53 | 54 | 55 | class InvertedResidual(nn.Module): 56 | def __init__(self, c1, c2, s, expand_ratio): 57 | super().__init__() 58 | ch = int(round(c1 * expand_ratio)) 59 | self.use_res_connect = s == 1 and c1 == c2 60 | 61 | layers = [] 62 | 63 | if expand_ratio != 1: 64 | layers.append(ConvModule(c1, ch, 1)) 65 | 66 | layers.extend([ 67 | ConvModule(ch, ch, 3, s, 1, g=ch), 68 | nn.Conv2d(ch, c2, 1, bias=False), 69 | nn.BatchNorm2d(c2) 70 | ]) 71 | 72 | self.conv = nn.Sequential(*layers) 73 | 74 | def forward(self, x: Tensor) -> Tensor: 75 | if self.use_res_connect: 76 | return x + self.conv(x) 77 | else: 78 | return self.conv(x) 79 | 80 | 81 | mobilenetv3_settings = { 82 | 'S': [], 83 | 'L': [] 84 | } 85 | 86 | 87 | class MobileNetV3(nn.Module): 88 | def __init__(self, variant: str = None): 89 | super().__init__() 90 | self.out_indices = [3, 6, 13, 17] 91 | self.channels = [24, 32, 96, 320] 92 | input_channel = 32 93 | 94 | inverted_residual_setting = [ 95 | # t, c, n, s 96 | [1, 16, 1, 1], 97 | [6, 24, 2, 2], 98 | [6, 32, 3, 2], 99 | [6, 64, 4, 2], 100 | [6, 96, 3, 1], 101 | [6, 160, 3, 2], 102 | [6, 320, 1, 1], 103 | ] 104 | 105 | self.features = nn.ModuleList([ConvModule(3, input_channel, 3, 2, 1)]) 106 | 107 | for t, c, n, s in inverted_residual_setting: 108 | output_channel = c 109 | for i in range(n): 110 | stride = s if i == 0 else 1 111 | self.features.append(InvertedResidual(input_channel, output_channel, stride, t)) 112 | input_channel = output_channel 113 | 114 | def forward(self, x: Tensor) -> Tensor: 115 | outs = [] 116 | for i, m in enumerate(self.features): 117 | x = m(x) 118 | if i in self.out_indices: 119 | outs.append(x) 120 | return outs 121 | 122 | 123 | if __name__ == '__main__': 124 | model = MobileNetV3() 125 | # model.load_state_dict(torch.load('checkpoints/backbones/mobilenet_v2.pth', map_location='cpu'), strict=False) 126 | model.eval() 127 | x = torch.randn(1, 3, 224, 224) 128 | # outs = model(x) 129 | # for y in outs: 130 | # print(y.shape) 131 | 132 | from fvcore.nn import flop_count_table, FlopCountAnalysis 133 | print(flop_count_table(FlopCountAnalysis(model, x))) -------------------------------------------------------------------------------- /semseg/models/backbones/poolformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from semseg.models.layers import DropPath 4 | 5 | 6 | class PatchEmbed(nn.Module): 7 | """Image to Patch Embedding with overlapping 8 | """ 9 | def __init__(self, patch_size=16, stride=16, padding=0, in_ch=3, embed_dim=768): 10 | super().__init__() 11 | self.proj = nn.Conv2d(in_ch, embed_dim, patch_size, stride, padding) 12 | 13 | def forward(self, x: torch.Tensor) -> Tensor: 14 | x = self.proj(x) # b x hidden_dim x 14 x 14 15 | return x 16 | 17 | 18 | class Pooling(nn.Module): 19 | def __init__(self, pool_size=3) -> None: 20 | super().__init__() 21 | self.pool = nn.AvgPool2d(pool_size, 1, pool_size//2, count_include_pad=False) 22 | 23 | def forward(self, x: Tensor) -> Tensor: 24 | return self.pool(x) - x 25 | 26 | 27 | class MLP(nn.Module): 28 | def __init__(self, dim, hidden_dim, out_dim=None) -> None: 29 | super().__init__() 30 | out_dim = out_dim or dim 31 | self.fc1 = nn.Conv2d(dim, hidden_dim, 1) 32 | self.act = nn.GELU() 33 | self.fc2 = nn.Conv2d(hidden_dim, out_dim, 1) 34 | 35 | def forward(self, x: Tensor) -> Tensor: 36 | return self.fc2(self.act(self.fc1(x))) 37 | 38 | 39 | class PoolFormerBlock(nn.Module): 40 | def __init__(self, dim, pool_size=3, dpr=0., layer_scale_init_value=1e-5): 41 | super().__init__() 42 | self.norm1 = nn.GroupNorm(1, dim) 43 | self.token_mixer = Pooling(pool_size) 44 | self.norm2 = nn.GroupNorm(1, dim) 45 | self.drop_path = DropPath(dpr) if dpr > 0. else nn.Identity() 46 | self.mlp = MLP(dim, int(dim*4)) 47 | 48 | self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) 49 | self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) 50 | 51 | def forward(self, x: Tensor) -> Tensor: 52 | x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.token_mixer(self.norm1(x))) 53 | x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x))) 54 | return x 55 | 56 | poolformer_settings = { 57 | 'S24': [[4, 4, 12, 4], [64, 128, 320, 512], 0.1], # [layers, embed_dims, drop_path_rate] 58 | 'S36': [[6, 6, 18, 6], [64, 128, 320, 512], 0.2], 59 | 'M36': [[6, 6, 18, 6], [96, 192, 384, 768], 0.3] 60 | } 61 | 62 | 63 | class PoolFormer(nn.Module): 64 | def __init__(self, model_name: str = 'S24') -> None: 65 | super().__init__() 66 | assert model_name in poolformer_settings.keys(), f"PoolFormer model name should be in {list(poolformer_settings.keys())}" 67 | layers, embed_dims, drop_path_rate = poolformer_settings[model_name] 68 | self.channels = embed_dims 69 | 70 | self.patch_embed = PatchEmbed(7, 4, 2, 3, embed_dims[0]) 71 | 72 | network = [] 73 | 74 | for i in range(len(layers)): 75 | blocks = [] 76 | for j in range(layers[i]): 77 | dpr = drop_path_rate * (j + sum(layers[:i])) / (sum(layers) - 1) 78 | blocks.append(PoolFormerBlock(embed_dims[i], 3, dpr)) 79 | 80 | network.append(nn.Sequential(*blocks)) 81 | if i >= len(layers) - 1: break 82 | network.append(PatchEmbed(3, 2, 1, embed_dims[i], embed_dims[i+1])) 83 | 84 | self.network = nn.ModuleList(network) 85 | 86 | self.out_indices = [0, 2, 4, 6] 87 | for i, index in enumerate(self.out_indices): 88 | self.add_module(f"norm{index}", nn.GroupNorm(1, embed_dims[i])) 89 | 90 | def forward(self, x: Tensor): 91 | x = self.patch_embed(x) 92 | outs = [] 93 | 94 | for i, blk in enumerate(self.network): 95 | x = blk(x) 96 | 97 | if i in self.out_indices: 98 | out = getattr(self, f"norm{i}")(x) 99 | outs.append(out) 100 | return outs 101 | 102 | 103 | if __name__ == '__main__': 104 | model = PoolFormer('S24') 105 | model.load_state_dict(torch.load('C:\\Users\\sithu\\Documents\\weights\\backbones\\poolformer\\poolformer_s24.pth.tar', map_location='cpu'), strict=False) 106 | x = torch.randn(1, 3, 224, 224) 107 | feats = model(x) 108 | for y in feats: 109 | print(y.shape) 110 | -------------------------------------------------------------------------------- /semseg/models/backbones/pvt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torch.nn import functional as F 4 | from semseg.models.layers import DropPath 5 | 6 | 7 | class DWConv(nn.Module): 8 | def __init__(self, dim): 9 | super().__init__() 10 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim) 11 | 12 | def forward(self, x: Tensor, H: int, W: int) -> Tensor: 13 | B, _, C = x.shape 14 | x = x.transpose(1, 2).view(B, C, H, W) 15 | x = self.dwconv(x) 16 | return x.flatten(2).transpose(1, 2) 17 | 18 | 19 | class MLP(nn.Module): 20 | def __init__(self, dim, hidden_dim, out_dim=None) -> None: 21 | super().__init__() 22 | out_dim = out_dim or dim 23 | self.fc1 = nn.Linear(dim, hidden_dim) 24 | self.fc2 = nn.Linear(hidden_dim, out_dim) 25 | self.dwconv = DWConv(hidden_dim) 26 | 27 | def forward(self, x: Tensor, H: int, W: int) -> Tensor: 28 | return self.fc2(F.gelu(self.dwconv(self.fc1(x), H, W))) 29 | 30 | 31 | class Attention(nn.Module): 32 | def __init__(self, dim, head, sr_ratio): 33 | super().__init__() 34 | self.head = head 35 | self.sr_ratio = sr_ratio 36 | self.scale = (dim // head) ** -0.5 37 | self.q = nn.Linear(dim, dim, bias=True) 38 | self.kv = nn.Linear(dim, dim*2, bias=True) 39 | self.proj = nn.Linear(dim, dim) 40 | 41 | if sr_ratio > 1: 42 | self.sr = nn.Conv2d(dim, dim, sr_ratio, sr_ratio) 43 | self.norm = nn.LayerNorm(dim) 44 | 45 | def forward(self, x: Tensor, H, W) -> Tensor: 46 | B, N, C = x.shape 47 | q = self.q(x).reshape(B, N, self.head, C // self.head).permute(0, 2, 1, 3) 48 | 49 | if self.sr_ratio > 1: 50 | x = x.permute(0, 2, 1).reshape(B, C, H, W) 51 | x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1) 52 | x = self.norm(x) 53 | k, v = self.kv(x).reshape(B, -1, 2, self.head, C // self.head).permute(2, 0, 3, 1, 4) 54 | 55 | attn = (q @ k.transpose(-2, -1)) * self.scale 56 | attn = attn.softmax(dim=-1) 57 | 58 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 59 | x = self.proj(x) 60 | return x 61 | 62 | 63 | class Block(nn.Module): 64 | def __init__(self, dim, head, sr_ratio=1, mlp_ratio=4, dpr=0.): 65 | super().__init__() 66 | self.norm1 = nn.LayerNorm(dim) 67 | self.attn = Attention(dim, head, sr_ratio) 68 | self.drop_path = DropPath(dpr) if dpr > 0. else nn.Identity() 69 | self.norm2 = nn.LayerNorm(dim) 70 | self.mlp = MLP(dim, int(dim*mlp_ratio)) 71 | 72 | def forward(self, x: Tensor, H, W) -> Tensor: 73 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 74 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 75 | return x 76 | 77 | 78 | class PatchEmbed(nn.Module): 79 | def __init__(self, c1=3, c2=64, patch_size=7, stride=4): 80 | super().__init__() 81 | self.proj = nn.Conv2d(c1, c2, patch_size, stride, patch_size//2) 82 | self.norm = nn.LayerNorm(c2) 83 | 84 | def forward(self, x: Tensor) -> Tensor: 85 | x = self.proj(x) 86 | _, _, H, W = x.shape 87 | x = x.flatten(2).transpose(1, 2) 88 | x = self.norm(x) 89 | return x, H, W 90 | 91 | 92 | pvtv2_settings = { 93 | 'B1': [2, 2, 2, 2], # depths 94 | 'B2': [3, 4, 6, 3], 95 | 'B3': [3, 4, 18, 3], 96 | 'B4': [3, 8, 27, 3], 97 | 'B5': [3, 6, 40, 3] 98 | } 99 | 100 | 101 | class PVTv2(nn.Module): 102 | def __init__(self, model_name: str = 'B1') -> None: 103 | super().__init__() 104 | assert model_name in pvtv2_settings.keys(), f"PVTv2 model name should be in {list(pvtv2_settings.keys())}" 105 | depths = pvtv2_settings[model_name] 106 | embed_dims = [64, 128, 320, 512] 107 | drop_path_rate = 0.1 108 | self.channels = embed_dims 109 | # patch_embed 110 | self.patch_embed1 = PatchEmbed(3, embed_dims[0], 7, 4) 111 | self.patch_embed2 = PatchEmbed(embed_dims[0], embed_dims[1], 3, 2) 112 | self.patch_embed3 = PatchEmbed(embed_dims[1], embed_dims[2], 3, 2) 113 | self.patch_embed4 = PatchEmbed(embed_dims[2], embed_dims[3], 3, 2) 114 | 115 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 116 | # transformer encoder 117 | cur = 0 118 | self.block1 = nn.ModuleList([Block(embed_dims[0], 1, 8, 8, dpr[cur+i]) for i in range(depths[0])]) 119 | self.norm1 = nn.LayerNorm(embed_dims[0]) 120 | 121 | cur += depths[0] 122 | self.block2 = nn.ModuleList([Block(embed_dims[1], 2, 4, 8, dpr[cur+i]) for i in range(depths[1])]) 123 | self.norm2 = nn.LayerNorm(embed_dims[1]) 124 | 125 | cur += depths[1] 126 | self.block3 = nn.ModuleList([Block(embed_dims[2], 5, 2, 4, dpr[cur+i]) for i in range(depths[2])]) 127 | self.norm3 = nn.LayerNorm(embed_dims[2]) 128 | 129 | cur += depths[2] 130 | self.block4 = nn.ModuleList([Block(embed_dims[3], 8, 1, 4, dpr[cur+i]) for i in range(depths[3])]) 131 | self.norm4 = nn.LayerNorm(embed_dims[3]) 132 | 133 | def forward(self, x: Tensor) -> Tensor: 134 | B = x.shape[0] 135 | # stage 1 136 | x, H, W = self.patch_embed1(x) 137 | for blk in self.block1: 138 | x = blk(x, H, W) 139 | x1 = self.norm1(x).reshape(B, H, W, -1).permute(0, 3, 1, 2) 140 | 141 | # stage 2 142 | x, H, W = self.patch_embed2(x1) 143 | for blk in self.block2: 144 | x = blk(x, H, W) 145 | x2 = self.norm2(x).reshape(B, H, W, -1).permute(0, 3, 1, 2) 146 | 147 | # stage 3 148 | x, H, W = self.patch_embed3(x2) 149 | for blk in self.block3: 150 | x = blk(x, H, W) 151 | x3 = self.norm3(x).reshape(B, H, W, -1).permute(0, 3, 1, 2) 152 | 153 | # stage 4 154 | x, H, W = self.patch_embed4(x3) 155 | for blk in self.block4: 156 | x = blk(x, H, W) 157 | x4 = self.norm4(x).reshape(B, H, W, -1).permute(0, 3, 1, 2) 158 | 159 | return x1, x2, x3, x4 160 | 161 | 162 | if __name__ == '__main__': 163 | model = PVTv2('B1') 164 | model.load_state_dict(torch.load('checkpoints/backbones/pvtv2/pvt_v2_b1.pth', map_location='cpu'), strict=False) 165 | x = torch.zeros(1, 3, 224, 224) 166 | outs = model(x) 167 | for y in outs: 168 | print(y.shape) -------------------------------------------------------------------------------- /semseg/models/backbones/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torch.nn import functional as F 4 | 5 | 6 | class BasicBlock(nn.Module): 7 | """2 Layer No Expansion Block 8 | """ 9 | expansion: int = 1 10 | def __init__(self, c1, c2, s=1, downsample= None) -> None: 11 | super().__init__() 12 | self.conv1 = nn.Conv2d(c1, c2, 3, s, 1, bias=False) 13 | self.bn1 = nn.BatchNorm2d(c2) 14 | self.conv2 = nn.Conv2d(c2, c2, 3, 1, 1, bias=False) 15 | self.bn2 = nn.BatchNorm2d(c2) 16 | self.downsample = downsample 17 | 18 | def forward(self, x: Tensor) -> Tensor: 19 | identity = x 20 | out = F.relu(self.bn1(self.conv1(x))) 21 | out = self.bn2(self.conv2(out)) 22 | if self.downsample is not None: identity = self.downsample(x) 23 | out += identity 24 | return F.relu(out) 25 | 26 | 27 | class Bottleneck(nn.Module): 28 | """3 Layer 4x Expansion Block 29 | """ 30 | expansion: int = 4 31 | def __init__(self, c1, c2, s=1, downsample=None) -> None: 32 | super().__init__() 33 | self.conv1 = nn.Conv2d(c1, c2, 1, 1, 0, bias=False) 34 | self.bn1 = nn.BatchNorm2d(c2) 35 | self.conv2 = nn.Conv2d(c2, c2, 3, s, 1, bias=False) 36 | self.bn2 = nn.BatchNorm2d(c2) 37 | self.conv3 = nn.Conv2d(c2, c2 * self.expansion, 1, 1, 0, bias=False) 38 | self.bn3 = nn.BatchNorm2d(c2 * self.expansion) 39 | self.downsample = downsample 40 | 41 | def forward(self, x: Tensor) -> Tensor: 42 | identity = x 43 | out = F.relu(self.bn1(self.conv1(x))) 44 | out = F.relu(self.bn2(self.conv2(out))) 45 | out = self.bn3(self.conv3(out)) 46 | if self.downsample is not None: identity = self.downsample(x) 47 | out += identity 48 | return F.relu(out) 49 | 50 | 51 | resnet_settings = { 52 | '18': [BasicBlock, [2, 2, 2, 2], [64, 128, 256, 512]], 53 | '34': [BasicBlock, [3, 4, 6, 3], [64, 128, 256, 512]], 54 | '50': [Bottleneck, [3, 4, 6, 3], [256, 512, 1024, 2048]], 55 | '101': [Bottleneck, [3, 4, 23, 3], [256, 512, 1024, 2048]], 56 | '152': [Bottleneck, [3, 8, 36, 3], [256, 512, 1024, 2048]] 57 | } 58 | 59 | 60 | class ResNet(nn.Module): 61 | def __init__(self, model_name: str = '50') -> None: 62 | super().__init__() 63 | assert model_name in resnet_settings.keys(), f"ResNet model name should be in {list(resnet_settings.keys())}" 64 | block, depths, channels = resnet_settings[model_name] 65 | 66 | self.inplanes = 64 67 | self.channels = channels 68 | self.conv1 = nn.Conv2d(3, self.inplanes, 7, 2, 3, bias=False) 69 | self.bn1 = nn.BatchNorm2d(self.inplanes) 70 | self.maxpool = nn.MaxPool2d(3, 2, 1) 71 | 72 | self.layer1 = self._make_layer(block, 64, depths[0], s=1) 73 | self.layer2 = self._make_layer(block, 128, depths[1], s=2) 74 | self.layer3 = self._make_layer(block, 256, depths[2], s=2) 75 | self.layer4 = self._make_layer(block, 512, depths[3], s=2) 76 | 77 | 78 | def _make_layer(self, block, planes, depth, s=1) -> nn.Sequential: 79 | downsample = None 80 | if s != 1 or self.inplanes != planes * block.expansion: 81 | downsample = nn.Sequential( 82 | nn.Conv2d(self.inplanes, planes * block.expansion, 1, s, bias=False), 83 | nn.BatchNorm2d(planes * block.expansion) 84 | ) 85 | layers = nn.Sequential( 86 | block(self.inplanes, planes, s, downsample), 87 | *[block(planes * block.expansion, planes) for _ in range(1, depth)] 88 | ) 89 | self.inplanes = planes * block.expansion 90 | return layers 91 | 92 | 93 | def forward(self, x: Tensor) -> Tensor: 94 | x = self.maxpool(F.relu(self.bn1(self.conv1(x)))) # [1, 64, H/4, W/4] 95 | x1 = self.layer1(x) # [1, 64/256, H/4, W/4] 96 | x2 = self.layer2(x1) # [1, 128/512, H/8, W/8] 97 | x3 = self.layer3(x2) # [1, 256/1024, H/16, W/16] 98 | x4 = self.layer4(x3) # [1, 512/2048, H/32, W/32] 99 | return x1, x2, x3, x4 100 | 101 | 102 | if __name__ == '__main__': 103 | model = ResNet('18') 104 | # model.load_state_dict(torch.load('C:\\Users\\sithu\\Documents\\weights\\backbones\\resnet\\resnet18_a1.pth', map_location='cpu'), strict=False) 105 | x = torch.zeros(1, 3, 224, 224) 106 | outs = model(x) 107 | for y in outs: 108 | print(y.shape) -------------------------------------------------------------------------------- /semseg/models/backbones/resnetd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torch.nn import functional as F 4 | 5 | 6 | class BasicBlock(nn.Module): 7 | """2 Layer No Expansion Block 8 | """ 9 | expansion: int = 1 10 | def __init__(self, c1, c2, s=1, d=1, downsample= None) -> None: 11 | super().__init__() 12 | self.conv1 = nn.Conv2d(c1, c2, 3, s, 1, bias=False) 13 | self.bn1 = nn.BatchNorm2d(c2) 14 | self.conv2 = nn.Conv2d(c2, c2, 3, 1, d if d != 1 else 1, d, bias=False) 15 | self.bn2 = nn.BatchNorm2d(c2) 16 | self.downsample = downsample 17 | 18 | def forward(self, x: Tensor) -> Tensor: 19 | identity = x 20 | out = F.relu(self.bn1(self.conv1(x))) 21 | out = self.bn2(self.conv2(out)) 22 | if self.downsample is not None: identity = self.downsample(x) 23 | out += identity 24 | return F.relu(out) 25 | 26 | 27 | class Bottleneck(nn.Module): 28 | """3 Layer 4x Expansion Block 29 | """ 30 | expansion: int = 4 31 | def __init__(self, c1, c2, s=1, d=1, downsample=None) -> None: 32 | super().__init__() 33 | self.conv1 = nn.Conv2d(c1, c2, 1, 1, 0, bias=False) 34 | self.bn1 = nn.BatchNorm2d(c2) 35 | self.conv2 = nn.Conv2d(c2, c2, 3, s, d if d != 1 else 1, d, bias=False) 36 | self.bn2 = nn.BatchNorm2d(c2) 37 | self.conv3 = nn.Conv2d(c2, c2 * self.expansion, 1, 1, 0, bias=False) 38 | self.bn3 = nn.BatchNorm2d(c2 * self.expansion) 39 | self.downsample = downsample 40 | 41 | def forward(self, x: Tensor) -> Tensor: 42 | identity = x 43 | out = F.relu(self.bn1(self.conv1(x))) 44 | out = F.relu(self.bn2(self.conv2(out))) 45 | out = self.bn3(self.conv3(out)) 46 | if self.downsample is not None: identity = self.downsample(x) 47 | out += identity 48 | return F.relu(out) 49 | 50 | 51 | class Stem(nn.Sequential): 52 | def __init__(self, c1, ch, c2): 53 | super().__init__( 54 | nn.Conv2d(c1, ch, 3, 2, 1, bias=False), 55 | nn.BatchNorm2d(ch), 56 | nn.ReLU(True), 57 | nn.Conv2d(ch, ch, 3, 1, 1, bias=False), 58 | nn.BatchNorm2d(ch), 59 | nn.ReLU(True), 60 | nn.Conv2d(ch, c2, 3, 1, 1, bias=False), 61 | nn.BatchNorm2d(c2), 62 | nn.ReLU(True), 63 | nn.MaxPool2d(3, 2, 1) 64 | ) 65 | 66 | 67 | resnetd_settings = { 68 | '18': [BasicBlock, [2, 2, 2, 2], [64, 128, 256, 512]], 69 | '50': [Bottleneck, [3, 4, 6, 3], [256, 512, 1024, 2048]], 70 | '101': [Bottleneck, [3, 4, 23, 3], [256, 512, 1024, 2048]] 71 | } 72 | 73 | 74 | class ResNetD(nn.Module): 75 | def __init__(self, model_name: str = '50') -> None: 76 | super().__init__() 77 | assert model_name in resnetd_settings.keys(), f"ResNetD model name should be in {list(resnetd_settings.keys())}" 78 | block, depths, channels = resnetd_settings[model_name] 79 | 80 | self.inplanes = 128 81 | self.channels = channels 82 | self.stem = Stem(3, 64, self.inplanes) 83 | self.layer1 = self._make_layer(block, 64, depths[0], s=1) 84 | self.layer2 = self._make_layer(block, 128, depths[1], s=2) 85 | self.layer3 = self._make_layer(block, 256, depths[2], s=2, d=2) 86 | self.layer4 = self._make_layer(block, 512, depths[3], s=2, d=4) 87 | 88 | 89 | def _make_layer(self, block, planes, depth, s=1, d=1) -> nn.Sequential: 90 | downsample = None 91 | 92 | if s != 1 or self.inplanes != planes * block.expansion: 93 | downsample = nn.Sequential( 94 | nn.Conv2d(self.inplanes, planes * block.expansion, 1, s, bias=False), 95 | nn.BatchNorm2d(planes * block.expansion) 96 | ) 97 | layers = nn.Sequential( 98 | block(self.inplanes, planes, s, d, downsample=downsample), 99 | *[block(planes * block.expansion, planes, d=d) for _ in range(1, depth)] 100 | ) 101 | self.inplanes = planes * block.expansion 102 | return layers 103 | 104 | 105 | def forward(self, x: Tensor) -> Tensor: 106 | x = self.stem(x) # [1, 128, H/4, W/4] 107 | x1 = self.layer1(x) # [1, 64/256, H/4, W/4] 108 | x2 = self.layer2(x1) # [1, 128/512, H/8, W/8] 109 | x3 = self.layer3(x2) # [1, 256/1024, H/16, W/16] 110 | x4 = self.layer4(x3) # [1, 512/2048, H/32, W/32] 111 | return x1, x2, x3, x4 112 | 113 | 114 | if __name__ == '__main__': 115 | model = ResNetD('18') 116 | model.load_state_dict(torch.load('checkpoints/backbones/resnetd/resnetd18.pth', map_location='cpu'), strict=False) 117 | x = torch.zeros(1, 3, 224, 224) 118 | outs = model(x) 119 | for y in outs: 120 | print(y.shape) -------------------------------------------------------------------------------- /semseg/models/backbones/uniformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from semseg.models.layers import DropPath 4 | 5 | 6 | class MLP(nn.Module): 7 | def __init__(self, dim, hidden_dim, out_dim=None) -> None: 8 | super().__init__() 9 | out_dim = out_dim or dim 10 | self.fc1 = nn.Linear(dim, hidden_dim) 11 | self.act = nn.GELU() 12 | self.fc2 = nn.Linear(hidden_dim, out_dim) 13 | 14 | def forward(self, x: Tensor) -> Tensor: 15 | return self.fc2(self.act(self.fc1(x))) 16 | 17 | 18 | class CMLP(nn.Module): 19 | def __init__(self, dim, hidden_dim, out_dim=None) -> None: 20 | super().__init__() 21 | out_dim = out_dim or dim 22 | self.fc1 = nn.Conv2d(dim, hidden_dim, 1) 23 | self.act = nn.GELU() 24 | self.fc2 = nn.Conv2d(hidden_dim, out_dim, 1) 25 | 26 | def forward(self, x: Tensor) -> Tensor: 27 | return self.fc2(self.act(self.fc1(x))) 28 | 29 | 30 | class Attention(nn.Module): 31 | def __init__(self, dim, num_heads=8) -> None: 32 | super().__init__() 33 | self.num_heads = num_heads 34 | self.scale = (dim // num_heads) ** -0.5 35 | self.qkv = nn.Linear(dim, dim*3) 36 | self.proj = nn.Linear(dim, dim) 37 | 38 | def forward(self, x: Tensor) -> Tensor: 39 | B, N, C = x.shape 40 | q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads, C//self.num_heads).permute(2, 0, 3, 1, 4) 41 | attn = (q @ k.transpose(-2, -1)) * self.scale 42 | attn = attn.softmax(dim=-1) 43 | 44 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 45 | x = self.proj(x) 46 | return x 47 | 48 | 49 | class CBlock(nn.Module): 50 | def __init__(self, dim, dpr=0.): 51 | super().__init__() 52 | self.pos_embed = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim) 53 | self.norm1 = nn.BatchNorm2d(dim) 54 | self.conv1 = nn.Conv2d(dim, dim, 1) 55 | self.conv2 = nn.Conv2d(dim, dim, 1) 56 | self.attn = nn.Conv2d(dim, dim, 5, 1, 2, groups=dim) 57 | self.drop_path = DropPath(dpr) if dpr > 0. else nn.Identity() 58 | self.norm2 = nn.BatchNorm2d(dim) 59 | self.mlp = CMLP(dim, int(dim*4)) 60 | 61 | def forward(self, x: Tensor) -> Tensor: 62 | x = x + self.pos_embed(x) 63 | x = x + self.drop_path(self.conv2(self.attn(self.conv1(self.norm1(x))))) 64 | x = x + self.drop_path(self.mlp(self.norm2(x))) 65 | return x 66 | 67 | 68 | class SABlock(nn.Module): 69 | def __init__(self, dim, num_heads, dpr=0.) -> None: 70 | super().__init__() 71 | self.pos_embed = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim) 72 | self.norm1 = nn.LayerNorm(dim) 73 | self.attn = Attention(dim, num_heads) 74 | self.drop_path = DropPath(dpr) if dpr > 0. else nn.Identity() 75 | self.norm2 = nn.LayerNorm(dim) 76 | self.mlp = MLP(dim, int(dim*4)) 77 | 78 | def forward(self, x: Tensor) -> Tensor: 79 | x = x + self.pos_embed(x) 80 | B, N, H, W = x.shape 81 | x = x.flatten(2).transpose(1, 2) 82 | x = x + self.drop_path(self.attn(self.norm1(x))) 83 | x = x + self.drop_path(self.mlp(self.norm2(x))) 84 | x = x.transpose(1, 2).reshape(B, N, H, W) 85 | return x 86 | 87 | 88 | class PatchEmbed(nn.Module): 89 | def __init__(self, patch_size=16, in_ch=3, embed_dim=768) -> None: 90 | super().__init__() 91 | self.norm = nn.LayerNorm(embed_dim) 92 | self.proj = nn.Conv2d(in_ch, embed_dim, patch_size, patch_size) 93 | 94 | def forward(self, x: Tensor) -> Tensor: 95 | x = self.proj(x) 96 | B, C, H, W = x.shape 97 | x = x.flatten(2).transpose(1, 2) 98 | x = self.norm(x) 99 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 100 | return x 101 | 102 | 103 | uniformer_settings = { 104 | 'S': [3, 4, 8, 3], # [depth] 105 | 'B': [5, 8, 20, 7] 106 | } 107 | 108 | 109 | class UniFormer(nn.Module): 110 | def __init__(self, model_name: str = 'S') -> None: 111 | super().__init__() 112 | assert model_name in uniformer_settings.keys(), f"UniFormer model name should be in {list(uniformer_settings.keys())}" 113 | depth = uniformer_settings[model_name] 114 | 115 | head_dim = 64 116 | drop_path_rate = 0. 117 | embed_dims = [64, 128, 320, 512] 118 | 119 | for i in range(4): 120 | self.add_module(f"patch_embed{i+1}", PatchEmbed(4 if i == 0 else 2, 3 if i == 0 else embed_dims[i-1], embed_dims[i])) 121 | self.add_module(f"norm{i+1}", nn.LayerNorm(embed_dims[i])) 122 | 123 | self.pos_drop = nn.Dropout(0.) 124 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depth))] 125 | num_heads = [dim // head_dim for dim in embed_dims] 126 | 127 | self.blocks1 = nn.ModuleList([ 128 | CBlock(embed_dims[0], dpr[i]) 129 | for i in range(depth[0])]) 130 | 131 | self.blocks2 = nn.ModuleList([ 132 | CBlock(embed_dims[1], dpr[i+depth[0]]) 133 | for i in range(depth[1])]) 134 | 135 | self.blocks3 = nn.ModuleList([ 136 | SABlock(embed_dims[2], num_heads[2], dpr[i+depth[0]+depth[1]]) 137 | for i in range(depth[2])]) 138 | 139 | self.blocks4 = nn.ModuleList([ 140 | SABlock(embed_dims[3], num_heads[3], dpr[i+depth[0]+depth[1]+depth[2]]) 141 | for i in range(depth[3])]) 142 | 143 | 144 | def forward(self, x: torch.Tensor): 145 | outs = [] 146 | 147 | x = self.patch_embed1(x) 148 | x = self.pos_drop(x) 149 | for blk in self.blocks1: 150 | x = blk(x) 151 | x_out = self.norm1(x.permute(0, 2, 3, 1)) 152 | outs.append(x_out.permute(0, 3, 1, 2)) 153 | 154 | x = self.patch_embed2(x) 155 | for blk in self.blocks2: 156 | x = blk(x) 157 | x_out = self.norm2(x.permute(0, 2, 3, 1)) 158 | outs.append(x_out.permute(0, 3, 1, 2)) 159 | 160 | x = self.patch_embed3(x) 161 | for blk in self.blocks3: 162 | x = blk(x) 163 | x_out = self.norm3(x.permute(0, 2, 3, 1)) 164 | outs.append(x_out.permute(0, 3, 1, 2)) 165 | 166 | x = self.patch_embed4(x) 167 | for blk in self.blocks4: 168 | x = blk(x) 169 | x_out = self.norm4(x.permute(0, 2, 3, 1)) 170 | outs.append(x_out.permute(0, 3, 1, 2)) 171 | 172 | return outs 173 | 174 | if __name__ == '__main__': 175 | model = UniFormer('S') 176 | model.load_state_dict(torch.load('C:\\Users\\sithu\\Documents\\weights\\backbones\\uniformer\\uniformer_small_in1k.pth', map_location='cpu')['model'], strict=False) 177 | x = torch.randn(1, 3, 224, 224) 178 | feats = model(x) 179 | for y in feats: 180 | print(y.shape) 181 | -------------------------------------------------------------------------------- /semseg/models/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from torch import nn 4 | from semseg.models.backbones import * 5 | from semseg.models.layers import trunc_normal_ 6 | 7 | 8 | class BaseModel(nn.Module): 9 | def __init__(self, backbone: str = 'MiT-B0', num_classes: int = 19) -> None: 10 | super().__init__() 11 | backbone, variant = backbone.split('-') 12 | self.backbone = eval(backbone)(variant) 13 | 14 | def _init_weights(self, m: nn.Module) -> None: 15 | if isinstance(m, nn.Linear): 16 | trunc_normal_(m.weight, std=.02) 17 | if m.bias is not None: 18 | nn.init.zeros_(m.bias) 19 | elif isinstance(m, nn.Conv2d): 20 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 21 | fan_out // m.groups 22 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 23 | if m.bias is not None: 24 | nn.init.zeros_(m.bias) 25 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): 26 | nn.init.ones_(m.weight) 27 | nn.init.zeros_(m.bias) 28 | 29 | def init_pretrained(self, pretrained: str = None) -> None: 30 | if pretrained: 31 | self.backbone.load_state_dict(torch.load(pretrained, map_location='cpu'), strict=False) -------------------------------------------------------------------------------- /semseg/models/custom_cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn import functional as F 4 | from semseg.models.base import BaseModel 5 | from semseg.models.heads import UPerHead 6 | 7 | 8 | class CustomCNN(BaseModel): 9 | def __init__(self, backbone: str = 'ResNet-50', num_classes: int = 19): 10 | super().__init__(backbone, num_classes) 11 | self.decode_head = UPerHead(self.backbone.channels, 256, num_classes) 12 | self.apply(self._init_weights) 13 | 14 | def forward(self, x: Tensor) -> Tensor: 15 | y = self.backbone(x) 16 | y = self.decode_head(y) # 4x reduction in image size 17 | y = F.interpolate(y, size=x.shape[2:], mode='bilinear', align_corners=False) # to original image shape 18 | return y 19 | 20 | 21 | if __name__ == '__main__': 22 | model = CustomCNN('ResNet-18', 19) 23 | model.init_pretrained('checkpoints/backbones/resnet/resnet18.pth') 24 | x = torch.randn(2, 3, 224, 224) 25 | y = model(x) 26 | print(y.shape) -------------------------------------------------------------------------------- /semseg/models/custom_vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn import functional as F 4 | from semseg.models.base import BaseModel 5 | from semseg.models.heads import UPerHead 6 | 7 | 8 | class CustomVIT(BaseModel): 9 | def __init__(self, backbone: str = 'ResT-S', num_classes: int = 19) -> None: 10 | super().__init__(backbone, num_classes) 11 | self.decode_head = UPerHead(self.backbone.channels, 128, num_classes) 12 | self.apply(self._init_weights) 13 | 14 | def forward(self, x: Tensor) -> Tensor: 15 | y = self.backbone(x) 16 | y = self.decode_head(y) # 4x reduction in image size 17 | y = F.interpolate(y, size=x.shape[2:], mode='bilinear', align_corners=False) # to original image shape 18 | return y 19 | 20 | 21 | if __name__ == '__main__': 22 | model = CustomVIT('ResT-S', 19) 23 | model.init_pretrained('checkpoints/backbones/rest/rest_small.pth') 24 | x = torch.zeros(2, 3, 512, 512) 25 | y = model(x) 26 | print(y.shape) 27 | 28 | 29 | -------------------------------------------------------------------------------- /semseg/models/fchardnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torch.nn import functional as F 4 | 5 | 6 | class ConvModule(nn.Module): 7 | def __init__(self, c1, c2, k=3, s=1): 8 | super().__init__() 9 | self.conv = nn.Conv2d(c1, c2, k, s, k//2, bias=False) 10 | self.norm = nn.BatchNorm2d(c2) 11 | self.relu = nn.ReLU6(True) 12 | 13 | def forward(self, x: Tensor) -> Tensor: 14 | return self.relu(self.norm(self.conv(x))) 15 | 16 | 17 | def get_link(layer, base_ch, growth_rate): 18 | if layer == 0: 19 | return base_ch, 0, [] 20 | 21 | link = [] 22 | out_channels = growth_rate 23 | 24 | for i in range(10): 25 | dv = 2 ** i 26 | if layer % dv == 0: 27 | link.append(layer - dv) 28 | 29 | if i > 0: out_channels *= 1.7 30 | 31 | out_channels = int((out_channels + 1) / 2) * 2 32 | in_channels = 0 33 | 34 | for i in link: 35 | ch, _, _ = get_link(i, base_ch, growth_rate) 36 | in_channels += ch 37 | 38 | return out_channels, in_channels, link 39 | 40 | 41 | class HarDBlock(nn.Module): 42 | def __init__(self, c1, growth_rate, n_layers): 43 | super().__init__() 44 | self.links = [] 45 | layers = [] 46 | self.out_channels = 0 47 | 48 | for i in range(n_layers): 49 | out_ch, in_ch, link = get_link(i+1, c1, growth_rate) 50 | self.links.append(link) 51 | 52 | layers.append(ConvModule(in_ch, out_ch)) 53 | 54 | if (i % 2 == 0) or (i == n_layers - 1): 55 | self.out_channels += out_ch 56 | 57 | self.layers = nn.ModuleList(layers) 58 | 59 | def forward(self, x: Tensor) -> Tensor: 60 | layers = [x] 61 | 62 | for layer in range(len(self.layers)): 63 | link = self.links[layer] 64 | tin = [] 65 | 66 | for i in link: 67 | tin.append(layers[i]) 68 | 69 | if len(tin) > 1: 70 | x = torch.cat(tin, dim=1) 71 | else: 72 | x = tin[0] 73 | 74 | out = self.layers[layer](x) 75 | layers.append(out) 76 | 77 | t = len(layers) 78 | outs = [] 79 | for i in range(t): 80 | if (i == t - 1) or (i % 2 == 1): 81 | outs.append(layers[i]) 82 | 83 | out = torch.cat(outs, dim=1) 84 | return out 85 | 86 | 87 | class FCHarDNet(nn.Module): 88 | def __init__(self, backbone: str = None, num_classes: int = 19) -> None: 89 | super().__init__() 90 | first_ch, ch_list, gr, n_layers = [16, 24, 32, 48], [64, 96, 160, 224, 320], [10, 16, 18, 24, 32], [4, 4, 8, 8, 8] 91 | 92 | self.base = nn.ModuleList([]) 93 | 94 | # stem 95 | self.base.append(ConvModule(3, first_ch[0], 3, 2)) 96 | self.base.append(ConvModule(first_ch[0], first_ch[1], 3)) 97 | self.base.append(ConvModule(first_ch[1], first_ch[2], 3, 2)) 98 | self.base.append(ConvModule(first_ch[2], first_ch[3], 3)) 99 | 100 | self.shortcut_layers = [] 101 | skip_connection_channel_counts = [] 102 | ch = first_ch[-1] 103 | 104 | for i in range(len(n_layers)): 105 | blk = HarDBlock(ch, gr[i], n_layers[i]) 106 | ch = blk.out_channels 107 | 108 | skip_connection_channel_counts.append(ch) 109 | self.base.append(blk) 110 | 111 | if i < len(n_layers) - 1: 112 | self.shortcut_layers.append(len(self.base) - 1) 113 | 114 | self.base.append(ConvModule(ch, ch_list[i], k=1)) 115 | ch = ch_list[i] 116 | 117 | if i < len(n_layers) - 1: 118 | self.base.append(nn.AvgPool2d(2, 2)) 119 | 120 | prev_block_channels = ch 121 | self.n_blocks = len(n_layers) - 1 122 | 123 | self.denseBlocksUp = nn.ModuleList([]) 124 | self.conv1x1_up = nn.ModuleList([]) 125 | 126 | for i in range(self.n_blocks-1, -1, -1): 127 | cur_channels_count = prev_block_channels + skip_connection_channel_counts[i] 128 | blk = HarDBlock(cur_channels_count // 2, gr[i], n_layers[i]) 129 | prev_block_channels = blk.out_channels 130 | 131 | self.conv1x1_up.append(ConvModule(cur_channels_count, cur_channels_count//2, 1)) 132 | self.denseBlocksUp.append(blk) 133 | 134 | self.finalConv = nn.Conv2d(prev_block_channels, num_classes, 1, 1, 0) 135 | 136 | self.apply(self._init_weights) 137 | 138 | def _init_weights(self, m: nn.Module) -> None: 139 | if isinstance(m, nn.Conv2d): 140 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 141 | elif isinstance(m, nn.BatchNorm2d): 142 | nn.init.constant_(m.weight, 1) 143 | nn.init.constant_(m.bias, 0) 144 | 145 | def init_pretrained(self, pretrained: str = None) -> None: 146 | if pretrained: 147 | self.load_state_dict(torch.load(pretrained, map_location='cpu'), strict=False) 148 | 149 | def forward(self, x: Tensor) -> Tensor: 150 | H, W = x.shape[-2:] 151 | skip_connections = [] 152 | for i, layer in enumerate(self.base): 153 | x = layer(x) 154 | if i in self.shortcut_layers: 155 | skip_connections.append(x) 156 | 157 | out = x 158 | 159 | for i in range(self.n_blocks): 160 | skip = skip_connections.pop() 161 | out = F.interpolate(out, size=skip.shape[-2:], mode='bilinear', align_corners=True) 162 | out = torch.cat([out, skip], dim=1) 163 | out = self.conv1x1_up[i](out) 164 | out = self.denseBlocksUp[i](out) 165 | 166 | out = self.finalConv(out) 167 | out = F.interpolate(out, size=(H, W), mode='bilinear', align_corners=True) 168 | return out 169 | 170 | 171 | if __name__ == '__main__': 172 | model = FCHarDNet() 173 | # model.init_pretrained('checkpoints/backbones/hardnet/hardnet_70.pth') 174 | # model.load_state_dict(torch.load('checkpoints/pretrained/hardnet/hardnet70_cityscapes.pth', map_location='cpu')) 175 | x = torch.zeros(1, 3, 224, 224) 176 | outs = model(x) 177 | print(outs.shape) 178 | -------------------------------------------------------------------------------- /semseg/models/heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .upernet import UPerHead 2 | from .segformer import SegFormerHead 3 | from .sfnet import SFHead 4 | from .fpn import FPNHead 5 | from .fapn import FaPNHead 6 | from .fcn import FCNHead 7 | from .condnet import CondHead 8 | from .lawin import LawinHead 9 | 10 | __all__ = ['UPerHead', 'SegFormerHead', 'SFHead', 'FPNHead', 'FaPNHead', 'FCNHead', 'CondHead', 'LawinHead'] -------------------------------------------------------------------------------- /semseg/models/heads/condnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torch.nn import functional as F 4 | from semseg.models.layers import ConvModule 5 | 6 | 7 | class CondHead(nn.Module): 8 | def __init__(self, in_channel: int = 2048, channel: int = 512, num_classes: int = 19): 9 | super().__init__() 10 | self.num_classes = num_classes 11 | self.weight_num = channel * num_classes 12 | self.bias_num = num_classes 13 | 14 | self.conv = ConvModule(in_channel, channel, 1) 15 | self.dropout = nn.Dropout2d(0.1) 16 | 17 | self.guidance_project = nn.Conv2d(channel, num_classes, 1) 18 | self.filter_project = nn.Conv2d(channel*num_classes, self.weight_num + self.bias_num, 1, groups=num_classes) 19 | 20 | def forward(self, features) -> Tensor: 21 | x = self.dropout(self.conv(features[-1])) 22 | B, C, H, W = x.shape 23 | guidance_mask = self.guidance_project(x) 24 | cond_logit = guidance_mask 25 | 26 | key = x 27 | value = x 28 | guidance_mask = guidance_mask.softmax(dim=1).view(*guidance_mask.shape[:2], -1) 29 | key = key.view(B, C, -1).permute(0, 2, 1) 30 | 31 | cond_filters = torch.matmul(guidance_mask, key) 32 | cond_filters /= H * W 33 | cond_filters = cond_filters.view(B, -1, 1, 1) 34 | cond_filters = self.filter_project(cond_filters) 35 | cond_filters = cond_filters.view(B, -1) 36 | 37 | weight, bias = torch.split(cond_filters, [self.weight_num, self.bias_num], dim=1) 38 | weight = weight.reshape(B * self.num_classes, -1, 1, 1) 39 | bias = bias.reshape(B * self.num_classes) 40 | 41 | value = value.view(-1, H, W).unsqueeze(0) 42 | seg_logit = F.conv2d(value, weight, bias, 1, 0, groups=B).view(B, self.num_classes, H, W) 43 | 44 | if self.training: 45 | return cond_logit, seg_logit 46 | return seg_logit 47 | 48 | 49 | if __name__ == '__main__': 50 | from semseg.models.backbones import ResNetD 51 | backbone = ResNetD('50') 52 | head = CondHead() 53 | x = torch.randn(2, 3, 224, 224) 54 | features = backbone(x) 55 | outs = head(features) 56 | for out in outs: 57 | out = F.interpolate(out, size=x.shape[-2:], mode='bilinear', align_corners=False) 58 | print(out.shape) -------------------------------------------------------------------------------- /semseg/models/heads/fapn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torch.nn import functional as F 4 | from torchvision.ops import DeformConv2d 5 | from semseg.models.layers import ConvModule 6 | 7 | 8 | class DCNv2(nn.Module): 9 | def __init__(self, c1, c2, k, s, p, g=1): 10 | super().__init__() 11 | self.dcn = DeformConv2d(c1, c2, k, s, p, groups=g) 12 | self.offset_mask = nn.Conv2d(c2, g* 3 * k * k, k, s, p) 13 | self._init_offset() 14 | 15 | def _init_offset(self): 16 | self.offset_mask.weight.data.zero_() 17 | self.offset_mask.bias.data.zero_() 18 | 19 | def forward(self, x, offset): 20 | out = self.offset_mask(offset) 21 | o1, o2, mask = torch.chunk(out, 3, dim=1) 22 | offset = torch.cat([o1, o2], dim=1) 23 | mask = mask.sigmoid() 24 | return self.dcn(x, offset, mask) 25 | 26 | 27 | class FSM(nn.Module): 28 | def __init__(self, c1, c2): 29 | super().__init__() 30 | self.conv_atten = nn.Conv2d(c1, c1, 1, bias=False) 31 | self.conv = nn.Conv2d(c1, c2, 1, bias=False) 32 | 33 | def forward(self, x: Tensor) -> Tensor: 34 | atten = self.conv_atten(F.avg_pool2d(x, x.shape[2:])).sigmoid() 35 | feat = torch.mul(x, atten) 36 | x = x + feat 37 | return self.conv(x) 38 | 39 | 40 | class FAM(nn.Module): 41 | def __init__(self, c1, c2): 42 | super().__init__() 43 | self.lateral_conv = FSM(c1, c2) 44 | self.offset = nn.Conv2d(c2*2, c2, 1, bias=False) 45 | self.dcpack_l2 = DCNv2(c2, c2, 3, 1, 1, 8) 46 | 47 | def forward(self, feat_l, feat_s): 48 | feat_up = feat_s 49 | if feat_l.shape[2:] != feat_s.shape[2:]: 50 | feat_up = F.interpolate(feat_s, size=feat_l.shape[2:], mode='bilinear', align_corners=False) 51 | 52 | feat_arm = self.lateral_conv(feat_l) 53 | offset = self.offset(torch.cat([feat_arm, feat_up*2], dim=1)) 54 | 55 | feat_align = F.relu(self.dcpack_l2(feat_up, offset)) 56 | return feat_align + feat_arm 57 | 58 | 59 | class FaPNHead(nn.Module): 60 | def __init__(self, in_channels, channel=128, num_classes=19): 61 | super().__init__() 62 | in_channels = in_channels[::-1] 63 | self.align_modules = nn.ModuleList([ConvModule(in_channels[0], channel, 1)]) 64 | self.output_convs = nn.ModuleList([]) 65 | 66 | for ch in in_channels[1:]: 67 | self.align_modules.append(FAM(ch, channel)) 68 | self.output_convs.append(ConvModule(channel, channel, 3, 1, 1)) 69 | 70 | self.conv_seg = nn.Conv2d(channel, num_classes, 1) 71 | self.dropout = nn.Dropout2d(0.1) 72 | 73 | def forward(self, features) -> Tensor: 74 | features = features[::-1] 75 | out = self.align_modules[0](features[0]) 76 | 77 | for feat, align_module, output_conv in zip(features[1:], self.align_modules[1:], self.output_convs): 78 | out = align_module(feat, out) 79 | out = output_conv(out) 80 | out = self.conv_seg(self.dropout(out)) 81 | return out 82 | 83 | 84 | if __name__ == '__main__': 85 | from semseg.models.backbones import ResNet 86 | backbone = ResNet('50') 87 | head = FaPNHead([256, 512, 1024, 2048], 128, 19) 88 | x = torch.randn(2, 3, 224, 224) 89 | features = backbone(x) 90 | out = head(features) 91 | out = F.interpolate(out, size=x.shape[-2:], mode='bilinear', align_corners=False) 92 | print(out.shape) -------------------------------------------------------------------------------- /semseg/models/heads/fcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torch.nn import functional as F 4 | from semseg.models.layers import ConvModule 5 | 6 | 7 | class FCNHead(nn.Module): 8 | def __init__(self, c1, c2, num_classes: int = 19): 9 | super().__init__() 10 | self.conv = ConvModule(c1, c2, 1) 11 | self.cls = nn.Conv2d(c2, num_classes, 1) 12 | 13 | def forward(self, features) -> Tensor: 14 | x = self.conv(features[-1]) 15 | x = self.cls(x) 16 | return x 17 | 18 | 19 | if __name__ == '__main__': 20 | from semseg.models.backbones import ResNet 21 | backbone = ResNet('50') 22 | head = FCNHead(2048, 256, 19) 23 | x = torch.randn(2, 3, 224, 224) 24 | features = backbone(x) 25 | out = head(features) 26 | out = F.interpolate(out, size=x.shape[-2:], mode='bilinear', align_corners=False) 27 | print(out.shape) 28 | -------------------------------------------------------------------------------- /semseg/models/heads/fpn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torch.nn import functional as F 4 | from semseg.models.layers import ConvModule 5 | 6 | 7 | class FPNHead(nn.Module): 8 | """Panoptic Feature Pyramid Networks 9 | https://arxiv.org/abs/1901.02446 10 | """ 11 | def __init__(self, in_channels, channel=128, num_classes=19): 12 | super().__init__() 13 | self.lateral_convs = nn.ModuleList([]) 14 | self.output_convs = nn.ModuleList([]) 15 | 16 | for ch in in_channels[::-1]: 17 | self.lateral_convs.append(ConvModule(ch, channel, 1)) 18 | self.output_convs.append(ConvModule(channel, channel, 3, 1, 1)) 19 | 20 | self.conv_seg = nn.Conv2d(channel, num_classes, 1) 21 | self.dropout = nn.Dropout2d(0.1) 22 | 23 | def forward(self, features) -> Tensor: 24 | features = features[::-1] 25 | out = self.lateral_convs[0](features[0]) 26 | 27 | for i in range(1, len(features)): 28 | out = F.interpolate(out, scale_factor=2.0, mode='nearest') 29 | out = out + self.lateral_convs[i](features[i]) 30 | out = self.output_convs[i](out) 31 | out = self.conv_seg(self.dropout(out)) 32 | return out 33 | 34 | 35 | if __name__ == '__main__': 36 | from semseg.models.backbones import ResNet 37 | backbone = ResNet('50') 38 | head = FPNHead([256, 512, 1024, 2048], 128, 19) 39 | x = torch.randn(2, 3, 224, 224) 40 | features = backbone(x) 41 | out = head(features) 42 | out = F.interpolate(out, size=x.shape[-2:], mode='bilinear', align_corners=False) 43 | print(out.shape) -------------------------------------------------------------------------------- /semseg/models/heads/segformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from typing import Tuple 4 | from torch.nn import functional as F 5 | 6 | 7 | class MLP(nn.Module): 8 | def __init__(self, dim, embed_dim): 9 | super().__init__() 10 | self.proj = nn.Linear(dim, embed_dim) 11 | 12 | def forward(self, x: Tensor) -> Tensor: 13 | x = x.flatten(2).transpose(1, 2) 14 | x = self.proj(x) 15 | return x 16 | 17 | 18 | class ConvModule(nn.Module): 19 | def __init__(self, c1, c2): 20 | super().__init__() 21 | self.conv = nn.Conv2d(c1, c2, 1, bias=False) 22 | self.bn = nn.BatchNorm2d(c2) # use SyncBN in original 23 | self.activate = nn.ReLU(True) 24 | 25 | def forward(self, x: Tensor) -> Tensor: 26 | return self.activate(self.bn(self.conv(x))) 27 | 28 | 29 | class SegFormerHead(nn.Module): 30 | def __init__(self, dims: list, embed_dim: int = 256, num_classes: int = 19): 31 | super().__init__() 32 | for i, dim in enumerate(dims): 33 | self.add_module(f"linear_c{i+1}", MLP(dim, embed_dim)) 34 | 35 | self.linear_fuse = ConvModule(embed_dim*4, embed_dim) 36 | self.linear_pred = nn.Conv2d(embed_dim, num_classes, 1) 37 | self.dropout = nn.Dropout2d(0.1) 38 | 39 | def forward(self, features: Tuple[Tensor, Tensor, Tensor, Tensor]) -> Tensor: 40 | B, _, H, W = features[0].shape 41 | outs = [self.linear_c1(features[0]).permute(0, 2, 1).reshape(B, -1, *features[0].shape[-2:])] 42 | 43 | for i, feature in enumerate(features[1:]): 44 | cf = eval(f"self.linear_c{i+2}")(feature).permute(0, 2, 1).reshape(B, -1, *feature.shape[-2:]) 45 | outs.append(F.interpolate(cf, size=(H, W), mode='bilinear', align_corners=False)) 46 | 47 | seg = self.linear_fuse(torch.cat(outs[::-1], dim=1)) 48 | seg = self.linear_pred(self.dropout(seg)) 49 | return seg -------------------------------------------------------------------------------- /semseg/models/heads/sfnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torch.nn import functional as F 4 | from semseg.models.layers import ConvModule 5 | from semseg.models.modules import PPM 6 | 7 | 8 | class AlignedModule(nn.Module): 9 | def __init__(self, c1, c2, k=3): 10 | super().__init__() 11 | self.down_h = nn.Conv2d(c1, c2, 1, bias=False) 12 | self.down_l = nn.Conv2d(c1, c2, 1, bias=False) 13 | self.flow_make = nn.Conv2d(c2 * 2, 2, k, 1, 1, bias=False) 14 | 15 | def forward(self, low_feature: Tensor, high_feature: Tensor) -> Tensor: 16 | high_feature_origin = high_feature 17 | H, W = low_feature.shape[-2:] 18 | low_feature = self.down_l(low_feature) 19 | high_feature = self.down_h(high_feature) 20 | high_feature = F.interpolate(high_feature, size=(H, W), mode='bilinear', align_corners=True) 21 | flow = self.flow_make(torch.cat([high_feature, low_feature], dim=1)) 22 | high_feature = self.flow_warp(high_feature_origin, flow, (H, W)) 23 | return high_feature 24 | 25 | def flow_warp(self, x: Tensor, flow: Tensor, size: tuple) -> Tensor: 26 | # norm = torch.tensor(size).reshape(1, 1, 1, -1) 27 | norm = torch.tensor([[[[*size]]]]).type_as(x).to(x.device) 28 | H = torch.linspace(-1.0, 1.0, size[0]).view(-1, 1).repeat(1, size[1]) 29 | W = torch.linspace(-1.0, 1.0, size[1]).repeat(size[0], 1) 30 | grid = torch.cat((W.unsqueeze(2), H.unsqueeze(2)), dim=2) 31 | grid = grid.repeat(x.shape[0], 1, 1, 1).type_as(x).to(x.device) 32 | grid = grid + flow.permute(0, 2, 3, 1) / norm 33 | output = F.grid_sample(x, grid, align_corners=False) 34 | return output 35 | 36 | 37 | class SFHead(nn.Module): 38 | def __init__(self, in_channels, channel=256, num_classes=19, scales=(1, 2, 3, 6)): 39 | super().__init__() 40 | self.ppm = PPM(in_channels[-1], channel, scales) 41 | 42 | self.fpn_in = nn.ModuleList([]) 43 | self.fpn_out = nn.ModuleList([]) 44 | self.fpn_out_align = nn.ModuleList([]) 45 | 46 | for in_ch in in_channels[:-1]: 47 | self.fpn_in.append(ConvModule(in_ch, channel, 1)) 48 | self.fpn_out.append(ConvModule(channel, channel, 3, 1, 1)) 49 | self.fpn_out_align.append(AlignedModule(channel, channel//2)) 50 | 51 | self.bottleneck = ConvModule(len(in_channels) * channel, channel, 3, 1, 1) 52 | self.dropout = nn.Dropout2d(0.1) 53 | self.conv_seg = nn.Conv2d(channel, num_classes, 1) 54 | 55 | def forward(self, features: list) -> Tensor: 56 | f = self.ppm(features[-1]) 57 | fpn_features = [f] 58 | 59 | for i in reversed(range(len(features) - 1)): 60 | feature = self.fpn_in[i](features[i]) 61 | f = feature + self.fpn_out_align[i](feature, f) 62 | fpn_features.append(self.fpn_out[i](f)) 63 | 64 | fpn_features.reverse() 65 | 66 | for i in range(1, len(fpn_features)): 67 | fpn_features[i] = F.interpolate(fpn_features[i], size=fpn_features[0].shape[-2:], mode='bilinear', align_corners=True) 68 | 69 | output = self.bottleneck(torch.cat(fpn_features, dim=1)) 70 | output = self.conv_seg(self.dropout(output)) 71 | return output 72 | 73 | -------------------------------------------------------------------------------- /semseg/models/heads/upernet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torch.nn import functional as F 4 | from typing import Tuple 5 | from semseg.models.layers import ConvModule 6 | from semseg.models.modules import PPM 7 | 8 | 9 | class UPerHead(nn.Module): 10 | """Unified Perceptual Parsing for Scene Understanding 11 | https://arxiv.org/abs/1807.10221 12 | scales: Pooling scales used in PPM module applied on the last feature 13 | """ 14 | def __init__(self, in_channels, channel=128, num_classes: int = 19, scales=(1, 2, 3, 6)): 15 | super().__init__() 16 | # PPM Module 17 | self.ppm = PPM(in_channels[-1], channel, scales) 18 | 19 | # FPN Module 20 | self.fpn_in = nn.ModuleList() 21 | self.fpn_out = nn.ModuleList() 22 | 23 | for in_ch in in_channels[:-1]: # skip the top layer 24 | self.fpn_in.append(ConvModule(in_ch, channel, 1)) 25 | self.fpn_out.append(ConvModule(channel, channel, 3, 1, 1)) 26 | 27 | self.bottleneck = ConvModule(len(in_channels)*channel, channel, 3, 1, 1) 28 | self.dropout = nn.Dropout2d(0.1) 29 | self.conv_seg = nn.Conv2d(channel, num_classes, 1) 30 | 31 | 32 | def forward(self, features: Tuple[Tensor, Tensor, Tensor, Tensor]) -> Tensor: 33 | f = self.ppm(features[-1]) 34 | fpn_features = [f] 35 | 36 | for i in reversed(range(len(features)-1)): 37 | feature = self.fpn_in[i](features[i]) 38 | f = feature + F.interpolate(f, size=feature.shape[-2:], mode='bilinear', align_corners=False) 39 | fpn_features.append(self.fpn_out[i](f)) 40 | 41 | fpn_features.reverse() 42 | for i in range(1, len(features)): 43 | fpn_features[i] = F.interpolate(fpn_features[i], size=fpn_features[0].shape[-2:], mode='bilinear', align_corners=False) 44 | 45 | output = self.bottleneck(torch.cat(fpn_features, dim=1)) 46 | output = self.conv_seg(self.dropout(output)) 47 | return output 48 | 49 | 50 | if __name__ == '__main__': 51 | model = UPerHead([64, 128, 256, 512], 128) 52 | x1 = torch.randn(2, 64, 56, 56) 53 | x2 = torch.randn(2, 128, 28, 28) 54 | x3 = torch.randn(2, 256, 14, 14) 55 | x4 = torch.randn(2, 512, 7, 7) 56 | y = model([x1, x2, x3, x4]) 57 | print(y.shape) -------------------------------------------------------------------------------- /semseg/models/lawin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn import functional as F 4 | from semseg.models.base import BaseModel 5 | from semseg.models.heads import LawinHead 6 | 7 | 8 | class Lawin(BaseModel): 9 | """ 10 | Notes::::: This implementation has larger params and FLOPs than the results reported in the paper. 11 | Will update the code and weights if the original author releases the full code. 12 | """ 13 | def __init__(self, backbone: str = 'MiT-B0', num_classes: int = 19) -> None: 14 | super().__init__(backbone, num_classes) 15 | self.decode_head = LawinHead(self.backbone.channels, 256 if 'B0' in backbone else 512, num_classes) 16 | self.apply(self._init_weights) 17 | 18 | def forward(self, x: Tensor) -> Tensor: 19 | y = self.backbone(x) 20 | y = self.decode_head(y) # 4x reduction in image size 21 | y = F.interpolate(y, size=x.shape[2:], mode='bilinear', align_corners=False) # to original image shape 22 | return y 23 | 24 | 25 | if __name__ == '__main__': 26 | model = Lawin('MiT-B1') 27 | model.eval() 28 | x = torch.zeros(1, 3, 512, 512) 29 | y = model(x) 30 | print(y.shape) 31 | from fvcore.nn import flop_count_table, FlopCountAnalysis 32 | print(flop_count_table(FlopCountAnalysis(model, x))) -------------------------------------------------------------------------------- /semseg/models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import * 2 | from .initialize import * -------------------------------------------------------------------------------- /semseg/models/layers/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | 4 | 5 | class ConvModule(nn.Sequential): 6 | def __init__(self, c1, c2, k, s=1, p=0, d=1, g=1): 7 | super().__init__( 8 | nn.Conv2d(c1, c2, k, s, p, d, g, bias=False), 9 | nn.BatchNorm2d(c2), 10 | nn.ReLU(True) 11 | ) 12 | 13 | 14 | class DropPath(nn.Module): 15 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 16 | Copied from timm 17 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 18 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 19 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 20 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 21 | 'survival rate' as the argument. 22 | """ 23 | def __init__(self, p: float = None): 24 | super().__init__() 25 | self.p = p 26 | 27 | def forward(self, x: Tensor) -> Tensor: 28 | if self.p == 0. or not self.training: 29 | return x 30 | kp = 1 - self.p 31 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) 32 | random_tensor = kp + torch.rand(shape, dtype=x.dtype, device=x.device) 33 | random_tensor.floor_() # binarize 34 | return x.div(kp) * random_tensor -------------------------------------------------------------------------------- /semseg/models/layers/initialize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import warnings 4 | from torch import nn, Tensor 5 | 6 | 7 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 8 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 9 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 10 | def norm_cdf(x): 11 | # Computes standard normal cumulative distribution function 12 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 13 | 14 | if (mean < a - 2 * std) or (mean > b + 2 * std): 15 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 16 | "The distribution of values may be incorrect.", 17 | stacklevel=2) 18 | 19 | with torch.no_grad(): 20 | # Values are generated by using a truncated uniform distribution and 21 | # then using the inverse CDF for the normal distribution. 22 | # Get upper and lower cdf values 23 | l = norm_cdf((a - mean) / std) 24 | u = norm_cdf((b - mean) / std) 25 | 26 | # Uniformly fill tensor with values from [l, u], then translate to 27 | # [2l-1, 2u-1]. 28 | tensor.uniform_(2 * l - 1, 2 * u - 1) 29 | 30 | # Use inverse cdf transform for normal distribution to get truncated 31 | # standard normal 32 | tensor.erfinv_() 33 | 34 | # Transform to proper mean, std 35 | tensor.mul_(std * math.sqrt(2.)) 36 | tensor.add_(mean) 37 | 38 | # Clamp to ensure it's in the proper range 39 | tensor.clamp_(min=a, max=b) 40 | return tensor 41 | 42 | 43 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 44 | # type: (Tensor, float, float, float, float) -> Tensor 45 | r"""Fills the input Tensor with values drawn from a truncated 46 | normal distribution. The values are effectively drawn from the 47 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 48 | with values outside :math:`[a, b]` redrawn until they are within 49 | the bounds. The method used for generating the random values works 50 | best when :math:`a \leq \text{mean} \leq b`. 51 | Args: 52 | tensor: an n-dimensional `torch.Tensor` 53 | mean: the mean of the normal distribution 54 | std: the standard deviation of the normal distribution 55 | a: the minimum cutoff value 56 | b: the maximum cutoff value 57 | Examples: 58 | >>> w = torch.empty(3, 5) 59 | >>> nn.init.trunc_normal_(w) 60 | """ 61 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 62 | -------------------------------------------------------------------------------- /semseg/models/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .ppm import PPM 2 | from .psa import PSAP, PSAS 3 | 4 | __all__ = ['PPM', 'PSAP', 'PSAS'] -------------------------------------------------------------------------------- /semseg/models/modules/ppm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torch.nn import functional as F 4 | from semseg.models.layers import ConvModule 5 | 6 | 7 | class PPM(nn.Module): 8 | """Pyramid Pooling Module in PSPNet 9 | """ 10 | def __init__(self, c1, c2=128, scales=(1, 2, 3, 6)): 11 | super().__init__() 12 | self.stages = nn.ModuleList([ 13 | nn.Sequential( 14 | nn.AdaptiveAvgPool2d(scale), 15 | ConvModule(c1, c2, 1) 16 | ) 17 | for scale in scales]) 18 | 19 | self.bottleneck = ConvModule(c1 + c2 * len(scales), c2, 3, 1, 1) 20 | 21 | def forward(self, x: Tensor) -> Tensor: 22 | outs = [] 23 | for stage in self.stages: 24 | outs.append(F.interpolate(stage(x), size=x.shape[-2:], mode='bilinear', align_corners=True)) 25 | 26 | outs = [x] + outs[::-1] 27 | out = self.bottleneck(torch.cat(outs, dim=1)) 28 | return out 29 | 30 | 31 | if __name__ == '__main__': 32 | model = PPM(512, 128) 33 | x = torch.randn(2, 512, 7, 7) 34 | y = model(x) 35 | print(y.shape) # [2, 128, 7, 7] -------------------------------------------------------------------------------- /semseg/models/segformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn import functional as F 4 | from semseg.models.base import BaseModel 5 | from semseg.models.heads import SegFormerHead 6 | 7 | 8 | class SegFormer(BaseModel): 9 | def __init__(self, backbone: str = 'MiT-B0', num_classes: int = 19) -> None: 10 | super().__init__(backbone, num_classes) 11 | self.decode_head = SegFormerHead(self.backbone.channels, 256 if 'B0' in backbone or 'B1' in backbone else 768, num_classes) 12 | self.apply(self._init_weights) 13 | 14 | def forward(self, x: Tensor) -> Tensor: 15 | y = self.backbone(x) 16 | y = self.decode_head(y) # 4x reduction in image size 17 | y = F.interpolate(y, size=x.shape[2:], mode='bilinear', align_corners=False) # to original image shape 18 | return y 19 | 20 | 21 | if __name__ == '__main__': 22 | model = SegFormer('MiT-B0') 23 | # model.load_state_dict(torch.load('checkpoints/pretrained/segformer/segformer.b0.ade.pth', map_location='cpu')) 24 | x = torch.zeros(1, 3, 512, 512) 25 | y = model(x) 26 | print(y.shape) -------------------------------------------------------------------------------- /semseg/models/sfnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn import functional as F 4 | from semseg.models.base import BaseModel 5 | from semseg.models.heads import SFHead 6 | 7 | 8 | class SFNet(BaseModel): 9 | def __init__(self, backbone: str = 'ResNetD-18', num_classes: int = 19): 10 | assert 'ResNet' in backbone 11 | super().__init__(backbone, num_classes) 12 | self.head = SFHead(self.backbone.channels, 128 if '18' in backbone else 256, num_classes) 13 | self.apply(self._init_weights) 14 | 15 | def forward(self, x: Tensor) -> Tensor: 16 | outs = self.backbone(x) 17 | out = self.head(outs) 18 | out = F.interpolate(out, size=x.shape[-2:], mode='bilinear', align_corners=True) 19 | return out 20 | 21 | 22 | if __name__ == '__main__': 23 | model = SFNet('ResNetD-18') 24 | model.init_pretrained('checkpoints/backbones/resnetd/resnetd18.pth') 25 | x = torch.randn(2, 3, 224, 224) 26 | y = model(x) 27 | print(y.shape) -------------------------------------------------------------------------------- /semseg/optimizers.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.optim import AdamW, SGD 3 | 4 | 5 | def get_optimizer(model: nn.Module, optimizer: str, lr: float, weight_decay: float = 0.01): 6 | wd_params, nwd_params = [], [] 7 | for p in model.parameters(): 8 | if p.dim() == 1: 9 | nwd_params.append(p) 10 | else: 11 | wd_params.append(p) 12 | 13 | params = [ 14 | {"params": wd_params}, 15 | {"params": nwd_params, "weight_decay": 0} 16 | ] 17 | 18 | if optimizer == 'adamw': 19 | return AdamW(params, lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=weight_decay) 20 | else: 21 | return SGD(params, lr, momentum=0.9, weight_decay=weight_decay) -------------------------------------------------------------------------------- /semseg/schedulers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | 6 | class PolyLR(_LRScheduler): 7 | def __init__(self, optimizer, max_iter, decay_iter=1, power=0.9, last_epoch=-1) -> None: 8 | self.decay_iter = decay_iter 9 | self.max_iter = max_iter 10 | self.power = power 11 | super().__init__(optimizer, last_epoch=last_epoch) 12 | 13 | def get_lr(self): 14 | if self.last_epoch % self.decay_iter or self.last_epoch % self.max_iter: 15 | return self.base_lrs 16 | else: 17 | factor = (1 - self.last_epoch / float(self.max_iter)) ** self.power 18 | return [factor*lr for lr in self.base_lrs] 19 | 20 | 21 | class WarmupLR(_LRScheduler): 22 | def __init__(self, optimizer, warmup_iter=500, warmup_ratio=5e-4, warmup='exp', last_epoch=-1) -> None: 23 | self.warmup_iter = warmup_iter 24 | self.warmup_ratio = warmup_ratio 25 | self.warmup = warmup 26 | super().__init__(optimizer, last_epoch) 27 | 28 | def get_lr(self): 29 | ratio = self.get_lr_ratio() 30 | return [ratio * lr for lr in self.base_lrs] 31 | 32 | def get_lr_ratio(self): 33 | return self.get_warmup_ratio() if self.last_epoch < self.warmup_iter else self.get_main_ratio() 34 | 35 | def get_main_ratio(self): 36 | raise NotImplementedError 37 | 38 | def get_warmup_ratio(self): 39 | assert self.warmup in ['linear', 'exp'] 40 | alpha = self.last_epoch / self.warmup_iter 41 | 42 | return self.warmup_ratio + (1. - self.warmup_ratio) * alpha if self.warmup == 'linear' else self.warmup_ratio ** (1. - alpha) 43 | 44 | 45 | class WarmupPolyLR(WarmupLR): 46 | def __init__(self, optimizer, power, max_iter, warmup_iter=500, warmup_ratio=5e-4, warmup='exp', last_epoch=-1) -> None: 47 | self.power = power 48 | self.max_iter = max_iter 49 | super().__init__(optimizer, warmup_iter, warmup_ratio, warmup, last_epoch) 50 | 51 | def get_main_ratio(self): 52 | real_iter = self.last_epoch - self.warmup_iter 53 | real_max_iter = self.max_iter - self.warmup_iter 54 | alpha = real_iter / real_max_iter 55 | 56 | return (1 - alpha) ** self.power 57 | 58 | 59 | class WarmupExpLR(WarmupLR): 60 | def __init__(self, optimizer, gamma, interval=1, warmup_iter=500, warmup_ratio=5e-4, warmup='exp', last_epoch=-1) -> None: 61 | self.gamma = gamma 62 | self.interval = interval 63 | super().__init__(optimizer, warmup_iter, warmup_ratio, warmup, last_epoch) 64 | 65 | def get_main_ratio(self): 66 | real_iter = self.last_epoch - self.warmup_iter 67 | return self.gamma ** (real_iter // self.interval) 68 | 69 | 70 | class WarmupCosineLR(WarmupLR): 71 | def __init__(self, optimizer, max_iter, eta_ratio=0, warmup_iter=500, warmup_ratio=5e-4, warmup='exp', last_epoch=-1) -> None: 72 | self.eta_ratio = eta_ratio 73 | self.max_iter = max_iter 74 | super().__init__(optimizer, warmup_iter, warmup_ratio, warmup, last_epoch) 75 | 76 | def get_main_ratio(self): 77 | real_iter = self.last_epoch - self.warmup_iter 78 | real_max_iter = self.max_iter - self.warmup_iter 79 | 80 | return self.eta_ratio + (1 - self.eta_ratio) * (1 + math.cos(math.pi * self.last_epoch / real_max_iter)) / 2 81 | 82 | 83 | 84 | __all__ = ['polylr', 'warmuppolylr', 'warmupcosinelr', 'warmupsteplr'] 85 | 86 | 87 | def get_scheduler(scheduler_name: str, optimizer, max_iter: int, power: int, warmup_iter: int, warmup_ratio: float): 88 | assert scheduler_name in __all__, f"Unavailable scheduler name >> {scheduler_name}.\nAvailable schedulers: {__all__}" 89 | if scheduler_name == 'warmuppolylr': 90 | return WarmupPolyLR(optimizer, power, max_iter, warmup_iter, warmup_ratio, warmup='linear') 91 | elif scheduler_name == 'warmupcosinelr': 92 | return WarmupCosineLR(optimizer, max_iter, warmup_iter=warmup_iter, warmup_ratio=warmup_ratio) 93 | return PolyLR(optimizer, max_iter) 94 | 95 | 96 | if __name__ == '__main__': 97 | model = torch.nn.Conv2d(3, 16, 3, 1, 1) 98 | optim = torch.optim.SGD(model.parameters(), lr=1e-3) 99 | 100 | max_iter = 20000 101 | sched = WarmupPolyLR(optim, power=0.9, max_iter=max_iter, warmup_iter=200, warmup_ratio=0.1, warmup='exp', last_epoch=-1) 102 | 103 | lrs = [] 104 | 105 | for _ in range(max_iter): 106 | lr = sched.get_lr()[0] 107 | lrs.append(lr) 108 | optim.step() 109 | sched.step() 110 | 111 | import matplotlib.pyplot as plt 112 | import numpy as np 113 | 114 | plt.plot(np.arange(len(lrs)), np.array(lrs)) 115 | plt.grid() 116 | plt.show() -------------------------------------------------------------------------------- /semseg/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sithu31296/semantic-segmentation/c2843c9cf42a66a96a4173a0ef0059975a6e387c/semseg/utils/__init__.py -------------------------------------------------------------------------------- /semseg/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | import time 5 | import os 6 | import functools 7 | from pathlib import Path 8 | from torch.backends import cudnn 9 | from torch import nn, Tensor 10 | from torch.autograd import profiler 11 | from typing import Union 12 | from torch import distributed as dist 13 | from tabulate import tabulate 14 | from semseg import models 15 | 16 | 17 | def fix_seeds(seed: int = 3407) -> None: 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed(seed) 20 | np.random.seed(seed) 21 | random.seed(seed) 22 | 23 | def setup_cudnn() -> None: 24 | # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html 25 | cudnn.benchmark = True 26 | cudnn.deterministic = False 27 | 28 | def time_sync() -> float: 29 | if torch.cuda.is_available(): 30 | torch.cuda.synchronize() 31 | return time.time() 32 | 33 | def get_model_size(model: Union[nn.Module, torch.jit.ScriptModule]): 34 | tmp_model_path = Path('temp.p') 35 | if isinstance(model, torch.jit.ScriptModule): 36 | torch.jit.save(model, tmp_model_path) 37 | else: 38 | torch.save(model.state_dict(), tmp_model_path) 39 | size = tmp_model_path.stat().st_size 40 | os.remove(tmp_model_path) 41 | return size / 1e6 # in MB 42 | 43 | @torch.no_grad() 44 | def test_model_latency(model: nn.Module, inputs: torch.Tensor, use_cuda: bool = False) -> float: 45 | with profiler.profile(use_cuda=use_cuda) as prof: 46 | _ = model(inputs) 47 | return prof.self_cpu_time_total / 1000 # ms 48 | 49 | def count_parameters(model: nn.Module) -> float: 50 | return sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6 # in M 51 | 52 | def setup_ddp() -> int: 53 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 54 | rank = int(os.environ['RANK']) 55 | world_size = int(os.environ['WORLD_SIZE']) 56 | gpu = int(os.environ(['LOCAL_RANK'])) 57 | torch.cuda.set_device(gpu) 58 | dist.init_process_group('nccl', init_method="env://",world_size=world_size, rank=rank) 59 | dist.barrier() 60 | else: 61 | gpu = 0 62 | return gpu 63 | 64 | def cleanup_ddp(): 65 | if dist.is_initialized(): 66 | dist.destroy_process_group() 67 | 68 | def reduce_tensor(tensor: Tensor) -> Tensor: 69 | rt = tensor.clone() 70 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 71 | rt /= dist.get_world_size() 72 | return rt 73 | 74 | @torch.no_grad() 75 | def throughput(dataloader, model: nn.Module, times: int = 30): 76 | model.eval() 77 | images, _ = next(iter(dataloader)) 78 | images = images.cuda(non_blocking=True) 79 | B = images.shape[0] 80 | print(f"Throughput averaged with {times} times") 81 | start = time_sync() 82 | for _ in range(times): 83 | model(images) 84 | end = time_sync() 85 | 86 | print(f"Batch Size {B} throughput {times * B / (end - start)} images/s") 87 | 88 | 89 | def show_models(): 90 | model_names = models.__all__ 91 | model_variants = [list(eval(f'models.{name.lower()}_settings').keys()) for name in model_names] 92 | 93 | print(tabulate({'Model Names': model_names, 'Model Variants': model_variants}, headers='keys')) 94 | 95 | 96 | def timer(func): 97 | @functools.wraps(func) 98 | def wrapper_timer(*args, **kwargs): 99 | tic = time.perf_counter() 100 | value = func(*args, **kwargs) 101 | toc = time.perf_counter() 102 | elapsed_time = toc - tic 103 | print(f"Elapsed time: {elapsed_time * 1000:.2f}ms") 104 | return value 105 | return wrapper_timer -------------------------------------------------------------------------------- /semseg/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | import torch 5 | import matplotlib.pyplot as plt 6 | from torch.utils.data import DataLoader 7 | from torchvision import transforms as T 8 | from torchvision.utils import make_grid 9 | from semseg.augmentations import Compose, Normalize, RandomResizedCrop 10 | from PIL import Image, ImageDraw, ImageFont 11 | 12 | 13 | def visualize_dataset_sample(dataset, root, split='val', batch_size=4): 14 | transform = Compose([ 15 | RandomResizedCrop((512, 512), scale=(1.0, 1.0)), 16 | Normalize() 17 | ]) 18 | 19 | dataset = dataset(root, split=split, transform=transform) 20 | dataloader = DataLoader(dataset, shuffle=True, batch_size=batch_size) 21 | image, label = next(iter(dataloader)) 22 | 23 | print(f"Image Shape\t: {image.shape}") 24 | print(f"Label Shape\t: {label.shape}") 25 | print(f"Classes\t\t: {label.unique().tolist()}") 26 | 27 | label[label == -1] = 0 28 | label[label == 255] = 0 29 | labels = [dataset.PALETTE[lbl.to(int)].permute(2, 0, 1) for lbl in label] 30 | labels = torch.stack(labels) 31 | 32 | inv_normalize = T.Normalize( 33 | mean=(-0.485/0.229, -0.456/0.224, -0.406/0.225), 34 | std=(1/0.229, 1/0.224, 1/0.225) 35 | ) 36 | image = inv_normalize(image) 37 | image *= 255 38 | images = torch.vstack([image, labels]) 39 | 40 | plt.imshow(make_grid(images, nrow=4).to(torch.uint8).numpy().transpose((1, 2, 0))) 41 | plt.show() 42 | 43 | 44 | colors = [ 45 | [120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], 46 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], 47 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], 48 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], 49 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], 50 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], 51 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], 52 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], 53 | [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], 54 | [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], 55 | [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], 56 | [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], 57 | [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], 58 | [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], 59 | [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], 60 | [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], 61 | [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], 62 | [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], 63 | [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], [102, 255, 0], [92, 0, 255] 64 | ] 65 | 66 | 67 | def generate_palette(num_classes, background: bool = False): 68 | random.shuffle(colors) 69 | if background: 70 | palette = [[0, 0, 0]] 71 | palette += colors[:num_classes-1] 72 | else: 73 | palette = colors[:num_classes] 74 | return np.array(palette) 75 | 76 | 77 | def draw_text(image: torch.Tensor, seg_map: torch.Tensor, labels: list, fontsize: int = 15): 78 | image = image.to(torch.uint8) 79 | font = ImageFont.truetype("assests/Helvetica.ttf", fontsize) 80 | pil_image = Image.fromarray(image.numpy()) 81 | draw = ImageDraw.Draw(pil_image) 82 | 83 | indices = seg_map.unique().tolist() 84 | classes = [labels[index] for index in indices] 85 | 86 | for idx, cls in zip(indices, classes): 87 | mask = seg_map == idx 88 | mask = mask.squeeze().numpy() 89 | center = np.median((mask == 1).nonzero(), axis=1)[::-1] 90 | bbox = draw.textbbox(center, cls, font=font) 91 | bbox = (bbox[0]-3, bbox[1]-3, bbox[2]+3, bbox[3]+3) 92 | draw.rectangle(bbox, fill=(255, 255, 255), width=1) 93 | draw.text(center, cls, fill=(0, 0, 0), font=font) 94 | return pil_image -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='semseg', 5 | version='0.4.1', 6 | description='SOTA Semantic Segmentation Models', 7 | url='https://github.com/sithu31296/semantic-segmentation', 8 | author='Sithu Aung', 9 | author_email='sithu31296@gmail.com', 10 | license='MIT', 11 | packages=find_packages(include=['semseg']), 12 | install_requires=[ 13 | 'tqdm', 14 | 'tabulate', 15 | 'numpy', 16 | 'scipy', 17 | 'matplotlib', 18 | 'tensorboard', 19 | 'fvcore', 20 | 'einops', 21 | 'rich', 22 | ] 23 | ) -------------------------------------------------------------------------------- /tools/benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import time 4 | from fvcore.nn import flop_count_table, FlopCountAnalysis 5 | from semseg.models import * 6 | 7 | 8 | def main( 9 | model_name: str, 10 | backbone_name: str, 11 | image_size: list, 12 | num_classes: int, 13 | device: str, 14 | ): 15 | device = torch.device('cuda' if torch.cuda.is_available() and device == 'cuda' else 'cpu') 16 | inputs = torch.randn(1, 3, *image_size).to(device) 17 | model = eval(model_name)(backbone_name, num_classes) 18 | model = model.to(device) 19 | model.eval() 20 | 21 | print(flop_count_table(FlopCountAnalysis(model, inputs))) 22 | 23 | total_time = 0.0 24 | for _ in range(10): 25 | tic = time.perf_counter() 26 | model(inputs) 27 | toc = time.perf_counter() 28 | total_time += toc - tic 29 | total_time /= 10 30 | print(f"Inference time: {total_time*1000:.2f}ms") 31 | print(f"FPS: {1/total_time}") 32 | 33 | 34 | if __name__ == '__main__': 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument('--model-name', type=str, default='SegFormer') 37 | parser.add_argument('--backbone-name', type=str, default='MiT-B0') 38 | parser.add_argument('--image-size', type=list, default=[512, 512]) 39 | parser.add_argument('--num-classes', type=int, default=11) 40 | parser.add_argument('--device', type=str, default='cuda') 41 | args = parser.parse_args() 42 | 43 | main(args.model_name, args.backbone_name, args.image_size, args.num_classes, args.device) -------------------------------------------------------------------------------- /tools/export.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import yaml 4 | import onnx 5 | from pathlib import Path 6 | from onnxsim import simplify 7 | from semseg.models import * 8 | from semseg.datasets import * 9 | 10 | 11 | def export_onnx(model, inputs, file): 12 | torch.onnx.export( 13 | model, 14 | inputs, 15 | f"{cfg['TEST']['MODEL_PATH'].split('.')[0]}.onnx", 16 | input_names=['input'], 17 | output_names=['output'], 18 | opset_version=13 19 | ) 20 | onnx_model = onnx.load(f"{file}.onnx") 21 | onnx.checker.check_model(onnx_model) 22 | 23 | onnx_model, check = simplify(onnx_model) 24 | onnx.save(onnx_model, f"{file}.onnx") 25 | assert check, "Simplified ONNX model could not be validated" 26 | print(f"ONNX model saved to {file}.onnx") 27 | 28 | 29 | def export_coreml(model, inputs, file): 30 | try: 31 | import coremltools as ct 32 | ts_model = torch.jit.trace(model, inputs, strict=True) 33 | ct_model = ct.convert( 34 | ts_model, 35 | inputs=[ct.ImageType('image', shape=inputs.shape, scale=1/255.0, bias=[0, 0, 0])] 36 | ) 37 | ct_model.save(f"{file}.mlmodel") 38 | print(f"CoreML model saved to {file}.mlmodel") 39 | except: 40 | print("Please install coremltools to export to CoreML.\n`pip install coremltools`") 41 | 42 | 43 | def main(cfg): 44 | model = eval(cfg['MODEL']['NAME'])(cfg['MODEL']['BACKBONE'], len(eval(cfg['DATASET']['NAME']).PALETTE)) 45 | model.load_state_dict(torch.load(cfg['TEST']['MODEL_PATH'], map_location='cpu')) 46 | model.eval() 47 | 48 | inputs = torch.randn(1, 3, *cfg['TEST']['IMAGE_SIZE']) 49 | file = cfg['TEST']['MODEL_PATH'].split('.')[0] 50 | 51 | export_onnx(model, inputs, file) 52 | export_coreml(model, inputs, file) 53 | print(f"Finished converting.") 54 | 55 | 56 | if __name__ == '__main__': 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument('--cfg', type=str, default='configs/custom.yaml') 59 | args = parser.parse_args() 60 | 61 | with open(args.cfg) as f: 62 | cfg = yaml.load(f, Loader=yaml.SafeLoader) 63 | 64 | save_dir = Path(cfg['SAVE_DIR']) 65 | save_dir.mkdir(exist_ok=True) 66 | 67 | main(cfg) -------------------------------------------------------------------------------- /tools/infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import yaml 4 | import math 5 | from torch import Tensor 6 | from torch.nn import functional as F 7 | from pathlib import Path 8 | from torchvision import io 9 | from torchvision import transforms as T 10 | from semseg.models import * 11 | from semseg.datasets import * 12 | from semseg.utils.utils import timer 13 | from semseg.utils.visualize import draw_text 14 | 15 | from rich.console import Console 16 | console = Console() 17 | 18 | 19 | class SemSeg: 20 | def __init__(self, cfg) -> None: 21 | # inference device cuda or cpu 22 | self.device = torch.device(cfg['DEVICE']) 23 | 24 | # get dataset classes' colors and labels 25 | self.palette = eval(cfg['DATASET']['NAME']).PALETTE 26 | self.labels = eval(cfg['DATASET']['NAME']).CLASSES 27 | 28 | # initialize the model and load weights and send to device 29 | self.model = eval(cfg['MODEL']['NAME'])(cfg['MODEL']['BACKBONE'], len(self.palette)) 30 | self.model.load_state_dict(torch.load(cfg['TEST']['MODEL_PATH'], map_location='cpu')) 31 | self.model = self.model.to(self.device) 32 | self.model.eval() 33 | 34 | # preprocess parameters and transformation pipeline 35 | self.size = cfg['TEST']['IMAGE_SIZE'] 36 | self.tf_pipeline = T.Compose([ 37 | T.Lambda(lambda x: x / 255), 38 | T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 39 | T.Lambda(lambda x: x.unsqueeze(0)) 40 | ]) 41 | 42 | def preprocess(self, image: Tensor) -> Tensor: 43 | H, W = image.shape[1:] 44 | console.print(f"Original Image Size > [red]{H}x{W}[/red]") 45 | # scale the short side of image to target size 46 | scale_factor = self.size[0] / min(H, W) 47 | nH, nW = round(H*scale_factor), round(W*scale_factor) 48 | # make it divisible by model stride 49 | nH, nW = int(math.ceil(nH / 32)) * 32, int(math.ceil(nW / 32)) * 32 50 | console.print(f"Inference Image Size > [red]{nH}x{nW}[/red]") 51 | # resize the image 52 | image = T.Resize((nH, nW))(image) 53 | # divide by 255, norm and add batch dim 54 | image = self.tf_pipeline(image).to(self.device) 55 | return image 56 | 57 | def postprocess(self, orig_img: Tensor, seg_map: Tensor, overlay: bool) -> Tensor: 58 | # resize to original image size 59 | seg_map = F.interpolate(seg_map, size=orig_img.shape[-2:], mode='bilinear', align_corners=True) 60 | # get segmentation map (value being 0 to num_classes) 61 | seg_map = seg_map.softmax(dim=1).argmax(dim=1).cpu().to(int) 62 | 63 | # convert segmentation map to color map 64 | seg_image = self.palette[seg_map].squeeze() 65 | if overlay: 66 | seg_image = (orig_img.permute(1, 2, 0) * 0.4) + (seg_image * 0.6) 67 | 68 | image = draw_text(seg_image, seg_map, self.labels) 69 | return image 70 | 71 | @torch.inference_mode() 72 | @timer 73 | def model_forward(self, img: Tensor) -> Tensor: 74 | return self.model(img) 75 | 76 | def predict(self, img_fname: str, overlay: bool) -> Tensor: 77 | image = io.read_image(img_fname) 78 | img = self.preprocess(image) 79 | seg_map = self.model_forward(img) 80 | seg_map = self.postprocess(image, seg_map, overlay) 81 | return seg_map 82 | 83 | 84 | if __name__ == '__main__': 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument('--cfg', type=str, default='configs/ade20k.yaml') 87 | args = parser.parse_args() 88 | 89 | with open(args.cfg) as f: 90 | cfg = yaml.load(f, Loader=yaml.SafeLoader) 91 | 92 | test_file = Path(cfg['TEST']['FILE']) 93 | if not test_file.exists(): 94 | raise FileNotFoundError(test_file) 95 | 96 | console.print(f"Model > [red]{cfg['MODEL']['NAME']} {cfg['MODEL']['BACKBONE']}[/red]") 97 | console.print(f"Model > [red]{cfg['DATASET']['NAME']}[/red]") 98 | 99 | save_dir = Path(cfg['SAVE_DIR']) / 'test_results' 100 | save_dir.mkdir(exist_ok=True) 101 | 102 | semseg = SemSeg(cfg) 103 | 104 | with console.status("[bright_green]Processing..."): 105 | if test_file.is_file(): 106 | console.rule(f'[green]{test_file}') 107 | segmap = semseg.predict(str(test_file), cfg['TEST']['OVERLAY']) 108 | segmap.save(save_dir / f"{str(test_file.stem)}.png") 109 | else: 110 | files = test_file.glob('*.*') 111 | for file in files: 112 | console.rule(f'[green]{file}') 113 | segmap = semseg.predict(str(file), cfg['TEST']['OVERLAY']) 114 | segmap.save(save_dir / f"{str(file.stem)}.png") 115 | 116 | console.rule(f"[cyan]Segmentation results are saved in `{save_dir}`") -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import yaml 4 | import time 5 | import multiprocessing as mp 6 | from tabulate import tabulate 7 | from tqdm import tqdm 8 | from torch.utils.data import DataLoader 9 | from pathlib import Path 10 | from torch.utils.tensorboard import SummaryWriter 11 | from torch.cuda.amp import GradScaler, autocast 12 | from torch.nn.parallel import DistributedDataParallel as DDP 13 | from torch.utils.data import DistributedSampler, RandomSampler 14 | from torch import distributed as dist 15 | from semseg.models import * 16 | from semseg.datasets import * 17 | from semseg.augmentations import get_train_augmentation, get_val_augmentation 18 | from semseg.losses import get_loss 19 | from semseg.schedulers import get_scheduler 20 | from semseg.optimizers import get_optimizer 21 | from semseg.utils.utils import fix_seeds, setup_cudnn, cleanup_ddp, setup_ddp 22 | from val import evaluate 23 | 24 | 25 | def main(cfg, gpu, save_dir): 26 | start = time.time() 27 | best_mIoU = 0.0 28 | num_workers = mp.cpu_count() 29 | device = torch.device(cfg['DEVICE']) 30 | train_cfg, eval_cfg = cfg['TRAIN'], cfg['EVAL'] 31 | dataset_cfg, model_cfg = cfg['DATASET'], cfg['MODEL'] 32 | loss_cfg, optim_cfg, sched_cfg = cfg['LOSS'], cfg['OPTIMIZER'], cfg['SCHEDULER'] 33 | epochs, lr = train_cfg['EPOCHS'], optim_cfg['LR'] 34 | 35 | traintransform = get_train_augmentation(train_cfg['IMAGE_SIZE'], seg_fill=dataset_cfg['IGNORE_LABEL']) 36 | valtransform = get_val_augmentation(eval_cfg['IMAGE_SIZE']) 37 | 38 | trainset = eval(dataset_cfg['NAME'])(dataset_cfg['ROOT'], 'train', traintransform) 39 | valset = eval(dataset_cfg['NAME'])(dataset_cfg['ROOT'], 'val', valtransform) 40 | 41 | model = eval(model_cfg['NAME'])(model_cfg['BACKBONE'], trainset.n_classes) 42 | model.init_pretrained(model_cfg['PRETRAINED']) 43 | model = model.to(device) 44 | 45 | if train_cfg['DDP']: 46 | sampler = DistributedSampler(trainset, dist.get_world_size(), dist.get_rank(), shuffle=True) 47 | model = DDP(model, device_ids=[gpu]) 48 | else: 49 | sampler = RandomSampler(trainset) 50 | 51 | trainloader = DataLoader(trainset, batch_size=train_cfg['BATCH_SIZE'], num_workers=num_workers, drop_last=True, pin_memory=True, sampler=sampler) 52 | valloader = DataLoader(valset, batch_size=1, num_workers=1, pin_memory=True) 53 | 54 | iters_per_epoch = len(trainset) // train_cfg['BATCH_SIZE'] 55 | # class_weights = trainset.class_weights.to(device) 56 | loss_fn = get_loss(loss_cfg['NAME'], trainset.ignore_label, None) 57 | optimizer = get_optimizer(model, optim_cfg['NAME'], lr, optim_cfg['WEIGHT_DECAY']) 58 | scheduler = get_scheduler(sched_cfg['NAME'], optimizer, epochs * iters_per_epoch, sched_cfg['POWER'], iters_per_epoch * sched_cfg['WARMUP'], sched_cfg['WARMUP_RATIO']) 59 | scaler = GradScaler(enabled=train_cfg['AMP']) 60 | writer = SummaryWriter(str(save_dir / 'logs')) 61 | 62 | for epoch in range(epochs): 63 | model.train() 64 | if train_cfg['DDP']: sampler.set_epoch(epoch) 65 | 66 | train_loss = 0.0 67 | pbar = tqdm(enumerate(trainloader), total=iters_per_epoch, desc=f"Epoch: [{epoch+1}/{epochs}] Iter: [{0}/{iters_per_epoch}] LR: {lr:.8f} Loss: {train_loss:.8f}") 68 | 69 | for iter, (img, lbl) in pbar: 70 | optimizer.zero_grad(set_to_none=True) 71 | 72 | img = img.to(device) 73 | lbl = lbl.to(device) 74 | 75 | with autocast(enabled=train_cfg['AMP']): 76 | logits = model(img) 77 | loss = loss_fn(logits, lbl) 78 | 79 | scaler.scale(loss).backward() 80 | scaler.step(optimizer) 81 | scaler.update() 82 | scheduler.step() 83 | torch.cuda.synchronize() 84 | 85 | lr = scheduler.get_lr() 86 | lr = sum(lr) / len(lr) 87 | train_loss += loss.item() 88 | 89 | pbar.set_description(f"Epoch: [{epoch+1}/{epochs}] Iter: [{iter+1}/{iters_per_epoch}] LR: {lr:.8f} Loss: {train_loss / (iter+1):.8f}") 90 | 91 | train_loss /= iter+1 92 | writer.add_scalar('train/loss', train_loss, epoch) 93 | torch.cuda.empty_cache() 94 | 95 | if (epoch+1) % train_cfg['EVAL_INTERVAL'] == 0 or (epoch+1) == epochs: 96 | miou = evaluate(model, valloader, device)[-1] 97 | writer.add_scalar('val/mIoU', miou, epoch) 98 | 99 | if miou > best_mIoU: 100 | best_mIoU = miou 101 | torch.save(model.module.state_dict() if train_cfg['DDP'] else model.state_dict(), save_dir / f"{model_cfg['NAME']}_{model_cfg['BACKBONE']}_{dataset_cfg['NAME']}.pth") 102 | print(f"Current mIoU: {miou} Best mIoU: {best_mIoU}") 103 | 104 | writer.close() 105 | pbar.close() 106 | end = time.gmtime(time.time() - start) 107 | 108 | table = [ 109 | ['Best mIoU', f"{best_mIoU:.2f}"], 110 | ['Total Training Time', time.strftime("%H:%M:%S", end)] 111 | ] 112 | print(tabulate(table, numalign='right')) 113 | 114 | 115 | if __name__ == '__main__': 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument('--cfg', type=str, default='configs/custom.yaml', help='Configuration file to use') 118 | args = parser.parse_args() 119 | 120 | with open(args.cfg) as f: 121 | cfg = yaml.load(f, Loader=yaml.SafeLoader) 122 | 123 | fix_seeds(3407) 124 | setup_cudnn() 125 | gpu = setup_ddp() 126 | save_dir = Path(cfg['SAVE_DIR']) 127 | save_dir.mkdir(exist_ok=True) 128 | main(cfg, gpu, save_dir) 129 | cleanup_ddp() -------------------------------------------------------------------------------- /tools/val.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import yaml 4 | import math 5 | from pathlib import Path 6 | from tqdm import tqdm 7 | from tabulate import tabulate 8 | from torch.utils.data import DataLoader 9 | from torch.nn import functional as F 10 | from semseg.models import * 11 | from semseg.datasets import * 12 | from semseg.augmentations import get_val_augmentation 13 | from semseg.metrics import Metrics 14 | from semseg.utils.utils import setup_cudnn 15 | 16 | 17 | @torch.no_grad() 18 | def evaluate(model, dataloader, device): 19 | print('Evaluating...') 20 | model.eval() 21 | metrics = Metrics(dataloader.dataset.n_classes, dataloader.dataset.ignore_label, device) 22 | 23 | for images, labels in tqdm(dataloader): 24 | images = images.to(device) 25 | labels = labels.to(device) 26 | preds = model(images).softmax(dim=1) 27 | metrics.update(preds, labels) 28 | 29 | ious, miou = metrics.compute_iou() 30 | acc, macc = metrics.compute_pixel_acc() 31 | f1, mf1 = metrics.compute_f1() 32 | 33 | return acc, macc, f1, mf1, ious, miou 34 | 35 | 36 | @torch.no_grad() 37 | def evaluate_msf(model, dataloader, device, scales, flip): 38 | model.eval() 39 | 40 | n_classes = dataloader.dataset.n_classes 41 | metrics = Metrics(n_classes, dataloader.dataset.ignore_label, device) 42 | 43 | for images, labels in tqdm(dataloader): 44 | labels = labels.to(device) 45 | B, H, W = labels.shape 46 | scaled_logits = torch.zeros(B, n_classes, H, W).to(device) 47 | 48 | for scale in scales: 49 | new_H, new_W = int(scale * H), int(scale * W) 50 | new_H, new_W = int(math.ceil(new_H / 32)) * 32, int(math.ceil(new_W / 32)) * 32 51 | scaled_images = F.interpolate(images, size=(new_H, new_W), mode='bilinear', align_corners=True) 52 | scaled_images = scaled_images.to(device) 53 | logits = model(scaled_images) 54 | logits = F.interpolate(logits, size=(H, W), mode='bilinear', align_corners=True) 55 | scaled_logits += logits.softmax(dim=1) 56 | 57 | if flip: 58 | scaled_images = torch.flip(scaled_images, dims=(3,)) 59 | logits = model(scaled_images) 60 | logits = torch.flip(logits, dims=(3,)) 61 | logits = F.interpolate(logits, size=(H, W), mode='bilinear', align_corners=True) 62 | scaled_logits += logits.softmax(dim=1) 63 | 64 | metrics.update(scaled_logits, labels) 65 | 66 | acc, macc = metrics.compute_pixel_acc() 67 | f1, mf1 = metrics.compute_f1() 68 | ious, miou = metrics.compute_iou() 69 | return acc, macc, f1, mf1, ious, miou 70 | 71 | 72 | def main(cfg): 73 | device = torch.device(cfg['DEVICE']) 74 | 75 | eval_cfg = cfg['EVAL'] 76 | transform = get_val_augmentation(eval_cfg['IMAGE_SIZE']) 77 | dataset = eval(cfg['DATASET']['NAME'])(cfg['DATASET']['ROOT'], 'val', transform) 78 | dataloader = DataLoader(dataset, 1, num_workers=1, pin_memory=True) 79 | 80 | model_path = Path(eval_cfg['MODEL_PATH']) 81 | if not model_path.exists(): model_path = Path(cfg['SAVE_DIR']) / f"{cfg['MODEL']['NAME']}_{cfg['MODEL']['BACKBONE']}_{cfg['DATASET']['NAME']}.pth" 82 | print(f"Evaluating {model_path}...") 83 | 84 | model = eval(cfg['MODEL']['NAME'])(cfg['MODEL']['BACKBONE'], dataset.n_classes) 85 | model.load_state_dict(torch.load(str(model_path), map_location='cpu')) 86 | model = model.to(device) 87 | 88 | if eval_cfg['MSF']['ENABLE']: 89 | acc, macc, f1, mf1, ious, miou = evaluate_msf(model, dataloader, device, eval_cfg['MSF']['SCALES'], eval_cfg['MSF']['FLIP']) 90 | else: 91 | acc, macc, f1, mf1, ious, miou = evaluate(model, dataloader, device) 92 | 93 | table = { 94 | 'Class': list(dataset.CLASSES) + ['Mean'], 95 | 'IoU': ious + [miou], 96 | 'F1': f1 + [mf1], 97 | 'Acc': acc + [macc] 98 | } 99 | 100 | print(tabulate(table, headers='keys')) 101 | 102 | 103 | if __name__ == '__main__': 104 | parser = argparse.ArgumentParser() 105 | parser.add_argument('--cfg', type=str, default='configs/custom.yaml') 106 | args = parser.parse_args() 107 | 108 | with open(args.cfg) as f: 109 | cfg = yaml.load(f, Loader=yaml.SafeLoader) 110 | 111 | setup_cudnn() 112 | main(cfg) --------------------------------------------------------------------------------