├── .gitignore ├── LICENSE ├── README.md ├── checkpoints └── experiment_name │ ├── latest_net_D.pth │ ├── latest_net_G.pth │ ├── opt.txt │ └── web │ └── index.html ├── data ├── __init__.py ├── aligned_dataset.py ├── base_data_loader.py ├── base_dataset.py ├── custom_dataset_data_loader.py ├── data_loader.py ├── image_folder.py ├── single_dataset.py └── unaligned_dataset.py ├── datasets ├── combine_A_and_B.py └── helper functions │ └── grayscale.py ├── images ├── animation1.gif ├── animation2.gif ├── animation3.gif ├── animation4.gif ├── results.png ├── test1_blur.jpg ├── test1_restored.jpg ├── test1_sharp.jpg ├── yolo_b.jpg ├── yolo_o.jpg └── yolo_s.jpg ├── models ├── __init__.py ├── base_model.py ├── conditional_gan_model.py ├── losses.py ├── models.py ├── networks.py └── test_model.py ├── motion_blur ├── __init__.py ├── blur_image.py ├── generate_PSF.py └── generate_trajectory.py ├── options ├── __init__.py ├── base_options.py ├── test_options.py └── train_options.py ├── test.py ├── train.py └── util ├── __init__.py ├── get_data.py ├── html.py ├── image_pool.py ├── metrics.py ├── png.py ├── util.py └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.iml 2 | *.xml 3 | *.pyc 4 | *.png 5 | *.txt 6 | *.pth 7 | *.html 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | 25 | 26 | --------------------------- LICENSE FOR pix2pix -------------------------------- 27 | BSD License 28 | 29 | For pix2pix software 30 | Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu 31 | All rights reserved. 32 | 33 | Redistribution and use in source and binary forms, with or without 34 | modification, are permitted provided that the following conditions are met: 35 | 36 | * Redistributions of source code must retain the above copyright notice, this 37 | list of conditions and the following disclaimer. 38 | 39 | * Redistributions in binary form must reproduce the above copyright notice, 40 | this list of conditions and the following disclaimer in the documentation 41 | and/or other materials provided with the distribution. 42 | 43 | ----------------------------- LICENSE FOR DCGAN -------------------------------- 44 | BSD License 45 | 46 | For dcgan.torch software 47 | 48 | Copyright (c) 2015, Facebook, Inc. All rights reserved. 49 | 50 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 51 | 52 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 53 | 54 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 55 | 56 | Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 57 | 58 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 59 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeblurGAN 2 | [arXiv Paper Version](https://arxiv.org/pdf/1711.07064.pdf) 3 | 4 | Pytorch implementation of the paper DeblurGAN: Blind Motion Deblurring Using Conditional Adversarial Networks. 5 | 6 | Our network takes blurry image as an input and procude the corresponding sharp estimate, as in the example: 7 | 8 | 9 | 10 | The model we use is Conditional Wasserstein GAN with Gradient Penalty + Perceptual loss based on VGG-19 activations. Such architecture also gives good results on other image-to-image translation problems (super resolution, colorization, inpainting, dehazing etc.) 11 | 12 | ## How to run 13 | 14 | ### Prerequisites 15 | - NVIDIA GPU + CUDA CuDNN (CPU untested, feedback appreciated) 16 | - Pytorch 17 | 18 | Download weights from [Google Drive](https://drive.google.com/file/d/1liKzdjMRHZ-i5MWhC72EL7UZLNPj5_8Y/view?usp=sharing) . Note that during the inference you need to keep only Generator weights. 19 | 20 | Put the weights into 21 | ```bash 22 | /.checkpoints/experiment_name 23 | ``` 24 | To test a model put your blurry images into a folder and run: 25 | ```bash 26 | python test.py --dataroot /.path_to_your_data --model test --dataset_mode single --learn_residual 27 | ``` 28 | ## Data 29 | Download dataset for Object Detection benchmark from [Google Drive](https://drive.google.com/file/d/1CPMBmRj-jBDO2ax4CxkBs9iczIFrs8VA/view?usp=sharing) 30 | 31 | ## Train 32 | 33 | If you want to train the model on your data run the following command to create image pairs: 34 | ```bash 35 | python datasets/combine_A_and_B.py --fold_A /path/to/data/A --fold_B /path/to/data/B --fold_AB /path/to/data 36 | ``` 37 | And then the following command to train the model 38 | 39 | ```bash 40 | python train.py --dataroot /.path_to_your_data --learn_residual --resize_or_crop crop --fineSize CROP_SIZE (we used 256) 41 | ``` 42 | 43 | ## Other Implementations 44 | 45 | [Keras Blog](https://blog.sicara.com/keras-generative-adversarial-networks-image-deblurring-45e3ab6977b5) 46 | 47 | [Keras Repository](https://github.com/RaphaelMeudec/deblur-gan) 48 | 49 | 50 | 51 | ## Citation 52 | 53 | If you find our code helpful in your research or work please cite our paper. 54 | 55 | ``` 56 | @article{DeblurGAN, 57 | title = {DeblurGAN: Blind Motion Deblurring Using Conditional Adversarial Networks}, 58 | author = {Kupyn, Orest and Budzan, Volodymyr and Mykhailych, Mykola and Mishkin, Dmytro and Matas, Jiri}, 59 | journal = {ArXiv e-prints}, 60 | eprint = {1711.07064}, 61 | year = 2017 62 | } 63 | ``` 64 | 65 | ## Acknowledgments 66 | Code borrows heavily from [pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). The images were taken from GoPRO test dataset - [DeepDeblur](https://github.com/SeungjunNah/DeepDeblur_release) 67 | 68 | 69 | -------------------------------------------------------------------------------- /checkpoints/experiment_name/latest_net_D.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/checkpoints/experiment_name/latest_net_D.pth -------------------------------------------------------------------------------- /checkpoints/experiment_name/latest_net_G.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/checkpoints/experiment_name/latest_net_G.pth -------------------------------------------------------------------------------- /checkpoints/experiment_name/opt.txt: -------------------------------------------------------------------------------- 1 | ------------ Options ------------- 2 | batchSize: 1 3 | beta1: 0.5 4 | checkpoints_dir: ./checkpoints 5 | continue_train: False 6 | dataroot: D:\Photos\TrainingData\BlurredSharp\combined 7 | dataset_mode: aligned 8 | display_freq: 100 9 | display_id: 1 10 | display_port: 8097 11 | display_single_pane_ncols: 0 12 | display_winsize: 256 13 | epoch_count: 1 14 | fineSize: 256 15 | gan_type: wgan-gp 16 | gpu_ids: [0] 17 | identity: 0.0 18 | input_nc: 3 19 | isTrain: True 20 | lambda_A: 100.0 21 | lambda_B: 10.0 22 | learn_residual: False 23 | loadSizeX: 640 24 | loadSizeY: 360 25 | lr: 0.0001 26 | max_dataset_size: inf 27 | model: content_gan 28 | nThreads: 2 29 | n_layers_D: 3 30 | name: experiment_name 31 | ndf: 64 32 | ngf: 64 33 | niter: 150 34 | niter_decay: 150 35 | no_dropout: False 36 | no_flip: False 37 | no_html: False 38 | norm: instance 39 | output_nc: 3 40 | phase: train 41 | pool_size: 50 42 | print_freq: 100 43 | resize_or_crop: resize_and_crop 44 | save_epoch_freq: 5 45 | save_latest_freq: 5000 46 | serial_batches: False 47 | which_direction: AtoB 48 | which_epoch: latest 49 | which_model_netD: basic 50 | which_model_netG: resnet_9blocks 51 | -------------- End ---------------- 52 | -------------------------------------------------------------------------------- /checkpoints/experiment_name/web/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Experiment name = experiment_name 5 | 6 | 7 | 8 |

Results of Epoch [33]

9 | 10 | 11 | 19 | 27 | 35 | 36 |
12 |

13 | 14 | 15 |
16 |

Blurred_Train

17 |

18 |
20 |

21 | 22 | 23 |
24 |

Restored_Train

25 |

26 |
28 |

29 | 30 | 31 |
32 |

Sharp_Train

33 |

34 |
37 |

Results of Epoch [32]

38 | 39 | 40 | 48 | 56 | 64 | 65 |
41 |

42 | 43 | 44 |
45 |

Blurred_Train

46 |

47 |
49 |

50 | 51 | 52 |
53 |

Restored_Train

54 |

55 |
57 |

58 | 59 | 60 |
61 |

Sharp_Train

62 |

63 |
66 |

Results of Epoch [31]

67 | 68 | 69 | 77 | 85 | 93 | 94 |
70 |

71 | 72 | 73 |
74 |

Blurred_Train

75 |

76 |
78 |

79 | 80 | 81 |
82 |

Restored_Train

83 |

84 |
86 |

87 | 88 | 89 |
90 |

Sharp_Train

91 |

92 |
95 |

Results of Epoch [30]

96 | 97 | 98 | 106 | 114 | 122 | 123 |
99 |

100 | 101 | 102 |
103 |

Blurred_Train

104 |

105 |
107 |

108 | 109 | 110 |
111 |

Restored_Train

112 |

113 |
115 |

116 | 117 | 118 |
119 |

Sharp_Train

120 |

121 |
124 |

Results of Epoch [29]

125 | 126 | 127 | 135 | 143 | 151 | 152 |
128 |

129 | 130 | 131 |
132 |

Blurred_Train

133 |

134 |
136 |

137 | 138 | 139 |
140 |

Restored_Train

141 |

142 |
144 |

145 | 146 | 147 |
148 |

Sharp_Train

149 |

150 |
153 |

Results of Epoch [28]

154 | 155 | 156 | 164 | 172 | 180 | 181 |
157 |

158 | 159 | 160 |
161 |

Blurred_Train

162 |

163 |
165 |

166 | 167 | 168 |
169 |

Restored_Train

170 |

171 |
173 |

174 | 175 | 176 |
177 |

Sharp_Train

178 |

179 |
182 |

Results of Epoch [27]

183 | 184 | 185 | 193 | 201 | 209 | 210 |
186 |

187 | 188 | 189 |
190 |

Blurred_Train

191 |

192 |
194 |

195 | 196 | 197 |
198 |

Restored_Train

199 |

200 |
202 |

203 | 204 | 205 |
206 |

Sharp_Train

207 |

208 |
211 |

Results of Epoch [26]

212 | 213 | 214 | 222 | 230 | 238 | 239 |
215 |

216 | 217 | 218 |
219 |

Blurred_Train

220 |

221 |
223 |

224 | 225 | 226 |
227 |

Restored_Train

228 |

229 |
231 |

232 | 233 | 234 |
235 |

Sharp_Train

236 |

237 |
240 |

Results of Epoch [25]

241 | 242 | 243 | 251 | 259 | 267 | 268 |
244 |

245 | 246 | 247 |
248 |

Blurred_Train

249 |

250 |
252 |

253 | 254 | 255 |
256 |

Restored_Train

257 |

258 |
260 |

261 | 262 | 263 |
264 |

Sharp_Train

265 |

266 |
269 |

Results of Epoch [24]

270 | 271 | 272 | 280 | 288 | 296 | 297 |
273 |

274 | 275 | 276 |
277 |

Blurred_Train

278 |

279 |
281 |

282 | 283 | 284 |
285 |

Restored_Train

286 |

287 |
289 |

290 | 291 | 292 |
293 |

Sharp_Train

294 |

295 |
298 |

Results of Epoch [23]

299 | 300 | 301 | 309 | 317 | 325 | 326 |
302 |

303 | 304 | 305 |
306 |

Blurred_Train

307 |

308 |
310 |

311 | 312 | 313 |
314 |

Restored_Train

315 |

316 |
318 |

319 | 320 | 321 |
322 |

Sharp_Train

323 |

324 |
327 |

Results of Epoch [22]

328 | 329 | 330 | 338 | 346 | 354 | 355 |
331 |

332 | 333 | 334 |
335 |

Blurred_Train

336 |

337 |
339 |

340 | 341 | 342 |
343 |

Restored_Train

344 |

345 |
347 |

348 | 349 | 350 |
351 |

Sharp_Train

352 |

353 |
356 |

Results of Epoch [21]

357 | 358 | 359 | 367 | 375 | 383 | 384 |
360 |

361 | 362 | 363 |
364 |

Blurred_Train

365 |

366 |
368 |

369 | 370 | 371 |
372 |

Restored_Train

373 |

374 |
376 |

377 | 378 | 379 |
380 |

Sharp_Train

381 |

382 |
385 |

Results of Epoch [20]

386 | 387 | 388 | 396 | 404 | 412 | 413 |
389 |

390 | 391 | 392 |
393 |

Blurred_Train

394 |

395 |
397 |

398 | 399 | 400 |
401 |

Restored_Train

402 |

403 |
405 |

406 | 407 | 408 |
409 |

Sharp_Train

410 |

411 |
414 |

Results of Epoch [19]

415 | 416 | 417 | 425 | 433 | 441 | 442 |
418 |

419 | 420 | 421 |
422 |

Blurred_Train

423 |

424 |
426 |

427 | 428 | 429 |
430 |

Restored_Train

431 |

432 |
434 |

435 | 436 | 437 |
438 |

Sharp_Train

439 |

440 |
443 |

Results of Epoch [18]

444 | 445 | 446 | 454 | 462 | 470 | 471 |
447 |

448 | 449 | 450 |
451 |

Blurred_Train

452 |

453 |
455 |

456 | 457 | 458 |
459 |

Restored_Train

460 |

461 |
463 |

464 | 465 | 466 |
467 |

Sharp_Train

468 |

469 |
472 |

Results of Epoch [17]

473 | 474 | 475 | 483 | 491 | 499 | 500 |
476 |

477 | 478 | 479 |
480 |

Blurred_Train

481 |

482 |
484 |

485 | 486 | 487 |
488 |

Restored_Train

489 |

490 |
492 |

493 | 494 | 495 |
496 |

Sharp_Train

497 |

498 |
501 |

Results of Epoch [16]

502 | 503 | 504 | 512 | 520 | 528 | 529 |
505 |

506 | 507 | 508 |
509 |

Blurred_Train

510 |

511 |
513 |

514 | 515 | 516 |
517 |

Restored_Train

518 |

519 |
521 |

522 | 523 | 524 |
525 |

Sharp_Train

526 |

527 |
530 |

Results of Epoch [15]

531 | 532 | 533 | 541 | 549 | 557 | 558 |
534 |

535 | 536 | 537 |
538 |

Blurred_Train

539 |

540 |
542 |

543 | 544 | 545 |
546 |

Restored_Train

547 |

548 |
550 |

551 | 552 | 553 |
554 |

Sharp_Train

555 |

556 |
559 |

Results of Epoch [14]

560 | 561 | 562 | 570 | 578 | 586 | 587 |
563 |

564 | 565 | 566 |
567 |

Blurred_Train

568 |

569 |
571 |

572 | 573 | 574 |
575 |

Restored_Train

576 |

577 |
579 |

580 | 581 | 582 |
583 |

Sharp_Train

584 |

585 |
588 |

Results of Epoch [13]

589 | 590 | 591 | 599 | 607 | 615 | 616 |
592 |

593 | 594 | 595 |
596 |

Blurred_Train

597 |

598 |
600 |

601 | 602 | 603 |
604 |

Restored_Train

605 |

606 |
608 |

609 | 610 | 611 |
612 |

Sharp_Train

613 |

614 |
617 |

Results of Epoch [12]

618 | 619 | 620 | 628 | 636 | 644 | 645 |
621 |

622 | 623 | 624 |
625 |

Blurred_Train

626 |

627 |
629 |

630 | 631 | 632 |
633 |

Restored_Train

634 |

635 |
637 |

638 | 639 | 640 |
641 |

Sharp_Train

642 |

643 |
646 |

Results of Epoch [11]

647 | 648 | 649 | 657 | 665 | 673 | 674 |
650 |

651 | 652 | 653 |
654 |

Blurred_Train

655 |

656 |
658 |

659 | 660 | 661 |
662 |

Restored_Train

663 |

664 |
666 |

667 | 668 | 669 |
670 |

Sharp_Train

671 |

672 |
675 |

Results of Epoch [10]

676 | 677 | 678 | 686 | 694 | 702 | 703 |
679 |

680 | 681 | 682 |
683 |

Blurred_Train

684 |

685 |
687 |

688 | 689 | 690 |
691 |

Restored_Train

692 |

693 |
695 |

696 | 697 | 698 |
699 |

Sharp_Train

700 |

701 |
704 |

Results of Epoch [9]

705 | 706 | 707 | 715 | 723 | 731 | 732 |
708 |

709 | 710 | 711 |
712 |

Blurred_Train

713 |

714 |
716 |

717 | 718 | 719 |
720 |

Restored_Train

721 |

722 |
724 |

725 | 726 | 727 |
728 |

Sharp_Train

729 |

730 |
733 |

Results of Epoch [8]

734 | 735 | 736 | 744 | 752 | 760 | 761 |
737 |

738 | 739 | 740 |
741 |

Blurred_Train

742 |

743 |
745 |

746 | 747 | 748 |
749 |

Restored_Train

750 |

751 |
753 |

754 | 755 | 756 |
757 |

Sharp_Train

758 |

759 |
762 |

Results of Epoch [7]

763 | 764 | 765 | 773 | 781 | 789 | 790 |
766 |

767 | 768 | 769 |
770 |

Blurred_Train

771 |

772 |
774 |

775 | 776 | 777 |
778 |

Restored_Train

779 |

780 |
782 |

783 | 784 | 785 |
786 |

Sharp_Train

787 |

788 |
791 |

Results of Epoch [6]

792 | 793 | 794 | 802 | 810 | 818 | 819 |
795 |

796 | 797 | 798 |
799 |

Blurred_Train

800 |

801 |
803 |

804 | 805 | 806 |
807 |

Restored_Train

808 |

809 |
811 |

812 | 813 | 814 |
815 |

Sharp_Train

816 |

817 |
820 |

Results of Epoch [5]

821 | 822 | 823 | 831 | 839 | 847 | 848 |
824 |

825 | 826 | 827 |
828 |

Blurred_Train

829 |

830 |
832 |

833 | 834 | 835 |
836 |

Restored_Train

837 |

838 |
840 |

841 | 842 | 843 |
844 |

Sharp_Train

845 |

846 |
849 |

Results of Epoch [4]

850 | 851 | 852 | 860 | 868 | 876 | 877 |
853 |

854 | 855 | 856 |
857 |

Blurred_Train

858 |

859 |
861 |

862 | 863 | 864 |
865 |

Restored_Train

866 |

867 |
869 |

870 | 871 | 872 |
873 |

Sharp_Train

874 |

875 |
878 |

Results of Epoch [3]

879 | 880 | 881 | 889 | 897 | 905 | 906 |
882 |

883 | 884 | 885 |
886 |

Blurred_Train

887 |

888 |
890 |

891 | 892 | 893 |
894 |

Restored_Train

895 |

896 |
898 |

899 | 900 | 901 |
902 |

Sharp_Train

903 |

904 |
907 |

Results of Epoch [2]

908 | 909 | 910 | 918 | 926 | 934 | 935 |
911 |

912 | 913 | 914 |
915 |

Blurred_Train

916 |

917 |
919 |

920 | 921 | 922 |
923 |

Restored_Train

924 |

925 |
927 |

928 | 929 | 930 |
931 |

Sharp_Train

932 |

933 |
936 |

Results of Epoch [1]

937 | 938 | 939 | 947 | 955 | 963 | 964 |
940 |

941 | 942 | 943 |
944 |

Blurred_Train

945 |

946 |
948 |

949 | 950 | 951 |
952 |

Restored_Train

953 |

954 |
956 |

957 | 958 | 959 |
960 |

Sharp_Train

961 |

962 |
965 | 966 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/data/__init__.py -------------------------------------------------------------------------------- /data/aligned_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import random 3 | import torchvision.transforms as transforms 4 | import torch 5 | from data.base_dataset import BaseDataset 6 | from data.image_folder import make_dataset 7 | from PIL import Image 8 | 9 | 10 | class AlignedDataset(BaseDataset): 11 | def __init__(self, opt): 12 | # super(AlignedDataset,self).__init__(opt) 13 | self.opt = opt 14 | self.root = opt.dataroot 15 | self.dir_AB = os.path.join(opt.dataroot, opt.phase) 16 | 17 | self.AB_paths = sorted(make_dataset(self.dir_AB)) 18 | 19 | #assert(opt.resize_or_crop == 'resize_and_crop') 20 | 21 | transform_list = [transforms.ToTensor(), 22 | transforms.Normalize((0.5, 0.5, 0.5), 23 | (0.5, 0.5, 0.5))] 24 | 25 | self.transform = transforms.Compose(transform_list) 26 | 27 | def __getitem__(self, index): 28 | AB_path = self.AB_paths[index] 29 | AB = Image.open(AB_path).convert('RGB') 30 | AB = AB.resize((self.opt.loadSizeX * 2, self.opt.loadSizeY), Image.BICUBIC) 31 | AB = self.transform(AB) 32 | 33 | w_total = AB.size(2) 34 | w = int(w_total / 2) 35 | h = AB.size(1) 36 | w_offset = random.randint(0, max(0, w - self.opt.fineSize - 1)) 37 | h_offset = random.randint(0, max(0, h - self.opt.fineSize - 1)) 38 | 39 | A = AB[:, h_offset:h_offset + self.opt.fineSize, 40 | w_offset:w_offset + self.opt.fineSize] 41 | B = AB[:, h_offset:h_offset + self.opt.fineSize, 42 | w + w_offset:w + w_offset + self.opt.fineSize] 43 | 44 | if (not self.opt.no_flip) and random.random() < 0.5: 45 | idx = [i for i in range(A.size(2) - 1, -1, -1)] 46 | idx = torch.LongTensor(idx) 47 | A = A.index_select(2, idx) 48 | B = B.index_select(2, idx) 49 | 50 | return {'A': A, 'B': B, 51 | 'A_paths': AB_path, 'B_paths': AB_path} 52 | 53 | def __len__(self): 54 | return len(self.AB_paths) 55 | 56 | def name(self): 57 | return 'AlignedDataset' 58 | -------------------------------------------------------------------------------- /data/base_data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | class BaseDataLoader(): 3 | def __init__(self): 4 | pass 5 | 6 | def initialize(self, opt): 7 | self.opt = opt 8 | pass 9 | 10 | def load_data(): 11 | return None 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import torchvision.transforms as transforms 4 | 5 | class BaseDataset(data.Dataset): 6 | def __init__(self): 7 | super(BaseDataset, self).__init__() 8 | 9 | def name(self): 10 | return 'BaseDataset' 11 | 12 | # def initialize(self, opt): 13 | # pass 14 | 15 | def get_transform(opt): 16 | transform_list = [] 17 | if opt.resize_or_crop == 'resize_and_crop': 18 | osize = [opt.loadSizeX, opt.loadSizeY] 19 | transform_list.append(transforms.Resize(osize, Image.BICUBIC)) 20 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 21 | elif opt.resize_or_crop == 'crop': 22 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 23 | elif opt.resize_or_crop == 'scale_width': 24 | transform_list.append(transforms.Lambda( 25 | lambda img: __scale_width(img, opt.fineSize))) 26 | elif opt.resize_or_crop == 'scale_width_and_crop': 27 | transform_list.append(transforms.Lambda( 28 | lambda img: __scale_width(img, opt.loadSizeX))) 29 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 30 | 31 | if opt.isTrain and not opt.no_flip: 32 | transform_list.append(transforms.RandomHorizontalFlip()) 33 | 34 | transform_list += [transforms.ToTensor(), 35 | transforms.Normalize((0.5, 0.5, 0.5), 36 | (0.5, 0.5, 0.5))] 37 | return transforms.Compose(transform_list) 38 | 39 | def __scale_width(img, target_width): 40 | ow, oh = img.size 41 | if (ow == target_width): 42 | return img 43 | w = target_width 44 | h = int(target_width * oh / ow) 45 | return img.resize((w, h), Image.BICUBIC) 46 | -------------------------------------------------------------------------------- /data/custom_dataset_data_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from data.base_data_loader import BaseDataLoader 3 | 4 | 5 | def CreateDataset(opt): 6 | dataset = None 7 | if opt.dataset_mode == 'aligned': 8 | from data.aligned_dataset import AlignedDataset 9 | dataset = AlignedDataset(opt) 10 | elif opt.dataset_mode == 'unaligned': 11 | from data.unaligned_dataset import UnalignedDataset 12 | dataset = UnalignedDataset() 13 | elif opt.dataset_mode == 'single': 14 | from data.single_dataset import SingleDataset 15 | dataset = SingleDataset() 16 | dataset.initialize(opt) 17 | else: 18 | raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode) 19 | 20 | print("dataset [%s] was created" % (dataset.name())) 21 | # dataset.initialize(opt) 22 | return dataset 23 | 24 | 25 | class CustomDatasetDataLoader(BaseDataLoader): 26 | def name(self): 27 | return 'CustomDatasetDataLoader' 28 | 29 | def __init__(self, opt): 30 | super(CustomDatasetDataLoader,self).initialize(opt) 31 | print("Opt.nThreads = ", opt.nThreads) 32 | self.dataset = CreateDataset(opt) 33 | self.dataloader = torch.utils.data.DataLoader( 34 | self.dataset, 35 | batch_size=opt.batchSize, 36 | shuffle=not opt.serial_batches, 37 | num_workers=int(opt.nThreads) 38 | ) 39 | 40 | def load_data(self): 41 | return self.dataloader 42 | 43 | def __len__(self): 44 | return min(len(self.dataset), self.opt.max_dataset_size) 45 | -------------------------------------------------------------------------------- /data/data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | def CreateDataLoader(opt): 3 | from data.custom_dataset_data_loader import CustomDatasetDataLoader 4 | data_loader = CustomDatasetDataLoader(opt) 5 | print(data_loader.name()) 6 | # data_loader.initialize(opt) 7 | return data_loader 8 | -------------------------------------------------------------------------------- /data/image_folder.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Code from 3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 4 | # Modified the original code so that it also loads images from the current 5 | # directory as well as the subdirectories 6 | ############################################################################### 7 | 8 | import torch.utils.data as data 9 | 10 | from PIL import Image 11 | import os 12 | import os.path 13 | 14 | IMG_EXTENSIONS = [ 15 | '.jpg', '.JPG', '.jpeg', '.JPEG', 16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 17 | ] 18 | 19 | 20 | def is_image_file(filename): 21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 22 | 23 | 24 | def make_dataset(dir): 25 | images = [] 26 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 27 | 28 | for root, _, fnames in sorted(os.walk(dir)): 29 | for fname in fnames: 30 | if is_image_file(fname): 31 | path = os.path.join(root, fname) 32 | images.append(path) 33 | 34 | return images 35 | 36 | 37 | def default_loader(path): 38 | return Image.open(path).convert('RGB') 39 | 40 | 41 | class ImageFolder(data.Dataset): 42 | 43 | def __init__(self, root, transform=None, return_paths=False, 44 | loader=default_loader): 45 | imgs = make_dataset(root) 46 | if len(imgs) == 0: 47 | raise(RuntimeError("Found 0 images in: " + root + "\n" 48 | "Supported image extensions are: " + 49 | ",".join(IMG_EXTENSIONS))) 50 | 51 | self.root = root 52 | self.imgs = imgs 53 | self.transform = transform 54 | self.return_paths = return_paths 55 | self.loader = loader 56 | 57 | def __getitem__(self, index): 58 | path = self.imgs[index] 59 | img = self.loader(path) 60 | if self.transform is not None: 61 | img = self.transform(img) 62 | if self.return_paths: 63 | return img, path 64 | else: 65 | return img 66 | 67 | def __len__(self): 68 | return len(self.imgs) 69 | -------------------------------------------------------------------------------- /data/single_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import torchvision.transforms as transforms 3 | from data.base_dataset import BaseDataset, get_transform 4 | from data.image_folder import make_dataset 5 | from PIL import Image 6 | 7 | 8 | class SingleDataset(BaseDataset): 9 | def initialize(self, opt): 10 | self.opt = opt 11 | self.root = opt.dataroot 12 | self.dir_A = os.path.join(opt.dataroot) 13 | 14 | self.A_paths = make_dataset(self.dir_A) 15 | 16 | self.A_paths = sorted(self.A_paths) 17 | 18 | self.transform = get_transform(opt) 19 | 20 | def __getitem__(self, index): 21 | A_path = self.A_paths[index] 22 | 23 | A_img = Image.open(A_path).convert('RGB') 24 | 25 | A_img = self.transform(A_img) 26 | 27 | return {'A': A_img, 'A_paths': A_path} 28 | 29 | def __len__(self): 30 | return len(self.A_paths) 31 | 32 | def name(self): 33 | return 'SingleImageDataset' 34 | -------------------------------------------------------------------------------- /data/unaligned_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import torchvision.transforms as transforms 3 | from data.base_dataset import BaseDataset, get_transform 4 | from data.image_folder import make_dataset 5 | from PIL import Image 6 | import PIL 7 | from pdb import set_trace as st 8 | import random 9 | import cv2 10 | 11 | class UnalignedDataset(BaseDataset): 12 | def initialize(self, opt): 13 | self.opt = opt 14 | self.root = opt.dataroot 15 | self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') 16 | self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') 17 | 18 | self.A_paths = make_dataset(self.dir_A) 19 | self.B_paths = make_dataset(self.dir_B) 20 | 21 | self.A_paths = sorted(self.A_paths) 22 | self.B_paths = sorted(self.B_paths) 23 | self.A_size = len(self.A_paths) 24 | self.B_size = len(self.B_paths) 25 | self.transform = get_transform(opt) 26 | 27 | def __getitem__(self, index): 28 | A_path = self.A_paths[index % self.A_size] 29 | index_A = index % self.A_size 30 | B_path = self.B_paths[index % self.A_size] 31 | # print('(A, B) = (%d, %d)' % (index_A, index_B)) 32 | A_img = Image.open(A_path).convert('L') 33 | B_img = Image.open(B_path).convert('RGB') 34 | 35 | A_img = self.transform(A_img) 36 | B_img = self.transform(B_img) 37 | 38 | return {'A': A_img, 'B': B_img, 39 | 'A_paths': A_path, 'B_paths': B_path} 40 | 41 | def __len__(self): 42 | return max(self.A_size, self.B_size) 43 | 44 | def name(self): 45 | return 'UnalignedDataset' 46 | -------------------------------------------------------------------------------- /datasets/combine_A_and_B.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as st 2 | import os 3 | import numpy as np 4 | import cv2 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser('create image pairs') 8 | parser.add_argument('--fold_A', dest='fold_A', help='input directory for image A', type=str, default='../dataset/50kshoes_edges') 9 | parser.add_argument('--fold_B', dest='fold_B', help='input directory for image B', type=str, default='../dataset/50kshoes_jpg') 10 | parser.add_argument('--fold_AB', dest='fold_AB', help='output directory', type=str, default='../dataset/test_AB') 11 | parser.add_argument('--num_imgs', dest='num_imgs', help='number of images',type=int, default=1000000) 12 | parser.add_argument('--use_AB', dest='use_AB', help='if true: (0001_A, 0001_B) to (0001_AB)',action='store_true') 13 | args = parser.parse_args() 14 | 15 | for arg in vars(args): 16 | print('[%s] = ' % arg, getattr(args, arg)) 17 | 18 | splits = os.listdir(args.fold_A) 19 | 20 | for sp in splits: 21 | img_fold_A = os.path.join(args.fold_A, sp) 22 | img_fold_B = os.path.join(args.fold_B, sp) 23 | img_list = os.listdir(img_fold_A) 24 | if args.use_AB: 25 | img_list = [img_path for img_path in img_list if '_A.' in img_path] 26 | 27 | num_imgs = min(args.num_imgs, len(img_list)) 28 | print('split = %s, use %d/%d images' % (sp, num_imgs, len(img_list))) 29 | img_fold_AB = os.path.join(args.fold_AB, sp) 30 | if not os.path.isdir(img_fold_AB): 31 | os.makedirs(img_fold_AB) 32 | print('split = %s, number of images = %d' % (sp, num_imgs)) 33 | for n in range(num_imgs): 34 | name_A = img_list[n] 35 | path_A = os.path.join(img_fold_A, name_A) 36 | if args.use_AB: 37 | name_B = name_A.replace('_A.', '_B.') 38 | else: 39 | name_B = name_A 40 | path_B = os.path.join(img_fold_B, name_B) 41 | if os.path.isfile(path_A) and os.path.isfile(path_B): 42 | name_AB = name_A 43 | if args.use_AB: 44 | name_AB = name_AB.replace('_A.', '.') # remove _A 45 | path_AB = os.path.join(img_fold_AB, name_AB) 46 | im_A = cv2.imread(path_A, cv2.IMREAD_COLOR) 47 | im_B = cv2.imread(path_B, cv2.IMREAD_COLOR) 48 | im_AB = np.concatenate([im_A, im_B], 1) 49 | cv2.imwrite(path_AB, im_AB) 50 | -------------------------------------------------------------------------------- /datasets/helper functions/grayscale.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as st 2 | import os 3 | import numpy as np 4 | import cv2 5 | import argparse 6 | 7 | # Helper script to create dataset for image colorization 8 | 9 | parser = argparse.ArgumentParser('create image pairs') 10 | parser.add_argument('--fold_A', dest='fold_A', help='input directory for images', type=str, default='../dataset/50kshoes_edges') 11 | parser.add_argument('--fold_B', dest='fold_B', help='output directory', type=str, default='../dataset/test_B') 12 | parser.add_argument('--num_imgs', dest='num_imgs', help='number of images',type=int, default=1000000) 13 | args = parser.parse_args() 14 | 15 | for arg in vars(args): 16 | print('[%s] = ' % arg, getattr(args, arg)) 17 | 18 | splits = os.listdir(args.fold_A) 19 | 20 | for sp in splits: 21 | img_fold_A = os.path.join(args.fold_A, sp) 22 | img_list = os.listdir(img_fold_A) 23 | 24 | num_imgs = min(args.num_imgs, len(img_list)) 25 | print('split = %s, use %d/%d images' % (sp, num_imgs, len(img_list))) 26 | img_fold_B = os.path.join(args.fold_B, sp) 27 | if not os.path.isdir(img_fold_B): 28 | os.makedirs(img_fold_B) 29 | print('split = %s, number of images = %d' % (sp, num_imgs)) 30 | for n in range(num_imgs): 31 | name_A = img_list[n] 32 | path_A = os.path.join(img_fold_A, name_A) 33 | 34 | if os.path.isfile(path_A): 35 | name_B = name_A 36 | 37 | path_B = os.path.join(img_fold_B, name_B) 38 | im_A = cv2.imread(path_A, 0) 39 | cv2.imwrite(path_B, im_A) 40 | -------------------------------------------------------------------------------- /images/animation1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/images/animation1.gif -------------------------------------------------------------------------------- /images/animation2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/images/animation2.gif -------------------------------------------------------------------------------- /images/animation3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/images/animation3.gif -------------------------------------------------------------------------------- /images/animation4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/images/animation4.gif -------------------------------------------------------------------------------- /images/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/images/results.png -------------------------------------------------------------------------------- /images/test1_blur.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/images/test1_blur.jpg -------------------------------------------------------------------------------- /images/test1_restored.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/images/test1_restored.jpg -------------------------------------------------------------------------------- /images/test1_sharp.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/images/test1_sharp.jpg -------------------------------------------------------------------------------- /images/yolo_b.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/images/yolo_b.jpg -------------------------------------------------------------------------------- /images/yolo_o.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/images/yolo_o.jpg -------------------------------------------------------------------------------- /images/yolo_s.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/images/yolo_s.jpg -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/models/__init__.py -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | class BaseModel(): 6 | def name(self): 7 | return 'BaseModel' 8 | 9 | def __init__(self, opt): 10 | self.opt = opt 11 | self.gpu_ids = opt.gpu_ids 12 | self.isTrain = opt.isTrain 13 | self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor 14 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 15 | 16 | def set_input(self, input): 17 | self.input = input 18 | 19 | def forward(self): 20 | pass 21 | 22 | # used in test time, no backprop 23 | def test(self): 24 | pass 25 | 26 | def get_image_paths(self): 27 | pass 28 | 29 | def optimize_parameters(self): 30 | pass 31 | 32 | def get_current_visuals(self): 33 | return self.input 34 | 35 | def get_current_errors(self): 36 | return {} 37 | 38 | def save(self, label): 39 | pass 40 | 41 | # helper saving function that can be used by subclasses 42 | def save_network(self, network, network_label, epoch_label, gpu_ids): 43 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 44 | save_path = os.path.join(self.save_dir, save_filename) 45 | torch.save(network.cpu().state_dict(), save_path) 46 | if len(gpu_ids) and torch.cuda.is_available(): 47 | network.cuda(device=gpu_ids[0]) 48 | 49 | 50 | # helper loading function that can be used by subclasses 51 | def load_network(self, network, network_label, epoch_label): 52 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 53 | save_path = os.path.join(self.save_dir, save_filename) 54 | network.load_state_dict(torch.load(save_path)) 55 | 56 | def update_learning_rate(): 57 | pass 58 | -------------------------------------------------------------------------------- /models/conditional_gan_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | from collections import OrderedDict 5 | from torch.autograd import Variable 6 | import util.util as util 7 | from util.image_pool import ImagePool 8 | from .base_model import BaseModel 9 | from . import networks 10 | from .losses import init_loss 11 | 12 | try: 13 | xrange # Python2 14 | except NameError: 15 | xrange = range # Python 3 16 | 17 | class ConditionalGAN(BaseModel): 18 | def name(self): 19 | return 'ConditionalGANModel' 20 | 21 | def __init__(self, opt): 22 | super(ConditionalGAN, self).__init__(opt) 23 | self.isTrain = opt.isTrain 24 | # define tensors 25 | self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) 26 | self.input_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) 27 | 28 | # load/define networks 29 | # Temp Fix for nn.parallel as nn.parallel crashes oc calculating gradient penalty 30 | use_parallel = not opt.gan_type == 'wgan-gp' 31 | print("Use Parallel = ", "True" if use_parallel else "False") 32 | self.netG = networks.define_G( 33 | opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, 34 | not opt.no_dropout, self.gpu_ids, use_parallel, opt.learn_residual 35 | ) 36 | if self.isTrain: 37 | use_sigmoid = opt.gan_type == 'gan' 38 | self.netD = networks.define_D( 39 | opt.output_nc, opt.ndf, opt.which_model_netD, 40 | opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids, use_parallel 41 | ) 42 | if not self.isTrain or opt.continue_train: 43 | self.load_network(self.netG, 'G', opt.which_epoch) 44 | if self.isTrain: 45 | self.load_network(self.netD, 'D', opt.which_epoch) 46 | 47 | if self.isTrain: 48 | self.fake_AB_pool = ImagePool(opt.pool_size) 49 | self.old_lr = opt.lr 50 | 51 | # initialize optimizers 52 | self.optimizer_G = torch.optim.Adam( self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999) ) 53 | self.optimizer_D = torch.optim.Adam( self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999) ) 54 | 55 | self.criticUpdates = 5 if opt.gan_type == 'wgan-gp' else 1 56 | 57 | # define loss functions 58 | self.discLoss, self.contentLoss = init_loss(opt, self.Tensor) 59 | 60 | print('---------- Networks initialized -------------') 61 | networks.print_network(self.netG) 62 | if self.isTrain: 63 | networks.print_network(self.netD) 64 | print('-----------------------------------------------') 65 | 66 | def set_input(self, input): 67 | AtoB = self.opt.which_direction == 'AtoB' 68 | inputA = input['A' if AtoB else 'B'] 69 | inputB = input['B' if AtoB else 'A'] 70 | self.input_A.resize_(inputA.size()).copy_(inputA) 71 | self.input_B.resize_(inputB.size()).copy_(inputB) 72 | self.image_paths = input['A_paths' if AtoB else 'B_paths'] 73 | 74 | def forward(self): 75 | self.real_A = Variable(self.input_A) 76 | self.fake_B = self.netG.forward(self.real_A) 77 | self.real_B = Variable(self.input_B) 78 | 79 | # no backprop gradients 80 | def test(self): 81 | self.real_A = Variable(self.input_A, volatile=True) 82 | self.fake_B = self.netG.forward(self.real_A) 83 | self.real_B = Variable(self.input_B, volatile=True) 84 | 85 | # get image paths 86 | def get_image_paths(self): 87 | return self.image_paths 88 | 89 | def backward_D(self): 90 | self.loss_D = self.discLoss.get_loss(self.netD, self.real_A, self.fake_B, self.real_B) 91 | 92 | self.loss_D.backward(retain_graph=True) 93 | 94 | def backward_G(self): 95 | self.loss_G_GAN = self.discLoss.get_g_loss(self.netD, self.real_A, self.fake_B) 96 | # Second, G(A) = B 97 | self.loss_G_Content = self.contentLoss.get_loss(self.fake_B, self.real_B) * self.opt.lambda_A 98 | 99 | self.loss_G = self.loss_G_GAN + self.loss_G_Content 100 | 101 | self.loss_G.backward() 102 | 103 | def optimize_parameters(self): 104 | self.forward() 105 | 106 | for iter_d in xrange(self.criticUpdates): 107 | self.optimizer_D.zero_grad() 108 | self.backward_D() 109 | self.optimizer_D.step() 110 | 111 | self.optimizer_G.zero_grad() 112 | self.backward_G() 113 | self.optimizer_G.step() 114 | 115 | def get_current_errors(self): 116 | return OrderedDict([('G_GAN', self.loss_G_GAN.item()), 117 | ('G_L1', self.loss_G_Content.item()), 118 | ('D_real+fake', self.loss_D.item()) 119 | ]) 120 | 121 | def get_current_visuals(self): 122 | real_A = util.tensor2im(self.real_A.data) 123 | fake_B = util.tensor2im(self.fake_B.data) 124 | real_B = util.tensor2im(self.real_B.data) 125 | return OrderedDict([('Blurred_Train', real_A), ('Restored_Train', fake_B), ('Sharp_Train', real_B)]) 126 | 127 | def save(self, label): 128 | self.save_network(self.netG, 'G', label, self.gpu_ids) 129 | self.save_network(self.netD, 'D', label, self.gpu_ids) 130 | 131 | def update_learning_rate(self): 132 | lrd = self.opt.lr / self.opt.niter_decay 133 | lr = self.old_lr - lrd 134 | for param_group in self.optimizer_D.param_groups: 135 | param_group['lr'] = lr 136 | for param_group in self.optimizer_G.param_groups: 137 | param_group['lr'] = lr 138 | print('update learning rate: %f -> %f' % (self.old_lr, lr)) 139 | self.old_lr = lr 140 | -------------------------------------------------------------------------------- /models/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | import torch.autograd as autograd 6 | import numpy as np 7 | import torchvision.models as models 8 | import util.util as util 9 | from util.image_pool import ImagePool 10 | from torch.autograd import Variable 11 | ############################################################################### 12 | # Functions 13 | ############################################################################### 14 | 15 | class ContentLoss: 16 | def __init__(self, loss): 17 | self.criterion = loss 18 | 19 | def get_loss(self, fakeIm, realIm): 20 | return self.criterion(fakeIm, realIm) 21 | 22 | class PerceptualLoss(): 23 | 24 | def contentFunc(self): 25 | conv_3_3_layer = 14 26 | cnn = models.vgg19(pretrained=True).features 27 | cnn = cnn.cuda() 28 | model = nn.Sequential() 29 | model = model.cuda() 30 | for i,layer in enumerate(list(cnn)): 31 | model.add_module(str(i),layer) 32 | if i == conv_3_3_layer: 33 | break 34 | return model 35 | 36 | def __init__(self, loss): 37 | self.criterion = loss 38 | self.contentFunc = self.contentFunc() 39 | 40 | def get_loss(self, fakeIm, realIm): 41 | f_fake = self.contentFunc.forward(fakeIm) 42 | f_real = self.contentFunc.forward(realIm) 43 | f_real_no_grad = f_real.detach() 44 | loss = self.criterion(f_fake, f_real_no_grad) 45 | return loss 46 | 47 | class GANLoss(nn.Module): 48 | def __init__( 49 | self, use_l1=True, target_real_label=1.0, 50 | target_fake_label=0.0, tensor=torch.FloatTensor): 51 | super(GANLoss, self).__init__() 52 | self.real_label = target_real_label 53 | self.fake_label = target_fake_label 54 | self.real_label_var = None 55 | self.fake_label_var = None 56 | self.Tensor = tensor 57 | if use_l1: 58 | self.loss = nn.L1Loss() 59 | else: 60 | self.loss = nn.BCELoss() 61 | 62 | def get_target_tensor(self, input, target_is_real): 63 | target_tensor = None 64 | if target_is_real: 65 | create_label = ((self.real_label_var is None) or 66 | (self.real_label_var.numel() != input.numel())) 67 | if create_label: 68 | real_tensor = self.Tensor(input.size()).fill_(self.real_label) 69 | self.real_label_var = Variable(real_tensor, requires_grad=False) 70 | target_tensor = self.real_label_var 71 | else: 72 | create_label = ((self.fake_label_var is None) or 73 | (self.fake_label_var.numel() != input.numel())) 74 | if create_label: 75 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) 76 | self.fake_label_var = Variable(fake_tensor, requires_grad=False) 77 | target_tensor = self.fake_label_var 78 | return target_tensor 79 | 80 | def __call__(self, input, target_is_real): 81 | target_tensor = self.get_target_tensor(input, target_is_real) 82 | return self.loss(input, target_tensor) 83 | 84 | class DiscLoss: 85 | def name(self): 86 | return 'DiscLoss' 87 | 88 | def __init__(self, opt, tensor): 89 | self.criterionGAN = GANLoss(use_l1=False, tensor=tensor) 90 | self.fake_AB_pool = ImagePool(opt.pool_size) 91 | 92 | def get_g_loss(self,net, realA, fakeB): 93 | # First, G(A) should fake the discriminator 94 | pred_fake = net.forward(fakeB) 95 | return self.criterionGAN(pred_fake, 1) 96 | 97 | def get_loss(self, net, realA, fakeB, realB): 98 | # Fake 99 | # stop backprop to the generator by detaching fake_B 100 | # Generated Image Disc Output should be close to zero 101 | self.pred_fake = net.forward(fakeB.detach()) 102 | self.loss_D_fake = self.criterionGAN(self.pred_fake, 0) 103 | 104 | # Real 105 | self.pred_real = net.forward(realB) 106 | self.loss_D_real = self.criterionGAN(self.pred_real, 1) 107 | 108 | # Combined loss 109 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 110 | return self.loss_D 111 | 112 | class DiscLossLS(DiscLoss): 113 | def name(self): 114 | return 'DiscLossLS' 115 | 116 | def __init__(self, opt, tensor): 117 | super(DiscLoss, self).__init__(opt, tensor) 118 | # DiscLoss.initialize(self, opt, tensor) 119 | self.criterionGAN = GANLoss(use_l1=True, tensor=tensor) 120 | 121 | def get_g_loss(self,net, realA, fakeB): 122 | return DiscLoss.get_g_loss(self,net, realA, fakeB) 123 | 124 | def get_loss(self, net, realA, fakeB, realB): 125 | return DiscLoss.get_loss(self, net, realA, fakeB, realB) 126 | 127 | class DiscLossWGANGP(DiscLossLS): 128 | def name(self): 129 | return 'DiscLossWGAN-GP' 130 | 131 | def __init__(self, opt, tensor): 132 | super(DiscLossWGANGP, self).__init__(opt, tensor) 133 | # DiscLossLS.initialize(self, opt, tensor) 134 | self.LAMBDA = 10 135 | 136 | def get_g_loss(self, net, realA, fakeB): 137 | # First, G(A) should fake the discriminator 138 | self.D_fake = net.forward(fakeB) 139 | return -self.D_fake.mean() 140 | 141 | def calc_gradient_penalty(self, netD, real_data, fake_data): 142 | alpha = torch.rand(1, 1) 143 | alpha = alpha.expand(real_data.size()) 144 | alpha = alpha.cuda() 145 | 146 | interpolates = alpha * real_data + ((1 - alpha) * fake_data) 147 | 148 | interpolates = interpolates.cuda() 149 | interpolates = Variable(interpolates, requires_grad=True) 150 | 151 | disc_interpolates = netD.forward(interpolates) 152 | 153 | gradients = autograd.grad( 154 | outputs=disc_interpolates, inputs=interpolates, grad_outputs=torch.ones(disc_interpolates.size()).cuda(), 155 | create_graph=True, retain_graph=True, only_inputs=True 156 | )[0] 157 | 158 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.LAMBDA 159 | return gradient_penalty 160 | 161 | def get_loss(self, net, realA, fakeB, realB): 162 | self.D_fake = net.forward(fakeB.detach()) 163 | self.D_fake = self.D_fake.mean() 164 | 165 | # Real 166 | self.D_real = net.forward(realB) 167 | self.D_real = self.D_real.mean() 168 | # Combined loss 169 | self.loss_D = self.D_fake - self.D_real 170 | gradient_penalty = self.calc_gradient_penalty(net, realB.data, fakeB.data) 171 | return self.loss_D + gradient_penalty 172 | 173 | 174 | def init_loss(opt, tensor): 175 | # disc_loss = None 176 | # content_loss = None 177 | 178 | if opt.model == 'content_gan': 179 | content_loss = PerceptualLoss(nn.MSELoss()) 180 | # content_loss.initialize(nn.MSELoss()) 181 | elif opt.model == 'pix2pix': 182 | content_loss = ContentLoss(nn.L1Loss()) 183 | # content_loss.initialize(nn.L1Loss()) 184 | else: 185 | raise ValueError("Model [%s] not recognized." % opt.model) 186 | 187 | if opt.gan_type == 'wgan-gp': 188 | disc_loss = DiscLossWGANGP(opt, tensor) 189 | elif opt.gan_type == 'lsgan': 190 | disc_loss = DiscLossLS(opt, tensor) 191 | elif opt.gan_type == 'gan': 192 | disc_loss = DiscLoss(opt, tensor) 193 | else: 194 | raise ValueError("GAN [%s] not recognized." % opt.gan_type) 195 | # disc_loss.initialize(opt, tensor) 196 | return disc_loss, content_loss -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | from .conditional_gan_model import ConditionalGAN 2 | 3 | def create_model(opt): 4 | model = None 5 | if opt.model == 'test': 6 | assert (opt.dataset_mode == 'single') 7 | from .test_model import TestModel 8 | model = TestModel( opt ) 9 | else: 10 | model = ConditionalGAN(opt) 11 | # model.initialize(opt) 12 | print("model [%s] was created" % (model.name())) 13 | return model 14 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | # from torch.nn import init 4 | import functools 5 | # from torch.autograd import Variable 6 | import numpy as np 7 | 8 | 9 | ############################################################################### 10 | # Functions 11 | ############################################################################### 12 | 13 | 14 | def weights_init(m): 15 | classname = m.__class__.__name__ 16 | if classname.find('Conv') != -1: 17 | m.weight.data.normal_(0.0, 0.02) 18 | if hasattr(m.bias, 'data'): 19 | m.bias.data.fill_(0) 20 | elif classname.find('BatchNorm2d') != -1: 21 | m.weight.data.normal_(1.0, 0.02) 22 | m.bias.data.fill_(0) 23 | 24 | 25 | def get_norm_layer(norm_type='instance'): 26 | if norm_type == 'batch': 27 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 28 | elif norm_type == 'instance': 29 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True) 30 | else: 31 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 32 | return norm_layer 33 | 34 | 35 | def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, gpu_ids=[], use_parallel=True, 36 | learn_residual=False): 37 | netG = None 38 | use_gpu = len(gpu_ids) > 0 39 | norm_layer = get_norm_layer(norm_type=norm) 40 | 41 | if use_gpu: 42 | assert (torch.cuda.is_available()) 43 | 44 | if which_model_netG == 'resnet_9blocks': 45 | netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, 46 | gpu_ids=gpu_ids, use_parallel=use_parallel, learn_residual=learn_residual) 47 | elif which_model_netG == 'resnet_6blocks': 48 | netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6, 49 | gpu_ids=gpu_ids, use_parallel=use_parallel, learn_residual=learn_residual) 50 | elif which_model_netG == 'unet_128': 51 | netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout, 52 | gpu_ids=gpu_ids, use_parallel=use_parallel, learn_residual=learn_residual) 53 | elif which_model_netG == 'unet_256': 54 | netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout, 55 | gpu_ids=gpu_ids, use_parallel=use_parallel, learn_residual=learn_residual) 56 | else: 57 | raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) 58 | if len(gpu_ids) > 0: 59 | netG.cuda(gpu_ids[0]) 60 | netG.apply(weights_init) 61 | return netG 62 | 63 | 64 | def define_D(input_nc, ndf, which_model_netD, n_layers_D=3, norm='batch', use_sigmoid=False, gpu_ids=[], 65 | use_parallel=True): 66 | netD = None 67 | use_gpu = len(gpu_ids) > 0 68 | norm_layer = get_norm_layer(norm_type=norm) 69 | 70 | if use_gpu: 71 | assert (torch.cuda.is_available()) 72 | if which_model_netD == 'basic': 73 | netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, 74 | gpu_ids=gpu_ids, use_parallel=use_parallel) 75 | elif which_model_netD == 'n_layers': 76 | netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, 77 | gpu_ids=gpu_ids, use_parallel=use_parallel) 78 | else: 79 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % which_model_netD) 80 | if use_gpu: 81 | netD.cuda(gpu_ids[0]) 82 | netD.apply(weights_init) 83 | return netD 84 | 85 | 86 | def print_network(net): 87 | num_params = 0 88 | for param in net.parameters(): 89 | num_params += param.numel() 90 | print(net) 91 | print('Total number of parameters: %d' % num_params) 92 | 93 | 94 | ############################################################################## 95 | # Classes 96 | ############################################################################## 97 | 98 | 99 | # Defines the generator that consists of Resnet blocks between a few 100 | # downsampling/upsampling operations. 101 | # Code and idea originally from Justin Johnson's architecture. 102 | # https://github.com/jcjohnson/fast-neural-style/ 103 | class ResnetGenerator(nn.Module): 104 | def __init__( 105 | self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, 106 | n_blocks=6, gpu_ids=[], use_parallel=True, learn_residual=False, padding_type='reflect'): 107 | assert (n_blocks >= 0) 108 | super(ResnetGenerator, self).__init__() 109 | self.input_nc = input_nc 110 | self.output_nc = output_nc 111 | self.ngf = ngf 112 | self.gpu_ids = gpu_ids 113 | self.use_parallel = use_parallel 114 | self.learn_residual = learn_residual 115 | 116 | if type(norm_layer) == functools.partial: 117 | use_bias = norm_layer.func == nn.InstanceNorm2d 118 | else: 119 | use_bias = norm_layer == nn.InstanceNorm2d 120 | 121 | model = [ 122 | nn.ReflectionPad2d(3), 123 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), 124 | norm_layer(ngf), 125 | nn.ReLU(True) 126 | ] 127 | 128 | n_downsampling = 2 129 | 130 | # 下采样 131 | # for i in range(n_downsampling): # [0,1] 132 | # mult = 2**i 133 | # 134 | # model += [ 135 | # nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), 136 | # norm_layer(ngf * mult * 2), 137 | # nn.ReLU(True) 138 | # ] 139 | 140 | model += [ 141 | nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=use_bias), 142 | norm_layer(128), 143 | nn.ReLU(True), 144 | 145 | nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=use_bias), 146 | norm_layer(256), 147 | nn.ReLU(True) 148 | ] 149 | 150 | # 中间的残差网络 151 | # mult = 2**n_downsampling 152 | for i in range(n_blocks): 153 | # model += [ 154 | # ResnetBlock( 155 | # ngf * mult, padding_type=padding_type, norm_layer=norm_layer, 156 | # use_dropout=use_dropout, use_bias=use_bias) 157 | # ] 158 | model += [ 159 | ResnetBlock(256, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias) 160 | ] 161 | 162 | # 上采样 163 | # for i in range(n_downsampling): 164 | # mult = 2**(n_downsampling - i) 165 | # 166 | # model += [ 167 | # nn.ConvTranspose2d( 168 | # ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, 169 | # padding=1, output_padding=1, bias=use_bias), 170 | # norm_layer(int(ngf * mult / 2)), 171 | # nn.ReLU(True) 172 | # ] 173 | model += [ 174 | nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), 175 | norm_layer(128), 176 | nn.ReLU(True), 177 | 178 | nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), 179 | norm_layer(64), 180 | nn.ReLU(True), 181 | ] 182 | 183 | model += [ 184 | nn.ReflectionPad2d(3), 185 | nn.Conv2d(64, output_nc, kernel_size=7, padding=0), 186 | nn.Tanh() 187 | ] 188 | 189 | self.model = nn.Sequential(*model) 190 | 191 | def forward(self, input): 192 | if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor) and self.use_parallel: 193 | output = nn.parallel.data_parallel(self.model, input, self.gpu_ids) 194 | else: 195 | output = self.model(input) 196 | if self.learn_residual: 197 | # output = input + output 198 | output = torch.clamp(input + output, min=-1, max=1) 199 | return output 200 | 201 | 202 | # Define a resnet block 203 | class ResnetBlock(nn.Module): 204 | 205 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): 206 | super(ResnetBlock, self).__init__() 207 | 208 | padAndConv = { 209 | 'reflect': [ 210 | nn.ReflectionPad2d(1), 211 | nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias)], 212 | 'replicate': [ 213 | nn.ReplicationPad2d(1), 214 | nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias)], 215 | 'zero': [ 216 | nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=use_bias)] 217 | } 218 | 219 | try: 220 | blocks = padAndConv[padding_type] + [ 221 | norm_layer(dim), 222 | nn.ReLU(True) 223 | ] + [ 224 | nn.Dropout(0.5) 225 | ] if use_dropout else [] + padAndConv[padding_type] + [ 226 | norm_layer(dim) 227 | ] 228 | except: 229 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 230 | 231 | self.conv_block = nn.Sequential(*blocks) 232 | 233 | # self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) 234 | # def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): 235 | # padAndConv = { 236 | # 'reflect': [nn.ReflectionPad2d(1), nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias)], 237 | # 'replicate': [nn.ReplicationPad2d(1), nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias)], 238 | # 'zero': [nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=use_bias)] 239 | # } 240 | # try: 241 | # blocks = [ 242 | # padAndConv[padding_type], 243 | # 244 | # norm_layer(dim), 245 | # nn.ReLU(True), 246 | # nn.Dropout(0.5) if use_dropout else None, 247 | # 248 | # padAndConv[padding_type], 249 | # 250 | # norm_layer(dim) 251 | # ] 252 | # except: 253 | # raise NotImplementedError('padding [%s] is not implemented' % padding_type) 254 | # 255 | # return nn.Sequential(*blocks) 256 | 257 | # blocks = [] 258 | # if padding_type == 'reflect': 259 | # blocks += [nn.ReflectionPad2d(1), nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias)] 260 | # elif padding_type == 'replicate': 261 | # blocks += [nn.ReplicationPad2d(1), nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias)] 262 | # elif padding_type == 'zero': 263 | # blocks += [nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=use_bias)] 264 | # else: 265 | # raise NotImplementedError('padding [%s] is not implemented' % padding_type) 266 | # 267 | # blocks += [ 268 | # norm_layer(dim), 269 | # nn.ReLU(True), 270 | # nn.Dropout(0.5) if use_dropout else None 271 | # ] 272 | # 273 | # if padding_type == 'reflect': 274 | # blocks += [nn.ReflectionPad2d(1), nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias)] 275 | # elif padding_type == 'replicate': 276 | # blocks += [nn.ReplicationPad2d(1), nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias)] 277 | # elif padding_type == 'zero': 278 | # blocks += [nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=use_bias)] 279 | # else: 280 | # raise NotImplementedError('padding [%s] is not implemented' % padding_type) 281 | # 282 | # blocks += [ 283 | # norm_layer(dim) 284 | # ] 285 | # 286 | # return nn.Sequential(*blocks) 287 | 288 | def forward(self, x): 289 | out = x + self.conv_block(x) 290 | return out 291 | 292 | 293 | # Defines the Unet generator. 294 | # |num_downs|: number of downsamplings in UNet. For example, 295 | # if |num_downs| == 7, image of size 128x128 will become of size 1x1 296 | # at the bottleneck 297 | class UnetGenerator(nn.Module): 298 | def __init__( 299 | self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, 300 | use_dropout=False, gpu_ids=[], use_parallel=True, learn_residual=False): 301 | super(UnetGenerator, self).__init__() 302 | self.gpu_ids = gpu_ids 303 | self.use_parallel = use_parallel 304 | self.learn_residual = learn_residual 305 | # currently support only input_nc == output_nc 306 | assert (input_nc == output_nc) 307 | 308 | # construct unet structure 309 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, innermost=True) 310 | for i in range(num_downs - 5): 311 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, norm_layer=norm_layer, 312 | use_dropout=use_dropout) 313 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block, norm_layer=norm_layer) 314 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block, norm_layer=norm_layer) 315 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block, norm_layer=norm_layer) 316 | unet_block = UnetSkipConnectionBlock(output_nc, ngf, unet_block, outermost=True, norm_layer=norm_layer) 317 | 318 | self.model = unet_block 319 | 320 | def forward(self, input): 321 | if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor) and self.use_parallel: 322 | output = nn.parallel.data_parallel(self.model, input, self.gpu_ids) 323 | else: 324 | output = self.model(input) 325 | if self.learn_residual: 326 | output = input + output 327 | output = torch.clamp(output, min=-1, max=1) 328 | return output 329 | 330 | 331 | # Defines the submodule with skip connection. 332 | # X -------------------identity---------------------- X 333 | # |-- downsampling -- |submodule| -- upsampling --| 334 | class UnetSkipConnectionBlock(nn.Module): 335 | def __init__( 336 | self, outer_nc, inner_nc, submodule=None, 337 | outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): 338 | super(UnetSkipConnectionBlock, self).__init__() 339 | self.outermost = outermost 340 | if type(norm_layer) == functools.partial: 341 | use_bias = norm_layer.func == nn.InstanceNorm2d 342 | else: 343 | use_bias = norm_layer == nn.InstanceNorm2d 344 | 345 | dConv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) 346 | dRelu = nn.LeakyReLU(0.2, True) 347 | dNorm = norm_layer(inner_nc) 348 | uRelu = nn.ReLU(True) 349 | uNorm = norm_layer(outer_nc) 350 | 351 | if outermost: 352 | uConv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1) 353 | dModel = [dConv] 354 | uModel = [uRelu, uConv, nn.Tanh()] 355 | model = [ 356 | dModel, 357 | submodule, 358 | uModel 359 | ] 360 | # model = [ 361 | # # Down 362 | # nn.Conv2d( outer_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias), 363 | # 364 | # submodule, 365 | # # Up 366 | # nn.ReLU(True), 367 | # nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1), 368 | # nn.Tanh() 369 | # ] 370 | elif innermost: 371 | uConv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) 372 | dModel = [dRelu, dConv] 373 | uModel = [uRelu, uConv, uNorm] 374 | model = [ 375 | dModel, 376 | uModel 377 | ] 378 | # model = [ 379 | # # down 380 | # nn.LeakyReLU(0.2, True), 381 | # # up 382 | # nn.ReLU(True), 383 | # nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias), 384 | # norm_layer(outer_nc) 385 | # ] 386 | else: 387 | uConv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) 388 | dModel = [dRelu, dConv, dNorm] 389 | uModel = [uRelu, uConv, uNorm] 390 | 391 | model = [ 392 | dModel, 393 | submodule, 394 | uModel 395 | ] 396 | model += [nn.Dropout(0.5)] if use_dropout else [] 397 | 398 | # if use_dropout: 399 | # model = down + [submodule] + up + [nn.Dropout(0.5)] 400 | # else: 401 | # model = down + [submodule] + up 402 | 403 | self.model = nn.Sequential(*model) 404 | 405 | def forward(self, x): 406 | if self.outermost: 407 | return self.model(x) 408 | else: 409 | return torch.cat([self.model(x), x], 1) 410 | 411 | 412 | # Defines the PatchGAN discriminator with the specified arguments. 413 | class NLayerDiscriminator(nn.Module): 414 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[], 415 | use_parallel=True): 416 | super(NLayerDiscriminator, self).__init__() 417 | self.gpu_ids = gpu_ids 418 | self.use_parallel = use_parallel 419 | 420 | if type(norm_layer) == functools.partial: 421 | use_bias = norm_layer.func == nn.InstanceNorm2d 422 | else: 423 | use_bias = norm_layer == nn.InstanceNorm2d 424 | 425 | kw = 4 426 | padw = int(np.ceil((kw - 1) / 2)) 427 | sequence = [ 428 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 429 | nn.LeakyReLU(0.2, True) 430 | ] 431 | 432 | nf_mult = 1 433 | nf_mult_prev = 1 434 | for n in range(1, n_layers): 435 | nf_mult_prev = nf_mult 436 | nf_mult = min(2 ** n, 8) 437 | sequence += [ 438 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 439 | kernel_size=kw, stride=2, padding=padw, bias=use_bias), 440 | norm_layer(ndf * nf_mult), 441 | nn.LeakyReLU(0.2, True) 442 | ] 443 | 444 | nf_mult_prev = nf_mult 445 | nf_mult = min(2 ** n_layers, 8) 446 | sequence += [ 447 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 448 | norm_layer(ndf * nf_mult), 449 | nn.LeakyReLU(0.2, True) 450 | ] 451 | 452 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] 453 | 454 | if use_sigmoid: 455 | sequence += [nn.Sigmoid()] 456 | 457 | self.model = nn.Sequential(*sequence) 458 | 459 | def forward(self, input): 460 | if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor) and self.use_parallel: 461 | return nn.parallel.data_parallel(self.model, input, self.gpu_ids) 462 | else: 463 | return self.model(input) 464 | -------------------------------------------------------------------------------- /models/test_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from collections import OrderedDict 4 | import util.util as util 5 | from .base_model import BaseModel 6 | from . import networks 7 | 8 | 9 | class TestModel(BaseModel): 10 | def name(self): 11 | return 'TestModel' 12 | 13 | def __init__(self, opt): 14 | assert(not opt.isTrain) 15 | super(TestModel, self).__init__(opt) 16 | self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) 17 | 18 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 19 | opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids, False, 20 | opt.learn_residual) 21 | which_epoch = opt.which_epoch 22 | self.load_network(self.netG, 'G', which_epoch) 23 | 24 | print('---------- Networks initialized -------------') 25 | networks.print_network(self.netG) 26 | print('-----------------------------------------------') 27 | 28 | def set_input(self, input): 29 | # we need to use single_dataset mode 30 | input_A = input['A'] 31 | temp = self.input_A.clone() 32 | temp.resize_(input_A.size()).copy_(input_A) 33 | self.input_A = temp 34 | self.image_paths = input['A_paths'] 35 | 36 | def test(self): 37 | with torch.no_grad(): 38 | self.real_A = Variable(self.input_A) 39 | self.fake_B = self.netG.forward(self.real_A) 40 | 41 | # get image paths 42 | def get_image_paths(self): 43 | return self.image_paths 44 | 45 | def get_current_visuals(self): 46 | real_A = util.tensor2im(self.real_A.data) 47 | fake_B = util.tensor2im(self.fake_B.data) 48 | return OrderedDict([('real_A', real_A), ('fake_B', fake_B)]) 49 | -------------------------------------------------------------------------------- /motion_blur/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/motion_blur/__init__.py -------------------------------------------------------------------------------- /motion_blur/blur_image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import cv2 4 | import os 5 | from scipy import signal 6 | from scipy import misc 7 | from motion_blur.generate_PSF import PSF 8 | from motion_blur.generate_trajectory import Trajectory 9 | 10 | 11 | class BlurImage(object): 12 | 13 | def __init__(self, image_path, PSFs=None, part=None, path__to_save=None): 14 | """ 15 | 16 | :param image_path: path to square, RGB image. 17 | :param PSFs: array of Kernels. 18 | :param part: int number of kernel to use. 19 | :param path__to_save: folder to save results. 20 | """ 21 | if os.path.isfile(image_path): 22 | self.image_path = image_path 23 | self.original = misc.imread(self.image_path) 24 | self.shape = self.original.shape 25 | if len(self.shape) < 3: 26 | raise Exception('We support only RGB images yet.') 27 | elif self.shape[0] != self.shape[1]: 28 | raise Exception('We support only square images yet.') 29 | else: 30 | raise Exception('Not correct path to image.') 31 | self.path_to_save = path__to_save 32 | if PSFs is None: 33 | if self.path_to_save is None: 34 | self.PSFs = PSF(canvas=self.shape[0]).fit() 35 | else: 36 | self.PSFs = PSF(canvas=self.shape[0], path_to_save=os.path.join(self.path_to_save, 37 | 'PSFs.png')).fit(save=True) 38 | else: 39 | self.PSFs = PSFs 40 | 41 | self.part = part 42 | self.result = [] 43 | 44 | def blur_image(self, save=False, show=False): 45 | if self.part is None: 46 | psf = self.PSFs 47 | else: 48 | psf = [self.PSFs[self.part]] 49 | yN, xN, channel = self.shape 50 | key, kex = self.PSFs[0].shape 51 | delta = yN - key 52 | assert delta >= 0, 'resolution of image should be higher than kernel' 53 | result=[] 54 | if len(psf) > 1: 55 | for p in psf: 56 | tmp = np.pad(p, delta // 2, 'constant') 57 | cv2.normalize(tmp, tmp, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F) 58 | # blured = np.zeros(self.shape) 59 | blured = cv2.normalize(self.original, self.original, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, 60 | dtype=cv2.CV_32F) 61 | blured[:, :, 0] = np.array(signal.fftconvolve(blured[:, :, 0], tmp, 'same')) 62 | blured[:, :, 1] = np.array(signal.fftconvolve(blured[:, :, 1], tmp, 'same')) 63 | blured[:, :, 2] = np.array(signal.fftconvolve(blured[:, :, 2], tmp, 'same')) 64 | blured = cv2.normalize(blured, blured, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F) 65 | blured = cv2.cvtColor(blured, cv2.COLOR_RGB2BGR) 66 | result.append(np.abs(blured)) 67 | else: 68 | psf = psf[0] 69 | tmp = np.pad(psf, delta // 2, 'constant') 70 | cv2.normalize(tmp, tmp, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F) 71 | blured = cv2.normalize(self.original, self.original, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, 72 | dtype=cv2.CV_32F) 73 | blured[:, :, 0] = np.array(signal.fftconvolve(blured[:, :, 0], tmp, 'same')) 74 | blured[:, :, 1] = np.array(signal.fftconvolve(blured[:, :, 1], tmp, 'same')) 75 | blured[:, :, 2] = np.array(signal.fftconvolve(blured[:, :, 2], tmp, 'same')) 76 | blured = cv2.normalize(blured, blured, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F) 77 | blured = cv2.cvtColor(blured, cv2.COLOR_RGB2BGR) 78 | result.append(np.abs(blured)) 79 | self.result = result 80 | if show or save: 81 | self.__plot_canvas(show, save) 82 | 83 | def __plot_canvas(self, show, save): 84 | if len(self.result) == 0: 85 | raise Exception('Please run blur_image() method first.') 86 | else: 87 | plt.close() 88 | plt.axis('off') 89 | fig, axes = plt.subplots(1, len(self.result), figsize=(10, 10)) 90 | if len(self.result) > 1: 91 | for i in range(len(self.result)): 92 | axes[i].imshow(self.result[i]) 93 | else: 94 | plt.axis('off') 95 | 96 | plt.imshow(self.result[0]) 97 | if show and save: 98 | if self.path_to_save is None: 99 | raise Exception('Please create Trajectory instance with path_to_save') 100 | cv2.imwrite(os.path.join(self.path_to_save, self.image_path.split('/')[-1]), self.result[0] * 255) 101 | plt.show() 102 | elif save: 103 | if self.path_to_save is None: 104 | raise Exception('Please create Trajectory instance with path_to_save') 105 | cv2.imwrite(os.path.join(self.path_to_save, self.image_path.split('/')[-1]), self.result[0] * 255) 106 | elif show: 107 | plt.show() 108 | 109 | 110 | if __name__ == '__main__': 111 | folder = '/Users/mykolam/PycharmProjects/University/DeblurGAN2/results_sharp' 112 | folder_to_save = '/Users/mykolam/PycharmProjects/University/DeblurGAN2/blured' 113 | params = [0.01, 0.009, 0.008, 0.007, 0.005, 0.003] 114 | for path in os.listdir(folder): 115 | print(path) 116 | trajectory = Trajectory(canvas=64, max_len=60, expl=np.random.choice(params)).fit() 117 | psf = PSF(canvas=64, trajectory=trajectory).fit() 118 | BlurImage(os.path.join(folder, path), PSFs=psf, 119 | path__to_save=folder_to_save, part=np.random.choice([1, 2, 3])).\ 120 | blur_image(save=True) 121 | -------------------------------------------------------------------------------- /motion_blur/generate_PSF.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from math import ceil 3 | import matplotlib.pyplot as plt 4 | from motion_blur.generate_trajectory import Trajectory 5 | 6 | 7 | class PSF(object): 8 | def __init__(self, canvas=None, trajectory=None, fraction=None, path_to_save=None): 9 | if canvas is None: 10 | self.canvas = (canvas, canvas) 11 | else: 12 | self.canvas = (canvas, canvas) 13 | if trajectory is None: 14 | self.trajectory = Trajectory(canvas=canvas, expl=0.005).fit(show=False, save=False) 15 | else: 16 | self.trajectory = trajectory.x 17 | if fraction is None: 18 | self.fraction = [1/100, 1/10, 1/2, 1] 19 | else: 20 | self.fraction = fraction 21 | self.path_to_save = path_to_save 22 | self.PSFnumber = len(self.fraction) 23 | self.iters = len(self.trajectory) 24 | self.PSFs = [] 25 | 26 | def fit(self, show=False, save=False): 27 | PSF = np.zeros(self.canvas) 28 | 29 | triangle_fun = lambda x: np.maximum(0, (1 - np.abs(x))) 30 | triangle_fun_prod = lambda x, y: np.multiply(triangle_fun(x), triangle_fun(y)) 31 | for j in range(self.PSFnumber): 32 | if j == 0: 33 | prevT = 0 34 | else: 35 | prevT = self.fraction[j - 1] 36 | 37 | for t in range(len(self.trajectory)): 38 | # print(j, t) 39 | if (self.fraction[j] * self.iters >= t) and (prevT * self.iters < t - 1): 40 | t_proportion = 1 41 | elif (self.fraction[j] * self.iters >= t - 1) and (prevT * self.iters < t - 1): 42 | t_proportion = self.fraction[j] * self.iters - (t - 1) 43 | elif (self.fraction[j] * self.iters >= t) and (prevT * self.iters < t): 44 | t_proportion = t - (prevT * self.iters) 45 | elif (self.fraction[j] * self.iters >= t - 1) and (prevT * self.iters < t): 46 | t_proportion = (self.fraction[j] - prevT) * self.iters 47 | else: 48 | t_proportion = 0 49 | 50 | m2 = int(np.minimum(self.canvas[1] - 1, np.maximum(1, np.math.floor(self.trajectory[t].real)))) 51 | M2 = int(m2 + 1) 52 | m1 = int(np.minimum(self.canvas[0] - 1, np.maximum(1, np.math.floor(self.trajectory[t].imag)))) 53 | M1 = int(m1 + 1) 54 | 55 | PSF[m1, m2] += t_proportion * triangle_fun_prod( 56 | self.trajectory[t].real - m2, self.trajectory[t].imag - m1 57 | ) 58 | PSF[m1, M2] += t_proportion * triangle_fun_prod( 59 | self.trajectory[t].real - M2, self.trajectory[t].imag - m1 60 | ) 61 | PSF[M1, m2] += t_proportion * triangle_fun_prod( 62 | self.trajectory[t].real - m2, self.trajectory[t].imag - M1 63 | ) 64 | PSF[M1, M2] += t_proportion * triangle_fun_prod( 65 | self.trajectory[t].real - M2, self.trajectory[t].imag - M1 66 | ) 67 | 68 | self.PSFs.append(PSF / (self.iters)) 69 | if show or save: 70 | self.__plot_canvas(show, save) 71 | 72 | return self.PSFs 73 | 74 | def __plot_canvas(self, show, save): 75 | if len(self.PSFs) == 0: 76 | raise Exception("Please run fit() method first.") 77 | else: 78 | plt.close() 79 | fig, axes = plt.subplots(1, self.PSFnumber, figsize=(10, 10)) 80 | for i in range(self.PSFnumber): 81 | axes[i].imshow(self.PSFs[i], cmap='gray') 82 | if show and save: 83 | if self.path_to_save is None: 84 | raise Exception('Please create Trajectory instance with path_to_save') 85 | plt.savefig(self.path_to_save) 86 | plt.show() 87 | elif save: 88 | if self.path_to_save is None: 89 | raise Exception('Please create Trajectory instance with path_to_save') 90 | plt.savefig(self.path_to_save) 91 | elif show: 92 | plt.show() 93 | 94 | 95 | if __name__ == '__main__': 96 | psf = PSF(canvas=128, path_to_save='/Users/mykolam/PycharmProjects/University/RandomMotionBlur/psf.png') 97 | psf.fit(show=True, save=True) -------------------------------------------------------------------------------- /motion_blur/generate_trajectory.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from math import ceil 4 | 5 | 6 | class Trajectory(object): 7 | def __init__(self, canvas=64, iters=2000, max_len=60, expl=None, path_to_save=None): 8 | """ 9 | Generates a variety of random motion trajectories in continuous domain as in [Boracchi and Foi 2012]. Each 10 | trajectory consists of a complex-valued vector determining the discrete positions of a particle following a 11 | 2-D random motion in continuous domain. The particle has an initial velocity vector which, at each iteration, 12 | is affected by a Gaussian perturbation and by a deterministic inertial component, directed toward the 13 | previous particle position. In addition, with a small probability, an impulsive (abrupt) perturbation aiming 14 | at inverting the particle velocity may arises, mimicking a sudden movement that occurs when the user presses 15 | the camera button or tries to compensate the camera shake. At each step, the velocity is normalized to 16 | guarantee that trajectories corresponding to equal exposures have the same length. Each perturbation ( 17 | Gaussian, inertial, and impulsive) is ruled by its own parameter. Rectilinear Blur as in [Boracchi and Foi 18 | 2011] can be obtained by setting anxiety to 0 (when no impulsive changes occurs 19 | :param canvas: size of domain where our trajectory os defined. 20 | :param iters: number of iterations for definition of our trajectory. 21 | :param max_len: maximum length of our trajectory. 22 | :param expl: this param helps to define probability of big shake. Recommended expl = 0.005. 23 | :param path_to_save: where to save if you need. 24 | """ 25 | self.canvas = canvas 26 | self.iters = iters 27 | self.max_len = max_len 28 | if expl is None: 29 | self.expl = 0.1 * np.random.uniform(0, 1) 30 | else: 31 | self.expl = expl 32 | if path_to_save is None: 33 | pass 34 | else: 35 | self.path_to_save = path_to_save 36 | self.tot_length = None 37 | self.big_expl_count = None 38 | self.x = None 39 | 40 | def fit(self, show=False, save=False): 41 | """ 42 | Generate motion, you can save or plot, coordinates of motion you can find in x property. 43 | Also you can fin properties tot_length, big_expl_count. 44 | :param show: default False. 45 | :param save: default False. 46 | :return: x (vector of motion). 47 | """ 48 | tot_length = 0 49 | big_expl_count = 0 50 | # how to be near the previous position 51 | # TODO: I can change this paramether for 0.1 and make kernel at all image 52 | centripetal = 0.7 * np.random.uniform(0, 1) 53 | # probability of big shake 54 | prob_big_shake = 0.2 * np.random.uniform(0, 1) 55 | # term determining, at each sample, the random component of the new direction 56 | gaussian_shake = 10 * np.random.uniform(0, 1) 57 | init_angle = 360 * np.random.uniform(0, 1) 58 | 59 | img_v0 = np.sin(np.deg2rad(init_angle)) 60 | real_v0 = np.cos(np.deg2rad(init_angle)) 61 | 62 | v0 = complex(real=real_v0, imag=img_v0) 63 | v = v0 * self.max_len / (self.iters - 1) 64 | 65 | if self.expl > 0: 66 | v = v0 * self.expl 67 | 68 | x = np.array([complex(real=0, imag=0)] * (self.iters)) 69 | 70 | for t in range(0, self.iters - 1): 71 | if np.random.uniform() < prob_big_shake * self.expl: 72 | next_direction = 2 * v * (np.exp(complex(real=0, imag=np.pi + (np.random.uniform() - 0.5)))) 73 | big_expl_count += 1 74 | else: 75 | next_direction = 0 76 | 77 | dv = next_direction + self.expl * ( 78 | gaussian_shake * complex(real=np.random.randn(), imag=np.random.randn()) - centripetal * x[t]) * ( 79 | self.max_len / (self.iters - 1)) 80 | 81 | v += dv 82 | v = (v / float(np.abs(v))) * (self.max_len / float((self.iters - 1))) 83 | x[t + 1] = x[t] + v 84 | tot_length = tot_length + abs(x[t + 1] - x[t]) 85 | 86 | # centere the motion 87 | x += complex(real=-np.min(x.real), imag=-np.min(x.imag)) 88 | x = x - complex(real=x[0].real % 1., imag=x[0].imag % 1.) + complex(1, 1) 89 | x += complex(real=ceil((self.canvas - max(x.real)) / 2), imag=ceil((self.canvas - max(x.imag)) / 2)) 90 | 91 | self.tot_length = tot_length 92 | self.big_expl_count = big_expl_count 93 | self.x = x 94 | 95 | if show or save: 96 | self.__plot_canvas(show, save) 97 | return self 98 | 99 | def __plot_canvas(self, show, save): 100 | if self.x is None: 101 | raise Exception("Please run fit() method first") 102 | else: 103 | plt.close() 104 | plt.plot(self.x.real, self.x.imag, '-', color='blue') 105 | 106 | plt.xlim((0, self.canvas)) 107 | plt.ylim((0, self.canvas)) 108 | if show and save: 109 | plt.savefig(self.path_to_save) 110 | plt.show() 111 | elif save: 112 | if self.path_to_save is None: 113 | raise Exception('Please create Trajectory instance with path_to_save') 114 | plt.savefig(self.path_to_save) 115 | elif show: 116 | plt.show() 117 | 118 | 119 | if __name__ == '__main__': 120 | trajectory = Trajectory(expl=0.005, 121 | path_to_save='/Users/mykolam/PycharmProjects/University/RandomMotionBlur/main.png') 122 | trajectory.fit(True, False) 123 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/options/__init__.py -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | 6 | class BaseOptions(): 7 | def __init__(self): 8 | self.parser = argparse.ArgumentParser() 9 | self.initialized = False 10 | 11 | def initialize(self): 12 | self.parser.add_argument('--dataroot', type=str, default="D:\Photos\TrainingData\BlurredSharp\combined", help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') 13 | self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size') 14 | self.parser.add_argument('--loadSizeX', type=int, default=640, help='scale images to this size') 15 | self.parser.add_argument('--loadSizeY', type=int, default=360, help='scale images to this size') 16 | self.parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size') 17 | self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels') 18 | self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') 19 | self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') 20 | self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') 21 | self.parser.add_argument('--which_model_netD', type=str, default='basic', help='selects model to use for netD') 22 | self.parser.add_argument('--which_model_netG', type=str, default='resnet_9blocks', help='selects model to use for netG') 23 | self.parser.add_argument('--learn_residual', action='store_true', help='if specified, model would learn only the residual to the input') 24 | self.parser.add_argument('--gan_type', type=str, default='wgan-gp', help='wgan-gp : Wasserstein GAN with Gradient Penalty, lsgan : Least Sqaures GAN, gan : Vanilla GAN') 25 | self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers') 26 | self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 27 | self.parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') 28 | self.parser.add_argument('--dataset_mode', type=str, default='aligned', help='chooses how datasets are loaded. [unaligned | aligned | single]') 29 | self.parser.add_argument('--model', type=str, default='content_gan', help='chooses which model to use. pix2pix, test, content_gan') 30 | self.parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA') 31 | self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data') 32 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 33 | self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization') 34 | self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 35 | self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size') 36 | self.parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') 37 | self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') 38 | self.parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 39 | self.parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') 40 | self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 41 | self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') 42 | self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') 43 | 44 | self.initialized = True 45 | 46 | def parse(self): 47 | if not self.initialized: 48 | self.initialize() 49 | self.opt = self.parser.parse_args() 50 | self.opt.isTrain = self.isTrain # train or test 51 | 52 | str_ids = self.opt.gpu_ids.split(',') 53 | self.opt.gpu_ids = [] 54 | for str_id in str_ids: 55 | id = int(str_id) 56 | if id >= 0: 57 | self.opt.gpu_ids.append(id) 58 | 59 | # set gpu ids 60 | if len(self.opt.gpu_ids) > 0: 61 | torch.cuda.set_device(self.opt.gpu_ids[0]) 62 | 63 | args = vars(self.opt) 64 | 65 | print('------------ Options -------------') 66 | for k, v in sorted(args.items()): 67 | print('%s: %s' % (str(k), str(v))) 68 | print('-------------- End ----------------') 69 | 70 | # save to the disk 71 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) 72 | util.mkdirs(expr_dir) 73 | file_name = os.path.join(expr_dir, 'opt.txt') 74 | with open(file_name, 'wt') as opt_file: 75 | opt_file.write('------------ Options -------------\n') 76 | for k, v in sorted(args.items()): 77 | opt_file.write('%s: %s\n' % (str(k), str(v))) 78 | opt_file.write('-------------- End ----------------\n') 79 | return self.opt 80 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | def initialize(self): 6 | BaseOptions.initialize(self) 7 | self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 8 | self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 9 | self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 10 | self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 11 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 12 | self.parser.add_argument('--how_many', type=int, default=5000, help='how many test images to run') 13 | self.isTrain = False 14 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def initialize(self): 6 | BaseOptions.initialize(self) 7 | self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen') 8 | self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 9 | self.parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 10 | self.parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') 11 | self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 12 | self.parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 13 | self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 14 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 15 | self.parser.add_argument('--niter', type=int, default=150, help='# of iter at starting learning rate') 16 | self.parser.add_argument('--niter_decay', type=int, default=150, help='# of iter to linearly decay learning rate to zero') 17 | self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 18 | self.parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam') 19 | self.parser.add_argument('--lambda_A', type=float, default=100.0, help='weight for cycle loss (A -> B -> A)') 20 | self.parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)') 21 | self.parser.add_argument('--identity', type=float, default=0.0, help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1') 22 | self.parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') 23 | self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 24 | self.isTrain = True 25 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | from options.test_options import TestOptions 4 | from data.data_loader import CreateDataLoader 5 | from models.models import create_model 6 | from util.visualizer import Visualizer 7 | from pdb import set_trace as st 8 | from util import html 9 | from util.metrics import PSNR 10 | from ssim import SSIM 11 | from PIL import Image 12 | 13 | opt = TestOptions().parse() 14 | opt.nThreads = 1 # test code only supports nThreads = 1 15 | opt.batchSize = 1 # test code only supports batchSize = 1 16 | opt.serial_batches = True # no shuffle 17 | opt.no_flip = True # no flip 18 | 19 | data_loader = CreateDataLoader(opt) 20 | dataset = data_loader.load_data() 21 | model = create_model(opt) 22 | visualizer = Visualizer(opt) 23 | # create website 24 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) 25 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) 26 | # test 27 | avgPSNR = 0.0 28 | avgSSIM = 0.0 29 | counter = 0 30 | 31 | for i, data in enumerate(dataset): 32 | if i >= opt.how_many: 33 | break 34 | counter = i 35 | model.set_input(data) 36 | model.test() 37 | visuals = model.get_current_visuals() 38 | #avgPSNR += PSNR(visuals['fake_B'],visuals['real_B']) 39 | #pilFake = Image.fromarray(visuals['fake_B']) 40 | #pilReal = Image.fromarray(visuals['real_B']) 41 | #avgSSIM += SSIM(pilFake).cw_ssim_value(pilReal) 42 | img_path = model.get_image_paths() 43 | print('process image... %s' % img_path) 44 | visualizer.save_images(webpage, visuals, img_path) 45 | 46 | #avgPSNR /= counter 47 | #avgSSIM /= counter 48 | #print('PSNR = %f, SSIM = %f' % 49 | # (avgPSNR, avgSSIM)) 50 | 51 | webpage.save() 52 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | from options.train_options import TrainOptions 3 | from data.data_loader import CreateDataLoader 4 | from models.models import create_model 5 | from util.visualizer import Visualizer 6 | from util.metrics import PSNR, SSIM 7 | from multiprocessing import freeze_support 8 | 9 | def train(opt, data_loader, model, visualizer): 10 | dataset = data_loader.load_data() 11 | dataset_size = len(data_loader) 12 | print('#training images = %d' % dataset_size) 13 | total_steps = 0 14 | for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): 15 | epoch_start_time = time.time() 16 | epoch_iter = 0 17 | for i, data in enumerate(dataset): 18 | iter_start_time = time.time() 19 | total_steps += opt.batchSize 20 | epoch_iter += opt.batchSize 21 | model.set_input(data) 22 | model.optimize_parameters() 23 | 24 | if total_steps % opt.display_freq == 0: 25 | results = model.get_current_visuals() 26 | psnrMetric = PSNR(results['Restored_Train'], results['Sharp_Train']) 27 | print('PSNR on Train = %f' % psnrMetric) 28 | visualizer.display_current_results(results, epoch) 29 | 30 | if total_steps % opt.print_freq == 0: 31 | errors = model.get_current_errors() 32 | t = (time.time() - iter_start_time) / opt.batchSize 33 | visualizer.print_current_errors(epoch, epoch_iter, errors, t) 34 | if opt.display_id > 0: 35 | visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors) 36 | 37 | if total_steps % opt.save_latest_freq == 0: 38 | print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps)) 39 | model.save('latest') 40 | 41 | if epoch % opt.save_epoch_freq == 0: 42 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) 43 | model.save('latest') 44 | model.save(epoch) 45 | 46 | print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 47 | 48 | if epoch > opt.niter: 49 | model.update_learning_rate() 50 | 51 | 52 | if __name__ == '__main__': 53 | freeze_support() 54 | 55 | # python train.py --dataroot /.path_to_your_data --learn_residual --resize_or_crop crop --fineSize CROP_SIZE (we used 256) 56 | 57 | opt = TrainOptions().parse() 58 | opt.dataroot = 'D:\Photos\TrainingData\BlurredSharp\combined' 59 | opt.learn_residual = True 60 | opt.resize_or_crop = "crop" 61 | opt.fineSize = 256 62 | opt.gan_type = "gan" 63 | # opt.which_model_netG = "unet_256" 64 | 65 | # default = 5000 66 | opt.save_latest_freq = 100 67 | 68 | # default = 100 69 | opt.print_freq = 20 70 | 71 | data_loader = CreateDataLoader(opt) 72 | model = create_model(opt) 73 | visualizer = Visualizer(opt) 74 | train(opt, data_loader, model, visualizer) 75 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KupynOrest/DeblurGAN/2499a87f96f34f262ea294c7e4cd3fd7e90251f8/util/__init__.py -------------------------------------------------------------------------------- /util/get_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import tarfile 4 | import requests 5 | from warnings import warn 6 | from zipfile import ZipFile 7 | from bs4 import BeautifulSoup 8 | from os.path import abspath, isdir, join, basename 9 | 10 | 11 | class GetData(object): 12 | """ 13 | 14 | Download CycleGAN or Pix2Pix Data. 15 | 16 | Args: 17 | technique : str 18 | One of: 'cyclegan' or 'pix2pix'. 19 | verbose : bool 20 | If True, print additional information. 21 | 22 | Examples: 23 | >>> from util.get_data import GetData 24 | >>> gd = GetData(technique='cyclegan') 25 | >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. 26 | 27 | """ 28 | 29 | def __init__(self, technique='cyclegan', verbose=True): 30 | url_dict = { 31 | 'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets', 32 | 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets' 33 | } 34 | self.url = url_dict.get(technique.lower()) 35 | self._verbose = verbose 36 | 37 | def _print(self, text): 38 | if self._verbose: 39 | print(text) 40 | 41 | @staticmethod 42 | def _get_options(r): 43 | soup = BeautifulSoup(r.text, 'lxml') 44 | options = [h.text for h in soup.find_all('a', href=True) 45 | if h.text.endswith(('.zip', 'tar.gz'))] 46 | return options 47 | 48 | def _present_options(self): 49 | r = requests.get(self.url) 50 | options = self._get_options(r) 51 | print('Options:\n') 52 | for i, o in enumerate(options): 53 | print("{0}: {1}".format(i, o)) 54 | choice = input("\nPlease enter the number of the " 55 | "dataset above you wish to download:") 56 | return options[int(choice)] 57 | 58 | def _download_data(self, dataset_url, save_path): 59 | if not isdir(save_path): 60 | os.makedirs(save_path) 61 | 62 | base = basename(dataset_url) 63 | temp_save_path = join(save_path, base) 64 | 65 | with open(temp_save_path, "wb") as f: 66 | r = requests.get(dataset_url) 67 | f.write(r.content) 68 | 69 | if base.endswith('.tar.gz'): 70 | obj = tarfile.open(temp_save_path) 71 | elif base.endswith('.zip'): 72 | obj = ZipFile(temp_save_path, 'r') 73 | else: 74 | raise ValueError("Unknown File Type: {0}.".format(base)) 75 | 76 | self._print("Unpacking Data...") 77 | obj.extractall(save_path) 78 | obj.close() 79 | os.remove(temp_save_path) 80 | 81 | def get(self, save_path, dataset=None): 82 | """ 83 | 84 | Download a dataset. 85 | 86 | Args: 87 | save_path : str 88 | A directory to save the data to. 89 | dataset : str, optional 90 | A specific dataset to download. 91 | Note: this must include the file extension. 92 | If None, options will be presented for you 93 | to choose from. 94 | 95 | Returns: 96 | save_path_full : str 97 | The absolute path to the downloaded data. 98 | 99 | """ 100 | if dataset is None: 101 | selected_dataset = self._present_options() 102 | else: 103 | selected_dataset = dataset 104 | 105 | save_path_full = join(save_path, selected_dataset.split('.')[0]) 106 | 107 | if isdir(save_path_full): 108 | warn("\n'{0}' already exists. Voiding Download.".format( 109 | save_path_full)) 110 | else: 111 | self._print('Downloading Data...') 112 | url = "{0}/{1}".format(self.url, selected_dataset) 113 | self._download_data(url, save_path=save_path) 114 | 115 | return abspath(save_path_full) 116 | -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import * 3 | import os 4 | 5 | 6 | class HTML: 7 | def __init__(self, web_dir, title, reflesh=0): 8 | self.title = title 9 | self.web_dir = web_dir 10 | self.img_dir = os.path.join(self.web_dir, 'images') 11 | if not os.path.exists(self.web_dir): 12 | os.makedirs(self.web_dir) 13 | if not os.path.exists(self.img_dir): 14 | os.makedirs(self.img_dir) 15 | # print(self.img_dir) 16 | 17 | self.doc = dominate.document(title=title) 18 | if reflesh > 0: 19 | with self.doc.head: 20 | meta(http_equiv="reflesh", content=str(reflesh)) 21 | 22 | def get_image_dir(self): 23 | return self.img_dir 24 | 25 | def add_header(self, str): 26 | with self.doc: 27 | h3(str) 28 | 29 | def add_table(self, border=1): 30 | self.t = table(border=border, style="table-layout: fixed;") 31 | self.doc.add(self.t) 32 | 33 | def add_images(self, ims, txts, links, width=400): 34 | self.add_table() 35 | with self.t: 36 | with tr(): 37 | for im, txt, link in zip(ims, txts, links): 38 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 39 | with p(): 40 | with a(href=os.path.join('images', link)): 41 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 42 | br() 43 | p(txt) 44 | 45 | def save(self): 46 | html_file = '%s/index.html' % self.web_dir 47 | f = open(html_file, 'wt') 48 | f.write(self.doc.render()) 49 | f.close() 50 | 51 | 52 | if __name__ == '__main__': 53 | html = HTML('web/', 'test_html') 54 | html.add_header('hello world') 55 | 56 | ims = [] 57 | txts = [] 58 | links = [] 59 | for n in range(4): 60 | ims.append('image_%d.png' % n) 61 | txts.append('text_%d' % n) 62 | links.append('image_%d.png' % n) 63 | html.add_images(ims, txts, links) 64 | html.save() 65 | -------------------------------------------------------------------------------- /util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | from torch.autograd import Variable 5 | class ImagePool(): 6 | def __init__(self, pool_size): 7 | self.pool_size = pool_size 8 | if self.pool_size > 0: 9 | self.num_imgs = 0 10 | self.images = [] 11 | 12 | def query(self, images): 13 | if self.pool_size == 0: 14 | return images 15 | return_images = [] 16 | for image in images.data: 17 | image = torch.unsqueeze(image, 0) 18 | if self.num_imgs < self.pool_size: 19 | self.num_imgs = self.num_imgs + 1 20 | self.images.append(image) 21 | return_images.append(image) 22 | else: 23 | p = random.uniform(0, 1) 24 | if p > 0.5: 25 | random_id = random.randint(0, self.pool_size-1) 26 | tmp = self.images[random_id].clone() 27 | self.images[random_id] = image 28 | return_images.append(tmp) 29 | else: 30 | return_images.append(image) 31 | return_images = Variable(torch.cat(return_images, 0)) 32 | return return_images 33 | -------------------------------------------------------------------------------- /util/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | import math 7 | 8 | def gaussian(window_size, sigma): 9 | gauss = torch.Tensor([exp(-(x - window_size/2)**2/float(2*sigma**2)) for x in range(window_size)]) 10 | return gauss/gauss.sum() 11 | 12 | def create_window(window_size, channel): 13 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 15 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size)) 16 | return window 17 | 18 | def SSIM(img1, img2): 19 | (_, channel, _, _) = img1.size() 20 | window_size = 11 21 | window = create_window(window_size, channel) 22 | mu1 = F.conv2d(img1, window, padding = window_size/2, groups = channel) 23 | mu2 = F.conv2d(img2, window, padding = window_size/2, groups = channel) 24 | 25 | mu1_sq = mu1.pow(2) 26 | mu2_sq = mu2.pow(2) 27 | mu1_mu2 = mu1*mu2 28 | 29 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size/2, groups = channel) - mu1_sq 30 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size/2, groups = channel) - mu2_sq 31 | sigma12 = F.conv2d(img1*img2, window, padding = window_size/2, groups = channel) - mu1_mu2 32 | 33 | C1 = 0.01**2 34 | C2 = 0.03**2 35 | 36 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 37 | return ssim_map.mean() 38 | 39 | def PSNR(img1, img2): 40 | mse = np.mean( (img1/255. - img2/255.) ** 2 ) 41 | if mse == 0: 42 | return 100 43 | PIXEL_MAX = 1 44 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 45 | -------------------------------------------------------------------------------- /util/png.py: -------------------------------------------------------------------------------- 1 | import struct 2 | import zlib 3 | 4 | def encode(buf, width, height): 5 | """ buf: must be bytes or a bytearray in py3, a regular string in py2. formatted RGBRGB... """ 6 | assert (width * height * 3 == len(buf)) 7 | bpp = 3 8 | 9 | def raw_data(): 10 | # reverse the vertical line order and add null bytes at the start 11 | row_bytes = width * bpp 12 | for row_start in range((height - 1) * width * bpp, -1, -row_bytes): 13 | yield b'\x00' 14 | yield buf[row_start:row_start + row_bytes] 15 | 16 | def chunk(tag, data): 17 | return [ 18 | struct.pack("!I", len(data)), 19 | tag, 20 | data, 21 | struct.pack("!I", 0xFFFFFFFF & zlib.crc32(data, zlib.crc32(tag))) 22 | ] 23 | 24 | SIGNATURE = b'\x89PNG\r\n\x1a\n' 25 | COLOR_TYPE_RGB = 2 26 | COLOR_TYPE_RGBA = 6 27 | bit_depth = 8 28 | return b''.join( 29 | [ SIGNATURE ] + 30 | chunk(b'IHDR', struct.pack("!2I5B", width, height, bit_depth, COLOR_TYPE_RGB, 0, 0, 0)) + 31 | chunk(b'IDAT', zlib.compress(b''.join(raw_data()), 9)) + 32 | chunk(b'IEND', b'') 33 | ) 34 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import inspect, re 6 | import numpy as np 7 | import os 8 | import collections 9 | 10 | # Converts a Tensor into a Numpy array 11 | # |imtype|: the desired type of the converted numpy array 12 | def tensor2im(image_tensor, imtype=np.uint8): 13 | image_numpy = image_tensor[0].cpu().float().numpy() 14 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 15 | return image_numpy.astype(imtype) 16 | 17 | 18 | def diagnose_network(net, name='network'): 19 | mean = 0.0 20 | count = 0 21 | for param in net.parameters(): 22 | if param.grad is not None: 23 | mean += torch.mean(torch.abs(param.grad.data)) 24 | count += 1 25 | if count > 0: 26 | mean = mean / count 27 | print(name) 28 | print(mean) 29 | 30 | 31 | def save_image(image_numpy, image_path): 32 | image_pil = None 33 | if image_numpy.shape[2] == 1: 34 | image_numpy = np.reshape(image_numpy, (image_numpy.shape[0],image_numpy.shape[1])) 35 | image_pil = Image.fromarray(image_numpy, 'L') 36 | else: 37 | image_pil = Image.fromarray(image_numpy) 38 | image_pil.save(image_path) 39 | 40 | def info(object, spacing=10, collapse=1): 41 | """Print methods and doc strings. 42 | Takes module, class, list, dictionary, or string.""" 43 | methodList = [e for e in dir(object) if isinstance(getattr(object, e), collections.Callable)] 44 | processFunc = collapse and (lambda s: " ".join(s.split())) or (lambda s: s) 45 | print( "\n".join(["%s %s" % 46 | (method.ljust(spacing), 47 | processFunc(str(getattr(object, method).__doc__))) 48 | for method in methodList]) ) 49 | 50 | def varname(p): 51 | for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]: 52 | m = re.search(r'\bvarname\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\)', line) 53 | if m: 54 | return m.group(1) 55 | 56 | def print_numpy(x, val=True, shp=False): 57 | x = x.astype(np.float64) 58 | if shp: 59 | print('shape,', x.shape) 60 | if val: 61 | x = x.flatten() 62 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 63 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 64 | 65 | 66 | def mkdirs(paths): 67 | if isinstance(paths, list) and not isinstance(paths, str): 68 | for path in paths: 69 | mkdir(path) 70 | else: 71 | mkdir(paths) 72 | 73 | 74 | def mkdir(path): 75 | if not os.path.exists(path): 76 | os.makedirs(path) 77 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import ntpath 4 | import time 5 | from . import util 6 | from . import html 7 | 8 | class Visualizer(): 9 | def __init__(self, opt): 10 | # self.opt = opt 11 | self.display_id = opt.display_id 12 | self.use_html = opt.isTrain and not opt.no_html 13 | self.win_size = opt.display_winsize 14 | self.name = opt.name 15 | if self.display_id > 0: 16 | import visdom 17 | self.vis = visdom.Visdom(port = opt.display_port) 18 | self.display_single_pane_ncols = opt.display_single_pane_ncols 19 | 20 | if self.use_html: 21 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 22 | self.img_dir = os.path.join(self.web_dir, 'images') 23 | print('create web directory %s...' % self.web_dir) 24 | util.mkdirs([self.web_dir, self.img_dir]) 25 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 26 | with open(self.log_name, "a") as log_file: 27 | now = time.strftime("%c") 28 | log_file.write('================ Training Loss (%s) ================\n' % now) 29 | 30 | # |visuals|: dictionary of images to display or save 31 | def display_current_results(self, visuals, epoch): 32 | if self.display_id > 0: # show images in the browser 33 | if self.display_single_pane_ncols > 0: 34 | h, w = next(iter(visuals.values())).shape[:2] 35 | table_css = """""" % (w, h) 39 | ncols = self.display_single_pane_ncols 40 | title = self.name 41 | label_html = '' 42 | label_html_row = '' 43 | nrows = int(np.ceil(len(visuals.items()) / ncols)) 44 | images = [] 45 | idx = 0 46 | for label, image_numpy in visuals.items(): 47 | label_html_row += '%s' % label 48 | images.append(image_numpy.transpose([2, 0, 1])) 49 | idx += 1 50 | if idx % ncols == 0: 51 | label_html += '%s' % label_html_row 52 | label_html_row = '' 53 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255 54 | while idx % ncols != 0: 55 | images.append(white_image) 56 | label_html_row += '' 57 | idx += 1 58 | if label_html_row != '': 59 | label_html += '%s' % label_html_row 60 | # pane col = image row 61 | self.vis.images(images, nrow=ncols, win=self.display_id + 1, 62 | padding=2, opts=dict(title=title + ' images')) 63 | label_html = '%s
' % label_html 64 | self.vis.text(table_css + label_html, win = self.display_id + 2, 65 | opts=dict(title=title + ' labels')) 66 | else: 67 | idx = 1 68 | for label, image_numpy in visuals.items(): 69 | #image_numpy = np.flipud(image_numpy) 70 | self.vis.image(image_numpy.transpose([2,0,1]), opts=dict(title=label), 71 | win=self.display_id + idx) 72 | idx += 1 73 | 74 | if self.use_html: # save images to a html file 75 | for label, image_numpy in visuals.items(): 76 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) 77 | util.save_image(image_numpy, img_path) 78 | # update website 79 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1) 80 | for n in range(epoch, 0, -1): 81 | webpage.add_header('Results of Epoch [%d]' % n) 82 | ims = [] 83 | txts = [] 84 | links = [] 85 | 86 | for label, image_numpy in visuals.items(): 87 | img_path = 'epoch%.3d_%s.png' % (n, label) 88 | ims.append(img_path) 89 | txts.append(label) 90 | links.append(img_path) 91 | webpage.add_images(ims, txts, links, width=self.win_size) 92 | webpage.save() 93 | 94 | # errors: dictionary of error labels and values 95 | def plot_current_errors(self, epoch, counter_ratio, opt, errors): 96 | if not hasattr(self, 'plot_data'): 97 | self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())} 98 | self.plot_data['X'].append(epoch + counter_ratio) 99 | self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']]) 100 | self.vis.line( 101 | X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1), 102 | Y=np.array(self.plot_data['Y']), 103 | opts={ 104 | 'title': self.name + ' loss over time', 105 | 'legend': self.plot_data['legend'], 106 | 'xlabel': 'epoch', 107 | 'ylabel': 'loss'}, 108 | win=self.display_id) 109 | 110 | # errors: same format as |errors| of plotCurrentErrors 111 | def print_current_errors(self, epoch, i, errors, t): 112 | message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t) 113 | for k, v in errors.items(): 114 | message += '%s: %.3f ' % (k, v) 115 | 116 | print(message) 117 | with open(self.log_name, "a") as log_file: 118 | log_file.write('%s\n' % message) 119 | 120 | # save image to the disk 121 | def save_images(self, webpage, visuals, image_path): 122 | image_dir = webpage.get_image_dir() 123 | short_path = ntpath.basename(image_path[0]) 124 | name = os.path.splitext(short_path)[0] 125 | 126 | webpage.add_header(name) 127 | ims = [] 128 | txts = [] 129 | links = [] 130 | 131 | for label, image_numpy in visuals.items(): 132 | image_name = '%s_%s.png' % (name, label) 133 | save_path = os.path.join(image_dir, image_name) 134 | util.save_image(image_numpy, save_path) 135 | 136 | ims.append(image_name) 137 | txts.append(label) 138 | links.append(image_name) 139 | webpage.add_images(ims, txts, links, width=self.win_size) 140 | --------------------------------------------------------------------------------