├── Environment.py ├── LICENSE ├── Loader_batch.py ├── README.md └── PPO_batch.py /Environment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | class Cusenv(): 5 | 6 | def __init__(self): 7 | super(Cusenv,self).__init__() 8 | self.img = [] 9 | self.prev = [] 10 | 11 | def reset(self, raw_x , raw_n):#This reset() is specific for denoising 12 | 13 | self.img = raw_x + raw_n 14 | self.prev = raw_x + raw_n 15 | self.ground_truth = raw_x 16 | 17 | return self.img 18 | 19 | def step(self, action, t): #This step() is specific for denoising 20 | move = action.astype(np.float32) 21 | move = (move[:,np.newaxis,:,:]-13.0)/255 22 | self.img = self.img+move 23 | r = 255*np.square(self.prev - self.ground_truth)-255*np.square(self.img-self.ground_truth) # Reward 24 | s_prime = self.img.copy() 25 | self.prev = self.img.copy() 26 | 27 | if t>4: 28 | done = True 29 | else: 30 | done = False 31 | 32 | return s_prime, r, done 33 | 34 | 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 RongkaiZhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Loader_batch.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import numpy as np 4 | 5 | def load_images_from_folder(folder): 6 | images = [] 7 | for filename in os.listdir(folder): 8 | img = cv2.imread(os.path.join(folder,filename),0) 9 | if img is not None: 10 | images.append(img) 11 | 12 | return images 13 | 14 | def data_crop(img,crop_size): 15 | ims = [] 16 | for im in img: 17 | h, w = im.shape 18 | rand_range_h = h-crop_size 19 | rand_range_w = w-crop_size 20 | x_offset = np.random.randint(rand_range_w) 21 | y_offset = np.random.randint(rand_range_h) 22 | im = im[y_offset:y_offset+crop_size, x_offset:x_offset+crop_size] 23 | ims.append(im) 24 | 25 | return ims 26 | 27 | def data_augment(img): 28 | ims = [] 29 | for im in img: 30 | h, w = im.shape 31 | if np.random.rand() > 0.5: 32 | im = np.fliplr(im) 33 | if np.random.rand() > 0.5: 34 | angle = 10*np.random.rand() 35 | if np.random.rand() > 0.5: 36 | angle *= -1 37 | M = cv2.getRotationMatrix2D((w/2,h/2),angle,1) 38 | im = cv2.warpAffine(im,M,(w,h)) 39 | ims.append(im) 40 | 41 | return ims 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep-Reinforcement-Learning-for-Image-Denoising-via-Residual-Recovery (R3L) in pytorch. 2 | 1. This is a simple pytorch implementation of DRL (PPO is used) for image denoising via residual recovery. 3 | 2. Detailed illustration can be found in our paper [R3L: Connecting Deep Reinforcement Learning To Recurrent Neural Networks For Image Denoising Via Residual Recovery](https://arxiv.org/abs/2107.05318) (accepted by ICIP 2021). 4 | 3. Although this project is for a specific task, this framework is designed ASAP (as simple as possible) to be applied for different tasks trained in "Batch Environment" (Batch * Channel * Height * Width) by slightly modifing the corresponding network and envrionment. 5 | # Introduction to this implementation: 6 | 1. Current implementations of PPO usually focus on environments with states in shape of (Height * Width) raising a gap for implementations in CV where (Channel * Height * Width) is needed. 7 | 2. This implementation aims for an easy-to-modify PPO framework for CV tasks. 8 | 3. The PPO used here is [PPO-clip](https://spinningup.openai.com/en/latest/algorithms/ppo.html). 9 | # How to apply to other tasks: 10 | 1. Customize the environment by setting task specific reset(), step() in environment.py. 11 | 2. Customize the data file paths in PPO_batch.py. 12 | 3. Customize data argumentation in Load_batch.py. 13 | # Dependance: 14 | 1. pytorch >= 1.6 15 | 2. opencv 16 | # Citation 17 | 18 | In case of use, please cite our publication: 19 | 20 | Rongkai Zhang, Jiang Zhu, Zhiyuan Zha, Justin Dauwels, Bihan Wen, "R3L: Connecting Deep Reinforcement Learning to Recurrent Neural Networks for Image Denoising via Residual Recovery," ICIP 2021. 21 | 22 | Bibtex: 23 | ``` 24 | @inproceedings{zhang2021r3l, 25 | title={R3L: Connecting Deep Reinforcement Learning to Recurrent Neural Networks for Image Denoising via Residual Recovery}, 26 | author={Zhang, Rongkai and Zhu, Jiang and Zha, Zhiyuan and Dauwels, Justin and Wen, Bihan}, 27 | booktitle={2021 IEEE International Conference on Image Processing (ICIP)}, 28 | pages={1624--1628}, 29 | year={2021}, 30 | organization={IEEE} 31 | } 32 | ``` 33 | -------------------------------------------------------------------------------- /PPO_batch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torch.distributions import Categorical 6 | from Loader_batch import * 7 | import cv2 8 | from Environment import Cusenv 9 | 10 | #Hyperparameters for PPO 11 | gamma = 0.95 12 | lmbda = 1 13 | eps_clip = 0.5 14 | K_epoch = 15 # Minibatch_size in PPO 15 | Train_Folder = 'BSD68/gray/train/' 16 | Test_Folder = 'BSD68/gray/test/' 17 | Crop_Size = 70 18 | 19 | class PPO(nn.Module): 20 | def __init__(self, device, batch_size): 21 | super(PPO, self).__init__() 22 | self.data = [] # data buffer is built in PPO class 23 | self.batch_size = batch_size # input data is [batch_size, channel, H, W] 24 | #Load DnCNN weights and bias for initialization 25 | net = torch.load('dncnn_25.pth') 26 | #The layers to build network (FCN is used here and can be replaced for different tasks) 27 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1, dilation=1, bias=True, padding_mode='zeros') 28 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=2, dilation=2, bias=True, padding_mode='zeros') 29 | self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=3, dilation=3, bias=True, padding_mode='zeros') 30 | self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=4, dilation=4, bias=True, padding_mode='zeros') 31 | self.conv5_pi = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=3, dilation=3, bias=True, padding_mode='zeros') 32 | self.conv6_pi = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=2, dilation=2, bias=True, padding_mode='zeros') 33 | self.conv7_pi = nn.Conv2d(in_channels=64, out_channels=27, kernel_size=3, stride=1, padding=1, dilation=1, bias=True, padding_mode='zeros') 34 | self.conv5_v = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=3, dilation=3, bias=True, padding_mode='zeros') 35 | self.conv6_v = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=2, dilation=2, bias=True, padding_mode='zeros') 36 | self.conv7_v = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, dilation=1, bias=True, padding_mode='zeros') 37 | self.relu = nn.ReLU(inplace= True) 38 | #Initialize the weights and bias 39 | nn.init.kaiming_normal_(self.conv1.weight) 40 | self.conv2.weight.data = net['model.6.weight'] 41 | self.conv2.bias.data = net['model.6.bias'] 42 | self.conv3.weight.data = net['model.12.weight'] 43 | self.conv3.bias.data = net['model.12.bias'] 44 | self.conv4.weight.data = net['model.18.weight'] 45 | self.conv4.bias.data = net['model.18.bias'] 46 | self.conv5_pi.weight.data = net['model.24.weight'] 47 | self.conv5_pi.bias.data = net['model.24.bias'] 48 | self.conv6_pi.weight.data = net['model.30.weight'] 49 | self.conv6_pi.bias.data = net['model.30.bias'] 50 | nn.init.kaiming_normal_(self.conv7_pi.weight) 51 | self.conv5_v.weight.data = net['model.24.weight'] 52 | self.conv5_v.bias.data = net['model.24.bias'] 53 | self.conv6_v.weight.data = net['model.30.weight'] 54 | self.conv6_v.bias.data = net['model.30.bias'] 55 | nn.init.kaiming_normal_(self.conv7_v.weight) 56 | self.optimizer = optim.Adam(self.parameters(), lr=1e-3) 57 | self.device = device 58 | self.train = True 59 | 60 | def pi(self, x, softmax_dim = 1): 61 | x = self.relu(self.conv1(x)) 62 | x = self.relu(self.conv2(x)) 63 | x = self.relu(self.conv3(x)) 64 | x = self.relu(self.conv4(x)) 65 | x = self.relu(self.conv5_pi(x)) 66 | x = self.relu(self.conv6_pi(x)) 67 | x = self.conv7_pi(x) 68 | prob = F.softmax(x, dim=softmax_dim) 69 | return prob 70 | 71 | def v(self, x): 72 | x = self.relu(self.conv1(x)) 73 | x = self.relu(self.conv2(x)) 74 | x = self.relu(self.conv3(x)) 75 | x = self.relu(self.conv4(x)) 76 | x = self.relu(self.conv5_v(x)) 77 | x = self.relu(self.conv6_v(x)) 78 | v = self.conv7_v(x) 79 | return v 80 | 81 | def put_data(self, transition): 82 | self.data.append(transition) 83 | 84 | def make_batch(self):# Convert list of transactions into tensors 85 | s_lst, a_lst, r_lst, s_prime_lst, prob_a_lst, done_lst= [], [], [], [], [], [] 86 | for transition in self.data: 87 | s, a, r, s_prime, prob_a, done = transition 88 | 89 | s_lst.extend(s) 90 | a_lst.extend(a) 91 | r_lst.extend(r) 92 | s_prime_lst.extend(s_prime) 93 | prob_a_lst.extend(prob_a) 94 | done_mask = 0 if done else 1 95 | done_lst.extend([done_mask]) 96 | 97 | 98 | s, a, r, s_prime, prob_a, done_mask = torch.tensor(s_lst, dtype=torch.float).to(self.device), torch.tensor(a_lst).to(self.device), \ 99 | torch.tensor(r_lst, dtype=torch.float).to(self.device), torch.tensor(s_prime_lst, dtype=torch.float).to(self.device), \ 100 | torch.tensor(prob_a_lst).to(self.device), torch.tensor(done_lst, dtype=torch.float).to(self.device) 101 | self.data = [] 102 | a = a.unsqueeze(1) 103 | prob_a = prob_a.unsqueeze(1) 104 | #s = s.squeeze(1) 105 | #r = r.squeeze(1) 106 | #s_prime = s_prime.squeeze(1) 107 | 108 | 109 | return s, a, r, s_prime, prob_a, done_mask.repeat_interleave(self.batch_size) 110 | 111 | def train_net(self):# Train the network 112 | s, a, r, s_prime, prob_a, done_mask = self.make_batch() 113 | 114 | for i in range(K_epoch): 115 | td_target = r + gamma * torch.mul(self.v(s_prime), done_mask.view(r.shape[0], 1, 1, 1)) 116 | delta = td_target - self.v(s) 117 | delta = delta.detach().cpu().numpy() 118 | 119 | advantage_lst = [] 120 | advantage = 0.0 121 | # reorder the transactions by batch 122 | for time in range(int(len(delta) / self.batch_size) - 1, -1, -1): 123 | advantage = gamma * lmbda * advantage + delta[time*self.batch_size:(time+1)*self.batch_size] 124 | for adv in advantage[::-1]: 125 | advantage_lst.insert(0,adv) 126 | advantage = torch.tensor(advantage_lst, dtype=torch.float).to(self.device) 127 | 128 | pi = self.pi(s) 129 | #If use entropy in loss 130 | #pi = pi+1e-13 #1e-13 for stable training 131 | #entropy = (- pi*torch.log(pi)).sum(1).mean() 132 | pi_a = pi.gather(1, a) 133 | ratio = (pi_a / prob_a) # a/b == exp(log(a)-log(b)) 134 | surr1 = ratio * advantage 135 | surr2 = torch.clamp(ratio, 1 - eps_clip, 1 + eps_clip) * advantage 136 | loss = -torch.min(surr1, surr2) + F.smooth_l1_loss(self.v(s), td_target.detach()) #-0.01*entropy if use entropy 137 | #loss = -torch.min(torch.mean(surr1), torch.mean(surr2)) + F.smooth_l1_loss(self.v(s), td_target.detach()) #-0.01*entropy if use entropy 138 | loss = loss.mean() 139 | self.optimizer.zero_grad() 140 | loss.mean().backward() 141 | self.optimizer.step() 142 | #Train and validate in main() 143 | def main(): 144 | #Define hyperparameters for training 145 | device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu") 146 | batch_size = 32 147 | print_interval = 1 148 | test_interval = 1000 149 | learning_rate = 0.001 150 | save_interval = 3000 151 | #Setup env. Cusenv() is a customized enironemnt which takes in the current state S and outputs the new state S_prime, Reward and Done (for episodic task). 152 | env = Cusenv() 153 | #Use PPO model. 154 | model = PPO(device,batch_size).to(device) 155 | #Load all data from files directly 156 | training_data = load_images_from_folder(Train_Folder) 157 | train_data_size = len(training_data) 158 | test_data = load_images_from_folder(Test_Folder) 159 | #Start to train 30001 episodes 160 | i = 0 161 | for n_epi in range(30001):# 162 | score = 0.0 163 | #Data preprocessing (augmentationm, crop, resize to (B,C,H,W), normalize) 164 | data = training_data[i:i + batch_size] 165 | data = data_augment(data) 166 | data = data_crop(data, Crop_Size) 167 | data = np.array(data) 168 | data = data[:, np.newaxis, :, :] 169 | raw_x = data / 255 170 | #Generate noise 171 | raw_n = np.random.normal(0, 25, raw_x.shape).astype(raw_x.dtype) / 255 172 | #Reset env (here is just add noise to image) 173 | s = env.reset(raw_x, raw_n) 174 | done = False 175 | t_info = 0 176 | #Rollout 177 | while not done: 178 | prob = model.pi(torch.from_numpy(s).float().to(device)).detach().cpu() 179 | prob = prob.permute(0,2,3,1) 180 | m = Categorical(prob) 181 | a = m.sample() 182 | prob_a = torch.exp(m.log_prob(a)).clone().numpy() 183 | action = a.numpy() 184 | s_prime, r, done = env.step(action, t_info) 185 | #if done: 186 | #print('Epoch', n_epi,'Image',n, 'process', t_info, 'steps') 187 | t_info += 1 188 | model.put_data((s, action, r, s_prime, prob_a, done)) #All (s, action, r, s_prime, prob_a, done) are saved as numpy for convenience 189 | s = s_prime 190 | score += r 191 | #Train model after one rollout with batch_size images. 192 | model.train_net() 193 | 194 | #Show score (for image denoising with sigma = 25, the max score (when all noise is removed) should be ) 195 | if n_epi%print_interval==0 and n_epi!=0: 196 | print("# of episode :{}, avg score : {:.3f}".format(n_epi, np.mean(score)*255)) 197 | score = 0.0 198 | 199 | #Validate the model 200 | if n_epi%test_interval==0 and n_epi!=0: 201 | test_result = 0 202 | test_id = 0 203 | img_id = 0 204 | input_psnr = 0 205 | for im in test_data: 206 | data = im[np.newaxis, np.newaxis, :, :] 207 | raw_x = data / 255 208 | raw_n = np.random.normal(0, 25, raw_x.shape).astype(raw_x.dtype) / 255 209 | s = env.reset(raw_x, raw_n) 210 | I = np.maximum(0, raw_x) 211 | I = np.minimum(1, I) 212 | N = np.maximum(0, raw_x + raw_n) 213 | N = np.minimum(1, N) 214 | I = (I[0] * 255+0.5).astype(np.uint8) 215 | N = (N[0] * 255+0.5).astype(np.uint8) 216 | I = np.transpose(I, (1, 2, 0)) 217 | N = np.transpose(N, (1, 2, 0)) 218 | cv2.imwrite('result_noentropy/' + str(test_id) + '_input.png', N) 219 | psnr1_cv = cv2.PSNR(N, I) 220 | #psnr1 = np.mean(10 * np.log10(1 / (s - env.ground_truth) ** 2)) 221 | #for t in range(T_horizon): 222 | done = False 223 | t = 0 224 | while not done: 225 | prob = model.pi(torch.from_numpy(s).float().to(device)) 226 | #prob = prob.permute(0,2,3,1).detach().cpu() 227 | #m = Categorical(prob) 228 | _, a = torch.max(prob, 1) 229 | action = a.cpu().numpy() 230 | s_prime, r, done = env.step(action,t) 231 | if done: 232 | print('test image', img_id, 'process', t, 'steps') 233 | s = s_prime 234 | t += 1 235 | img_id += 1 236 | p = np.maximum(0, s) 237 | p = np.minimum(1, p) 238 | p = (p[0] * 255+0.5).astype(np.uint8) 239 | p = np.transpose(p, (1, 2, 0)) 240 | #p = cv2.blur(p,(3,3)) 241 | cv2.imwrite('result_noentropy/' + str(test_id) + '_output.png', p) 242 | #psnr2 = np.mean(10*np.log10(1/(s-env.ground_truth)**2)) 243 | psnr2_cv = cv2.PSNR(p, I) 244 | test_result += psnr2_cv 245 | print('test: PSNR_CV before:',psnr1_cv, 'PSNR_CV after:', psnr2_cv) 246 | test_id +=1 247 | print('Overall performance:', test_result/len(test_data)) 248 | 249 | 250 | if n_epi%save_interval==0 and n_epi!=0: 251 | torch.save(model,'PPO_model_{}.pt'.format(n_epi)) 252 | score = 0.0 253 | 254 | if i + batch_size >= train_data_size: 255 | i = 0 256 | 257 | else: 258 | i += batch_size 259 | 260 | if i + 2 * batch_size >= train_data_size: 261 | i = train_data_size - batch_size 262 | 263 | model.optimizer = optim.Adam(model.parameters(), lr=learning_rate*((1-n_epi/30001)**0.9)) #Lr decay for stable training 264 | 265 | if __name__ == '__main__': 266 | main() 267 | --------------------------------------------------------------------------------