├── .gitattributes
├── .idea
├── .gitignore
├── .name
├── SRCNN-pytorch-master.iml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── README.md
├── __pycache__
├── datasets.cpython-36.pyc
├── models.cpython-36.pyc
└── utils.cpython-36.pyc
├── datasets.py
├── img
├── 2.tif
├── 2_bicubic_x3.tif
├── 2_srcnn_x3.tif
├── butterfly_GT.bmp
├── butterfly_GT_bicubic_x3.bmp
├── butterfly_GT_srcnn_x3.bmp
├── kenan.jpeg
├── kenan_bicubic_x3.jpeg
├── kenan_srcnn_x3.jpeg
├── zebra.bmp
├── zebra_bicubic_x3.bmp
└── zebra_srcnn_x3.bmp
├── models.py
├── outputs
└── x3
│ ├── best.pth
│ └── epoch_399.pth
├── prepare.py
├── test.py
├── thumbnails
└── fig1.png
├── train.py
└── utils.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # 默认忽略的文件
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/.idea/.name:
--------------------------------------------------------------------------------
1 | prepare.py
--------------------------------------------------------------------------------
/.idea/SRCNN-pytorch-master.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SRCNN_Pytorch
2 | CSDN博客代码讲解地址:https://blog.csdn.net/weixin_52261094/article/details/128389448
3 |
4 |
5 | ## Requirements
6 | - PyTorch 1.0.0
7 | - Numpy 1.15.4
8 | - Pillow 5.4.1
9 | - h5py 2.8.0
10 | - tqdm 4.30.0
11 |
12 |
13 | ### Step1: 训练集,测速集下载地址,
14 | #### 也可以用 preapre自己制作训练集和测试集
15 |
16 | | Dataset | Scale | Type | Link |
17 | |---------|-------|------|------|
18 | | 91-image | 2 | Train | [Download](https://github.com/learner-lu/image-super-resolution/releases/download/v0.0.1/91-image_x2.h5) |
19 | | 91-image | 3 | Train | [Download](https://github.com/learner-lu/image-super-resolution/releases/download/v0.0.1/91-image_x3.h5) |
20 | | 91-image | 4 | Train | [Download](https://github.com/learner-lu/image-super-resolution/releases/download/v0.0.1/91-image_x4.h5) |
21 | | Set5 | 2 | Eval | [Download](https://github.com/learner-lu/image-super-resolution/releases/download/v0.0.1/Set5_x2.h5) |
22 | | Set5 | 3 | Eval | [Download](https://github.com/learner-lu/image-super-resolution/releases/download/v0.0.1/Set5_x3.h5) |
23 | | Set5 | 4 | Eval | [Download](https://github.com/learner-lu/image-super-resolution/releases/download/v0.0.1/Set5_x4.h5) |
24 |
25 | Download any one of 91-image and Set5 in the same Scale and then **move them under `./datasets` as `./datasets/91-image_x2.h5` and `./datasets/Set5_x2.h5`**
26 |
27 | ### Step2: 训练模型
28 |
29 | --train-file "path_to_train_file" \
30 | --eval-file "path_to_eval_file" \
31 | --outputs-dir "path_to_outputs_file" \
32 | --scale 3 \
33 | --lr 1e-4 \
34 | --batch-size 16 \
35 | --num-epochs 400 \
36 | --num-workers 0 \
37 | --seed 123
38 |
39 |
40 |
41 |
42 | ### Step3: 400轮训练结果,训练得到的最优权重
43 |
44 | - [trained by x3](https://pan.baidu.com/s/1sLGsDPuC7BCUaVMDRv013A?pwd=1234)
45 |
46 | **将权重移动到该目录文件下: `./outputs` as `./outputs/x3/best.pth`**
47 |
48 |
49 |
--------------------------------------------------------------------------------
/__pycache__/datasets.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1990571096/SRCNN_Pytorch/29bdbc07430a869fbfcf7d8711d3191c5c4a57e6/__pycache__/datasets.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/models.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1990571096/SRCNN_Pytorch/29bdbc07430a869fbfcf7d8711d3191c5c4a57e6/__pycache__/models.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1990571096/SRCNN_Pytorch/29bdbc07430a869fbfcf7d8711d3191c5c4a57e6/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import h5py # 一个h5py文件是 “dataset” 和 “group” 二合一的容器。
2 | import numpy as np
3 | from torch.utils.data import Dataset
4 |
5 | '''为这些数据创建一个读取类,以便torch中的DataLoader调用,而DataLoader中的内容则是Dataset,
6 | 所以新建的读取类需要继承Dataset,并实现其__getitem__和__len__这两个成员方法。
7 | '''
8 |
9 | class TrainDataset(Dataset): # 构建训练数据集,通过np.expand_dims将h5文件中的lr(低分辨率图像)和hr(高分辨率图像)组合为训练集
10 | def __init__(self, h5_file):
11 | super(TrainDataset, self).__init__()
12 | self.h5_file = h5_file
13 |
14 | def __getitem__(self, idx): #通过np.expand_dims方法得到组合的新数据
15 | with h5py.File(self.h5_file, 'r') as f:
16 | return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0)
17 |
18 | def __len__(self): #得到数据大小
19 | with h5py.File(self.h5_file, 'r') as f:
20 | return len(f['lr'])
21 |
22 | # 与TrainDataset类似
23 | class EvalDataset(Dataset): # 构建测试数据集,通过np.expand_dims将h5文件中的lr(低分辨率图像)和hr(高分辨率图像)组合为验证集
24 | def __init__(self, h5_file):
25 | super(EvalDataset, self).__init__()
26 | self.h5_file = h5_file
27 |
28 | def __getitem__(self, idx):
29 | with h5py.File(self.h5_file, 'r') as f:
30 | return np.expand_dims(f['lr'][str(idx)][:, :] / 255., 0), np.expand_dims(f['hr'][str(idx)][:, :] / 255., 0)
31 |
32 | def __len__(self):
33 | with h5py.File(self.h5_file, 'r') as f:
34 | return len(f['lr'])
--------------------------------------------------------------------------------
/img/2.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1990571096/SRCNN_Pytorch/29bdbc07430a869fbfcf7d8711d3191c5c4a57e6/img/2.tif
--------------------------------------------------------------------------------
/img/2_bicubic_x3.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1990571096/SRCNN_Pytorch/29bdbc07430a869fbfcf7d8711d3191c5c4a57e6/img/2_bicubic_x3.tif
--------------------------------------------------------------------------------
/img/2_srcnn_x3.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1990571096/SRCNN_Pytorch/29bdbc07430a869fbfcf7d8711d3191c5c4a57e6/img/2_srcnn_x3.tif
--------------------------------------------------------------------------------
/img/butterfly_GT.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1990571096/SRCNN_Pytorch/29bdbc07430a869fbfcf7d8711d3191c5c4a57e6/img/butterfly_GT.bmp
--------------------------------------------------------------------------------
/img/butterfly_GT_bicubic_x3.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1990571096/SRCNN_Pytorch/29bdbc07430a869fbfcf7d8711d3191c5c4a57e6/img/butterfly_GT_bicubic_x3.bmp
--------------------------------------------------------------------------------
/img/butterfly_GT_srcnn_x3.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1990571096/SRCNN_Pytorch/29bdbc07430a869fbfcf7d8711d3191c5c4a57e6/img/butterfly_GT_srcnn_x3.bmp
--------------------------------------------------------------------------------
/img/kenan.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1990571096/SRCNN_Pytorch/29bdbc07430a869fbfcf7d8711d3191c5c4a57e6/img/kenan.jpeg
--------------------------------------------------------------------------------
/img/kenan_bicubic_x3.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1990571096/SRCNN_Pytorch/29bdbc07430a869fbfcf7d8711d3191c5c4a57e6/img/kenan_bicubic_x3.jpeg
--------------------------------------------------------------------------------
/img/kenan_srcnn_x3.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1990571096/SRCNN_Pytorch/29bdbc07430a869fbfcf7d8711d3191c5c4a57e6/img/kenan_srcnn_x3.jpeg
--------------------------------------------------------------------------------
/img/zebra.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1990571096/SRCNN_Pytorch/29bdbc07430a869fbfcf7d8711d3191c5c4a57e6/img/zebra.bmp
--------------------------------------------------------------------------------
/img/zebra_bicubic_x3.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1990571096/SRCNN_Pytorch/29bdbc07430a869fbfcf7d8711d3191c5c4a57e6/img/zebra_bicubic_x3.bmp
--------------------------------------------------------------------------------
/img/zebra_srcnn_x3.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1990571096/SRCNN_Pytorch/29bdbc07430a869fbfcf7d8711d3191c5c4a57e6/img/zebra_srcnn_x3.bmp
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 | class SRCNN(nn.Module): #搭建SRCNN 3层卷积模型,Conve2d(输入层数,输出层数,卷积核大小,步长,填充层)
4 | def __init__(self, num_channels=1):
5 | super(SRCNN, self).__init__()
6 | self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2)
7 | self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2)
8 | self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2)
9 | self.relu = nn.ReLU(inplace=True)
10 |
11 | def forward(self, x):
12 | x = self.relu(self.conv1(x))
13 | x = self.relu(self.conv2(x))
14 | x = self.conv3(x)
15 | return x
16 |
--------------------------------------------------------------------------------
/outputs/x3/best.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1990571096/SRCNN_Pytorch/29bdbc07430a869fbfcf7d8711d3191c5c4a57e6/outputs/x3/best.pth
--------------------------------------------------------------------------------
/outputs/x3/epoch_399.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1990571096/SRCNN_Pytorch/29bdbc07430a869fbfcf7d8711d3191c5c4a57e6/outputs/x3/epoch_399.pth
--------------------------------------------------------------------------------
/prepare.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import h5py
4 | import numpy as np
5 | import PIL.Image as pil_image
6 | from utils import convert_rgb_to_y
7 |
8 | '''
9 | 训练数据集可手动生成,设放大倍数为scale,考虑到原始数据未必会被scale整除,所以要重新规划一下图像尺寸,所以训练数据集的生成分为三步:
10 | 1.将原始图像通过双三次插值重设尺寸,使之可被scale整除,作为高分辨图像数据HR
11 | 2.将HR通过双三次插值压缩scale倍,为低分辨图像的原始数据
12 | 3.将低分辨图像通过双三次插值放大scale倍,与HR图像维度相等,作为低分辨图像数据LR
13 | 最后,可通过h5py将训练数据分块并打包
14 | '''
15 | # 生成训练集
16 | def train(args):
17 |
18 | """
19 | def是python的关键字,用来定义函数。这里通过def定义名为train的函数,函数的参数为args,args这个参数通过外部命令行传入output
20 | 的路径,通过h5py.File()方法的w模式--创建文件自己自写,已经存在的文件会被覆盖,文件的路径是通过args.output_path来传入
21 | """
22 | h5_file = h5py.File(args.output_path, 'w')
23 | # #用于存储低分辨率和高分辨率的patch
24 | lr_patches = []
25 | hr_patches = []
26 |
27 | for image_path in sorted(glob.glob('{}/*'.format(args.images_dir))):
28 | '''
29 | 这部分代码的目的就是搜索指定文件夹下的文件并排序,for这一句包含了几个知识点:
30 | 1.{}.format():-->格式化输出函数,从args.images_dir路径中格式化输出路径
31 | 2.glob.glob():-->返回所有匹配的文件路径列表,将1得到的路径中的所有文件返回
32 | 3.sorted():-->排序,将2得到的所有文件按照某种顺序返回,,默认是升序
33 | 4.for x in *: -->循换输出
34 | '''
35 | #将照片转换为RGB通道
36 | hr = pil_image.open(image_path).convert('RGB')
37 | '''
38 | 1. *.open(): 是PIL图像库的函数,用来从image_path中加载图像
39 | 2. *.convert(): 是PIL图像库的函数, 用来转换图像的模式
40 | '''
41 | #取放大倍数的倍数, width, height为可被scale整除的训练数据尺寸
42 | hr_width = (hr.width // args.scale) * args.scale
43 | hr_height = (hr.height // args.scale) * args.scale
44 | #图像大小调整,得到高分辨率图像Hr
45 | hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
46 | #低分辨率图像缩小
47 | lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)
48 | #低分辨率图像放大,得到低分辨率图像Lr
49 | lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)
50 | #转换为浮点并取ycrcb中的y通道
51 | hr = np.array(hr).astype(np.float32)
52 | lr = np.array(lr).astype(np.float32)
53 | hr = convert_rgb_to_y(hr)
54 | lr = convert_rgb_to_y(lr)
55 | '''
56 | np.array():将列表list或元组tuple转换为ndarray数组
57 | astype():转换数组的数据类型
58 | convert_rgb_to_y():将图像从RGB格式转换为Y通道格式的图片
59 | 假设原始输入图像为(321,481,3)-->依次为高,宽,通道数
60 | 1.先把图像转为可放缩的scale大小的图片,之后hr的图像尺寸为(320,480,3)
61 | 2.对hr图像进行双三次上采样放大操作
62 | 3.将hr//scale进行双三次上采样放大操作之后×scale得到lr
63 | 4.接着进行通道数转换和类型转换
64 | '''
65 | # 将数据分割
66 | for i in range(0, lr.shape[0] - args.patch_size + 1, args.stride):
67 | for j in range(0, lr.shape[1] - args.patch_size + 1, args.stride):
68 | '''
69 | 图像的shape是宽度、高度和通道数,shape[0]是指图像的高度=320;shape[1]是图像的宽度=480; shape[2]是指图像的通道数
70 | '''
71 | lr_patches.append(lr[i:i + args.patch_size, j:j + args.patch_size])
72 | hr_patches.append(hr[i:i + args.patch_size, j:j + args.patch_size])
73 |
74 | lr_patches = np.array(lr_patches)
75 | hr_patches = np.array(hr_patches)
76 | #创建数据集,把得到的数据转化为数组类型
77 | h5_file.create_dataset('lr', data=lr_patches)
78 | h5_file.create_dataset('hr', data=hr_patches)
79 | h5_file.close()
80 |
81 | #下同,生成测试集
82 | def eval(args):
83 | h5_file = h5py.File(args.output_path, 'w')
84 |
85 | lr_group = h5_file.create_group('lr')
86 | hr_group = h5_file.create_group('hr')
87 |
88 | for i, image_path in enumerate(sorted(glob.glob('{}/*'.format(args.images_dir)))):
89 | hr = pil_image.open(image_path).convert('RGB')
90 | hr_width = (hr.width // args.scale) * args.scale
91 | hr_height = (hr.height // args.scale) * args.scale
92 | hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
93 | lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)
94 | lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)
95 | hr = np.array(hr).astype(np.float32)
96 | lr = np.array(lr).astype(np.float32)
97 | hr = convert_rgb_to_y(hr)
98 | lr = convert_rgb_to_y(lr)
99 |
100 | lr_group.create_dataset(str(i), data=lr)
101 | hr_group.create_dataset(str(i), data=hr)
102 |
103 | h5_file.close()
104 |
105 |
106 | if __name__ == '__main__':
107 | parser = argparse.ArgumentParser()
108 | parser.add_argument('--images-dir', type=str, required=True)
109 | parser.add_argument('--output-path', type=str, required=True)
110 | parser.add_argument('--patch-size', type=int, default=32)
111 | parser.add_argument('--stride', type=int, default=14)
112 | parser.add_argument('--scale', type=int, default=4)
113 | parser.add_argument('--eval', action='store_true') #store_flase就是存储一个bool值true,也就是说在该参数在被激活时它会输出store存储的值true。
114 | args = parser.parse_args()
115 |
116 | #决定使用哪个函数来生成h5文件,因为有俩个不同的函数train和eval生成对应的h5文件。
117 | if not args.eval:
118 | train(args)
119 | else:
120 | eval(args)
121 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | import torch.backends.cudnn as cudnn
5 | import numpy as np
6 | import PIL.Image as pil_image
7 |
8 | from models import SRCNN
9 | from utils import convert_rgb_to_ycbcr, convert_ycbcr_to_rgb, calc_psnr
10 |
11 |
12 | if __name__ == '__main__':
13 | # 设置权重参数目录,处理图像目录,放大倍数
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--weights-file', default='outputs/x3/best.pth', type=str)
16 | parser.add_argument('--image-file', default='img/img_kenan.jpeg', type=str)
17 | parser.add_argument('--scale', type=int, default=3)
18 | args = parser.parse_args()
19 | # Benchmark模式会提升计算速度
20 | cudnn.benchmark = True
21 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
22 |
23 | model = SRCNN().to(device) # 新建一个模型
24 |
25 | state_dict = model.state_dict() # 通过 model.state_dict()得到模型有哪些 parameters and persistent buffers
26 | # torch.load('tensors.pth', map_location=lambda storage, loc: storage) 使用函数将所有张量加载到CPU(适用在GPU训练的模型在CPU上加载)
27 | for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items(): # 载入最好的模型参数
28 | if n in state_dict.keys():
29 | state_dict[n].copy_(p)
30 | else:
31 | raise KeyError(n)
32 |
33 | model.eval() # 切换为测试模式 ,取消dropout
34 |
35 | image = pil_image.open(args.image_file).convert('RGB') # 将图片转为RGB类型
36 |
37 | # 经过一个插值操作,首先将原始图片重设尺寸,使之可以被放大倍数scale整除
38 | # 得到低分辨率图像Lr,即三次插值后的图像,同时保存输出
39 | image_width = (image.width // args.scale) * args.scale
40 | image_height = (image.height // args.scale) * args.scale
41 | image = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
42 | image = image.resize((image.width // args.scale, image.height // args.scale), resample=pil_image.BICUBIC)
43 | image = image.resize((image.width * args.scale, image.height * args.scale), resample=pil_image.BICUBIC)
44 | image.save(args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale)))
45 | # 将图像转化为数组类型,同时图像转为ycbcr类型
46 | image = np.array(image).astype(np.float32)
47 | ycbcr = convert_rgb_to_ycbcr(image)
48 | # 得到 ycbcr中的 y 通道
49 | y = ycbcr[..., 0]
50 | y /= 255. # 归一化处理
51 | y = torch.from_numpy(y).to(device) #把数组转换成张量,且二者共享内存,对张量进行修改比如重新赋值,那么原始数组也会相应发生改变,并且将参数放到device上
52 | y = y.unsqueeze(0).unsqueeze(0) # 增加两个维度
53 | # 令reqires_grad自动设为False,关闭自动求导
54 | # clamp将inputs归一化为0到1区间
55 | with torch.no_grad():
56 | preds = model(y).clamp(0.0, 1.0)
57 |
58 | psnr = calc_psnr(y, preds) # 计算y通道的psnr值
59 | print('PSNR: {:.2f}'.format(psnr)) # 格式化输出PSNR值
60 |
61 | # 1.mul函数类似矩阵.*,即每个元素×255
62 | # 2. *.cpu().numpy() 将数据的处理设备从其他设备(如gpu拿到cpu上),不会改变变量类型,转换后仍然是Tensor变量,同时将Tensor转化为ndarray
63 | # 3. *.squeeze(0).squeeze(0)数据的维度进行压缩
64 | preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0) #得到的是经过模型处理,取值在[0,255]的y通道图像
65 |
66 | # 将img的数据格式由(channels,imagesize,imagesize)转化为(imagesize,imagesize,channels),进行格式的转换后方可进行显示。
67 | output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
68 |
69 | output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8) # 将图像格式从ycbcr转为rgb,限制取值范围[0,255],同时矩阵元素类型为uint8类型
70 | output = pil_image.fromarray(output) # array转换成image,即将矩阵转为图像
71 | output.save(args.image_file.replace('.', '_srcnn_x{}.'.format(args.scale))) # 对图像进行保存
72 |
--------------------------------------------------------------------------------
/thumbnails/fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1990571096/SRCNN_Pytorch/29bdbc07430a869fbfcf7d8711d3191c5c4a57e6/thumbnails/fig1.png
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import copy
4 |
5 | import numpy as np
6 | from torch import Tensor
7 | import torch
8 | from torch import nn
9 | import torch.optim as optim
10 |
11 | # gpu加速库
12 | import torch.backends.cudnn as cudnn
13 |
14 | from torch.utils.data.dataloader import DataLoader
15 |
16 | # 进度条
17 | from tqdm import tqdm
18 |
19 | from models import SRCNN
20 | from datasets import TrainDataset, EvalDataset
21 | from utils import AverageMeter, calc_psnr
22 |
23 | ##需要修改的参数
24 | # epoch.pth
25 | # losslog
26 | # psnrlog
27 | # best.pth
28 |
29 | '''
30 | python train.py --train-file "path_to_train_file" \
31 | --eval-file "path_to_eval_file" \
32 | --outputs-dir "path_to_outputs_file" \
33 | --scale 3 \
34 | --lr 1e-4 \
35 | --batch-size 16 \
36 | --num-epochs 400 \
37 | --num-workers 0 \
38 | --seed 123
39 | '''
40 | if __name__ == '__main__':
41 |
42 | # 初始参数设定
43 | parser = argparse.ArgumentParser() # argparse是python用于解析命令行参数和选项的标准模块
44 | parser.add_argument('--train-file', type=str, required=True,) # 训练 h5文件目录
45 | parser.add_argument('--eval-file', type=str, required=True) # 测试 h5文件目录
46 | parser.add_argument('--outputs-dir', type=str, required=True) #模型 .pth保存目录
47 | parser.add_argument('--scale', type=int, default=3) # 放大倍数
48 | parser.add_argument('--lr', type=float, default=1e-4) #学习率
49 | parser.add_argument('--batch-size', type=int, default=16) # 一次处理的图片大小
50 | parser.add_argument('--num-workers', type=int, default=0) # 线程数
51 | parser.add_argument('--num-epochs', type=int, default=400) #训练次数
52 | parser.add_argument('--seed', type=int, default=123) # 随机种子
53 | args = parser.parse_args()
54 |
55 | # 输出放入固定文件夹里
56 | args.outputs_dir = os.path.join(args.outputs_dir, 'x{}'.format(args.scale))
57 | # 没有该文件夹就新建一个文件夹
58 | if not os.path.exists(args.outputs_dir):
59 | os.makedirs(args.outputs_dir)
60 |
61 | # benckmark模式,加速计算,但寻找最优配置,计算的前馈结果会有差异
62 | cudnn.benchmark = True
63 |
64 | # gpu或者cpu模式,取决于当前cpu是否可用
65 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
66 |
67 | # 每次程序运行生成的随机数固定
68 | torch.manual_seed(args.seed)
69 |
70 | # 构建SRCNN模型,并且放到device上训练
71 | model = SRCNN().to(device)
72 |
73 | # 恢复训练,从之前结束的那个地方开始
74 | # model.load_state_dict(torch.load('outputs/x3/epoch_173.pth'))
75 |
76 | # 设置损失函数为MSE
77 | criterion = nn.MSELoss()
78 |
79 | # 优化函数Adam,lr代表学习率,
80 | optimizer = optim.Adam([
81 | {'params': model.conv1.parameters()},
82 | {'params': model.conv2.parameters()},
83 | {'params': model.conv3.parameters(), 'lr': args.lr * 0.1}
84 | ], lr=args.lr)
85 |
86 | # 预处理训练集
87 | train_dataset = TrainDataset(args.train_file)
88 | train_dataloader = DataLoader(
89 | # 数据
90 | dataset=train_dataset,
91 | # 分块
92 | batch_size=args.batch_size,
93 | # 数据集数据洗牌,打乱后取batch
94 | shuffle=True,
95 | # 工作进程,像是虚拟存储器中的页表机制
96 | num_workers=args.num_workers,
97 | # 锁页内存,不换出内存,生成的Tensor数据是属于内存中的锁页内存区
98 | pin_memory=True,
99 | # 不取余,丢弃不足batchSize大小的图像
100 | drop_last=True)
101 | # 预处理验证集
102 | eval_dataset = EvalDataset(args.eval_file)
103 | eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)
104 |
105 | # 拷贝权重
106 | best_weights = copy.deepcopy(model.state_dict())
107 | best_epoch = 0
108 | best_psnr = 0.0
109 |
110 | # 画图用
111 | lossLog = []
112 | psnrLog = []
113 |
114 | # 恢复训练
115 | # for epoch in range(args.num_epochs):
116 | for epoch in range(1, args.num_epochs + 1):
117 | # for epoch in range(174, 400):
118 | # 模型训练入口
119 | model.train()
120 |
121 | # 变量更新,计算epoch平均损失
122 | epoch_losses = AverageMeter()
123 |
124 | # 进度条,就是不要不足batchsize的部分
125 | with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size)) as t:
126 | # t.set_description('epoch:{}/{}'.format(epoch, args.num_epochs - 1))
127 | t.set_description('epoch:{}/{}'.format(epoch, args.num_epochs))
128 |
129 | # 每个batch计算一次
130 | for data in train_dataloader:
131 | # 对应datastes.py中的__getItem__,分别为lr,hr图像
132 | inputs, labels = data
133 |
134 | inputs = inputs.to(device)
135 | labels = labels.to(device)
136 | # 送入模型训练
137 | preds = model(inputs)
138 |
139 | # 获得损失
140 | loss = criterion(preds, labels)
141 |
142 | # 显示损失值与长度
143 | epoch_losses.update(loss.item(), len(inputs))
144 |
145 | # 梯度清零
146 | optimizer.zero_grad()
147 |
148 | # 反向传播
149 | loss.backward()
150 |
151 | # 更新参数
152 | optimizer.step()
153 |
154 | # 进度条更新
155 | t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
156 | t.update(len(inputs))
157 | # 记录lossLog 方面画图
158 | lossLog.append(np.array(epoch_losses.avg))
159 | # 可以在前面加上路径
160 | np.savetxt("lossLog.txt", lossLog)
161 |
162 | # 保存模型
163 | torch.save(model.state_dict(), os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch)))
164 |
165 | # 是否更新当前最好参数
166 | model.eval()
167 | epoch_psnr = AverageMeter()
168 |
169 | for data in eval_dataloader:
170 | inputs, labels = data
171 |
172 | inputs = inputs.to(device)
173 | labels = labels.to(device)
174 |
175 | # 验证不用求导
176 | with torch.no_grad():
177 | preds = model(inputs).clamp(0.0, 1.0)
178 |
179 | epoch_psnr.update(calc_psnr(preds, labels), len(inputs))
180 |
181 | print('eval psnr: {:.2f}'.format(epoch_psnr.avg))
182 |
183 | # 记录psnr
184 | psnrLog.append(Tensor.cpu(epoch_psnr.avg))
185 | np.savetxt('psnrLog.txt', psnrLog)
186 | # 找到更好的权重参数,更新
187 | if epoch_psnr.avg > best_psnr:
188 | best_epoch = epoch
189 | best_psnr = epoch_psnr.avg
190 | best_weights = copy.deepcopy(model.state_dict())
191 |
192 | print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
193 |
194 | torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth'))
195 |
196 | print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
197 |
198 | torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth'))
199 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | """
5 | 只操作y通道
6 | 因为我们感兴趣的不是颜色变化(存储在 CbCr 通道中的信息)而只是其亮度(Y 通道);
7 | 根本原因在于相较于色差,人类视觉对亮度变化更为敏感。
8 | """
9 | def convert_rgb_to_y(img):
10 | if type(img) == np.ndarray:
11 | return 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
12 | elif type(img) == torch.Tensor:
13 | if len(img.shape) == 4:
14 | img = img.squeeze(0)
15 | return 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256.
16 | else:
17 | raise Exception('Unknown Type', type(img))
18 |
19 | """
20 | RGB转YCBCR
21 | Y=0.257*R+0.564*G+0.098*B+16
22 | Cb=-0.148*R-0.291*G+0.439*B+128
23 | Cr=0.439*R-0.368*G-0.071*B+128
24 | """
25 | def convert_rgb_to_ycbcr(img):
26 | if type(img) == np.ndarray:
27 | y = 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
28 | cb = 128. + (-37.945 * img[:, :, 0] - 74.494 * img[:, :, 1] + 112.439 * img[:, :, 2]) / 256.
29 | cr = 128. + (112.439 * img[:, :, 0] - 94.154 * img[:, :, 1] - 18.285 * img[:, :, 2]) / 256.
30 | return np.array([y, cb, cr]).transpose([1, 2, 0])
31 | elif type(img) == torch.Tensor:
32 | if len(img.shape) == 4:
33 | img = img.squeeze(0)
34 | y = 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256.
35 | cb = 128. + (-37.945 * img[0, :, :] - 74.494 * img[1, :, :] + 112.439 * img[2, :, :]) / 256.
36 | cr = 128. + (112.439 * img[0, :, :] - 94.154 * img[1, :, :] - 18.285 * img[2, :, :]) / 256.
37 | return torch.cat([y, cb, cr], 0).permute(1, 2, 0)
38 | else:
39 | raise Exception('Unknown Type', type(img))
40 |
41 | """
42 | YCBCR转RGB
43 | R=1.164*(Y-16)+1.596*(Cr-128)
44 | G=1.164*(Y-16)-0.392*(Cb-128)-0.813*(Cr-128)
45 | B=1.164*(Y-16)+2.017*(Cb-128)
46 | """
47 | def convert_ycbcr_to_rgb(img):
48 | if type(img) == np.ndarray:
49 | r = 298.082 * img[:, :, 0] / 256. + 408.583 * img[:, :, 2] / 256. - 222.921
50 | g = 298.082 * img[:, :, 0] / 256. - 100.291 * img[:, :, 1] / 256. - 208.120 * img[:, :, 2] / 256. + 135.576
51 | b = 298.082 * img[:, :, 0] / 256. + 516.412 * img[:, :, 1] / 256. - 276.836
52 | return np.array([r, g, b]).transpose([1, 2, 0])
53 | elif type(img) == torch.Tensor:
54 | if len(img.shape) == 4:
55 | img = img.squeeze(0)
56 | r = 298.082 * img[0, :, :] / 256. + 408.583 * img[2, :, :] / 256. - 222.921
57 | g = 298.082 * img[0, :, :] / 256. - 100.291 * img[1, :, :] / 256. - 208.120 * img[2, :, :] / 256. + 135.576
58 | b = 298.082 * img[0, :, :] / 256. + 516.412 * img[1, :, :] / 256. - 276.836
59 | return torch.cat([r, g, b], 0).permute(1, 2, 0)
60 | else:
61 | raise Exception('Unknown Type', type(img))
62 |
63 | # PSNR 计算
64 | def calc_psnr(img1, img2):
65 | return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))
66 |
67 | # 计算 平均数,求和,长度
68 | class AverageMeter(object):
69 | def __init__(self):
70 | self.reset()
71 |
72 | def reset(self):
73 | self.val = 0
74 | self.avg = 0
75 | self.sum = 0
76 | self.count = 0
77 |
78 | def update(self, val, n=1):
79 | self.val = val
80 | self.sum += val * n
81 | self.count += n
82 | self.avg = self.sum / self.count
83 |
--------------------------------------------------------------------------------