├── .idea
├── DeepLabv3_MobileNetv2.iml
├── codeStyles
│ └── codeStyleConfig.xml
├── dictionaries
│ └── root.xml
├── misc.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── ImageNet_pretrain.pth
├── MobileNetv2_DeepLabv3_cityscapes
└── checkpoints
│ └── Checkpoint_epoch_150.pth.tar
├── README.md
├── cityscapes.py
├── config.py
├── img
└── Screenshot from 2018-07-13 10-45-35.png
├── layers.py
├── main.py
├── network.py
├── progressbar.py
├── transfer_weights.py
└── utils.py
/.idea/DeepLabv3_MobileNetv2.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/codeStyles/codeStyleConfig.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/.idea/dictionaries/root.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | bilinear
5 | cityscapes
6 | concate
7 | conv
8 | convolutional
9 | cuda
10 | dataloader
11 | dataset
12 | datasets
13 | depthwise
14 | softmax
15 | upsample
16 |
17 |
18 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
117 |
118 |
119 |
120 | shuffle
121 | dataset
122 | test
123 | dropout
124 | Params
125 | asp
126 | resize
127 | checkpoint
128 | train_loss
129 | self.train_loss
130 | summar
131 | Test
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 | 1530843905882
357 |
358 |
359 | 1530843905882
360 |
361 |
362 | 1530844058511
363 |
364 |
365 |
366 | 1530844058511
367 |
368 |
369 | 1531204463551
370 |
371 |
372 |
373 | 1531204463551
374 |
375 |
376 | 1531220857653
377 |
378 |
379 |
380 | 1531220857653
381 |
382 |
383 | 1531293351155
384 |
385 |
386 |
387 | 1531293351155
388 |
389 |
390 | 1531294388339
391 |
392 |
393 |
394 | 1531294388339
395 |
396 |
397 | 1531362813430
398 |
399 |
400 |
401 | 1531362813430
402 |
403 |
404 | 1531395549910
405 |
406 |
407 |
408 | 1531395549910
409 |
410 |
411 | 1531450065034
412 |
413 |
414 |
415 | 1531450065034
416 |
417 |
418 | 1531910390601
419 |
420 |
421 |
422 | 1531910390601
423 |
424 |
425 | 1532671785934
426 |
427 |
428 |
429 | 1532671785934
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 |
452 |
453 |
454 |
455 |
456 |
457 |
458 |
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
467 |
468 |
469 |
470 |
471 |
472 |
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
502 |
503 |
504 |
505 | file://$PROJECT_DIR$/utils.py
506 | 123
507 |
508 |
509 |
510 |
511 |
512 |
513 |
514 |
515 |
516 |
517 |
518 |
519 |
520 |
521 |
522 |
523 |
524 |
525 |
526 |
527 |
528 | list(net.state_dict().keys())[0]
529 | Python
530 | CODE_FRAGMENT
531 |
532 |
533 | net.state_dict().keys()[0]
534 | Python
535 | CODE_FRAGMENT
536 |
537 |
538 | net.state_dict().keys()
539 | Python
540 | CODE_FRAGMENT
541 |
542 |
543 | state_dict.values()
544 | Python
545 | CODE_FRAGMENT
546 |
547 |
548 | state_dict.items()
549 | Python
550 | CODE_FRAGMENT
551 |
552 |
553 | len(state_dict)
554 | Python
555 | CODE_FRAGMENT
556 |
557 |
558 | for layer in net.state_dict():
559 | print(layer)
560 | Python
561 | CODE_FRAGMENT
562 |
563 |
564 | for layer in state_dict:
565 | print(layer)
566 | Python
567 | CODE_FRAGMENT
568 |
569 |
570 | for layer in state_dict:
571 | print(layer.shape)
572 | Python
573 | CODE_FRAGMENT
574 |
575 |
576 |
577 |
578 | net.state_dict()
579 | Python
580 | EXPRESSION
581 |
582 |
583 | net.state_dict()['network.0.0']
584 | Python
585 | EXPRESSION
586 |
587 |
588 |
589 |
590 |
591 |
592 |
593 |
594 |
595 |
596 |
597 |
598 |
599 |
600 |
601 |
602 |
603 |
604 |
605 |
606 |
607 |
608 |
609 |
610 |
611 |
612 |
613 |
614 |
615 |
616 |
617 |
618 |
619 |
620 |
621 |
622 |
623 |
624 |
625 |
626 |
627 |
628 |
629 |
630 |
631 |
632 |
633 |
634 |
635 |
636 |
637 |
638 |
639 |
640 |
641 |
642 |
643 |
644 |
645 |
646 |
647 |
648 |
649 |
650 |
651 |
652 |
653 |
654 |
655 |
656 |
657 |
658 |
659 |
660 |
661 |
662 |
663 |
664 |
665 |
666 |
667 |
668 |
669 |
670 |
671 |
672 |
673 |
674 |
675 |
676 |
677 |
678 |
679 |
680 |
681 |
682 |
683 |
684 |
685 |
686 |
687 |
688 |
689 |
690 |
691 |
692 |
693 |
694 |
695 |
696 |
697 |
698 |
699 |
700 |
701 |
702 |
703 |
704 |
705 |
706 |
707 |
708 |
709 |
710 |
711 |
712 |
713 |
714 |
715 |
716 |
717 |
718 |
719 |
720 |
721 |
722 |
723 |
724 |
725 |
726 |
727 |
728 |
729 |
730 |
731 |
732 |
733 |
734 |
735 |
736 |
737 |
738 |
739 |
740 |
741 |
742 |
743 |
744 |
745 |
746 |
747 |
748 |
749 |
750 |
751 |
752 |
753 |
754 |
755 |
756 |
757 |
758 |
759 |
760 |
761 |
762 |
763 |
764 |
765 |
766 |
767 |
768 |
769 |
770 |
771 |
772 |
773 |
774 |
775 |
776 |
777 |
778 |
779 |
780 |
781 |
782 |
783 |
784 |
785 |
786 |
787 |
788 |
789 |
790 |
791 |
792 |
793 |
794 |
795 |
796 |
797 |
798 |
799 |
800 |
801 |
802 |
803 |
804 |
805 |
806 |
807 |
808 |
809 |
810 |
811 |
812 |
813 |
814 |
815 |
816 |
817 |
818 |
819 |
820 |
821 |
822 |
823 |
824 |
825 |
826 |
827 |
828 |
829 |
830 |
831 |
832 |
833 |
834 |
835 |
836 |
837 |
838 |
839 |
840 |
841 |
842 |
843 |
844 |
845 |
846 |
847 |
848 |
849 |
850 |
851 |
852 |
853 |
854 |
855 |
856 |
857 |
858 |
859 |
860 |
861 |
862 |
863 |
864 |
865 |
866 |
867 |
868 |
869 |
870 |
871 |
872 |
873 |
874 |
875 |
876 |
877 |
878 |
879 |
880 |
881 |
882 |
883 |
884 |
885 |
886 |
887 |
888 |
889 |
890 |
891 |
892 |
893 |
894 |
895 |
896 |
897 |
898 |
899 |
900 |
901 |
902 |
903 |
904 |
905 |
906 |
907 |
908 |
909 |
910 |
911 |
912 |
913 |
914 |
915 |
916 |
917 |
918 |
919 |
920 |
921 |
922 |
923 |
924 |
925 |
926 |
927 |
928 |
929 |
930 |
931 |
932 |
933 |
934 |
935 |
936 |
937 |
938 |
939 |
940 |
941 |
942 |
943 |
944 |
945 |
946 |
947 |
948 |
949 |
950 |
951 |
952 |
953 |
954 |
955 |
956 |
957 |
958 |
959 |
960 |
961 |
962 |
963 |
964 |
965 |
966 |
967 |
968 |
969 |
970 |
971 |
--------------------------------------------------------------------------------
/ImageNet_pretrain.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zym1119/DeepLabv3_MobileNetv2_PyTorch/5f3ee3060fa005657e8ff3bc301197107a06444c/ImageNet_pretrain.pth
--------------------------------------------------------------------------------
/MobileNetv2_DeepLabv3_cityscapes/checkpoints/Checkpoint_epoch_150.pth.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zym1119/DeepLabv3_MobileNetv2_PyTorch/5f3ee3060fa005657e8ff3bc301197107a06444c/MobileNetv2_DeepLabv3_cityscapes/checkpoints/Checkpoint_epoch_150.pth.tar
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DeepLabv3_MobileNetv2
2 | This is a PyTorch implementation of MobileNet v2 network with DeepLab v3 structure used for semantic segmentation.
3 |
4 | The backbone of MobileNetv2 comes from paper:
5 | >[Inverted Residuals and Linear Bottlenecks: Mobile Networks for Classification, Detection and Segmentation ](https://arxiv.org/abs/1801.04381v3)
6 |
7 | And the segment head of DeepLabv3 comes from paper:
8 | >[Rethinking Atrous Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1706.05587)
9 |
10 | Please refer to these papers about details like Atrous Convolution, Inverted Residuals, Depthwise Convolution or ASPP if you have some confusion about these blocks.
11 |
12 | # Results
13 | After training for 150 epochs, without any further tuning, the first training result on test set is like:
14 | 
15 |
16 | Feel free to change any config or code in this repo :-)
17 |
18 | # How to use?
19 | First you need to install dependencies of this implementation.
20 | This implementation is written under Python 3.5 with following libs:
21 | >torch 0.4.0
22 | torchvision 0.2.1
23 | numpy 1.14.5
24 | opencv-python 3.4.1.15
25 | tensorflow 1.8.0 (necessary for tensorboardX)
26 | tensorboardX 1.2
27 |
28 | use `sudo pip install lib` to install them
29 |
30 | Then, prepare cityscapes dataset or your own dataset.
31 | Currently, cityscapes is the only supported dataset without any modification.
32 |
33 | Cityscapes dataset should have the following hierachy:
34 | ```
35 | dataset_root
36 | | trainImages.txt
37 | | trainLabels.txt
38 | | valImages.txt
39 | | valLabels.txt
40 | |
41 | └───gtFine(Label Folder)
42 | | └───train(train set)
43 | | | └───aachen(city)
44 | | | └───bochum
45 | | | └───...
46 | | |
47 | | └───test(test set)
48 | | └───val(val set)
49 | |
50 | └───leftImg8bit(Image Folder)
51 | └───train
52 | └───test
53 | └───val
54 | ```
55 | Don't worry about txt files if you don't have them, this program can generate unexist txt files automatically.
56 |
57 | Third, modify `config.py` to fit your own training policy or configuration
58 |
59 | At last, run `python main.py --root /your/path/to/dataset/` or just run `python main.py`
60 |
61 | After training, tensorboard is also available to observe training procedure using `tensorboard --logdir=./exp_dir/summaries`
62 |
63 | # Tips
64 | I have changed a little from origin MobileNetv2 and DeepLabv3 network, here are the changes:
65 | ```
66 | 1. The multi-grid blocks have the same structure with the 7-th layer in MobileNetv2 while
67 | the rest layers of MobileNetv2 are discarded.
68 | 2. The lr decay is determined by epoch not iterations as in DeepLab and the input image
69 | is randomly cropped by 512 instead of 513 in DeepLab.
70 | 3. During training, a input image is first resized so that the shorter side is 600 pixel,
71 | then cropped into 512 pixels square and sent into network.
72 | ```
73 |
74 | If you have some question, please leave an issue.
75 |
76 | ImageNet pre-trained weights are loaded from [Randl's github](https://github.com/Randl/MobileNetV2-pytorch), really helpful!
77 |
78 | # TO-DO
79 | 1. add cityscapes visualization tools(Done)
80 | 2. fine-tune training policy
81 | 3. use ImageNet pre-trained model(Done)
82 |
83 | # Logs
84 | | Date | Changes |
85 | |------|----------------------------|
86 | | 7.11 | fix bugs in network.Test(), add cityscapes output visualization function |
87 | | 7.12 | fix bugs in network.plot_curve(), add checkpoint split to avoid out of memory, add save loss in network.save_checkpoint() |
88 | | 7.13 | fix bugs in figure save, add checkpoint@150epoch |
89 | | 7.27 | upload ImageNet pre-trained weight |
90 |
--------------------------------------------------------------------------------
/cityscapes.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import matplotlib.pyplot as plt
3 | import torch
4 | from collections import namedtuple
5 | import numpy as np
6 |
7 | """###############"""
8 | """# Definitions #"""
9 | """###############"""
10 | # following definition are copied from github repository:
11 | # mcordts/cityscapesScripts/cityscapesscripts/helpers/labels.py
12 | # a label and all meta information
13 | Label = namedtuple( 'Label' , [
14 |
15 | 'name' , # The identifier of this label, e.g. 'car', 'person', ... .
16 | # We use them to uniquely name a class
17 |
18 | 'id' , # An integer ID that is associated with this label.
19 | # The IDs are used to represent the label in ground truth images
20 | # An ID of -1 means that this label does not have an ID and thus
21 | # is ignored when creating ground truth images (e.g. license plate).
22 | # Do not modify these IDs, since exactly these IDs are expected by the
23 | # evaluation server.
24 |
25 | 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create
26 | # ground truth images with train IDs, using the tools provided in the
27 | # 'preparation' folder. However, make sure to validate or submit results
28 | # to our evaluation server using the regular IDs above!
29 | # For trainIds, multiple labels might have the same ID. Then, these labels
30 | # are mapped to the same class in the ground truth images. For the inverse
31 | # mapping, we use the label that is defined first in the list below.
32 | # For example, mapping all void-type classes to the same ID in training,
33 | # might make sense for some approaches.
34 | # Max value is 255!
35 |
36 | 'category' , # The name of the category that this label belongs to
37 |
38 | 'categoryId' , # The ID of this category. Used to create ground truth images
39 | # on category level.
40 |
41 | 'hasInstances', # Whether this label distinguishes between single instances or not
42 |
43 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
44 | # during evaluations or not
45 |
46 | 'color' , # The color of this label
47 | ] )
48 |
49 |
50 | #--------------------------------------------------------------------------------
51 | # A list of all labels
52 | #--------------------------------------------------------------------------------
53 |
54 | # Please adapt the train IDs as appropriate for your approach.
55 | # Note that you might want to ignore labels with ID 255 during training.
56 | # Further note that the current train IDs are only a suggestion. You can use whatever you like.
57 | # Make sure to provide your results using the original IDs and not the training IDs.
58 | # Note that many IDs are ignored in evaluation and thus you never need to predict these!
59 |
60 | labels = [
61 | # name id trainId category catId hasInstances ignoreInEval color
62 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
63 | Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
64 | Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
65 | Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
66 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
67 | Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ),
68 | Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ),
69 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ),
70 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ),
71 | Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ),
72 | Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ),
73 | Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ),
74 | Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ),
75 | Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ),
76 | Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ),
77 | Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ),
78 | Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ),
79 | Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ),
80 | Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ),
81 | Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ),
82 | Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ),
83 | Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ),
84 | Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ),
85 | Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ),
86 | Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ),
87 | Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ),
88 | Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ),
89 | Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ),
90 | Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ),
91 | Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ),
92 | Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ),
93 | Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ),
94 | Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ),
95 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ),
96 | Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,142) ),
97 | ]
98 |
99 | """###################"""
100 | """# Transformations #"""
101 | """###################"""
102 |
103 | def logits2trainId(logits):
104 | """
105 | Transform output of network into trainId map
106 | :param logits: output tensor of network, before softmax, should be in shape (#classes, h, w)
107 | """
108 | # squeeze logits
109 | # num_classes = logits.size[1]
110 | upsample = torch.nn.Upsample(size=(1024, 2048), mode='bilinear', align_corners=False)
111 | logits = upsample(logits.unsqueeze_(0))
112 | logits.squeeze_(0)
113 | logits = torch.argmax(logits, dim=0)
114 |
115 | return logits
116 |
117 |
118 | def trainId2color(dataset_root, id_map, name):
119 | """
120 | Transform trainId map into color map
121 | :param dataset_root: the path to dataset root, eg. '/media/ubuntu/disk/cityscapes'
122 | :param id_map: torch tensor
123 | :param name: name of image, eg. 'gtFine/test/leverkusen/leverkusen_000027_000019_gtFine_labelTrainIds.png'
124 | """
125 | # transform = {label.trainId: label.color for label in labels}
126 | assert len(id_map.shape) == 2, 'Id_map must be a 2-D tensor of shape (h, w) where h, w = H, W / output_stride'
127 | h, w = id_map.shape
128 | color_map = np.zeros((h, w, 3))
129 | id_map = id_map.cpu().numpy()
130 | for label in labels:
131 | if not label.ignoreInEval:
132 | color_map[id_map == label.trainId] = np.array(label.color)
133 | color_map = color_map.astype(np.uint8)
134 | # color_map = cv2.resize(color_map, dsize=(2048, 1024), interpolation=cv2.INTER_NEAREST)
135 |
136 | # save trainIds and color
137 | cv2.imwrite(dataset_root + '/' + name, id_map)
138 | name = name.replace('labelTrainIds', 'color')
139 | cv2.imwrite(dataset_root + '/' + name, color_map)
140 |
141 | return color_map
142 |
143 |
144 | def trainId2LabelId(dataset_root, train_id, name):
145 | """
146 | Transform trainId map into labelId map
147 | :param dataset_root: the path to dataset root, eg. '/media/ubuntu/disk/cityscapes'
148 | :param id_map: torch tensor
149 | :param name: name of image, eg. 'gtFine/test/leverkusen/leverkusen_000027_000019_gtFine_labelTrainIds.png'
150 | """
151 | assert len(train_id.shape) == 2, 'Id_map must be a 2-D tensor of shape (h, w) where h, w = H, W / output_stride'
152 | h, w = train_id.shape
153 | label_id = np.zeros((h, w, 3))
154 | train_id = train_id.cpu().numpy()
155 | for label in labels:
156 | if not label.ignoreInEval:
157 | label_id[train_id == label.trainId] = np.array([label.id]*3)
158 | label_id = label_id.astype(np.uint8)
159 | # label_id = cv2.resize(label_id, dsize=(2048, 1024), interpolation=cv2.INTER_NEAREST)
160 |
161 | name = name.replace('labelTrainIds', 'labelIds')
162 | cv2.imwrite(dataset_root + '/' + name, label_id)
163 |
164 | if __name__ == '__main__':
165 | trainId = cv2.imread('/media/ubuntu/disk/cityscapes/gtFine/train/aachen/aachen_000000_000019_gtFine_labelTrainIds.png')
166 | trainId2color(trainId[:, :, 0])
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | from utils import create_train_dir
2 |
3 |
4 | """ Dataset parameters """
5 | class Params():
6 | def __init__(self):
7 | # network structure parameters
8 | self.model = 'MobileNetv2_DeepLabv3'
9 | self.dataset = 'cityscapes'
10 | self.s = [2, 1, 2, 2, 2, 1, 1] # stride of each conv stage
11 | self.t = [1, 1, 6, 6, 6, 6, 6] # expansion factor t
12 | self.n = [1, 1, 2, 3, 4, 3, 3] # number of repeat time
13 | self.c = [32, 16, 24, 32, 64, 96, 160] # output channel of each conv stage
14 | self.output_stride = 16
15 | self.multi_grid = (1, 2, 4)
16 | self.aspp = (6, 12, 18)
17 | self.down_sample_rate = 32 # classic down sample rate
18 |
19 | # dataset parameters
20 | self.rescale_size = 600
21 | self.image_size = 512
22 | self.num_class = 20 # 20 classes for training
23 | self.dataset_root = '/path/to/your/dataset'
24 | self.dataloader_workers = 12
25 | self.shuffle = True
26 | self.train_batch = 10
27 | self.val_batch = 2
28 | self.test_batch = 1
29 |
30 | # train parameters
31 | self.num_epoch = 150
32 | self.base_lr = 0.0002
33 | self.power = 0.9
34 | self.momentum = 0.9
35 | self.weight_decay = 0.0005
36 | self.should_val = True
37 | self.val_every = 2
38 | self.display = 1 # show train result every display epoch
39 | self.should_split = True # should split training procedure into several parts
40 | self.split = 2 # number of split
41 |
42 | # model restore parameters
43 | self.resume_from = None # None for train from scratch
44 | self.pre_trained_from = None # None for train from scratch
45 | self.should_save = True
46 | self.save_every = 10
47 |
48 | # create training dir
49 | self.summary_dir, self.ckpt_dir = create_train_dir(self)
50 |
51 | if __name__ == '__main__':
52 | aa = Params()
53 | print(aa)
--------------------------------------------------------------------------------
/img/Screenshot from 2018-07-13 10-45-35.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zym1119/DeepLabv3_MobileNetv2_PyTorch/5f3ee3060fa005657e8ff3bc301197107a06444c/img/Screenshot from 2018-07-13 10-45-35.png
--------------------------------------------------------------------------------
/layers.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 |
5 | class InvertedResidual(nn.Module):
6 | def __init__(self, in_channels, out_channels, t=6, s=1, dilation=1):
7 | """
8 | Initialization of inverted residual block
9 | :param in_channels: number of input channels
10 | :param out_channels: number of output channels
11 | :param t: the expansion factor of block
12 | :param s: stride of the first convolution
13 | :param dilation: dilation rate of 3*3 depthwise conv
14 | """
15 | super(InvertedResidual, self).__init__()
16 |
17 | self.in_ = in_channels
18 | self.out_ = out_channels
19 | self.t = t
20 | self.s = s
21 | self.dilation = dilation
22 | self.inverted_residual_block()
23 |
24 | def inverted_residual_block(self):
25 | """
26 | Build Inverted Residual Block and residual connection
27 | """
28 | block = []
29 | # pad = 1 if self.s == 3 else 0
30 | # conv 1*1
31 | block.append(nn.Conv2d(self.in_, self.in_*self.t, 1, bias=False))
32 | block.append(nn.BatchNorm2d(self.in_*self.t))
33 | block.append(nn.ReLU6())
34 |
35 | # conv 3*3 depthwise
36 | block.append(nn.Conv2d(self.in_*self.t, self.in_*self.t, 3,
37 | stride=self.s, padding=self.dilation, groups=self.in_*self.t, dilation=self.dilation,
38 | bias=False))
39 | block.append(nn.BatchNorm2d(self.in_*self.t))
40 | block.append(nn.ReLU6())
41 |
42 | # conv 1*1 linear
43 | block.append(nn.Conv2d(self.in_*self.t, self.out_, 1, bias=False))
44 | block.append(nn.BatchNorm2d(self.out_))
45 |
46 | self.block = nn.Sequential(*block)
47 |
48 | # if use conv residual connection
49 | if self.in_ != self.out_ and self.s != 2:
50 | self.res_conv = nn.Sequential(nn.Conv2d(self.in_, self.out_, 1, bias=False),
51 | nn.BatchNorm2d(self.out_))
52 | else:
53 | self.res_conv = None
54 |
55 | def forward(self, x):
56 | if self.s == 1:
57 | # use residual connection
58 | if self.res_conv is None:
59 | out = x + self.block(x)
60 | else:
61 | out = self.res_conv(x) + self.block(x)
62 | else:
63 | # plain block
64 | out = self.block(x)
65 |
66 | return out
67 |
68 |
69 | def get_inverted_residual_block_arr(in_, out_, t=6, s=1, n=1):
70 | block = []
71 | block.append(InvertedResidual(in_, out_, t, s=s))
72 | for i in range(n-1):
73 | block.append(InvertedResidual(out_, out_, t, 1))
74 | return block
75 |
76 |
77 | class ASPP_plus(nn.Module):
78 | def __init__(self, params):
79 | super(ASPP_plus, self).__init__()
80 | self.conv11 = nn.Sequential(nn.Conv2d(params.c[-1], 256, 1, bias=False),
81 | nn.BatchNorm2d(256))
82 | self.conv33_1 = nn.Sequential(nn.Conv2d(params.c[-1], 256, 3,
83 | padding=params.aspp[0], dilation=params.aspp[0], bias=False),
84 | nn.BatchNorm2d(256))
85 | self.conv33_2 = nn.Sequential(nn.Conv2d(params.c[-1], 256, 3,
86 | padding=params.aspp[1], dilation=params.aspp[1], bias=False),
87 | nn.BatchNorm2d(256))
88 | self.conv33_3 = nn.Sequential(nn.Conv2d(params.c[-1], 256, 3,
89 | padding=params.aspp[2], dilation=params.aspp[2], bias=False),
90 | nn.BatchNorm2d(256))
91 | self.concate_conv = nn.Sequential(nn.Conv2d(256*5, 256, 1, bias=False),
92 | nn.BatchNorm2d(256))
93 | # self.upsample = nn.Upsample(mode='bilinear', align_corners=True)
94 | def forward(self, x):
95 | conv11 = self.conv11(x)
96 | conv33_1 = self.conv33_1(x)
97 | conv33_2 = self.conv33_2(x)
98 | conv33_3 = self.conv33_3(x)
99 |
100 | # image pool and upsample
101 | image_pool = nn.AvgPool2d(kernel_size=x.size()[2:])
102 | image_pool = image_pool(x)
103 | image_pool = self.conv11(image_pool)
104 | upsample = nn.Upsample(size=x.size()[2:], mode='bilinear', align_corners=True)
105 | upsample = upsample(image_pool)
106 |
107 | # concate
108 | concate = torch.cat([conv11, conv33_1, conv33_2, conv33_3, upsample], dim=1)
109 |
110 | return self.concate_conv(concate)
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from utils import create_dataset
4 | from network import MobileNetv2_DeepLabv3
5 | from config import Params
6 | from utils import print_config
7 |
8 |
9 | LOG = lambda x: print('\033[0;31;2m' + x + '\033[0m')
10 |
11 |
12 | def main():
13 | # add argumentation
14 | parser = argparse.ArgumentParser(description='MobileNet_v2_DeepLab_v3 Pytorch Implementation')
15 | parser.add_argument('--dataset', default='cityscapes', choices=['cityscapes', 'other'],
16 | help='Dataset used in training MobileNet v2+DeepLab v3')
17 | parser.add_argument('--root', default='./data/cityscapes', help='Path to your dataset')
18 | parser.add_argument('--epoch', default=None, help='Total number of training epoch')
19 | parser.add_argument('--lr', default=None, help='Base learning rate')
20 | parser.add_argument('--pretrain', default=None, help='Path to a pre-trained backbone model')
21 | parser.add_argument('--resume_from', default=None, help='Path to a checkpoint to resume model')
22 |
23 | args = parser.parse_args()
24 | params = Params()
25 |
26 | # parse args
27 | if not os.path.exists(args.root):
28 | if params.dataset_root is None:
29 | raise ValueError('ERROR: Root %s not exists!' % args.root)
30 | else:
31 | params.dataset_root = args.root
32 | if args.epoch is not None:
33 | params.num_epoch = args.epoch
34 | if args.lr is not None:
35 | params.base_lr = args.lr
36 | if args.pretrain is not None:
37 | params.pre_trained_from = args.pretrain
38 | if args.resume_from is not None:
39 | params.resume_from = args.resume_from
40 |
41 | LOG('Network parameters:')
42 | print_config(params)
43 |
44 | # create dataset and transformation
45 | LOG('Creating Dataset and Transformation......')
46 | datasets = create_dataset(params)
47 | LOG('Creation Succeed.\n')
48 |
49 | # create model
50 | LOG('Initializing MobileNet and DeepLab......')
51 | net = MobileNetv2_DeepLabv3(params, datasets)
52 | LOG('Model Built.\n')
53 |
54 | # let's start to train!
55 | net.Train()
56 | net.Test()
57 |
58 | if __name__ == '__main__':
59 | main()
--------------------------------------------------------------------------------
/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import os
4 | from torch.utils.data import DataLoader
5 | from tensorboardX import SummaryWriter
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 | from torch.utils.checkpoint import checkpoint_sequential
9 |
10 | import layers
11 | from progressbar import bar
12 | from cityscapes import logits2trainId, trainId2color, trainId2LabelId
13 |
14 | WARNING = lambda x: print('\033[1;31;2mWARNING: ' + x + '\033[0m')
15 | LOG = lambda x: print('\033[0;31;2m' + x + '\033[0m')
16 |
17 | # create model
18 | class MobileNetv2_DeepLabv3(nn.Module):
19 | """
20 | A Convolutional Neural Network with MobileNet v2 backbone and DeepLab v3 head
21 | used for Semantic Segmentation on Cityscapes dataset
22 | """
23 |
24 | """######################"""
25 | """# Model Construction #"""
26 | """######################"""
27 |
28 | def __init__(self, params, datasets):
29 | super(MobileNetv2_DeepLabv3, self).__init__()
30 | self.params = params
31 | self.datasets = datasets
32 | self.pb = bar() # hand-made progressbar
33 | self.epoch = 0
34 | self.init_epoch = 0
35 | self.ckpt_flag = False
36 | self.train_loss = []
37 | self.val_loss = []
38 | self.summary_writer = SummaryWriter(log_dir=self.params.summary_dir)
39 |
40 | # build network
41 | block = []
42 |
43 | # conv layer 1
44 | block.append(nn.Sequential(nn.Conv2d(3, self.params.c[0], 3, stride=self.params.s[0], padding=1, bias=False),
45 | nn.BatchNorm2d(self.params.c[0]),
46 | # nn.Dropout2d(self.params.dropout_prob, inplace=True),
47 | nn.ReLU6()))
48 |
49 | # conv layer 2-7
50 | for i in range(6):
51 | block.extend(layers.get_inverted_residual_block_arr(self.params.c[i], self.params.c[i+1],
52 | t=self.params.t[i+1], s=self.params.s[i+1],
53 | n=self.params.n[i+1]))
54 |
55 | # dilated conv layer 1-4
56 | # first dilation=rate, follows dilation=multi_grid*rate
57 | rate = self.params.down_sample_rate // self.params.output_stride
58 | block.append(layers.InvertedResidual(self.params.c[6], self.params.c[6],
59 | t=self.params.t[6], s=1, dilation=rate))
60 | for i in range(3):
61 | block.append(layers.InvertedResidual(self.params.c[6], self.params.c[6],
62 | t=self.params.t[6], s=1, dilation=rate*self.params.multi_grid[i]))
63 |
64 | # ASPP layer
65 | block.append(layers.ASPP_plus(self.params))
66 |
67 | # final conv layer
68 | block.append(nn.Conv2d(256, self.params.num_class, 1))
69 |
70 | # bilinear upsample
71 | block.append(nn.Upsample(scale_factor=self.params.output_stride, mode='bilinear', align_corners=False))
72 |
73 | self.network = nn.Sequential(*block).cuda()
74 | # print(self.network)
75 |
76 | # build loss
77 | self.loss_fn = nn.CrossEntropyLoss(ignore_index=255)
78 |
79 | # optimizer
80 | self.opt = torch.optim.RMSprop(self.network.parameters(),
81 | lr=self.params.base_lr,
82 | momentum=self.params.momentum,
83 | weight_decay=self.params.weight_decay)
84 |
85 | # initialize
86 | self.initialize()
87 |
88 | # load data
89 | self.load_checkpoint()
90 | self.load_model()
91 |
92 | """######################"""
93 | """# Train and Validate #"""
94 | """######################"""
95 |
96 | def train_one_epoch(self):
97 | """
98 | Train network in one epoch
99 | """
100 | print('Training......')
101 |
102 | # set mode train
103 | self.network.train()
104 |
105 | # prepare data
106 | train_loss = 0
107 | train_loader = DataLoader(self.datasets['train'],
108 | batch_size=self.params.train_batch,
109 | shuffle=self.params.shuffle,
110 | num_workers=self.params.dataloader_workers)
111 | train_size = len(self.datasets['train'])
112 | if train_size % self.params.train_batch != 0:
113 | total_batch = train_size // self.params.train_batch + 1
114 | else:
115 | total_batch = train_size // self.params.train_batch
116 |
117 | # train through dataset
118 | for batch_idx, batch in enumerate(train_loader):
119 | self.pb.click(batch_idx, total_batch)
120 | image, label = batch['image'], batch['label']
121 | image_cuda, label_cuda = image.cuda(), label.cuda()
122 |
123 | # checkpoint split
124 | if self.params.should_split:
125 | image_cuda.requires_grad_()
126 | out = checkpoint_sequential(self.network, self.params.split, image_cuda)
127 | else:
128 | out = self.network(image_cuda)
129 | loss = self.loss_fn(out, label_cuda)
130 |
131 | # optimize
132 | self.opt.zero_grad()
133 | loss.backward()
134 | self.opt.step()
135 |
136 | # accumulate
137 | train_loss += loss.item()
138 |
139 | # record first loss
140 | if self.train_loss == []:
141 | self.train_loss.append(train_loss)
142 | self.summary_writer.add_scalar('loss/train_loss', train_loss, 0)
143 |
144 | self.pb.close()
145 | train_loss /= total_batch
146 | self.train_loss.append(train_loss)
147 |
148 | # add to summary
149 | self.summary_writer.add_scalar('loss/train_loss', train_loss, self.epoch)
150 |
151 | def val_one_epoch(self):
152 | """
153 | Validate network in one epoch every m training epochs,
154 | m is defined in params.val_every
155 | """
156 | # TODO: add IoU compute function
157 | print('Validating:')
158 |
159 | # set mode eval
160 | self.network.eval()
161 |
162 | # prepare data
163 | val_loss = 0
164 | val_loader = DataLoader(self.datasets['val'],
165 | batch_size=self.params.val_batch,
166 | shuffle=self.params.shuffle,
167 | num_workers=self.params.dataloader_workers)
168 | val_size = len(self.datasets['val'])
169 | if val_size % self.params.val_batch != 0:
170 | total_batch = val_size // self.params.val_batch + 1
171 | else:
172 | total_batch = val_size // self.params.val_batch
173 |
174 | # validate through dataset
175 | for batch_idx, batch in enumerate(val_loader):
176 | self.pb.click(batch_idx, total_batch)
177 | image, label = batch['image'], batch['label']
178 | image_cuda, label_cuda = image.cuda(), label.cuda()
179 |
180 | # checkpoint split
181 | if self.params.should_split:
182 | image_cuda.requires_grad_()
183 | out = checkpoint_sequential(self.network, self.params.split, image_cuda)
184 | else:
185 | out = self.network(image_cuda)
186 |
187 | loss = self.loss_fn(out, label_cuda)
188 |
189 | val_loss += loss.item()
190 |
191 | # record first loss
192 | if self.val_loss == []:
193 | self.val_loss.append(val_loss)
194 | self.summary_writer.add_scalar('loss/val_loss', val_loss, 0)
195 |
196 | self.pb.close()
197 | val_loss /= total_batch
198 | self.val_loss.append(val_loss)
199 |
200 | # add to summary
201 | self.summary_writer.add_scalar('loss/val_loss', val_loss, self.epoch)
202 |
203 |
204 | def Train(self):
205 | """
206 | Train network in n epochs, n is defined in params.num_epoch
207 | """
208 | self.init_epoch = self.epoch
209 | if self.epoch >= self.params.num_epoch:
210 | WARNING('Num_epoch should be smaller than current epoch. Skip training......\n')
211 | else:
212 | for _ in range(self.epoch, self.params.num_epoch):
213 | self.epoch += 1
214 | print('-' * 20 + 'Epoch.' + str(self.epoch) + '-' * 20)
215 |
216 | # train one epoch
217 | self.train_one_epoch()
218 |
219 | # should display
220 | if self.epoch % self.params.display == 0:
221 | print('\tTrain loss: %.4f' % self.train_loss[-1])
222 |
223 | # should save
224 | if self.params.should_save:
225 | if self.epoch % self.params.save_every == 0:
226 | self.save_checkpoint()
227 |
228 | # test every params.test_every epoch
229 | if self.params.should_val:
230 | if self.epoch % self.params.val_every == 0:
231 | self.val_one_epoch()
232 | print('\tVal loss: %.4f' % self.val_loss[-1])
233 |
234 | # adjust learning rate
235 | self.adjust_lr()
236 |
237 | # save the last network state
238 | if self.params.should_save:
239 | self.save_checkpoint()
240 |
241 | # train visualization
242 | self.plot_curve()
243 |
244 | def Test(self):
245 | """
246 | Test network on test set
247 | """
248 | print('Testing:')
249 | # set mode eval
250 | torch.cuda.empty_cache()
251 | self.network.eval()
252 |
253 | # prepare test data
254 | test_loader = DataLoader(self.datasets['test'],
255 | batch_size=self.params.test_batch,
256 | shuffle=False, num_workers=self.params.dataloader_workers)
257 | test_size = len(self.datasets['test'])
258 | if test_size % self.params.test_batch != 0:
259 | total_batch = test_size // self.params.test_batch + 1
260 | else:
261 | total_batch = test_size // self.params.test_batch
262 |
263 | # test for one epoch
264 | for batch_idx, batch in enumerate(test_loader):
265 | self.pb.click(batch_idx, total_batch)
266 | image, label, name = batch['image'], batch['label'], batch['label_name']
267 | image_cuda, label_cuda = image.cuda(), label.cuda()
268 | if self.params.should_split:
269 | image_cuda.requires_grad_()
270 | out = checkpoint_sequential(self.network, self.params.split, image_cuda)
271 | else:
272 | out = self.network(image_cuda)
273 |
274 | for i in range(self.params.test_batch):
275 | idx = batch_idx*self.params.test_batch+i
276 | id_map = logits2trainId(out[i, ...])
277 | color_map = trainId2color(self.params.dataset_root, id_map, name=name[i])
278 | trainId2LabelId(self.params.dataset_root, id_map, name=name[i])
279 | image_orig = image[i].numpy().transpose(1, 2, 0)
280 | image_orig = image_orig*255
281 | image_orig = image_orig.astype(np.uint8)
282 | self.summary_writer.add_image('test/img_%d/orig' % idx, image_orig, idx)
283 | self.summary_writer.add_image('test/img_%d/seg' % idx, color_map, idx)
284 |
285 | """##########################"""
286 | """# Model Save and Restore #"""
287 | """##########################"""
288 |
289 | def save_checkpoint(self):
290 | save_dict = {'epoch' : self.epoch,
291 | 'train_loss' : self.train_loss,
292 | 'val_loss' : self.val_loss,
293 | 'state_dict' : self.network.state_dict(),
294 | 'optimizer' : self.opt.state_dict()}
295 | torch.save(save_dict, self.params.ckpt_dir+'Checkpoint_epoch_%d.pth.tar' % self.epoch)
296 | print('Checkpoint saved')
297 |
298 | def load_checkpoint(self):
299 | """
300 | Load checkpoint from given path
301 | """
302 | if self.params.resume_from is not None and os.path.exists(self.params.resume_from):
303 | try:
304 | LOG('Loading Checkpoint at %s' % self.params.resume_from)
305 | ckpt = torch.load(self.params.resume_from)
306 | self.epoch = ckpt['epoch']
307 | try:
308 | self.train_loss = ckpt['train_loss']
309 | self.val_loss = ckpt['val_loss']
310 | except:
311 | self.train_loss = []
312 | self.val_loss = []
313 | self.network.load_state_dict(ckpt['state_dict'])
314 | self.opt.load_state_dict(ckpt['optimizer'])
315 | LOG('Checkpoint Loaded!')
316 | LOG('Current Epoch: %d' % self.epoch)
317 | self.ckpt_flag = True
318 | except:
319 | WARNING('Cannot load checkpoint from %s. Start loading pre-trained model......' % self.params.resume_from)
320 | else:
321 | WARNING('Checkpoint do not exists. Start loading pre-trained model......')
322 |
323 | def load_model(self):
324 | """
325 | Load ImageNet pre-trained model into MobileNetv2 backbone, only happen when
326 | no checkpoint is loaded
327 | """
328 | if self.ckpt_flag:
329 | LOG('Skip Loading Pre-trained Model......')
330 | else:
331 | if self.params.pre_trained_from is not None and os.path.exists(self.params.pre_trained_from):
332 | try:
333 | LOG('Loading Pre-trained Model at %s' % self.params.pre_trained_from)
334 | pretrain = torch.load(self.params.pre_trained_from)
335 | self.network.load_state_dict(pretrain)
336 | LOG('Pre-trained Model Loaded!')
337 | except:
338 | WARNING('Cannot load pre-trained model. Start training......')
339 | else:
340 | WARNING('Pre-trained model do not exits. Start training......')
341 |
342 | """#############"""
343 | """# Utilities #"""
344 | """#############"""
345 |
346 | def initialize(self):
347 | """
348 | Initializes the model parameters
349 | """
350 | for m in self.modules():
351 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
352 | nn.init.xavier_normal_(m.weight)
353 | if m.bias is not None:
354 | nn.init.constant_(m.bias, 0)
355 | elif isinstance(m, nn.BatchNorm2d):
356 | nn.init.constant_(m.weight, 1)
357 | nn.init.constant_(m.bias, 0)
358 |
359 | def adjust_lr(self):
360 | """
361 | Adjust learning rate at each epoch
362 | """
363 | learning_rate = self.params.base_lr * (1 - float(self.epoch) / self.params.num_epoch) ** self.params.power
364 | for param_group in self.opt.param_groups:
365 | param_group['lr'] = learning_rate
366 | print('Change learning rate into %f' % (learning_rate))
367 | self.summary_writer.add_scalar('learning_rate', learning_rate, self.epoch)
368 |
369 | def plot_curve(self):
370 | """
371 | Plot train/val loss curve
372 | """
373 | x1 = np.arange(self.init_epoch, self.params.num_epoch+1, dtype=np.int).tolist()
374 | x2 = np.linspace(self.init_epoch, self.epoch,
375 | num=(self.epoch-self.init_epoch)//self.params.val_every+1, dtype=np.int64)
376 | plt.plot(x1, self.train_loss, label='train_loss')
377 | plt.plot(x2, self.val_loss, label='val_loss')
378 | plt.legend(loc='best')
379 | plt.title('Train/Val loss')
380 | plt.grid()
381 | plt.xlabel('Epoch')
382 | plt.ylabel('Loss')
383 | plt.show()
384 |
385 |
386 | # """ TEST """
387 | # if __name__ == '__main__':
388 | # params = CIFAR100_params()
389 | # params.dataset_root = '/home/ubuntu/cifar100'
390 | # net = MobileNetv2(params)
391 | # net.save_checkpoint()
--------------------------------------------------------------------------------
/progressbar.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 |
4 | class bar(object):
5 | def __init__(self):
6 | self.start_time = None
7 | self.iter_per_sec = 0
8 | self.time = None
9 |
10 | def click(self, current_idx, max_idx, total_length=40):
11 | """
12 | Each click is a draw procedure of progressbar
13 | :param current_idx: range from 0 to max_idx-1
14 | :param max_idx: maximum iteration
15 | :param total_length: length of progressbar
16 | """
17 | if self.start_time is None:
18 | self.start_time = time.time()
19 | else:
20 | self.time = time.time()-self.start_time
21 | self.iter_per_sec = 1/self.time
22 | perc = current_idx * total_length // max_idx
23 | # print progress bar
24 | print('\r|'+'='*perc+'>'+' '*(total_length-1-perc)+'| %d/%d (%.2f iter/s)' % (current_idx+1,
25 | max_idx,
26 | self.iter_per_sec), end='')
27 | self.start_time = time.time()
28 |
29 | def close(self):
30 | self.__init__()
31 | print('')
32 |
33 | if __name__ == '__main__':
34 | pb = bar()
35 | for i in range(10):
36 | pb.click(i, 10)
37 | time.sleep(0.5)
38 | print(pb.time)
39 | pb.close()
40 |
--------------------------------------------------------------------------------
/transfer_weights.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from config import Params
3 | from network import MobileNetv2_DeepLabv3, LOG
4 | from utils import create_dataset
5 |
6 |
7 | source_weight = '/home/ubuntu/Downloads/model_best.pth.tar'
8 |
9 | weight = torch.load(source_weight, map_location='cuda:0')
10 | state_dict = weight['state_dict']
11 |
12 | params = Params()
13 | params.dataset_root = '/media/ubuntu/disk/cityscapes'
14 | datasets = create_dataset(params)
15 | LOG('Creation Succeed.\n')
16 |
17 | # create model
18 | LOG('Initializing MobileNet and DeepLab......')
19 | net = MobileNetv2_DeepLabv3(params, datasets)
20 |
21 | index = 0
22 | my_net_keys = list(net.state_dict().keys())
23 | my_net_weights = list(net.state_dict().values())
24 | for w in state_dict.values():
25 | if my_net_weights[index].shape == w.shape:
26 | net.state_dict()[my_net_keys[index]] = w
27 | print('Store weight in %s layer' % my_net_keys[index])
28 | index += 1
29 | torch.save(net.state_dict(), './ImageNet_pretrain.pth')
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import cv2
4 | import torch
5 | from torchvision import transforms
6 | from torch.utils.data import Dataset
7 | from torch.utils.data import DataLoader
8 | import zipfile
9 |
10 |
11 | def calc_dataset_stats(dataset, axis=0, ep=1e-7):
12 | return (np.mean(dataset, axis=axis) / 255.0).tolist(), (np.std(dataset + ep, axis=axis) / 255.0).tolist()
13 |
14 |
15 | def create_train_dir(params):
16 | """
17 | Create folder used in training, folder hierarchy:
18 | current folder--exp_folder
19 | |
20 | --summaries
21 | --checkpoints
22 | """
23 | experiment = params.model + '_' + params.dataset
24 | exp_dir = os.path.join(os.getcwd(), experiment)
25 | summary_dir = os.path.join(exp_dir, 'summaries/')
26 | checkpoint_dir = os.path.join(exp_dir, 'checkpoints/')
27 |
28 | dir = [exp_dir, summary_dir, checkpoint_dir]
29 | for dir_ in dir:
30 | if not os.path.exists(dir_):
31 | os.mkdir(dir_)
32 |
33 | return summary_dir, checkpoint_dir
34 |
35 |
36 | def create_dataset(params):
37 | """
38 | Create datasets for training, testing and validating
39 | :return datasets: a python dictionary includes three datasets
40 | datasets[
41 | """
42 | phase = ['train', 'val', 'test']
43 | # if params.dataset_root is not None and not os.path.exists(params.dataset_root):
44 | # raise ValueError('Dataset not exists!')
45 |
46 | transform = {'train': transforms.Compose([Rescale(params.rescale_size),
47 | RandomCrop(params.image_size),
48 | RandomHorizontalFlip(),
49 | ToTensor()
50 | ]),
51 | 'val' : transforms.Compose([Rescale(params.image_size),
52 | ToTensor()
53 | ]),
54 | 'test' : transforms.Compose([Rescale(params.image_size),
55 | ToTensor()
56 | ])}
57 |
58 | # file_dir = {p: os.path.join(params.dataset_root, p) for p in phase}
59 |
60 | # datasets = {Cityscapes(file_dir[p], mode=p, transforms=transform[p]) for p in phase}
61 | datasets = {p: Cityscapes(params.dataset_root, mode=p, transforms=transform[p]) for p in phase}
62 |
63 | return datasets
64 |
65 |
66 | class Cityscapes(Dataset):
67 | def __init__(self, dataset_dir, mode='train', transforms=None):
68 | """
69 | Create Dataset subclass on cityscapes dataset
70 | :param dataset_dir: the path to dataset root, eg. '/media/ubuntu/disk/cityscapes'
71 | :param mode: phase, 'train', 'test' or 'eval'
72 | :param transforms: transformation
73 | """
74 | self.dataset = dataset_dir
75 | self.transforms = transforms
76 | require_file = ['trainImages.txt', 'trainLabels.txt',
77 | 'valImages.txt', 'valLabels.txt',
78 | 'testImages.txt', 'testLabels.txt']
79 |
80 | # check requirement
81 | if mode not in ['train', 'test', 'val']:
82 | raise ValueError('Unsupported mode %s' % mode)
83 |
84 | if not os.path.exists(self.dataset):
85 | raise ValueError('Dataset not exists at %s' % self.dataset)
86 |
87 | for file in require_file:
88 | if file not in os.listdir(self.dataset):
89 | # raise ValueError('Cannot find file %s in dataset root folder!' % file)
90 | generate_txt(self.dataset, file)
91 |
92 | # create image and label list
93 | self.image_list = []
94 | self.label_list = []
95 | if mode == 'train':
96 | for line in open(os.path.join(self.dataset, 'trainImages.txt')):
97 | self.image_list.append(line.strip())
98 | for line in open(os.path.join(self.dataset, 'trainLabels.txt')):
99 | self.label_list.append(line.strip())
100 | elif mode == 'val':
101 | for line in open(os.path.join(self.dataset, 'valImages.txt')):
102 | self.image_list.append(line.strip())
103 | for line in open(os.path.join(self.dataset, 'valLabels.txt')):
104 | self.label_list.append(line.strip())
105 | else:
106 | for line in open(os.path.join(self.dataset, 'testImages.txt')):
107 | self.image_list.append(line.strip())
108 | for line in open(os.path.join(self.dataset, 'testLabels.txt')):
109 | self.label_list.append(line.strip())
110 |
111 | def __len__(self):
112 | return len(self.image_list)
113 |
114 | def __getitem__(self, index):
115 | """
116 | Overrides default method
117 | tips: 3 channels of label image are the same
118 | """
119 | image = cv2.imread(os.path.join(self.dataset, self.image_list[index]))
120 | label = cv2.imread(os.path.join(self.dataset, self.label_list[index])) # label.size (1024, 2048, 3)
121 | image_name = self.image_list[index]
122 | label_name = self.label_list[index]
123 | if label.min() == -1:
124 | raise ValueError
125 |
126 | sample = {'image': image, 'label': label[:, :, 0],
127 | 'image_name': image_name, 'label_name': label_name}
128 |
129 | if self.transforms:
130 | sample = self.transforms(sample)
131 |
132 | return sample
133 |
134 |
135 | class Rescale(object):
136 | """
137 | Rescale the image in a sample to a given size.
138 | :param output_size (tuple or int): Desired output size. If tuple, output is
139 | matched to output_size. If int, smaller of image edges is matched
140 | to output_size keeping aspect ratio the same.
141 | """
142 | def __init__(self, output_size):
143 | assert isinstance(output_size, (int, tuple))
144 | self.output_size = output_size
145 |
146 | def __call__(self, sample):
147 | image, label = sample['image'], sample['label']
148 |
149 | h, w = image.shape[:2]
150 | if isinstance(self.output_size, int):
151 | if h > w:
152 | new_h, new_w = self.output_size * h / w, self.output_size
153 | else:
154 | new_h, new_w = self.output_size, self.output_size * w / h
155 | else:
156 | new_h, new_w = self.output_size
157 |
158 | new_h, new_w = int(new_h), int(new_w)
159 |
160 | image = cv2.resize(image, (new_w, new_h))
161 | label = cv2.resize(label, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
162 |
163 | sample['image'], sample['label'] = image, label
164 |
165 | return sample
166 |
167 |
168 | class ToTensor(object):
169 | """
170 | Convert ndarrays in sample to Tensors with normalization.
171 | """
172 | def __init__(self, output_stride=16):
173 | self.output_stride = output_stride
174 |
175 | def __call__(self, sample):
176 | image, label = sample['image'], sample['label']
177 | # swap color axis because
178 | # numpy image: H x W x C
179 | # torch image: C X H X W
180 | image = image.transpose((2, 0, 1)).astype(np.float32)
181 |
182 | # reset label shape
183 | # w, h = label.shape[0]//self.output_stride, label.shape[1]//self.output_stride
184 | # label = cv2.resize(label, (h, w), interpolation=cv2.INTER_NEAREST).astype(np.int64)
185 | # label[label == 255] = 19
186 | label = label.astype(np.int64)
187 |
188 | # normalize image
189 | image /= 255
190 |
191 | sample['image'], sample['label'] = torch.from_numpy(image), torch.from_numpy(label)
192 |
193 | return sample
194 |
195 |
196 | class RandomHorizontalFlip(object):
197 | """
198 | Random flip image and label horizontally
199 | """
200 | def __call__(self, sample, p=0.5):
201 | image, label = sample['image'], sample['label']
202 | if np.random.uniform(0, 1) < p:
203 | image = cv2.flip(image, 1)
204 | label = cv2.flip(label, 1)
205 |
206 | sample['image'], sample['label'] = image, label
207 |
208 | return sample
209 |
210 |
211 | class RandomCrop(object):
212 | """
213 | Crop randomly the image in a sample.
214 |
215 | :param output_size (tuple or int): Desired output size. If int, square crop
216 | is made.
217 | """
218 |
219 | def __init__(self, output_size):
220 | assert isinstance(output_size, (int, tuple))
221 | if isinstance(output_size, int):
222 | self.output_size = (output_size, output_size)
223 | else:
224 | assert len(output_size) == 2
225 | self.output_size = output_size
226 |
227 | def __call__(self, sample):
228 | image, label = sample['image'], sample['label']
229 |
230 | h, w = image.shape[:2]
231 | new_h, new_w = self.output_size
232 |
233 | top = np.random.randint(0, h - new_h)
234 | left = np.random.randint(0, w - new_w)
235 |
236 | image = image[top: top + new_h, left: left + new_w, :]
237 |
238 | label = label[top: top + new_h, left: left + new_w]
239 |
240 | sample['image'], sample['label'] = image, label
241 |
242 | return sample
243 |
244 |
245 | def print_config(params):
246 | for name, value in sorted(vars(params).items()):
247 | print('\t%-20s:%s' % (name, str(value)))
248 | print('')
249 |
250 |
251 | def generate_txt(dataset_root, file):
252 | """
253 | Generate txt files that not exists but required
254 | :param dataset_root: the path to dataset root, eg. '/media/ubuntu/disk/cityscapes'
255 | :param file: txt file need to generate
256 | """
257 | with open(os.path.join(dataset_root, file), 'w') as f:
258 | # get mode and folder
259 | if 'train' in file:
260 | mode = 'train'
261 | elif 'test' in file:
262 | mode = 'test'
263 | else:
264 | mode = 'val'
265 | folder = 'leftImg8bit' if 'Image' in file else 'gtFine'
266 |
267 | path = os.path.join(os.path.join(dataset_root, folder), mode)
268 |
269 | assert os.path.exists(path), 'Cannot find %s set in folder %s' % (mode, folder)
270 |
271 | # collect images or labels
272 | if 'Images' in file:
273 | cities = os.listdir(path)
274 | for city in cities:
275 | # write them into txt
276 | for image in os.listdir(os.path.join(path, city)):
277 | print(folder + '/' + mode + '/' + city + '/' + image, file=f)
278 | else:
279 | image_txt = mode+'Images.txt'
280 | if image_txt in os.listdir(dataset_root):
281 | for line in open(os.path.join(dataset_root, image_txt)):
282 | line = line.strip()
283 | line = line.replace('leftImg8bit/', 'gtFine/')
284 | line = line.replace('_leftImg8bit', '_gtFine_labelTrainIds')
285 | print(line, file=f)
286 | else:
287 | generate_txt(dataset_root, image_txt)
288 |
289 |
290 | def generate_zip(dataset_root):
291 | azip = zipfile.ZipFile('submit.zip', 'w')
292 | txt = os.path.join(dataset_root, 'testLabels.txt')
293 | if os.path.exists(txt):
294 | for line in open(txt):
295 | line = line.strip()
296 | line = line.replace('labelTrainIds', 'labelIds')
297 | azip.write(os.path.join(dataset_root, line), arcname=line)
298 | azip.close()
299 | else:
300 | generate_txt(dataset_root, 'testLabels.txt')
301 |
302 |
303 | if __name__ == '__main__':
304 | dir = '/media/ubuntu/disk/cityscapes'
305 | # dataset = Cityscapes(dir)
306 | # loader = DataLoader(dataset,
307 | # batch_size=10,
308 | # shuffle=True,
309 | # num_workers=8)
310 | # for idx, batch in enumerate(loader):
311 | # img = batch['image']
312 | # lb = batch['label']
313 | # print(idx, img.shape)
314 | generate_zip(dir)
315 | # tips: the last batch may not be as big as batch_size
316 |
317 |
--------------------------------------------------------------------------------