├── README.md
├── code
├── ablation
│ ├── r1
│ │ ├── config
│ │ │ ├── cal_ssim.py
│ │ │ ├── clean.sh
│ │ │ ├── compile.py
│ │ │ ├── dataset.py
│ │ │ ├── eval.py
│ │ │ ├── model.py
│ │ │ ├── settings.py
│ │ │ ├── show.py
│ │ │ ├── tensorboard.sh
│ │ │ └── train.py
│ │ └── models
│ │ │ └── README.md
│ └── r2
│ │ ├── config
│ │ ├── .DS_Store
│ │ ├── cal_ssim.py
│ │ ├── clean.sh
│ │ ├── compile.py
│ │ ├── dataset.py
│ │ ├── eval.py
│ │ ├── model.py
│ │ ├── settings.py
│ │ ├── show.py
│ │ ├── tensorboard.sh
│ │ └── train.py
│ │ └── models
│ │ ├── .DS_Store
│ │ └── README.md
├── base
│ ├── .DS_Store
│ ├── rain100H
│ │ ├── .DS_Store
│ │ ├── config
│ │ │ ├── .DS_Store
│ │ │ ├── cal_ssim.py
│ │ │ ├── clean.sh
│ │ │ ├── compile.py
│ │ │ ├── dataset.py
│ │ │ ├── eval.py
│ │ │ ├── function
│ │ │ │ ├── __pycache__
│ │ │ │ │ └── functional.cpython-37.pyc
│ │ │ │ ├── functional.py
│ │ │ │ ├── functions
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── __pycache__
│ │ │ │ │ │ ├── __init__.cpython-37.pyc
│ │ │ │ │ │ ├── aggregation_refpad.cpython-37.pyc
│ │ │ │ │ │ ├── aggregation_zeropad.cpython-37.pyc
│ │ │ │ │ │ ├── subtraction2_refpad.cpython-37.pyc
│ │ │ │ │ │ ├── subtraction2_zeropad.cpython-37.pyc
│ │ │ │ │ │ ├── subtraction_refpad.cpython-37.pyc
│ │ │ │ │ │ ├── subtraction_zeropad.cpython-37.pyc
│ │ │ │ │ │ └── utils.cpython-37.pyc
│ │ │ │ │ ├── aggregation_refpad.py
│ │ │ │ │ ├── aggregation_zeropad.py
│ │ │ │ │ ├── subtraction2_refpad.py
│ │ │ │ │ ├── subtraction2_zeropad.py
│ │ │ │ │ ├── subtraction_refpad.py
│ │ │ │ │ ├── subtraction_zeropad.py
│ │ │ │ │ └── utils.py
│ │ │ │ └── modules
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── __pycache__
│ │ │ │ │ ├── __init__.cpython-37.pyc
│ │ │ │ │ ├── aggregation.cpython-37.pyc
│ │ │ │ │ ├── subtraction.cpython-37.pyc
│ │ │ │ │ └── subtraction2.cpython-37.pyc
│ │ │ │ │ ├── aggregation.py
│ │ │ │ │ ├── subtraction.py
│ │ │ │ │ └── subtraction2.py
│ │ │ ├── model.py
│ │ │ ├── settings.py
│ │ │ ├── show.py
│ │ │ ├── tensorboard.sh
│ │ │ └── train.py
│ │ └── models
│ │ │ ├── .DS_Store
│ │ │ └── README.md
│ └── rain100L
│ │ ├── .DS_Store
│ │ ├── config
│ │ ├── .DS_Store
│ │ ├── cal_ssim.py
│ │ ├── clean.sh
│ │ ├── compile.py
│ │ ├── dataset.py
│ │ ├── eval.py
│ │ ├── function
│ │ │ ├── __pycache__
│ │ │ │ └── functional.cpython-37.pyc
│ │ │ ├── functional.py
│ │ │ ├── functions
│ │ │ │ ├── __init__.py
│ │ │ │ ├── __pycache__
│ │ │ │ │ ├── __init__.cpython-37.pyc
│ │ │ │ │ ├── aggregation_refpad.cpython-37.pyc
│ │ │ │ │ ├── aggregation_zeropad.cpython-37.pyc
│ │ │ │ │ ├── subtraction2_refpad.cpython-37.pyc
│ │ │ │ │ ├── subtraction2_zeropad.cpython-37.pyc
│ │ │ │ │ ├── subtraction_refpad.cpython-37.pyc
│ │ │ │ │ ├── subtraction_zeropad.cpython-37.pyc
│ │ │ │ │ └── utils.cpython-37.pyc
│ │ │ │ ├── aggregation_refpad.py
│ │ │ │ ├── aggregation_zeropad.py
│ │ │ │ ├── subtraction2_refpad.py
│ │ │ │ ├── subtraction2_zeropad.py
│ │ │ │ ├── subtraction_refpad.py
│ │ │ │ ├── subtraction_zeropad.py
│ │ │ │ └── utils.py
│ │ │ └── modules
│ │ │ │ ├── __init__.py
│ │ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-37.pyc
│ │ │ │ ├── aggregation.cpython-37.pyc
│ │ │ │ ├── subtraction.cpython-37.pyc
│ │ │ │ └── subtraction2.cpython-37.pyc
│ │ │ │ ├── aggregation.py
│ │ │ │ ├── subtraction.py
│ │ │ │ └── subtraction2.py
│ │ ├── model.py
│ │ ├── settings.py
│ │ ├── show.py
│ │ ├── tensorboard.sh
│ │ └── train.py
│ │ └── models
│ │ ├── .DS_Store
│ │ └── README.md
└── diff_loss
│ ├── .DS_Store
│ ├── mae
│ ├── .DS_Store
│ ├── config
│ │ ├── .DS_Store
│ │ ├── cal_ssim.py
│ │ ├── clean.sh
│ │ ├── compile.py
│ │ ├── dataset.py
│ │ ├── eval.py
│ │ ├── function
│ │ │ ├── __pycache__
│ │ │ │ ├── functional.cpython-36.pyc
│ │ │ │ └── functional.cpython-37.pyc
│ │ │ ├── functional.py
│ │ │ ├── functions
│ │ │ │ ├── __init__.py
│ │ │ │ ├── __pycache__
│ │ │ │ │ ├── __init__.cpython-36.pyc
│ │ │ │ │ ├── __init__.cpython-37.pyc
│ │ │ │ │ ├── aggregation_refpad.cpython-36.pyc
│ │ │ │ │ ├── aggregation_refpad.cpython-37.pyc
│ │ │ │ │ ├── aggregation_zeropad.cpython-36.pyc
│ │ │ │ │ ├── aggregation_zeropad.cpython-37.pyc
│ │ │ │ │ ├── subtraction2_refpad.cpython-36.pyc
│ │ │ │ │ ├── subtraction2_refpad.cpython-37.pyc
│ │ │ │ │ ├── subtraction2_zeropad.cpython-36.pyc
│ │ │ │ │ ├── subtraction2_zeropad.cpython-37.pyc
│ │ │ │ │ ├── subtraction_refpad.cpython-36.pyc
│ │ │ │ │ ├── subtraction_refpad.cpython-37.pyc
│ │ │ │ │ ├── subtraction_zeropad.cpython-36.pyc
│ │ │ │ │ ├── subtraction_zeropad.cpython-37.pyc
│ │ │ │ │ ├── utils.cpython-36.pyc
│ │ │ │ │ └── utils.cpython-37.pyc
│ │ │ │ ├── aggregation_refpad.py
│ │ │ │ ├── aggregation_zeropad.py
│ │ │ │ ├── subtraction2_refpad.py
│ │ │ │ ├── subtraction2_zeropad.py
│ │ │ │ ├── subtraction_refpad.py
│ │ │ │ ├── subtraction_zeropad.py
│ │ │ │ └── utils.py
│ │ │ └── modules
│ │ │ │ ├── __init__.py
│ │ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-36.pyc
│ │ │ │ ├── __init__.cpython-37.pyc
│ │ │ │ ├── aggregation.cpython-36.pyc
│ │ │ │ ├── aggregation.cpython-37.pyc
│ │ │ │ ├── subtraction.cpython-36.pyc
│ │ │ │ ├── subtraction.cpython-37.pyc
│ │ │ │ ├── subtraction2.cpython-36.pyc
│ │ │ │ └── subtraction2.cpython-37.pyc
│ │ │ │ ├── aggregation.py
│ │ │ │ ├── subtraction.py
│ │ │ │ └── subtraction2.py
│ │ ├── model.py
│ │ ├── settings.py
│ │ ├── show.py
│ │ ├── tensorboard.sh
│ │ └── train.py
│ └── models
│ │ ├── .DS_Store
│ │ └── README.md
│ └── mse
│ ├── .DS_Store
│ ├── config
│ ├── .DS_Store
│ ├── cal_ssim.py
│ ├── clean.sh
│ ├── compile.py
│ ├── dataset.py
│ ├── eval.py
│ ├── function
│ │ ├── __pycache__
│ │ │ └── functional.cpython-37.pyc
│ │ ├── functional.py
│ │ ├── functions
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-37.pyc
│ │ │ │ ├── aggregation_refpad.cpython-37.pyc
│ │ │ │ ├── aggregation_zeropad.cpython-37.pyc
│ │ │ │ ├── subtraction2_refpad.cpython-37.pyc
│ │ │ │ ├── subtraction2_zeropad.cpython-37.pyc
│ │ │ │ ├── subtraction_refpad.cpython-37.pyc
│ │ │ │ ├── subtraction_zeropad.cpython-37.pyc
│ │ │ │ └── utils.cpython-37.pyc
│ │ │ ├── aggregation_refpad.py
│ │ │ ├── aggregation_zeropad.py
│ │ │ ├── subtraction2_refpad.py
│ │ │ ├── subtraction2_zeropad.py
│ │ │ ├── subtraction_refpad.py
│ │ │ ├── subtraction_zeropad.py
│ │ │ └── utils.py
│ │ └── modules
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── aggregation.cpython-37.pyc
│ │ │ ├── subtraction.cpython-37.pyc
│ │ │ └── subtraction2.cpython-37.pyc
│ │ │ ├── aggregation.py
│ │ │ ├── subtraction.py
│ │ │ └── subtraction2.py
│ ├── model.py
│ ├── settings.py
│ ├── show.py
│ ├── tensorboard.sh
│ └── train.py
│ └── models
│ ├── .DS_Store
│ └── README.md
└── fig
├── .DS_Store
├── ex_pair.png
├── ex_unpair.png
├── overall.png
└── result.png
/README.md:
--------------------------------------------------------------------------------
1 | # JDNet:Joint Self-Attention and Scale-Aggregation for Self-Calibrated Deraining Network
2 |
3 | [Cong Wang](https://supercong94.wixsite.com/supercong94)\*, [Yutong Wu](https://ohraincu.github.io/)\*, [Zhixun Su](http://faculty.dlut.edu.cn/ZhixunSu/zh_CN/index/759047/list/index.htm) †, Junyang Chen
4 |
5 | <\* Both authors contributed equally to this research. † Corresponding author.>
6 |
7 | This work has been accepted by ACM'MM 2020. [\[Arxiv\]](https://arxiv.org/abs/2008.02763)
8 |
9 |
10 |

11 |
12 | Fig1:An example from real-world datasets.
13 |
14 |
15 | ## Abstract
16 | In this paper, we propose an effective algorithm, called JDNet, to solve the single image deraining problem and conduct the segmentation and detection task for applications. Specifically, considering the important information on multi-scale features, we propose a Scale-Aggregation module to learn the features with different scales. Simultaneously, Self-Attention module is introduced to match or outperform their convolutional counterparts, which allows the feature aggregation to adapt to each channel. Furthermore, to improve the basic convolutional feature transformation process of Convolutional Neural Networks (CNNs), Self-Calibrated convolution is applied to build long-range spatial and inter-channel dependencies around each spatial location that explicitly expand fields-of-view of each convolutional layer through internal communications and hence enriches the output features. By designing the Scale-Aggregation and Self-Attention modules with Self-Calibrated convolution skillfully, the proposed model has better deraining results both on real-world and synthetic datasets. Extensive experiments are con- ducted to demonstrate the superiority of our method compared with state-of-the-art methods.
17 |
18 |
19 |

20 |
21 | Fig2:The architecture of Joint Network for deraining (JDNet).
22 |
23 |
24 | ## Requirements
25 | - CUDA 9.0 (or later)
26 | - Python 3.6 (or later)
27 | - Pytorch 1.1.0
28 | - Torchvision 0.3.0
29 | - OpenCV
30 |
31 | ## Dataset
32 | Please download the following datasets:
33 |
34 | * Rain12 [[dataset](http://yu-li.github.io/paper/li_cvpr16_rain.zip)]
35 | * Rain100L [[dataset](http://www.icst.pku.edu.cn/struct/Projects/joint_rain_removal.html)]
36 | * Rain100H [[dataset](http://www.icst.pku.edu.cn/struct/Projects/joint_rain_removal.html)]
37 | * Real-world images [[BaiduYun](https://pan.baidu.com/s/1gpuB6NUHPnQEtgRn4evr5A)](password:vvuk) / [[GoogleDrive](https://drive.google.com/drive/folders/1sSfm-HplPO3FLKR3wAH3iWBmYsD_UAy_?usp=sharing)]
38 |
39 | ## Setup
40 | Please download this project through 'git' command.
41 | All training and testing experiments are in a folder called [code](https://github.com/Ohraincu/JDNet/tree/master/code):
42 | ```
43 | $ git clone https://github.com/Ohraincu/JDNet.git
44 | $ cd code
45 | ```
46 | Next, we can see folders in the following form:
47 |
48 | |--ablation
49 | |--r1
50 | |--r2
51 |
52 | |--base
53 | |--rain100H
54 | |--rain100L
55 |
56 | |--diff_loss
57 | |--mae
58 | |--mse
59 |
60 | The implementation code in our work is in the [base](https://github.com/Ohraincu/JDNet/tree/master/code/base) folder.
61 |
62 | In the [ablation](https://github.com/Ohraincu/JDNet/tree/master/code/ablation) folder, there are two experimental codes for ablation learning mentioned in the paper.
63 |
64 | The only difference between the code in the [loss](https://github.com/Ohraincu/JDNet/tree/master/code/loss) folder and the [base](https://github.com/Ohraincu/JDNet/tree/master/code/base) is that different loss functions are used in the training process.
65 |
66 | Thanks to [the code by Li et al.](https://xialipku.github.io/RESCAN/), our code is also adapted based on this.
67 |
68 | ## Training
69 | After you download the above datasets, you can perform the following operations to train:
70 | ```
71 | $ cd config
72 | $ python train.py
73 | ```
74 | You can pause or start the training at any time because we can save the pre-trained models (named 'latest_net' or 'net_x_epoch') in due course.
75 |
76 | ## Testing
77 | ### Pre-trained Models
78 | [[BaiduYun](https://pan.baidu.com/s/1MHVnTTN-kjeDQ-x9m_zr0Q)](password:9sbj) /
79 | [[GoogleDrive](https://drive.google.com/drive/folders/1rzidCqbhgnKEGNOgFLqSUynhux4SfURI)]
80 |
81 | ### Quantitative and Qualitative Results
82 | After running eval.py, you can get the corresponding numerical results (PSNR and SSIM):
83 | ```
84 | $ python eval.py
85 | ```
86 | If the visual results on datasets need to be observed, the show.py can be run:
87 | ```
88 | $ python show.py
89 | ```
90 | All the pre-trained model in each case is placed in the corresponding 'model' folder, and the 'latest_net' model is directly referenced by default. If you want to generate results for your training intermediate model, you can make the following changes to eval.py or show.py:
91 | ```
92 | parser.add_argument('-m', '--model', default='net_x_epoch')
93 | ```
94 |
95 | ## Supplement to settings.py
96 | One thing to note is that, there are two types of input in our training or testing.
97 |
98 | One takes a paired image as input, and the other uses a single rainy image, as shown in Fig3 and Fig4 respectively.
99 |
100 |
101 |
102 |
103 |  |
104 |  |
105 |
106 |
107 | Fig3: Paired image. |
108 | Fig4: Unpaired image. |
109 |
110 |
111 |
112 |
113 | For the different forms of input images, we need to modify some settings in the settings.py:
114 |
115 | ```
116 | pic_is_pair = True #input picture is pair or single
117 | data_dir = '/Your_paired_datadir'
118 | if pic_is_pair is False:
119 | data_dir = '/Your_unpaired_datadir'
120 | ```
121 | At this point, if you have the following representation in your train/eval/show.py:
122 |
123 | ```
124 | dt = sess.get_dataloader('test')
125 | ```
126 | The path of the dataset you are about to run is: "/Your_paired_datadir/test"
127 |
128 | ## Citation
129 | Wait for update!
130 | ```
131 | @inproceedings{acmmm20_jdnet,
132 | author = {Cong Wang and Yutong Wu and Zhixun Su and Junyang Chen},
133 | title = {Joint Self-Attention and Scale-Aggregation for Self-Calibrated Deraining Network},
134 | booktitle = {ACM International Conference on Multimedia},
135 | year = {2020},
136 | }
137 | ```
138 |
139 | ## Contact
140 |
141 | If you are interested in our work or have any questions, please directly contact my github.
142 | Email: ytongwu@mail.dlut.edu.cn / supercong94@gmail.com
143 |
--------------------------------------------------------------------------------
/code/ablation/r1/config/cal_ssim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | import numpy as np
5 | from math import exp
6 |
7 | def gaussian(window_size, sigma):
8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
9 | return gauss/gauss.sum()
10 |
11 | def create_window(window_size, channel):
12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
15 | return window
16 |
17 | def _ssim(img1, img2, window, window_size, channel, size_average = True):
18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
20 |
21 | mu1_sq = mu1.pow(2)
22 | mu2_sq = mu2.pow(2)
23 | mu1_mu2 = mu1*mu2
24 |
25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
28 |
29 | C1 = 0.01**2
30 | C2 = 0.03**2
31 |
32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
33 |
34 | if size_average:
35 | return ssim_map.mean()
36 | else:
37 | return ssim_map.mean(1).mean(1).mean(1)
38 |
39 | class SSIM(torch.nn.Module):
40 | def __init__(self, window_size = 11, size_average = True):
41 | super(SSIM, self).__init__()
42 | self.window_size = window_size
43 | self.size_average = size_average
44 | self.channel = 3
45 | self.window = create_window(window_size, self.channel)
46 |
47 | def forward(self, img1, img2):
48 | #print(img1.size())
49 | (_, channel, _, _) = img1.size()
50 |
51 | if channel == self.channel and self.window.data.type() == img1.data.type():
52 | window = self.window
53 | else:
54 | window = create_window(self.window_size, channel)
55 |
56 | if img1.is_cuda:
57 | window = window.cuda(img1.get_device())
58 | window = window.type_as(img1)
59 |
60 | self.window = window
61 | self.channel = channel
62 |
63 |
64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
65 |
66 | def ssim(img1, img2, window_size = 11, size_average = True):
67 | (_, channel, _, _) = img1.size()
68 | window = create_window(window_size, channel)
69 |
70 | if img1.is_cuda:
71 | window = window.cuda(img1.get_device())
72 | window = window.type_as(img1)
73 |
74 | return _ssim(img1, img2, window, window_size, channel, size_average)
75 |
--------------------------------------------------------------------------------
/code/ablation/r1/config/clean.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | rm -rf __pycache__
3 | rm \.*\.swp
4 | rm -R ../logdir/*
5 | rm -R ../showdir/*
6 | rm ../models/*
7 |
--------------------------------------------------------------------------------
/code/ablation/r1/config/compile.py:
--------------------------------------------------------------------------------
1 | import re
2 | rec_data="+MIPLOBSERVE:0,68220,1,3303,0,-1"
3 | msgidRegex = re.compile(r',(\d)+,')
4 | mo = msgidRegex.search(rec_data)
5 | print(mo.group())
--------------------------------------------------------------------------------
/code/ablation/r1/config/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | from numpy.random import RandomState
5 | from torch.utils.data import Dataset
6 |
7 | import settings
8 |
9 |
10 | class TrainValDataset(Dataset):
11 | def __init__(self, name):
12 | super().__init__()
13 | self.rand_state = RandomState(66)
14 | self.root_dir = os.path.join(settings.data_dir, name)
15 | self.mat_files = os.listdir(self.root_dir)
16 | self.patch_size = settings.patch_size
17 | self.file_num = len(self.mat_files)
18 |
19 | def __len__(self):
20 | return self.file_num * 100
21 |
22 | def __getitem__(self, idx):
23 | file_name = self.mat_files[idx % self.file_num]
24 | img_file = os.path.join(self.root_dir, file_name)
25 | img_pair = cv2.imread(img_file).astype(np.float32) / 255
26 |
27 | if settings.aug_data:
28 | O, B = self.crop(img_pair, aug=True)
29 | O, B = self.flip(O, B)
30 | O, B = self.rotate(O, B)
31 | else:
32 | O, B = self.crop(img_pair, aug=False)
33 |
34 | O = np.transpose(O, (2, 0, 1))
35 | B = np.transpose(B, (2, 0, 1))
36 | sample = {'O': O, 'B': B}
37 |
38 | return sample
39 |
40 | def crop(self, img_pair, aug):
41 | patch_size = self.patch_size
42 | h, ww, c = img_pair.shape
43 | w = int(ww / 2)
44 |
45 | if aug:
46 | mini = - 1 / 4 * self.patch_size
47 | maxi = 1 / 4 * self.patch_size + 1
48 | p_h = patch_size + self.rand_state.randint(mini, maxi)
49 | p_w = patch_size + self.rand_state.randint(mini, maxi)
50 | else:
51 | p_h, p_w = patch_size, patch_size
52 |
53 | r = self.rand_state.randint(0, h - p_h)
54 | c = self.rand_state.randint(0, w - p_w)
55 | O = img_pair[r: r+p_h, c+w: c+p_w+w]
56 | B = img_pair[r: r+p_h, c: c+p_w]
57 |
58 | if aug:
59 | O = cv2.resize(O, (patch_size, patch_size))
60 | B = cv2.resize(B, (patch_size, patch_size))
61 |
62 | return O, B
63 |
64 | def flip(self, O, B):
65 | if self.rand_state.rand() > 0.5:
66 | O = np.flip(O, axis=1)
67 | B = np.flip(B, axis=1)
68 | return O, B
69 |
70 | def rotate(self, O, B):
71 | angle = self.rand_state.randint(-30, 30)
72 | patch_size = self.patch_size
73 | center = (int(patch_size / 2), int(patch_size / 2))
74 | M = cv2.getRotationMatrix2D(center, angle, 1)
75 | O = cv2.warpAffine(O, M, (patch_size, patch_size))
76 | B = cv2.warpAffine(B, M, (patch_size, patch_size))
77 | return O, B
78 |
79 |
80 | class TestDataset(Dataset):
81 | def __init__(self, name):
82 | super().__init__()
83 | self.rand_state = RandomState(66)
84 | self.root_dir = os.path.join(settings.data_dir, name)
85 | self.mat_files = os.listdir(self.root_dir)
86 | self.patch_size = settings.patch_size
87 | self.file_num = len(self.mat_files)
88 |
89 | def __len__(self):
90 | return self.file_num
91 |
92 | def __getitem__(self, idx):
93 | file_name = self.mat_files[idx % self.file_num]
94 | img_file = os.path.join(self.root_dir, file_name)
95 | img_pair = cv2.imread(img_file).astype(np.float32) / 255
96 | h, ww, c = img_pair.shape
97 | w = int(ww / 2)
98 | #h_8=h%8
99 | #w_8=w%8
100 | O = np.transpose(img_pair[:, w:], (2, 0, 1))
101 | B = np.transpose(img_pair[:, :w], (2, 0, 1))
102 | sample = {'O': O, 'B': B}
103 |
104 | return sample
105 |
106 |
107 | class ShowDataset(Dataset):
108 | def __init__(self, name):
109 | super().__init__()
110 | self.rand_state = RandomState(66)
111 | self.root_dir = os.path.join(settings.data_dir, name)
112 | self.img_files = sorted(os.listdir(self.root_dir))
113 | self.file_num = len(self.img_files)
114 |
115 | def __len__(self):
116 | return self.file_num
117 |
118 | def __getitem__(self, idx):
119 | file_name = self.img_files[idx % self.file_num]
120 | img_file = os.path.join(self.root_dir, file_name)
121 | print(img_file)
122 | img_pair = cv2.imread(img_file).astype(np.float32) / 255
123 |
124 | h, ww, c = img_pair.shape
125 | w = int(ww / 2)
126 |
127 | #h_8 = h % 8
128 | #w_8 = w % 8
129 | if settings.pic_is_pair:
130 | O = np.transpose(img_pair[:, w:], (2, 0, 1))
131 | B = np.transpose(img_pair[:, :w], (2, 0, 1))
132 | else:
133 | O = np.transpose(img_pair[:, :], (2, 0, 1))
134 | B = np.transpose(img_pair[:, :], (2, 0, 1))
135 | sample = {'O': O, 'B': B,'file_name':file_name[:-4]}
136 |
137 | return sample
138 |
139 |
140 | if __name__ == '__main__':
141 | dt = TrainValDataset('val')
142 | print('TrainValDataset')
143 | for i in range(10):
144 | smp = dt[i]
145 | for k, v in smp.items():
146 | print(k, v.shape, v.dtype, v.mean())
147 |
148 | print()
149 | dt = TestDataset('test')
150 | print('TestDataset')
151 | for i in range(10):
152 | smp = dt[i]
153 | for k, v in smp.items():
154 | print(k, v.shape, v.dtype, v.mean())
155 |
156 | print()
157 | print('ShowDataset')
158 | dt = ShowDataset('test')
159 | for i in range(10):
160 | smp = dt[i]
161 | for k, v in smp.items():
162 | print(k, v.shape, v.dtype, v.mean())
163 |
--------------------------------------------------------------------------------
/code/ablation/r1/config/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | import argparse
5 | import numpy as np
6 | import math
7 | import torch
8 | from torch import nn
9 | from torch.nn import MSELoss
10 | from torch.optim import Adam
11 | from torch.optim.lr_scheduler import MultiStepLR
12 | from torch.autograd import Variable
13 | from torch.utils.data import DataLoader
14 | from tensorboardX import SummaryWriter
15 |
16 | import settings
17 | from dataset import TestDataset
18 | from model import ODE_DerainNet
19 | from cal_ssim import SSIM
20 |
21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id
22 | logger = settings.logger
23 | torch.cuda.manual_seed_all(66)
24 | torch.manual_seed(66)
25 | #torch.cuda.set_device(settings.device_id)
26 |
27 |
28 | def ensure_dir(dir_path):
29 | if not os.path.isdir(dir_path):
30 | os.makedirs(dir_path)
31 | def PSNR(img1, img2):
32 | b,_,_,_=img1.shape
33 | #mse=0
34 | #for i in range(b):
35 | img1=np.clip(img1,0,255)
36 | img2=np.clip(img2,0,255)
37 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse
38 | if mse == 0:
39 | return 100
40 | #mse=mse/b
41 | PIXEL_MAX = 1
42 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
43 | class Session:
44 | def __init__(self):
45 | self.log_dir = settings.log_dir
46 | self.model_dir = settings.model_dir
47 | ensure_dir(settings.log_dir)
48 | ensure_dir(settings.model_dir)
49 | logger.info('set log dir as %s' % settings.log_dir)
50 | logger.info('set model dir as %s' % settings.model_dir)
51 | if len(settings.device_id) >1:
52 | self.net = nn.DataParallel(ODE_DerainNet()).cuda()
53 | else:
54 | torch.cuda.set_device(settings.device_id[0])
55 | self.net = ODE_DerainNet().cuda()
56 | self.l2 = MSELoss().cuda()
57 | self.l1 = nn.L1Loss().cuda()
58 | self.ssim = SSIM().cuda()
59 | self.dataloaders = {}
60 |
61 | def get_dataloader(self, dataset_name):
62 | dataset = TestDataset(dataset_name)
63 | if not dataset_name in self.dataloaders:
64 | self.dataloaders[dataset_name] = \
65 | DataLoader(dataset, batch_size=1,
66 | shuffle=False, num_workers=1, drop_last=False)
67 | return self.dataloaders[dataset_name]
68 |
69 | def load_checkpoints(self, name):
70 | ckp_path = os.path.join(self.model_dir, name)
71 | try:
72 | obj = torch.load(ckp_path)
73 | logger.info('Load checkpoint %s' % ckp_path)
74 | except FileNotFoundError:
75 | logger.info('No checkpoint %s!!' % ckp_path)
76 | return
77 | self.net.load_state_dict(obj['net'])
78 |
79 | def inf_batch(self, name, batch):
80 | O, B = batch['O'].cuda(), batch['B'].cuda()
81 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False)
82 |
83 | with torch.no_grad():
84 | derain = self.net(O)
85 |
86 | l1_loss = self.l1(derain, B)
87 | ssim = self.ssim(derain, B)
88 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255)
89 | losses = { 'L1 loss' : l1_loss }
90 | ssimes = { 'ssim' : ssim }
91 | losses.update(ssimes)
92 |
93 | return losses, psnr
94 |
95 |
96 | def run_test(ckp_name):
97 | sess = Session()
98 | sess.net.eval()
99 | sess.load_checkpoints(ckp_name)
100 | dt = sess.get_dataloader('test')
101 | psnr_all = 0
102 | all_num = 0
103 | all_losses = {}
104 | for i, batch in enumerate(dt):
105 | losses,psnr= sess.inf_batch('test', batch)
106 | psnr_all=psnr_all+psnr
107 | batch_size = batch['O'].size(0)
108 | all_num += batch_size
109 | for key, val in losses.items():
110 | if i == 0:
111 | all_losses[key] = 0.
112 | all_losses[key] += val * batch_size
113 | logger.info('batch %d mse %s: %f' % (i, key, val))
114 |
115 | for key, val in all_losses.items():
116 | logger.info('total mse %s: %f' % (key, val / all_num))
117 | #psnr=sum(psnr_all)
118 | #print(psnr)
119 | print('psnr_ll:%8f'%(psnr_all/all_num))
120 |
121 | if __name__ == '__main__':
122 | parser = argparse.ArgumentParser()
123 | parser.add_argument('-m', '--model', default='latest_net')
124 |
125 | args = parser.parse_args(sys.argv[1:])
126 | run_test(args.model)
127 |
128 |
--------------------------------------------------------------------------------
/code/ablation/r1/config/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | import settings
5 | from itertools import combinations,product
6 | import math
7 |
8 | class Residual_Block(nn.Module):
9 | def __init__(self, in_ch, out_ch, stride):
10 | super(Residual_Block, self).__init__()
11 | self.channel_num = settings.channel
12 | self.convs = nn.ModuleList()
13 | self.relus = nn.ModuleList()
14 | self.convert = nn.Sequential(
15 | nn.Conv2d(in_ch, out_ch, 3, stride, 1),
16 | nn.LeakyReLU(0.2)
17 | )
18 | self.res = nn.Sequential(
19 | nn.Conv2d(out_ch, out_ch, 3, 1, 1),
20 | nn.LeakyReLU(0.2),
21 | nn.Conv2d(out_ch, out_ch, 3, 1, 1),
22 | )
23 |
24 | def forward(self, x):
25 | convert = self.convert(x)
26 | out = convert + self.res(convert)
27 | return out
28 |
29 | class Scale_attention(nn.Module):
30 | def __init__(self):
31 | super(Scale_attention, self).__init__()
32 | self.scale_attention = nn.ModuleList()
33 | self.res_list = nn.ModuleList()
34 | self.channel = settings.channel
35 | if settings.scale_attention is True:
36 | for i in range(settings.num_scale_attention):
37 | self.scale_attention.append(
38 | nn.Sequential(
39 | nn.MaxPool2d(2 ** (i + 1), 2 ** (i + 1)),
40 | nn.Conv2d(self.channel, self.channel, 1, 1),
41 | nn.Sigmoid()
42 | )
43 | )
44 | for i in range(settings.num_scale_attention):
45 | self.res_list.append(
46 | Residual_Block(self.channel, self.channel, 2)
47 | )
48 |
49 | self.conv11 = nn.Sequential(
50 | nn.Conv2d((settings.num_scale_attention + 1) * self.channel, self.channel, 1, 1),
51 | nn.LeakyReLU(0.2)
52 | )
53 |
54 | def forward(self, x):
55 | b, c, h, w = x.size()
56 | temp = x
57 | out = []
58 | out.append(temp)
59 | if settings.scale_attention is True:
60 | for i in range(settings.num_scale_attention):
61 | temp = self.res_list[i](temp)
62 | b0,c0,h0,w0 = temp.size()
63 | temp = temp * F.upsample(self.scale_attention[i](x), [h0, w0])
64 | up = temp
65 | out.append(F.upsample(up, [h, w]))
66 | fusion = self.conv11(torch.cat(out, dim=1))
67 |
68 | else:
69 | for i in range(settings.num_scale_attention):
70 | temp = self.res_list[i](temp)
71 | up = temp
72 | out.append(F.upsample(up, [h, w]))
73 | fusion = self.conv11(torch.cat(out, dim=1))
74 | return fusion + x
75 |
76 | class DenseConnection(nn.Module):
77 | def __init__(self, unit, unit_num):
78 | super(DenseConnection, self).__init__()
79 | self.unit_num = unit_num
80 | self.channel = settings.channel
81 | self.units = nn.ModuleList()
82 | self.conv1x1 = nn.ModuleList()
83 | for i in range(self.unit_num):
84 | self.units.append(unit())
85 | self.conv1x1.append(nn.Sequential(nn.Conv2d((i+2)*self.channel, self.channel, 1, 1), nn.LeakyReLU(0.2)))
86 |
87 | def forward(self, x):
88 | cat = []
89 | cat.append(x)
90 | out = x
91 | for i in range(self.unit_num):
92 | tmp = self.units[i](out)
93 | cat.append(tmp)
94 | out = self.conv1x1[i](torch.cat(cat,dim=1))
95 | return out
96 |
97 |
98 | class ODE_DerainNet(nn.Module):
99 | def __init__(self):
100 | super(ODE_DerainNet, self).__init__()
101 | self.channel = settings.channel
102 | self.unit_num = settings.unit_num
103 | self.enterBlock = nn.Sequential(nn.Conv2d(3, self.channel, 3, 1, 1), nn.LeakyReLU(0.2))
104 | self.derain_net = DenseConnection(Scale_attention, self.unit_num)
105 | self.exitBlock = nn.Sequential(nn.Conv2d(self.channel, 3, 3, 1, 1), nn.LeakyReLU(0.2))
106 |
107 |
108 | def forward(self, x):
109 | image_feature = self.enterBlock(x)
110 | rain_feature = self.derain_net(image_feature)
111 | rain = self.exitBlock(rain_feature)
112 | derain = x - rain
113 | return derain
114 |
--------------------------------------------------------------------------------
/code/ablation/r1/config/settings.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 |
4 |
5 | ####################### 网络参数设置 ############################################################
6 | channel = 32
7 | feature_map_num = 32
8 | res_conv_num = 4
9 | unit_num = 32
10 | #scale_num = 4
11 | num_scale_attention = 4
12 | scale_attention = False
13 | ssim_loss = True
14 | ########################################################################################
15 | aug_data = False # Set as False for fair comparison
16 |
17 | patch_size = 64
18 | pic_is_pair = True #input picture is pair or single
19 |
20 | lr = 0.0005
21 |
22 | data_dir = '/data1/wangcong/dataset/rain12'
23 | if pic_is_pair is False:
24 | data_dir = '/data1/wangcong/dataset/real-world-images'
25 | log_dir = '../logdir'
26 | show_dir = '../showdir'
27 | model_dir = '../models'
28 | show_dir_feature = '../showdir_feature'
29 |
30 | log_level = 'info'
31 | model_path = os.path.join(model_dir, 'latest_net')
32 | save_steps = 400
33 |
34 | num_workers = 8
35 |
36 | num_GPU = 2
37 |
38 | device_id = '2,1'
39 |
40 | epoch = 1000
41 | batch_size = 28
42 |
43 | if pic_is_pair:
44 | root_dir = os.path.join(data_dir, 'train')
45 | mat_files = os.listdir(root_dir)
46 | num_datasets = len(mat_files)
47 | l1 = int(3/5 * epoch * num_datasets / batch_size)
48 | l2 = int(4/5 * epoch * num_datasets / batch_size)
49 | one_epoch = int(num_datasets/batch_size)
50 | total_step = int((epoch * num_datasets)/batch_size)
51 |
52 | logger = logging.getLogger('train')
53 | logger.setLevel(logging.INFO)
54 |
55 | ch = logging.StreamHandler()
56 | ch.setLevel(logging.INFO)
57 |
58 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
59 | ch.setFormatter(formatter)
60 | logger.addHandler(ch)
61 |
62 |
63 |
--------------------------------------------------------------------------------
/code/ablation/r1/config/show.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | import argparse
5 | import math
6 | import numpy as np
7 | import itertools
8 |
9 | import torch
10 | from torch import nn
11 | from torch.nn import DataParallel
12 | from torch.optim import Adam
13 | from torch.autograd import Variable
14 | from torch.utils.data import DataLoader
15 |
16 | import settings
17 | from dataset import ShowDataset
18 | from model import ODE_DerainNet
19 | from cal_ssim import SSIM
20 |
21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id
22 | logger = settings.logger
23 | torch.cuda.manual_seed_all(66)
24 | torch.manual_seed(66)
25 | #torch.cuda.set_device(settings.device_id)
26 |
27 |
28 | def ensure_dir(dir_path):
29 | if not os.path.isdir(dir_path):
30 | os.makedirs(dir_path)
31 |
32 | def PSNR(img1, img2):
33 | b,_,_,_=img1.shape
34 | #mse=0
35 | #for i in range(b):
36 | img1=np.clip(img1,0,255)
37 | img2=np.clip(img2,0,255)
38 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse
39 | if mse == 0:
40 | return 100
41 | #mse=mse/b
42 | PIXEL_MAX = 1
43 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
44 |
45 | class Session:
46 | def __init__(self):
47 | self.show_dir = settings.show_dir
48 | self.model_dir = settings.model_dir
49 | ensure_dir(settings.show_dir)
50 | ensure_dir(settings.model_dir)
51 | logger.info('set show dir as %s' % settings.show_dir)
52 | logger.info('set model dir as %s' % settings.model_dir)
53 |
54 | if len(settings.device_id) >1:
55 | self.net = nn.DataParallel(ODE_DerainNet()).cuda()
56 | #self.l2 = nn.DataParallel(MSELoss(),settings.device_id)
57 | #self.l1 = nn.DataParallel(nn.L1Loss(),settings.device_id)
58 | #self.ssim = nn.DataParallel(SSIM(),settings.device_id)
59 | #self.vgg = nn.DataParallel(VGG(),settings.device_id)
60 | else:
61 | torch.cuda.set_device(settings.device_id[0])
62 | self.net = ODE_DerainNet().cuda()
63 | self.ssim = SSIM().cuda()
64 | self.dataloaders = {}
65 | self.ssim=SSIM().cuda()
66 | self.a=0
67 | self.t=0
68 | def get_dataloader(self, dataset_name):
69 | dataset = ShowDataset(dataset_name)
70 | self.dataloaders[dataset_name] = \
71 | DataLoader(dataset, batch_size=1,
72 | shuffle=False, num_workers=1)
73 | return self.dataloaders[dataset_name]
74 |
75 | def load_checkpoints(self, name):
76 | ckp_path = os.path.join(self.model_dir, name)
77 | try:
78 | obj = torch.load(ckp_path)
79 | logger.info('Load checkpoint %s' % ckp_path)
80 | except FileNotFoundError:
81 | logger.info('No checkpoint %s!!' % ckp_path)
82 | return
83 | self.net.load_state_dict(obj['net'])
84 | def inf_batch(self, name, batch,i):
85 | O, B, file_name= batch['O'].cuda(), batch['B'].cuda(), batch['file_name']
86 | file_name = str(file_name[0])
87 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False)
88 | with torch.no_grad():
89 | import time
90 | t0=time.time()
91 | derain = self.net(O)
92 | t1 = time.time()
93 | comput_time=t1-t0
94 | print(comput_time)
95 | ssim = self.ssim(derain, B).data.cpu().numpy()
96 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255)
97 | print('psnr:%4f-------------ssim:%4f'%(psnr, ssim))
98 | return derain, psnr, ssim, file_name
99 |
100 | def save_image(self, No, imgs, name, psnr, ssim, file_name):
101 | for i, img in enumerate(imgs):
102 | img = (img.cpu().data * 255).numpy()
103 | img = np.clip(img, 0, 255)
104 | img = np.transpose(img, (1, 2, 0))
105 | h, w, c = img.shape
106 |
107 | img_file = os.path.join(self.show_dir, '%s.png' % (file_name))
108 | print(img_file)
109 | cv2.imwrite(img_file, img)
110 |
111 |
112 | def run_show(ckp_name):
113 | sess = Session()
114 | sess.load_checkpoints(ckp_name)
115 | sess.net.eval()
116 | dataset = 'test'
117 | if settings.pic_is_pair is False:
118 | dataset = 'train-small'
119 | dt = sess.get_dataloader(dataset)
120 |
121 | for i, batch in enumerate(dt):
122 | logger.info(i)
123 | if i>423:
124 | imgs,psnr,ssim, file_name= sess.inf_batch('test', batch,i)
125 | sess.save_image(i, imgs, dataset, psnr, ssim, file_name)
126 |
127 |
128 |
129 | if __name__ == '__main__':
130 | parser = argparse.ArgumentParser()
131 | parser.add_argument('-m', '--model', default='latest_net')
132 |
133 | args = parser.parse_args(sys.argv[1:])
134 |
135 | run_show(args.model)
136 |
137 |
--------------------------------------------------------------------------------
/code/ablation/r1/config/tensorboard.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | function rand() {
3 | min=$1
4 | max=$(($2-$min+1))
5 | num=$(($RANDOM+1000000000000))
6 | echo $(($num%$max+$min))
7 | }
8 |
9 | rnd=$(rand 3000 12000)
10 | tensorboard --logdir ../logdir --host 0.0.0.0 --port $rnd --reload_interval 3
11 |
--------------------------------------------------------------------------------
/code/ablation/r1/config/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | import argparse
5 | import math
6 |
7 | import torch
8 | from torch import nn
9 | from torch.nn import MSELoss
10 | from torch.optim import Adam
11 | from torch.optim.lr_scheduler import MultiStepLR
12 | from torch.autograd import Variable
13 | from torch.utils.data import DataLoader
14 | from tensorboardX import SummaryWriter
15 |
16 | import settings
17 | from dataset import TrainValDataset, TestDataset
18 | from model import ODE_DerainNet
19 | from cal_ssim import SSIM
20 |
21 | logger = settings.logger
22 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id
23 | torch.cuda.manual_seed_all(66)
24 | torch.manual_seed(66)
25 | import numpy as np
26 |
27 |
28 | def ensure_dir(dir_path):
29 | if not os.path.isdir(dir_path):
30 | os.makedirs(dir_path)
31 |
32 | def PSNR(img1, img2):
33 | b,_,_,_=img1.shape
34 | img1=np.clip(img1,0,255)
35 | img2=np.clip(img2,0,255)
36 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)
37 | if mse == 0:
38 | return 100
39 |
40 | PIXEL_MAX = 1
41 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
42 |
43 | class Session:
44 | def __init__(self):
45 | self.log_dir = settings.log_dir
46 | self.model_dir = settings.model_dir
47 | self.ssim_loss = settings.ssim_loss
48 | ensure_dir(settings.log_dir)
49 | ensure_dir(settings.model_dir)
50 | ensure_dir('../log_test')
51 | logger.info('set log dir as %s' % settings.log_dir)
52 | logger.info('set model dir as %s' % settings.model_dir)
53 | if len(settings.device_id) >1:
54 | self.net = nn.DataParallel(ODE_DerainNet()).cuda()
55 | else:
56 | torch.cuda.set_device(settings.device_id[0])
57 | self.net = ODE_DerainNet().cuda()
58 | self.l1 = nn.L1Loss().cuda()
59 | self.ssim = SSIM().cuda()
60 | self.step = 0
61 | self.save_steps = settings.save_steps
62 | self.num_workers = settings.num_workers
63 | self.batch_size = settings.batch_size
64 | self.writers = {}
65 | self.dataloaders = {}
66 | self.opt_net = Adam(self.net.parameters(), lr=settings.lr)
67 | self.sche_net = MultiStepLR(self.opt_net, milestones=[settings.l1, settings.l2], gamma=0.1)
68 |
69 | def tensorboard(self, name):
70 | self.writers[name] = SummaryWriter(os.path.join(self.log_dir, name + '.events'))
71 | return self.writers[name]
72 |
73 | def write(self, name, out):
74 | for k, v in out.items():
75 | self.writers[name].add_scalar(k, v, self.step)
76 | out['lr'] = self.opt_net.param_groups[0]['lr']
77 | out['step'] = self.step
78 | outputs = [
79 | "{}:{:.4g}".format(k, v)
80 | for k, v in out.items()
81 | ]
82 | logger.info(name + '--' + ' '.join(outputs))
83 |
84 | def get_dataloader(self, dataset_name):
85 | dataset = TrainValDataset(dataset_name)
86 | if not dataset_name in self.dataloaders:
87 | self.dataloaders[dataset_name] = \
88 | DataLoader(dataset, batch_size=self.batch_size,
89 | shuffle=True, num_workers=self.num_workers, drop_last=True)
90 | return iter(self.dataloaders[dataset_name])
91 |
92 | def get_test_dataloader(self, dataset_name):
93 | dataset = TestDataset(dataset_name)
94 | if not dataset_name in self.dataloaders:
95 | self.dataloaders[dataset_name] = \
96 | DataLoader(dataset, batch_size=1,
97 | shuffle=False, num_workers=1, drop_last=False)
98 | return self.dataloaders[dataset_name]
99 |
100 | def save_checkpoints_net(self, name):
101 | ckp_path = os.path.join(self.model_dir, name)
102 | obj = {
103 | 'net': self.net.state_dict(),
104 | 'clock_net': self.step,
105 | 'opt_net': self.opt_net.state_dict(),
106 | }
107 | torch.save(obj, ckp_path)
108 |
109 | def load_checkpoints_net(self, name):
110 | ckp_path = os.path.join(self.model_dir, name)
111 | try:
112 | logger.info('Load checkpoint %s' % ckp_path)
113 | obj = torch.load(ckp_path)
114 | except FileNotFoundError:
115 | logger.info('No checkpoint %s!!' % ckp_path)
116 | return
117 | self.net.load_state_dict(obj['net'])
118 | self.opt_net.load_state_dict(obj['opt_net'])
119 | self.step = obj['clock_net']
120 | self.sche_net.last_epoch = self.step
121 |
122 | def print_network(self, model):
123 | num_params = 0
124 | for p in model.parameters():
125 | num_params += p.numel()
126 | print(model)
127 | print("The number of parameters: {}".format(num_params))
128 |
129 | def inf_batch(self, name, batch):
130 | if name == 'train':
131 | self.net.zero_grad()
132 | if self.step==0:
133 | self.print_network(self.net)
134 |
135 | O, B = batch['O'].cuda(), batch['B'].cuda()
136 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False)
137 | derain = self.net(O)
138 | l1_loss = self.l1(derain, B)
139 | ssim = self.ssim(derain, B)
140 | if self.ssim_loss == True :
141 | loss = -ssim
142 | else:
143 | loss = l1_loss
144 | if name == 'train':
145 | loss.backward()
146 | self.opt_net.step()
147 | losses = {'L1loss' : l1_loss}
148 | ssimes = {'ssim' : ssim}
149 | losses.update(ssimes)
150 | self.write(name, losses)
151 |
152 | return derain
153 |
154 | def save_image(self, name, img_lists):
155 | data, pred, label = img_lists
156 | pred = pred.cpu().data
157 |
158 | data, label, pred = data * 255, label * 255, pred * 255
159 | pred = np.clip(pred, 0, 255)
160 | h, w = pred.shape[-2:]
161 | gen_num = (1, 1)
162 | img = np.zeros((gen_num[0] * h, gen_num[1] * 3 * w, 3))
163 | for img_list in img_lists:
164 | for i in range(gen_num[0]):
165 | row = i * h
166 | for j in range(gen_num[1]):
167 | idx = i * gen_num[1] + j
168 | tmp_list = [data[idx], pred[idx], label[idx]]
169 | for k in range(3):
170 | col = (j * 3 + k) * w
171 | tmp = np.transpose(tmp_list[k], (1, 2, 0))
172 | img[row: row+h, col: col+w] = tmp
173 | img_file = os.path.join(self.log_dir, '%d_%s.jpg' % (self.step, name))
174 | cv2.imwrite(img_file, img)
175 |
176 | def inf_batch_test(self, name, batch):
177 | O, B = batch['O'].cuda(), batch['B'].cuda()
178 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False)
179 |
180 | with torch.no_grad():
181 | derain = self.net(O)
182 |
183 | l1_loss = self.l1(derain, B)
184 | ssim = self.ssim(derain, B)
185 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255)
186 | losses = { 'L1 loss' : l1_loss }
187 | ssimes = { 'ssim' : ssim }
188 | losses.update(ssimes)
189 |
190 | return l1_loss.data.cpu().numpy(), ssim.data.cpu().numpy(), psnr
191 |
192 |
193 | def run_train_val(ckp_name_net='latest_net'):
194 | sess = Session()
195 | sess.load_checkpoints_net(ckp_name_net)
196 | sess.tensorboard('train')
197 | dt_train = sess.get_dataloader('train')
198 | while sess.step < settings.total_step+1:
199 | sess.sche_net.step()
200 | sess.net.train()
201 | try:
202 | batch_t = next(dt_train)
203 | except StopIteration:
204 | dt_train = sess.get_dataloader('train')
205 | batch_t = next(dt_train)
206 | pred_t = sess.inf_batch('train', batch_t)
207 | if sess.step % int(sess.save_steps / 16) == 0:
208 | sess.save_checkpoints_net('latest_net')
209 | if sess.step % sess.save_steps == 0:
210 | sess.save_image('train', [batch_t['O'], pred_t, batch_t['B']])
211 |
212 | #observe tendency of ssim, psnr and loss
213 | ssim_all = 0
214 | psnr_all = 0
215 | loss_all = 0
216 | num_all = 0
217 | if sess.step % (settings.one_epoch * 20) == 0:
218 | dt_val = sess.get_test_dataloader('test')
219 | sess.net.eval()
220 | for i, batch_v in enumerate(dt_val):
221 | loss, ssim, psnr = sess.inf_batch_test('test', batch_v)
222 | print(i)
223 | ssim_all = ssim_all + ssim
224 | psnr_all = psnr_all + psnr
225 | loss_all = loss_all + loss
226 | num_all = num_all + 1
227 | print('num_all:',num_all)
228 | loss_avg = loss_all / num_all
229 | ssim_avg = ssim_all / num_all
230 | psnr_avg = psnr_all / num_all
231 | logfile = open('../log_test/' + 'val' + '.txt','a+')
232 | epoch = int(sess.step / settings.one_epoch)
233 | logfile.write(
234 | 'step = ' + str(sess.step) + '\t'
235 | 'epoch = ' + str(epoch) + '\t'
236 | 'loss = ' + str(loss_avg) + '\t'
237 | 'ssim = ' + str(ssim_avg) + '\t'
238 | 'pnsr = ' + str(psnr_avg) + '\t'
239 | '\n\n'
240 | )
241 | logfile.close()
242 | if sess.step % (settings.one_epoch*10) == 0:
243 | sess.save_checkpoints_net('net_%d_epoch' % int(sess.step / settings.one_epoch))
244 | logger.info('save model as net_%d_epoch' % int(sess.step / settings.one_epoch))
245 | sess.step += 1
246 |
247 |
248 |
249 | if __name__ == '__main__':
250 | parser = argparse.ArgumentParser()
251 | parser.add_argument('-m1', '--model_1', default='latest_net')
252 |
253 | args = parser.parse_args(sys.argv[1:])
254 | run_train_val(args.model_1)
255 |
256 |
--------------------------------------------------------------------------------
/code/ablation/r1/models/README.md:
--------------------------------------------------------------------------------
1 | # JDNet:Joint Self-Attention and Scale-Aggregation for Self-Calibrated Deraining Network
2 |
3 | The training model or pre-training model will be placed here.
4 |
--------------------------------------------------------------------------------
/code/ablation/r2/config/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/ablation/r2/config/.DS_Store
--------------------------------------------------------------------------------
/code/ablation/r2/config/cal_ssim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | import numpy as np
5 | from math import exp
6 |
7 | def gaussian(window_size, sigma):
8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
9 | return gauss/gauss.sum()
10 |
11 | def create_window(window_size, channel):
12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
15 | return window
16 |
17 | def _ssim(img1, img2, window, window_size, channel, size_average = True):
18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
20 |
21 | mu1_sq = mu1.pow(2)
22 | mu2_sq = mu2.pow(2)
23 | mu1_mu2 = mu1*mu2
24 |
25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
28 |
29 | C1 = 0.01**2
30 | C2 = 0.03**2
31 |
32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
33 |
34 | if size_average:
35 | return ssim_map.mean()
36 | else:
37 | return ssim_map.mean(1).mean(1).mean(1)
38 |
39 | class SSIM(torch.nn.Module):
40 | def __init__(self, window_size = 11, size_average = True):
41 | super(SSIM, self).__init__()
42 | self.window_size = window_size
43 | self.size_average = size_average
44 | self.channel = 3
45 | self.window = create_window(window_size, self.channel)
46 |
47 | def forward(self, img1, img2):
48 | #print(img1.size())
49 | (_, channel, _, _) = img1.size()
50 |
51 | if channel == self.channel and self.window.data.type() == img1.data.type():
52 | window = self.window
53 | else:
54 | window = create_window(self.window_size, channel)
55 |
56 | if img1.is_cuda:
57 | window = window.cuda(img1.get_device())
58 | window = window.type_as(img1)
59 |
60 | self.window = window
61 | self.channel = channel
62 |
63 |
64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
65 |
66 | def ssim(img1, img2, window_size = 11, size_average = True):
67 | (_, channel, _, _) = img1.size()
68 | window = create_window(window_size, channel)
69 |
70 | if img1.is_cuda:
71 | window = window.cuda(img1.get_device())
72 | window = window.type_as(img1)
73 |
74 | return _ssim(img1, img2, window, window_size, channel, size_average)
75 |
--------------------------------------------------------------------------------
/code/ablation/r2/config/clean.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | rm -rf __pycache__
3 | rm \.*\.swp
4 | rm -R ../logdir/*
5 | rm -R ../showdir/*
6 | rm ../models/*
7 |
--------------------------------------------------------------------------------
/code/ablation/r2/config/compile.py:
--------------------------------------------------------------------------------
1 | import re
2 | rec_data="+MIPLOBSERVE:0,68220,1,3303,0,-1"
3 | msgidRegex = re.compile(r',(\d)+,')
4 | mo = msgidRegex.search(rec_data)
5 | print(mo.group())
--------------------------------------------------------------------------------
/code/ablation/r2/config/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | from numpy.random import RandomState
5 | from torch.utils.data import Dataset
6 |
7 | import settings
8 |
9 |
10 | class TrainValDataset(Dataset):
11 | def __init__(self, name):
12 | super().__init__()
13 | self.rand_state = RandomState(66)
14 | self.root_dir = os.path.join(settings.data_dir, name)
15 | self.mat_files = os.listdir(self.root_dir)
16 | self.patch_size = settings.patch_size
17 | self.file_num = len(self.mat_files)
18 |
19 | def __len__(self):
20 | return self.file_num * 100
21 |
22 | def __getitem__(self, idx):
23 | file_name = self.mat_files[idx % self.file_num]
24 | img_file = os.path.join(self.root_dir, file_name)
25 | img_pair = cv2.imread(img_file).astype(np.float32) / 255
26 |
27 | if settings.aug_data:
28 | O, B = self.crop(img_pair, aug=True)
29 | O, B = self.flip(O, B)
30 | O, B = self.rotate(O, B)
31 | else:
32 | O, B = self.crop(img_pair, aug=False)
33 |
34 | O = np.transpose(O, (2, 0, 1))
35 | B = np.transpose(B, (2, 0, 1))
36 | sample = {'O': O, 'B': B}
37 |
38 | return sample
39 |
40 | def crop(self, img_pair, aug):
41 | patch_size = self.patch_size
42 | h, ww, c = img_pair.shape
43 | w = int(ww / 2)
44 |
45 | if aug:
46 | mini = - 1 / 4 * self.patch_size
47 | maxi = 1 / 4 * self.patch_size + 1
48 | p_h = patch_size + self.rand_state.randint(mini, maxi)
49 | p_w = patch_size + self.rand_state.randint(mini, maxi)
50 | else:
51 | p_h, p_w = patch_size, patch_size
52 |
53 | r = self.rand_state.randint(0, h - p_h)
54 | c = self.rand_state.randint(0, w - p_w)
55 | O = img_pair[r: r+p_h, c+w: c+p_w+w]
56 | B = img_pair[r: r+p_h, c: c+p_w]
57 |
58 | if aug:
59 | O = cv2.resize(O, (patch_size, patch_size))
60 | B = cv2.resize(B, (patch_size, patch_size))
61 |
62 | return O, B
63 |
64 | def flip(self, O, B):
65 | if self.rand_state.rand() > 0.5:
66 | O = np.flip(O, axis=1)
67 | B = np.flip(B, axis=1)
68 | return O, B
69 |
70 | def rotate(self, O, B):
71 | angle = self.rand_state.randint(-30, 30)
72 | patch_size = self.patch_size
73 | center = (int(patch_size / 2), int(patch_size / 2))
74 | M = cv2.getRotationMatrix2D(center, angle, 1)
75 | O = cv2.warpAffine(O, M, (patch_size, patch_size))
76 | B = cv2.warpAffine(B, M, (patch_size, patch_size))
77 | return O, B
78 |
79 |
80 | class TestDataset(Dataset):
81 | def __init__(self, name):
82 | super().__init__()
83 | self.rand_state = RandomState(66)
84 | self.root_dir = os.path.join(settings.data_dir, name)
85 | self.mat_files = os.listdir(self.root_dir)
86 | self.patch_size = settings.patch_size
87 | self.file_num = len(self.mat_files)
88 |
89 | def __len__(self):
90 | return self.file_num
91 |
92 | def __getitem__(self, idx):
93 | file_name = self.mat_files[idx % self.file_num]
94 | img_file = os.path.join(self.root_dir, file_name)
95 | img_pair = cv2.imread(img_file).astype(np.float32) / 255
96 | h, ww, c = img_pair.shape
97 | w = int(ww / 2)
98 | #h_8=h%8
99 | #w_8=w%8
100 | O = np.transpose(img_pair[:, w:], (2, 0, 1))
101 | B = np.transpose(img_pair[:, :w], (2, 0, 1))
102 | sample = {'O': O, 'B': B}
103 |
104 | return sample
105 |
106 |
107 | class ShowDataset(Dataset):
108 | def __init__(self, name):
109 | super().__init__()
110 | self.rand_state = RandomState(66)
111 | self.root_dir = os.path.join(settings.data_dir, name)
112 | self.img_files = sorted(os.listdir(self.root_dir))
113 | self.file_num = len(self.img_files)
114 |
115 | def __len__(self):
116 | return self.file_num
117 |
118 | def __getitem__(self, idx):
119 | file_name = self.img_files[idx % self.file_num]
120 | img_file = os.path.join(self.root_dir, file_name)
121 | print(img_file)
122 | img_pair = cv2.imread(img_file).astype(np.float32) / 255
123 |
124 | h, ww, c = img_pair.shape
125 | w = int(ww / 2)
126 |
127 | #h_8 = h % 8
128 | #w_8 = w % 8
129 | if settings.pic_is_pair:
130 | O = np.transpose(img_pair[:, w:], (2, 0, 1))
131 | B = np.transpose(img_pair[:, :w], (2, 0, 1))
132 | else:
133 | O = np.transpose(img_pair[:, :], (2, 0, 1))
134 | B = np.transpose(img_pair[:, :], (2, 0, 1))
135 | sample = {'O': O, 'B': B,'file_name':file_name[:-4]}
136 |
137 | return sample
138 |
139 |
140 | if __name__ == '__main__':
141 | dt = TrainValDataset('val')
142 | print('TrainValDataset')
143 | for i in range(10):
144 | smp = dt[i]
145 | for k, v in smp.items():
146 | print(k, v.shape, v.dtype, v.mean())
147 |
148 | print()
149 | dt = TestDataset('test')
150 | print('TestDataset')
151 | for i in range(10):
152 | smp = dt[i]
153 | for k, v in smp.items():
154 | print(k, v.shape, v.dtype, v.mean())
155 |
156 | print()
157 | print('ShowDataset')
158 | dt = ShowDataset('test')
159 | for i in range(10):
160 | smp = dt[i]
161 | for k, v in smp.items():
162 | print(k, v.shape, v.dtype, v.mean())
163 |
--------------------------------------------------------------------------------
/code/ablation/r2/config/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | import argparse
5 | import numpy as np
6 | import math
7 | import torch
8 | from torch import nn
9 | from torch.nn import MSELoss
10 | from torch.optim import Adam
11 | from torch.optim.lr_scheduler import MultiStepLR
12 | from torch.autograd import Variable
13 | from torch.utils.data import DataLoader
14 | from tensorboardX import SummaryWriter
15 |
16 | import settings
17 | from dataset import TestDataset
18 | from model import ODE_DerainNet
19 | from cal_ssim import SSIM
20 |
21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id
22 | logger = settings.logger
23 | torch.cuda.manual_seed_all(66)
24 | torch.manual_seed(66)
25 | #torch.cuda.set_device(settings.device_id)
26 |
27 |
28 | def ensure_dir(dir_path):
29 | if not os.path.isdir(dir_path):
30 | os.makedirs(dir_path)
31 | def PSNR(img1, img2):
32 | b,_,_,_=img1.shape
33 | #mse=0
34 | #for i in range(b):
35 | img1=np.clip(img1,0,255)
36 | img2=np.clip(img2,0,255)
37 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse
38 | if mse == 0:
39 | return 100
40 | #mse=mse/b
41 | PIXEL_MAX = 1
42 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
43 | class Session:
44 | def __init__(self):
45 | self.log_dir = settings.log_dir
46 | self.model_dir = settings.model_dir
47 | ensure_dir(settings.log_dir)
48 | ensure_dir(settings.model_dir)
49 | logger.info('set log dir as %s' % settings.log_dir)
50 | logger.info('set model dir as %s' % settings.model_dir)
51 | if len(settings.device_id) >1:
52 | self.net = nn.DataParallel(ODE_DerainNet()).cuda()
53 | else:
54 | torch.cuda.set_device(settings.device_id[0])
55 | self.net = ODE_DerainNet().cuda()
56 | self.l2 = MSELoss().cuda()
57 | self.l1 = nn.L1Loss().cuda()
58 | self.ssim = SSIM().cuda()
59 | self.dataloaders = {}
60 |
61 | def get_dataloader(self, dataset_name):
62 | dataset = TestDataset(dataset_name)
63 | if not dataset_name in self.dataloaders:
64 | self.dataloaders[dataset_name] = \
65 | DataLoader(dataset, batch_size=1,
66 | shuffle=False, num_workers=1, drop_last=False)
67 | return self.dataloaders[dataset_name]
68 |
69 | def load_checkpoints(self, name):
70 | ckp_path = os.path.join(self.model_dir, name)
71 | try:
72 | obj = torch.load(ckp_path)
73 | logger.info('Load checkpoint %s' % ckp_path)
74 | except FileNotFoundError:
75 | logger.info('No checkpoint %s!!' % ckp_path)
76 | return
77 | self.net.load_state_dict(obj['net'])
78 |
79 | def inf_batch(self, name, batch):
80 | O, B = batch['O'].cuda(), batch['B'].cuda()
81 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False)
82 |
83 | with torch.no_grad():
84 | derain = self.net(O)
85 |
86 | l1_loss = self.l1(derain, B)
87 | ssim = self.ssim(derain, B)
88 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255)
89 | losses = { 'L1 loss' : l1_loss }
90 | ssimes = { 'ssim' : ssim }
91 | losses.update(ssimes)
92 |
93 | return losses, psnr
94 |
95 |
96 | def run_test(ckp_name):
97 | sess = Session()
98 | sess.net.eval()
99 | sess.load_checkpoints(ckp_name)
100 | dt = sess.get_dataloader('test')
101 | psnr_all = 0
102 | all_num = 0
103 | all_losses = {}
104 | for i, batch in enumerate(dt):
105 | losses,psnr= sess.inf_batch('test', batch)
106 | psnr_all=psnr_all+psnr
107 | batch_size = batch['O'].size(0)
108 | all_num += batch_size
109 | for key, val in losses.items():
110 | if i == 0:
111 | all_losses[key] = 0.
112 | all_losses[key] += val * batch_size
113 | logger.info('batch %d mse %s: %f' % (i, key, val))
114 |
115 | for key, val in all_losses.items():
116 | logger.info('total mse %s: %f' % (key, val / all_num))
117 | #psnr=sum(psnr_all)
118 | #print(psnr)
119 | print('psnr_ll:%8f'%(psnr_all/all_num))
120 |
121 | if __name__ == '__main__':
122 | parser = argparse.ArgumentParser()
123 | parser.add_argument('-m', '--model', default='latest_net')
124 |
125 | args = parser.parse_args(sys.argv[1:])
126 | run_test(args.model)
127 |
128 |
--------------------------------------------------------------------------------
/code/ablation/r2/config/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | import settings
5 | from itertools import combinations,product
6 | import math
7 |
8 | class Residual_Block(nn.Module):
9 | def __init__(self, in_ch, out_ch, stride):
10 | super(Residual_Block, self).__init__()
11 | self.channel_num = settings.channel
12 | self.convs = nn.ModuleList()
13 | self.relus = nn.ModuleList()
14 | self.convert = nn.Sequential(
15 | nn.Conv2d(in_ch, out_ch, 3, stride, 1),
16 | nn.LeakyReLU(0.2)
17 | )
18 | self.res = nn.Sequential(
19 | nn.Conv2d(out_ch, out_ch, 3, 1, 1),
20 | nn.LeakyReLU(0.2),
21 | nn.Conv2d(out_ch, out_ch, 3, 1, 1),
22 | )
23 |
24 | def forward(self, x):
25 | convert = self.convert(x)
26 | out = convert + self.res(convert)
27 | return out
28 |
29 | class SCConv(nn.Module):
30 | def __init__(self, planes, pooling_r):
31 | super(SCConv, self).__init__()
32 | self.k2 = nn.Sequential(
33 | nn.AvgPool2d(kernel_size=pooling_r, stride=pooling_r),
34 | nn.Conv2d(planes, planes, 3, 1, 1),
35 | )
36 | self.k3 = nn.Sequential(
37 | nn.Conv2d(planes, planes, 3, 1, 1),
38 | )
39 | self.k4 = nn.Sequential(
40 | nn.Conv2d(planes, planes, 3, 1, 1),
41 | nn.LeakyReLU(0.2),
42 | )
43 |
44 | def forward(self, x):
45 | identity = x
46 |
47 | out = torch.sigmoid(torch.add(identity, F.interpolate(self.k2(x), identity.size()[2:]))) # sigmoid(identity + k2)
48 | out = torch.mul(self.k3(x), out) # k3 * sigmoid(identity + k2)
49 | out = self.k4(out) # k4
50 |
51 | return out
52 |
53 | class SCBottleneck(nn.Module):
54 | #expansion = 4
55 | pooling_r = 4 # down-sampling rate of the avg pooling layer in the K3 path of SC-Conv.
56 |
57 | def __init__(self, in_planes, planes):
58 | super(SCBottleneck, self).__init__()
59 | planes = int(planes / 2)
60 |
61 | self.conv1_a = nn.Conv2d(in_planes, planes, 1, 1)
62 | self.k1 = nn.Sequential(
63 | nn.Conv2d(planes, planes, 3, 1, 1),
64 | nn.LeakyReLU(0.2),
65 | )
66 |
67 | self.conv1_b = nn.Conv2d(in_planes, planes, 1, 1)
68 |
69 | self.scconv = SCConv(planes, self.pooling_r)
70 |
71 | self.conv3 = nn.Conv2d(planes * 2, planes * 2, 1, 1)
72 | self.relu = nn.LeakyReLU(0.2)
73 |
74 | def forward(self, x):
75 | residual = x
76 |
77 | out_a= self.conv1_a(x)
78 | out_a = self.relu(out_a)
79 |
80 | out_a = self.k1(out_a)
81 |
82 | out_b = self.conv1_b(x)
83 | out_b = self.relu(out_b)
84 |
85 | out_b = self.scconv(out_b)
86 |
87 | out = self.conv3(torch.cat([out_a, out_b], dim=1))
88 |
89 | out += residual
90 | out = self.relu(out)
91 |
92 | return out
93 |
94 | class Scale_attention(nn.Module):
95 | def __init__(self):
96 | super(Scale_attention, self).__init__()
97 | self.scale_attention = nn.ModuleList()
98 | self.res_list = nn.ModuleList()
99 | self.channel = settings.channel
100 | if settings.scale_attention is True:
101 | for i in range(settings.num_scale_attention):
102 | self.scale_attention.append(
103 | nn.Sequential(
104 | nn.MaxPool2d(2 ** (i + 1), 2 ** (i + 1)),
105 | nn.Conv2d(self.channel, self.channel, 1, 1),
106 | nn.Sigmoid()
107 | )
108 | )
109 | for i in range(settings.num_scale_attention):
110 | self.res_list.append(
111 | Residual_Block(self.channel, self.channel, 2)
112 | )
113 |
114 | self.conv11 = nn.Sequential(
115 | nn.Conv2d((settings.num_scale_attention + 1) * self.channel, self.channel, 1, 1),
116 | nn.LeakyReLU(0.2)
117 | )
118 | self.scn = SCBottleneck(self.channel, self.channel)
119 |
120 | def forward(self, x):
121 | b, c, h, w = x.size()
122 | temp = x
123 | out = []
124 | out.append(temp)
125 | if settings.scale_attention is True:
126 | for i in range(settings.num_scale_attention):
127 | temp = self.res_list[i](temp)
128 | b0,c0,h0,w0 = temp.size()
129 | temp = temp * F.upsample(self.scale_attention[i](x), [h0, w0])
130 | up = temp
131 | out.append(F.upsample(up, [h, w]))
132 | fusion = self.conv11(torch.cat(out, dim=1))
133 |
134 | else:
135 | for i in range(settings.num_scale_attention):
136 | temp = self.res_list[i](temp)
137 | up = temp
138 | out.append(F.upsample(up, [h, w]))
139 | fusion = self.conv11(torch.cat(out, dim=1))
140 | out = self.scn(fusion + x)
141 | return out
142 |
143 | class DenseConnection(nn.Module):
144 | def __init__(self, unit, unit_num):
145 | super(DenseConnection, self).__init__()
146 | self.unit_num = unit_num
147 | self.channel = settings.channel
148 | self.units = nn.ModuleList()
149 | self.conv1x1 = nn.ModuleList()
150 | for i in range(self.unit_num):
151 | self.units.append(unit())
152 | self.conv1x1.append(nn.Sequential(nn.Conv2d((i+2)*self.channel, self.channel, 1, 1), nn.LeakyReLU(0.2)))
153 |
154 | def forward(self, x):
155 | cat = []
156 | cat.append(x)
157 | out = x
158 | for i in range(self.unit_num):
159 | tmp = self.units[i](out)
160 | cat.append(tmp)
161 | out = self.conv1x1[i](torch.cat(cat,dim=1))
162 | return out
163 |
164 |
165 | class ODE_DerainNet(nn.Module):
166 | def __init__(self):
167 | super(ODE_DerainNet, self).__init__()
168 | self.channel = settings.channel
169 | self.unit_num = settings.unit_num
170 | self.enterBlock = nn.Sequential(nn.Conv2d(3, self.channel, 3, 1, 1), nn.LeakyReLU(0.2))
171 | self.derain_net = DenseConnection(Scale_attention, self.unit_num)
172 | self.exitBlock = nn.Sequential(nn.Conv2d(self.channel, 3, 3, 1, 1), nn.LeakyReLU(0.2))
173 |
174 |
175 | def forward(self, x):
176 | image_feature = self.enterBlock(x)
177 | rain_feature = self.derain_net(image_feature)
178 | rain = self.exitBlock(rain_feature)
179 | derain = x - rain
180 | return derain
181 |
--------------------------------------------------------------------------------
/code/ablation/r2/config/settings.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 |
4 |
5 | ####################### 网络参数设置 ############################################################
6 | channel = 32
7 | feature_map_num = 32
8 | res_conv_num = 4
9 | unit_num = 32
10 | #scale_num = 4
11 | num_scale_attention = 4
12 | scale_attention = False
13 | ssim_loss = True
14 | ########################################################################################
15 | aug_data = False # Set as False for fair comparison
16 |
17 | patch_size = 64
18 | pic_is_pair = True #input picture is pair or single
19 |
20 | lr = 0.0005
21 |
22 | data_dir = '/data1/wangcong/dataset/rain12'
23 | if pic_is_pair is False:
24 | data_dir = '/data1/wangcong/dataset/real-world-images'
25 | log_dir = '../logdir'
26 | show_dir = '../showdir'
27 | model_dir = '../models'
28 | show_dir_feature = '../showdir_feature'
29 |
30 | log_level = 'info'
31 | model_path = os.path.join(model_dir, 'latest_net')
32 | save_steps = 400
33 |
34 | num_workers = 8
35 |
36 | num_GPU = 2
37 |
38 | device_id = '2,1'
39 |
40 | epoch = 1000
41 | batch_size = 24
42 |
43 | if pic_is_pair:
44 | root_dir = os.path.join(data_dir, 'train')
45 | mat_files = os.listdir(root_dir)
46 | num_datasets = len(mat_files)
47 | l1 = int(3/5 * epoch * num_datasets / batch_size)
48 | l2 = int(4/5 * epoch * num_datasets / batch_size)
49 | one_epoch = int(num_datasets/batch_size)
50 | total_step = int((epoch * num_datasets)/batch_size)
51 |
52 | logger = logging.getLogger('train')
53 | logger.setLevel(logging.INFO)
54 |
55 | ch = logging.StreamHandler()
56 | ch.setLevel(logging.INFO)
57 |
58 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
59 | ch.setFormatter(formatter)
60 | logger.addHandler(ch)
61 |
62 |
63 |
--------------------------------------------------------------------------------
/code/ablation/r2/config/show.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | import argparse
5 | import math
6 | import numpy as np
7 | import itertools
8 |
9 | import torch
10 | from torch import nn
11 | from torch.nn import DataParallel
12 | from torch.optim import Adam
13 | from torch.autograd import Variable
14 | from torch.utils.data import DataLoader
15 |
16 | import settings
17 | from dataset import ShowDataset
18 | from model import ODE_DerainNet
19 | from cal_ssim import SSIM
20 |
21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id
22 | logger = settings.logger
23 | torch.cuda.manual_seed_all(66)
24 | torch.manual_seed(66)
25 | #torch.cuda.set_device(settings.device_id)
26 |
27 |
28 | def ensure_dir(dir_path):
29 | if not os.path.isdir(dir_path):
30 | os.makedirs(dir_path)
31 |
32 | def PSNR(img1, img2):
33 | b,_,_,_=img1.shape
34 | #mse=0
35 | #for i in range(b):
36 | img1=np.clip(img1,0,255)
37 | img2=np.clip(img2,0,255)
38 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse
39 | if mse == 0:
40 | return 100
41 | #mse=mse/b
42 | PIXEL_MAX = 1
43 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
44 |
45 | class Session:
46 | def __init__(self):
47 | self.show_dir = settings.show_dir
48 | self.model_dir = settings.model_dir
49 | ensure_dir(settings.show_dir)
50 | ensure_dir(settings.model_dir)
51 | logger.info('set show dir as %s' % settings.show_dir)
52 | logger.info('set model dir as %s' % settings.model_dir)
53 |
54 | if len(settings.device_id) >1:
55 | self.net = nn.DataParallel(ODE_DerainNet()).cuda()
56 | #self.l2 = nn.DataParallel(MSELoss(),settings.device_id)
57 | #self.l1 = nn.DataParallel(nn.L1Loss(),settings.device_id)
58 | #self.ssim = nn.DataParallel(SSIM(),settings.device_id)
59 | #self.vgg = nn.DataParallel(VGG(),settings.device_id)
60 | else:
61 | torch.cuda.set_device(settings.device_id[0])
62 | self.net = ODE_DerainNet().cuda()
63 | self.ssim = SSIM().cuda()
64 | self.dataloaders = {}
65 | self.ssim=SSIM().cuda()
66 | self.a=0
67 | self.t=0
68 | def get_dataloader(self, dataset_name):
69 | dataset = ShowDataset(dataset_name)
70 | self.dataloaders[dataset_name] = \
71 | DataLoader(dataset, batch_size=1,
72 | shuffle=False, num_workers=1)
73 | return self.dataloaders[dataset_name]
74 |
75 | def load_checkpoints(self, name):
76 | ckp_path = os.path.join(self.model_dir, name)
77 | try:
78 | obj = torch.load(ckp_path)
79 | logger.info('Load checkpoint %s' % ckp_path)
80 | except FileNotFoundError:
81 | logger.info('No checkpoint %s!!' % ckp_path)
82 | return
83 | self.net.load_state_dict(obj['net'])
84 | def inf_batch(self, name, batch,i):
85 | O, B, file_name= batch['O'].cuda(), batch['B'].cuda(), batch['file_name']
86 | file_name = str(file_name[0])
87 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False)
88 | with torch.no_grad():
89 | import time
90 | t0=time.time()
91 | derain = self.net(O)
92 | t1 = time.time()
93 | comput_time=t1-t0
94 | print(comput_time)
95 | ssim = self.ssim(derain, B).data.cpu().numpy()
96 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255)
97 | print('psnr:%4f-------------ssim:%4f'%(psnr, ssim))
98 | return derain, psnr, ssim, file_name
99 |
100 | def save_image(self, No, imgs, name, psnr, ssim, file_name):
101 | for i, img in enumerate(imgs):
102 | img = (img.cpu().data * 255).numpy()
103 | img = np.clip(img, 0, 255)
104 | img = np.transpose(img, (1, 2, 0))
105 | h, w, c = img.shape
106 |
107 | img_file = os.path.join(self.show_dir, '%s.png' % (file_name))
108 | print(img_file)
109 | cv2.imwrite(img_file, img)
110 |
111 |
112 | def run_show(ckp_name):
113 | sess = Session()
114 | sess.load_checkpoints(ckp_name)
115 | sess.net.eval()
116 | dataset = 'test'
117 | if settings.pic_is_pair is False:
118 | dataset = 'train-small'
119 | dt = sess.get_dataloader(dataset)
120 |
121 | for i, batch in enumerate(dt):
122 | logger.info(i)
123 | if i>-1:
124 | imgs,psnr,ssim, file_name= sess.inf_batch('test', batch,i)
125 | sess.save_image(i, imgs, dataset, psnr, ssim, file_name)
126 |
127 |
128 |
129 | if __name__ == '__main__':
130 | parser = argparse.ArgumentParser()
131 | parser.add_argument('-m', '--model', default='latest_net')
132 |
133 | args = parser.parse_args(sys.argv[1:])
134 |
135 | run_show(args.model)
136 |
137 |
--------------------------------------------------------------------------------
/code/ablation/r2/config/tensorboard.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | function rand() {
3 | min=$1
4 | max=$(($2-$min+1))
5 | num=$(($RANDOM+1000000000000))
6 | echo $(($num%$max+$min))
7 | }
8 |
9 | rnd=$(rand 3000 12000)
10 | tensorboard --logdir ../logdir --host 0.0.0.0 --port $rnd --reload_interval 3
11 |
--------------------------------------------------------------------------------
/code/ablation/r2/config/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | import argparse
5 | import math
6 |
7 | import torch
8 | from torch import nn
9 | from torch.nn import MSELoss
10 | from torch.optim import Adam
11 | from torch.optim.lr_scheduler import MultiStepLR
12 | from torch.autograd import Variable
13 | from torch.utils.data import DataLoader
14 | from tensorboardX import SummaryWriter
15 |
16 | import settings
17 | from dataset import TrainValDataset, TestDataset
18 | from model import ODE_DerainNet
19 | from cal_ssim import SSIM
20 |
21 | logger = settings.logger
22 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id
23 | torch.cuda.manual_seed_all(66)
24 | torch.manual_seed(66)
25 | import numpy as np
26 |
27 |
28 | def ensure_dir(dir_path):
29 | if not os.path.isdir(dir_path):
30 | os.makedirs(dir_path)
31 |
32 | def PSNR(img1, img2):
33 | b,_,_,_=img1.shape
34 | img1=np.clip(img1,0,255)
35 | img2=np.clip(img2,0,255)
36 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)
37 | if mse == 0:
38 | return 100
39 |
40 | PIXEL_MAX = 1
41 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
42 |
43 | class Session:
44 | def __init__(self):
45 | self.log_dir = settings.log_dir
46 | self.model_dir = settings.model_dir
47 | self.ssim_loss = settings.ssim_loss
48 | ensure_dir(settings.log_dir)
49 | ensure_dir(settings.model_dir)
50 | ensure_dir('../log_test')
51 | logger.info('set log dir as %s' % settings.log_dir)
52 | logger.info('set model dir as %s' % settings.model_dir)
53 | if len(settings.device_id) >1:
54 | self.net = nn.DataParallel(ODE_DerainNet()).cuda()
55 | else:
56 | torch.cuda.set_device(settings.device_id[0])
57 | self.net = ODE_DerainNet().cuda()
58 | self.l1 = nn.L1Loss().cuda()
59 | self.ssim = SSIM().cuda()
60 | self.step = 0
61 | self.save_steps = settings.save_steps
62 | self.num_workers = settings.num_workers
63 | self.batch_size = settings.batch_size
64 | self.writers = {}
65 | self.dataloaders = {}
66 | self.opt_net = Adam(self.net.parameters(), lr=settings.lr)
67 | self.sche_net = MultiStepLR(self.opt_net, milestones=[settings.l1, settings.l2], gamma=0.1)
68 |
69 | def tensorboard(self, name):
70 | self.writers[name] = SummaryWriter(os.path.join(self.log_dir, name + '.events'))
71 | return self.writers[name]
72 |
73 | def write(self, name, out):
74 | for k, v in out.items():
75 | self.writers[name].add_scalar(k, v, self.step)
76 | out['lr'] = self.opt_net.param_groups[0]['lr']
77 | out['step'] = self.step
78 | outputs = [
79 | "{}:{:.4g}".format(k, v)
80 | for k, v in out.items()
81 | ]
82 | logger.info(name + '--' + ' '.join(outputs))
83 |
84 | def get_dataloader(self, dataset_name):
85 | dataset = TrainValDataset(dataset_name)
86 | if not dataset_name in self.dataloaders:
87 | self.dataloaders[dataset_name] = \
88 | DataLoader(dataset, batch_size=self.batch_size,
89 | shuffle=True, num_workers=self.num_workers, drop_last=True)
90 | return iter(self.dataloaders[dataset_name])
91 |
92 | def get_test_dataloader(self, dataset_name):
93 | dataset = TestDataset(dataset_name)
94 | if not dataset_name in self.dataloaders:
95 | self.dataloaders[dataset_name] = \
96 | DataLoader(dataset, batch_size=1,
97 | shuffle=False, num_workers=1, drop_last=False)
98 | return self.dataloaders[dataset_name]
99 |
100 | def save_checkpoints_net(self, name):
101 | ckp_path = os.path.join(self.model_dir, name)
102 | obj = {
103 | 'net': self.net.state_dict(),
104 | 'clock_net': self.step,
105 | 'opt_net': self.opt_net.state_dict(),
106 | }
107 | torch.save(obj, ckp_path)
108 |
109 | def load_checkpoints_net(self, name):
110 | ckp_path = os.path.join(self.model_dir, name)
111 | try:
112 | logger.info('Load checkpoint %s' % ckp_path)
113 | obj = torch.load(ckp_path)
114 | except FileNotFoundError:
115 | logger.info('No checkpoint %s!!' % ckp_path)
116 | return
117 | self.net.load_state_dict(obj['net'])
118 | self.opt_net.load_state_dict(obj['opt_net'])
119 | self.step = obj['clock_net']
120 | self.sche_net.last_epoch = self.step
121 |
122 | def print_network(self, model):
123 | num_params = 0
124 | for p in model.parameters():
125 | num_params += p.numel()
126 | print(model)
127 | print("The number of parameters: {}".format(num_params))
128 |
129 | def inf_batch(self, name, batch):
130 | if name == 'train':
131 | self.net.zero_grad()
132 | if self.step==0:
133 | self.print_network(self.net)
134 |
135 | O, B = batch['O'].cuda(), batch['B'].cuda()
136 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False)
137 | derain = self.net(O)
138 | l1_loss = self.l1(derain, B)
139 | ssim = self.ssim(derain, B)
140 | if self.ssim_loss == True :
141 | loss = -ssim
142 | else:
143 | loss = l1_loss
144 | if name == 'train':
145 | loss.backward()
146 | self.opt_net.step()
147 | losses = {'L1loss' : l1_loss}
148 | ssimes = {'ssim' : ssim}
149 | losses.update(ssimes)
150 | self.write(name, losses)
151 |
152 | return derain
153 |
154 | def save_image(self, name, img_lists):
155 | data, pred, label = img_lists
156 | pred = pred.cpu().data
157 |
158 | data, label, pred = data * 255, label * 255, pred * 255
159 | pred = np.clip(pred, 0, 255)
160 | h, w = pred.shape[-2:]
161 | gen_num = (1, 1)
162 | img = np.zeros((gen_num[0] * h, gen_num[1] * 3 * w, 3))
163 | for img_list in img_lists:
164 | for i in range(gen_num[0]):
165 | row = i * h
166 | for j in range(gen_num[1]):
167 | idx = i * gen_num[1] + j
168 | tmp_list = [data[idx], pred[idx], label[idx]]
169 | for k in range(3):
170 | col = (j * 3 + k) * w
171 | tmp = np.transpose(tmp_list[k], (1, 2, 0))
172 | img[row: row+h, col: col+w] = tmp
173 | img_file = os.path.join(self.log_dir, '%d_%s.jpg' % (self.step, name))
174 | cv2.imwrite(img_file, img)
175 |
176 | def inf_batch_test(self, name, batch):
177 | O, B = batch['O'].cuda(), batch['B'].cuda()
178 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False)
179 |
180 | with torch.no_grad():
181 | derain = self.net(O)
182 |
183 | l1_loss = self.l1(derain, B)
184 | ssim = self.ssim(derain, B)
185 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255)
186 | losses = { 'L1 loss' : l1_loss }
187 | ssimes = { 'ssim' : ssim }
188 | losses.update(ssimes)
189 |
190 | return l1_loss.data.cpu().numpy(), ssim.data.cpu().numpy(), psnr
191 |
192 |
193 | def run_train_val(ckp_name_net='latest_net'):
194 | sess = Session()
195 | sess.load_checkpoints_net(ckp_name_net)
196 | sess.tensorboard('train')
197 | dt_train = sess.get_dataloader('train')
198 | while sess.step < settings.total_step+1:
199 | sess.sche_net.step()
200 | sess.net.train()
201 | try:
202 | batch_t = next(dt_train)
203 | except StopIteration:
204 | dt_train = sess.get_dataloader('train')
205 | batch_t = next(dt_train)
206 | pred_t = sess.inf_batch('train', batch_t)
207 | if sess.step % int(sess.save_steps / 16) == 0:
208 | sess.save_checkpoints_net('latest_net')
209 | if sess.step % sess.save_steps == 0:
210 | sess.save_image('train', [batch_t['O'], pred_t, batch_t['B']])
211 |
212 | #observe tendency of ssim, psnr and loss
213 | ssim_all = 0
214 | psnr_all = 0
215 | loss_all = 0
216 | num_all = 0
217 | if sess.step % (settings.one_epoch * 20) == 0:
218 | dt_val = sess.get_test_dataloader('test')
219 | sess.net.eval()
220 | for i, batch_v in enumerate(dt_val):
221 | loss, ssim, psnr = sess.inf_batch_test('test', batch_v)
222 | print(i)
223 | ssim_all = ssim_all + ssim
224 | psnr_all = psnr_all + psnr
225 | loss_all = loss_all + loss
226 | num_all = num_all + 1
227 | print('num_all:',num_all)
228 | loss_avg = loss_all / num_all
229 | ssim_avg = ssim_all / num_all
230 | psnr_avg = psnr_all / num_all
231 | logfile = open('../log_test/' + 'val' + '.txt','a+')
232 | epoch = int(sess.step / settings.one_epoch)
233 | logfile.write(
234 | 'step = ' + str(sess.step) + '\t'
235 | 'epoch = ' + str(epoch) + '\t'
236 | 'loss = ' + str(loss_avg) + '\t'
237 | 'ssim = ' + str(ssim_avg) + '\t'
238 | 'pnsr = ' + str(psnr_avg) + '\t'
239 | '\n\n'
240 | )
241 | logfile.close()
242 | if sess.step % (settings.one_epoch*10) == 0:
243 | sess.save_checkpoints_net('net_%d_epoch' % int(sess.step / settings.one_epoch))
244 | logger.info('save model as net_%d_epoch' % int(sess.step / settings.one_epoch))
245 | sess.step += 1
246 |
247 |
248 |
249 | if __name__ == '__main__':
250 | parser = argparse.ArgumentParser()
251 | parser.add_argument('-m1', '--model_1', default='latest_net')
252 |
253 | args = parser.parse_args(sys.argv[1:])
254 | run_train_val(args.model_1)
255 |
256 |
--------------------------------------------------------------------------------
/code/ablation/r2/models/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/ablation/r2/models/.DS_Store
--------------------------------------------------------------------------------
/code/ablation/r2/models/README.md:
--------------------------------------------------------------------------------
1 | # JDNet:Joint Self-Attention and Scale-Aggregation for Self-Calibrated Deraining Network
2 |
3 | The training model or pre-training model will be placed here.
4 |
--------------------------------------------------------------------------------
/code/base/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/.DS_Store
--------------------------------------------------------------------------------
/code/base/rain100H/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/.DS_Store
--------------------------------------------------------------------------------
/code/base/rain100H/config/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/config/.DS_Store
--------------------------------------------------------------------------------
/code/base/rain100H/config/cal_ssim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | import numpy as np
5 | from math import exp
6 |
7 | def gaussian(window_size, sigma):
8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
9 | return gauss/gauss.sum()
10 |
11 | def create_window(window_size, channel):
12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
15 | return window
16 |
17 | def _ssim(img1, img2, window, window_size, channel, size_average = True):
18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
20 |
21 | mu1_sq = mu1.pow(2)
22 | mu2_sq = mu2.pow(2)
23 | mu1_mu2 = mu1*mu2
24 |
25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
28 |
29 | C1 = 0.01**2
30 | C2 = 0.03**2
31 |
32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
33 |
34 | if size_average:
35 | return ssim_map.mean()
36 | else:
37 | return ssim_map.mean(1).mean(1).mean(1)
38 |
39 | class SSIM(torch.nn.Module):
40 | def __init__(self, window_size = 11, size_average = True):
41 | super(SSIM, self).__init__()
42 | self.window_size = window_size
43 | self.size_average = size_average
44 | self.channel = 3
45 | self.window = create_window(window_size, self.channel)
46 |
47 | def forward(self, img1, img2):
48 | #print(img1.size())
49 | (_, channel, _, _) = img1.size()
50 |
51 | if channel == self.channel and self.window.data.type() == img1.data.type():
52 | window = self.window
53 | else:
54 | window = create_window(self.window_size, channel)
55 |
56 | if img1.is_cuda:
57 | window = window.cuda(img1.get_device())
58 | window = window.type_as(img1)
59 |
60 | self.window = window
61 | self.channel = channel
62 |
63 |
64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
65 |
66 | def ssim(img1, img2, window_size = 11, size_average = True):
67 | (_, channel, _, _) = img1.size()
68 | window = create_window(window_size, channel)
69 |
70 | if img1.is_cuda:
71 | window = window.cuda(img1.get_device())
72 | window = window.type_as(img1)
73 |
74 | return _ssim(img1, img2, window, window_size, channel, size_average)
75 |
--------------------------------------------------------------------------------
/code/base/rain100H/config/clean.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | rm -rf __pycache__
3 | rm \.*\.swp
4 | rm -R ../logdir/*
5 | rm -R ../showdir/*
6 | rm ../models/*
7 |
--------------------------------------------------------------------------------
/code/base/rain100H/config/compile.py:
--------------------------------------------------------------------------------
1 | import re
2 | rec_data="+MIPLOBSERVE:0,68220,1,3303,0,-1"
3 | msgidRegex = re.compile(r',(\d)+,')
4 | mo = msgidRegex.search(rec_data)
5 | print(mo.group())
--------------------------------------------------------------------------------
/code/base/rain100H/config/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | from numpy.random import RandomState
5 | from torch.utils.data import Dataset
6 |
7 | import settings
8 |
9 |
10 | class TrainValDataset(Dataset):
11 | def __init__(self, name):
12 | super().__init__()
13 | self.rand_state = RandomState(66)
14 | self.root_dir = os.path.join(settings.data_dir, name)
15 | self.mat_files = os.listdir(self.root_dir)
16 | self.patch_size = settings.patch_size
17 | self.file_num = len(self.mat_files)
18 |
19 | def __len__(self):
20 | return self.file_num * 100
21 |
22 | def __getitem__(self, idx):
23 | file_name = self.mat_files[idx % self.file_num]
24 | img_file = os.path.join(self.root_dir, file_name)
25 | img_pair = cv2.imread(img_file).astype(np.float32) / 255
26 |
27 | if settings.aug_data:
28 | O, B = self.crop(img_pair, aug=True)
29 | O, B = self.flip(O, B)
30 | O, B = self.rotate(O, B)
31 | else:
32 | O, B = self.crop(img_pair, aug=False)
33 |
34 | O = np.transpose(O, (2, 0, 1))
35 | B = np.transpose(B, (2, 0, 1))
36 | sample = {'O': O, 'B': B}
37 |
38 | return sample
39 |
40 | def crop(self, img_pair, aug):
41 | patch_size = self.patch_size
42 | h, ww, c = img_pair.shape
43 | w = int(ww / 2)
44 |
45 | if aug:
46 | mini = - 1 / 4 * self.patch_size
47 | maxi = 1 / 4 * self.patch_size + 1
48 | p_h = patch_size + self.rand_state.randint(mini, maxi)
49 | p_w = patch_size + self.rand_state.randint(mini, maxi)
50 | else:
51 | p_h, p_w = patch_size, patch_size
52 |
53 | r = self.rand_state.randint(0, h - p_h)
54 | c = self.rand_state.randint(0, w - p_w)
55 | O = img_pair[r: r+p_h, c+w: c+p_w+w]
56 | B = img_pair[r: r+p_h, c: c+p_w]
57 |
58 | if aug:
59 | O = cv2.resize(O, (patch_size, patch_size))
60 | B = cv2.resize(B, (patch_size, patch_size))
61 |
62 | return O, B
63 |
64 | def flip(self, O, B):
65 | if self.rand_state.rand() > 0.5:
66 | O = np.flip(O, axis=1)
67 | B = np.flip(B, axis=1)
68 | return O, B
69 |
70 | def rotate(self, O, B):
71 | angle = self.rand_state.randint(-30, 30)
72 | patch_size = self.patch_size
73 | center = (int(patch_size / 2), int(patch_size / 2))
74 | M = cv2.getRotationMatrix2D(center, angle, 1)
75 | O = cv2.warpAffine(O, M, (patch_size, patch_size))
76 | B = cv2.warpAffine(B, M, (patch_size, patch_size))
77 | return O, B
78 |
79 |
80 | class TestDataset(Dataset):
81 | def __init__(self, name):
82 | super().__init__()
83 | self.rand_state = RandomState(66)
84 | self.root_dir = os.path.join(settings.data_dir, name)
85 | self.mat_files = os.listdir(self.root_dir)
86 | self.patch_size = settings.patch_size
87 | self.file_num = len(self.mat_files)
88 |
89 | def __len__(self):
90 | return self.file_num
91 |
92 | def __getitem__(self, idx):
93 | file_name = self.mat_files[idx % self.file_num]
94 | img_file = os.path.join(self.root_dir, file_name)
95 | img_pair = cv2.imread(img_file).astype(np.float32) / 255
96 | h, ww, c = img_pair.shape
97 | w = int(ww / 2)
98 | #h_8=h%8
99 | #w_8=w%8
100 | O = np.transpose(img_pair[:, w:], (2, 0, 1))
101 | B = np.transpose(img_pair[:, :w], (2, 0, 1))
102 | sample = {'O': O, 'B': B}
103 |
104 | return sample
105 |
106 |
107 | class ShowDataset(Dataset):
108 | def __init__(self, name):
109 | super().__init__()
110 | self.rand_state = RandomState(66)
111 | self.root_dir = os.path.join(settings.data_dir, name)
112 | self.img_files = sorted(os.listdir(self.root_dir))
113 | self.file_num = len(self.img_files)
114 |
115 | def __len__(self):
116 | return self.file_num
117 |
118 | def __getitem__(self, idx):
119 | file_name = self.img_files[idx % self.file_num]
120 | img_file = os.path.join(self.root_dir, file_name)
121 | print(img_file)
122 | img_pair = cv2.imread(img_file).astype(np.float32) / 255
123 |
124 | h, ww, c = img_pair.shape
125 | w = int(ww / 2)
126 |
127 | #h_8 = h % 8
128 | #w_8 = w % 8
129 | if settings.pic_is_pair:
130 | O = np.transpose(img_pair[:, w:], (2, 0, 1))
131 | B = np.transpose(img_pair[:, :w], (2, 0, 1))
132 | else:
133 | O = np.transpose(img_pair[:, :], (2, 0, 1))
134 | B = np.transpose(img_pair[:, :], (2, 0, 1))
135 | sample = {'O': O, 'B': B,'file_name':file_name[:-4]}
136 |
137 | return sample
138 |
139 |
140 | if __name__ == '__main__':
141 | dt = TrainValDataset('val')
142 | print('TrainValDataset')
143 | for i in range(10):
144 | smp = dt[i]
145 | for k, v in smp.items():
146 | print(k, v.shape, v.dtype, v.mean())
147 |
148 | print()
149 | dt = TestDataset('test')
150 | print('TestDataset')
151 | for i in range(10):
152 | smp = dt[i]
153 | for k, v in smp.items():
154 | print(k, v.shape, v.dtype, v.mean())
155 |
156 | print()
157 | print('ShowDataset')
158 | dt = ShowDataset('test')
159 | for i in range(10):
160 | smp = dt[i]
161 | for k, v in smp.items():
162 | print(k, v.shape, v.dtype, v.mean())
163 |
--------------------------------------------------------------------------------
/code/base/rain100H/config/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | import argparse
5 | import numpy as np
6 | import math
7 | import torch
8 | from torch import nn
9 | from torch.nn import MSELoss
10 | from torch.optim import Adam
11 | from torch.optim.lr_scheduler import MultiStepLR
12 | from torch.autograd import Variable
13 | from torch.utils.data import DataLoader
14 | from tensorboardX import SummaryWriter
15 |
16 | import settings
17 | from dataset import TestDataset
18 | from model import ODE_DerainNet
19 | from cal_ssim import SSIM
20 |
21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id
22 | logger = settings.logger
23 | torch.cuda.manual_seed_all(66)
24 | torch.manual_seed(66)
25 | #torch.cuda.set_device(settings.device_id)
26 |
27 |
28 | def ensure_dir(dir_path):
29 | if not os.path.isdir(dir_path):
30 | os.makedirs(dir_path)
31 | def PSNR(img1, img2):
32 | b,_,_,_=img1.shape
33 | #mse=0
34 | #for i in range(b):
35 | img1=np.clip(img1,0,255)
36 | img2=np.clip(img2,0,255)
37 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse
38 | if mse == 0:
39 | return 100
40 | #mse=mse/b
41 | PIXEL_MAX = 1
42 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
43 | class Session:
44 | def __init__(self):
45 | self.log_dir = settings.log_dir
46 | self.model_dir = settings.model_dir
47 | ensure_dir(settings.log_dir)
48 | ensure_dir(settings.model_dir)
49 | logger.info('set log dir as %s' % settings.log_dir)
50 | logger.info('set model dir as %s' % settings.model_dir)
51 | if len(settings.device_id) >1:
52 | self.net = nn.DataParallel(ODE_DerainNet()).cuda()
53 | else:
54 | torch.cuda.set_device(settings.device_id[0])
55 | self.net = ODE_DerainNet().cuda()
56 | self.l2 = MSELoss().cuda()
57 | self.l1 = nn.L1Loss().cuda()
58 | self.ssim = SSIM().cuda()
59 | self.dataloaders = {}
60 |
61 | def get_dataloader(self, dataset_name):
62 | dataset = TestDataset(dataset_name)
63 | if not dataset_name in self.dataloaders:
64 | self.dataloaders[dataset_name] = \
65 | DataLoader(dataset, batch_size=1,
66 | shuffle=False, num_workers=1, drop_last=False)
67 | return self.dataloaders[dataset_name]
68 |
69 | def load_checkpoints(self, name):
70 | ckp_path = os.path.join(self.model_dir, name)
71 | try:
72 | obj = torch.load(ckp_path)
73 | logger.info('Load checkpoint %s' % ckp_path)
74 | except FileNotFoundError:
75 | logger.info('No checkpoint %s!!' % ckp_path)
76 | return
77 | self.net.load_state_dict(obj['net'])
78 |
79 | def inf_batch(self, name, batch):
80 | O, B = batch['O'].cuda(), batch['B'].cuda()
81 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False)
82 |
83 | with torch.no_grad():
84 | derain = self.net(O)
85 |
86 | l1_loss = self.l1(derain, B)
87 | ssim = self.ssim(derain, B)
88 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255)
89 | losses = { 'L1 loss' : l1_loss }
90 | ssimes = { 'ssim' : ssim }
91 | losses.update(ssimes)
92 |
93 | return losses, psnr
94 |
95 |
96 | def run_test(ckp_name):
97 | sess = Session()
98 | sess.net.eval()
99 | sess.load_checkpoints(ckp_name)
100 | dt = sess.get_dataloader('test')
101 | psnr_all = 0
102 | all_num = 0
103 | all_losses = {}
104 | for i, batch in enumerate(dt):
105 | losses,psnr= sess.inf_batch('test', batch)
106 | psnr_all=psnr_all+psnr
107 | batch_size = batch['O'].size(0)
108 | all_num += batch_size
109 | for key, val in losses.items():
110 | if i == 0:
111 | all_losses[key] = 0.
112 | all_losses[key] += val * batch_size
113 | logger.info('batch %d mse %s: %f' % (i, key, val))
114 |
115 | for key, val in all_losses.items():
116 | logger.info('total mse %s: %f' % (key, val / all_num))
117 | #psnr=sum(psnr_all)
118 | #print(psnr)
119 | print('psnr_ll:%8f'%(psnr_all/all_num))
120 |
121 | if __name__ == '__main__':
122 | parser = argparse.ArgumentParser()
123 | parser.add_argument('-m', '--model', default='latest_net')
124 |
125 | args = parser.parse_args(sys.argv[1:])
126 | run_test(args.model)
127 |
128 |
--------------------------------------------------------------------------------
/code/base/rain100H/config/function/__pycache__/functional.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/config/function/__pycache__/functional.cpython-37.pyc
--------------------------------------------------------------------------------
/code/base/rain100H/config/function/functional.py:
--------------------------------------------------------------------------------
1 | from . import functions
2 |
3 |
4 | def aggregation(input, weight, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1):
5 | assert input.shape[0] == weight.shape[0] and (input.shape[1] % weight.shape[1] == 0) and pad_mode in [0, 1]
6 | if input.is_cuda:
7 | if pad_mode == 0:
8 | out = functions.aggregation_zeropad(input, weight, kernel_size, stride, padding, dilation)
9 | elif pad_mode == 1:
10 | out = functions.aggregation_refpad(input, weight, kernel_size, stride, padding, dilation)
11 | else:
12 | raise NotImplementedError
13 | return out
14 |
15 |
16 | def subtraction(input, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1):
17 | assert input.dim() == 4 and pad_mode in [0, 1]
18 | if input.is_cuda:
19 | if pad_mode == 0:
20 | out = functions.subtraction_zeropad(input, kernel_size, stride, padding, dilation)
21 | elif pad_mode == 1:
22 | out = functions.subtraction_refpad(input, kernel_size, stride, padding, dilation)
23 | else:
24 | raise NotImplementedError
25 | return out
26 |
27 |
28 | def subtraction2(input1, input2, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1):
29 | assert input1.dim() == 4 and input2.dim() == 4 and pad_mode in [0, 1]
30 | if input1.is_cuda:
31 | if pad_mode == 0:
32 | out = functions.subtraction2_zeropad(input1, input2, kernel_size, stride, padding, dilation)
33 | elif pad_mode == 1:
34 | out = functions.subtraction2_refpad(input1, input2, kernel_size, stride, padding, dilation)
35 | else:
36 | raise NotImplementedError
37 | return out
38 |
--------------------------------------------------------------------------------
/code/base/rain100H/config/function/functions/__init__.py:
--------------------------------------------------------------------------------
1 | from .aggregation_zeropad import *
2 | from .aggregation_refpad import *
3 | from .subtraction_zeropad import *
4 | from .subtraction_refpad import *
5 | from .subtraction2_zeropad import *
6 | from .subtraction2_refpad import *
7 | from .utils import *
8 |
--------------------------------------------------------------------------------
/code/base/rain100H/config/function/functions/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/config/function/functions/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/code/base/rain100H/config/function/functions/__pycache__/aggregation_refpad.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/config/function/functions/__pycache__/aggregation_refpad.cpython-37.pyc
--------------------------------------------------------------------------------
/code/base/rain100H/config/function/functions/__pycache__/aggregation_zeropad.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/config/function/functions/__pycache__/aggregation_zeropad.cpython-37.pyc
--------------------------------------------------------------------------------
/code/base/rain100H/config/function/functions/__pycache__/subtraction2_refpad.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/config/function/functions/__pycache__/subtraction2_refpad.cpython-37.pyc
--------------------------------------------------------------------------------
/code/base/rain100H/config/function/functions/__pycache__/subtraction2_zeropad.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/config/function/functions/__pycache__/subtraction2_zeropad.cpython-37.pyc
--------------------------------------------------------------------------------
/code/base/rain100H/config/function/functions/__pycache__/subtraction_refpad.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/config/function/functions/__pycache__/subtraction_refpad.cpython-37.pyc
--------------------------------------------------------------------------------
/code/base/rain100H/config/function/functions/__pycache__/subtraction_zeropad.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/config/function/functions/__pycache__/subtraction_zeropad.cpython-37.pyc
--------------------------------------------------------------------------------
/code/base/rain100H/config/function/functions/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/config/function/functions/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/code/base/rain100H/config/function/functions/utils.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from string import Template
3 | import cupy
4 | import torch
5 |
6 |
7 | Stream = namedtuple('Stream', ['ptr'])
8 |
9 |
10 | def Dtype(t):
11 | if isinstance(t, torch.cuda.FloatTensor):
12 | return 'float'
13 | elif isinstance(t, torch.cuda.DoubleTensor):
14 | return 'double'
15 |
16 |
17 | @cupy.util.memoize(for_each_device=True)
18 | def load_kernel(kernel_name, code, **kwargs):
19 | code = Template(code).substitute(**kwargs)
20 | kernel_code = cupy.cuda.compile_with_cache(code)
21 | return kernel_code.get_function(kernel_name)
22 |
--------------------------------------------------------------------------------
/code/base/rain100H/config/function/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .aggregation import *
2 | from .subtraction import *
3 | from .subtraction2 import *
4 |
--------------------------------------------------------------------------------
/code/base/rain100H/config/function/modules/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/config/function/modules/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/code/base/rain100H/config/function/modules/__pycache__/aggregation.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/config/function/modules/__pycache__/aggregation.cpython-37.pyc
--------------------------------------------------------------------------------
/code/base/rain100H/config/function/modules/__pycache__/subtraction.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/config/function/modules/__pycache__/subtraction.cpython-37.pyc
--------------------------------------------------------------------------------
/code/base/rain100H/config/function/modules/__pycache__/subtraction2.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/config/function/modules/__pycache__/subtraction2.cpython-37.pyc
--------------------------------------------------------------------------------
/code/base/rain100H/config/function/modules/aggregation.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch.nn.modules.utils import _pair
3 |
4 | from .. import functional as F
5 |
6 |
7 | class Aggregation(nn.Module):
8 |
9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode):
10 | super(Aggregation, self).__init__()
11 | self.kernel_size = _pair(kernel_size)
12 | self.stride = _pair(stride)
13 | self.padding = _pair(padding)
14 | self.dilation = _pair(dilation)
15 | self.pad_mode = pad_mode
16 |
17 | def forward(self, input, weight):
18 | return F.aggregation(input, weight, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode)
19 |
--------------------------------------------------------------------------------
/code/base/rain100H/config/function/modules/subtraction.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch.nn.modules.utils import _pair
3 |
4 | from .. import functional as F
5 |
6 |
7 | class Subtraction(nn.Module):
8 |
9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode):
10 | super(Subtraction, self).__init__()
11 | self.kernel_size = _pair(kernel_size)
12 | self.stride = _pair(stride)
13 | self.padding = _pair(padding)
14 | self.dilation = _pair(dilation)
15 | self.pad_mode = pad_mode
16 |
17 | def forward(self, input):
18 | return F.subtraction(input, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode)
19 |
--------------------------------------------------------------------------------
/code/base/rain100H/config/function/modules/subtraction2.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch.nn.modules.utils import _pair
3 |
4 | from .. import functional as F
5 |
6 |
7 | class Subtraction2(nn.Module):
8 |
9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode):
10 | super(Subtraction2, self).__init__()
11 | self.kernel_size = _pair(kernel_size)
12 | self.stride = _pair(stride)
13 | self.padding = _pair(padding)
14 | self.dilation = _pair(dilation)
15 | self.pad_mode = pad_mode
16 |
17 | def forward(self, input1, input2):
18 | return F.subtraction2(input1, input2, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode)
19 |
--------------------------------------------------------------------------------
/code/base/rain100H/config/settings.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 |
4 |
5 | ####################### 网络参数设置 ############################################################
6 | channel = 32
7 | feature_map_num = 32
8 | res_conv_num = 4
9 | unit_num = 32
10 | #scale_num = 4
11 | num_scale_attention = 4
12 | scale_attention = False
13 | ssim_loss = True
14 | ########################################################################################
15 | aug_data = False # Set as False for fair comparison
16 |
17 | patch_size = 64
18 | pic_is_pair = True #input picture is pair or single
19 |
20 | lr = 0.0005
21 |
22 | data_dir = '/data1/wangcong/dataset/rain12'
23 | if pic_is_pair is False:
24 | data_dir = '/data1/wangcong/dataset/real-world-images'
25 | log_dir = '../logdir'
26 | show_dir = '../showdir'
27 | model_dir = '../models'
28 | show_dir_feature = '../showdir_feature'
29 |
30 | log_level = 'info'
31 | model_path = os.path.join(model_dir, 'latest_net')
32 | save_steps = 400
33 |
34 | num_workers = 8
35 |
36 | num_GPU = 2
37 |
38 | device_id = '1,2'
39 |
40 | epoch = 1000
41 | batch_size = 12
42 |
43 | if pic_is_pair:
44 | root_dir = os.path.join(data_dir, 'train')
45 | mat_files = os.listdir(root_dir)
46 | num_datasets = len(mat_files)
47 | l1 = int(3/5 * epoch * num_datasets / batch_size)
48 | l2 = int(4/5 * epoch * num_datasets / batch_size)
49 | one_epoch = int(num_datasets/batch_size)
50 | total_step = int((epoch * num_datasets)/batch_size)
51 |
52 | logger = logging.getLogger('train')
53 | logger.setLevel(logging.INFO)
54 |
55 | ch = logging.StreamHandler()
56 | ch.setLevel(logging.INFO)
57 |
58 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
59 | ch.setFormatter(formatter)
60 | logger.addHandler(ch)
61 |
62 |
63 |
--------------------------------------------------------------------------------
/code/base/rain100H/config/show.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | import argparse
5 | import math
6 | import numpy as np
7 | import itertools
8 |
9 | import torch
10 | from torch import nn
11 | from torch.nn import DataParallel
12 | from torch.optim import Adam
13 | from torch.autograd import Variable
14 | from torch.utils.data import DataLoader
15 |
16 | import settings
17 | from dataset import ShowDataset
18 | from model import ODE_DerainNet
19 | from cal_ssim import SSIM
20 |
21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id
22 | logger = settings.logger
23 | torch.cuda.manual_seed_all(66)
24 | torch.manual_seed(66)
25 | #torch.cuda.set_device(settings.device_id)
26 |
27 |
28 | def ensure_dir(dir_path):
29 | if not os.path.isdir(dir_path):
30 | os.makedirs(dir_path)
31 |
32 | def PSNR(img1, img2):
33 | b,_,_,_=img1.shape
34 | #mse=0
35 | #for i in range(b):
36 | img1=np.clip(img1,0,255)
37 | img2=np.clip(img2,0,255)
38 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse
39 | if mse == 0:
40 | return 100
41 | #mse=mse/b
42 | PIXEL_MAX = 1
43 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
44 |
45 | class Session:
46 | def __init__(self):
47 | self.show_dir = settings.show_dir
48 | self.model_dir = settings.model_dir
49 | ensure_dir(settings.show_dir)
50 | ensure_dir(settings.model_dir)
51 | logger.info('set show dir as %s' % settings.show_dir)
52 | logger.info('set model dir as %s' % settings.model_dir)
53 |
54 | if len(settings.device_id) >1:
55 | self.net = nn.DataParallel(ODE_DerainNet()).cuda()
56 | #self.l2 = nn.DataParallel(MSELoss(),settings.device_id)
57 | #self.l1 = nn.DataParallel(nn.L1Loss(),settings.device_id)
58 | #self.ssim = nn.DataParallel(SSIM(),settings.device_id)
59 | #self.vgg = nn.DataParallel(VGG(),settings.device_id)
60 | else:
61 | torch.cuda.set_device(settings.device_id[0])
62 | self.net = ODE_DerainNet().cuda()
63 | self.ssim = SSIM().cuda()
64 | self.dataloaders = {}
65 | self.ssim=SSIM().cuda()
66 | self.a=0
67 | self.t=0
68 | def get_dataloader(self, dataset_name):
69 | dataset = ShowDataset(dataset_name)
70 | self.dataloaders[dataset_name] = \
71 | DataLoader(dataset, batch_size=1,
72 | shuffle=False, num_workers=1)
73 | return self.dataloaders[dataset_name]
74 |
75 | def load_checkpoints(self, name):
76 | ckp_path = os.path.join(self.model_dir, name)
77 | try:
78 | obj = torch.load(ckp_path)
79 | logger.info('Load checkpoint %s' % ckp_path)
80 | except FileNotFoundError:
81 | logger.info('No checkpoint %s!!' % ckp_path)
82 | return
83 | self.net.load_state_dict(obj['net'])
84 | def inf_batch(self, name, batch,i):
85 | O, B, file_name= batch['O'].cuda(), batch['B'].cuda(), batch['file_name']
86 | file_name = str(file_name[0])
87 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False)
88 | with torch.no_grad():
89 | import time
90 | t0=time.time()
91 | derain = self.net(O)
92 | t1 = time.time()
93 | comput_time=t1-t0
94 | print(comput_time)
95 | ssim = self.ssim(derain, B).data.cpu().numpy()
96 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255)
97 | print('psnr:%4f-------------ssim:%4f'%(psnr, ssim))
98 | return derain, psnr, ssim, file_name
99 |
100 | def save_image(self, No, imgs, name, psnr, ssim, file_name):
101 | for i, img in enumerate(imgs):
102 | img = (img.cpu().data * 255).numpy()
103 | img = np.clip(img, 0, 255)
104 | img = np.transpose(img, (1, 2, 0))
105 | h, w, c = img.shape
106 |
107 | img_file = os.path.join(self.show_dir, '%s.png' % (file_name))
108 | print(img_file)
109 | cv2.imwrite(img_file, img)
110 |
111 |
112 | def run_show(ckp_name):
113 | sess = Session()
114 | sess.load_checkpoints(ckp_name)
115 | sess.net.eval()
116 | dataset = 'test'
117 | if settings.pic_is_pair is False:
118 | dataset = 'train-small'
119 | dt = sess.get_dataloader(dataset)
120 |
121 | for i, batch in enumerate(dt):
122 | logger.info(i)
123 | if i>-1:
124 | imgs,psnr,ssim, file_name= sess.inf_batch('test', batch,i)
125 | sess.save_image(i, imgs, dataset, psnr, ssim, file_name)
126 |
127 |
128 |
129 | if __name__ == '__main__':
130 | parser = argparse.ArgumentParser()
131 | parser.add_argument('-m', '--model', default='latest_net')
132 |
133 | args = parser.parse_args(sys.argv[1:])
134 |
135 | run_show(args.model)
136 |
137 |
--------------------------------------------------------------------------------
/code/base/rain100H/config/tensorboard.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | function rand() {
3 | min=$1
4 | max=$(($2-$min+1))
5 | num=$(($RANDOM+1000000000000))
6 | echo $(($num%$max+$min))
7 | }
8 |
9 | rnd=$(rand 3000 12000)
10 | tensorboard --logdir ../logdir --host 0.0.0.0 --port $rnd --reload_interval 3
11 |
--------------------------------------------------------------------------------
/code/base/rain100H/models/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/models/.DS_Store
--------------------------------------------------------------------------------
/code/base/rain100H/models/README.md:
--------------------------------------------------------------------------------
1 | # JDNet:Joint Self-Attention and Scale-Aggregation for Self-Calibrated Deraining Network
2 |
3 | The training model or pre-training model will be placed here.
4 |
--------------------------------------------------------------------------------
/code/base/rain100L/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/.DS_Store
--------------------------------------------------------------------------------
/code/base/rain100L/config/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/config/.DS_Store
--------------------------------------------------------------------------------
/code/base/rain100L/config/cal_ssim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | import numpy as np
5 | from math import exp
6 |
7 | def gaussian(window_size, sigma):
8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
9 | return gauss/gauss.sum()
10 |
11 | def create_window(window_size, channel):
12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
15 | return window
16 |
17 | def _ssim(img1, img2, window, window_size, channel, size_average = True):
18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
20 |
21 | mu1_sq = mu1.pow(2)
22 | mu2_sq = mu2.pow(2)
23 | mu1_mu2 = mu1*mu2
24 |
25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
28 |
29 | C1 = 0.01**2
30 | C2 = 0.03**2
31 |
32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
33 |
34 | if size_average:
35 | return ssim_map.mean()
36 | else:
37 | return ssim_map.mean(1).mean(1).mean(1)
38 |
39 | class SSIM(torch.nn.Module):
40 | def __init__(self, window_size = 11, size_average = True):
41 | super(SSIM, self).__init__()
42 | self.window_size = window_size
43 | self.size_average = size_average
44 | self.channel = 3
45 | self.window = create_window(window_size, self.channel)
46 |
47 | def forward(self, img1, img2):
48 | #print(img1.size())
49 | (_, channel, _, _) = img1.size()
50 |
51 | if channel == self.channel and self.window.data.type() == img1.data.type():
52 | window = self.window
53 | else:
54 | window = create_window(self.window_size, channel)
55 |
56 | if img1.is_cuda:
57 | window = window.cuda(img1.get_device())
58 | window = window.type_as(img1)
59 |
60 | self.window = window
61 | self.channel = channel
62 |
63 |
64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
65 |
66 | def ssim(img1, img2, window_size = 11, size_average = True):
67 | (_, channel, _, _) = img1.size()
68 | window = create_window(window_size, channel)
69 |
70 | if img1.is_cuda:
71 | window = window.cuda(img1.get_device())
72 | window = window.type_as(img1)
73 |
74 | return _ssim(img1, img2, window, window_size, channel, size_average)
75 |
--------------------------------------------------------------------------------
/code/base/rain100L/config/clean.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | rm -rf __pycache__
3 | rm \.*\.swp
4 | rm -R ../logdir/*
5 | rm -R ../showdir/*
6 | rm ../models/*
7 |
--------------------------------------------------------------------------------
/code/base/rain100L/config/compile.py:
--------------------------------------------------------------------------------
1 | import re
2 | rec_data="+MIPLOBSERVE:0,68220,1,3303,0,-1"
3 | msgidRegex = re.compile(r',(\d)+,')
4 | mo = msgidRegex.search(rec_data)
5 | print(mo.group())
--------------------------------------------------------------------------------
/code/base/rain100L/config/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | from numpy.random import RandomState
5 | from torch.utils.data import Dataset
6 |
7 | import settings
8 |
9 |
10 | class TrainValDataset(Dataset):
11 | def __init__(self, name):
12 | super().__init__()
13 | self.rand_state = RandomState(66)
14 | self.root_dir = os.path.join(settings.data_dir, name)
15 | self.mat_files = os.listdir(self.root_dir)
16 | self.patch_size = settings.patch_size
17 | self.file_num = len(self.mat_files)
18 |
19 | def __len__(self):
20 | return self.file_num * 100
21 |
22 | def __getitem__(self, idx):
23 | file_name = self.mat_files[idx % self.file_num]
24 | img_file = os.path.join(self.root_dir, file_name)
25 | img_pair = cv2.imread(img_file).astype(np.float32) / 255
26 |
27 | if settings.aug_data:
28 | O, B = self.crop(img_pair, aug=True)
29 | O, B = self.flip(O, B)
30 | O, B = self.rotate(O, B)
31 | else:
32 | O, B = self.crop(img_pair, aug=False)
33 |
34 | O = np.transpose(O, (2, 0, 1))
35 | B = np.transpose(B, (2, 0, 1))
36 | sample = {'O': O, 'B': B}
37 |
38 | return sample
39 |
40 | def crop(self, img_pair, aug):
41 | patch_size = self.patch_size
42 | h, ww, c = img_pair.shape
43 | w = int(ww / 2)
44 |
45 | if aug:
46 | mini = - 1 / 4 * self.patch_size
47 | maxi = 1 / 4 * self.patch_size + 1
48 | p_h = patch_size + self.rand_state.randint(mini, maxi)
49 | p_w = patch_size + self.rand_state.randint(mini, maxi)
50 | else:
51 | p_h, p_w = patch_size, patch_size
52 |
53 | r = self.rand_state.randint(0, h - p_h)
54 | c = self.rand_state.randint(0, w - p_w)
55 | O = img_pair[r: r+p_h, c+w: c+p_w+w]
56 | B = img_pair[r: r+p_h, c: c+p_w]
57 |
58 | if aug:
59 | O = cv2.resize(O, (patch_size, patch_size))
60 | B = cv2.resize(B, (patch_size, patch_size))
61 |
62 | return O, B
63 |
64 | def flip(self, O, B):
65 | if self.rand_state.rand() > 0.5:
66 | O = np.flip(O, axis=1)
67 | B = np.flip(B, axis=1)
68 | return O, B
69 |
70 | def rotate(self, O, B):
71 | angle = self.rand_state.randint(-30, 30)
72 | patch_size = self.patch_size
73 | center = (int(patch_size / 2), int(patch_size / 2))
74 | M = cv2.getRotationMatrix2D(center, angle, 1)
75 | O = cv2.warpAffine(O, M, (patch_size, patch_size))
76 | B = cv2.warpAffine(B, M, (patch_size, patch_size))
77 | return O, B
78 |
79 |
80 | class TestDataset(Dataset):
81 | def __init__(self, name):
82 | super().__init__()
83 | self.rand_state = RandomState(66)
84 | self.root_dir = os.path.join(settings.data_dir, name)
85 | self.mat_files = os.listdir(self.root_dir)
86 | self.patch_size = settings.patch_size
87 | self.file_num = len(self.mat_files)
88 |
89 | def __len__(self):
90 | return self.file_num
91 |
92 | def __getitem__(self, idx):
93 | file_name = self.mat_files[idx % self.file_num]
94 | img_file = os.path.join(self.root_dir, file_name)
95 | img_pair = cv2.imread(img_file).astype(np.float32) / 255
96 | h, ww, c = img_pair.shape
97 | w = int(ww / 2)
98 | #h_8=h%8
99 | #w_8=w%8
100 | O = np.transpose(img_pair[:, w:], (2, 0, 1))
101 | B = np.transpose(img_pair[:, :w], (2, 0, 1))
102 | sample = {'O': O, 'B': B}
103 |
104 | return sample
105 |
106 |
107 | class ShowDataset(Dataset):
108 | def __init__(self, name):
109 | super().__init__()
110 | self.rand_state = RandomState(66)
111 | self.root_dir = os.path.join(settings.data_dir, name)
112 | self.img_files = sorted(os.listdir(self.root_dir))
113 | self.file_num = len(self.img_files)
114 |
115 | def __len__(self):
116 | return self.file_num
117 |
118 | def __getitem__(self, idx):
119 | file_name = self.img_files[idx % self.file_num]
120 | img_file = os.path.join(self.root_dir, file_name)
121 | print(img_file)
122 | img_pair = cv2.imread(img_file).astype(np.float32) / 255
123 |
124 | h, ww, c = img_pair.shape
125 | w = int(ww / 2)
126 |
127 | #h_8 = h % 8
128 | #w_8 = w % 8
129 | if settings.pic_is_pair:
130 | O = np.transpose(img_pair[:, w:], (2, 0, 1))
131 | B = np.transpose(img_pair[:, :w], (2, 0, 1))
132 | else:
133 | O = np.transpose(img_pair[:, :], (2, 0, 1))
134 | B = np.transpose(img_pair[:, :], (2, 0, 1))
135 | sample = {'O': O, 'B': B,'file_name':file_name[:-4]}
136 |
137 | return sample
138 |
139 |
140 | if __name__ == '__main__':
141 | dt = TrainValDataset('val')
142 | print('TrainValDataset')
143 | for i in range(10):
144 | smp = dt[i]
145 | for k, v in smp.items():
146 | print(k, v.shape, v.dtype, v.mean())
147 |
148 | print()
149 | dt = TestDataset('test')
150 | print('TestDataset')
151 | for i in range(10):
152 | smp = dt[i]
153 | for k, v in smp.items():
154 | print(k, v.shape, v.dtype, v.mean())
155 |
156 | print()
157 | print('ShowDataset')
158 | dt = ShowDataset('test')
159 | for i in range(10):
160 | smp = dt[i]
161 | for k, v in smp.items():
162 | print(k, v.shape, v.dtype, v.mean())
163 |
--------------------------------------------------------------------------------
/code/base/rain100L/config/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | import argparse
5 | import numpy as np
6 | import math
7 | import torch
8 | from torch import nn
9 | from torch.nn import MSELoss
10 | from torch.optim import Adam
11 | from torch.optim.lr_scheduler import MultiStepLR
12 | from torch.autograd import Variable
13 | from torch.utils.data import DataLoader
14 | from tensorboardX import SummaryWriter
15 |
16 | import settings
17 | from dataset import TestDataset
18 | from model import ODE_DerainNet
19 | from cal_ssim import SSIM
20 |
21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id
22 | logger = settings.logger
23 | torch.cuda.manual_seed_all(66)
24 | torch.manual_seed(66)
25 | #torch.cuda.set_device(settings.device_id)
26 |
27 |
28 | def ensure_dir(dir_path):
29 | if not os.path.isdir(dir_path):
30 | os.makedirs(dir_path)
31 | def PSNR(img1, img2):
32 | b,_,_,_=img1.shape
33 | #mse=0
34 | #for i in range(b):
35 | img1=np.clip(img1,0,255)
36 | img2=np.clip(img2,0,255)
37 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse
38 | if mse == 0:
39 | return 100
40 | #mse=mse/b
41 | PIXEL_MAX = 1
42 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
43 | class Session:
44 | def __init__(self):
45 | self.log_dir = settings.log_dir
46 | self.model_dir = settings.model_dir
47 | ensure_dir(settings.log_dir)
48 | ensure_dir(settings.model_dir)
49 | logger.info('set log dir as %s' % settings.log_dir)
50 | logger.info('set model dir as %s' % settings.model_dir)
51 | if len(settings.device_id) >1:
52 | self.net = nn.DataParallel(ODE_DerainNet()).cuda()
53 | else:
54 | torch.cuda.set_device(settings.device_id[0])
55 | self.net = ODE_DerainNet().cuda()
56 | self.l2 = MSELoss().cuda()
57 | self.l1 = nn.L1Loss().cuda()
58 | self.ssim = SSIM().cuda()
59 | self.dataloaders = {}
60 |
61 | def get_dataloader(self, dataset_name):
62 | dataset = TestDataset(dataset_name)
63 | if not dataset_name in self.dataloaders:
64 | self.dataloaders[dataset_name] = \
65 | DataLoader(dataset, batch_size=1,
66 | shuffle=False, num_workers=1, drop_last=False)
67 | return self.dataloaders[dataset_name]
68 |
69 | def load_checkpoints(self, name):
70 | ckp_path = os.path.join(self.model_dir, name)
71 | try:
72 | obj = torch.load(ckp_path)
73 | logger.info('Load checkpoint %s' % ckp_path)
74 | except FileNotFoundError:
75 | logger.info('No checkpoint %s!!' % ckp_path)
76 | return
77 | self.net.load_state_dict(obj['net'])
78 | def loss_vgg(self,input,groundtruth):
79 | vgg_gt = self.vgg.forward(groundtruth)
80 | eval = self.vgg.forward(input)
81 | loss_vgg = [self.l1(eval[m], vgg_gt[m]) for m in range(len(vgg_gt))]
82 | loss = sum(loss_vgg)
83 | return loss
84 |
85 | def inf_batch(self, name, batch):
86 | O, B = batch['O'].cuda(), batch['B'].cuda()
87 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False)
88 |
89 | with torch.no_grad():
90 | derain = self.net(O)
91 |
92 | l1_loss = self.l1(derain, B)
93 | ssim = self.ssim(derain, B)
94 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255)
95 | losses = { 'L1 loss' : l1_loss }
96 | ssimes = { 'ssim' : ssim }
97 | losses.update(ssimes)
98 |
99 | return losses, psnr
100 |
101 |
102 | def run_test(ckp_name):
103 | sess = Session()
104 | sess.net.eval()
105 | sess.load_checkpoints(ckp_name)
106 | dt = sess.get_dataloader('test')
107 | psnr_all = 0
108 | all_num = 0
109 | all_losses = {}
110 | for i, batch in enumerate(dt):
111 | losses,psnr= sess.inf_batch('test', batch)
112 | psnr_all=psnr_all+psnr
113 | batch_size = batch['O'].size(0)
114 | all_num += batch_size
115 | for key, val in losses.items():
116 | if i == 0:
117 | all_losses[key] = 0.
118 | all_losses[key] += val * batch_size
119 | logger.info('batch %d mse %s: %f' % (i, key, val))
120 |
121 | for key, val in all_losses.items():
122 | logger.info('total mse %s: %f' % (key, val / all_num))
123 | #psnr=sum(psnr_all)
124 | #print(psnr)
125 | print('psnr_ll:%8f'%(psnr_all/all_num))
126 |
127 | if __name__ == '__main__':
128 | parser = argparse.ArgumentParser()
129 | parser.add_argument('-m', '--model', default='latest_net')
130 |
131 | args = parser.parse_args(sys.argv[1:])
132 | run_test(args.model)
133 |
134 |
--------------------------------------------------------------------------------
/code/base/rain100L/config/function/__pycache__/functional.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/config/function/__pycache__/functional.cpython-37.pyc
--------------------------------------------------------------------------------
/code/base/rain100L/config/function/functional.py:
--------------------------------------------------------------------------------
1 | from . import functions
2 |
3 |
4 | def aggregation(input, weight, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1):
5 | assert input.shape[0] == weight.shape[0] and (input.shape[1] % weight.shape[1] == 0) and pad_mode in [0, 1]
6 | if input.is_cuda:
7 | if pad_mode == 0:
8 | out = functions.aggregation_zeropad(input, weight, kernel_size, stride, padding, dilation)
9 | elif pad_mode == 1:
10 | out = functions.aggregation_refpad(input, weight, kernel_size, stride, padding, dilation)
11 | else:
12 | raise NotImplementedError
13 | return out
14 |
15 |
16 | def subtraction(input, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1):
17 | assert input.dim() == 4 and pad_mode in [0, 1]
18 | if input.is_cuda:
19 | if pad_mode == 0:
20 | out = functions.subtraction_zeropad(input, kernel_size, stride, padding, dilation)
21 | elif pad_mode == 1:
22 | out = functions.subtraction_refpad(input, kernel_size, stride, padding, dilation)
23 | else:
24 | raise NotImplementedError
25 | return out
26 |
27 |
28 | def subtraction2(input1, input2, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1):
29 | assert input1.dim() == 4 and input2.dim() == 4 and pad_mode in [0, 1]
30 | if input1.is_cuda:
31 | if pad_mode == 0:
32 | out = functions.subtraction2_zeropad(input1, input2, kernel_size, stride, padding, dilation)
33 | elif pad_mode == 1:
34 | out = functions.subtraction2_refpad(input1, input2, kernel_size, stride, padding, dilation)
35 | else:
36 | raise NotImplementedError
37 | return out
38 |
--------------------------------------------------------------------------------
/code/base/rain100L/config/function/functions/__init__.py:
--------------------------------------------------------------------------------
1 | from .aggregation_zeropad import *
2 | from .aggregation_refpad import *
3 | from .subtraction_zeropad import *
4 | from .subtraction_refpad import *
5 | from .subtraction2_zeropad import *
6 | from .subtraction2_refpad import *
7 | from .utils import *
8 |
--------------------------------------------------------------------------------
/code/base/rain100L/config/function/functions/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/config/function/functions/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/code/base/rain100L/config/function/functions/__pycache__/aggregation_refpad.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/config/function/functions/__pycache__/aggregation_refpad.cpython-37.pyc
--------------------------------------------------------------------------------
/code/base/rain100L/config/function/functions/__pycache__/aggregation_zeropad.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/config/function/functions/__pycache__/aggregation_zeropad.cpython-37.pyc
--------------------------------------------------------------------------------
/code/base/rain100L/config/function/functions/__pycache__/subtraction2_refpad.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/config/function/functions/__pycache__/subtraction2_refpad.cpython-37.pyc
--------------------------------------------------------------------------------
/code/base/rain100L/config/function/functions/__pycache__/subtraction2_zeropad.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/config/function/functions/__pycache__/subtraction2_zeropad.cpython-37.pyc
--------------------------------------------------------------------------------
/code/base/rain100L/config/function/functions/__pycache__/subtraction_refpad.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/config/function/functions/__pycache__/subtraction_refpad.cpython-37.pyc
--------------------------------------------------------------------------------
/code/base/rain100L/config/function/functions/__pycache__/subtraction_zeropad.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/config/function/functions/__pycache__/subtraction_zeropad.cpython-37.pyc
--------------------------------------------------------------------------------
/code/base/rain100L/config/function/functions/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/config/function/functions/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/code/base/rain100L/config/function/functions/utils.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from string import Template
3 | import cupy
4 | import torch
5 |
6 |
7 | Stream = namedtuple('Stream', ['ptr'])
8 |
9 |
10 | def Dtype(t):
11 | if isinstance(t, torch.cuda.FloatTensor):
12 | return 'float'
13 | elif isinstance(t, torch.cuda.DoubleTensor):
14 | return 'double'
15 |
16 |
17 | @cupy.util.memoize(for_each_device=True)
18 | def load_kernel(kernel_name, code, **kwargs):
19 | code = Template(code).substitute(**kwargs)
20 | kernel_code = cupy.cuda.compile_with_cache(code)
21 | return kernel_code.get_function(kernel_name)
22 |
--------------------------------------------------------------------------------
/code/base/rain100L/config/function/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .aggregation import *
2 | from .subtraction import *
3 | from .subtraction2 import *
4 |
--------------------------------------------------------------------------------
/code/base/rain100L/config/function/modules/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/config/function/modules/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/code/base/rain100L/config/function/modules/__pycache__/aggregation.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/config/function/modules/__pycache__/aggregation.cpython-37.pyc
--------------------------------------------------------------------------------
/code/base/rain100L/config/function/modules/__pycache__/subtraction.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/config/function/modules/__pycache__/subtraction.cpython-37.pyc
--------------------------------------------------------------------------------
/code/base/rain100L/config/function/modules/__pycache__/subtraction2.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/config/function/modules/__pycache__/subtraction2.cpython-37.pyc
--------------------------------------------------------------------------------
/code/base/rain100L/config/function/modules/aggregation.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch.nn.modules.utils import _pair
3 |
4 | from .. import functional as F
5 |
6 |
7 | class Aggregation(nn.Module):
8 |
9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode):
10 | super(Aggregation, self).__init__()
11 | self.kernel_size = _pair(kernel_size)
12 | self.stride = _pair(stride)
13 | self.padding = _pair(padding)
14 | self.dilation = _pair(dilation)
15 | self.pad_mode = pad_mode
16 |
17 | def forward(self, input, weight):
18 | return F.aggregation(input, weight, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode)
19 |
--------------------------------------------------------------------------------
/code/base/rain100L/config/function/modules/subtraction.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch.nn.modules.utils import _pair
3 |
4 | from .. import functional as F
5 |
6 |
7 | class Subtraction(nn.Module):
8 |
9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode):
10 | super(Subtraction, self).__init__()
11 | self.kernel_size = _pair(kernel_size)
12 | self.stride = _pair(stride)
13 | self.padding = _pair(padding)
14 | self.dilation = _pair(dilation)
15 | self.pad_mode = pad_mode
16 |
17 | def forward(self, input):
18 | return F.subtraction(input, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode)
19 |
--------------------------------------------------------------------------------
/code/base/rain100L/config/function/modules/subtraction2.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch.nn.modules.utils import _pair
3 |
4 | from .. import functional as F
5 |
6 |
7 | class Subtraction2(nn.Module):
8 |
9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode):
10 | super(Subtraction2, self).__init__()
11 | self.kernel_size = _pair(kernel_size)
12 | self.stride = _pair(stride)
13 | self.padding = _pair(padding)
14 | self.dilation = _pair(dilation)
15 | self.pad_mode = pad_mode
16 |
17 | def forward(self, input1, input2):
18 | return F.subtraction2(input1, input2, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode)
19 |
--------------------------------------------------------------------------------
/code/base/rain100L/config/settings.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 |
4 |
5 | ####################### 网络参数设置 ############################################################
6 | channel = 32
7 | feature_map_num = 32
8 | res_conv_num = 4
9 | unit_num = 32
10 | #scale_num = 4
11 | num_scale_attention = 4
12 | scale_attention = False
13 | ssim_loss = True
14 | ########################################################################################
15 | aug_data = False # Set as False for fair comparison
16 |
17 | patch_size = 64
18 | pic_is_pair = True #input picture is pair or single
19 |
20 | lr = 0.0005
21 |
22 | data_dir = '/data1/dataset/haze_blur_rain/rain100L'
23 | if pic_is_pair is False:
24 | data_dir = '/data1/wangcong/dataset/real-world-images'
25 | log_dir = '../logdir'
26 | show_dir = '../showdir'
27 | model_dir = '../models'
28 | show_dir_feature = '../showdir_feature'
29 |
30 | log_level = 'info'
31 | model_path = os.path.join(model_dir, 'latest_net')
32 | save_steps = 400
33 |
34 | num_workers = 8
35 |
36 | num_GPU = 4
37 |
38 | device_id = '0,1,2,3'
39 |
40 | epoch = 1000
41 | batch_size = 32
42 |
43 | if pic_is_pair:
44 | root_dir = os.path.join(data_dir, 'train')
45 | mat_files = os.listdir(root_dir)
46 | num_datasets = len(mat_files)
47 | l1 = int(3/5 * epoch * num_datasets / batch_size)
48 | l2 = int(4/5 * epoch * num_datasets / batch_size)
49 | one_epoch = int(num_datasets/batch_size)
50 | total_step = int((epoch * num_datasets)/batch_size)
51 |
52 | logger = logging.getLogger('train')
53 | logger.setLevel(logging.INFO)
54 |
55 | ch = logging.StreamHandler()
56 | ch.setLevel(logging.INFO)
57 |
58 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
59 | ch.setFormatter(formatter)
60 | logger.addHandler(ch)
61 |
62 |
63 |
--------------------------------------------------------------------------------
/code/base/rain100L/config/show.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | import argparse
5 | import math
6 | import numpy as np
7 | import itertools
8 |
9 | import torch
10 | from torch import nn
11 | from torch.nn import DataParallel
12 | from torch.optim import Adam
13 | from torch.autograd import Variable
14 | from torch.utils.data import DataLoader
15 |
16 | import settings
17 | from dataset import ShowDataset
18 | from model import ODE_DerainNet
19 | from cal_ssim import SSIM
20 |
21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id
22 | logger = settings.logger
23 | torch.cuda.manual_seed_all(66)
24 | torch.manual_seed(66)
25 | #torch.cuda.set_device(settings.device_id)
26 |
27 |
28 | def ensure_dir(dir_path):
29 | if not os.path.isdir(dir_path):
30 | os.makedirs(dir_path)
31 |
32 | def PSNR(img1, img2):
33 | b,_,_,_=img1.shape
34 | #mse=0
35 | #for i in range(b):
36 | img1=np.clip(img1,0,255)
37 | img2=np.clip(img2,0,255)
38 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse
39 | if mse == 0:
40 | return 100
41 | #mse=mse/b
42 | PIXEL_MAX = 1
43 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
44 |
45 | class Session:
46 | def __init__(self):
47 | self.show_dir = settings.show_dir
48 | self.model_dir = settings.model_dir
49 | ensure_dir(settings.show_dir)
50 | ensure_dir(settings.model_dir)
51 | logger.info('set show dir as %s' % settings.show_dir)
52 | logger.info('set model dir as %s' % settings.model_dir)
53 |
54 | if len(settings.device_id) >1:
55 | self.net = nn.DataParallel(ODE_DerainNet()).cuda()
56 | #self.l2 = nn.DataParallel(MSELoss(),settings.device_id)
57 | #self.l1 = nn.DataParallel(nn.L1Loss(),settings.device_id)
58 | #self.ssim = nn.DataParallel(SSIM(),settings.device_id)
59 | #self.vgg = nn.DataParallel(VGG(),settings.device_id)
60 | else:
61 | torch.cuda.set_device(settings.device_id[0])
62 | self.net = ODE_DerainNet().cuda()
63 | self.ssim = SSIM().cuda()
64 | self.dataloaders = {}
65 | self.ssim=SSIM().cuda()
66 | self.a=0
67 | self.t=0
68 | def get_dataloader(self, dataset_name):
69 | dataset = ShowDataset(dataset_name)
70 | self.dataloaders[dataset_name] = \
71 | DataLoader(dataset, batch_size=1,
72 | shuffle=False, num_workers=1)
73 | return self.dataloaders[dataset_name]
74 |
75 | def load_checkpoints(self, name):
76 | ckp_path = os.path.join(self.model_dir, name)
77 | try:
78 | obj = torch.load(ckp_path)
79 | logger.info('Load checkpoint %s' % ckp_path)
80 | except FileNotFoundError:
81 | logger.info('No checkpoint %s!!' % ckp_path)
82 | return
83 | self.net.load_state_dict(obj['net'])
84 | def inf_batch(self, name, batch,i):
85 | O, B, file_name= batch['O'].cuda(), batch['B'].cuda(), batch['file_name']
86 | file_name = str(file_name[0])
87 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False)
88 | with torch.no_grad():
89 | import time
90 | t0=time.time()
91 | derain = self.net(O)
92 | t1 = time.time()
93 | comput_time=t1-t0
94 | print(comput_time)
95 | ssim = self.ssim(derain, B).data.cpu().numpy()
96 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255)
97 | print('psnr:%4f-------------ssim:%4f'%(psnr, ssim))
98 | return derain, psnr, ssim, file_name
99 |
100 | def save_image(self, No, imgs, name, psnr, ssim, file_name):
101 | for i, img in enumerate(imgs):
102 | img = (img.cpu().data * 255).numpy()
103 | img = np.clip(img, 0, 255)
104 | img = np.transpose(img, (1, 2, 0))
105 | h, w, c = img.shape
106 |
107 | img_file = os.path.join(self.show_dir, '%s.png' % (file_name))
108 | print(img_file)
109 | cv2.imwrite(img_file, img)
110 |
111 |
112 | def run_show(ckp_name):
113 | sess = Session()
114 | sess.load_checkpoints(ckp_name)
115 | sess.net.eval()
116 | dataset = 'test'
117 | if settings.pic_is_pair is False:
118 | dataset = 'train-w'
119 | dt = sess.get_dataloader(dataset)
120 |
121 | for i, batch in enumerate(dt):
122 | logger.info(i)
123 | if i>-1:
124 | imgs,psnr,ssim, file_name= sess.inf_batch('test', batch,i)
125 | sess.save_image(i, imgs, dataset, psnr, ssim, file_name)
126 |
127 |
128 |
129 | if __name__ == '__main__':
130 | parser = argparse.ArgumentParser()
131 | parser.add_argument('-m', '--model', default='latest_net')
132 |
133 | args = parser.parse_args(sys.argv[1:])
134 |
135 | run_show(args.model)
136 |
137 |
--------------------------------------------------------------------------------
/code/base/rain100L/config/tensorboard.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | function rand() {
3 | min=$1
4 | max=$(($2-$min+1))
5 | num=$(($RANDOM+1000000000000))
6 | echo $(($num%$max+$min))
7 | }
8 |
9 | rnd=$(rand 3000 12000)
10 | tensorboard --logdir ../logdir --host 0.0.0.0 --port $rnd --reload_interval 3
11 |
--------------------------------------------------------------------------------
/code/base/rain100L/models/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/models/.DS_Store
--------------------------------------------------------------------------------
/code/base/rain100L/models/README.md:
--------------------------------------------------------------------------------
1 | # JDNet:Joint Self-Attention and Scale-Aggregation for Self-Calibrated Deraining Network
2 |
3 | The training model or pre-training model will be placed here.
4 |
--------------------------------------------------------------------------------
/code/diff_loss/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/.DS_Store
--------------------------------------------------------------------------------
/code/diff_loss/mae/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/.DS_Store
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/.DS_Store
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/cal_ssim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | import numpy as np
5 | from math import exp
6 |
7 | def gaussian(window_size, sigma):
8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
9 | return gauss/gauss.sum()
10 |
11 | def create_window(window_size, channel):
12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
15 | return window
16 |
17 | def _ssim(img1, img2, window, window_size, channel, size_average = True):
18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
20 |
21 | mu1_sq = mu1.pow(2)
22 | mu2_sq = mu2.pow(2)
23 | mu1_mu2 = mu1*mu2
24 |
25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
28 |
29 | C1 = 0.01**2
30 | C2 = 0.03**2
31 |
32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
33 |
34 | if size_average:
35 | return ssim_map.mean()
36 | else:
37 | return ssim_map.mean(1).mean(1).mean(1)
38 |
39 | class SSIM(torch.nn.Module):
40 | def __init__(self, window_size = 11, size_average = True):
41 | super(SSIM, self).__init__()
42 | self.window_size = window_size
43 | self.size_average = size_average
44 | self.channel = 3
45 | self.window = create_window(window_size, self.channel)
46 |
47 | def forward(self, img1, img2):
48 | #print(img1.size())
49 | (_, channel, _, _) = img1.size()
50 |
51 | if channel == self.channel and self.window.data.type() == img1.data.type():
52 | window = self.window
53 | else:
54 | window = create_window(self.window_size, channel)
55 |
56 | if img1.is_cuda:
57 | window = window.cuda(img1.get_device())
58 | window = window.type_as(img1)
59 |
60 | self.window = window
61 | self.channel = channel
62 |
63 |
64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
65 |
66 | def ssim(img1, img2, window_size = 11, size_average = True):
67 | (_, channel, _, _) = img1.size()
68 | window = create_window(window_size, channel)
69 |
70 | if img1.is_cuda:
71 | window = window.cuda(img1.get_device())
72 | window = window.type_as(img1)
73 |
74 | return _ssim(img1, img2, window, window_size, channel, size_average)
75 |
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/clean.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | rm -rf __pycache__
3 | rm \.*\.swp
4 | rm -R ../logdir/*
5 | rm -R ../showdir/*
6 | rm ../models/*
7 |
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/compile.py:
--------------------------------------------------------------------------------
1 | import re
2 | rec_data="+MIPLOBSERVE:0,68220,1,3303,0,-1"
3 | msgidRegex = re.compile(r',(\d)+,')
4 | mo = msgidRegex.search(rec_data)
5 | print(mo.group())
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | from numpy.random import RandomState
5 | from torch.utils.data import Dataset
6 |
7 | import settings
8 |
9 |
10 | class TrainValDataset(Dataset):
11 | def __init__(self, name):
12 | super().__init__()
13 | self.rand_state = RandomState(66)
14 | self.root_dir = os.path.join(settings.data_dir, name)
15 | self.mat_files = os.listdir(self.root_dir)
16 | self.patch_size = settings.patch_size
17 | self.file_num = len(self.mat_files)
18 |
19 | def __len__(self):
20 | return self.file_num * 100
21 |
22 | def __getitem__(self, idx):
23 | file_name = self.mat_files[idx % self.file_num]
24 | img_file = os.path.join(self.root_dir, file_name)
25 | img_pair = cv2.imread(img_file).astype(np.float32) / 255
26 |
27 | if settings.aug_data:
28 | O, B = self.crop(img_pair, aug=True)
29 | O, B = self.flip(O, B)
30 | O, B = self.rotate(O, B)
31 | else:
32 | O, B = self.crop(img_pair, aug=False)
33 |
34 | O = np.transpose(O, (2, 0, 1))
35 | B = np.transpose(B, (2, 0, 1))
36 | sample = {'O': O, 'B': B}
37 |
38 | return sample
39 |
40 | def crop(self, img_pair, aug):
41 | patch_size = self.patch_size
42 | h, ww, c = img_pair.shape
43 | w = int(ww / 2)
44 |
45 | if aug:
46 | mini = - 1 / 4 * self.patch_size
47 | maxi = 1 / 4 * self.patch_size + 1
48 | p_h = patch_size + self.rand_state.randint(mini, maxi)
49 | p_w = patch_size + self.rand_state.randint(mini, maxi)
50 | else:
51 | p_h, p_w = patch_size, patch_size
52 |
53 | r = self.rand_state.randint(0, h - p_h)
54 | c = self.rand_state.randint(0, w - p_w)
55 | O = img_pair[r: r+p_h, c+w: c+p_w+w]
56 | B = img_pair[r: r+p_h, c: c+p_w]
57 |
58 | if aug:
59 | O = cv2.resize(O, (patch_size, patch_size))
60 | B = cv2.resize(B, (patch_size, patch_size))
61 |
62 | return O, B
63 |
64 | def flip(self, O, B):
65 | if self.rand_state.rand() > 0.5:
66 | O = np.flip(O, axis=1)
67 | B = np.flip(B, axis=1)
68 | return O, B
69 |
70 | def rotate(self, O, B):
71 | angle = self.rand_state.randint(-30, 30)
72 | patch_size = self.patch_size
73 | center = (int(patch_size / 2), int(patch_size / 2))
74 | M = cv2.getRotationMatrix2D(center, angle, 1)
75 | O = cv2.warpAffine(O, M, (patch_size, patch_size))
76 | B = cv2.warpAffine(B, M, (patch_size, patch_size))
77 | return O, B
78 |
79 |
80 | class TestDataset(Dataset):
81 | def __init__(self, name):
82 | super().__init__()
83 | self.rand_state = RandomState(66)
84 | self.root_dir = os.path.join(settings.data_dir, name)
85 | self.mat_files = os.listdir(self.root_dir)
86 | self.patch_size = settings.patch_size
87 | self.file_num = len(self.mat_files)
88 |
89 | def __len__(self):
90 | return self.file_num
91 |
92 | def __getitem__(self, idx):
93 | file_name = self.mat_files[idx % self.file_num]
94 | img_file = os.path.join(self.root_dir, file_name)
95 | img_pair = cv2.imread(img_file).astype(np.float32) / 255
96 | h, ww, c = img_pair.shape
97 | w = int(ww / 2)
98 | #h_8=h%8
99 | #w_8=w%8
100 | O = np.transpose(img_pair[:, w:], (2, 0, 1))
101 | B = np.transpose(img_pair[:, :w], (2, 0, 1))
102 | sample = {'O': O, 'B': B}
103 |
104 | return sample
105 |
106 |
107 | class ShowDataset(Dataset):
108 | def __init__(self, name):
109 | super().__init__()
110 | self.rand_state = RandomState(66)
111 | self.root_dir = os.path.join(settings.data_dir, name)
112 | self.img_files = sorted(os.listdir(self.root_dir))
113 | self.file_num = len(self.img_files)
114 |
115 | def __len__(self):
116 | return self.file_num
117 |
118 | def __getitem__(self, idx):
119 | file_name = self.img_files[idx % self.file_num]
120 | img_file = os.path.join(self.root_dir, file_name)
121 | print(img_file)
122 | img_pair = cv2.imread(img_file).astype(np.float32) / 255
123 |
124 | h, ww, c = img_pair.shape
125 | w = int(ww / 2)
126 |
127 | #h_8 = h % 8
128 | #w_8 = w % 8
129 | if settings.pic_is_pair:
130 | O = np.transpose(img_pair[:, w:], (2, 0, 1))
131 | B = np.transpose(img_pair[:, :w], (2, 0, 1))
132 | else:
133 | O = np.transpose(img_pair[:, :], (2, 0, 1))
134 | B = np.transpose(img_pair[:, :], (2, 0, 1))
135 | sample = {'O': O, 'B': B,'file_name':file_name[:-4]}
136 |
137 | return sample
138 |
139 |
140 | if __name__ == '__main__':
141 | dt = TrainValDataset('val')
142 | print('TrainValDataset')
143 | for i in range(10):
144 | smp = dt[i]
145 | for k, v in smp.items():
146 | print(k, v.shape, v.dtype, v.mean())
147 |
148 | print()
149 | dt = TestDataset('test')
150 | print('TestDataset')
151 | for i in range(10):
152 | smp = dt[i]
153 | for k, v in smp.items():
154 | print(k, v.shape, v.dtype, v.mean())
155 |
156 | print()
157 | print('ShowDataset')
158 | dt = ShowDataset('test')
159 | for i in range(10):
160 | smp = dt[i]
161 | for k, v in smp.items():
162 | print(k, v.shape, v.dtype, v.mean())
163 |
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | import argparse
5 | import numpy as np
6 | import math
7 | import torch
8 | from torch import nn
9 | from torch.nn import MSELoss
10 | from torch.optim import Adam
11 | from torch.optim.lr_scheduler import MultiStepLR
12 | from torch.autograd import Variable
13 | from torch.utils.data import DataLoader
14 | from tensorboardX import SummaryWriter
15 |
16 | import settings
17 | from dataset import TestDataset
18 | from model import ODE_DerainNet
19 | from cal_ssim import SSIM
20 |
21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id
22 | logger = settings.logger
23 | torch.cuda.manual_seed_all(66)
24 | torch.manual_seed(66)
25 | #torch.cuda.set_device(settings.device_id)
26 |
27 |
28 | def ensure_dir(dir_path):
29 | if not os.path.isdir(dir_path):
30 | os.makedirs(dir_path)
31 | def PSNR(img1, img2):
32 | b,_,_,_=img1.shape
33 | #mse=0
34 | #for i in range(b):
35 | img1=np.clip(img1,0,255)
36 | img2=np.clip(img2,0,255)
37 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse
38 | if mse == 0:
39 | return 100
40 | #mse=mse/b
41 | PIXEL_MAX = 1
42 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
43 | class Session:
44 | def __init__(self):
45 | self.log_dir = settings.log_dir
46 | self.model_dir = settings.model_dir
47 | ensure_dir(settings.log_dir)
48 | ensure_dir(settings.model_dir)
49 | logger.info('set log dir as %s' % settings.log_dir)
50 | logger.info('set model dir as %s' % settings.model_dir)
51 | if len(settings.device_id) >1:
52 | self.net = nn.DataParallel(ODE_DerainNet()).cuda()
53 | else:
54 | torch.cuda.set_device(settings.device_id[0])
55 | self.net = ODE_DerainNet().cuda()
56 | self.l2 = MSELoss().cuda()
57 | self.l1 = nn.L1Loss().cuda()
58 | self.ssim = SSIM().cuda()
59 | self.dataloaders = {}
60 |
61 | def get_dataloader(self, dataset_name):
62 | dataset = TestDataset(dataset_name)
63 | if not dataset_name in self.dataloaders:
64 | self.dataloaders[dataset_name] = \
65 | DataLoader(dataset, batch_size=1,
66 | shuffle=False, num_workers=1, drop_last=False)
67 | return self.dataloaders[dataset_name]
68 |
69 | def load_checkpoints(self, name):
70 | ckp_path = os.path.join(self.model_dir, name)
71 | try:
72 | obj = torch.load(ckp_path)
73 | logger.info('Load checkpoint %s' % ckp_path)
74 | except FileNotFoundError:
75 | logger.info('No checkpoint %s!!' % ckp_path)
76 | return
77 | self.net.load_state_dict(obj['net'])
78 |
79 | def inf_batch(self, name, batch):
80 | O, B = batch['O'].cuda(), batch['B'].cuda()
81 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False)
82 |
83 | with torch.no_grad():
84 | derain = self.net(O)
85 |
86 | l1_loss = self.l1(derain, B)
87 | ssim = self.ssim(derain, B)
88 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255)
89 | losses = { 'L1 loss' : l1_loss }
90 | ssimes = { 'ssim' : ssim }
91 | losses.update(ssimes)
92 |
93 | return losses, psnr
94 |
95 |
96 | def run_test(ckp_name):
97 | sess = Session()
98 | sess.net.eval()
99 | sess.load_checkpoints(ckp_name)
100 | dt = sess.get_dataloader('test')
101 | psnr_all = 0
102 | all_num = 0
103 | all_losses = {}
104 | for i, batch in enumerate(dt):
105 | losses,psnr= sess.inf_batch('test', batch)
106 | psnr_all=psnr_all+psnr
107 | batch_size = batch['O'].size(0)
108 | all_num += batch_size
109 | for key, val in losses.items():
110 | if i == 0:
111 | all_losses[key] = 0.
112 | all_losses[key] += val * batch_size
113 | logger.info('batch %d mse %s: %f' % (i, key, val))
114 |
115 | for key, val in all_losses.items():
116 | logger.info('total mse %s: %f' % (key, val / all_num))
117 | #psnr=sum(psnr_all)
118 | #print(psnr)
119 | print('psnr_ll:%8f'%(psnr_all/all_num))
120 |
121 | if __name__ == '__main__':
122 | parser = argparse.ArgumentParser()
123 | parser.add_argument('-m', '--model', default='latest_net')
124 |
125 | args = parser.parse_args(sys.argv[1:])
126 | run_test(args.model)
127 |
128 |
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/__pycache__/functional.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/__pycache__/functional.cpython-36.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/__pycache__/functional.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/__pycache__/functional.cpython-37.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/functional.py:
--------------------------------------------------------------------------------
1 | from . import functions
2 |
3 |
4 | def aggregation(input, weight, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1):
5 | assert input.shape[0] == weight.shape[0] and (input.shape[1] % weight.shape[1] == 0) and pad_mode in [0, 1]
6 | if input.is_cuda:
7 | if pad_mode == 0:
8 | out = functions.aggregation_zeropad(input, weight, kernel_size, stride, padding, dilation)
9 | elif pad_mode == 1:
10 | out = functions.aggregation_refpad(input, weight, kernel_size, stride, padding, dilation)
11 | else:
12 | raise NotImplementedError
13 | return out
14 |
15 |
16 | def subtraction(input, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1):
17 | assert input.dim() == 4 and pad_mode in [0, 1]
18 | if input.is_cuda:
19 | if pad_mode == 0:
20 | out = functions.subtraction_zeropad(input, kernel_size, stride, padding, dilation)
21 | elif pad_mode == 1:
22 | out = functions.subtraction_refpad(input, kernel_size, stride, padding, dilation)
23 | else:
24 | raise NotImplementedError
25 | return out
26 |
27 |
28 | def subtraction2(input1, input2, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1):
29 | assert input1.dim() == 4 and input2.dim() == 4 and pad_mode in [0, 1]
30 | if input1.is_cuda:
31 | if pad_mode == 0:
32 | out = functions.subtraction2_zeropad(input1, input2, kernel_size, stride, padding, dilation)
33 | elif pad_mode == 1:
34 | out = functions.subtraction2_refpad(input1, input2, kernel_size, stride, padding, dilation)
35 | else:
36 | raise NotImplementedError
37 | return out
38 |
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/functions/__init__.py:
--------------------------------------------------------------------------------
1 | from .aggregation_zeropad import *
2 | from .aggregation_refpad import *
3 | from .subtraction_zeropad import *
4 | from .subtraction_refpad import *
5 | from .subtraction2_zeropad import *
6 | from .subtraction2_refpad import *
7 | from .utils import *
8 |
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/functions/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/functions/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/functions/__pycache__/aggregation_refpad.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/aggregation_refpad.cpython-36.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/functions/__pycache__/aggregation_refpad.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/aggregation_refpad.cpython-37.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/functions/__pycache__/aggregation_zeropad.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/aggregation_zeropad.cpython-36.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/functions/__pycache__/aggregation_zeropad.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/aggregation_zeropad.cpython-37.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/functions/__pycache__/subtraction2_refpad.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/subtraction2_refpad.cpython-36.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/functions/__pycache__/subtraction2_refpad.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/subtraction2_refpad.cpython-37.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/functions/__pycache__/subtraction2_zeropad.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/subtraction2_zeropad.cpython-36.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/functions/__pycache__/subtraction2_zeropad.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/subtraction2_zeropad.cpython-37.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/functions/__pycache__/subtraction_refpad.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/subtraction_refpad.cpython-36.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/functions/__pycache__/subtraction_refpad.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/subtraction_refpad.cpython-37.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/functions/__pycache__/subtraction_zeropad.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/subtraction_zeropad.cpython-36.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/functions/__pycache__/subtraction_zeropad.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/subtraction_zeropad.cpython-37.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/functions/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/functions/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/functions/utils.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from string import Template
3 | import cupy
4 | import torch
5 |
6 |
7 | Stream = namedtuple('Stream', ['ptr'])
8 |
9 |
10 | def Dtype(t):
11 | if isinstance(t, torch.cuda.FloatTensor):
12 | return 'float'
13 | elif isinstance(t, torch.cuda.DoubleTensor):
14 | return 'double'
15 |
16 |
17 | @cupy.util.memoize(for_each_device=True)
18 | def load_kernel(kernel_name, code, **kwargs):
19 | code = Template(code).substitute(**kwargs)
20 | kernel_code = cupy.cuda.compile_with_cache(code)
21 | return kernel_code.get_function(kernel_name)
22 |
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .aggregation import *
2 | from .subtraction import *
3 | from .subtraction2 import *
4 |
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/modules/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/modules/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/modules/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/modules/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/modules/__pycache__/aggregation.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/modules/__pycache__/aggregation.cpython-36.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/modules/__pycache__/aggregation.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/modules/__pycache__/aggregation.cpython-37.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/modules/__pycache__/subtraction.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/modules/__pycache__/subtraction.cpython-36.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/modules/__pycache__/subtraction.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/modules/__pycache__/subtraction.cpython-37.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/modules/__pycache__/subtraction2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/modules/__pycache__/subtraction2.cpython-36.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/modules/__pycache__/subtraction2.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/modules/__pycache__/subtraction2.cpython-37.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/modules/aggregation.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch.nn.modules.utils import _pair
3 |
4 | from .. import functional as F
5 |
6 |
7 | class Aggregation(nn.Module):
8 |
9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode):
10 | super(Aggregation, self).__init__()
11 | self.kernel_size = _pair(kernel_size)
12 | self.stride = _pair(stride)
13 | self.padding = _pair(padding)
14 | self.dilation = _pair(dilation)
15 | self.pad_mode = pad_mode
16 |
17 | def forward(self, input, weight):
18 | return F.aggregation(input, weight, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode)
19 |
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/modules/subtraction.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch.nn.modules.utils import _pair
3 |
4 | from .. import functional as F
5 |
6 |
7 | class Subtraction(nn.Module):
8 |
9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode):
10 | super(Subtraction, self).__init__()
11 | self.kernel_size = _pair(kernel_size)
12 | self.stride = _pair(stride)
13 | self.padding = _pair(padding)
14 | self.dilation = _pair(dilation)
15 | self.pad_mode = pad_mode
16 |
17 | def forward(self, input):
18 | return F.subtraction(input, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode)
19 |
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/function/modules/subtraction2.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch.nn.modules.utils import _pair
3 |
4 | from .. import functional as F
5 |
6 |
7 | class Subtraction2(nn.Module):
8 |
9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode):
10 | super(Subtraction2, self).__init__()
11 | self.kernel_size = _pair(kernel_size)
12 | self.stride = _pair(stride)
13 | self.padding = _pair(padding)
14 | self.dilation = _pair(dilation)
15 | self.pad_mode = pad_mode
16 |
17 | def forward(self, input1, input2):
18 | return F.subtraction2(input1, input2, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode)
19 |
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/settings.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 |
4 |
5 | ####################### 网络参数设置 ############################################################
6 | channel = 32
7 | feature_map_num = 32
8 | res_conv_num = 4
9 | unit_num = 32
10 | #scale_num = 4
11 | num_scale_attention = 4
12 | scale_attention = False
13 | ssim_loss = True
14 | ########################################################################################
15 | aug_data = False # Set as False for fair comparison
16 |
17 | patch_size = 64
18 | pic_is_pair = True #input picture is pair or single
19 |
20 | lr = 0.0005
21 |
22 | data_dir = '/data1/dataset/haze_blur_rain/rain100H'
23 | if pic_is_pair is False:
24 | data_dir = '/data1/wangcong/dataset/real-world-images'
25 | log_dir = '../logdir'
26 | show_dir = '../showdir'
27 | model_dir = '../models'
28 | show_dir_feature = '../showdir_feature'
29 |
30 | log_level = 'info'
31 | model_path = os.path.join(model_dir, 'latest_net')
32 | save_steps = 400
33 |
34 | num_workers = 8
35 |
36 | num_GPU = 2
37 |
38 | device_id = '0,1'
39 |
40 | epoch = 1000
41 | batch_size = 12
42 |
43 | if pic_is_pair:
44 | root_dir = os.path.join(data_dir, 'train')
45 | mat_files = os.listdir(root_dir)
46 | num_datasets = len(mat_files)
47 | l1 = int(3/5 * epoch * num_datasets / batch_size)
48 | l2 = int(4/5 * epoch * num_datasets / batch_size)
49 | one_epoch = int(num_datasets/batch_size)
50 | total_step = int((epoch * num_datasets)/batch_size)
51 |
52 | logger = logging.getLogger('train')
53 | logger.setLevel(logging.INFO)
54 |
55 | ch = logging.StreamHandler()
56 | ch.setLevel(logging.INFO)
57 |
58 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
59 | ch.setFormatter(formatter)
60 | logger.addHandler(ch)
61 |
62 |
63 |
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/show.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | import argparse
5 | import math
6 | import numpy as np
7 | import itertools
8 |
9 | import torch
10 | from torch import nn
11 | from torch.nn import DataParallel
12 | from torch.optim import Adam
13 | from torch.autograd import Variable
14 | from torch.utils.data import DataLoader
15 |
16 | import settings
17 | from dataset import ShowDataset
18 | from model import ODE_DerainNet
19 | from cal_ssim import SSIM
20 |
21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id
22 | logger = settings.logger
23 | torch.cuda.manual_seed_all(66)
24 | torch.manual_seed(66)
25 | #torch.cuda.set_device(settings.device_id)
26 |
27 |
28 | def ensure_dir(dir_path):
29 | if not os.path.isdir(dir_path):
30 | os.makedirs(dir_path)
31 |
32 | def PSNR(img1, img2):
33 | b,_,_,_=img1.shape
34 | #mse=0
35 | #for i in range(b):
36 | img1=np.clip(img1,0,255)
37 | img2=np.clip(img2,0,255)
38 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse
39 | if mse == 0:
40 | return 100
41 | #mse=mse/b
42 | PIXEL_MAX = 1
43 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
44 |
45 | class Session:
46 | def __init__(self):
47 | self.show_dir = settings.show_dir
48 | self.model_dir = settings.model_dir
49 | ensure_dir(settings.show_dir)
50 | ensure_dir(settings.model_dir)
51 | logger.info('set show dir as %s' % settings.show_dir)
52 | logger.info('set model dir as %s' % settings.model_dir)
53 |
54 | if len(settings.device_id) >1:
55 | self.net = nn.DataParallel(ODE_DerainNet()).cuda()
56 | #self.l2 = nn.DataParallel(MSELoss(),settings.device_id)
57 | #self.l1 = nn.DataParallel(nn.L1Loss(),settings.device_id)
58 | #self.ssim = nn.DataParallel(SSIM(),settings.device_id)
59 | #self.vgg = nn.DataParallel(VGG(),settings.device_id)
60 | else:
61 | torch.cuda.set_device(settings.device_id[0])
62 | self.net = ODE_DerainNet().cuda()
63 | self.ssim = SSIM().cuda()
64 | self.dataloaders = {}
65 | self.ssim=SSIM().cuda()
66 | self.a=0
67 | self.t=0
68 | def get_dataloader(self, dataset_name):
69 | dataset = ShowDataset(dataset_name)
70 | self.dataloaders[dataset_name] = \
71 | DataLoader(dataset, batch_size=1,
72 | shuffle=False, num_workers=1)
73 | return self.dataloaders[dataset_name]
74 |
75 | def load_checkpoints(self, name):
76 | ckp_path = os.path.join(self.model_dir, name)
77 | try:
78 | obj = torch.load(ckp_path)
79 | logger.info('Load checkpoint %s' % ckp_path)
80 | except FileNotFoundError:
81 | logger.info('No checkpoint %s!!' % ckp_path)
82 | return
83 | self.net.load_state_dict(obj['net'])
84 | def inf_batch(self, name, batch,i):
85 | O, B, file_name= batch['O'].cuda(), batch['B'].cuda(), batch['file_name']
86 | file_name = str(file_name[0])
87 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False)
88 | with torch.no_grad():
89 | import time
90 | t0=time.time()
91 | derain = self.net(O)
92 | t1 = time.time()
93 | comput_time=t1-t0
94 | print(comput_time)
95 | ssim = self.ssim(derain, B).data.cpu().numpy()
96 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255)
97 | print('psnr:%4f-------------ssim:%4f'%(psnr, ssim))
98 | return derain, psnr, ssim, file_name
99 |
100 | def save_image(self, No, imgs, name, psnr, ssim, file_name):
101 | for i, img in enumerate(imgs):
102 | img = (img.cpu().data * 255).numpy()
103 | img = np.clip(img, 0, 255)
104 | img = np.transpose(img, (1, 2, 0))
105 | h, w, c = img.shape
106 |
107 | img_file = os.path.join(self.show_dir, '%s.png' % (file_name))
108 | print(img_file)
109 | cv2.imwrite(img_file, img)
110 |
111 |
112 | def run_show(ckp_name):
113 | sess = Session()
114 | sess.load_checkpoints(ckp_name)
115 | sess.net.eval()
116 | dataset = 'test'
117 | if settings.pic_is_pair is False:
118 | dataset = 'train-small'
119 | dt = sess.get_dataloader(dataset)
120 |
121 | for i, batch in enumerate(dt):
122 | logger.info(i)
123 | if i>-1:
124 | imgs,psnr,ssim, file_name= sess.inf_batch('test', batch,i)
125 | sess.save_image(i, imgs, dataset, psnr, ssim, file_name)
126 |
127 |
128 |
129 | if __name__ == '__main__':
130 | parser = argparse.ArgumentParser()
131 | parser.add_argument('-m', '--model', default='latest_net')
132 |
133 | args = parser.parse_args(sys.argv[1:])
134 |
135 | run_show(args.model)
136 |
137 |
--------------------------------------------------------------------------------
/code/diff_loss/mae/config/tensorboard.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | function rand() {
3 | min=$1
4 | max=$(($2-$min+1))
5 | num=$(($RANDOM+1000000000000))
6 | echo $(($num%$max+$min))
7 | }
8 |
9 | rnd=$(rand 3000 12000)
10 | tensorboard --logdir ../logdir --host 0.0.0.0 --port $rnd --reload_interval 3
11 |
--------------------------------------------------------------------------------
/code/diff_loss/mae/models/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/models/.DS_Store
--------------------------------------------------------------------------------
/code/diff_loss/mae/models/README.md:
--------------------------------------------------------------------------------
1 | # JDNet:Joint Self-Attention and Scale-Aggregation for Self-Calibrated Deraining Network
2 |
3 | The training model or pre-training model will be placed here.
4 |
--------------------------------------------------------------------------------
/code/diff_loss/mse/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/.DS_Store
--------------------------------------------------------------------------------
/code/diff_loss/mse/config/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/config/.DS_Store
--------------------------------------------------------------------------------
/code/diff_loss/mse/config/cal_ssim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | import numpy as np
5 | from math import exp
6 |
7 | def gaussian(window_size, sigma):
8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
9 | return gauss/gauss.sum()
10 |
11 | def create_window(window_size, channel):
12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
15 | return window
16 |
17 | def _ssim(img1, img2, window, window_size, channel, size_average = True):
18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
20 |
21 | mu1_sq = mu1.pow(2)
22 | mu2_sq = mu2.pow(2)
23 | mu1_mu2 = mu1*mu2
24 |
25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
28 |
29 | C1 = 0.01**2
30 | C2 = 0.03**2
31 |
32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
33 |
34 | if size_average:
35 | return ssim_map.mean()
36 | else:
37 | return ssim_map.mean(1).mean(1).mean(1)
38 |
39 | class SSIM(torch.nn.Module):
40 | def __init__(self, window_size = 11, size_average = True):
41 | super(SSIM, self).__init__()
42 | self.window_size = window_size
43 | self.size_average = size_average
44 | self.channel = 3
45 | self.window = create_window(window_size, self.channel)
46 |
47 | def forward(self, img1, img2):
48 | #print(img1.size())
49 | (_, channel, _, _) = img1.size()
50 |
51 | if channel == self.channel and self.window.data.type() == img1.data.type():
52 | window = self.window
53 | else:
54 | window = create_window(self.window_size, channel)
55 |
56 | if img1.is_cuda:
57 | window = window.cuda(img1.get_device())
58 | window = window.type_as(img1)
59 |
60 | self.window = window
61 | self.channel = channel
62 |
63 |
64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
65 |
66 | def ssim(img1, img2, window_size = 11, size_average = True):
67 | (_, channel, _, _) = img1.size()
68 | window = create_window(window_size, channel)
69 |
70 | if img1.is_cuda:
71 | window = window.cuda(img1.get_device())
72 | window = window.type_as(img1)
73 |
74 | return _ssim(img1, img2, window, window_size, channel, size_average)
75 |
--------------------------------------------------------------------------------
/code/diff_loss/mse/config/clean.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | rm -rf __pycache__
3 | rm \.*\.swp
4 | rm -R ../logdir/*
5 | rm -R ../showdir/*
6 | rm ../models/*
7 |
--------------------------------------------------------------------------------
/code/diff_loss/mse/config/compile.py:
--------------------------------------------------------------------------------
1 | import re
2 | rec_data="+MIPLOBSERVE:0,68220,1,3303,0,-1"
3 | msgidRegex = re.compile(r',(\d)+,')
4 | mo = msgidRegex.search(rec_data)
5 | print(mo.group())
--------------------------------------------------------------------------------
/code/diff_loss/mse/config/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | from numpy.random import RandomState
5 | from torch.utils.data import Dataset
6 |
7 | import settings
8 |
9 |
10 | class TrainValDataset(Dataset):
11 | def __init__(self, name):
12 | super().__init__()
13 | self.rand_state = RandomState(66)
14 | self.root_dir = os.path.join(settings.data_dir, name)
15 | self.mat_files = os.listdir(self.root_dir)
16 | self.patch_size = settings.patch_size
17 | self.file_num = len(self.mat_files)
18 |
19 | def __len__(self):
20 | return self.file_num * 100
21 |
22 | def __getitem__(self, idx):
23 | file_name = self.mat_files[idx % self.file_num]
24 | img_file = os.path.join(self.root_dir, file_name)
25 | img_pair = cv2.imread(img_file).astype(np.float32) / 255
26 |
27 | if settings.aug_data:
28 | O, B = self.crop(img_pair, aug=True)
29 | O, B = self.flip(O, B)
30 | O, B = self.rotate(O, B)
31 | else:
32 | O, B = self.crop(img_pair, aug=False)
33 |
34 | O = np.transpose(O, (2, 0, 1))
35 | B = np.transpose(B, (2, 0, 1))
36 | sample = {'O': O, 'B': B}
37 |
38 | return sample
39 |
40 | def crop(self, img_pair, aug):
41 | patch_size = self.patch_size
42 | h, ww, c = img_pair.shape
43 | w = int(ww / 2)
44 |
45 | if aug:
46 | mini = - 1 / 4 * self.patch_size
47 | maxi = 1 / 4 * self.patch_size + 1
48 | p_h = patch_size + self.rand_state.randint(mini, maxi)
49 | p_w = patch_size + self.rand_state.randint(mini, maxi)
50 | else:
51 | p_h, p_w = patch_size, patch_size
52 |
53 | r = self.rand_state.randint(0, h - p_h)
54 | c = self.rand_state.randint(0, w - p_w)
55 | O = img_pair[r: r+p_h, c+w: c+p_w+w]
56 | B = img_pair[r: r+p_h, c: c+p_w]
57 |
58 | if aug:
59 | O = cv2.resize(O, (patch_size, patch_size))
60 | B = cv2.resize(B, (patch_size, patch_size))
61 |
62 | return O, B
63 |
64 | def flip(self, O, B):
65 | if self.rand_state.rand() > 0.5:
66 | O = np.flip(O, axis=1)
67 | B = np.flip(B, axis=1)
68 | return O, B
69 |
70 | def rotate(self, O, B):
71 | angle = self.rand_state.randint(-30, 30)
72 | patch_size = self.patch_size
73 | center = (int(patch_size / 2), int(patch_size / 2))
74 | M = cv2.getRotationMatrix2D(center, angle, 1)
75 | O = cv2.warpAffine(O, M, (patch_size, patch_size))
76 | B = cv2.warpAffine(B, M, (patch_size, patch_size))
77 | return O, B
78 |
79 |
80 | class TestDataset(Dataset):
81 | def __init__(self, name):
82 | super().__init__()
83 | self.rand_state = RandomState(66)
84 | self.root_dir = os.path.join(settings.data_dir, name)
85 | self.mat_files = os.listdir(self.root_dir)
86 | self.patch_size = settings.patch_size
87 | self.file_num = len(self.mat_files)
88 |
89 | def __len__(self):
90 | return self.file_num
91 |
92 | def __getitem__(self, idx):
93 | file_name = self.mat_files[idx % self.file_num]
94 | img_file = os.path.join(self.root_dir, file_name)
95 | img_pair = cv2.imread(img_file).astype(np.float32) / 255
96 | h, ww, c = img_pair.shape
97 | w = int(ww / 2)
98 | #h_8=h%8
99 | #w_8=w%8
100 | O = np.transpose(img_pair[:, w:], (2, 0, 1))
101 | B = np.transpose(img_pair[:, :w], (2, 0, 1))
102 | sample = {'O': O, 'B': B}
103 |
104 | return sample
105 |
106 |
107 | class ShowDataset(Dataset):
108 | def __init__(self, name):
109 | super().__init__()
110 | self.rand_state = RandomState(66)
111 | self.root_dir = os.path.join(settings.data_dir, name)
112 | self.img_files = sorted(os.listdir(self.root_dir))
113 | self.file_num = len(self.img_files)
114 |
115 | def __len__(self):
116 | return self.file_num
117 |
118 | def __getitem__(self, idx):
119 | file_name = self.img_files[idx % self.file_num]
120 | img_file = os.path.join(self.root_dir, file_name)
121 | print(img_file)
122 | img_pair = cv2.imread(img_file).astype(np.float32) / 255
123 |
124 | h, ww, c = img_pair.shape
125 | w = int(ww / 2)
126 |
127 | #h_8 = h % 8
128 | #w_8 = w % 8
129 | if settings.pic_is_pair:
130 | O = np.transpose(img_pair[:, w:], (2, 0, 1))
131 | B = np.transpose(img_pair[:, :w], (2, 0, 1))
132 | else:
133 | O = np.transpose(img_pair[:, :], (2, 0, 1))
134 | B = np.transpose(img_pair[:, :], (2, 0, 1))
135 | sample = {'O': O, 'B': B,'file_name':file_name[:-4]}
136 |
137 | return sample
138 |
139 |
140 | if __name__ == '__main__':
141 | dt = TrainValDataset('val')
142 | print('TrainValDataset')
143 | for i in range(10):
144 | smp = dt[i]
145 | for k, v in smp.items():
146 | print(k, v.shape, v.dtype, v.mean())
147 |
148 | print()
149 | dt = TestDataset('test')
150 | print('TestDataset')
151 | for i in range(10):
152 | smp = dt[i]
153 | for k, v in smp.items():
154 | print(k, v.shape, v.dtype, v.mean())
155 |
156 | print()
157 | print('ShowDataset')
158 | dt = ShowDataset('test')
159 | for i in range(10):
160 | smp = dt[i]
161 | for k, v in smp.items():
162 | print(k, v.shape, v.dtype, v.mean())
163 |
--------------------------------------------------------------------------------
/code/diff_loss/mse/config/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | import argparse
5 | import numpy as np
6 | import math
7 | import torch
8 | from torch import nn
9 | from torch.nn import MSELoss
10 | from torch.optim import Adam
11 | from torch.optim.lr_scheduler import MultiStepLR
12 | from torch.autograd import Variable
13 | from torch.utils.data import DataLoader
14 | from tensorboardX import SummaryWriter
15 |
16 | import settings
17 | from dataset import TestDataset
18 | from model import ODE_DerainNet
19 | from cal_ssim import SSIM
20 |
21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id
22 | logger = settings.logger
23 | torch.cuda.manual_seed_all(66)
24 | torch.manual_seed(66)
25 | #torch.cuda.set_device(settings.device_id)
26 |
27 |
28 | def ensure_dir(dir_path):
29 | if not os.path.isdir(dir_path):
30 | os.makedirs(dir_path)
31 | def PSNR(img1, img2):
32 | b,_,_,_=img1.shape
33 | #mse=0
34 | #for i in range(b):
35 | img1=np.clip(img1,0,255)
36 | img2=np.clip(img2,0,255)
37 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse
38 | if mse == 0:
39 | return 100
40 | #mse=mse/b
41 | PIXEL_MAX = 1
42 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
43 | class Session:
44 | def __init__(self):
45 | self.log_dir = settings.log_dir
46 | self.model_dir = settings.model_dir
47 | ensure_dir(settings.log_dir)
48 | ensure_dir(settings.model_dir)
49 | logger.info('set log dir as %s' % settings.log_dir)
50 | logger.info('set model dir as %s' % settings.model_dir)
51 | if len(settings.device_id) >1:
52 | self.net = nn.DataParallel(ODE_DerainNet()).cuda()
53 | else:
54 | torch.cuda.set_device(settings.device_id[0])
55 | self.net = ODE_DerainNet().cuda()
56 | self.l2 = MSELoss().cuda()
57 | self.l1 = nn.L1Loss().cuda()
58 | self.ssim = SSIM().cuda()
59 | self.dataloaders = {}
60 |
61 | def get_dataloader(self, dataset_name):
62 | dataset = TestDataset(dataset_name)
63 | if not dataset_name in self.dataloaders:
64 | self.dataloaders[dataset_name] = \
65 | DataLoader(dataset, batch_size=1,
66 | shuffle=False, num_workers=1, drop_last=False)
67 | return self.dataloaders[dataset_name]
68 |
69 | def load_checkpoints(self, name):
70 | ckp_path = os.path.join(self.model_dir, name)
71 | try:
72 | obj = torch.load(ckp_path)
73 | logger.info('Load checkpoint %s' % ckp_path)
74 | except FileNotFoundError:
75 | logger.info('No checkpoint %s!!' % ckp_path)
76 | return
77 | self.net.load_state_dict(obj['net'])
78 |
79 | def inf_batch(self, name, batch):
80 | O, B = batch['O'].cuda(), batch['B'].cuda()
81 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False)
82 |
83 | with torch.no_grad():
84 | derain = self.net(O)
85 |
86 | l1_loss = self.l1(derain, B)
87 | ssim = self.ssim(derain, B)
88 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255)
89 | losses = { 'L1 loss' : l1_loss }
90 | ssimes = { 'ssim' : ssim }
91 | losses.update(ssimes)
92 |
93 | return losses, psnr
94 |
95 |
96 | def run_test(ckp_name):
97 | sess = Session()
98 | sess.net.eval()
99 | sess.load_checkpoints(ckp_name)
100 | dt = sess.get_dataloader('test')
101 | psnr_all = 0
102 | all_num = 0
103 | all_losses = {}
104 | for i, batch in enumerate(dt):
105 | losses,psnr= sess.inf_batch('test', batch)
106 | psnr_all=psnr_all+psnr
107 | batch_size = batch['O'].size(0)
108 | all_num += batch_size
109 | for key, val in losses.items():
110 | if i == 0:
111 | all_losses[key] = 0.
112 | all_losses[key] += val * batch_size
113 | logger.info('batch %d mse %s: %f' % (i, key, val))
114 |
115 | for key, val in all_losses.items():
116 | logger.info('total mse %s: %f' % (key, val / all_num))
117 | #psnr=sum(psnr_all)
118 | #print(psnr)
119 | print('psnr_ll:%8f'%(psnr_all/all_num))
120 |
121 | if __name__ == '__main__':
122 | parser = argparse.ArgumentParser()
123 | parser.add_argument('-m', '--model', default='latest_net')
124 |
125 | args = parser.parse_args(sys.argv[1:])
126 | run_test(args.model)
127 |
128 |
--------------------------------------------------------------------------------
/code/diff_loss/mse/config/function/__pycache__/functional.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/config/function/__pycache__/functional.cpython-37.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mse/config/function/functional.py:
--------------------------------------------------------------------------------
1 | from . import functions
2 |
3 |
4 | def aggregation(input, weight, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1):
5 | assert input.shape[0] == weight.shape[0] and (input.shape[1] % weight.shape[1] == 0) and pad_mode in [0, 1]
6 | if input.is_cuda:
7 | if pad_mode == 0:
8 | out = functions.aggregation_zeropad(input, weight, kernel_size, stride, padding, dilation)
9 | elif pad_mode == 1:
10 | out = functions.aggregation_refpad(input, weight, kernel_size, stride, padding, dilation)
11 | else:
12 | raise NotImplementedError
13 | return out
14 |
15 |
16 | def subtraction(input, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1):
17 | assert input.dim() == 4 and pad_mode in [0, 1]
18 | if input.is_cuda:
19 | if pad_mode == 0:
20 | out = functions.subtraction_zeropad(input, kernel_size, stride, padding, dilation)
21 | elif pad_mode == 1:
22 | out = functions.subtraction_refpad(input, kernel_size, stride, padding, dilation)
23 | else:
24 | raise NotImplementedError
25 | return out
26 |
27 |
28 | def subtraction2(input1, input2, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1):
29 | assert input1.dim() == 4 and input2.dim() == 4 and pad_mode in [0, 1]
30 | if input1.is_cuda:
31 | if pad_mode == 0:
32 | out = functions.subtraction2_zeropad(input1, input2, kernel_size, stride, padding, dilation)
33 | elif pad_mode == 1:
34 | out = functions.subtraction2_refpad(input1, input2, kernel_size, stride, padding, dilation)
35 | else:
36 | raise NotImplementedError
37 | return out
38 |
--------------------------------------------------------------------------------
/code/diff_loss/mse/config/function/functions/__init__.py:
--------------------------------------------------------------------------------
1 | from .aggregation_zeropad import *
2 | from .aggregation_refpad import *
3 | from .subtraction_zeropad import *
4 | from .subtraction_refpad import *
5 | from .subtraction2_zeropad import *
6 | from .subtraction2_refpad import *
7 | from .utils import *
8 |
--------------------------------------------------------------------------------
/code/diff_loss/mse/config/function/functions/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/config/function/functions/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mse/config/function/functions/__pycache__/aggregation_refpad.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/config/function/functions/__pycache__/aggregation_refpad.cpython-37.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mse/config/function/functions/__pycache__/aggregation_zeropad.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/config/function/functions/__pycache__/aggregation_zeropad.cpython-37.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mse/config/function/functions/__pycache__/subtraction2_refpad.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/config/function/functions/__pycache__/subtraction2_refpad.cpython-37.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mse/config/function/functions/__pycache__/subtraction2_zeropad.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/config/function/functions/__pycache__/subtraction2_zeropad.cpython-37.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mse/config/function/functions/__pycache__/subtraction_refpad.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/config/function/functions/__pycache__/subtraction_refpad.cpython-37.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mse/config/function/functions/__pycache__/subtraction_zeropad.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/config/function/functions/__pycache__/subtraction_zeropad.cpython-37.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mse/config/function/functions/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/config/function/functions/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mse/config/function/functions/utils.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from string import Template
3 | import cupy
4 | import torch
5 |
6 |
7 | Stream = namedtuple('Stream', ['ptr'])
8 |
9 |
10 | def Dtype(t):
11 | if isinstance(t, torch.cuda.FloatTensor):
12 | return 'float'
13 | elif isinstance(t, torch.cuda.DoubleTensor):
14 | return 'double'
15 |
16 |
17 | @cupy.util.memoize(for_each_device=True)
18 | def load_kernel(kernel_name, code, **kwargs):
19 | code = Template(code).substitute(**kwargs)
20 | kernel_code = cupy.cuda.compile_with_cache(code)
21 | return kernel_code.get_function(kernel_name)
22 |
--------------------------------------------------------------------------------
/code/diff_loss/mse/config/function/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .aggregation import *
2 | from .subtraction import *
3 | from .subtraction2 import *
4 |
--------------------------------------------------------------------------------
/code/diff_loss/mse/config/function/modules/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/config/function/modules/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mse/config/function/modules/__pycache__/aggregation.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/config/function/modules/__pycache__/aggregation.cpython-37.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mse/config/function/modules/__pycache__/subtraction.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/config/function/modules/__pycache__/subtraction.cpython-37.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mse/config/function/modules/__pycache__/subtraction2.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/config/function/modules/__pycache__/subtraction2.cpython-37.pyc
--------------------------------------------------------------------------------
/code/diff_loss/mse/config/function/modules/aggregation.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch.nn.modules.utils import _pair
3 |
4 | from .. import functional as F
5 |
6 |
7 | class Aggregation(nn.Module):
8 |
9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode):
10 | super(Aggregation, self).__init__()
11 | self.kernel_size = _pair(kernel_size)
12 | self.stride = _pair(stride)
13 | self.padding = _pair(padding)
14 | self.dilation = _pair(dilation)
15 | self.pad_mode = pad_mode
16 |
17 | def forward(self, input, weight):
18 | return F.aggregation(input, weight, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode)
19 |
--------------------------------------------------------------------------------
/code/diff_loss/mse/config/function/modules/subtraction.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch.nn.modules.utils import _pair
3 |
4 | from .. import functional as F
5 |
6 |
7 | class Subtraction(nn.Module):
8 |
9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode):
10 | super(Subtraction, self).__init__()
11 | self.kernel_size = _pair(kernel_size)
12 | self.stride = _pair(stride)
13 | self.padding = _pair(padding)
14 | self.dilation = _pair(dilation)
15 | self.pad_mode = pad_mode
16 |
17 | def forward(self, input):
18 | return F.subtraction(input, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode)
19 |
--------------------------------------------------------------------------------
/code/diff_loss/mse/config/function/modules/subtraction2.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch.nn.modules.utils import _pair
3 |
4 | from .. import functional as F
5 |
6 |
7 | class Subtraction2(nn.Module):
8 |
9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode):
10 | super(Subtraction2, self).__init__()
11 | self.kernel_size = _pair(kernel_size)
12 | self.stride = _pair(stride)
13 | self.padding = _pair(padding)
14 | self.dilation = _pair(dilation)
15 | self.pad_mode = pad_mode
16 |
17 | def forward(self, input1, input2):
18 | return F.subtraction2(input1, input2, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode)
19 |
--------------------------------------------------------------------------------
/code/diff_loss/mse/config/settings.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 |
4 |
5 | ####################### 网络参数设置 ############################################################
6 | channel = 32
7 | feature_map_num = 32
8 | res_conv_num = 4
9 | unit_num = 32
10 | #scale_num = 4
11 | num_scale_attention = 4
12 | scale_attention = False
13 | ssim_loss = True
14 | ########################################################################################
15 | aug_data = False # Set as False for fair comparison
16 |
17 | patch_size = 64
18 | pic_is_pair = True #input picture is pair or single
19 |
20 | lr = 0.0005
21 |
22 | data_dir = '/data1/datasets/haze_blur_rain/rain100H'
23 | if pic_is_pair is False:
24 | data_dir = '/data1/wangcong/dataset/real-world-images'
25 | log_dir = '../logdir'
26 | show_dir = '../showdir'
27 | model_dir = '../models'
28 | show_dir_feature = '../showdir_feature'
29 |
30 | log_level = 'info'
31 | model_path = os.path.join(model_dir, 'latest_net')
32 | save_steps = 400
33 |
34 | num_workers = 8
35 |
36 | num_GPU = 2
37 |
38 | device_id = '1,2'
39 |
40 | epoch = 1000
41 | batch_size = 12
42 |
43 | if pic_is_pair:
44 | root_dir = os.path.join(data_dir, 'train')
45 | mat_files = os.listdir(root_dir)
46 | num_datasets = len(mat_files)
47 | l1 = int(3/5 * epoch * num_datasets / batch_size)
48 | l2 = int(4/5 * epoch * num_datasets / batch_size)
49 | one_epoch = int(num_datasets/batch_size)
50 | total_step = int((epoch * num_datasets)/batch_size)
51 |
52 | logger = logging.getLogger('train')
53 | logger.setLevel(logging.INFO)
54 |
55 | ch = logging.StreamHandler()
56 | ch.setLevel(logging.INFO)
57 |
58 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
59 | ch.setFormatter(formatter)
60 | logger.addHandler(ch)
61 |
62 |
63 |
--------------------------------------------------------------------------------
/code/diff_loss/mse/config/show.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | import argparse
5 | import math
6 | import numpy as np
7 | import itertools
8 |
9 | import torch
10 | from torch import nn
11 | from torch.nn import DataParallel
12 | from torch.optim import Adam
13 | from torch.autograd import Variable
14 | from torch.utils.data import DataLoader
15 |
16 | import settings
17 | from dataset import ShowDataset
18 | from model import ODE_DerainNet
19 | from cal_ssim import SSIM
20 |
21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id
22 | logger = settings.logger
23 | torch.cuda.manual_seed_all(66)
24 | torch.manual_seed(66)
25 | #torch.cuda.set_device(settings.device_id)
26 |
27 |
28 | def ensure_dir(dir_path):
29 | if not os.path.isdir(dir_path):
30 | os.makedirs(dir_path)
31 |
32 | def PSNR(img1, img2):
33 | b,_,_,_=img1.shape
34 | #mse=0
35 | #for i in range(b):
36 | img1=np.clip(img1,0,255)
37 | img2=np.clip(img2,0,255)
38 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse
39 | if mse == 0:
40 | return 100
41 | #mse=mse/b
42 | PIXEL_MAX = 1
43 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
44 |
45 | class Session:
46 | def __init__(self):
47 | self.show_dir = settings.show_dir
48 | self.model_dir = settings.model_dir
49 | ensure_dir(settings.show_dir)
50 | ensure_dir(settings.model_dir)
51 | logger.info('set show dir as %s' % settings.show_dir)
52 | logger.info('set model dir as %s' % settings.model_dir)
53 |
54 | if len(settings.device_id) >1:
55 | self.net = nn.DataParallel(ODE_DerainNet()).cuda()
56 | #self.l2 = nn.DataParallel(MSELoss(),settings.device_id)
57 | #self.l1 = nn.DataParallel(nn.L1Loss(),settings.device_id)
58 | #self.ssim = nn.DataParallel(SSIM(),settings.device_id)
59 | #self.vgg = nn.DataParallel(VGG(),settings.device_id)
60 | else:
61 | torch.cuda.set_device(settings.device_id[0])
62 | self.net = ODE_DerainNet().cuda()
63 | self.ssim = SSIM().cuda()
64 | self.dataloaders = {}
65 | self.ssim=SSIM().cuda()
66 | self.a=0
67 | self.t=0
68 | def get_dataloader(self, dataset_name):
69 | dataset = ShowDataset(dataset_name)
70 | self.dataloaders[dataset_name] = \
71 | DataLoader(dataset, batch_size=1,
72 | shuffle=False, num_workers=1)
73 | return self.dataloaders[dataset_name]
74 |
75 | def load_checkpoints(self, name):
76 | ckp_path = os.path.join(self.model_dir, name)
77 | try:
78 | obj = torch.load(ckp_path)
79 | logger.info('Load checkpoint %s' % ckp_path)
80 | except FileNotFoundError:
81 | logger.info('No checkpoint %s!!' % ckp_path)
82 | return
83 | self.net.load_state_dict(obj['net'])
84 | def inf_batch(self, name, batch,i):
85 | O, B, file_name= batch['O'].cuda(), batch['B'].cuda(), batch['file_name']
86 | file_name = str(file_name[0])
87 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False)
88 | with torch.no_grad():
89 | import time
90 | t0=time.time()
91 | derain = self.net(O)
92 | t1 = time.time()
93 | comput_time=t1-t0
94 | print(comput_time)
95 | ssim = self.ssim(derain, B).data.cpu().numpy()
96 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255)
97 | print('psnr:%4f-------------ssim:%4f'%(psnr, ssim))
98 | return derain, psnr, ssim, file_name
99 |
100 | def save_image(self, No, imgs, name, psnr, ssim, file_name):
101 | for i, img in enumerate(imgs):
102 | img = (img.cpu().data * 255).numpy()
103 | img = np.clip(img, 0, 255)
104 | img = np.transpose(img, (1, 2, 0))
105 | h, w, c = img.shape
106 |
107 | img_file = os.path.join(self.show_dir, '%s.png' % (file_name))
108 | print(img_file)
109 | cv2.imwrite(img_file, img)
110 |
111 |
112 | def run_show(ckp_name):
113 | sess = Session()
114 | sess.load_checkpoints(ckp_name)
115 | sess.net.eval()
116 | dataset = 'test'
117 | if settings.pic_is_pair is False:
118 | dataset = 'train-small'
119 | dt = sess.get_dataloader(dataset)
120 |
121 | for i, batch in enumerate(dt):
122 | logger.info(i)
123 | if i>-1:
124 | imgs,psnr,ssim, file_name= sess.inf_batch('test', batch,i)
125 | sess.save_image(i, imgs, dataset, psnr, ssim, file_name)
126 |
127 |
128 |
129 | if __name__ == '__main__':
130 | parser = argparse.ArgumentParser()
131 | parser.add_argument('-m', '--model', default='latest_net')
132 |
133 | args = parser.parse_args(sys.argv[1:])
134 |
135 | run_show(args.model)
136 |
137 |
--------------------------------------------------------------------------------
/code/diff_loss/mse/config/tensorboard.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | function rand() {
3 | min=$1
4 | max=$(($2-$min+1))
5 | num=$(($RANDOM+1000000000000))
6 | echo $(($num%$max+$min))
7 | }
8 |
9 | rnd=$(rand 3000 12000)
10 | tensorboard --logdir ../logdir --host 0.0.0.0 --port $rnd --reload_interval 3
11 |
--------------------------------------------------------------------------------
/code/diff_loss/mse/models/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/models/.DS_Store
--------------------------------------------------------------------------------
/code/diff_loss/mse/models/README.md:
--------------------------------------------------------------------------------
1 | # JDNet:Joint Self-Attention and Scale-Aggregation for Self-Calibrated Deraining Network
2 |
3 | The training model or pre-training model will be placed here.
4 |
--------------------------------------------------------------------------------
/fig/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/fig/.DS_Store
--------------------------------------------------------------------------------
/fig/ex_pair.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/fig/ex_pair.png
--------------------------------------------------------------------------------
/fig/ex_unpair.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/fig/ex_unpair.png
--------------------------------------------------------------------------------
/fig/overall.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/fig/overall.png
--------------------------------------------------------------------------------
/fig/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/fig/result.png
--------------------------------------------------------------------------------