├── .idea ├── inspectionProfiles │ └── profiles_settings.xml ├── libraries │ └── R_User_Library.xml ├── misc.xml ├── modules.xml ├── parse2.iml ├── vcs.xml └── workspace.xml ├── .vscode └── .ropeproject │ └── config.py ├── README.md ├── checkpoints ├── exp │ └── result checkpoints.txt └── init │ └── pretrained models.txt ├── dataset ├── ATR │ ├── README.MD │ ├── select_id.txt │ ├── test_id.txt │ └── train_id.txt ├── CCF │ ├── select_id.txt │ ├── test_id.txt │ └── train_id.txt ├── CIHP │ ├── README.md │ ├── all_id.txt │ ├── human_colormap.mat │ ├── test_id.txt │ ├── train_id.txt │ ├── trainval_id.txt │ └── val_id.txt ├── LIP │ ├── README.md │ ├── hard_id.txt │ ├── train_id.txt │ ├── train_val.txt │ ├── val.txt │ └── val_id.txt ├── PPSS │ ├── test_id.txt │ └── train_id.txt ├── Pascal │ ├── README.MD │ ├── train_id.txt │ └── val_id.txt ├── __init__.py ├── data_CIHP.py ├── data_atr.py ├── data_ccf.py ├── data_lip.py ├── data_pascal.py ├── data_ppss.py ├── data_transforms.py ├── transforms.py └── weights.py ├── doc └── architecture.png ├── evaluate_pascal.py ├── evaluate_pascal.sh ├── inplace_abn ├── __init__.py ├── bn.py ├── functions.py └── src │ ├── checks.h │ ├── common.h │ ├── inplace_abn.cpp │ ├── inplace_abn.h │ ├── inplace_abn_cpu.cpp │ ├── inplace_abn_cuda.cu │ ├── inplace_abn_cuda_half.cu │ └── utils │ ├── checks.h │ ├── common.h │ └── cuda.cuh ├── modules ├── __init__.py ├── com_mod.py ├── convGRU.py ├── inits.py └── parse_mod.py ├── network ├── ResNet_stem_converter.py ├── __init__.py ├── baseline.py └── gnn_parse.py ├── requirements.txt ├── train ├── train_atr.py ├── train_ccf.py ├── train_lip.py ├── train_pascal.py └── train_ppss.py ├── train_baseline.py ├── train_pascal.sh ├── utils ├── __init__.py ├── aaf │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── layers.cpython-36.pyc │ │ └── losses.cpython-36.pyc │ ├── layers.py │ └── losses.py ├── best │ └── lovasz_loss.py ├── gnn_loss.py ├── learning_policy.py ├── lovasz_loss.py ├── metric.py ├── parallel.py └── visualize.py └── val ├── evaluate_atr.py ├── evaluate_ccf.py ├── evaluate_lip.py ├── evaluate_pascal.py ├── evaluate_ppss.py ├── f1_eval.py └── f1_eval_atr.py /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/libraries/R_User_Library.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/parse2.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 14 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 12 | 13 | 15 | 16 | 17 | 18 | true 19 | DEFINITION_ORDER 20 | 21 | 22 | 23 | 24 | 25 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 1589771543740 65 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 102 | 103 | 105 | -------------------------------------------------------------------------------- /.vscode/.ropeproject/config.py: -------------------------------------------------------------------------------- 1 | # The default ``config.py`` 2 | # flake8: noqa 3 | 4 | 5 | def set_prefs(prefs): 6 | """This function is called before opening the project""" 7 | 8 | # Specify which files and folders to ignore in the project. 9 | # Changes to ignored resources are not added to the history and 10 | # VCSs. Also they are not returned in `Project.get_files()`. 11 | # Note that ``?`` and ``*`` match all characters but slashes. 12 | # '*.pyc': matches 'test.pyc' and 'pkg/test.pyc' 13 | # 'mod*.pyc': matches 'test/mod1.pyc' but not 'mod/1.pyc' 14 | # '.svn': matches 'pkg/.svn' and all of its children 15 | # 'build/*.o': matches 'build/lib.o' but not 'build/sub/lib.o' 16 | # 'build//*.o': matches 'build/lib.o' and 'build/sub/lib.o' 17 | prefs['ignored_resources'] = ['*.pyc', '*~', '.ropeproject', 18 | '.hg', '.svn', '_svn', '.git', '.tox'] 19 | 20 | # Specifies which files should be considered python files. It is 21 | # useful when you have scripts inside your project. Only files 22 | # ending with ``.py`` are considered to be python files by 23 | # default. 24 | # prefs['python_files'] = ['*.py'] 25 | 26 | # Custom source folders: By default rope searches the project 27 | # for finding source folders (folders that should be searched 28 | # for finding modules). You can add paths to that list. Note 29 | # that rope guesses project source folders correctly most of the 30 | # time; use this if you have any problems. 31 | # The folders should be relative to project root and use '/' for 32 | # separating folders regardless of the platform rope is running on. 33 | # 'src/my_source_folder' for instance. 34 | # prefs.add('source_folders', 'src') 35 | 36 | # You can extend python path for looking up modules 37 | # prefs.add('python_path', '~/python/') 38 | 39 | # Should rope save object information or not. 40 | prefs['save_objectdb'] = True 41 | prefs['compress_objectdb'] = False 42 | 43 | # If `True`, rope analyzes each module when it is being saved. 44 | prefs['automatic_soa'] = True 45 | # The depth of calls to follow in static object analysis 46 | prefs['soa_followed_calls'] = 0 47 | 48 | # If `False` when running modules or unit tests "dynamic object 49 | # analysis" is turned off. This makes them much faster. 50 | prefs['perform_doa'] = True 51 | 52 | # Rope can check the validity of its object DB when running. 53 | prefs['validate_objectdb'] = True 54 | 55 | # How many undos to hold? 56 | prefs['max_history_items'] = 32 57 | 58 | # Shows whether to save history across sessions. 59 | prefs['save_history'] = True 60 | prefs['compress_history'] = False 61 | 62 | # Set the number spaces used for indenting. According to 63 | # :PEP:`8`, it is best to use 4 spaces. Since most of rope's 64 | # unit-tests use 4 spaces it is more reliable, too. 65 | prefs['indent_size'] = 4 66 | 67 | # Builtin and c-extension modules that are allowed to be imported 68 | # and inspected by rope. 69 | prefs['extension_modules'] = [] 70 | 71 | # Add all standard c-extensions to extension_modules list. 72 | prefs['import_dynload_stdmods'] = True 73 | 74 | # If `True` modules with syntax errors are considered to be empty. 75 | # The default value is `False`; When `False` syntax errors raise 76 | # `rope.base.exceptions.ModuleSyntaxError` exception. 77 | prefs['ignore_syntax_errors'] = False 78 | 79 | # If `True`, rope ignores unresolvable imports. Otherwise, they 80 | # appear in the importing namespace. 81 | prefs['ignore_bad_imports'] = False 82 | 83 | # If `True`, rope will insert new module imports as 84 | # `from import ` by default. 85 | prefs['prefer_module_from_imports'] = False 86 | 87 | # If `True`, rope will transform a comma list of imports into 88 | # multiple separate import statements when organizing 89 | # imports. 90 | prefs['split_imports'] = False 91 | 92 | # If `True`, rope will remove all top-level import statements and 93 | # reinsert them at the top of the module when making changes. 94 | prefs['pull_imports_to_top'] = True 95 | 96 | # If `True`, rope will sort imports alphabetically by module name instead 97 | # of alphabetically by import statement, with from imports after normal 98 | # imports. 99 | prefs['sort_imports_alphabetically'] = False 100 | 101 | # Location of implementation of 102 | # rope.base.oi.type_hinting.interfaces.ITypeHintingFactory In general 103 | # case, you don't have to change this value, unless you're an rope expert. 104 | # Change this value to inject you own implementations of interfaces 105 | # listed in module rope.base.oi.type_hinting.providers.interfaces 106 | # For example, you can add you own providers for Django Models, or disable 107 | # the search type-hinting in a class hierarchy, etc. 108 | prefs['type_hinting_factory'] = ( 109 | 'rope.base.oi.type_hinting.factory.default_type_hinting_factory') 110 | 111 | 112 | def project_opened(project): 113 | """This function is called after opening the project""" 114 | # Do whatever you like here! 115 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hierarchical Human Parsing with Typed Part-Relation Reasoning (CVPR2020) 2 | 3 | ## Introduction 4 | The algorithm is described in the [CVPR 2020 paper: Hierarchical Human Parsing with Typed Part-Relation Reasoning](https://openaccess.thecvf.com/content_CVPR_2020/papers/Wang_Hierarchical_Human_Parsing_With_Typed_Part-Relation_Reasoning_CVPR_2020_paper.pdf). 5 | 6 | ![network](doc/architecture.png) 7 | *** 8 | 9 | ## Environment and installation 10 | This repository is developed under **CUDA-10.0** and **pytorch-1.2.0** in **python3.6**. The required packages can be installed by: 11 | ```bash 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | ## Structure of repo 16 | ````bash 17 | $HierarchicalHumanParsing 18 | ├── checkpoints 19 | │ ├── init 20 | ├── dataset 21 | │ ├── list 22 | ├── doc 23 | ├── inplace_abn 24 | │ ├── src 25 | ├── modules 26 | ├── network 27 | ├── utils 28 | ```` 29 | 30 | ## Running the code 31 | ```bash 32 | python evaluate_pascal.py 33 | ``` 34 | 35 | *** 36 | ## Citation 37 | If you find this code useful, please cite the related work with the following bibtex: 38 | ``` 39 | @InProceedings{Wang_2020_CVPR, 40 | author = {Wang, Wenguan and Zhu, Hailong and Dai, Jifeng and Pang, Yanwei and Shen, Jianbing and Shao, Ling}, 41 | title = {Hierarchical Human Parsing With Typed Part-Relation Reasoning}, 42 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 43 | month = {June}, 44 | year = {2020} 45 | } 46 | 47 | @InProceedings{Wang_2019_ICCV, 48 | author = {Wang, Wenguan and Zhang, Zhijie and Qi, Siyuan and Shen, Jianbing and Pang, Yanwei and Shao, Ling}, 49 | title = {Learning Compositional Neural Information Fusion for Human Parsing}, 50 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 51 | month = {October}, 52 | year = {2019} 53 | } 54 | ``` 55 | -------------------------------------------------------------------------------- /checkpoints/exp/result checkpoints.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hlzhu09/Hierarchical-Human-Parsing/58737dbb7763f5a7baafd4801c5a287facf536cd/checkpoints/exp/result checkpoints.txt -------------------------------------------------------------------------------- /checkpoints/init/pretrained models.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hlzhu09/Hierarchical-Human-Parsing/58737dbb7763f5a7baafd4801c5a287facf536cd/checkpoints/init/pretrained models.txt -------------------------------------------------------------------------------- /dataset/ATR/README.MD: -------------------------------------------------------------------------------- 1 | background 0 2 | hat 1 3 | hair 2 4 | sunglass 3 5 | upper-clothes 4 6 | skirt 5 7 | pants 6 8 | dress 7 9 | belt 8 10 | left-shoe 9 11 | right-shoe 10 12 | face 11 13 | left-leg 12 14 | right-leg 13 15 | left-arm 14 16 | right-arm 15 17 | bag 16 18 | scarf 17 19 | 20 | 0 background 21 | 1-4 up 22 | 5-10 down 23 | 11 up 24 | 12-13 down 25 | 14-17 up 26 | -------------------------------------------------------------------------------- /dataset/ATR/select_id.txt: -------------------------------------------------------------------------------- 1 | 997_392 2 | 997_586 3 | 2500_175 4 | 2500_198 5 | 2500_434 6 | 2500_520 7 | 2500_534 8 | 2500_783 9 | 4565_1329 10 | 4565_1414 11 | 4565_1427 12 | 4565_1572 13 | 4565_1575 14 | 4565_1624 15 | 4565_1730 16 | 4565_1991 17 | 4565_2002 18 | 4565_2046 19 | 4565_2672 20 | 4565_2732 -------------------------------------------------------------------------------- /dataset/CCF/select_id.txt: -------------------------------------------------------------------------------- 1 | 0770 2 | 0913 3 | 0474 4 | 0840 5 | 0662 6 | 0315 7 | 0102 8 | 0425 -------------------------------------------------------------------------------- /dataset/CCF/test_id.txt: -------------------------------------------------------------------------------- 1 | 0012 2 | 0018 3 | 0020 4 | 0022 5 | 0023 6 | 0028 7 | 0040 8 | 0045 9 | 0047 10 | 0050 11 | 0052 12 | 0055 13 | 0060 14 | 0061 15 | 0062 16 | 0063 17 | 0066 18 | 0084 19 | 0085 20 | 0087 21 | 0104 22 | 0106 23 | 0108 24 | 0110 25 | 0113 26 | 0114 27 | 0115 28 | 0119 29 | 0120 30 | 0126 31 | 0129 32 | 0147 33 | 0154 34 | 0160 35 | 0162 36 | 0163 37 | 0171 38 | 0175 39 | 0195 40 | 0197 41 | 0200 42 | 0202 43 | 0225 44 | 0240 45 | 0245 46 | 0251 47 | 0252 48 | 0254 49 | 0263 50 | 0277 51 | 0281 52 | 0286 53 | 0291 54 | 0292 55 | 0301 56 | 0304 57 | 0307 58 | 0325 59 | 0327 60 | 0330 61 | 0346 62 | 0366 63 | 0378 64 | 0383 65 | 0390 66 | 0399 67 | 0408 68 | 0433 69 | 0440 70 | 0459 71 | 0463 72 | 0465 73 | 0467 74 | 0479 75 | 0487 76 | 0490 77 | 0522 78 | 0554 79 | 0593 80 | 0625 81 | 0646 82 | 0699 83 | 0721 84 | 0790 85 | 0832 86 | 0848 87 | 0861 88 | 0871 89 | 0873 90 | 0881 91 | 0936 92 | 0993 -------------------------------------------------------------------------------- /dataset/CCF/train_id.txt: -------------------------------------------------------------------------------- 1 | 0001 2 | 0002 3 | 0003 4 | 0004 5 | 0006 6 | 0007 7 | 0008 8 | 0009 9 | 0010 10 | 0011 11 | 0013 12 | 0014 13 | 0015 14 | 0016 15 | 0017 16 | 0019 17 | 0021 18 | 0025 19 | 0026 20 | 0029 21 | 0030 22 | 0031 23 | 0032 24 | 0033 25 | 0034 26 | 0035 27 | 0036 28 | 0037 29 | 0039 30 | 0042 31 | 0043 32 | 0044 33 | 0046 34 | 0048 35 | 0049 36 | 0051 37 | 0053 38 | 0054 39 | 0057 40 | 0058 41 | 0059 42 | 0065 43 | 0067 44 | 0068 45 | 0069 46 | 0070 47 | 0071 48 | 0072 49 | 0073 50 | 0074 51 | 0075 52 | 0076 53 | 0077 54 | 0078 55 | 0079 56 | 0080 57 | 0081 58 | 0082 59 | 0083 60 | 0086 61 | 0088 62 | 0089 63 | 0091 64 | 0092 65 | 0093 66 | 0094 67 | 0095 68 | 0096 69 | 0097 70 | 0099 71 | 0100 72 | 0101 73 | 0102 74 | 0103 75 | 0105 76 | 0107 77 | 0111 78 | 0112 79 | 0116 80 | 0117 81 | 0118 82 | 0122 83 | 0123 84 | 0124 85 | 0125 86 | 0128 87 | 0131 88 | 0132 89 | 0133 90 | 0134 91 | 0135 92 | 0136 93 | 0139 94 | 0140 95 | 0143 96 | 0144 97 | 0148 98 | 0149 99 | 0150 100 | 0151 101 | 0152 102 | 0153 103 | 0155 104 | 0156 105 | 0157 106 | 0158 107 | 0159 108 | 0161 109 | 0164 110 | 0165 111 | 0166 112 | 0167 113 | 0168 114 | 0169 115 | 0172 116 | 0173 117 | 0174 118 | 0176 119 | 0177 120 | 0178 121 | 0179 122 | 0180 123 | 0181 124 | 0182 125 | 0183 126 | 0184 127 | 0185 128 | 0186 129 | 0187 130 | 0188 131 | 0189 132 | 0190 133 | 0191 134 | 0192 135 | 0193 136 | 0194 137 | 0196 138 | 0198 139 | 0201 140 | 0203 141 | 0204 142 | 0205 143 | 0206 144 | 0207 145 | 0208 146 | 0209 147 | 0210 148 | 0211 149 | 0212 150 | 0213 151 | 0214 152 | 0215 153 | 0216 154 | 0217 155 | 0218 156 | 0219 157 | 0220 158 | 0221 159 | 0222 160 | 0223 161 | 0224 162 | 0226 163 | 0227 164 | 0228 165 | 0229 166 | 0230 167 | 0231 168 | 0233 169 | 0234 170 | 0235 171 | 0237 172 | 0238 173 | 0239 174 | 0241 175 | 0243 176 | 0244 177 | 0246 178 | 0247 179 | 0248 180 | 0249 181 | 0250 182 | 0253 183 | 0255 184 | 0256 185 | 0257 186 | 0258 187 | 0259 188 | 0260 189 | 0261 190 | 0262 191 | 0265 192 | 0267 193 | 0268 194 | 0269 195 | 0270 196 | 0271 197 | 0272 198 | 0273 199 | 0274 200 | 0275 201 | 0276 202 | 0278 203 | 0279 204 | 0280 205 | 0282 206 | 0284 207 | 0285 208 | 0287 209 | 0288 210 | 0289 211 | 0290 212 | 0293 213 | 0294 214 | 0295 215 | 0297 216 | 0298 217 | 0299 218 | 0300 219 | 0302 220 | 0303 221 | 0305 222 | 0306 223 | 0308 224 | 0309 225 | 0310 226 | 0311 227 | 0312 228 | 0313 229 | 0314 230 | 0315 231 | 0316 232 | 0317 233 | 0318 234 | 0319 235 | 0320 236 | 0321 237 | 0322 238 | 0323 239 | 0324 240 | 0326 241 | 0328 242 | 0329 243 | 0331 244 | 0332 245 | 0333 246 | 0334 247 | 0335 248 | 0336 249 | 0337 250 | 0338 251 | 0339 252 | 0340 253 | 0341 254 | 0342 255 | 0344 256 | 0345 257 | 0347 258 | 0348 259 | 0349 260 | 0350 261 | 0351 262 | 0352 263 | 0353 264 | 0354 265 | 0355 266 | 0356 267 | 0357 268 | 0358 269 | 0359 270 | 0360 271 | 0361 272 | 0362 273 | 0363 274 | 0364 275 | 0365 276 | 0367 277 | 0369 278 | 0370 279 | 0371 280 | 0372 281 | 0373 282 | 0374 283 | 0375 284 | 0376 285 | 0377 286 | 0380 287 | 0381 288 | 0382 289 | 0384 290 | 0386 291 | 0388 292 | 0389 293 | 0391 294 | 0392 295 | 0393 296 | 0394 297 | 0395 298 | 0396 299 | 0398 300 | 0400 301 | 0401 302 | 0402 303 | 0403 304 | 0405 305 | 0406 306 | 0407 307 | 0409 308 | 0410 309 | 0411 310 | 0412 311 | 0413 312 | 0414 313 | 0415 314 | 0416 315 | 0417 316 | 0419 317 | 0420 318 | 0421 319 | 0422 320 | 0423 321 | 0424 322 | 0425 323 | 0426 324 | 0427 325 | 0428 326 | 0429 327 | 0431 328 | 0432 329 | 0434 330 | 0435 331 | 0436 332 | 0437 333 | 0438 334 | 0441 335 | 0442 336 | 0443 337 | 0444 338 | 0445 339 | 0446 340 | 0447 341 | 0448 342 | 0449 343 | 0450 344 | 0451 345 | 0452 346 | 0453 347 | 0454 348 | 0455 349 | 0456 350 | 0457 351 | 0458 352 | 0460 353 | 0461 354 | 0462 355 | 0464 356 | 0466 357 | 0468 358 | 0469 359 | 0470 360 | 0471 361 | 0472 362 | 0473 363 | 0474 364 | 0475 365 | 0476 366 | 0477 367 | 0478 368 | 0480 369 | 0481 370 | 0482 371 | 0483 372 | 0484 373 | 0485 374 | 0486 375 | 0488 376 | 0489 377 | 0491 378 | 0492 379 | 0493 380 | 0494 381 | 0495 382 | 0497 383 | 0498 384 | 0499 385 | 0500 386 | 0501 387 | 0502 388 | 0503 389 | 0504 390 | 0505 391 | 0506 392 | 0507 393 | 0508 394 | 0509 395 | 0510 396 | 0511 397 | 0513 398 | 0514 399 | 0515 400 | 0516 401 | 0517 402 | 0518 403 | 0519 404 | 0520 405 | 0521 406 | 0523 407 | 0524 408 | 0525 409 | 0526 410 | 0527 411 | 0528 412 | 0529 413 | 0530 414 | 0531 415 | 0532 416 | 0533 417 | 0534 418 | 0535 419 | 0536 420 | 0537 421 | 0538 422 | 0539 423 | 0540 424 | 0541 425 | 0542 426 | 0543 427 | 0544 428 | 0545 429 | 0546 430 | 0547 431 | 0548 432 | 0549 433 | 0551 434 | 0552 435 | 0553 436 | 0555 437 | 0556 438 | 0558 439 | 0559 440 | 0560 441 | 0561 442 | 0562 443 | 0563 444 | 0564 445 | 0565 446 | 0566 447 | 0567 448 | 0568 449 | 0569 450 | 0570 451 | 0571 452 | 0572 453 | 0573 454 | 0574 455 | 0575 456 | 0577 457 | 0578 458 | 0579 459 | 0580 460 | 0581 461 | 0582 462 | 0583 463 | 0584 464 | 0585 465 | 0586 466 | 0587 467 | 0588 468 | 0589 469 | 0590 470 | 0591 471 | 0592 472 | 0594 473 | 0595 474 | 0596 475 | 0597 476 | 0598 477 | 0600 478 | 0601 479 | 0602 480 | 0603 481 | 0604 482 | 0605 483 | 0606 484 | 0607 485 | 0608 486 | 0610 487 | 0611 488 | 0612 489 | 0613 490 | 0614 491 | 0615 492 | 0616 493 | 0617 494 | 0618 495 | 0619 496 | 0620 497 | 0621 498 | 0622 499 | 0623 500 | 0624 501 | 0626 502 | 0627 503 | 0628 504 | 0629 505 | 0630 506 | 0631 507 | 0632 508 | 0633 509 | 0634 510 | 0635 511 | 0636 512 | 0637 513 | 0638 514 | 0639 515 | 0640 516 | 0641 517 | 0642 518 | 0643 519 | 0644 520 | 0645 521 | 0647 522 | 0648 523 | 0649 524 | 0650 525 | 0651 526 | 0653 527 | 0654 528 | 0655 529 | 0656 530 | 0657 531 | 0659 532 | 0660 533 | 0661 534 | 0662 535 | 0663 536 | 0664 537 | 0665 538 | 0666 539 | 0668 540 | 0669 541 | 0670 542 | 0671 543 | 0672 544 | 0673 545 | 0675 546 | 0677 547 | 0678 548 | 0679 549 | 0680 550 | 0681 551 | 0682 552 | 0684 553 | 0686 554 | 0687 555 | 0688 556 | 0689 557 | 0690 558 | 0691 559 | 0692 560 | 0693 561 | 0694 562 | 0695 563 | 0696 564 | 0697 565 | 0698 566 | 0700 567 | 0701 568 | 0702 569 | 0703 570 | 0704 571 | 0705 572 | 0706 573 | 0707 574 | 0708 575 | 0709 576 | 0711 577 | 0712 578 | 0713 579 | 0714 580 | 0715 581 | 0716 582 | 0717 583 | 0718 584 | 0719 585 | 0720 586 | 0722 587 | 0723 588 | 0725 589 | 0726 590 | 0727 591 | 0728 592 | 0729 593 | 0730 594 | 0731 595 | 0733 596 | 0734 597 | 0736 598 | 0737 599 | 0738 600 | 0739 601 | 0741 602 | 0742 603 | 0743 604 | 0744 605 | 0745 606 | 0746 607 | 0747 608 | 0748 609 | 0749 610 | 0750 611 | 0751 612 | 0752 613 | 0753 614 | 0755 615 | 0756 616 | 0757 617 | 0758 618 | 0759 619 | 0760 620 | 0761 621 | 0762 622 | 0763 623 | 0764 624 | 0765 625 | 0766 626 | 0767 627 | 0768 628 | 0769 629 | 0770 630 | 0771 631 | 0772 632 | 0773 633 | 0774 634 | 0775 635 | 0776 636 | 0777 637 | 0778 638 | 0779 639 | 0780 640 | 0781 641 | 0782 642 | 0783 643 | 0784 644 | 0785 645 | 0786 646 | 0787 647 | 0788 648 | 0789 649 | 0791 650 | 0792 651 | 0793 652 | 0794 653 | 0795 654 | 0798 655 | 0799 656 | 0800 657 | 0801 658 | 0802 659 | 0803 660 | 0804 661 | 0805 662 | 0806 663 | 0807 664 | 0808 665 | 0809 666 | 0810 667 | 0811 668 | 0812 669 | 0813 670 | 0814 671 | 0815 672 | 0816 673 | 0817 674 | 0818 675 | 0819 676 | 0820 677 | 0821 678 | 0822 679 | 0823 680 | 0824 681 | 0825 682 | 0826 683 | 0827 684 | 0828 685 | 0829 686 | 0831 687 | 0834 688 | 0835 689 | 0836 690 | 0837 691 | 0838 692 | 0839 693 | 0840 694 | 0842 695 | 0843 696 | 0844 697 | 0845 698 | 0847 699 | 0849 700 | 0850 701 | 0851 702 | 0852 703 | 0853 704 | 0854 705 | 0856 706 | 0857 707 | 0858 708 | 0859 709 | 0860 710 | 0862 711 | 0863 712 | 0864 713 | 0865 714 | 0866 715 | 0867 716 | 0868 717 | 0869 718 | 0870 719 | 0872 720 | 0874 721 | 0875 722 | 0876 723 | 0877 724 | 0878 725 | 0879 726 | 0880 727 | 0882 728 | 0883 729 | 0884 730 | 0885 731 | 0886 732 | 0887 733 | 0888 734 | 0889 735 | 0890 736 | 0891 737 | 0892 738 | 0893 739 | 0894 740 | 0895 741 | 0897 742 | 0898 743 | 0899 744 | 0900 745 | 0901 746 | 0902 747 | 0903 748 | 0904 749 | 0905 750 | 0906 751 | 0907 752 | 0909 753 | 0910 754 | 0911 755 | 0912 756 | 0913 757 | 0914 758 | 0915 759 | 0916 760 | 0917 761 | 0918 762 | 0919 763 | 0920 764 | 0921 765 | 0922 766 | 0925 767 | 0926 768 | 0927 769 | 0928 770 | 0929 771 | 0930 772 | 0931 773 | 0932 774 | 0933 775 | 0934 776 | 0935 777 | 0937 778 | 0938 779 | 0939 780 | 0940 781 | 0941 782 | 0943 783 | 0944 784 | 0945 785 | 0946 786 | 0947 787 | 0948 788 | 0949 789 | 0950 790 | 0951 791 | 0952 792 | 0953 793 | 0954 794 | 0956 795 | 0957 796 | 0958 797 | 0959 798 | 0960 799 | 0961 800 | 0962 801 | 0964 802 | 0965 803 | 0966 804 | 0967 805 | 0968 806 | 0969 807 | 0970 808 | 0971 809 | 0972 810 | 0973 811 | 0974 812 | 0975 813 | 0976 814 | 0977 815 | 0978 816 | 0979 817 | 0980 818 | 0981 819 | 0983 820 | 0984 821 | 0986 822 | 0987 823 | 0988 824 | 0989 825 | 0990 826 | 0991 827 | 0992 828 | 0994 829 | 0995 830 | 0996 831 | 0999 832 | 1001 833 | 1002 834 | 1003 835 | 1004 -------------------------------------------------------------------------------- /dataset/CIHP/README.md: -------------------------------------------------------------------------------- 1 | 2 | Images: images 3 | Category_ids: semantic part segmentation labels Categories: visualized semantic part segmentation labels 4 | Human_ids: semantic person segmentation labels Human: visualized semantic person segmentation labels 5 | Instance_ids: instance-level human parsing labels Instances: visualized instance-level human parsing labels 6 | 7 | 8 | Label order of semantic part segmentation: 9 | 10 | 1.Hat 11 | 2.Hair 12 | 3.Glove 13 | 4.Sunglasses 14 | 5.UpperClothes 15 | 6.Dress 16 | 7.Coat 17 | 8.Socks 18 | 9.Pants 19 | 10.Torso-skin 20 | 11.Scarf 21 | 12.Skirt 22 | 13.Face 23 | 14.Left-arm 24 | 15.Right-arm 25 | 16.Left-leg 26 | 17.Right-leg 27 | 18.Left-shoe 28 | 19.Right-shoe -------------------------------------------------------------------------------- /dataset/CIHP/human_colormap.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hlzhu09/Hierarchical-Human-Parsing/58737dbb7763f5a7baafd4801c5a287facf536cd/dataset/CIHP/human_colormap.mat -------------------------------------------------------------------------------- /dataset/LIP/README.md: -------------------------------------------------------------------------------- 1 | 2 | Images: images 3 | Category_ids: semantic part segmentation labels Categories: visualized semantic part segmentation labels 4 | Human_ids: semantic person segmentation labels Human: visualized semantic person segmentation labels 5 | Instance_ids: instance-level human parsing labels Instances: visualized instance-level human parsing labels 6 | 7 | 8 | Label order of semantic part segmentation: 9 | 10 | 1.Hat 11 | 2.Hair 12 | 3.Glove 13 | 4.Sunglasses 14 | 5.UpperClothes 15 | 6.Dress 16 | 7.Coat 17 | 8.Socks 18 | 9.Pants 19 | 10.Torso-skin 20 | 11.Scarf 21 | 12.Skirt 22 | 13.Face 23 | 14.Left-arm 24 | 15.Right-arm 25 | 16.Left-leg 26 | 17.Right-leg 27 | 18.Left-shoe 28 | 19.Right-shoe -------------------------------------------------------------------------------- /dataset/Pascal/README.MD: -------------------------------------------------------------------------------- 1 | background 0 2 | head 1 3 | torso 2 4 | upper-arm 3 5 | lower-arm 4 6 | upper-leg 5 7 | lower-leg 6 8 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hlzhu09/Hierarchical-Human-Parsing/58737dbb7763f5a7baafd4801c5a287facf536cd/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/data_CIHP.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import random 4 | 5 | import cv2 6 | import numpy as np 7 | import torch.utils.data as data 8 | import torchvision.transforms as transforms 9 | from .data_transforms import RandomRotate 10 | from PIL import Image 11 | 12 | 13 | # ###### Data loading ####### 14 | def make_dataset(root, lst): 15 | # append all index 16 | fid = open(lst, 'r') 17 | imgs, segs, segs_rev = [], [], [] 18 | for line in fid.readlines(): 19 | idx = line.strip().split(' ')[0] 20 | image_path = os.path.join(root, 'JPEGImages/' + str(idx) + '.jpg') 21 | seg_path = os.path.join(root, 'Segmentations/' + str(idx) + '.png') 22 | seg_rev_path = os.path.join(root, 'Segmentations_rev/' + str(idx) + '.png') 23 | imgs.append(image_path) 24 | segs.append(seg_path) 25 | segs_rev.append(seg_rev_path) 26 | return imgs, segs, segs_rev 27 | 28 | 29 | # ###### val resize & crop ###### 30 | def scale_crop(img, seg, crop_size): 31 | oh, ow = seg.shape 32 | pad_h = max(0, crop_size - oh) 33 | pad_w = max(0, crop_size - ow) 34 | if pad_h > 0 or pad_w > 0: 35 | img_pad = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, 36 | value=(0.0, 0.0, 0.0)) 37 | seg_pad = cv2.copyMakeBorder(seg, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, 38 | value=255) 39 | else: 40 | img_pad, seg_pad = img, seg 41 | 42 | img = np.asarray(img_pad[0: crop_size, 0: crop_size], np.float32) 43 | seg = np.asarray(seg_pad[0: crop_size, 0: crop_size], np.float32) 44 | 45 | return img, seg 46 | 47 | 48 | class DatasetGenerator(data.Dataset): 49 | def __init__(self, root, list_path, crop_size, training=True): 50 | 51 | imgs, segs, segs_rev = make_dataset(root, list_path) 52 | 53 | self.root = root 54 | self.imgs = imgs 55 | self.segs = segs 56 | self.segs_rev = segs_rev 57 | self.crop_size = crop_size 58 | self.training = training 59 | self.colorjitter = transforms.ColorJitter(brightness=0.1, contrast=0.5, saturation=0.5, hue=0.1) 60 | self.random_rotate=RandomRotate(20) 61 | 62 | def __getitem__(self, index): 63 | # load data 64 | mean = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32) 65 | name = self.imgs[index].split('/')[-1][:-4] 66 | img = cv2.imread(self.imgs[index], cv2.IMREAD_COLOR) 67 | seg_in = cv2.imread(self.segs[index], cv2.IMREAD_GRAYSCALE) 68 | seg_rev_in = cv2.imread(self.segs_rev[index], cv2.IMREAD_GRAYSCALE) 69 | 70 | if self.training: 71 | #colorjitter and rotate 72 | if random.random() < 0.5: 73 | img = Image.fromarray(img) 74 | seg = Image.fromarray(seg) 75 | img = self.colorjitter(img) 76 | img, seg = self.random_rotate(img, seg) 77 | img = np.array(img).astype(np.uint8) 78 | seg = np.array(seg).astype(np.uint8) 79 | # random mirror 80 | flip = np.random.choice(2) * 2 - 1 81 | img = img[:, ::flip, :] 82 | if flip == -1: 83 | seg = seg_rev_in 84 | else: 85 | seg = seg_in 86 | # random scale 87 | ratio = random.uniform(0.5, 2.0) 88 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR) 89 | seg = cv2.resize(seg, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST) 90 | img = np.array(img).astype(np.float32) - mean 91 | 92 | # pad & crop 93 | img_h, img_w = seg.shape 94 | pad_h = max(self.crop_size - img_h, 0) 95 | pad_w = max(self.crop_size - img_w, 0) 96 | if pad_h > 0 or pad_w > 0: 97 | img_pad = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, 98 | value=(0.0, 0.0, 0.0)) 99 | seg_pad = cv2.copyMakeBorder(seg, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, 100 | value=(255,)) 101 | else: 102 | img_pad, seg_pad = img, seg 103 | 104 | img_h, img_w = seg_pad.shape 105 | h_off = random.randint(0, img_h - self.crop_size) 106 | w_off = random.randint(0, img_w - self.crop_size) 107 | img = np.asarray(img_pad[h_off: h_off + self.crop_size, w_off: w_off + self.crop_size], np.float32) 108 | seg = np.asarray(seg_pad[h_off: h_off + self.crop_size, w_off: w_off + self.crop_size], np.float32) 109 | img = img.transpose((2, 0, 1)) 110 | # generate body masks 111 | seg_half = seg.copy() 112 | seg_half[(seg_half > 0) & (seg_half <= 7)] = 1 113 | seg_half[(seg_half > 7) & (seg_half <= 10)] = 2 114 | seg_half[seg_half == 11] = 1 115 | seg_half[seg_half == 12] = 2 116 | seg_half[(seg_half > 12) & (seg_half <= 15)] = 1 117 | seg_half[(seg_half > 15) & (seg_half < 255)] = 2 118 | seg_full = seg.copy() 119 | seg_full[(seg_full > 0) & (seg_full < 255)] = 1 120 | 121 | else: 122 | h, w = seg_in.shape 123 | max_size = max(w, h) 124 | ratio = self.crop_size / max_size 125 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR) 126 | seg = cv2.resize(seg_in, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST) 127 | img = np.array(img).astype(np.float32) - mean 128 | img, seg = scale_crop(img, seg, crop_size=self.crop_size) 129 | img = img.transpose((2, 0, 1)) 130 | # generate body masks 131 | seg_half = seg.copy() 132 | seg_half[(seg_half > 0) & (seg_half <= 7)] = 1 133 | seg_half[(seg_half > 7) & (seg_half <= 10)] = 2 134 | seg_half[seg_half == 11] = 1 135 | seg_half[seg_half == 12] = 2 136 | seg_half[(seg_half > 12) & (seg_half <= 15)] = 1 137 | seg_half[(seg_half > 15) & (seg_half < 255)] = 2 138 | seg_full = seg.copy() 139 | seg_full[(seg_full > 0) & (seg_full < 255)] = 1 140 | 141 | images = img.copy() 142 | segmentations = seg.copy() 143 | segmentations_half = seg_half.copy() 144 | segmentations_full = seg_full.copy() 145 | 146 | return images, segmentations, segmentations_half, segmentations_full, name 147 | 148 | def __len__(self): 149 | return len(self.imgs) 150 | 151 | 152 | class ValidationLoader(data.Dataset): 153 | """evaluate on LIP val set""" 154 | 155 | def __init__(self, root, list_path, crop_size): 156 | fid = open(list_path, 'r') 157 | imgs, segs = [], [] 158 | for line in fid.readlines(): 159 | idx = line.strip().split(' ')[0] 160 | image_path = os.path.join(root, 'images/' + str(idx) + '.jpg') 161 | seg_path = os.path.join(root, 'segmentations/' + str(idx) + '.png') 162 | imgs.append(image_path) 163 | segs.append(seg_path) 164 | 165 | self.root = root 166 | self.imgs = imgs 167 | self.segs = segs 168 | self.crop_size = crop_size 169 | 170 | def __getitem__(self, index): 171 | # load data 172 | mean = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32) 173 | name = self.imgs[index].split('/')[-1][:-4] 174 | img = cv2.imread(self.imgs[index], cv2.IMREAD_COLOR) 175 | seg = cv2.imread(self.segs[index], cv2.IMREAD_GRAYSCALE) 176 | 177 | h, w = seg.shape 178 | max_size = max(w, h) 179 | ratio = self.crop_size / max_size 180 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR) 181 | seg = cv2.resize(seg, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST) 182 | img = np.array(img).astype(np.float32) - mean 183 | img, seg = scale_crop(img, seg, crop_size=self.crop_size) 184 | img = img.transpose((2, 0, 1)) 185 | 186 | images = img.copy() 187 | segmentations = seg.copy() 188 | 189 | return images, segmentations, name 190 | 191 | def __len__(self): 192 | return len(self.imgs) 193 | -------------------------------------------------------------------------------- /dataset/data_atr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import random 4 | 5 | import cv2 6 | import numpy as np 7 | import torch.utils.data as data 8 | import torchvision.transforms as transforms 9 | from .data_transforms import RandomRotate 10 | from PIL import Image 11 | 12 | 13 | # ###### Data loading ####### 14 | def make_dataset(root, lst): 15 | # append all index 16 | fid = open(lst, 'r') 17 | imgs, segs, segs_rev = [], [], [] 18 | for line in fid.readlines(): 19 | idx = line.strip().split(' ')[0] 20 | image_path = os.path.join(root, 'JPEGImages/' + str(idx) + '.jpg') 21 | seg_path = os.path.join(root, 'Segmentations/' + str(idx) + '.png') 22 | seg_rev_path = os.path.join(root, 'SegmentationsRev/' + str(idx) + '_rev.png') 23 | imgs.append(image_path) 24 | segs.append(seg_path) 25 | segs_rev.append(seg_rev_path) 26 | return imgs, segs, segs_rev 27 | 28 | 29 | # ###### val resize & crop ###### 30 | def scale_crop(img, seg, crop_size): 31 | oh, ow = seg.shape 32 | pad_h = max(0, crop_size - oh) 33 | pad_w = max(0, crop_size - ow) 34 | if pad_h > 0 or pad_w > 0: 35 | img_pad = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, 36 | value=(0.0, 0.0, 0.0)) 37 | seg_pad = cv2.copyMakeBorder(seg, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, 38 | value=255) 39 | else: 40 | img_pad, seg_pad = img, seg 41 | 42 | img = np.asarray(img_pad[0: crop_size, 0: crop_size], np.float32) 43 | seg = np.asarray(seg_pad[0: crop_size, 0: crop_size], np.float32) 44 | 45 | return img, seg 46 | 47 | 48 | class DatasetGenerator(data.Dataset): 49 | def __init__(self, root, list_path, crop_size, training=True): 50 | 51 | imgs, segs, segs_rev = make_dataset(root, list_path) 52 | 53 | self.root = root 54 | self.imgs = imgs 55 | self.segs = segs 56 | self.segs_rev = segs_rev 57 | self.crop_size = crop_size 58 | self.training = training 59 | self.colorjitter = transforms.ColorJitter(brightness=0.1, contrast=0.5, saturation=0.5, hue=0.1) 60 | self.random_rotate=RandomRotate(20) 61 | def __getitem__(self, index): 62 | # load data 63 | mean = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32) 64 | name = self.imgs[index].split('/')[-1][:-4] 65 | img = cv2.imread(self.imgs[index], cv2.IMREAD_COLOR) 66 | seg_in = cv2.imread(self.segs[index], cv2.IMREAD_GRAYSCALE) 67 | seg_rev_in = cv2.imread(self.segs_rev[index], cv2.IMREAD_GRAYSCALE) 68 | 69 | if self.training: 70 | #colorjitter and rotate 71 | if random.random() < 0.5: 72 | img = Image.fromarray(img) 73 | seg = Image.fromarray(seg) 74 | img = self.colorjitter(img) 75 | img, seg = self.random_rotate(img, seg) 76 | img = np.array(img).astype(np.uint8) 77 | seg = np.array(seg).astype(np.uint8) 78 | # random mirror 79 | flip = np.random.choice(2) * 2 - 1 80 | img = img[:, ::flip, :] 81 | if flip == -1: 82 | seg = seg_rev_in 83 | else: 84 | seg = seg_in 85 | # random scale 86 | ratio = random.uniform(0.5, 2.0) 87 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR) 88 | seg = cv2.resize(seg, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST) 89 | img = np.array(img).astype(np.float32) - mean 90 | 91 | # pad & crop 92 | img_h, img_w = seg.shape 93 | pad_h = max(self.crop_size - img_h, 0) 94 | pad_w = max(self.crop_size - img_w, 0) 95 | if pad_h > 0 or pad_w > 0: 96 | img_pad = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, 97 | value=(0.0, 0.0, 0.0)) 98 | seg_pad = cv2.copyMakeBorder(seg, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, 99 | value=(255,)) 100 | else: 101 | img_pad, seg_pad = img, seg 102 | 103 | img_h, img_w = seg_pad.shape 104 | h_off = random.randint(0, img_h - self.crop_size) 105 | w_off = random.randint(0, img_w - self.crop_size) 106 | img = np.asarray(img_pad[h_off: h_off + self.crop_size, w_off: w_off + self.crop_size], np.float32) 107 | seg = np.asarray(seg_pad[h_off: h_off + self.crop_size, w_off: w_off + self.crop_size], np.uint8) 108 | img = img.transpose((2, 0, 1)) 109 | # generate body masks 110 | seg_half = seg.copy() 111 | seg_half[(seg_half > 0) & (seg_half <= 7)] = 1 112 | seg_half[(seg_half > 7) & (seg_half <= 10)] = 2 113 | seg_half[seg_half == 11] = 1 114 | seg_half[seg_half == 12] = 2 115 | seg_half[(seg_half > 12) & (seg_half <= 15)] = 1 116 | seg_half[(seg_half > 15) & (seg_half < 255)] = 2 117 | seg_full = seg.copy() 118 | seg_full[(seg_full > 0) & (seg_full < 255)] = 1 119 | 120 | else: 121 | h, w = seg_in.shape 122 | max_size = max(w, h) 123 | ratio = self.crop_size / max_size 124 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR) 125 | seg = cv2.resize(seg_in, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST) 126 | img = np.array(img).astype(np.float32) - mean 127 | img, seg = scale_crop(img, seg, crop_size=self.crop_size) 128 | img = img.transpose((2, 0, 1)) 129 | # generate body masks 130 | # 0 background, 1-4 up, 5-10 down, 11 up, 12-13 down, 14-17 up 131 | seg_half = seg.copy() 132 | seg_half[(seg_half > 0) & (seg_half <= 4)] = 1 133 | seg_half[(seg_half > 4) & (seg_half <= 10)] = 2 134 | seg_half[seg_half == 11] = 1 135 | seg_half[(seg_half > 11) & (seg_half <= 13)] = 2 136 | seg_half[(seg_half > 13) & (seg_half < 255)] = 1 137 | seg_full = seg.copy() 138 | seg_full[(seg_full > 0) & (seg_full < 255)] = 1 139 | 140 | images = img.copy() 141 | segmentations = seg.copy() 142 | segmentations_half = seg_half.copy() 143 | segmentations_full = seg_full.copy() 144 | 145 | return images, segmentations, segmentations_half, segmentations_full, name 146 | 147 | def __len__(self): 148 | return len(self.imgs) 149 | 150 | 151 | class ATRTestGenerator(data.Dataset): 152 | def __init__(self, root, list_path, crop_size): 153 | 154 | fid = open(list_path, 'r') 155 | imgs, segs = [], [] 156 | for line in fid.readlines(): 157 | idx = line.strip().split(' ')[0] 158 | image_path = os.path.join(root, 'JPEGImages/' + str(idx) + '.jpg') 159 | seg_path = os.path.join(root, 'Segmentations/' + str(idx) + '.png') 160 | imgs.append(image_path) 161 | segs.append(seg_path) 162 | 163 | self.root = root 164 | self.imgs = imgs 165 | self.segs = segs 166 | self.crop_size = crop_size 167 | 168 | def __getitem__(self, index): 169 | # load data 170 | mean = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32) 171 | name = self.imgs[index].split('/')[-1][:-4] 172 | img = cv2.imread(self.imgs[index], cv2.IMREAD_COLOR) 173 | seg = cv2.imread(self.segs[index], cv2.IMREAD_GRAYSCALE) 174 | ori_size = img.shape 175 | 176 | h, w = seg.shape 177 | length = max(w, h) 178 | ratio = self.crop_size / length 179 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR) 180 | img = np.array(img).astype(np.float32) - mean 181 | img = img.transpose((2, 0, 1)) 182 | 183 | images = img.copy() 184 | segmentations = seg.copy() 185 | 186 | return images, segmentations, np.array(ori_size), name 187 | 188 | def __len__(self): 189 | return len(self.imgs) 190 | -------------------------------------------------------------------------------- /dataset/data_ccf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import random 4 | 5 | import cv2 6 | import numpy as np 7 | import torch.utils.data as data 8 | import torchvision.transforms as transforms 9 | from .data_transforms import RandomRotate 10 | from PIL import Image 11 | 12 | 13 | # ###### Data loading ####### 14 | def make_dataset(root, lst): 15 | # append all index 16 | fid = open(lst, 'r') 17 | imgs, segs = [], [] 18 | for line in fid.readlines(): 19 | idx = line.strip().split(' ')[0] 20 | image_path = os.path.join(root, 'JPEGImages/' + str(idx) + '.jpg') 21 | seg_path = os.path.join(root, 'Segmentations/' + str(idx) + '.png') 22 | imgs.append(image_path) 23 | segs.append(seg_path) 24 | return imgs, segs 25 | 26 | 27 | # ###### val resize & crop ###### 28 | def scale_crop(img, seg, crop_size): 29 | oh, ow = seg.shape 30 | pad_h = max(crop_size - oh, 0) 31 | pad_ht, pad_hb = pad_h // 2, pad_h - pad_h // 2 32 | pad_w = max(crop_size - ow, 0) 33 | pad_wl, pad_wr = pad_w // 2, pad_w - pad_w // 2 34 | if pad_h > 0 or pad_w > 0: 35 | img_pad = cv2.copyMakeBorder(img, pad_ht, pad_hb, pad_wl, pad_wr, cv2.BORDER_CONSTANT, 36 | value=(0.0, 0.0, 0.0)) 37 | seg_pad = cv2.copyMakeBorder(seg, pad_ht, pad_hb, pad_wl, pad_wr, cv2.BORDER_CONSTANT, 38 | value=(255,)) 39 | else: 40 | img_pad, seg_pad = img, seg 41 | 42 | return img_pad, seg_pad 43 | 44 | 45 | class DatasetGenerator(data.Dataset): 46 | def __init__(self, root, list_path, crop_size, training=True): 47 | 48 | imgs, segs = make_dataset(root, list_path) 49 | 50 | self.root = root 51 | self.imgs = imgs 52 | self.segs = segs 53 | self.crop_size = crop_size 54 | self.training = training 55 | self.colorjitter = transforms.ColorJitter(brightness=0.1, contrast=0.5, saturation=0.5, hue=0.1) 56 | self.random_rotate=RandomRotate(20) 57 | 58 | def __getitem__(self, index): 59 | mean = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32) 60 | # load data 61 | name = self.imgs[index].split('/')[-1][:-4] 62 | img = cv2.imread(self.imgs[index], cv2.IMREAD_COLOR) 63 | seg = cv2.imread(self.segs[index], cv2.IMREAD_GRAYSCALE) 64 | 65 | if self.training: 66 | #colorjitter and rotate 67 | if random.random() < 0.5: 68 | img = Image.fromarray(img) 69 | seg = Image.fromarray(seg) 70 | img = self.colorjitter(img) 71 | img, seg = self.random_rotate(img, seg) 72 | img = np.array(img).astype(np.uint8) 73 | seg = np.array(seg).astype(np.uint8) 74 | # random scale 75 | ratio = random.uniform(0.5, 2.0) 76 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR) 77 | seg = cv2.resize(seg, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST) 78 | img = np.array(img).astype(np.float32) - mean 79 | 80 | # pad & crop 81 | img_h, img_w = seg.shape[:2] 82 | pad_h = max(self.crop_size - img_h, 0) 83 | pad_ht, pad_hb = pad_h // 2, pad_h - pad_h // 2 84 | pad_w = max(self.crop_size - img_w, 0) 85 | pad_wl, pad_wr = pad_w // 2, pad_w - pad_w // 2 86 | if pad_h > 0 or pad_w > 0: 87 | img_pad = cv2.copyMakeBorder(img, pad_ht, pad_hb, pad_wl, pad_wr, cv2.BORDER_CONSTANT, 88 | value=(0.0, 0.0, 0.0)) 89 | seg_pad = cv2.copyMakeBorder(seg, pad_ht, pad_hb, pad_wl, pad_wr, cv2.BORDER_CONSTANT, 90 | value=(255,)) 91 | else: 92 | img_pad, seg_pad = img, seg 93 | 94 | seg_pad_h, seg_pad_w = seg_pad.shape 95 | h_off = random.randint(0, seg_pad_h - self.crop_size) 96 | w_off = random.randint(0, seg_pad_w - self.crop_size) 97 | img = np.asarray(img_pad[h_off: h_off + self.crop_size, w_off: w_off + self.crop_size], np.float32) 98 | seg = np.asarray(seg_pad[h_off: h_off + self.crop_size, w_off: w_off + self.crop_size], np.uint8) 99 | # random mirror 100 | flip = np.random.choice(2) * 2 - 1 101 | img = img[:, ::flip, :] 102 | seg = seg[:, ::flip] 103 | # Generate target maps 104 | img = img.transpose((2, 0, 1)) 105 | seg_half = seg.copy() 106 | seg_half[(seg_half > 0) & (seg_half <= 4)] = 1 107 | seg_half[seg_half == 5] = 2 108 | seg_half[(seg_half > 5) & (seg_half <= 7)] = 1 109 | seg_half[(seg_half > 7) & (seg_half <= 9)] = 2 110 | seg_half[(seg_half > 9) & (seg_half <= 11)] = 1 111 | seg_half[seg_half == 12] = 2 112 | seg_half[(seg_half > 12) & (seg_half <= 15)] = 1 113 | seg_half[seg_half == 16] = 2 114 | seg_half[(seg_half > 16) & (seg_half < 255)] = 1 115 | seg_full = seg.copy() 116 | seg_full[(seg_full > 0) & (seg_full < 255)] = 1 117 | 118 | else: 119 | h, w = seg.shape 120 | max_size = max(w, h) 121 | ratio = self.crop_size / max_size 122 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR) 123 | seg = cv2.resize(seg, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST) 124 | img = np.array(img).astype(np.float32) - mean 125 | img, seg = scale_crop(img, seg, crop_size=self.crop_size) 126 | img = img.transpose((2, 0, 1)) 127 | seg_half = seg.copy() 128 | seg_half[(seg_half > 0) & (seg_half <= 4)] = 1 129 | seg_half[seg_half == 5] = 2 130 | seg_half[(seg_half > 5) & (seg_half <= 7)] = 1 131 | seg_half[(seg_half > 7) & (seg_half <= 9)] = 2 132 | seg_half[(seg_half > 9) & (seg_half <= 11)] = 1 133 | seg_half[seg_half == 12] = 2 134 | seg_half[(seg_half > 12) & (seg_half <= 15)] = 1 135 | seg_half[seg_half == 16] = 2 136 | seg_half[(seg_half > 16) & (seg_half < 255)] = 1 137 | seg_full = seg.copy() 138 | seg_full[(seg_full > 0) & (seg_full < 255)] = 1 139 | 140 | images = img.copy() 141 | segmentations = seg.copy() 142 | segmentations_half = seg_half.copy() 143 | segmentations_full = seg_full.copy() 144 | 145 | return images, segmentations, segmentations_half, segmentations_full, name 146 | 147 | def __len__(self): 148 | return len(self.imgs) 149 | 150 | 151 | class TestGenerator(data.Dataset): 152 | 153 | def __init__(self, root, list_path, crop_size): 154 | 155 | imgs, segs = make_dataset(root, list_path) 156 | self.root = root 157 | self.imgs = imgs 158 | self.segs = segs 159 | self.crop_size = crop_size 160 | 161 | def __getitem__(self, index): 162 | # load data 163 | mean = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32) 164 | name = self.imgs[index].split('/')[-1][:-4] 165 | img = cv2.imread(self.imgs[index], cv2.IMREAD_COLOR) 166 | seg = cv2.imread(self.segs[index], cv2.IMREAD_GRAYSCALE) 167 | ori_size = img.shape 168 | 169 | h, w = seg.shape 170 | length = max(w, h) 171 | ratio = self.crop_size / length 172 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR) 173 | img = np.array(img).astype(np.float32) - mean 174 | img = img.transpose((2, 0, 1)) 175 | 176 | images = img.copy() 177 | segmentations = seg.copy() 178 | 179 | return images, segmentations, np.array(ori_size), name 180 | 181 | def __len__(self): 182 | return len(self.imgs) 183 | -------------------------------------------------------------------------------- /dataset/data_lip.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import random 4 | 5 | import cv2 6 | import numpy as np 7 | import torch.utils.data as data 8 | import torchvision.transforms as transforms 9 | from .data_transforms import RandomRotate 10 | from PIL import Image 11 | 12 | 13 | # ###### Data loading ####### 14 | def make_dataset(root, lst): 15 | # append all index 16 | fid = open(lst, 'r') 17 | imgs, segs, segs_rev = [], [], [] 18 | for line in fid.readlines(): 19 | idx = line.strip().split(' ')[0] 20 | image_path = os.path.join(root, 'images/' + str(idx) + '.jpg') 21 | seg_path = os.path.join(root, 'segmentations/' + str(idx) + '.png') 22 | seg_rev_path = os.path.join(root, 'segmentations_rev/' + str(idx) + '.png') 23 | imgs.append(image_path) 24 | segs.append(seg_path) 25 | segs_rev.append(seg_rev_path) 26 | return imgs, segs, segs_rev 27 | 28 | 29 | # ###### val resize & crop ###### 30 | def scale_crop(img, seg, crop_size): 31 | oh, ow = seg.shape 32 | pad_h = max(0, crop_size - oh) 33 | pad_w = max(0, crop_size - ow) 34 | if pad_h > 0 or pad_w > 0: 35 | img_pad = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, 36 | value=(0.0, 0.0, 0.0)) 37 | seg_pad = cv2.copyMakeBorder(seg, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, 38 | value=255) 39 | else: 40 | img_pad, seg_pad = img, seg 41 | 42 | img = np.asarray(img_pad[0: crop_size, 0: crop_size], np.float32) 43 | seg = np.asarray(seg_pad[0: crop_size, 0: crop_size], np.float32) 44 | 45 | return img, seg 46 | 47 | 48 | class DatasetGenerator(data.Dataset): 49 | def __init__(self, root, list_path, crop_size, training=True): 50 | 51 | imgs, segs, segs_rev = make_dataset(root, list_path) 52 | 53 | self.root = root 54 | self.imgs = imgs 55 | self.segs = segs 56 | self.segs_rev = segs_rev 57 | self.crop_size = crop_size 58 | self.training = training 59 | self.colorjitter = transforms.ColorJitter(brightness=0.1, contrast=0.5, saturation=0.5, hue=0.1) 60 | self.random_rotate=RandomRotate(20) 61 | def __getitem__(self, index): 62 | # load data 63 | mean = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32) 64 | name = self.imgs[index].split('/')[-1][:-4] 65 | img = cv2.imread(self.imgs[index], cv2.IMREAD_COLOR) 66 | seg_in = cv2.imread(self.segs[index], cv2.IMREAD_GRAYSCALE) 67 | seg_rev_in = cv2.imread(self.segs_rev[index], cv2.IMREAD_GRAYSCALE) 68 | 69 | if self.training: 70 | #colorjitter and rotate 71 | if random.random() < 0.5: 72 | img = Image.fromarray(img) 73 | seg = Image.fromarray(seg) 74 | img = self.colorjitter(img) 75 | img, seg = self.random_rotate(img, seg) 76 | img = np.array(img).astype(np.uint8) 77 | seg = np.array(seg).astype(np.uint8) 78 | # random mirror 79 | flip = np.random.choice(2) * 2 - 1 80 | img = img[:, ::flip, :] 81 | if flip == -1: 82 | seg = seg_rev_in 83 | else: 84 | seg = seg_in 85 | # random scale 86 | ratio = random.uniform(0.5, 1.5) 87 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR) 88 | seg = cv2.resize(seg, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST) 89 | img = np.array(img).astype(np.float32) - mean 90 | 91 | # pad & crop 92 | img_h, img_w = seg.shape 93 | pad_h = max(self.crop_size - img_h, 0) 94 | pad_w = max(self.crop_size - img_w, 0) 95 | if pad_h > 0 or pad_w > 0: 96 | img_pad = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, 97 | value=(0.0, 0.0, 0.0)) 98 | seg_pad = cv2.copyMakeBorder(seg, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, 99 | value=(255,)) 100 | else: 101 | img_pad, seg_pad = img, seg 102 | 103 | img_h, img_w = seg_pad.shape 104 | h_off = random.randint(0, img_h - self.crop_size) 105 | w_off = random.randint(0, img_w - self.crop_size) 106 | img = np.asarray(img_pad[h_off: h_off + self.crop_size, w_off: w_off + self.crop_size], np.float32) 107 | seg = np.asarray(seg_pad[h_off: h_off + self.crop_size, w_off: w_off + self.crop_size], np.float32) 108 | img = img.transpose((2, 0, 1)) 109 | # generate body masks 110 | seg_half = seg.copy() 111 | seg_half[(seg_half > 0) & (seg_half <= 7)] = 1 112 | seg_half[(seg_half > 7) & (seg_half <= 10)] = 2 113 | seg_half[seg_half == 11] = 1 114 | seg_half[seg_half == 12] = 2 115 | seg_half[(seg_half > 12) & (seg_half <= 15)] = 1 116 | seg_half[(seg_half > 15) & (seg_half < 255)] = 2 117 | seg_full = seg.copy() 118 | seg_full[(seg_full > 0) & (seg_full < 255)] = 1 119 | 120 | else: 121 | h, w = seg_in.shape 122 | max_size = max(w, h) 123 | ratio = self.crop_size / max_size 124 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR) 125 | seg = cv2.resize(seg_in, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST) 126 | img = np.array(img).astype(np.float32) - mean 127 | img, seg = scale_crop(img, seg, crop_size=self.crop_size) 128 | img = img.transpose((2, 0, 1)) 129 | # generate body masks 130 | seg_half = seg.copy() 131 | seg_half[(seg_half > 0) & (seg_half <= 7)] = 1 132 | seg_half[(seg_half > 7) & (seg_half <= 10)] = 2 133 | seg_half[seg_half == 11] = 1 134 | seg_half[seg_half == 12] = 2 135 | seg_half[(seg_half > 12) & (seg_half <= 15)] = 1 136 | seg_half[(seg_half > 15) & (seg_half < 255)] = 2 137 | seg_full = seg.copy() 138 | seg_full[(seg_full > 0) & (seg_full < 255)] = 1 139 | 140 | images = img.copy() 141 | segmentations = seg.copy() 142 | segmentations_half = seg_half.copy() 143 | segmentations_full = seg_full.copy() 144 | 145 | return images, segmentations, segmentations_half, segmentations_full, name 146 | 147 | def __len__(self): 148 | return len(self.imgs) 149 | 150 | 151 | class LIPValGenerator(data.Dataset): 152 | def __init__(self, root, list_path, crop_size): 153 | 154 | fid = open(list_path, 'r') 155 | imgs, segs = [], [] 156 | for line in fid.readlines(): 157 | idx = line.strip().split(' ')[0] 158 | image_path = os.path.join(root, 'images/' + str(idx) + '.jpg') 159 | seg_path = os.path.join(root, 'segmentations/' + str(idx) + '.png') 160 | imgs.append(image_path) 161 | segs.append(seg_path) 162 | 163 | self.root = root 164 | self.imgs = imgs 165 | self.segs = segs 166 | self.crop_size = crop_size 167 | 168 | def __getitem__(self, index): 169 | # load data 170 | mean = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32) 171 | name = self.imgs[index].split('/')[-1][:-4] 172 | img = cv2.imread(self.imgs[index], cv2.IMREAD_COLOR) 173 | seg = cv2.imread(self.segs[index], cv2.IMREAD_GRAYSCALE) 174 | ori_size = img.shape 175 | 176 | h, w = seg.shape 177 | length = max(w, h) 178 | ratio = self.crop_size / length 179 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR) 180 | img = np.array(img).astype(np.float32) - mean 181 | img = img.transpose((2, 0, 1)) 182 | 183 | images = img.copy() 184 | segmentations = seg.copy() 185 | 186 | return images, segmentations, np.array(ori_size), name 187 | 188 | def __len__(self): 189 | return len(self.imgs) 190 | -------------------------------------------------------------------------------- /dataset/data_pascal.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import random 4 | 5 | import cv2 6 | import numpy as np 7 | import torch.utils.data as data 8 | import torchvision.transforms as transforms 9 | from .data_transforms import RandomRotate 10 | from PIL import Image 11 | # ###### Data loading ####### 12 | def make_dataset(root, lst): 13 | # append all index 14 | fid = open(lst, 'r') 15 | imgs, segs = [], [] 16 | for line in fid.readlines(): 17 | idx = line.strip().split(' ')[0] 18 | image_path = os.path.join(root, 'JPEGImages/' + str(idx) + '.jpg') 19 | # image_path = os.path.join(root, str(idx) + '.jpg') 20 | seg_path = os.path.join(root, 'SegmentationPart/' + str(idx) + '.png') 21 | # seg_path = os.path.join(root, str(idx) + '.jpg') 22 | imgs.append(image_path) 23 | segs.append(seg_path) 24 | return imgs, segs 25 | 26 | 27 | # ###### val resize & crop ###### 28 | def scale_crop(img, seg, crop_size): 29 | oh, ow = seg.shape 30 | pad_h = max(crop_size - oh, 0) 31 | pad_ht, pad_hb = pad_h // 2, pad_h - pad_h // 2 32 | pad_w = max(crop_size - ow, 0) 33 | pad_wl, pad_wr = pad_w // 2, pad_w - pad_w // 2 34 | if pad_h > 0 or pad_w > 0: 35 | img_pad = cv2.copyMakeBorder(img, pad_ht, pad_hb, pad_wl, pad_wr, cv2.BORDER_CONSTANT, 36 | value=(0.0, 0.0, 0.0)) 37 | seg_pad = cv2.copyMakeBorder(seg, pad_ht, pad_hb, pad_wl, pad_wr, cv2.BORDER_CONSTANT, 38 | value=(255,)) 39 | else: 40 | img_pad, seg_pad = img, seg 41 | 42 | return img_pad, seg_pad 43 | 44 | 45 | class DatasetGenerator(data.Dataset): 46 | def __init__(self, root, list_path, crop_size, training=True): 47 | 48 | imgs, segs = make_dataset(root, list_path) 49 | 50 | self.root = root 51 | self.imgs = imgs 52 | self.segs = segs 53 | self.crop_size = crop_size 54 | self.training = training 55 | self.colorjitter = transforms.ColorJitter(brightness=0.1, contrast=0.5, saturation=0.5, hue=0.1) 56 | self.random_rotate=RandomRotate(20) 57 | 58 | def __getitem__(self, index): 59 | mean = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32) 60 | # load data 61 | name = self.imgs[index].split('/')[-1][:-4] 62 | img = cv2.imread(self.imgs[index], cv2.IMREAD_COLOR) 63 | seg = cv2.imread(self.segs[index], cv2.IMREAD_GRAYSCALE) 64 | 65 | if self.training: 66 | 67 | #colorjitter and rotate 68 | if random.random() < 0.5: 69 | img = Image.fromarray(img) 70 | seg = Image.fromarray(seg) 71 | img = self.colorjitter(img) 72 | img, seg = self.random_rotate(img, seg) 73 | img = np.array(img).astype(np.uint8) 74 | seg = np.array(seg).astype(np.uint8) 75 | 76 | # random scale 77 | ratio = random.uniform(0.5, 2.0) 78 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR) 79 | seg = cv2.resize(seg, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST) 80 | img = np.array(img).astype(np.float32) - mean 81 | 82 | # pad & crop 83 | img_h, img_w = seg.shape[:2] 84 | pad_h = max(self.crop_size - img_h, 0) 85 | pad_ht, pad_hb = pad_h // 2, pad_h - pad_h // 2 86 | pad_w = max(self.crop_size - img_w, 0) 87 | pad_wl, pad_wr = pad_w // 2, pad_w - pad_w // 2 88 | if pad_h > 0 or pad_w > 0: 89 | img_pad = cv2.copyMakeBorder(img, pad_ht, pad_hb, pad_wl, pad_wr, cv2.BORDER_CONSTANT, 90 | value=(0.0, 0.0, 0.0)) 91 | seg_pad = cv2.copyMakeBorder(seg, pad_ht, pad_hb, pad_wl, pad_wr, cv2.BORDER_CONSTANT, 92 | value=(255,)) 93 | else: 94 | img_pad, seg_pad = img, seg 95 | 96 | seg_pad_h, seg_pad_w = seg_pad.shape 97 | h_off = random.randint(0, seg_pad_h - self.crop_size) 98 | w_off = random.randint(0, seg_pad_w - self.crop_size) 99 | img = np.asarray(img_pad[h_off: h_off + self.crop_size, w_off: w_off + self.crop_size], np.float32) 100 | seg = np.asarray(seg_pad[h_off: h_off + self.crop_size, w_off: w_off + self.crop_size], np.uint8) 101 | # random mirror 102 | flip = np.random.choice(2) * 2 - 1 103 | img = img[:, ::flip, :] 104 | seg = seg[:, ::flip] 105 | # Generate target maps 106 | img = img.transpose((2, 0, 1)) 107 | seg_half = seg.copy() 108 | seg_half[(seg_half > 0) & (seg_half <= 4)] = 1 109 | seg_half[(seg_half > 4) & (seg_half < 255)] = 2 110 | seg_full = seg.copy() 111 | seg_full[(seg_full > 0) & (seg_full < 255)] = 1 112 | 113 | else: 114 | h, w = seg.shape 115 | max_size = max(w, h) 116 | ratio = self.crop_size / max_size 117 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR) 118 | seg = cv2.resize(seg, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST) 119 | img = np.array(img).astype(np.float32) - mean 120 | img, seg = scale_crop(img, seg, crop_size=self.crop_size) 121 | img = img.transpose((2, 0, 1)) 122 | seg_half = seg.copy() 123 | seg_half[(seg_half > 0) & (seg_half <= 4)] = 1 124 | seg_half[(seg_half > 4) & (seg_half < 255)] = 2 125 | seg_full = seg.copy() 126 | seg_full[(seg_full > 0) & (seg_full < 255)] = 1 127 | 128 | images = img.copy() 129 | segmentations = seg.copy() 130 | segmentations_half = seg_half.copy() 131 | segmentations_full = seg_full.copy() 132 | 133 | return images, segmentations, segmentations_half, segmentations_full, name 134 | 135 | def __len__(self): 136 | return len(self.imgs) 137 | 138 | 139 | class TestGenerator(data.Dataset): 140 | 141 | def __init__(self, root, list_path, crop_size): 142 | 143 | imgs, segs = make_dataset(root, list_path) 144 | self.root = root 145 | self.imgs = imgs 146 | self.segs = segs 147 | self.crop_size = crop_size 148 | 149 | def __getitem__(self, index): 150 | # load data 151 | mean = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32) 152 | name = self.imgs[index].split('/')[-1][:-4] 153 | img = cv2.imread(self.imgs[index], cv2.IMREAD_COLOR) 154 | seg = cv2.imread(self.segs[index], cv2.IMREAD_GRAYSCALE) 155 | ori_size = img.shape 156 | 157 | h, w = seg.shape 158 | length = max(w, h) 159 | ratio = self.crop_size / length 160 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR) 161 | img = np.array(img).astype(np.float32) - mean 162 | img = img.transpose((2, 0, 1)) 163 | 164 | images = img.copy() 165 | segmentations = seg.copy() 166 | 167 | return images, segmentations, np.array(ori_size), name 168 | 169 | def __len__(self): 170 | return len(self.imgs) 171 | 172 | 173 | class ReportGenerator(data.Dataset): 174 | 175 | def __init__(self, root, list_path, crop_size): 176 | 177 | imgs, segs = make_dataset(root, list_path) 178 | self.root = root 179 | self.imgs = imgs 180 | self.segs = segs 181 | self.crop_size = crop_size 182 | 183 | def __getitem__(self, index): 184 | # load data 185 | mean = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32) 186 | name = self.imgs[index].split('/')[-1][:-4] 187 | img = cv2.imread(self.imgs[index], cv2.IMREAD_COLOR) 188 | seg = cv2.imread(self.segs[index], cv2.IMREAD_GRAYSCALE) 189 | ori_size = img.shape 190 | 191 | h, w = seg.shape 192 | length = max(w, h) 193 | ratio = self.crop_size / length 194 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR) 195 | img = np.array(img).astype(np.float32) - mean 196 | img = img.transpose((2, 0, 1)) 197 | seg_half = seg.copy() 198 | seg_half[(seg_half > 0) & (seg_half <= 4)] = 1 199 | seg_half[(seg_half > 4) & (seg_half < 255)] = 2 200 | seg_full = seg.copy() 201 | seg_full[(seg_full > 0) & (seg_full < 255)] = 1 202 | 203 | images = img.copy() 204 | segmentations = seg.copy() 205 | segmentations_half = seg_half.copy() 206 | segmentations_full = seg_full.copy() 207 | 208 | return images, segmentations, segmentations_half, segmentations_full, np.array(ori_size), name 209 | 210 | def __len__(self): 211 | return len(self.imgs) 212 | 213 | 214 | if __name__ == '__main__': 215 | dl = DataGenerator('/media/jzzz/Data/Dataset/PascalPersonPart/', './pascal/train_id.txt', 216 | crop_size=512, training=True) 217 | 218 | item = iter(dl) 219 | for i in range(len(dl)): 220 | imgs, segs, segs_half, segs_full, idx = next(item) 221 | pass 222 | -------------------------------------------------------------------------------- /dataset/data_ppss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import random 4 | 5 | import cv2 6 | import numpy as np 7 | import torch.utils.data as data 8 | import torchvision.transforms as transforms 9 | from .data_transforms import RandomRotate 10 | from PIL import Image 11 | map_idx = [0, 9, 19, 29, 50, 39, 60, 62] 12 | # 0background, 1hair, 2face, 3upper clothes, 4arms, 5lower clothes, 6legs, 7shoes 13 | 14 | 15 | # ###### Data loading ####### 16 | def make_dataset(root, lst): 17 | # append all index 18 | fid = open(lst, 'r') 19 | imgs, segs = [], [] 20 | for line in fid.readlines(): 21 | idx = line.strip() 22 | image_path = os.path.join(root, str(idx) + '.jpg') 23 | seg_path = os.path.join(root, str(idx) + '_m.png') 24 | imgs.append(image_path) 25 | segs.append(seg_path) 26 | return imgs, segs 27 | 28 | 29 | # ###### val resize & crop ###### 30 | def scale_crop(img, seg, crop_size): 31 | oh, ow = seg.shape 32 | pad_h = max(0, crop_size[0] - oh) 33 | pad_w = max(0, crop_size[1] - ow) 34 | if pad_h > 0 or pad_w > 0: 35 | img_pad = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, 36 | value=(0.0, 0.0, 0.0)) 37 | seg_pad = cv2.copyMakeBorder(seg, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, 38 | value=255) 39 | else: 40 | img_pad, seg_pad = img, seg 41 | 42 | img = np.asarray(img_pad[0: crop_size[0], 0: crop_size[1]], np.float32) 43 | seg = np.asarray(seg_pad[0: crop_size[0], 0: crop_size[1]], np.float32) 44 | 45 | return img, seg 46 | 47 | 48 | class DatasetGenerator(data.Dataset): 49 | def __init__(self, root, list_path, crop_size, training=True): 50 | 51 | imgs, segs = make_dataset(root, list_path) 52 | 53 | self.root = root 54 | self.imgs = imgs 55 | self.segs = segs 56 | self.crop_size = crop_size 57 | self.training = training 58 | self.colorjitter = transforms.ColorJitter(brightness=0.1, contrast=0.5, saturation=0.5, hue=0.1) 59 | self.random_rotate=RandomRotate(20) 60 | 61 | def __getitem__(self, index): 62 | # load data 63 | mean = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32) 64 | name = self.imgs[index].split('/')[-1][:-4] 65 | img = cv2.imread(self.imgs[index], cv2.IMREAD_COLOR) 66 | seg = np.array(Image.open(self.segs[index])) 67 | # seg_h, seg_w = seg.shape 68 | seg_h, seg_w, _ = img.shape 69 | # img = cv2.resize(img, (seg_w, seg_h), interpolation=cv2.INTER_LINEAR) 70 | seg = cv2.resize(seg, (seg_w, seg_h), interpolation=cv2.INTER_NEAREST) 71 | new_seg = (np.ones_like(seg)*255).astype(np.uint8) 72 | for i in range(len(map_idx)): 73 | new_seg[seg == map_idx[i]] = i 74 | seg = new_seg 75 | if self.training: 76 | #colorjitter and rotate 77 | if random.random() < 0.5: 78 | img = Image.fromarray(img) 79 | seg = Image.fromarray(seg) 80 | img = self.colorjitter(img) 81 | img, seg = self.random_rotate(img, seg) 82 | img = np.array(img).astype(np.uint8) 83 | seg = np.array(seg).astype(np.uint8) 84 | # random mirror 85 | flip = np.random.choice(2) * 2 - 1 86 | img = img[:, ::flip, :] 87 | seg = seg[:, ::flip] 88 | # random scale 89 | ratio = random.uniform(0.75, 2.5) 90 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR) 91 | seg = cv2.resize(seg, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST) 92 | img = np.array(img).astype(np.float32) - mean 93 | 94 | # pad & crop 95 | img_h, img_w = seg.shape 96 | assert img_w < img_h 97 | pad_h = max(self.crop_size[0] - img_h, 0) 98 | pad_w = max(self.crop_size[1] - img_w, 0) 99 | if pad_h > 0 or pad_w > 0: 100 | img_pad = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, 101 | value=(0.0, 0.0, 0.0)) 102 | seg_pad = cv2.copyMakeBorder(seg, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, 103 | value=(255,)) 104 | else: 105 | img_pad, seg_pad = img, seg 106 | 107 | img_h, img_w = seg_pad.shape 108 | h_off = random.randint(0, img_h - self.crop_size[0]) 109 | w_off = random.randint(0, img_w - self.crop_size[1]) 110 | img = np.asarray(img_pad[h_off: h_off + self.crop_size[0], w_off: w_off + self.crop_size[1]], np.float32) 111 | seg = np.asarray(seg_pad[h_off: h_off + self.crop_size[0], w_off: w_off + self.crop_size[1]], np.uint8) 112 | img = img.transpose((2, 0, 1)) 113 | # generate body masks 114 | seg_half = seg.copy() 115 | seg_half[(seg_half > 0) & (seg_half <= 4)] = 1 116 | seg_half[(seg_half > 4) & (seg_half < 255)] = 2 117 | seg_full = seg.copy() 118 | seg_full[(seg_full > 0) & (seg_full < 255)] = 1 119 | 120 | else: 121 | h, w = seg.shape 122 | max_size = max(w, h) 123 | ratio = self.crop_size[0] / max_size 124 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR) 125 | seg = cv2.resize(seg, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST) 126 | img = np.array(img).astype(np.float32) - mean 127 | img, seg = scale_crop(img, seg, crop_size=self.crop_size) 128 | img = img.transpose((2, 0, 1)) 129 | # generate body masks 130 | seg_half = seg.copy() 131 | seg_half[(seg_half > 0) & (seg_half <= 4)] = 1 132 | seg_half[(seg_half > 4) & (seg_half < 255)] = 2 133 | seg_full = seg.copy() 134 | seg_full[(seg_full > 0) & (seg_full < 255)] = 1 135 | 136 | images = img.copy() 137 | segmentations = seg.copy() 138 | segmentations_half = seg_half.copy() 139 | segmentations_full = seg_full.copy() 140 | 141 | return images, segmentations, segmentations_half, segmentations_full, name 142 | 143 | def __len__(self): 144 | return len(self.imgs) 145 | 146 | class TestGenerator(data.Dataset): 147 | 148 | def __init__(self, root, list_path, crop_size): 149 | 150 | imgs, segs = make_dataset(root, list_path) 151 | self.root = root 152 | self.imgs = imgs 153 | self.segs = segs 154 | self.crop_size = crop_size 155 | 156 | def __getitem__(self, index): 157 | # load data 158 | mean = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32) 159 | name = self.imgs[index].split('/')[-1][:-4] 160 | img = cv2.imread(self.imgs[index], cv2.IMREAD_COLOR) 161 | seg = np.array(Image.open(self.segs[index])) 162 | seg_h, seg_w = seg.shape 163 | img = cv2.resize(img, (seg_w, seg_h), interpolation=cv2.INTER_LINEAR) 164 | for i in range(len(map_idx)): 165 | seg[seg == map_idx[i]] = i 166 | ori_size = img.shape 167 | 168 | h, w = seg.shape 169 | length = max(w, h) 170 | ratio = self.crop_size[0] / length 171 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR) 172 | img = np.array(img).astype(np.float32) - mean 173 | img = img.transpose((2, 0, 1)) 174 | 175 | images = img.copy() 176 | segmentations = seg.copy() 177 | 178 | return images, segmentations, np.array(ori_size), name 179 | 180 | def __len__(self): 181 | return len(self.imgs) 182 | 183 | if __name__ == '__main__': 184 | dl = DatasetGenerator('/media/jzzz/Data/Dataset/PPSS/TrainData/', './PPSS/train_id.txt', 185 | crop_size=(321, 161), training=False) 186 | 187 | item = iter(dl) 188 | for i in range(len(dl)): 189 | imgs, segs, segs_half, segs_full, idx = next(item) 190 | pass 191 | -------------------------------------------------------------------------------- /dataset/transforms.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Author: Donny You (youansheng@gmail.com) 4 | 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import numpy as np 11 | import torch 12 | from PIL import Image 13 | 14 | 15 | class Normalize(object): 16 | """Normalize a ``torch.tensor`` 17 | 18 | Args: 19 | inputs (torch.tensor): tensor to be normalized. 20 | mean: (list): the mean of RGB 21 | std: (list): the std of RGB 22 | 23 | Returns: 24 | Tensor: Normalized tensor. 25 | """ 26 | def __init__(self, div_value, mean, std): 27 | self.div_value = div_value 28 | self.mean = mean 29 | self.std =std 30 | 31 | def __call__(self, inputs): 32 | inputs = inputs.div(self.div_value) 33 | for t, m, s in zip(inputs, self.mean, self.std): 34 | t.sub_(m).div_(s) 35 | 36 | return inputs 37 | 38 | 39 | class DeNormalize(object): 40 | """DeNormalize a ``torch.tensor`` 41 | 42 | Args: 43 | inputs (torch.tensor): tensor to be normalized. 44 | mean: (list): the mean of RGB 45 | std: (list): the std of RGB 46 | 47 | Returns: 48 | Tensor: Normalized tensor. 49 | """ 50 | def __init__(self, div_value, mean, std): 51 | self.div_value = div_value 52 | self.mean = mean 53 | self.std =std 54 | 55 | def __call__(self, inputs): 56 | result = inputs.clone() 57 | for i in range(result.size(0)): 58 | result[i, :, :] = result[i, :, :] * self.std[i] + self.mean[i] 59 | 60 | return result.mul_(self.div_value) 61 | 62 | 63 | class ToTensor(object): 64 | """Convert a ``numpy.ndarray or Image`` to tensor. 65 | 66 | See ``ToTensor`` for more details. 67 | 68 | Args: 69 | inputs (numpy.ndarray or Image): Image to be converted to tensor. 70 | 71 | Returns: 72 | Tensor: Converted image. 73 | """ 74 | def __call__(self, inputs): 75 | if isinstance(inputs, Image.Image): 76 | channels = len(inputs.mode) 77 | inputs = np.array(inputs) 78 | inputs = inputs.reshape(inputs.shape[0], inputs.shape[1], channels) 79 | inputs = torch.from_numpy(inputs.transpose(2, 0, 1)) 80 | else: 81 | inputs = torch.from_numpy(inputs.transpose(2, 0, 1)) 82 | 83 | return inputs.float() 84 | 85 | 86 | class ToLabel(object): 87 | def __call__(self, inputs): 88 | return torch.from_numpy(np.array(inputs)).long() 89 | 90 | 91 | class ReLabel(object): 92 | """ 93 | 255 indicate the background, relabel 255 to some value. 94 | """ 95 | def __init__(self, olabel, nlabel): 96 | self.olabel = olabel 97 | self.nlabel = nlabel 98 | 99 | def __call__(self, inputs): 100 | assert isinstance(inputs, torch.LongTensor), 'tensor needs to be LongTensor' 101 | 102 | inputs[inputs == self.olabel] = self.nlabel 103 | return inputs 104 | 105 | 106 | class Compose(object): 107 | 108 | def __init__(self, transforms): 109 | self.transforms = transforms 110 | 111 | def __call__(self, inputs): 112 | for t in self.transforms: 113 | inputs = t(inputs) 114 | 115 | return inputs 116 | 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /dataset/weights.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from PIL import Image 5 | 6 | # data_root = '../data/CCF/Segmentations/' 7 | # fid = open('./CCF/train_id.txt', 'r') 8 | # num_cls = 18 9 | 10 | 11 | # data_root = '../data/Person/SegmentationPart/' 12 | # fid = open('./Pascal/train_id.txt', 'r') 13 | # num_cls = 7 14 | data_root = '../data/LIP/train_set/segmentations/' 15 | fid = open('./LIP/train_id.txt', 'r') 16 | num_cls = 20 17 | 18 | cls_pix_num = np.zeros(num_cls) 19 | cls_hbody_num = np.zeros(3) 20 | cls_fbody_num = np.zeros(2) 21 | 22 | map_idx = [0, 9, 19, 29, 50, 39, 60, 62] 23 | 24 | for line in fid.readlines(): 25 | img_path = os.path.join(data_root, line.strip() + '.png') 26 | # img_data = np.asarray(Image.open(img_path).convert('L')) 27 | img_data = np.array(Image.open(img_path)) 28 | # for i in range(len(map_idx)): 29 | # img_data[img_data == map_idx[i]] = i 30 | # img_size = img_data.size 31 | for i in range(num_cls): 32 | cls_pix_num[i] += (img_data == i).astype(int).sum(axis=None) 33 | 34 | # # half body 35 | # cls_hbody_num[0] = cls_pix_num[0] 36 | # for i in range(1, 5): 37 | # cls_hbody_num[1] += cls_pix_num[i] 38 | # for i in range(5, 8): 39 | # cls_hbody_num[2] += cls_pix_num[i] 40 | # 41 | # # full body 42 | # cls_fbody_num[0] = cls_pix_num[0] 43 | # for i in range(1, 8): 44 | # cls_fbody_num[1] += cls_pix_num[i] 45 | 46 | weight = np.log(cls_pix_num) 47 | weight_norm = np.zeros(num_cls) 48 | for i in range(num_cls): 49 | weight_norm[i] = 16 / weight[i] 50 | print(weight_norm) 51 | 52 | 53 | # [0.8373, 0.918, 0.866, 1.0345, 1.0166, 54 | # 0.9969, 0.9754, 1.0489, 0.8786, 1.0023, 55 | # 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 56 | # 1.0955, 1.0865, 1.1529, 1.0507] 57 | 58 | # 0.93237515, 1.01116892, 1.11201307 59 | 60 | # 0.98417377, 1.05657165 61 | 62 | # ATR training 63 | # [0.85978634, 1.19630769, 1.02639146, 1.30664970, 0.97220603, 1.04885815, 64 | # 1.01745278, 1.01481690, 1.27155077, 1.12947663, 1.13016390, 1.06514227, 65 | # 1.08384483, 1.08506841, 1.09560942, 1.09565198, 1.07504567, 1.20411509] 66 | 67 | #CCF 68 | # [0.82073458, 1.23651165, 1.0366326, 0.97076566, 1.2802332, 0.98860602, 69 | # 1.29035071, 1.03882453, 0.96725283, 1.05142434, 1.0075884, 0.98630539, 70 | # 1.06208869, 1.0160915, 1.1613597, 1.17624919, 1.1701143, 1.24720215] 71 | 72 | #PPSS 73 | # [0.89680465, 1.14352656, 1.20982646, 0.99269248, 74 | # 1.17911144, 1.00641032, 1.47017195, 1.16447113] 75 | 76 | #Pascal 77 | # [0.82877791, 0.95688253, 0.94921949, 1.00538108, 1.0201687, 1.01665831, 1.05470914] 78 | 79 | #Lip 80 | # [0.7602572, 0.94236198, 0.85644457, 1.04346266, 1.10627293, 0.80980162, 81 | # 0.95168713, 0.8403769, 1.05798412, 0.85746254, 1.01274366, 1.05854692, 82 | # 1.03430773, 0.84867818, 0.88027721, 0.87580925, 0.98747462, 0.9876475, 83 | # 1.00016535, 1.00108882] 84 | -------------------------------------------------------------------------------- /doc/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hlzhu09/Hierarchical-Human-Parsing/58737dbb7763f5a7baafd4801c5a287facf536cd/doc/architecture.png -------------------------------------------------------------------------------- /evaluate_pascal.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import cv2 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.autograd import Variable 10 | from torch.utils import data 11 | 12 | from dataset.data_pascal import TestGenerator 13 | from network.baseline import get_model 14 | 15 | 16 | def get_arguments(): 17 | """Parse all the arguments provided from the CLI. 18 | 19 | Returns: 20 | A list of parsed arguments. 21 | """ 22 | parser = argparse.ArgumentParser(description="Pytorch Segmentation") 23 | parser.add_argument('--root', default='./data/Person', type=str) 24 | parser.add_argument("--data-list", type=str, default='./dataset/Pascal/val_id.txt') 25 | parser.add_argument("--crop-size", type=int, default=473) 26 | parser.add_argument("--num-classes", type=int, default=7) 27 | parser.add_argument("--ignore-label", type=int, default=255) 28 | parser.add_argument('--restore-from', default='./checkpoints/exp/baseline_pascal.pth', type=str) 29 | 30 | parser.add_argument("--is-mirror", action="store_true") 31 | parser.add_argument("--ms", action="store_true") 32 | # parser.add_argument('--eval-scale', nargs='+', type=float, default=[1.0]) 33 | # parser.add_argument('--eval-scale', nargs='+', type=float, default=[0.50, 0.75, 1.0, 1.25, 1.50, 1.75]) 34 | 35 | parser.add_argument("--save-dir", type=str) 36 | parser.add_argument("--gpu", type=str, default='0') 37 | return parser.parse_args() 38 | 39 | 40 | def main(): 41 | """Create the model and start the evaluation process.""" 42 | args = get_arguments() 43 | 44 | # initialization 45 | print("Input arguments:") 46 | for key, val in vars(args).items(): 47 | print("{:16} {}".format(key, val)) 48 | 49 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 50 | 51 | model = get_model(num_classes=args.num_classes) 52 | 53 | # if not os.path.exists(args.save_dir): 54 | # os.makedirs(args.save_dir) 55 | 56 | palette = get_palette() 57 | restore_from = args.restore_from 58 | saved_state_dict = torch.load(restore_from) 59 | model.load_state_dict(saved_state_dict) 60 | 61 | model.eval() 62 | model.cuda() 63 | 64 | testloader = data.DataLoader(TestGenerator(args.root, args.data_list, crop_size=args.crop_size), 65 | batch_size=1, shuffle=False, pin_memory=True) 66 | 67 | confusion_matrix = np.zeros((args.num_classes, args.num_classes)) 68 | 69 | for index, batch in enumerate(testloader): 70 | if index % 100 == 0: 71 | print('%d images have been proceeded' % index) 72 | image, label, ori_size, name = batch 73 | 74 | ori_size = ori_size[0].numpy() 75 | if args.ms: 76 | eval_scale=[0.50, 0.75, 1.0, 1.25, 1.50, 1.75] 77 | else: 78 | eval_scale=[1.0] 79 | 80 | output = predict(model, image.numpy(), (np.asscalar(ori_size[0]), np.asscalar(ori_size[1])), 81 | is_mirror=args.is_mirror, scales=eval_scale) 82 | seg_pred = np.asarray(np.argmax(output, axis=2), dtype=np.uint8) 83 | 84 | # output_im = PILImage.fromarray(seg_pred) 85 | # output_im.putpalette(palette) 86 | # output_im.save(args.save_dir + name[0] + '.png') 87 | 88 | seg_gt = np.asarray(label[0].numpy(), dtype=np.int) 89 | ignore_index = seg_gt != 255 90 | seg_gt = seg_gt[ignore_index] 91 | seg_pred = seg_pred[ignore_index] 92 | 93 | confusion_matrix += get_confusion_matrix(seg_gt, seg_pred, args.num_classes) 94 | 95 | pos = confusion_matrix.sum(1) 96 | res = confusion_matrix.sum(0) 97 | tp = np.diag(confusion_matrix) 98 | 99 | pixel_accuracy = tp.sum() / pos.sum() 100 | mean_accuracy = (tp / np.maximum(1.0, pos)).mean() 101 | IU_array = (tp / np.maximum(1.0, pos + res - tp)) 102 | mean_IU = IU_array.mean() 103 | 104 | # get_confusion_matrix_plot() 105 | 106 | print('Pixel accuracy: %f \n' % pixel_accuracy) 107 | print('Mean accuracy: %f \n' % mean_accuracy) 108 | print('Mean IU: %f \n' % mean_IU) 109 | for index, IU in enumerate(IU_array): 110 | print('%f ', IU) 111 | 112 | 113 | def scale_image(image, scale): 114 | image = image[0, :, :, :] 115 | image = image.transpose((1, 2, 0)) 116 | image = cv2.resize(image, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) 117 | image = image.transpose((2, 0, 1)) 118 | return image 119 | 120 | 121 | def predict(net, image, output_size, is_mirror=True, scales=[1]): 122 | if is_mirror: 123 | image_rev = image[:, :, :, ::-1] 124 | 125 | interp = nn.Upsample(size=output_size, mode='bilinear', align_corners=True) 126 | 127 | outputs = [] 128 | if is_mirror: 129 | for scale in scales: 130 | if scale != 1: 131 | image_scale = scale_image(image=image, scale=scale) 132 | image_rev_scale = scale_image(image=image_rev, scale=scale) 133 | else: 134 | image_scale = image[0, :, :, :] 135 | image_rev_scale = image_rev[0, :, :, :] 136 | 137 | image_scale = np.stack((image_scale, image_rev_scale)) 138 | 139 | with torch.no_grad(): 140 | prediction = net(Variable(torch.from_numpy(image_scale)).cuda()) 141 | prediction = interp(prediction[0]).cpu().data.numpy() 142 | 143 | prediction_rev = prediction[1, :, :, :].copy() 144 | prediction_rev = prediction_rev[:, :, ::-1] 145 | prediction = prediction[0, :, :, :] 146 | prediction = np.mean([prediction, prediction_rev], axis=0) 147 | 148 | outputs.append(prediction) 149 | 150 | outputs = np.mean(outputs, axis=0) 151 | outputs = outputs.transpose(1, 2, 0) 152 | else: 153 | for scale in scales: 154 | if scale != 1: 155 | image_scale = scale_image(image=image, scale=scale) 156 | else: 157 | image_scale = image[0, :, :, :] 158 | 159 | with torch.no_grad(): 160 | prediction = net(Variable(torch.from_numpy(image_scale).unsqueeze(0)).cuda()) 161 | prediction = interp(prediction[0]).cpu().data.numpy() 162 | outputs.append(prediction[0, :, :, :]) 163 | 164 | outputs = np.mean(outputs, axis=0) 165 | outputs = outputs.transpose(1, 2, 0) 166 | 167 | return outputs 168 | 169 | 170 | def get_confusion_matrix(gt_label, pred_label, class_num): 171 | """ 172 | Calculate the confusion matrix by given label and pred 173 | :param gt_label: the ground truth label 174 | :param pred_label: the pred label 175 | :param class_num: the nunber of class 176 | """ 177 | index = (gt_label * class_num + pred_label).astype('int32') 178 | label_count = np.bincount(index) 179 | confusion_matrix = np.zeros((class_num, class_num)) 180 | 181 | for i_label in range(class_num): 182 | for i_pred_label in range(class_num): 183 | cur_index = i_label * class_num + i_pred_label 184 | if cur_index < len(label_count): 185 | confusion_matrix[i_label, i_pred_label] = label_count[cur_index] 186 | 187 | return confusion_matrix 188 | 189 | 190 | def get_confusion_matrix_plot(conf_arr): 191 | norm_conf = [] 192 | for i in conf_arr: 193 | tmp_arr = [] 194 | a = sum(i, 0) 195 | for j in i: 196 | tmp_arr.append(float(j) / max(1.0, float(a))) 197 | norm_conf.append(tmp_arr) 198 | 199 | fig = plt.figure() 200 | plt.clf() 201 | ax = fig.add_subplot(111) 202 | ax.set_aspect(1) 203 | res = ax.imshow(np.array(norm_conf), cmap=plt.cm.jet, interpolation='nearest') 204 | 205 | width, height = conf_arr.shape 206 | 207 | cb = fig.colorbar(res) 208 | alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' 209 | plt.xticks(range(width), alphabet[:width]) 210 | plt.yticks(range(height), alphabet[:height]) 211 | plt.savefig('confusion_matrix.png', format='png') 212 | 213 | 214 | def get_palette(): 215 | palette = [0, 0, 0, 216 | 128, 0, 0, 217 | 0, 128, 0, 218 | 128, 128, 0, 219 | 0, 0, 128, 220 | 128, 0, 128, 221 | 0, 128, 128] 222 | return palette 223 | 224 | if __name__ == '__main__': 225 | main() 226 | -------------------------------------------------------------------------------- /evaluate_pascal.sh: -------------------------------------------------------------------------------- 1 | python evaluate_pascal.py --root ./data/Person --data-list ./dataset/Pascal/val_id.txt --crop-size 473 --restore-from [checkpoint path] --ms -------------------------------------------------------------------------------- /inplace_abn/__init__.py: -------------------------------------------------------------------------------- 1 | from .bn import ABN, InPlaceABN, InPlaceABNSync 2 | from .functions import ACT_RELU, ACT_LEAKY_RELU, ACT_ELU, ACT_NONE 3 | -------------------------------------------------------------------------------- /inplace_abn/bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as functional 4 | 5 | try: 6 | from queue import Queue 7 | except ImportError: 8 | from Queue import Queue 9 | 10 | from .functions import * 11 | 12 | 13 | class ABN(nn.Module): 14 | """Activated Batch Normalization 15 | 16 | This gathers a `BatchNorm2d` and an activation function in a single module 17 | """ 18 | 19 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01): 20 | """Creates an Activated Batch Normalization module 21 | 22 | Parameters 23 | ---------- 24 | num_features : int 25 | Number of feature channels in the input and output. 26 | eps : float 27 | Small constant to prevent numerical issues. 28 | momentum : float 29 | Momentum factor applied to compute running statistics as. 30 | affine : bool 31 | If `True` apply learned scale and shift transformation after normalization. 32 | activation : str 33 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`. 34 | slope : float 35 | Negative slope for the `leaky_relu` activation. 36 | """ 37 | super(ABN, self).__init__() 38 | self.num_features = num_features 39 | self.affine = affine 40 | self.eps = eps 41 | self.momentum = momentum 42 | self.activation = activation 43 | self.slope = slope 44 | if self.affine: 45 | self.weight = nn.Parameter(torch.ones(num_features)) 46 | self.bias = nn.Parameter(torch.zeros(num_features)) 47 | else: 48 | self.register_parameter('weight', None) 49 | self.register_parameter('bias', None) 50 | self.register_buffer('running_mean', torch.zeros(num_features)) 51 | self.register_buffer('running_var', torch.ones(num_features)) 52 | self.reset_parameters() 53 | 54 | def reset_parameters(self): 55 | nn.init.constant_(self.running_mean, 0) 56 | nn.init.constant_(self.running_var, 1) 57 | if self.affine: 58 | nn.init.constant_(self.weight, 1) 59 | nn.init.constant_(self.bias, 0) 60 | 61 | def forward(self, x): 62 | x = functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, 63 | self.training, self.momentum, self.eps) 64 | 65 | if self.activation == ACT_RELU: 66 | return functional.relu(x, inplace=True) 67 | elif self.activation == ACT_LEAKY_RELU: 68 | return functional.leaky_relu(x, negative_slope=self.slope, inplace=True) 69 | elif self.activation == ACT_ELU: 70 | return functional.elu(x, inplace=True) 71 | else: 72 | return x 73 | 74 | def __repr__(self): 75 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ 76 | ' affine={affine}, activation={activation}' 77 | if self.activation == "leaky_relu": 78 | rep += ', slope={slope})' 79 | else: 80 | rep += ')' 81 | return rep.format(name=self.__class__.__name__, **self.__dict__) 82 | 83 | 84 | class InPlaceABN(ABN): 85 | """InPlace Activated Batch Normalization""" 86 | 87 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01): 88 | """Creates an InPlace Activated Batch Normalization module 89 | 90 | Parameters 91 | ---------- 92 | num_features : int 93 | Number of feature channels in the input and output. 94 | eps : float 95 | Small constant to prevent numerical issues. 96 | momentum : float 97 | Momentum factor applied to compute running statistics as. 98 | affine : bool 99 | If `True` apply learned scale and shift transformation after normalization. 100 | activation : str 101 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`. 102 | slope : float 103 | Negative slope for the `leaky_relu` activation. 104 | """ 105 | super(InPlaceABN, self).__init__(num_features, eps, momentum, affine, activation, slope) 106 | 107 | def forward(self, x): 108 | return inplace_abn(x, self.weight, self.bias, self.running_mean, self.running_var, 109 | self.training, self.momentum, self.eps, self.activation, self.slope) 110 | 111 | 112 | class InPlaceABNSync(ABN): 113 | """InPlace Activated Batch Normalization with cross-GPU synchronization 114 | This assumes that it will be replicated across GPUs using the same mechanism as in `nn.DistributedDataParallel`. 115 | """ 116 | 117 | def forward(self, x): 118 | return inplace_abn_sync(x, self.weight, self.bias, self.running_mean, self.running_var, 119 | self.training, self.momentum, self.eps, self.activation, self.slope) 120 | 121 | def __repr__(self): 122 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ 123 | ' affine={affine}, activation={activation}' 124 | if self.activation == "leaky_relu": 125 | rep += ', slope={slope})' 126 | else: 127 | rep += ')' 128 | return rep.format(name=self.__class__.__name__, **self.__dict__) 129 | 130 | 131 | -------------------------------------------------------------------------------- /inplace_abn/functions.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | import torch 3 | import torch.distributed as dist 4 | import torch.autograd as autograd 5 | import torch.cuda.comm as comm 6 | from torch.autograd.function import once_differentiable 7 | from torch.utils.cpp_extension import load 8 | 9 | _src_path = path.join(path.dirname(path.abspath(__file__)), "src") 10 | _backend = load(name="inplace_abn", 11 | extra_cflags=["-O3"], 12 | sources=[path.join(_src_path, f) for f in [ 13 | "inplace_abn.cpp", 14 | "inplace_abn_cpu.cpp", 15 | "inplace_abn_cuda.cu", 16 | "inplace_abn_cuda_half.cu" 17 | ]], 18 | extra_cuda_cflags=["--expt-extended-lambda"]) 19 | 20 | # Activation names 21 | ACT_RELU = "relu" 22 | ACT_LEAKY_RELU = "leaky_relu" 23 | ACT_ELU = "elu" 24 | ACT_NONE = "none" 25 | 26 | 27 | def _check(fn, *args, **kwargs): 28 | success = fn(*args, **kwargs) 29 | if not success: 30 | raise RuntimeError("CUDA Error encountered in {}".format(fn)) 31 | 32 | 33 | def _broadcast_shape(x): 34 | out_size = [] 35 | for i, s in enumerate(x.size()): 36 | if i != 1: 37 | out_size.append(1) 38 | else: 39 | out_size.append(s) 40 | return out_size 41 | 42 | 43 | def _reduce(x): 44 | if len(x.size()) == 2: 45 | return x.sum(dim=0) 46 | else: 47 | n, c = x.size()[0:2] 48 | return x.contiguous().view((n, c, -1)).sum(2).sum(0) 49 | 50 | 51 | def _count_samples(x): 52 | count = 1 53 | for i, s in enumerate(x.size()): 54 | if i != 1: 55 | count *= s 56 | return count 57 | 58 | 59 | def _act_forward(ctx, x): 60 | if ctx.activation == ACT_LEAKY_RELU: 61 | _backend.leaky_relu_forward(x, ctx.slope) 62 | elif ctx.activation == ACT_ELU: 63 | _backend.elu_forward(x) 64 | elif ctx.activation == ACT_NONE: 65 | pass 66 | 67 | 68 | def _act_backward(ctx, x, dx): 69 | if ctx.activation == ACT_LEAKY_RELU: 70 | _backend.leaky_relu_backward(x, dx, ctx.slope) 71 | elif ctx.activation == ACT_ELU: 72 | _backend.elu_backward(x, dx) 73 | elif ctx.activation == ACT_NONE: 74 | pass 75 | 76 | 77 | class InPlaceABN(autograd.Function): 78 | @staticmethod 79 | def forward(ctx, x, weight, bias, running_mean, running_var, 80 | training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01): 81 | # Save context 82 | ctx.training = training 83 | ctx.momentum = momentum 84 | ctx.eps = eps 85 | ctx.activation = activation 86 | ctx.slope = slope 87 | ctx.affine = weight is not None and bias is not None 88 | 89 | # Prepare inputs 90 | count = _count_samples(x) 91 | x = x.contiguous() 92 | weight = weight.contiguous() if ctx.affine else x.new_empty(0, dtype=torch.float32) 93 | bias = bias.contiguous() if ctx.affine else x.new_empty(0, dtype=torch.float32) 94 | 95 | if ctx.training: 96 | mean, var = _backend.mean_var(x) 97 | 98 | # Update running stats 99 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) 100 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1)) 101 | 102 | # Mark in-place modified tensors 103 | ctx.mark_dirty(x, running_mean, running_var) 104 | else: 105 | mean, var = running_mean.contiguous(), running_var.contiguous() 106 | ctx.mark_dirty(x) 107 | 108 | # BN forward + activation 109 | _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps) 110 | _act_forward(ctx, x) 111 | 112 | # Output 113 | ctx.var = var 114 | ctx.save_for_backward(x, var, weight, bias) 115 | return x 116 | 117 | @staticmethod 118 | @once_differentiable 119 | def backward(ctx, dz): 120 | z, var, weight, bias = ctx.saved_tensors 121 | dz = dz.contiguous() 122 | 123 | # Undo activation 124 | _act_backward(ctx, z, dz) 125 | 126 | if ctx.training: 127 | edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps) 128 | else: 129 | # TODO: implement simplified CUDA backward for inference mode 130 | edz = dz.new_zeros(dz.size(1)) 131 | eydz = dz.new_zeros(dz.size(1)) 132 | 133 | dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps) 134 | # dweight = eydz * weight.sign() if ctx.affine else None 135 | dweight = eydz if ctx.affine else None 136 | if dweight is not None: 137 | dweight[weight < 0] *= -1 138 | dbias = edz if ctx.affine else None 139 | 140 | return dx, dweight, dbias, None, None, None, None, None, None, None 141 | 142 | 143 | class InPlaceABNSync(autograd.Function): 144 | @classmethod 145 | def forward(cls, ctx, x, weight, bias, running_mean, running_var, 146 | training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01, equal_batches=True): 147 | # Save context 148 | ctx.training = training 149 | ctx.momentum = momentum 150 | ctx.eps = eps 151 | ctx.activation = activation 152 | ctx.slope = slope 153 | ctx.affine = weight is not None and bias is not None 154 | 155 | # Prepare inputs 156 | ctx.world_size = dist.get_world_size() if dist.is_initialized() else 1 157 | 158 | # count = _count_samples(x) 159 | batch_size = x.new_tensor([x.shape[0]], dtype=torch.long) 160 | 161 | x = x.contiguous() 162 | weight = weight.contiguous() if ctx.affine else x.new_empty(0, dtype=torch.float32) 163 | bias = bias.contiguous() if ctx.affine else x.new_empty(0, dtype=torch.float32) 164 | 165 | if ctx.training: 166 | mean, var = _backend.mean_var(x) 167 | if ctx.world_size > 1: 168 | # get global batch size 169 | if equal_batches: 170 | batch_size *= ctx.world_size 171 | else: 172 | dist.all_reduce(batch_size, dist.ReduceOp.SUM) 173 | 174 | ctx.factor = x.shape[0] / float(batch_size.item()) 175 | 176 | mean_all = mean.clone() * ctx.factor 177 | dist.all_reduce(mean_all, dist.ReduceOp.SUM) 178 | 179 | var_all = (var + (mean - mean_all) ** 2) * ctx.factor 180 | dist.all_reduce(var_all, dist.ReduceOp.SUM) 181 | 182 | mean = mean_all 183 | var = var_all 184 | 185 | # Update running stats 186 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) 187 | count = batch_size.item() * x.view(x.shape[0], x.shape[1], -1).shape[-1] 188 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * (float(count) / (count - 1))) 189 | 190 | # Mark in-place modified tensors 191 | ctx.mark_dirty(x, running_mean, running_var) 192 | else: 193 | mean, var = running_mean.contiguous(), running_var.contiguous() 194 | ctx.mark_dirty(x) 195 | 196 | # BN forward + activation 197 | _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps) 198 | _act_forward(ctx, x) 199 | 200 | # Output 201 | ctx.var = var 202 | ctx.save_for_backward(x, var, weight, bias) 203 | return x 204 | 205 | @staticmethod 206 | @once_differentiable 207 | def backward(ctx, dz): 208 | z, var, weight, bias = ctx.saved_tensors 209 | dz = dz.contiguous() 210 | 211 | # Undo activation 212 | _act_backward(ctx, z, dz) 213 | 214 | if ctx.training: 215 | edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps) 216 | edz_local = edz.clone() 217 | eydz_local = eydz.clone() 218 | 219 | if ctx.world_size > 1: 220 | edz *= ctx.factor 221 | dist.all_reduce(edz, dist.ReduceOp.SUM) 222 | 223 | eydz *= ctx.factor 224 | dist.all_reduce(eydz, dist.ReduceOp.SUM) 225 | else: 226 | edz_local = edz = dz.new_zeros(dz.size(1)) 227 | eydz_local = eydz = dz.new_zeros(dz.size(1)) 228 | 229 | dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps) 230 | # dweight = eydz_local * weight.sign() if ctx.affine else None 231 | dweight = eydz_local if ctx.affine else None 232 | if dweight is not None: 233 | dweight[weight < 0] *= -1 234 | dbias = edz_local if ctx.affine else None 235 | 236 | return dx, dweight, dbias, None, None, None, None, None, None, None 237 | 238 | 239 | inplace_abn = InPlaceABN.apply 240 | inplace_abn_sync = InPlaceABNSync.apply 241 | 242 | __all__ = ["inplace_abn", "inplace_abn_sync", "ACT_RELU", "ACT_LEAKY_RELU", "ACT_ELU", "ACT_NONE"] 243 | -------------------------------------------------------------------------------- /inplace_abn/src/checks.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | // Define AT_CHECK for old version of ATen where the same function was called AT_ASSERT 6 | #ifndef AT_CHECK 7 | #define AT_CHECK AT_ASSERT 8 | #endif 9 | 10 | #define CHECK_CUDA(x) AT_CHECK((x).type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CPU(x) AT_CHECK(!(x).type().is_cuda(), #x " must be a CPU tensor") 12 | #define CHECK_CONTIGUOUS(x) AT_CHECK((x).is_contiguous(), #x " must be contiguous") 13 | 14 | #define CHECK_CUDA_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 15 | #define CHECK_CPU_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x) -------------------------------------------------------------------------------- /inplace_abn/src/common.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | /* 6 | * General settings 7 | */ 8 | const int WARP_SIZE = 32; 9 | const int MAX_BLOCK_SIZE = 512; 10 | 11 | template 12 | struct Pair { 13 | T v1, v2; 14 | __device__ Pair() {} 15 | __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {} 16 | __device__ Pair(T v) : v1(v), v2(v) {} 17 | __device__ Pair(int v) : v1(v), v2(v) {} 18 | __device__ Pair &operator+=(const Pair &a) { 19 | v1 += a.v1; 20 | v2 += a.v2; 21 | return *this; 22 | } 23 | }; 24 | 25 | /* 26 | * Utility functions 27 | */ 28 | template 29 | __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, 30 | unsigned int mask = 0xffffffff) { 31 | #if CUDART_VERSION >= 9000 32 | return __shfl_xor_sync(mask, value, laneMask, width); 33 | #else 34 | return __shfl_xor(value, laneMask, width); 35 | #endif 36 | } 37 | 38 | __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); } 39 | 40 | static int getNumThreads(int nElem) { 41 | int threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE}; 42 | for (int i = 0; i != 5; ++i) { 43 | if (nElem <= threadSizes[i]) { 44 | return threadSizes[i]; 45 | } 46 | } 47 | return MAX_BLOCK_SIZE; 48 | } 49 | 50 | template 51 | static __device__ __forceinline__ T warpSum(T val) { 52 | #if __CUDA_ARCH__ >= 300 53 | for (int i = 0; i < getMSB(WARP_SIZE); ++i) { 54 | val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE); 55 | } 56 | #else 57 | __shared__ T values[MAX_BLOCK_SIZE]; 58 | values[threadIdx.x] = val; 59 | __threadfence_block(); 60 | const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE; 61 | for (int i = 1; i < WARP_SIZE; i++) { 62 | val += values[base + ((i + threadIdx.x) % WARP_SIZE)]; 63 | } 64 | #endif 65 | return val; 66 | } 67 | 68 | template 69 | static __device__ __forceinline__ Pair warpSum(Pair value) { 70 | value.v1 = warpSum(value.v1); 71 | value.v2 = warpSum(value.v2); 72 | return value; 73 | } 74 | 75 | template 76 | __device__ T reduce(Op op, int plane, int N, int C, int S) { 77 | T sum = (T)0; 78 | for (int batch = 0; batch < N; ++batch) { 79 | for (int x = threadIdx.x; x < S; x += blockDim.x) { 80 | sum += op(batch, plane, x); 81 | } 82 | } 83 | 84 | // sum over NumThreads within a warp 85 | sum = warpSum(sum); 86 | 87 | // 'transpose', and reduce within warp again 88 | __shared__ T shared[32]; 89 | __syncthreads(); 90 | if (threadIdx.x % WARP_SIZE == 0) { 91 | shared[threadIdx.x / WARP_SIZE] = sum; 92 | } 93 | if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { 94 | // zero out the other entries in shared 95 | shared[threadIdx.x] = (T)0; 96 | } 97 | __syncthreads(); 98 | if (threadIdx.x / WARP_SIZE == 0) { 99 | sum = warpSum(shared[threadIdx.x]); 100 | if (threadIdx.x == 0) { 101 | shared[0] = sum; 102 | } 103 | } 104 | __syncthreads(); 105 | 106 | // Everyone picks it up, should be broadcast into the whole gradInput 107 | return shared[0]; 108 | } -------------------------------------------------------------------------------- /inplace_abn/src/inplace_abn.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "inplace_abn.h" 6 | 7 | std::vector mean_var(at::Tensor x) { 8 | if (x.is_cuda()) { 9 | if (x.type().scalarType() == at::ScalarType::Half) { 10 | return mean_var_cuda_h(x); 11 | } else { 12 | return mean_var_cuda(x); 13 | } 14 | } else { 15 | return mean_var_cpu(x); 16 | } 17 | } 18 | 19 | at::Tensor forward(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 20 | bool affine, float eps) { 21 | if (x.is_cuda()) { 22 | if (x.type().scalarType() == at::ScalarType::Half) { 23 | return forward_cuda_h(x, mean, var, weight, bias, affine, eps); 24 | } else { 25 | return forward_cuda(x, mean, var, weight, bias, affine, eps); 26 | } 27 | } else { 28 | return forward_cpu(x, mean, var, weight, bias, affine, eps); 29 | } 30 | } 31 | 32 | std::vector edz_eydz(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 33 | bool affine, float eps) { 34 | if (z.is_cuda()) { 35 | if (z.type().scalarType() == at::ScalarType::Half) { 36 | return edz_eydz_cuda_h(z, dz, weight, bias, affine, eps); 37 | } else { 38 | return edz_eydz_cuda(z, dz, weight, bias, affine, eps); 39 | } 40 | } else { 41 | return edz_eydz_cpu(z, dz, weight, bias, affine, eps); 42 | } 43 | } 44 | 45 | at::Tensor backward(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 46 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) { 47 | if (z.is_cuda()) { 48 | if (z.type().scalarType() == at::ScalarType::Half) { 49 | return backward_cuda_h(z, dz, var, weight, bias, edz, eydz, affine, eps); 50 | } else { 51 | return backward_cuda(z, dz, var, weight, bias, edz, eydz, affine, eps); 52 | } 53 | } else { 54 | return backward_cpu(z, dz, var, weight, bias, edz, eydz, affine, eps); 55 | } 56 | } 57 | 58 | void leaky_relu_forward(at::Tensor z, float slope) { 59 | at::leaky_relu_(z, slope); 60 | } 61 | 62 | void leaky_relu_backward(at::Tensor z, at::Tensor dz, float slope) { 63 | if (z.is_cuda()) { 64 | if (z.type().scalarType() == at::ScalarType::Half) { 65 | return leaky_relu_backward_cuda_h(z, dz, slope); 66 | } else { 67 | return leaky_relu_backward_cuda(z, dz, slope); 68 | } 69 | } else { 70 | return leaky_relu_backward_cpu(z, dz, slope); 71 | } 72 | } 73 | 74 | void elu_forward(at::Tensor z) { 75 | at::elu_(z); 76 | } 77 | 78 | void elu_backward(at::Tensor z, at::Tensor dz) { 79 | if (z.is_cuda()) { 80 | return elu_backward_cuda(z, dz); 81 | } else { 82 | return elu_backward_cpu(z, dz); 83 | } 84 | } 85 | 86 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 87 | m.def("mean_var", &mean_var, "Mean and variance computation"); 88 | m.def("forward", &forward, "In-place forward computation"); 89 | m.def("edz_eydz", &edz_eydz, "First part of backward computation"); 90 | m.def("backward", &backward, "Second part of backward computation"); 91 | m.def("leaky_relu_forward", &leaky_relu_forward, "Leaky relu forward computation"); 92 | m.def("leaky_relu_backward", &leaky_relu_backward, "Leaky relu backward computation and inversion"); 93 | m.def("elu_forward", &elu_forward, "Elu forward computation"); 94 | m.def("elu_backward", &elu_backward, "Elu backward computation and inversion"); 95 | } 96 | -------------------------------------------------------------------------------- /inplace_abn/src/inplace_abn.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | 7 | std::vector mean_var_cpu(at::Tensor x); 8 | std::vector mean_var_cuda(at::Tensor x); 9 | std::vector mean_var_cuda_h(at::Tensor x); 10 | 11 | at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 12 | bool affine, float eps); 13 | at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 14 | bool affine, float eps); 15 | at::Tensor forward_cuda_h(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 16 | bool affine, float eps); 17 | 18 | std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 19 | bool affine, float eps); 20 | std::vector edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 21 | bool affine, float eps); 22 | std::vector edz_eydz_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 23 | bool affine, float eps); 24 | 25 | at::Tensor backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 26 | at::Tensor edz, at::Tensor eydz, bool affine, float eps); 27 | at::Tensor backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 28 | at::Tensor edz, at::Tensor eydz, bool affine, float eps); 29 | at::Tensor backward_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 30 | at::Tensor edz, at::Tensor eydz, bool affine, float eps); 31 | 32 | void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope); 33 | void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope); 34 | void leaky_relu_backward_cuda_h(at::Tensor z, at::Tensor dz, float slope); 35 | 36 | void elu_backward_cpu(at::Tensor z, at::Tensor dz); 37 | void elu_backward_cuda(at::Tensor z, at::Tensor dz); 38 | 39 | static void get_dims(at::Tensor x, int64_t& num, int64_t& chn, int64_t& sp) { 40 | num = x.size(0); 41 | chn = x.size(1); 42 | sp = 1; 43 | for (int64_t i = 2; i < x.ndimension(); ++i) 44 | sp *= x.size(i); 45 | } 46 | 47 | /* 48 | * Specialized CUDA reduction functions for BN 49 | */ 50 | #ifdef __CUDACC__ 51 | 52 | #include "utils/cuda.cuh" 53 | 54 | template 55 | __device__ T reduce(Op op, int plane, int N, int S) { 56 | T sum = (T)0; 57 | for (int batch = 0; batch < N; ++batch) { 58 | for (int x = threadIdx.x; x < S; x += blockDim.x) { 59 | sum += op(batch, plane, x); 60 | } 61 | } 62 | 63 | // sum over NumThreads within a warp 64 | sum = warpSum(sum); 65 | 66 | // 'transpose', and reduce within warp again 67 | __shared__ T shared[32]; 68 | __syncthreads(); 69 | if (threadIdx.x % WARP_SIZE == 0) { 70 | shared[threadIdx.x / WARP_SIZE] = sum; 71 | } 72 | if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { 73 | // zero out the other entries in shared 74 | shared[threadIdx.x] = (T)0; 75 | } 76 | __syncthreads(); 77 | if (threadIdx.x / WARP_SIZE == 0) { 78 | sum = warpSum(shared[threadIdx.x]); 79 | if (threadIdx.x == 0) { 80 | shared[0] = sum; 81 | } 82 | } 83 | __syncthreads(); 84 | 85 | // Everyone picks it up, should be broadcast into the whole gradInput 86 | return shared[0]; 87 | } 88 | #endif 89 | -------------------------------------------------------------------------------- /inplace_abn/src/inplace_abn_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "utils/checks.h" 6 | #include "inplace_abn.h" 7 | 8 | at::Tensor reduce_sum(at::Tensor x) { 9 | if (x.ndimension() == 2) { 10 | return x.sum(0); 11 | } else { 12 | auto x_view = x.view({x.size(0), x.size(1), -1}); 13 | return x_view.sum(-1).sum(0); 14 | } 15 | } 16 | 17 | at::Tensor broadcast_to(at::Tensor v, at::Tensor x) { 18 | if (x.ndimension() == 2) { 19 | return v; 20 | } else { 21 | std::vector broadcast_size = {1, -1}; 22 | for (int64_t i = 2; i < x.ndimension(); ++i) 23 | broadcast_size.push_back(1); 24 | 25 | return v.view(broadcast_size); 26 | } 27 | } 28 | 29 | int64_t count(at::Tensor x) { 30 | int64_t count = x.size(0); 31 | for (int64_t i = 2; i < x.ndimension(); ++i) 32 | count *= x.size(i); 33 | 34 | return count; 35 | } 36 | 37 | at::Tensor invert_affine(at::Tensor z, at::Tensor weight, at::Tensor bias, bool affine, float eps) { 38 | if (affine) { 39 | return (z - broadcast_to(bias, z)) / broadcast_to(at::abs(weight) + eps, z); 40 | } else { 41 | return z; 42 | } 43 | } 44 | 45 | std::vector mean_var_cpu(at::Tensor x) { 46 | auto num = count(x); 47 | auto mean = reduce_sum(x) / num; 48 | auto diff = x - broadcast_to(mean, x); 49 | auto var = reduce_sum(diff.pow(2)) / num; 50 | 51 | return {mean, var}; 52 | } 53 | 54 | at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 55 | bool affine, float eps) { 56 | auto gamma = affine ? at::abs(weight) + eps : at::ones_like(var); 57 | auto mul = at::rsqrt(var + eps) * gamma; 58 | 59 | x.sub_(broadcast_to(mean, x)); 60 | x.mul_(broadcast_to(mul, x)); 61 | if (affine) x.add_(broadcast_to(bias, x)); 62 | 63 | return x; 64 | } 65 | 66 | std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 67 | bool affine, float eps) { 68 | auto edz = reduce_sum(dz); 69 | auto y = invert_affine(z, weight, bias, affine, eps); 70 | auto eydz = reduce_sum(y * dz); 71 | 72 | return {edz, eydz}; 73 | } 74 | 75 | at::Tensor backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 76 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) { 77 | auto y = invert_affine(z, weight, bias, affine, eps); 78 | auto mul = affine ? at::rsqrt(var + eps) * (at::abs(weight) + eps) : at::rsqrt(var + eps); 79 | 80 | auto num = count(z); 81 | auto dx = (dz - broadcast_to(edz / num, dz) - y * broadcast_to(eydz / num, dz)) * broadcast_to(mul, dz); 82 | return dx; 83 | } 84 | 85 | void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope) { 86 | CHECK_CPU_INPUT(z); 87 | CHECK_CPU_INPUT(dz); 88 | 89 | AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cpu", ([&] { 90 | int64_t count = z.numel(); 91 | auto *_z = z.data(); 92 | auto *_dz = dz.data(); 93 | 94 | for (int64_t i = 0; i < count; ++i) { 95 | if (_z[i] < 0) { 96 | _z[i] *= 1 / slope; 97 | _dz[i] *= slope; 98 | } 99 | } 100 | })); 101 | } 102 | 103 | void elu_backward_cpu(at::Tensor z, at::Tensor dz) { 104 | CHECK_CPU_INPUT(z); 105 | CHECK_CPU_INPUT(dz); 106 | 107 | AT_DISPATCH_FLOATING_TYPES(z.type(), "elu_backward_cpu", ([&] { 108 | int64_t count = z.numel(); 109 | auto *_z = z.data(); 110 | auto *_dz = dz.data(); 111 | 112 | for (int64_t i = 0; i < count; ++i) { 113 | if (_z[i] < 0) { 114 | _z[i] = log1p(_z[i]); 115 | _dz[i] *= (_z[i] + 1.f); 116 | } 117 | } 118 | })); 119 | } 120 | -------------------------------------------------------------------------------- /inplace_abn/src/inplace_abn_cuda_half.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include 6 | 7 | #include "utils/checks.h" 8 | #include "utils/cuda.cuh" 9 | #include "inplace_abn.h" 10 | 11 | #include 12 | 13 | // Operations for reduce 14 | struct SumOpH { 15 | __device__ SumOpH(const half *t, int c, int s) 16 | : tensor(t), chn(c), sp(s) {} 17 | __device__ __forceinline__ float operator()(int batch, int plane, int n) { 18 | return __half2float(tensor[(batch * chn + plane) * sp + n]); 19 | } 20 | const half *tensor; 21 | const int chn; 22 | const int sp; 23 | }; 24 | 25 | struct VarOpH { 26 | __device__ VarOpH(float m, const half *t, int c, int s) 27 | : mean(m), tensor(t), chn(c), sp(s) {} 28 | __device__ __forceinline__ float operator()(int batch, int plane, int n) { 29 | const auto t = __half2float(tensor[(batch * chn + plane) * sp + n]); 30 | return (t - mean) * (t - mean); 31 | } 32 | const float mean; 33 | const half *tensor; 34 | const int chn; 35 | const int sp; 36 | }; 37 | 38 | struct GradOpH { 39 | __device__ GradOpH(float _weight, float _bias, const half *_z, const half *_dz, int c, int s) 40 | : weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {} 41 | __device__ __forceinline__ Pair operator()(int batch, int plane, int n) { 42 | float _y = (__half2float(z[(batch * chn + plane) * sp + n]) - bias) / weight; 43 | float _dz = __half2float(dz[(batch * chn + plane) * sp + n]); 44 | return Pair(_dz, _y * _dz); 45 | } 46 | const float weight; 47 | const float bias; 48 | const half *z; 49 | const half *dz; 50 | const int chn; 51 | const int sp; 52 | }; 53 | 54 | /*********** 55 | * mean_var 56 | ***********/ 57 | 58 | __global__ void mean_var_kernel_h(const half *x, float *mean, float *var, int num, int chn, int sp) { 59 | int plane = blockIdx.x; 60 | float norm = 1.f / static_cast(num * sp); 61 | 62 | float _mean = reduce(SumOpH(x, chn, sp), plane, num, sp) * norm; 63 | __syncthreads(); 64 | float _var = reduce(VarOpH(_mean, x, chn, sp), plane, num, sp) * norm; 65 | 66 | if (threadIdx.x == 0) { 67 | mean[plane] = _mean; 68 | var[plane] = _var; 69 | } 70 | } 71 | 72 | std::vector mean_var_cuda_h(at::Tensor x) { 73 | CHECK_CUDA_INPUT(x); 74 | 75 | // Extract dimensions 76 | int64_t num, chn, sp; 77 | get_dims(x, num, chn, sp); 78 | 79 | // Prepare output tensors 80 | auto mean = at::empty({chn},x.options().dtype(at::kFloat)); 81 | auto var = at::empty({chn},x.options().dtype(at::kFloat)); 82 | 83 | // Run kernel 84 | dim3 blocks(chn); 85 | dim3 threads(getNumThreads(sp)); 86 | auto stream = at::cuda::getCurrentCUDAStream(); 87 | mean_var_kernel_h<<>>( 88 | reinterpret_cast(x.data()), 89 | mean.data(), 90 | var.data(), 91 | num, chn, sp); 92 | 93 | return {mean, var}; 94 | } 95 | 96 | /********** 97 | * forward 98 | **********/ 99 | 100 | __global__ void forward_kernel_h(half *x, const float *mean, const float *var, const float *weight, const float *bias, 101 | bool affine, float eps, int num, int chn, int sp) { 102 | int plane = blockIdx.x; 103 | 104 | const float _mean = mean[plane]; 105 | const float _var = var[plane]; 106 | const float _weight = affine ? abs(weight[plane]) + eps : 1.f; 107 | const float _bias = affine ? bias[plane] : 0.f; 108 | 109 | const float mul = rsqrt(_var + eps) * _weight; 110 | 111 | for (int batch = 0; batch < num; ++batch) { 112 | for (int n = threadIdx.x; n < sp; n += blockDim.x) { 113 | half *x_ptr = x + (batch * chn + plane) * sp + n; 114 | float _x = __half2float(*x_ptr); 115 | float _y = (_x - _mean) * mul + _bias; 116 | 117 | *x_ptr = __float2half(_y); 118 | } 119 | } 120 | } 121 | 122 | at::Tensor forward_cuda_h(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 123 | bool affine, float eps) { 124 | CHECK_CUDA_INPUT(x); 125 | CHECK_CUDA_INPUT(mean); 126 | CHECK_CUDA_INPUT(var); 127 | CHECK_CUDA_INPUT(weight); 128 | CHECK_CUDA_INPUT(bias); 129 | 130 | // Extract dimensions 131 | int64_t num, chn, sp; 132 | get_dims(x, num, chn, sp); 133 | 134 | // Run kernel 135 | dim3 blocks(chn); 136 | dim3 threads(getNumThreads(sp)); 137 | auto stream = at::cuda::getCurrentCUDAStream(); 138 | forward_kernel_h<<>>( 139 | reinterpret_cast(x.data()), 140 | mean.data(), 141 | var.data(), 142 | weight.data(), 143 | bias.data(), 144 | affine, eps, num, chn, sp); 145 | 146 | return x; 147 | } 148 | 149 | __global__ void edz_eydz_kernel_h(const half *z, const half *dz, const float *weight, const float *bias, 150 | float *edz, float *eydz, bool affine, float eps, int num, int chn, int sp) { 151 | int plane = blockIdx.x; 152 | 153 | float _weight = affine ? abs(weight[plane]) + eps : 1.f; 154 | float _bias = affine ? bias[plane] : 0.f; 155 | 156 | Pair res = reduce, GradOpH>(GradOpH(_weight, _bias, z, dz, chn, sp), plane, num, sp); 157 | __syncthreads(); 158 | 159 | if (threadIdx.x == 0) { 160 | edz[plane] = res.v1; 161 | eydz[plane] = res.v2; 162 | } 163 | } 164 | 165 | std::vector edz_eydz_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 166 | bool affine, float eps) { 167 | CHECK_CUDA_INPUT(z); 168 | CHECK_CUDA_INPUT(dz); 169 | CHECK_CUDA_INPUT(weight); 170 | CHECK_CUDA_INPUT(bias); 171 | 172 | // Extract dimensions 173 | int64_t num, chn, sp; 174 | get_dims(z, num, chn, sp); 175 | 176 | auto edz = at::empty({chn},z.options().dtype(at::kFloat)); 177 | auto eydz = at::empty({chn},z.options().dtype(at::kFloat)); 178 | 179 | // Run kernel 180 | dim3 blocks(chn); 181 | dim3 threads(getNumThreads(sp)); 182 | auto stream = at::cuda::getCurrentCUDAStream(); 183 | edz_eydz_kernel_h<<>>( 184 | reinterpret_cast(z.data()), 185 | reinterpret_cast(dz.data()), 186 | weight.data(), 187 | bias.data(), 188 | edz.data(), 189 | eydz.data(), 190 | affine, eps, num, chn, sp); 191 | 192 | return {edz, eydz}; 193 | } 194 | 195 | __global__ void backward_kernel_h(const half *z, const half *dz, const float *var, const float *weight, const float *bias, const float *edz, 196 | const float *eydz, half *dx, bool affine, float eps, int num, int chn, int sp) { 197 | int plane = blockIdx.x; 198 | 199 | float _weight = affine ? abs(weight[plane]) + eps : 1.f; 200 | float _bias = affine ? bias[plane] : 0.f; 201 | float _var = var[plane]; 202 | float _edz = edz[plane]; 203 | float _eydz = eydz[plane]; 204 | 205 | float _mul = _weight * rsqrt(_var + eps); 206 | float count = float(num * sp); 207 | 208 | for (int batch = 0; batch < num; ++batch) { 209 | for (int n = threadIdx.x; n < sp; n += blockDim.x) { 210 | float _dz = __half2float(dz[(batch * chn + plane) * sp + n]); 211 | float _y = (__half2float(z[(batch * chn + plane) * sp + n]) - _bias) / _weight; 212 | 213 | dx[(batch * chn + plane) * sp + n] = __float2half((_dz - _edz / count - _y * _eydz / count) * _mul); 214 | } 215 | } 216 | } 217 | 218 | at::Tensor backward_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 219 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) { 220 | CHECK_CUDA_INPUT(z); 221 | CHECK_CUDA_INPUT(dz); 222 | CHECK_CUDA_INPUT(var); 223 | CHECK_CUDA_INPUT(weight); 224 | CHECK_CUDA_INPUT(bias); 225 | CHECK_CUDA_INPUT(edz); 226 | CHECK_CUDA_INPUT(eydz); 227 | 228 | // Extract dimensions 229 | int64_t num, chn, sp; 230 | get_dims(z, num, chn, sp); 231 | 232 | auto dx = at::zeros_like(z); 233 | 234 | // Run kernel 235 | dim3 blocks(chn); 236 | dim3 threads(getNumThreads(sp)); 237 | auto stream = at::cuda::getCurrentCUDAStream(); 238 | backward_kernel_h<<>>( 239 | reinterpret_cast(z.data()), 240 | reinterpret_cast(dz.data()), 241 | var.data(), 242 | weight.data(), 243 | bias.data(), 244 | edz.data(), 245 | eydz.data(), 246 | reinterpret_cast(dx.data()), 247 | affine, eps, num, chn, sp); 248 | 249 | return dx; 250 | } 251 | 252 | __global__ void leaky_relu_backward_impl_h(half *z, half *dz, float slope, int64_t count) { 253 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x){ 254 | float _z = __half2float(z[i]); 255 | if (_z < 0) { 256 | dz[i] = __float2half(__half2float(dz[i]) * slope); 257 | z[i] = __float2half(_z / slope); 258 | } 259 | } 260 | } 261 | 262 | void leaky_relu_backward_cuda_h(at::Tensor z, at::Tensor dz, float slope) { 263 | CHECK_CUDA_INPUT(z); 264 | CHECK_CUDA_INPUT(dz); 265 | 266 | int64_t count = z.numel(); 267 | dim3 threads(getNumThreads(count)); 268 | dim3 blocks = (count + threads.x - 1) / threads.x; 269 | auto stream = at::cuda::getCurrentCUDAStream(); 270 | leaky_relu_backward_impl_h<<>>( 271 | reinterpret_cast(z.data()), 272 | reinterpret_cast(dz.data()), 273 | slope, count); 274 | } 275 | 276 | -------------------------------------------------------------------------------- /inplace_abn/src/utils/checks.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | // Define AT_CHECK for old version of ATen where the same function was called AT_ASSERT 6 | #ifndef AT_CHECK 7 | #define AT_CHECK AT_ASSERT 8 | #endif 9 | 10 | #define CHECK_CUDA(x) AT_CHECK((x).type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CPU(x) AT_CHECK(!(x).type().is_cuda(), #x " must be a CPU tensor") 12 | #define CHECK_CONTIGUOUS(x) AT_CHECK((x).is_contiguous(), #x " must be contiguous") 13 | 14 | #define CHECK_CUDA_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 15 | #define CHECK_CPU_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x) -------------------------------------------------------------------------------- /inplace_abn/src/utils/common.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | /* 6 | * Functions to share code between CPU and GPU 7 | */ 8 | 9 | #ifdef __CUDACC__ 10 | // CUDA versions 11 | 12 | #define HOST_DEVICE __host__ __device__ 13 | #define INLINE_HOST_DEVICE __host__ __device__ inline 14 | #define FLOOR(x) floor(x) 15 | 16 | #if __CUDA_ARCH__ >= 600 17 | // Recent compute capabilities have block-level atomicAdd for all data types, so we use that 18 | #define ACCUM(x,y) atomicAdd_block(&(x),(y)) 19 | #else 20 | // Older architectures don't have block-level atomicAdd, nor atomicAdd for doubles, so we defer to atomicAdd for float 21 | // and use the known atomicCAS-based implementation for double 22 | template 23 | __device__ inline data_t atomic_add(data_t *address, data_t val) { 24 | return atomicAdd(address, val); 25 | } 26 | 27 | template<> 28 | __device__ inline double atomic_add(double *address, double val) { 29 | unsigned long long int* address_as_ull = (unsigned long long int*)address; 30 | unsigned long long int old = *address_as_ull, assumed; 31 | do { 32 | assumed = old; 33 | old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed))); 34 | } while (assumed != old); 35 | return __longlong_as_double(old); 36 | } 37 | 38 | #define ACCUM(x,y) atomic_add(&(x),(y)) 39 | #endif // #if __CUDA_ARCH__ >= 600 40 | 41 | #else 42 | // CPU versions 43 | 44 | #define HOST_DEVICE 45 | #define INLINE_HOST_DEVICE inline 46 | #define FLOOR(x) std::floor(x) 47 | #define ACCUM(x,y) (x) += (y) 48 | 49 | #endif // #ifdef __CUDACC__ -------------------------------------------------------------------------------- /inplace_abn/src/utils/cuda.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | /* 4 | * General settings and functions 5 | */ 6 | const int WARP_SIZE = 32; 7 | const int MAX_BLOCK_SIZE = 1024; 8 | 9 | static int getNumThreads(int nElem) { 10 | int threadSizes[6] = {32, 64, 128, 256, 512, MAX_BLOCK_SIZE}; 11 | for (int i = 0; i < 6; ++i) { 12 | if (nElem <= threadSizes[i]) { 13 | return threadSizes[i]; 14 | } 15 | } 16 | return MAX_BLOCK_SIZE; 17 | } 18 | 19 | /* 20 | * Reduction utilities 21 | */ 22 | template 23 | __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, 24 | unsigned int mask = 0xffffffff) { 25 | #if CUDART_VERSION >= 9000 26 | return __shfl_xor_sync(mask, value, laneMask, width); 27 | #else 28 | return __shfl_xor(value, laneMask, width); 29 | #endif 30 | } 31 | 32 | __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); } 33 | 34 | template 35 | struct Pair { 36 | T v1, v2; 37 | __device__ Pair() {} 38 | __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {} 39 | __device__ Pair(T v) : v1(v), v2(v) {} 40 | __device__ Pair(int v) : v1(v), v2(v) {} 41 | __device__ Pair &operator+=(const Pair &a) { 42 | v1 += a.v1; 43 | v2 += a.v2; 44 | return *this; 45 | } 46 | }; 47 | 48 | template 49 | static __device__ __forceinline__ T warpSum(T val) { 50 | #if __CUDA_ARCH__ >= 300 51 | for (int i = 0; i < getMSB(WARP_SIZE); ++i) { 52 | val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE); 53 | } 54 | #else 55 | __shared__ T values[MAX_BLOCK_SIZE]; 56 | values[threadIdx.x] = val; 57 | __threadfence_block(); 58 | const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE; 59 | for (int i = 1; i < WARP_SIZE; i++) { 60 | val += values[base + ((i + threadIdx.x) % WARP_SIZE)]; 61 | } 62 | #endif 63 | return val; 64 | } 65 | 66 | template 67 | static __device__ __forceinline__ Pair warpSum(Pair value) { 68 | value.v1 = warpSum(value.v1); 69 | value.v2 = warpSum(value.v2); 70 | return value; 71 | } -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hlzhu09/Hierarchical-Human-Parsing/58737dbb7763f5a7baafd4801c5a287facf536cd/modules/__init__.py -------------------------------------------------------------------------------- /modules/inits.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def uniform(size, tensor): 5 | bound = 1.0 / math.sqrt(size) 6 | if tensor is not None: 7 | tensor.data.uniform_(-bound, bound) 8 | 9 | 10 | def kaiming_uniform(tensor, fan, a): 11 | if tensor is not None: 12 | bound = math.sqrt(6 / ((1 + a**2) * fan)) 13 | tensor.data.uniform_(-bound, bound) 14 | 15 | 16 | def glorot(tensor): 17 | if tensor is not None: 18 | stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1))) 19 | tensor.data.uniform_(-stdv, stdv) 20 | 21 | 22 | def zeros(tensor): 23 | if tensor is not None: 24 | tensor.data.fill_(0) 25 | 26 | 27 | def ones(tensor): 28 | if tensor is not None: 29 | tensor.data.fill_(1) 30 | 31 | 32 | def normal(tensor, mean, std): 33 | if tensor is not None: 34 | tensor.data.normal_(mean, std) 35 | 36 | 37 | def reset(nn): 38 | def _reset(item): 39 | if hasattr(item, 'reset_parameters'): 40 | item.reset_parameters() 41 | 42 | if nn is not None: 43 | if hasattr(nn, 'children') and len(list(nn.children())) > 0: 44 | for item in nn.children(): 45 | _reset(item) 46 | else: 47 | _reset(nn) 48 | -------------------------------------------------------------------------------- /modules/parse_mod.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from inplace_abn.bn import InPlaceABNSync 8 | from modules.com_mod import SEModule, ContextContrastedModule 9 | 10 | BatchNorm2d = functools.partial(InPlaceABNSync, activation='none') 11 | 12 | class ASPPModule(nn.Module): 13 | """ASPP""" 14 | 15 | def __init__(self, in_dim, out_dim, scale=1): 16 | super(ASPPModule, self).__init__() 17 | self.gap = nn.Sequential(nn.AdaptiveAvgPool2d(1), 18 | nn.Conv2d(in_dim, out_dim, 1, bias=False), InPlaceABNSync(out_dim)) 19 | 20 | self.dilation_0 = nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=1, padding=0, dilation=1, bias=False), 21 | InPlaceABNSync(out_dim), SEModule(out_dim, reduction=16)) 22 | 23 | self.dilation_1 = nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=1, padding=0, dilation=1, bias=False), 24 | InPlaceABNSync(out_dim), 25 | nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=6, dilation=6, bias=False), 26 | InPlaceABNSync(out_dim),SEModule(out_dim, reduction=16)) 27 | 28 | self.dilation_2 = nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=1, padding=0, dilation=1, bias=False), 29 | InPlaceABNSync(out_dim), 30 | nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=12, dilation=12, bias=False), 31 | InPlaceABNSync(out_dim), SEModule(out_dim, reduction=16)) 32 | 33 | self.dilation_3 = nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=1, padding=0, dilation=1, bias=False), 34 | InPlaceABNSync(out_dim), 35 | nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=18, dilation=18, bias=False), 36 | InPlaceABNSync(out_dim), SEModule(out_dim, reduction=16)) 37 | 38 | self.psaa_conv = nn.Sequential(nn.Conv2d(in_dim + 5 * out_dim, out_dim, 1, padding=0, bias=False), 39 | InPlaceABNSync(out_dim), 40 | nn.Conv2d(out_dim, 5, 1, bias=True), 41 | nn.Sigmoid()) 42 | 43 | self.project = nn.Sequential(nn.Conv2d(out_dim * 5, out_dim, kernel_size=1, padding=0, bias=False), 44 | InPlaceABNSync(out_dim)) 45 | 46 | def forward(self, x): 47 | # parallel branch 48 | feat0 = self.dilation_0(x) 49 | feat1 = self.dilation_1(x) 50 | feat2 = self.dilation_2(x) 51 | feat3 = self.dilation_3(x) 52 | n, c, h, w = feat0.size() 53 | gp = self.gap(x) 54 | 55 | feat4 = gp.expand(n, c, h, w) 56 | # psaa 57 | y1 = torch.cat((x, feat0, feat1, feat2, feat3, feat4), 1) 58 | 59 | psaa_att = self.psaa_conv(y1) 60 | 61 | psaa_att_list = torch.split(psaa_att, 1, dim=1) 62 | 63 | y2 = torch.cat((psaa_att_list[0] * feat0, psaa_att_list[1] * feat1, psaa_att_list[2] * feat2, psaa_att_list[3] * feat3, psaa_att_list[4]*feat4), 1) 64 | out = self.project(y2) 65 | return out 66 | -------------------------------------------------------------------------------- /network/ResNet_stem_converter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | 4 | saved_state_dict = torch.load('../checkpoints/init/resnet101_stem.pth') 5 | 6 | new_state=OrderedDict() 7 | for k, v in saved_state_dict.items(): 8 | if k=='conv1.0.weight': 9 | new_state.update({'conv1.weight':v}) 10 | elif k=='conv1.1.weight': 11 | new_state.update({'bn1.weight': v}) 12 | elif k=='conv1.1.bias': 13 | new_state.update({'bn1.bias': v}) 14 | elif k=='conv1.1.running_mean': 15 | new_state.update({'bn1.running_mean': v}) 16 | elif k=='conv1.1.running_var': 17 | new_state.update({'bn1.running_var': v}) 18 | elif k=='conv1.3.weight': 19 | new_state.update({'conv2.weight': v}) 20 | elif k=='conv1.4.weight': 21 | new_state.update({'bn2.weight':v}) 22 | elif k=='conv1.4.bias': 23 | new_state.update({'bn2.bias': v}) 24 | elif k=='conv1.4.running_mean': 25 | new_state.update({'bn2.running_mean': v}) 26 | elif k=='conv1.4.running_var': 27 | new_state.update({'bn2.running_var': v}) 28 | elif k=='conv1.6.weight': 29 | new_state.update({'conv3.weight':v}) 30 | elif k=='bn1.weight': 31 | new_state.update({'bn3.weight': v}) 32 | elif k=='bn1.bias': 33 | new_state.update({'bn3.bias': v}) 34 | elif k=='bn1.running_mean': 35 | new_state.update({'bn3.running_mean': v}) 36 | elif k=='bn1.running_var': 37 | new_state.update({'bn3.running_var': v}) 38 | else: 39 | new_state.update({k: v}) 40 | 41 | 42 | 43 | torch.save(new_state, '../checkpoints/init/new_resnet101_stem.pth') 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hlzhu09/Hierarchical-Human-Parsing/58737dbb7763f5a7baafd4801c5a287facf536cd/network/__init__.py -------------------------------------------------------------------------------- /network/baseline.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | from inplace_abn.bn import InPlaceABNSync 8 | from modules.com_mod import Bottleneck, ResGridNet, SEModule 9 | from modules.parse_mod import ASPPModule 10 | 11 | BatchNorm2d = functools.partial(InPlaceABNSync, activation='none') 12 | 13 | class DecoderModule(nn.Module): 14 | 15 | def __init__(self, num_classes): 16 | super(DecoderModule, self).__init__() 17 | self.conv0 = nn.Sequential(nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False), 18 | BatchNorm2d(256), nn.ReLU(inplace=False)) 19 | self.se = nn.Sequential(nn.AdaptiveAvgPool2d(1), 20 | nn.Conv2d(256, 256, 1, bias=False), 21 | nn.ReLU(True), 22 | nn.Conv2d(256, 256, 1, bias=True), 23 | nn.Sigmoid()) 24 | def forward(self, x): 25 | out=self.conv0(x) 26 | out = out + self.se(out)*out 27 | return out 28 | 29 | 30 | class GNN_infer(nn.Module): 31 | def __init__(self, adj_matrix, upper_half_node=[1, 2, 3, 4], lower_half_node=[5, 6], in_dim=256, hidden_dim=64, 32 | cls_p=7, cls_h=3, cls_f=2): 33 | super(GNN_infer, self).__init__() 34 | self.cls_p = cls_p 35 | self.cls_h = cls_h 36 | self.cls_f = cls_f 37 | self.in_dim = in_dim 38 | self.hidden_dim = hidden_dim 39 | 40 | # node feature transform 41 | self.p_conv = nn.Sequential( 42 | nn.Conv2d(in_dim, hidden_dim * cls_p, kernel_size=1, padding=0, stride=1, bias=False), 43 | BatchNorm2d(hidden_dim * cls_p), nn.ReLU(inplace=False)) 44 | self.h_conv = nn.Sequential( 45 | nn.Conv2d(in_dim, hidden_dim * cls_h, kernel_size=1, padding=0, stride=1, bias=False), 46 | BatchNorm2d(hidden_dim * cls_h), nn.ReLU(inplace=False)) 47 | self.f_conv = nn.Sequential( 48 | nn.Conv2d(in_dim, hidden_dim * cls_f, kernel_size=1, padding=0, stride=1, bias=False), 49 | BatchNorm2d(hidden_dim * cls_f), nn.ReLU(inplace=False)) 50 | 51 | # node supervision 52 | self.node_seg = nn.Conv2d(hidden_dim, 1, 1) 53 | 54 | def forward(self, xp, xh, xf): 55 | # gnn inference at stride 8 56 | # feature transform 57 | f_node_list = list(torch.split(self.f_conv(xf), self.hidden_dim, dim=1)) 58 | h_node_list = list(torch.split(self.h_conv(xh), self.hidden_dim, dim=1)) 59 | p_node_list = list(torch.split(self.p_conv(xp), self.hidden_dim, dim=1)) 60 | 61 | # node supervision 62 | f_seg = torch.cat([self.node_seg(node) for node in f_node_list], dim=1) 63 | h_seg = torch.cat([self.node_seg(node) for node in h_node_list], dim=1) 64 | p_seg = torch.cat([self.node_seg(node) for node in p_node_list], dim=1) 65 | 66 | return [p_seg], [h_seg], [f_seg], [], [], [ 67 | ], [], [], [], [] 68 | 69 | 70 | 71 | class Decoder(nn.Module): 72 | def __init__(self, num_classes=7, hbody_cls=3, fbody_cls=2): 73 | super(Decoder, self).__init__() 74 | self.layer5 = ASPPModule(2048, 512) 75 | self.layer_part = DecoderModule(num_classes) 76 | self.layer_half = DecoderModule(hbody_cls) 77 | self.layer_full = DecoderModule(fbody_cls) 78 | 79 | self.layer_dsn = nn.Sequential(nn.Conv2d(1024, 256, kernel_size=3, stride=1, padding=1), 80 | BatchNorm2d(256), nn.ReLU(inplace=False), 81 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1, padding=0, bias=True)) 82 | 83 | self.skip = nn.Sequential(nn.Conv2d(512, 512, kernel_size=1, padding=0, bias=False), 84 | BatchNorm2d(512), nn.ReLU(inplace=False), 85 | ) 86 | self.fuse = nn.Sequential(nn.Conv2d(1024, 512, kernel_size=3, padding=1, bias=False), 87 | BatchNorm2d(512), nn.ReLU(inplace=False)) 88 | 89 | 90 | # adjacent matrix for pascal person 91 | self.adj_matrix = torch.tensor( 92 | [[0, 1, 0, 0, 0, 0], [1, 0, 1, 0, 1, 0], [0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0], [0, 1, 0, 0, 0, 1], 93 | [0, 0, 0, 0, 1, 0]], requires_grad=False) 94 | 95 | # infer with hierarchical person graph 96 | self.gnn_infer = GNN_infer(adj_matrix=self.adj_matrix, upper_half_node=[1, 2, 3, 4], lower_half_node=[5, 6], 97 | in_dim=256, hidden_dim=32, cls_p=7, cls_h=3, cls_f=2) 98 | # aux layer 99 | self.layer_dsn = nn.Sequential(nn.Conv2d(1024, 256, kernel_size=3, stride=1, padding=1), 100 | BatchNorm2d(256), nn.ReLU(inplace=False), 101 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1, padding=0, bias=True)) 102 | 103 | def forward(self, x): 104 | x_dsn = self.layer_dsn(x[-2]) 105 | _,_,h,w = x[1].size() 106 | context = self.layer5(x[-1]) 107 | context = F.interpolate(context, size=(h, w), mode='bilinear', align_corners=True) 108 | context = self.fuse(torch.cat([self.skip(x[1]), context], dim=1)) 109 | 110 | p_fea = self.layer_part(context) 111 | h_fea = self.layer_half(context) 112 | f_fea = self.layer_full(context) 113 | 114 | # gnn infer 115 | p_seg, h_seg, f_seg, decomp_map_f, decomp_map_u, decomp_map_l, comp_map_f, comp_map_u, comp_map_l, \ 116 | Fdep_att_list= self.gnn_infer(p_fea, h_fea, f_fea) 117 | 118 | return p_seg, h_seg, f_seg, decomp_map_f, decomp_map_u, decomp_map_l, comp_map_f, comp_map_u, comp_map_l, \ 119 | Fdep_att_list, x_dsn 120 | 121 | class OCNet(nn.Module): 122 | def __init__(self, block, layers, num_classes): 123 | super(OCNet, self).__init__() 124 | self.encoder = ResGridNet(block, layers) 125 | self.decoder = Decoder(num_classes=num_classes) 126 | 127 | for m in self.modules(): 128 | if isinstance(m, nn.Conv2d): 129 | nn.init.kaiming_normal_(m.weight.data) 130 | elif isinstance(m, InPlaceABNSync): 131 | m.weight.data.fill_(1) 132 | m.bias.data.zero_() 133 | 134 | def forward(self, x): 135 | x = self.encoder(x) 136 | x = self.decoder(x) 137 | return x 138 | 139 | def get_model(num_classes=20): 140 | model = OCNet(Bottleneck, [3, 4, 23, 3], num_classes) # 101 141 | return model 142 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.2.0 2 | torchvision==0.4.0 3 | numpy 4 | opencv-python 5 | tqdm 6 | -------------------------------------------------------------------------------- /train_pascal.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train_baseline.py --init --method baseline --crop-size 473 --batch-size 20 --learning-rate 1e-2 -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hlzhu09/Hierarchical-Human-Parsing/58737dbb7763f5a7baafd4801c5a287facf536cd/utils/__init__.py -------------------------------------------------------------------------------- /utils/aaf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hlzhu09/Hierarchical-Human-Parsing/58737dbb7763f5a7baafd4801c5a287facf536cd/utils/aaf/__init__.py -------------------------------------------------------------------------------- /utils/aaf/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hlzhu09/Hierarchical-Human-Parsing/58737dbb7763f5a7baafd4801c5a287facf536cd/utils/aaf/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/aaf/__pycache__/layers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hlzhu09/Hierarchical-Human-Parsing/58737dbb7763f5a7baafd4801c5a287facf536cd/utils/aaf/__pycache__/layers.cpython-36.pyc -------------------------------------------------------------------------------- /utils/aaf/__pycache__/losses.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hlzhu09/Hierarchical-Human-Parsing/58737dbb7763f5a7baafd4801c5a287facf536cd/utils/aaf/__pycache__/losses.cpython-36.pyc -------------------------------------------------------------------------------- /utils/aaf/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | def eightway_activation(x): 6 | """Retrieves neighboring pixels/features on the eight corners from 7 | a 3x3 patch. 8 | 9 | Args: 10 | x: A tensor of size [batch_size, height_in, width_in, channels] 11 | 12 | Returns: 13 | A tensor of size [batch_size, height_in, width_in, channels, 8] 14 | """ 15 | # Get the number of channels in the input. 16 | shape_x = list(x.shape) 17 | if len(shape_x) != 4: 18 | raise ValueError('Only support for 4-D tensors!') 19 | 20 | # Pad at the margin. 21 | x = F.pad(x, 22 | pad=(0,0,1,1,1,1,0,0), 23 | mode='reflect') 24 | # Get eight neighboring pixels/features. 25 | x_groups = [ 26 | x[:, 1:-1, :-2, :].clone(), # left 27 | x[:, 1:-1, 2:, :].clone(), # right 28 | x[:, :-2, 1:-1, :].clone(), # up 29 | x[:, 2:, 1:-1, :].clone(), # down 30 | x[:, :-2, :-2, :].clone(), # left-up 31 | x[:, 2:, :-2, :].clone(), # left-down 32 | x[:, :-2, 2:, :].clone(), # right-up 33 | x[:, 2:, 2:, :].clone() # right-down 34 | ] 35 | output = [ 36 | torch.unsqueeze(c, dim=-1) for c in x_groups 37 | ] 38 | output = torch.cat(output, dim=-1) 39 | 40 | return output 41 | 42 | 43 | def eightcorner_activation(x, size): 44 | """Retrieves neighboring pixels one the eight corners from a 45 | (2*size+1)x(2*size+1) patch. 46 | 47 | Args: 48 | x: A tensor of size [batch_size, height_in, width_in, channels] 49 | size: A number indicating the half size of a patch. 50 | 51 | Returns: 52 | A tensor of size [batch_size, height_in, width_in, channels, 8] 53 | """ 54 | # Get the number of channels in the input. 55 | shape_x = list(x.shape) 56 | if len(shape_x) != 4: 57 | raise ValueError('Only support for 4-D tensors!') 58 | n, c, h, w = shape_x 59 | 60 | # Pad at the margin. 61 | p = size 62 | x_pad = F.pad(x, 63 | pad=(p,p,p,p,0,0,0,0), 64 | mode='constant', 65 | value=0) 66 | 67 | # Get eight corner pixels/features in the patch. 68 | x_groups = [] 69 | for st_y in range(0,2*size+1,size): 70 | for st_x in range(0,2*size+1,size): 71 | if st_y == size and st_x == size: 72 | # Ignore the center pixel/feature. 73 | continue 74 | 75 | x_neighbor = x_pad[:, :, st_y:st_y+h, st_x:st_x+w].clone() 76 | x_groups.append(x_neighbor) 77 | 78 | output = [torch.unsqueeze(c, dim=-1) for c in x_groups] 79 | output = torch.cat(output, dim=-1) 80 | 81 | return output 82 | 83 | 84 | def ignores_from_label(labels, num_classes, size, ignore_index): 85 | """Retrieves ignorable pixels from the ground-truth labels. 86 | 87 | This function returns a binary map in which 1 denotes ignored pixels 88 | and 0 means not ignored ones. For those ignored pixels, they are not 89 | only the pixels with label value >= num_classes, but also the 90 | corresponding neighboring pixels, which are on the the eight cornerls 91 | from a (2*size+1)x(2*size+1) patch. 92 | 93 | Args: 94 | labels: A tensor of size [batch_size, height_in, width_in], indicating 95 | semantic segmentation ground-truth labels. 96 | num_classes: A number indicating the total number of valid classes. The 97 | labels ranges from 0 to (num_classes-1), and any value >= num_classes 98 | would be ignored. 99 | size: A number indicating the half size of a patch. 100 | 101 | Return: 102 | A tensor of size [batch_size, height_in, width_in, 8] 103 | """ 104 | # Get the number of channels in the input. 105 | shape_lab = list(labels.shape) 106 | if len(shape_lab) != 3: 107 | raise ValueError('Only support for 3-D label tensors!') 108 | n, h, w = shape_lab 109 | 110 | # Retrieve ignored pixels with label value >= num_classes. 111 | # ignore = labels>num_classes-1 # NxHxW 112 | ignore = (labels==ignore_index) 113 | 114 | # Pad at the margin. 115 | p = size 116 | ignore_pad = F.pad(ignore, 117 | pad=(p,p,p,p,0,0), 118 | mode='constant', 119 | value=1) 120 | 121 | # Retrieve eight corner pixels from the center, where the center 122 | # is ignored. Note that it should be bi-directional. For example, 123 | # when computing AAF loss with top-left pixels, the ignored pixels 124 | # might be the center or the top-left ones. 125 | ignore_groups= [] 126 | for st_y in range(2*size,-1,-size): 127 | for st_x in range(2*size,-1,-size): 128 | if st_y == size and st_x == size: 129 | continue 130 | ignore_neighbor = ignore_pad[:,st_y:st_y+h,st_x:st_x+w].clone() 131 | mask = ignore_neighbor | ignore 132 | ignore_groups.append(mask) 133 | 134 | ig = 0 135 | for st_y in range(0,2*size+1,size): 136 | for st_x in range(0,2*size+1,size): 137 | if st_y == size and st_x == size: 138 | continue 139 | ignore_neighbor = ignore_pad[:,st_y:st_y+h,st_x:st_x+w].clone() 140 | mask = ignore_neighbor | ignore_groups[ig] 141 | ignore_groups[ig] = mask 142 | ig += 1 143 | 144 | ignore_groups = [ 145 | torch.unsqueeze(c, dim=-1) for c in ignore_groups 146 | ] # NxHxWx1 147 | ignore = torch.cat(ignore_groups, dim=-1) #NxHxWx8 148 | 149 | return ignore 150 | 151 | 152 | def edges_from_label(labels, size, ignore_class=255): 153 | """Retrieves edge positions from the ground-truth labels. 154 | 155 | This function computes the edge map by considering if the pixel values 156 | are equal between the center and the neighboring pixels on the eight 157 | corners from a (2*size+1)*(2*size+1) patch. Ignore edges where the any 158 | of the paired pixels with label value >= num_classes. 159 | 160 | Args: 161 | labels: A tensor of size [batch_size, height_in, width_in], indicating 162 | semantic segmentation ground-truth labels. 163 | size: A number indicating the half size of a patch. 164 | ignore_class: A number indicating the label value to ignore. 165 | 166 | Return: 167 | A tensor of size [batch_size, height_in, width_in, 1, 8] 168 | """ 169 | # Get the number of channels in the input. 170 | shape_lab = list(labels.shape) 171 | if len(shape_lab) != 4: 172 | raise ValueError('Only support for 4-D label tensors!') 173 | n, h, w, c = shape_lab 174 | 175 | # Pad at the margin. 176 | p = size 177 | labels_pad = F.pad( 178 | labels, pad=(0,0,p,p,p,p,0,0), 179 | mode='constant', 180 | value=ignore_class) 181 | 182 | # Get the edge by comparing label value of the center and it paired pixels. 183 | edge_groups= [] 184 | for st_y in range(0,2*size+1,size): 185 | for st_x in range(0,2*size+1,size): 186 | if st_y == size and st_x == size: 187 | continue 188 | labels_neighbor = labels_pad[:,st_y:st_y+h,st_x:st_x+w] 189 | edge = labels_neighbor!=labels 190 | edge_groups.append(edge) 191 | 192 | edge_groups = [ 193 | torch.unsqueeze(c, dim=-1) for c in edge_groups 194 | ] # NxHxWx1x1 195 | edge = torch.cat(edge_groups, dim=-1) #NxHxWx1x8 196 | 197 | return edge 198 | -------------------------------------------------------------------------------- /utils/aaf/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import utils.aaf.layers as nnx 4 | import numpy as np 5 | 6 | def affinity_loss(labels, 7 | probs, 8 | num_classes, 9 | kld_margin): 10 | """Affinity Field (AFF) loss. 11 | 12 | This function computes AFF loss. There are several components in the 13 | function: 14 | 1) extracts edges from the ground-truth labels. 15 | 2) extracts ignored pixels and their paired pixels (the neighboring 16 | pixels on the eight corners). 17 | 3) extracts neighboring pixels on the eight corners from a 3x3 patch. 18 | 4) computes KL-Divergence between center pixels and their neighboring 19 | pixels from the eight corners. 20 | 21 | Args: 22 | labels: A tensor of size [batch_size, height_in, width_in], indicating 23 | semantic segmentation ground-truth labels. 24 | probs: A tensor of size [batch_size, height_in, width_in, num_classes], 25 | indicating segmentation predictions. 26 | num_classes: A number indicating the total number of valid classes. 27 | kld_margin: A number indicating the margin for KL-Divergence at edge. 28 | 29 | Returns: 30 | Two 1-D tensors value indicating the loss at edge and non-edge. 31 | """ 32 | # Compute ignore map (e.g, label of 255 and their paired pixels). 33 | 34 | labels = torch.squeeze(labels, dim=1) # NxHxW 35 | ignore = nnx.ignores_from_label(labels, num_classes, 1) # NxHxWx8 36 | not_ignore = np.logical_not(ignore) 37 | not_ignore = torch.unsqueeze(not_ignore, dim=3) # NxHxWx1x8 38 | 39 | # Compute edge map. 40 | one_hot_lab = F.one_hot(labels, depth=num_classes) 41 | edge = nnx.edges_from_label(one_hot_lab, 1, 255) # NxHxWxCx8 42 | 43 | # Remove ignored pixels from the edge/non-edge. 44 | edge = np.logical_and(edge, not_ignore) 45 | not_edge = np.logical_and(np.logical_not(edge), not_ignore) 46 | 47 | edge_indices = torch.nonzero(torch.reshape(edge, (-1,))) 48 | not_edge_indices = torch.nonzero(torch.reshape(not_edge, (-1,))) 49 | 50 | # Extract eight corner from the center in a patch as paired pixels. 51 | probs_paired = nnx.eightcorner_activation(probs, 1) # NxHxWxCx8 52 | probs = torch.unsqueeze(probs, dim=-1) # NxHxWxCx1 53 | bot_epsilon = 1e-4 54 | top_epsilon = 1.0 55 | 56 | neg_probs = np.clip( 57 | 1-probs, bot_epsilon, top_epsilon) 58 | neg_probs_paired = np.clip( 59 | 1-probs_paired, bot_epsilon, top_epsilon) 60 | probs = np.clip( 61 | probs, bot_epsilon, top_epsilon) 62 | probs_paired = np.clip( 63 | probs_paired, bot_epsilon, top_epsilon) 64 | 65 | # Compute KL-Divergence. 66 | kldiv = probs_paired*torch.log(probs_paired/probs) 67 | kldiv += neg_probs_paired*torch.log(neg_probs_paired/neg_probs) 68 | edge_loss = torch.max(0.0, kld_margin-kldiv) 69 | not_edge_loss = kldiv 70 | 71 | 72 | not_edge_loss = torch.reshape(not_edge_loss, (-1,)) 73 | not_edge_loss = torch.gather(not_edge_loss, 0, not_edge_indices) 74 | edge_loss = torch.reshape(edge_loss, (-1,)) 75 | edge_loss = torch.gather(edge_loss, 0, edge_indices) 76 | 77 | return edge_loss, not_edge_loss 78 | 79 | 80 | def adaptive_affinity_loss(labels, 81 | one_hot_lab, 82 | probs, 83 | size, 84 | num_classes, 85 | kld_margin, 86 | w_edge, 87 | w_not_edge, 88 | ignore_index=255): 89 | """Adaptive affinity field (AAF) loss. 90 | 91 | This function computes AAF loss. There are three components in the function: 92 | 1) extracts edges from the ground-truth labels. 93 | 2) extracts ignored pixels and their paired pixels (usually the eight corner 94 | pixels). 95 | 3) extracts eight corner pixels/predictions from the center in a 96 | (2*size+1)x(2*size+1) patch 97 | 4) computes KL-Divergence between center pixels and their paired pixels (the 98 | eight corner). 99 | 5) imposes adaptive weightings on the loss. 100 | 101 | Args: 102 | labels: A tensor of size [batch_size, height_in, width_in], indicating 103 | semantic segmentation ground-truth labels. 104 | one_hot_lab: A tensor of size [batch_size, num_classes, height_in, width_in] 105 | which is the ground-truth labels in the form of one-hot vector. 106 | probs: A tensor of size [batch_size, num_classes, height_in, width_in], 107 | indicating segmentation predictions. 108 | size: A number indicating the half size of a patch. 109 | num_classes: A number indicating the total number of valid classes. The 110 | kld_margin: A number indicating the margin for KL-Divergence at edge. 111 | w_edge: A number indicating the weighting for KL-Divergence at edge. 112 | w_not_edge: A number indicating the weighting for KL-Divergence at non-edge. 113 | ignore_index: ignore index 114 | 115 | Returns: 116 | Two 1-D tensors value indicating the loss at edge and non-edge. 117 | """ 118 | # Compute ignore map (e.g, label of 255 and their paired pixels). 119 | labels = torch.squeeze(labels, dim=1) # NxHxW 120 | ignore = nnx.ignores_from_label(labels, num_classes, size, ignore_index) # NxHxWx8 121 | not_ignore = ~ignore 122 | not_ignore = torch.unsqueeze(not_ignore, dim=3) # NxHxWx1x8 123 | 124 | # Compute edge map. 125 | edge = nnx.edges_from_label(one_hot_lab, size, ignore_index) # NxHxWxCx8 126 | 127 | # Remove ignored pixels from the edge/non-edge. 128 | edge = edge & not_ignore 129 | not_edge = ~edge & not_ignore 130 | 131 | edge_indices = torch.nonzero(torch.reshape(edge, (-1,))) 132 | # print(edge_indices.size()) 133 | if edge_indices.size()[0]==0: 134 | edge_loss=torch.tensor(0.0, requires_grad=False).cuda() 135 | not_edge_loss=torch.tensor(0.0, requires_grad=False).cuda() 136 | return edge_loss, not_edge_loss 137 | 138 | not_edge_indices = torch.nonzero(torch.reshape(not_edge, (-1,))) 139 | 140 | # Extract eight corner from the center in a patch as paired pixels. 141 | probs_paired = nnx.eightcorner_activation(probs, size) # NxHxWxCx8 142 | probs = torch.unsqueeze(probs, dim=-1) # NxHxWxCx1 143 | bot_epsilon = torch.tensor(1e-4, requires_grad=False).cuda() 144 | top_epsilon = torch.tensor(1.0, requires_grad=False).cuda() 145 | 146 | neg_probs = torch.where(1-probs < bot_epsilon, bot_epsilon, 1-probs) 147 | neg_probs = torch.where(neg_probs > top_epsilon, top_epsilon, neg_probs) 148 | 149 | neg_probs_paired = torch.where(1 - probs_paired < bot_epsilon, bot_epsilon, 1 - probs_paired) 150 | neg_probs_paired = torch.where(neg_probs_paired > top_epsilon, top_epsilon, neg_probs_paired) 151 | 152 | probs = torch.where(probs < bot_epsilon, bot_epsilon, probs) 153 | probs = torch.where(probs > top_epsilon, top_epsilon, probs) 154 | 155 | probs_paired = torch.where(probs_paired < bot_epsilon, bot_epsilon, probs_paired) 156 | probs_paired = torch.where(probs_paired > top_epsilon, top_epsilon, probs_paired) 157 | 158 | # neg_probs = np.clip( 159 | # 1-probs, bot_epsilon, top_epsilon) 160 | # neg_probs_paired = np.clip( 161 | # 1-probs_paired, bot_epsilon, top_epsilon) 162 | # probs = np.clip( 163 | # probs, bot_epsilon, top_epsilon) 164 | # probs_paired = np.clip( 165 | # probs_paired, bot_epsilon, top_epsilon) 166 | 167 | # Compute KL-Divergence. 168 | kldiv = probs_paired*torch.log(probs_paired/probs) 169 | kldiv += neg_probs_paired*torch.log(neg_probs_paired/neg_probs) 170 | edge_loss = torch.max(torch.tensor(0.0, requires_grad=False).cuda(), kld_margin-kldiv) 171 | not_edge_loss = kldiv 172 | 173 | # Impose weights on edge/non-edge losses. 174 | one_hot_lab = torch.unsqueeze(one_hot_lab, dim=-1) 175 | w_edge = torch.sum(w_edge*one_hot_lab.float(), dim=3, keepdim=True) # NxHxWx1x1 176 | w_not_edge = torch.sum(w_not_edge*one_hot_lab.float(), dim=3, keepdim=True) # NxHxWx1x1 177 | 178 | edge_loss *= w_edge.permute(0,3,1,2,4) 179 | not_edge_loss *= w_not_edge.permute(0,3,1,2,4) 180 | 181 | not_edge_loss = torch.reshape(not_edge_loss, (-1,1)) 182 | not_edge_loss = torch.gather(not_edge_loss, 0, not_edge_indices) 183 | edge_loss = torch.reshape(edge_loss, (-1,1)) 184 | edge_loss = torch.gather(edge_loss, 0, edge_indices) 185 | 186 | return edge_loss, not_edge_loss 187 | -------------------------------------------------------------------------------- /utils/best/lovasz_loss.py: -------------------------------------------------------------------------------- 1 | from itertools import filterfalse as ifilterfalse 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | 9 | class LovaszSoftmaxLoss(nn.Module): 10 | """Multi-class Lovasz-Softmax loss. 11 | :param only_present: average only on classes present in ground truth. 12 | :param per_image: calculate the loss in image separately. 13 | :param ignore_index: 14 | """ 15 | 16 | def __init__(self, ignore_index=None, only_present=False, per_image=False): 17 | super(LovaszSoftmaxLoss, self).__init__() 18 | self.ignore_index = ignore_index 19 | self.only_present = only_present 20 | self.per_image = per_image 21 | self.weight = torch.FloatTensor([0.80777327, 1.00125961, 0.90997236, 1.10867908, 1.17541499, 22 | 0.86041422, 1.01116758, 0.89290045, 1.12410812, 0.91105395, 23 | 1.07604013, 1.12470610, 1.09895196, 0.90172057, 0.93529453, 24 | 0.93054733, 1.04919178, 1.04937547, 1.06267568, 1.06365688]) 25 | self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, weight=self.weight) 26 | 27 | def forward(self, preds, targets): 28 | h, w = targets.size(1), targets.size(2) 29 | # seg loss 30 | pred = F.interpolate(input=preds[0], size=(h, w), mode='bilinear', align_corners=True) 31 | pred = F.softmax(input=pred, dim=1) 32 | if self.per_image: 33 | loss = mean(lovasz_softmax_flat(*flatten_probas(pre.unsqueeze(0), tar.unsqueeze(0), self.ignore_index), 34 | only_present=self.only_present) for pre, tar in zip(pred, targets)) 35 | else: 36 | loss = lovasz_softmax_flat(*flatten_probas(pred, targets, self.ignore_index), 37 | only_present=self.only_present) 38 | # dsn loss 39 | pred_dsn = F.interpolate(input=preds[1], size=(h, w), mode='bilinear', align_corners=True) 40 | loss_dsn = self.criterion(pred_dsn, targets) 41 | return loss + 0.4 * loss_dsn 42 | 43 | 44 | def lovasz_softmax_flat(preds, targets, only_present=False): 45 | """ 46 | Multi-class Lovasz-Softmax loss 47 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) 48 | labels: [P] Tensor, ground truth labels (between 0 and C - 1) 49 | only_present: average only on classes present in ground truth 50 | """ 51 | if preds.numel() == 0: 52 | # only void pixels, the gradients should be 0 53 | return preds * 0. 54 | 55 | C = preds.size(1) 56 | losses = [] 57 | for c in range(C): 58 | fg = (targets == c).float() # foreground for class c 59 | if only_present and fg.sum() == 0: 60 | continue 61 | errors = (Variable(fg) - preds[:, c]).abs() 62 | errors_sorted, perm = torch.sort(errors, 0, descending=True) 63 | perm = perm.data 64 | fg_sorted = fg[perm] 65 | losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) 66 | return mean(losses) 67 | 68 | 69 | def lovasz_grad(gt_sorted): 70 | """ 71 | Computes gradient of the Lovasz extension w.r.t sorted errors 72 | See Alg. 1 in paper 73 | """ 74 | p = len(gt_sorted) 75 | gts = gt_sorted.sum() 76 | intersection = gts - gt_sorted.float().cumsum(0) 77 | union = gts + (1 - gt_sorted).float().cumsum(0) 78 | jaccard = 1. - intersection / union 79 | if p > 1: # cover 1-pixel case 80 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 81 | return jaccard 82 | 83 | 84 | def flatten_probas(preds, targets, ignore=None): 85 | """ 86 | Flattens predictions in the batch 87 | """ 88 | B, C, H, W = preds.size() 89 | preds = preds.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C 90 | targets = targets.view(-1) 91 | if ignore is None: 92 | return preds, targets 93 | valid = (targets != ignore) 94 | vprobas = preds[valid.nonzero().squeeze()] 95 | vlabels = targets[valid] 96 | return vprobas, vlabels 97 | 98 | 99 | def mean(l, ignore_nan=True, empty=0): 100 | """ 101 | nan mean compatible with generators. 102 | """ 103 | l = iter(l) 104 | if ignore_nan: 105 | l = ifilterfalse(isnan, l) 106 | try: 107 | n = 1 108 | acc = next(l) 109 | except StopIteration: 110 | if empty == 'raise': 111 | raise ValueError('Empty mean') 112 | return empty 113 | for n, v in enumerate(l, 2): 114 | acc += v 115 | if n == 1: 116 | return acc 117 | return acc / n 118 | 119 | 120 | def isnan(x): 121 | return x != x 122 | -------------------------------------------------------------------------------- /utils/learning_policy.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | # poly lr 4 | # def adjust_learning_rate(optimizer, epoch, i_iter, iters_per_epoch, method='poly'): 5 | # if method == 'poly': 6 | # current_step = epoch * iters_per_epoch + i_iter 7 | # max_step = args.epochs * iters_per_epoch 8 | # lr = args.learning_rate * ((1 - current_step / max_step) ** 0.9) 9 | # else: 10 | # lr = args.learning_rate 11 | # optimizer.param_groups[0]['lr'] = lr 12 | # return lr 13 | 14 | def cosine_decay(base_learning_rate, global_step, warm_step, decay_steps, alpha=0.0001): 15 | # warm_step = 5 * iters_per_epoch 16 | # warm_lr = 0.01 * learning_rate 17 | # current_step = epoch * iters_per_epoch + i_iter 18 | alpha = alpha/base_learning_rate 19 | if global_step < warm_step: 20 | lr = base_learning_rate*global_step/warm_step 21 | # lr = base_learning_rate 22 | else: 23 | global_step = min(global_step, decay_steps)-warm_step 24 | cosine_decay = 0.5 * (1 + math.cos(math.pi * global_step / (decay_steps-warm_step))) 25 | decayed = (1 - alpha) * cosine_decay + alpha 26 | lr = base_learning_rate * decayed 27 | return lr 28 | 29 | 30 | def restart_cosine_decay(base_learning_rate, global_step, warm_step, decay_steps, alpha=0.0001): 31 | # warm_step = 5 * iters_per_epoch 32 | # warm_lr = 0.01 * learning_rate 33 | # current_step = epoch * iters_per_epoch + i_iter 34 | alpha = alpha/base_learning_rate 35 | restart_step = int((warm_step+decay_steps)/2) 36 | if global_step < warm_step: 37 | lr = base_learning_rate*global_step/warm_step 38 | elif global_step = 0) & (label < n) 52 | return np.bincount( 53 | n * label[k].astype(int) + pred[k], minlength=n ** 2).reshape(n, n) 54 | 55 | 56 | def per_class_iu(hist): 57 | return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) 58 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | # colour map 5 | label_colours = [(0, 0, 0) 6 | # 0=Background 7 | , (128, 0, 0), (255, 0, 0), (0, 85, 0), (170, 0, 51), (255, 85, 0) 8 | # 1=Hat, 2=Hair, 3=Glove, 4=Sunglasses, 5=UpperClothes 9 | , (0, 0, 85), (0, 119, 221), (85, 85, 0), (0, 85, 85), (85, 51, 0) 10 | # 6=Dress, 7=Coat, 8=Socks, 9=Pants, 10=Jumpsuits 11 | , (52, 86, 128), (0, 128, 0), (0, 0, 255), (51, 170, 221), (0, 255, 255) 12 | # 11=Scarf, 12=Skirt, 13=Face, 14=LeftArm, 15=RightArm 13 | , (85, 255, 170), (170, 255, 85), (255, 255, 0), (255, 170, 0)] 14 | # 16=LeftLeg, 17=RightLeg, 18=LeftShoe, 19=RightShoe 15 | 16 | 17 | pascal_person = [(0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), (0, 128, 128)] 18 | 19 | 20 | def decode_predictions(preds, num_images=4, num_classes=20): 21 | """Decode batch of segmentation masks. 22 | """ 23 | preds = preds.data.cpu().numpy() 24 | n, h, w = preds.shape 25 | assert n >= num_images 26 | outputs = np.zeros((num_images, h, w, 3), dtype=np.uint8) 27 | for i in range(num_images): 28 | img = Image.new('RGB', (len(preds[i, 0]), len(preds[i]))) 29 | pixels = img.load() 30 | for j_, j in enumerate(preds[i, :, :]): 31 | for k_, k in enumerate(j): 32 | if k < num_classes: 33 | pixels[k_, j_] = label_colours[k] 34 | outputs[i] = np.array(img) 35 | return outputs 36 | 37 | 38 | def inv_preprocess(imgs, num_images=4): 39 | """Inverse preprocessing of the batch of images. 40 | """ 41 | mean = (104.00698793, 116.66876762, 122.67891434) 42 | imgs = imgs.data.cpu().numpy() 43 | n, c, h, w = imgs.shape 44 | assert n >= num_images 45 | outputs = np.zeros((num_images, h, w, c), dtype=np.uint8) 46 | for i in range(num_images): 47 | outputs[i] = (np.transpose(imgs[i], (1, 2, 0)) + mean)[:, :, ::-1].astype(np.uint8) 48 | return outputs 49 | -------------------------------------------------------------------------------- /val/evaluate_atr.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import cv2 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.autograd import Variable 10 | from torch.utils import data 11 | 12 | from dataset.data_atr import ATRTestGenerator as TestGenerator 13 | from network.baseline import get_model 14 | 15 | 16 | def get_arguments(): 17 | """Parse all the arguments provided from the CLI. 18 | 19 | Returns: 20 | A list of parsed arguments. 21 | """ 22 | parser = argparse.ArgumentParser(description="Pytorch Segmentation") 23 | parser.add_argument("--root", type=str, default='./data/ATR/test_set/') 24 | parser.add_argument("--data-list", type=str, default='./dataset/ATR/test_id.txt') 25 | parser.add_argument("--crop-size", type=int, default=513) 26 | parser.add_argument("--num-classes", type=int, default=18) 27 | parser.add_argument("--ignore-label", type=int, default=255) 28 | parser.add_argument("--restore-from", type=str, 29 | default='./checkpoints/exp/model_best.pth') 30 | parser.add_argument("--is-mirror", action="store_true") 31 | parser.add_argument('--eval-scale', nargs='+', type=float, default=[1.0]) 32 | # parser.add_argument('--eval-scale', nargs='+', type=float, default=[0.50, 0.75, 1.0, 1.25, 1.50, 1.75]) 33 | 34 | parser.add_argument("--save-dir", type=str) 35 | parser.add_argument("--gpu", type=str, default='0') 36 | return parser.parse_args() 37 | 38 | 39 | def main(): 40 | """Create the model and start the evaluation process.""" 41 | args = get_arguments() 42 | 43 | # initialization 44 | print("Input arguments:") 45 | for key, val in vars(args).items(): 46 | print("{:16} {}".format(key, val)) 47 | 48 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 49 | 50 | model = get_model(num_classes=args.num_classes) 51 | 52 | # if not os.path.exists(args.save_dir): 53 | # os.makedirs(args.save_dir) 54 | 55 | palette = get_lip_palette() 56 | restore_from = args.restore_from 57 | saved_state_dict = torch.load(restore_from) 58 | model.load_state_dict(saved_state_dict) 59 | 60 | model.eval() 61 | model.cuda() 62 | 63 | testloader = data.DataLoader(TestGenerator(args.root, args.data_list, crop_size=args.crop_size), 64 | batch_size=1, shuffle=False, pin_memory=True) 65 | 66 | confusion_matrix = np.zeros((args.num_classes, args.num_classes)) 67 | 68 | for index, batch in enumerate(testloader): 69 | if index % 100 == 0: 70 | print('%d images have been proceeded' % index) 71 | image, label, ori_size, name = batch 72 | 73 | ori_size = ori_size[0].numpy() 74 | 75 | output = predict(model, image.numpy(), (np.asscalar(ori_size[0]), np.asscalar(ori_size[1])), 76 | is_mirror=args.is_mirror, scales=args.eval_scale) 77 | seg_pred = np.asarray(np.argmax(output, axis=2), dtype=np.uint8) 78 | 79 | # output_im = PILImage.fromarray(seg_pred) 80 | # output_im.putpalette(palette) 81 | # output_im.save(args.save_dir + name[0] + '.png') 82 | 83 | seg_gt = np.asarray(label[0].numpy(), dtype=np.int) 84 | ignore_index = seg_gt != 255 85 | seg_gt = seg_gt[ignore_index] 86 | seg_pred = seg_pred[ignore_index] 87 | 88 | confusion_matrix += get_confusion_matrix(seg_gt, seg_pred, args.num_classes) 89 | 90 | pos = confusion_matrix.sum(1) 91 | res = confusion_matrix.sum(0) 92 | tp = np.diag(confusion_matrix) 93 | 94 | pixel_accuracy = tp.sum() / pos.sum() 95 | mean_accuracy = (tp / np.maximum(1.0, pos)).mean() 96 | IU_array = (tp / np.maximum(1.0, pos + res - tp)) 97 | mean_IU = IU_array.mean() 98 | 99 | # get_confusion_matrix_plot() 100 | 101 | print('Pixel accuracy: %f \n' % pixel_accuracy) 102 | print('Mean accuracy: %f \n' % mean_accuracy) 103 | print('Mean IU: %f \n' % mean_IU) 104 | for index, IU in enumerate(IU_array): 105 | print('%f ', IU) 106 | 107 | 108 | def scale_image(image, scale): 109 | image = image[0, :, :, :] 110 | image = image.transpose((1, 2, 0)) 111 | image = cv2.resize(image, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) 112 | image = image.transpose((2, 0, 1)) 113 | return image 114 | 115 | 116 | def predict(net, image, output_size, is_mirror=True, scales=[1]): 117 | if is_mirror: 118 | image_rev = image[:, :, :, ::-1] 119 | 120 | interp = nn.Upsample(size=output_size, mode='bilinear', align_corners=True) 121 | 122 | outputs = [] 123 | if is_mirror: 124 | for scale in scales: 125 | if scale != 1: 126 | image_scale = scale_image(image=image, scale=scale) 127 | image_rev_scale = scale_image(image=image_rev, scale=scale) 128 | else: 129 | image_scale = image[0, :, :, :] 130 | image_rev_scale = image_rev[0, :, :, :] 131 | 132 | image_scale = np.stack((image_scale, image_rev_scale)) 133 | 134 | with torch.no_grad(): 135 | prediction = net(Variable(torch.from_numpy(image_scale)).cuda()) 136 | prediction = interp(prediction[0]).cpu().data.numpy() 137 | 138 | prediction_rev = prediction[1, :, :, :].copy() 139 | prediction_rev[9, :, :] = prediction[1, 10, :, :] 140 | prediction_rev[10, :, :] = prediction[1, 9, :, :] 141 | prediction_rev[12, :, :] = prediction[1, 13, :, :] 142 | prediction_rev[13, :, :] = prediction[1, 12, :, :] 143 | prediction_rev[14, :, :] = prediction[1, 15, :, :] 144 | prediction_rev[15, :, :] = prediction[1, 14, :, :] 145 | prediction_rev = prediction_rev[:, :, ::-1] 146 | prediction = prediction[0, :, :, :] 147 | prediction = np.mean([prediction, prediction_rev], axis=0) 148 | 149 | outputs.append(prediction) 150 | 151 | outputs = np.mean(outputs, axis=0) 152 | outputs = outputs.transpose(1, 2, 0) 153 | else: 154 | for scale in scales: 155 | if scale != 1: 156 | image_scale = scale_image(image=image, scale=scale) 157 | else: 158 | image_scale = image[0, :, :, :] 159 | 160 | with torch.no_grad(): 161 | prediction = net(Variable(torch.from_numpy(image_scale).unsqueeze(0)).cuda()) 162 | prediction = interp(prediction[0]).cpu().data.numpy() 163 | outputs.append(prediction[0, :, :, :]) 164 | 165 | outputs = np.mean(outputs, axis=0) 166 | outputs = outputs.transpose(1, 2, 0) 167 | 168 | return outputs 169 | 170 | 171 | def get_confusion_matrix(gt_label, pred_label, class_num): 172 | """ 173 | Calculate the confusion matrix by given label and pred 174 | :param gt_label: the ground truth label 175 | :param pred_label: the pred label 176 | :param class_num: the nunber of class 177 | """ 178 | index = (gt_label * class_num + pred_label).astype('int32') 179 | label_count = np.bincount(index) 180 | confusion_matrix = np.zeros((class_num, class_num)) 181 | 182 | for i_label in range(class_num): 183 | for i_pred_label in range(class_num): 184 | cur_index = i_label * class_num + i_pred_label 185 | if cur_index < len(label_count): 186 | confusion_matrix[i_label, i_pred_label] = label_count[cur_index] 187 | 188 | return confusion_matrix 189 | 190 | 191 | def get_confusion_matrix_plot(conf_arr): 192 | norm_conf = [] 193 | for i in conf_arr: 194 | tmp_arr = [] 195 | a = sum(i, 0) 196 | for j in i: 197 | tmp_arr.append(float(j) / max(1.0, float(a))) 198 | norm_conf.append(tmp_arr) 199 | 200 | fig = plt.figure() 201 | plt.clf() 202 | ax = fig.add_subplot(111) 203 | ax.set_aspect(1) 204 | res = ax.imshow(np.array(norm_conf), cmap=plt.cm.jet, interpolation='nearest') 205 | 206 | width, height = conf_arr.shape 207 | 208 | cb = fig.colorbar(res) 209 | alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' 210 | plt.xticks(range(width), alphabet[:width]) 211 | plt.yticks(range(height), alphabet[:height]) 212 | plt.savefig('confusion_matrix.png', format='png') 213 | 214 | 215 | def get_lip_palette(): 216 | palette = [0, 0, 0, 217 | 128, 0, 0, 218 | 255, 0, 0, 219 | 0, 85, 0, 220 | 170, 0, 51, 221 | 255, 85, 0, 222 | 0, 0, 85, 223 | 0, 119, 221, 224 | 85, 85, 0, 225 | 0, 85, 85, 226 | 85, 51, 0, 227 | 52, 86, 128, 228 | 0, 128, 0, 229 | 0, 0, 255, 230 | 51, 170, 221, 231 | 0, 255, 255, 232 | 85, 255, 170, 233 | 170, 255, 85, 234 | 255, 255, 0, 235 | 255, 170, 0] 236 | return palette 237 | 238 | 239 | if __name__ == '__main__': 240 | main() 241 | -------------------------------------------------------------------------------- /val/evaluate_ccf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import cv2 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from PIL import Image 10 | from torch.autograd import Variable 11 | from torch.utils import data 12 | 13 | from dataset.data_ccf import TestGenerator 14 | from network.baseline import get_model 15 | 16 | 17 | def get_arguments(): 18 | """Parse all the arguments provided from the CLI. 19 | 20 | Returns: 21 | A list of parsed arguments. 22 | """ 23 | parser = argparse.ArgumentParser(description="Pytorch Segmentation") 24 | parser.add_argument('--root', default='./data/CCF', type=str) 25 | parser.add_argument("--data-list", type=str, default='./dataset/CCF/test_id.txt') 26 | parser.add_argument("--crop-size", type=int, default=513) 27 | parser.add_argument("--num-classes", type=int, default=18) 28 | parser.add_argument("--ignore-label", type=int, default=255) 29 | parser.add_argument('--restore-from', default='./checkpoints/exp/model_best.pth', type=str) 30 | 31 | parser.add_argument("--is-mirror", action="store_true") 32 | parser.add_argument('--eval-scale', nargs='+', type=float, default=[1.0]) 33 | # parser.add_argument('--eval-scale', nargs='+', type=float, default=[0.50, 0.75, 1.0, 1.25, 1.50, 1.75]) 34 | parser.add_argument("--save-dir", type=str) 35 | parser.add_argument("--gpu", type=str, default='0') 36 | return parser.parse_args() 37 | 38 | 39 | def main(): 40 | """Create the model and start the evaluation process.""" 41 | args = get_arguments() 42 | 43 | # initialization 44 | print("Input arguments:") 45 | for key, val in vars(args).items(): 46 | print("{:16} {}".format(key, val)) 47 | 48 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 49 | 50 | 51 | # if not os.path.exists(args.save_dir): 52 | # os.makedirs(args.save_dir) 53 | 54 | # obtain the color map 55 | palette = get_lip_palette() 56 | 57 | # conduct model & load pre-trained weights 58 | model = get_model(num_classes=args.num_classes) 59 | restore_from = args.restore_from 60 | saved_state_dict = torch.load(restore_from) 61 | model.load_state_dict(saved_state_dict) 62 | 63 | model.eval() 64 | model.cuda() 65 | # data loader 66 | testloader = data.DataLoader(TestGenerator(args.root, args.data_list, crop_size=args.crop_size), 67 | batch_size=1, shuffle=False, pin_memory=True) 68 | 69 | confusion_matrix = np.zeros((args.num_classes, args.num_classes)) 70 | 71 | for index, batch in enumerate(testloader): 72 | if index % 100 == 0: 73 | print('%d images have been proceeded' % index) 74 | image, label, ori_size, name = batch 75 | 76 | ori_size = ori_size[0].numpy() 77 | output = predict(model, image.numpy(), (np.asscalar(ori_size[0]), np.asscalar(ori_size[1])), 78 | is_mirror=args.is_mirror, scales=args.eval_scale) 79 | seg_pred = np.asarray(np.argmax(output, axis=2), dtype=np.uint8) 80 | 81 | # output_im = PILImage.fromarray(seg_pred) 82 | # output_im.putpalette(palette) 83 | # output_im.save(args.save_dir + name[0] + '.png') 84 | 85 | seg_gt = np.asarray(label[0].numpy(), dtype=np.int) 86 | ignore_index = seg_gt != 255 87 | seg_gt = seg_gt[ignore_index] 88 | seg_pred = seg_pred[ignore_index] 89 | 90 | confusion_matrix += get_confusion_matrix(seg_gt, seg_pred, args.num_classes) 91 | 92 | pos = confusion_matrix.sum(1) 93 | res = confusion_matrix.sum(0) 94 | tp = np.diag(confusion_matrix) 95 | 96 | pixel_accuracy = tp.sum() / pos.sum() 97 | mean_accuracy = (tp / np.maximum(1.0, pos)).mean() 98 | IU_array = (tp / np.maximum(1.0, pos + res - tp)) 99 | mean_IU = IU_array.mean() 100 | 101 | # get_confusion_matrix_plot() 102 | 103 | print('Pixel accuracy: %f \n' % pixel_accuracy) 104 | print('Mean accuracy: %f \n' % mean_accuracy) 105 | print('Mean IU: %f \n' % mean_IU) 106 | for index, IU in enumerate(IU_array): 107 | print('%f ', IU) 108 | 109 | 110 | def scale_image(image, scale): 111 | image = image[0, :, :, :] 112 | image = image.transpose((1, 2, 0)) 113 | image = cv2.resize(image, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) 114 | image = image.transpose((2, 0, 1)) 115 | return image 116 | 117 | 118 | def predict(net, image, output_size, is_mirror=True, scales=[1]): 119 | if is_mirror: 120 | image_rev = image[:, :, :, ::-1] 121 | 122 | interp = nn.Upsample(size=output_size, mode='bilinear', align_corners=True) 123 | 124 | outputs = [] 125 | if is_mirror: 126 | for scale in scales: 127 | if scale != 1: 128 | image_scale = scale_image(image=image, scale=scale) 129 | image_rev_scale = scale_image(image=image_rev, scale=scale) 130 | else: 131 | image_scale = image[0, :, :, :] 132 | image_rev_scale = image_rev[0, :, :, :] 133 | 134 | image_scale = np.stack((image_scale, image_rev_scale)) 135 | 136 | with torch.no_grad(): 137 | prediction = net(Variable(torch.from_numpy(image_scale)).cuda()) 138 | prediction = interp(prediction[0]).cpu().data.numpy() 139 | 140 | prediction_rev = prediction[1, :, :, :].copy() 141 | prediction_rev = prediction_rev[:, :, ::-1] 142 | prediction = prediction[0, :, :, :] 143 | prediction = np.mean([prediction, prediction_rev], axis=0) 144 | 145 | outputs.append(prediction) 146 | 147 | outputs = np.mean(outputs, axis=0) 148 | outputs = outputs.transpose(1, 2, 0) 149 | else: 150 | for scale in scales: 151 | if scale != 1: 152 | image_scale = scale_image(image=image, scale=scale) 153 | else: 154 | image_scale = image[0, :, :, :] 155 | 156 | with torch.no_grad(): 157 | prediction = net(Variable(torch.from_numpy(image_scale).unsqueeze(0)).cuda()) 158 | prediction = interp(prediction[0]).cpu().data.numpy() 159 | outputs.append(prediction[0, :, :, :]) 160 | 161 | outputs = np.mean(outputs, axis=0) 162 | outputs = outputs.transpose(1, 2, 0) 163 | 164 | return outputs 165 | 166 | 167 | def get_confusion_matrix(gt_label, pred_label, class_num): 168 | """ 169 | Calculate the confusion matrix by given label and pred 170 | :param gt_label: the ground truth label 171 | :param pred_label: the pred label 172 | :param class_num: the nunber of class 173 | """ 174 | index = (gt_label * class_num + pred_label).astype('int32') 175 | label_count = np.bincount(index) 176 | confusion_matrix = np.zeros((class_num, class_num)) 177 | 178 | for i_label in range(class_num): 179 | for i_pred_label in range(class_num): 180 | cur_index = i_label * class_num + i_pred_label 181 | if cur_index < len(label_count): 182 | confusion_matrix[i_label, i_pred_label] = label_count[cur_index] 183 | 184 | return confusion_matrix 185 | 186 | 187 | def get_confusion_matrix_plot(conf_arr): 188 | norm_conf = [] 189 | for i in conf_arr: 190 | tmp_arr = [] 191 | a = sum(i, 0) 192 | for j in i: 193 | tmp_arr.append(float(j) / max(1.0, float(a))) 194 | norm_conf.append(tmp_arr) 195 | 196 | fig = plt.figure() 197 | plt.clf() 198 | ax = fig.add_subplot(111) 199 | ax.set_aspect(1) 200 | res = ax.imshow(np.array(norm_conf), cmap=plt.cm.jet, interpolation='nearest') 201 | 202 | width, height = conf_arr.shape 203 | 204 | cb = fig.colorbar(res) 205 | alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' 206 | plt.xticks(range(width), alphabet[:width]) 207 | plt.yticks(range(height), alphabet[:height]) 208 | plt.savefig('confusion_matrix.png', format='png') 209 | 210 | 211 | def get_lip_palette(): 212 | palette = [0, 0, 0, 213 | 128, 0, 0, 214 | 255, 0, 0, 215 | 0, 85, 0, 216 | 170, 0, 51, 217 | 255, 85, 0, 218 | 0, 0, 85, 219 | 0, 119, 221, 220 | 85, 85, 0, 221 | 0, 85, 85, 222 | 85, 51, 0, 223 | 52, 86, 128, 224 | 0, 128, 0, 225 | 0, 0, 255, 226 | 51, 170, 221, 227 | 0, 255, 255, 228 | 85, 255, 170, 229 | 170, 255, 85, 230 | 255, 255, 0, 231 | 255, 170, 0] 232 | return palette 233 | 234 | 235 | if __name__ == '__main__': 236 | main() 237 | -------------------------------------------------------------------------------- /val/evaluate_lip.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import cv2 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.autograd import Variable 10 | from torch.utils import data 11 | 12 | from dataset.datasets import LIPValGenerator 13 | from network.baseline import get_model 14 | 15 | 16 | def get_arguments(): 17 | """Parse all the arguments provided from the CLI. 18 | 19 | Returns: 20 | A list of parsed arguments. 21 | """ 22 | parser = argparse.ArgumentParser(description="Pytorch Segmentation") 23 | parser.add_argument("--root", type=str, default='./data/LIP/val_set/') 24 | parser.add_argument("--data-list", type=str, default='./dataset/LIP/val_id.txt') 25 | parser.add_argument("--crop-size", type=int, default=473) 26 | parser.add_argument("--num-classes", type=int, default=20) 27 | parser.add_argument("--ignore-label", type=int, default=255) 28 | parser.add_argument("--restore-from", type=str, 29 | default='./checkpoints/exp/model_best.pth') 30 | parser.add_argument("--is-mirror", action="store_true") 31 | parser.add_argument('--eval-scale', nargs='+', type=float, default=[1.0]) 32 | # parser.add_argument('--eval-scale', nargs='+', type=float, default=[0.50, 0.75, 1.0, 1.25, 1.50, 1.75]) 33 | 34 | parser.add_argument("--save-dir", type=str) 35 | parser.add_argument("--gpu", type=str, default='0') 36 | return parser.parse_args() 37 | 38 | 39 | def main(): 40 | """Create the model and start the evaluation process.""" 41 | args = get_arguments() 42 | 43 | # initialization 44 | print("Input arguments:") 45 | for key, val in vars(args).items(): 46 | print("{:16} {}".format(key, val)) 47 | 48 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 49 | 50 | model = get_model(num_classes=args.num_classes) 51 | 52 | # if not os.path.exists(args.save_dir): 53 | # os.makedirs(args.save_dir) 54 | 55 | palette = get_lip_palette() 56 | restore_from = args.restore_from 57 | saved_state_dict = torch.load(restore_from) 58 | model.load_state_dict(saved_state_dict) 59 | 60 | model.eval() 61 | model.cuda() 62 | 63 | testloader = data.DataLoader(LIPValGenerator(args.root, args.data_list, crop_size=args.crop_size), 64 | batch_size=1, shuffle=False, pin_memory=True) 65 | 66 | confusion_matrix = np.zeros((args.num_classes, args.num_classes)) 67 | 68 | for index, batch in enumerate(testloader): 69 | if index % 100 == 0: 70 | print('%d images have been proceeded' % index) 71 | image, label, ori_size, name = batch 72 | 73 | ori_size = ori_size[0].numpy() 74 | 75 | output = predict(model, image.numpy(), (np.asscalar(ori_size[0]), np.asscalar(ori_size[1])), 76 | is_mirror=args.is_mirror, scales=args.eval_scale) 77 | seg_pred = np.asarray(np.argmax(output, axis=2), dtype=np.uint8) 78 | 79 | # output_im = PILImage.fromarray(seg_pred) 80 | # output_im.putpalette(palette) 81 | # output_im.save(args.save_dir + name[0] + '.png') 82 | 83 | seg_gt = np.asarray(label[0].numpy(), dtype=np.int) 84 | ignore_index = seg_gt != 255 85 | seg_gt = seg_gt[ignore_index] 86 | seg_pred = seg_pred[ignore_index] 87 | 88 | confusion_matrix += get_confusion_matrix(seg_gt, seg_pred, args.num_classes) 89 | 90 | pos = confusion_matrix.sum(1) 91 | res = confusion_matrix.sum(0) 92 | tp = np.diag(confusion_matrix) 93 | 94 | pixel_accuracy = tp.sum() / pos.sum() 95 | mean_accuracy = (tp / np.maximum(1.0, pos)).mean() 96 | IU_array = (tp / np.maximum(1.0, pos + res - tp)) 97 | mean_IU = IU_array.mean() 98 | 99 | # get_confusion_matrix_plot() 100 | 101 | print('Pixel accuracy: %f \n' % pixel_accuracy) 102 | print('Mean accuracy: %f \n' % mean_accuracy) 103 | print('Mean IU: %f \n' % mean_IU) 104 | for index, IU in enumerate(IU_array): 105 | print('%f ', IU) 106 | 107 | 108 | def scale_image(image, scale): 109 | image = image[0, :, :, :] 110 | image = image.transpose((1, 2, 0)) 111 | image = cv2.resize(image, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) 112 | image = image.transpose((2, 0, 1)) 113 | return image 114 | 115 | 116 | def predict(net, image, output_size, is_mirror=True, scales=[1]): 117 | if is_mirror: 118 | image_rev = image[:, :, :, ::-1] 119 | 120 | interp = nn.Upsample(size=output_size, mode='bilinear', align_corners=True) 121 | 122 | outputs = [] 123 | if is_mirror: 124 | for scale in scales: 125 | if scale != 1: 126 | image_scale = scale_image(image=image, scale=scale) 127 | image_rev_scale = scale_image(image=image_rev, scale=scale) 128 | else: 129 | image_scale = image[0, :, :, :] 130 | image_rev_scale = image_rev[0, :, :, :] 131 | 132 | image_scale = np.stack((image_scale, image_rev_scale)) 133 | 134 | with torch.no_grad(): 135 | prediction = net(Variable(torch.from_numpy(image_scale)).cuda()) 136 | prediction = interp(prediction[0]).cpu().data.numpy() 137 | 138 | prediction_rev = prediction[1, :, :, :].copy() 139 | prediction_rev[14, :, :] = prediction[1, 15, :, :] 140 | prediction_rev[15, :, :] = prediction[1, 14, :, :] 141 | prediction_rev[16, :, :] = prediction[1, 17, :, :] 142 | prediction_rev[17, :, :] = prediction[1, 16, :, :] 143 | prediction_rev[18, :, :] = prediction[1, 19, :, :] 144 | prediction_rev[19, :, :] = prediction[1, 18, :, :] 145 | prediction_rev = prediction_rev[:, :, ::-1] 146 | prediction = prediction[0, :, :, :] 147 | prediction = np.mean([prediction, prediction_rev], axis=0) 148 | 149 | outputs.append(prediction) 150 | 151 | outputs = np.mean(outputs, axis=0) 152 | outputs = outputs.transpose(1, 2, 0) 153 | else: 154 | for scale in scales: 155 | if scale != 1: 156 | image_scale = scale_image(image=image, scale=scale) 157 | else: 158 | image_scale = image[0, :, :, :] 159 | 160 | with torch.no_grad(): 161 | prediction = net(Variable(torch.from_numpy(image_scale).unsqueeze(0)).cuda()) 162 | prediction = interp(prediction[0]).cpu().data.numpy() 163 | outputs.append(prediction[0, :, :, :]) 164 | 165 | outputs = np.mean(outputs, axis=0) 166 | outputs = outputs.transpose(1, 2, 0) 167 | 168 | return outputs 169 | 170 | 171 | def get_confusion_matrix(gt_label, pred_label, class_num): 172 | """ 173 | Calculate the confusion matrix by given label and pred 174 | :param gt_label: the ground truth label 175 | :param pred_label: the pred label 176 | :param class_num: the nunber of class 177 | """ 178 | index = (gt_label * class_num + pred_label).astype('int32') 179 | label_count = np.bincount(index) 180 | confusion_matrix = np.zeros((class_num, class_num)) 181 | 182 | for i_label in range(class_num): 183 | for i_pred_label in range(class_num): 184 | cur_index = i_label * class_num + i_pred_label 185 | if cur_index < len(label_count): 186 | confusion_matrix[i_label, i_pred_label] = label_count[cur_index] 187 | 188 | return confusion_matrix 189 | 190 | 191 | def get_confusion_matrix_plot(conf_arr): 192 | norm_conf = [] 193 | for i in conf_arr: 194 | tmp_arr = [] 195 | a = sum(i, 0) 196 | for j in i: 197 | tmp_arr.append(float(j) / max(1.0, float(a))) 198 | norm_conf.append(tmp_arr) 199 | 200 | fig = plt.figure() 201 | plt.clf() 202 | ax = fig.add_subplot(111) 203 | ax.set_aspect(1) 204 | res = ax.imshow(np.array(norm_conf), cmap=plt.cm.jet, interpolation='nearest') 205 | 206 | width, height = conf_arr.shape 207 | 208 | cb = fig.colorbar(res) 209 | alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' 210 | plt.xticks(range(width), alphabet[:width]) 211 | plt.yticks(range(height), alphabet[:height]) 212 | plt.savefig('confusion_matrix.png', format='png') 213 | 214 | 215 | def get_lip_palette(): 216 | palette = [0, 0, 0, 217 | 128, 0, 0, 218 | 255, 0, 0, 219 | 0, 85, 0, 220 | 170, 0, 51, 221 | 255, 85, 0, 222 | 0, 0, 85, 223 | 0, 119, 221, 224 | 85, 85, 0, 225 | 0, 85, 85, 226 | 85, 51, 0, 227 | 52, 86, 128, 228 | 0, 128, 0, 229 | 0, 0, 255, 230 | 51, 170, 221, 231 | 0, 255, 255, 232 | 85, 255, 170, 233 | 170, 255, 85, 234 | 255, 255, 0, 235 | 255, 170, 0] 236 | return palette 237 | 238 | 239 | if __name__ == '__main__': 240 | main() 241 | -------------------------------------------------------------------------------- /val/evaluate_pascal.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import cv2 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.autograd import Variable 10 | from torch.utils import data 11 | 12 | from dataset.data_pascal import TestGenerator 13 | from network.baseline import get_model 14 | 15 | 16 | def get_arguments(): 17 | """Parse all the arguments provided from the CLI. 18 | 19 | Returns: 20 | A list of parsed arguments. 21 | """ 22 | parser = argparse.ArgumentParser(description="Pytorch Segmentation") 23 | parser.add_argument('--root', default='./data/Person', type=str) 24 | parser.add_argument("--data-list", type=str, default='./dataset/Pascal/val_id.txt') 25 | parser.add_argument("--crop-size", type=int, default=473) 26 | parser.add_argument("--num-classes", type=int, default=7) 27 | parser.add_argument("--ignore-label", type=int, default=255) 28 | parser.add_argument('--restore-from', default='./checkpoints/exp/model_best.pth', type=str) 29 | 30 | parser.add_argument("--is-mirror", action="store_true") 31 | parser.add_argument('--eval-scale', nargs='+', type=float, default=[1.0]) 32 | # parser.add_argument('--eval-scale', nargs='+', type=float, default=[0.50, 0.75, 1.0, 1.25, 1.50, 1.75]) 33 | parser.add_argument("--save-dir", type=str) 34 | parser.add_argument("--gpu", type=str, default='0') 35 | return parser.parse_args() 36 | 37 | 38 | def main(): 39 | """Create the model and start the evaluation process.""" 40 | args = get_arguments() 41 | 42 | # initialization 43 | print("Input arguments:") 44 | for key, val in vars(args).items(): 45 | print("{:16} {}".format(key, val)) 46 | 47 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 48 | 49 | model = get_model(num_classes=args.num_classes) 50 | 51 | # if not os.path.exists(args.save_dir): 52 | # os.makedirs(args.save_dir) 53 | 54 | palette = get_lip_palette() 55 | restore_from = args.restore_from 56 | saved_state_dict = torch.load(restore_from) 57 | model.load_state_dict(saved_state_dict) 58 | 59 | model.eval() 60 | model.cuda() 61 | 62 | testloader = data.DataLoader(TestGenerator(args.root, args.data_list, crop_size=args.crop_size), 63 | batch_size=1, shuffle=False, pin_memory=True) 64 | 65 | confusion_matrix = np.zeros((args.num_classes, args.num_classes)) 66 | 67 | for index, batch in enumerate(testloader): 68 | if index % 100 == 0: 69 | print('%d images have been proceeded' % index) 70 | image, label, ori_size, name = batch 71 | 72 | # img_name = "/home/hlzhu/hlzhu/Iter_ParseNet_final/data/Person/JPEGImages/"+name[0]+'.jpg' 73 | # print(img_name) 74 | # ori_img = cv2.imread(img_name) 75 | # cv2.imshow('image',ori_img) 76 | # cv2.waitKey(1) 77 | # 2008_000195 multi person 78 | # 2008_002829 single person 79 | if name[0]=="2008_002829": 80 | print("2008_002829.jpg") 81 | ori_size = ori_size[0].numpy() 82 | 83 | output = predict(model, image.numpy(), (np.asscalar(ori_size[0]), np.asscalar(ori_size[1])), 84 | is_mirror=args.is_mirror, scales=args.eval_scale) 85 | seg_pred = np.asarray(np.argmax(output, axis=2), dtype=np.uint8) 86 | 87 | # output_im = PILImage.fromarray(seg_pred) 88 | # output_im.putpalette(palette) 89 | # output_im.save(args.save_dir + name[0] + '.png') 90 | 91 | seg_gt = np.asarray(label[0].numpy(), dtype=np.int) 92 | ignore_index = seg_gt != 255 93 | seg_gt = seg_gt[ignore_index] 94 | seg_pred = seg_pred[ignore_index] 95 | 96 | confusion_matrix += get_confusion_matrix(seg_gt, seg_pred, args.num_classes) 97 | 98 | pos = confusion_matrix.sum(1) 99 | res = confusion_matrix.sum(0) 100 | tp = np.diag(confusion_matrix) 101 | 102 | pixel_accuracy = tp.sum() / pos.sum() 103 | mean_accuracy = (tp / np.maximum(1.0, pos)).mean() 104 | IU_array = (tp / np.maximum(1.0, pos + res - tp)) 105 | mean_IU = IU_array.mean() 106 | 107 | # get_confusion_matrix_plot() 108 | 109 | print('Pixel accuracy: %f \n' % pixel_accuracy) 110 | print('Mean accuracy: %f \n' % mean_accuracy) 111 | print('Mean IU: %f \n' % mean_IU) 112 | for index, IU in enumerate(IU_array): 113 | print('%f ', IU) 114 | 115 | 116 | def scale_image(image, scale): 117 | image = image[0, :, :, :] 118 | image = image.transpose((1, 2, 0)) 119 | image = cv2.resize(image, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) 120 | image = image.transpose((2, 0, 1)) 121 | return image 122 | 123 | 124 | def predict(net, image, output_size, is_mirror=True, scales=[1]): 125 | if is_mirror: 126 | image_rev = image[:, :, :, ::-1] 127 | 128 | interp = nn.Upsample(size=output_size, mode='bilinear', align_corners=True) 129 | 130 | outputs = [] 131 | if is_mirror: 132 | for scale in scales: 133 | if scale != 1: 134 | image_scale = scale_image(image=image, scale=scale) 135 | image_rev_scale = scale_image(image=image_rev, scale=scale) 136 | else: 137 | image_scale = image[0, :, :, :] 138 | image_rev_scale = image_rev[0, :, :, :] 139 | 140 | image_scale = np.stack((image_scale, image_rev_scale)) 141 | 142 | with torch.no_grad(): 143 | prediction = net(Variable(torch.from_numpy(image_scale)).cuda()) 144 | prediction = interp(prediction[0]).cpu().data.numpy() 145 | 146 | prediction_rev = prediction[1, :, :, :].copy() 147 | prediction_rev = prediction_rev[:, :, ::-1] 148 | prediction = prediction[0, :, :, :] 149 | prediction = np.mean([prediction, prediction_rev], axis=0) 150 | 151 | outputs.append(prediction) 152 | 153 | outputs = np.mean(outputs, axis=0) 154 | outputs = outputs.transpose(1, 2, 0) 155 | else: 156 | for scale in scales: 157 | if scale != 1: 158 | image_scale = scale_image(image=image, scale=scale) 159 | else: 160 | image_scale = image[0, :, :, :] 161 | 162 | with torch.no_grad(): 163 | prediction = net(Variable(torch.from_numpy(image_scale).unsqueeze(0)).cuda()) 164 | prediction = interp(prediction[0]).cpu().data.numpy() 165 | outputs.append(prediction[0, :, :, :]) 166 | 167 | outputs = np.mean(outputs, axis=0) 168 | outputs = outputs.transpose(1, 2, 0) 169 | 170 | return outputs 171 | 172 | 173 | def get_confusion_matrix(gt_label, pred_label, class_num): 174 | """ 175 | Calculate the confusion matrix by given label and pred 176 | :param gt_label: the ground truth label 177 | :param pred_label: the pred label 178 | :param class_num: the nunber of class 179 | """ 180 | index = (gt_label * class_num + pred_label).astype('int32') 181 | label_count = np.bincount(index) 182 | confusion_matrix = np.zeros((class_num, class_num)) 183 | 184 | for i_label in range(class_num): 185 | for i_pred_label in range(class_num): 186 | cur_index = i_label * class_num + i_pred_label 187 | if cur_index < len(label_count): 188 | confusion_matrix[i_label, i_pred_label] = label_count[cur_index] 189 | 190 | return confusion_matrix 191 | 192 | 193 | def get_confusion_matrix_plot(conf_arr): 194 | norm_conf = [] 195 | for i in conf_arr: 196 | tmp_arr = [] 197 | a = sum(i, 0) 198 | for j in i: 199 | tmp_arr.append(float(j) / max(1.0, float(a))) 200 | norm_conf.append(tmp_arr) 201 | 202 | fig = plt.figure() 203 | plt.clf() 204 | ax = fig.add_subplot(111) 205 | ax.set_aspect(1) 206 | res = ax.imshow(np.array(norm_conf), cmap=plt.cm.jet, interpolation='nearest') 207 | 208 | width, height = conf_arr.shape 209 | 210 | cb = fig.colorbar(res) 211 | alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' 212 | plt.xticks(range(width), alphabet[:width]) 213 | plt.yticks(range(height), alphabet[:height]) 214 | plt.savefig('confusion_matrix.png', format='png') 215 | 216 | 217 | def get_lip_palette(): 218 | palette = [0, 0, 0, 219 | 128, 0, 0, 220 | 255, 0, 0, 221 | 0, 85, 0, 222 | 170, 0, 51, 223 | 255, 85, 0, 224 | 0, 0, 85, 225 | 0, 119, 221, 226 | 85, 85, 0, 227 | 0, 85, 85, 228 | 85, 51, 0, 229 | 52, 86, 128, 230 | 0, 128, 0, 231 | 0, 0, 255, 232 | 51, 170, 221, 233 | 0, 255, 255, 234 | 85, 255, 170, 235 | 170, 255, 85, 236 | 255, 255, 0, 237 | 255, 170, 0] 238 | return palette 239 | 240 | 241 | if __name__ == '__main__': 242 | main() 243 | -------------------------------------------------------------------------------- /val/evaluate_ppss.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import cv2 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.autograd import Variable 10 | from torch.utils import data 11 | 12 | from dataset.datappss import TestGenerator 13 | from network.baseline import get_model 14 | 15 | 16 | def get_arguments(): 17 | """Parse all the arguments provided from the CLI. 18 | 19 | Returns: 20 | A list of parsed arguments. 21 | """ 22 | parser = argparse.ArgumentParser(description="Pytorch Segmentation") 23 | parser.add_argument('--root', default='./data/PPSS/TestData/', type=str) 24 | parser.add_argument("--data-list", type=str, default='./dataset/PPSS/test_id.txt') 25 | parser.add_argument("--crop-size", type=tuple, default=(321, 321)) 26 | parser.add_argument("--num-classes", type=int, default=8) 27 | parser.add_argument("--ignore-label", type=int, default=255) 28 | parser.add_argument('--restore-from', default='./checkpoints/exp/model_best.pth', type=str) 29 | 30 | parser.add_argument("--is-mirror", action="store_true") 31 | parser.add_argument('--eval-scale', nargs='+', type=float, default=[1.0]) 32 | # parser.add_argument('--eval-scale', nargs='+', type=float, default=[0.50, 0.75, 1.0, 1.25, 1.50, 1.75]) 33 | parser.add_argument("--save-dir", type=str) 34 | parser.add_argument("--gpu", type=str, default='0') 35 | return parser.parse_args() 36 | 37 | 38 | def main(): 39 | """Create the model and start the evaluation process.""" 40 | args = get_arguments() 41 | 42 | # initialization 43 | print("Input arguments:") 44 | for key, val in vars(args).items(): 45 | print("{:16} {}".format(key, val)) 46 | 47 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 48 | 49 | model = get_model(num_classes=args.num_classes) 50 | 51 | # if not os.path.exists(args.save_dir): 52 | # os.makedirs(args.save_dir) 53 | 54 | palette = get_lip_palette() 55 | restore_from = args.restore_from 56 | saved_state_dict = torch.load(restore_from) 57 | model.load_state_dict(saved_state_dict) 58 | 59 | model.eval() 60 | model.cuda() 61 | 62 | testloader = data.DataLoader(TestGenerator(args.root, args.data_list, crop_size=args.crop_size), 63 | batch_size=1, shuffle=False, pin_memory=True) 64 | 65 | confusion_matrix = np.zeros((args.num_classes, args.num_classes)) 66 | 67 | for index, batch in enumerate(testloader): 68 | if index % 100 == 0: 69 | print('%d images have been proceeded' % index) 70 | image, label, ori_size, name = batch 71 | 72 | # img_name = "/home/hlzhu/hlzhu/Iter_ParseNet_final/data/Person/JPEGImages/"+name[0]+'.jpg' 73 | # print(img_name) 74 | # ori_img = cv2.imread(img_name) 75 | # cv2.imshow('image',ori_img) 76 | # cv2.waitKey(1) 77 | # 2008_000195 multi person 78 | # 2008_002829 single person 79 | if name[0] == "2008_002829": 80 | print("2008_002829.jpg") 81 | ori_size = ori_size[0].numpy() 82 | 83 | output = predict(model, image.numpy(), (np.asscalar(ori_size[0]), np.asscalar(ori_size[1])), 84 | is_mirror=args.is_mirror, scales=args.eval_scale) 85 | seg_pred = np.asarray(np.argmax(output, axis=2), dtype=np.uint8) 86 | 87 | # output_im = PILImage.fromarray(seg_pred) 88 | # output_im.putpalette(palette) 89 | # output_im.save(args.save_dir + name[0] + '.png') 90 | 91 | seg_gt = np.asarray(label[0].numpy(), dtype=np.int) 92 | ignore_index = seg_gt != 255 93 | seg_gt = seg_gt[ignore_index] 94 | seg_pred = seg_pred[ignore_index] 95 | 96 | confusion_matrix += get_confusion_matrix(seg_gt, seg_pred, args.num_classes) 97 | 98 | pos = confusion_matrix.sum(1) 99 | res = confusion_matrix.sum(0) 100 | tp = np.diag(confusion_matrix) 101 | 102 | pixel_accuracy = tp.sum() / pos.sum() 103 | mean_accuracy = (tp / np.maximum(1.0, pos)).mean() 104 | IU_array = (tp / np.maximum(1.0, pos + res - tp)) 105 | mean_IU = IU_array.mean() 106 | 107 | # get_confusion_matrix_plot() 108 | 109 | print('Pixel accuracy: %f \n' % pixel_accuracy) 110 | print('Mean accuracy: %f \n' % mean_accuracy) 111 | print('Mean IU: %f \n' % mean_IU) 112 | for index, IU in enumerate(IU_array): 113 | print('%f ', IU) 114 | 115 | 116 | def scale_image(image, scale): 117 | image = image[0, :, :, :] 118 | image = image.transpose((1, 2, 0)) 119 | image = cv2.resize(image, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) 120 | image = image.transpose((2, 0, 1)) 121 | return image 122 | 123 | 124 | def predict(net, image, output_size, is_mirror=True, scales=[1]): 125 | if is_mirror: 126 | image_rev = image[:, :, :, ::-1] 127 | 128 | interp = nn.Upsample(size=output_size, mode='bilinear', align_corners=True) 129 | 130 | outputs = [] 131 | if is_mirror: 132 | for scale in scales: 133 | if scale != 1: 134 | image_scale = scale_image(image=image, scale=scale) 135 | image_rev_scale = scale_image(image=image_rev, scale=scale) 136 | else: 137 | image_scale = image[0, :, :, :] 138 | image_rev_scale = image_rev[0, :, :, :] 139 | 140 | image_scale = np.stack((image_scale, image_rev_scale)) 141 | 142 | with torch.no_grad(): 143 | prediction = net(Variable(torch.from_numpy(image_scale)).cuda()) 144 | prediction = interp(prediction[0]).cpu().data.numpy() 145 | 146 | prediction_rev = prediction[1, :, :, :].copy() 147 | prediction_rev = prediction_rev[:, :, ::-1] 148 | prediction = prediction[0, :, :, :] 149 | prediction = np.mean([prediction, prediction_rev], axis=0) 150 | 151 | outputs.append(prediction) 152 | 153 | outputs = np.mean(outputs, axis=0) 154 | outputs = outputs.transpose(1, 2, 0) 155 | else: 156 | for scale in scales: 157 | if scale != 1: 158 | image_scale = scale_image(image=image, scale=scale) 159 | else: 160 | image_scale = image[0, :, :, :] 161 | 162 | with torch.no_grad(): 163 | prediction = net(Variable(torch.from_numpy(image_scale).unsqueeze(0)).cuda()) 164 | prediction = interp(prediction[0]).cpu().data.numpy() 165 | outputs.append(prediction[0, :, :, :]) 166 | 167 | outputs = np.mean(outputs, axis=0) 168 | outputs = outputs.transpose(1, 2, 0) 169 | 170 | return outputs 171 | 172 | 173 | def get_confusion_matrix(gt_label, pred_label, class_num): 174 | """ 175 | Calculate the confusion matrix by given label and pred 176 | :param gt_label: the ground truth label 177 | :param pred_label: the pred label 178 | :param class_num: the nunber of class 179 | """ 180 | index = (gt_label * class_num + pred_label).astype('int32') 181 | label_count = np.bincount(index) 182 | confusion_matrix = np.zeros((class_num, class_num)) 183 | 184 | for i_label in range(class_num): 185 | for i_pred_label in range(class_num): 186 | cur_index = i_label * class_num + i_pred_label 187 | if cur_index < len(label_count): 188 | confusion_matrix[i_label, i_pred_label] = label_count[cur_index] 189 | 190 | return confusion_matrix 191 | 192 | 193 | def get_confusion_matrix_plot(conf_arr): 194 | norm_conf = [] 195 | for i in conf_arr: 196 | tmp_arr = [] 197 | a = sum(i, 0) 198 | for j in i: 199 | tmp_arr.append(float(j) / max(1.0, float(a))) 200 | norm_conf.append(tmp_arr) 201 | 202 | fig = plt.figure() 203 | plt.clf() 204 | ax = fig.add_subplot(111) 205 | ax.set_aspect(1) 206 | res = ax.imshow(np.array(norm_conf), cmap=plt.cm.jet, interpolation='nearest') 207 | 208 | width, height = conf_arr.shape 209 | 210 | cb = fig.colorbar(res) 211 | alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' 212 | plt.xticks(range(width), alphabet[:width]) 213 | plt.yticks(range(height), alphabet[:height]) 214 | plt.savefig('confusion_matrix.png', format='png') 215 | 216 | 217 | def get_lip_palette(): 218 | palette = [0, 0, 0, 219 | 128, 0, 0, 220 | 255, 0, 0, 221 | 0, 85, 0, 222 | 170, 0, 51, 223 | 255, 85, 0, 224 | 0, 0, 85, 225 | 0, 119, 221, 226 | 85, 85, 0, 227 | 0, 85, 85, 228 | 85, 51, 0, 229 | 52, 86, 128, 230 | 0, 128, 0, 231 | 0, 0, 255, 232 | 51, 170, 221, 233 | 0, 255, 255, 234 | 85, 255, 170, 235 | 170, 255, 85, 236 | 255, 255, 0, 237 | 255, 170, 0] 238 | return palette 239 | 240 | 241 | if __name__ == '__main__': 242 | main() 243 | --------------------------------------------------------------------------------