├── .idea
├── inspectionProfiles
│ └── profiles_settings.xml
├── libraries
│ └── R_User_Library.xml
├── misc.xml
├── modules.xml
├── parse2.iml
├── vcs.xml
└── workspace.xml
├── .vscode
└── .ropeproject
│ └── config.py
├── README.md
├── checkpoints
├── exp
│ └── result checkpoints.txt
└── init
│ └── pretrained models.txt
├── dataset
├── ATR
│ ├── README.MD
│ ├── select_id.txt
│ ├── test_id.txt
│ └── train_id.txt
├── CCF
│ ├── select_id.txt
│ ├── test_id.txt
│ └── train_id.txt
├── CIHP
│ ├── README.md
│ ├── all_id.txt
│ ├── human_colormap.mat
│ ├── test_id.txt
│ ├── train_id.txt
│ ├── trainval_id.txt
│ └── val_id.txt
├── LIP
│ ├── README.md
│ ├── hard_id.txt
│ ├── train_id.txt
│ ├── train_val.txt
│ ├── val.txt
│ └── val_id.txt
├── PPSS
│ ├── test_id.txt
│ └── train_id.txt
├── Pascal
│ ├── README.MD
│ ├── train_id.txt
│ └── val_id.txt
├── __init__.py
├── data_CIHP.py
├── data_atr.py
├── data_ccf.py
├── data_lip.py
├── data_pascal.py
├── data_ppss.py
├── data_transforms.py
├── transforms.py
└── weights.py
├── doc
└── architecture.png
├── evaluate_pascal.py
├── evaluate_pascal.sh
├── inplace_abn
├── __init__.py
├── bn.py
├── functions.py
└── src
│ ├── checks.h
│ ├── common.h
│ ├── inplace_abn.cpp
│ ├── inplace_abn.h
│ ├── inplace_abn_cpu.cpp
│ ├── inplace_abn_cuda.cu
│ ├── inplace_abn_cuda_half.cu
│ └── utils
│ ├── checks.h
│ ├── common.h
│ └── cuda.cuh
├── modules
├── __init__.py
├── com_mod.py
├── convGRU.py
├── inits.py
└── parse_mod.py
├── network
├── ResNet_stem_converter.py
├── __init__.py
├── baseline.py
└── gnn_parse.py
├── requirements.txt
├── train
├── train_atr.py
├── train_ccf.py
├── train_lip.py
├── train_pascal.py
└── train_ppss.py
├── train_baseline.py
├── train_pascal.sh
├── utils
├── __init__.py
├── aaf
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── layers.cpython-36.pyc
│ │ └── losses.cpython-36.pyc
│ ├── layers.py
│ └── losses.py
├── best
│ └── lovasz_loss.py
├── gnn_loss.py
├── learning_policy.py
├── lovasz_loss.py
├── metric.py
├── parallel.py
└── visualize.py
└── val
├── evaluate_atr.py
├── evaluate_ccf.py
├── evaluate_lip.py
├── evaluate_pascal.py
├── evaluate_ppss.py
├── f1_eval.py
└── f1_eval_atr.py
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/libraries/R_User_Library.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/parse2.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 | true
19 | DEFINITION_ORDER
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 | 1589771543740
65 |
66 |
67 | 1589771543740
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
--------------------------------------------------------------------------------
/.vscode/.ropeproject/config.py:
--------------------------------------------------------------------------------
1 | # The default ``config.py``
2 | # flake8: noqa
3 |
4 |
5 | def set_prefs(prefs):
6 | """This function is called before opening the project"""
7 |
8 | # Specify which files and folders to ignore in the project.
9 | # Changes to ignored resources are not added to the history and
10 | # VCSs. Also they are not returned in `Project.get_files()`.
11 | # Note that ``?`` and ``*`` match all characters but slashes.
12 | # '*.pyc': matches 'test.pyc' and 'pkg/test.pyc'
13 | # 'mod*.pyc': matches 'test/mod1.pyc' but not 'mod/1.pyc'
14 | # '.svn': matches 'pkg/.svn' and all of its children
15 | # 'build/*.o': matches 'build/lib.o' but not 'build/sub/lib.o'
16 | # 'build//*.o': matches 'build/lib.o' and 'build/sub/lib.o'
17 | prefs['ignored_resources'] = ['*.pyc', '*~', '.ropeproject',
18 | '.hg', '.svn', '_svn', '.git', '.tox']
19 |
20 | # Specifies which files should be considered python files. It is
21 | # useful when you have scripts inside your project. Only files
22 | # ending with ``.py`` are considered to be python files by
23 | # default.
24 | # prefs['python_files'] = ['*.py']
25 |
26 | # Custom source folders: By default rope searches the project
27 | # for finding source folders (folders that should be searched
28 | # for finding modules). You can add paths to that list. Note
29 | # that rope guesses project source folders correctly most of the
30 | # time; use this if you have any problems.
31 | # The folders should be relative to project root and use '/' for
32 | # separating folders regardless of the platform rope is running on.
33 | # 'src/my_source_folder' for instance.
34 | # prefs.add('source_folders', 'src')
35 |
36 | # You can extend python path for looking up modules
37 | # prefs.add('python_path', '~/python/')
38 |
39 | # Should rope save object information or not.
40 | prefs['save_objectdb'] = True
41 | prefs['compress_objectdb'] = False
42 |
43 | # If `True`, rope analyzes each module when it is being saved.
44 | prefs['automatic_soa'] = True
45 | # The depth of calls to follow in static object analysis
46 | prefs['soa_followed_calls'] = 0
47 |
48 | # If `False` when running modules or unit tests "dynamic object
49 | # analysis" is turned off. This makes them much faster.
50 | prefs['perform_doa'] = True
51 |
52 | # Rope can check the validity of its object DB when running.
53 | prefs['validate_objectdb'] = True
54 |
55 | # How many undos to hold?
56 | prefs['max_history_items'] = 32
57 |
58 | # Shows whether to save history across sessions.
59 | prefs['save_history'] = True
60 | prefs['compress_history'] = False
61 |
62 | # Set the number spaces used for indenting. According to
63 | # :PEP:`8`, it is best to use 4 spaces. Since most of rope's
64 | # unit-tests use 4 spaces it is more reliable, too.
65 | prefs['indent_size'] = 4
66 |
67 | # Builtin and c-extension modules that are allowed to be imported
68 | # and inspected by rope.
69 | prefs['extension_modules'] = []
70 |
71 | # Add all standard c-extensions to extension_modules list.
72 | prefs['import_dynload_stdmods'] = True
73 |
74 | # If `True` modules with syntax errors are considered to be empty.
75 | # The default value is `False`; When `False` syntax errors raise
76 | # `rope.base.exceptions.ModuleSyntaxError` exception.
77 | prefs['ignore_syntax_errors'] = False
78 |
79 | # If `True`, rope ignores unresolvable imports. Otherwise, they
80 | # appear in the importing namespace.
81 | prefs['ignore_bad_imports'] = False
82 |
83 | # If `True`, rope will insert new module imports as
84 | # `from import ` by default.
85 | prefs['prefer_module_from_imports'] = False
86 |
87 | # If `True`, rope will transform a comma list of imports into
88 | # multiple separate import statements when organizing
89 | # imports.
90 | prefs['split_imports'] = False
91 |
92 | # If `True`, rope will remove all top-level import statements and
93 | # reinsert them at the top of the module when making changes.
94 | prefs['pull_imports_to_top'] = True
95 |
96 | # If `True`, rope will sort imports alphabetically by module name instead
97 | # of alphabetically by import statement, with from imports after normal
98 | # imports.
99 | prefs['sort_imports_alphabetically'] = False
100 |
101 | # Location of implementation of
102 | # rope.base.oi.type_hinting.interfaces.ITypeHintingFactory In general
103 | # case, you don't have to change this value, unless you're an rope expert.
104 | # Change this value to inject you own implementations of interfaces
105 | # listed in module rope.base.oi.type_hinting.providers.interfaces
106 | # For example, you can add you own providers for Django Models, or disable
107 | # the search type-hinting in a class hierarchy, etc.
108 | prefs['type_hinting_factory'] = (
109 | 'rope.base.oi.type_hinting.factory.default_type_hinting_factory')
110 |
111 |
112 | def project_opened(project):
113 | """This function is called after opening the project"""
114 | # Do whatever you like here!
115 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Hierarchical Human Parsing with Typed Part-Relation Reasoning (CVPR2020)
2 |
3 | ## Introduction
4 | The algorithm is described in the [CVPR 2020 paper: Hierarchical Human Parsing with Typed Part-Relation Reasoning](https://openaccess.thecvf.com/content_CVPR_2020/papers/Wang_Hierarchical_Human_Parsing_With_Typed_Part-Relation_Reasoning_CVPR_2020_paper.pdf).
5 |
6 | 
7 | ***
8 |
9 | ## Environment and installation
10 | This repository is developed under **CUDA-10.0** and **pytorch-1.2.0** in **python3.6**. The required packages can be installed by:
11 | ```bash
12 | pip install -r requirements.txt
13 | ```
14 |
15 | ## Structure of repo
16 | ````bash
17 | $HierarchicalHumanParsing
18 | ├── checkpoints
19 | │ ├── init
20 | ├── dataset
21 | │ ├── list
22 | ├── doc
23 | ├── inplace_abn
24 | │ ├── src
25 | ├── modules
26 | ├── network
27 | ├── utils
28 | ````
29 |
30 | ## Running the code
31 | ```bash
32 | python evaluate_pascal.py
33 | ```
34 |
35 | ***
36 | ## Citation
37 | If you find this code useful, please cite the related work with the following bibtex:
38 | ```
39 | @InProceedings{Wang_2020_CVPR,
40 | author = {Wang, Wenguan and Zhu, Hailong and Dai, Jifeng and Pang, Yanwei and Shen, Jianbing and Shao, Ling},
41 | title = {Hierarchical Human Parsing With Typed Part-Relation Reasoning},
42 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
43 | month = {June},
44 | year = {2020}
45 | }
46 |
47 | @InProceedings{Wang_2019_ICCV,
48 | author = {Wang, Wenguan and Zhang, Zhijie and Qi, Siyuan and Shen, Jianbing and Pang, Yanwei and Shao, Ling},
49 | title = {Learning Compositional Neural Information Fusion for Human Parsing},
50 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
51 | month = {October},
52 | year = {2019}
53 | }
54 | ```
55 |
--------------------------------------------------------------------------------
/checkpoints/exp/result checkpoints.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hlzhu09/Hierarchical-Human-Parsing/58737dbb7763f5a7baafd4801c5a287facf536cd/checkpoints/exp/result checkpoints.txt
--------------------------------------------------------------------------------
/checkpoints/init/pretrained models.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hlzhu09/Hierarchical-Human-Parsing/58737dbb7763f5a7baafd4801c5a287facf536cd/checkpoints/init/pretrained models.txt
--------------------------------------------------------------------------------
/dataset/ATR/README.MD:
--------------------------------------------------------------------------------
1 | background 0
2 | hat 1
3 | hair 2
4 | sunglass 3
5 | upper-clothes 4
6 | skirt 5
7 | pants 6
8 | dress 7
9 | belt 8
10 | left-shoe 9
11 | right-shoe 10
12 | face 11
13 | left-leg 12
14 | right-leg 13
15 | left-arm 14
16 | right-arm 15
17 | bag 16
18 | scarf 17
19 |
20 | 0 background
21 | 1-4 up
22 | 5-10 down
23 | 11 up
24 | 12-13 down
25 | 14-17 up
26 |
--------------------------------------------------------------------------------
/dataset/ATR/select_id.txt:
--------------------------------------------------------------------------------
1 | 997_392
2 | 997_586
3 | 2500_175
4 | 2500_198
5 | 2500_434
6 | 2500_520
7 | 2500_534
8 | 2500_783
9 | 4565_1329
10 | 4565_1414
11 | 4565_1427
12 | 4565_1572
13 | 4565_1575
14 | 4565_1624
15 | 4565_1730
16 | 4565_1991
17 | 4565_2002
18 | 4565_2046
19 | 4565_2672
20 | 4565_2732
--------------------------------------------------------------------------------
/dataset/CCF/select_id.txt:
--------------------------------------------------------------------------------
1 | 0770
2 | 0913
3 | 0474
4 | 0840
5 | 0662
6 | 0315
7 | 0102
8 | 0425
--------------------------------------------------------------------------------
/dataset/CCF/test_id.txt:
--------------------------------------------------------------------------------
1 | 0012
2 | 0018
3 | 0020
4 | 0022
5 | 0023
6 | 0028
7 | 0040
8 | 0045
9 | 0047
10 | 0050
11 | 0052
12 | 0055
13 | 0060
14 | 0061
15 | 0062
16 | 0063
17 | 0066
18 | 0084
19 | 0085
20 | 0087
21 | 0104
22 | 0106
23 | 0108
24 | 0110
25 | 0113
26 | 0114
27 | 0115
28 | 0119
29 | 0120
30 | 0126
31 | 0129
32 | 0147
33 | 0154
34 | 0160
35 | 0162
36 | 0163
37 | 0171
38 | 0175
39 | 0195
40 | 0197
41 | 0200
42 | 0202
43 | 0225
44 | 0240
45 | 0245
46 | 0251
47 | 0252
48 | 0254
49 | 0263
50 | 0277
51 | 0281
52 | 0286
53 | 0291
54 | 0292
55 | 0301
56 | 0304
57 | 0307
58 | 0325
59 | 0327
60 | 0330
61 | 0346
62 | 0366
63 | 0378
64 | 0383
65 | 0390
66 | 0399
67 | 0408
68 | 0433
69 | 0440
70 | 0459
71 | 0463
72 | 0465
73 | 0467
74 | 0479
75 | 0487
76 | 0490
77 | 0522
78 | 0554
79 | 0593
80 | 0625
81 | 0646
82 | 0699
83 | 0721
84 | 0790
85 | 0832
86 | 0848
87 | 0861
88 | 0871
89 | 0873
90 | 0881
91 | 0936
92 | 0993
--------------------------------------------------------------------------------
/dataset/CCF/train_id.txt:
--------------------------------------------------------------------------------
1 | 0001
2 | 0002
3 | 0003
4 | 0004
5 | 0006
6 | 0007
7 | 0008
8 | 0009
9 | 0010
10 | 0011
11 | 0013
12 | 0014
13 | 0015
14 | 0016
15 | 0017
16 | 0019
17 | 0021
18 | 0025
19 | 0026
20 | 0029
21 | 0030
22 | 0031
23 | 0032
24 | 0033
25 | 0034
26 | 0035
27 | 0036
28 | 0037
29 | 0039
30 | 0042
31 | 0043
32 | 0044
33 | 0046
34 | 0048
35 | 0049
36 | 0051
37 | 0053
38 | 0054
39 | 0057
40 | 0058
41 | 0059
42 | 0065
43 | 0067
44 | 0068
45 | 0069
46 | 0070
47 | 0071
48 | 0072
49 | 0073
50 | 0074
51 | 0075
52 | 0076
53 | 0077
54 | 0078
55 | 0079
56 | 0080
57 | 0081
58 | 0082
59 | 0083
60 | 0086
61 | 0088
62 | 0089
63 | 0091
64 | 0092
65 | 0093
66 | 0094
67 | 0095
68 | 0096
69 | 0097
70 | 0099
71 | 0100
72 | 0101
73 | 0102
74 | 0103
75 | 0105
76 | 0107
77 | 0111
78 | 0112
79 | 0116
80 | 0117
81 | 0118
82 | 0122
83 | 0123
84 | 0124
85 | 0125
86 | 0128
87 | 0131
88 | 0132
89 | 0133
90 | 0134
91 | 0135
92 | 0136
93 | 0139
94 | 0140
95 | 0143
96 | 0144
97 | 0148
98 | 0149
99 | 0150
100 | 0151
101 | 0152
102 | 0153
103 | 0155
104 | 0156
105 | 0157
106 | 0158
107 | 0159
108 | 0161
109 | 0164
110 | 0165
111 | 0166
112 | 0167
113 | 0168
114 | 0169
115 | 0172
116 | 0173
117 | 0174
118 | 0176
119 | 0177
120 | 0178
121 | 0179
122 | 0180
123 | 0181
124 | 0182
125 | 0183
126 | 0184
127 | 0185
128 | 0186
129 | 0187
130 | 0188
131 | 0189
132 | 0190
133 | 0191
134 | 0192
135 | 0193
136 | 0194
137 | 0196
138 | 0198
139 | 0201
140 | 0203
141 | 0204
142 | 0205
143 | 0206
144 | 0207
145 | 0208
146 | 0209
147 | 0210
148 | 0211
149 | 0212
150 | 0213
151 | 0214
152 | 0215
153 | 0216
154 | 0217
155 | 0218
156 | 0219
157 | 0220
158 | 0221
159 | 0222
160 | 0223
161 | 0224
162 | 0226
163 | 0227
164 | 0228
165 | 0229
166 | 0230
167 | 0231
168 | 0233
169 | 0234
170 | 0235
171 | 0237
172 | 0238
173 | 0239
174 | 0241
175 | 0243
176 | 0244
177 | 0246
178 | 0247
179 | 0248
180 | 0249
181 | 0250
182 | 0253
183 | 0255
184 | 0256
185 | 0257
186 | 0258
187 | 0259
188 | 0260
189 | 0261
190 | 0262
191 | 0265
192 | 0267
193 | 0268
194 | 0269
195 | 0270
196 | 0271
197 | 0272
198 | 0273
199 | 0274
200 | 0275
201 | 0276
202 | 0278
203 | 0279
204 | 0280
205 | 0282
206 | 0284
207 | 0285
208 | 0287
209 | 0288
210 | 0289
211 | 0290
212 | 0293
213 | 0294
214 | 0295
215 | 0297
216 | 0298
217 | 0299
218 | 0300
219 | 0302
220 | 0303
221 | 0305
222 | 0306
223 | 0308
224 | 0309
225 | 0310
226 | 0311
227 | 0312
228 | 0313
229 | 0314
230 | 0315
231 | 0316
232 | 0317
233 | 0318
234 | 0319
235 | 0320
236 | 0321
237 | 0322
238 | 0323
239 | 0324
240 | 0326
241 | 0328
242 | 0329
243 | 0331
244 | 0332
245 | 0333
246 | 0334
247 | 0335
248 | 0336
249 | 0337
250 | 0338
251 | 0339
252 | 0340
253 | 0341
254 | 0342
255 | 0344
256 | 0345
257 | 0347
258 | 0348
259 | 0349
260 | 0350
261 | 0351
262 | 0352
263 | 0353
264 | 0354
265 | 0355
266 | 0356
267 | 0357
268 | 0358
269 | 0359
270 | 0360
271 | 0361
272 | 0362
273 | 0363
274 | 0364
275 | 0365
276 | 0367
277 | 0369
278 | 0370
279 | 0371
280 | 0372
281 | 0373
282 | 0374
283 | 0375
284 | 0376
285 | 0377
286 | 0380
287 | 0381
288 | 0382
289 | 0384
290 | 0386
291 | 0388
292 | 0389
293 | 0391
294 | 0392
295 | 0393
296 | 0394
297 | 0395
298 | 0396
299 | 0398
300 | 0400
301 | 0401
302 | 0402
303 | 0403
304 | 0405
305 | 0406
306 | 0407
307 | 0409
308 | 0410
309 | 0411
310 | 0412
311 | 0413
312 | 0414
313 | 0415
314 | 0416
315 | 0417
316 | 0419
317 | 0420
318 | 0421
319 | 0422
320 | 0423
321 | 0424
322 | 0425
323 | 0426
324 | 0427
325 | 0428
326 | 0429
327 | 0431
328 | 0432
329 | 0434
330 | 0435
331 | 0436
332 | 0437
333 | 0438
334 | 0441
335 | 0442
336 | 0443
337 | 0444
338 | 0445
339 | 0446
340 | 0447
341 | 0448
342 | 0449
343 | 0450
344 | 0451
345 | 0452
346 | 0453
347 | 0454
348 | 0455
349 | 0456
350 | 0457
351 | 0458
352 | 0460
353 | 0461
354 | 0462
355 | 0464
356 | 0466
357 | 0468
358 | 0469
359 | 0470
360 | 0471
361 | 0472
362 | 0473
363 | 0474
364 | 0475
365 | 0476
366 | 0477
367 | 0478
368 | 0480
369 | 0481
370 | 0482
371 | 0483
372 | 0484
373 | 0485
374 | 0486
375 | 0488
376 | 0489
377 | 0491
378 | 0492
379 | 0493
380 | 0494
381 | 0495
382 | 0497
383 | 0498
384 | 0499
385 | 0500
386 | 0501
387 | 0502
388 | 0503
389 | 0504
390 | 0505
391 | 0506
392 | 0507
393 | 0508
394 | 0509
395 | 0510
396 | 0511
397 | 0513
398 | 0514
399 | 0515
400 | 0516
401 | 0517
402 | 0518
403 | 0519
404 | 0520
405 | 0521
406 | 0523
407 | 0524
408 | 0525
409 | 0526
410 | 0527
411 | 0528
412 | 0529
413 | 0530
414 | 0531
415 | 0532
416 | 0533
417 | 0534
418 | 0535
419 | 0536
420 | 0537
421 | 0538
422 | 0539
423 | 0540
424 | 0541
425 | 0542
426 | 0543
427 | 0544
428 | 0545
429 | 0546
430 | 0547
431 | 0548
432 | 0549
433 | 0551
434 | 0552
435 | 0553
436 | 0555
437 | 0556
438 | 0558
439 | 0559
440 | 0560
441 | 0561
442 | 0562
443 | 0563
444 | 0564
445 | 0565
446 | 0566
447 | 0567
448 | 0568
449 | 0569
450 | 0570
451 | 0571
452 | 0572
453 | 0573
454 | 0574
455 | 0575
456 | 0577
457 | 0578
458 | 0579
459 | 0580
460 | 0581
461 | 0582
462 | 0583
463 | 0584
464 | 0585
465 | 0586
466 | 0587
467 | 0588
468 | 0589
469 | 0590
470 | 0591
471 | 0592
472 | 0594
473 | 0595
474 | 0596
475 | 0597
476 | 0598
477 | 0600
478 | 0601
479 | 0602
480 | 0603
481 | 0604
482 | 0605
483 | 0606
484 | 0607
485 | 0608
486 | 0610
487 | 0611
488 | 0612
489 | 0613
490 | 0614
491 | 0615
492 | 0616
493 | 0617
494 | 0618
495 | 0619
496 | 0620
497 | 0621
498 | 0622
499 | 0623
500 | 0624
501 | 0626
502 | 0627
503 | 0628
504 | 0629
505 | 0630
506 | 0631
507 | 0632
508 | 0633
509 | 0634
510 | 0635
511 | 0636
512 | 0637
513 | 0638
514 | 0639
515 | 0640
516 | 0641
517 | 0642
518 | 0643
519 | 0644
520 | 0645
521 | 0647
522 | 0648
523 | 0649
524 | 0650
525 | 0651
526 | 0653
527 | 0654
528 | 0655
529 | 0656
530 | 0657
531 | 0659
532 | 0660
533 | 0661
534 | 0662
535 | 0663
536 | 0664
537 | 0665
538 | 0666
539 | 0668
540 | 0669
541 | 0670
542 | 0671
543 | 0672
544 | 0673
545 | 0675
546 | 0677
547 | 0678
548 | 0679
549 | 0680
550 | 0681
551 | 0682
552 | 0684
553 | 0686
554 | 0687
555 | 0688
556 | 0689
557 | 0690
558 | 0691
559 | 0692
560 | 0693
561 | 0694
562 | 0695
563 | 0696
564 | 0697
565 | 0698
566 | 0700
567 | 0701
568 | 0702
569 | 0703
570 | 0704
571 | 0705
572 | 0706
573 | 0707
574 | 0708
575 | 0709
576 | 0711
577 | 0712
578 | 0713
579 | 0714
580 | 0715
581 | 0716
582 | 0717
583 | 0718
584 | 0719
585 | 0720
586 | 0722
587 | 0723
588 | 0725
589 | 0726
590 | 0727
591 | 0728
592 | 0729
593 | 0730
594 | 0731
595 | 0733
596 | 0734
597 | 0736
598 | 0737
599 | 0738
600 | 0739
601 | 0741
602 | 0742
603 | 0743
604 | 0744
605 | 0745
606 | 0746
607 | 0747
608 | 0748
609 | 0749
610 | 0750
611 | 0751
612 | 0752
613 | 0753
614 | 0755
615 | 0756
616 | 0757
617 | 0758
618 | 0759
619 | 0760
620 | 0761
621 | 0762
622 | 0763
623 | 0764
624 | 0765
625 | 0766
626 | 0767
627 | 0768
628 | 0769
629 | 0770
630 | 0771
631 | 0772
632 | 0773
633 | 0774
634 | 0775
635 | 0776
636 | 0777
637 | 0778
638 | 0779
639 | 0780
640 | 0781
641 | 0782
642 | 0783
643 | 0784
644 | 0785
645 | 0786
646 | 0787
647 | 0788
648 | 0789
649 | 0791
650 | 0792
651 | 0793
652 | 0794
653 | 0795
654 | 0798
655 | 0799
656 | 0800
657 | 0801
658 | 0802
659 | 0803
660 | 0804
661 | 0805
662 | 0806
663 | 0807
664 | 0808
665 | 0809
666 | 0810
667 | 0811
668 | 0812
669 | 0813
670 | 0814
671 | 0815
672 | 0816
673 | 0817
674 | 0818
675 | 0819
676 | 0820
677 | 0821
678 | 0822
679 | 0823
680 | 0824
681 | 0825
682 | 0826
683 | 0827
684 | 0828
685 | 0829
686 | 0831
687 | 0834
688 | 0835
689 | 0836
690 | 0837
691 | 0838
692 | 0839
693 | 0840
694 | 0842
695 | 0843
696 | 0844
697 | 0845
698 | 0847
699 | 0849
700 | 0850
701 | 0851
702 | 0852
703 | 0853
704 | 0854
705 | 0856
706 | 0857
707 | 0858
708 | 0859
709 | 0860
710 | 0862
711 | 0863
712 | 0864
713 | 0865
714 | 0866
715 | 0867
716 | 0868
717 | 0869
718 | 0870
719 | 0872
720 | 0874
721 | 0875
722 | 0876
723 | 0877
724 | 0878
725 | 0879
726 | 0880
727 | 0882
728 | 0883
729 | 0884
730 | 0885
731 | 0886
732 | 0887
733 | 0888
734 | 0889
735 | 0890
736 | 0891
737 | 0892
738 | 0893
739 | 0894
740 | 0895
741 | 0897
742 | 0898
743 | 0899
744 | 0900
745 | 0901
746 | 0902
747 | 0903
748 | 0904
749 | 0905
750 | 0906
751 | 0907
752 | 0909
753 | 0910
754 | 0911
755 | 0912
756 | 0913
757 | 0914
758 | 0915
759 | 0916
760 | 0917
761 | 0918
762 | 0919
763 | 0920
764 | 0921
765 | 0922
766 | 0925
767 | 0926
768 | 0927
769 | 0928
770 | 0929
771 | 0930
772 | 0931
773 | 0932
774 | 0933
775 | 0934
776 | 0935
777 | 0937
778 | 0938
779 | 0939
780 | 0940
781 | 0941
782 | 0943
783 | 0944
784 | 0945
785 | 0946
786 | 0947
787 | 0948
788 | 0949
789 | 0950
790 | 0951
791 | 0952
792 | 0953
793 | 0954
794 | 0956
795 | 0957
796 | 0958
797 | 0959
798 | 0960
799 | 0961
800 | 0962
801 | 0964
802 | 0965
803 | 0966
804 | 0967
805 | 0968
806 | 0969
807 | 0970
808 | 0971
809 | 0972
810 | 0973
811 | 0974
812 | 0975
813 | 0976
814 | 0977
815 | 0978
816 | 0979
817 | 0980
818 | 0981
819 | 0983
820 | 0984
821 | 0986
822 | 0987
823 | 0988
824 | 0989
825 | 0990
826 | 0991
827 | 0992
828 | 0994
829 | 0995
830 | 0996
831 | 0999
832 | 1001
833 | 1002
834 | 1003
835 | 1004
--------------------------------------------------------------------------------
/dataset/CIHP/README.md:
--------------------------------------------------------------------------------
1 |
2 | Images: images
3 | Category_ids: semantic part segmentation labels Categories: visualized semantic part segmentation labels
4 | Human_ids: semantic person segmentation labels Human: visualized semantic person segmentation labels
5 | Instance_ids: instance-level human parsing labels Instances: visualized instance-level human parsing labels
6 |
7 |
8 | Label order of semantic part segmentation:
9 |
10 | 1.Hat
11 | 2.Hair
12 | 3.Glove
13 | 4.Sunglasses
14 | 5.UpperClothes
15 | 6.Dress
16 | 7.Coat
17 | 8.Socks
18 | 9.Pants
19 | 10.Torso-skin
20 | 11.Scarf
21 | 12.Skirt
22 | 13.Face
23 | 14.Left-arm
24 | 15.Right-arm
25 | 16.Left-leg
26 | 17.Right-leg
27 | 18.Left-shoe
28 | 19.Right-shoe
--------------------------------------------------------------------------------
/dataset/CIHP/human_colormap.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hlzhu09/Hierarchical-Human-Parsing/58737dbb7763f5a7baafd4801c5a287facf536cd/dataset/CIHP/human_colormap.mat
--------------------------------------------------------------------------------
/dataset/LIP/README.md:
--------------------------------------------------------------------------------
1 |
2 | Images: images
3 | Category_ids: semantic part segmentation labels Categories: visualized semantic part segmentation labels
4 | Human_ids: semantic person segmentation labels Human: visualized semantic person segmentation labels
5 | Instance_ids: instance-level human parsing labels Instances: visualized instance-level human parsing labels
6 |
7 |
8 | Label order of semantic part segmentation:
9 |
10 | 1.Hat
11 | 2.Hair
12 | 3.Glove
13 | 4.Sunglasses
14 | 5.UpperClothes
15 | 6.Dress
16 | 7.Coat
17 | 8.Socks
18 | 9.Pants
19 | 10.Torso-skin
20 | 11.Scarf
21 | 12.Skirt
22 | 13.Face
23 | 14.Left-arm
24 | 15.Right-arm
25 | 16.Left-leg
26 | 17.Right-leg
27 | 18.Left-shoe
28 | 19.Right-shoe
--------------------------------------------------------------------------------
/dataset/Pascal/README.MD:
--------------------------------------------------------------------------------
1 | background 0
2 | head 1
3 | torso 2
4 | upper-arm 3
5 | lower-arm 4
6 | upper-leg 5
7 | lower-leg 6
8 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hlzhu09/Hierarchical-Human-Parsing/58737dbb7763f5a7baafd4801c5a287facf536cd/dataset/__init__.py
--------------------------------------------------------------------------------
/dataset/data_CIHP.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import random
4 |
5 | import cv2
6 | import numpy as np
7 | import torch.utils.data as data
8 | import torchvision.transforms as transforms
9 | from .data_transforms import RandomRotate
10 | from PIL import Image
11 |
12 |
13 | # ###### Data loading #######
14 | def make_dataset(root, lst):
15 | # append all index
16 | fid = open(lst, 'r')
17 | imgs, segs, segs_rev = [], [], []
18 | for line in fid.readlines():
19 | idx = line.strip().split(' ')[0]
20 | image_path = os.path.join(root, 'JPEGImages/' + str(idx) + '.jpg')
21 | seg_path = os.path.join(root, 'Segmentations/' + str(idx) + '.png')
22 | seg_rev_path = os.path.join(root, 'Segmentations_rev/' + str(idx) + '.png')
23 | imgs.append(image_path)
24 | segs.append(seg_path)
25 | segs_rev.append(seg_rev_path)
26 | return imgs, segs, segs_rev
27 |
28 |
29 | # ###### val resize & crop ######
30 | def scale_crop(img, seg, crop_size):
31 | oh, ow = seg.shape
32 | pad_h = max(0, crop_size - oh)
33 | pad_w = max(0, crop_size - ow)
34 | if pad_h > 0 or pad_w > 0:
35 | img_pad = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT,
36 | value=(0.0, 0.0, 0.0))
37 | seg_pad = cv2.copyMakeBorder(seg, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT,
38 | value=255)
39 | else:
40 | img_pad, seg_pad = img, seg
41 |
42 | img = np.asarray(img_pad[0: crop_size, 0: crop_size], np.float32)
43 | seg = np.asarray(seg_pad[0: crop_size, 0: crop_size], np.float32)
44 |
45 | return img, seg
46 |
47 |
48 | class DatasetGenerator(data.Dataset):
49 | def __init__(self, root, list_path, crop_size, training=True):
50 |
51 | imgs, segs, segs_rev = make_dataset(root, list_path)
52 |
53 | self.root = root
54 | self.imgs = imgs
55 | self.segs = segs
56 | self.segs_rev = segs_rev
57 | self.crop_size = crop_size
58 | self.training = training
59 | self.colorjitter = transforms.ColorJitter(brightness=0.1, contrast=0.5, saturation=0.5, hue=0.1)
60 | self.random_rotate=RandomRotate(20)
61 |
62 | def __getitem__(self, index):
63 | # load data
64 | mean = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32)
65 | name = self.imgs[index].split('/')[-1][:-4]
66 | img = cv2.imread(self.imgs[index], cv2.IMREAD_COLOR)
67 | seg_in = cv2.imread(self.segs[index], cv2.IMREAD_GRAYSCALE)
68 | seg_rev_in = cv2.imread(self.segs_rev[index], cv2.IMREAD_GRAYSCALE)
69 |
70 | if self.training:
71 | #colorjitter and rotate
72 | if random.random() < 0.5:
73 | img = Image.fromarray(img)
74 | seg = Image.fromarray(seg)
75 | img = self.colorjitter(img)
76 | img, seg = self.random_rotate(img, seg)
77 | img = np.array(img).astype(np.uint8)
78 | seg = np.array(seg).astype(np.uint8)
79 | # random mirror
80 | flip = np.random.choice(2) * 2 - 1
81 | img = img[:, ::flip, :]
82 | if flip == -1:
83 | seg = seg_rev_in
84 | else:
85 | seg = seg_in
86 | # random scale
87 | ratio = random.uniform(0.5, 2.0)
88 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR)
89 | seg = cv2.resize(seg, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST)
90 | img = np.array(img).astype(np.float32) - mean
91 |
92 | # pad & crop
93 | img_h, img_w = seg.shape
94 | pad_h = max(self.crop_size - img_h, 0)
95 | pad_w = max(self.crop_size - img_w, 0)
96 | if pad_h > 0 or pad_w > 0:
97 | img_pad = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT,
98 | value=(0.0, 0.0, 0.0))
99 | seg_pad = cv2.copyMakeBorder(seg, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT,
100 | value=(255,))
101 | else:
102 | img_pad, seg_pad = img, seg
103 |
104 | img_h, img_w = seg_pad.shape
105 | h_off = random.randint(0, img_h - self.crop_size)
106 | w_off = random.randint(0, img_w - self.crop_size)
107 | img = np.asarray(img_pad[h_off: h_off + self.crop_size, w_off: w_off + self.crop_size], np.float32)
108 | seg = np.asarray(seg_pad[h_off: h_off + self.crop_size, w_off: w_off + self.crop_size], np.float32)
109 | img = img.transpose((2, 0, 1))
110 | # generate body masks
111 | seg_half = seg.copy()
112 | seg_half[(seg_half > 0) & (seg_half <= 7)] = 1
113 | seg_half[(seg_half > 7) & (seg_half <= 10)] = 2
114 | seg_half[seg_half == 11] = 1
115 | seg_half[seg_half == 12] = 2
116 | seg_half[(seg_half > 12) & (seg_half <= 15)] = 1
117 | seg_half[(seg_half > 15) & (seg_half < 255)] = 2
118 | seg_full = seg.copy()
119 | seg_full[(seg_full > 0) & (seg_full < 255)] = 1
120 |
121 | else:
122 | h, w = seg_in.shape
123 | max_size = max(w, h)
124 | ratio = self.crop_size / max_size
125 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR)
126 | seg = cv2.resize(seg_in, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST)
127 | img = np.array(img).astype(np.float32) - mean
128 | img, seg = scale_crop(img, seg, crop_size=self.crop_size)
129 | img = img.transpose((2, 0, 1))
130 | # generate body masks
131 | seg_half = seg.copy()
132 | seg_half[(seg_half > 0) & (seg_half <= 7)] = 1
133 | seg_half[(seg_half > 7) & (seg_half <= 10)] = 2
134 | seg_half[seg_half == 11] = 1
135 | seg_half[seg_half == 12] = 2
136 | seg_half[(seg_half > 12) & (seg_half <= 15)] = 1
137 | seg_half[(seg_half > 15) & (seg_half < 255)] = 2
138 | seg_full = seg.copy()
139 | seg_full[(seg_full > 0) & (seg_full < 255)] = 1
140 |
141 | images = img.copy()
142 | segmentations = seg.copy()
143 | segmentations_half = seg_half.copy()
144 | segmentations_full = seg_full.copy()
145 |
146 | return images, segmentations, segmentations_half, segmentations_full, name
147 |
148 | def __len__(self):
149 | return len(self.imgs)
150 |
151 |
152 | class ValidationLoader(data.Dataset):
153 | """evaluate on LIP val set"""
154 |
155 | def __init__(self, root, list_path, crop_size):
156 | fid = open(list_path, 'r')
157 | imgs, segs = [], []
158 | for line in fid.readlines():
159 | idx = line.strip().split(' ')[0]
160 | image_path = os.path.join(root, 'images/' + str(idx) + '.jpg')
161 | seg_path = os.path.join(root, 'segmentations/' + str(idx) + '.png')
162 | imgs.append(image_path)
163 | segs.append(seg_path)
164 |
165 | self.root = root
166 | self.imgs = imgs
167 | self.segs = segs
168 | self.crop_size = crop_size
169 |
170 | def __getitem__(self, index):
171 | # load data
172 | mean = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32)
173 | name = self.imgs[index].split('/')[-1][:-4]
174 | img = cv2.imread(self.imgs[index], cv2.IMREAD_COLOR)
175 | seg = cv2.imread(self.segs[index], cv2.IMREAD_GRAYSCALE)
176 |
177 | h, w = seg.shape
178 | max_size = max(w, h)
179 | ratio = self.crop_size / max_size
180 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR)
181 | seg = cv2.resize(seg, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST)
182 | img = np.array(img).astype(np.float32) - mean
183 | img, seg = scale_crop(img, seg, crop_size=self.crop_size)
184 | img = img.transpose((2, 0, 1))
185 |
186 | images = img.copy()
187 | segmentations = seg.copy()
188 |
189 | return images, segmentations, name
190 |
191 | def __len__(self):
192 | return len(self.imgs)
193 |
--------------------------------------------------------------------------------
/dataset/data_atr.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import random
4 |
5 | import cv2
6 | import numpy as np
7 | import torch.utils.data as data
8 | import torchvision.transforms as transforms
9 | from .data_transforms import RandomRotate
10 | from PIL import Image
11 |
12 |
13 | # ###### Data loading #######
14 | def make_dataset(root, lst):
15 | # append all index
16 | fid = open(lst, 'r')
17 | imgs, segs, segs_rev = [], [], []
18 | for line in fid.readlines():
19 | idx = line.strip().split(' ')[0]
20 | image_path = os.path.join(root, 'JPEGImages/' + str(idx) + '.jpg')
21 | seg_path = os.path.join(root, 'Segmentations/' + str(idx) + '.png')
22 | seg_rev_path = os.path.join(root, 'SegmentationsRev/' + str(idx) + '_rev.png')
23 | imgs.append(image_path)
24 | segs.append(seg_path)
25 | segs_rev.append(seg_rev_path)
26 | return imgs, segs, segs_rev
27 |
28 |
29 | # ###### val resize & crop ######
30 | def scale_crop(img, seg, crop_size):
31 | oh, ow = seg.shape
32 | pad_h = max(0, crop_size - oh)
33 | pad_w = max(0, crop_size - ow)
34 | if pad_h > 0 or pad_w > 0:
35 | img_pad = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT,
36 | value=(0.0, 0.0, 0.0))
37 | seg_pad = cv2.copyMakeBorder(seg, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT,
38 | value=255)
39 | else:
40 | img_pad, seg_pad = img, seg
41 |
42 | img = np.asarray(img_pad[0: crop_size, 0: crop_size], np.float32)
43 | seg = np.asarray(seg_pad[0: crop_size, 0: crop_size], np.float32)
44 |
45 | return img, seg
46 |
47 |
48 | class DatasetGenerator(data.Dataset):
49 | def __init__(self, root, list_path, crop_size, training=True):
50 |
51 | imgs, segs, segs_rev = make_dataset(root, list_path)
52 |
53 | self.root = root
54 | self.imgs = imgs
55 | self.segs = segs
56 | self.segs_rev = segs_rev
57 | self.crop_size = crop_size
58 | self.training = training
59 | self.colorjitter = transforms.ColorJitter(brightness=0.1, contrast=0.5, saturation=0.5, hue=0.1)
60 | self.random_rotate=RandomRotate(20)
61 | def __getitem__(self, index):
62 | # load data
63 | mean = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32)
64 | name = self.imgs[index].split('/')[-1][:-4]
65 | img = cv2.imread(self.imgs[index], cv2.IMREAD_COLOR)
66 | seg_in = cv2.imread(self.segs[index], cv2.IMREAD_GRAYSCALE)
67 | seg_rev_in = cv2.imread(self.segs_rev[index], cv2.IMREAD_GRAYSCALE)
68 |
69 | if self.training:
70 | #colorjitter and rotate
71 | if random.random() < 0.5:
72 | img = Image.fromarray(img)
73 | seg = Image.fromarray(seg)
74 | img = self.colorjitter(img)
75 | img, seg = self.random_rotate(img, seg)
76 | img = np.array(img).astype(np.uint8)
77 | seg = np.array(seg).astype(np.uint8)
78 | # random mirror
79 | flip = np.random.choice(2) * 2 - 1
80 | img = img[:, ::flip, :]
81 | if flip == -1:
82 | seg = seg_rev_in
83 | else:
84 | seg = seg_in
85 | # random scale
86 | ratio = random.uniform(0.5, 2.0)
87 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR)
88 | seg = cv2.resize(seg, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST)
89 | img = np.array(img).astype(np.float32) - mean
90 |
91 | # pad & crop
92 | img_h, img_w = seg.shape
93 | pad_h = max(self.crop_size - img_h, 0)
94 | pad_w = max(self.crop_size - img_w, 0)
95 | if pad_h > 0 or pad_w > 0:
96 | img_pad = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT,
97 | value=(0.0, 0.0, 0.0))
98 | seg_pad = cv2.copyMakeBorder(seg, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT,
99 | value=(255,))
100 | else:
101 | img_pad, seg_pad = img, seg
102 |
103 | img_h, img_w = seg_pad.shape
104 | h_off = random.randint(0, img_h - self.crop_size)
105 | w_off = random.randint(0, img_w - self.crop_size)
106 | img = np.asarray(img_pad[h_off: h_off + self.crop_size, w_off: w_off + self.crop_size], np.float32)
107 | seg = np.asarray(seg_pad[h_off: h_off + self.crop_size, w_off: w_off + self.crop_size], np.uint8)
108 | img = img.transpose((2, 0, 1))
109 | # generate body masks
110 | seg_half = seg.copy()
111 | seg_half[(seg_half > 0) & (seg_half <= 7)] = 1
112 | seg_half[(seg_half > 7) & (seg_half <= 10)] = 2
113 | seg_half[seg_half == 11] = 1
114 | seg_half[seg_half == 12] = 2
115 | seg_half[(seg_half > 12) & (seg_half <= 15)] = 1
116 | seg_half[(seg_half > 15) & (seg_half < 255)] = 2
117 | seg_full = seg.copy()
118 | seg_full[(seg_full > 0) & (seg_full < 255)] = 1
119 |
120 | else:
121 | h, w = seg_in.shape
122 | max_size = max(w, h)
123 | ratio = self.crop_size / max_size
124 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR)
125 | seg = cv2.resize(seg_in, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST)
126 | img = np.array(img).astype(np.float32) - mean
127 | img, seg = scale_crop(img, seg, crop_size=self.crop_size)
128 | img = img.transpose((2, 0, 1))
129 | # generate body masks
130 | # 0 background, 1-4 up, 5-10 down, 11 up, 12-13 down, 14-17 up
131 | seg_half = seg.copy()
132 | seg_half[(seg_half > 0) & (seg_half <= 4)] = 1
133 | seg_half[(seg_half > 4) & (seg_half <= 10)] = 2
134 | seg_half[seg_half == 11] = 1
135 | seg_half[(seg_half > 11) & (seg_half <= 13)] = 2
136 | seg_half[(seg_half > 13) & (seg_half < 255)] = 1
137 | seg_full = seg.copy()
138 | seg_full[(seg_full > 0) & (seg_full < 255)] = 1
139 |
140 | images = img.copy()
141 | segmentations = seg.copy()
142 | segmentations_half = seg_half.copy()
143 | segmentations_full = seg_full.copy()
144 |
145 | return images, segmentations, segmentations_half, segmentations_full, name
146 |
147 | def __len__(self):
148 | return len(self.imgs)
149 |
150 |
151 | class ATRTestGenerator(data.Dataset):
152 | def __init__(self, root, list_path, crop_size):
153 |
154 | fid = open(list_path, 'r')
155 | imgs, segs = [], []
156 | for line in fid.readlines():
157 | idx = line.strip().split(' ')[0]
158 | image_path = os.path.join(root, 'JPEGImages/' + str(idx) + '.jpg')
159 | seg_path = os.path.join(root, 'Segmentations/' + str(idx) + '.png')
160 | imgs.append(image_path)
161 | segs.append(seg_path)
162 |
163 | self.root = root
164 | self.imgs = imgs
165 | self.segs = segs
166 | self.crop_size = crop_size
167 |
168 | def __getitem__(self, index):
169 | # load data
170 | mean = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32)
171 | name = self.imgs[index].split('/')[-1][:-4]
172 | img = cv2.imread(self.imgs[index], cv2.IMREAD_COLOR)
173 | seg = cv2.imread(self.segs[index], cv2.IMREAD_GRAYSCALE)
174 | ori_size = img.shape
175 |
176 | h, w = seg.shape
177 | length = max(w, h)
178 | ratio = self.crop_size / length
179 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR)
180 | img = np.array(img).astype(np.float32) - mean
181 | img = img.transpose((2, 0, 1))
182 |
183 | images = img.copy()
184 | segmentations = seg.copy()
185 |
186 | return images, segmentations, np.array(ori_size), name
187 |
188 | def __len__(self):
189 | return len(self.imgs)
190 |
--------------------------------------------------------------------------------
/dataset/data_ccf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import random
4 |
5 | import cv2
6 | import numpy as np
7 | import torch.utils.data as data
8 | import torchvision.transforms as transforms
9 | from .data_transforms import RandomRotate
10 | from PIL import Image
11 |
12 |
13 | # ###### Data loading #######
14 | def make_dataset(root, lst):
15 | # append all index
16 | fid = open(lst, 'r')
17 | imgs, segs = [], []
18 | for line in fid.readlines():
19 | idx = line.strip().split(' ')[0]
20 | image_path = os.path.join(root, 'JPEGImages/' + str(idx) + '.jpg')
21 | seg_path = os.path.join(root, 'Segmentations/' + str(idx) + '.png')
22 | imgs.append(image_path)
23 | segs.append(seg_path)
24 | return imgs, segs
25 |
26 |
27 | # ###### val resize & crop ######
28 | def scale_crop(img, seg, crop_size):
29 | oh, ow = seg.shape
30 | pad_h = max(crop_size - oh, 0)
31 | pad_ht, pad_hb = pad_h // 2, pad_h - pad_h // 2
32 | pad_w = max(crop_size - ow, 0)
33 | pad_wl, pad_wr = pad_w // 2, pad_w - pad_w // 2
34 | if pad_h > 0 or pad_w > 0:
35 | img_pad = cv2.copyMakeBorder(img, pad_ht, pad_hb, pad_wl, pad_wr, cv2.BORDER_CONSTANT,
36 | value=(0.0, 0.0, 0.0))
37 | seg_pad = cv2.copyMakeBorder(seg, pad_ht, pad_hb, pad_wl, pad_wr, cv2.BORDER_CONSTANT,
38 | value=(255,))
39 | else:
40 | img_pad, seg_pad = img, seg
41 |
42 | return img_pad, seg_pad
43 |
44 |
45 | class DatasetGenerator(data.Dataset):
46 | def __init__(self, root, list_path, crop_size, training=True):
47 |
48 | imgs, segs = make_dataset(root, list_path)
49 |
50 | self.root = root
51 | self.imgs = imgs
52 | self.segs = segs
53 | self.crop_size = crop_size
54 | self.training = training
55 | self.colorjitter = transforms.ColorJitter(brightness=0.1, contrast=0.5, saturation=0.5, hue=0.1)
56 | self.random_rotate=RandomRotate(20)
57 |
58 | def __getitem__(self, index):
59 | mean = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32)
60 | # load data
61 | name = self.imgs[index].split('/')[-1][:-4]
62 | img = cv2.imread(self.imgs[index], cv2.IMREAD_COLOR)
63 | seg = cv2.imread(self.segs[index], cv2.IMREAD_GRAYSCALE)
64 |
65 | if self.training:
66 | #colorjitter and rotate
67 | if random.random() < 0.5:
68 | img = Image.fromarray(img)
69 | seg = Image.fromarray(seg)
70 | img = self.colorjitter(img)
71 | img, seg = self.random_rotate(img, seg)
72 | img = np.array(img).astype(np.uint8)
73 | seg = np.array(seg).astype(np.uint8)
74 | # random scale
75 | ratio = random.uniform(0.5, 2.0)
76 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR)
77 | seg = cv2.resize(seg, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST)
78 | img = np.array(img).astype(np.float32) - mean
79 |
80 | # pad & crop
81 | img_h, img_w = seg.shape[:2]
82 | pad_h = max(self.crop_size - img_h, 0)
83 | pad_ht, pad_hb = pad_h // 2, pad_h - pad_h // 2
84 | pad_w = max(self.crop_size - img_w, 0)
85 | pad_wl, pad_wr = pad_w // 2, pad_w - pad_w // 2
86 | if pad_h > 0 or pad_w > 0:
87 | img_pad = cv2.copyMakeBorder(img, pad_ht, pad_hb, pad_wl, pad_wr, cv2.BORDER_CONSTANT,
88 | value=(0.0, 0.0, 0.0))
89 | seg_pad = cv2.copyMakeBorder(seg, pad_ht, pad_hb, pad_wl, pad_wr, cv2.BORDER_CONSTANT,
90 | value=(255,))
91 | else:
92 | img_pad, seg_pad = img, seg
93 |
94 | seg_pad_h, seg_pad_w = seg_pad.shape
95 | h_off = random.randint(0, seg_pad_h - self.crop_size)
96 | w_off = random.randint(0, seg_pad_w - self.crop_size)
97 | img = np.asarray(img_pad[h_off: h_off + self.crop_size, w_off: w_off + self.crop_size], np.float32)
98 | seg = np.asarray(seg_pad[h_off: h_off + self.crop_size, w_off: w_off + self.crop_size], np.uint8)
99 | # random mirror
100 | flip = np.random.choice(2) * 2 - 1
101 | img = img[:, ::flip, :]
102 | seg = seg[:, ::flip]
103 | # Generate target maps
104 | img = img.transpose((2, 0, 1))
105 | seg_half = seg.copy()
106 | seg_half[(seg_half > 0) & (seg_half <= 4)] = 1
107 | seg_half[seg_half == 5] = 2
108 | seg_half[(seg_half > 5) & (seg_half <= 7)] = 1
109 | seg_half[(seg_half > 7) & (seg_half <= 9)] = 2
110 | seg_half[(seg_half > 9) & (seg_half <= 11)] = 1
111 | seg_half[seg_half == 12] = 2
112 | seg_half[(seg_half > 12) & (seg_half <= 15)] = 1
113 | seg_half[seg_half == 16] = 2
114 | seg_half[(seg_half > 16) & (seg_half < 255)] = 1
115 | seg_full = seg.copy()
116 | seg_full[(seg_full > 0) & (seg_full < 255)] = 1
117 |
118 | else:
119 | h, w = seg.shape
120 | max_size = max(w, h)
121 | ratio = self.crop_size / max_size
122 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR)
123 | seg = cv2.resize(seg, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST)
124 | img = np.array(img).astype(np.float32) - mean
125 | img, seg = scale_crop(img, seg, crop_size=self.crop_size)
126 | img = img.transpose((2, 0, 1))
127 | seg_half = seg.copy()
128 | seg_half[(seg_half > 0) & (seg_half <= 4)] = 1
129 | seg_half[seg_half == 5] = 2
130 | seg_half[(seg_half > 5) & (seg_half <= 7)] = 1
131 | seg_half[(seg_half > 7) & (seg_half <= 9)] = 2
132 | seg_half[(seg_half > 9) & (seg_half <= 11)] = 1
133 | seg_half[seg_half == 12] = 2
134 | seg_half[(seg_half > 12) & (seg_half <= 15)] = 1
135 | seg_half[seg_half == 16] = 2
136 | seg_half[(seg_half > 16) & (seg_half < 255)] = 1
137 | seg_full = seg.copy()
138 | seg_full[(seg_full > 0) & (seg_full < 255)] = 1
139 |
140 | images = img.copy()
141 | segmentations = seg.copy()
142 | segmentations_half = seg_half.copy()
143 | segmentations_full = seg_full.copy()
144 |
145 | return images, segmentations, segmentations_half, segmentations_full, name
146 |
147 | def __len__(self):
148 | return len(self.imgs)
149 |
150 |
151 | class TestGenerator(data.Dataset):
152 |
153 | def __init__(self, root, list_path, crop_size):
154 |
155 | imgs, segs = make_dataset(root, list_path)
156 | self.root = root
157 | self.imgs = imgs
158 | self.segs = segs
159 | self.crop_size = crop_size
160 |
161 | def __getitem__(self, index):
162 | # load data
163 | mean = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32)
164 | name = self.imgs[index].split('/')[-1][:-4]
165 | img = cv2.imread(self.imgs[index], cv2.IMREAD_COLOR)
166 | seg = cv2.imread(self.segs[index], cv2.IMREAD_GRAYSCALE)
167 | ori_size = img.shape
168 |
169 | h, w = seg.shape
170 | length = max(w, h)
171 | ratio = self.crop_size / length
172 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR)
173 | img = np.array(img).astype(np.float32) - mean
174 | img = img.transpose((2, 0, 1))
175 |
176 | images = img.copy()
177 | segmentations = seg.copy()
178 |
179 | return images, segmentations, np.array(ori_size), name
180 |
181 | def __len__(self):
182 | return len(self.imgs)
183 |
--------------------------------------------------------------------------------
/dataset/data_lip.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import random
4 |
5 | import cv2
6 | import numpy as np
7 | import torch.utils.data as data
8 | import torchvision.transforms as transforms
9 | from .data_transforms import RandomRotate
10 | from PIL import Image
11 |
12 |
13 | # ###### Data loading #######
14 | def make_dataset(root, lst):
15 | # append all index
16 | fid = open(lst, 'r')
17 | imgs, segs, segs_rev = [], [], []
18 | for line in fid.readlines():
19 | idx = line.strip().split(' ')[0]
20 | image_path = os.path.join(root, 'images/' + str(idx) + '.jpg')
21 | seg_path = os.path.join(root, 'segmentations/' + str(idx) + '.png')
22 | seg_rev_path = os.path.join(root, 'segmentations_rev/' + str(idx) + '.png')
23 | imgs.append(image_path)
24 | segs.append(seg_path)
25 | segs_rev.append(seg_rev_path)
26 | return imgs, segs, segs_rev
27 |
28 |
29 | # ###### val resize & crop ######
30 | def scale_crop(img, seg, crop_size):
31 | oh, ow = seg.shape
32 | pad_h = max(0, crop_size - oh)
33 | pad_w = max(0, crop_size - ow)
34 | if pad_h > 0 or pad_w > 0:
35 | img_pad = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT,
36 | value=(0.0, 0.0, 0.0))
37 | seg_pad = cv2.copyMakeBorder(seg, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT,
38 | value=255)
39 | else:
40 | img_pad, seg_pad = img, seg
41 |
42 | img = np.asarray(img_pad[0: crop_size, 0: crop_size], np.float32)
43 | seg = np.asarray(seg_pad[0: crop_size, 0: crop_size], np.float32)
44 |
45 | return img, seg
46 |
47 |
48 | class DatasetGenerator(data.Dataset):
49 | def __init__(self, root, list_path, crop_size, training=True):
50 |
51 | imgs, segs, segs_rev = make_dataset(root, list_path)
52 |
53 | self.root = root
54 | self.imgs = imgs
55 | self.segs = segs
56 | self.segs_rev = segs_rev
57 | self.crop_size = crop_size
58 | self.training = training
59 | self.colorjitter = transforms.ColorJitter(brightness=0.1, contrast=0.5, saturation=0.5, hue=0.1)
60 | self.random_rotate=RandomRotate(20)
61 | def __getitem__(self, index):
62 | # load data
63 | mean = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32)
64 | name = self.imgs[index].split('/')[-1][:-4]
65 | img = cv2.imread(self.imgs[index], cv2.IMREAD_COLOR)
66 | seg_in = cv2.imread(self.segs[index], cv2.IMREAD_GRAYSCALE)
67 | seg_rev_in = cv2.imread(self.segs_rev[index], cv2.IMREAD_GRAYSCALE)
68 |
69 | if self.training:
70 | #colorjitter and rotate
71 | if random.random() < 0.5:
72 | img = Image.fromarray(img)
73 | seg = Image.fromarray(seg)
74 | img = self.colorjitter(img)
75 | img, seg = self.random_rotate(img, seg)
76 | img = np.array(img).astype(np.uint8)
77 | seg = np.array(seg).astype(np.uint8)
78 | # random mirror
79 | flip = np.random.choice(2) * 2 - 1
80 | img = img[:, ::flip, :]
81 | if flip == -1:
82 | seg = seg_rev_in
83 | else:
84 | seg = seg_in
85 | # random scale
86 | ratio = random.uniform(0.5, 1.5)
87 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR)
88 | seg = cv2.resize(seg, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST)
89 | img = np.array(img).astype(np.float32) - mean
90 |
91 | # pad & crop
92 | img_h, img_w = seg.shape
93 | pad_h = max(self.crop_size - img_h, 0)
94 | pad_w = max(self.crop_size - img_w, 0)
95 | if pad_h > 0 or pad_w > 0:
96 | img_pad = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT,
97 | value=(0.0, 0.0, 0.0))
98 | seg_pad = cv2.copyMakeBorder(seg, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT,
99 | value=(255,))
100 | else:
101 | img_pad, seg_pad = img, seg
102 |
103 | img_h, img_w = seg_pad.shape
104 | h_off = random.randint(0, img_h - self.crop_size)
105 | w_off = random.randint(0, img_w - self.crop_size)
106 | img = np.asarray(img_pad[h_off: h_off + self.crop_size, w_off: w_off + self.crop_size], np.float32)
107 | seg = np.asarray(seg_pad[h_off: h_off + self.crop_size, w_off: w_off + self.crop_size], np.float32)
108 | img = img.transpose((2, 0, 1))
109 | # generate body masks
110 | seg_half = seg.copy()
111 | seg_half[(seg_half > 0) & (seg_half <= 7)] = 1
112 | seg_half[(seg_half > 7) & (seg_half <= 10)] = 2
113 | seg_half[seg_half == 11] = 1
114 | seg_half[seg_half == 12] = 2
115 | seg_half[(seg_half > 12) & (seg_half <= 15)] = 1
116 | seg_half[(seg_half > 15) & (seg_half < 255)] = 2
117 | seg_full = seg.copy()
118 | seg_full[(seg_full > 0) & (seg_full < 255)] = 1
119 |
120 | else:
121 | h, w = seg_in.shape
122 | max_size = max(w, h)
123 | ratio = self.crop_size / max_size
124 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR)
125 | seg = cv2.resize(seg_in, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST)
126 | img = np.array(img).astype(np.float32) - mean
127 | img, seg = scale_crop(img, seg, crop_size=self.crop_size)
128 | img = img.transpose((2, 0, 1))
129 | # generate body masks
130 | seg_half = seg.copy()
131 | seg_half[(seg_half > 0) & (seg_half <= 7)] = 1
132 | seg_half[(seg_half > 7) & (seg_half <= 10)] = 2
133 | seg_half[seg_half == 11] = 1
134 | seg_half[seg_half == 12] = 2
135 | seg_half[(seg_half > 12) & (seg_half <= 15)] = 1
136 | seg_half[(seg_half > 15) & (seg_half < 255)] = 2
137 | seg_full = seg.copy()
138 | seg_full[(seg_full > 0) & (seg_full < 255)] = 1
139 |
140 | images = img.copy()
141 | segmentations = seg.copy()
142 | segmentations_half = seg_half.copy()
143 | segmentations_full = seg_full.copy()
144 |
145 | return images, segmentations, segmentations_half, segmentations_full, name
146 |
147 | def __len__(self):
148 | return len(self.imgs)
149 |
150 |
151 | class LIPValGenerator(data.Dataset):
152 | def __init__(self, root, list_path, crop_size):
153 |
154 | fid = open(list_path, 'r')
155 | imgs, segs = [], []
156 | for line in fid.readlines():
157 | idx = line.strip().split(' ')[0]
158 | image_path = os.path.join(root, 'images/' + str(idx) + '.jpg')
159 | seg_path = os.path.join(root, 'segmentations/' + str(idx) + '.png')
160 | imgs.append(image_path)
161 | segs.append(seg_path)
162 |
163 | self.root = root
164 | self.imgs = imgs
165 | self.segs = segs
166 | self.crop_size = crop_size
167 |
168 | def __getitem__(self, index):
169 | # load data
170 | mean = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32)
171 | name = self.imgs[index].split('/')[-1][:-4]
172 | img = cv2.imread(self.imgs[index], cv2.IMREAD_COLOR)
173 | seg = cv2.imread(self.segs[index], cv2.IMREAD_GRAYSCALE)
174 | ori_size = img.shape
175 |
176 | h, w = seg.shape
177 | length = max(w, h)
178 | ratio = self.crop_size / length
179 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR)
180 | img = np.array(img).astype(np.float32) - mean
181 | img = img.transpose((2, 0, 1))
182 |
183 | images = img.copy()
184 | segmentations = seg.copy()
185 |
186 | return images, segmentations, np.array(ori_size), name
187 |
188 | def __len__(self):
189 | return len(self.imgs)
190 |
--------------------------------------------------------------------------------
/dataset/data_pascal.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import random
4 |
5 | import cv2
6 | import numpy as np
7 | import torch.utils.data as data
8 | import torchvision.transforms as transforms
9 | from .data_transforms import RandomRotate
10 | from PIL import Image
11 | # ###### Data loading #######
12 | def make_dataset(root, lst):
13 | # append all index
14 | fid = open(lst, 'r')
15 | imgs, segs = [], []
16 | for line in fid.readlines():
17 | idx = line.strip().split(' ')[0]
18 | image_path = os.path.join(root, 'JPEGImages/' + str(idx) + '.jpg')
19 | # image_path = os.path.join(root, str(idx) + '.jpg')
20 | seg_path = os.path.join(root, 'SegmentationPart/' + str(idx) + '.png')
21 | # seg_path = os.path.join(root, str(idx) + '.jpg')
22 | imgs.append(image_path)
23 | segs.append(seg_path)
24 | return imgs, segs
25 |
26 |
27 | # ###### val resize & crop ######
28 | def scale_crop(img, seg, crop_size):
29 | oh, ow = seg.shape
30 | pad_h = max(crop_size - oh, 0)
31 | pad_ht, pad_hb = pad_h // 2, pad_h - pad_h // 2
32 | pad_w = max(crop_size - ow, 0)
33 | pad_wl, pad_wr = pad_w // 2, pad_w - pad_w // 2
34 | if pad_h > 0 or pad_w > 0:
35 | img_pad = cv2.copyMakeBorder(img, pad_ht, pad_hb, pad_wl, pad_wr, cv2.BORDER_CONSTANT,
36 | value=(0.0, 0.0, 0.0))
37 | seg_pad = cv2.copyMakeBorder(seg, pad_ht, pad_hb, pad_wl, pad_wr, cv2.BORDER_CONSTANT,
38 | value=(255,))
39 | else:
40 | img_pad, seg_pad = img, seg
41 |
42 | return img_pad, seg_pad
43 |
44 |
45 | class DatasetGenerator(data.Dataset):
46 | def __init__(self, root, list_path, crop_size, training=True):
47 |
48 | imgs, segs = make_dataset(root, list_path)
49 |
50 | self.root = root
51 | self.imgs = imgs
52 | self.segs = segs
53 | self.crop_size = crop_size
54 | self.training = training
55 | self.colorjitter = transforms.ColorJitter(brightness=0.1, contrast=0.5, saturation=0.5, hue=0.1)
56 | self.random_rotate=RandomRotate(20)
57 |
58 | def __getitem__(self, index):
59 | mean = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32)
60 | # load data
61 | name = self.imgs[index].split('/')[-1][:-4]
62 | img = cv2.imread(self.imgs[index], cv2.IMREAD_COLOR)
63 | seg = cv2.imread(self.segs[index], cv2.IMREAD_GRAYSCALE)
64 |
65 | if self.training:
66 |
67 | #colorjitter and rotate
68 | if random.random() < 0.5:
69 | img = Image.fromarray(img)
70 | seg = Image.fromarray(seg)
71 | img = self.colorjitter(img)
72 | img, seg = self.random_rotate(img, seg)
73 | img = np.array(img).astype(np.uint8)
74 | seg = np.array(seg).astype(np.uint8)
75 |
76 | # random scale
77 | ratio = random.uniform(0.5, 2.0)
78 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR)
79 | seg = cv2.resize(seg, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST)
80 | img = np.array(img).astype(np.float32) - mean
81 |
82 | # pad & crop
83 | img_h, img_w = seg.shape[:2]
84 | pad_h = max(self.crop_size - img_h, 0)
85 | pad_ht, pad_hb = pad_h // 2, pad_h - pad_h // 2
86 | pad_w = max(self.crop_size - img_w, 0)
87 | pad_wl, pad_wr = pad_w // 2, pad_w - pad_w // 2
88 | if pad_h > 0 or pad_w > 0:
89 | img_pad = cv2.copyMakeBorder(img, pad_ht, pad_hb, pad_wl, pad_wr, cv2.BORDER_CONSTANT,
90 | value=(0.0, 0.0, 0.0))
91 | seg_pad = cv2.copyMakeBorder(seg, pad_ht, pad_hb, pad_wl, pad_wr, cv2.BORDER_CONSTANT,
92 | value=(255,))
93 | else:
94 | img_pad, seg_pad = img, seg
95 |
96 | seg_pad_h, seg_pad_w = seg_pad.shape
97 | h_off = random.randint(0, seg_pad_h - self.crop_size)
98 | w_off = random.randint(0, seg_pad_w - self.crop_size)
99 | img = np.asarray(img_pad[h_off: h_off + self.crop_size, w_off: w_off + self.crop_size], np.float32)
100 | seg = np.asarray(seg_pad[h_off: h_off + self.crop_size, w_off: w_off + self.crop_size], np.uint8)
101 | # random mirror
102 | flip = np.random.choice(2) * 2 - 1
103 | img = img[:, ::flip, :]
104 | seg = seg[:, ::flip]
105 | # Generate target maps
106 | img = img.transpose((2, 0, 1))
107 | seg_half = seg.copy()
108 | seg_half[(seg_half > 0) & (seg_half <= 4)] = 1
109 | seg_half[(seg_half > 4) & (seg_half < 255)] = 2
110 | seg_full = seg.copy()
111 | seg_full[(seg_full > 0) & (seg_full < 255)] = 1
112 |
113 | else:
114 | h, w = seg.shape
115 | max_size = max(w, h)
116 | ratio = self.crop_size / max_size
117 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR)
118 | seg = cv2.resize(seg, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST)
119 | img = np.array(img).astype(np.float32) - mean
120 | img, seg = scale_crop(img, seg, crop_size=self.crop_size)
121 | img = img.transpose((2, 0, 1))
122 | seg_half = seg.copy()
123 | seg_half[(seg_half > 0) & (seg_half <= 4)] = 1
124 | seg_half[(seg_half > 4) & (seg_half < 255)] = 2
125 | seg_full = seg.copy()
126 | seg_full[(seg_full > 0) & (seg_full < 255)] = 1
127 |
128 | images = img.copy()
129 | segmentations = seg.copy()
130 | segmentations_half = seg_half.copy()
131 | segmentations_full = seg_full.copy()
132 |
133 | return images, segmentations, segmentations_half, segmentations_full, name
134 |
135 | def __len__(self):
136 | return len(self.imgs)
137 |
138 |
139 | class TestGenerator(data.Dataset):
140 |
141 | def __init__(self, root, list_path, crop_size):
142 |
143 | imgs, segs = make_dataset(root, list_path)
144 | self.root = root
145 | self.imgs = imgs
146 | self.segs = segs
147 | self.crop_size = crop_size
148 |
149 | def __getitem__(self, index):
150 | # load data
151 | mean = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32)
152 | name = self.imgs[index].split('/')[-1][:-4]
153 | img = cv2.imread(self.imgs[index], cv2.IMREAD_COLOR)
154 | seg = cv2.imread(self.segs[index], cv2.IMREAD_GRAYSCALE)
155 | ori_size = img.shape
156 |
157 | h, w = seg.shape
158 | length = max(w, h)
159 | ratio = self.crop_size / length
160 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR)
161 | img = np.array(img).astype(np.float32) - mean
162 | img = img.transpose((2, 0, 1))
163 |
164 | images = img.copy()
165 | segmentations = seg.copy()
166 |
167 | return images, segmentations, np.array(ori_size), name
168 |
169 | def __len__(self):
170 | return len(self.imgs)
171 |
172 |
173 | class ReportGenerator(data.Dataset):
174 |
175 | def __init__(self, root, list_path, crop_size):
176 |
177 | imgs, segs = make_dataset(root, list_path)
178 | self.root = root
179 | self.imgs = imgs
180 | self.segs = segs
181 | self.crop_size = crop_size
182 |
183 | def __getitem__(self, index):
184 | # load data
185 | mean = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32)
186 | name = self.imgs[index].split('/')[-1][:-4]
187 | img = cv2.imread(self.imgs[index], cv2.IMREAD_COLOR)
188 | seg = cv2.imread(self.segs[index], cv2.IMREAD_GRAYSCALE)
189 | ori_size = img.shape
190 |
191 | h, w = seg.shape
192 | length = max(w, h)
193 | ratio = self.crop_size / length
194 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR)
195 | img = np.array(img).astype(np.float32) - mean
196 | img = img.transpose((2, 0, 1))
197 | seg_half = seg.copy()
198 | seg_half[(seg_half > 0) & (seg_half <= 4)] = 1
199 | seg_half[(seg_half > 4) & (seg_half < 255)] = 2
200 | seg_full = seg.copy()
201 | seg_full[(seg_full > 0) & (seg_full < 255)] = 1
202 |
203 | images = img.copy()
204 | segmentations = seg.copy()
205 | segmentations_half = seg_half.copy()
206 | segmentations_full = seg_full.copy()
207 |
208 | return images, segmentations, segmentations_half, segmentations_full, np.array(ori_size), name
209 |
210 | def __len__(self):
211 | return len(self.imgs)
212 |
213 |
214 | if __name__ == '__main__':
215 | dl = DataGenerator('/media/jzzz/Data/Dataset/PascalPersonPart/', './pascal/train_id.txt',
216 | crop_size=512, training=True)
217 |
218 | item = iter(dl)
219 | for i in range(len(dl)):
220 | imgs, segs, segs_half, segs_full, idx = next(item)
221 | pass
222 |
--------------------------------------------------------------------------------
/dataset/data_ppss.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import random
4 |
5 | import cv2
6 | import numpy as np
7 | import torch.utils.data as data
8 | import torchvision.transforms as transforms
9 | from .data_transforms import RandomRotate
10 | from PIL import Image
11 | map_idx = [0, 9, 19, 29, 50, 39, 60, 62]
12 | # 0background, 1hair, 2face, 3upper clothes, 4arms, 5lower clothes, 6legs, 7shoes
13 |
14 |
15 | # ###### Data loading #######
16 | def make_dataset(root, lst):
17 | # append all index
18 | fid = open(lst, 'r')
19 | imgs, segs = [], []
20 | for line in fid.readlines():
21 | idx = line.strip()
22 | image_path = os.path.join(root, str(idx) + '.jpg')
23 | seg_path = os.path.join(root, str(idx) + '_m.png')
24 | imgs.append(image_path)
25 | segs.append(seg_path)
26 | return imgs, segs
27 |
28 |
29 | # ###### val resize & crop ######
30 | def scale_crop(img, seg, crop_size):
31 | oh, ow = seg.shape
32 | pad_h = max(0, crop_size[0] - oh)
33 | pad_w = max(0, crop_size[1] - ow)
34 | if pad_h > 0 or pad_w > 0:
35 | img_pad = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT,
36 | value=(0.0, 0.0, 0.0))
37 | seg_pad = cv2.copyMakeBorder(seg, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT,
38 | value=255)
39 | else:
40 | img_pad, seg_pad = img, seg
41 |
42 | img = np.asarray(img_pad[0: crop_size[0], 0: crop_size[1]], np.float32)
43 | seg = np.asarray(seg_pad[0: crop_size[0], 0: crop_size[1]], np.float32)
44 |
45 | return img, seg
46 |
47 |
48 | class DatasetGenerator(data.Dataset):
49 | def __init__(self, root, list_path, crop_size, training=True):
50 |
51 | imgs, segs = make_dataset(root, list_path)
52 |
53 | self.root = root
54 | self.imgs = imgs
55 | self.segs = segs
56 | self.crop_size = crop_size
57 | self.training = training
58 | self.colorjitter = transforms.ColorJitter(brightness=0.1, contrast=0.5, saturation=0.5, hue=0.1)
59 | self.random_rotate=RandomRotate(20)
60 |
61 | def __getitem__(self, index):
62 | # load data
63 | mean = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32)
64 | name = self.imgs[index].split('/')[-1][:-4]
65 | img = cv2.imread(self.imgs[index], cv2.IMREAD_COLOR)
66 | seg = np.array(Image.open(self.segs[index]))
67 | # seg_h, seg_w = seg.shape
68 | seg_h, seg_w, _ = img.shape
69 | # img = cv2.resize(img, (seg_w, seg_h), interpolation=cv2.INTER_LINEAR)
70 | seg = cv2.resize(seg, (seg_w, seg_h), interpolation=cv2.INTER_NEAREST)
71 | new_seg = (np.ones_like(seg)*255).astype(np.uint8)
72 | for i in range(len(map_idx)):
73 | new_seg[seg == map_idx[i]] = i
74 | seg = new_seg
75 | if self.training:
76 | #colorjitter and rotate
77 | if random.random() < 0.5:
78 | img = Image.fromarray(img)
79 | seg = Image.fromarray(seg)
80 | img = self.colorjitter(img)
81 | img, seg = self.random_rotate(img, seg)
82 | img = np.array(img).astype(np.uint8)
83 | seg = np.array(seg).astype(np.uint8)
84 | # random mirror
85 | flip = np.random.choice(2) * 2 - 1
86 | img = img[:, ::flip, :]
87 | seg = seg[:, ::flip]
88 | # random scale
89 | ratio = random.uniform(0.75, 2.5)
90 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR)
91 | seg = cv2.resize(seg, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST)
92 | img = np.array(img).astype(np.float32) - mean
93 |
94 | # pad & crop
95 | img_h, img_w = seg.shape
96 | assert img_w < img_h
97 | pad_h = max(self.crop_size[0] - img_h, 0)
98 | pad_w = max(self.crop_size[1] - img_w, 0)
99 | if pad_h > 0 or pad_w > 0:
100 | img_pad = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT,
101 | value=(0.0, 0.0, 0.0))
102 | seg_pad = cv2.copyMakeBorder(seg, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT,
103 | value=(255,))
104 | else:
105 | img_pad, seg_pad = img, seg
106 |
107 | img_h, img_w = seg_pad.shape
108 | h_off = random.randint(0, img_h - self.crop_size[0])
109 | w_off = random.randint(0, img_w - self.crop_size[1])
110 | img = np.asarray(img_pad[h_off: h_off + self.crop_size[0], w_off: w_off + self.crop_size[1]], np.float32)
111 | seg = np.asarray(seg_pad[h_off: h_off + self.crop_size[0], w_off: w_off + self.crop_size[1]], np.uint8)
112 | img = img.transpose((2, 0, 1))
113 | # generate body masks
114 | seg_half = seg.copy()
115 | seg_half[(seg_half > 0) & (seg_half <= 4)] = 1
116 | seg_half[(seg_half > 4) & (seg_half < 255)] = 2
117 | seg_full = seg.copy()
118 | seg_full[(seg_full > 0) & (seg_full < 255)] = 1
119 |
120 | else:
121 | h, w = seg.shape
122 | max_size = max(w, h)
123 | ratio = self.crop_size[0] / max_size
124 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR)
125 | seg = cv2.resize(seg, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST)
126 | img = np.array(img).astype(np.float32) - mean
127 | img, seg = scale_crop(img, seg, crop_size=self.crop_size)
128 | img = img.transpose((2, 0, 1))
129 | # generate body masks
130 | seg_half = seg.copy()
131 | seg_half[(seg_half > 0) & (seg_half <= 4)] = 1
132 | seg_half[(seg_half > 4) & (seg_half < 255)] = 2
133 | seg_full = seg.copy()
134 | seg_full[(seg_full > 0) & (seg_full < 255)] = 1
135 |
136 | images = img.copy()
137 | segmentations = seg.copy()
138 | segmentations_half = seg_half.copy()
139 | segmentations_full = seg_full.copy()
140 |
141 | return images, segmentations, segmentations_half, segmentations_full, name
142 |
143 | def __len__(self):
144 | return len(self.imgs)
145 |
146 | class TestGenerator(data.Dataset):
147 |
148 | def __init__(self, root, list_path, crop_size):
149 |
150 | imgs, segs = make_dataset(root, list_path)
151 | self.root = root
152 | self.imgs = imgs
153 | self.segs = segs
154 | self.crop_size = crop_size
155 |
156 | def __getitem__(self, index):
157 | # load data
158 | mean = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32)
159 | name = self.imgs[index].split('/')[-1][:-4]
160 | img = cv2.imread(self.imgs[index], cv2.IMREAD_COLOR)
161 | seg = np.array(Image.open(self.segs[index]))
162 | seg_h, seg_w = seg.shape
163 | img = cv2.resize(img, (seg_w, seg_h), interpolation=cv2.INTER_LINEAR)
164 | for i in range(len(map_idx)):
165 | seg[seg == map_idx[i]] = i
166 | ori_size = img.shape
167 |
168 | h, w = seg.shape
169 | length = max(w, h)
170 | ratio = self.crop_size[0] / length
171 | img = cv2.resize(img, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR)
172 | img = np.array(img).astype(np.float32) - mean
173 | img = img.transpose((2, 0, 1))
174 |
175 | images = img.copy()
176 | segmentations = seg.copy()
177 |
178 | return images, segmentations, np.array(ori_size), name
179 |
180 | def __len__(self):
181 | return len(self.imgs)
182 |
183 | if __name__ == '__main__':
184 | dl = DatasetGenerator('/media/jzzz/Data/Dataset/PPSS/TrainData/', './PPSS/train_id.txt',
185 | crop_size=(321, 161), training=False)
186 |
187 | item = iter(dl)
188 | for i in range(len(dl)):
189 | imgs, segs, segs_half, segs_full, idx = next(item)
190 | pass
191 |
--------------------------------------------------------------------------------
/dataset/transforms.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 | # Author: Donny You (youansheng@gmail.com)
4 |
5 |
6 | from __future__ import absolute_import
7 | from __future__ import division
8 | from __future__ import print_function
9 |
10 | import numpy as np
11 | import torch
12 | from PIL import Image
13 |
14 |
15 | class Normalize(object):
16 | """Normalize a ``torch.tensor``
17 |
18 | Args:
19 | inputs (torch.tensor): tensor to be normalized.
20 | mean: (list): the mean of RGB
21 | std: (list): the std of RGB
22 |
23 | Returns:
24 | Tensor: Normalized tensor.
25 | """
26 | def __init__(self, div_value, mean, std):
27 | self.div_value = div_value
28 | self.mean = mean
29 | self.std =std
30 |
31 | def __call__(self, inputs):
32 | inputs = inputs.div(self.div_value)
33 | for t, m, s in zip(inputs, self.mean, self.std):
34 | t.sub_(m).div_(s)
35 |
36 | return inputs
37 |
38 |
39 | class DeNormalize(object):
40 | """DeNormalize a ``torch.tensor``
41 |
42 | Args:
43 | inputs (torch.tensor): tensor to be normalized.
44 | mean: (list): the mean of RGB
45 | std: (list): the std of RGB
46 |
47 | Returns:
48 | Tensor: Normalized tensor.
49 | """
50 | def __init__(self, div_value, mean, std):
51 | self.div_value = div_value
52 | self.mean = mean
53 | self.std =std
54 |
55 | def __call__(self, inputs):
56 | result = inputs.clone()
57 | for i in range(result.size(0)):
58 | result[i, :, :] = result[i, :, :] * self.std[i] + self.mean[i]
59 |
60 | return result.mul_(self.div_value)
61 |
62 |
63 | class ToTensor(object):
64 | """Convert a ``numpy.ndarray or Image`` to tensor.
65 |
66 | See ``ToTensor`` for more details.
67 |
68 | Args:
69 | inputs (numpy.ndarray or Image): Image to be converted to tensor.
70 |
71 | Returns:
72 | Tensor: Converted image.
73 | """
74 | def __call__(self, inputs):
75 | if isinstance(inputs, Image.Image):
76 | channels = len(inputs.mode)
77 | inputs = np.array(inputs)
78 | inputs = inputs.reshape(inputs.shape[0], inputs.shape[1], channels)
79 | inputs = torch.from_numpy(inputs.transpose(2, 0, 1))
80 | else:
81 | inputs = torch.from_numpy(inputs.transpose(2, 0, 1))
82 |
83 | return inputs.float()
84 |
85 |
86 | class ToLabel(object):
87 | def __call__(self, inputs):
88 | return torch.from_numpy(np.array(inputs)).long()
89 |
90 |
91 | class ReLabel(object):
92 | """
93 | 255 indicate the background, relabel 255 to some value.
94 | """
95 | def __init__(self, olabel, nlabel):
96 | self.olabel = olabel
97 | self.nlabel = nlabel
98 |
99 | def __call__(self, inputs):
100 | assert isinstance(inputs, torch.LongTensor), 'tensor needs to be LongTensor'
101 |
102 | inputs[inputs == self.olabel] = self.nlabel
103 | return inputs
104 |
105 |
106 | class Compose(object):
107 |
108 | def __init__(self, transforms):
109 | self.transforms = transforms
110 |
111 | def __call__(self, inputs):
112 | for t in self.transforms:
113 | inputs = t(inputs)
114 |
115 | return inputs
116 |
117 |
118 |
119 |
120 |
--------------------------------------------------------------------------------
/dataset/weights.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | from PIL import Image
5 |
6 | # data_root = '../data/CCF/Segmentations/'
7 | # fid = open('./CCF/train_id.txt', 'r')
8 | # num_cls = 18
9 |
10 |
11 | # data_root = '../data/Person/SegmentationPart/'
12 | # fid = open('./Pascal/train_id.txt', 'r')
13 | # num_cls = 7
14 | data_root = '../data/LIP/train_set/segmentations/'
15 | fid = open('./LIP/train_id.txt', 'r')
16 | num_cls = 20
17 |
18 | cls_pix_num = np.zeros(num_cls)
19 | cls_hbody_num = np.zeros(3)
20 | cls_fbody_num = np.zeros(2)
21 |
22 | map_idx = [0, 9, 19, 29, 50, 39, 60, 62]
23 |
24 | for line in fid.readlines():
25 | img_path = os.path.join(data_root, line.strip() + '.png')
26 | # img_data = np.asarray(Image.open(img_path).convert('L'))
27 | img_data = np.array(Image.open(img_path))
28 | # for i in range(len(map_idx)):
29 | # img_data[img_data == map_idx[i]] = i
30 | # img_size = img_data.size
31 | for i in range(num_cls):
32 | cls_pix_num[i] += (img_data == i).astype(int).sum(axis=None)
33 |
34 | # # half body
35 | # cls_hbody_num[0] = cls_pix_num[0]
36 | # for i in range(1, 5):
37 | # cls_hbody_num[1] += cls_pix_num[i]
38 | # for i in range(5, 8):
39 | # cls_hbody_num[2] += cls_pix_num[i]
40 | #
41 | # # full body
42 | # cls_fbody_num[0] = cls_pix_num[0]
43 | # for i in range(1, 8):
44 | # cls_fbody_num[1] += cls_pix_num[i]
45 |
46 | weight = np.log(cls_pix_num)
47 | weight_norm = np.zeros(num_cls)
48 | for i in range(num_cls):
49 | weight_norm[i] = 16 / weight[i]
50 | print(weight_norm)
51 |
52 |
53 | # [0.8373, 0.918, 0.866, 1.0345, 1.0166,
54 | # 0.9969, 0.9754, 1.0489, 0.8786, 1.0023,
55 | # 0.9539, 0.9843, 1.1116, 0.9037, 1.0865,
56 | # 1.0955, 1.0865, 1.1529, 1.0507]
57 |
58 | # 0.93237515, 1.01116892, 1.11201307
59 |
60 | # 0.98417377, 1.05657165
61 |
62 | # ATR training
63 | # [0.85978634, 1.19630769, 1.02639146, 1.30664970, 0.97220603, 1.04885815,
64 | # 1.01745278, 1.01481690, 1.27155077, 1.12947663, 1.13016390, 1.06514227,
65 | # 1.08384483, 1.08506841, 1.09560942, 1.09565198, 1.07504567, 1.20411509]
66 |
67 | #CCF
68 | # [0.82073458, 1.23651165, 1.0366326, 0.97076566, 1.2802332, 0.98860602,
69 | # 1.29035071, 1.03882453, 0.96725283, 1.05142434, 1.0075884, 0.98630539,
70 | # 1.06208869, 1.0160915, 1.1613597, 1.17624919, 1.1701143, 1.24720215]
71 |
72 | #PPSS
73 | # [0.89680465, 1.14352656, 1.20982646, 0.99269248,
74 | # 1.17911144, 1.00641032, 1.47017195, 1.16447113]
75 |
76 | #Pascal
77 | # [0.82877791, 0.95688253, 0.94921949, 1.00538108, 1.0201687, 1.01665831, 1.05470914]
78 |
79 | #Lip
80 | # [0.7602572, 0.94236198, 0.85644457, 1.04346266, 1.10627293, 0.80980162,
81 | # 0.95168713, 0.8403769, 1.05798412, 0.85746254, 1.01274366, 1.05854692,
82 | # 1.03430773, 0.84867818, 0.88027721, 0.87580925, 0.98747462, 0.9876475,
83 | # 1.00016535, 1.00108882]
84 |
--------------------------------------------------------------------------------
/doc/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hlzhu09/Hierarchical-Human-Parsing/58737dbb7763f5a7baafd4801c5a287facf536cd/doc/architecture.png
--------------------------------------------------------------------------------
/evaluate_pascal.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import cv2
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | from torch.autograd import Variable
10 | from torch.utils import data
11 |
12 | from dataset.data_pascal import TestGenerator
13 | from network.baseline import get_model
14 |
15 |
16 | def get_arguments():
17 | """Parse all the arguments provided from the CLI.
18 |
19 | Returns:
20 | A list of parsed arguments.
21 | """
22 | parser = argparse.ArgumentParser(description="Pytorch Segmentation")
23 | parser.add_argument('--root', default='./data/Person', type=str)
24 | parser.add_argument("--data-list", type=str, default='./dataset/Pascal/val_id.txt')
25 | parser.add_argument("--crop-size", type=int, default=473)
26 | parser.add_argument("--num-classes", type=int, default=7)
27 | parser.add_argument("--ignore-label", type=int, default=255)
28 | parser.add_argument('--restore-from', default='./checkpoints/exp/baseline_pascal.pth', type=str)
29 |
30 | parser.add_argument("--is-mirror", action="store_true")
31 | parser.add_argument("--ms", action="store_true")
32 | # parser.add_argument('--eval-scale', nargs='+', type=float, default=[1.0])
33 | # parser.add_argument('--eval-scale', nargs='+', type=float, default=[0.50, 0.75, 1.0, 1.25, 1.50, 1.75])
34 |
35 | parser.add_argument("--save-dir", type=str)
36 | parser.add_argument("--gpu", type=str, default='0')
37 | return parser.parse_args()
38 |
39 |
40 | def main():
41 | """Create the model and start the evaluation process."""
42 | args = get_arguments()
43 |
44 | # initialization
45 | print("Input arguments:")
46 | for key, val in vars(args).items():
47 | print("{:16} {}".format(key, val))
48 |
49 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
50 |
51 | model = get_model(num_classes=args.num_classes)
52 |
53 | # if not os.path.exists(args.save_dir):
54 | # os.makedirs(args.save_dir)
55 |
56 | palette = get_palette()
57 | restore_from = args.restore_from
58 | saved_state_dict = torch.load(restore_from)
59 | model.load_state_dict(saved_state_dict)
60 |
61 | model.eval()
62 | model.cuda()
63 |
64 | testloader = data.DataLoader(TestGenerator(args.root, args.data_list, crop_size=args.crop_size),
65 | batch_size=1, shuffle=False, pin_memory=True)
66 |
67 | confusion_matrix = np.zeros((args.num_classes, args.num_classes))
68 |
69 | for index, batch in enumerate(testloader):
70 | if index % 100 == 0:
71 | print('%d images have been proceeded' % index)
72 | image, label, ori_size, name = batch
73 |
74 | ori_size = ori_size[0].numpy()
75 | if args.ms:
76 | eval_scale=[0.50, 0.75, 1.0, 1.25, 1.50, 1.75]
77 | else:
78 | eval_scale=[1.0]
79 |
80 | output = predict(model, image.numpy(), (np.asscalar(ori_size[0]), np.asscalar(ori_size[1])),
81 | is_mirror=args.is_mirror, scales=eval_scale)
82 | seg_pred = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
83 |
84 | # output_im = PILImage.fromarray(seg_pred)
85 | # output_im.putpalette(palette)
86 | # output_im.save(args.save_dir + name[0] + '.png')
87 |
88 | seg_gt = np.asarray(label[0].numpy(), dtype=np.int)
89 | ignore_index = seg_gt != 255
90 | seg_gt = seg_gt[ignore_index]
91 | seg_pred = seg_pred[ignore_index]
92 |
93 | confusion_matrix += get_confusion_matrix(seg_gt, seg_pred, args.num_classes)
94 |
95 | pos = confusion_matrix.sum(1)
96 | res = confusion_matrix.sum(0)
97 | tp = np.diag(confusion_matrix)
98 |
99 | pixel_accuracy = tp.sum() / pos.sum()
100 | mean_accuracy = (tp / np.maximum(1.0, pos)).mean()
101 | IU_array = (tp / np.maximum(1.0, pos + res - tp))
102 | mean_IU = IU_array.mean()
103 |
104 | # get_confusion_matrix_plot()
105 |
106 | print('Pixel accuracy: %f \n' % pixel_accuracy)
107 | print('Mean accuracy: %f \n' % mean_accuracy)
108 | print('Mean IU: %f \n' % mean_IU)
109 | for index, IU in enumerate(IU_array):
110 | print('%f ', IU)
111 |
112 |
113 | def scale_image(image, scale):
114 | image = image[0, :, :, :]
115 | image = image.transpose((1, 2, 0))
116 | image = cv2.resize(image, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
117 | image = image.transpose((2, 0, 1))
118 | return image
119 |
120 |
121 | def predict(net, image, output_size, is_mirror=True, scales=[1]):
122 | if is_mirror:
123 | image_rev = image[:, :, :, ::-1]
124 |
125 | interp = nn.Upsample(size=output_size, mode='bilinear', align_corners=True)
126 |
127 | outputs = []
128 | if is_mirror:
129 | for scale in scales:
130 | if scale != 1:
131 | image_scale = scale_image(image=image, scale=scale)
132 | image_rev_scale = scale_image(image=image_rev, scale=scale)
133 | else:
134 | image_scale = image[0, :, :, :]
135 | image_rev_scale = image_rev[0, :, :, :]
136 |
137 | image_scale = np.stack((image_scale, image_rev_scale))
138 |
139 | with torch.no_grad():
140 | prediction = net(Variable(torch.from_numpy(image_scale)).cuda())
141 | prediction = interp(prediction[0]).cpu().data.numpy()
142 |
143 | prediction_rev = prediction[1, :, :, :].copy()
144 | prediction_rev = prediction_rev[:, :, ::-1]
145 | prediction = prediction[0, :, :, :]
146 | prediction = np.mean([prediction, prediction_rev], axis=0)
147 |
148 | outputs.append(prediction)
149 |
150 | outputs = np.mean(outputs, axis=0)
151 | outputs = outputs.transpose(1, 2, 0)
152 | else:
153 | for scale in scales:
154 | if scale != 1:
155 | image_scale = scale_image(image=image, scale=scale)
156 | else:
157 | image_scale = image[0, :, :, :]
158 |
159 | with torch.no_grad():
160 | prediction = net(Variable(torch.from_numpy(image_scale).unsqueeze(0)).cuda())
161 | prediction = interp(prediction[0]).cpu().data.numpy()
162 | outputs.append(prediction[0, :, :, :])
163 |
164 | outputs = np.mean(outputs, axis=0)
165 | outputs = outputs.transpose(1, 2, 0)
166 |
167 | return outputs
168 |
169 |
170 | def get_confusion_matrix(gt_label, pred_label, class_num):
171 | """
172 | Calculate the confusion matrix by given label and pred
173 | :param gt_label: the ground truth label
174 | :param pred_label: the pred label
175 | :param class_num: the nunber of class
176 | """
177 | index = (gt_label * class_num + pred_label).astype('int32')
178 | label_count = np.bincount(index)
179 | confusion_matrix = np.zeros((class_num, class_num))
180 |
181 | for i_label in range(class_num):
182 | for i_pred_label in range(class_num):
183 | cur_index = i_label * class_num + i_pred_label
184 | if cur_index < len(label_count):
185 | confusion_matrix[i_label, i_pred_label] = label_count[cur_index]
186 |
187 | return confusion_matrix
188 |
189 |
190 | def get_confusion_matrix_plot(conf_arr):
191 | norm_conf = []
192 | for i in conf_arr:
193 | tmp_arr = []
194 | a = sum(i, 0)
195 | for j in i:
196 | tmp_arr.append(float(j) / max(1.0, float(a)))
197 | norm_conf.append(tmp_arr)
198 |
199 | fig = plt.figure()
200 | plt.clf()
201 | ax = fig.add_subplot(111)
202 | ax.set_aspect(1)
203 | res = ax.imshow(np.array(norm_conf), cmap=plt.cm.jet, interpolation='nearest')
204 |
205 | width, height = conf_arr.shape
206 |
207 | cb = fig.colorbar(res)
208 | alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
209 | plt.xticks(range(width), alphabet[:width])
210 | plt.yticks(range(height), alphabet[:height])
211 | plt.savefig('confusion_matrix.png', format='png')
212 |
213 |
214 | def get_palette():
215 | palette = [0, 0, 0,
216 | 128, 0, 0,
217 | 0, 128, 0,
218 | 128, 128, 0,
219 | 0, 0, 128,
220 | 128, 0, 128,
221 | 0, 128, 128]
222 | return palette
223 |
224 | if __name__ == '__main__':
225 | main()
226 |
--------------------------------------------------------------------------------
/evaluate_pascal.sh:
--------------------------------------------------------------------------------
1 | python evaluate_pascal.py --root ./data/Person --data-list ./dataset/Pascal/val_id.txt --crop-size 473 --restore-from [checkpoint path] --ms
--------------------------------------------------------------------------------
/inplace_abn/__init__.py:
--------------------------------------------------------------------------------
1 | from .bn import ABN, InPlaceABN, InPlaceABNSync
2 | from .functions import ACT_RELU, ACT_LEAKY_RELU, ACT_ELU, ACT_NONE
3 |
--------------------------------------------------------------------------------
/inplace_abn/bn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as functional
4 |
5 | try:
6 | from queue import Queue
7 | except ImportError:
8 | from Queue import Queue
9 |
10 | from .functions import *
11 |
12 |
13 | class ABN(nn.Module):
14 | """Activated Batch Normalization
15 |
16 | This gathers a `BatchNorm2d` and an activation function in a single module
17 | """
18 |
19 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
20 | """Creates an Activated Batch Normalization module
21 |
22 | Parameters
23 | ----------
24 | num_features : int
25 | Number of feature channels in the input and output.
26 | eps : float
27 | Small constant to prevent numerical issues.
28 | momentum : float
29 | Momentum factor applied to compute running statistics as.
30 | affine : bool
31 | If `True` apply learned scale and shift transformation after normalization.
32 | activation : str
33 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
34 | slope : float
35 | Negative slope for the `leaky_relu` activation.
36 | """
37 | super(ABN, self).__init__()
38 | self.num_features = num_features
39 | self.affine = affine
40 | self.eps = eps
41 | self.momentum = momentum
42 | self.activation = activation
43 | self.slope = slope
44 | if self.affine:
45 | self.weight = nn.Parameter(torch.ones(num_features))
46 | self.bias = nn.Parameter(torch.zeros(num_features))
47 | else:
48 | self.register_parameter('weight', None)
49 | self.register_parameter('bias', None)
50 | self.register_buffer('running_mean', torch.zeros(num_features))
51 | self.register_buffer('running_var', torch.ones(num_features))
52 | self.reset_parameters()
53 |
54 | def reset_parameters(self):
55 | nn.init.constant_(self.running_mean, 0)
56 | nn.init.constant_(self.running_var, 1)
57 | if self.affine:
58 | nn.init.constant_(self.weight, 1)
59 | nn.init.constant_(self.bias, 0)
60 |
61 | def forward(self, x):
62 | x = functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias,
63 | self.training, self.momentum, self.eps)
64 |
65 | if self.activation == ACT_RELU:
66 | return functional.relu(x, inplace=True)
67 | elif self.activation == ACT_LEAKY_RELU:
68 | return functional.leaky_relu(x, negative_slope=self.slope, inplace=True)
69 | elif self.activation == ACT_ELU:
70 | return functional.elu(x, inplace=True)
71 | else:
72 | return x
73 |
74 | def __repr__(self):
75 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
76 | ' affine={affine}, activation={activation}'
77 | if self.activation == "leaky_relu":
78 | rep += ', slope={slope})'
79 | else:
80 | rep += ')'
81 | return rep.format(name=self.__class__.__name__, **self.__dict__)
82 |
83 |
84 | class InPlaceABN(ABN):
85 | """InPlace Activated Batch Normalization"""
86 |
87 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
88 | """Creates an InPlace Activated Batch Normalization module
89 |
90 | Parameters
91 | ----------
92 | num_features : int
93 | Number of feature channels in the input and output.
94 | eps : float
95 | Small constant to prevent numerical issues.
96 | momentum : float
97 | Momentum factor applied to compute running statistics as.
98 | affine : bool
99 | If `True` apply learned scale and shift transformation after normalization.
100 | activation : str
101 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
102 | slope : float
103 | Negative slope for the `leaky_relu` activation.
104 | """
105 | super(InPlaceABN, self).__init__(num_features, eps, momentum, affine, activation, slope)
106 |
107 | def forward(self, x):
108 | return inplace_abn(x, self.weight, self.bias, self.running_mean, self.running_var,
109 | self.training, self.momentum, self.eps, self.activation, self.slope)
110 |
111 |
112 | class InPlaceABNSync(ABN):
113 | """InPlace Activated Batch Normalization with cross-GPU synchronization
114 | This assumes that it will be replicated across GPUs using the same mechanism as in `nn.DistributedDataParallel`.
115 | """
116 |
117 | def forward(self, x):
118 | return inplace_abn_sync(x, self.weight, self.bias, self.running_mean, self.running_var,
119 | self.training, self.momentum, self.eps, self.activation, self.slope)
120 |
121 | def __repr__(self):
122 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
123 | ' affine={affine}, activation={activation}'
124 | if self.activation == "leaky_relu":
125 | rep += ', slope={slope})'
126 | else:
127 | rep += ')'
128 | return rep.format(name=self.__class__.__name__, **self.__dict__)
129 |
130 |
131 |
--------------------------------------------------------------------------------
/inplace_abn/functions.py:
--------------------------------------------------------------------------------
1 | from os import path
2 | import torch
3 | import torch.distributed as dist
4 | import torch.autograd as autograd
5 | import torch.cuda.comm as comm
6 | from torch.autograd.function import once_differentiable
7 | from torch.utils.cpp_extension import load
8 |
9 | _src_path = path.join(path.dirname(path.abspath(__file__)), "src")
10 | _backend = load(name="inplace_abn",
11 | extra_cflags=["-O3"],
12 | sources=[path.join(_src_path, f) for f in [
13 | "inplace_abn.cpp",
14 | "inplace_abn_cpu.cpp",
15 | "inplace_abn_cuda.cu",
16 | "inplace_abn_cuda_half.cu"
17 | ]],
18 | extra_cuda_cflags=["--expt-extended-lambda"])
19 |
20 | # Activation names
21 | ACT_RELU = "relu"
22 | ACT_LEAKY_RELU = "leaky_relu"
23 | ACT_ELU = "elu"
24 | ACT_NONE = "none"
25 |
26 |
27 | def _check(fn, *args, **kwargs):
28 | success = fn(*args, **kwargs)
29 | if not success:
30 | raise RuntimeError("CUDA Error encountered in {}".format(fn))
31 |
32 |
33 | def _broadcast_shape(x):
34 | out_size = []
35 | for i, s in enumerate(x.size()):
36 | if i != 1:
37 | out_size.append(1)
38 | else:
39 | out_size.append(s)
40 | return out_size
41 |
42 |
43 | def _reduce(x):
44 | if len(x.size()) == 2:
45 | return x.sum(dim=0)
46 | else:
47 | n, c = x.size()[0:2]
48 | return x.contiguous().view((n, c, -1)).sum(2).sum(0)
49 |
50 |
51 | def _count_samples(x):
52 | count = 1
53 | for i, s in enumerate(x.size()):
54 | if i != 1:
55 | count *= s
56 | return count
57 |
58 |
59 | def _act_forward(ctx, x):
60 | if ctx.activation == ACT_LEAKY_RELU:
61 | _backend.leaky_relu_forward(x, ctx.slope)
62 | elif ctx.activation == ACT_ELU:
63 | _backend.elu_forward(x)
64 | elif ctx.activation == ACT_NONE:
65 | pass
66 |
67 |
68 | def _act_backward(ctx, x, dx):
69 | if ctx.activation == ACT_LEAKY_RELU:
70 | _backend.leaky_relu_backward(x, dx, ctx.slope)
71 | elif ctx.activation == ACT_ELU:
72 | _backend.elu_backward(x, dx)
73 | elif ctx.activation == ACT_NONE:
74 | pass
75 |
76 |
77 | class InPlaceABN(autograd.Function):
78 | @staticmethod
79 | def forward(ctx, x, weight, bias, running_mean, running_var,
80 | training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01):
81 | # Save context
82 | ctx.training = training
83 | ctx.momentum = momentum
84 | ctx.eps = eps
85 | ctx.activation = activation
86 | ctx.slope = slope
87 | ctx.affine = weight is not None and bias is not None
88 |
89 | # Prepare inputs
90 | count = _count_samples(x)
91 | x = x.contiguous()
92 | weight = weight.contiguous() if ctx.affine else x.new_empty(0, dtype=torch.float32)
93 | bias = bias.contiguous() if ctx.affine else x.new_empty(0, dtype=torch.float32)
94 |
95 | if ctx.training:
96 | mean, var = _backend.mean_var(x)
97 |
98 | # Update running stats
99 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
100 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1))
101 |
102 | # Mark in-place modified tensors
103 | ctx.mark_dirty(x, running_mean, running_var)
104 | else:
105 | mean, var = running_mean.contiguous(), running_var.contiguous()
106 | ctx.mark_dirty(x)
107 |
108 | # BN forward + activation
109 | _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps)
110 | _act_forward(ctx, x)
111 |
112 | # Output
113 | ctx.var = var
114 | ctx.save_for_backward(x, var, weight, bias)
115 | return x
116 |
117 | @staticmethod
118 | @once_differentiable
119 | def backward(ctx, dz):
120 | z, var, weight, bias = ctx.saved_tensors
121 | dz = dz.contiguous()
122 |
123 | # Undo activation
124 | _act_backward(ctx, z, dz)
125 |
126 | if ctx.training:
127 | edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps)
128 | else:
129 | # TODO: implement simplified CUDA backward for inference mode
130 | edz = dz.new_zeros(dz.size(1))
131 | eydz = dz.new_zeros(dz.size(1))
132 |
133 | dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps)
134 | # dweight = eydz * weight.sign() if ctx.affine else None
135 | dweight = eydz if ctx.affine else None
136 | if dweight is not None:
137 | dweight[weight < 0] *= -1
138 | dbias = edz if ctx.affine else None
139 |
140 | return dx, dweight, dbias, None, None, None, None, None, None, None
141 |
142 |
143 | class InPlaceABNSync(autograd.Function):
144 | @classmethod
145 | def forward(cls, ctx, x, weight, bias, running_mean, running_var,
146 | training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01, equal_batches=True):
147 | # Save context
148 | ctx.training = training
149 | ctx.momentum = momentum
150 | ctx.eps = eps
151 | ctx.activation = activation
152 | ctx.slope = slope
153 | ctx.affine = weight is not None and bias is not None
154 |
155 | # Prepare inputs
156 | ctx.world_size = dist.get_world_size() if dist.is_initialized() else 1
157 |
158 | # count = _count_samples(x)
159 | batch_size = x.new_tensor([x.shape[0]], dtype=torch.long)
160 |
161 | x = x.contiguous()
162 | weight = weight.contiguous() if ctx.affine else x.new_empty(0, dtype=torch.float32)
163 | bias = bias.contiguous() if ctx.affine else x.new_empty(0, dtype=torch.float32)
164 |
165 | if ctx.training:
166 | mean, var = _backend.mean_var(x)
167 | if ctx.world_size > 1:
168 | # get global batch size
169 | if equal_batches:
170 | batch_size *= ctx.world_size
171 | else:
172 | dist.all_reduce(batch_size, dist.ReduceOp.SUM)
173 |
174 | ctx.factor = x.shape[0] / float(batch_size.item())
175 |
176 | mean_all = mean.clone() * ctx.factor
177 | dist.all_reduce(mean_all, dist.ReduceOp.SUM)
178 |
179 | var_all = (var + (mean - mean_all) ** 2) * ctx.factor
180 | dist.all_reduce(var_all, dist.ReduceOp.SUM)
181 |
182 | mean = mean_all
183 | var = var_all
184 |
185 | # Update running stats
186 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
187 | count = batch_size.item() * x.view(x.shape[0], x.shape[1], -1).shape[-1]
188 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * (float(count) / (count - 1)))
189 |
190 | # Mark in-place modified tensors
191 | ctx.mark_dirty(x, running_mean, running_var)
192 | else:
193 | mean, var = running_mean.contiguous(), running_var.contiguous()
194 | ctx.mark_dirty(x)
195 |
196 | # BN forward + activation
197 | _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps)
198 | _act_forward(ctx, x)
199 |
200 | # Output
201 | ctx.var = var
202 | ctx.save_for_backward(x, var, weight, bias)
203 | return x
204 |
205 | @staticmethod
206 | @once_differentiable
207 | def backward(ctx, dz):
208 | z, var, weight, bias = ctx.saved_tensors
209 | dz = dz.contiguous()
210 |
211 | # Undo activation
212 | _act_backward(ctx, z, dz)
213 |
214 | if ctx.training:
215 | edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps)
216 | edz_local = edz.clone()
217 | eydz_local = eydz.clone()
218 |
219 | if ctx.world_size > 1:
220 | edz *= ctx.factor
221 | dist.all_reduce(edz, dist.ReduceOp.SUM)
222 |
223 | eydz *= ctx.factor
224 | dist.all_reduce(eydz, dist.ReduceOp.SUM)
225 | else:
226 | edz_local = edz = dz.new_zeros(dz.size(1))
227 | eydz_local = eydz = dz.new_zeros(dz.size(1))
228 |
229 | dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps)
230 | # dweight = eydz_local * weight.sign() if ctx.affine else None
231 | dweight = eydz_local if ctx.affine else None
232 | if dweight is not None:
233 | dweight[weight < 0] *= -1
234 | dbias = edz_local if ctx.affine else None
235 |
236 | return dx, dweight, dbias, None, None, None, None, None, None, None
237 |
238 |
239 | inplace_abn = InPlaceABN.apply
240 | inplace_abn_sync = InPlaceABNSync.apply
241 |
242 | __all__ = ["inplace_abn", "inplace_abn_sync", "ACT_RELU", "ACT_LEAKY_RELU", "ACT_ELU", "ACT_NONE"]
243 |
--------------------------------------------------------------------------------
/inplace_abn/src/checks.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | // Define AT_CHECK for old version of ATen where the same function was called AT_ASSERT
6 | #ifndef AT_CHECK
7 | #define AT_CHECK AT_ASSERT
8 | #endif
9 |
10 | #define CHECK_CUDA(x) AT_CHECK((x).type().is_cuda(), #x " must be a CUDA tensor")
11 | #define CHECK_CPU(x) AT_CHECK(!(x).type().is_cuda(), #x " must be a CPU tensor")
12 | #define CHECK_CONTIGUOUS(x) AT_CHECK((x).is_contiguous(), #x " must be contiguous")
13 |
14 | #define CHECK_CUDA_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
15 | #define CHECK_CPU_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x)
--------------------------------------------------------------------------------
/inplace_abn/src/common.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | /*
6 | * General settings
7 | */
8 | const int WARP_SIZE = 32;
9 | const int MAX_BLOCK_SIZE = 512;
10 |
11 | template
12 | struct Pair {
13 | T v1, v2;
14 | __device__ Pair() {}
15 | __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {}
16 | __device__ Pair(T v) : v1(v), v2(v) {}
17 | __device__ Pair(int v) : v1(v), v2(v) {}
18 | __device__ Pair &operator+=(const Pair &a) {
19 | v1 += a.v1;
20 | v2 += a.v2;
21 | return *this;
22 | }
23 | };
24 |
25 | /*
26 | * Utility functions
27 | */
28 | template
29 | __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize,
30 | unsigned int mask = 0xffffffff) {
31 | #if CUDART_VERSION >= 9000
32 | return __shfl_xor_sync(mask, value, laneMask, width);
33 | #else
34 | return __shfl_xor(value, laneMask, width);
35 | #endif
36 | }
37 |
38 | __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); }
39 |
40 | static int getNumThreads(int nElem) {
41 | int threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE};
42 | for (int i = 0; i != 5; ++i) {
43 | if (nElem <= threadSizes[i]) {
44 | return threadSizes[i];
45 | }
46 | }
47 | return MAX_BLOCK_SIZE;
48 | }
49 |
50 | template
51 | static __device__ __forceinline__ T warpSum(T val) {
52 | #if __CUDA_ARCH__ >= 300
53 | for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
54 | val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE);
55 | }
56 | #else
57 | __shared__ T values[MAX_BLOCK_SIZE];
58 | values[threadIdx.x] = val;
59 | __threadfence_block();
60 | const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE;
61 | for (int i = 1; i < WARP_SIZE; i++) {
62 | val += values[base + ((i + threadIdx.x) % WARP_SIZE)];
63 | }
64 | #endif
65 | return val;
66 | }
67 |
68 | template
69 | static __device__ __forceinline__ Pair warpSum(Pair value) {
70 | value.v1 = warpSum(value.v1);
71 | value.v2 = warpSum(value.v2);
72 | return value;
73 | }
74 |
75 | template
76 | __device__ T reduce(Op op, int plane, int N, int C, int S) {
77 | T sum = (T)0;
78 | for (int batch = 0; batch < N; ++batch) {
79 | for (int x = threadIdx.x; x < S; x += blockDim.x) {
80 | sum += op(batch, plane, x);
81 | }
82 | }
83 |
84 | // sum over NumThreads within a warp
85 | sum = warpSum(sum);
86 |
87 | // 'transpose', and reduce within warp again
88 | __shared__ T shared[32];
89 | __syncthreads();
90 | if (threadIdx.x % WARP_SIZE == 0) {
91 | shared[threadIdx.x / WARP_SIZE] = sum;
92 | }
93 | if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) {
94 | // zero out the other entries in shared
95 | shared[threadIdx.x] = (T)0;
96 | }
97 | __syncthreads();
98 | if (threadIdx.x / WARP_SIZE == 0) {
99 | sum = warpSum(shared[threadIdx.x]);
100 | if (threadIdx.x == 0) {
101 | shared[0] = sum;
102 | }
103 | }
104 | __syncthreads();
105 |
106 | // Everyone picks it up, should be broadcast into the whole gradInput
107 | return shared[0];
108 | }
--------------------------------------------------------------------------------
/inplace_abn/src/inplace_abn.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 |
5 | #include "inplace_abn.h"
6 |
7 | std::vector mean_var(at::Tensor x) {
8 | if (x.is_cuda()) {
9 | if (x.type().scalarType() == at::ScalarType::Half) {
10 | return mean_var_cuda_h(x);
11 | } else {
12 | return mean_var_cuda(x);
13 | }
14 | } else {
15 | return mean_var_cpu(x);
16 | }
17 | }
18 |
19 | at::Tensor forward(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
20 | bool affine, float eps) {
21 | if (x.is_cuda()) {
22 | if (x.type().scalarType() == at::ScalarType::Half) {
23 | return forward_cuda_h(x, mean, var, weight, bias, affine, eps);
24 | } else {
25 | return forward_cuda(x, mean, var, weight, bias, affine, eps);
26 | }
27 | } else {
28 | return forward_cpu(x, mean, var, weight, bias, affine, eps);
29 | }
30 | }
31 |
32 | std::vector edz_eydz(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
33 | bool affine, float eps) {
34 | if (z.is_cuda()) {
35 | if (z.type().scalarType() == at::ScalarType::Half) {
36 | return edz_eydz_cuda_h(z, dz, weight, bias, affine, eps);
37 | } else {
38 | return edz_eydz_cuda(z, dz, weight, bias, affine, eps);
39 | }
40 | } else {
41 | return edz_eydz_cpu(z, dz, weight, bias, affine, eps);
42 | }
43 | }
44 |
45 | at::Tensor backward(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
46 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
47 | if (z.is_cuda()) {
48 | if (z.type().scalarType() == at::ScalarType::Half) {
49 | return backward_cuda_h(z, dz, var, weight, bias, edz, eydz, affine, eps);
50 | } else {
51 | return backward_cuda(z, dz, var, weight, bias, edz, eydz, affine, eps);
52 | }
53 | } else {
54 | return backward_cpu(z, dz, var, weight, bias, edz, eydz, affine, eps);
55 | }
56 | }
57 |
58 | void leaky_relu_forward(at::Tensor z, float slope) {
59 | at::leaky_relu_(z, slope);
60 | }
61 |
62 | void leaky_relu_backward(at::Tensor z, at::Tensor dz, float slope) {
63 | if (z.is_cuda()) {
64 | if (z.type().scalarType() == at::ScalarType::Half) {
65 | return leaky_relu_backward_cuda_h(z, dz, slope);
66 | } else {
67 | return leaky_relu_backward_cuda(z, dz, slope);
68 | }
69 | } else {
70 | return leaky_relu_backward_cpu(z, dz, slope);
71 | }
72 | }
73 |
74 | void elu_forward(at::Tensor z) {
75 | at::elu_(z);
76 | }
77 |
78 | void elu_backward(at::Tensor z, at::Tensor dz) {
79 | if (z.is_cuda()) {
80 | return elu_backward_cuda(z, dz);
81 | } else {
82 | return elu_backward_cpu(z, dz);
83 | }
84 | }
85 |
86 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
87 | m.def("mean_var", &mean_var, "Mean and variance computation");
88 | m.def("forward", &forward, "In-place forward computation");
89 | m.def("edz_eydz", &edz_eydz, "First part of backward computation");
90 | m.def("backward", &backward, "Second part of backward computation");
91 | m.def("leaky_relu_forward", &leaky_relu_forward, "Leaky relu forward computation");
92 | m.def("leaky_relu_backward", &leaky_relu_backward, "Leaky relu backward computation and inversion");
93 | m.def("elu_forward", &elu_forward, "Elu forward computation");
94 | m.def("elu_backward", &elu_backward, "Elu backward computation and inversion");
95 | }
96 |
--------------------------------------------------------------------------------
/inplace_abn/src/inplace_abn.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | #include
6 |
7 | std::vector mean_var_cpu(at::Tensor x);
8 | std::vector mean_var_cuda(at::Tensor x);
9 | std::vector mean_var_cuda_h(at::Tensor x);
10 |
11 | at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
12 | bool affine, float eps);
13 | at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
14 | bool affine, float eps);
15 | at::Tensor forward_cuda_h(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
16 | bool affine, float eps);
17 |
18 | std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
19 | bool affine, float eps);
20 | std::vector edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
21 | bool affine, float eps);
22 | std::vector edz_eydz_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
23 | bool affine, float eps);
24 |
25 | at::Tensor backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
26 | at::Tensor edz, at::Tensor eydz, bool affine, float eps);
27 | at::Tensor backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
28 | at::Tensor edz, at::Tensor eydz, bool affine, float eps);
29 | at::Tensor backward_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
30 | at::Tensor edz, at::Tensor eydz, bool affine, float eps);
31 |
32 | void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope);
33 | void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope);
34 | void leaky_relu_backward_cuda_h(at::Tensor z, at::Tensor dz, float slope);
35 |
36 | void elu_backward_cpu(at::Tensor z, at::Tensor dz);
37 | void elu_backward_cuda(at::Tensor z, at::Tensor dz);
38 |
39 | static void get_dims(at::Tensor x, int64_t& num, int64_t& chn, int64_t& sp) {
40 | num = x.size(0);
41 | chn = x.size(1);
42 | sp = 1;
43 | for (int64_t i = 2; i < x.ndimension(); ++i)
44 | sp *= x.size(i);
45 | }
46 |
47 | /*
48 | * Specialized CUDA reduction functions for BN
49 | */
50 | #ifdef __CUDACC__
51 |
52 | #include "utils/cuda.cuh"
53 |
54 | template
55 | __device__ T reduce(Op op, int plane, int N, int S) {
56 | T sum = (T)0;
57 | for (int batch = 0; batch < N; ++batch) {
58 | for (int x = threadIdx.x; x < S; x += blockDim.x) {
59 | sum += op(batch, plane, x);
60 | }
61 | }
62 |
63 | // sum over NumThreads within a warp
64 | sum = warpSum(sum);
65 |
66 | // 'transpose', and reduce within warp again
67 | __shared__ T shared[32];
68 | __syncthreads();
69 | if (threadIdx.x % WARP_SIZE == 0) {
70 | shared[threadIdx.x / WARP_SIZE] = sum;
71 | }
72 | if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) {
73 | // zero out the other entries in shared
74 | shared[threadIdx.x] = (T)0;
75 | }
76 | __syncthreads();
77 | if (threadIdx.x / WARP_SIZE == 0) {
78 | sum = warpSum(shared[threadIdx.x]);
79 | if (threadIdx.x == 0) {
80 | shared[0] = sum;
81 | }
82 | }
83 | __syncthreads();
84 |
85 | // Everyone picks it up, should be broadcast into the whole gradInput
86 | return shared[0];
87 | }
88 | #endif
89 |
--------------------------------------------------------------------------------
/inplace_abn/src/inplace_abn_cpu.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 |
5 | #include "utils/checks.h"
6 | #include "inplace_abn.h"
7 |
8 | at::Tensor reduce_sum(at::Tensor x) {
9 | if (x.ndimension() == 2) {
10 | return x.sum(0);
11 | } else {
12 | auto x_view = x.view({x.size(0), x.size(1), -1});
13 | return x_view.sum(-1).sum(0);
14 | }
15 | }
16 |
17 | at::Tensor broadcast_to(at::Tensor v, at::Tensor x) {
18 | if (x.ndimension() == 2) {
19 | return v;
20 | } else {
21 | std::vector broadcast_size = {1, -1};
22 | for (int64_t i = 2; i < x.ndimension(); ++i)
23 | broadcast_size.push_back(1);
24 |
25 | return v.view(broadcast_size);
26 | }
27 | }
28 |
29 | int64_t count(at::Tensor x) {
30 | int64_t count = x.size(0);
31 | for (int64_t i = 2; i < x.ndimension(); ++i)
32 | count *= x.size(i);
33 |
34 | return count;
35 | }
36 |
37 | at::Tensor invert_affine(at::Tensor z, at::Tensor weight, at::Tensor bias, bool affine, float eps) {
38 | if (affine) {
39 | return (z - broadcast_to(bias, z)) / broadcast_to(at::abs(weight) + eps, z);
40 | } else {
41 | return z;
42 | }
43 | }
44 |
45 | std::vector mean_var_cpu(at::Tensor x) {
46 | auto num = count(x);
47 | auto mean = reduce_sum(x) / num;
48 | auto diff = x - broadcast_to(mean, x);
49 | auto var = reduce_sum(diff.pow(2)) / num;
50 |
51 | return {mean, var};
52 | }
53 |
54 | at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
55 | bool affine, float eps) {
56 | auto gamma = affine ? at::abs(weight) + eps : at::ones_like(var);
57 | auto mul = at::rsqrt(var + eps) * gamma;
58 |
59 | x.sub_(broadcast_to(mean, x));
60 | x.mul_(broadcast_to(mul, x));
61 | if (affine) x.add_(broadcast_to(bias, x));
62 |
63 | return x;
64 | }
65 |
66 | std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
67 | bool affine, float eps) {
68 | auto edz = reduce_sum(dz);
69 | auto y = invert_affine(z, weight, bias, affine, eps);
70 | auto eydz = reduce_sum(y * dz);
71 |
72 | return {edz, eydz};
73 | }
74 |
75 | at::Tensor backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
76 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
77 | auto y = invert_affine(z, weight, bias, affine, eps);
78 | auto mul = affine ? at::rsqrt(var + eps) * (at::abs(weight) + eps) : at::rsqrt(var + eps);
79 |
80 | auto num = count(z);
81 | auto dx = (dz - broadcast_to(edz / num, dz) - y * broadcast_to(eydz / num, dz)) * broadcast_to(mul, dz);
82 | return dx;
83 | }
84 |
85 | void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope) {
86 | CHECK_CPU_INPUT(z);
87 | CHECK_CPU_INPUT(dz);
88 |
89 | AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cpu", ([&] {
90 | int64_t count = z.numel();
91 | auto *_z = z.data();
92 | auto *_dz = dz.data();
93 |
94 | for (int64_t i = 0; i < count; ++i) {
95 | if (_z[i] < 0) {
96 | _z[i] *= 1 / slope;
97 | _dz[i] *= slope;
98 | }
99 | }
100 | }));
101 | }
102 |
103 | void elu_backward_cpu(at::Tensor z, at::Tensor dz) {
104 | CHECK_CPU_INPUT(z);
105 | CHECK_CPU_INPUT(dz);
106 |
107 | AT_DISPATCH_FLOATING_TYPES(z.type(), "elu_backward_cpu", ([&] {
108 | int64_t count = z.numel();
109 | auto *_z = z.data();
110 | auto *_dz = dz.data();
111 |
112 | for (int64_t i = 0; i < count; ++i) {
113 | if (_z[i] < 0) {
114 | _z[i] = log1p(_z[i]);
115 | _dz[i] *= (_z[i] + 1.f);
116 | }
117 | }
118 | }));
119 | }
120 |
--------------------------------------------------------------------------------
/inplace_abn/src/inplace_abn_cuda_half.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 |
5 | #include
6 |
7 | #include "utils/checks.h"
8 | #include "utils/cuda.cuh"
9 | #include "inplace_abn.h"
10 |
11 | #include
12 |
13 | // Operations for reduce
14 | struct SumOpH {
15 | __device__ SumOpH(const half *t, int c, int s)
16 | : tensor(t), chn(c), sp(s) {}
17 | __device__ __forceinline__ float operator()(int batch, int plane, int n) {
18 | return __half2float(tensor[(batch * chn + plane) * sp + n]);
19 | }
20 | const half *tensor;
21 | const int chn;
22 | const int sp;
23 | };
24 |
25 | struct VarOpH {
26 | __device__ VarOpH(float m, const half *t, int c, int s)
27 | : mean(m), tensor(t), chn(c), sp(s) {}
28 | __device__ __forceinline__ float operator()(int batch, int plane, int n) {
29 | const auto t = __half2float(tensor[(batch * chn + plane) * sp + n]);
30 | return (t - mean) * (t - mean);
31 | }
32 | const float mean;
33 | const half *tensor;
34 | const int chn;
35 | const int sp;
36 | };
37 |
38 | struct GradOpH {
39 | __device__ GradOpH(float _weight, float _bias, const half *_z, const half *_dz, int c, int s)
40 | : weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {}
41 | __device__ __forceinline__ Pair operator()(int batch, int plane, int n) {
42 | float _y = (__half2float(z[(batch * chn + plane) * sp + n]) - bias) / weight;
43 | float _dz = __half2float(dz[(batch * chn + plane) * sp + n]);
44 | return Pair(_dz, _y * _dz);
45 | }
46 | const float weight;
47 | const float bias;
48 | const half *z;
49 | const half *dz;
50 | const int chn;
51 | const int sp;
52 | };
53 |
54 | /***********
55 | * mean_var
56 | ***********/
57 |
58 | __global__ void mean_var_kernel_h(const half *x, float *mean, float *var, int num, int chn, int sp) {
59 | int plane = blockIdx.x;
60 | float norm = 1.f / static_cast(num * sp);
61 |
62 | float _mean = reduce(SumOpH(x, chn, sp), plane, num, sp) * norm;
63 | __syncthreads();
64 | float _var = reduce(VarOpH(_mean, x, chn, sp), plane, num, sp) * norm;
65 |
66 | if (threadIdx.x == 0) {
67 | mean[plane] = _mean;
68 | var[plane] = _var;
69 | }
70 | }
71 |
72 | std::vector mean_var_cuda_h(at::Tensor x) {
73 | CHECK_CUDA_INPUT(x);
74 |
75 | // Extract dimensions
76 | int64_t num, chn, sp;
77 | get_dims(x, num, chn, sp);
78 |
79 | // Prepare output tensors
80 | auto mean = at::empty({chn},x.options().dtype(at::kFloat));
81 | auto var = at::empty({chn},x.options().dtype(at::kFloat));
82 |
83 | // Run kernel
84 | dim3 blocks(chn);
85 | dim3 threads(getNumThreads(sp));
86 | auto stream = at::cuda::getCurrentCUDAStream();
87 | mean_var_kernel_h<<>>(
88 | reinterpret_cast(x.data()),
89 | mean.data(),
90 | var.data(),
91 | num, chn, sp);
92 |
93 | return {mean, var};
94 | }
95 |
96 | /**********
97 | * forward
98 | **********/
99 |
100 | __global__ void forward_kernel_h(half *x, const float *mean, const float *var, const float *weight, const float *bias,
101 | bool affine, float eps, int num, int chn, int sp) {
102 | int plane = blockIdx.x;
103 |
104 | const float _mean = mean[plane];
105 | const float _var = var[plane];
106 | const float _weight = affine ? abs(weight[plane]) + eps : 1.f;
107 | const float _bias = affine ? bias[plane] : 0.f;
108 |
109 | const float mul = rsqrt(_var + eps) * _weight;
110 |
111 | for (int batch = 0; batch < num; ++batch) {
112 | for (int n = threadIdx.x; n < sp; n += blockDim.x) {
113 | half *x_ptr = x + (batch * chn + plane) * sp + n;
114 | float _x = __half2float(*x_ptr);
115 | float _y = (_x - _mean) * mul + _bias;
116 |
117 | *x_ptr = __float2half(_y);
118 | }
119 | }
120 | }
121 |
122 | at::Tensor forward_cuda_h(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
123 | bool affine, float eps) {
124 | CHECK_CUDA_INPUT(x);
125 | CHECK_CUDA_INPUT(mean);
126 | CHECK_CUDA_INPUT(var);
127 | CHECK_CUDA_INPUT(weight);
128 | CHECK_CUDA_INPUT(bias);
129 |
130 | // Extract dimensions
131 | int64_t num, chn, sp;
132 | get_dims(x, num, chn, sp);
133 |
134 | // Run kernel
135 | dim3 blocks(chn);
136 | dim3 threads(getNumThreads(sp));
137 | auto stream = at::cuda::getCurrentCUDAStream();
138 | forward_kernel_h<<>>(
139 | reinterpret_cast(x.data()),
140 | mean.data(),
141 | var.data(),
142 | weight.data(),
143 | bias.data(),
144 | affine, eps, num, chn, sp);
145 |
146 | return x;
147 | }
148 |
149 | __global__ void edz_eydz_kernel_h(const half *z, const half *dz, const float *weight, const float *bias,
150 | float *edz, float *eydz, bool affine, float eps, int num, int chn, int sp) {
151 | int plane = blockIdx.x;
152 |
153 | float _weight = affine ? abs(weight[plane]) + eps : 1.f;
154 | float _bias = affine ? bias[plane] : 0.f;
155 |
156 | Pair res = reduce, GradOpH>(GradOpH(_weight, _bias, z, dz, chn, sp), plane, num, sp);
157 | __syncthreads();
158 |
159 | if (threadIdx.x == 0) {
160 | edz[plane] = res.v1;
161 | eydz[plane] = res.v2;
162 | }
163 | }
164 |
165 | std::vector edz_eydz_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
166 | bool affine, float eps) {
167 | CHECK_CUDA_INPUT(z);
168 | CHECK_CUDA_INPUT(dz);
169 | CHECK_CUDA_INPUT(weight);
170 | CHECK_CUDA_INPUT(bias);
171 |
172 | // Extract dimensions
173 | int64_t num, chn, sp;
174 | get_dims(z, num, chn, sp);
175 |
176 | auto edz = at::empty({chn},z.options().dtype(at::kFloat));
177 | auto eydz = at::empty({chn},z.options().dtype(at::kFloat));
178 |
179 | // Run kernel
180 | dim3 blocks(chn);
181 | dim3 threads(getNumThreads(sp));
182 | auto stream = at::cuda::getCurrentCUDAStream();
183 | edz_eydz_kernel_h<<>>(
184 | reinterpret_cast(z.data()),
185 | reinterpret_cast(dz.data()),
186 | weight.data(),
187 | bias.data(),
188 | edz.data(),
189 | eydz.data(),
190 | affine, eps, num, chn, sp);
191 |
192 | return {edz, eydz};
193 | }
194 |
195 | __global__ void backward_kernel_h(const half *z, const half *dz, const float *var, const float *weight, const float *bias, const float *edz,
196 | const float *eydz, half *dx, bool affine, float eps, int num, int chn, int sp) {
197 | int plane = blockIdx.x;
198 |
199 | float _weight = affine ? abs(weight[plane]) + eps : 1.f;
200 | float _bias = affine ? bias[plane] : 0.f;
201 | float _var = var[plane];
202 | float _edz = edz[plane];
203 | float _eydz = eydz[plane];
204 |
205 | float _mul = _weight * rsqrt(_var + eps);
206 | float count = float(num * sp);
207 |
208 | for (int batch = 0; batch < num; ++batch) {
209 | for (int n = threadIdx.x; n < sp; n += blockDim.x) {
210 | float _dz = __half2float(dz[(batch * chn + plane) * sp + n]);
211 | float _y = (__half2float(z[(batch * chn + plane) * sp + n]) - _bias) / _weight;
212 |
213 | dx[(batch * chn + plane) * sp + n] = __float2half((_dz - _edz / count - _y * _eydz / count) * _mul);
214 | }
215 | }
216 | }
217 |
218 | at::Tensor backward_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
219 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
220 | CHECK_CUDA_INPUT(z);
221 | CHECK_CUDA_INPUT(dz);
222 | CHECK_CUDA_INPUT(var);
223 | CHECK_CUDA_INPUT(weight);
224 | CHECK_CUDA_INPUT(bias);
225 | CHECK_CUDA_INPUT(edz);
226 | CHECK_CUDA_INPUT(eydz);
227 |
228 | // Extract dimensions
229 | int64_t num, chn, sp;
230 | get_dims(z, num, chn, sp);
231 |
232 | auto dx = at::zeros_like(z);
233 |
234 | // Run kernel
235 | dim3 blocks(chn);
236 | dim3 threads(getNumThreads(sp));
237 | auto stream = at::cuda::getCurrentCUDAStream();
238 | backward_kernel_h<<>>(
239 | reinterpret_cast(z.data()),
240 | reinterpret_cast(dz.data()),
241 | var.data(),
242 | weight.data(),
243 | bias.data(),
244 | edz.data(),
245 | eydz.data(),
246 | reinterpret_cast(dx.data()),
247 | affine, eps, num, chn, sp);
248 |
249 | return dx;
250 | }
251 |
252 | __global__ void leaky_relu_backward_impl_h(half *z, half *dz, float slope, int64_t count) {
253 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x){
254 | float _z = __half2float(z[i]);
255 | if (_z < 0) {
256 | dz[i] = __float2half(__half2float(dz[i]) * slope);
257 | z[i] = __float2half(_z / slope);
258 | }
259 | }
260 | }
261 |
262 | void leaky_relu_backward_cuda_h(at::Tensor z, at::Tensor dz, float slope) {
263 | CHECK_CUDA_INPUT(z);
264 | CHECK_CUDA_INPUT(dz);
265 |
266 | int64_t count = z.numel();
267 | dim3 threads(getNumThreads(count));
268 | dim3 blocks = (count + threads.x - 1) / threads.x;
269 | auto stream = at::cuda::getCurrentCUDAStream();
270 | leaky_relu_backward_impl_h<<>>(
271 | reinterpret_cast(z.data()),
272 | reinterpret_cast(dz.data()),
273 | slope, count);
274 | }
275 |
276 |
--------------------------------------------------------------------------------
/inplace_abn/src/utils/checks.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | // Define AT_CHECK for old version of ATen where the same function was called AT_ASSERT
6 | #ifndef AT_CHECK
7 | #define AT_CHECK AT_ASSERT
8 | #endif
9 |
10 | #define CHECK_CUDA(x) AT_CHECK((x).type().is_cuda(), #x " must be a CUDA tensor")
11 | #define CHECK_CPU(x) AT_CHECK(!(x).type().is_cuda(), #x " must be a CPU tensor")
12 | #define CHECK_CONTIGUOUS(x) AT_CHECK((x).is_contiguous(), #x " must be contiguous")
13 |
14 | #define CHECK_CUDA_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
15 | #define CHECK_CPU_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x)
--------------------------------------------------------------------------------
/inplace_abn/src/utils/common.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | /*
6 | * Functions to share code between CPU and GPU
7 | */
8 |
9 | #ifdef __CUDACC__
10 | // CUDA versions
11 |
12 | #define HOST_DEVICE __host__ __device__
13 | #define INLINE_HOST_DEVICE __host__ __device__ inline
14 | #define FLOOR(x) floor(x)
15 |
16 | #if __CUDA_ARCH__ >= 600
17 | // Recent compute capabilities have block-level atomicAdd for all data types, so we use that
18 | #define ACCUM(x,y) atomicAdd_block(&(x),(y))
19 | #else
20 | // Older architectures don't have block-level atomicAdd, nor atomicAdd for doubles, so we defer to atomicAdd for float
21 | // and use the known atomicCAS-based implementation for double
22 | template
23 | __device__ inline data_t atomic_add(data_t *address, data_t val) {
24 | return atomicAdd(address, val);
25 | }
26 |
27 | template<>
28 | __device__ inline double atomic_add(double *address, double val) {
29 | unsigned long long int* address_as_ull = (unsigned long long int*)address;
30 | unsigned long long int old = *address_as_ull, assumed;
31 | do {
32 | assumed = old;
33 | old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed)));
34 | } while (assumed != old);
35 | return __longlong_as_double(old);
36 | }
37 |
38 | #define ACCUM(x,y) atomic_add(&(x),(y))
39 | #endif // #if __CUDA_ARCH__ >= 600
40 |
41 | #else
42 | // CPU versions
43 |
44 | #define HOST_DEVICE
45 | #define INLINE_HOST_DEVICE inline
46 | #define FLOOR(x) std::floor(x)
47 | #define ACCUM(x,y) (x) += (y)
48 |
49 | #endif // #ifdef __CUDACC__
--------------------------------------------------------------------------------
/inplace_abn/src/utils/cuda.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | /*
4 | * General settings and functions
5 | */
6 | const int WARP_SIZE = 32;
7 | const int MAX_BLOCK_SIZE = 1024;
8 |
9 | static int getNumThreads(int nElem) {
10 | int threadSizes[6] = {32, 64, 128, 256, 512, MAX_BLOCK_SIZE};
11 | for (int i = 0; i < 6; ++i) {
12 | if (nElem <= threadSizes[i]) {
13 | return threadSizes[i];
14 | }
15 | }
16 | return MAX_BLOCK_SIZE;
17 | }
18 |
19 | /*
20 | * Reduction utilities
21 | */
22 | template
23 | __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize,
24 | unsigned int mask = 0xffffffff) {
25 | #if CUDART_VERSION >= 9000
26 | return __shfl_xor_sync(mask, value, laneMask, width);
27 | #else
28 | return __shfl_xor(value, laneMask, width);
29 | #endif
30 | }
31 |
32 | __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); }
33 |
34 | template
35 | struct Pair {
36 | T v1, v2;
37 | __device__ Pair() {}
38 | __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {}
39 | __device__ Pair(T v) : v1(v), v2(v) {}
40 | __device__ Pair(int v) : v1(v), v2(v) {}
41 | __device__ Pair &operator+=(const Pair &a) {
42 | v1 += a.v1;
43 | v2 += a.v2;
44 | return *this;
45 | }
46 | };
47 |
48 | template
49 | static __device__ __forceinline__ T warpSum(T val) {
50 | #if __CUDA_ARCH__ >= 300
51 | for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
52 | val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE);
53 | }
54 | #else
55 | __shared__ T values[MAX_BLOCK_SIZE];
56 | values[threadIdx.x] = val;
57 | __threadfence_block();
58 | const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE;
59 | for (int i = 1; i < WARP_SIZE; i++) {
60 | val += values[base + ((i + threadIdx.x) % WARP_SIZE)];
61 | }
62 | #endif
63 | return val;
64 | }
65 |
66 | template
67 | static __device__ __forceinline__ Pair warpSum(Pair value) {
68 | value.v1 = warpSum(value.v1);
69 | value.v2 = warpSum(value.v2);
70 | return value;
71 | }
--------------------------------------------------------------------------------
/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hlzhu09/Hierarchical-Human-Parsing/58737dbb7763f5a7baafd4801c5a287facf536cd/modules/__init__.py
--------------------------------------------------------------------------------
/modules/inits.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 |
4 | def uniform(size, tensor):
5 | bound = 1.0 / math.sqrt(size)
6 | if tensor is not None:
7 | tensor.data.uniform_(-bound, bound)
8 |
9 |
10 | def kaiming_uniform(tensor, fan, a):
11 | if tensor is not None:
12 | bound = math.sqrt(6 / ((1 + a**2) * fan))
13 | tensor.data.uniform_(-bound, bound)
14 |
15 |
16 | def glorot(tensor):
17 | if tensor is not None:
18 | stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
19 | tensor.data.uniform_(-stdv, stdv)
20 |
21 |
22 | def zeros(tensor):
23 | if tensor is not None:
24 | tensor.data.fill_(0)
25 |
26 |
27 | def ones(tensor):
28 | if tensor is not None:
29 | tensor.data.fill_(1)
30 |
31 |
32 | def normal(tensor, mean, std):
33 | if tensor is not None:
34 | tensor.data.normal_(mean, std)
35 |
36 |
37 | def reset(nn):
38 | def _reset(item):
39 | if hasattr(item, 'reset_parameters'):
40 | item.reset_parameters()
41 |
42 | if nn is not None:
43 | if hasattr(nn, 'children') and len(list(nn.children())) > 0:
44 | for item in nn.children():
45 | _reset(item)
46 | else:
47 | _reset(nn)
48 |
--------------------------------------------------------------------------------
/modules/parse_mod.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 | from inplace_abn.bn import InPlaceABNSync
8 | from modules.com_mod import SEModule, ContextContrastedModule
9 |
10 | BatchNorm2d = functools.partial(InPlaceABNSync, activation='none')
11 |
12 | class ASPPModule(nn.Module):
13 | """ASPP"""
14 |
15 | def __init__(self, in_dim, out_dim, scale=1):
16 | super(ASPPModule, self).__init__()
17 | self.gap = nn.Sequential(nn.AdaptiveAvgPool2d(1),
18 | nn.Conv2d(in_dim, out_dim, 1, bias=False), InPlaceABNSync(out_dim))
19 |
20 | self.dilation_0 = nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=1, padding=0, dilation=1, bias=False),
21 | InPlaceABNSync(out_dim), SEModule(out_dim, reduction=16))
22 |
23 | self.dilation_1 = nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=1, padding=0, dilation=1, bias=False),
24 | InPlaceABNSync(out_dim),
25 | nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=6, dilation=6, bias=False),
26 | InPlaceABNSync(out_dim),SEModule(out_dim, reduction=16))
27 |
28 | self.dilation_2 = nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=1, padding=0, dilation=1, bias=False),
29 | InPlaceABNSync(out_dim),
30 | nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=12, dilation=12, bias=False),
31 | InPlaceABNSync(out_dim), SEModule(out_dim, reduction=16))
32 |
33 | self.dilation_3 = nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=1, padding=0, dilation=1, bias=False),
34 | InPlaceABNSync(out_dim),
35 | nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=18, dilation=18, bias=False),
36 | InPlaceABNSync(out_dim), SEModule(out_dim, reduction=16))
37 |
38 | self.psaa_conv = nn.Sequential(nn.Conv2d(in_dim + 5 * out_dim, out_dim, 1, padding=0, bias=False),
39 | InPlaceABNSync(out_dim),
40 | nn.Conv2d(out_dim, 5, 1, bias=True),
41 | nn.Sigmoid())
42 |
43 | self.project = nn.Sequential(nn.Conv2d(out_dim * 5, out_dim, kernel_size=1, padding=0, bias=False),
44 | InPlaceABNSync(out_dim))
45 |
46 | def forward(self, x):
47 | # parallel branch
48 | feat0 = self.dilation_0(x)
49 | feat1 = self.dilation_1(x)
50 | feat2 = self.dilation_2(x)
51 | feat3 = self.dilation_3(x)
52 | n, c, h, w = feat0.size()
53 | gp = self.gap(x)
54 |
55 | feat4 = gp.expand(n, c, h, w)
56 | # psaa
57 | y1 = torch.cat((x, feat0, feat1, feat2, feat3, feat4), 1)
58 |
59 | psaa_att = self.psaa_conv(y1)
60 |
61 | psaa_att_list = torch.split(psaa_att, 1, dim=1)
62 |
63 | y2 = torch.cat((psaa_att_list[0] * feat0, psaa_att_list[1] * feat1, psaa_att_list[2] * feat2, psaa_att_list[3] * feat3, psaa_att_list[4]*feat4), 1)
64 | out = self.project(y2)
65 | return out
66 |
--------------------------------------------------------------------------------
/network/ResNet_stem_converter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from collections import OrderedDict
3 |
4 | saved_state_dict = torch.load('../checkpoints/init/resnet101_stem.pth')
5 |
6 | new_state=OrderedDict()
7 | for k, v in saved_state_dict.items():
8 | if k=='conv1.0.weight':
9 | new_state.update({'conv1.weight':v})
10 | elif k=='conv1.1.weight':
11 | new_state.update({'bn1.weight': v})
12 | elif k=='conv1.1.bias':
13 | new_state.update({'bn1.bias': v})
14 | elif k=='conv1.1.running_mean':
15 | new_state.update({'bn1.running_mean': v})
16 | elif k=='conv1.1.running_var':
17 | new_state.update({'bn1.running_var': v})
18 | elif k=='conv1.3.weight':
19 | new_state.update({'conv2.weight': v})
20 | elif k=='conv1.4.weight':
21 | new_state.update({'bn2.weight':v})
22 | elif k=='conv1.4.bias':
23 | new_state.update({'bn2.bias': v})
24 | elif k=='conv1.4.running_mean':
25 | new_state.update({'bn2.running_mean': v})
26 | elif k=='conv1.4.running_var':
27 | new_state.update({'bn2.running_var': v})
28 | elif k=='conv1.6.weight':
29 | new_state.update({'conv3.weight':v})
30 | elif k=='bn1.weight':
31 | new_state.update({'bn3.weight': v})
32 | elif k=='bn1.bias':
33 | new_state.update({'bn3.bias': v})
34 | elif k=='bn1.running_mean':
35 | new_state.update({'bn3.running_mean': v})
36 | elif k=='bn1.running_var':
37 | new_state.update({'bn3.running_var': v})
38 | else:
39 | new_state.update({k: v})
40 |
41 |
42 |
43 | torch.save(new_state, '../checkpoints/init/new_resnet101_stem.pth')
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
--------------------------------------------------------------------------------
/network/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hlzhu09/Hierarchical-Human-Parsing/58737dbb7763f5a7baafd4801c5a287facf536cd/network/__init__.py
--------------------------------------------------------------------------------
/network/baseline.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import functional as F
6 |
7 | from inplace_abn.bn import InPlaceABNSync
8 | from modules.com_mod import Bottleneck, ResGridNet, SEModule
9 | from modules.parse_mod import ASPPModule
10 |
11 | BatchNorm2d = functools.partial(InPlaceABNSync, activation='none')
12 |
13 | class DecoderModule(nn.Module):
14 |
15 | def __init__(self, num_classes):
16 | super(DecoderModule, self).__init__()
17 | self.conv0 = nn.Sequential(nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False),
18 | BatchNorm2d(256), nn.ReLU(inplace=False))
19 | self.se = nn.Sequential(nn.AdaptiveAvgPool2d(1),
20 | nn.Conv2d(256, 256, 1, bias=False),
21 | nn.ReLU(True),
22 | nn.Conv2d(256, 256, 1, bias=True),
23 | nn.Sigmoid())
24 | def forward(self, x):
25 | out=self.conv0(x)
26 | out = out + self.se(out)*out
27 | return out
28 |
29 |
30 | class GNN_infer(nn.Module):
31 | def __init__(self, adj_matrix, upper_half_node=[1, 2, 3, 4], lower_half_node=[5, 6], in_dim=256, hidden_dim=64,
32 | cls_p=7, cls_h=3, cls_f=2):
33 | super(GNN_infer, self).__init__()
34 | self.cls_p = cls_p
35 | self.cls_h = cls_h
36 | self.cls_f = cls_f
37 | self.in_dim = in_dim
38 | self.hidden_dim = hidden_dim
39 |
40 | # node feature transform
41 | self.p_conv = nn.Sequential(
42 | nn.Conv2d(in_dim, hidden_dim * cls_p, kernel_size=1, padding=0, stride=1, bias=False),
43 | BatchNorm2d(hidden_dim * cls_p), nn.ReLU(inplace=False))
44 | self.h_conv = nn.Sequential(
45 | nn.Conv2d(in_dim, hidden_dim * cls_h, kernel_size=1, padding=0, stride=1, bias=False),
46 | BatchNorm2d(hidden_dim * cls_h), nn.ReLU(inplace=False))
47 | self.f_conv = nn.Sequential(
48 | nn.Conv2d(in_dim, hidden_dim * cls_f, kernel_size=1, padding=0, stride=1, bias=False),
49 | BatchNorm2d(hidden_dim * cls_f), nn.ReLU(inplace=False))
50 |
51 | # node supervision
52 | self.node_seg = nn.Conv2d(hidden_dim, 1, 1)
53 |
54 | def forward(self, xp, xh, xf):
55 | # gnn inference at stride 8
56 | # feature transform
57 | f_node_list = list(torch.split(self.f_conv(xf), self.hidden_dim, dim=1))
58 | h_node_list = list(torch.split(self.h_conv(xh), self.hidden_dim, dim=1))
59 | p_node_list = list(torch.split(self.p_conv(xp), self.hidden_dim, dim=1))
60 |
61 | # node supervision
62 | f_seg = torch.cat([self.node_seg(node) for node in f_node_list], dim=1)
63 | h_seg = torch.cat([self.node_seg(node) for node in h_node_list], dim=1)
64 | p_seg = torch.cat([self.node_seg(node) for node in p_node_list], dim=1)
65 |
66 | return [p_seg], [h_seg], [f_seg], [], [], [
67 | ], [], [], [], []
68 |
69 |
70 |
71 | class Decoder(nn.Module):
72 | def __init__(self, num_classes=7, hbody_cls=3, fbody_cls=2):
73 | super(Decoder, self).__init__()
74 | self.layer5 = ASPPModule(2048, 512)
75 | self.layer_part = DecoderModule(num_classes)
76 | self.layer_half = DecoderModule(hbody_cls)
77 | self.layer_full = DecoderModule(fbody_cls)
78 |
79 | self.layer_dsn = nn.Sequential(nn.Conv2d(1024, 256, kernel_size=3, stride=1, padding=1),
80 | BatchNorm2d(256), nn.ReLU(inplace=False),
81 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1, padding=0, bias=True))
82 |
83 | self.skip = nn.Sequential(nn.Conv2d(512, 512, kernel_size=1, padding=0, bias=False),
84 | BatchNorm2d(512), nn.ReLU(inplace=False),
85 | )
86 | self.fuse = nn.Sequential(nn.Conv2d(1024, 512, kernel_size=3, padding=1, bias=False),
87 | BatchNorm2d(512), nn.ReLU(inplace=False))
88 |
89 |
90 | # adjacent matrix for pascal person
91 | self.adj_matrix = torch.tensor(
92 | [[0, 1, 0, 0, 0, 0], [1, 0, 1, 0, 1, 0], [0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0], [0, 1, 0, 0, 0, 1],
93 | [0, 0, 0, 0, 1, 0]], requires_grad=False)
94 |
95 | # infer with hierarchical person graph
96 | self.gnn_infer = GNN_infer(adj_matrix=self.adj_matrix, upper_half_node=[1, 2, 3, 4], lower_half_node=[5, 6],
97 | in_dim=256, hidden_dim=32, cls_p=7, cls_h=3, cls_f=2)
98 | # aux layer
99 | self.layer_dsn = nn.Sequential(nn.Conv2d(1024, 256, kernel_size=3, stride=1, padding=1),
100 | BatchNorm2d(256), nn.ReLU(inplace=False),
101 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1, padding=0, bias=True))
102 |
103 | def forward(self, x):
104 | x_dsn = self.layer_dsn(x[-2])
105 | _,_,h,w = x[1].size()
106 | context = self.layer5(x[-1])
107 | context = F.interpolate(context, size=(h, w), mode='bilinear', align_corners=True)
108 | context = self.fuse(torch.cat([self.skip(x[1]), context], dim=1))
109 |
110 | p_fea = self.layer_part(context)
111 | h_fea = self.layer_half(context)
112 | f_fea = self.layer_full(context)
113 |
114 | # gnn infer
115 | p_seg, h_seg, f_seg, decomp_map_f, decomp_map_u, decomp_map_l, comp_map_f, comp_map_u, comp_map_l, \
116 | Fdep_att_list= self.gnn_infer(p_fea, h_fea, f_fea)
117 |
118 | return p_seg, h_seg, f_seg, decomp_map_f, decomp_map_u, decomp_map_l, comp_map_f, comp_map_u, comp_map_l, \
119 | Fdep_att_list, x_dsn
120 |
121 | class OCNet(nn.Module):
122 | def __init__(self, block, layers, num_classes):
123 | super(OCNet, self).__init__()
124 | self.encoder = ResGridNet(block, layers)
125 | self.decoder = Decoder(num_classes=num_classes)
126 |
127 | for m in self.modules():
128 | if isinstance(m, nn.Conv2d):
129 | nn.init.kaiming_normal_(m.weight.data)
130 | elif isinstance(m, InPlaceABNSync):
131 | m.weight.data.fill_(1)
132 | m.bias.data.zero_()
133 |
134 | def forward(self, x):
135 | x = self.encoder(x)
136 | x = self.decoder(x)
137 | return x
138 |
139 | def get_model(num_classes=20):
140 | model = OCNet(Bottleneck, [3, 4, 23, 3], num_classes) # 101
141 | return model
142 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.2.0
2 | torchvision==0.4.0
3 | numpy
4 | opencv-python
5 | tqdm
6 |
--------------------------------------------------------------------------------
/train_pascal.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train_baseline.py --init --method baseline --crop-size 473 --batch-size 20 --learning-rate 1e-2
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hlzhu09/Hierarchical-Human-Parsing/58737dbb7763f5a7baafd4801c5a287facf536cd/utils/__init__.py
--------------------------------------------------------------------------------
/utils/aaf/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hlzhu09/Hierarchical-Human-Parsing/58737dbb7763f5a7baafd4801c5a287facf536cd/utils/aaf/__init__.py
--------------------------------------------------------------------------------
/utils/aaf/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hlzhu09/Hierarchical-Human-Parsing/58737dbb7763f5a7baafd4801c5a287facf536cd/utils/aaf/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/aaf/__pycache__/layers.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hlzhu09/Hierarchical-Human-Parsing/58737dbb7763f5a7baafd4801c5a287facf536cd/utils/aaf/__pycache__/layers.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/aaf/__pycache__/losses.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hlzhu09/Hierarchical-Human-Parsing/58737dbb7763f5a7baafd4801c5a287facf536cd/utils/aaf/__pycache__/losses.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/aaf/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 |
5 | def eightway_activation(x):
6 | """Retrieves neighboring pixels/features on the eight corners from
7 | a 3x3 patch.
8 |
9 | Args:
10 | x: A tensor of size [batch_size, height_in, width_in, channels]
11 |
12 | Returns:
13 | A tensor of size [batch_size, height_in, width_in, channels, 8]
14 | """
15 | # Get the number of channels in the input.
16 | shape_x = list(x.shape)
17 | if len(shape_x) != 4:
18 | raise ValueError('Only support for 4-D tensors!')
19 |
20 | # Pad at the margin.
21 | x = F.pad(x,
22 | pad=(0,0,1,1,1,1,0,0),
23 | mode='reflect')
24 | # Get eight neighboring pixels/features.
25 | x_groups = [
26 | x[:, 1:-1, :-2, :].clone(), # left
27 | x[:, 1:-1, 2:, :].clone(), # right
28 | x[:, :-2, 1:-1, :].clone(), # up
29 | x[:, 2:, 1:-1, :].clone(), # down
30 | x[:, :-2, :-2, :].clone(), # left-up
31 | x[:, 2:, :-2, :].clone(), # left-down
32 | x[:, :-2, 2:, :].clone(), # right-up
33 | x[:, 2:, 2:, :].clone() # right-down
34 | ]
35 | output = [
36 | torch.unsqueeze(c, dim=-1) for c in x_groups
37 | ]
38 | output = torch.cat(output, dim=-1)
39 |
40 | return output
41 |
42 |
43 | def eightcorner_activation(x, size):
44 | """Retrieves neighboring pixels one the eight corners from a
45 | (2*size+1)x(2*size+1) patch.
46 |
47 | Args:
48 | x: A tensor of size [batch_size, height_in, width_in, channels]
49 | size: A number indicating the half size of a patch.
50 |
51 | Returns:
52 | A tensor of size [batch_size, height_in, width_in, channels, 8]
53 | """
54 | # Get the number of channels in the input.
55 | shape_x = list(x.shape)
56 | if len(shape_x) != 4:
57 | raise ValueError('Only support for 4-D tensors!')
58 | n, c, h, w = shape_x
59 |
60 | # Pad at the margin.
61 | p = size
62 | x_pad = F.pad(x,
63 | pad=(p,p,p,p,0,0,0,0),
64 | mode='constant',
65 | value=0)
66 |
67 | # Get eight corner pixels/features in the patch.
68 | x_groups = []
69 | for st_y in range(0,2*size+1,size):
70 | for st_x in range(0,2*size+1,size):
71 | if st_y == size and st_x == size:
72 | # Ignore the center pixel/feature.
73 | continue
74 |
75 | x_neighbor = x_pad[:, :, st_y:st_y+h, st_x:st_x+w].clone()
76 | x_groups.append(x_neighbor)
77 |
78 | output = [torch.unsqueeze(c, dim=-1) for c in x_groups]
79 | output = torch.cat(output, dim=-1)
80 |
81 | return output
82 |
83 |
84 | def ignores_from_label(labels, num_classes, size, ignore_index):
85 | """Retrieves ignorable pixels from the ground-truth labels.
86 |
87 | This function returns a binary map in which 1 denotes ignored pixels
88 | and 0 means not ignored ones. For those ignored pixels, they are not
89 | only the pixels with label value >= num_classes, but also the
90 | corresponding neighboring pixels, which are on the the eight cornerls
91 | from a (2*size+1)x(2*size+1) patch.
92 |
93 | Args:
94 | labels: A tensor of size [batch_size, height_in, width_in], indicating
95 | semantic segmentation ground-truth labels.
96 | num_classes: A number indicating the total number of valid classes. The
97 | labels ranges from 0 to (num_classes-1), and any value >= num_classes
98 | would be ignored.
99 | size: A number indicating the half size of a patch.
100 |
101 | Return:
102 | A tensor of size [batch_size, height_in, width_in, 8]
103 | """
104 | # Get the number of channels in the input.
105 | shape_lab = list(labels.shape)
106 | if len(shape_lab) != 3:
107 | raise ValueError('Only support for 3-D label tensors!')
108 | n, h, w = shape_lab
109 |
110 | # Retrieve ignored pixels with label value >= num_classes.
111 | # ignore = labels>num_classes-1 # NxHxW
112 | ignore = (labels==ignore_index)
113 |
114 | # Pad at the margin.
115 | p = size
116 | ignore_pad = F.pad(ignore,
117 | pad=(p,p,p,p,0,0),
118 | mode='constant',
119 | value=1)
120 |
121 | # Retrieve eight corner pixels from the center, where the center
122 | # is ignored. Note that it should be bi-directional. For example,
123 | # when computing AAF loss with top-left pixels, the ignored pixels
124 | # might be the center or the top-left ones.
125 | ignore_groups= []
126 | for st_y in range(2*size,-1,-size):
127 | for st_x in range(2*size,-1,-size):
128 | if st_y == size and st_x == size:
129 | continue
130 | ignore_neighbor = ignore_pad[:,st_y:st_y+h,st_x:st_x+w].clone()
131 | mask = ignore_neighbor | ignore
132 | ignore_groups.append(mask)
133 |
134 | ig = 0
135 | for st_y in range(0,2*size+1,size):
136 | for st_x in range(0,2*size+1,size):
137 | if st_y == size and st_x == size:
138 | continue
139 | ignore_neighbor = ignore_pad[:,st_y:st_y+h,st_x:st_x+w].clone()
140 | mask = ignore_neighbor | ignore_groups[ig]
141 | ignore_groups[ig] = mask
142 | ig += 1
143 |
144 | ignore_groups = [
145 | torch.unsqueeze(c, dim=-1) for c in ignore_groups
146 | ] # NxHxWx1
147 | ignore = torch.cat(ignore_groups, dim=-1) #NxHxWx8
148 |
149 | return ignore
150 |
151 |
152 | def edges_from_label(labels, size, ignore_class=255):
153 | """Retrieves edge positions from the ground-truth labels.
154 |
155 | This function computes the edge map by considering if the pixel values
156 | are equal between the center and the neighboring pixels on the eight
157 | corners from a (2*size+1)*(2*size+1) patch. Ignore edges where the any
158 | of the paired pixels with label value >= num_classes.
159 |
160 | Args:
161 | labels: A tensor of size [batch_size, height_in, width_in], indicating
162 | semantic segmentation ground-truth labels.
163 | size: A number indicating the half size of a patch.
164 | ignore_class: A number indicating the label value to ignore.
165 |
166 | Return:
167 | A tensor of size [batch_size, height_in, width_in, 1, 8]
168 | """
169 | # Get the number of channels in the input.
170 | shape_lab = list(labels.shape)
171 | if len(shape_lab) != 4:
172 | raise ValueError('Only support for 4-D label tensors!')
173 | n, h, w, c = shape_lab
174 |
175 | # Pad at the margin.
176 | p = size
177 | labels_pad = F.pad(
178 | labels, pad=(0,0,p,p,p,p,0,0),
179 | mode='constant',
180 | value=ignore_class)
181 |
182 | # Get the edge by comparing label value of the center and it paired pixels.
183 | edge_groups= []
184 | for st_y in range(0,2*size+1,size):
185 | for st_x in range(0,2*size+1,size):
186 | if st_y == size and st_x == size:
187 | continue
188 | labels_neighbor = labels_pad[:,st_y:st_y+h,st_x:st_x+w]
189 | edge = labels_neighbor!=labels
190 | edge_groups.append(edge)
191 |
192 | edge_groups = [
193 | torch.unsqueeze(c, dim=-1) for c in edge_groups
194 | ] # NxHxWx1x1
195 | edge = torch.cat(edge_groups, dim=-1) #NxHxWx1x8
196 |
197 | return edge
198 |
--------------------------------------------------------------------------------
/utils/aaf/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import utils.aaf.layers as nnx
4 | import numpy as np
5 |
6 | def affinity_loss(labels,
7 | probs,
8 | num_classes,
9 | kld_margin):
10 | """Affinity Field (AFF) loss.
11 |
12 | This function computes AFF loss. There are several components in the
13 | function:
14 | 1) extracts edges from the ground-truth labels.
15 | 2) extracts ignored pixels and their paired pixels (the neighboring
16 | pixels on the eight corners).
17 | 3) extracts neighboring pixels on the eight corners from a 3x3 patch.
18 | 4) computes KL-Divergence between center pixels and their neighboring
19 | pixels from the eight corners.
20 |
21 | Args:
22 | labels: A tensor of size [batch_size, height_in, width_in], indicating
23 | semantic segmentation ground-truth labels.
24 | probs: A tensor of size [batch_size, height_in, width_in, num_classes],
25 | indicating segmentation predictions.
26 | num_classes: A number indicating the total number of valid classes.
27 | kld_margin: A number indicating the margin for KL-Divergence at edge.
28 |
29 | Returns:
30 | Two 1-D tensors value indicating the loss at edge and non-edge.
31 | """
32 | # Compute ignore map (e.g, label of 255 and their paired pixels).
33 |
34 | labels = torch.squeeze(labels, dim=1) # NxHxW
35 | ignore = nnx.ignores_from_label(labels, num_classes, 1) # NxHxWx8
36 | not_ignore = np.logical_not(ignore)
37 | not_ignore = torch.unsqueeze(not_ignore, dim=3) # NxHxWx1x8
38 |
39 | # Compute edge map.
40 | one_hot_lab = F.one_hot(labels, depth=num_classes)
41 | edge = nnx.edges_from_label(one_hot_lab, 1, 255) # NxHxWxCx8
42 |
43 | # Remove ignored pixels from the edge/non-edge.
44 | edge = np.logical_and(edge, not_ignore)
45 | not_edge = np.logical_and(np.logical_not(edge), not_ignore)
46 |
47 | edge_indices = torch.nonzero(torch.reshape(edge, (-1,)))
48 | not_edge_indices = torch.nonzero(torch.reshape(not_edge, (-1,)))
49 |
50 | # Extract eight corner from the center in a patch as paired pixels.
51 | probs_paired = nnx.eightcorner_activation(probs, 1) # NxHxWxCx8
52 | probs = torch.unsqueeze(probs, dim=-1) # NxHxWxCx1
53 | bot_epsilon = 1e-4
54 | top_epsilon = 1.0
55 |
56 | neg_probs = np.clip(
57 | 1-probs, bot_epsilon, top_epsilon)
58 | neg_probs_paired = np.clip(
59 | 1-probs_paired, bot_epsilon, top_epsilon)
60 | probs = np.clip(
61 | probs, bot_epsilon, top_epsilon)
62 | probs_paired = np.clip(
63 | probs_paired, bot_epsilon, top_epsilon)
64 |
65 | # Compute KL-Divergence.
66 | kldiv = probs_paired*torch.log(probs_paired/probs)
67 | kldiv += neg_probs_paired*torch.log(neg_probs_paired/neg_probs)
68 | edge_loss = torch.max(0.0, kld_margin-kldiv)
69 | not_edge_loss = kldiv
70 |
71 |
72 | not_edge_loss = torch.reshape(not_edge_loss, (-1,))
73 | not_edge_loss = torch.gather(not_edge_loss, 0, not_edge_indices)
74 | edge_loss = torch.reshape(edge_loss, (-1,))
75 | edge_loss = torch.gather(edge_loss, 0, edge_indices)
76 |
77 | return edge_loss, not_edge_loss
78 |
79 |
80 | def adaptive_affinity_loss(labels,
81 | one_hot_lab,
82 | probs,
83 | size,
84 | num_classes,
85 | kld_margin,
86 | w_edge,
87 | w_not_edge,
88 | ignore_index=255):
89 | """Adaptive affinity field (AAF) loss.
90 |
91 | This function computes AAF loss. There are three components in the function:
92 | 1) extracts edges from the ground-truth labels.
93 | 2) extracts ignored pixels and their paired pixels (usually the eight corner
94 | pixels).
95 | 3) extracts eight corner pixels/predictions from the center in a
96 | (2*size+1)x(2*size+1) patch
97 | 4) computes KL-Divergence between center pixels and their paired pixels (the
98 | eight corner).
99 | 5) imposes adaptive weightings on the loss.
100 |
101 | Args:
102 | labels: A tensor of size [batch_size, height_in, width_in], indicating
103 | semantic segmentation ground-truth labels.
104 | one_hot_lab: A tensor of size [batch_size, num_classes, height_in, width_in]
105 | which is the ground-truth labels in the form of one-hot vector.
106 | probs: A tensor of size [batch_size, num_classes, height_in, width_in],
107 | indicating segmentation predictions.
108 | size: A number indicating the half size of a patch.
109 | num_classes: A number indicating the total number of valid classes. The
110 | kld_margin: A number indicating the margin for KL-Divergence at edge.
111 | w_edge: A number indicating the weighting for KL-Divergence at edge.
112 | w_not_edge: A number indicating the weighting for KL-Divergence at non-edge.
113 | ignore_index: ignore index
114 |
115 | Returns:
116 | Two 1-D tensors value indicating the loss at edge and non-edge.
117 | """
118 | # Compute ignore map (e.g, label of 255 and their paired pixels).
119 | labels = torch.squeeze(labels, dim=1) # NxHxW
120 | ignore = nnx.ignores_from_label(labels, num_classes, size, ignore_index) # NxHxWx8
121 | not_ignore = ~ignore
122 | not_ignore = torch.unsqueeze(not_ignore, dim=3) # NxHxWx1x8
123 |
124 | # Compute edge map.
125 | edge = nnx.edges_from_label(one_hot_lab, size, ignore_index) # NxHxWxCx8
126 |
127 | # Remove ignored pixels from the edge/non-edge.
128 | edge = edge & not_ignore
129 | not_edge = ~edge & not_ignore
130 |
131 | edge_indices = torch.nonzero(torch.reshape(edge, (-1,)))
132 | # print(edge_indices.size())
133 | if edge_indices.size()[0]==0:
134 | edge_loss=torch.tensor(0.0, requires_grad=False).cuda()
135 | not_edge_loss=torch.tensor(0.0, requires_grad=False).cuda()
136 | return edge_loss, not_edge_loss
137 |
138 | not_edge_indices = torch.nonzero(torch.reshape(not_edge, (-1,)))
139 |
140 | # Extract eight corner from the center in a patch as paired pixels.
141 | probs_paired = nnx.eightcorner_activation(probs, size) # NxHxWxCx8
142 | probs = torch.unsqueeze(probs, dim=-1) # NxHxWxCx1
143 | bot_epsilon = torch.tensor(1e-4, requires_grad=False).cuda()
144 | top_epsilon = torch.tensor(1.0, requires_grad=False).cuda()
145 |
146 | neg_probs = torch.where(1-probs < bot_epsilon, bot_epsilon, 1-probs)
147 | neg_probs = torch.where(neg_probs > top_epsilon, top_epsilon, neg_probs)
148 |
149 | neg_probs_paired = torch.where(1 - probs_paired < bot_epsilon, bot_epsilon, 1 - probs_paired)
150 | neg_probs_paired = torch.where(neg_probs_paired > top_epsilon, top_epsilon, neg_probs_paired)
151 |
152 | probs = torch.where(probs < bot_epsilon, bot_epsilon, probs)
153 | probs = torch.where(probs > top_epsilon, top_epsilon, probs)
154 |
155 | probs_paired = torch.where(probs_paired < bot_epsilon, bot_epsilon, probs_paired)
156 | probs_paired = torch.where(probs_paired > top_epsilon, top_epsilon, probs_paired)
157 |
158 | # neg_probs = np.clip(
159 | # 1-probs, bot_epsilon, top_epsilon)
160 | # neg_probs_paired = np.clip(
161 | # 1-probs_paired, bot_epsilon, top_epsilon)
162 | # probs = np.clip(
163 | # probs, bot_epsilon, top_epsilon)
164 | # probs_paired = np.clip(
165 | # probs_paired, bot_epsilon, top_epsilon)
166 |
167 | # Compute KL-Divergence.
168 | kldiv = probs_paired*torch.log(probs_paired/probs)
169 | kldiv += neg_probs_paired*torch.log(neg_probs_paired/neg_probs)
170 | edge_loss = torch.max(torch.tensor(0.0, requires_grad=False).cuda(), kld_margin-kldiv)
171 | not_edge_loss = kldiv
172 |
173 | # Impose weights on edge/non-edge losses.
174 | one_hot_lab = torch.unsqueeze(one_hot_lab, dim=-1)
175 | w_edge = torch.sum(w_edge*one_hot_lab.float(), dim=3, keepdim=True) # NxHxWx1x1
176 | w_not_edge = torch.sum(w_not_edge*one_hot_lab.float(), dim=3, keepdim=True) # NxHxWx1x1
177 |
178 | edge_loss *= w_edge.permute(0,3,1,2,4)
179 | not_edge_loss *= w_not_edge.permute(0,3,1,2,4)
180 |
181 | not_edge_loss = torch.reshape(not_edge_loss, (-1,1))
182 | not_edge_loss = torch.gather(not_edge_loss, 0, not_edge_indices)
183 | edge_loss = torch.reshape(edge_loss, (-1,1))
184 | edge_loss = torch.gather(edge_loss, 0, edge_indices)
185 |
186 | return edge_loss, not_edge_loss
187 |
--------------------------------------------------------------------------------
/utils/best/lovasz_loss.py:
--------------------------------------------------------------------------------
1 | from itertools import filterfalse as ifilterfalse
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.autograd import Variable
7 |
8 |
9 | class LovaszSoftmaxLoss(nn.Module):
10 | """Multi-class Lovasz-Softmax loss.
11 | :param only_present: average only on classes present in ground truth.
12 | :param per_image: calculate the loss in image separately.
13 | :param ignore_index:
14 | """
15 |
16 | def __init__(self, ignore_index=None, only_present=False, per_image=False):
17 | super(LovaszSoftmaxLoss, self).__init__()
18 | self.ignore_index = ignore_index
19 | self.only_present = only_present
20 | self.per_image = per_image
21 | self.weight = torch.FloatTensor([0.80777327, 1.00125961, 0.90997236, 1.10867908, 1.17541499,
22 | 0.86041422, 1.01116758, 0.89290045, 1.12410812, 0.91105395,
23 | 1.07604013, 1.12470610, 1.09895196, 0.90172057, 0.93529453,
24 | 0.93054733, 1.04919178, 1.04937547, 1.06267568, 1.06365688])
25 | self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, weight=self.weight)
26 |
27 | def forward(self, preds, targets):
28 | h, w = targets.size(1), targets.size(2)
29 | # seg loss
30 | pred = F.interpolate(input=preds[0], size=(h, w), mode='bilinear', align_corners=True)
31 | pred = F.softmax(input=pred, dim=1)
32 | if self.per_image:
33 | loss = mean(lovasz_softmax_flat(*flatten_probas(pre.unsqueeze(0), tar.unsqueeze(0), self.ignore_index),
34 | only_present=self.only_present) for pre, tar in zip(pred, targets))
35 | else:
36 | loss = lovasz_softmax_flat(*flatten_probas(pred, targets, self.ignore_index),
37 | only_present=self.only_present)
38 | # dsn loss
39 | pred_dsn = F.interpolate(input=preds[1], size=(h, w), mode='bilinear', align_corners=True)
40 | loss_dsn = self.criterion(pred_dsn, targets)
41 | return loss + 0.4 * loss_dsn
42 |
43 |
44 | def lovasz_softmax_flat(preds, targets, only_present=False):
45 | """
46 | Multi-class Lovasz-Softmax loss
47 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
48 | labels: [P] Tensor, ground truth labels (between 0 and C - 1)
49 | only_present: average only on classes present in ground truth
50 | """
51 | if preds.numel() == 0:
52 | # only void pixels, the gradients should be 0
53 | return preds * 0.
54 |
55 | C = preds.size(1)
56 | losses = []
57 | for c in range(C):
58 | fg = (targets == c).float() # foreground for class c
59 | if only_present and fg.sum() == 0:
60 | continue
61 | errors = (Variable(fg) - preds[:, c]).abs()
62 | errors_sorted, perm = torch.sort(errors, 0, descending=True)
63 | perm = perm.data
64 | fg_sorted = fg[perm]
65 | losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
66 | return mean(losses)
67 |
68 |
69 | def lovasz_grad(gt_sorted):
70 | """
71 | Computes gradient of the Lovasz extension w.r.t sorted errors
72 | See Alg. 1 in paper
73 | """
74 | p = len(gt_sorted)
75 | gts = gt_sorted.sum()
76 | intersection = gts - gt_sorted.float().cumsum(0)
77 | union = gts + (1 - gt_sorted).float().cumsum(0)
78 | jaccard = 1. - intersection / union
79 | if p > 1: # cover 1-pixel case
80 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
81 | return jaccard
82 |
83 |
84 | def flatten_probas(preds, targets, ignore=None):
85 | """
86 | Flattens predictions in the batch
87 | """
88 | B, C, H, W = preds.size()
89 | preds = preds.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C
90 | targets = targets.view(-1)
91 | if ignore is None:
92 | return preds, targets
93 | valid = (targets != ignore)
94 | vprobas = preds[valid.nonzero().squeeze()]
95 | vlabels = targets[valid]
96 | return vprobas, vlabels
97 |
98 |
99 | def mean(l, ignore_nan=True, empty=0):
100 | """
101 | nan mean compatible with generators.
102 | """
103 | l = iter(l)
104 | if ignore_nan:
105 | l = ifilterfalse(isnan, l)
106 | try:
107 | n = 1
108 | acc = next(l)
109 | except StopIteration:
110 | if empty == 'raise':
111 | raise ValueError('Empty mean')
112 | return empty
113 | for n, v in enumerate(l, 2):
114 | acc += v
115 | if n == 1:
116 | return acc
117 | return acc / n
118 |
119 |
120 | def isnan(x):
121 | return x != x
122 |
--------------------------------------------------------------------------------
/utils/learning_policy.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | # poly lr
4 | # def adjust_learning_rate(optimizer, epoch, i_iter, iters_per_epoch, method='poly'):
5 | # if method == 'poly':
6 | # current_step = epoch * iters_per_epoch + i_iter
7 | # max_step = args.epochs * iters_per_epoch
8 | # lr = args.learning_rate * ((1 - current_step / max_step) ** 0.9)
9 | # else:
10 | # lr = args.learning_rate
11 | # optimizer.param_groups[0]['lr'] = lr
12 | # return lr
13 |
14 | def cosine_decay(base_learning_rate, global_step, warm_step, decay_steps, alpha=0.0001):
15 | # warm_step = 5 * iters_per_epoch
16 | # warm_lr = 0.01 * learning_rate
17 | # current_step = epoch * iters_per_epoch + i_iter
18 | alpha = alpha/base_learning_rate
19 | if global_step < warm_step:
20 | lr = base_learning_rate*global_step/warm_step
21 | # lr = base_learning_rate
22 | else:
23 | global_step = min(global_step, decay_steps)-warm_step
24 | cosine_decay = 0.5 * (1 + math.cos(math.pi * global_step / (decay_steps-warm_step)))
25 | decayed = (1 - alpha) * cosine_decay + alpha
26 | lr = base_learning_rate * decayed
27 | return lr
28 |
29 |
30 | def restart_cosine_decay(base_learning_rate, global_step, warm_step, decay_steps, alpha=0.0001):
31 | # warm_step = 5 * iters_per_epoch
32 | # warm_lr = 0.01 * learning_rate
33 | # current_step = epoch * iters_per_epoch + i_iter
34 | alpha = alpha/base_learning_rate
35 | restart_step = int((warm_step+decay_steps)/2)
36 | if global_step < warm_step:
37 | lr = base_learning_rate*global_step/warm_step
38 | elif global_step = 0) & (label < n)
52 | return np.bincount(
53 | n * label[k].astype(int) + pred[k], minlength=n ** 2).reshape(n, n)
54 |
55 |
56 | def per_class_iu(hist):
57 | return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
58 |
--------------------------------------------------------------------------------
/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 |
4 | # colour map
5 | label_colours = [(0, 0, 0)
6 | # 0=Background
7 | , (128, 0, 0), (255, 0, 0), (0, 85, 0), (170, 0, 51), (255, 85, 0)
8 | # 1=Hat, 2=Hair, 3=Glove, 4=Sunglasses, 5=UpperClothes
9 | , (0, 0, 85), (0, 119, 221), (85, 85, 0), (0, 85, 85), (85, 51, 0)
10 | # 6=Dress, 7=Coat, 8=Socks, 9=Pants, 10=Jumpsuits
11 | , (52, 86, 128), (0, 128, 0), (0, 0, 255), (51, 170, 221), (0, 255, 255)
12 | # 11=Scarf, 12=Skirt, 13=Face, 14=LeftArm, 15=RightArm
13 | , (85, 255, 170), (170, 255, 85), (255, 255, 0), (255, 170, 0)]
14 | # 16=LeftLeg, 17=RightLeg, 18=LeftShoe, 19=RightShoe
15 |
16 |
17 | pascal_person = [(0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), (0, 128, 128)]
18 |
19 |
20 | def decode_predictions(preds, num_images=4, num_classes=20):
21 | """Decode batch of segmentation masks.
22 | """
23 | preds = preds.data.cpu().numpy()
24 | n, h, w = preds.shape
25 | assert n >= num_images
26 | outputs = np.zeros((num_images, h, w, 3), dtype=np.uint8)
27 | for i in range(num_images):
28 | img = Image.new('RGB', (len(preds[i, 0]), len(preds[i])))
29 | pixels = img.load()
30 | for j_, j in enumerate(preds[i, :, :]):
31 | for k_, k in enumerate(j):
32 | if k < num_classes:
33 | pixels[k_, j_] = label_colours[k]
34 | outputs[i] = np.array(img)
35 | return outputs
36 |
37 |
38 | def inv_preprocess(imgs, num_images=4):
39 | """Inverse preprocessing of the batch of images.
40 | """
41 | mean = (104.00698793, 116.66876762, 122.67891434)
42 | imgs = imgs.data.cpu().numpy()
43 | n, c, h, w = imgs.shape
44 | assert n >= num_images
45 | outputs = np.zeros((num_images, h, w, c), dtype=np.uint8)
46 | for i in range(num_images):
47 | outputs[i] = (np.transpose(imgs[i], (1, 2, 0)) + mean)[:, :, ::-1].astype(np.uint8)
48 | return outputs
49 |
--------------------------------------------------------------------------------
/val/evaluate_atr.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import cv2
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | from torch.autograd import Variable
10 | from torch.utils import data
11 |
12 | from dataset.data_atr import ATRTestGenerator as TestGenerator
13 | from network.baseline import get_model
14 |
15 |
16 | def get_arguments():
17 | """Parse all the arguments provided from the CLI.
18 |
19 | Returns:
20 | A list of parsed arguments.
21 | """
22 | parser = argparse.ArgumentParser(description="Pytorch Segmentation")
23 | parser.add_argument("--root", type=str, default='./data/ATR/test_set/')
24 | parser.add_argument("--data-list", type=str, default='./dataset/ATR/test_id.txt')
25 | parser.add_argument("--crop-size", type=int, default=513)
26 | parser.add_argument("--num-classes", type=int, default=18)
27 | parser.add_argument("--ignore-label", type=int, default=255)
28 | parser.add_argument("--restore-from", type=str,
29 | default='./checkpoints/exp/model_best.pth')
30 | parser.add_argument("--is-mirror", action="store_true")
31 | parser.add_argument('--eval-scale', nargs='+', type=float, default=[1.0])
32 | # parser.add_argument('--eval-scale', nargs='+', type=float, default=[0.50, 0.75, 1.0, 1.25, 1.50, 1.75])
33 |
34 | parser.add_argument("--save-dir", type=str)
35 | parser.add_argument("--gpu", type=str, default='0')
36 | return parser.parse_args()
37 |
38 |
39 | def main():
40 | """Create the model and start the evaluation process."""
41 | args = get_arguments()
42 |
43 | # initialization
44 | print("Input arguments:")
45 | for key, val in vars(args).items():
46 | print("{:16} {}".format(key, val))
47 |
48 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
49 |
50 | model = get_model(num_classes=args.num_classes)
51 |
52 | # if not os.path.exists(args.save_dir):
53 | # os.makedirs(args.save_dir)
54 |
55 | palette = get_lip_palette()
56 | restore_from = args.restore_from
57 | saved_state_dict = torch.load(restore_from)
58 | model.load_state_dict(saved_state_dict)
59 |
60 | model.eval()
61 | model.cuda()
62 |
63 | testloader = data.DataLoader(TestGenerator(args.root, args.data_list, crop_size=args.crop_size),
64 | batch_size=1, shuffle=False, pin_memory=True)
65 |
66 | confusion_matrix = np.zeros((args.num_classes, args.num_classes))
67 |
68 | for index, batch in enumerate(testloader):
69 | if index % 100 == 0:
70 | print('%d images have been proceeded' % index)
71 | image, label, ori_size, name = batch
72 |
73 | ori_size = ori_size[0].numpy()
74 |
75 | output = predict(model, image.numpy(), (np.asscalar(ori_size[0]), np.asscalar(ori_size[1])),
76 | is_mirror=args.is_mirror, scales=args.eval_scale)
77 | seg_pred = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
78 |
79 | # output_im = PILImage.fromarray(seg_pred)
80 | # output_im.putpalette(palette)
81 | # output_im.save(args.save_dir + name[0] + '.png')
82 |
83 | seg_gt = np.asarray(label[0].numpy(), dtype=np.int)
84 | ignore_index = seg_gt != 255
85 | seg_gt = seg_gt[ignore_index]
86 | seg_pred = seg_pred[ignore_index]
87 |
88 | confusion_matrix += get_confusion_matrix(seg_gt, seg_pred, args.num_classes)
89 |
90 | pos = confusion_matrix.sum(1)
91 | res = confusion_matrix.sum(0)
92 | tp = np.diag(confusion_matrix)
93 |
94 | pixel_accuracy = tp.sum() / pos.sum()
95 | mean_accuracy = (tp / np.maximum(1.0, pos)).mean()
96 | IU_array = (tp / np.maximum(1.0, pos + res - tp))
97 | mean_IU = IU_array.mean()
98 |
99 | # get_confusion_matrix_plot()
100 |
101 | print('Pixel accuracy: %f \n' % pixel_accuracy)
102 | print('Mean accuracy: %f \n' % mean_accuracy)
103 | print('Mean IU: %f \n' % mean_IU)
104 | for index, IU in enumerate(IU_array):
105 | print('%f ', IU)
106 |
107 |
108 | def scale_image(image, scale):
109 | image = image[0, :, :, :]
110 | image = image.transpose((1, 2, 0))
111 | image = cv2.resize(image, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
112 | image = image.transpose((2, 0, 1))
113 | return image
114 |
115 |
116 | def predict(net, image, output_size, is_mirror=True, scales=[1]):
117 | if is_mirror:
118 | image_rev = image[:, :, :, ::-1]
119 |
120 | interp = nn.Upsample(size=output_size, mode='bilinear', align_corners=True)
121 |
122 | outputs = []
123 | if is_mirror:
124 | for scale in scales:
125 | if scale != 1:
126 | image_scale = scale_image(image=image, scale=scale)
127 | image_rev_scale = scale_image(image=image_rev, scale=scale)
128 | else:
129 | image_scale = image[0, :, :, :]
130 | image_rev_scale = image_rev[0, :, :, :]
131 |
132 | image_scale = np.stack((image_scale, image_rev_scale))
133 |
134 | with torch.no_grad():
135 | prediction = net(Variable(torch.from_numpy(image_scale)).cuda())
136 | prediction = interp(prediction[0]).cpu().data.numpy()
137 |
138 | prediction_rev = prediction[1, :, :, :].copy()
139 | prediction_rev[9, :, :] = prediction[1, 10, :, :]
140 | prediction_rev[10, :, :] = prediction[1, 9, :, :]
141 | prediction_rev[12, :, :] = prediction[1, 13, :, :]
142 | prediction_rev[13, :, :] = prediction[1, 12, :, :]
143 | prediction_rev[14, :, :] = prediction[1, 15, :, :]
144 | prediction_rev[15, :, :] = prediction[1, 14, :, :]
145 | prediction_rev = prediction_rev[:, :, ::-1]
146 | prediction = prediction[0, :, :, :]
147 | prediction = np.mean([prediction, prediction_rev], axis=0)
148 |
149 | outputs.append(prediction)
150 |
151 | outputs = np.mean(outputs, axis=0)
152 | outputs = outputs.transpose(1, 2, 0)
153 | else:
154 | for scale in scales:
155 | if scale != 1:
156 | image_scale = scale_image(image=image, scale=scale)
157 | else:
158 | image_scale = image[0, :, :, :]
159 |
160 | with torch.no_grad():
161 | prediction = net(Variable(torch.from_numpy(image_scale).unsqueeze(0)).cuda())
162 | prediction = interp(prediction[0]).cpu().data.numpy()
163 | outputs.append(prediction[0, :, :, :])
164 |
165 | outputs = np.mean(outputs, axis=0)
166 | outputs = outputs.transpose(1, 2, 0)
167 |
168 | return outputs
169 |
170 |
171 | def get_confusion_matrix(gt_label, pred_label, class_num):
172 | """
173 | Calculate the confusion matrix by given label and pred
174 | :param gt_label: the ground truth label
175 | :param pred_label: the pred label
176 | :param class_num: the nunber of class
177 | """
178 | index = (gt_label * class_num + pred_label).astype('int32')
179 | label_count = np.bincount(index)
180 | confusion_matrix = np.zeros((class_num, class_num))
181 |
182 | for i_label in range(class_num):
183 | for i_pred_label in range(class_num):
184 | cur_index = i_label * class_num + i_pred_label
185 | if cur_index < len(label_count):
186 | confusion_matrix[i_label, i_pred_label] = label_count[cur_index]
187 |
188 | return confusion_matrix
189 |
190 |
191 | def get_confusion_matrix_plot(conf_arr):
192 | norm_conf = []
193 | for i in conf_arr:
194 | tmp_arr = []
195 | a = sum(i, 0)
196 | for j in i:
197 | tmp_arr.append(float(j) / max(1.0, float(a)))
198 | norm_conf.append(tmp_arr)
199 |
200 | fig = plt.figure()
201 | plt.clf()
202 | ax = fig.add_subplot(111)
203 | ax.set_aspect(1)
204 | res = ax.imshow(np.array(norm_conf), cmap=plt.cm.jet, interpolation='nearest')
205 |
206 | width, height = conf_arr.shape
207 |
208 | cb = fig.colorbar(res)
209 | alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
210 | plt.xticks(range(width), alphabet[:width])
211 | plt.yticks(range(height), alphabet[:height])
212 | plt.savefig('confusion_matrix.png', format='png')
213 |
214 |
215 | def get_lip_palette():
216 | palette = [0, 0, 0,
217 | 128, 0, 0,
218 | 255, 0, 0,
219 | 0, 85, 0,
220 | 170, 0, 51,
221 | 255, 85, 0,
222 | 0, 0, 85,
223 | 0, 119, 221,
224 | 85, 85, 0,
225 | 0, 85, 85,
226 | 85, 51, 0,
227 | 52, 86, 128,
228 | 0, 128, 0,
229 | 0, 0, 255,
230 | 51, 170, 221,
231 | 0, 255, 255,
232 | 85, 255, 170,
233 | 170, 255, 85,
234 | 255, 255, 0,
235 | 255, 170, 0]
236 | return palette
237 |
238 |
239 | if __name__ == '__main__':
240 | main()
241 |
--------------------------------------------------------------------------------
/val/evaluate_ccf.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import cv2
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | from PIL import Image
10 | from torch.autograd import Variable
11 | from torch.utils import data
12 |
13 | from dataset.data_ccf import TestGenerator
14 | from network.baseline import get_model
15 |
16 |
17 | def get_arguments():
18 | """Parse all the arguments provided from the CLI.
19 |
20 | Returns:
21 | A list of parsed arguments.
22 | """
23 | parser = argparse.ArgumentParser(description="Pytorch Segmentation")
24 | parser.add_argument('--root', default='./data/CCF', type=str)
25 | parser.add_argument("--data-list", type=str, default='./dataset/CCF/test_id.txt')
26 | parser.add_argument("--crop-size", type=int, default=513)
27 | parser.add_argument("--num-classes", type=int, default=18)
28 | parser.add_argument("--ignore-label", type=int, default=255)
29 | parser.add_argument('--restore-from', default='./checkpoints/exp/model_best.pth', type=str)
30 |
31 | parser.add_argument("--is-mirror", action="store_true")
32 | parser.add_argument('--eval-scale', nargs='+', type=float, default=[1.0])
33 | # parser.add_argument('--eval-scale', nargs='+', type=float, default=[0.50, 0.75, 1.0, 1.25, 1.50, 1.75])
34 | parser.add_argument("--save-dir", type=str)
35 | parser.add_argument("--gpu", type=str, default='0')
36 | return parser.parse_args()
37 |
38 |
39 | def main():
40 | """Create the model and start the evaluation process."""
41 | args = get_arguments()
42 |
43 | # initialization
44 | print("Input arguments:")
45 | for key, val in vars(args).items():
46 | print("{:16} {}".format(key, val))
47 |
48 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
49 |
50 |
51 | # if not os.path.exists(args.save_dir):
52 | # os.makedirs(args.save_dir)
53 |
54 | # obtain the color map
55 | palette = get_lip_palette()
56 |
57 | # conduct model & load pre-trained weights
58 | model = get_model(num_classes=args.num_classes)
59 | restore_from = args.restore_from
60 | saved_state_dict = torch.load(restore_from)
61 | model.load_state_dict(saved_state_dict)
62 |
63 | model.eval()
64 | model.cuda()
65 | # data loader
66 | testloader = data.DataLoader(TestGenerator(args.root, args.data_list, crop_size=args.crop_size),
67 | batch_size=1, shuffle=False, pin_memory=True)
68 |
69 | confusion_matrix = np.zeros((args.num_classes, args.num_classes))
70 |
71 | for index, batch in enumerate(testloader):
72 | if index % 100 == 0:
73 | print('%d images have been proceeded' % index)
74 | image, label, ori_size, name = batch
75 |
76 | ori_size = ori_size[0].numpy()
77 | output = predict(model, image.numpy(), (np.asscalar(ori_size[0]), np.asscalar(ori_size[1])),
78 | is_mirror=args.is_mirror, scales=args.eval_scale)
79 | seg_pred = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
80 |
81 | # output_im = PILImage.fromarray(seg_pred)
82 | # output_im.putpalette(palette)
83 | # output_im.save(args.save_dir + name[0] + '.png')
84 |
85 | seg_gt = np.asarray(label[0].numpy(), dtype=np.int)
86 | ignore_index = seg_gt != 255
87 | seg_gt = seg_gt[ignore_index]
88 | seg_pred = seg_pred[ignore_index]
89 |
90 | confusion_matrix += get_confusion_matrix(seg_gt, seg_pred, args.num_classes)
91 |
92 | pos = confusion_matrix.sum(1)
93 | res = confusion_matrix.sum(0)
94 | tp = np.diag(confusion_matrix)
95 |
96 | pixel_accuracy = tp.sum() / pos.sum()
97 | mean_accuracy = (tp / np.maximum(1.0, pos)).mean()
98 | IU_array = (tp / np.maximum(1.0, pos + res - tp))
99 | mean_IU = IU_array.mean()
100 |
101 | # get_confusion_matrix_plot()
102 |
103 | print('Pixel accuracy: %f \n' % pixel_accuracy)
104 | print('Mean accuracy: %f \n' % mean_accuracy)
105 | print('Mean IU: %f \n' % mean_IU)
106 | for index, IU in enumerate(IU_array):
107 | print('%f ', IU)
108 |
109 |
110 | def scale_image(image, scale):
111 | image = image[0, :, :, :]
112 | image = image.transpose((1, 2, 0))
113 | image = cv2.resize(image, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
114 | image = image.transpose((2, 0, 1))
115 | return image
116 |
117 |
118 | def predict(net, image, output_size, is_mirror=True, scales=[1]):
119 | if is_mirror:
120 | image_rev = image[:, :, :, ::-1]
121 |
122 | interp = nn.Upsample(size=output_size, mode='bilinear', align_corners=True)
123 |
124 | outputs = []
125 | if is_mirror:
126 | for scale in scales:
127 | if scale != 1:
128 | image_scale = scale_image(image=image, scale=scale)
129 | image_rev_scale = scale_image(image=image_rev, scale=scale)
130 | else:
131 | image_scale = image[0, :, :, :]
132 | image_rev_scale = image_rev[0, :, :, :]
133 |
134 | image_scale = np.stack((image_scale, image_rev_scale))
135 |
136 | with torch.no_grad():
137 | prediction = net(Variable(torch.from_numpy(image_scale)).cuda())
138 | prediction = interp(prediction[0]).cpu().data.numpy()
139 |
140 | prediction_rev = prediction[1, :, :, :].copy()
141 | prediction_rev = prediction_rev[:, :, ::-1]
142 | prediction = prediction[0, :, :, :]
143 | prediction = np.mean([prediction, prediction_rev], axis=0)
144 |
145 | outputs.append(prediction)
146 |
147 | outputs = np.mean(outputs, axis=0)
148 | outputs = outputs.transpose(1, 2, 0)
149 | else:
150 | for scale in scales:
151 | if scale != 1:
152 | image_scale = scale_image(image=image, scale=scale)
153 | else:
154 | image_scale = image[0, :, :, :]
155 |
156 | with torch.no_grad():
157 | prediction = net(Variable(torch.from_numpy(image_scale).unsqueeze(0)).cuda())
158 | prediction = interp(prediction[0]).cpu().data.numpy()
159 | outputs.append(prediction[0, :, :, :])
160 |
161 | outputs = np.mean(outputs, axis=0)
162 | outputs = outputs.transpose(1, 2, 0)
163 |
164 | return outputs
165 |
166 |
167 | def get_confusion_matrix(gt_label, pred_label, class_num):
168 | """
169 | Calculate the confusion matrix by given label and pred
170 | :param gt_label: the ground truth label
171 | :param pred_label: the pred label
172 | :param class_num: the nunber of class
173 | """
174 | index = (gt_label * class_num + pred_label).astype('int32')
175 | label_count = np.bincount(index)
176 | confusion_matrix = np.zeros((class_num, class_num))
177 |
178 | for i_label in range(class_num):
179 | for i_pred_label in range(class_num):
180 | cur_index = i_label * class_num + i_pred_label
181 | if cur_index < len(label_count):
182 | confusion_matrix[i_label, i_pred_label] = label_count[cur_index]
183 |
184 | return confusion_matrix
185 |
186 |
187 | def get_confusion_matrix_plot(conf_arr):
188 | norm_conf = []
189 | for i in conf_arr:
190 | tmp_arr = []
191 | a = sum(i, 0)
192 | for j in i:
193 | tmp_arr.append(float(j) / max(1.0, float(a)))
194 | norm_conf.append(tmp_arr)
195 |
196 | fig = plt.figure()
197 | plt.clf()
198 | ax = fig.add_subplot(111)
199 | ax.set_aspect(1)
200 | res = ax.imshow(np.array(norm_conf), cmap=plt.cm.jet, interpolation='nearest')
201 |
202 | width, height = conf_arr.shape
203 |
204 | cb = fig.colorbar(res)
205 | alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
206 | plt.xticks(range(width), alphabet[:width])
207 | plt.yticks(range(height), alphabet[:height])
208 | plt.savefig('confusion_matrix.png', format='png')
209 |
210 |
211 | def get_lip_palette():
212 | palette = [0, 0, 0,
213 | 128, 0, 0,
214 | 255, 0, 0,
215 | 0, 85, 0,
216 | 170, 0, 51,
217 | 255, 85, 0,
218 | 0, 0, 85,
219 | 0, 119, 221,
220 | 85, 85, 0,
221 | 0, 85, 85,
222 | 85, 51, 0,
223 | 52, 86, 128,
224 | 0, 128, 0,
225 | 0, 0, 255,
226 | 51, 170, 221,
227 | 0, 255, 255,
228 | 85, 255, 170,
229 | 170, 255, 85,
230 | 255, 255, 0,
231 | 255, 170, 0]
232 | return palette
233 |
234 |
235 | if __name__ == '__main__':
236 | main()
237 |
--------------------------------------------------------------------------------
/val/evaluate_lip.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import cv2
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | from torch.autograd import Variable
10 | from torch.utils import data
11 |
12 | from dataset.datasets import LIPValGenerator
13 | from network.baseline import get_model
14 |
15 |
16 | def get_arguments():
17 | """Parse all the arguments provided from the CLI.
18 |
19 | Returns:
20 | A list of parsed arguments.
21 | """
22 | parser = argparse.ArgumentParser(description="Pytorch Segmentation")
23 | parser.add_argument("--root", type=str, default='./data/LIP/val_set/')
24 | parser.add_argument("--data-list", type=str, default='./dataset/LIP/val_id.txt')
25 | parser.add_argument("--crop-size", type=int, default=473)
26 | parser.add_argument("--num-classes", type=int, default=20)
27 | parser.add_argument("--ignore-label", type=int, default=255)
28 | parser.add_argument("--restore-from", type=str,
29 | default='./checkpoints/exp/model_best.pth')
30 | parser.add_argument("--is-mirror", action="store_true")
31 | parser.add_argument('--eval-scale', nargs='+', type=float, default=[1.0])
32 | # parser.add_argument('--eval-scale', nargs='+', type=float, default=[0.50, 0.75, 1.0, 1.25, 1.50, 1.75])
33 |
34 | parser.add_argument("--save-dir", type=str)
35 | parser.add_argument("--gpu", type=str, default='0')
36 | return parser.parse_args()
37 |
38 |
39 | def main():
40 | """Create the model and start the evaluation process."""
41 | args = get_arguments()
42 |
43 | # initialization
44 | print("Input arguments:")
45 | for key, val in vars(args).items():
46 | print("{:16} {}".format(key, val))
47 |
48 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
49 |
50 | model = get_model(num_classes=args.num_classes)
51 |
52 | # if not os.path.exists(args.save_dir):
53 | # os.makedirs(args.save_dir)
54 |
55 | palette = get_lip_palette()
56 | restore_from = args.restore_from
57 | saved_state_dict = torch.load(restore_from)
58 | model.load_state_dict(saved_state_dict)
59 |
60 | model.eval()
61 | model.cuda()
62 |
63 | testloader = data.DataLoader(LIPValGenerator(args.root, args.data_list, crop_size=args.crop_size),
64 | batch_size=1, shuffle=False, pin_memory=True)
65 |
66 | confusion_matrix = np.zeros((args.num_classes, args.num_classes))
67 |
68 | for index, batch in enumerate(testloader):
69 | if index % 100 == 0:
70 | print('%d images have been proceeded' % index)
71 | image, label, ori_size, name = batch
72 |
73 | ori_size = ori_size[0].numpy()
74 |
75 | output = predict(model, image.numpy(), (np.asscalar(ori_size[0]), np.asscalar(ori_size[1])),
76 | is_mirror=args.is_mirror, scales=args.eval_scale)
77 | seg_pred = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
78 |
79 | # output_im = PILImage.fromarray(seg_pred)
80 | # output_im.putpalette(palette)
81 | # output_im.save(args.save_dir + name[0] + '.png')
82 |
83 | seg_gt = np.asarray(label[0].numpy(), dtype=np.int)
84 | ignore_index = seg_gt != 255
85 | seg_gt = seg_gt[ignore_index]
86 | seg_pred = seg_pred[ignore_index]
87 |
88 | confusion_matrix += get_confusion_matrix(seg_gt, seg_pred, args.num_classes)
89 |
90 | pos = confusion_matrix.sum(1)
91 | res = confusion_matrix.sum(0)
92 | tp = np.diag(confusion_matrix)
93 |
94 | pixel_accuracy = tp.sum() / pos.sum()
95 | mean_accuracy = (tp / np.maximum(1.0, pos)).mean()
96 | IU_array = (tp / np.maximum(1.0, pos + res - tp))
97 | mean_IU = IU_array.mean()
98 |
99 | # get_confusion_matrix_plot()
100 |
101 | print('Pixel accuracy: %f \n' % pixel_accuracy)
102 | print('Mean accuracy: %f \n' % mean_accuracy)
103 | print('Mean IU: %f \n' % mean_IU)
104 | for index, IU in enumerate(IU_array):
105 | print('%f ', IU)
106 |
107 |
108 | def scale_image(image, scale):
109 | image = image[0, :, :, :]
110 | image = image.transpose((1, 2, 0))
111 | image = cv2.resize(image, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
112 | image = image.transpose((2, 0, 1))
113 | return image
114 |
115 |
116 | def predict(net, image, output_size, is_mirror=True, scales=[1]):
117 | if is_mirror:
118 | image_rev = image[:, :, :, ::-1]
119 |
120 | interp = nn.Upsample(size=output_size, mode='bilinear', align_corners=True)
121 |
122 | outputs = []
123 | if is_mirror:
124 | for scale in scales:
125 | if scale != 1:
126 | image_scale = scale_image(image=image, scale=scale)
127 | image_rev_scale = scale_image(image=image_rev, scale=scale)
128 | else:
129 | image_scale = image[0, :, :, :]
130 | image_rev_scale = image_rev[0, :, :, :]
131 |
132 | image_scale = np.stack((image_scale, image_rev_scale))
133 |
134 | with torch.no_grad():
135 | prediction = net(Variable(torch.from_numpy(image_scale)).cuda())
136 | prediction = interp(prediction[0]).cpu().data.numpy()
137 |
138 | prediction_rev = prediction[1, :, :, :].copy()
139 | prediction_rev[14, :, :] = prediction[1, 15, :, :]
140 | prediction_rev[15, :, :] = prediction[1, 14, :, :]
141 | prediction_rev[16, :, :] = prediction[1, 17, :, :]
142 | prediction_rev[17, :, :] = prediction[1, 16, :, :]
143 | prediction_rev[18, :, :] = prediction[1, 19, :, :]
144 | prediction_rev[19, :, :] = prediction[1, 18, :, :]
145 | prediction_rev = prediction_rev[:, :, ::-1]
146 | prediction = prediction[0, :, :, :]
147 | prediction = np.mean([prediction, prediction_rev], axis=0)
148 |
149 | outputs.append(prediction)
150 |
151 | outputs = np.mean(outputs, axis=0)
152 | outputs = outputs.transpose(1, 2, 0)
153 | else:
154 | for scale in scales:
155 | if scale != 1:
156 | image_scale = scale_image(image=image, scale=scale)
157 | else:
158 | image_scale = image[0, :, :, :]
159 |
160 | with torch.no_grad():
161 | prediction = net(Variable(torch.from_numpy(image_scale).unsqueeze(0)).cuda())
162 | prediction = interp(prediction[0]).cpu().data.numpy()
163 | outputs.append(prediction[0, :, :, :])
164 |
165 | outputs = np.mean(outputs, axis=0)
166 | outputs = outputs.transpose(1, 2, 0)
167 |
168 | return outputs
169 |
170 |
171 | def get_confusion_matrix(gt_label, pred_label, class_num):
172 | """
173 | Calculate the confusion matrix by given label and pred
174 | :param gt_label: the ground truth label
175 | :param pred_label: the pred label
176 | :param class_num: the nunber of class
177 | """
178 | index = (gt_label * class_num + pred_label).astype('int32')
179 | label_count = np.bincount(index)
180 | confusion_matrix = np.zeros((class_num, class_num))
181 |
182 | for i_label in range(class_num):
183 | for i_pred_label in range(class_num):
184 | cur_index = i_label * class_num + i_pred_label
185 | if cur_index < len(label_count):
186 | confusion_matrix[i_label, i_pred_label] = label_count[cur_index]
187 |
188 | return confusion_matrix
189 |
190 |
191 | def get_confusion_matrix_plot(conf_arr):
192 | norm_conf = []
193 | for i in conf_arr:
194 | tmp_arr = []
195 | a = sum(i, 0)
196 | for j in i:
197 | tmp_arr.append(float(j) / max(1.0, float(a)))
198 | norm_conf.append(tmp_arr)
199 |
200 | fig = plt.figure()
201 | plt.clf()
202 | ax = fig.add_subplot(111)
203 | ax.set_aspect(1)
204 | res = ax.imshow(np.array(norm_conf), cmap=plt.cm.jet, interpolation='nearest')
205 |
206 | width, height = conf_arr.shape
207 |
208 | cb = fig.colorbar(res)
209 | alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
210 | plt.xticks(range(width), alphabet[:width])
211 | plt.yticks(range(height), alphabet[:height])
212 | plt.savefig('confusion_matrix.png', format='png')
213 |
214 |
215 | def get_lip_palette():
216 | palette = [0, 0, 0,
217 | 128, 0, 0,
218 | 255, 0, 0,
219 | 0, 85, 0,
220 | 170, 0, 51,
221 | 255, 85, 0,
222 | 0, 0, 85,
223 | 0, 119, 221,
224 | 85, 85, 0,
225 | 0, 85, 85,
226 | 85, 51, 0,
227 | 52, 86, 128,
228 | 0, 128, 0,
229 | 0, 0, 255,
230 | 51, 170, 221,
231 | 0, 255, 255,
232 | 85, 255, 170,
233 | 170, 255, 85,
234 | 255, 255, 0,
235 | 255, 170, 0]
236 | return palette
237 |
238 |
239 | if __name__ == '__main__':
240 | main()
241 |
--------------------------------------------------------------------------------
/val/evaluate_pascal.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import cv2
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | from torch.autograd import Variable
10 | from torch.utils import data
11 |
12 | from dataset.data_pascal import TestGenerator
13 | from network.baseline import get_model
14 |
15 |
16 | def get_arguments():
17 | """Parse all the arguments provided from the CLI.
18 |
19 | Returns:
20 | A list of parsed arguments.
21 | """
22 | parser = argparse.ArgumentParser(description="Pytorch Segmentation")
23 | parser.add_argument('--root', default='./data/Person', type=str)
24 | parser.add_argument("--data-list", type=str, default='./dataset/Pascal/val_id.txt')
25 | parser.add_argument("--crop-size", type=int, default=473)
26 | parser.add_argument("--num-classes", type=int, default=7)
27 | parser.add_argument("--ignore-label", type=int, default=255)
28 | parser.add_argument('--restore-from', default='./checkpoints/exp/model_best.pth', type=str)
29 |
30 | parser.add_argument("--is-mirror", action="store_true")
31 | parser.add_argument('--eval-scale', nargs='+', type=float, default=[1.0])
32 | # parser.add_argument('--eval-scale', nargs='+', type=float, default=[0.50, 0.75, 1.0, 1.25, 1.50, 1.75])
33 | parser.add_argument("--save-dir", type=str)
34 | parser.add_argument("--gpu", type=str, default='0')
35 | return parser.parse_args()
36 |
37 |
38 | def main():
39 | """Create the model and start the evaluation process."""
40 | args = get_arguments()
41 |
42 | # initialization
43 | print("Input arguments:")
44 | for key, val in vars(args).items():
45 | print("{:16} {}".format(key, val))
46 |
47 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
48 |
49 | model = get_model(num_classes=args.num_classes)
50 |
51 | # if not os.path.exists(args.save_dir):
52 | # os.makedirs(args.save_dir)
53 |
54 | palette = get_lip_palette()
55 | restore_from = args.restore_from
56 | saved_state_dict = torch.load(restore_from)
57 | model.load_state_dict(saved_state_dict)
58 |
59 | model.eval()
60 | model.cuda()
61 |
62 | testloader = data.DataLoader(TestGenerator(args.root, args.data_list, crop_size=args.crop_size),
63 | batch_size=1, shuffle=False, pin_memory=True)
64 |
65 | confusion_matrix = np.zeros((args.num_classes, args.num_classes))
66 |
67 | for index, batch in enumerate(testloader):
68 | if index % 100 == 0:
69 | print('%d images have been proceeded' % index)
70 | image, label, ori_size, name = batch
71 |
72 | # img_name = "/home/hlzhu/hlzhu/Iter_ParseNet_final/data/Person/JPEGImages/"+name[0]+'.jpg'
73 | # print(img_name)
74 | # ori_img = cv2.imread(img_name)
75 | # cv2.imshow('image',ori_img)
76 | # cv2.waitKey(1)
77 | # 2008_000195 multi person
78 | # 2008_002829 single person
79 | if name[0]=="2008_002829":
80 | print("2008_002829.jpg")
81 | ori_size = ori_size[0].numpy()
82 |
83 | output = predict(model, image.numpy(), (np.asscalar(ori_size[0]), np.asscalar(ori_size[1])),
84 | is_mirror=args.is_mirror, scales=args.eval_scale)
85 | seg_pred = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
86 |
87 | # output_im = PILImage.fromarray(seg_pred)
88 | # output_im.putpalette(palette)
89 | # output_im.save(args.save_dir + name[0] + '.png')
90 |
91 | seg_gt = np.asarray(label[0].numpy(), dtype=np.int)
92 | ignore_index = seg_gt != 255
93 | seg_gt = seg_gt[ignore_index]
94 | seg_pred = seg_pred[ignore_index]
95 |
96 | confusion_matrix += get_confusion_matrix(seg_gt, seg_pred, args.num_classes)
97 |
98 | pos = confusion_matrix.sum(1)
99 | res = confusion_matrix.sum(0)
100 | tp = np.diag(confusion_matrix)
101 |
102 | pixel_accuracy = tp.sum() / pos.sum()
103 | mean_accuracy = (tp / np.maximum(1.0, pos)).mean()
104 | IU_array = (tp / np.maximum(1.0, pos + res - tp))
105 | mean_IU = IU_array.mean()
106 |
107 | # get_confusion_matrix_plot()
108 |
109 | print('Pixel accuracy: %f \n' % pixel_accuracy)
110 | print('Mean accuracy: %f \n' % mean_accuracy)
111 | print('Mean IU: %f \n' % mean_IU)
112 | for index, IU in enumerate(IU_array):
113 | print('%f ', IU)
114 |
115 |
116 | def scale_image(image, scale):
117 | image = image[0, :, :, :]
118 | image = image.transpose((1, 2, 0))
119 | image = cv2.resize(image, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
120 | image = image.transpose((2, 0, 1))
121 | return image
122 |
123 |
124 | def predict(net, image, output_size, is_mirror=True, scales=[1]):
125 | if is_mirror:
126 | image_rev = image[:, :, :, ::-1]
127 |
128 | interp = nn.Upsample(size=output_size, mode='bilinear', align_corners=True)
129 |
130 | outputs = []
131 | if is_mirror:
132 | for scale in scales:
133 | if scale != 1:
134 | image_scale = scale_image(image=image, scale=scale)
135 | image_rev_scale = scale_image(image=image_rev, scale=scale)
136 | else:
137 | image_scale = image[0, :, :, :]
138 | image_rev_scale = image_rev[0, :, :, :]
139 |
140 | image_scale = np.stack((image_scale, image_rev_scale))
141 |
142 | with torch.no_grad():
143 | prediction = net(Variable(torch.from_numpy(image_scale)).cuda())
144 | prediction = interp(prediction[0]).cpu().data.numpy()
145 |
146 | prediction_rev = prediction[1, :, :, :].copy()
147 | prediction_rev = prediction_rev[:, :, ::-1]
148 | prediction = prediction[0, :, :, :]
149 | prediction = np.mean([prediction, prediction_rev], axis=0)
150 |
151 | outputs.append(prediction)
152 |
153 | outputs = np.mean(outputs, axis=0)
154 | outputs = outputs.transpose(1, 2, 0)
155 | else:
156 | for scale in scales:
157 | if scale != 1:
158 | image_scale = scale_image(image=image, scale=scale)
159 | else:
160 | image_scale = image[0, :, :, :]
161 |
162 | with torch.no_grad():
163 | prediction = net(Variable(torch.from_numpy(image_scale).unsqueeze(0)).cuda())
164 | prediction = interp(prediction[0]).cpu().data.numpy()
165 | outputs.append(prediction[0, :, :, :])
166 |
167 | outputs = np.mean(outputs, axis=0)
168 | outputs = outputs.transpose(1, 2, 0)
169 |
170 | return outputs
171 |
172 |
173 | def get_confusion_matrix(gt_label, pred_label, class_num):
174 | """
175 | Calculate the confusion matrix by given label and pred
176 | :param gt_label: the ground truth label
177 | :param pred_label: the pred label
178 | :param class_num: the nunber of class
179 | """
180 | index = (gt_label * class_num + pred_label).astype('int32')
181 | label_count = np.bincount(index)
182 | confusion_matrix = np.zeros((class_num, class_num))
183 |
184 | for i_label in range(class_num):
185 | for i_pred_label in range(class_num):
186 | cur_index = i_label * class_num + i_pred_label
187 | if cur_index < len(label_count):
188 | confusion_matrix[i_label, i_pred_label] = label_count[cur_index]
189 |
190 | return confusion_matrix
191 |
192 |
193 | def get_confusion_matrix_plot(conf_arr):
194 | norm_conf = []
195 | for i in conf_arr:
196 | tmp_arr = []
197 | a = sum(i, 0)
198 | for j in i:
199 | tmp_arr.append(float(j) / max(1.0, float(a)))
200 | norm_conf.append(tmp_arr)
201 |
202 | fig = plt.figure()
203 | plt.clf()
204 | ax = fig.add_subplot(111)
205 | ax.set_aspect(1)
206 | res = ax.imshow(np.array(norm_conf), cmap=plt.cm.jet, interpolation='nearest')
207 |
208 | width, height = conf_arr.shape
209 |
210 | cb = fig.colorbar(res)
211 | alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
212 | plt.xticks(range(width), alphabet[:width])
213 | plt.yticks(range(height), alphabet[:height])
214 | plt.savefig('confusion_matrix.png', format='png')
215 |
216 |
217 | def get_lip_palette():
218 | palette = [0, 0, 0,
219 | 128, 0, 0,
220 | 255, 0, 0,
221 | 0, 85, 0,
222 | 170, 0, 51,
223 | 255, 85, 0,
224 | 0, 0, 85,
225 | 0, 119, 221,
226 | 85, 85, 0,
227 | 0, 85, 85,
228 | 85, 51, 0,
229 | 52, 86, 128,
230 | 0, 128, 0,
231 | 0, 0, 255,
232 | 51, 170, 221,
233 | 0, 255, 255,
234 | 85, 255, 170,
235 | 170, 255, 85,
236 | 255, 255, 0,
237 | 255, 170, 0]
238 | return palette
239 |
240 |
241 | if __name__ == '__main__':
242 | main()
243 |
--------------------------------------------------------------------------------
/val/evaluate_ppss.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import cv2
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | from torch.autograd import Variable
10 | from torch.utils import data
11 |
12 | from dataset.datappss import TestGenerator
13 | from network.baseline import get_model
14 |
15 |
16 | def get_arguments():
17 | """Parse all the arguments provided from the CLI.
18 |
19 | Returns:
20 | A list of parsed arguments.
21 | """
22 | parser = argparse.ArgumentParser(description="Pytorch Segmentation")
23 | parser.add_argument('--root', default='./data/PPSS/TestData/', type=str)
24 | parser.add_argument("--data-list", type=str, default='./dataset/PPSS/test_id.txt')
25 | parser.add_argument("--crop-size", type=tuple, default=(321, 321))
26 | parser.add_argument("--num-classes", type=int, default=8)
27 | parser.add_argument("--ignore-label", type=int, default=255)
28 | parser.add_argument('--restore-from', default='./checkpoints/exp/model_best.pth', type=str)
29 |
30 | parser.add_argument("--is-mirror", action="store_true")
31 | parser.add_argument('--eval-scale', nargs='+', type=float, default=[1.0])
32 | # parser.add_argument('--eval-scale', nargs='+', type=float, default=[0.50, 0.75, 1.0, 1.25, 1.50, 1.75])
33 | parser.add_argument("--save-dir", type=str)
34 | parser.add_argument("--gpu", type=str, default='0')
35 | return parser.parse_args()
36 |
37 |
38 | def main():
39 | """Create the model and start the evaluation process."""
40 | args = get_arguments()
41 |
42 | # initialization
43 | print("Input arguments:")
44 | for key, val in vars(args).items():
45 | print("{:16} {}".format(key, val))
46 |
47 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
48 |
49 | model = get_model(num_classes=args.num_classes)
50 |
51 | # if not os.path.exists(args.save_dir):
52 | # os.makedirs(args.save_dir)
53 |
54 | palette = get_lip_palette()
55 | restore_from = args.restore_from
56 | saved_state_dict = torch.load(restore_from)
57 | model.load_state_dict(saved_state_dict)
58 |
59 | model.eval()
60 | model.cuda()
61 |
62 | testloader = data.DataLoader(TestGenerator(args.root, args.data_list, crop_size=args.crop_size),
63 | batch_size=1, shuffle=False, pin_memory=True)
64 |
65 | confusion_matrix = np.zeros((args.num_classes, args.num_classes))
66 |
67 | for index, batch in enumerate(testloader):
68 | if index % 100 == 0:
69 | print('%d images have been proceeded' % index)
70 | image, label, ori_size, name = batch
71 |
72 | # img_name = "/home/hlzhu/hlzhu/Iter_ParseNet_final/data/Person/JPEGImages/"+name[0]+'.jpg'
73 | # print(img_name)
74 | # ori_img = cv2.imread(img_name)
75 | # cv2.imshow('image',ori_img)
76 | # cv2.waitKey(1)
77 | # 2008_000195 multi person
78 | # 2008_002829 single person
79 | if name[0] == "2008_002829":
80 | print("2008_002829.jpg")
81 | ori_size = ori_size[0].numpy()
82 |
83 | output = predict(model, image.numpy(), (np.asscalar(ori_size[0]), np.asscalar(ori_size[1])),
84 | is_mirror=args.is_mirror, scales=args.eval_scale)
85 | seg_pred = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
86 |
87 | # output_im = PILImage.fromarray(seg_pred)
88 | # output_im.putpalette(palette)
89 | # output_im.save(args.save_dir + name[0] + '.png')
90 |
91 | seg_gt = np.asarray(label[0].numpy(), dtype=np.int)
92 | ignore_index = seg_gt != 255
93 | seg_gt = seg_gt[ignore_index]
94 | seg_pred = seg_pred[ignore_index]
95 |
96 | confusion_matrix += get_confusion_matrix(seg_gt, seg_pred, args.num_classes)
97 |
98 | pos = confusion_matrix.sum(1)
99 | res = confusion_matrix.sum(0)
100 | tp = np.diag(confusion_matrix)
101 |
102 | pixel_accuracy = tp.sum() / pos.sum()
103 | mean_accuracy = (tp / np.maximum(1.0, pos)).mean()
104 | IU_array = (tp / np.maximum(1.0, pos + res - tp))
105 | mean_IU = IU_array.mean()
106 |
107 | # get_confusion_matrix_plot()
108 |
109 | print('Pixel accuracy: %f \n' % pixel_accuracy)
110 | print('Mean accuracy: %f \n' % mean_accuracy)
111 | print('Mean IU: %f \n' % mean_IU)
112 | for index, IU in enumerate(IU_array):
113 | print('%f ', IU)
114 |
115 |
116 | def scale_image(image, scale):
117 | image = image[0, :, :, :]
118 | image = image.transpose((1, 2, 0))
119 | image = cv2.resize(image, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
120 | image = image.transpose((2, 0, 1))
121 | return image
122 |
123 |
124 | def predict(net, image, output_size, is_mirror=True, scales=[1]):
125 | if is_mirror:
126 | image_rev = image[:, :, :, ::-1]
127 |
128 | interp = nn.Upsample(size=output_size, mode='bilinear', align_corners=True)
129 |
130 | outputs = []
131 | if is_mirror:
132 | for scale in scales:
133 | if scale != 1:
134 | image_scale = scale_image(image=image, scale=scale)
135 | image_rev_scale = scale_image(image=image_rev, scale=scale)
136 | else:
137 | image_scale = image[0, :, :, :]
138 | image_rev_scale = image_rev[0, :, :, :]
139 |
140 | image_scale = np.stack((image_scale, image_rev_scale))
141 |
142 | with torch.no_grad():
143 | prediction = net(Variable(torch.from_numpy(image_scale)).cuda())
144 | prediction = interp(prediction[0]).cpu().data.numpy()
145 |
146 | prediction_rev = prediction[1, :, :, :].copy()
147 | prediction_rev = prediction_rev[:, :, ::-1]
148 | prediction = prediction[0, :, :, :]
149 | prediction = np.mean([prediction, prediction_rev], axis=0)
150 |
151 | outputs.append(prediction)
152 |
153 | outputs = np.mean(outputs, axis=0)
154 | outputs = outputs.transpose(1, 2, 0)
155 | else:
156 | for scale in scales:
157 | if scale != 1:
158 | image_scale = scale_image(image=image, scale=scale)
159 | else:
160 | image_scale = image[0, :, :, :]
161 |
162 | with torch.no_grad():
163 | prediction = net(Variable(torch.from_numpy(image_scale).unsqueeze(0)).cuda())
164 | prediction = interp(prediction[0]).cpu().data.numpy()
165 | outputs.append(prediction[0, :, :, :])
166 |
167 | outputs = np.mean(outputs, axis=0)
168 | outputs = outputs.transpose(1, 2, 0)
169 |
170 | return outputs
171 |
172 |
173 | def get_confusion_matrix(gt_label, pred_label, class_num):
174 | """
175 | Calculate the confusion matrix by given label and pred
176 | :param gt_label: the ground truth label
177 | :param pred_label: the pred label
178 | :param class_num: the nunber of class
179 | """
180 | index = (gt_label * class_num + pred_label).astype('int32')
181 | label_count = np.bincount(index)
182 | confusion_matrix = np.zeros((class_num, class_num))
183 |
184 | for i_label in range(class_num):
185 | for i_pred_label in range(class_num):
186 | cur_index = i_label * class_num + i_pred_label
187 | if cur_index < len(label_count):
188 | confusion_matrix[i_label, i_pred_label] = label_count[cur_index]
189 |
190 | return confusion_matrix
191 |
192 |
193 | def get_confusion_matrix_plot(conf_arr):
194 | norm_conf = []
195 | for i in conf_arr:
196 | tmp_arr = []
197 | a = sum(i, 0)
198 | for j in i:
199 | tmp_arr.append(float(j) / max(1.0, float(a)))
200 | norm_conf.append(tmp_arr)
201 |
202 | fig = plt.figure()
203 | plt.clf()
204 | ax = fig.add_subplot(111)
205 | ax.set_aspect(1)
206 | res = ax.imshow(np.array(norm_conf), cmap=plt.cm.jet, interpolation='nearest')
207 |
208 | width, height = conf_arr.shape
209 |
210 | cb = fig.colorbar(res)
211 | alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
212 | plt.xticks(range(width), alphabet[:width])
213 | plt.yticks(range(height), alphabet[:height])
214 | plt.savefig('confusion_matrix.png', format='png')
215 |
216 |
217 | def get_lip_palette():
218 | palette = [0, 0, 0,
219 | 128, 0, 0,
220 | 255, 0, 0,
221 | 0, 85, 0,
222 | 170, 0, 51,
223 | 255, 85, 0,
224 | 0, 0, 85,
225 | 0, 119, 221,
226 | 85, 85, 0,
227 | 0, 85, 85,
228 | 85, 51, 0,
229 | 52, 86, 128,
230 | 0, 128, 0,
231 | 0, 0, 255,
232 | 51, 170, 221,
233 | 0, 255, 255,
234 | 85, 255, 170,
235 | 170, 255, 85,
236 | 255, 255, 0,
237 | 255, 170, 0]
238 | return palette
239 |
240 |
241 | if __name__ == '__main__':
242 | main()
243 |
--------------------------------------------------------------------------------