├── CNNs ├── logger.py ├── msssim.py ├── mymodel.py ├── myutils.py ├── test.py └── train.py ├── GAN ├── GAN_model.py ├── loss.py ├── myutils.py └── train_GAN.py ├── Multi-Scale ├── MS_model.py ├── test_MS.py ├── test_MS_single.py └── train_multi_scale.py ├── README.md └── model_parameters ├── ARCNN_sao-qp42-90-0.001434-28.4783-0.8401param.pth ├── ARDenseNet_qp42-28-0.001368-28.6754-0.8436param.pth ├── BEST-MS21-0.001316-29.0139-0.8454param.pth ├── L1-164-0.024904-28.6810-0.8440param.pth ├── L8_qp42-99-0.001374-28.6387-0.8429param.pth ├── MSqp42-124-0.001293-29.0237-0.8455param.pth ├── edar2_qp42-63-0.001339-28.6824-0.8445param.pth └── vdar_qp42-51-0.001339-28.6784-0.8443param.pth /CNNs/logger.py: -------------------------------------------------------------------------------- 1 | # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 2 | import tensorflow as tf 3 | import numpy as np 4 | import scipy.misc 5 | try: 6 | from StringIO import StringIO # Python 2.7 7 | except ImportError: 8 | from io import BytesIO # Python 3.x 9 | 10 | 11 | class Logger(object): 12 | 13 | def __init__(self, log_dir): 14 | """Create a summary writer logging to log_dir.""" 15 | self.writer = tf.summary.FileWriter(log_dir) 16 | 17 | def scalar_summary(self, tag, value, step): 18 | """Log a scalar variable.""" 19 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 20 | self.writer.add_summary(summary, step) 21 | 22 | def image_summary(self, tag, images, step): 23 | """Log a list of images.""" 24 | 25 | img_summaries = [] 26 | for i, img in enumerate(images): 27 | # Write the image to a string 28 | try: 29 | s = StringIO() 30 | except: 31 | s = BytesIO() 32 | scipy.misc.toimage(img).save(s, format="png") 33 | 34 | # Create an Image object 35 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 36 | height=img.shape[0], 37 | width=img.shape[1]) 38 | # Create a Summary value 39 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 40 | 41 | # Create and write Summary 42 | summary = tf.Summary(value=img_summaries) 43 | self.writer.add_summary(summary, step) 44 | 45 | def histo_summary(self, tag, values, step, bins=1000): 46 | """Log a histogram of the tensor of values.""" 47 | 48 | # Create a histogram using numpy 49 | counts, bin_edges = np.histogram(values, bins=bins) 50 | 51 | # Fill the fields of the histogram proto 52 | hist = tf.HistogramProto() 53 | hist.min = float(np.min(values)) 54 | hist.max = float(np.max(values)) 55 | hist.num = int(np.prod(values.shape)) 56 | hist.sum = float(np.sum(values)) 57 | hist.sum_squares = float(np.sum(values**2)) 58 | 59 | # Drop the start of the first bin 60 | bin_edges = bin_edges[1:] 61 | 62 | # Add bin edges and counts 63 | for edge in bin_edges: 64 | hist.bucket_limit.append(edge) 65 | for c in counts: 66 | hist.bucket.append(c) 67 | 68 | # Create and write Summary 69 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 70 | self.writer.add_summary(summary, step) 71 | self.writer.flush() 72 | 73 | -------------------------------------------------------------------------------- /CNNs/msssim.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | 18 | """Python implementation of MS-SSIM. 19 | 20 | Usage: 21 | 22 | python msssim.py --original_image=original.png --compared_image=distorted.png 23 | """ 24 | import numpy as np 25 | from scipy import signal 26 | from scipy.ndimage.filters import convolve 27 | import tensorflow as tf 28 | 29 | 30 | tf.flags.DEFINE_string('original_image', None, 'Path to PNG image.') 31 | tf.flags.DEFINE_string('compared_image', None, 'Path to PNG image.') 32 | tf.flags.DEFINE_string('msssim_write_path', None, 'Path to write MS-SSIM.') 33 | FLAGS = tf.flags.FLAGS 34 | 35 | 36 | def _FSpecialGauss(size, sigma): 37 | """Function to mimic the 'fspecial' gaussian MATLAB function.""" 38 | radius = size // 2 39 | offset = 0.0 40 | start, stop = -radius, radius + 1 41 | if size % 2 == 0: 42 | offset = 0.5 43 | stop -= 1 44 | x, y = np.mgrid[offset + start:stop, offset + start:stop] 45 | assert len(x) == size 46 | g = np.exp(-((x**2 + y**2)/(2.0 * sigma**2))) 47 | return g / g.sum() 48 | 49 | 50 | def _SSIMForMultiScale(img1, img2, max_val=255, filter_size=11, 51 | filter_sigma=1.5, k1=0.01, k2=0.03): 52 | """Return the Structural Similarity Map between `img1` and `img2`. 53 | 54 | This function attempts to match the functionality of ssim_index_new.m by 55 | Zhou Wang: http://www.cns.nyu.edu/~lcv/ssim/msssim.zip 56 | 57 | Arguments: 58 | img1: Numpy array holding the first RGB image batch. 59 | img2: Numpy array holding the second RGB image batch. 60 | max_val: the dynamic range of the images (i.e., the difference between the 61 | maximum the and minimum allowed values). 62 | filter_size: Size of blur kernel to use (will be reduced for small images). 63 | filter_sigma: Standard deviation for Gaussian blur kernel (will be reduced 64 | for small images). 65 | k1: Constant used to maintain stability in the SSIM calculation (0.01 in 66 | the original paper). 67 | k2: Constant used to maintain stability in the SSIM calculation (0.03 in 68 | the original paper). 69 | 70 | Returns: 71 | Pair containing the mean SSIM and contrast sensitivity between `img1` and 72 | `img2`. 73 | 74 | Raises: 75 | RuntimeError: If input images don't have the same shape or don't have four 76 | dimensions: [batch_size, height, width, depth]. 77 | """ 78 | if img1.shape != img2.shape: 79 | raise RuntimeError('Input images must have the same shape (%s vs. %s).', 80 | img1.shape, img2.shape) 81 | if img1.ndim != 4: 82 | raise RuntimeError('Input images must have four dimensions, not %d', 83 | img1.ndim) 84 | 85 | img1 = img1.astype(np.float64) 86 | img2 = img2.astype(np.float64) 87 | _, height, width, _ = img1.shape 88 | 89 | # Filter size can't be larger than height or width of images. 90 | size = min(filter_size, height, width) 91 | 92 | # Scale down sigma if a smaller filter size is used. 93 | sigma = size * filter_sigma / filter_size if filter_size else 0 94 | 95 | if filter_size: 96 | window = np.reshape(_FSpecialGauss(size, sigma), (1, size, size, 1)) 97 | mu1 = signal.fftconvolve(img1, window, mode='valid') 98 | mu2 = signal.fftconvolve(img2, window, mode='valid') 99 | sigma11 = signal.fftconvolve(img1 * img1, window, mode='valid') 100 | sigma22 = signal.fftconvolve(img2 * img2, window, mode='valid') 101 | sigma12 = signal.fftconvolve(img1 * img2, window, mode='valid') 102 | else: 103 | # Empty blur kernel so no need to convolve. 104 | mu1, mu2 = img1, img2 105 | sigma11 = img1 * img1 106 | sigma22 = img2 * img2 107 | sigma12 = img1 * img2 108 | 109 | mu11 = mu1 * mu1 110 | mu22 = mu2 * mu2 111 | mu12 = mu1 * mu2 112 | sigma11 -= mu11 113 | sigma22 -= mu22 114 | sigma12 -= mu12 115 | 116 | # Calculate intermediate values used by both ssim and cs_map. 117 | c1 = (k1 * max_val) ** 2 118 | c2 = (k2 * max_val) ** 2 119 | v1 = 2.0 * sigma12 + c2 120 | v2 = sigma11 + sigma22 + c2 121 | ssim = np.mean((((2.0 * mu12 + c1) * v1) / ((mu11 + mu22 + c1) * v2))) 122 | cs = np.mean(v1 / v2) 123 | return ssim, cs 124 | 125 | 126 | def MultiScaleSSIM(img1, img2, max_val=255, filter_size=11, filter_sigma=1.5, 127 | k1=0.01, k2=0.03, weights=None): 128 | """Return the MS-SSIM score between `img1` and `img2`. 129 | 130 | This function implements Multi-Scale Structural Similarity (MS-SSIM) Image 131 | Quality Assessment according to Zhou Wang's paper, "Multi-scale structural 132 | similarity for image quality assessment" (2003). 133 | Link: https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf 134 | 135 | Author's MATLAB implementation: 136 | http://www.cns.nyu.edu/~lcv/ssim/msssim.zip 137 | 138 | Arguments: 139 | img1: Numpy array holding the first RGB image batch. 140 | img2: Numpy array holding the second RGB image batch. 141 | max_val: the dynamic range of the images (i.e., the difference between the 142 | maximum the and minimum allowed values). 143 | filter_size: Size of blur kernel to use (will be reduced for small images). 144 | filter_sigma: Standard deviation for Gaussian blur kernel (will be reduced 145 | for small images). 146 | k1: Constant used to maintain stability in the SSIM calculation (0.01 in 147 | the original paper). 148 | k2: Constant used to maintain stability in the SSIM calculation (0.03 in 149 | the original paper). 150 | weights: List of weights for each level; if none, use five levels and the 151 | weights from the original paper. 152 | 153 | Returns: 154 | MS-SSIM score between `img1` and `img2`. 155 | 156 | Raises: 157 | RuntimeError: If input images don't have the same shape or don't have four 158 | dimensions: [batch_size, height, width, depth]. 159 | """ 160 | if img1.shape != img2.shape: 161 | raise RuntimeError('Input images must have the same shape (%s vs. %s).', 162 | img1.shape, img2.shape) 163 | if img1.ndim != 4: 164 | raise RuntimeError('Input images must have four dimensions, not %d', 165 | img1.ndim) 166 | 167 | # Note: default weights don't sum to 1.0 but do match the paper / matlab code. 168 | weights = np.array(weights if weights else 169 | [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]) 170 | levels = weights.size 171 | downsample_filter = np.ones((1, 2, 2, 1)) / 4.0 172 | im1, im2 = [x.astype(np.float64) for x in [img1, img2]] 173 | mssim = np.array([]) 174 | mcs = np.array([]) 175 | for _ in range(levels): 176 | ssim, cs = _SSIMForMultiScale( 177 | im1, im2, max_val=max_val, filter_size=filter_size, 178 | filter_sigma=filter_sigma, k1=k1, k2=k2) 179 | mssim = np.append(mssim, ssim) 180 | mcs = np.append(mcs, cs) 181 | filtered = [convolve(im, downsample_filter, mode='reflect') 182 | for im in [im1, im2]] 183 | im1, im2 = [x[:, ::2, ::2, :] for x in filtered] 184 | return (np.prod(mcs[0:levels-1] ** weights[0:levels-1]) * 185 | (mssim[levels-1] ** weights[levels-1])) 186 | -------------------------------------------------------------------------------- /CNNs/mymodel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from collections import namedtuple 5 | from torchvision import models 6 | import math 7 | from torch.autograd import Variable 8 | 9 | 10 | 11 | #model_1 12 | class ARCNN(nn.Module): 13 | 14 | def __init__(self): 15 | super(ARCNN, self).__init__() 16 | #feature extract 17 | self.conv1=nn.Conv2d(3,64,9,1,(9-1)/2) 18 | #nolinear mapping 19 | self.conv2=nn.Conv2d(64,32,7,1,(7-1)/2) 20 | self.conv3=nn.Conv2d(32,16,1,1,0) 21 | #reconstruct 22 | self.conv4=nn.Conv2d(16,3,5,1,(5-1)/2) 23 | #self.relu=nn.ReLU(inplace=False)#not sure 24 | 25 | for m in self.modules(): 26 | if isinstance(m, nn.Conv2d): 27 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 28 | m.weight.data.normal_(0, math.sqrt(2. / n)) 29 | if m.bias is not None: 30 | m.bias.data.zero_() 31 | 32 | def forward(self,x):# 33 | x=F.relu(self.conv1(x)) 34 | x=F.relu(self.conv2(x)) 35 | x=F.relu(self.conv3(x)) 36 | x=F.relu(self.conv4(x)) 37 | return x 38 | 39 | 40 | #model_2 41 | class L8(nn.Module): 42 | 43 | def __init__(self): 44 | super(L8, self).__init__() 45 | self.conv1=nn.Conv2d( 46 | in_channels=3, 47 | out_channels=32, 48 | kernel_size=11, 49 | stride=1, 50 | padding=5# if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1 51 | ) 52 | self.conv2=nn.Conv2d(32,64,3,1,1) 53 | self.conv3=nn.Conv2d(64,64,3,1,1) 54 | self.conv4=nn.Conv2d(64,64,3,1,1) 55 | self.conv5=nn.Conv2d(32+64,64,1,1,0) 56 | self.conv6=nn.Conv2d(64,64,5,1,2) 57 | self.conv7=nn.Conv2d(64+32,128,1,1,0) 58 | self.conv8=nn.Conv2d(128,3,5,1,2) 59 | #self.relu=nn.ReLU(inplace=False) 60 | 61 | for m in self.modules(): 62 | if isinstance(m, nn.Conv2d): 63 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 64 | m.weight.data.normal_(0, math.sqrt(2. / n)) 65 | if m.bias is not None: 66 | m.bias.data.zero_() 67 | 68 | def forward(self,x):#not sure what is x? 69 | x1=F.relu(self.conv1(x)) 70 | x=F.relu(self.conv2(x1)) 71 | x=F.relu(self.conv3(x)) 72 | x=F.relu(self.conv4(x)) 73 | x=torch.cat([x, x1], 1)#the dimensionality of Variable is [number,channel,height,width] 74 | x=F.relu(self.conv5(x)) 75 | x=F.relu(self.conv6(x)) 76 | x=torch.cat([x, x1], 1) 77 | x=F.relu(self.conv7(x)) 78 | x=F.relu(self.conv8(x)) 79 | return x 80 | 81 | 82 | #model_3:refer to vdsr 83 | class Conv_ReLU_Block(nn.Module): 84 | def __init__(self): 85 | super(Conv_ReLU_Block, self).__init__() 86 | self.conv = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 87 | self.relu = nn.ReLU(inplace=True) 88 | 89 | def forward(self, x): 90 | return self.relu(self.conv(x)) 91 | 92 | class vdar(nn.Module): 93 | def __init__(self): 94 | super(vdar, self).__init__() 95 | self.residual_layer = self.make_layer(Conv_ReLU_Block, 18) 96 | self.input = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 97 | self.output = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=False) 98 | self.relu = nn.ReLU(inplace=True) 99 | 100 | #weights initialization by normal(Gaussinn) distribution:normal_(mean=0, std=1 , gengerator=None*) 101 | for m in self.modules(): 102 | if isinstance(m, nn.Conv2d): 103 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 104 | m.weight.data.normal_(0, math.sqrt(2. / n)) 105 | 106 | def make_layer(self, block, num_of_layer): 107 | layers = [] 108 | for _ in range(num_of_layer): 109 | layers.append(block()) 110 | return nn.Sequential(*layers) 111 | 112 | def forward(self, x): 113 | residual = x 114 | out = self.relu(self.input(x)) 115 | out = self.residual_layer(out) 116 | out = self.output(out) 117 | out = torch.add(out,residual)#global residual 118 | return out 119 | 120 | 121 | 122 | class ResidualBlock(nn.Module): 123 | def __init__(self, channels=64): 124 | super(ResidualBlock, self).__init__() 125 | self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) 126 | self.relu = nn.ReLU(inplace=True)#attrntion:relu not prelu 127 | self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) 128 | 129 | def forward(self, x): 130 | residual = self.conv1(x) 131 | residual = self.relu(residual) 132 | residual = self.conv2(residual)#the original model multipy a value 0.1 here,which is to be analyzed in the feature 133 | out = torch.add(x,residual) 134 | 135 | return out 136 | 137 | 138 | #model_4_1 refer to edsr 139 | class edar(nn.Module): 140 | def __init__(self): 141 | super(edar, self).__init__() 142 | self.head= nn.Conv2d(3,64,3,1,1) 143 | #self.body = nn.ModuleList([ResidualBlock(64) for i in range(8)]) 144 | self.body = self.make_layer(ResidualBlock, 8) 145 | self.tail=nn.Conv2d(64,64,3,1,1) 146 | self.reconstruct=nn.Conv2d(64,3,3,1,1) 147 | self.relu= nn.ReLU(inplace=True) 148 | 149 | #weights initialization by normal(Gaussinn) distribution:normal_(mean=0, std=1 , gengerator=None*) 150 | for m in self.modules(): 151 | if isinstance(m, nn.Conv2d): 152 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 153 | m.weight.data.normal_(0, math.sqrt(2. / n)) 154 | 155 | def make_layer(self, block, num_of_layer): 156 | layers = [] 157 | for _ in range(num_of_layer): 158 | layers.append(block()) 159 | return nn.Sequential(*layers) 160 | 161 | def forward(self, x): 162 | x = self.head(x) 163 | residual=x#i'm afraid this may cause residual is x 164 | for layer in self.body: 165 | residual = layer(residual) 166 | residual = self.tail(residual)# 167 | out = torch.add(x,residual)#global residual 168 | #out=self.tail(out) 169 | out = self.reconstruct(out) 170 | return out 171 | 172 | 173 | #model_4_2 refer to edsr 174 | class edar2(nn.Module): 175 | def __init__(self): 176 | super(edar2, self).__init__() 177 | self.head= nn.Conv2d(3,64,3,1,1) 178 | #self.body = nn.ModuleList([ResidualBlock(64) for i in range(8)]) 179 | self.body = self.make_layer(ResidualBlock, 8) 180 | self.tail=nn.Conv2d(64,64,3,1,1) 181 | self.reconstruct=nn.Conv2d(128,3,3,1,1) 182 | self.relu= nn.ReLU(inplace=True) 183 | 184 | #weights initialization by normal(Gaussinn) distribution:normal_(mean=0, std=1 , gengerator=None*) 185 | for m in self.modules(): 186 | if isinstance(m, nn.Conv2d): 187 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 188 | m.weight.data.normal_(0, math.sqrt(2. / n)) 189 | 190 | def make_layer(self, block, num_of_layer): 191 | layers = [] 192 | for _ in range(num_of_layer): 193 | layers.append(block()) 194 | return nn.Sequential(*layers) 195 | 196 | def forward(self, x): 197 | x = self.head(x) 198 | residual=x#i'm afraid this may cause residual is x 199 | residual = self.body(residual) 200 | residual = self.relu(self.tail(residual))# 201 | out = torch.cat([x,residual],1)#global residual 202 | #out=self.tail(out) 203 | out = self.reconstruct(out) 204 | return out 205 | 206 | 207 | #model_5 refer to SRDenseNet 208 | class _Dense_Block(nn.Module): 209 | def __init__(self, channel_in): 210 | super(_Dense_Block, self).__init__() 211 | 212 | self.relu = nn.PReLU() 213 | self.conv1 = nn.Conv2d(in_channels=channel_in, out_channels=16, kernel_size=3, stride=1, padding=1) 214 | self.conv2 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1) 215 | self.conv3 = nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=1) 216 | self.conv4 = nn.Conv2d(in_channels=48, out_channels=16, kernel_size=3, stride=1, padding=1) 217 | self.conv5 = nn.Conv2d(in_channels=64, out_channels=16, kernel_size=3, stride=1, padding=1) 218 | self.conv6 = nn.Conv2d(in_channels=80, out_channels=16, kernel_size=3, stride=1, padding=1) 219 | self.conv7 = nn.Conv2d(in_channels=96, out_channels=16, kernel_size=3, stride=1, padding=1) 220 | self.conv8 = nn.Conv2d(in_channels=112, out_channels=16, kernel_size=3, stride=1, padding=1) 221 | 222 | def forward(self, x): 223 | conv1 = self.relu(self.conv1(x)) 224 | 225 | conv2 = self.relu(self.conv2(conv1)) 226 | cout2_dense = self.relu(torch.cat([conv1,conv2], 1)) 227 | 228 | conv3 = self.relu(self.conv3(cout2_dense)) 229 | cout3_dense = self.relu(torch.cat([conv1,conv2,conv3], 1)) 230 | 231 | conv4 = self.relu(self.conv4(cout3_dense)) 232 | cout4_dense = self.relu(torch.cat([conv1,conv2,conv3,conv4], 1)) 233 | 234 | conv5 = self.relu(self.conv5(cout4_dense)) 235 | cout5_dense = self.relu(torch.cat([conv1,conv2,conv3,conv4,conv5], 1)) 236 | 237 | conv6 = self.relu(self.conv6(cout5_dense)) 238 | cout6_dense = self.relu(torch.cat([conv1,conv2,conv3,conv4,conv5,conv6], 1)) 239 | 240 | conv7 = self.relu(self.conv7(cout6_dense)) 241 | cout7_dense = self.relu(torch.cat([conv1,conv2,conv3,conv4,conv5,conv6,conv7], 1)) 242 | 243 | conv8 = self.relu(self.conv8(cout7_dense)) 244 | cout8_dense = self.relu(torch.cat([conv1,conv2,conv3,conv4,conv5,conv6,conv7,conv8], 1)) 245 | 246 | return cout8_dense 247 | 248 | class ARDenseNet(nn.Module): 249 | def __init__(self): 250 | super(ARDenseNet, self).__init__() 251 | 252 | self.relu = nn.PReLU() 253 | self.lowlevel = nn.Conv2d(in_channels=3, out_channels=128, kernel_size=3, stride=1, padding=1) 254 | self.bottleneck = nn.Conv2d(in_channels=640, out_channels=256, kernel_size=1, stride=1, padding=0, bias=False) 255 | self.reconstruction = nn.Conv2d(in_channels=256, out_channels=3, kernel_size=3, stride=1, padding=1, bias=False) 256 | self.denseblock1 = self.make_layer(_Dense_Block, 128) 257 | self.denseblock2 = self.make_layer(_Dense_Block, 256) 258 | self.denseblock3 = self.make_layer(_Dense_Block, 384) 259 | self.denseblock4 = self.make_layer(_Dense_Block, 512) 260 | 261 | for m in self.modules(): 262 | if isinstance(m, nn.Conv2d): 263 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 264 | m.weight.data.normal_(0, math.sqrt(2. / n)) 265 | if m.bias is not None: 266 | m.bias.data.zero_() 267 | 268 | def make_layer(self, block, channel_in): 269 | layers = [] 270 | layers.append(block(channel_in)) 271 | return nn.Sequential(*layers) 272 | 273 | def forward(self, x): 274 | residual = self.relu(self.lowlevel(x)) 275 | 276 | out = self.denseblock1(residual) 277 | concat = torch.cat([residual,out], 1) 278 | 279 | out = self.denseblock2(concat) 280 | concat = torch.cat([concat,out], 1) 281 | 282 | out = self.denseblock3(concat) 283 | concat = torch.cat([concat,out], 1) 284 | 285 | out = self.denseblock4(concat) 286 | concat = torch.cat([concat,out], 1) 287 | 288 | out = self.bottleneck(concat) 289 | 290 | out = self.reconstruction(out) 291 | 292 | return out 293 | 294 | 295 | class Vgg16(nn.Module): 296 | def __init__(self, requires_grad=False): 297 | super(Vgg16, self).__init__() 298 | vgg_pretrained_features = models.vgg16(pretrained=True).features#It will dawnload the parameters,which will spend some time 299 | self.slice1 = torch.nn.Sequential() 300 | self.slice2 = torch.nn.Sequential() 301 | self.slice3 = torch.nn.Sequential() 302 | self.slice4 = torch.nn.Sequential() 303 | for x in range(4): 304 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 305 | for x in range(4, 9): 306 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 307 | for x in range(9, 16): 308 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 309 | for x in range(16, 23): 310 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 311 | if not requires_grad: 312 | for param in self.parameters(): 313 | param.requires_grad = False 314 | 315 | def forward(self, X): 316 | h = self.slice1(X) 317 | h_relu1_2 = h 318 | h = self.slice2(h) 319 | h_relu2_2 = h 320 | h = self.slice3(h) 321 | h_relu3_3 = h 322 | h = self.slice4(h) 323 | h_relu4_3 = h 324 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3']) 325 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3) 326 | return out 327 | 328 | 329 | 330 | 331 | if __name__== '__main__': 332 | model=ARCNN().cuda() 333 | print('Model Structure:',model) 334 | 335 | params = list(model.parameters()) 336 | for i in range(len(params)): 337 | print('layer:',i+1,params[i].size()) 338 | print('parameters:', sum(param.numel() for param in model.parameters())) 339 | 340 | 341 | -------------------------------------------------------------------------------- /CNNs/myutils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | from math import exp,log10 6 | 7 | #mormalize input batch for vgg 8 | def normalize_batch(batch): 9 | # normalize using imagenet mean and std 10 | mean = batch.data.new(batch.data.size()) 11 | std = batch.data.new(batch.data.size()) 12 | mean[:, 0, :, :] = 0.485 13 | mean[:, 1, :, :] = 0.456 14 | mean[:, 2, :, :] = 0.406 15 | std[:, 0, :, :] = 0.229 16 | std[:, 1, :, :] = 0.224 17 | std[:, 2, :, :] = 0.225 18 | batch = torch.div(batch, 1.0)#Attention!BUG 19 | batch -= Variable(mean) 20 | batch = batch / Variable(std) 21 | return batch 22 | 23 | #caculate ssim 24 | def gaussian(window_size, sigma): 25 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 26 | return gauss/gauss.sum() 27 | 28 | def create_window(window_size, channel): 29 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 30 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 31 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 32 | return window 33 | 34 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 35 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 36 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 37 | 38 | mu1_sq = mu1.pow(2) 39 | mu2_sq = mu2.pow(2) 40 | mu1_mu2 = mu1*mu2 41 | 42 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 43 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 44 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 45 | 46 | C1 = 0.01**2 47 | C2 = 0.03**2 48 | 49 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 50 | 51 | if size_average: 52 | return ssim_map.mean() 53 | else: 54 | return ssim_map.mean(1).mean(1).mean(1) 55 | 56 | def ssim(img1, img2, window_size = 11, size_average = True): 57 | (_, channel, _, _) = img1.size() 58 | window = create_window(window_size, channel) 59 | 60 | if img1.is_cuda: 61 | window = window.cuda(img1.get_device()) 62 | window = window.type_as(img1) 63 | 64 | return _ssim(img1, img2, window, window_size, channel, size_average) 65 | 66 | #caculate psnr 67 | def psnr(tesnor1, tensor2): 68 | mse = nn.MSELoss()(tesnor1, tensor2) 69 | psnr = 10 * log10(1 / mse.data[0]) 70 | return psnr -------------------------------------------------------------------------------- /CNNs/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | #from PIL import Image 3 | import cv2 4 | import numpy 5 | import torch 6 | from torch.utils.data import Dataset, DataLoader 7 | from torch.autograd import Variable 8 | import mymodel 9 | import MS_model 10 | import myutils 11 | import msssim 12 | import torch.nn as nn 13 | 14 | 15 | class ImageDataset(Dataset): 16 | def __init__(self, root_dir, transform=None): 17 | 18 | self.input_dir = os.path.join(root_dir,'output_qp_42_no_sao','320x480') 19 | self.label_dir = os.path.join(root_dir,'images_png','320x480') 20 | self.transform = transform 21 | 22 | def __len__(self): 23 | return os.listdir(self.input_dir).__len__() 24 | 25 | def __getitem__(self, idx): 26 | input_names = sorted(os.listdir(self.input_dir)) 27 | label_names = sorted(os.listdir(self.label_dir)) 28 | 29 | input_name = os.path.join(self.input_dir,input_names[idx]) 30 | #input_image =Image.open(input_name)#Image get jpg 31 | input_image=cv2.imread(input_name) 32 | input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB) 33 | 34 | label_name = os.path.join(self.label_dir,label_names[idx]) 35 | #label_image =Image.open(label_name) 36 | label_image=cv2.imread(label_name) 37 | label_image = cv2.cvtColor(label_image, cv2.COLOR_BGR2RGB) 38 | 39 | sample = {'input_image': input_image, 'label_image': label_image, 'name': input_names[idx]} 40 | 41 | if self.transform: 42 | sample = self.transform(sample) 43 | 44 | return sample 45 | 46 | def edge_clip(image): 47 | if image.shape[0]%2==1: 48 | image=image[:-1,:,:] 49 | if image.shape[1]%2==1: 50 | image=image[:,:-1,:] 51 | return image 52 | 53 | 54 | class mytransform(object): 55 | def __call__(self, sample): 56 | input_image, label_image,name= sample['input_image'], sample['label_image'],sample['name'] 57 | # swap color axis because 58 | # numpy image: H x W x C 59 | # torch image: C X H X W 60 | #input_image = numpy.asarray(input_image).transpose(2, 0, 1)/255.0 61 | #label_image = numpy.asarray(label_image).transpose(2, 0, 1)/255.0 62 | input_image=edge_clip(input_image) 63 | label_image=edge_clip(label_image) 64 | 65 | input_image = input_image.transpose(2, 0, 1)/255.0 66 | label_image = label_image.transpose(2, 0, 1)/255.0 67 | 68 | 69 | return {'input_image': torch.from_numpy(input_image).float(), 70 | 'label_image': torch.from_numpy(label_image).float(), 71 | 'name':name} 72 | 73 | def wrap_variable(input_batch, label_batch, use_gpu,flag): 74 | if use_gpu: 75 | input_batch, label_batch = (Variable(input_batch.cuda(),volatile=flag), Variable(label_batch.cuda(),volatile=flag)) 76 | 77 | else: 78 | input_batch, label_batch = (Variable(input_batch,volatile=flag),Variable(label_batch,volatile=flag)) 79 | return input_batch, label_batch 80 | 81 | 82 | 83 | def checkpoint(name,psnr1,psnr2,ssim1,ssim2,msssim1,msssim2): 84 | print('{},psnr:{:.4f}->{:.4f},ssim:{:.4f}->{:.4f},msssim:{:.4f}-{:.4f}'.format(name,psnr1,psnr2,ssim1,ssim2,msssim1,msssim2)) 85 | #write to text 86 | output = open(os.path.join(Image_folder,'test_result.txt'),'a+') 87 | output.write(('{} {:.4f}->{:.4f} {:.4f}->{:.4f},{:.4f}->{:.4f}'.format(name,psnr1,psnr2,ssim1,ssim2,msssim1,msssim2))+'\r\n') 88 | output.close() 89 | 90 | def save(output_image,name): 91 | output_data=output_image.data[0] 92 | if use_gpu: 93 | img=255.0*output_data.clone().cpu().numpy() 94 | else: 95 | img=255.0*output_data.clone().numpy() 96 | img = img.transpose(1, 2, 0).astype("uint8") 97 | #img = Image.fromarray(img) 98 | #img.save(os.path.join(Image_folder,'output','{}.jpg'.format(name[:-4]))) 99 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 100 | cv2.imwrite(os.path.join(Image_folder,'output','{}.png'.format(name[:-4])),img) 101 | 102 | downsampling=nn.AvgPool2d(2) 103 | 104 | 105 | def test(): 106 | model.eval() 107 | #input and label 108 | avg_psnr1 = 0 109 | avg_ssim1 = 0 110 | avg_msssim1 = 0 111 | #output and label 112 | avg_psnr2 = 0 113 | avg_ssim2 = 0 114 | avg_msssim2 = 0 115 | 116 | for i, sample in enumerate(dataloader): 117 | input_image,label_image,name=sample['input_image'],sample['label_image'],sample['name'][0]#tuple to str 118 | 119 | 120 | #Wrap with torch Variable 121 | input_image,label_image=wrap_variable(input_image,label_image, use_gpu,True) 122 | #predict 123 | output_image = model(input_image) 124 | #clamp in[0,1] 125 | 126 | 127 | ''' 128 | ################## 129 | #Wrap with torch Variable 130 | input_image,label_image=wrap_variable(input_image,label_image, use_gpu,True) 131 | inputs_1=downsampling(input_image) 132 | label_1=downsampling(label_image) 133 | inputs_2=downsampling(inputs_1) 134 | label_2=downsampling(label_1) 135 | 136 | output_image, outputs_1, outputs_2 = model(input_image,inputs_1,inputs_2) 137 | ################# 138 | ''' 139 | 140 | 141 | 142 | output_image=output_image.clamp(0.0, 1.0) 143 | 144 | 145 | #calculate psnr 146 | psnr1 =myutils.psnr(input_image, label_image) 147 | psnr2 =myutils.psnr(output_image, label_image) 148 | 149 | # ssim is calculated with the normalize (range [0, 1]) image 150 | ssim1 = torch.sum((myutils.ssim(input_image, label_image, size_average=False)).data)/1.0#batch_size 151 | ssim2 = torch.sum((myutils.ssim(output_image, label_image, size_average=False)).data)/1.0 152 | 153 | #msssim 154 | msssim1 = numpy.sum((msssim.MultiScaleSSIM(numpy.expand_dims(input_image.data[0].clone().cpu().numpy(),axis=0), numpy.expand_dims(label_image.data[0].clone().cpu().numpy(),axis=0), max_val=1.0)))/1.0#batch_size 155 | msssim2 = numpy.sum((msssim.MultiScaleSSIM(numpy.expand_dims(output_image.data[0].clone().cpu().numpy(),axis=0), numpy.expand_dims(label_image.data[0].clone().cpu().numpy(),axis=0), max_val=1.0)))/1.0 156 | 157 | avg_ssim1 += ssim1 158 | avg_psnr1 += psnr1 159 | avg_ssim2 += ssim2 160 | avg_psnr2 += psnr2 161 | avg_msssim1 += msssim1 162 | avg_msssim2 += msssim2 163 | 164 | #save output and record 165 | checkpoint(name,psnr1,psnr2,ssim1,ssim2,msssim1,msssim2) 166 | save(output_image,name) 167 | 168 | #print and save 169 | avg_psnr1 = avg_psnr1/len(dataloader) 170 | avg_ssim1 = avg_ssim1/len(dataloader) 171 | avg_psnr2 = avg_psnr2/len(dataloader) 172 | avg_ssim2 = avg_ssim2/len(dataloader) 173 | avg_msssim1 = avg_msssim1/len(dataloader) 174 | avg_msssim2 = avg_msssim2/len(dataloader) 175 | 176 | print('Avg. PSNR: {:.4f}->{:.4f} Avg. SSIM: {:.4f}->{:.4f} Avg. MSSSIM: {:.4f}->{:.4f}'.format(avg_psnr1,avg_psnr2,avg_ssim1,avg_ssim2,avg_msssim1,avg_msssim2)) 177 | output = open(os.path.join(Image_folder,'test_result.txt'),'a+') 178 | output.write('Avg. PSNR: {:.4f}->{:.4f} Avg. SSIM: {:.4f}->{:.4f}Avg. MSSSIM: {:.4f}->{:.4f}'.format(avg_psnr1,avg_psnr2,avg_ssim1,avg_ssim2,avg_msssim1,avg_msssim2)+'\r\n') 179 | output.close() 180 | 181 | 182 | #------------------------------------------------------------------ 183 | #cuda 184 | use_gpu=torch.cuda.is_available() 185 | 186 | #set path 187 | root_dir=os.getcwd() 188 | Image_folder=os.path.join(root_dir) 189 | 190 | #model_weights_file=os.path.join(root_dir,'parameters_nosao','MS_47-0.001301-28.9525-0.8454param.pth') 191 | #model_weights_file=os.path.join(root_dir,'parameters_nosao','ARCNN_sao-qp42-90-0.001434-28.4783-0.8401param.pth') 192 | #model_weights_file=os.path.join(root_dir,'parameters_nosao','L8_qp42-99-0.001374-28.6387-0.8429param.pth') 193 | #model_weights_file=os.path.join(root_dir,'parameters_nosao','vdar_qp42-51-0.001339-28.6784-0.8443param.pth') 194 | #model_weights_file=os.path.join(root_dir,'parameters_nosao','ARDenseNet_qp42-28-0.001368-28.6754-0.8436param.pth') 195 | #model_weights_file=os.path.join(root_dir,'parameters_nosao','edar2_qp42-63-0.001339-28.6824-0.8445param.pth') 196 | #model_weights_file=os.path.join(root_dir,'parameters_nosao','vgg-10-0.001407-1.416533-28.476954-0.835027param.pth') 197 | #model_weights_file=os.path.join(root_dir,'parameters_nosao','L1-14-0.024889-28.6602-0.8444param.pth') 198 | model_weights_file=os.path.join(root_dir,'parameters_nosao','qp42-122--0.807578-28.5970-0.8461-0.9645param.pth') 199 | 200 | #model_weights_file=os.path.join(root_dir,'parameters_nosao37','ARCNN_qp37-6-0.000795-31.0831-0.9016param.pth') 201 | #model_weights_file=os.path.join(root_dir,'parameters_nosao37','VDAR_qp37-12-0.000741-31.3404-0.9050param.pth') 202 | #model_weights_file=os.path.join(root_dir,'parameters_nosao37','EDAR2qp37-23-0.000742-31.3593-0.9053param.pth') 203 | #model_weights_file=os.path.join(root_dir,'parameters_nosao37','L8qp37-45-0.000756-31.2899-0.9044param.pth') 204 | #model_weights_file=os.path.join(root_dir,'parameters_nosao37','MSqp37-3-0.000722-31.5615-0.9060param.pth') 205 | #model_weights_file=os.path.join(root_dir,'parameters_nosao37','BestDenseqp37-24-0.000747-31.3328-0.9049param.pth') 206 | #model_weights_file=os.path.join(root_dir,'parameters_nosao','msssimqp42-10--0.807487-28.5841-0.8471-0.9645param.pth') 207 | #model_weights_file=os.path.join(root_dir,'parameters_nosao','msssimqp42-30--0.807743-28.5685-0.8467-0.9645param.pth') 208 | 209 | #set model 210 | model=mymodel.edar2() 211 | #model=MS_model.IntraDeblocking() 212 | 213 | #model=torch.load(model_weights_file) 214 | #vgg=mymodel_new.Vgg16(requires_grad=False) 215 | if use_gpu: 216 | model = model.cuda() 217 | #vgg = vgg.cuda() 218 | 219 | model.load_state_dict(torch.load(model_weights_file)) 220 | 221 | 222 | #set dataset and dataloader 223 | mydataset = ImageDataset(root_dir=Image_folder, transform=mytransform()) 224 | dataloader = DataLoader(mydataset, batch_size=1,shuffle=False, num_workers=0) 225 | 226 | 227 | 228 | def main(): 229 | test() 230 | 231 | if __name__=='__main__': 232 | main() 233 | -------------------------------------------------------------------------------- /CNNs/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import h5py 3 | import numpy 4 | from pathlib import Path 5 | import torch 6 | import torch.nn as nn 7 | from torch.utils.data import Dataset,DataLoader,ConcatDataset 8 | from torch.autograd import Variable 9 | 10 | import mymodel 11 | import myutils 12 | 13 | from logger import Logger 14 | 15 | logger=Logger('./logs/edar2') 16 | 17 | #prepare data 18 | class MyDataset(Dataset): 19 | def __init__(self,data_file,n): 20 | self.file=h5py.File(str(data_file),'r') 21 | #self.inputs=self.file['data'][:].astype(numpy.float32)/255.0#simple normalization in[0,1] 22 | #self.label=self.file['label'][:].astype(numpy.float32)/255.0 23 | self.n=n 24 | 25 | def __len__(self): 26 | #return self.inputs.shape[0] 27 | return self.n 28 | 29 | def __getitem__(self,idx): 30 | inputs=self.file['data'][idx,:,:,:].astype(numpy.float32).transpose(2,0,1)/255.0 31 | label=self.file['label'][idx,:,:,:].astype(numpy.float32).transpose(2,0,1)/255.0 32 | #label=self.label[idx,:,:,:].transpose(2,0,1) 33 | inputs=torch.Tensor(inputs) 34 | label=torch.Tensor(label) 35 | sample={'inputs':inputs,'label':label} 36 | return sample 37 | 38 | 39 | def checkpoint(epoch,loss,psnr,ssim,mse): 40 | model.eval() 41 | model_path1 = str(checkpoint_dir/'qp37-{}-{:.6f}-{:.4f}-{:.4f}.pth'.format(epoch,loss,psnr,ssim)) 42 | torch.save(model,model_path1) 43 | 44 | if use_gpu: 45 | model.cpu()#you should save weights on cpu not on gpu 46 | 47 | #save weights 48 | model_path = str(checkpoint_dir/'qp37-{}-{:.6f}-{:.4f}-{:.4f}param.pth'.format(epoch,loss,psnr,ssim)) 49 | 50 | torch.save(model.state_dict(),model_path) 51 | 52 | #print and save record 53 | print('Epoch {} : Avg.loss:{:.6f}'.format(epoch,loss)) 54 | print("Test Avg. PSNR: {:.4f} Avg. SSIM: {:.4f} Avg.MSE{:.6f} ".format(psnr,ssim,mse)) 55 | print("Checkpoint saved to {}".format(model_path)) 56 | 57 | output = open(str(checkpoint_dir/'train_result.txt'),'a+') 58 | output.write(('{} {:.6f} {:.4f} {:.4f}'.format(epoch,loss,psnr,ssim))+'\r\n') 59 | output.close() 60 | 61 | if use_gpu: 62 | model.cuda()#don't forget return to gpu 63 | #model.train() 64 | 65 | 66 | def wrap_variable(input_batch, label_batch, use_gpu,flag): 67 | if use_gpu: 68 | input_batch, label_batch = (Variable(input_batch.cuda(),volatile=flag), Variable(label_batch.cuda(),volatile=flag)) 69 | 70 | else: 71 | input_batch, label_batch = (Variable(input_batch,volatile=flag),Variable(label_batch,volatile=flag)) 72 | return input_batch, label_batch 73 | 74 | 75 | 76 | def train(epoch): 77 | model.train() 78 | sum_loss=0.0 79 | 80 | for iteration, sample in enumerate(dataloader):#difference between (dataloader) &(dataloader,1) 81 | inputs,label=sample['inputs'],sample['label'] 82 | 83 | #Wrap with torch Variable 84 | inputs,label=wrap_variable(inputs, label, use_gpu, False) 85 | 86 | #clear the optimizer 87 | optimizer.zero_grad() 88 | 89 | # forward propagation 90 | outputs = model(inputs) 91 | 92 | #get the loss for backward 93 | loss =criterion(outputs, label) 94 | 95 | #backward propagation and optimize 96 | loss.backward() 97 | optimizer.step() 98 | 99 | if iteration%100==0: 100 | print("===> Epoch[{}]({}/{}):loss: {:.6f}".format(epoch, iteration, len(dataloader), loss.data[0])) 101 | #if iteration==101: 102 | # break 103 | 104 | info={'edar2_loss':loss.data[0]} 105 | for tag,value in info.items(): 106 | logger.scalar_summary(tag,value,iteration+epoch*len(dataloader)) 107 | 108 | #caculate the average loss 109 | sum_loss += loss.data[0] 110 | 111 | return sum_loss/len(dataloader) 112 | 113 | 114 | def test(): 115 | model.eval() 116 | avg_psnr = 0 117 | avg_ssim = 0 118 | avg_mse = 0 119 | for iteration, sample in enumerate(test_dataloader): 120 | inputs,label=sample['inputs'],sample['label'] 121 | #Wrap with torch Variable 122 | inputs,label=wrap_variable(inputs, label, use_gpu, True) 123 | 124 | outputs = model(inputs) 125 | mse = criterion(outputs,label).data[0] 126 | psnr = myutils.psnr(outputs, label) 127 | ssim = torch.sum((myutils.ssim(outputs, label, size_average=False)).data)/args.testbatchsize 128 | avg_ssim += ssim 129 | avg_psnr += psnr 130 | avg_mse += mse 131 | return (avg_psnr / len(test_dataloader)),(avg_ssim / len(test_dataloader)),(avg_mse/len(test_dataloader)) 132 | 133 | 134 | 135 | def main(): 136 | #train & test & record 137 | for epoch in range(args.epochs): 138 | loss=train(epoch) 139 | psnr,ssim,mse = test() 140 | checkpoint(epoch,loss,psnr,ssim,mse) 141 | info1={'edar2_avg_loss':loss,'edar2_psnr':psnr,'edar2_ssim':ssim,'edar2_mse':mse} 142 | for tag,value in info1.items(): 143 | logger.scalar_summary(tag,value,epoch) 144 | 145 | 146 | 147 | 148 | #--------------------------------------------------------------------------------------------------- 149 | # Training settings 150 | parser = argparse.ArgumentParser(description='ARCNN') 151 | parser.add_argument('--batchsize', type=int, default=64, help='training batch size') 152 | parser.add_argument('--testbatchsize', type=int, default=16, help='testing batch size') 153 | parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train for') 154 | parser.add_argument('--lr', type=float, default=0.0001, help='Learning Rate. Default=0.0001') 155 | args = parser.parse_args() 156 | 157 | print(args) 158 | 159 | #---------------------------------------------------------------------------------------------------- 160 | #set other parameters 161 | #1.set cuda 162 | use_gpu=torch.cuda.is_available() 163 | 164 | 165 | #2.set path and file 166 | save_dir = Path('.') 167 | checkpoint_dir = Path('.') / 'Checkpoints_edar2_L1'#save model parameters and train record 168 | if checkpoint_dir.exists(): 169 | print 'folder esxited' 170 | else: 171 | checkpoint_dir.mkdir() 172 | 173 | model_weights_file=checkpoint_dir/'edar2_qp42-0-0.001520-28.6083-0.8426param.pth' 174 | 175 | 176 | #3.set dataset and dataloader 177 | dataset=MyDataset(data_file=save_dir/'TrainData_37_nosao.h5',n=43848)#you need to obtain the number from dataprocess 178 | dataset2K=MyDataset(data_file=save_dir/'Data2K_37_nosao.h5',n=38478) 179 | test_dataset=MyDataset(data_file=save_dir/'ValData_37_nosao.h5',n=5320) 180 | 181 | dataloader=DataLoader(ConcatDataset([dataset,dataset2K]),batch_size=args.batchsize,shuffle=True,num_workers=0) 182 | #dataloader=DataLoader(dataset,batch_size=args.batchsize,shuffle=True,num_workers=0) 183 | test_dataloader=DataLoader(test_dataset,batch_size=args.testbatchsize,shuffle=False,num_workers=0) 184 | 185 | 186 | 187 | #4.set model& criterion& optimizer 188 | model=mymodel.edar2() 189 | 190 | criterion = nn.MSELoss() 191 | optimizer=torch.optim.Adam(model.parameters(), lr=args.lr) 192 | 193 | if use_gpu: 194 | model = model.cuda() 195 | criterion = criterion.cuda() 196 | 197 | #load parameters 198 | if not use_gpu: 199 | model.load_state_dict(torch.load(str(model_weights_file), map_location=lambda storage, loc: storage)) 200 | #model=torch.load(str(model_weights_file), map_location=lambda storage, loc: storage) 201 | else: 202 | model.load_state_dict(torch.load(str(model_weights_file))) 203 | #model=torch.load(str(model_weights_file)) 204 | 205 | 206 | #show mdoel¶meters&dataset 207 | print('Model Structure:',model) 208 | print('parameters:', sum(param.numel() for param in model.parameters())) 209 | params = list(model.parameters()) 210 | for i in range(len(params)): 211 | print('layer:',i+1,params[i].size()) 212 | 213 | #print('length of dataset:',len(dataset)) 214 | 215 | 216 | if __name__=='__main__': 217 | main() 218 | -------------------------------------------------------------------------------- /GAN/GAN_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, channels=64): 8 | super(ResidualBlock, self).__init__() 9 | self.conv = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) 10 | self.relu = nn.ReLU(inplace=True)#attrntion:relu not prelu 11 | 12 | def forward(self, x): 13 | residual = self.conv(x) 14 | residual = self.relu(residual) 15 | residual = self.conv(residual)#the original model multipy a value 0.1 here,which is to be analyzed in the feature 16 | out = torch.add(x,residual) 17 | 18 | return out 19 | 20 | class Generator(nn.Module): 21 | def __init__(self): 22 | super(Generator, self).__init__() 23 | self.head= nn.Conv2d(3,64,3,1,1) 24 | self.body = nn.ModuleList([ResidualBlock(64) for i in range(8)]) 25 | self.tail=nn.Conv2d(64,64,3,1,1) 26 | self.reconstruct=nn.Conv2d(64,3,3,1,1) 27 | self.relu= nn.ReLU(inplace=True) 28 | 29 | #weights initialization by normal(Gaussinn) distribution:normal_(mean=0, std=1 , gengerator=None*) 30 | for m in self.modules(): 31 | if isinstance(m, nn.Conv2d): 32 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 33 | m.weight.data.normal_(0, math.sqrt(2. / n)) 34 | 35 | def forward(self, x): 36 | x = self.head(x) 37 | residual=x#i'm afraid this may cause residual is x 38 | for layer in self.body: 39 | residual = layer(residual) 40 | residual = self.tail(residual)# 41 | out = torch.add(x,residual)#global residual 42 | #out=self.tail(out) 43 | out = self.reconstruct(out) 44 | return out 45 | 46 | 47 | class Discriminator(nn.Module): 48 | def __init__(self): 49 | super(Discriminator, self).__init__() 50 | self.net = nn.Sequential( 51 | nn.Conv2d(3, 64, kernel_size=3, padding=1), 52 | nn.LeakyReLU(0.2), 53 | 54 | nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1), 55 | nn.BatchNorm2d(64), 56 | nn.LeakyReLU(0.2), 57 | 58 | nn.Conv2d(64, 128, kernel_size=3, padding=1), 59 | nn.BatchNorm2d(128), 60 | nn.LeakyReLU(0.2), 61 | 62 | nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1), 63 | nn.BatchNorm2d(128), 64 | nn.LeakyReLU(0.2), 65 | 66 | nn.Conv2d(128, 256, kernel_size=3, padding=1), 67 | nn.BatchNorm2d(256), 68 | nn.LeakyReLU(0.2), 69 | 70 | nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1), 71 | nn.BatchNorm2d(256), 72 | nn.LeakyReLU(0.2), 73 | 74 | nn.Conv2d(256, 512, kernel_size=3, padding=1), 75 | nn.BatchNorm2d(512), 76 | nn.LeakyReLU(0.2), 77 | 78 | nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1), 79 | nn.BatchNorm2d(512), 80 | nn.LeakyReLU(0.2), 81 | 82 | nn.AdaptiveAvgPool2d(1), 83 | nn.Conv2d(512, 1024, kernel_size=1), 84 | nn.LeakyReLU(0.2), 85 | nn.Conv2d(1024, 1, kernel_size=1) 86 | ) 87 | 88 | def forward(self, x): 89 | batch_size = x.size(0) 90 | return F.sigmoid(self.net(x).view(batch_size)) 91 | 92 | 93 | 94 | ''' 95 | class Generator(nn.Module): 96 | def __init__(self, scale_factor): 97 | upsample_block_num = int(math.log(scale_factor, 2)) 98 | 99 | super(Generator, self).__init__() 100 | self.block1 = nn.Sequential( 101 | nn.Conv2d(3, 64, kernel_size=9, padding=4), 102 | nn.PReLU() 103 | ) 104 | self.block2 = ResidualBlock(64) 105 | self.block3 = ResidualBlock(64) 106 | self.block4 = ResidualBlock(64) 107 | self.block5 = ResidualBlock(64) 108 | self.block6 = ResidualBlock(64) 109 | self.block7 = nn.Sequential( 110 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 111 | nn.PReLU() 112 | ) 113 | block8 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)] 114 | block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4)) 115 | self.block8 = nn.Sequential(*block8) 116 | 117 | def forward(self, x): 118 | block1 = self.block1(x) 119 | block2 = self.block2(block1) 120 | block3 = self.block3(block2) 121 | block4 = self.block4(block3) 122 | block5 = self.block5(block4) 123 | block6 = self.block6(block5) 124 | block7 = self.block7(block6) 125 | block8 = self.block8(block1 + block7) 126 | 127 | return (F.tanh(block8) + 1) / 2#? 128 | ''' -------------------------------------------------------------------------------- /GAN/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchvision.models.vgg import vgg16 4 | 5 | 6 | class GeneratorLoss(nn.Module): 7 | def __init__(self): 8 | super(GeneratorLoss, self).__init__() 9 | vgg = vgg16(pretrained=True) 10 | loss_network = nn.Sequential(*list(vgg.features)[:31]).eval() 11 | for param in loss_network.parameters(): 12 | param.requires_grad = False 13 | self.loss_network = loss_network 14 | self.mse_loss = nn.MSELoss() 15 | self.tv_loss = TVLoss() 16 | 17 | def forward(self, out_labels, out_images, target_images): 18 | # Adversarial Loss 19 | adversarial_loss = torch.mean(1 - out_labels) 20 | # Perception Loss 21 | perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images)) 22 | # Image Loss 23 | image_loss = self.mse_loss(out_images, target_images) 24 | # TV Loss 25 | tv_loss = self.tv_loss(out_images) 26 | return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss 27 | 28 | 29 | class TVLoss(nn.Module): 30 | def __init__(self, tv_loss_weight=1): 31 | super(TVLoss, self).__init__() 32 | self.tv_loss_weight = tv_loss_weight 33 | 34 | def forward(self, x): 35 | batch_size = x.size()[0] 36 | h_x = x.size()[2] 37 | w_x = x.size()[3] 38 | count_h = self.tensor_size(x[:, :, 1:, :]) 39 | count_w = self.tensor_size(x[:, :, :, 1:]) 40 | h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum() 41 | w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum() 42 | return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size 43 | 44 | @staticmethod 45 | def tensor_size(t): 46 | return t.size()[1] * t.size()[2] * t.size()[3] 47 | 48 | 49 | if __name__ == "__main__": 50 | g_loss = GeneratorLoss() 51 | print(g_loss) 52 | -------------------------------------------------------------------------------- /GAN/myutils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | from math import exp,log10 6 | 7 | #mormalize input batch for vgg 8 | def normalize_batch(batch): 9 | # normalize using imagenet mean and std 10 | mean = batch.data.new(batch.data.size()) 11 | std = batch.data.new(batch.data.size()) 12 | mean[:, 0, :, :] = 0.485 13 | mean[:, 1, :, :] = 0.456 14 | mean[:, 2, :, :] = 0.406 15 | std[:, 0, :, :] = 0.229 16 | std[:, 1, :, :] = 0.224 17 | std[:, 2, :, :] = 0.225 18 | batch = torch.div(batch, 1.0)#Attention!BUG 19 | batch -= Variable(mean) 20 | batch = batch / Variable(std) 21 | return batch 22 | 23 | #caculate ssim 24 | def gaussian(window_size, sigma): 25 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 26 | return gauss/gauss.sum() 27 | 28 | def create_window(window_size, channel): 29 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 30 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 31 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 32 | return window 33 | 34 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 35 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 36 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 37 | 38 | mu1_sq = mu1.pow(2) 39 | mu2_sq = mu2.pow(2) 40 | mu1_mu2 = mu1*mu2 41 | 42 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 43 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 44 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 45 | 46 | C1 = 0.01**2 47 | C2 = 0.03**2 48 | 49 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 50 | 51 | if size_average: 52 | return ssim_map.mean() 53 | else: 54 | return ssim_map.mean(1).mean(1).mean(1) 55 | 56 | def ssim(img1, img2, window_size = 11, size_average = True): 57 | (_, channel, _, _) = img1.size() 58 | window = create_window(window_size, channel) 59 | 60 | if img1.is_cuda: 61 | window = window.cuda(img1.get_device()) 62 | window = window.type_as(img1) 63 | 64 | return _ssim(img1, img2, window, window_size, channel, size_average) 65 | 66 | #caculate psnr 67 | def psnr(tesnor1, tensor2): 68 | mse = nn.MSELoss()(tesnor1, tensor2) 69 | psnr = 10 * log10(1 / mse.data[0]) 70 | return psnr -------------------------------------------------------------------------------- /GAN/train_GAN.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import h5py 3 | import numpy 4 | from pathlib import Path 5 | import torch 6 | import torch.nn as nn 7 | from torch.utils.data import Dataset,DataLoader 8 | from torch.autograd import Variable 9 | from loss import GeneratorLoss 10 | from math import log10 11 | import GAN_model 12 | import myutils 13 | 14 | #prepare data 15 | class MyDataset(Dataset): 16 | def __init__(self,data_file): 17 | self.file=h5py.File(str(data_file),'r') 18 | self.inputs=self.file['data'][:].astype(numpy.float32)/255.0#simple normalization in[0,1] 19 | self.label=self.file['label'][:].astype(numpy.float32)/255.0 20 | #BUG! 21 | 22 | 23 | def __len__(self): 24 | return self.inputs.shape[0] 25 | 26 | def __getitem__(self,idx): 27 | inputs=self.inputs[idx,:,:,:].transpose(2,0,1) 28 | label=self.label[idx,:,:,:].transpose(2,0,1) 29 | inputs=torch.Tensor(inputs) 30 | label=torch.Tensor(label) 31 | sample={'inputs':inputs,'label':label} 32 | return sample 33 | 34 | 35 | 36 | def checkpoint(epoch,d_loss,g_loss,d_score,g_score,psnr,ssim,mse): 37 | netG.eval() 38 | netD.eval() 39 | model_pathG = str(checkpoint_dir/'netG{}-{:.6f}-{:.4f}-{:.4f}.pth'.format(epoch,g_loss,psnr,ssim)) 40 | torch.save(netG,model_pathG) 41 | 42 | model_pathD = str(checkpoint_dir/'netD{}-{:.6f}-{:.6f}-{:.6f}-{:.6f}.pth'.format(epoch,d_loss,g_loss,d_score,g_score)) 43 | torch.save(netD,model_pathD) 44 | 45 | if use_gpu: 46 | netG.cpu()#you should save weights on cpu not on gpu 47 | netD.cpu() 48 | 49 | #save weights 50 | model_path_G_params = str(checkpoint_dir/'netG{}-{:.6f}-{:.4f}-{:.4f}params.pth'.format(epoch,g_loss,psnr,ssim)) 51 | torch.save(netG.state_dict(),model_path_G_params) 52 | 53 | model_path_D_params = str(checkpoint_dir/'netD{}-{:.6f}-{:.6f}-{:.6f}-{:.6f}params.pth'.format(epoch,d_loss,g_loss,d_score,g_score)) 54 | torch.save(netD.state_dict(),model_path_D_params) 55 | 56 | #print and save record 57 | print('Epoch {} : Avg.d_loss:{:.6f},g_loss:{:.6f},d_score:{:.6f},g_score:{:.6f},'.format(epoch,d_loss,g_loss,d_score,g_score)) 58 | print("Test Avg. PSNR: {:.4f} Avg. SSIM: {:.4f} Avg.MSE{:.6f} ".format(psnr,ssim,mse)) 59 | print("Checkpoint saved to {}".format(model_path_G_params)) 60 | 61 | output = open(str(checkpoint_dir/'train_result.txt'),'a+') 62 | output.write(('{} {:.6f} {:.6f} {:.6f} {:.6f} {:.4f} {:.4f} {:.6f}'.format(epoch,d_loss,g_loss,d_score,g_score,psnr,ssim,mse))+'\r\n') 63 | output.close() 64 | 65 | if use_gpu: 66 | netG.cuda()#don't forget return to gpu 67 | netD.cuda() 68 | 69 | 70 | 71 | def wrap_variable(input_batch, label_batch, use_gpu,flag): 72 | if use_gpu: 73 | input_batch, label_batch = (Variable(input_batch.cuda(),volatile=flag), Variable(label_batch.cuda(),volatile=flag)) 74 | 75 | else: 76 | input_batch, label_batch = (Variable(input_batch,volatile=flag),Variable(label_batch,volatile=flag)) 77 | return input_batch, label_batch 78 | 79 | 80 | 81 | def train(epoch): 82 | netG.train() 83 | netD.train() 84 | 85 | running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0} 86 | 87 | for iteration, sample in enumerate(dataloader):#difference between (dataloader) &(dataloader,1) 88 | inputs,label=sample['inputs'],sample['label'] 89 | batch_size = inputs.size(0) 90 | running_results['batch_sizes'] += batch_size 91 | 92 | ############################ 93 | # (1) Update D network: maximize D(x)-1-D(G(z)) 94 | ########################### 95 | #Wrap with torch Variable 96 | inputs,label=wrap_variable(inputs, label, use_gpu, False) 97 | 98 | # forward propagation in netG to get the fake_img 99 | fake_img=netG(inputs) 100 | 101 | #clear the optimizer of netD first! 102 | netD.zero_grad() 103 | #optimizerD.zero_grad() 104 | 105 | #forward propagation in netD 106 | real_out = netD(label).mean() 107 | fake_out = netD(fake_img).mean() 108 | 109 | #get the loss of Discriminator for backward 110 | d_loss = 1 - real_out + fake_out 111 | 112 | #another methond to calculate the d_loss,which use cross entropy 113 | ''' 114 | #generator a target first:real is high,fake is low 115 | target_real = Variable(torch.rand(opt.batchSize,1)*0.5 + 0.7)#target of real should be high 116 | target_fake = Variable(torch.rand(opt.batchSize,1)*0.3)#target of fake should be low 117 | 118 | #use BCE loss(cross entroy:loss(o,t)=-\frac{1}{n}\sum_i(t[i] log(o[i])+(1-t[i]) log(1-o[i]))) 119 | adversarial_criterion = nn.BCELoss() 120 | discriminator_loss = adversarial_criterion(discriminator(high_res_real), target_real) + \ 121 | adversarial_criterion(discriminator(Variable(high_res_fake.data)), target_fake) 122 | ''' 123 | 124 | 125 | #backward propagation and optimize the netD 126 | d_loss.backward(retain_graph=True) 127 | optimizerD.step()#Attention! use optimizerD means only optimize netD 128 | 129 | ############################ 130 | # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss 131 | ########################### 132 | 133 | #clear the optimizer of netG first! 134 | netG.zero_grad() 135 | 136 | #get the loss from fake_out & fake_img,which is generate from netG,so netD will get the gradient 137 | g_loss = generator_criterion(fake_out, fake_img, label) 138 | 139 | #backward propagation and optimize the netG 140 | g_loss.backward() 141 | optimizerG.step() 142 | 143 | ############################ 144 | # (3) record and show the loss&score 145 | ########################### 146 | #calculate the loss 147 | fake_img = netG(inputs) 148 | fake_out = netD(fake_img).mean() 149 | 150 | g_loss = generator_criterion(fake_out, fake_img, label)#the lower, the better 151 | running_results['g_loss'] += g_loss.data[0] * batch_size 152 | d_loss = 1 - real_out + fake_out 153 | running_results['d_loss'] += d_loss.data[0] * batch_size#the lower,the better,which means discriminator is right 154 | running_results['d_score'] += real_out.data[0] * batch_size#the higher,means discriminator the better 155 | running_results['g_score'] += fake_out.data[0] * batch_size#the higher,means generator the better(no more than 0.5) 156 | 157 | 158 | #which is used for monitor 159 | if iteration%100==0: 160 | print("===> Epoch[{}]({}/{}):loss_d: {:.6f} Loss_G: {:.6f} D(label)/d_score: {:.6f} D(G(inputs))/g_score: {:.6f}".format( 161 | epoch, iteration, len(dataloader),running_results['d_loss'] / running_results['batch_sizes'], 162 | running_results['g_loss'] / running_results['batch_sizes'], 163 | running_results['d_score'] / running_results['batch_sizes'], 164 | running_results['g_score'] / running_results['batch_sizes'])) 165 | #if iteration==30: 166 | # break 167 | 168 | return running_results['d_loss'] / running_results['batch_sizes'], running_results['g_loss'] / running_results['batch_sizes'], running_results['d_score'] / running_results['batch_sizes'],running_results['g_score'] / running_results['batch_sizes'] 169 | 170 | 171 | 172 | def test(): 173 | netG.eval() 174 | valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0} 175 | 176 | for iteration, sample in enumerate(test_dataloader): 177 | inputs,label=sample['inputs'],sample['label'] 178 | batch_size = inputs.size(0) 179 | valing_results['batch_sizes'] += batch_size 180 | 181 | #Wrap with torch Variable 182 | inputs,label=wrap_variable(inputs, label, use_gpu, True) 183 | #get the output of netG 184 | outputs = netG(inputs) 185 | #calcualte metrics 186 | batch_mse = ((outputs - label) ** 2).data.mean() 187 | valing_results['mse'] += batch_mse * batch_size 188 | batch_ssim = myutils.ssim(outputs, label).data[0] 189 | valing_results['ssims'] += batch_ssim * batch_size#sum each patch,you must divide batch sizes 190 | 191 | valing_results['psnr'] = 10 * log10(1 / (valing_results['mse'] / valing_results['batch_sizes'])) 192 | valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes'] 193 | 194 | return valing_results['psnr'], valing_results['ssim'], valing_results['mse'] / valing_results['batch_sizes'] 195 | 196 | 197 | 198 | def main(): 199 | #train & test & record 200 | for epoch in range(args.epochs): 201 | d_loss,g_loss,d_score,g_score=train(epoch) 202 | psnr,ssim,mse = test() 203 | checkpoint(epoch,d_loss,g_loss,d_score,g_score,psnr,ssim,mse) 204 | 205 | 206 | 207 | #--------------------------------------------------------------------------------------------------- 208 | # Training settings 209 | parser = argparse.ArgumentParser(description='GAN') 210 | parser.add_argument('--batchsize', type=int, default=64, help='training batch size') 211 | parser.add_argument('--testbatchsize', type=int, default=16, help='testing batch size') 212 | parser.add_argument('--epochs', type=int, default=200, help='number of epochs to train for') 213 | parser.add_argument('--g_lr', type=float, default=0.0001, help='Learning Rate. Default=0.0001') 214 | parser.add_argument('--d_lr', type=float, default=0.001, help='Learning Rate. Default=0.0001') 215 | args = parser.parse_args() 216 | 217 | print(args) 218 | 219 | #---------------------------------------------------------------------------------------------------- 220 | #set other parameters 221 | #1.set cuda 222 | use_gpu=torch.cuda.is_available() 223 | 224 | 225 | #2.set path and file 226 | save_dir = Path('.') 227 | checkpoint_dir = Path('.') / 'Checkpoints_GAN'#save model parameters and train record 228 | if checkpoint_dir.exists(): 229 | print 'folder esxited' 230 | else: 231 | checkpoint_dir.mkdir() 232 | 233 | generator_weights_file=checkpoint_dir/'18-0.001067-30.8295-0.9167param.pth' 234 | discriminator_weights_file=checkpoint_dir/'netD10-0.925502-0.002677-0.543894-0.469396params.pth' 235 | 236 | 237 | #4.set model& criterion& optimizer 238 | netG=GAN_model.Generator() 239 | netD=GAN_model.Discriminator() 240 | 241 | 242 | generator_criterion = GeneratorLoss() 243 | 244 | optimizerG = torch.optim.Adam(netG.parameters(), lr=args.g_lr) 245 | optimizerD = torch.optim.Adam(netD.parameters(), lr=args.d_lr) 246 | 247 | 248 | ''' 249 | criterion = nn.MSELoss() 250 | optimizer=torch.optim.Adam(model.parameters(), lr=args.lr) 251 | ''' 252 | if use_gpu: 253 | netG.cuda() 254 | netD.cuda() 255 | generator_criterion.cuda() 256 | 257 | #load parameters 258 | if not use_gpu: 259 | netG.load_state_dict(torch.load(str(generator_weights_file), map_location=lambda storage, loc: storage)) 260 | netD.load_state_dict(torch.load(str(discriminator_weights_file), map_location=lambda storage, loc: storage)) 261 | #model=torch.load(str(model_weights_file), map_location=lambda storage, loc: storage) 262 | else: 263 | netG.load_state_dict(torch.load(str(generator_weights_file))) 264 | netD.load_state_dict(torch.load(str(discriminator_weights_file))) 265 | #model=torch.load(str(model_weights_file)) 266 | 267 | 268 | 269 | 270 | #3.set dataset and dataloader 271 | dataset=MyDataset(data_file=save_dir/'TrainData32.h5') 272 | test_dataset=MyDataset(data_file=save_dir/'TestData32.h5') 273 | 274 | dataloader=DataLoader(dataset,batch_size=args.batchsize,shuffle=True,num_workers=0) 275 | test_dataloader=DataLoader(test_dataset,batch_size=args.testbatchsize,shuffle=False,num_workers=0) 276 | 277 | 278 | 279 | 280 | #show mdoel¶meters&dataset 281 | print('NetG Structure:',netG) 282 | print('# generator parameters:', sum(param.numel() for param in netG.parameters())) 283 | params = list(netG.parameters()) 284 | for i in range(len(params)): 285 | print('layer:',i+1,params[i].size()) 286 | 287 | print('NetD Structure:',netD) 288 | print('# generator parameters:', sum(param.numel() for param in netD.parameters())) 289 | params = list(netD.parameters()) 290 | for i in range(len(params)): 291 | print('layer:',i+1,params[i].size()) 292 | 293 | print('length of dataset:',len(dataset)) 294 | 295 | 296 | if __name__=='__main__': 297 | main() 298 | 299 | 300 | netG = Generator() -------------------------------------------------------------------------------- /Multi-Scale/MS_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.autograd import Variable 3 | import torch 4 | import torch.nn.functional as f 5 | 6 | 7 | 8 | ## 1)all relu =>> prelu 9 | 10 | class Residual_Block(nn.Module): 11 | def __init__(self, Channel): 12 | super(Residual_Block, self).__init__() 13 | self.bn1 = nn.BatchNorm2d(Channel) 14 | self.bn2 = nn.BatchNorm2d(Channel) 15 | self.conv1 = nn.Conv2d(Channel, Channel, 3, 1, 1) 16 | self.conv2 = nn.Conv2d(Channel, Channel, 3, 1, 1) 17 | self.prelu = nn.PReLU() 18 | def forward(self, x): 19 | x1 = self.conv1(x) 20 | x2 = self.prelu(self.bn1(x1)) 21 | x3 = self.conv2(x2) 22 | x4 = self.bn2(x3) 23 | x5 = torch.add(x, x4) 24 | return x5 25 | 26 | 27 | class IntraDeblocking(nn.Module): 28 | def __init__(self): 29 | super(IntraDeblocking,self).__init__() 30 | self.conv1 = nn.Conv2d(3,64,7,1,3) 31 | self.prelu1 = nn.PReLU() 32 | self.downsample1 = nn.Conv2d(64,64,4,2,1) 33 | self.downsample2 = nn.Conv2d(64,64, 4, 2, 1) 34 | self.block0 = Residual_Block(64) 35 | self.block1 = Residual_Block(64) 36 | self.block2 = Residual_Block(64) 37 | self.block3 = Residual_Block(64) 38 | self.conv2 = nn.Conv2d(64,3,3,1,1) 39 | self.prelu2 = nn.PReLU() 40 | 41 | self.block4 = Residual_Block(64) 42 | self.block5 = Residual_Block(64) 43 | self.block6 = Residual_Block(64) 44 | self.block7 = Residual_Block(64) 45 | self.conv3 = nn.Conv2d(64,3,5,1,2) 46 | self.prelu3 = nn.PReLU() 47 | 48 | self.block8 = Residual_Block(64) 49 | self.block9 = Residual_Block(64) 50 | self.block10 = Residual_Block(64) 51 | self.block11 = Residual_Block(64) 52 | self.conv4 = nn.Conv2d(64,3,7,1,3) 53 | self.prelu4 = nn.PReLU() 54 | 55 | self.conv5 = nn.Conv2d(64,64,3,1,1) 56 | self.up1 = nn.PixelShuffle(2) 57 | self.prelu5 =nn.PReLU() 58 | 59 | self.conv6 = nn.Conv2d(64, 64, 3, 1, 1) 60 | self.up2 = nn.PixelShuffle(2) 61 | self.prelu6 = nn.PReLU() 62 | 63 | self.conv7 = nn.Conv2d(80,64,3,1,1) 64 | self.prelu7 = nn.PReLU() 65 | self.conv8 = nn.Conv2d(80, 64, 3, 1, 1) 66 | self.prelu8 = nn.PReLU() 67 | self.conv9 = nn.Conv2d(3, 64, 5, 1, 2) 68 | self.prelu9 = nn.PReLU() 69 | self.conv10 = nn.Conv2d(3, 64, 3, 1, 1) 70 | self.prelu10 = nn.PReLU() 71 | def forward(self,x,x1,x2): 72 | x0 = self.prelu1(self.conv1(x)) 73 | ## decompose the feature into 3 scales 74 | x1 = self.prelu9(self.conv9(x1)) 75 | x2 = self.prelu10(self.conv10(x2)) 76 | ## the third scale 77 | x = self.block0(x2) 78 | x = self.block1(x) 79 | x = self.block2(x) 80 | xt1 = self.block3(x) 81 | x = self.conv5(xt1) 82 | x_cat1 = self.prelu5(self.up1(x)) 83 | 84 | output3 = self.prelu2(self.conv2(xt1)) 85 | ## the second scale 86 | x = torch.cat((x_cat1,x1),dim=1) 87 | x = self.prelu7(self.conv7(x)) 88 | x = self.block4(x) 89 | x = self.block5(x) 90 | x = self.block6(x) 91 | xt2 = self.block7(x) 92 | output2 = self.prelu3(self.conv3(xt2)) 93 | x_cat2 = self.prelu6(self.up2(self.conv6(xt2))) 94 | ## the first scale 95 | x = torch.cat((x_cat2,x0), dim=1) 96 | x = self.prelu8(self.conv8(x)) 97 | x = self.block8(x) 98 | x = self.block9(x) 99 | x = self.block10(x) 100 | x = self.block11(x) 101 | output1 = self.prelu4(self.conv4(x)) 102 | 103 | 104 | return output1,output2,output3 105 | if __name__== '__main__': 106 | x = Variable(torch.rand(8,3,64,64)).cuda() 107 | deblock = IntraDeblocking().cuda() 108 | y1,y2,y3 = deblock(x) 109 | print y1,y2,y3 -------------------------------------------------------------------------------- /Multi-Scale/test_MS.py: -------------------------------------------------------------------------------- 1 | import os 2 | #from PIL import Image 3 | import cv2 4 | import numpy 5 | import torch 6 | import torch.nn as nn 7 | from torch.utils.data import Dataset, DataLoader 8 | from torch.autograd import Variable 9 | import MS_model 10 | import myutils 11 | import msssim 12 | 13 | 14 | class ImageDataset(Dataset): 15 | def __init__(self, root_dir, transform=None): 16 | 17 | self.input_dir = os.path.join(root_dir,'output_qp_42_no_sao','320x480') 18 | self.label_dir = os.path.join(root_dir,'images_png','320x480') 19 | self.transform = transform 20 | 21 | def __len__(self): 22 | return os.listdir(self.input_dir).__len__() 23 | 24 | def __getitem__(self, idx): 25 | input_names = sorted(os.listdir(self.input_dir)) 26 | label_names = sorted(os.listdir(self.label_dir)) 27 | 28 | input_name = os.path.join(self.input_dir,input_names[idx]) 29 | #input_image =Image.open(input_name)#Image get jpg 30 | input_image=cv2.imread(input_name) 31 | input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB) 32 | 33 | label_name = os.path.join(self.label_dir,label_names[idx]) 34 | #label_image =Image.open(label_name) 35 | label_image=cv2.imread(label_name) 36 | label_image = cv2.cvtColor(label_image, cv2.COLOR_BGR2RGB) 37 | 38 | sample = {'input_image': input_image, 'label_image': label_image, 'name': input_names[idx]} 39 | 40 | if self.transform: 41 | sample = self.transform(sample) 42 | 43 | return sample 44 | 45 | def edge_clip(image): 46 | if image.shape[0]%2==1: 47 | image=image[:-1,:,:] 48 | if image.shape[1]%2==1: 49 | image=image[:,:-1,:] 50 | return image 51 | 52 | 53 | class mytransform(object): 54 | def __call__(self, sample): 55 | input_image, label_image,name= sample['input_image'], sample['label_image'],sample['name'] 56 | # swap color axis because 57 | # numpy image: H x W x C 58 | # torch image: C X H X W 59 | #input_image = numpy.asarray(input_image).transpose(2, 0, 1)/255.0 60 | #label_image = numpy.asarray(label_image).transpose(2, 0, 1)/255.0 61 | input_image=edge_clip(input_image) 62 | label_image=edge_clip(label_image) 63 | 64 | input_image = input_image.transpose(2, 0, 1)/255.0 65 | label_image = label_image.transpose(2, 0, 1)/255.0 66 | 67 | 68 | return {'input_image': torch.from_numpy(input_image).float(), 69 | 'label_image': torch.from_numpy(label_image).float(), 70 | 'name':name} 71 | 72 | def wrap_variable(input_batch, label_batch, use_gpu,flag): 73 | if use_gpu: 74 | input_batch, label_batch = (Variable(input_batch.cuda(),volatile=flag), Variable(label_batch.cuda(),volatile=flag)) 75 | 76 | else: 77 | input_batch, label_batch = (Variable(input_batch,volatile=flag),Variable(label_batch,volatile=flag)) 78 | return input_batch, label_batch 79 | 80 | 81 | def checkpoint(name,psnr1,psnr2,ssim1,ssim2,msssim1,msssim2): 82 | print('{},psnr:{:.4f}->{:.4f},ssim:{:.4f}->{:.4f},msssim:{:.4f}-{:.4f}'.format(name,psnr1,psnr2,ssim1,ssim2,msssim1,msssim2)) 83 | #write to text 84 | output = open(os.path.join(Image_folder,'test_result.txt'),'a+') 85 | output.write(('{} {:.4f}->{:.4f} {:.4f}->{:.4f},{:.4f}->{:.4f}'.format(name,psnr1,psnr2,ssim1,ssim2,msssim1,msssim2))+'\r\n') 86 | output.close() 87 | def save(output_image,name): 88 | output_data=output_image.data[0] 89 | if use_gpu: 90 | img=255.0*output_data.clone().cpu().numpy() 91 | else: 92 | img=255.0*output_data.clone().numpy() 93 | img = img.transpose(1, 2, 0).astype("uint8") 94 | #img = Image.fromarray(img) 95 | #img.save(os.path.join(Image_folder,'output','{}.jpg'.format(name[:-4]))) 96 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 97 | cv2.imwrite(os.path.join(Image_folder,'output','{}.png'.format(name[:-4])),img) 98 | 99 | downsampling=nn.AvgPool2d(2) 100 | 101 | def test(): 102 | model.eval() 103 | avg_psnr1 = 0 104 | avg_ssim1 = 0 105 | avg_msssim1 = 0 106 | #output and label 107 | avg_psnr2 = 0 108 | avg_ssim2 = 0 109 | avg_msssim2 = 0 110 | 111 | 112 | for i, sample in enumerate(dataloader): 113 | input_image,label_image,name=sample['input_image'],sample['label_image'],sample['name'][0]#tuple to str 114 | 115 | #Wrap with torch Variable 116 | input_image,label_image=wrap_variable(input_image,label_image, use_gpu,True) 117 | inputs_1=downsampling(input_image) 118 | label_1=downsampling(label_image) 119 | inputs_2=downsampling(inputs_1) 120 | label_2=downsampling(label_1) 121 | 122 | output_image, outputs_1, outputs_2 = model(input_image,inputs_1,inputs_2) 123 | 124 | 125 | #predict 126 | #output_image = model(input_image) 127 | #clamp in[0,1] 128 | output_image=output_image.clamp(0.0, 1.0) 129 | 130 | 131 | #calculate psnr 132 | psnr1 =myutils.psnr(input_image, label_image) 133 | psnr2 =myutils.psnr(output_image, label_image) 134 | # ssim is calculated with the normalize (range [0, 1]) image 135 | ssim1 = torch.sum((myutils.ssim(input_image, label_image, size_average=False)).data)/1.0#batch_size 136 | ssim2 = torch.sum((myutils.ssim(output_image, label_image, size_average=False)).data)/1.0 137 | 138 | #msssim 139 | msssim1 = numpy.sum((msssim.MultiScaleSSIM(numpy.expand_dims(input_image.data[0].clone().cpu().numpy(),axis=0), numpy.expand_dims(label_image.data[0].clone().cpu().numpy(),axis=0), max_val=1.0)))/1.0#batch_size 140 | msssim2 = numpy.sum((msssim.MultiScaleSSIM(numpy.expand_dims(output_image.data[0].clone().cpu().numpy(),axis=0), numpy.expand_dims(label_image.data[0].clone().cpu().numpy(),axis=0), max_val=1.0)))/1.0 141 | 142 | avg_ssim1 += ssim1 143 | avg_psnr1 += psnr1 144 | avg_ssim2 += ssim2 145 | avg_psnr2 += psnr2 146 | avg_msssim1 += msssim1 147 | avg_msssim2 += msssim2 148 | 149 | #save output and record 150 | checkpoint(name,psnr1,psnr2,ssim1,ssim2,msssim1,msssim2) 151 | save(output_image,name) 152 | 153 | #print and save 154 | avg_psnr1 = avg_psnr1/len(dataloader) 155 | avg_ssim1 = avg_ssim1/len(dataloader) 156 | avg_psnr2 = avg_psnr2/len(dataloader) 157 | avg_ssim2 = avg_ssim2/len(dataloader) 158 | avg_msssim1 = avg_msssim1/len(dataloader) 159 | avg_msssim2 = avg_msssim2/len(dataloader) 160 | 161 | print('Avg. PSNR: {:.4f}->{:.4f} Avg. SSIM: {:.4f}->{:.4f} Avg. MSSSIM: {:.4f}->{:.4f}'.format(avg_psnr1,avg_psnr2,avg_ssim1,avg_ssim2,avg_msssim1,avg_msssim2)) 162 | output = open(os.path.join(Image_folder,'test_result.txt'),'a+') 163 | output.write('Avg. PSNR: {:.4f}->{:.4f} Avg. SSIM: {:.4f}->{:.4f}Avg. MSSSIM: {:.4f}->{:.4f}'.format(avg_psnr1,avg_psnr2,avg_ssim1,avg_ssim2,avg_msssim1,avg_msssim2)+'\r\n') 164 | output.close() 165 | 166 | 167 | #------------------------------------------------------------------ 168 | #cuda 169 | use_gpu=torch.cuda.is_available() 170 | 171 | #set path 172 | root_dir=os.getcwd() 173 | Image_folder=os.path.join(root_dir) 174 | model_weights_file=os.path.join(root_dir,'parameters_nosao','MSqp42-124-0.001293-29.0237-0.8455param.pth') 175 | 176 | 177 | #set model 178 | model=MS_model.IntraDeblocking() 179 | #vgg=mymodel_new.Vgg16(requires_grad=False) 180 | if use_gpu: 181 | model = model.cuda() 182 | #vgg = vgg.cuda() 183 | 184 | model.load_state_dict(torch.load(model_weights_file)) 185 | #model=torch.load(model_weights_file) 186 | 187 | #set dataset and dataloader 188 | mydataset = ImageDataset(root_dir=Image_folder, transform=mytransform()) 189 | dataloader = DataLoader(mydataset, batch_size=1,shuffle=False, num_workers=0) 190 | 191 | 192 | 193 | def main(): 194 | test() 195 | 196 | if __name__=='__main__': 197 | main() 198 | -------------------------------------------------------------------------------- /Multi-Scale/test_MS_single.py: -------------------------------------------------------------------------------- 1 | import os 2 | #from PIL import Image 3 | import cv2 4 | import numpy 5 | import torch 6 | import torch.nn as nn 7 | from torch.utils.data import Dataset, DataLoader 8 | from torch.autograd import Variable 9 | import MS_model 10 | import myutils 11 | import msssim 12 | 13 | 14 | class ImageDataset(Dataset): 15 | def __init__(self, root_dir, transform=None): 16 | 17 | self.input_dir = os.path.join(root_dir,'output_qp_42_no_sao','320x480') 18 | self.label_dir = os.path.join(root_dir,'images_png','320x480') 19 | self.transform = transform 20 | 21 | def __len__(self): 22 | return os.listdir(self.input_dir).__len__() 23 | 24 | def __getitem__(self, idx): 25 | input_names = sorted(os.listdir(self.input_dir)) 26 | label_names = sorted(os.listdir(self.label_dir)) 27 | 28 | input_name = os.path.join(self.input_dir,input_names[idx]) 29 | #input_image =Image.open(input_name)#Image get jpg 30 | input_image=cv2.imread(input_name) 31 | input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB) 32 | 33 | label_name = os.path.join(self.label_dir,label_names[idx]) 34 | #label_image =Image.open(label_name) 35 | label_image=cv2.imread(label_name) 36 | label_image = cv2.cvtColor(label_image, cv2.COLOR_BGR2RGB) 37 | 38 | sample = {'input_image': input_image, 'label_image': label_image, 'name': input_names[idx]} 39 | 40 | if self.transform: 41 | sample = self.transform(sample) 42 | 43 | return sample 44 | 45 | def edge_clip(image): 46 | if image.shape[0]%2==1: 47 | image=image[:-1,:,:] 48 | if image.shape[1]%2==1: 49 | image=image[:,:-1,:] 50 | return image 51 | 52 | 53 | class mytransform(object): 54 | def __call__(self, sample): 55 | input_image, label_image,name= sample['input_image'], sample['label_image'],sample['name'] 56 | # swap color axis because 57 | # numpy image: H x W x C 58 | # torch image: C X H X W 59 | #input_image = numpy.asarray(input_image).transpose(2, 0, 1)/255.0 60 | #label_image = numpy.asarray(label_image).transpose(2, 0, 1)/255.0 61 | input_image=edge_clip(input_image) 62 | label_image=edge_clip(label_image) 63 | 64 | input_image = input_image.transpose(2, 0, 1)/255.0 65 | label_image = label_image.transpose(2, 0, 1)/255.0 66 | 67 | 68 | return {'input_image': torch.from_numpy(input_image).float(), 69 | 'label_image': torch.from_numpy(label_image).float(), 70 | 'name':name} 71 | 72 | def wrap_variable(input_batch, label_batch, use_gpu,flag): 73 | if use_gpu: 74 | input_batch, label_batch = (Variable(input_batch.cuda(),volatile=flag), Variable(label_batch.cuda(),volatile=flag)) 75 | 76 | else: 77 | input_batch, label_batch = (Variable(input_batch,volatile=flag),Variable(label_batch,volatile=flag)) 78 | return input_batch, label_batch 79 | 80 | 81 | def checkpoint(name,psnr1,psnr2,ssim1,ssim2,msssim1,msssim2): 82 | print('{},psnr:{:.4f}->{:.4f},ssim:{:.4f}->{:.4f},msssim:{:.4f}-{:.4f}'.format(name,psnr1,psnr2,ssim1,ssim2,msssim1,msssim2)) 83 | #write to text 84 | output = open(os.path.join(Image_folder,'test_result.txt'),'a+') 85 | output.write(('{} {:.4f}->{:.4f} {:.4f}->{:.4f},{:.4f}->{:.4f}'.format(name,psnr1,psnr2,ssim1,ssim2,msssim1,msssim2))+'\r\n') 86 | output.close() 87 | def save(output_image,name): 88 | output_data=output_image.data[0] 89 | if use_gpu: 90 | img=255.0*output_data.clone().cpu().numpy() 91 | else: 92 | img=255.0*output_data.clone().numpy() 93 | img = img.transpose(1, 2, 0).astype("uint8") 94 | #img = Image.fromarray(img) 95 | #img.save(os.path.join(Image_folder,'output','{}.jpg'.format(name[:-4]))) 96 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 97 | cv2.imwrite(os.path.join(Image_folder,'output','{}.png'.format(name[:-4])),img) 98 | 99 | downsampling=nn.AvgPool2d(2) 100 | 101 | def MS(ms_model,inputs): 102 | inputs_1=downsampling(inputs) 103 | inputs_2=downsampling(inputs_1) 104 | outputs, outputs_1, outputs_2 = ms_model(inputs,inputs_1,inputs_2) 105 | return outputs 106 | 107 | def test(): 108 | model.eval() 109 | avg_psnr1 = 0 110 | avg_ssim1 = 0 111 | avg_msssim1 = 0 112 | #output and label 113 | avg_psnr2 = 0 114 | avg_ssim2 = 0 115 | avg_msssim2 = 0 116 | 117 | 118 | for i, sample in enumerate(dataloader): 119 | input_image,label_image,name=sample['input_image'],sample['label_image'],sample['name'][0]#tuple to str 120 | 121 | 122 | 123 | #################################### 124 | #Wrap with torch Variable 125 | input_image,label_image=wrap_variable(input_image,label_image, use_gpu,True) 126 | 127 | ############################### 128 | 129 | #some parameters 130 | patch_size=120 131 | #print(input_image.size()) 132 | [batch_number,channel,height,width]=input_image.size() 133 | row_numbers=int(height/patch_size) 134 | col_numbers=int(width/patch_size) 135 | #1 create the output tensor 136 | output_image=Variable(torch.zeros(batch_number,channel,height,width).cuda(),volatile=True) 137 | 138 | for row_number in range(row_numbers): 139 | for col_number in range(col_numbers): 140 | row_start=patch_size*row_number 141 | row_end =patch_size*(row_number+1) 142 | col_start=patch_size*col_number 143 | col_end =patch_size*(col_number+1) 144 | 145 | input_patch=input_image[:,:,row_start:row_end,col_start:col_end] 146 | 147 | #estimate the output 148 | output_patch=MS(model,input_patch) 149 | 150 | output_image[:,:,row_start:row_end,col_start:col_end]=output_patch 151 | 152 | for row_number in range(row_numbers): 153 | row_start=patch_size*row_number 154 | row_end =patch_size*(row_number+1) 155 | col_start=width - patch_size 156 | col_end =width 157 | input_patch=input_image[:,:,row_start:row_end,col_start:col_end] 158 | 159 | #estimate the output 160 | output_patch=MS(model,input_patch) 161 | output_image[:,:,row_start:row_end,col_start:col_end]=output_patch 162 | 163 | for col_number in range(col_numbers): 164 | row_start=height- patch_size 165 | row_end =height 166 | col_start=patch_size*col_number 167 | col_end =patch_size*(col_number+1) 168 | input_patch=input_image[:,:,row_start:row_end,col_start:col_end] 169 | 170 | #estimate the output 171 | output_patch=MS(model,input_patch) 172 | output_image[:,:,row_start:row_end,col_start:col_end]=output_patch 173 | 174 | #last one 175 | row_start=height- patch_size 176 | row_end =height 177 | col_start=width- patch_size 178 | col_end =width 179 | input_patch=input_image[:,:,row_start:row_end,col_start:col_end] 180 | #estimate the output 181 | output_patch=MS(model,input_patch) 182 | output_image[:,:,row_start:row_end,col_start:col_end]=output_patch 183 | ####################################### 184 | 185 | ################################## 186 | 187 | #predict 188 | #output_image = model(input_image) 189 | #clamp in[0,1] 190 | output_image=output_image.clamp(0.0, 1.0) 191 | 192 | 193 | #calculate psnr 194 | psnr1 =myutils.psnr(input_image, label_image) 195 | psnr2 =myutils.psnr(output_image, label_image) 196 | # ssim is calculated with the normalize (range [0, 1]) image 197 | ssim1 = torch.sum((myutils.ssim(input_image, label_image, size_average=False)).data)/1.0#batch_size 198 | ssim2 = torch.sum((myutils.ssim(output_image, label_image, size_average=False)).data)/1.0 199 | 200 | #msssim 201 | msssim1 = numpy.sum((msssim.MultiScaleSSIM(numpy.expand_dims(input_image.data[0].clone().cpu().numpy(),axis=0), numpy.expand_dims(label_image.data[0].clone().cpu().numpy(),axis=0), max_val=1.0)))/1.0#batch_size 202 | msssim2 = numpy.sum((msssim.MultiScaleSSIM(numpy.expand_dims(output_image.data[0].clone().cpu().numpy(),axis=0), numpy.expand_dims(label_image.data[0].clone().cpu().numpy(),axis=0), max_val=1.0)))/1.0 203 | 204 | avg_ssim1 += ssim1 205 | avg_psnr1 += psnr1 206 | avg_ssim2 += ssim2 207 | avg_psnr2 += psnr2 208 | avg_msssim1 += msssim1 209 | avg_msssim2 += msssim2 210 | 211 | #save output and record 212 | checkpoint(name,psnr1,psnr2,ssim1,ssim2,msssim1,msssim2) 213 | save(output_image,name) 214 | 215 | #print and save 216 | avg_psnr1 = avg_psnr1/len(dataloader) 217 | avg_ssim1 = avg_ssim1/len(dataloader) 218 | avg_psnr2 = avg_psnr2/len(dataloader) 219 | avg_ssim2 = avg_ssim2/len(dataloader) 220 | avg_msssim1 = avg_msssim1/len(dataloader) 221 | avg_msssim2 = avg_msssim2/len(dataloader) 222 | 223 | print('Avg. PSNR: {:.4f}->{:.4f} Avg. SSIM: {:.4f}->{:.4f} Avg. MSSSIM: {:.4f}->{:.4f}'.format(avg_psnr1,avg_psnr2,avg_ssim1,avg_ssim2,avg_msssim1,avg_msssim2)) 224 | output = open(os.path.join(Image_folder,'test_result.txt'),'a+') 225 | output.write('Avg. PSNR: {:.4f}->{:.4f} Avg. SSIM: {:.4f}->{:.4f}Avg. MSSSIM: {:.4f}->{:.4f}'.format(avg_psnr1,avg_psnr2,avg_ssim1,avg_ssim2,avg_msssim1,avg_msssim2)+'\r\n') 226 | output.close() 227 | 228 | 229 | #------------------------------------------------------------------ 230 | #cuda 231 | use_gpu=torch.cuda.is_available() 232 | 233 | #set path 234 | root_dir=os.getcwd() 235 | Image_folder=os.path.join(root_dir) 236 | model_weights_file=os.path.join(root_dir,'parameters_nosao','MS_47-0.001301-28.9525-0.8454param.pth') 237 | 238 | 239 | #set model 240 | model=MS_model.IntraDeblocking() 241 | #vgg=mymodel_new.Vgg16(requires_grad=False) 242 | if use_gpu: 243 | model = model.cuda() 244 | #vgg = vgg.cuda() 245 | 246 | model.load_state_dict(torch.load(model_weights_file)) 247 | 248 | #set dataset and dataloader 249 | mydataset = ImageDataset(root_dir=Image_folder, transform=mytransform()) 250 | dataloader = DataLoader(mydataset, batch_size=1,shuffle=False, num_workers=0) 251 | 252 | 253 | 254 | def main(): 255 | test() 256 | 257 | if __name__=='__main__': 258 | main() 259 | -------------------------------------------------------------------------------- /Multi-Scale/train_multi_scale.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import h5py 3 | import numpy 4 | from pathlib import Path 5 | import torch 6 | import torch.nn as nn 7 | from torch.utils.data import Dataset,DataLoader 8 | from torch.autograd import Variable 9 | 10 | import MS_model 11 | import myutils 12 | 13 | #prepare data 14 | class MyDataset(Dataset): 15 | def __init__(self,data_file): 16 | self.file=h5py.File(str(data_file),'r') 17 | self.inputs=self.file['data'][:].astype(numpy.float32)/255.0#simple normalization in[0,1] 18 | self.label=self.file['label'][:].astype(numpy.float32)/255.0 19 | 20 | def __len__(self): 21 | return self.inputs.shape[0] 22 | 23 | def __getitem__(self,idx): 24 | inputs=self.inputs[idx,:,:,:].transpose(2,0,1) 25 | label=self.label[idx,:,:,:].transpose(2,0,1) 26 | inputs=torch.Tensor(inputs) 27 | label=torch.Tensor(label) 28 | sample={'inputs':inputs,'label':label} 29 | return sample 30 | 31 | 32 | #def checkpoint(epoch,loss,psnr,ssim,mse): 33 | def checkpoint(epoch,loss,loss_0,loss_1,loss_2,psnr,ssim,mse,mse_0,mse_1,mse_2): 34 | model.eval() 35 | model_path1 = str(checkpoint_dir/'{}-{:.6f}-{:.4f}-{:.4f}.pth'.format(epoch,loss,psnr,ssim)) 36 | torch.save(model,model_path1) 37 | 38 | if use_gpu: 39 | model.cpu()#you should save weights on cpu not on gpu 40 | 41 | #save weights 42 | model_path = str(checkpoint_dir/'{}-{:.6f}-{:.4f}-{:.4f}param.pth'.format(epoch,loss,psnr,ssim)) 43 | 44 | torch.save(model.state_dict(),model_path) 45 | 46 | #print and save record 47 | print('Epoch {} : Avg.loss:{:.6f}'.format(epoch,loss)) 48 | print("Test Avg. PSNR: {:.4f} Avg. SSIM: {:.4f} Avg.MSE{:.6f} ".format(psnr,ssim,mse)) 49 | print("Checkpoint saved to {}".format(model_path)) 50 | 51 | output = open(str(checkpoint_dir/'train_result.txt'),'a+') 52 | output.write(('{} {:.4f} {:.4f} {:.4f}'.format(epoch,loss,psnr,ssim))+'\r\n') 53 | output.write(('{} {:.6f} {:.6f} {:.6f} {:.6f} {:.4f} {:.4f} {:.6f} {:.6f} {:.6f} {:.6f}'.format(epoch,loss,loss_0,loss_1,loss_2,psnr,ssim,mse,mse_0,mse_1,mse_2))+'\r\n') 54 | output.close() 55 | 56 | if use_gpu: 57 | model.cuda()#don't forget return to gpu 58 | #model.train() 59 | 60 | 61 | def wrap_variable(input_batch, label_batch, use_gpu,flag): 62 | if use_gpu: 63 | input_batch, label_batch = (Variable(input_batch.cuda(),volatile=flag), Variable(label_batch.cuda(),volatile=flag)) 64 | 65 | else: 66 | input_batch, label_batch = (Variable(input_batch,volatile=flag),Variable(label_batch,volatile=flag)) 67 | return input_batch, label_batch 68 | 69 | downsampling=nn.AvgPool2d(2) 70 | 71 | def train(epoch): 72 | model.train() 73 | sum_loss=0.0 74 | sum_loss_0=0.0 75 | sum_loss_1=0.0 76 | sum_loss_2=0.0 77 | 78 | 79 | for iteration, sample in enumerate(dataloader):#difference between (dataloader) &(dataloader,1) 80 | inputs,label=sample['inputs'],sample['label'] 81 | 82 | 83 | #Wrap with torch Variable 84 | inputs,label=wrap_variable(inputs, label, use_gpu, False) 85 | inputs_1=downsampling(inputs) 86 | label_1=downsampling(label) 87 | inputs_2=downsampling(inputs_1) 88 | label_2=downsampling(label_1) 89 | #clear the optimizer 90 | optimizer.zero_grad() 91 | 92 | # forward propagation 93 | outputs, outputs_1, outputs_2 = model(inputs,inputs_1,inputs_2) 94 | 95 | #get the loss for backward 96 | loss_0 =criterion(outputs, label) 97 | loss_1 =criterion(outputs_1, label_1) 98 | loss_2 =criterion(outputs_2, label_2) 99 | 100 | #backward propagation and optimize 101 | loss=loss_0*1+loss_1*0.5+loss_2*0.25 102 | 103 | loss.backward() 104 | optimizer.step() 105 | 106 | if iteration%100==0: 107 | print("===> Epoch[{}]({}/{}):loss: {:.6f}-{:.6f}-{:.6f}-{:.6f}".format(epoch, iteration, len(dataloader), loss.data[0],loss_0.data[0],loss_1.data[0],loss_2.data[0])) 108 | #if iteration==101: 109 | # break 110 | 111 | #caculate the average loss 112 | sum_loss_0 += loss_0.data[0] 113 | sum_loss_1 += loss_1.data[0] 114 | sum_loss_2 += loss_2.data[0] 115 | sum_loss += loss.data[0] 116 | 117 | return sum_loss/len(dataloader),sum_loss_0/len(dataloader),sum_loss_1/len(dataloader),sum_loss_2/len(dataloader) 118 | 119 | 120 | def test(): 121 | model.eval() 122 | avg_psnr = 0 123 | avg_ssim = 0 124 | avg_mse = 0 125 | avg_mse_0 = 0 126 | avg_mse_1 = 0 127 | avg_mse_2 = 0 128 | 129 | for iteration, sample in enumerate(test_dataloader): 130 | inputs,label=sample['inputs'],sample['label'] 131 | #Wrap with torch Variable 132 | inputs,label=wrap_variable(inputs, label, use_gpu, True) 133 | 134 | inputs_1=downsampling(inputs) 135 | label_1=downsampling(label) 136 | inputs_2=downsampling(inputs_1) 137 | label_2=downsampling(label_1) 138 | 139 | outputs, outputs_1, outputs_2 = model(inputs,inputs_1,inputs_2) 140 | 141 | mse_0 = criterion(outputs,label).data[0] 142 | mse_1 = criterion(outputs_1,label_1).data[0] 143 | mse_2 = criterion(outputs_2,label_2).data[0] 144 | psnr = myutils.psnr(outputs, label) 145 | ssim = torch.sum((myutils.ssim(outputs, label, size_average=False)).data)/args.testbatchsize 146 | avg_ssim += ssim 147 | avg_psnr += psnr 148 | avg_mse += mse_0+mse_1*0.5+mse_2*0.25 149 | avg_mse_0 += mse_0 150 | avg_mse_1 += mse_1 151 | avg_mse_2 += mse_2 152 | 153 | return (avg_psnr / len(test_dataloader)),(avg_ssim / len(test_dataloader)),(avg_mse/len(test_dataloader)),(avg_mse_0/len(test_dataloader)),(avg_mse_1/len(test_dataloader)),(avg_mse_2/len(test_dataloader)) 154 | 155 | 156 | 157 | def main(): 158 | #train & test & record 159 | for epoch in range(args.epochs): 160 | loss,loss_0,loss_1,loss_2=train(epoch+200) 161 | psnr,ssim,mse,mse_0,mse_1,mse_2 = test() 162 | checkpoint(epoch+200,loss,loss_0,loss_1,loss_2,psnr,ssim,mse,mse_0,mse_1,mse_2) 163 | 164 | 165 | 166 | #--------------------------------------------------------------------------------------------------- 167 | # Training settings 168 | parser = argparse.ArgumentParser(description='ARCNN') 169 | parser.add_argument('--batchsize', type=int, default=64, help='training batch size') 170 | parser.add_argument('--testbatchsize', type=int, default=16, help='testing batch size') 171 | parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train for') 172 | parser.add_argument('--lr', type=float, default=0.0001, help='Learning Rate. Default=0.0001') 173 | args = parser.parse_args() 174 | 175 | print(args) 176 | 177 | #---------------------------------------------------------------------------------------------------- 178 | #set other parameters 179 | #1.set cuda 180 | use_gpu=torch.cuda.is_available() 181 | 182 | 183 | #2.set path and file 184 | save_dir = Path('.') 185 | checkpoint_dir = Path('.') / 'Checkpoints_multi_scale'#save model parameters and train record 186 | if checkpoint_dir.exists(): 187 | print 'folder esxited' 188 | else: 189 | checkpoint_dir.mkdir() 190 | 191 | model_weights_file=checkpoint_dir/'99-0.000451-34.2162-0.9449param.pth' 192 | 193 | 194 | #3.set dataset and dataloader 195 | dataset=MyDataset(data_file=save_dir/'TrainData32.h5') 196 | test_dataset=MyDataset(data_file=save_dir/'TestData32.h5') 197 | 198 | dataloader=DataLoader(dataset,batch_size=args.batchsize,shuffle=True,num_workers=0) 199 | test_dataloader=DataLoader(test_dataset,batch_size=args.testbatchsize,shuffle=False,num_workers=0) 200 | 201 | 202 | #4.set model& criterion& optimizer 203 | model=MS_model.IntraDeblocking() 204 | 205 | criterion = nn.MSELoss() 206 | optimizer=torch.optim.Adam(model.parameters(), lr=args.lr) 207 | 208 | if use_gpu: 209 | model = model.cuda() 210 | criterion = criterion.cuda() 211 | 212 | #load parameters 213 | if not use_gpu: 214 | model.load_state_dict(torch.load(str(model_weights_file), map_location=lambda storage, loc: storage)) 215 | #model=torch.load(str(model_weights_file), map_location=lambda storage, loc: storage) 216 | else: 217 | model.load_state_dict(torch.load(str(model_weights_file))) 218 | #model=torch.load(str(model_weights_file)) 219 | 220 | 221 | #show mdoel¶meters&dataset 222 | print('Model Structure:',model) 223 | params = list(model.parameters()) 224 | for i in range(len(params)): 225 | print('layer:',i+1,params[i].size()) 226 | 227 | print('length of dataset:',len(dataset)) 228 | 229 | 230 | if __name__=='__main__': 231 | main() 232 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Image deblocking using Convolutional Neural Networks 2 | 3 | ### Experiment Result: 4 | 5 | model | PSNR|SSIM|ranking| 6 | ---|---|---|---| 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /model_parameters/ARCNN_sao-qp42-90-0.001434-28.4783-0.8401param.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yydlmzyz/Image-deblocking-using-deep-learning/879a83f6eb4a6724ed633aa86f0956993a808ad0/model_parameters/ARCNN_sao-qp42-90-0.001434-28.4783-0.8401param.pth -------------------------------------------------------------------------------- /model_parameters/ARDenseNet_qp42-28-0.001368-28.6754-0.8436param.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yydlmzyz/Image-deblocking-using-deep-learning/879a83f6eb4a6724ed633aa86f0956993a808ad0/model_parameters/ARDenseNet_qp42-28-0.001368-28.6754-0.8436param.pth -------------------------------------------------------------------------------- /model_parameters/BEST-MS21-0.001316-29.0139-0.8454param.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yydlmzyz/Image-deblocking-using-deep-learning/879a83f6eb4a6724ed633aa86f0956993a808ad0/model_parameters/BEST-MS21-0.001316-29.0139-0.8454param.pth -------------------------------------------------------------------------------- /model_parameters/L1-164-0.024904-28.6810-0.8440param.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yydlmzyz/Image-deblocking-using-deep-learning/879a83f6eb4a6724ed633aa86f0956993a808ad0/model_parameters/L1-164-0.024904-28.6810-0.8440param.pth -------------------------------------------------------------------------------- /model_parameters/L8_qp42-99-0.001374-28.6387-0.8429param.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yydlmzyz/Image-deblocking-using-deep-learning/879a83f6eb4a6724ed633aa86f0956993a808ad0/model_parameters/L8_qp42-99-0.001374-28.6387-0.8429param.pth -------------------------------------------------------------------------------- /model_parameters/MSqp42-124-0.001293-29.0237-0.8455param.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yydlmzyz/Image-deblocking-using-deep-learning/879a83f6eb4a6724ed633aa86f0956993a808ad0/model_parameters/MSqp42-124-0.001293-29.0237-0.8455param.pth -------------------------------------------------------------------------------- /model_parameters/edar2_qp42-63-0.001339-28.6824-0.8445param.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yydlmzyz/Image-deblocking-using-deep-learning/879a83f6eb4a6724ed633aa86f0956993a808ad0/model_parameters/edar2_qp42-63-0.001339-28.6824-0.8445param.pth -------------------------------------------------------------------------------- /model_parameters/vdar_qp42-51-0.001339-28.6784-0.8443param.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yydlmzyz/Image-deblocking-using-deep-learning/879a83f6eb4a6724ed633aa86f0956993a808ad0/model_parameters/vdar_qp42-51-0.001339-28.6784-0.8443param.pth --------------------------------------------------------------------------------