├── .gitignore
├── LICENSE
├── README.md
├── checkpoints
└── experiment_name
│ ├── latest_net_D.pth
│ ├── latest_net_G.pth
│ ├── opt.txt
│ └── web
│ └── index.html
├── data
├── __init__.py
├── aligned_dataset.py
├── base_data_loader.py
├── base_dataset.py
├── custom_dataset_data_loader.py
├── data_loader.py
├── image_folder.py
├── single_dataset.py
└── unaligned_dataset.py
├── datasets
├── combine_A_and_B.py
└── helper functions
│ └── grayscale.py
├── images
├── animation1.gif
├── animation2.gif
├── animation3.gif
├── animation4.gif
├── results.png
├── test1_blur.jpg
├── test1_restored.jpg
├── test1_sharp.jpg
├── yolo_b.jpg
├── yolo_o.jpg
└── yolo_s.jpg
├── models
├── __init__.py
├── base_model.py
├── conditional_gan_model.py
├── losses.py
├── models.py
├── networks.py
└── test_model.py
├── motion_blur
├── __init__.py
├── blur_image.py
├── generate_PSF.py
└── generate_trajectory.py
├── options
├── __init__.py
├── base_options.py
├── test_options.py
└── train_options.py
├── test.py
├── train.py
└── util
├── __init__.py
├── get_data.py
├── html.py
├── image_pool.py
├── metrics.py
├── png.py
├── util.py
└── visualizer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.iml
2 | *.xml
3 | *.pyc
4 | *.png
5 | *.txt
6 | *.pth
7 | *.html
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | * Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | * Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 |
14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 |
25 |
26 | --------------------------- LICENSE FOR pix2pix --------------------------------
27 | BSD License
28 |
29 | For pix2pix software
30 | Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
31 | All rights reserved.
32 |
33 | Redistribution and use in source and binary forms, with or without
34 | modification, are permitted provided that the following conditions are met:
35 |
36 | * Redistributions of source code must retain the above copyright notice, this
37 | list of conditions and the following disclaimer.
38 |
39 | * Redistributions in binary form must reproduce the above copyright notice,
40 | this list of conditions and the following disclaimer in the documentation
41 | and/or other materials provided with the distribution.
42 |
43 | ----------------------------- LICENSE FOR DCGAN --------------------------------
44 | BSD License
45 |
46 | For dcgan.torch software
47 |
48 | Copyright (c) 2015, Facebook, Inc. All rights reserved.
49 |
50 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
51 |
52 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
53 |
54 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
55 |
56 | Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
57 |
58 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
59 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DeblurGAN
2 | [arXiv Paper Version](https://arxiv.org/pdf/1711.07064.pdf)
3 |
4 | Pytorch implementation of the paper DeblurGAN: Blind Motion Deblurring Using Conditional Adversarial Networks.
5 |
6 | Our network takes blurry image as an input and procude the corresponding sharp estimate, as in the example:
7 |
8 |
9 |
10 | The model we use is Conditional Wasserstein GAN with Gradient Penalty + Perceptual loss based on VGG-19 activations. Such architecture also gives good results on other image-to-image translation problems (super resolution, colorization, inpainting, dehazing etc.)
11 |
12 | ## How to run
13 |
14 | ### Prerequisites
15 | - NVIDIA GPU + CUDA CuDNN (CPU untested, feedback appreciated)
16 | - Pytorch
17 |
18 | Download weights from [Google Drive](https://drive.google.com/file/d/1liKzdjMRHZ-i5MWhC72EL7UZLNPj5_8Y/view?usp=sharing) . Note that during the inference you need to keep only Generator weights.
19 |
20 | Put the weights into
21 | ```bash
22 | /.checkpoints/experiment_name
23 | ```
24 | To test a model put your blurry images into a folder and run:
25 | ```bash
26 | python test.py --dataroot /.path_to_your_data --model test --dataset_mode single --learn_residual
27 | ```
28 | ## Data
29 | Download dataset for Object Detection benchmark from [Google Drive](https://drive.google.com/file/d/1CPMBmRj-jBDO2ax4CxkBs9iczIFrs8VA/view?usp=sharing)
30 |
31 | ## Train
32 |
33 | If you want to train the model on your data run the following command to create image pairs:
34 | ```bash
35 | python datasets/combine_A_and_B.py --fold_A /path/to/data/A --fold_B /path/to/data/B --fold_AB /path/to/data
36 | ```
37 | And then the following command to train the model
38 |
39 | ```bash
40 | python train.py --dataroot /.path_to_your_data --learn_residual --resize_or_crop crop --fineSize CROP_SIZE (we used 256)
41 | ```
42 |
43 | ## Other Implementations
44 |
45 | [Keras Blog](https://blog.sicara.com/keras-generative-adversarial-networks-image-deblurring-45e3ab6977b5)
46 |
47 | [Keras Repository](https://github.com/RaphaelMeudec/deblur-gan)
48 |
49 |
50 |
51 | ## Citation
52 |
53 | If you find our code helpful in your research or work please cite our paper.
54 |
55 | ```
56 | @article{DeblurGAN,
57 | title = {DeblurGAN: Blind Motion Deblurring Using Conditional Adversarial Networks},
58 | author = {Kupyn, Orest and Budzan, Volodymyr and Mykhailych, Mykola and Mishkin, Dmytro and Matas, Jiri},
59 | journal = {ArXiv e-prints},
60 | eprint = {1711.07064},
61 | year = 2017
62 | }
63 | ```
64 |
65 | ## Acknowledgments
66 | Code borrows heavily from [pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). The images were taken from GoPRO test dataset - [DeepDeblur](https://github.com/SeungjunNah/DeepDeblur_release)
67 |
68 |
69 |
--------------------------------------------------------------------------------
/checkpoints/experiment_name/latest_net_D.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/checkpoints/experiment_name/latest_net_D.pth
--------------------------------------------------------------------------------
/checkpoints/experiment_name/latest_net_G.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/checkpoints/experiment_name/latest_net_G.pth
--------------------------------------------------------------------------------
/checkpoints/experiment_name/opt.txt:
--------------------------------------------------------------------------------
1 | ------------ Options -------------
2 | batchSize: 1
3 | beta1: 0.5
4 | checkpoints_dir: ./checkpoints
5 | continue_train: False
6 | dataroot: D:\Photos\TrainingData\BlurredSharp\combined
7 | dataset_mode: aligned
8 | display_freq: 100
9 | display_id: 1
10 | display_port: 8097
11 | display_single_pane_ncols: 0
12 | display_winsize: 256
13 | epoch_count: 1
14 | fineSize: 256
15 | gan_type: wgan-gp
16 | gpu_ids: [0]
17 | identity: 0.0
18 | input_nc: 3
19 | isTrain: True
20 | lambda_A: 100.0
21 | lambda_B: 10.0
22 | learn_residual: False
23 | loadSizeX: 640
24 | loadSizeY: 360
25 | lr: 0.0001
26 | max_dataset_size: inf
27 | model: content_gan
28 | nThreads: 2
29 | n_layers_D: 3
30 | name: experiment_name
31 | ndf: 64
32 | ngf: 64
33 | niter: 150
34 | niter_decay: 150
35 | no_dropout: False
36 | no_flip: False
37 | no_html: False
38 | norm: instance
39 | output_nc: 3
40 | phase: train
41 | pool_size: 50
42 | print_freq: 100
43 | resize_or_crop: resize_and_crop
44 | save_epoch_freq: 5
45 | save_latest_freq: 5000
46 | serial_batches: False
47 | which_direction: AtoB
48 | which_epoch: latest
49 | which_model_netD: basic
50 | which_model_netG: resnet_9blocks
51 | -------------- End ----------------
52 |
--------------------------------------------------------------------------------
/checkpoints/experiment_name/web/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | Experiment name = experiment_name
5 |
6 |
7 |
8 | Results of Epoch [33]
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 | Blurred_Train
17 |
18 | |
19 |
20 |
21 |
22 |
23 |
24 | Restored_Train
25 |
26 | |
27 |
28 |
29 |
30 |
31 |
32 | Sharp_Train
33 |
34 | |
35 |
36 |
37 | Results of Epoch [32]
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 | Blurred_Train
46 |
47 | |
48 |
49 |
50 |
51 |
52 |
53 | Restored_Train
54 |
55 | |
56 |
57 |
58 |
59 |
60 |
61 | Sharp_Train
62 |
63 | |
64 |
65 |
66 | Results of Epoch [31]
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 | Blurred_Train
75 |
76 | |
77 |
78 |
79 |
80 |
81 |
82 | Restored_Train
83 |
84 | |
85 |
86 |
87 |
88 |
89 |
90 | Sharp_Train
91 |
92 | |
93 |
94 |
95 | Results of Epoch [30]
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 | Blurred_Train
104 |
105 | |
106 |
107 |
108 |
109 |
110 |
111 | Restored_Train
112 |
113 | |
114 |
115 |
116 |
117 |
118 |
119 | Sharp_Train
120 |
121 | |
122 |
123 |
124 | Results of Epoch [29]
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 | Blurred_Train
133 |
134 | |
135 |
136 |
137 |
138 |
139 |
140 | Restored_Train
141 |
142 | |
143 |
144 |
145 |
146 |
147 |
148 | Sharp_Train
149 |
150 | |
151 |
152 |
153 | Results of Epoch [28]
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 | Blurred_Train
162 |
163 | |
164 |
165 |
166 |
167 |
168 |
169 | Restored_Train
170 |
171 | |
172 |
173 |
174 |
175 |
176 |
177 | Sharp_Train
178 |
179 | |
180 |
181 |
182 | Results of Epoch [27]
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 | Blurred_Train
191 |
192 | |
193 |
194 |
195 |
196 |
197 |
198 | Restored_Train
199 |
200 | |
201 |
202 |
203 |
204 |
205 |
206 | Sharp_Train
207 |
208 | |
209 |
210 |
211 | Results of Epoch [26]
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 | Blurred_Train
220 |
221 | |
222 |
223 |
224 |
225 |
226 |
227 | Restored_Train
228 |
229 | |
230 |
231 |
232 |
233 |
234 |
235 | Sharp_Train
236 |
237 | |
238 |
239 |
240 | Results of Epoch [25]
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 | Blurred_Train
249 |
250 | |
251 |
252 |
253 |
254 |
255 |
256 | Restored_Train
257 |
258 | |
259 |
260 |
261 |
262 |
263 |
264 | Sharp_Train
265 |
266 | |
267 |
268 |
269 | Results of Epoch [24]
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 | Blurred_Train
278 |
279 | |
280 |
281 |
282 |
283 |
284 |
285 | Restored_Train
286 |
287 | |
288 |
289 |
290 |
291 |
292 |
293 | Sharp_Train
294 |
295 | |
296 |
297 |
298 | Results of Epoch [23]
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 | Blurred_Train
307 |
308 | |
309 |
310 |
311 |
312 |
313 |
314 | Restored_Train
315 |
316 | |
317 |
318 |
319 |
320 |
321 |
322 | Sharp_Train
323 |
324 | |
325 |
326 |
327 | Results of Epoch [22]
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 | Blurred_Train
336 |
337 | |
338 |
339 |
340 |
341 |
342 |
343 | Restored_Train
344 |
345 | |
346 |
347 |
348 |
349 |
350 |
351 | Sharp_Train
352 |
353 | |
354 |
355 |
356 | Results of Epoch [21]
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 | Blurred_Train
365 |
366 | |
367 |
368 |
369 |
370 |
371 |
372 | Restored_Train
373 |
374 | |
375 |
376 |
377 |
378 |
379 |
380 | Sharp_Train
381 |
382 | |
383 |
384 |
385 | Results of Epoch [20]
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 | Blurred_Train
394 |
395 | |
396 |
397 |
398 |
399 |
400 |
401 | Restored_Train
402 |
403 | |
404 |
405 |
406 |
407 |
408 |
409 | Sharp_Train
410 |
411 | |
412 |
413 |
414 | Results of Epoch [19]
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 | Blurred_Train
423 |
424 | |
425 |
426 |
427 |
428 |
429 |
430 | Restored_Train
431 |
432 | |
433 |
434 |
435 |
436 |
437 |
438 | Sharp_Train
439 |
440 | |
441 |
442 |
443 | Results of Epoch [18]
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 | Blurred_Train
452 |
453 | |
454 |
455 |
456 |
457 |
458 |
459 | Restored_Train
460 |
461 | |
462 |
463 |
464 |
465 |
466 |
467 | Sharp_Train
468 |
469 | |
470 |
471 |
472 | Results of Epoch [17]
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 | Blurred_Train
481 |
482 | |
483 |
484 |
485 |
486 |
487 |
488 | Restored_Train
489 |
490 | |
491 |
492 |
493 |
494 |
495 |
496 | Sharp_Train
497 |
498 | |
499 |
500 |
501 | Results of Epoch [16]
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 | Blurred_Train
510 |
511 | |
512 |
513 |
514 |
515 |
516 |
517 | Restored_Train
518 |
519 | |
520 |
521 |
522 |
523 |
524 |
525 | Sharp_Train
526 |
527 | |
528 |
529 |
530 | Results of Epoch [15]
531 |
532 |
533 |
534 |
535 |
536 |
537 |
538 | Blurred_Train
539 |
540 | |
541 |
542 |
543 |
544 |
545 |
546 | Restored_Train
547 |
548 | |
549 |
550 |
551 |
552 |
553 |
554 | Sharp_Train
555 |
556 | |
557 |
558 |
559 | Results of Epoch [14]
560 |
561 |
562 |
563 |
564 |
565 |
566 |
567 | Blurred_Train
568 |
569 | |
570 |
571 |
572 |
573 |
574 |
575 | Restored_Train
576 |
577 | |
578 |
579 |
580 |
581 |
582 |
583 | Sharp_Train
584 |
585 | |
586 |
587 |
588 | Results of Epoch [13]
589 |
590 |
591 |
592 |
593 |
594 |
595 |
596 | Blurred_Train
597 |
598 | |
599 |
600 |
601 |
602 |
603 |
604 | Restored_Train
605 |
606 | |
607 |
608 |
609 |
610 |
611 |
612 | Sharp_Train
613 |
614 | |
615 |
616 |
617 | Results of Epoch [12]
618 |
619 |
620 |
621 |
622 |
623 |
624 |
625 | Blurred_Train
626 |
627 | |
628 |
629 |
630 |
631 |
632 |
633 | Restored_Train
634 |
635 | |
636 |
637 |
638 |
639 |
640 |
641 | Sharp_Train
642 |
643 | |
644 |
645 |
646 | Results of Epoch [11]
647 |
648 |
649 |
650 |
651 |
652 |
653 |
654 | Blurred_Train
655 |
656 | |
657 |
658 |
659 |
660 |
661 |
662 | Restored_Train
663 |
664 | |
665 |
666 |
667 |
668 |
669 |
670 | Sharp_Train
671 |
672 | |
673 |
674 |
675 | Results of Epoch [10]
676 |
677 |
678 |
679 |
680 |
681 |
682 |
683 | Blurred_Train
684 |
685 | |
686 |
687 |
688 |
689 |
690 |
691 | Restored_Train
692 |
693 | |
694 |
695 |
696 |
697 |
698 |
699 | Sharp_Train
700 |
701 | |
702 |
703 |
704 | Results of Epoch [9]
705 |
706 |
707 |
708 |
709 |
710 |
711 |
712 | Blurred_Train
713 |
714 | |
715 |
716 |
717 |
718 |
719 |
720 | Restored_Train
721 |
722 | |
723 |
724 |
725 |
726 |
727 |
728 | Sharp_Train
729 |
730 | |
731 |
732 |
733 | Results of Epoch [8]
734 |
735 |
736 |
737 |
738 |
739 |
740 |
741 | Blurred_Train
742 |
743 | |
744 |
745 |
746 |
747 |
748 |
749 | Restored_Train
750 |
751 | |
752 |
753 |
754 |
755 |
756 |
757 | Sharp_Train
758 |
759 | |
760 |
761 |
762 | Results of Epoch [7]
763 |
764 |
765 |
766 |
767 |
768 |
769 |
770 | Blurred_Train
771 |
772 | |
773 |
774 |
775 |
776 |
777 |
778 | Restored_Train
779 |
780 | |
781 |
782 |
783 |
784 |
785 |
786 | Sharp_Train
787 |
788 | |
789 |
790 |
791 | Results of Epoch [6]
792 |
793 |
794 |
795 |
796 |
797 |
798 |
799 | Blurred_Train
800 |
801 | |
802 |
803 |
804 |
805 |
806 |
807 | Restored_Train
808 |
809 | |
810 |
811 |
812 |
813 |
814 |
815 | Sharp_Train
816 |
817 | |
818 |
819 |
820 | Results of Epoch [5]
821 |
822 |
823 |
824 |
825 |
826 |
827 |
828 | Blurred_Train
829 |
830 | |
831 |
832 |
833 |
834 |
835 |
836 | Restored_Train
837 |
838 | |
839 |
840 |
841 |
842 |
843 |
844 | Sharp_Train
845 |
846 | |
847 |
848 |
849 | Results of Epoch [4]
850 |
851 |
852 |
853 |
854 |
855 |
856 |
857 | Blurred_Train
858 |
859 | |
860 |
861 |
862 |
863 |
864 |
865 | Restored_Train
866 |
867 | |
868 |
869 |
870 |
871 |
872 |
873 | Sharp_Train
874 |
875 | |
876 |
877 |
878 | Results of Epoch [3]
879 |
880 |
881 |
882 |
883 |
884 |
885 |
886 | Blurred_Train
887 |
888 | |
889 |
890 |
891 |
892 |
893 |
894 | Restored_Train
895 |
896 | |
897 |
898 |
899 |
900 |
901 |
902 | Sharp_Train
903 |
904 | |
905 |
906 |
907 | Results of Epoch [2]
908 |
909 |
910 |
911 |
912 |
913 |
914 |
915 | Blurred_Train
916 |
917 | |
918 |
919 |
920 |
921 |
922 |
923 | Restored_Train
924 |
925 | |
926 |
927 |
928 |
929 |
930 |
931 | Sharp_Train
932 |
933 | |
934 |
935 |
936 | Results of Epoch [1]
937 |
938 |
939 |
940 |
941 |
942 |
943 |
944 | Blurred_Train
945 |
946 | |
947 |
948 |
949 |
950 |
951 |
952 | Restored_Train
953 |
954 | |
955 |
956 |
957 |
958 |
959 |
960 | Sharp_Train
961 |
962 | |
963 |
964 |
965 |
966 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/data/__init__.py
--------------------------------------------------------------------------------
/data/aligned_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import random
3 | import torchvision.transforms as transforms
4 | import torch
5 | from data.base_dataset import BaseDataset
6 | from data.image_folder import make_dataset
7 | from PIL import Image
8 |
9 |
10 | class AlignedDataset(BaseDataset):
11 | def __init__(self, opt):
12 | # super(AlignedDataset,self).__init__(opt)
13 | self.opt = opt
14 | self.root = opt.dataroot
15 | self.dir_AB = os.path.join(opt.dataroot, opt.phase)
16 |
17 | self.AB_paths = sorted(make_dataset(self.dir_AB))
18 |
19 | #assert(opt.resize_or_crop == 'resize_and_crop')
20 |
21 | transform_list = [transforms.ToTensor(),
22 | transforms.Normalize((0.5, 0.5, 0.5),
23 | (0.5, 0.5, 0.5))]
24 |
25 | self.transform = transforms.Compose(transform_list)
26 |
27 | def __getitem__(self, index):
28 | AB_path = self.AB_paths[index]
29 | AB = Image.open(AB_path).convert('RGB')
30 | AB = AB.resize((self.opt.loadSizeX * 2, self.opt.loadSizeY), Image.BICUBIC)
31 | AB = self.transform(AB)
32 |
33 | w_total = AB.size(2)
34 | w = int(w_total / 2)
35 | h = AB.size(1)
36 | w_offset = random.randint(0, max(0, w - self.opt.fineSize - 1))
37 | h_offset = random.randint(0, max(0, h - self.opt.fineSize - 1))
38 |
39 | A = AB[:, h_offset:h_offset + self.opt.fineSize,
40 | w_offset:w_offset + self.opt.fineSize]
41 | B = AB[:, h_offset:h_offset + self.opt.fineSize,
42 | w + w_offset:w + w_offset + self.opt.fineSize]
43 |
44 | if (not self.opt.no_flip) and random.random() < 0.5:
45 | idx = [i for i in range(A.size(2) - 1, -1, -1)]
46 | idx = torch.LongTensor(idx)
47 | A = A.index_select(2, idx)
48 | B = B.index_select(2, idx)
49 |
50 | return {'A': A, 'B': B,
51 | 'A_paths': AB_path, 'B_paths': AB_path}
52 |
53 | def __len__(self):
54 | return len(self.AB_paths)
55 |
56 | def name(self):
57 | return 'AlignedDataset'
58 |
--------------------------------------------------------------------------------
/data/base_data_loader.py:
--------------------------------------------------------------------------------
1 |
2 | class BaseDataLoader():
3 | def __init__(self):
4 | pass
5 |
6 | def initialize(self, opt):
7 | self.opt = opt
8 | pass
9 |
10 | def load_data():
11 | return None
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | from PIL import Image
3 | import torchvision.transforms as transforms
4 |
5 | class BaseDataset(data.Dataset):
6 | def __init__(self):
7 | super(BaseDataset, self).__init__()
8 |
9 | def name(self):
10 | return 'BaseDataset'
11 |
12 | # def initialize(self, opt):
13 | # pass
14 |
15 | def get_transform(opt):
16 | transform_list = []
17 | if opt.resize_or_crop == 'resize_and_crop':
18 | osize = [opt.loadSizeX, opt.loadSizeY]
19 | transform_list.append(transforms.Resize(osize, Image.BICUBIC))
20 | transform_list.append(transforms.RandomCrop(opt.fineSize))
21 | elif opt.resize_or_crop == 'crop':
22 | transform_list.append(transforms.RandomCrop(opt.fineSize))
23 | elif opt.resize_or_crop == 'scale_width':
24 | transform_list.append(transforms.Lambda(
25 | lambda img: __scale_width(img, opt.fineSize)))
26 | elif opt.resize_or_crop == 'scale_width_and_crop':
27 | transform_list.append(transforms.Lambda(
28 | lambda img: __scale_width(img, opt.loadSizeX)))
29 | transform_list.append(transforms.RandomCrop(opt.fineSize))
30 |
31 | if opt.isTrain and not opt.no_flip:
32 | transform_list.append(transforms.RandomHorizontalFlip())
33 |
34 | transform_list += [transforms.ToTensor(),
35 | transforms.Normalize((0.5, 0.5, 0.5),
36 | (0.5, 0.5, 0.5))]
37 | return transforms.Compose(transform_list)
38 |
39 | def __scale_width(img, target_width):
40 | ow, oh = img.size
41 | if (ow == target_width):
42 | return img
43 | w = target_width
44 | h = int(target_width * oh / ow)
45 | return img.resize((w, h), Image.BICUBIC)
46 |
--------------------------------------------------------------------------------
/data/custom_dataset_data_loader.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data
2 | from data.base_data_loader import BaseDataLoader
3 |
4 |
5 | def CreateDataset(opt):
6 | dataset = None
7 | if opt.dataset_mode == 'aligned':
8 | from data.aligned_dataset import AlignedDataset
9 | dataset = AlignedDataset(opt)
10 | elif opt.dataset_mode == 'unaligned':
11 | from data.unaligned_dataset import UnalignedDataset
12 | dataset = UnalignedDataset()
13 | elif opt.dataset_mode == 'single':
14 | from data.single_dataset import SingleDataset
15 | dataset = SingleDataset()
16 | dataset.initialize(opt)
17 | else:
18 | raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)
19 |
20 | print("dataset [%s] was created" % (dataset.name()))
21 | # dataset.initialize(opt)
22 | return dataset
23 |
24 |
25 | class CustomDatasetDataLoader(BaseDataLoader):
26 | def name(self):
27 | return 'CustomDatasetDataLoader'
28 |
29 | def __init__(self, opt):
30 | super(CustomDatasetDataLoader,self).initialize(opt)
31 | print("Opt.nThreads = ", opt.nThreads)
32 | self.dataset = CreateDataset(opt)
33 | self.dataloader = torch.utils.data.DataLoader(
34 | self.dataset,
35 | batch_size=opt.batchSize,
36 | shuffle=not opt.serial_batches,
37 | num_workers=int(opt.nThreads)
38 | )
39 |
40 | def load_data(self):
41 | return self.dataloader
42 |
43 | def __len__(self):
44 | return min(len(self.dataset), self.opt.max_dataset_size)
45 |
--------------------------------------------------------------------------------
/data/data_loader.py:
--------------------------------------------------------------------------------
1 |
2 | def CreateDataLoader(opt):
3 | from data.custom_dataset_data_loader import CustomDatasetDataLoader
4 | data_loader = CustomDatasetDataLoader(opt)
5 | print(data_loader.name())
6 | # data_loader.initialize(opt)
7 | return data_loader
8 |
--------------------------------------------------------------------------------
/data/image_folder.py:
--------------------------------------------------------------------------------
1 | ###############################################################################
2 | # Code from
3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
4 | # Modified the original code so that it also loads images from the current
5 | # directory as well as the subdirectories
6 | ###############################################################################
7 |
8 | import torch.utils.data as data
9 |
10 | from PIL import Image
11 | import os
12 | import os.path
13 |
14 | IMG_EXTENSIONS = [
15 | '.jpg', '.JPG', '.jpeg', '.JPEG',
16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
17 | ]
18 |
19 |
20 | def is_image_file(filename):
21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
22 |
23 |
24 | def make_dataset(dir):
25 | images = []
26 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
27 |
28 | for root, _, fnames in sorted(os.walk(dir)):
29 | for fname in fnames:
30 | if is_image_file(fname):
31 | path = os.path.join(root, fname)
32 | images.append(path)
33 |
34 | return images
35 |
36 |
37 | def default_loader(path):
38 | return Image.open(path).convert('RGB')
39 |
40 |
41 | class ImageFolder(data.Dataset):
42 |
43 | def __init__(self, root, transform=None, return_paths=False,
44 | loader=default_loader):
45 | imgs = make_dataset(root)
46 | if len(imgs) == 0:
47 | raise(RuntimeError("Found 0 images in: " + root + "\n"
48 | "Supported image extensions are: " +
49 | ",".join(IMG_EXTENSIONS)))
50 |
51 | self.root = root
52 | self.imgs = imgs
53 | self.transform = transform
54 | self.return_paths = return_paths
55 | self.loader = loader
56 |
57 | def __getitem__(self, index):
58 | path = self.imgs[index]
59 | img = self.loader(path)
60 | if self.transform is not None:
61 | img = self.transform(img)
62 | if self.return_paths:
63 | return img, path
64 | else:
65 | return img
66 |
67 | def __len__(self):
68 | return len(self.imgs)
69 |
--------------------------------------------------------------------------------
/data/single_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import torchvision.transforms as transforms
3 | from data.base_dataset import BaseDataset, get_transform
4 | from data.image_folder import make_dataset
5 | from PIL import Image
6 |
7 |
8 | class SingleDataset(BaseDataset):
9 | def initialize(self, opt):
10 | self.opt = opt
11 | self.root = opt.dataroot
12 | self.dir_A = os.path.join(opt.dataroot)
13 |
14 | self.A_paths = make_dataset(self.dir_A)
15 |
16 | self.A_paths = sorted(self.A_paths)
17 |
18 | self.transform = get_transform(opt)
19 |
20 | def __getitem__(self, index):
21 | A_path = self.A_paths[index]
22 |
23 | A_img = Image.open(A_path).convert('RGB')
24 |
25 | A_img = self.transform(A_img)
26 |
27 | return {'A': A_img, 'A_paths': A_path}
28 |
29 | def __len__(self):
30 | return len(self.A_paths)
31 |
32 | def name(self):
33 | return 'SingleImageDataset'
34 |
--------------------------------------------------------------------------------
/data/unaligned_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import torchvision.transforms as transforms
3 | from data.base_dataset import BaseDataset, get_transform
4 | from data.image_folder import make_dataset
5 | from PIL import Image
6 | import PIL
7 | from pdb import set_trace as st
8 | import random
9 | import cv2
10 |
11 | class UnalignedDataset(BaseDataset):
12 | def initialize(self, opt):
13 | self.opt = opt
14 | self.root = opt.dataroot
15 | self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A')
16 | self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B')
17 |
18 | self.A_paths = make_dataset(self.dir_A)
19 | self.B_paths = make_dataset(self.dir_B)
20 |
21 | self.A_paths = sorted(self.A_paths)
22 | self.B_paths = sorted(self.B_paths)
23 | self.A_size = len(self.A_paths)
24 | self.B_size = len(self.B_paths)
25 | self.transform = get_transform(opt)
26 |
27 | def __getitem__(self, index):
28 | A_path = self.A_paths[index % self.A_size]
29 | index_A = index % self.A_size
30 | B_path = self.B_paths[index % self.A_size]
31 | # print('(A, B) = (%d, %d)' % (index_A, index_B))
32 | A_img = Image.open(A_path).convert('L')
33 | B_img = Image.open(B_path).convert('RGB')
34 |
35 | A_img = self.transform(A_img)
36 | B_img = self.transform(B_img)
37 |
38 | return {'A': A_img, 'B': B_img,
39 | 'A_paths': A_path, 'B_paths': B_path}
40 |
41 | def __len__(self):
42 | return max(self.A_size, self.B_size)
43 |
44 | def name(self):
45 | return 'UnalignedDataset'
46 |
--------------------------------------------------------------------------------
/datasets/combine_A_and_B.py:
--------------------------------------------------------------------------------
1 | from pdb import set_trace as st
2 | import os
3 | import numpy as np
4 | import cv2
5 | import argparse
6 |
7 | parser = argparse.ArgumentParser('create image pairs')
8 | parser.add_argument('--fold_A', dest='fold_A', help='input directory for image A', type=str, default='../dataset/50kshoes_edges')
9 | parser.add_argument('--fold_B', dest='fold_B', help='input directory for image B', type=str, default='../dataset/50kshoes_jpg')
10 | parser.add_argument('--fold_AB', dest='fold_AB', help='output directory', type=str, default='../dataset/test_AB')
11 | parser.add_argument('--num_imgs', dest='num_imgs', help='number of images',type=int, default=1000000)
12 | parser.add_argument('--use_AB', dest='use_AB', help='if true: (0001_A, 0001_B) to (0001_AB)',action='store_true')
13 | args = parser.parse_args()
14 |
15 | for arg in vars(args):
16 | print('[%s] = ' % arg, getattr(args, arg))
17 |
18 | splits = os.listdir(args.fold_A)
19 |
20 | for sp in splits:
21 | img_fold_A = os.path.join(args.fold_A, sp)
22 | img_fold_B = os.path.join(args.fold_B, sp)
23 | img_list = os.listdir(img_fold_A)
24 | if args.use_AB:
25 | img_list = [img_path for img_path in img_list if '_A.' in img_path]
26 |
27 | num_imgs = min(args.num_imgs, len(img_list))
28 | print('split = %s, use %d/%d images' % (sp, num_imgs, len(img_list)))
29 | img_fold_AB = os.path.join(args.fold_AB, sp)
30 | if not os.path.isdir(img_fold_AB):
31 | os.makedirs(img_fold_AB)
32 | print('split = %s, number of images = %d' % (sp, num_imgs))
33 | for n in range(num_imgs):
34 | name_A = img_list[n]
35 | path_A = os.path.join(img_fold_A, name_A)
36 | if args.use_AB:
37 | name_B = name_A.replace('_A.', '_B.')
38 | else:
39 | name_B = name_A
40 | path_B = os.path.join(img_fold_B, name_B)
41 | if os.path.isfile(path_A) and os.path.isfile(path_B):
42 | name_AB = name_A
43 | if args.use_AB:
44 | name_AB = name_AB.replace('_A.', '.') # remove _A
45 | path_AB = os.path.join(img_fold_AB, name_AB)
46 | im_A = cv2.imread(path_A, cv2.IMREAD_COLOR)
47 | im_B = cv2.imread(path_B, cv2.IMREAD_COLOR)
48 | im_AB = np.concatenate([im_A, im_B], 1)
49 | cv2.imwrite(path_AB, im_AB)
50 |
--------------------------------------------------------------------------------
/datasets/helper functions/grayscale.py:
--------------------------------------------------------------------------------
1 | from pdb import set_trace as st
2 | import os
3 | import numpy as np
4 | import cv2
5 | import argparse
6 |
7 | # Helper script to create dataset for image colorization
8 |
9 | parser = argparse.ArgumentParser('create image pairs')
10 | parser.add_argument('--fold_A', dest='fold_A', help='input directory for images', type=str, default='../dataset/50kshoes_edges')
11 | parser.add_argument('--fold_B', dest='fold_B', help='output directory', type=str, default='../dataset/test_B')
12 | parser.add_argument('--num_imgs', dest='num_imgs', help='number of images',type=int, default=1000000)
13 | args = parser.parse_args()
14 |
15 | for arg in vars(args):
16 | print('[%s] = ' % arg, getattr(args, arg))
17 |
18 | splits = os.listdir(args.fold_A)
19 |
20 | for sp in splits:
21 | img_fold_A = os.path.join(args.fold_A, sp)
22 | img_list = os.listdir(img_fold_A)
23 |
24 | num_imgs = min(args.num_imgs, len(img_list))
25 | print('split = %s, use %d/%d images' % (sp, num_imgs, len(img_list)))
26 | img_fold_B = os.path.join(args.fold_B, sp)
27 | if not os.path.isdir(img_fold_B):
28 | os.makedirs(img_fold_B)
29 | print('split = %s, number of images = %d' % (sp, num_imgs))
30 | for n in range(num_imgs):
31 | name_A = img_list[n]
32 | path_A = os.path.join(img_fold_A, name_A)
33 |
34 | if os.path.isfile(path_A):
35 | name_B = name_A
36 |
37 | path_B = os.path.join(img_fold_B, name_B)
38 | im_A = cv2.imread(path_A, 0)
39 | cv2.imwrite(path_B, im_A)
40 |
--------------------------------------------------------------------------------
/images/animation1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/images/animation1.gif
--------------------------------------------------------------------------------
/images/animation2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/images/animation2.gif
--------------------------------------------------------------------------------
/images/animation3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/images/animation3.gif
--------------------------------------------------------------------------------
/images/animation4.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/images/animation4.gif
--------------------------------------------------------------------------------
/images/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/images/results.png
--------------------------------------------------------------------------------
/images/test1_blur.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/images/test1_blur.jpg
--------------------------------------------------------------------------------
/images/test1_restored.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/images/test1_restored.jpg
--------------------------------------------------------------------------------
/images/test1_sharp.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/images/test1_sharp.jpg
--------------------------------------------------------------------------------
/images/yolo_b.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/images/yolo_b.jpg
--------------------------------------------------------------------------------
/images/yolo_o.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/images/yolo_o.jpg
--------------------------------------------------------------------------------
/images/yolo_s.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/images/yolo_s.jpg
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/models/__init__.py
--------------------------------------------------------------------------------
/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 |
5 | class BaseModel():
6 | def name(self):
7 | return 'BaseModel'
8 |
9 | def __init__(self, opt):
10 | self.opt = opt
11 | self.gpu_ids = opt.gpu_ids
12 | self.isTrain = opt.isTrain
13 | self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
14 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
15 |
16 | def set_input(self, input):
17 | self.input = input
18 |
19 | def forward(self):
20 | pass
21 |
22 | # used in test time, no backprop
23 | def test(self):
24 | pass
25 |
26 | def get_image_paths(self):
27 | pass
28 |
29 | def optimize_parameters(self):
30 | pass
31 |
32 | def get_current_visuals(self):
33 | return self.input
34 |
35 | def get_current_errors(self):
36 | return {}
37 |
38 | def save(self, label):
39 | pass
40 |
41 | # helper saving function that can be used by subclasses
42 | def save_network(self, network, network_label, epoch_label, gpu_ids):
43 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
44 | save_path = os.path.join(self.save_dir, save_filename)
45 | torch.save(network.cpu().state_dict(), save_path)
46 | if len(gpu_ids) and torch.cuda.is_available():
47 | network.cuda(device=gpu_ids[0])
48 |
49 |
50 | # helper loading function that can be used by subclasses
51 | def load_network(self, network, network_label, epoch_label):
52 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
53 | save_path = os.path.join(self.save_dir, save_filename)
54 | network.load_state_dict(torch.load(save_path))
55 |
56 | def update_learning_rate():
57 | pass
58 |
--------------------------------------------------------------------------------
/models/conditional_gan_model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import os
4 | from collections import OrderedDict
5 | from torch.autograd import Variable
6 | import util.util as util
7 | from util.image_pool import ImagePool
8 | from .base_model import BaseModel
9 | from . import networks
10 | from .losses import init_loss
11 |
12 | try:
13 | xrange # Python2
14 | except NameError:
15 | xrange = range # Python 3
16 |
17 | class ConditionalGAN(BaseModel):
18 | def name(self):
19 | return 'ConditionalGANModel'
20 |
21 | def __init__(self, opt):
22 | super(ConditionalGAN, self).__init__(opt)
23 | self.isTrain = opt.isTrain
24 | # define tensors
25 | self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
26 | self.input_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize)
27 |
28 | # load/define networks
29 | # Temp Fix for nn.parallel as nn.parallel crashes oc calculating gradient penalty
30 | use_parallel = not opt.gan_type == 'wgan-gp'
31 | print("Use Parallel = ", "True" if use_parallel else "False")
32 | self.netG = networks.define_G(
33 | opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm,
34 | not opt.no_dropout, self.gpu_ids, use_parallel, opt.learn_residual
35 | )
36 | if self.isTrain:
37 | use_sigmoid = opt.gan_type == 'gan'
38 | self.netD = networks.define_D(
39 | opt.output_nc, opt.ndf, opt.which_model_netD,
40 | opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids, use_parallel
41 | )
42 | if not self.isTrain or opt.continue_train:
43 | self.load_network(self.netG, 'G', opt.which_epoch)
44 | if self.isTrain:
45 | self.load_network(self.netD, 'D', opt.which_epoch)
46 |
47 | if self.isTrain:
48 | self.fake_AB_pool = ImagePool(opt.pool_size)
49 | self.old_lr = opt.lr
50 |
51 | # initialize optimizers
52 | self.optimizer_G = torch.optim.Adam( self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999) )
53 | self.optimizer_D = torch.optim.Adam( self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999) )
54 |
55 | self.criticUpdates = 5 if opt.gan_type == 'wgan-gp' else 1
56 |
57 | # define loss functions
58 | self.discLoss, self.contentLoss = init_loss(opt, self.Tensor)
59 |
60 | print('---------- Networks initialized -------------')
61 | networks.print_network(self.netG)
62 | if self.isTrain:
63 | networks.print_network(self.netD)
64 | print('-----------------------------------------------')
65 |
66 | def set_input(self, input):
67 | AtoB = self.opt.which_direction == 'AtoB'
68 | inputA = input['A' if AtoB else 'B']
69 | inputB = input['B' if AtoB else 'A']
70 | self.input_A.resize_(inputA.size()).copy_(inputA)
71 | self.input_B.resize_(inputB.size()).copy_(inputB)
72 | self.image_paths = input['A_paths' if AtoB else 'B_paths']
73 |
74 | def forward(self):
75 | self.real_A = Variable(self.input_A)
76 | self.fake_B = self.netG.forward(self.real_A)
77 | self.real_B = Variable(self.input_B)
78 |
79 | # no backprop gradients
80 | def test(self):
81 | self.real_A = Variable(self.input_A, volatile=True)
82 | self.fake_B = self.netG.forward(self.real_A)
83 | self.real_B = Variable(self.input_B, volatile=True)
84 |
85 | # get image paths
86 | def get_image_paths(self):
87 | return self.image_paths
88 |
89 | def backward_D(self):
90 | self.loss_D = self.discLoss.get_loss(self.netD, self.real_A, self.fake_B, self.real_B)
91 |
92 | self.loss_D.backward(retain_graph=True)
93 |
94 | def backward_G(self):
95 | self.loss_G_GAN = self.discLoss.get_g_loss(self.netD, self.real_A, self.fake_B)
96 | # Second, G(A) = B
97 | self.loss_G_Content = self.contentLoss.get_loss(self.fake_B, self.real_B) * self.opt.lambda_A
98 |
99 | self.loss_G = self.loss_G_GAN + self.loss_G_Content
100 |
101 | self.loss_G.backward()
102 |
103 | def optimize_parameters(self):
104 | self.forward()
105 |
106 | for iter_d in xrange(self.criticUpdates):
107 | self.optimizer_D.zero_grad()
108 | self.backward_D()
109 | self.optimizer_D.step()
110 |
111 | self.optimizer_G.zero_grad()
112 | self.backward_G()
113 | self.optimizer_G.step()
114 |
115 | def get_current_errors(self):
116 | return OrderedDict([('G_GAN', self.loss_G_GAN.item()),
117 | ('G_L1', self.loss_G_Content.item()),
118 | ('D_real+fake', self.loss_D.item())
119 | ])
120 |
121 | def get_current_visuals(self):
122 | real_A = util.tensor2im(self.real_A.data)
123 | fake_B = util.tensor2im(self.fake_B.data)
124 | real_B = util.tensor2im(self.real_B.data)
125 | return OrderedDict([('Blurred_Train', real_A), ('Restored_Train', fake_B), ('Sharp_Train', real_B)])
126 |
127 | def save(self, label):
128 | self.save_network(self.netG, 'G', label, self.gpu_ids)
129 | self.save_network(self.netD, 'D', label, self.gpu_ids)
130 |
131 | def update_learning_rate(self):
132 | lrd = self.opt.lr / self.opt.niter_decay
133 | lr = self.old_lr - lrd
134 | for param_group in self.optimizer_D.param_groups:
135 | param_group['lr'] = lr
136 | for param_group in self.optimizer_G.param_groups:
137 | param_group['lr'] = lr
138 | print('update learning rate: %f -> %f' % (self.old_lr, lr))
139 | self.old_lr = lr
140 |
--------------------------------------------------------------------------------
/models/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | import functools
5 | import torch.autograd as autograd
6 | import numpy as np
7 | import torchvision.models as models
8 | import util.util as util
9 | from util.image_pool import ImagePool
10 | from torch.autograd import Variable
11 | ###############################################################################
12 | # Functions
13 | ###############################################################################
14 |
15 | class ContentLoss:
16 | def __init__(self, loss):
17 | self.criterion = loss
18 |
19 | def get_loss(self, fakeIm, realIm):
20 | return self.criterion(fakeIm, realIm)
21 |
22 | class PerceptualLoss():
23 |
24 | def contentFunc(self):
25 | conv_3_3_layer = 14
26 | cnn = models.vgg19(pretrained=True).features
27 | cnn = cnn.cuda()
28 | model = nn.Sequential()
29 | model = model.cuda()
30 | for i,layer in enumerate(list(cnn)):
31 | model.add_module(str(i),layer)
32 | if i == conv_3_3_layer:
33 | break
34 | return model
35 |
36 | def __init__(self, loss):
37 | self.criterion = loss
38 | self.contentFunc = self.contentFunc()
39 |
40 | def get_loss(self, fakeIm, realIm):
41 | f_fake = self.contentFunc.forward(fakeIm)
42 | f_real = self.contentFunc.forward(realIm)
43 | f_real_no_grad = f_real.detach()
44 | loss = self.criterion(f_fake, f_real_no_grad)
45 | return loss
46 |
47 | class GANLoss(nn.Module):
48 | def __init__(
49 | self, use_l1=True, target_real_label=1.0,
50 | target_fake_label=0.0, tensor=torch.FloatTensor):
51 | super(GANLoss, self).__init__()
52 | self.real_label = target_real_label
53 | self.fake_label = target_fake_label
54 | self.real_label_var = None
55 | self.fake_label_var = None
56 | self.Tensor = tensor
57 | if use_l1:
58 | self.loss = nn.L1Loss()
59 | else:
60 | self.loss = nn.BCELoss()
61 |
62 | def get_target_tensor(self, input, target_is_real):
63 | target_tensor = None
64 | if target_is_real:
65 | create_label = ((self.real_label_var is None) or
66 | (self.real_label_var.numel() != input.numel()))
67 | if create_label:
68 | real_tensor = self.Tensor(input.size()).fill_(self.real_label)
69 | self.real_label_var = Variable(real_tensor, requires_grad=False)
70 | target_tensor = self.real_label_var
71 | else:
72 | create_label = ((self.fake_label_var is None) or
73 | (self.fake_label_var.numel() != input.numel()))
74 | if create_label:
75 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
76 | self.fake_label_var = Variable(fake_tensor, requires_grad=False)
77 | target_tensor = self.fake_label_var
78 | return target_tensor
79 |
80 | def __call__(self, input, target_is_real):
81 | target_tensor = self.get_target_tensor(input, target_is_real)
82 | return self.loss(input, target_tensor)
83 |
84 | class DiscLoss:
85 | def name(self):
86 | return 'DiscLoss'
87 |
88 | def __init__(self, opt, tensor):
89 | self.criterionGAN = GANLoss(use_l1=False, tensor=tensor)
90 | self.fake_AB_pool = ImagePool(opt.pool_size)
91 |
92 | def get_g_loss(self,net, realA, fakeB):
93 | # First, G(A) should fake the discriminator
94 | pred_fake = net.forward(fakeB)
95 | return self.criterionGAN(pred_fake, 1)
96 |
97 | def get_loss(self, net, realA, fakeB, realB):
98 | # Fake
99 | # stop backprop to the generator by detaching fake_B
100 | # Generated Image Disc Output should be close to zero
101 | self.pred_fake = net.forward(fakeB.detach())
102 | self.loss_D_fake = self.criterionGAN(self.pred_fake, 0)
103 |
104 | # Real
105 | self.pred_real = net.forward(realB)
106 | self.loss_D_real = self.criterionGAN(self.pred_real, 1)
107 |
108 | # Combined loss
109 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
110 | return self.loss_D
111 |
112 | class DiscLossLS(DiscLoss):
113 | def name(self):
114 | return 'DiscLossLS'
115 |
116 | def __init__(self, opt, tensor):
117 | super(DiscLoss, self).__init__(opt, tensor)
118 | # DiscLoss.initialize(self, opt, tensor)
119 | self.criterionGAN = GANLoss(use_l1=True, tensor=tensor)
120 |
121 | def get_g_loss(self,net, realA, fakeB):
122 | return DiscLoss.get_g_loss(self,net, realA, fakeB)
123 |
124 | def get_loss(self, net, realA, fakeB, realB):
125 | return DiscLoss.get_loss(self, net, realA, fakeB, realB)
126 |
127 | class DiscLossWGANGP(DiscLossLS):
128 | def name(self):
129 | return 'DiscLossWGAN-GP'
130 |
131 | def __init__(self, opt, tensor):
132 | super(DiscLossWGANGP, self).__init__(opt, tensor)
133 | # DiscLossLS.initialize(self, opt, tensor)
134 | self.LAMBDA = 10
135 |
136 | def get_g_loss(self, net, realA, fakeB):
137 | # First, G(A) should fake the discriminator
138 | self.D_fake = net.forward(fakeB)
139 | return -self.D_fake.mean()
140 |
141 | def calc_gradient_penalty(self, netD, real_data, fake_data):
142 | alpha = torch.rand(1, 1)
143 | alpha = alpha.expand(real_data.size())
144 | alpha = alpha.cuda()
145 |
146 | interpolates = alpha * real_data + ((1 - alpha) * fake_data)
147 |
148 | interpolates = interpolates.cuda()
149 | interpolates = Variable(interpolates, requires_grad=True)
150 |
151 | disc_interpolates = netD.forward(interpolates)
152 |
153 | gradients = autograd.grad(
154 | outputs=disc_interpolates, inputs=interpolates, grad_outputs=torch.ones(disc_interpolates.size()).cuda(),
155 | create_graph=True, retain_graph=True, only_inputs=True
156 | )[0]
157 |
158 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.LAMBDA
159 | return gradient_penalty
160 |
161 | def get_loss(self, net, realA, fakeB, realB):
162 | self.D_fake = net.forward(fakeB.detach())
163 | self.D_fake = self.D_fake.mean()
164 |
165 | # Real
166 | self.D_real = net.forward(realB)
167 | self.D_real = self.D_real.mean()
168 | # Combined loss
169 | self.loss_D = self.D_fake - self.D_real
170 | gradient_penalty = self.calc_gradient_penalty(net, realB.data, fakeB.data)
171 | return self.loss_D + gradient_penalty
172 |
173 |
174 | def init_loss(opt, tensor):
175 | # disc_loss = None
176 | # content_loss = None
177 |
178 | if opt.model == 'content_gan':
179 | content_loss = PerceptualLoss(nn.MSELoss())
180 | # content_loss.initialize(nn.MSELoss())
181 | elif opt.model == 'pix2pix':
182 | content_loss = ContentLoss(nn.L1Loss())
183 | # content_loss.initialize(nn.L1Loss())
184 | else:
185 | raise ValueError("Model [%s] not recognized." % opt.model)
186 |
187 | if opt.gan_type == 'wgan-gp':
188 | disc_loss = DiscLossWGANGP(opt, tensor)
189 | elif opt.gan_type == 'lsgan':
190 | disc_loss = DiscLossLS(opt, tensor)
191 | elif opt.gan_type == 'gan':
192 | disc_loss = DiscLoss(opt, tensor)
193 | else:
194 | raise ValueError("GAN [%s] not recognized." % opt.gan_type)
195 | # disc_loss.initialize(opt, tensor)
196 | return disc_loss, content_loss
--------------------------------------------------------------------------------
/models/models.py:
--------------------------------------------------------------------------------
1 | from .conditional_gan_model import ConditionalGAN
2 |
3 | def create_model(opt):
4 | model = None
5 | if opt.model == 'test':
6 | assert (opt.dataset_mode == 'single')
7 | from .test_model import TestModel
8 | model = TestModel( opt )
9 | else:
10 | model = ConditionalGAN(opt)
11 | # model.initialize(opt)
12 | print("model [%s] was created" % (model.name()))
13 | return model
14 |
--------------------------------------------------------------------------------
/models/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | # from torch.nn import init
4 | import functools
5 | # from torch.autograd import Variable
6 | import numpy as np
7 |
8 |
9 | ###############################################################################
10 | # Functions
11 | ###############################################################################
12 |
13 |
14 | def weights_init(m):
15 | classname = m.__class__.__name__
16 | if classname.find('Conv') != -1:
17 | m.weight.data.normal_(0.0, 0.02)
18 | if hasattr(m.bias, 'data'):
19 | m.bias.data.fill_(0)
20 | elif classname.find('BatchNorm2d') != -1:
21 | m.weight.data.normal_(1.0, 0.02)
22 | m.bias.data.fill_(0)
23 |
24 |
25 | def get_norm_layer(norm_type='instance'):
26 | if norm_type == 'batch':
27 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
28 | elif norm_type == 'instance':
29 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
30 | else:
31 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
32 | return norm_layer
33 |
34 |
35 | def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, gpu_ids=[], use_parallel=True,
36 | learn_residual=False):
37 | netG = None
38 | use_gpu = len(gpu_ids) > 0
39 | norm_layer = get_norm_layer(norm_type=norm)
40 |
41 | if use_gpu:
42 | assert (torch.cuda.is_available())
43 |
44 | if which_model_netG == 'resnet_9blocks':
45 | netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9,
46 | gpu_ids=gpu_ids, use_parallel=use_parallel, learn_residual=learn_residual)
47 | elif which_model_netG == 'resnet_6blocks':
48 | netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6,
49 | gpu_ids=gpu_ids, use_parallel=use_parallel, learn_residual=learn_residual)
50 | elif which_model_netG == 'unet_128':
51 | netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout,
52 | gpu_ids=gpu_ids, use_parallel=use_parallel, learn_residual=learn_residual)
53 | elif which_model_netG == 'unet_256':
54 | netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout,
55 | gpu_ids=gpu_ids, use_parallel=use_parallel, learn_residual=learn_residual)
56 | else:
57 | raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
58 | if len(gpu_ids) > 0:
59 | netG.cuda(gpu_ids[0])
60 | netG.apply(weights_init)
61 | return netG
62 |
63 |
64 | def define_D(input_nc, ndf, which_model_netD, n_layers_D=3, norm='batch', use_sigmoid=False, gpu_ids=[],
65 | use_parallel=True):
66 | netD = None
67 | use_gpu = len(gpu_ids) > 0
68 | norm_layer = get_norm_layer(norm_type=norm)
69 |
70 | if use_gpu:
71 | assert (torch.cuda.is_available())
72 | if which_model_netD == 'basic':
73 | netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid,
74 | gpu_ids=gpu_ids, use_parallel=use_parallel)
75 | elif which_model_netD == 'n_layers':
76 | netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid,
77 | gpu_ids=gpu_ids, use_parallel=use_parallel)
78 | else:
79 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % which_model_netD)
80 | if use_gpu:
81 | netD.cuda(gpu_ids[0])
82 | netD.apply(weights_init)
83 | return netD
84 |
85 |
86 | def print_network(net):
87 | num_params = 0
88 | for param in net.parameters():
89 | num_params += param.numel()
90 | print(net)
91 | print('Total number of parameters: %d' % num_params)
92 |
93 |
94 | ##############################################################################
95 | # Classes
96 | ##############################################################################
97 |
98 |
99 | # Defines the generator that consists of Resnet blocks between a few
100 | # downsampling/upsampling operations.
101 | # Code and idea originally from Justin Johnson's architecture.
102 | # https://github.com/jcjohnson/fast-neural-style/
103 | class ResnetGenerator(nn.Module):
104 | def __init__(
105 | self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False,
106 | n_blocks=6, gpu_ids=[], use_parallel=True, learn_residual=False, padding_type='reflect'):
107 | assert (n_blocks >= 0)
108 | super(ResnetGenerator, self).__init__()
109 | self.input_nc = input_nc
110 | self.output_nc = output_nc
111 | self.ngf = ngf
112 | self.gpu_ids = gpu_ids
113 | self.use_parallel = use_parallel
114 | self.learn_residual = learn_residual
115 |
116 | if type(norm_layer) == functools.partial:
117 | use_bias = norm_layer.func == nn.InstanceNorm2d
118 | else:
119 | use_bias = norm_layer == nn.InstanceNorm2d
120 |
121 | model = [
122 | nn.ReflectionPad2d(3),
123 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
124 | norm_layer(ngf),
125 | nn.ReLU(True)
126 | ]
127 |
128 | n_downsampling = 2
129 |
130 | # 下采样
131 | # for i in range(n_downsampling): # [0,1]
132 | # mult = 2**i
133 | #
134 | # model += [
135 | # nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
136 | # norm_layer(ngf * mult * 2),
137 | # nn.ReLU(True)
138 | # ]
139 |
140 | model += [
141 | nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=use_bias),
142 | norm_layer(128),
143 | nn.ReLU(True),
144 |
145 | nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=use_bias),
146 | norm_layer(256),
147 | nn.ReLU(True)
148 | ]
149 |
150 | # 中间的残差网络
151 | # mult = 2**n_downsampling
152 | for i in range(n_blocks):
153 | # model += [
154 | # ResnetBlock(
155 | # ngf * mult, padding_type=padding_type, norm_layer=norm_layer,
156 | # use_dropout=use_dropout, use_bias=use_bias)
157 | # ]
158 | model += [
159 | ResnetBlock(256, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)
160 | ]
161 |
162 | # 上采样
163 | # for i in range(n_downsampling):
164 | # mult = 2**(n_downsampling - i)
165 | #
166 | # model += [
167 | # nn.ConvTranspose2d(
168 | # ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2,
169 | # padding=1, output_padding=1, bias=use_bias),
170 | # norm_layer(int(ngf * mult / 2)),
171 | # nn.ReLU(True)
172 | # ]
173 | model += [
174 | nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias),
175 | norm_layer(128),
176 | nn.ReLU(True),
177 |
178 | nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias),
179 | norm_layer(64),
180 | nn.ReLU(True),
181 | ]
182 |
183 | model += [
184 | nn.ReflectionPad2d(3),
185 | nn.Conv2d(64, output_nc, kernel_size=7, padding=0),
186 | nn.Tanh()
187 | ]
188 |
189 | self.model = nn.Sequential(*model)
190 |
191 | def forward(self, input):
192 | if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor) and self.use_parallel:
193 | output = nn.parallel.data_parallel(self.model, input, self.gpu_ids)
194 | else:
195 | output = self.model(input)
196 | if self.learn_residual:
197 | # output = input + output
198 | output = torch.clamp(input + output, min=-1, max=1)
199 | return output
200 |
201 |
202 | # Define a resnet block
203 | class ResnetBlock(nn.Module):
204 |
205 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
206 | super(ResnetBlock, self).__init__()
207 |
208 | padAndConv = {
209 | 'reflect': [
210 | nn.ReflectionPad2d(1),
211 | nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias)],
212 | 'replicate': [
213 | nn.ReplicationPad2d(1),
214 | nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias)],
215 | 'zero': [
216 | nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=use_bias)]
217 | }
218 |
219 | try:
220 | blocks = padAndConv[padding_type] + [
221 | norm_layer(dim),
222 | nn.ReLU(True)
223 | ] + [
224 | nn.Dropout(0.5)
225 | ] if use_dropout else [] + padAndConv[padding_type] + [
226 | norm_layer(dim)
227 | ]
228 | except:
229 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
230 |
231 | self.conv_block = nn.Sequential(*blocks)
232 |
233 | # self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
234 | # def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
235 | # padAndConv = {
236 | # 'reflect': [nn.ReflectionPad2d(1), nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias)],
237 | # 'replicate': [nn.ReplicationPad2d(1), nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias)],
238 | # 'zero': [nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=use_bias)]
239 | # }
240 | # try:
241 | # blocks = [
242 | # padAndConv[padding_type],
243 | #
244 | # norm_layer(dim),
245 | # nn.ReLU(True),
246 | # nn.Dropout(0.5) if use_dropout else None,
247 | #
248 | # padAndConv[padding_type],
249 | #
250 | # norm_layer(dim)
251 | # ]
252 | # except:
253 | # raise NotImplementedError('padding [%s] is not implemented' % padding_type)
254 | #
255 | # return nn.Sequential(*blocks)
256 |
257 | # blocks = []
258 | # if padding_type == 'reflect':
259 | # blocks += [nn.ReflectionPad2d(1), nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias)]
260 | # elif padding_type == 'replicate':
261 | # blocks += [nn.ReplicationPad2d(1), nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias)]
262 | # elif padding_type == 'zero':
263 | # blocks += [nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=use_bias)]
264 | # else:
265 | # raise NotImplementedError('padding [%s] is not implemented' % padding_type)
266 | #
267 | # blocks += [
268 | # norm_layer(dim),
269 | # nn.ReLU(True),
270 | # nn.Dropout(0.5) if use_dropout else None
271 | # ]
272 | #
273 | # if padding_type == 'reflect':
274 | # blocks += [nn.ReflectionPad2d(1), nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias)]
275 | # elif padding_type == 'replicate':
276 | # blocks += [nn.ReplicationPad2d(1), nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias)]
277 | # elif padding_type == 'zero':
278 | # blocks += [nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=use_bias)]
279 | # else:
280 | # raise NotImplementedError('padding [%s] is not implemented' % padding_type)
281 | #
282 | # blocks += [
283 | # norm_layer(dim)
284 | # ]
285 | #
286 | # return nn.Sequential(*blocks)
287 |
288 | def forward(self, x):
289 | out = x + self.conv_block(x)
290 | return out
291 |
292 |
293 | # Defines the Unet generator.
294 | # |num_downs|: number of downsamplings in UNet. For example,
295 | # if |num_downs| == 7, image of size 128x128 will become of size 1x1
296 | # at the bottleneck
297 | class UnetGenerator(nn.Module):
298 | def __init__(
299 | self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d,
300 | use_dropout=False, gpu_ids=[], use_parallel=True, learn_residual=False):
301 | super(UnetGenerator, self).__init__()
302 | self.gpu_ids = gpu_ids
303 | self.use_parallel = use_parallel
304 | self.learn_residual = learn_residual
305 | # currently support only input_nc == output_nc
306 | assert (input_nc == output_nc)
307 |
308 | # construct unet structure
309 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, innermost=True)
310 | for i in range(num_downs - 5):
311 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, norm_layer=norm_layer,
312 | use_dropout=use_dropout)
313 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block, norm_layer=norm_layer)
314 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block, norm_layer=norm_layer)
315 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block, norm_layer=norm_layer)
316 | unet_block = UnetSkipConnectionBlock(output_nc, ngf, unet_block, outermost=True, norm_layer=norm_layer)
317 |
318 | self.model = unet_block
319 |
320 | def forward(self, input):
321 | if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor) and self.use_parallel:
322 | output = nn.parallel.data_parallel(self.model, input, self.gpu_ids)
323 | else:
324 | output = self.model(input)
325 | if self.learn_residual:
326 | output = input + output
327 | output = torch.clamp(output, min=-1, max=1)
328 | return output
329 |
330 |
331 | # Defines the submodule with skip connection.
332 | # X -------------------identity---------------------- X
333 | # |-- downsampling -- |submodule| -- upsampling --|
334 | class UnetSkipConnectionBlock(nn.Module):
335 | def __init__(
336 | self, outer_nc, inner_nc, submodule=None,
337 | outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
338 | super(UnetSkipConnectionBlock, self).__init__()
339 | self.outermost = outermost
340 | if type(norm_layer) == functools.partial:
341 | use_bias = norm_layer.func == nn.InstanceNorm2d
342 | else:
343 | use_bias = norm_layer == nn.InstanceNorm2d
344 |
345 | dConv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
346 | dRelu = nn.LeakyReLU(0.2, True)
347 | dNorm = norm_layer(inner_nc)
348 | uRelu = nn.ReLU(True)
349 | uNorm = norm_layer(outer_nc)
350 |
351 | if outermost:
352 | uConv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1)
353 | dModel = [dConv]
354 | uModel = [uRelu, uConv, nn.Tanh()]
355 | model = [
356 | dModel,
357 | submodule,
358 | uModel
359 | ]
360 | # model = [
361 | # # Down
362 | # nn.Conv2d( outer_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias),
363 | #
364 | # submodule,
365 | # # Up
366 | # nn.ReLU(True),
367 | # nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1),
368 | # nn.Tanh()
369 | # ]
370 | elif innermost:
371 | uConv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
372 | dModel = [dRelu, dConv]
373 | uModel = [uRelu, uConv, uNorm]
374 | model = [
375 | dModel,
376 | uModel
377 | ]
378 | # model = [
379 | # # down
380 | # nn.LeakyReLU(0.2, True),
381 | # # up
382 | # nn.ReLU(True),
383 | # nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias),
384 | # norm_layer(outer_nc)
385 | # ]
386 | else:
387 | uConv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
388 | dModel = [dRelu, dConv, dNorm]
389 | uModel = [uRelu, uConv, uNorm]
390 |
391 | model = [
392 | dModel,
393 | submodule,
394 | uModel
395 | ]
396 | model += [nn.Dropout(0.5)] if use_dropout else []
397 |
398 | # if use_dropout:
399 | # model = down + [submodule] + up + [nn.Dropout(0.5)]
400 | # else:
401 | # model = down + [submodule] + up
402 |
403 | self.model = nn.Sequential(*model)
404 |
405 | def forward(self, x):
406 | if self.outermost:
407 | return self.model(x)
408 | else:
409 | return torch.cat([self.model(x), x], 1)
410 |
411 |
412 | # Defines the PatchGAN discriminator with the specified arguments.
413 | class NLayerDiscriminator(nn.Module):
414 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[],
415 | use_parallel=True):
416 | super(NLayerDiscriminator, self).__init__()
417 | self.gpu_ids = gpu_ids
418 | self.use_parallel = use_parallel
419 |
420 | if type(norm_layer) == functools.partial:
421 | use_bias = norm_layer.func == nn.InstanceNorm2d
422 | else:
423 | use_bias = norm_layer == nn.InstanceNorm2d
424 |
425 | kw = 4
426 | padw = int(np.ceil((kw - 1) / 2))
427 | sequence = [
428 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
429 | nn.LeakyReLU(0.2, True)
430 | ]
431 |
432 | nf_mult = 1
433 | nf_mult_prev = 1
434 | for n in range(1, n_layers):
435 | nf_mult_prev = nf_mult
436 | nf_mult = min(2 ** n, 8)
437 | sequence += [
438 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
439 | kernel_size=kw, stride=2, padding=padw, bias=use_bias),
440 | norm_layer(ndf * nf_mult),
441 | nn.LeakyReLU(0.2, True)
442 | ]
443 |
444 | nf_mult_prev = nf_mult
445 | nf_mult = min(2 ** n_layers, 8)
446 | sequence += [
447 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
448 | norm_layer(ndf * nf_mult),
449 | nn.LeakyReLU(0.2, True)
450 | ]
451 |
452 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
453 |
454 | if use_sigmoid:
455 | sequence += [nn.Sigmoid()]
456 |
457 | self.model = nn.Sequential(*sequence)
458 |
459 | def forward(self, input):
460 | if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor) and self.use_parallel:
461 | return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
462 | else:
463 | return self.model(input)
464 |
--------------------------------------------------------------------------------
/models/test_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 | from collections import OrderedDict
4 | import util.util as util
5 | from .base_model import BaseModel
6 | from . import networks
7 |
8 |
9 | class TestModel(BaseModel):
10 | def name(self):
11 | return 'TestModel'
12 |
13 | def __init__(self, opt):
14 | assert(not opt.isTrain)
15 | super(TestModel, self).__init__(opt)
16 | self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
17 |
18 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
19 | opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids, False,
20 | opt.learn_residual)
21 | which_epoch = opt.which_epoch
22 | self.load_network(self.netG, 'G', which_epoch)
23 |
24 | print('---------- Networks initialized -------------')
25 | networks.print_network(self.netG)
26 | print('-----------------------------------------------')
27 |
28 | def set_input(self, input):
29 | # we need to use single_dataset mode
30 | input_A = input['A']
31 | temp = self.input_A.clone()
32 | temp.resize_(input_A.size()).copy_(input_A)
33 | self.input_A = temp
34 | self.image_paths = input['A_paths']
35 |
36 | def test(self):
37 | with torch.no_grad():
38 | self.real_A = Variable(self.input_A)
39 | self.fake_B = self.netG.forward(self.real_A)
40 |
41 | # get image paths
42 | def get_image_paths(self):
43 | return self.image_paths
44 |
45 | def get_current_visuals(self):
46 | real_A = util.tensor2im(self.real_A.data)
47 | fake_B = util.tensor2im(self.fake_B.data)
48 | return OrderedDict([('real_A', real_A), ('fake_B', fake_B)])
49 |
--------------------------------------------------------------------------------
/motion_blur/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/motion_blur/__init__.py
--------------------------------------------------------------------------------
/motion_blur/blur_image.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import cv2
4 | import os
5 | from scipy import signal
6 | from scipy import misc
7 | from motion_blur.generate_PSF import PSF
8 | from motion_blur.generate_trajectory import Trajectory
9 |
10 |
11 | class BlurImage(object):
12 |
13 | def __init__(self, image_path, PSFs=None, part=None, path__to_save=None):
14 | """
15 |
16 | :param image_path: path to square, RGB image.
17 | :param PSFs: array of Kernels.
18 | :param part: int number of kernel to use.
19 | :param path__to_save: folder to save results.
20 | """
21 | if os.path.isfile(image_path):
22 | self.image_path = image_path
23 | self.original = misc.imread(self.image_path)
24 | self.shape = self.original.shape
25 | if len(self.shape) < 3:
26 | raise Exception('We support only RGB images yet.')
27 | elif self.shape[0] != self.shape[1]:
28 | raise Exception('We support only square images yet.')
29 | else:
30 | raise Exception('Not correct path to image.')
31 | self.path_to_save = path__to_save
32 | if PSFs is None:
33 | if self.path_to_save is None:
34 | self.PSFs = PSF(canvas=self.shape[0]).fit()
35 | else:
36 | self.PSFs = PSF(canvas=self.shape[0], path_to_save=os.path.join(self.path_to_save,
37 | 'PSFs.png')).fit(save=True)
38 | else:
39 | self.PSFs = PSFs
40 |
41 | self.part = part
42 | self.result = []
43 |
44 | def blur_image(self, save=False, show=False):
45 | if self.part is None:
46 | psf = self.PSFs
47 | else:
48 | psf = [self.PSFs[self.part]]
49 | yN, xN, channel = self.shape
50 | key, kex = self.PSFs[0].shape
51 | delta = yN - key
52 | assert delta >= 0, 'resolution of image should be higher than kernel'
53 | result=[]
54 | if len(psf) > 1:
55 | for p in psf:
56 | tmp = np.pad(p, delta // 2, 'constant')
57 | cv2.normalize(tmp, tmp, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
58 | # blured = np.zeros(self.shape)
59 | blured = cv2.normalize(self.original, self.original, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX,
60 | dtype=cv2.CV_32F)
61 | blured[:, :, 0] = np.array(signal.fftconvolve(blured[:, :, 0], tmp, 'same'))
62 | blured[:, :, 1] = np.array(signal.fftconvolve(blured[:, :, 1], tmp, 'same'))
63 | blured[:, :, 2] = np.array(signal.fftconvolve(blured[:, :, 2], tmp, 'same'))
64 | blured = cv2.normalize(blured, blured, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
65 | blured = cv2.cvtColor(blured, cv2.COLOR_RGB2BGR)
66 | result.append(np.abs(blured))
67 | else:
68 | psf = psf[0]
69 | tmp = np.pad(psf, delta // 2, 'constant')
70 | cv2.normalize(tmp, tmp, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
71 | blured = cv2.normalize(self.original, self.original, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX,
72 | dtype=cv2.CV_32F)
73 | blured[:, :, 0] = np.array(signal.fftconvolve(blured[:, :, 0], tmp, 'same'))
74 | blured[:, :, 1] = np.array(signal.fftconvolve(blured[:, :, 1], tmp, 'same'))
75 | blured[:, :, 2] = np.array(signal.fftconvolve(blured[:, :, 2], tmp, 'same'))
76 | blured = cv2.normalize(blured, blured, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
77 | blured = cv2.cvtColor(blured, cv2.COLOR_RGB2BGR)
78 | result.append(np.abs(blured))
79 | self.result = result
80 | if show or save:
81 | self.__plot_canvas(show, save)
82 |
83 | def __plot_canvas(self, show, save):
84 | if len(self.result) == 0:
85 | raise Exception('Please run blur_image() method first.')
86 | else:
87 | plt.close()
88 | plt.axis('off')
89 | fig, axes = plt.subplots(1, len(self.result), figsize=(10, 10))
90 | if len(self.result) > 1:
91 | for i in range(len(self.result)):
92 | axes[i].imshow(self.result[i])
93 | else:
94 | plt.axis('off')
95 |
96 | plt.imshow(self.result[0])
97 | if show and save:
98 | if self.path_to_save is None:
99 | raise Exception('Please create Trajectory instance with path_to_save')
100 | cv2.imwrite(os.path.join(self.path_to_save, self.image_path.split('/')[-1]), self.result[0] * 255)
101 | plt.show()
102 | elif save:
103 | if self.path_to_save is None:
104 | raise Exception('Please create Trajectory instance with path_to_save')
105 | cv2.imwrite(os.path.join(self.path_to_save, self.image_path.split('/')[-1]), self.result[0] * 255)
106 | elif show:
107 | plt.show()
108 |
109 |
110 | if __name__ == '__main__':
111 | folder = '/Users/mykolam/PycharmProjects/University/DeblurGAN2/results_sharp'
112 | folder_to_save = '/Users/mykolam/PycharmProjects/University/DeblurGAN2/blured'
113 | params = [0.01, 0.009, 0.008, 0.007, 0.005, 0.003]
114 | for path in os.listdir(folder):
115 | print(path)
116 | trajectory = Trajectory(canvas=64, max_len=60, expl=np.random.choice(params)).fit()
117 | psf = PSF(canvas=64, trajectory=trajectory).fit()
118 | BlurImage(os.path.join(folder, path), PSFs=psf,
119 | path__to_save=folder_to_save, part=np.random.choice([1, 2, 3])).\
120 | blur_image(save=True)
121 |
--------------------------------------------------------------------------------
/motion_blur/generate_PSF.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from math import ceil
3 | import matplotlib.pyplot as plt
4 | from motion_blur.generate_trajectory import Trajectory
5 |
6 |
7 | class PSF(object):
8 | def __init__(self, canvas=None, trajectory=None, fraction=None, path_to_save=None):
9 | if canvas is None:
10 | self.canvas = (canvas, canvas)
11 | else:
12 | self.canvas = (canvas, canvas)
13 | if trajectory is None:
14 | self.trajectory = Trajectory(canvas=canvas, expl=0.005).fit(show=False, save=False)
15 | else:
16 | self.trajectory = trajectory.x
17 | if fraction is None:
18 | self.fraction = [1/100, 1/10, 1/2, 1]
19 | else:
20 | self.fraction = fraction
21 | self.path_to_save = path_to_save
22 | self.PSFnumber = len(self.fraction)
23 | self.iters = len(self.trajectory)
24 | self.PSFs = []
25 |
26 | def fit(self, show=False, save=False):
27 | PSF = np.zeros(self.canvas)
28 |
29 | triangle_fun = lambda x: np.maximum(0, (1 - np.abs(x)))
30 | triangle_fun_prod = lambda x, y: np.multiply(triangle_fun(x), triangle_fun(y))
31 | for j in range(self.PSFnumber):
32 | if j == 0:
33 | prevT = 0
34 | else:
35 | prevT = self.fraction[j - 1]
36 |
37 | for t in range(len(self.trajectory)):
38 | # print(j, t)
39 | if (self.fraction[j] * self.iters >= t) and (prevT * self.iters < t - 1):
40 | t_proportion = 1
41 | elif (self.fraction[j] * self.iters >= t - 1) and (prevT * self.iters < t - 1):
42 | t_proportion = self.fraction[j] * self.iters - (t - 1)
43 | elif (self.fraction[j] * self.iters >= t) and (prevT * self.iters < t):
44 | t_proportion = t - (prevT * self.iters)
45 | elif (self.fraction[j] * self.iters >= t - 1) and (prevT * self.iters < t):
46 | t_proportion = (self.fraction[j] - prevT) * self.iters
47 | else:
48 | t_proportion = 0
49 |
50 | m2 = int(np.minimum(self.canvas[1] - 1, np.maximum(1, np.math.floor(self.trajectory[t].real))))
51 | M2 = int(m2 + 1)
52 | m1 = int(np.minimum(self.canvas[0] - 1, np.maximum(1, np.math.floor(self.trajectory[t].imag))))
53 | M1 = int(m1 + 1)
54 |
55 | PSF[m1, m2] += t_proportion * triangle_fun_prod(
56 | self.trajectory[t].real - m2, self.trajectory[t].imag - m1
57 | )
58 | PSF[m1, M2] += t_proportion * triangle_fun_prod(
59 | self.trajectory[t].real - M2, self.trajectory[t].imag - m1
60 | )
61 | PSF[M1, m2] += t_proportion * triangle_fun_prod(
62 | self.trajectory[t].real - m2, self.trajectory[t].imag - M1
63 | )
64 | PSF[M1, M2] += t_proportion * triangle_fun_prod(
65 | self.trajectory[t].real - M2, self.trajectory[t].imag - M1
66 | )
67 |
68 | self.PSFs.append(PSF / (self.iters))
69 | if show or save:
70 | self.__plot_canvas(show, save)
71 |
72 | return self.PSFs
73 |
74 | def __plot_canvas(self, show, save):
75 | if len(self.PSFs) == 0:
76 | raise Exception("Please run fit() method first.")
77 | else:
78 | plt.close()
79 | fig, axes = plt.subplots(1, self.PSFnumber, figsize=(10, 10))
80 | for i in range(self.PSFnumber):
81 | axes[i].imshow(self.PSFs[i], cmap='gray')
82 | if show and save:
83 | if self.path_to_save is None:
84 | raise Exception('Please create Trajectory instance with path_to_save')
85 | plt.savefig(self.path_to_save)
86 | plt.show()
87 | elif save:
88 | if self.path_to_save is None:
89 | raise Exception('Please create Trajectory instance with path_to_save')
90 | plt.savefig(self.path_to_save)
91 | elif show:
92 | plt.show()
93 |
94 |
95 | if __name__ == '__main__':
96 | psf = PSF(canvas=128, path_to_save='/Users/mykolam/PycharmProjects/University/RandomMotionBlur/psf.png')
97 | psf.fit(show=True, save=True)
--------------------------------------------------------------------------------
/motion_blur/generate_trajectory.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from math import ceil
4 |
5 |
6 | class Trajectory(object):
7 | def __init__(self, canvas=64, iters=2000, max_len=60, expl=None, path_to_save=None):
8 | """
9 | Generates a variety of random motion trajectories in continuous domain as in [Boracchi and Foi 2012]. Each
10 | trajectory consists of a complex-valued vector determining the discrete positions of a particle following a
11 | 2-D random motion in continuous domain. The particle has an initial velocity vector which, at each iteration,
12 | is affected by a Gaussian perturbation and by a deterministic inertial component, directed toward the
13 | previous particle position. In addition, with a small probability, an impulsive (abrupt) perturbation aiming
14 | at inverting the particle velocity may arises, mimicking a sudden movement that occurs when the user presses
15 | the camera button or tries to compensate the camera shake. At each step, the velocity is normalized to
16 | guarantee that trajectories corresponding to equal exposures have the same length. Each perturbation (
17 | Gaussian, inertial, and impulsive) is ruled by its own parameter. Rectilinear Blur as in [Boracchi and Foi
18 | 2011] can be obtained by setting anxiety to 0 (when no impulsive changes occurs
19 | :param canvas: size of domain where our trajectory os defined.
20 | :param iters: number of iterations for definition of our trajectory.
21 | :param max_len: maximum length of our trajectory.
22 | :param expl: this param helps to define probability of big shake. Recommended expl = 0.005.
23 | :param path_to_save: where to save if you need.
24 | """
25 | self.canvas = canvas
26 | self.iters = iters
27 | self.max_len = max_len
28 | if expl is None:
29 | self.expl = 0.1 * np.random.uniform(0, 1)
30 | else:
31 | self.expl = expl
32 | if path_to_save is None:
33 | pass
34 | else:
35 | self.path_to_save = path_to_save
36 | self.tot_length = None
37 | self.big_expl_count = None
38 | self.x = None
39 |
40 | def fit(self, show=False, save=False):
41 | """
42 | Generate motion, you can save or plot, coordinates of motion you can find in x property.
43 | Also you can fin properties tot_length, big_expl_count.
44 | :param show: default False.
45 | :param save: default False.
46 | :return: x (vector of motion).
47 | """
48 | tot_length = 0
49 | big_expl_count = 0
50 | # how to be near the previous position
51 | # TODO: I can change this paramether for 0.1 and make kernel at all image
52 | centripetal = 0.7 * np.random.uniform(0, 1)
53 | # probability of big shake
54 | prob_big_shake = 0.2 * np.random.uniform(0, 1)
55 | # term determining, at each sample, the random component of the new direction
56 | gaussian_shake = 10 * np.random.uniform(0, 1)
57 | init_angle = 360 * np.random.uniform(0, 1)
58 |
59 | img_v0 = np.sin(np.deg2rad(init_angle))
60 | real_v0 = np.cos(np.deg2rad(init_angle))
61 |
62 | v0 = complex(real=real_v0, imag=img_v0)
63 | v = v0 * self.max_len / (self.iters - 1)
64 |
65 | if self.expl > 0:
66 | v = v0 * self.expl
67 |
68 | x = np.array([complex(real=0, imag=0)] * (self.iters))
69 |
70 | for t in range(0, self.iters - 1):
71 | if np.random.uniform() < prob_big_shake * self.expl:
72 | next_direction = 2 * v * (np.exp(complex(real=0, imag=np.pi + (np.random.uniform() - 0.5))))
73 | big_expl_count += 1
74 | else:
75 | next_direction = 0
76 |
77 | dv = next_direction + self.expl * (
78 | gaussian_shake * complex(real=np.random.randn(), imag=np.random.randn()) - centripetal * x[t]) * (
79 | self.max_len / (self.iters - 1))
80 |
81 | v += dv
82 | v = (v / float(np.abs(v))) * (self.max_len / float((self.iters - 1)))
83 | x[t + 1] = x[t] + v
84 | tot_length = tot_length + abs(x[t + 1] - x[t])
85 |
86 | # centere the motion
87 | x += complex(real=-np.min(x.real), imag=-np.min(x.imag))
88 | x = x - complex(real=x[0].real % 1., imag=x[0].imag % 1.) + complex(1, 1)
89 | x += complex(real=ceil((self.canvas - max(x.real)) / 2), imag=ceil((self.canvas - max(x.imag)) / 2))
90 |
91 | self.tot_length = tot_length
92 | self.big_expl_count = big_expl_count
93 | self.x = x
94 |
95 | if show or save:
96 | self.__plot_canvas(show, save)
97 | return self
98 |
99 | def __plot_canvas(self, show, save):
100 | if self.x is None:
101 | raise Exception("Please run fit() method first")
102 | else:
103 | plt.close()
104 | plt.plot(self.x.real, self.x.imag, '-', color='blue')
105 |
106 | plt.xlim((0, self.canvas))
107 | plt.ylim((0, self.canvas))
108 | if show and save:
109 | plt.savefig(self.path_to_save)
110 | plt.show()
111 | elif save:
112 | if self.path_to_save is None:
113 | raise Exception('Please create Trajectory instance with path_to_save')
114 | plt.savefig(self.path_to_save)
115 | elif show:
116 | plt.show()
117 |
118 |
119 | if __name__ == '__main__':
120 | trajectory = Trajectory(expl=0.005,
121 | path_to_save='/Users/mykolam/PycharmProjects/University/RandomMotionBlur/main.png')
122 | trajectory.fit(True, False)
123 |
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/options/__init__.py
--------------------------------------------------------------------------------
/options/base_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from util import util
4 | import torch
5 |
6 | class BaseOptions():
7 | def __init__(self):
8 | self.parser = argparse.ArgumentParser()
9 | self.initialized = False
10 |
11 | def initialize(self):
12 | self.parser.add_argument('--dataroot', type=str, default="D:\Photos\TrainingData\BlurredSharp\combined", help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
13 | self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
14 | self.parser.add_argument('--loadSizeX', type=int, default=640, help='scale images to this size')
15 | self.parser.add_argument('--loadSizeY', type=int, default=360, help='scale images to this size')
16 | self.parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size')
17 | self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
18 | self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
19 | self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
20 | self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
21 | self.parser.add_argument('--which_model_netD', type=str, default='basic', help='selects model to use for netD')
22 | self.parser.add_argument('--which_model_netG', type=str, default='resnet_9blocks', help='selects model to use for netG')
23 | self.parser.add_argument('--learn_residual', action='store_true', help='if specified, model would learn only the residual to the input')
24 | self.parser.add_argument('--gan_type', type=str, default='wgan-gp', help='wgan-gp : Wasserstein GAN with Gradient Penalty, lsgan : Least Sqaures GAN, gan : Vanilla GAN')
25 | self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
26 | self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
27 | self.parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
28 | self.parser.add_argument('--dataset_mode', type=str, default='aligned', help='chooses how datasets are loaded. [unaligned | aligned | single]')
29 | self.parser.add_argument('--model', type=str, default='content_gan', help='chooses which model to use. pix2pix, test, content_gan')
30 | self.parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA')
31 | self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data')
32 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
33 | self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
34 | self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
35 | self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
36 | self.parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
37 | self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
38 | self.parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
39 | self.parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
40 | self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
41 | self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
42 | self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
43 |
44 | self.initialized = True
45 |
46 | def parse(self):
47 | if not self.initialized:
48 | self.initialize()
49 | self.opt = self.parser.parse_args()
50 | self.opt.isTrain = self.isTrain # train or test
51 |
52 | str_ids = self.opt.gpu_ids.split(',')
53 | self.opt.gpu_ids = []
54 | for str_id in str_ids:
55 | id = int(str_id)
56 | if id >= 0:
57 | self.opt.gpu_ids.append(id)
58 |
59 | # set gpu ids
60 | if len(self.opt.gpu_ids) > 0:
61 | torch.cuda.set_device(self.opt.gpu_ids[0])
62 |
63 | args = vars(self.opt)
64 |
65 | print('------------ Options -------------')
66 | for k, v in sorted(args.items()):
67 | print('%s: %s' % (str(k), str(v)))
68 | print('-------------- End ----------------')
69 |
70 | # save to the disk
71 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
72 | util.mkdirs(expr_dir)
73 | file_name = os.path.join(expr_dir, 'opt.txt')
74 | with open(file_name, 'wt') as opt_file:
75 | opt_file.write('------------ Options -------------\n')
76 | for k, v in sorted(args.items()):
77 | opt_file.write('%s: %s\n' % (str(k), str(v)))
78 | opt_file.write('-------------- End ----------------\n')
79 | return self.opt
80 |
--------------------------------------------------------------------------------
/options/test_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TestOptions(BaseOptions):
5 | def initialize(self):
6 | BaseOptions.initialize(self)
7 | self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
8 | self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
9 | self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
10 | self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
11 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
12 | self.parser.add_argument('--how_many', type=int, default=5000, help='how many test images to run')
13 | self.isTrain = False
14 |
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TrainOptions(BaseOptions):
5 | def initialize(self):
6 | BaseOptions.initialize(self)
7 | self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen')
8 | self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
9 | self.parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
10 | self.parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
11 | self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
12 | self.parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')
13 | self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
14 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
15 | self.parser.add_argument('--niter', type=int, default=150, help='# of iter at starting learning rate')
16 | self.parser.add_argument('--niter_decay', type=int, default=150, help='# of iter to linearly decay learning rate to zero')
17 | self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
18 | self.parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam')
19 | self.parser.add_argument('--lambda_A', type=float, default=100.0, help='weight for cycle loss (A -> B -> A)')
20 | self.parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)')
21 | self.parser.add_argument('--identity', type=float, default=0.0, help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1')
22 | self.parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
23 | self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
24 | self.isTrain = True
25 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | from options.test_options import TestOptions
4 | from data.data_loader import CreateDataLoader
5 | from models.models import create_model
6 | from util.visualizer import Visualizer
7 | from pdb import set_trace as st
8 | from util import html
9 | from util.metrics import PSNR
10 | from ssim import SSIM
11 | from PIL import Image
12 |
13 | opt = TestOptions().parse()
14 | opt.nThreads = 1 # test code only supports nThreads = 1
15 | opt.batchSize = 1 # test code only supports batchSize = 1
16 | opt.serial_batches = True # no shuffle
17 | opt.no_flip = True # no flip
18 |
19 | data_loader = CreateDataLoader(opt)
20 | dataset = data_loader.load_data()
21 | model = create_model(opt)
22 | visualizer = Visualizer(opt)
23 | # create website
24 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
25 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
26 | # test
27 | avgPSNR = 0.0
28 | avgSSIM = 0.0
29 | counter = 0
30 |
31 | for i, data in enumerate(dataset):
32 | if i >= opt.how_many:
33 | break
34 | counter = i
35 | model.set_input(data)
36 | model.test()
37 | visuals = model.get_current_visuals()
38 | #avgPSNR += PSNR(visuals['fake_B'],visuals['real_B'])
39 | #pilFake = Image.fromarray(visuals['fake_B'])
40 | #pilReal = Image.fromarray(visuals['real_B'])
41 | #avgSSIM += SSIM(pilFake).cw_ssim_value(pilReal)
42 | img_path = model.get_image_paths()
43 | print('process image... %s' % img_path)
44 | visualizer.save_images(webpage, visuals, img_path)
45 |
46 | #avgPSNR /= counter
47 | #avgSSIM /= counter
48 | #print('PSNR = %f, SSIM = %f' %
49 | # (avgPSNR, avgSSIM))
50 |
51 | webpage.save()
52 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import time
2 | from options.train_options import TrainOptions
3 | from data.data_loader import CreateDataLoader
4 | from models.models import create_model
5 | from util.visualizer import Visualizer
6 | from util.metrics import PSNR, SSIM
7 | from multiprocessing import freeze_support
8 |
9 | def train(opt, data_loader, model, visualizer):
10 | dataset = data_loader.load_data()
11 | dataset_size = len(data_loader)
12 | print('#training images = %d' % dataset_size)
13 | total_steps = 0
14 | for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
15 | epoch_start_time = time.time()
16 | epoch_iter = 0
17 | for i, data in enumerate(dataset):
18 | iter_start_time = time.time()
19 | total_steps += opt.batchSize
20 | epoch_iter += opt.batchSize
21 | model.set_input(data)
22 | model.optimize_parameters()
23 |
24 | if total_steps % opt.display_freq == 0:
25 | results = model.get_current_visuals()
26 | psnrMetric = PSNR(results['Restored_Train'], results['Sharp_Train'])
27 | print('PSNR on Train = %f' % psnrMetric)
28 | visualizer.display_current_results(results, epoch)
29 |
30 | if total_steps % opt.print_freq == 0:
31 | errors = model.get_current_errors()
32 | t = (time.time() - iter_start_time) / opt.batchSize
33 | visualizer.print_current_errors(epoch, epoch_iter, errors, t)
34 | if opt.display_id > 0:
35 | visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)
36 |
37 | if total_steps % opt.save_latest_freq == 0:
38 | print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
39 | model.save('latest')
40 |
41 | if epoch % opt.save_epoch_freq == 0:
42 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
43 | model.save('latest')
44 | model.save(epoch)
45 |
46 | print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
47 |
48 | if epoch > opt.niter:
49 | model.update_learning_rate()
50 |
51 |
52 | if __name__ == '__main__':
53 | freeze_support()
54 |
55 | # python train.py --dataroot /.path_to_your_data --learn_residual --resize_or_crop crop --fineSize CROP_SIZE (we used 256)
56 |
57 | opt = TrainOptions().parse()
58 | opt.dataroot = 'D:\Photos\TrainingData\BlurredSharp\combined'
59 | opt.learn_residual = True
60 | opt.resize_or_crop = "crop"
61 | opt.fineSize = 256
62 | opt.gan_type = "gan"
63 | # opt.which_model_netG = "unet_256"
64 |
65 | # default = 5000
66 | opt.save_latest_freq = 100
67 |
68 | # default = 100
69 | opt.print_freq = 20
70 |
71 | data_loader = CreateDataLoader(opt)
72 | model = create_model(opt)
73 | visualizer = Visualizer(opt)
74 | train(opt, data_loader, model, visualizer)
75 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/util/__init__.py
--------------------------------------------------------------------------------
/util/get_data.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import tarfile
4 | import requests
5 | from warnings import warn
6 | from zipfile import ZipFile
7 | from bs4 import BeautifulSoup
8 | from os.path import abspath, isdir, join, basename
9 |
10 |
11 | class GetData(object):
12 | """
13 |
14 | Download CycleGAN or Pix2Pix Data.
15 |
16 | Args:
17 | technique : str
18 | One of: 'cyclegan' or 'pix2pix'.
19 | verbose : bool
20 | If True, print additional information.
21 |
22 | Examples:
23 | >>> from util.get_data import GetData
24 | >>> gd = GetData(technique='cyclegan')
25 | >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
26 |
27 | """
28 |
29 | def __init__(self, technique='cyclegan', verbose=True):
30 | url_dict = {
31 | 'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets',
32 | 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
33 | }
34 | self.url = url_dict.get(technique.lower())
35 | self._verbose = verbose
36 |
37 | def _print(self, text):
38 | if self._verbose:
39 | print(text)
40 |
41 | @staticmethod
42 | def _get_options(r):
43 | soup = BeautifulSoup(r.text, 'lxml')
44 | options = [h.text for h in soup.find_all('a', href=True)
45 | if h.text.endswith(('.zip', 'tar.gz'))]
46 | return options
47 |
48 | def _present_options(self):
49 | r = requests.get(self.url)
50 | options = self._get_options(r)
51 | print('Options:\n')
52 | for i, o in enumerate(options):
53 | print("{0}: {1}".format(i, o))
54 | choice = input("\nPlease enter the number of the "
55 | "dataset above you wish to download:")
56 | return options[int(choice)]
57 |
58 | def _download_data(self, dataset_url, save_path):
59 | if not isdir(save_path):
60 | os.makedirs(save_path)
61 |
62 | base = basename(dataset_url)
63 | temp_save_path = join(save_path, base)
64 |
65 | with open(temp_save_path, "wb") as f:
66 | r = requests.get(dataset_url)
67 | f.write(r.content)
68 |
69 | if base.endswith('.tar.gz'):
70 | obj = tarfile.open(temp_save_path)
71 | elif base.endswith('.zip'):
72 | obj = ZipFile(temp_save_path, 'r')
73 | else:
74 | raise ValueError("Unknown File Type: {0}.".format(base))
75 |
76 | self._print("Unpacking Data...")
77 | obj.extractall(save_path)
78 | obj.close()
79 | os.remove(temp_save_path)
80 |
81 | def get(self, save_path, dataset=None):
82 | """
83 |
84 | Download a dataset.
85 |
86 | Args:
87 | save_path : str
88 | A directory to save the data to.
89 | dataset : str, optional
90 | A specific dataset to download.
91 | Note: this must include the file extension.
92 | If None, options will be presented for you
93 | to choose from.
94 |
95 | Returns:
96 | save_path_full : str
97 | The absolute path to the downloaded data.
98 |
99 | """
100 | if dataset is None:
101 | selected_dataset = self._present_options()
102 | else:
103 | selected_dataset = dataset
104 |
105 | save_path_full = join(save_path, selected_dataset.split('.')[0])
106 |
107 | if isdir(save_path_full):
108 | warn("\n'{0}' already exists. Voiding Download.".format(
109 | save_path_full))
110 | else:
111 | self._print('Downloading Data...')
112 | url = "{0}/{1}".format(self.url, selected_dataset)
113 | self._download_data(url, save_path=save_path)
114 |
115 | return abspath(save_path_full)
116 |
--------------------------------------------------------------------------------
/util/html.py:
--------------------------------------------------------------------------------
1 | import dominate
2 | from dominate.tags import *
3 | import os
4 |
5 |
6 | class HTML:
7 | def __init__(self, web_dir, title, reflesh=0):
8 | self.title = title
9 | self.web_dir = web_dir
10 | self.img_dir = os.path.join(self.web_dir, 'images')
11 | if not os.path.exists(self.web_dir):
12 | os.makedirs(self.web_dir)
13 | if not os.path.exists(self.img_dir):
14 | os.makedirs(self.img_dir)
15 | # print(self.img_dir)
16 |
17 | self.doc = dominate.document(title=title)
18 | if reflesh > 0:
19 | with self.doc.head:
20 | meta(http_equiv="reflesh", content=str(reflesh))
21 |
22 | def get_image_dir(self):
23 | return self.img_dir
24 |
25 | def add_header(self, str):
26 | with self.doc:
27 | h3(str)
28 |
29 | def add_table(self, border=1):
30 | self.t = table(border=border, style="table-layout: fixed;")
31 | self.doc.add(self.t)
32 |
33 | def add_images(self, ims, txts, links, width=400):
34 | self.add_table()
35 | with self.t:
36 | with tr():
37 | for im, txt, link in zip(ims, txts, links):
38 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
39 | with p():
40 | with a(href=os.path.join('images', link)):
41 | img(style="width:%dpx" % width, src=os.path.join('images', im))
42 | br()
43 | p(txt)
44 |
45 | def save(self):
46 | html_file = '%s/index.html' % self.web_dir
47 | f = open(html_file, 'wt')
48 | f.write(self.doc.render())
49 | f.close()
50 |
51 |
52 | if __name__ == '__main__':
53 | html = HTML('web/', 'test_html')
54 | html.add_header('hello world')
55 |
56 | ims = []
57 | txts = []
58 | links = []
59 | for n in range(4):
60 | ims.append('image_%d.png' % n)
61 | txts.append('text_%d' % n)
62 | links.append('image_%d.png' % n)
63 | html.add_images(ims, txts, links)
64 | html.save()
65 |
--------------------------------------------------------------------------------
/util/image_pool.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 | from torch.autograd import Variable
5 | class ImagePool():
6 | def __init__(self, pool_size):
7 | self.pool_size = pool_size
8 | if self.pool_size > 0:
9 | self.num_imgs = 0
10 | self.images = []
11 |
12 | def query(self, images):
13 | if self.pool_size == 0:
14 | return images
15 | return_images = []
16 | for image in images.data:
17 | image = torch.unsqueeze(image, 0)
18 | if self.num_imgs < self.pool_size:
19 | self.num_imgs = self.num_imgs + 1
20 | self.images.append(image)
21 | return_images.append(image)
22 | else:
23 | p = random.uniform(0, 1)
24 | if p > 0.5:
25 | random_id = random.randint(0, self.pool_size-1)
26 | tmp = self.images[random_id].clone()
27 | self.images[random_id] = image
28 | return_images.append(tmp)
29 | else:
30 | return_images.append(image)
31 | return_images = Variable(torch.cat(return_images, 0))
32 | return return_images
33 |
--------------------------------------------------------------------------------
/util/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | import numpy as np
5 | from math import exp
6 | import math
7 |
8 | def gaussian(window_size, sigma):
9 | gauss = torch.Tensor([exp(-(x - window_size/2)**2/float(2*sigma**2)) for x in range(window_size)])
10 | return gauss/gauss.sum()
11 |
12 | def create_window(window_size, channel):
13 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
15 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size))
16 | return window
17 |
18 | def SSIM(img1, img2):
19 | (_, channel, _, _) = img1.size()
20 | window_size = 11
21 | window = create_window(window_size, channel)
22 | mu1 = F.conv2d(img1, window, padding = window_size/2, groups = channel)
23 | mu2 = F.conv2d(img2, window, padding = window_size/2, groups = channel)
24 |
25 | mu1_sq = mu1.pow(2)
26 | mu2_sq = mu2.pow(2)
27 | mu1_mu2 = mu1*mu2
28 |
29 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size/2, groups = channel) - mu1_sq
30 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size/2, groups = channel) - mu2_sq
31 | sigma12 = F.conv2d(img1*img2, window, padding = window_size/2, groups = channel) - mu1_mu2
32 |
33 | C1 = 0.01**2
34 | C2 = 0.03**2
35 |
36 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
37 | return ssim_map.mean()
38 |
39 | def PSNR(img1, img2):
40 | mse = np.mean( (img1/255. - img2/255.) ** 2 )
41 | if mse == 0:
42 | return 100
43 | PIXEL_MAX = 1
44 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
45 |
--------------------------------------------------------------------------------
/util/png.py:
--------------------------------------------------------------------------------
1 | import struct
2 | import zlib
3 |
4 | def encode(buf, width, height):
5 | """ buf: must be bytes or a bytearray in py3, a regular string in py2. formatted RGBRGB... """
6 | assert (width * height * 3 == len(buf))
7 | bpp = 3
8 |
9 | def raw_data():
10 | # reverse the vertical line order and add null bytes at the start
11 | row_bytes = width * bpp
12 | for row_start in range((height - 1) * width * bpp, -1, -row_bytes):
13 | yield b'\x00'
14 | yield buf[row_start:row_start + row_bytes]
15 |
16 | def chunk(tag, data):
17 | return [
18 | struct.pack("!I", len(data)),
19 | tag,
20 | data,
21 | struct.pack("!I", 0xFFFFFFFF & zlib.crc32(data, zlib.crc32(tag)))
22 | ]
23 |
24 | SIGNATURE = b'\x89PNG\r\n\x1a\n'
25 | COLOR_TYPE_RGB = 2
26 | COLOR_TYPE_RGBA = 6
27 | bit_depth = 8
28 | return b''.join(
29 | [ SIGNATURE ] +
30 | chunk(b'IHDR', struct.pack("!2I5B", width, height, bit_depth, COLOR_TYPE_RGB, 0, 0, 0)) +
31 | chunk(b'IDAT', zlib.compress(b''.join(raw_data()), 9)) +
32 | chunk(b'IEND', b'')
33 | )
34 |
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 | import inspect, re
6 | import numpy as np
7 | import os
8 | import collections
9 |
10 | # Converts a Tensor into a Numpy array
11 | # |imtype|: the desired type of the converted numpy array
12 | def tensor2im(image_tensor, imtype=np.uint8):
13 | image_numpy = image_tensor[0].cpu().float().numpy()
14 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
15 | return image_numpy.astype(imtype)
16 |
17 |
18 | def diagnose_network(net, name='network'):
19 | mean = 0.0
20 | count = 0
21 | for param in net.parameters():
22 | if param.grad is not None:
23 | mean += torch.mean(torch.abs(param.grad.data))
24 | count += 1
25 | if count > 0:
26 | mean = mean / count
27 | print(name)
28 | print(mean)
29 |
30 |
31 | def save_image(image_numpy, image_path):
32 | image_pil = None
33 | if image_numpy.shape[2] == 1:
34 | image_numpy = np.reshape(image_numpy, (image_numpy.shape[0],image_numpy.shape[1]))
35 | image_pil = Image.fromarray(image_numpy, 'L')
36 | else:
37 | image_pil = Image.fromarray(image_numpy)
38 | image_pil.save(image_path)
39 |
40 | def info(object, spacing=10, collapse=1):
41 | """Print methods and doc strings.
42 | Takes module, class, list, dictionary, or string."""
43 | methodList = [e for e in dir(object) if isinstance(getattr(object, e), collections.Callable)]
44 | processFunc = collapse and (lambda s: " ".join(s.split())) or (lambda s: s)
45 | print( "\n".join(["%s %s" %
46 | (method.ljust(spacing),
47 | processFunc(str(getattr(object, method).__doc__)))
48 | for method in methodList]) )
49 |
50 | def varname(p):
51 | for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]:
52 | m = re.search(r'\bvarname\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\)', line)
53 | if m:
54 | return m.group(1)
55 |
56 | def print_numpy(x, val=True, shp=False):
57 | x = x.astype(np.float64)
58 | if shp:
59 | print('shape,', x.shape)
60 | if val:
61 | x = x.flatten()
62 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
63 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
64 |
65 |
66 | def mkdirs(paths):
67 | if isinstance(paths, list) and not isinstance(paths, str):
68 | for path in paths:
69 | mkdir(path)
70 | else:
71 | mkdir(paths)
72 |
73 |
74 | def mkdir(path):
75 | if not os.path.exists(path):
76 | os.makedirs(path)
77 |
--------------------------------------------------------------------------------
/util/visualizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import ntpath
4 | import time
5 | from . import util
6 | from . import html
7 |
8 | class Visualizer():
9 | def __init__(self, opt):
10 | # self.opt = opt
11 | self.display_id = opt.display_id
12 | self.use_html = opt.isTrain and not opt.no_html
13 | self.win_size = opt.display_winsize
14 | self.name = opt.name
15 | if self.display_id > 0:
16 | import visdom
17 | self.vis = visdom.Visdom(port = opt.display_port)
18 | self.display_single_pane_ncols = opt.display_single_pane_ncols
19 |
20 | if self.use_html:
21 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
22 | self.img_dir = os.path.join(self.web_dir, 'images')
23 | print('create web directory %s...' % self.web_dir)
24 | util.mkdirs([self.web_dir, self.img_dir])
25 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
26 | with open(self.log_name, "a") as log_file:
27 | now = time.strftime("%c")
28 | log_file.write('================ Training Loss (%s) ================\n' % now)
29 |
30 | # |visuals|: dictionary of images to display or save
31 | def display_current_results(self, visuals, epoch):
32 | if self.display_id > 0: # show images in the browser
33 | if self.display_single_pane_ncols > 0:
34 | h, w = next(iter(visuals.values())).shape[:2]
35 | table_css = """""" % (w, h)
39 | ncols = self.display_single_pane_ncols
40 | title = self.name
41 | label_html = ''
42 | label_html_row = ''
43 | nrows = int(np.ceil(len(visuals.items()) / ncols))
44 | images = []
45 | idx = 0
46 | for label, image_numpy in visuals.items():
47 | label_html_row += '%s | ' % label
48 | images.append(image_numpy.transpose([2, 0, 1]))
49 | idx += 1
50 | if idx % ncols == 0:
51 | label_html += '%s
' % label_html_row
52 | label_html_row = ''
53 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255
54 | while idx % ncols != 0:
55 | images.append(white_image)
56 | label_html_row += ' | '
57 | idx += 1
58 | if label_html_row != '':
59 | label_html += '%s
' % label_html_row
60 | # pane col = image row
61 | self.vis.images(images, nrow=ncols, win=self.display_id + 1,
62 | padding=2, opts=dict(title=title + ' images'))
63 | label_html = '' % label_html
64 | self.vis.text(table_css + label_html, win = self.display_id + 2,
65 | opts=dict(title=title + ' labels'))
66 | else:
67 | idx = 1
68 | for label, image_numpy in visuals.items():
69 | #image_numpy = np.flipud(image_numpy)
70 | self.vis.image(image_numpy.transpose([2,0,1]), opts=dict(title=label),
71 | win=self.display_id + idx)
72 | idx += 1
73 |
74 | if self.use_html: # save images to a html file
75 | for label, image_numpy in visuals.items():
76 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
77 | util.save_image(image_numpy, img_path)
78 | # update website
79 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1)
80 | for n in range(epoch, 0, -1):
81 | webpage.add_header('Results of Epoch [%d]' % n)
82 | ims = []
83 | txts = []
84 | links = []
85 |
86 | for label, image_numpy in visuals.items():
87 | img_path = 'epoch%.3d_%s.png' % (n, label)
88 | ims.append(img_path)
89 | txts.append(label)
90 | links.append(img_path)
91 | webpage.add_images(ims, txts, links, width=self.win_size)
92 | webpage.save()
93 |
94 | # errors: dictionary of error labels and values
95 | def plot_current_errors(self, epoch, counter_ratio, opt, errors):
96 | if not hasattr(self, 'plot_data'):
97 | self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())}
98 | self.plot_data['X'].append(epoch + counter_ratio)
99 | self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']])
100 | self.vis.line(
101 | X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1),
102 | Y=np.array(self.plot_data['Y']),
103 | opts={
104 | 'title': self.name + ' loss over time',
105 | 'legend': self.plot_data['legend'],
106 | 'xlabel': 'epoch',
107 | 'ylabel': 'loss'},
108 | win=self.display_id)
109 |
110 | # errors: same format as |errors| of plotCurrentErrors
111 | def print_current_errors(self, epoch, i, errors, t):
112 | message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t)
113 | for k, v in errors.items():
114 | message += '%s: %.3f ' % (k, v)
115 |
116 | print(message)
117 | with open(self.log_name, "a") as log_file:
118 | log_file.write('%s\n' % message)
119 |
120 | # save image to the disk
121 | def save_images(self, webpage, visuals, image_path):
122 | image_dir = webpage.get_image_dir()
123 | short_path = ntpath.basename(image_path[0])
124 | name = os.path.splitext(short_path)[0]
125 |
126 | webpage.add_header(name)
127 | ims = []
128 | txts = []
129 | links = []
130 |
131 | for label, image_numpy in visuals.items():
132 | image_name = '%s_%s.png' % (name, label)
133 | save_path = os.path.join(image_dir, image_name)
134 | util.save_image(image_numpy, save_path)
135 |
136 | ims.append(image_name)
137 | txts.append(label)
138 | links.append(image_name)
139 | webpage.add_images(ims, txts, links, width=self.win_size)
140 |
--------------------------------------------------------------------------------