├── README.md ├── code ├── ablation │ ├── r1 │ │ ├── config │ │ │ ├── cal_ssim.py │ │ │ ├── clean.sh │ │ │ ├── compile.py │ │ │ ├── dataset.py │ │ │ ├── eval.py │ │ │ ├── model.py │ │ │ ├── settings.py │ │ │ ├── show.py │ │ │ ├── tensorboard.sh │ │ │ └── train.py │ │ └── models │ │ │ └── README.md │ └── r2 │ │ ├── config │ │ ├── .DS_Store │ │ ├── cal_ssim.py │ │ ├── clean.sh │ │ ├── compile.py │ │ ├── dataset.py │ │ ├── eval.py │ │ ├── model.py │ │ ├── settings.py │ │ ├── show.py │ │ ├── tensorboard.sh │ │ └── train.py │ │ └── models │ │ ├── .DS_Store │ │ └── README.md ├── base │ ├── .DS_Store │ ├── rain100H │ │ ├── .DS_Store │ │ ├── config │ │ │ ├── .DS_Store │ │ │ ├── cal_ssim.py │ │ │ ├── clean.sh │ │ │ ├── compile.py │ │ │ ├── dataset.py │ │ │ ├── eval.py │ │ │ ├── function │ │ │ │ ├── __pycache__ │ │ │ │ │ └── functional.cpython-37.pyc │ │ │ │ ├── functional.py │ │ │ │ ├── functions │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── __pycache__ │ │ │ │ │ │ ├── __init__.cpython-37.pyc │ │ │ │ │ │ ├── aggregation_refpad.cpython-37.pyc │ │ │ │ │ │ ├── aggregation_zeropad.cpython-37.pyc │ │ │ │ │ │ ├── subtraction2_refpad.cpython-37.pyc │ │ │ │ │ │ ├── subtraction2_zeropad.cpython-37.pyc │ │ │ │ │ │ ├── subtraction_refpad.cpython-37.pyc │ │ │ │ │ │ ├── subtraction_zeropad.cpython-37.pyc │ │ │ │ │ │ └── utils.cpython-37.pyc │ │ │ │ │ ├── aggregation_refpad.py │ │ │ │ │ ├── aggregation_zeropad.py │ │ │ │ │ ├── subtraction2_refpad.py │ │ │ │ │ ├── subtraction2_zeropad.py │ │ │ │ │ ├── subtraction_refpad.py │ │ │ │ │ ├── subtraction_zeropad.py │ │ │ │ │ └── utils.py │ │ │ │ └── modules │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── __pycache__ │ │ │ │ │ ├── __init__.cpython-37.pyc │ │ │ │ │ ├── aggregation.cpython-37.pyc │ │ │ │ │ ├── subtraction.cpython-37.pyc │ │ │ │ │ └── subtraction2.cpython-37.pyc │ │ │ │ │ ├── aggregation.py │ │ │ │ │ ├── subtraction.py │ │ │ │ │ └── subtraction2.py │ │ │ ├── model.py │ │ │ ├── settings.py │ │ │ ├── show.py │ │ │ ├── tensorboard.sh │ │ │ └── train.py │ │ └── models │ │ │ ├── .DS_Store │ │ │ └── README.md │ └── rain100L │ │ ├── .DS_Store │ │ ├── config │ │ ├── .DS_Store │ │ ├── cal_ssim.py │ │ ├── clean.sh │ │ ├── compile.py │ │ ├── dataset.py │ │ ├── eval.py │ │ ├── function │ │ │ ├── __pycache__ │ │ │ │ └── functional.cpython-37.pyc │ │ │ ├── functional.py │ │ │ ├── functions │ │ │ │ ├── __init__.py │ │ │ │ ├── __pycache__ │ │ │ │ │ ├── __init__.cpython-37.pyc │ │ │ │ │ ├── aggregation_refpad.cpython-37.pyc │ │ │ │ │ ├── aggregation_zeropad.cpython-37.pyc │ │ │ │ │ ├── subtraction2_refpad.cpython-37.pyc │ │ │ │ │ ├── subtraction2_zeropad.cpython-37.pyc │ │ │ │ │ ├── subtraction_refpad.cpython-37.pyc │ │ │ │ │ ├── subtraction_zeropad.cpython-37.pyc │ │ │ │ │ └── utils.cpython-37.pyc │ │ │ │ ├── aggregation_refpad.py │ │ │ │ ├── aggregation_zeropad.py │ │ │ │ ├── subtraction2_refpad.py │ │ │ │ ├── subtraction2_zeropad.py │ │ │ │ ├── subtraction_refpad.py │ │ │ │ ├── subtraction_zeropad.py │ │ │ │ └── utils.py │ │ │ └── modules │ │ │ │ ├── __init__.py │ │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-37.pyc │ │ │ │ ├── aggregation.cpython-37.pyc │ │ │ │ ├── subtraction.cpython-37.pyc │ │ │ │ └── subtraction2.cpython-37.pyc │ │ │ │ ├── aggregation.py │ │ │ │ ├── subtraction.py │ │ │ │ └── subtraction2.py │ │ ├── model.py │ │ ├── settings.py │ │ ├── show.py │ │ ├── tensorboard.sh │ │ └── train.py │ │ └── models │ │ ├── .DS_Store │ │ └── README.md └── diff_loss │ ├── .DS_Store │ ├── mae │ ├── .DS_Store │ ├── config │ │ ├── .DS_Store │ │ ├── cal_ssim.py │ │ ├── clean.sh │ │ ├── compile.py │ │ ├── dataset.py │ │ ├── eval.py │ │ ├── function │ │ │ ├── __pycache__ │ │ │ │ ├── functional.cpython-36.pyc │ │ │ │ └── functional.cpython-37.pyc │ │ │ ├── functional.py │ │ │ ├── functions │ │ │ │ ├── __init__.py │ │ │ │ ├── __pycache__ │ │ │ │ │ ├── __init__.cpython-36.pyc │ │ │ │ │ ├── __init__.cpython-37.pyc │ │ │ │ │ ├── aggregation_refpad.cpython-36.pyc │ │ │ │ │ ├── aggregation_refpad.cpython-37.pyc │ │ │ │ │ ├── aggregation_zeropad.cpython-36.pyc │ │ │ │ │ ├── aggregation_zeropad.cpython-37.pyc │ │ │ │ │ ├── subtraction2_refpad.cpython-36.pyc │ │ │ │ │ ├── subtraction2_refpad.cpython-37.pyc │ │ │ │ │ ├── subtraction2_zeropad.cpython-36.pyc │ │ │ │ │ ├── subtraction2_zeropad.cpython-37.pyc │ │ │ │ │ ├── subtraction_refpad.cpython-36.pyc │ │ │ │ │ ├── subtraction_refpad.cpython-37.pyc │ │ │ │ │ ├── subtraction_zeropad.cpython-36.pyc │ │ │ │ │ ├── subtraction_zeropad.cpython-37.pyc │ │ │ │ │ ├── utils.cpython-36.pyc │ │ │ │ │ └── utils.cpython-37.pyc │ │ │ │ ├── aggregation_refpad.py │ │ │ │ ├── aggregation_zeropad.py │ │ │ │ ├── subtraction2_refpad.py │ │ │ │ ├── subtraction2_zeropad.py │ │ │ │ ├── subtraction_refpad.py │ │ │ │ ├── subtraction_zeropad.py │ │ │ │ └── utils.py │ │ │ └── modules │ │ │ │ ├── __init__.py │ │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-36.pyc │ │ │ │ ├── __init__.cpython-37.pyc │ │ │ │ ├── aggregation.cpython-36.pyc │ │ │ │ ├── aggregation.cpython-37.pyc │ │ │ │ ├── subtraction.cpython-36.pyc │ │ │ │ ├── subtraction.cpython-37.pyc │ │ │ │ ├── subtraction2.cpython-36.pyc │ │ │ │ └── subtraction2.cpython-37.pyc │ │ │ │ ├── aggregation.py │ │ │ │ ├── subtraction.py │ │ │ │ └── subtraction2.py │ │ ├── model.py │ │ ├── settings.py │ │ ├── show.py │ │ ├── tensorboard.sh │ │ └── train.py │ └── models │ │ ├── .DS_Store │ │ └── README.md │ └── mse │ ├── .DS_Store │ ├── config │ ├── .DS_Store │ ├── cal_ssim.py │ ├── clean.sh │ ├── compile.py │ ├── dataset.py │ ├── eval.py │ ├── function │ │ ├── __pycache__ │ │ │ └── functional.cpython-37.pyc │ │ ├── functional.py │ │ ├── functions │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-37.pyc │ │ │ │ ├── aggregation_refpad.cpython-37.pyc │ │ │ │ ├── aggregation_zeropad.cpython-37.pyc │ │ │ │ ├── subtraction2_refpad.cpython-37.pyc │ │ │ │ ├── subtraction2_zeropad.cpython-37.pyc │ │ │ │ ├── subtraction_refpad.cpython-37.pyc │ │ │ │ ├── subtraction_zeropad.cpython-37.pyc │ │ │ │ └── utils.cpython-37.pyc │ │ │ ├── aggregation_refpad.py │ │ │ ├── aggregation_zeropad.py │ │ │ ├── subtraction2_refpad.py │ │ │ ├── subtraction2_zeropad.py │ │ │ ├── subtraction_refpad.py │ │ │ ├── subtraction_zeropad.py │ │ │ └── utils.py │ │ └── modules │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── aggregation.cpython-37.pyc │ │ │ ├── subtraction.cpython-37.pyc │ │ │ └── subtraction2.cpython-37.pyc │ │ │ ├── aggregation.py │ │ │ ├── subtraction.py │ │ │ └── subtraction2.py │ ├── model.py │ ├── settings.py │ ├── show.py │ ├── tensorboard.sh │ └── train.py │ └── models │ ├── .DS_Store │ └── README.md └── fig ├── .DS_Store ├── ex_pair.png ├── ex_unpair.png ├── overall.png └── result.png /README.md: -------------------------------------------------------------------------------- 1 | # JDNet:Joint Self-Attention and Scale-Aggregation for Self-Calibrated Deraining Network 2 | 3 | [Cong Wang](https://supercong94.wixsite.com/supercong94)\*, [Yutong Wu](https://ohraincu.github.io/)\*, [Zhixun Su](http://faculty.dlut.edu.cn/ZhixunSu/zh_CN/index/759047/list/index.htm) †, Junyang Chen 4 | 5 | <\* Both authors contributed equally to this research. † Corresponding author.> 6 | 7 | This work has been accepted by ACM'MM 2020. [\[Arxiv\]](https://arxiv.org/abs/2008.02763) 8 | 9 |
10 | 11 | 12 | Fig1:An example from real-world datasets. 13 |
14 | 15 | ## Abstract 16 | In this paper, we propose an effective algorithm, called JDNet, to solve the single image deraining problem and conduct the segmentation and detection task for applications. Specifically, considering the important information on multi-scale features, we propose a Scale-Aggregation module to learn the features with different scales. Simultaneously, Self-Attention module is introduced to match or outperform their convolutional counterparts, which allows the feature aggregation to adapt to each channel. Furthermore, to improve the basic convolutional feature transformation process of Convolutional Neural Networks (CNNs), Self-Calibrated convolution is applied to build long-range spatial and inter-channel dependencies around each spatial location that explicitly expand fields-of-view of each convolutional layer through internal communications and hence enriches the output features. By designing the Scale-Aggregation and Self-Attention modules with Self-Calibrated convolution skillfully, the proposed model has better deraining results both on real-world and synthetic datasets. Extensive experiments are con- ducted to demonstrate the superiority of our method compared with state-of-the-art methods. 17 | 18 |
19 | 20 | 21 | Fig2:The architecture of Joint Network for deraining (JDNet). 22 |
23 | 24 | ## Requirements 25 | - CUDA 9.0 (or later) 26 | - Python 3.6 (or later) 27 | - Pytorch 1.1.0 28 | - Torchvision 0.3.0 29 | - OpenCV 30 | 31 | ## Dataset 32 | Please download the following datasets: 33 | 34 | * Rain12 [[dataset](http://yu-li.github.io/paper/li_cvpr16_rain.zip)] 35 | * Rain100L [[dataset](http://www.icst.pku.edu.cn/struct/Projects/joint_rain_removal.html)] 36 | * Rain100H [[dataset](http://www.icst.pku.edu.cn/struct/Projects/joint_rain_removal.html)] 37 | * Real-world images [[BaiduYun](https://pan.baidu.com/s/1gpuB6NUHPnQEtgRn4evr5A)](password:vvuk) / [[GoogleDrive](https://drive.google.com/drive/folders/1sSfm-HplPO3FLKR3wAH3iWBmYsD_UAy_?usp=sharing)] 38 | 39 | ## Setup 40 | Please download this project through 'git' command. 41 | All training and testing experiments are in a folder called [code](https://github.com/Ohraincu/JDNet/tree/master/code): 42 | ``` 43 | $ git clone https://github.com/Ohraincu/JDNet.git 44 | $ cd code 45 | ``` 46 | Next, we can see folders in the following form: 47 | 48 | |--ablation 49 | |--r1 50 | |--r2 51 | 52 | |--base 53 | |--rain100H 54 | |--rain100L 55 | 56 | |--diff_loss 57 | |--mae 58 | |--mse 59 | 60 | The implementation code in our work is in the [base](https://github.com/Ohraincu/JDNet/tree/master/code/base) folder. 61 | 62 | In the [ablation](https://github.com/Ohraincu/JDNet/tree/master/code/ablation) folder, there are two experimental codes for ablation learning mentioned in the paper. 63 | 64 | The only difference between the code in the [loss](https://github.com/Ohraincu/JDNet/tree/master/code/loss) folder and the [base](https://github.com/Ohraincu/JDNet/tree/master/code/base) is that different loss functions are used in the training process. 65 | 66 | Thanks to [the code by Li et al.](https://xialipku.github.io/RESCAN/), our code is also adapted based on this. 67 | 68 | ## Training 69 | After you download the above datasets, you can perform the following operations to train: 70 | ``` 71 | $ cd config 72 | $ python train.py 73 | ``` 74 | You can pause or start the training at any time because we can save the pre-trained models (named 'latest_net' or 'net_x_epoch') in due course. 75 | 76 | ## Testing 77 | ### Pre-trained Models 78 | [[BaiduYun](https://pan.baidu.com/s/1MHVnTTN-kjeDQ-x9m_zr0Q)](password:9sbj) / 79 | [[GoogleDrive](https://drive.google.com/drive/folders/1rzidCqbhgnKEGNOgFLqSUynhux4SfURI)] 80 | 81 | ### Quantitative and Qualitative Results 82 | After running eval.py, you can get the corresponding numerical results (PSNR and SSIM): 83 | ``` 84 | $ python eval.py 85 | ``` 86 | If the visual results on datasets need to be observed, the show.py can be run: 87 | ``` 88 | $ python show.py 89 | ``` 90 | All the pre-trained model in each case is placed in the corresponding 'model' folder, and the 'latest_net' model is directly referenced by default. If you want to generate results for your training intermediate model, you can make the following changes to eval.py or show.py: 91 | ``` 92 | parser.add_argument('-m', '--model', default='net_x_epoch') 93 | ``` 94 | 95 | ## Supplement to settings.py 96 | One thing to note is that, there are two types of input in our training or testing. 97 | 98 | One takes a paired image as input, and the other uses a single rainy image, as shown in Fig3 and Fig4 respectively. 99 | 100 |
101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 |
Fig3: Paired image.Fig4: Unpaired image.
111 |
112 | 113 | For the different forms of input images, we need to modify some settings in the settings.py: 114 | 115 | ``` 116 | pic_is_pair = True #input picture is pair or single 117 | data_dir = '/Your_paired_datadir' 118 | if pic_is_pair is False: 119 | data_dir = '/Your_unpaired_datadir' 120 | ``` 121 | At this point, if you have the following representation in your train/eval/show.py: 122 | 123 | ``` 124 | dt = sess.get_dataloader('test') 125 | ``` 126 | The path of the dataset you are about to run is: "/Your_paired_datadir/test" 127 | 128 | ## Citation 129 | Wait for update! 130 | ``` 131 | @inproceedings{acmmm20_jdnet, 132 | author = {Cong Wang and Yutong Wu and Zhixun Su and Junyang Chen}, 133 | title = {Joint Self-Attention and Scale-Aggregation for Self-Calibrated Deraining Network}, 134 | booktitle = {ACM International Conference on Multimedia}, 135 | year = {2020}, 136 | } 137 | ``` 138 | 139 | ## Contact 140 | 141 | If you are interested in our work or have any questions, please directly contact my github. 142 | Email: ytongwu@mail.dlut.edu.cn / supercong94@gmail.com 143 | -------------------------------------------------------------------------------- /code/ablation/r1/config/cal_ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map.mean(1).mean(1).mean(1) 38 | 39 | class SSIM(torch.nn.Module): 40 | def __init__(self, window_size = 11, size_average = True): 41 | super(SSIM, self).__init__() 42 | self.window_size = window_size 43 | self.size_average = size_average 44 | self.channel = 3 45 | self.window = create_window(window_size, self.channel) 46 | 47 | def forward(self, img1, img2): 48 | #print(img1.size()) 49 | (_, channel, _, _) = img1.size() 50 | 51 | if channel == self.channel and self.window.data.type() == img1.data.type(): 52 | window = self.window 53 | else: 54 | window = create_window(self.window_size, channel) 55 | 56 | if img1.is_cuda: 57 | window = window.cuda(img1.get_device()) 58 | window = window.type_as(img1) 59 | 60 | self.window = window 61 | self.channel = channel 62 | 63 | 64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 65 | 66 | def ssim(img1, img2, window_size = 11, size_average = True): 67 | (_, channel, _, _) = img1.size() 68 | window = create_window(window_size, channel) 69 | 70 | if img1.is_cuda: 71 | window = window.cuda(img1.get_device()) 72 | window = window.type_as(img1) 73 | 74 | return _ssim(img1, img2, window, window_size, channel, size_average) 75 | -------------------------------------------------------------------------------- /code/ablation/r1/config/clean.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | rm -rf __pycache__ 3 | rm \.*\.swp 4 | rm -R ../logdir/* 5 | rm -R ../showdir/* 6 | rm ../models/* 7 | -------------------------------------------------------------------------------- /code/ablation/r1/config/compile.py: -------------------------------------------------------------------------------- 1 | import re 2 | rec_data="+MIPLOBSERVE:0,68220,1,3303,0,-1" 3 | msgidRegex = re.compile(r',(\d)+,') 4 | mo = msgidRegex.search(rec_data) 5 | print(mo.group()) -------------------------------------------------------------------------------- /code/ablation/r1/config/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from numpy.random import RandomState 5 | from torch.utils.data import Dataset 6 | 7 | import settings 8 | 9 | 10 | class TrainValDataset(Dataset): 11 | def __init__(self, name): 12 | super().__init__() 13 | self.rand_state = RandomState(66) 14 | self.root_dir = os.path.join(settings.data_dir, name) 15 | self.mat_files = os.listdir(self.root_dir) 16 | self.patch_size = settings.patch_size 17 | self.file_num = len(self.mat_files) 18 | 19 | def __len__(self): 20 | return self.file_num * 100 21 | 22 | def __getitem__(self, idx): 23 | file_name = self.mat_files[idx % self.file_num] 24 | img_file = os.path.join(self.root_dir, file_name) 25 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 26 | 27 | if settings.aug_data: 28 | O, B = self.crop(img_pair, aug=True) 29 | O, B = self.flip(O, B) 30 | O, B = self.rotate(O, B) 31 | else: 32 | O, B = self.crop(img_pair, aug=False) 33 | 34 | O = np.transpose(O, (2, 0, 1)) 35 | B = np.transpose(B, (2, 0, 1)) 36 | sample = {'O': O, 'B': B} 37 | 38 | return sample 39 | 40 | def crop(self, img_pair, aug): 41 | patch_size = self.patch_size 42 | h, ww, c = img_pair.shape 43 | w = int(ww / 2) 44 | 45 | if aug: 46 | mini = - 1 / 4 * self.patch_size 47 | maxi = 1 / 4 * self.patch_size + 1 48 | p_h = patch_size + self.rand_state.randint(mini, maxi) 49 | p_w = patch_size + self.rand_state.randint(mini, maxi) 50 | else: 51 | p_h, p_w = patch_size, patch_size 52 | 53 | r = self.rand_state.randint(0, h - p_h) 54 | c = self.rand_state.randint(0, w - p_w) 55 | O = img_pair[r: r+p_h, c+w: c+p_w+w] 56 | B = img_pair[r: r+p_h, c: c+p_w] 57 | 58 | if aug: 59 | O = cv2.resize(O, (patch_size, patch_size)) 60 | B = cv2.resize(B, (patch_size, patch_size)) 61 | 62 | return O, B 63 | 64 | def flip(self, O, B): 65 | if self.rand_state.rand() > 0.5: 66 | O = np.flip(O, axis=1) 67 | B = np.flip(B, axis=1) 68 | return O, B 69 | 70 | def rotate(self, O, B): 71 | angle = self.rand_state.randint(-30, 30) 72 | patch_size = self.patch_size 73 | center = (int(patch_size / 2), int(patch_size / 2)) 74 | M = cv2.getRotationMatrix2D(center, angle, 1) 75 | O = cv2.warpAffine(O, M, (patch_size, patch_size)) 76 | B = cv2.warpAffine(B, M, (patch_size, patch_size)) 77 | return O, B 78 | 79 | 80 | class TestDataset(Dataset): 81 | def __init__(self, name): 82 | super().__init__() 83 | self.rand_state = RandomState(66) 84 | self.root_dir = os.path.join(settings.data_dir, name) 85 | self.mat_files = os.listdir(self.root_dir) 86 | self.patch_size = settings.patch_size 87 | self.file_num = len(self.mat_files) 88 | 89 | def __len__(self): 90 | return self.file_num 91 | 92 | def __getitem__(self, idx): 93 | file_name = self.mat_files[idx % self.file_num] 94 | img_file = os.path.join(self.root_dir, file_name) 95 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 96 | h, ww, c = img_pair.shape 97 | w = int(ww / 2) 98 | #h_8=h%8 99 | #w_8=w%8 100 | O = np.transpose(img_pair[:, w:], (2, 0, 1)) 101 | B = np.transpose(img_pair[:, :w], (2, 0, 1)) 102 | sample = {'O': O, 'B': B} 103 | 104 | return sample 105 | 106 | 107 | class ShowDataset(Dataset): 108 | def __init__(self, name): 109 | super().__init__() 110 | self.rand_state = RandomState(66) 111 | self.root_dir = os.path.join(settings.data_dir, name) 112 | self.img_files = sorted(os.listdir(self.root_dir)) 113 | self.file_num = len(self.img_files) 114 | 115 | def __len__(self): 116 | return self.file_num 117 | 118 | def __getitem__(self, idx): 119 | file_name = self.img_files[idx % self.file_num] 120 | img_file = os.path.join(self.root_dir, file_name) 121 | print(img_file) 122 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 123 | 124 | h, ww, c = img_pair.shape 125 | w = int(ww / 2) 126 | 127 | #h_8 = h % 8 128 | #w_8 = w % 8 129 | if settings.pic_is_pair: 130 | O = np.transpose(img_pair[:, w:], (2, 0, 1)) 131 | B = np.transpose(img_pair[:, :w], (2, 0, 1)) 132 | else: 133 | O = np.transpose(img_pair[:, :], (2, 0, 1)) 134 | B = np.transpose(img_pair[:, :], (2, 0, 1)) 135 | sample = {'O': O, 'B': B,'file_name':file_name[:-4]} 136 | 137 | return sample 138 | 139 | 140 | if __name__ == '__main__': 141 | dt = TrainValDataset('val') 142 | print('TrainValDataset') 143 | for i in range(10): 144 | smp = dt[i] 145 | for k, v in smp.items(): 146 | print(k, v.shape, v.dtype, v.mean()) 147 | 148 | print() 149 | dt = TestDataset('test') 150 | print('TestDataset') 151 | for i in range(10): 152 | smp = dt[i] 153 | for k, v in smp.items(): 154 | print(k, v.shape, v.dtype, v.mean()) 155 | 156 | print() 157 | print('ShowDataset') 158 | dt = ShowDataset('test') 159 | for i in range(10): 160 | smp = dt[i] 161 | for k, v in smp.items(): 162 | print(k, v.shape, v.dtype, v.mean()) 163 | -------------------------------------------------------------------------------- /code/ablation/r1/config/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | import numpy as np 6 | import math 7 | import torch 8 | from torch import nn 9 | from torch.nn import MSELoss 10 | from torch.optim import Adam 11 | from torch.optim.lr_scheduler import MultiStepLR 12 | from torch.autograd import Variable 13 | from torch.utils.data import DataLoader 14 | from tensorboardX import SummaryWriter 15 | 16 | import settings 17 | from dataset import TestDataset 18 | from model import ODE_DerainNet 19 | from cal_ssim import SSIM 20 | 21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id 22 | logger = settings.logger 23 | torch.cuda.manual_seed_all(66) 24 | torch.manual_seed(66) 25 | #torch.cuda.set_device(settings.device_id) 26 | 27 | 28 | def ensure_dir(dir_path): 29 | if not os.path.isdir(dir_path): 30 | os.makedirs(dir_path) 31 | def PSNR(img1, img2): 32 | b,_,_,_=img1.shape 33 | #mse=0 34 | #for i in range(b): 35 | img1=np.clip(img1,0,255) 36 | img2=np.clip(img2,0,255) 37 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse 38 | if mse == 0: 39 | return 100 40 | #mse=mse/b 41 | PIXEL_MAX = 1 42 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 43 | class Session: 44 | def __init__(self): 45 | self.log_dir = settings.log_dir 46 | self.model_dir = settings.model_dir 47 | ensure_dir(settings.log_dir) 48 | ensure_dir(settings.model_dir) 49 | logger.info('set log dir as %s' % settings.log_dir) 50 | logger.info('set model dir as %s' % settings.model_dir) 51 | if len(settings.device_id) >1: 52 | self.net = nn.DataParallel(ODE_DerainNet()).cuda() 53 | else: 54 | torch.cuda.set_device(settings.device_id[0]) 55 | self.net = ODE_DerainNet().cuda() 56 | self.l2 = MSELoss().cuda() 57 | self.l1 = nn.L1Loss().cuda() 58 | self.ssim = SSIM().cuda() 59 | self.dataloaders = {} 60 | 61 | def get_dataloader(self, dataset_name): 62 | dataset = TestDataset(dataset_name) 63 | if not dataset_name in self.dataloaders: 64 | self.dataloaders[dataset_name] = \ 65 | DataLoader(dataset, batch_size=1, 66 | shuffle=False, num_workers=1, drop_last=False) 67 | return self.dataloaders[dataset_name] 68 | 69 | def load_checkpoints(self, name): 70 | ckp_path = os.path.join(self.model_dir, name) 71 | try: 72 | obj = torch.load(ckp_path) 73 | logger.info('Load checkpoint %s' % ckp_path) 74 | except FileNotFoundError: 75 | logger.info('No checkpoint %s!!' % ckp_path) 76 | return 77 | self.net.load_state_dict(obj['net']) 78 | 79 | def inf_batch(self, name, batch): 80 | O, B = batch['O'].cuda(), batch['B'].cuda() 81 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) 82 | 83 | with torch.no_grad(): 84 | derain = self.net(O) 85 | 86 | l1_loss = self.l1(derain, B) 87 | ssim = self.ssim(derain, B) 88 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) 89 | losses = { 'L1 loss' : l1_loss } 90 | ssimes = { 'ssim' : ssim } 91 | losses.update(ssimes) 92 | 93 | return losses, psnr 94 | 95 | 96 | def run_test(ckp_name): 97 | sess = Session() 98 | sess.net.eval() 99 | sess.load_checkpoints(ckp_name) 100 | dt = sess.get_dataloader('test') 101 | psnr_all = 0 102 | all_num = 0 103 | all_losses = {} 104 | for i, batch in enumerate(dt): 105 | losses,psnr= sess.inf_batch('test', batch) 106 | psnr_all=psnr_all+psnr 107 | batch_size = batch['O'].size(0) 108 | all_num += batch_size 109 | for key, val in losses.items(): 110 | if i == 0: 111 | all_losses[key] = 0. 112 | all_losses[key] += val * batch_size 113 | logger.info('batch %d mse %s: %f' % (i, key, val)) 114 | 115 | for key, val in all_losses.items(): 116 | logger.info('total mse %s: %f' % (key, val / all_num)) 117 | #psnr=sum(psnr_all) 118 | #print(psnr) 119 | print('psnr_ll:%8f'%(psnr_all/all_num)) 120 | 121 | if __name__ == '__main__': 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument('-m', '--model', default='latest_net') 124 | 125 | args = parser.parse_args(sys.argv[1:]) 126 | run_test(args.model) 127 | 128 | -------------------------------------------------------------------------------- /code/ablation/r1/config/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import settings 5 | from itertools import combinations,product 6 | import math 7 | 8 | class Residual_Block(nn.Module): 9 | def __init__(self, in_ch, out_ch, stride): 10 | super(Residual_Block, self).__init__() 11 | self.channel_num = settings.channel 12 | self.convs = nn.ModuleList() 13 | self.relus = nn.ModuleList() 14 | self.convert = nn.Sequential( 15 | nn.Conv2d(in_ch, out_ch, 3, stride, 1), 16 | nn.LeakyReLU(0.2) 17 | ) 18 | self.res = nn.Sequential( 19 | nn.Conv2d(out_ch, out_ch, 3, 1, 1), 20 | nn.LeakyReLU(0.2), 21 | nn.Conv2d(out_ch, out_ch, 3, 1, 1), 22 | ) 23 | 24 | def forward(self, x): 25 | convert = self.convert(x) 26 | out = convert + self.res(convert) 27 | return out 28 | 29 | class Scale_attention(nn.Module): 30 | def __init__(self): 31 | super(Scale_attention, self).__init__() 32 | self.scale_attention = nn.ModuleList() 33 | self.res_list = nn.ModuleList() 34 | self.channel = settings.channel 35 | if settings.scale_attention is True: 36 | for i in range(settings.num_scale_attention): 37 | self.scale_attention.append( 38 | nn.Sequential( 39 | nn.MaxPool2d(2 ** (i + 1), 2 ** (i + 1)), 40 | nn.Conv2d(self.channel, self.channel, 1, 1), 41 | nn.Sigmoid() 42 | ) 43 | ) 44 | for i in range(settings.num_scale_attention): 45 | self.res_list.append( 46 | Residual_Block(self.channel, self.channel, 2) 47 | ) 48 | 49 | self.conv11 = nn.Sequential( 50 | nn.Conv2d((settings.num_scale_attention + 1) * self.channel, self.channel, 1, 1), 51 | nn.LeakyReLU(0.2) 52 | ) 53 | 54 | def forward(self, x): 55 | b, c, h, w = x.size() 56 | temp = x 57 | out = [] 58 | out.append(temp) 59 | if settings.scale_attention is True: 60 | for i in range(settings.num_scale_attention): 61 | temp = self.res_list[i](temp) 62 | b0,c0,h0,w0 = temp.size() 63 | temp = temp * F.upsample(self.scale_attention[i](x), [h0, w0]) 64 | up = temp 65 | out.append(F.upsample(up, [h, w])) 66 | fusion = self.conv11(torch.cat(out, dim=1)) 67 | 68 | else: 69 | for i in range(settings.num_scale_attention): 70 | temp = self.res_list[i](temp) 71 | up = temp 72 | out.append(F.upsample(up, [h, w])) 73 | fusion = self.conv11(torch.cat(out, dim=1)) 74 | return fusion + x 75 | 76 | class DenseConnection(nn.Module): 77 | def __init__(self, unit, unit_num): 78 | super(DenseConnection, self).__init__() 79 | self.unit_num = unit_num 80 | self.channel = settings.channel 81 | self.units = nn.ModuleList() 82 | self.conv1x1 = nn.ModuleList() 83 | for i in range(self.unit_num): 84 | self.units.append(unit()) 85 | self.conv1x1.append(nn.Sequential(nn.Conv2d((i+2)*self.channel, self.channel, 1, 1), nn.LeakyReLU(0.2))) 86 | 87 | def forward(self, x): 88 | cat = [] 89 | cat.append(x) 90 | out = x 91 | for i in range(self.unit_num): 92 | tmp = self.units[i](out) 93 | cat.append(tmp) 94 | out = self.conv1x1[i](torch.cat(cat,dim=1)) 95 | return out 96 | 97 | 98 | class ODE_DerainNet(nn.Module): 99 | def __init__(self): 100 | super(ODE_DerainNet, self).__init__() 101 | self.channel = settings.channel 102 | self.unit_num = settings.unit_num 103 | self.enterBlock = nn.Sequential(nn.Conv2d(3, self.channel, 3, 1, 1), nn.LeakyReLU(0.2)) 104 | self.derain_net = DenseConnection(Scale_attention, self.unit_num) 105 | self.exitBlock = nn.Sequential(nn.Conv2d(self.channel, 3, 3, 1, 1), nn.LeakyReLU(0.2)) 106 | 107 | 108 | def forward(self, x): 109 | image_feature = self.enterBlock(x) 110 | rain_feature = self.derain_net(image_feature) 111 | rain = self.exitBlock(rain_feature) 112 | derain = x - rain 113 | return derain 114 | -------------------------------------------------------------------------------- /code/ablation/r1/config/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | 5 | ####################### 网络参数设置  ############################################################ 6 | channel = 32 7 | feature_map_num = 32 8 | res_conv_num = 4 9 | unit_num = 32 10 | #scale_num = 4 11 | num_scale_attention = 4 12 | scale_attention = False 13 | ssim_loss = True 14 | ######################################################################################## 15 | aug_data = False # Set as False for fair comparison 16 | 17 | patch_size = 64 18 | pic_is_pair = True #input picture is pair or single 19 | 20 | lr = 0.0005 21 | 22 | data_dir = '/data1/wangcong/dataset/rain12' 23 | if pic_is_pair is False: 24 | data_dir = '/data1/wangcong/dataset/real-world-images' 25 | log_dir = '../logdir' 26 | show_dir = '../showdir' 27 | model_dir = '../models' 28 | show_dir_feature = '../showdir_feature' 29 | 30 | log_level = 'info' 31 | model_path = os.path.join(model_dir, 'latest_net') 32 | save_steps = 400 33 | 34 | num_workers = 8 35 | 36 | num_GPU = 2 37 | 38 | device_id = '2,1' 39 | 40 | epoch = 1000 41 | batch_size = 28 42 | 43 | if pic_is_pair: 44 | root_dir = os.path.join(data_dir, 'train') 45 | mat_files = os.listdir(root_dir) 46 | num_datasets = len(mat_files) 47 | l1 = int(3/5 * epoch * num_datasets / batch_size) 48 | l2 = int(4/5 * epoch * num_datasets / batch_size) 49 | one_epoch = int(num_datasets/batch_size) 50 | total_step = int((epoch * num_datasets)/batch_size) 51 | 52 | logger = logging.getLogger('train') 53 | logger.setLevel(logging.INFO) 54 | 55 | ch = logging.StreamHandler() 56 | ch.setLevel(logging.INFO) 57 | 58 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 59 | ch.setFormatter(formatter) 60 | logger.addHandler(ch) 61 | 62 | 63 | -------------------------------------------------------------------------------- /code/ablation/r1/config/show.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | import math 6 | import numpy as np 7 | import itertools 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import DataParallel 12 | from torch.optim import Adam 13 | from torch.autograd import Variable 14 | from torch.utils.data import DataLoader 15 | 16 | import settings 17 | from dataset import ShowDataset 18 | from model import ODE_DerainNet 19 | from cal_ssim import SSIM 20 | 21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id 22 | logger = settings.logger 23 | torch.cuda.manual_seed_all(66) 24 | torch.manual_seed(66) 25 | #torch.cuda.set_device(settings.device_id) 26 | 27 | 28 | def ensure_dir(dir_path): 29 | if not os.path.isdir(dir_path): 30 | os.makedirs(dir_path) 31 | 32 | def PSNR(img1, img2): 33 | b,_,_,_=img1.shape 34 | #mse=0 35 | #for i in range(b): 36 | img1=np.clip(img1,0,255) 37 | img2=np.clip(img2,0,255) 38 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse 39 | if mse == 0: 40 | return 100 41 | #mse=mse/b 42 | PIXEL_MAX = 1 43 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 44 | 45 | class Session: 46 | def __init__(self): 47 | self.show_dir = settings.show_dir 48 | self.model_dir = settings.model_dir 49 | ensure_dir(settings.show_dir) 50 | ensure_dir(settings.model_dir) 51 | logger.info('set show dir as %s' % settings.show_dir) 52 | logger.info('set model dir as %s' % settings.model_dir) 53 | 54 | if len(settings.device_id) >1: 55 | self.net = nn.DataParallel(ODE_DerainNet()).cuda() 56 | #self.l2 = nn.DataParallel(MSELoss(),settings.device_id) 57 | #self.l1 = nn.DataParallel(nn.L1Loss(),settings.device_id) 58 | #self.ssim = nn.DataParallel(SSIM(),settings.device_id) 59 | #self.vgg = nn.DataParallel(VGG(),settings.device_id) 60 | else: 61 | torch.cuda.set_device(settings.device_id[0]) 62 | self.net = ODE_DerainNet().cuda() 63 | self.ssim = SSIM().cuda() 64 | self.dataloaders = {} 65 | self.ssim=SSIM().cuda() 66 | self.a=0 67 | self.t=0 68 | def get_dataloader(self, dataset_name): 69 | dataset = ShowDataset(dataset_name) 70 | self.dataloaders[dataset_name] = \ 71 | DataLoader(dataset, batch_size=1, 72 | shuffle=False, num_workers=1) 73 | return self.dataloaders[dataset_name] 74 | 75 | def load_checkpoints(self, name): 76 | ckp_path = os.path.join(self.model_dir, name) 77 | try: 78 | obj = torch.load(ckp_path) 79 | logger.info('Load checkpoint %s' % ckp_path) 80 | except FileNotFoundError: 81 | logger.info('No checkpoint %s!!' % ckp_path) 82 | return 83 | self.net.load_state_dict(obj['net']) 84 | def inf_batch(self, name, batch,i): 85 | O, B, file_name= batch['O'].cuda(), batch['B'].cuda(), batch['file_name'] 86 | file_name = str(file_name[0]) 87 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) 88 | with torch.no_grad(): 89 | import time 90 | t0=time.time() 91 | derain = self.net(O) 92 | t1 = time.time() 93 | comput_time=t1-t0 94 | print(comput_time) 95 | ssim = self.ssim(derain, B).data.cpu().numpy() 96 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) 97 | print('psnr:%4f-------------ssim:%4f'%(psnr, ssim)) 98 | return derain, psnr, ssim, file_name 99 | 100 | def save_image(self, No, imgs, name, psnr, ssim, file_name): 101 | for i, img in enumerate(imgs): 102 | img = (img.cpu().data * 255).numpy() 103 | img = np.clip(img, 0, 255) 104 | img = np.transpose(img, (1, 2, 0)) 105 | h, w, c = img.shape 106 | 107 | img_file = os.path.join(self.show_dir, '%s.png' % (file_name)) 108 | print(img_file) 109 | cv2.imwrite(img_file, img) 110 | 111 | 112 | def run_show(ckp_name): 113 | sess = Session() 114 | sess.load_checkpoints(ckp_name) 115 | sess.net.eval() 116 | dataset = 'test' 117 | if settings.pic_is_pair is False: 118 | dataset = 'train-small' 119 | dt = sess.get_dataloader(dataset) 120 | 121 | for i, batch in enumerate(dt): 122 | logger.info(i) 123 | if i>423: 124 | imgs,psnr,ssim, file_name= sess.inf_batch('test', batch,i) 125 | sess.save_image(i, imgs, dataset, psnr, ssim, file_name) 126 | 127 | 128 | 129 | if __name__ == '__main__': 130 | parser = argparse.ArgumentParser() 131 | parser.add_argument('-m', '--model', default='latest_net') 132 | 133 | args = parser.parse_args(sys.argv[1:]) 134 | 135 | run_show(args.model) 136 | 137 | -------------------------------------------------------------------------------- /code/ablation/r1/config/tensorboard.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | function rand() { 3 | min=$1 4 | max=$(($2-$min+1)) 5 | num=$(($RANDOM+1000000000000)) 6 | echo $(($num%$max+$min)) 7 | } 8 | 9 | rnd=$(rand 3000 12000) 10 | tensorboard --logdir ../logdir --host 0.0.0.0 --port $rnd --reload_interval 3 11 | -------------------------------------------------------------------------------- /code/ablation/r1/config/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | import math 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import MSELoss 10 | from torch.optim import Adam 11 | from torch.optim.lr_scheduler import MultiStepLR 12 | from torch.autograd import Variable 13 | from torch.utils.data import DataLoader 14 | from tensorboardX import SummaryWriter 15 | 16 | import settings 17 | from dataset import TrainValDataset, TestDataset 18 | from model import ODE_DerainNet 19 | from cal_ssim import SSIM 20 | 21 | logger = settings.logger 22 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id 23 | torch.cuda.manual_seed_all(66) 24 | torch.manual_seed(66) 25 | import numpy as np 26 | 27 | 28 | def ensure_dir(dir_path): 29 | if not os.path.isdir(dir_path): 30 | os.makedirs(dir_path) 31 | 32 | def PSNR(img1, img2): 33 | b,_,_,_=img1.shape 34 | img1=np.clip(img1,0,255) 35 | img2=np.clip(img2,0,255) 36 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2) 37 | if mse == 0: 38 | return 100 39 | 40 | PIXEL_MAX = 1 41 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 42 | 43 | class Session: 44 | def __init__(self): 45 | self.log_dir = settings.log_dir 46 | self.model_dir = settings.model_dir 47 | self.ssim_loss = settings.ssim_loss 48 | ensure_dir(settings.log_dir) 49 | ensure_dir(settings.model_dir) 50 | ensure_dir('../log_test') 51 | logger.info('set log dir as %s' % settings.log_dir) 52 | logger.info('set model dir as %s' % settings.model_dir) 53 | if len(settings.device_id) >1: 54 | self.net = nn.DataParallel(ODE_DerainNet()).cuda() 55 | else: 56 | torch.cuda.set_device(settings.device_id[0]) 57 | self.net = ODE_DerainNet().cuda() 58 | self.l1 = nn.L1Loss().cuda() 59 | self.ssim = SSIM().cuda() 60 | self.step = 0 61 | self.save_steps = settings.save_steps 62 | self.num_workers = settings.num_workers 63 | self.batch_size = settings.batch_size 64 | self.writers = {} 65 | self.dataloaders = {} 66 | self.opt_net = Adam(self.net.parameters(), lr=settings.lr) 67 | self.sche_net = MultiStepLR(self.opt_net, milestones=[settings.l1, settings.l2], gamma=0.1) 68 | 69 | def tensorboard(self, name): 70 | self.writers[name] = SummaryWriter(os.path.join(self.log_dir, name + '.events')) 71 | return self.writers[name] 72 | 73 | def write(self, name, out): 74 | for k, v in out.items(): 75 | self.writers[name].add_scalar(k, v, self.step) 76 | out['lr'] = self.opt_net.param_groups[0]['lr'] 77 | out['step'] = self.step 78 | outputs = [ 79 | "{}:{:.4g}".format(k, v) 80 | for k, v in out.items() 81 | ] 82 | logger.info(name + '--' + ' '.join(outputs)) 83 | 84 | def get_dataloader(self, dataset_name): 85 | dataset = TrainValDataset(dataset_name) 86 | if not dataset_name in self.dataloaders: 87 | self.dataloaders[dataset_name] = \ 88 | DataLoader(dataset, batch_size=self.batch_size, 89 | shuffle=True, num_workers=self.num_workers, drop_last=True) 90 | return iter(self.dataloaders[dataset_name]) 91 | 92 | def get_test_dataloader(self, dataset_name): 93 | dataset = TestDataset(dataset_name) 94 | if not dataset_name in self.dataloaders: 95 | self.dataloaders[dataset_name] = \ 96 | DataLoader(dataset, batch_size=1, 97 | shuffle=False, num_workers=1, drop_last=False) 98 | return self.dataloaders[dataset_name] 99 | 100 | def save_checkpoints_net(self, name): 101 | ckp_path = os.path.join(self.model_dir, name) 102 | obj = { 103 | 'net': self.net.state_dict(), 104 | 'clock_net': self.step, 105 | 'opt_net': self.opt_net.state_dict(), 106 | } 107 | torch.save(obj, ckp_path) 108 | 109 | def load_checkpoints_net(self, name): 110 | ckp_path = os.path.join(self.model_dir, name) 111 | try: 112 | logger.info('Load checkpoint %s' % ckp_path) 113 | obj = torch.load(ckp_path) 114 | except FileNotFoundError: 115 | logger.info('No checkpoint %s!!' % ckp_path) 116 | return 117 | self.net.load_state_dict(obj['net']) 118 | self.opt_net.load_state_dict(obj['opt_net']) 119 | self.step = obj['clock_net'] 120 | self.sche_net.last_epoch = self.step 121 | 122 | def print_network(self, model): 123 | num_params = 0 124 | for p in model.parameters(): 125 | num_params += p.numel() 126 | print(model) 127 | print("The number of parameters: {}".format(num_params)) 128 | 129 | def inf_batch(self, name, batch): 130 | if name == 'train': 131 | self.net.zero_grad() 132 | if self.step==0: 133 | self.print_network(self.net) 134 | 135 | O, B = batch['O'].cuda(), batch['B'].cuda() 136 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) 137 | derain = self.net(O) 138 | l1_loss = self.l1(derain, B) 139 | ssim = self.ssim(derain, B) 140 | if self.ssim_loss == True : 141 | loss = -ssim 142 | else: 143 | loss = l1_loss 144 | if name == 'train': 145 | loss.backward() 146 | self.opt_net.step() 147 | losses = {'L1loss' : l1_loss} 148 | ssimes = {'ssim' : ssim} 149 | losses.update(ssimes) 150 | self.write(name, losses) 151 | 152 | return derain 153 | 154 | def save_image(self, name, img_lists): 155 | data, pred, label = img_lists 156 | pred = pred.cpu().data 157 | 158 | data, label, pred = data * 255, label * 255, pred * 255 159 | pred = np.clip(pred, 0, 255) 160 | h, w = pred.shape[-2:] 161 | gen_num = (1, 1) 162 | img = np.zeros((gen_num[0] * h, gen_num[1] * 3 * w, 3)) 163 | for img_list in img_lists: 164 | for i in range(gen_num[0]): 165 | row = i * h 166 | for j in range(gen_num[1]): 167 | idx = i * gen_num[1] + j 168 | tmp_list = [data[idx], pred[idx], label[idx]] 169 | for k in range(3): 170 | col = (j * 3 + k) * w 171 | tmp = np.transpose(tmp_list[k], (1, 2, 0)) 172 | img[row: row+h, col: col+w] = tmp 173 | img_file = os.path.join(self.log_dir, '%d_%s.jpg' % (self.step, name)) 174 | cv2.imwrite(img_file, img) 175 | 176 | def inf_batch_test(self, name, batch): 177 | O, B = batch['O'].cuda(), batch['B'].cuda() 178 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) 179 | 180 | with torch.no_grad(): 181 | derain = self.net(O) 182 | 183 | l1_loss = self.l1(derain, B) 184 | ssim = self.ssim(derain, B) 185 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) 186 | losses = { 'L1 loss' : l1_loss } 187 | ssimes = { 'ssim' : ssim } 188 | losses.update(ssimes) 189 | 190 | return l1_loss.data.cpu().numpy(), ssim.data.cpu().numpy(), psnr 191 | 192 | 193 | def run_train_val(ckp_name_net='latest_net'): 194 | sess = Session() 195 | sess.load_checkpoints_net(ckp_name_net) 196 | sess.tensorboard('train') 197 | dt_train = sess.get_dataloader('train') 198 | while sess.step < settings.total_step+1: 199 | sess.sche_net.step() 200 | sess.net.train() 201 | try: 202 | batch_t = next(dt_train) 203 | except StopIteration: 204 | dt_train = sess.get_dataloader('train') 205 | batch_t = next(dt_train) 206 | pred_t = sess.inf_batch('train', batch_t) 207 | if sess.step % int(sess.save_steps / 16) == 0: 208 | sess.save_checkpoints_net('latest_net') 209 | if sess.step % sess.save_steps == 0: 210 | sess.save_image('train', [batch_t['O'], pred_t, batch_t['B']]) 211 | 212 | #observe tendency of ssim, psnr and loss 213 | ssim_all = 0 214 | psnr_all = 0 215 | loss_all = 0 216 | num_all = 0 217 | if sess.step % (settings.one_epoch * 20) == 0: 218 | dt_val = sess.get_test_dataloader('test') 219 | sess.net.eval() 220 | for i, batch_v in enumerate(dt_val): 221 | loss, ssim, psnr = sess.inf_batch_test('test', batch_v) 222 | print(i) 223 | ssim_all = ssim_all + ssim 224 | psnr_all = psnr_all + psnr 225 | loss_all = loss_all + loss 226 | num_all = num_all + 1 227 | print('num_all:',num_all) 228 | loss_avg = loss_all / num_all 229 | ssim_avg = ssim_all / num_all 230 | psnr_avg = psnr_all / num_all 231 | logfile = open('../log_test/' + 'val' + '.txt','a+') 232 | epoch = int(sess.step / settings.one_epoch) 233 | logfile.write( 234 | 'step = ' + str(sess.step) + '\t' 235 | 'epoch = ' + str(epoch) + '\t' 236 | 'loss = ' + str(loss_avg) + '\t' 237 | 'ssim = ' + str(ssim_avg) + '\t' 238 | 'pnsr = ' + str(psnr_avg) + '\t' 239 | '\n\n' 240 | ) 241 | logfile.close() 242 | if sess.step % (settings.one_epoch*10) == 0: 243 | sess.save_checkpoints_net('net_%d_epoch' % int(sess.step / settings.one_epoch)) 244 | logger.info('save model as net_%d_epoch' % int(sess.step / settings.one_epoch)) 245 | sess.step += 1 246 | 247 | 248 | 249 | if __name__ == '__main__': 250 | parser = argparse.ArgumentParser() 251 | parser.add_argument('-m1', '--model_1', default='latest_net') 252 | 253 | args = parser.parse_args(sys.argv[1:]) 254 | run_train_val(args.model_1) 255 | 256 | -------------------------------------------------------------------------------- /code/ablation/r1/models/README.md: -------------------------------------------------------------------------------- 1 | # JDNet:Joint Self-Attention and Scale-Aggregation for Self-Calibrated Deraining Network 2 | 3 | The training model or pre-training model will be placed here. 4 | -------------------------------------------------------------------------------- /code/ablation/r2/config/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/ablation/r2/config/.DS_Store -------------------------------------------------------------------------------- /code/ablation/r2/config/cal_ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map.mean(1).mean(1).mean(1) 38 | 39 | class SSIM(torch.nn.Module): 40 | def __init__(self, window_size = 11, size_average = True): 41 | super(SSIM, self).__init__() 42 | self.window_size = window_size 43 | self.size_average = size_average 44 | self.channel = 3 45 | self.window = create_window(window_size, self.channel) 46 | 47 | def forward(self, img1, img2): 48 | #print(img1.size()) 49 | (_, channel, _, _) = img1.size() 50 | 51 | if channel == self.channel and self.window.data.type() == img1.data.type(): 52 | window = self.window 53 | else: 54 | window = create_window(self.window_size, channel) 55 | 56 | if img1.is_cuda: 57 | window = window.cuda(img1.get_device()) 58 | window = window.type_as(img1) 59 | 60 | self.window = window 61 | self.channel = channel 62 | 63 | 64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 65 | 66 | def ssim(img1, img2, window_size = 11, size_average = True): 67 | (_, channel, _, _) = img1.size() 68 | window = create_window(window_size, channel) 69 | 70 | if img1.is_cuda: 71 | window = window.cuda(img1.get_device()) 72 | window = window.type_as(img1) 73 | 74 | return _ssim(img1, img2, window, window_size, channel, size_average) 75 | -------------------------------------------------------------------------------- /code/ablation/r2/config/clean.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | rm -rf __pycache__ 3 | rm \.*\.swp 4 | rm -R ../logdir/* 5 | rm -R ../showdir/* 6 | rm ../models/* 7 | -------------------------------------------------------------------------------- /code/ablation/r2/config/compile.py: -------------------------------------------------------------------------------- 1 | import re 2 | rec_data="+MIPLOBSERVE:0,68220,1,3303,0,-1" 3 | msgidRegex = re.compile(r',(\d)+,') 4 | mo = msgidRegex.search(rec_data) 5 | print(mo.group()) -------------------------------------------------------------------------------- /code/ablation/r2/config/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from numpy.random import RandomState 5 | from torch.utils.data import Dataset 6 | 7 | import settings 8 | 9 | 10 | class TrainValDataset(Dataset): 11 | def __init__(self, name): 12 | super().__init__() 13 | self.rand_state = RandomState(66) 14 | self.root_dir = os.path.join(settings.data_dir, name) 15 | self.mat_files = os.listdir(self.root_dir) 16 | self.patch_size = settings.patch_size 17 | self.file_num = len(self.mat_files) 18 | 19 | def __len__(self): 20 | return self.file_num * 100 21 | 22 | def __getitem__(self, idx): 23 | file_name = self.mat_files[idx % self.file_num] 24 | img_file = os.path.join(self.root_dir, file_name) 25 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 26 | 27 | if settings.aug_data: 28 | O, B = self.crop(img_pair, aug=True) 29 | O, B = self.flip(O, B) 30 | O, B = self.rotate(O, B) 31 | else: 32 | O, B = self.crop(img_pair, aug=False) 33 | 34 | O = np.transpose(O, (2, 0, 1)) 35 | B = np.transpose(B, (2, 0, 1)) 36 | sample = {'O': O, 'B': B} 37 | 38 | return sample 39 | 40 | def crop(self, img_pair, aug): 41 | patch_size = self.patch_size 42 | h, ww, c = img_pair.shape 43 | w = int(ww / 2) 44 | 45 | if aug: 46 | mini = - 1 / 4 * self.patch_size 47 | maxi = 1 / 4 * self.patch_size + 1 48 | p_h = patch_size + self.rand_state.randint(mini, maxi) 49 | p_w = patch_size + self.rand_state.randint(mini, maxi) 50 | else: 51 | p_h, p_w = patch_size, patch_size 52 | 53 | r = self.rand_state.randint(0, h - p_h) 54 | c = self.rand_state.randint(0, w - p_w) 55 | O = img_pair[r: r+p_h, c+w: c+p_w+w] 56 | B = img_pair[r: r+p_h, c: c+p_w] 57 | 58 | if aug: 59 | O = cv2.resize(O, (patch_size, patch_size)) 60 | B = cv2.resize(B, (patch_size, patch_size)) 61 | 62 | return O, B 63 | 64 | def flip(self, O, B): 65 | if self.rand_state.rand() > 0.5: 66 | O = np.flip(O, axis=1) 67 | B = np.flip(B, axis=1) 68 | return O, B 69 | 70 | def rotate(self, O, B): 71 | angle = self.rand_state.randint(-30, 30) 72 | patch_size = self.patch_size 73 | center = (int(patch_size / 2), int(patch_size / 2)) 74 | M = cv2.getRotationMatrix2D(center, angle, 1) 75 | O = cv2.warpAffine(O, M, (patch_size, patch_size)) 76 | B = cv2.warpAffine(B, M, (patch_size, patch_size)) 77 | return O, B 78 | 79 | 80 | class TestDataset(Dataset): 81 | def __init__(self, name): 82 | super().__init__() 83 | self.rand_state = RandomState(66) 84 | self.root_dir = os.path.join(settings.data_dir, name) 85 | self.mat_files = os.listdir(self.root_dir) 86 | self.patch_size = settings.patch_size 87 | self.file_num = len(self.mat_files) 88 | 89 | def __len__(self): 90 | return self.file_num 91 | 92 | def __getitem__(self, idx): 93 | file_name = self.mat_files[idx % self.file_num] 94 | img_file = os.path.join(self.root_dir, file_name) 95 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 96 | h, ww, c = img_pair.shape 97 | w = int(ww / 2) 98 | #h_8=h%8 99 | #w_8=w%8 100 | O = np.transpose(img_pair[:, w:], (2, 0, 1)) 101 | B = np.transpose(img_pair[:, :w], (2, 0, 1)) 102 | sample = {'O': O, 'B': B} 103 | 104 | return sample 105 | 106 | 107 | class ShowDataset(Dataset): 108 | def __init__(self, name): 109 | super().__init__() 110 | self.rand_state = RandomState(66) 111 | self.root_dir = os.path.join(settings.data_dir, name) 112 | self.img_files = sorted(os.listdir(self.root_dir)) 113 | self.file_num = len(self.img_files) 114 | 115 | def __len__(self): 116 | return self.file_num 117 | 118 | def __getitem__(self, idx): 119 | file_name = self.img_files[idx % self.file_num] 120 | img_file = os.path.join(self.root_dir, file_name) 121 | print(img_file) 122 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 123 | 124 | h, ww, c = img_pair.shape 125 | w = int(ww / 2) 126 | 127 | #h_8 = h % 8 128 | #w_8 = w % 8 129 | if settings.pic_is_pair: 130 | O = np.transpose(img_pair[:, w:], (2, 0, 1)) 131 | B = np.transpose(img_pair[:, :w], (2, 0, 1)) 132 | else: 133 | O = np.transpose(img_pair[:, :], (2, 0, 1)) 134 | B = np.transpose(img_pair[:, :], (2, 0, 1)) 135 | sample = {'O': O, 'B': B,'file_name':file_name[:-4]} 136 | 137 | return sample 138 | 139 | 140 | if __name__ == '__main__': 141 | dt = TrainValDataset('val') 142 | print('TrainValDataset') 143 | for i in range(10): 144 | smp = dt[i] 145 | for k, v in smp.items(): 146 | print(k, v.shape, v.dtype, v.mean()) 147 | 148 | print() 149 | dt = TestDataset('test') 150 | print('TestDataset') 151 | for i in range(10): 152 | smp = dt[i] 153 | for k, v in smp.items(): 154 | print(k, v.shape, v.dtype, v.mean()) 155 | 156 | print() 157 | print('ShowDataset') 158 | dt = ShowDataset('test') 159 | for i in range(10): 160 | smp = dt[i] 161 | for k, v in smp.items(): 162 | print(k, v.shape, v.dtype, v.mean()) 163 | -------------------------------------------------------------------------------- /code/ablation/r2/config/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | import numpy as np 6 | import math 7 | import torch 8 | from torch import nn 9 | from torch.nn import MSELoss 10 | from torch.optim import Adam 11 | from torch.optim.lr_scheduler import MultiStepLR 12 | from torch.autograd import Variable 13 | from torch.utils.data import DataLoader 14 | from tensorboardX import SummaryWriter 15 | 16 | import settings 17 | from dataset import TestDataset 18 | from model import ODE_DerainNet 19 | from cal_ssim import SSIM 20 | 21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id 22 | logger = settings.logger 23 | torch.cuda.manual_seed_all(66) 24 | torch.manual_seed(66) 25 | #torch.cuda.set_device(settings.device_id) 26 | 27 | 28 | def ensure_dir(dir_path): 29 | if not os.path.isdir(dir_path): 30 | os.makedirs(dir_path) 31 | def PSNR(img1, img2): 32 | b,_,_,_=img1.shape 33 | #mse=0 34 | #for i in range(b): 35 | img1=np.clip(img1,0,255) 36 | img2=np.clip(img2,0,255) 37 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse 38 | if mse == 0: 39 | return 100 40 | #mse=mse/b 41 | PIXEL_MAX = 1 42 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 43 | class Session: 44 | def __init__(self): 45 | self.log_dir = settings.log_dir 46 | self.model_dir = settings.model_dir 47 | ensure_dir(settings.log_dir) 48 | ensure_dir(settings.model_dir) 49 | logger.info('set log dir as %s' % settings.log_dir) 50 | logger.info('set model dir as %s' % settings.model_dir) 51 | if len(settings.device_id) >1: 52 | self.net = nn.DataParallel(ODE_DerainNet()).cuda() 53 | else: 54 | torch.cuda.set_device(settings.device_id[0]) 55 | self.net = ODE_DerainNet().cuda() 56 | self.l2 = MSELoss().cuda() 57 | self.l1 = nn.L1Loss().cuda() 58 | self.ssim = SSIM().cuda() 59 | self.dataloaders = {} 60 | 61 | def get_dataloader(self, dataset_name): 62 | dataset = TestDataset(dataset_name) 63 | if not dataset_name in self.dataloaders: 64 | self.dataloaders[dataset_name] = \ 65 | DataLoader(dataset, batch_size=1, 66 | shuffle=False, num_workers=1, drop_last=False) 67 | return self.dataloaders[dataset_name] 68 | 69 | def load_checkpoints(self, name): 70 | ckp_path = os.path.join(self.model_dir, name) 71 | try: 72 | obj = torch.load(ckp_path) 73 | logger.info('Load checkpoint %s' % ckp_path) 74 | except FileNotFoundError: 75 | logger.info('No checkpoint %s!!' % ckp_path) 76 | return 77 | self.net.load_state_dict(obj['net']) 78 | 79 | def inf_batch(self, name, batch): 80 | O, B = batch['O'].cuda(), batch['B'].cuda() 81 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) 82 | 83 | with torch.no_grad(): 84 | derain = self.net(O) 85 | 86 | l1_loss = self.l1(derain, B) 87 | ssim = self.ssim(derain, B) 88 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) 89 | losses = { 'L1 loss' : l1_loss } 90 | ssimes = { 'ssim' : ssim } 91 | losses.update(ssimes) 92 | 93 | return losses, psnr 94 | 95 | 96 | def run_test(ckp_name): 97 | sess = Session() 98 | sess.net.eval() 99 | sess.load_checkpoints(ckp_name) 100 | dt = sess.get_dataloader('test') 101 | psnr_all = 0 102 | all_num = 0 103 | all_losses = {} 104 | for i, batch in enumerate(dt): 105 | losses,psnr= sess.inf_batch('test', batch) 106 | psnr_all=psnr_all+psnr 107 | batch_size = batch['O'].size(0) 108 | all_num += batch_size 109 | for key, val in losses.items(): 110 | if i == 0: 111 | all_losses[key] = 0. 112 | all_losses[key] += val * batch_size 113 | logger.info('batch %d mse %s: %f' % (i, key, val)) 114 | 115 | for key, val in all_losses.items(): 116 | logger.info('total mse %s: %f' % (key, val / all_num)) 117 | #psnr=sum(psnr_all) 118 | #print(psnr) 119 | print('psnr_ll:%8f'%(psnr_all/all_num)) 120 | 121 | if __name__ == '__main__': 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument('-m', '--model', default='latest_net') 124 | 125 | args = parser.parse_args(sys.argv[1:]) 126 | run_test(args.model) 127 | 128 | -------------------------------------------------------------------------------- /code/ablation/r2/config/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import settings 5 | from itertools import combinations,product 6 | import math 7 | 8 | class Residual_Block(nn.Module): 9 | def __init__(self, in_ch, out_ch, stride): 10 | super(Residual_Block, self).__init__() 11 | self.channel_num = settings.channel 12 | self.convs = nn.ModuleList() 13 | self.relus = nn.ModuleList() 14 | self.convert = nn.Sequential( 15 | nn.Conv2d(in_ch, out_ch, 3, stride, 1), 16 | nn.LeakyReLU(0.2) 17 | ) 18 | self.res = nn.Sequential( 19 | nn.Conv2d(out_ch, out_ch, 3, 1, 1), 20 | nn.LeakyReLU(0.2), 21 | nn.Conv2d(out_ch, out_ch, 3, 1, 1), 22 | ) 23 | 24 | def forward(self, x): 25 | convert = self.convert(x) 26 | out = convert + self.res(convert) 27 | return out 28 | 29 | class SCConv(nn.Module): 30 | def __init__(self, planes, pooling_r): 31 | super(SCConv, self).__init__() 32 | self.k2 = nn.Sequential( 33 | nn.AvgPool2d(kernel_size=pooling_r, stride=pooling_r), 34 | nn.Conv2d(planes, planes, 3, 1, 1), 35 | ) 36 | self.k3 = nn.Sequential( 37 | nn.Conv2d(planes, planes, 3, 1, 1), 38 | ) 39 | self.k4 = nn.Sequential( 40 | nn.Conv2d(planes, planes, 3, 1, 1), 41 | nn.LeakyReLU(0.2), 42 | ) 43 | 44 | def forward(self, x): 45 | identity = x 46 | 47 | out = torch.sigmoid(torch.add(identity, F.interpolate(self.k2(x), identity.size()[2:]))) # sigmoid(identity + k2) 48 | out = torch.mul(self.k3(x), out) # k3 * sigmoid(identity + k2) 49 | out = self.k4(out) # k4 50 | 51 | return out 52 | 53 | class SCBottleneck(nn.Module): 54 | #expansion = 4 55 | pooling_r = 4 # down-sampling rate of the avg pooling layer in the K3 path of SC-Conv. 56 | 57 | def __init__(self, in_planes, planes): 58 | super(SCBottleneck, self).__init__() 59 | planes = int(planes / 2) 60 | 61 | self.conv1_a = nn.Conv2d(in_planes, planes, 1, 1) 62 | self.k1 = nn.Sequential( 63 | nn.Conv2d(planes, planes, 3, 1, 1), 64 | nn.LeakyReLU(0.2), 65 | ) 66 | 67 | self.conv1_b = nn.Conv2d(in_planes, planes, 1, 1) 68 | 69 | self.scconv = SCConv(planes, self.pooling_r) 70 | 71 | self.conv3 = nn.Conv2d(planes * 2, planes * 2, 1, 1) 72 | self.relu = nn.LeakyReLU(0.2) 73 | 74 | def forward(self, x): 75 | residual = x 76 | 77 | out_a= self.conv1_a(x) 78 | out_a = self.relu(out_a) 79 | 80 | out_a = self.k1(out_a) 81 | 82 | out_b = self.conv1_b(x) 83 | out_b = self.relu(out_b) 84 | 85 | out_b = self.scconv(out_b) 86 | 87 | out = self.conv3(torch.cat([out_a, out_b], dim=1)) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | class Scale_attention(nn.Module): 95 | def __init__(self): 96 | super(Scale_attention, self).__init__() 97 | self.scale_attention = nn.ModuleList() 98 | self.res_list = nn.ModuleList() 99 | self.channel = settings.channel 100 | if settings.scale_attention is True: 101 | for i in range(settings.num_scale_attention): 102 | self.scale_attention.append( 103 | nn.Sequential( 104 | nn.MaxPool2d(2 ** (i + 1), 2 ** (i + 1)), 105 | nn.Conv2d(self.channel, self.channel, 1, 1), 106 | nn.Sigmoid() 107 | ) 108 | ) 109 | for i in range(settings.num_scale_attention): 110 | self.res_list.append( 111 | Residual_Block(self.channel, self.channel, 2) 112 | ) 113 | 114 | self.conv11 = nn.Sequential( 115 | nn.Conv2d((settings.num_scale_attention + 1) * self.channel, self.channel, 1, 1), 116 | nn.LeakyReLU(0.2) 117 | ) 118 | self.scn = SCBottleneck(self.channel, self.channel) 119 | 120 | def forward(self, x): 121 | b, c, h, w = x.size() 122 | temp = x 123 | out = [] 124 | out.append(temp) 125 | if settings.scale_attention is True: 126 | for i in range(settings.num_scale_attention): 127 | temp = self.res_list[i](temp) 128 | b0,c0,h0,w0 = temp.size() 129 | temp = temp * F.upsample(self.scale_attention[i](x), [h0, w0]) 130 | up = temp 131 | out.append(F.upsample(up, [h, w])) 132 | fusion = self.conv11(torch.cat(out, dim=1)) 133 | 134 | else: 135 | for i in range(settings.num_scale_attention): 136 | temp = self.res_list[i](temp) 137 | up = temp 138 | out.append(F.upsample(up, [h, w])) 139 | fusion = self.conv11(torch.cat(out, dim=1)) 140 | out = self.scn(fusion + x) 141 | return out 142 | 143 | class DenseConnection(nn.Module): 144 | def __init__(self, unit, unit_num): 145 | super(DenseConnection, self).__init__() 146 | self.unit_num = unit_num 147 | self.channel = settings.channel 148 | self.units = nn.ModuleList() 149 | self.conv1x1 = nn.ModuleList() 150 | for i in range(self.unit_num): 151 | self.units.append(unit()) 152 | self.conv1x1.append(nn.Sequential(nn.Conv2d((i+2)*self.channel, self.channel, 1, 1), nn.LeakyReLU(0.2))) 153 | 154 | def forward(self, x): 155 | cat = [] 156 | cat.append(x) 157 | out = x 158 | for i in range(self.unit_num): 159 | tmp = self.units[i](out) 160 | cat.append(tmp) 161 | out = self.conv1x1[i](torch.cat(cat,dim=1)) 162 | return out 163 | 164 | 165 | class ODE_DerainNet(nn.Module): 166 | def __init__(self): 167 | super(ODE_DerainNet, self).__init__() 168 | self.channel = settings.channel 169 | self.unit_num = settings.unit_num 170 | self.enterBlock = nn.Sequential(nn.Conv2d(3, self.channel, 3, 1, 1), nn.LeakyReLU(0.2)) 171 | self.derain_net = DenseConnection(Scale_attention, self.unit_num) 172 | self.exitBlock = nn.Sequential(nn.Conv2d(self.channel, 3, 3, 1, 1), nn.LeakyReLU(0.2)) 173 | 174 | 175 | def forward(self, x): 176 | image_feature = self.enterBlock(x) 177 | rain_feature = self.derain_net(image_feature) 178 | rain = self.exitBlock(rain_feature) 179 | derain = x - rain 180 | return derain 181 | -------------------------------------------------------------------------------- /code/ablation/r2/config/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | 5 | ####################### 网络参数设置  ############################################################ 6 | channel = 32 7 | feature_map_num = 32 8 | res_conv_num = 4 9 | unit_num = 32 10 | #scale_num = 4 11 | num_scale_attention = 4 12 | scale_attention = False 13 | ssim_loss = True 14 | ######################################################################################## 15 | aug_data = False # Set as False for fair comparison 16 | 17 | patch_size = 64 18 | pic_is_pair = True #input picture is pair or single 19 | 20 | lr = 0.0005 21 | 22 | data_dir = '/data1/wangcong/dataset/rain12' 23 | if pic_is_pair is False: 24 | data_dir = '/data1/wangcong/dataset/real-world-images' 25 | log_dir = '../logdir' 26 | show_dir = '../showdir' 27 | model_dir = '../models' 28 | show_dir_feature = '../showdir_feature' 29 | 30 | log_level = 'info' 31 | model_path = os.path.join(model_dir, 'latest_net') 32 | save_steps = 400 33 | 34 | num_workers = 8 35 | 36 | num_GPU = 2 37 | 38 | device_id = '2,1' 39 | 40 | epoch = 1000 41 | batch_size = 24 42 | 43 | if pic_is_pair: 44 | root_dir = os.path.join(data_dir, 'train') 45 | mat_files = os.listdir(root_dir) 46 | num_datasets = len(mat_files) 47 | l1 = int(3/5 * epoch * num_datasets / batch_size) 48 | l2 = int(4/5 * epoch * num_datasets / batch_size) 49 | one_epoch = int(num_datasets/batch_size) 50 | total_step = int((epoch * num_datasets)/batch_size) 51 | 52 | logger = logging.getLogger('train') 53 | logger.setLevel(logging.INFO) 54 | 55 | ch = logging.StreamHandler() 56 | ch.setLevel(logging.INFO) 57 | 58 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 59 | ch.setFormatter(formatter) 60 | logger.addHandler(ch) 61 | 62 | 63 | -------------------------------------------------------------------------------- /code/ablation/r2/config/show.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | import math 6 | import numpy as np 7 | import itertools 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import DataParallel 12 | from torch.optim import Adam 13 | from torch.autograd import Variable 14 | from torch.utils.data import DataLoader 15 | 16 | import settings 17 | from dataset import ShowDataset 18 | from model import ODE_DerainNet 19 | from cal_ssim import SSIM 20 | 21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id 22 | logger = settings.logger 23 | torch.cuda.manual_seed_all(66) 24 | torch.manual_seed(66) 25 | #torch.cuda.set_device(settings.device_id) 26 | 27 | 28 | def ensure_dir(dir_path): 29 | if not os.path.isdir(dir_path): 30 | os.makedirs(dir_path) 31 | 32 | def PSNR(img1, img2): 33 | b,_,_,_=img1.shape 34 | #mse=0 35 | #for i in range(b): 36 | img1=np.clip(img1,0,255) 37 | img2=np.clip(img2,0,255) 38 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse 39 | if mse == 0: 40 | return 100 41 | #mse=mse/b 42 | PIXEL_MAX = 1 43 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 44 | 45 | class Session: 46 | def __init__(self): 47 | self.show_dir = settings.show_dir 48 | self.model_dir = settings.model_dir 49 | ensure_dir(settings.show_dir) 50 | ensure_dir(settings.model_dir) 51 | logger.info('set show dir as %s' % settings.show_dir) 52 | logger.info('set model dir as %s' % settings.model_dir) 53 | 54 | if len(settings.device_id) >1: 55 | self.net = nn.DataParallel(ODE_DerainNet()).cuda() 56 | #self.l2 = nn.DataParallel(MSELoss(),settings.device_id) 57 | #self.l1 = nn.DataParallel(nn.L1Loss(),settings.device_id) 58 | #self.ssim = nn.DataParallel(SSIM(),settings.device_id) 59 | #self.vgg = nn.DataParallel(VGG(),settings.device_id) 60 | else: 61 | torch.cuda.set_device(settings.device_id[0]) 62 | self.net = ODE_DerainNet().cuda() 63 | self.ssim = SSIM().cuda() 64 | self.dataloaders = {} 65 | self.ssim=SSIM().cuda() 66 | self.a=0 67 | self.t=0 68 | def get_dataloader(self, dataset_name): 69 | dataset = ShowDataset(dataset_name) 70 | self.dataloaders[dataset_name] = \ 71 | DataLoader(dataset, batch_size=1, 72 | shuffle=False, num_workers=1) 73 | return self.dataloaders[dataset_name] 74 | 75 | def load_checkpoints(self, name): 76 | ckp_path = os.path.join(self.model_dir, name) 77 | try: 78 | obj = torch.load(ckp_path) 79 | logger.info('Load checkpoint %s' % ckp_path) 80 | except FileNotFoundError: 81 | logger.info('No checkpoint %s!!' % ckp_path) 82 | return 83 | self.net.load_state_dict(obj['net']) 84 | def inf_batch(self, name, batch,i): 85 | O, B, file_name= batch['O'].cuda(), batch['B'].cuda(), batch['file_name'] 86 | file_name = str(file_name[0]) 87 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) 88 | with torch.no_grad(): 89 | import time 90 | t0=time.time() 91 | derain = self.net(O) 92 | t1 = time.time() 93 | comput_time=t1-t0 94 | print(comput_time) 95 | ssim = self.ssim(derain, B).data.cpu().numpy() 96 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) 97 | print('psnr:%4f-------------ssim:%4f'%(psnr, ssim)) 98 | return derain, psnr, ssim, file_name 99 | 100 | def save_image(self, No, imgs, name, psnr, ssim, file_name): 101 | for i, img in enumerate(imgs): 102 | img = (img.cpu().data * 255).numpy() 103 | img = np.clip(img, 0, 255) 104 | img = np.transpose(img, (1, 2, 0)) 105 | h, w, c = img.shape 106 | 107 | img_file = os.path.join(self.show_dir, '%s.png' % (file_name)) 108 | print(img_file) 109 | cv2.imwrite(img_file, img) 110 | 111 | 112 | def run_show(ckp_name): 113 | sess = Session() 114 | sess.load_checkpoints(ckp_name) 115 | sess.net.eval() 116 | dataset = 'test' 117 | if settings.pic_is_pair is False: 118 | dataset = 'train-small' 119 | dt = sess.get_dataloader(dataset) 120 | 121 | for i, batch in enumerate(dt): 122 | logger.info(i) 123 | if i>-1: 124 | imgs,psnr,ssim, file_name= sess.inf_batch('test', batch,i) 125 | sess.save_image(i, imgs, dataset, psnr, ssim, file_name) 126 | 127 | 128 | 129 | if __name__ == '__main__': 130 | parser = argparse.ArgumentParser() 131 | parser.add_argument('-m', '--model', default='latest_net') 132 | 133 | args = parser.parse_args(sys.argv[1:]) 134 | 135 | run_show(args.model) 136 | 137 | -------------------------------------------------------------------------------- /code/ablation/r2/config/tensorboard.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | function rand() { 3 | min=$1 4 | max=$(($2-$min+1)) 5 | num=$(($RANDOM+1000000000000)) 6 | echo $(($num%$max+$min)) 7 | } 8 | 9 | rnd=$(rand 3000 12000) 10 | tensorboard --logdir ../logdir --host 0.0.0.0 --port $rnd --reload_interval 3 11 | -------------------------------------------------------------------------------- /code/ablation/r2/config/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | import math 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import MSELoss 10 | from torch.optim import Adam 11 | from torch.optim.lr_scheduler import MultiStepLR 12 | from torch.autograd import Variable 13 | from torch.utils.data import DataLoader 14 | from tensorboardX import SummaryWriter 15 | 16 | import settings 17 | from dataset import TrainValDataset, TestDataset 18 | from model import ODE_DerainNet 19 | from cal_ssim import SSIM 20 | 21 | logger = settings.logger 22 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id 23 | torch.cuda.manual_seed_all(66) 24 | torch.manual_seed(66) 25 | import numpy as np 26 | 27 | 28 | def ensure_dir(dir_path): 29 | if not os.path.isdir(dir_path): 30 | os.makedirs(dir_path) 31 | 32 | def PSNR(img1, img2): 33 | b,_,_,_=img1.shape 34 | img1=np.clip(img1,0,255) 35 | img2=np.clip(img2,0,255) 36 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2) 37 | if mse == 0: 38 | return 100 39 | 40 | PIXEL_MAX = 1 41 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 42 | 43 | class Session: 44 | def __init__(self): 45 | self.log_dir = settings.log_dir 46 | self.model_dir = settings.model_dir 47 | self.ssim_loss = settings.ssim_loss 48 | ensure_dir(settings.log_dir) 49 | ensure_dir(settings.model_dir) 50 | ensure_dir('../log_test') 51 | logger.info('set log dir as %s' % settings.log_dir) 52 | logger.info('set model dir as %s' % settings.model_dir) 53 | if len(settings.device_id) >1: 54 | self.net = nn.DataParallel(ODE_DerainNet()).cuda() 55 | else: 56 | torch.cuda.set_device(settings.device_id[0]) 57 | self.net = ODE_DerainNet().cuda() 58 | self.l1 = nn.L1Loss().cuda() 59 | self.ssim = SSIM().cuda() 60 | self.step = 0 61 | self.save_steps = settings.save_steps 62 | self.num_workers = settings.num_workers 63 | self.batch_size = settings.batch_size 64 | self.writers = {} 65 | self.dataloaders = {} 66 | self.opt_net = Adam(self.net.parameters(), lr=settings.lr) 67 | self.sche_net = MultiStepLR(self.opt_net, milestones=[settings.l1, settings.l2], gamma=0.1) 68 | 69 | def tensorboard(self, name): 70 | self.writers[name] = SummaryWriter(os.path.join(self.log_dir, name + '.events')) 71 | return self.writers[name] 72 | 73 | def write(self, name, out): 74 | for k, v in out.items(): 75 | self.writers[name].add_scalar(k, v, self.step) 76 | out['lr'] = self.opt_net.param_groups[0]['lr'] 77 | out['step'] = self.step 78 | outputs = [ 79 | "{}:{:.4g}".format(k, v) 80 | for k, v in out.items() 81 | ] 82 | logger.info(name + '--' + ' '.join(outputs)) 83 | 84 | def get_dataloader(self, dataset_name): 85 | dataset = TrainValDataset(dataset_name) 86 | if not dataset_name in self.dataloaders: 87 | self.dataloaders[dataset_name] = \ 88 | DataLoader(dataset, batch_size=self.batch_size, 89 | shuffle=True, num_workers=self.num_workers, drop_last=True) 90 | return iter(self.dataloaders[dataset_name]) 91 | 92 | def get_test_dataloader(self, dataset_name): 93 | dataset = TestDataset(dataset_name) 94 | if not dataset_name in self.dataloaders: 95 | self.dataloaders[dataset_name] = \ 96 | DataLoader(dataset, batch_size=1, 97 | shuffle=False, num_workers=1, drop_last=False) 98 | return self.dataloaders[dataset_name] 99 | 100 | def save_checkpoints_net(self, name): 101 | ckp_path = os.path.join(self.model_dir, name) 102 | obj = { 103 | 'net': self.net.state_dict(), 104 | 'clock_net': self.step, 105 | 'opt_net': self.opt_net.state_dict(), 106 | } 107 | torch.save(obj, ckp_path) 108 | 109 | def load_checkpoints_net(self, name): 110 | ckp_path = os.path.join(self.model_dir, name) 111 | try: 112 | logger.info('Load checkpoint %s' % ckp_path) 113 | obj = torch.load(ckp_path) 114 | except FileNotFoundError: 115 | logger.info('No checkpoint %s!!' % ckp_path) 116 | return 117 | self.net.load_state_dict(obj['net']) 118 | self.opt_net.load_state_dict(obj['opt_net']) 119 | self.step = obj['clock_net'] 120 | self.sche_net.last_epoch = self.step 121 | 122 | def print_network(self, model): 123 | num_params = 0 124 | for p in model.parameters(): 125 | num_params += p.numel() 126 | print(model) 127 | print("The number of parameters: {}".format(num_params)) 128 | 129 | def inf_batch(self, name, batch): 130 | if name == 'train': 131 | self.net.zero_grad() 132 | if self.step==0: 133 | self.print_network(self.net) 134 | 135 | O, B = batch['O'].cuda(), batch['B'].cuda() 136 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) 137 | derain = self.net(O) 138 | l1_loss = self.l1(derain, B) 139 | ssim = self.ssim(derain, B) 140 | if self.ssim_loss == True : 141 | loss = -ssim 142 | else: 143 | loss = l1_loss 144 | if name == 'train': 145 | loss.backward() 146 | self.opt_net.step() 147 | losses = {'L1loss' : l1_loss} 148 | ssimes = {'ssim' : ssim} 149 | losses.update(ssimes) 150 | self.write(name, losses) 151 | 152 | return derain 153 | 154 | def save_image(self, name, img_lists): 155 | data, pred, label = img_lists 156 | pred = pred.cpu().data 157 | 158 | data, label, pred = data * 255, label * 255, pred * 255 159 | pred = np.clip(pred, 0, 255) 160 | h, w = pred.shape[-2:] 161 | gen_num = (1, 1) 162 | img = np.zeros((gen_num[0] * h, gen_num[1] * 3 * w, 3)) 163 | for img_list in img_lists: 164 | for i in range(gen_num[0]): 165 | row = i * h 166 | for j in range(gen_num[1]): 167 | idx = i * gen_num[1] + j 168 | tmp_list = [data[idx], pred[idx], label[idx]] 169 | for k in range(3): 170 | col = (j * 3 + k) * w 171 | tmp = np.transpose(tmp_list[k], (1, 2, 0)) 172 | img[row: row+h, col: col+w] = tmp 173 | img_file = os.path.join(self.log_dir, '%d_%s.jpg' % (self.step, name)) 174 | cv2.imwrite(img_file, img) 175 | 176 | def inf_batch_test(self, name, batch): 177 | O, B = batch['O'].cuda(), batch['B'].cuda() 178 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) 179 | 180 | with torch.no_grad(): 181 | derain = self.net(O) 182 | 183 | l1_loss = self.l1(derain, B) 184 | ssim = self.ssim(derain, B) 185 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) 186 | losses = { 'L1 loss' : l1_loss } 187 | ssimes = { 'ssim' : ssim } 188 | losses.update(ssimes) 189 | 190 | return l1_loss.data.cpu().numpy(), ssim.data.cpu().numpy(), psnr 191 | 192 | 193 | def run_train_val(ckp_name_net='latest_net'): 194 | sess = Session() 195 | sess.load_checkpoints_net(ckp_name_net) 196 | sess.tensorboard('train') 197 | dt_train = sess.get_dataloader('train') 198 | while sess.step < settings.total_step+1: 199 | sess.sche_net.step() 200 | sess.net.train() 201 | try: 202 | batch_t = next(dt_train) 203 | except StopIteration: 204 | dt_train = sess.get_dataloader('train') 205 | batch_t = next(dt_train) 206 | pred_t = sess.inf_batch('train', batch_t) 207 | if sess.step % int(sess.save_steps / 16) == 0: 208 | sess.save_checkpoints_net('latest_net') 209 | if sess.step % sess.save_steps == 0: 210 | sess.save_image('train', [batch_t['O'], pred_t, batch_t['B']]) 211 | 212 | #observe tendency of ssim, psnr and loss 213 | ssim_all = 0 214 | psnr_all = 0 215 | loss_all = 0 216 | num_all = 0 217 | if sess.step % (settings.one_epoch * 20) == 0: 218 | dt_val = sess.get_test_dataloader('test') 219 | sess.net.eval() 220 | for i, batch_v in enumerate(dt_val): 221 | loss, ssim, psnr = sess.inf_batch_test('test', batch_v) 222 | print(i) 223 | ssim_all = ssim_all + ssim 224 | psnr_all = psnr_all + psnr 225 | loss_all = loss_all + loss 226 | num_all = num_all + 1 227 | print('num_all:',num_all) 228 | loss_avg = loss_all / num_all 229 | ssim_avg = ssim_all / num_all 230 | psnr_avg = psnr_all / num_all 231 | logfile = open('../log_test/' + 'val' + '.txt','a+') 232 | epoch = int(sess.step / settings.one_epoch) 233 | logfile.write( 234 | 'step = ' + str(sess.step) + '\t' 235 | 'epoch = ' + str(epoch) + '\t' 236 | 'loss = ' + str(loss_avg) + '\t' 237 | 'ssim = ' + str(ssim_avg) + '\t' 238 | 'pnsr = ' + str(psnr_avg) + '\t' 239 | '\n\n' 240 | ) 241 | logfile.close() 242 | if sess.step % (settings.one_epoch*10) == 0: 243 | sess.save_checkpoints_net('net_%d_epoch' % int(sess.step / settings.one_epoch)) 244 | logger.info('save model as net_%d_epoch' % int(sess.step / settings.one_epoch)) 245 | sess.step += 1 246 | 247 | 248 | 249 | if __name__ == '__main__': 250 | parser = argparse.ArgumentParser() 251 | parser.add_argument('-m1', '--model_1', default='latest_net') 252 | 253 | args = parser.parse_args(sys.argv[1:]) 254 | run_train_val(args.model_1) 255 | 256 | -------------------------------------------------------------------------------- /code/ablation/r2/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/ablation/r2/models/.DS_Store -------------------------------------------------------------------------------- /code/ablation/r2/models/README.md: -------------------------------------------------------------------------------- 1 | # JDNet:Joint Self-Attention and Scale-Aggregation for Self-Calibrated Deraining Network 2 | 3 | The training model or pre-training model will be placed here. 4 | -------------------------------------------------------------------------------- /code/base/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/.DS_Store -------------------------------------------------------------------------------- /code/base/rain100H/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/.DS_Store -------------------------------------------------------------------------------- /code/base/rain100H/config/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/config/.DS_Store -------------------------------------------------------------------------------- /code/base/rain100H/config/cal_ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map.mean(1).mean(1).mean(1) 38 | 39 | class SSIM(torch.nn.Module): 40 | def __init__(self, window_size = 11, size_average = True): 41 | super(SSIM, self).__init__() 42 | self.window_size = window_size 43 | self.size_average = size_average 44 | self.channel = 3 45 | self.window = create_window(window_size, self.channel) 46 | 47 | def forward(self, img1, img2): 48 | #print(img1.size()) 49 | (_, channel, _, _) = img1.size() 50 | 51 | if channel == self.channel and self.window.data.type() == img1.data.type(): 52 | window = self.window 53 | else: 54 | window = create_window(self.window_size, channel) 55 | 56 | if img1.is_cuda: 57 | window = window.cuda(img1.get_device()) 58 | window = window.type_as(img1) 59 | 60 | self.window = window 61 | self.channel = channel 62 | 63 | 64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 65 | 66 | def ssim(img1, img2, window_size = 11, size_average = True): 67 | (_, channel, _, _) = img1.size() 68 | window = create_window(window_size, channel) 69 | 70 | if img1.is_cuda: 71 | window = window.cuda(img1.get_device()) 72 | window = window.type_as(img1) 73 | 74 | return _ssim(img1, img2, window, window_size, channel, size_average) 75 | -------------------------------------------------------------------------------- /code/base/rain100H/config/clean.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | rm -rf __pycache__ 3 | rm \.*\.swp 4 | rm -R ../logdir/* 5 | rm -R ../showdir/* 6 | rm ../models/* 7 | -------------------------------------------------------------------------------- /code/base/rain100H/config/compile.py: -------------------------------------------------------------------------------- 1 | import re 2 | rec_data="+MIPLOBSERVE:0,68220,1,3303,0,-1" 3 | msgidRegex = re.compile(r',(\d)+,') 4 | mo = msgidRegex.search(rec_data) 5 | print(mo.group()) -------------------------------------------------------------------------------- /code/base/rain100H/config/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from numpy.random import RandomState 5 | from torch.utils.data import Dataset 6 | 7 | import settings 8 | 9 | 10 | class TrainValDataset(Dataset): 11 | def __init__(self, name): 12 | super().__init__() 13 | self.rand_state = RandomState(66) 14 | self.root_dir = os.path.join(settings.data_dir, name) 15 | self.mat_files = os.listdir(self.root_dir) 16 | self.patch_size = settings.patch_size 17 | self.file_num = len(self.mat_files) 18 | 19 | def __len__(self): 20 | return self.file_num * 100 21 | 22 | def __getitem__(self, idx): 23 | file_name = self.mat_files[idx % self.file_num] 24 | img_file = os.path.join(self.root_dir, file_name) 25 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 26 | 27 | if settings.aug_data: 28 | O, B = self.crop(img_pair, aug=True) 29 | O, B = self.flip(O, B) 30 | O, B = self.rotate(O, B) 31 | else: 32 | O, B = self.crop(img_pair, aug=False) 33 | 34 | O = np.transpose(O, (2, 0, 1)) 35 | B = np.transpose(B, (2, 0, 1)) 36 | sample = {'O': O, 'B': B} 37 | 38 | return sample 39 | 40 | def crop(self, img_pair, aug): 41 | patch_size = self.patch_size 42 | h, ww, c = img_pair.shape 43 | w = int(ww / 2) 44 | 45 | if aug: 46 | mini = - 1 / 4 * self.patch_size 47 | maxi = 1 / 4 * self.patch_size + 1 48 | p_h = patch_size + self.rand_state.randint(mini, maxi) 49 | p_w = patch_size + self.rand_state.randint(mini, maxi) 50 | else: 51 | p_h, p_w = patch_size, patch_size 52 | 53 | r = self.rand_state.randint(0, h - p_h) 54 | c = self.rand_state.randint(0, w - p_w) 55 | O = img_pair[r: r+p_h, c+w: c+p_w+w] 56 | B = img_pair[r: r+p_h, c: c+p_w] 57 | 58 | if aug: 59 | O = cv2.resize(O, (patch_size, patch_size)) 60 | B = cv2.resize(B, (patch_size, patch_size)) 61 | 62 | return O, B 63 | 64 | def flip(self, O, B): 65 | if self.rand_state.rand() > 0.5: 66 | O = np.flip(O, axis=1) 67 | B = np.flip(B, axis=1) 68 | return O, B 69 | 70 | def rotate(self, O, B): 71 | angle = self.rand_state.randint(-30, 30) 72 | patch_size = self.patch_size 73 | center = (int(patch_size / 2), int(patch_size / 2)) 74 | M = cv2.getRotationMatrix2D(center, angle, 1) 75 | O = cv2.warpAffine(O, M, (patch_size, patch_size)) 76 | B = cv2.warpAffine(B, M, (patch_size, patch_size)) 77 | return O, B 78 | 79 | 80 | class TestDataset(Dataset): 81 | def __init__(self, name): 82 | super().__init__() 83 | self.rand_state = RandomState(66) 84 | self.root_dir = os.path.join(settings.data_dir, name) 85 | self.mat_files = os.listdir(self.root_dir) 86 | self.patch_size = settings.patch_size 87 | self.file_num = len(self.mat_files) 88 | 89 | def __len__(self): 90 | return self.file_num 91 | 92 | def __getitem__(self, idx): 93 | file_name = self.mat_files[idx % self.file_num] 94 | img_file = os.path.join(self.root_dir, file_name) 95 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 96 | h, ww, c = img_pair.shape 97 | w = int(ww / 2) 98 | #h_8=h%8 99 | #w_8=w%8 100 | O = np.transpose(img_pair[:, w:], (2, 0, 1)) 101 | B = np.transpose(img_pair[:, :w], (2, 0, 1)) 102 | sample = {'O': O, 'B': B} 103 | 104 | return sample 105 | 106 | 107 | class ShowDataset(Dataset): 108 | def __init__(self, name): 109 | super().__init__() 110 | self.rand_state = RandomState(66) 111 | self.root_dir = os.path.join(settings.data_dir, name) 112 | self.img_files = sorted(os.listdir(self.root_dir)) 113 | self.file_num = len(self.img_files) 114 | 115 | def __len__(self): 116 | return self.file_num 117 | 118 | def __getitem__(self, idx): 119 | file_name = self.img_files[idx % self.file_num] 120 | img_file = os.path.join(self.root_dir, file_name) 121 | print(img_file) 122 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 123 | 124 | h, ww, c = img_pair.shape 125 | w = int(ww / 2) 126 | 127 | #h_8 = h % 8 128 | #w_8 = w % 8 129 | if settings.pic_is_pair: 130 | O = np.transpose(img_pair[:, w:], (2, 0, 1)) 131 | B = np.transpose(img_pair[:, :w], (2, 0, 1)) 132 | else: 133 | O = np.transpose(img_pair[:, :], (2, 0, 1)) 134 | B = np.transpose(img_pair[:, :], (2, 0, 1)) 135 | sample = {'O': O, 'B': B,'file_name':file_name[:-4]} 136 | 137 | return sample 138 | 139 | 140 | if __name__ == '__main__': 141 | dt = TrainValDataset('val') 142 | print('TrainValDataset') 143 | for i in range(10): 144 | smp = dt[i] 145 | for k, v in smp.items(): 146 | print(k, v.shape, v.dtype, v.mean()) 147 | 148 | print() 149 | dt = TestDataset('test') 150 | print('TestDataset') 151 | for i in range(10): 152 | smp = dt[i] 153 | for k, v in smp.items(): 154 | print(k, v.shape, v.dtype, v.mean()) 155 | 156 | print() 157 | print('ShowDataset') 158 | dt = ShowDataset('test') 159 | for i in range(10): 160 | smp = dt[i] 161 | for k, v in smp.items(): 162 | print(k, v.shape, v.dtype, v.mean()) 163 | -------------------------------------------------------------------------------- /code/base/rain100H/config/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | import numpy as np 6 | import math 7 | import torch 8 | from torch import nn 9 | from torch.nn import MSELoss 10 | from torch.optim import Adam 11 | from torch.optim.lr_scheduler import MultiStepLR 12 | from torch.autograd import Variable 13 | from torch.utils.data import DataLoader 14 | from tensorboardX import SummaryWriter 15 | 16 | import settings 17 | from dataset import TestDataset 18 | from model import ODE_DerainNet 19 | from cal_ssim import SSIM 20 | 21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id 22 | logger = settings.logger 23 | torch.cuda.manual_seed_all(66) 24 | torch.manual_seed(66) 25 | #torch.cuda.set_device(settings.device_id) 26 | 27 | 28 | def ensure_dir(dir_path): 29 | if not os.path.isdir(dir_path): 30 | os.makedirs(dir_path) 31 | def PSNR(img1, img2): 32 | b,_,_,_=img1.shape 33 | #mse=0 34 | #for i in range(b): 35 | img1=np.clip(img1,0,255) 36 | img2=np.clip(img2,0,255) 37 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse 38 | if mse == 0: 39 | return 100 40 | #mse=mse/b 41 | PIXEL_MAX = 1 42 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 43 | class Session: 44 | def __init__(self): 45 | self.log_dir = settings.log_dir 46 | self.model_dir = settings.model_dir 47 | ensure_dir(settings.log_dir) 48 | ensure_dir(settings.model_dir) 49 | logger.info('set log dir as %s' % settings.log_dir) 50 | logger.info('set model dir as %s' % settings.model_dir) 51 | if len(settings.device_id) >1: 52 | self.net = nn.DataParallel(ODE_DerainNet()).cuda() 53 | else: 54 | torch.cuda.set_device(settings.device_id[0]) 55 | self.net = ODE_DerainNet().cuda() 56 | self.l2 = MSELoss().cuda() 57 | self.l1 = nn.L1Loss().cuda() 58 | self.ssim = SSIM().cuda() 59 | self.dataloaders = {} 60 | 61 | def get_dataloader(self, dataset_name): 62 | dataset = TestDataset(dataset_name) 63 | if not dataset_name in self.dataloaders: 64 | self.dataloaders[dataset_name] = \ 65 | DataLoader(dataset, batch_size=1, 66 | shuffle=False, num_workers=1, drop_last=False) 67 | return self.dataloaders[dataset_name] 68 | 69 | def load_checkpoints(self, name): 70 | ckp_path = os.path.join(self.model_dir, name) 71 | try: 72 | obj = torch.load(ckp_path) 73 | logger.info('Load checkpoint %s' % ckp_path) 74 | except FileNotFoundError: 75 | logger.info('No checkpoint %s!!' % ckp_path) 76 | return 77 | self.net.load_state_dict(obj['net']) 78 | 79 | def inf_batch(self, name, batch): 80 | O, B = batch['O'].cuda(), batch['B'].cuda() 81 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) 82 | 83 | with torch.no_grad(): 84 | derain = self.net(O) 85 | 86 | l1_loss = self.l1(derain, B) 87 | ssim = self.ssim(derain, B) 88 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) 89 | losses = { 'L1 loss' : l1_loss } 90 | ssimes = { 'ssim' : ssim } 91 | losses.update(ssimes) 92 | 93 | return losses, psnr 94 | 95 | 96 | def run_test(ckp_name): 97 | sess = Session() 98 | sess.net.eval() 99 | sess.load_checkpoints(ckp_name) 100 | dt = sess.get_dataloader('test') 101 | psnr_all = 0 102 | all_num = 0 103 | all_losses = {} 104 | for i, batch in enumerate(dt): 105 | losses,psnr= sess.inf_batch('test', batch) 106 | psnr_all=psnr_all+psnr 107 | batch_size = batch['O'].size(0) 108 | all_num += batch_size 109 | for key, val in losses.items(): 110 | if i == 0: 111 | all_losses[key] = 0. 112 | all_losses[key] += val * batch_size 113 | logger.info('batch %d mse %s: %f' % (i, key, val)) 114 | 115 | for key, val in all_losses.items(): 116 | logger.info('total mse %s: %f' % (key, val / all_num)) 117 | #psnr=sum(psnr_all) 118 | #print(psnr) 119 | print('psnr_ll:%8f'%(psnr_all/all_num)) 120 | 121 | if __name__ == '__main__': 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument('-m', '--model', default='latest_net') 124 | 125 | args = parser.parse_args(sys.argv[1:]) 126 | run_test(args.model) 127 | 128 | -------------------------------------------------------------------------------- /code/base/rain100H/config/function/__pycache__/functional.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/config/function/__pycache__/functional.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100H/config/function/functional.py: -------------------------------------------------------------------------------- 1 | from . import functions 2 | 3 | 4 | def aggregation(input, weight, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 5 | assert input.shape[0] == weight.shape[0] and (input.shape[1] % weight.shape[1] == 0) and pad_mode in [0, 1] 6 | if input.is_cuda: 7 | if pad_mode == 0: 8 | out = functions.aggregation_zeropad(input, weight, kernel_size, stride, padding, dilation) 9 | elif pad_mode == 1: 10 | out = functions.aggregation_refpad(input, weight, kernel_size, stride, padding, dilation) 11 | else: 12 | raise NotImplementedError 13 | return out 14 | 15 | 16 | def subtraction(input, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 17 | assert input.dim() == 4 and pad_mode in [0, 1] 18 | if input.is_cuda: 19 | if pad_mode == 0: 20 | out = functions.subtraction_zeropad(input, kernel_size, stride, padding, dilation) 21 | elif pad_mode == 1: 22 | out = functions.subtraction_refpad(input, kernel_size, stride, padding, dilation) 23 | else: 24 | raise NotImplementedError 25 | return out 26 | 27 | 28 | def subtraction2(input1, input2, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 29 | assert input1.dim() == 4 and input2.dim() == 4 and pad_mode in [0, 1] 30 | if input1.is_cuda: 31 | if pad_mode == 0: 32 | out = functions.subtraction2_zeropad(input1, input2, kernel_size, stride, padding, dilation) 33 | elif pad_mode == 1: 34 | out = functions.subtraction2_refpad(input1, input2, kernel_size, stride, padding, dilation) 35 | else: 36 | raise NotImplementedError 37 | return out 38 | -------------------------------------------------------------------------------- /code/base/rain100H/config/function/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregation_zeropad import * 2 | from .aggregation_refpad import * 3 | from .subtraction_zeropad import * 4 | from .subtraction_refpad import * 5 | from .subtraction2_zeropad import * 6 | from .subtraction2_refpad import * 7 | from .utils import * 8 | -------------------------------------------------------------------------------- /code/base/rain100H/config/function/functions/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/config/function/functions/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100H/config/function/functions/__pycache__/aggregation_refpad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/config/function/functions/__pycache__/aggregation_refpad.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100H/config/function/functions/__pycache__/aggregation_zeropad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/config/function/functions/__pycache__/aggregation_zeropad.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100H/config/function/functions/__pycache__/subtraction2_refpad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/config/function/functions/__pycache__/subtraction2_refpad.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100H/config/function/functions/__pycache__/subtraction2_zeropad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/config/function/functions/__pycache__/subtraction2_zeropad.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100H/config/function/functions/__pycache__/subtraction_refpad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/config/function/functions/__pycache__/subtraction_refpad.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100H/config/function/functions/__pycache__/subtraction_zeropad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/config/function/functions/__pycache__/subtraction_zeropad.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100H/config/function/functions/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/config/function/functions/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100H/config/function/functions/utils.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from string import Template 3 | import cupy 4 | import torch 5 | 6 | 7 | Stream = namedtuple('Stream', ['ptr']) 8 | 9 | 10 | def Dtype(t): 11 | if isinstance(t, torch.cuda.FloatTensor): 12 | return 'float' 13 | elif isinstance(t, torch.cuda.DoubleTensor): 14 | return 'double' 15 | 16 | 17 | @cupy.util.memoize(for_each_device=True) 18 | def load_kernel(kernel_name, code, **kwargs): 19 | code = Template(code).substitute(**kwargs) 20 | kernel_code = cupy.cuda.compile_with_cache(code) 21 | return kernel_code.get_function(kernel_name) 22 | -------------------------------------------------------------------------------- /code/base/rain100H/config/function/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregation import * 2 | from .subtraction import * 3 | from .subtraction2 import * 4 | -------------------------------------------------------------------------------- /code/base/rain100H/config/function/modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/config/function/modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100H/config/function/modules/__pycache__/aggregation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/config/function/modules/__pycache__/aggregation.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100H/config/function/modules/__pycache__/subtraction.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/config/function/modules/__pycache__/subtraction.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100H/config/function/modules/__pycache__/subtraction2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/config/function/modules/__pycache__/subtraction2.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100H/config/function/modules/aggregation.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Aggregation(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Aggregation, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input, weight): 18 | return F.aggregation(input, weight, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /code/base/rain100H/config/function/modules/subtraction.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Subtraction(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Subtraction, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input): 18 | return F.subtraction(input, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /code/base/rain100H/config/function/modules/subtraction2.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Subtraction2(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Subtraction2, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input1, input2): 18 | return F.subtraction2(input1, input2, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /code/base/rain100H/config/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | 5 | ####################### 网络参数设置  ############################################################ 6 | channel = 32 7 | feature_map_num = 32 8 | res_conv_num = 4 9 | unit_num = 32 10 | #scale_num = 4 11 | num_scale_attention = 4 12 | scale_attention = False 13 | ssim_loss = True 14 | ######################################################################################## 15 | aug_data = False # Set as False for fair comparison 16 | 17 | patch_size = 64 18 | pic_is_pair = True #input picture is pair or single 19 | 20 | lr = 0.0005 21 | 22 | data_dir = '/data1/wangcong/dataset/rain12' 23 | if pic_is_pair is False: 24 | data_dir = '/data1/wangcong/dataset/real-world-images' 25 | log_dir = '../logdir' 26 | show_dir = '../showdir' 27 | model_dir = '../models' 28 | show_dir_feature = '../showdir_feature' 29 | 30 | log_level = 'info' 31 | model_path = os.path.join(model_dir, 'latest_net') 32 | save_steps = 400 33 | 34 | num_workers = 8 35 | 36 | num_GPU = 2 37 | 38 | device_id = '1,2' 39 | 40 | epoch = 1000 41 | batch_size = 12 42 | 43 | if pic_is_pair: 44 | root_dir = os.path.join(data_dir, 'train') 45 | mat_files = os.listdir(root_dir) 46 | num_datasets = len(mat_files) 47 | l1 = int(3/5 * epoch * num_datasets / batch_size) 48 | l2 = int(4/5 * epoch * num_datasets / batch_size) 49 | one_epoch = int(num_datasets/batch_size) 50 | total_step = int((epoch * num_datasets)/batch_size) 51 | 52 | logger = logging.getLogger('train') 53 | logger.setLevel(logging.INFO) 54 | 55 | ch = logging.StreamHandler() 56 | ch.setLevel(logging.INFO) 57 | 58 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 59 | ch.setFormatter(formatter) 60 | logger.addHandler(ch) 61 | 62 | 63 | -------------------------------------------------------------------------------- /code/base/rain100H/config/show.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | import math 6 | import numpy as np 7 | import itertools 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import DataParallel 12 | from torch.optim import Adam 13 | from torch.autograd import Variable 14 | from torch.utils.data import DataLoader 15 | 16 | import settings 17 | from dataset import ShowDataset 18 | from model import ODE_DerainNet 19 | from cal_ssim import SSIM 20 | 21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id 22 | logger = settings.logger 23 | torch.cuda.manual_seed_all(66) 24 | torch.manual_seed(66) 25 | #torch.cuda.set_device(settings.device_id) 26 | 27 | 28 | def ensure_dir(dir_path): 29 | if not os.path.isdir(dir_path): 30 | os.makedirs(dir_path) 31 | 32 | def PSNR(img1, img2): 33 | b,_,_,_=img1.shape 34 | #mse=0 35 | #for i in range(b): 36 | img1=np.clip(img1,0,255) 37 | img2=np.clip(img2,0,255) 38 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse 39 | if mse == 0: 40 | return 100 41 | #mse=mse/b 42 | PIXEL_MAX = 1 43 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 44 | 45 | class Session: 46 | def __init__(self): 47 | self.show_dir = settings.show_dir 48 | self.model_dir = settings.model_dir 49 | ensure_dir(settings.show_dir) 50 | ensure_dir(settings.model_dir) 51 | logger.info('set show dir as %s' % settings.show_dir) 52 | logger.info('set model dir as %s' % settings.model_dir) 53 | 54 | if len(settings.device_id) >1: 55 | self.net = nn.DataParallel(ODE_DerainNet()).cuda() 56 | #self.l2 = nn.DataParallel(MSELoss(),settings.device_id) 57 | #self.l1 = nn.DataParallel(nn.L1Loss(),settings.device_id) 58 | #self.ssim = nn.DataParallel(SSIM(),settings.device_id) 59 | #self.vgg = nn.DataParallel(VGG(),settings.device_id) 60 | else: 61 | torch.cuda.set_device(settings.device_id[0]) 62 | self.net = ODE_DerainNet().cuda() 63 | self.ssim = SSIM().cuda() 64 | self.dataloaders = {} 65 | self.ssim=SSIM().cuda() 66 | self.a=0 67 | self.t=0 68 | def get_dataloader(self, dataset_name): 69 | dataset = ShowDataset(dataset_name) 70 | self.dataloaders[dataset_name] = \ 71 | DataLoader(dataset, batch_size=1, 72 | shuffle=False, num_workers=1) 73 | return self.dataloaders[dataset_name] 74 | 75 | def load_checkpoints(self, name): 76 | ckp_path = os.path.join(self.model_dir, name) 77 | try: 78 | obj = torch.load(ckp_path) 79 | logger.info('Load checkpoint %s' % ckp_path) 80 | except FileNotFoundError: 81 | logger.info('No checkpoint %s!!' % ckp_path) 82 | return 83 | self.net.load_state_dict(obj['net']) 84 | def inf_batch(self, name, batch,i): 85 | O, B, file_name= batch['O'].cuda(), batch['B'].cuda(), batch['file_name'] 86 | file_name = str(file_name[0]) 87 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) 88 | with torch.no_grad(): 89 | import time 90 | t0=time.time() 91 | derain = self.net(O) 92 | t1 = time.time() 93 | comput_time=t1-t0 94 | print(comput_time) 95 | ssim = self.ssim(derain, B).data.cpu().numpy() 96 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) 97 | print('psnr:%4f-------------ssim:%4f'%(psnr, ssim)) 98 | return derain, psnr, ssim, file_name 99 | 100 | def save_image(self, No, imgs, name, psnr, ssim, file_name): 101 | for i, img in enumerate(imgs): 102 | img = (img.cpu().data * 255).numpy() 103 | img = np.clip(img, 0, 255) 104 | img = np.transpose(img, (1, 2, 0)) 105 | h, w, c = img.shape 106 | 107 | img_file = os.path.join(self.show_dir, '%s.png' % (file_name)) 108 | print(img_file) 109 | cv2.imwrite(img_file, img) 110 | 111 | 112 | def run_show(ckp_name): 113 | sess = Session() 114 | sess.load_checkpoints(ckp_name) 115 | sess.net.eval() 116 | dataset = 'test' 117 | if settings.pic_is_pair is False: 118 | dataset = 'train-small' 119 | dt = sess.get_dataloader(dataset) 120 | 121 | for i, batch in enumerate(dt): 122 | logger.info(i) 123 | if i>-1: 124 | imgs,psnr,ssim, file_name= sess.inf_batch('test', batch,i) 125 | sess.save_image(i, imgs, dataset, psnr, ssim, file_name) 126 | 127 | 128 | 129 | if __name__ == '__main__': 130 | parser = argparse.ArgumentParser() 131 | parser.add_argument('-m', '--model', default='latest_net') 132 | 133 | args = parser.parse_args(sys.argv[1:]) 134 | 135 | run_show(args.model) 136 | 137 | -------------------------------------------------------------------------------- /code/base/rain100H/config/tensorboard.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | function rand() { 3 | min=$1 4 | max=$(($2-$min+1)) 5 | num=$(($RANDOM+1000000000000)) 6 | echo $(($num%$max+$min)) 7 | } 8 | 9 | rnd=$(rand 3000 12000) 10 | tensorboard --logdir ../logdir --host 0.0.0.0 --port $rnd --reload_interval 3 11 | -------------------------------------------------------------------------------- /code/base/rain100H/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100H/models/.DS_Store -------------------------------------------------------------------------------- /code/base/rain100H/models/README.md: -------------------------------------------------------------------------------- 1 | # JDNet:Joint Self-Attention and Scale-Aggregation for Self-Calibrated Deraining Network 2 | 3 | The training model or pre-training model will be placed here. 4 | -------------------------------------------------------------------------------- /code/base/rain100L/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/.DS_Store -------------------------------------------------------------------------------- /code/base/rain100L/config/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/config/.DS_Store -------------------------------------------------------------------------------- /code/base/rain100L/config/cal_ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map.mean(1).mean(1).mean(1) 38 | 39 | class SSIM(torch.nn.Module): 40 | def __init__(self, window_size = 11, size_average = True): 41 | super(SSIM, self).__init__() 42 | self.window_size = window_size 43 | self.size_average = size_average 44 | self.channel = 3 45 | self.window = create_window(window_size, self.channel) 46 | 47 | def forward(self, img1, img2): 48 | #print(img1.size()) 49 | (_, channel, _, _) = img1.size() 50 | 51 | if channel == self.channel and self.window.data.type() == img1.data.type(): 52 | window = self.window 53 | else: 54 | window = create_window(self.window_size, channel) 55 | 56 | if img1.is_cuda: 57 | window = window.cuda(img1.get_device()) 58 | window = window.type_as(img1) 59 | 60 | self.window = window 61 | self.channel = channel 62 | 63 | 64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 65 | 66 | def ssim(img1, img2, window_size = 11, size_average = True): 67 | (_, channel, _, _) = img1.size() 68 | window = create_window(window_size, channel) 69 | 70 | if img1.is_cuda: 71 | window = window.cuda(img1.get_device()) 72 | window = window.type_as(img1) 73 | 74 | return _ssim(img1, img2, window, window_size, channel, size_average) 75 | -------------------------------------------------------------------------------- /code/base/rain100L/config/clean.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | rm -rf __pycache__ 3 | rm \.*\.swp 4 | rm -R ../logdir/* 5 | rm -R ../showdir/* 6 | rm ../models/* 7 | -------------------------------------------------------------------------------- /code/base/rain100L/config/compile.py: -------------------------------------------------------------------------------- 1 | import re 2 | rec_data="+MIPLOBSERVE:0,68220,1,3303,0,-1" 3 | msgidRegex = re.compile(r',(\d)+,') 4 | mo = msgidRegex.search(rec_data) 5 | print(mo.group()) -------------------------------------------------------------------------------- /code/base/rain100L/config/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from numpy.random import RandomState 5 | from torch.utils.data import Dataset 6 | 7 | import settings 8 | 9 | 10 | class TrainValDataset(Dataset): 11 | def __init__(self, name): 12 | super().__init__() 13 | self.rand_state = RandomState(66) 14 | self.root_dir = os.path.join(settings.data_dir, name) 15 | self.mat_files = os.listdir(self.root_dir) 16 | self.patch_size = settings.patch_size 17 | self.file_num = len(self.mat_files) 18 | 19 | def __len__(self): 20 | return self.file_num * 100 21 | 22 | def __getitem__(self, idx): 23 | file_name = self.mat_files[idx % self.file_num] 24 | img_file = os.path.join(self.root_dir, file_name) 25 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 26 | 27 | if settings.aug_data: 28 | O, B = self.crop(img_pair, aug=True) 29 | O, B = self.flip(O, B) 30 | O, B = self.rotate(O, B) 31 | else: 32 | O, B = self.crop(img_pair, aug=False) 33 | 34 | O = np.transpose(O, (2, 0, 1)) 35 | B = np.transpose(B, (2, 0, 1)) 36 | sample = {'O': O, 'B': B} 37 | 38 | return sample 39 | 40 | def crop(self, img_pair, aug): 41 | patch_size = self.patch_size 42 | h, ww, c = img_pair.shape 43 | w = int(ww / 2) 44 | 45 | if aug: 46 | mini = - 1 / 4 * self.patch_size 47 | maxi = 1 / 4 * self.patch_size + 1 48 | p_h = patch_size + self.rand_state.randint(mini, maxi) 49 | p_w = patch_size + self.rand_state.randint(mini, maxi) 50 | else: 51 | p_h, p_w = patch_size, patch_size 52 | 53 | r = self.rand_state.randint(0, h - p_h) 54 | c = self.rand_state.randint(0, w - p_w) 55 | O = img_pair[r: r+p_h, c+w: c+p_w+w] 56 | B = img_pair[r: r+p_h, c: c+p_w] 57 | 58 | if aug: 59 | O = cv2.resize(O, (patch_size, patch_size)) 60 | B = cv2.resize(B, (patch_size, patch_size)) 61 | 62 | return O, B 63 | 64 | def flip(self, O, B): 65 | if self.rand_state.rand() > 0.5: 66 | O = np.flip(O, axis=1) 67 | B = np.flip(B, axis=1) 68 | return O, B 69 | 70 | def rotate(self, O, B): 71 | angle = self.rand_state.randint(-30, 30) 72 | patch_size = self.patch_size 73 | center = (int(patch_size / 2), int(patch_size / 2)) 74 | M = cv2.getRotationMatrix2D(center, angle, 1) 75 | O = cv2.warpAffine(O, M, (patch_size, patch_size)) 76 | B = cv2.warpAffine(B, M, (patch_size, patch_size)) 77 | return O, B 78 | 79 | 80 | class TestDataset(Dataset): 81 | def __init__(self, name): 82 | super().__init__() 83 | self.rand_state = RandomState(66) 84 | self.root_dir = os.path.join(settings.data_dir, name) 85 | self.mat_files = os.listdir(self.root_dir) 86 | self.patch_size = settings.patch_size 87 | self.file_num = len(self.mat_files) 88 | 89 | def __len__(self): 90 | return self.file_num 91 | 92 | def __getitem__(self, idx): 93 | file_name = self.mat_files[idx % self.file_num] 94 | img_file = os.path.join(self.root_dir, file_name) 95 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 96 | h, ww, c = img_pair.shape 97 | w = int(ww / 2) 98 | #h_8=h%8 99 | #w_8=w%8 100 | O = np.transpose(img_pair[:, w:], (2, 0, 1)) 101 | B = np.transpose(img_pair[:, :w], (2, 0, 1)) 102 | sample = {'O': O, 'B': B} 103 | 104 | return sample 105 | 106 | 107 | class ShowDataset(Dataset): 108 | def __init__(self, name): 109 | super().__init__() 110 | self.rand_state = RandomState(66) 111 | self.root_dir = os.path.join(settings.data_dir, name) 112 | self.img_files = sorted(os.listdir(self.root_dir)) 113 | self.file_num = len(self.img_files) 114 | 115 | def __len__(self): 116 | return self.file_num 117 | 118 | def __getitem__(self, idx): 119 | file_name = self.img_files[idx % self.file_num] 120 | img_file = os.path.join(self.root_dir, file_name) 121 | print(img_file) 122 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 123 | 124 | h, ww, c = img_pair.shape 125 | w = int(ww / 2) 126 | 127 | #h_8 = h % 8 128 | #w_8 = w % 8 129 | if settings.pic_is_pair: 130 | O = np.transpose(img_pair[:, w:], (2, 0, 1)) 131 | B = np.transpose(img_pair[:, :w], (2, 0, 1)) 132 | else: 133 | O = np.transpose(img_pair[:, :], (2, 0, 1)) 134 | B = np.transpose(img_pair[:, :], (2, 0, 1)) 135 | sample = {'O': O, 'B': B,'file_name':file_name[:-4]} 136 | 137 | return sample 138 | 139 | 140 | if __name__ == '__main__': 141 | dt = TrainValDataset('val') 142 | print('TrainValDataset') 143 | for i in range(10): 144 | smp = dt[i] 145 | for k, v in smp.items(): 146 | print(k, v.shape, v.dtype, v.mean()) 147 | 148 | print() 149 | dt = TestDataset('test') 150 | print('TestDataset') 151 | for i in range(10): 152 | smp = dt[i] 153 | for k, v in smp.items(): 154 | print(k, v.shape, v.dtype, v.mean()) 155 | 156 | print() 157 | print('ShowDataset') 158 | dt = ShowDataset('test') 159 | for i in range(10): 160 | smp = dt[i] 161 | for k, v in smp.items(): 162 | print(k, v.shape, v.dtype, v.mean()) 163 | -------------------------------------------------------------------------------- /code/base/rain100L/config/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | import numpy as np 6 | import math 7 | import torch 8 | from torch import nn 9 | from torch.nn import MSELoss 10 | from torch.optim import Adam 11 | from torch.optim.lr_scheduler import MultiStepLR 12 | from torch.autograd import Variable 13 | from torch.utils.data import DataLoader 14 | from tensorboardX import SummaryWriter 15 | 16 | import settings 17 | from dataset import TestDataset 18 | from model import ODE_DerainNet 19 | from cal_ssim import SSIM 20 | 21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id 22 | logger = settings.logger 23 | torch.cuda.manual_seed_all(66) 24 | torch.manual_seed(66) 25 | #torch.cuda.set_device(settings.device_id) 26 | 27 | 28 | def ensure_dir(dir_path): 29 | if not os.path.isdir(dir_path): 30 | os.makedirs(dir_path) 31 | def PSNR(img1, img2): 32 | b,_,_,_=img1.shape 33 | #mse=0 34 | #for i in range(b): 35 | img1=np.clip(img1,0,255) 36 | img2=np.clip(img2,0,255) 37 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse 38 | if mse == 0: 39 | return 100 40 | #mse=mse/b 41 | PIXEL_MAX = 1 42 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 43 | class Session: 44 | def __init__(self): 45 | self.log_dir = settings.log_dir 46 | self.model_dir = settings.model_dir 47 | ensure_dir(settings.log_dir) 48 | ensure_dir(settings.model_dir) 49 | logger.info('set log dir as %s' % settings.log_dir) 50 | logger.info('set model dir as %s' % settings.model_dir) 51 | if len(settings.device_id) >1: 52 | self.net = nn.DataParallel(ODE_DerainNet()).cuda() 53 | else: 54 | torch.cuda.set_device(settings.device_id[0]) 55 | self.net = ODE_DerainNet().cuda() 56 | self.l2 = MSELoss().cuda() 57 | self.l1 = nn.L1Loss().cuda() 58 | self.ssim = SSIM().cuda() 59 | self.dataloaders = {} 60 | 61 | def get_dataloader(self, dataset_name): 62 | dataset = TestDataset(dataset_name) 63 | if not dataset_name in self.dataloaders: 64 | self.dataloaders[dataset_name] = \ 65 | DataLoader(dataset, batch_size=1, 66 | shuffle=False, num_workers=1, drop_last=False) 67 | return self.dataloaders[dataset_name] 68 | 69 | def load_checkpoints(self, name): 70 | ckp_path = os.path.join(self.model_dir, name) 71 | try: 72 | obj = torch.load(ckp_path) 73 | logger.info('Load checkpoint %s' % ckp_path) 74 | except FileNotFoundError: 75 | logger.info('No checkpoint %s!!' % ckp_path) 76 | return 77 | self.net.load_state_dict(obj['net']) 78 | def loss_vgg(self,input,groundtruth): 79 | vgg_gt = self.vgg.forward(groundtruth) 80 | eval = self.vgg.forward(input) 81 | loss_vgg = [self.l1(eval[m], vgg_gt[m]) for m in range(len(vgg_gt))] 82 | loss = sum(loss_vgg) 83 | return loss 84 | 85 | def inf_batch(self, name, batch): 86 | O, B = batch['O'].cuda(), batch['B'].cuda() 87 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) 88 | 89 | with torch.no_grad(): 90 | derain = self.net(O) 91 | 92 | l1_loss = self.l1(derain, B) 93 | ssim = self.ssim(derain, B) 94 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) 95 | losses = { 'L1 loss' : l1_loss } 96 | ssimes = { 'ssim' : ssim } 97 | losses.update(ssimes) 98 | 99 | return losses, psnr 100 | 101 | 102 | def run_test(ckp_name): 103 | sess = Session() 104 | sess.net.eval() 105 | sess.load_checkpoints(ckp_name) 106 | dt = sess.get_dataloader('test') 107 | psnr_all = 0 108 | all_num = 0 109 | all_losses = {} 110 | for i, batch in enumerate(dt): 111 | losses,psnr= sess.inf_batch('test', batch) 112 | psnr_all=psnr_all+psnr 113 | batch_size = batch['O'].size(0) 114 | all_num += batch_size 115 | for key, val in losses.items(): 116 | if i == 0: 117 | all_losses[key] = 0. 118 | all_losses[key] += val * batch_size 119 | logger.info('batch %d mse %s: %f' % (i, key, val)) 120 | 121 | for key, val in all_losses.items(): 122 | logger.info('total mse %s: %f' % (key, val / all_num)) 123 | #psnr=sum(psnr_all) 124 | #print(psnr) 125 | print('psnr_ll:%8f'%(psnr_all/all_num)) 126 | 127 | if __name__ == '__main__': 128 | parser = argparse.ArgumentParser() 129 | parser.add_argument('-m', '--model', default='latest_net') 130 | 131 | args = parser.parse_args(sys.argv[1:]) 132 | run_test(args.model) 133 | 134 | -------------------------------------------------------------------------------- /code/base/rain100L/config/function/__pycache__/functional.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/config/function/__pycache__/functional.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100L/config/function/functional.py: -------------------------------------------------------------------------------- 1 | from . import functions 2 | 3 | 4 | def aggregation(input, weight, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 5 | assert input.shape[0] == weight.shape[0] and (input.shape[1] % weight.shape[1] == 0) and pad_mode in [0, 1] 6 | if input.is_cuda: 7 | if pad_mode == 0: 8 | out = functions.aggregation_zeropad(input, weight, kernel_size, stride, padding, dilation) 9 | elif pad_mode == 1: 10 | out = functions.aggregation_refpad(input, weight, kernel_size, stride, padding, dilation) 11 | else: 12 | raise NotImplementedError 13 | return out 14 | 15 | 16 | def subtraction(input, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 17 | assert input.dim() == 4 and pad_mode in [0, 1] 18 | if input.is_cuda: 19 | if pad_mode == 0: 20 | out = functions.subtraction_zeropad(input, kernel_size, stride, padding, dilation) 21 | elif pad_mode == 1: 22 | out = functions.subtraction_refpad(input, kernel_size, stride, padding, dilation) 23 | else: 24 | raise NotImplementedError 25 | return out 26 | 27 | 28 | def subtraction2(input1, input2, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 29 | assert input1.dim() == 4 and input2.dim() == 4 and pad_mode in [0, 1] 30 | if input1.is_cuda: 31 | if pad_mode == 0: 32 | out = functions.subtraction2_zeropad(input1, input2, kernel_size, stride, padding, dilation) 33 | elif pad_mode == 1: 34 | out = functions.subtraction2_refpad(input1, input2, kernel_size, stride, padding, dilation) 35 | else: 36 | raise NotImplementedError 37 | return out 38 | -------------------------------------------------------------------------------- /code/base/rain100L/config/function/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregation_zeropad import * 2 | from .aggregation_refpad import * 3 | from .subtraction_zeropad import * 4 | from .subtraction_refpad import * 5 | from .subtraction2_zeropad import * 6 | from .subtraction2_refpad import * 7 | from .utils import * 8 | -------------------------------------------------------------------------------- /code/base/rain100L/config/function/functions/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/config/function/functions/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100L/config/function/functions/__pycache__/aggregation_refpad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/config/function/functions/__pycache__/aggregation_refpad.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100L/config/function/functions/__pycache__/aggregation_zeropad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/config/function/functions/__pycache__/aggregation_zeropad.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100L/config/function/functions/__pycache__/subtraction2_refpad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/config/function/functions/__pycache__/subtraction2_refpad.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100L/config/function/functions/__pycache__/subtraction2_zeropad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/config/function/functions/__pycache__/subtraction2_zeropad.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100L/config/function/functions/__pycache__/subtraction_refpad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/config/function/functions/__pycache__/subtraction_refpad.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100L/config/function/functions/__pycache__/subtraction_zeropad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/config/function/functions/__pycache__/subtraction_zeropad.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100L/config/function/functions/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/config/function/functions/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100L/config/function/functions/utils.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from string import Template 3 | import cupy 4 | import torch 5 | 6 | 7 | Stream = namedtuple('Stream', ['ptr']) 8 | 9 | 10 | def Dtype(t): 11 | if isinstance(t, torch.cuda.FloatTensor): 12 | return 'float' 13 | elif isinstance(t, torch.cuda.DoubleTensor): 14 | return 'double' 15 | 16 | 17 | @cupy.util.memoize(for_each_device=True) 18 | def load_kernel(kernel_name, code, **kwargs): 19 | code = Template(code).substitute(**kwargs) 20 | kernel_code = cupy.cuda.compile_with_cache(code) 21 | return kernel_code.get_function(kernel_name) 22 | -------------------------------------------------------------------------------- /code/base/rain100L/config/function/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregation import * 2 | from .subtraction import * 3 | from .subtraction2 import * 4 | -------------------------------------------------------------------------------- /code/base/rain100L/config/function/modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/config/function/modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100L/config/function/modules/__pycache__/aggregation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/config/function/modules/__pycache__/aggregation.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100L/config/function/modules/__pycache__/subtraction.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/config/function/modules/__pycache__/subtraction.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100L/config/function/modules/__pycache__/subtraction2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/config/function/modules/__pycache__/subtraction2.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100L/config/function/modules/aggregation.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Aggregation(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Aggregation, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input, weight): 18 | return F.aggregation(input, weight, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /code/base/rain100L/config/function/modules/subtraction.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Subtraction(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Subtraction, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input): 18 | return F.subtraction(input, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /code/base/rain100L/config/function/modules/subtraction2.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Subtraction2(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Subtraction2, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input1, input2): 18 | return F.subtraction2(input1, input2, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /code/base/rain100L/config/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | 5 | ####################### 网络参数设置  ############################################################ 6 | channel = 32 7 | feature_map_num = 32 8 | res_conv_num = 4 9 | unit_num = 32 10 | #scale_num = 4 11 | num_scale_attention = 4 12 | scale_attention = False 13 | ssim_loss = True 14 | ######################################################################################## 15 | aug_data = False # Set as False for fair comparison 16 | 17 | patch_size = 64 18 | pic_is_pair = True #input picture is pair or single 19 | 20 | lr = 0.0005 21 | 22 | data_dir = '/data1/dataset/haze_blur_rain/rain100L' 23 | if pic_is_pair is False: 24 | data_dir = '/data1/wangcong/dataset/real-world-images' 25 | log_dir = '../logdir' 26 | show_dir = '../showdir' 27 | model_dir = '../models' 28 | show_dir_feature = '../showdir_feature' 29 | 30 | log_level = 'info' 31 | model_path = os.path.join(model_dir, 'latest_net') 32 | save_steps = 400 33 | 34 | num_workers = 8 35 | 36 | num_GPU = 4 37 | 38 | device_id = '0,1,2,3' 39 | 40 | epoch = 1000 41 | batch_size = 32 42 | 43 | if pic_is_pair: 44 | root_dir = os.path.join(data_dir, 'train') 45 | mat_files = os.listdir(root_dir) 46 | num_datasets = len(mat_files) 47 | l1 = int(3/5 * epoch * num_datasets / batch_size) 48 | l2 = int(4/5 * epoch * num_datasets / batch_size) 49 | one_epoch = int(num_datasets/batch_size) 50 | total_step = int((epoch * num_datasets)/batch_size) 51 | 52 | logger = logging.getLogger('train') 53 | logger.setLevel(logging.INFO) 54 | 55 | ch = logging.StreamHandler() 56 | ch.setLevel(logging.INFO) 57 | 58 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 59 | ch.setFormatter(formatter) 60 | logger.addHandler(ch) 61 | 62 | 63 | -------------------------------------------------------------------------------- /code/base/rain100L/config/show.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | import math 6 | import numpy as np 7 | import itertools 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import DataParallel 12 | from torch.optim import Adam 13 | from torch.autograd import Variable 14 | from torch.utils.data import DataLoader 15 | 16 | import settings 17 | from dataset import ShowDataset 18 | from model import ODE_DerainNet 19 | from cal_ssim import SSIM 20 | 21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id 22 | logger = settings.logger 23 | torch.cuda.manual_seed_all(66) 24 | torch.manual_seed(66) 25 | #torch.cuda.set_device(settings.device_id) 26 | 27 | 28 | def ensure_dir(dir_path): 29 | if not os.path.isdir(dir_path): 30 | os.makedirs(dir_path) 31 | 32 | def PSNR(img1, img2): 33 | b,_,_,_=img1.shape 34 | #mse=0 35 | #for i in range(b): 36 | img1=np.clip(img1,0,255) 37 | img2=np.clip(img2,0,255) 38 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse 39 | if mse == 0: 40 | return 100 41 | #mse=mse/b 42 | PIXEL_MAX = 1 43 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 44 | 45 | class Session: 46 | def __init__(self): 47 | self.show_dir = settings.show_dir 48 | self.model_dir = settings.model_dir 49 | ensure_dir(settings.show_dir) 50 | ensure_dir(settings.model_dir) 51 | logger.info('set show dir as %s' % settings.show_dir) 52 | logger.info('set model dir as %s' % settings.model_dir) 53 | 54 | if len(settings.device_id) >1: 55 | self.net = nn.DataParallel(ODE_DerainNet()).cuda() 56 | #self.l2 = nn.DataParallel(MSELoss(),settings.device_id) 57 | #self.l1 = nn.DataParallel(nn.L1Loss(),settings.device_id) 58 | #self.ssim = nn.DataParallel(SSIM(),settings.device_id) 59 | #self.vgg = nn.DataParallel(VGG(),settings.device_id) 60 | else: 61 | torch.cuda.set_device(settings.device_id[0]) 62 | self.net = ODE_DerainNet().cuda() 63 | self.ssim = SSIM().cuda() 64 | self.dataloaders = {} 65 | self.ssim=SSIM().cuda() 66 | self.a=0 67 | self.t=0 68 | def get_dataloader(self, dataset_name): 69 | dataset = ShowDataset(dataset_name) 70 | self.dataloaders[dataset_name] = \ 71 | DataLoader(dataset, batch_size=1, 72 | shuffle=False, num_workers=1) 73 | return self.dataloaders[dataset_name] 74 | 75 | def load_checkpoints(self, name): 76 | ckp_path = os.path.join(self.model_dir, name) 77 | try: 78 | obj = torch.load(ckp_path) 79 | logger.info('Load checkpoint %s' % ckp_path) 80 | except FileNotFoundError: 81 | logger.info('No checkpoint %s!!' % ckp_path) 82 | return 83 | self.net.load_state_dict(obj['net']) 84 | def inf_batch(self, name, batch,i): 85 | O, B, file_name= batch['O'].cuda(), batch['B'].cuda(), batch['file_name'] 86 | file_name = str(file_name[0]) 87 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) 88 | with torch.no_grad(): 89 | import time 90 | t0=time.time() 91 | derain = self.net(O) 92 | t1 = time.time() 93 | comput_time=t1-t0 94 | print(comput_time) 95 | ssim = self.ssim(derain, B).data.cpu().numpy() 96 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) 97 | print('psnr:%4f-------------ssim:%4f'%(psnr, ssim)) 98 | return derain, psnr, ssim, file_name 99 | 100 | def save_image(self, No, imgs, name, psnr, ssim, file_name): 101 | for i, img in enumerate(imgs): 102 | img = (img.cpu().data * 255).numpy() 103 | img = np.clip(img, 0, 255) 104 | img = np.transpose(img, (1, 2, 0)) 105 | h, w, c = img.shape 106 | 107 | img_file = os.path.join(self.show_dir, '%s.png' % (file_name)) 108 | print(img_file) 109 | cv2.imwrite(img_file, img) 110 | 111 | 112 | def run_show(ckp_name): 113 | sess = Session() 114 | sess.load_checkpoints(ckp_name) 115 | sess.net.eval() 116 | dataset = 'test' 117 | if settings.pic_is_pair is False: 118 | dataset = 'train-w' 119 | dt = sess.get_dataloader(dataset) 120 | 121 | for i, batch in enumerate(dt): 122 | logger.info(i) 123 | if i>-1: 124 | imgs,psnr,ssim, file_name= sess.inf_batch('test', batch,i) 125 | sess.save_image(i, imgs, dataset, psnr, ssim, file_name) 126 | 127 | 128 | 129 | if __name__ == '__main__': 130 | parser = argparse.ArgumentParser() 131 | parser.add_argument('-m', '--model', default='latest_net') 132 | 133 | args = parser.parse_args(sys.argv[1:]) 134 | 135 | run_show(args.model) 136 | 137 | -------------------------------------------------------------------------------- /code/base/rain100L/config/tensorboard.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | function rand() { 3 | min=$1 4 | max=$(($2-$min+1)) 5 | num=$(($RANDOM+1000000000000)) 6 | echo $(($num%$max+$min)) 7 | } 8 | 9 | rnd=$(rand 3000 12000) 10 | tensorboard --logdir ../logdir --host 0.0.0.0 --port $rnd --reload_interval 3 11 | -------------------------------------------------------------------------------- /code/base/rain100L/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/base/rain100L/models/.DS_Store -------------------------------------------------------------------------------- /code/base/rain100L/models/README.md: -------------------------------------------------------------------------------- 1 | # JDNet:Joint Self-Attention and Scale-Aggregation for Self-Calibrated Deraining Network 2 | 3 | The training model or pre-training model will be placed here. 4 | -------------------------------------------------------------------------------- /code/diff_loss/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/.DS_Store -------------------------------------------------------------------------------- /code/diff_loss/mae/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/.DS_Store -------------------------------------------------------------------------------- /code/diff_loss/mae/config/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/.DS_Store -------------------------------------------------------------------------------- /code/diff_loss/mae/config/cal_ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map.mean(1).mean(1).mean(1) 38 | 39 | class SSIM(torch.nn.Module): 40 | def __init__(self, window_size = 11, size_average = True): 41 | super(SSIM, self).__init__() 42 | self.window_size = window_size 43 | self.size_average = size_average 44 | self.channel = 3 45 | self.window = create_window(window_size, self.channel) 46 | 47 | def forward(self, img1, img2): 48 | #print(img1.size()) 49 | (_, channel, _, _) = img1.size() 50 | 51 | if channel == self.channel and self.window.data.type() == img1.data.type(): 52 | window = self.window 53 | else: 54 | window = create_window(self.window_size, channel) 55 | 56 | if img1.is_cuda: 57 | window = window.cuda(img1.get_device()) 58 | window = window.type_as(img1) 59 | 60 | self.window = window 61 | self.channel = channel 62 | 63 | 64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 65 | 66 | def ssim(img1, img2, window_size = 11, size_average = True): 67 | (_, channel, _, _) = img1.size() 68 | window = create_window(window_size, channel) 69 | 70 | if img1.is_cuda: 71 | window = window.cuda(img1.get_device()) 72 | window = window.type_as(img1) 73 | 74 | return _ssim(img1, img2, window, window_size, channel, size_average) 75 | -------------------------------------------------------------------------------- /code/diff_loss/mae/config/clean.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | rm -rf __pycache__ 3 | rm \.*\.swp 4 | rm -R ../logdir/* 5 | rm -R ../showdir/* 6 | rm ../models/* 7 | -------------------------------------------------------------------------------- /code/diff_loss/mae/config/compile.py: -------------------------------------------------------------------------------- 1 | import re 2 | rec_data="+MIPLOBSERVE:0,68220,1,3303,0,-1" 3 | msgidRegex = re.compile(r',(\d)+,') 4 | mo = msgidRegex.search(rec_data) 5 | print(mo.group()) -------------------------------------------------------------------------------- /code/diff_loss/mae/config/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from numpy.random import RandomState 5 | from torch.utils.data import Dataset 6 | 7 | import settings 8 | 9 | 10 | class TrainValDataset(Dataset): 11 | def __init__(self, name): 12 | super().__init__() 13 | self.rand_state = RandomState(66) 14 | self.root_dir = os.path.join(settings.data_dir, name) 15 | self.mat_files = os.listdir(self.root_dir) 16 | self.patch_size = settings.patch_size 17 | self.file_num = len(self.mat_files) 18 | 19 | def __len__(self): 20 | return self.file_num * 100 21 | 22 | def __getitem__(self, idx): 23 | file_name = self.mat_files[idx % self.file_num] 24 | img_file = os.path.join(self.root_dir, file_name) 25 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 26 | 27 | if settings.aug_data: 28 | O, B = self.crop(img_pair, aug=True) 29 | O, B = self.flip(O, B) 30 | O, B = self.rotate(O, B) 31 | else: 32 | O, B = self.crop(img_pair, aug=False) 33 | 34 | O = np.transpose(O, (2, 0, 1)) 35 | B = np.transpose(B, (2, 0, 1)) 36 | sample = {'O': O, 'B': B} 37 | 38 | return sample 39 | 40 | def crop(self, img_pair, aug): 41 | patch_size = self.patch_size 42 | h, ww, c = img_pair.shape 43 | w = int(ww / 2) 44 | 45 | if aug: 46 | mini = - 1 / 4 * self.patch_size 47 | maxi = 1 / 4 * self.patch_size + 1 48 | p_h = patch_size + self.rand_state.randint(mini, maxi) 49 | p_w = patch_size + self.rand_state.randint(mini, maxi) 50 | else: 51 | p_h, p_w = patch_size, patch_size 52 | 53 | r = self.rand_state.randint(0, h - p_h) 54 | c = self.rand_state.randint(0, w - p_w) 55 | O = img_pair[r: r+p_h, c+w: c+p_w+w] 56 | B = img_pair[r: r+p_h, c: c+p_w] 57 | 58 | if aug: 59 | O = cv2.resize(O, (patch_size, patch_size)) 60 | B = cv2.resize(B, (patch_size, patch_size)) 61 | 62 | return O, B 63 | 64 | def flip(self, O, B): 65 | if self.rand_state.rand() > 0.5: 66 | O = np.flip(O, axis=1) 67 | B = np.flip(B, axis=1) 68 | return O, B 69 | 70 | def rotate(self, O, B): 71 | angle = self.rand_state.randint(-30, 30) 72 | patch_size = self.patch_size 73 | center = (int(patch_size / 2), int(patch_size / 2)) 74 | M = cv2.getRotationMatrix2D(center, angle, 1) 75 | O = cv2.warpAffine(O, M, (patch_size, patch_size)) 76 | B = cv2.warpAffine(B, M, (patch_size, patch_size)) 77 | return O, B 78 | 79 | 80 | class TestDataset(Dataset): 81 | def __init__(self, name): 82 | super().__init__() 83 | self.rand_state = RandomState(66) 84 | self.root_dir = os.path.join(settings.data_dir, name) 85 | self.mat_files = os.listdir(self.root_dir) 86 | self.patch_size = settings.patch_size 87 | self.file_num = len(self.mat_files) 88 | 89 | def __len__(self): 90 | return self.file_num 91 | 92 | def __getitem__(self, idx): 93 | file_name = self.mat_files[idx % self.file_num] 94 | img_file = os.path.join(self.root_dir, file_name) 95 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 96 | h, ww, c = img_pair.shape 97 | w = int(ww / 2) 98 | #h_8=h%8 99 | #w_8=w%8 100 | O = np.transpose(img_pair[:, w:], (2, 0, 1)) 101 | B = np.transpose(img_pair[:, :w], (2, 0, 1)) 102 | sample = {'O': O, 'B': B} 103 | 104 | return sample 105 | 106 | 107 | class ShowDataset(Dataset): 108 | def __init__(self, name): 109 | super().__init__() 110 | self.rand_state = RandomState(66) 111 | self.root_dir = os.path.join(settings.data_dir, name) 112 | self.img_files = sorted(os.listdir(self.root_dir)) 113 | self.file_num = len(self.img_files) 114 | 115 | def __len__(self): 116 | return self.file_num 117 | 118 | def __getitem__(self, idx): 119 | file_name = self.img_files[idx % self.file_num] 120 | img_file = os.path.join(self.root_dir, file_name) 121 | print(img_file) 122 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 123 | 124 | h, ww, c = img_pair.shape 125 | w = int(ww / 2) 126 | 127 | #h_8 = h % 8 128 | #w_8 = w % 8 129 | if settings.pic_is_pair: 130 | O = np.transpose(img_pair[:, w:], (2, 0, 1)) 131 | B = np.transpose(img_pair[:, :w], (2, 0, 1)) 132 | else: 133 | O = np.transpose(img_pair[:, :], (2, 0, 1)) 134 | B = np.transpose(img_pair[:, :], (2, 0, 1)) 135 | sample = {'O': O, 'B': B,'file_name':file_name[:-4]} 136 | 137 | return sample 138 | 139 | 140 | if __name__ == '__main__': 141 | dt = TrainValDataset('val') 142 | print('TrainValDataset') 143 | for i in range(10): 144 | smp = dt[i] 145 | for k, v in smp.items(): 146 | print(k, v.shape, v.dtype, v.mean()) 147 | 148 | print() 149 | dt = TestDataset('test') 150 | print('TestDataset') 151 | for i in range(10): 152 | smp = dt[i] 153 | for k, v in smp.items(): 154 | print(k, v.shape, v.dtype, v.mean()) 155 | 156 | print() 157 | print('ShowDataset') 158 | dt = ShowDataset('test') 159 | for i in range(10): 160 | smp = dt[i] 161 | for k, v in smp.items(): 162 | print(k, v.shape, v.dtype, v.mean()) 163 | -------------------------------------------------------------------------------- /code/diff_loss/mae/config/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | import numpy as np 6 | import math 7 | import torch 8 | from torch import nn 9 | from torch.nn import MSELoss 10 | from torch.optim import Adam 11 | from torch.optim.lr_scheduler import MultiStepLR 12 | from torch.autograd import Variable 13 | from torch.utils.data import DataLoader 14 | from tensorboardX import SummaryWriter 15 | 16 | import settings 17 | from dataset import TestDataset 18 | from model import ODE_DerainNet 19 | from cal_ssim import SSIM 20 | 21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id 22 | logger = settings.logger 23 | torch.cuda.manual_seed_all(66) 24 | torch.manual_seed(66) 25 | #torch.cuda.set_device(settings.device_id) 26 | 27 | 28 | def ensure_dir(dir_path): 29 | if not os.path.isdir(dir_path): 30 | os.makedirs(dir_path) 31 | def PSNR(img1, img2): 32 | b,_,_,_=img1.shape 33 | #mse=0 34 | #for i in range(b): 35 | img1=np.clip(img1,0,255) 36 | img2=np.clip(img2,0,255) 37 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse 38 | if mse == 0: 39 | return 100 40 | #mse=mse/b 41 | PIXEL_MAX = 1 42 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 43 | class Session: 44 | def __init__(self): 45 | self.log_dir = settings.log_dir 46 | self.model_dir = settings.model_dir 47 | ensure_dir(settings.log_dir) 48 | ensure_dir(settings.model_dir) 49 | logger.info('set log dir as %s' % settings.log_dir) 50 | logger.info('set model dir as %s' % settings.model_dir) 51 | if len(settings.device_id) >1: 52 | self.net = nn.DataParallel(ODE_DerainNet()).cuda() 53 | else: 54 | torch.cuda.set_device(settings.device_id[0]) 55 | self.net = ODE_DerainNet().cuda() 56 | self.l2 = MSELoss().cuda() 57 | self.l1 = nn.L1Loss().cuda() 58 | self.ssim = SSIM().cuda() 59 | self.dataloaders = {} 60 | 61 | def get_dataloader(self, dataset_name): 62 | dataset = TestDataset(dataset_name) 63 | if not dataset_name in self.dataloaders: 64 | self.dataloaders[dataset_name] = \ 65 | DataLoader(dataset, batch_size=1, 66 | shuffle=False, num_workers=1, drop_last=False) 67 | return self.dataloaders[dataset_name] 68 | 69 | def load_checkpoints(self, name): 70 | ckp_path = os.path.join(self.model_dir, name) 71 | try: 72 | obj = torch.load(ckp_path) 73 | logger.info('Load checkpoint %s' % ckp_path) 74 | except FileNotFoundError: 75 | logger.info('No checkpoint %s!!' % ckp_path) 76 | return 77 | self.net.load_state_dict(obj['net']) 78 | 79 | def inf_batch(self, name, batch): 80 | O, B = batch['O'].cuda(), batch['B'].cuda() 81 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) 82 | 83 | with torch.no_grad(): 84 | derain = self.net(O) 85 | 86 | l1_loss = self.l1(derain, B) 87 | ssim = self.ssim(derain, B) 88 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) 89 | losses = { 'L1 loss' : l1_loss } 90 | ssimes = { 'ssim' : ssim } 91 | losses.update(ssimes) 92 | 93 | return losses, psnr 94 | 95 | 96 | def run_test(ckp_name): 97 | sess = Session() 98 | sess.net.eval() 99 | sess.load_checkpoints(ckp_name) 100 | dt = sess.get_dataloader('test') 101 | psnr_all = 0 102 | all_num = 0 103 | all_losses = {} 104 | for i, batch in enumerate(dt): 105 | losses,psnr= sess.inf_batch('test', batch) 106 | psnr_all=psnr_all+psnr 107 | batch_size = batch['O'].size(0) 108 | all_num += batch_size 109 | for key, val in losses.items(): 110 | if i == 0: 111 | all_losses[key] = 0. 112 | all_losses[key] += val * batch_size 113 | logger.info('batch %d mse %s: %f' % (i, key, val)) 114 | 115 | for key, val in all_losses.items(): 116 | logger.info('total mse %s: %f' % (key, val / all_num)) 117 | #psnr=sum(psnr_all) 118 | #print(psnr) 119 | print('psnr_ll:%8f'%(psnr_all/all_num)) 120 | 121 | if __name__ == '__main__': 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument('-m', '--model', default='latest_net') 124 | 125 | args = parser.parse_args(sys.argv[1:]) 126 | run_test(args.model) 127 | 128 | -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/__pycache__/functional.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/__pycache__/functional.cpython-36.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/__pycache__/functional.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/__pycache__/functional.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functional.py: -------------------------------------------------------------------------------- 1 | from . import functions 2 | 3 | 4 | def aggregation(input, weight, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 5 | assert input.shape[0] == weight.shape[0] and (input.shape[1] % weight.shape[1] == 0) and pad_mode in [0, 1] 6 | if input.is_cuda: 7 | if pad_mode == 0: 8 | out = functions.aggregation_zeropad(input, weight, kernel_size, stride, padding, dilation) 9 | elif pad_mode == 1: 10 | out = functions.aggregation_refpad(input, weight, kernel_size, stride, padding, dilation) 11 | else: 12 | raise NotImplementedError 13 | return out 14 | 15 | 16 | def subtraction(input, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 17 | assert input.dim() == 4 and pad_mode in [0, 1] 18 | if input.is_cuda: 19 | if pad_mode == 0: 20 | out = functions.subtraction_zeropad(input, kernel_size, stride, padding, dilation) 21 | elif pad_mode == 1: 22 | out = functions.subtraction_refpad(input, kernel_size, stride, padding, dilation) 23 | else: 24 | raise NotImplementedError 25 | return out 26 | 27 | 28 | def subtraction2(input1, input2, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 29 | assert input1.dim() == 4 and input2.dim() == 4 and pad_mode in [0, 1] 30 | if input1.is_cuda: 31 | if pad_mode == 0: 32 | out = functions.subtraction2_zeropad(input1, input2, kernel_size, stride, padding, dilation) 33 | elif pad_mode == 1: 34 | out = functions.subtraction2_refpad(input1, input2, kernel_size, stride, padding, dilation) 35 | else: 36 | raise NotImplementedError 37 | return out 38 | -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregation_zeropad import * 2 | from .aggregation_refpad import * 3 | from .subtraction_zeropad import * 4 | from .subtraction_refpad import * 5 | from .subtraction2_zeropad import * 6 | from .subtraction2_refpad import * 7 | from .utils import * 8 | -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/aggregation_refpad.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/aggregation_refpad.cpython-36.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/aggregation_refpad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/aggregation_refpad.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/aggregation_zeropad.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/aggregation_zeropad.cpython-36.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/aggregation_zeropad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/aggregation_zeropad.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/subtraction2_refpad.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/subtraction2_refpad.cpython-36.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/subtraction2_refpad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/subtraction2_refpad.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/subtraction2_zeropad.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/subtraction2_zeropad.cpython-36.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/subtraction2_zeropad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/subtraction2_zeropad.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/subtraction_refpad.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/subtraction_refpad.cpython-36.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/subtraction_refpad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/subtraction_refpad.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/subtraction_zeropad.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/subtraction_zeropad.cpython-36.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/subtraction_zeropad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/subtraction_zeropad.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/functions/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/utils.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from string import Template 3 | import cupy 4 | import torch 5 | 6 | 7 | Stream = namedtuple('Stream', ['ptr']) 8 | 9 | 10 | def Dtype(t): 11 | if isinstance(t, torch.cuda.FloatTensor): 12 | return 'float' 13 | elif isinstance(t, torch.cuda.DoubleTensor): 14 | return 'double' 15 | 16 | 17 | @cupy.util.memoize(for_each_device=True) 18 | def load_kernel(kernel_name, code, **kwargs): 19 | code = Template(code).substitute(**kwargs) 20 | kernel_code = cupy.cuda.compile_with_cache(code) 21 | return kernel_code.get_function(kernel_name) 22 | -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregation import * 2 | from .subtraction import * 3 | from .subtraction2 import * 4 | -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/modules/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/modules/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/modules/__pycache__/aggregation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/modules/__pycache__/aggregation.cpython-36.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/modules/__pycache__/aggregation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/modules/__pycache__/aggregation.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/modules/__pycache__/subtraction.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/modules/__pycache__/subtraction.cpython-36.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/modules/__pycache__/subtraction.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/modules/__pycache__/subtraction.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/modules/__pycache__/subtraction2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/modules/__pycache__/subtraction2.cpython-36.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/modules/__pycache__/subtraction2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/config/function/modules/__pycache__/subtraction2.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/modules/aggregation.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Aggregation(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Aggregation, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input, weight): 18 | return F.aggregation(input, weight, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/modules/subtraction.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Subtraction(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Subtraction, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input): 18 | return F.subtraction(input, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/modules/subtraction2.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Subtraction2(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Subtraction2, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input1, input2): 18 | return F.subtraction2(input1, input2, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /code/diff_loss/mae/config/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | 5 | ####################### 网络参数设置  ############################################################ 6 | channel = 32 7 | feature_map_num = 32 8 | res_conv_num = 4 9 | unit_num = 32 10 | #scale_num = 4 11 | num_scale_attention = 4 12 | scale_attention = False 13 | ssim_loss = True 14 | ######################################################################################## 15 | aug_data = False # Set as False for fair comparison 16 | 17 | patch_size = 64 18 | pic_is_pair = True #input picture is pair or single 19 | 20 | lr = 0.0005 21 | 22 | data_dir = '/data1/dataset/haze_blur_rain/rain100H' 23 | if pic_is_pair is False: 24 | data_dir = '/data1/wangcong/dataset/real-world-images' 25 | log_dir = '../logdir' 26 | show_dir = '../showdir' 27 | model_dir = '../models' 28 | show_dir_feature = '../showdir_feature' 29 | 30 | log_level = 'info' 31 | model_path = os.path.join(model_dir, 'latest_net') 32 | save_steps = 400 33 | 34 | num_workers = 8 35 | 36 | num_GPU = 2 37 | 38 | device_id = '0,1' 39 | 40 | epoch = 1000 41 | batch_size = 12 42 | 43 | if pic_is_pair: 44 | root_dir = os.path.join(data_dir, 'train') 45 | mat_files = os.listdir(root_dir) 46 | num_datasets = len(mat_files) 47 | l1 = int(3/5 * epoch * num_datasets / batch_size) 48 | l2 = int(4/5 * epoch * num_datasets / batch_size) 49 | one_epoch = int(num_datasets/batch_size) 50 | total_step = int((epoch * num_datasets)/batch_size) 51 | 52 | logger = logging.getLogger('train') 53 | logger.setLevel(logging.INFO) 54 | 55 | ch = logging.StreamHandler() 56 | ch.setLevel(logging.INFO) 57 | 58 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 59 | ch.setFormatter(formatter) 60 | logger.addHandler(ch) 61 | 62 | 63 | -------------------------------------------------------------------------------- /code/diff_loss/mae/config/show.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | import math 6 | import numpy as np 7 | import itertools 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import DataParallel 12 | from torch.optim import Adam 13 | from torch.autograd import Variable 14 | from torch.utils.data import DataLoader 15 | 16 | import settings 17 | from dataset import ShowDataset 18 | from model import ODE_DerainNet 19 | from cal_ssim import SSIM 20 | 21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id 22 | logger = settings.logger 23 | torch.cuda.manual_seed_all(66) 24 | torch.manual_seed(66) 25 | #torch.cuda.set_device(settings.device_id) 26 | 27 | 28 | def ensure_dir(dir_path): 29 | if not os.path.isdir(dir_path): 30 | os.makedirs(dir_path) 31 | 32 | def PSNR(img1, img2): 33 | b,_,_,_=img1.shape 34 | #mse=0 35 | #for i in range(b): 36 | img1=np.clip(img1,0,255) 37 | img2=np.clip(img2,0,255) 38 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse 39 | if mse == 0: 40 | return 100 41 | #mse=mse/b 42 | PIXEL_MAX = 1 43 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 44 | 45 | class Session: 46 | def __init__(self): 47 | self.show_dir = settings.show_dir 48 | self.model_dir = settings.model_dir 49 | ensure_dir(settings.show_dir) 50 | ensure_dir(settings.model_dir) 51 | logger.info('set show dir as %s' % settings.show_dir) 52 | logger.info('set model dir as %s' % settings.model_dir) 53 | 54 | if len(settings.device_id) >1: 55 | self.net = nn.DataParallel(ODE_DerainNet()).cuda() 56 | #self.l2 = nn.DataParallel(MSELoss(),settings.device_id) 57 | #self.l1 = nn.DataParallel(nn.L1Loss(),settings.device_id) 58 | #self.ssim = nn.DataParallel(SSIM(),settings.device_id) 59 | #self.vgg = nn.DataParallel(VGG(),settings.device_id) 60 | else: 61 | torch.cuda.set_device(settings.device_id[0]) 62 | self.net = ODE_DerainNet().cuda() 63 | self.ssim = SSIM().cuda() 64 | self.dataloaders = {} 65 | self.ssim=SSIM().cuda() 66 | self.a=0 67 | self.t=0 68 | def get_dataloader(self, dataset_name): 69 | dataset = ShowDataset(dataset_name) 70 | self.dataloaders[dataset_name] = \ 71 | DataLoader(dataset, batch_size=1, 72 | shuffle=False, num_workers=1) 73 | return self.dataloaders[dataset_name] 74 | 75 | def load_checkpoints(self, name): 76 | ckp_path = os.path.join(self.model_dir, name) 77 | try: 78 | obj = torch.load(ckp_path) 79 | logger.info('Load checkpoint %s' % ckp_path) 80 | except FileNotFoundError: 81 | logger.info('No checkpoint %s!!' % ckp_path) 82 | return 83 | self.net.load_state_dict(obj['net']) 84 | def inf_batch(self, name, batch,i): 85 | O, B, file_name= batch['O'].cuda(), batch['B'].cuda(), batch['file_name'] 86 | file_name = str(file_name[0]) 87 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) 88 | with torch.no_grad(): 89 | import time 90 | t0=time.time() 91 | derain = self.net(O) 92 | t1 = time.time() 93 | comput_time=t1-t0 94 | print(comput_time) 95 | ssim = self.ssim(derain, B).data.cpu().numpy() 96 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) 97 | print('psnr:%4f-------------ssim:%4f'%(psnr, ssim)) 98 | return derain, psnr, ssim, file_name 99 | 100 | def save_image(self, No, imgs, name, psnr, ssim, file_name): 101 | for i, img in enumerate(imgs): 102 | img = (img.cpu().data * 255).numpy() 103 | img = np.clip(img, 0, 255) 104 | img = np.transpose(img, (1, 2, 0)) 105 | h, w, c = img.shape 106 | 107 | img_file = os.path.join(self.show_dir, '%s.png' % (file_name)) 108 | print(img_file) 109 | cv2.imwrite(img_file, img) 110 | 111 | 112 | def run_show(ckp_name): 113 | sess = Session() 114 | sess.load_checkpoints(ckp_name) 115 | sess.net.eval() 116 | dataset = 'test' 117 | if settings.pic_is_pair is False: 118 | dataset = 'train-small' 119 | dt = sess.get_dataloader(dataset) 120 | 121 | for i, batch in enumerate(dt): 122 | logger.info(i) 123 | if i>-1: 124 | imgs,psnr,ssim, file_name= sess.inf_batch('test', batch,i) 125 | sess.save_image(i, imgs, dataset, psnr, ssim, file_name) 126 | 127 | 128 | 129 | if __name__ == '__main__': 130 | parser = argparse.ArgumentParser() 131 | parser.add_argument('-m', '--model', default='latest_net') 132 | 133 | args = parser.parse_args(sys.argv[1:]) 134 | 135 | run_show(args.model) 136 | 137 | -------------------------------------------------------------------------------- /code/diff_loss/mae/config/tensorboard.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | function rand() { 3 | min=$1 4 | max=$(($2-$min+1)) 5 | num=$(($RANDOM+1000000000000)) 6 | echo $(($num%$max+$min)) 7 | } 8 | 9 | rnd=$(rand 3000 12000) 10 | tensorboard --logdir ../logdir --host 0.0.0.0 --port $rnd --reload_interval 3 11 | -------------------------------------------------------------------------------- /code/diff_loss/mae/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mae/models/.DS_Store -------------------------------------------------------------------------------- /code/diff_loss/mae/models/README.md: -------------------------------------------------------------------------------- 1 | # JDNet:Joint Self-Attention and Scale-Aggregation for Self-Calibrated Deraining Network 2 | 3 | The training model or pre-training model will be placed here. 4 | -------------------------------------------------------------------------------- /code/diff_loss/mse/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/.DS_Store -------------------------------------------------------------------------------- /code/diff_loss/mse/config/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/config/.DS_Store -------------------------------------------------------------------------------- /code/diff_loss/mse/config/cal_ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map.mean(1).mean(1).mean(1) 38 | 39 | class SSIM(torch.nn.Module): 40 | def __init__(self, window_size = 11, size_average = True): 41 | super(SSIM, self).__init__() 42 | self.window_size = window_size 43 | self.size_average = size_average 44 | self.channel = 3 45 | self.window = create_window(window_size, self.channel) 46 | 47 | def forward(self, img1, img2): 48 | #print(img1.size()) 49 | (_, channel, _, _) = img1.size() 50 | 51 | if channel == self.channel and self.window.data.type() == img1.data.type(): 52 | window = self.window 53 | else: 54 | window = create_window(self.window_size, channel) 55 | 56 | if img1.is_cuda: 57 | window = window.cuda(img1.get_device()) 58 | window = window.type_as(img1) 59 | 60 | self.window = window 61 | self.channel = channel 62 | 63 | 64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 65 | 66 | def ssim(img1, img2, window_size = 11, size_average = True): 67 | (_, channel, _, _) = img1.size() 68 | window = create_window(window_size, channel) 69 | 70 | if img1.is_cuda: 71 | window = window.cuda(img1.get_device()) 72 | window = window.type_as(img1) 73 | 74 | return _ssim(img1, img2, window, window_size, channel, size_average) 75 | -------------------------------------------------------------------------------- /code/diff_loss/mse/config/clean.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | rm -rf __pycache__ 3 | rm \.*\.swp 4 | rm -R ../logdir/* 5 | rm -R ../showdir/* 6 | rm ../models/* 7 | -------------------------------------------------------------------------------- /code/diff_loss/mse/config/compile.py: -------------------------------------------------------------------------------- 1 | import re 2 | rec_data="+MIPLOBSERVE:0,68220,1,3303,0,-1" 3 | msgidRegex = re.compile(r',(\d)+,') 4 | mo = msgidRegex.search(rec_data) 5 | print(mo.group()) -------------------------------------------------------------------------------- /code/diff_loss/mse/config/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from numpy.random import RandomState 5 | from torch.utils.data import Dataset 6 | 7 | import settings 8 | 9 | 10 | class TrainValDataset(Dataset): 11 | def __init__(self, name): 12 | super().__init__() 13 | self.rand_state = RandomState(66) 14 | self.root_dir = os.path.join(settings.data_dir, name) 15 | self.mat_files = os.listdir(self.root_dir) 16 | self.patch_size = settings.patch_size 17 | self.file_num = len(self.mat_files) 18 | 19 | def __len__(self): 20 | return self.file_num * 100 21 | 22 | def __getitem__(self, idx): 23 | file_name = self.mat_files[idx % self.file_num] 24 | img_file = os.path.join(self.root_dir, file_name) 25 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 26 | 27 | if settings.aug_data: 28 | O, B = self.crop(img_pair, aug=True) 29 | O, B = self.flip(O, B) 30 | O, B = self.rotate(O, B) 31 | else: 32 | O, B = self.crop(img_pair, aug=False) 33 | 34 | O = np.transpose(O, (2, 0, 1)) 35 | B = np.transpose(B, (2, 0, 1)) 36 | sample = {'O': O, 'B': B} 37 | 38 | return sample 39 | 40 | def crop(self, img_pair, aug): 41 | patch_size = self.patch_size 42 | h, ww, c = img_pair.shape 43 | w = int(ww / 2) 44 | 45 | if aug: 46 | mini = - 1 / 4 * self.patch_size 47 | maxi = 1 / 4 * self.patch_size + 1 48 | p_h = patch_size + self.rand_state.randint(mini, maxi) 49 | p_w = patch_size + self.rand_state.randint(mini, maxi) 50 | else: 51 | p_h, p_w = patch_size, patch_size 52 | 53 | r = self.rand_state.randint(0, h - p_h) 54 | c = self.rand_state.randint(0, w - p_w) 55 | O = img_pair[r: r+p_h, c+w: c+p_w+w] 56 | B = img_pair[r: r+p_h, c: c+p_w] 57 | 58 | if aug: 59 | O = cv2.resize(O, (patch_size, patch_size)) 60 | B = cv2.resize(B, (patch_size, patch_size)) 61 | 62 | return O, B 63 | 64 | def flip(self, O, B): 65 | if self.rand_state.rand() > 0.5: 66 | O = np.flip(O, axis=1) 67 | B = np.flip(B, axis=1) 68 | return O, B 69 | 70 | def rotate(self, O, B): 71 | angle = self.rand_state.randint(-30, 30) 72 | patch_size = self.patch_size 73 | center = (int(patch_size / 2), int(patch_size / 2)) 74 | M = cv2.getRotationMatrix2D(center, angle, 1) 75 | O = cv2.warpAffine(O, M, (patch_size, patch_size)) 76 | B = cv2.warpAffine(B, M, (patch_size, patch_size)) 77 | return O, B 78 | 79 | 80 | class TestDataset(Dataset): 81 | def __init__(self, name): 82 | super().__init__() 83 | self.rand_state = RandomState(66) 84 | self.root_dir = os.path.join(settings.data_dir, name) 85 | self.mat_files = os.listdir(self.root_dir) 86 | self.patch_size = settings.patch_size 87 | self.file_num = len(self.mat_files) 88 | 89 | def __len__(self): 90 | return self.file_num 91 | 92 | def __getitem__(self, idx): 93 | file_name = self.mat_files[idx % self.file_num] 94 | img_file = os.path.join(self.root_dir, file_name) 95 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 96 | h, ww, c = img_pair.shape 97 | w = int(ww / 2) 98 | #h_8=h%8 99 | #w_8=w%8 100 | O = np.transpose(img_pair[:, w:], (2, 0, 1)) 101 | B = np.transpose(img_pair[:, :w], (2, 0, 1)) 102 | sample = {'O': O, 'B': B} 103 | 104 | return sample 105 | 106 | 107 | class ShowDataset(Dataset): 108 | def __init__(self, name): 109 | super().__init__() 110 | self.rand_state = RandomState(66) 111 | self.root_dir = os.path.join(settings.data_dir, name) 112 | self.img_files = sorted(os.listdir(self.root_dir)) 113 | self.file_num = len(self.img_files) 114 | 115 | def __len__(self): 116 | return self.file_num 117 | 118 | def __getitem__(self, idx): 119 | file_name = self.img_files[idx % self.file_num] 120 | img_file = os.path.join(self.root_dir, file_name) 121 | print(img_file) 122 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 123 | 124 | h, ww, c = img_pair.shape 125 | w = int(ww / 2) 126 | 127 | #h_8 = h % 8 128 | #w_8 = w % 8 129 | if settings.pic_is_pair: 130 | O = np.transpose(img_pair[:, w:], (2, 0, 1)) 131 | B = np.transpose(img_pair[:, :w], (2, 0, 1)) 132 | else: 133 | O = np.transpose(img_pair[:, :], (2, 0, 1)) 134 | B = np.transpose(img_pair[:, :], (2, 0, 1)) 135 | sample = {'O': O, 'B': B,'file_name':file_name[:-4]} 136 | 137 | return sample 138 | 139 | 140 | if __name__ == '__main__': 141 | dt = TrainValDataset('val') 142 | print('TrainValDataset') 143 | for i in range(10): 144 | smp = dt[i] 145 | for k, v in smp.items(): 146 | print(k, v.shape, v.dtype, v.mean()) 147 | 148 | print() 149 | dt = TestDataset('test') 150 | print('TestDataset') 151 | for i in range(10): 152 | smp = dt[i] 153 | for k, v in smp.items(): 154 | print(k, v.shape, v.dtype, v.mean()) 155 | 156 | print() 157 | print('ShowDataset') 158 | dt = ShowDataset('test') 159 | for i in range(10): 160 | smp = dt[i] 161 | for k, v in smp.items(): 162 | print(k, v.shape, v.dtype, v.mean()) 163 | -------------------------------------------------------------------------------- /code/diff_loss/mse/config/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | import numpy as np 6 | import math 7 | import torch 8 | from torch import nn 9 | from torch.nn import MSELoss 10 | from torch.optim import Adam 11 | from torch.optim.lr_scheduler import MultiStepLR 12 | from torch.autograd import Variable 13 | from torch.utils.data import DataLoader 14 | from tensorboardX import SummaryWriter 15 | 16 | import settings 17 | from dataset import TestDataset 18 | from model import ODE_DerainNet 19 | from cal_ssim import SSIM 20 | 21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id 22 | logger = settings.logger 23 | torch.cuda.manual_seed_all(66) 24 | torch.manual_seed(66) 25 | #torch.cuda.set_device(settings.device_id) 26 | 27 | 28 | def ensure_dir(dir_path): 29 | if not os.path.isdir(dir_path): 30 | os.makedirs(dir_path) 31 | def PSNR(img1, img2): 32 | b,_,_,_=img1.shape 33 | #mse=0 34 | #for i in range(b): 35 | img1=np.clip(img1,0,255) 36 | img2=np.clip(img2,0,255) 37 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse 38 | if mse == 0: 39 | return 100 40 | #mse=mse/b 41 | PIXEL_MAX = 1 42 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 43 | class Session: 44 | def __init__(self): 45 | self.log_dir = settings.log_dir 46 | self.model_dir = settings.model_dir 47 | ensure_dir(settings.log_dir) 48 | ensure_dir(settings.model_dir) 49 | logger.info('set log dir as %s' % settings.log_dir) 50 | logger.info('set model dir as %s' % settings.model_dir) 51 | if len(settings.device_id) >1: 52 | self.net = nn.DataParallel(ODE_DerainNet()).cuda() 53 | else: 54 | torch.cuda.set_device(settings.device_id[0]) 55 | self.net = ODE_DerainNet().cuda() 56 | self.l2 = MSELoss().cuda() 57 | self.l1 = nn.L1Loss().cuda() 58 | self.ssim = SSIM().cuda() 59 | self.dataloaders = {} 60 | 61 | def get_dataloader(self, dataset_name): 62 | dataset = TestDataset(dataset_name) 63 | if not dataset_name in self.dataloaders: 64 | self.dataloaders[dataset_name] = \ 65 | DataLoader(dataset, batch_size=1, 66 | shuffle=False, num_workers=1, drop_last=False) 67 | return self.dataloaders[dataset_name] 68 | 69 | def load_checkpoints(self, name): 70 | ckp_path = os.path.join(self.model_dir, name) 71 | try: 72 | obj = torch.load(ckp_path) 73 | logger.info('Load checkpoint %s' % ckp_path) 74 | except FileNotFoundError: 75 | logger.info('No checkpoint %s!!' % ckp_path) 76 | return 77 | self.net.load_state_dict(obj['net']) 78 | 79 | def inf_batch(self, name, batch): 80 | O, B = batch['O'].cuda(), batch['B'].cuda() 81 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) 82 | 83 | with torch.no_grad(): 84 | derain = self.net(O) 85 | 86 | l1_loss = self.l1(derain, B) 87 | ssim = self.ssim(derain, B) 88 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) 89 | losses = { 'L1 loss' : l1_loss } 90 | ssimes = { 'ssim' : ssim } 91 | losses.update(ssimes) 92 | 93 | return losses, psnr 94 | 95 | 96 | def run_test(ckp_name): 97 | sess = Session() 98 | sess.net.eval() 99 | sess.load_checkpoints(ckp_name) 100 | dt = sess.get_dataloader('test') 101 | psnr_all = 0 102 | all_num = 0 103 | all_losses = {} 104 | for i, batch in enumerate(dt): 105 | losses,psnr= sess.inf_batch('test', batch) 106 | psnr_all=psnr_all+psnr 107 | batch_size = batch['O'].size(0) 108 | all_num += batch_size 109 | for key, val in losses.items(): 110 | if i == 0: 111 | all_losses[key] = 0. 112 | all_losses[key] += val * batch_size 113 | logger.info('batch %d mse %s: %f' % (i, key, val)) 114 | 115 | for key, val in all_losses.items(): 116 | logger.info('total mse %s: %f' % (key, val / all_num)) 117 | #psnr=sum(psnr_all) 118 | #print(psnr) 119 | print('psnr_ll:%8f'%(psnr_all/all_num)) 120 | 121 | if __name__ == '__main__': 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument('-m', '--model', default='latest_net') 124 | 125 | args = parser.parse_args(sys.argv[1:]) 126 | run_test(args.model) 127 | 128 | -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/__pycache__/functional.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/config/function/__pycache__/functional.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/functional.py: -------------------------------------------------------------------------------- 1 | from . import functions 2 | 3 | 4 | def aggregation(input, weight, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 5 | assert input.shape[0] == weight.shape[0] and (input.shape[1] % weight.shape[1] == 0) and pad_mode in [0, 1] 6 | if input.is_cuda: 7 | if pad_mode == 0: 8 | out = functions.aggregation_zeropad(input, weight, kernel_size, stride, padding, dilation) 9 | elif pad_mode == 1: 10 | out = functions.aggregation_refpad(input, weight, kernel_size, stride, padding, dilation) 11 | else: 12 | raise NotImplementedError 13 | return out 14 | 15 | 16 | def subtraction(input, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 17 | assert input.dim() == 4 and pad_mode in [0, 1] 18 | if input.is_cuda: 19 | if pad_mode == 0: 20 | out = functions.subtraction_zeropad(input, kernel_size, stride, padding, dilation) 21 | elif pad_mode == 1: 22 | out = functions.subtraction_refpad(input, kernel_size, stride, padding, dilation) 23 | else: 24 | raise NotImplementedError 25 | return out 26 | 27 | 28 | def subtraction2(input1, input2, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 29 | assert input1.dim() == 4 and input2.dim() == 4 and pad_mode in [0, 1] 30 | if input1.is_cuda: 31 | if pad_mode == 0: 32 | out = functions.subtraction2_zeropad(input1, input2, kernel_size, stride, padding, dilation) 33 | elif pad_mode == 1: 34 | out = functions.subtraction2_refpad(input1, input2, kernel_size, stride, padding, dilation) 35 | else: 36 | raise NotImplementedError 37 | return out 38 | -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregation_zeropad import * 2 | from .aggregation_refpad import * 3 | from .subtraction_zeropad import * 4 | from .subtraction_refpad import * 5 | from .subtraction2_zeropad import * 6 | from .subtraction2_refpad import * 7 | from .utils import * 8 | -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/functions/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/config/function/functions/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/functions/__pycache__/aggregation_refpad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/config/function/functions/__pycache__/aggregation_refpad.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/functions/__pycache__/aggregation_zeropad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/config/function/functions/__pycache__/aggregation_zeropad.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/functions/__pycache__/subtraction2_refpad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/config/function/functions/__pycache__/subtraction2_refpad.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/functions/__pycache__/subtraction2_zeropad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/config/function/functions/__pycache__/subtraction2_zeropad.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/functions/__pycache__/subtraction_refpad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/config/function/functions/__pycache__/subtraction_refpad.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/functions/__pycache__/subtraction_zeropad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/config/function/functions/__pycache__/subtraction_zeropad.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/functions/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/config/function/functions/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/functions/utils.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from string import Template 3 | import cupy 4 | import torch 5 | 6 | 7 | Stream = namedtuple('Stream', ['ptr']) 8 | 9 | 10 | def Dtype(t): 11 | if isinstance(t, torch.cuda.FloatTensor): 12 | return 'float' 13 | elif isinstance(t, torch.cuda.DoubleTensor): 14 | return 'double' 15 | 16 | 17 | @cupy.util.memoize(for_each_device=True) 18 | def load_kernel(kernel_name, code, **kwargs): 19 | code = Template(code).substitute(**kwargs) 20 | kernel_code = cupy.cuda.compile_with_cache(code) 21 | return kernel_code.get_function(kernel_name) 22 | -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregation import * 2 | from .subtraction import * 3 | from .subtraction2 import * 4 | -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/config/function/modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/modules/__pycache__/aggregation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/config/function/modules/__pycache__/aggregation.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/modules/__pycache__/subtraction.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/config/function/modules/__pycache__/subtraction.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/modules/__pycache__/subtraction2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/config/function/modules/__pycache__/subtraction2.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/modules/aggregation.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Aggregation(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Aggregation, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input, weight): 18 | return F.aggregation(input, weight, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/modules/subtraction.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Subtraction(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Subtraction, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input): 18 | return F.subtraction(input, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/modules/subtraction2.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Subtraction2(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Subtraction2, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input1, input2): 18 | return F.subtraction2(input1, input2, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /code/diff_loss/mse/config/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | 5 | ####################### 网络参数设置  ############################################################ 6 | channel = 32 7 | feature_map_num = 32 8 | res_conv_num = 4 9 | unit_num = 32 10 | #scale_num = 4 11 | num_scale_attention = 4 12 | scale_attention = False 13 | ssim_loss = True 14 | ######################################################################################## 15 | aug_data = False # Set as False for fair comparison 16 | 17 | patch_size = 64 18 | pic_is_pair = True #input picture is pair or single 19 | 20 | lr = 0.0005 21 | 22 | data_dir = '/data1/datasets/haze_blur_rain/rain100H' 23 | if pic_is_pair is False: 24 | data_dir = '/data1/wangcong/dataset/real-world-images' 25 | log_dir = '../logdir' 26 | show_dir = '../showdir' 27 | model_dir = '../models' 28 | show_dir_feature = '../showdir_feature' 29 | 30 | log_level = 'info' 31 | model_path = os.path.join(model_dir, 'latest_net') 32 | save_steps = 400 33 | 34 | num_workers = 8 35 | 36 | num_GPU = 2 37 | 38 | device_id = '1,2' 39 | 40 | epoch = 1000 41 | batch_size = 12 42 | 43 | if pic_is_pair: 44 | root_dir = os.path.join(data_dir, 'train') 45 | mat_files = os.listdir(root_dir) 46 | num_datasets = len(mat_files) 47 | l1 = int(3/5 * epoch * num_datasets / batch_size) 48 | l2 = int(4/5 * epoch * num_datasets / batch_size) 49 | one_epoch = int(num_datasets/batch_size) 50 | total_step = int((epoch * num_datasets)/batch_size) 51 | 52 | logger = logging.getLogger('train') 53 | logger.setLevel(logging.INFO) 54 | 55 | ch = logging.StreamHandler() 56 | ch.setLevel(logging.INFO) 57 | 58 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 59 | ch.setFormatter(formatter) 60 | logger.addHandler(ch) 61 | 62 | 63 | -------------------------------------------------------------------------------- /code/diff_loss/mse/config/show.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | import math 6 | import numpy as np 7 | import itertools 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import DataParallel 12 | from torch.optim import Adam 13 | from torch.autograd import Variable 14 | from torch.utils.data import DataLoader 15 | 16 | import settings 17 | from dataset import ShowDataset 18 | from model import ODE_DerainNet 19 | from cal_ssim import SSIM 20 | 21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id 22 | logger = settings.logger 23 | torch.cuda.manual_seed_all(66) 24 | torch.manual_seed(66) 25 | #torch.cuda.set_device(settings.device_id) 26 | 27 | 28 | def ensure_dir(dir_path): 29 | if not os.path.isdir(dir_path): 30 | os.makedirs(dir_path) 31 | 32 | def PSNR(img1, img2): 33 | b,_,_,_=img1.shape 34 | #mse=0 35 | #for i in range(b): 36 | img1=np.clip(img1,0,255) 37 | img2=np.clip(img2,0,255) 38 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse 39 | if mse == 0: 40 | return 100 41 | #mse=mse/b 42 | PIXEL_MAX = 1 43 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 44 | 45 | class Session: 46 | def __init__(self): 47 | self.show_dir = settings.show_dir 48 | self.model_dir = settings.model_dir 49 | ensure_dir(settings.show_dir) 50 | ensure_dir(settings.model_dir) 51 | logger.info('set show dir as %s' % settings.show_dir) 52 | logger.info('set model dir as %s' % settings.model_dir) 53 | 54 | if len(settings.device_id) >1: 55 | self.net = nn.DataParallel(ODE_DerainNet()).cuda() 56 | #self.l2 = nn.DataParallel(MSELoss(),settings.device_id) 57 | #self.l1 = nn.DataParallel(nn.L1Loss(),settings.device_id) 58 | #self.ssim = nn.DataParallel(SSIM(),settings.device_id) 59 | #self.vgg = nn.DataParallel(VGG(),settings.device_id) 60 | else: 61 | torch.cuda.set_device(settings.device_id[0]) 62 | self.net = ODE_DerainNet().cuda() 63 | self.ssim = SSIM().cuda() 64 | self.dataloaders = {} 65 | self.ssim=SSIM().cuda() 66 | self.a=0 67 | self.t=0 68 | def get_dataloader(self, dataset_name): 69 | dataset = ShowDataset(dataset_name) 70 | self.dataloaders[dataset_name] = \ 71 | DataLoader(dataset, batch_size=1, 72 | shuffle=False, num_workers=1) 73 | return self.dataloaders[dataset_name] 74 | 75 | def load_checkpoints(self, name): 76 | ckp_path = os.path.join(self.model_dir, name) 77 | try: 78 | obj = torch.load(ckp_path) 79 | logger.info('Load checkpoint %s' % ckp_path) 80 | except FileNotFoundError: 81 | logger.info('No checkpoint %s!!' % ckp_path) 82 | return 83 | self.net.load_state_dict(obj['net']) 84 | def inf_batch(self, name, batch,i): 85 | O, B, file_name= batch['O'].cuda(), batch['B'].cuda(), batch['file_name'] 86 | file_name = str(file_name[0]) 87 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) 88 | with torch.no_grad(): 89 | import time 90 | t0=time.time() 91 | derain = self.net(O) 92 | t1 = time.time() 93 | comput_time=t1-t0 94 | print(comput_time) 95 | ssim = self.ssim(derain, B).data.cpu().numpy() 96 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) 97 | print('psnr:%4f-------------ssim:%4f'%(psnr, ssim)) 98 | return derain, psnr, ssim, file_name 99 | 100 | def save_image(self, No, imgs, name, psnr, ssim, file_name): 101 | for i, img in enumerate(imgs): 102 | img = (img.cpu().data * 255).numpy() 103 | img = np.clip(img, 0, 255) 104 | img = np.transpose(img, (1, 2, 0)) 105 | h, w, c = img.shape 106 | 107 | img_file = os.path.join(self.show_dir, '%s.png' % (file_name)) 108 | print(img_file) 109 | cv2.imwrite(img_file, img) 110 | 111 | 112 | def run_show(ckp_name): 113 | sess = Session() 114 | sess.load_checkpoints(ckp_name) 115 | sess.net.eval() 116 | dataset = 'test' 117 | if settings.pic_is_pair is False: 118 | dataset = 'train-small' 119 | dt = sess.get_dataloader(dataset) 120 | 121 | for i, batch in enumerate(dt): 122 | logger.info(i) 123 | if i>-1: 124 | imgs,psnr,ssim, file_name= sess.inf_batch('test', batch,i) 125 | sess.save_image(i, imgs, dataset, psnr, ssim, file_name) 126 | 127 | 128 | 129 | if __name__ == '__main__': 130 | parser = argparse.ArgumentParser() 131 | parser.add_argument('-m', '--model', default='latest_net') 132 | 133 | args = parser.parse_args(sys.argv[1:]) 134 | 135 | run_show(args.model) 136 | 137 | -------------------------------------------------------------------------------- /code/diff_loss/mse/config/tensorboard.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | function rand() { 3 | min=$1 4 | max=$(($2-$min+1)) 5 | num=$(($RANDOM+1000000000000)) 6 | echo $(($num%$max+$min)) 7 | } 8 | 9 | rnd=$(rand 3000 12000) 10 | tensorboard --logdir ../logdir --host 0.0.0.0 --port $rnd --reload_interval 3 11 | -------------------------------------------------------------------------------- /code/diff_loss/mse/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/code/diff_loss/mse/models/.DS_Store -------------------------------------------------------------------------------- /code/diff_loss/mse/models/README.md: -------------------------------------------------------------------------------- 1 | # JDNet:Joint Self-Attention and Scale-Aggregation for Self-Calibrated Deraining Network 2 | 3 | The training model or pre-training model will be placed here. 4 | -------------------------------------------------------------------------------- /fig/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/fig/.DS_Store -------------------------------------------------------------------------------- /fig/ex_pair.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/fig/ex_pair.png -------------------------------------------------------------------------------- /fig/ex_unpair.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/fig/ex_unpair.png -------------------------------------------------------------------------------- /fig/overall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/fig/overall.png -------------------------------------------------------------------------------- /fig/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/c1cc5c8f6b19d52feae7a60e8ecdf96107526ff5/fig/result.png --------------------------------------------------------------------------------