├── README.md ├── Test ├── Set11_mat │ ├── Monarch.mat │ ├── Parrots.mat │ ├── barbara.mat │ ├── boats.mat │ ├── cameraman.mat │ ├── fingerprint.mat │ ├── flinstones.mat │ ├── foreman.mat │ ├── house.mat │ ├── lena256.mat │ └── peppers256.mat ├── Set14_mat │ ├── baboon.mat │ ├── barbara.mat │ ├── bridge.mat │ ├── coastguard.mat │ ├── comic.mat │ ├── face.mat │ ├── flowers.mat │ ├── foreman.mat │ ├── lenna.mat │ ├── man.mat │ ├── monarch.mat │ ├── pepper.mat │ ├── ppt3.mat │ └── zebra.mat ├── Set5_mat │ ├── baby_GT.mat │ ├── bird_GT.mat │ ├── butterfly_GT.mat │ ├── head_GT.mat │ └── woman_GT.mat ├── generate_testset_mat.m └── modcrop.m ├── data_utils.py ├── datasets ├── README.md ├── baby_GT.bmp ├── bird_GT.bmp ├── butterfly_GT.bmp ├── head_GT.bmp └── woman_GT.bmp ├── images ├── README.md ├── framework.jpg ├── results.jpg ├── results1.jpg ├── results2.jpg ├── table.jpg ├── table1.jpg ├── table2.jpg └── table3.jpg ├── lib ├── README.md └── network.py ├── models_subrate_0.1_blocksize_32 └── README.md ├── results_subrate_0.1_blocksize_32 └── README.md ├── test.py ├── test_new.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # CSNet-Pytorch 2 | 3 | Pytorch code for paper 4 | 5 | * "Deep Networks for Compressed Image Sensing" ICME2017 6 | 7 | * "Image Compressed Sensing Using Convolutional Neural Network" TIP2019 8 | 9 | ## Requirements and Dependencies 10 | 11 | * Ubuntu 16.04 CUDA 10.0 12 | * Python3 (Testing in Python3.5) 13 | * Pytorch 1.1.0 14 | * Torchvision 0.2.2 15 | 16 | ## Details of Implementations 17 | 18 | In our code, two model version are included: 19 | 20 | * simple version of CSNet (Similar with paper ICME2017) 21 | * Enhanced version of CSNet (local skip connection + global skip connection + resudial learning) 22 | 23 | ## How to Run 24 | 25 | ### Training CSNet 26 | * Preparing the dataset for training 27 | 28 | * Editing the path of training data in file `train.py`. 29 | 30 | * For CSNet training in terms of subrate=0.1: 31 | 32 | ```python train.py --sub_rate=0.1 --block_size=32``` 33 | 34 | ### Testing CSNet 35 | * Preparing the dataset for testing 36 | 37 | * Editing the path of trained model in file `test.py` and `test_new.py`. 38 | 39 | * For CSNet testing in terms of subrate=0.1: (**ps: For this testing code, there is a big gap compared with the result in the publised paper. And I am confused about it. If you know the reason, please let me know. Thanks very much!**) 40 | 41 | ```python test.py --sub_rate=0.1 --block_size=32``` 42 | 43 | * For CSNet testing (new testing code) in terms of subrate=0.1: 44 | 45 | ```python test_new.py --cuda --sub_rate=0.1 --block_size=32``` 46 | 47 | ## CSNet results 48 | ### Subjective results 49 | 50 | ![image](https://github.com/WenxueCui/CSNet-Pytorch/raw/master/images/results.jpg) 51 | 52 | ### Objective results 53 | ![image](https://github.com/WenxueCui/CSNet-Pytorch/raw/master/images/table.jpg) 54 | 55 | ## Additional instructions 56 | 57 | * For training data, you can choose any natural image dataset. 58 | * The training data is very important, if you can not achieve ideal result, maybe you can focus on the augmentation of training data or the structure of the network. 59 | * If you like this repo, Star or Fork to support my work. Thank you. 60 | * If you have any problem for this code, please email: wxcui@hit.edu.cn 61 | 62 | -------------------------------------------------------------------------------- /Test/Set11_mat/Monarch.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/Test/Set11_mat/Monarch.mat -------------------------------------------------------------------------------- /Test/Set11_mat/Parrots.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/Test/Set11_mat/Parrots.mat -------------------------------------------------------------------------------- /Test/Set11_mat/barbara.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/Test/Set11_mat/barbara.mat -------------------------------------------------------------------------------- /Test/Set11_mat/boats.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/Test/Set11_mat/boats.mat -------------------------------------------------------------------------------- /Test/Set11_mat/cameraman.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/Test/Set11_mat/cameraman.mat -------------------------------------------------------------------------------- /Test/Set11_mat/fingerprint.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/Test/Set11_mat/fingerprint.mat -------------------------------------------------------------------------------- /Test/Set11_mat/flinstones.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/Test/Set11_mat/flinstones.mat -------------------------------------------------------------------------------- /Test/Set11_mat/foreman.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/Test/Set11_mat/foreman.mat -------------------------------------------------------------------------------- /Test/Set11_mat/house.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/Test/Set11_mat/house.mat -------------------------------------------------------------------------------- /Test/Set11_mat/lena256.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/Test/Set11_mat/lena256.mat -------------------------------------------------------------------------------- /Test/Set11_mat/peppers256.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/Test/Set11_mat/peppers256.mat -------------------------------------------------------------------------------- /Test/Set14_mat/baboon.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/Test/Set14_mat/baboon.mat -------------------------------------------------------------------------------- /Test/Set14_mat/barbara.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/Test/Set14_mat/barbara.mat -------------------------------------------------------------------------------- /Test/Set14_mat/bridge.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/Test/Set14_mat/bridge.mat -------------------------------------------------------------------------------- /Test/Set14_mat/coastguard.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/Test/Set14_mat/coastguard.mat -------------------------------------------------------------------------------- /Test/Set14_mat/comic.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/Test/Set14_mat/comic.mat -------------------------------------------------------------------------------- /Test/Set14_mat/face.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/Test/Set14_mat/face.mat -------------------------------------------------------------------------------- /Test/Set14_mat/flowers.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/Test/Set14_mat/flowers.mat -------------------------------------------------------------------------------- /Test/Set14_mat/foreman.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/Test/Set14_mat/foreman.mat -------------------------------------------------------------------------------- /Test/Set14_mat/lenna.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/Test/Set14_mat/lenna.mat -------------------------------------------------------------------------------- /Test/Set14_mat/man.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/Test/Set14_mat/man.mat -------------------------------------------------------------------------------- /Test/Set14_mat/monarch.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/Test/Set14_mat/monarch.mat -------------------------------------------------------------------------------- /Test/Set14_mat/pepper.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/Test/Set14_mat/pepper.mat -------------------------------------------------------------------------------- /Test/Set14_mat/ppt3.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/Test/Set14_mat/ppt3.mat -------------------------------------------------------------------------------- /Test/Set14_mat/zebra.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/Test/Set14_mat/zebra.mat -------------------------------------------------------------------------------- /Test/Set5_mat/baby_GT.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/Test/Set5_mat/baby_GT.mat -------------------------------------------------------------------------------- /Test/Set5_mat/bird_GT.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/Test/Set5_mat/bird_GT.mat -------------------------------------------------------------------------------- /Test/Set5_mat/butterfly_GT.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/Test/Set5_mat/butterfly_GT.mat -------------------------------------------------------------------------------- /Test/Set5_mat/head_GT.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/Test/Set5_mat/head_GT.mat -------------------------------------------------------------------------------- /Test/Set5_mat/woman_GT.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/Test/Set5_mat/woman_GT.mat -------------------------------------------------------------------------------- /Test/generate_testset_mat.m: -------------------------------------------------------------------------------- 1 | clear;close all; 2 | %% settings 3 | folder = 'Set5'; 4 | scale = 1; 5 | blocksize = 32; 6 | 7 | %% generate data 8 | filepaths = dir(fullfile(folder,'*.bmp')); 9 | 10 | for i = 1 : length(filepaths) 11 | im_gt = imread(fullfile(folder,filepaths(i).name)); 12 | if size(im_gt, 3) > 1 13 | im_gt = modcrop(im_gt, blocksize); 14 | im_gt = double(im_gt); 15 | im_gt_ycbcr = rgb2ycbcr(im_gt / 255.0); 16 | im_gt_y = im_gt_ycbcr(:,:,1) * 255.0; 17 | im_l_ycbcr = imresize(im_gt_ycbcr, 1/scale, 'bicubic'); 18 | im_b_ycbcr = imresize(im_l_ycbcr, scale, 'bicubic'); 19 | im_l_y = im_l_ycbcr(:,:,1) * 255.0; 20 | im_l = ycbcr2rgb(im_l_ycbcr) * 255.0; 21 | im_b_y = im_b_ycbcr(:,:,1) * 255.0; 22 | im_b = ycbcr2rgb(im_b_ycbcr) * 255.0; 23 | 24 | mat_output = [folder, '_mat2']; 25 | if exist(mat_output, 'dir') == 0 26 | mkdir(mat_output); 27 | end 28 | filename = [mat_output,'/',filepaths(i).name(1:end-4),'.mat']; 29 | save(filename, 'im_gt_y', 'im_b_y', 'im_l_y'); 30 | else 31 | 32 | im_gt_y = im_gt; 33 | im_b_y = im_gt; 34 | im_l_y = im_gt; 35 | mat_output = [folder, '_mat2']; 36 | if exist(mat_output, 'dir') == 0 37 | mkdir(mat_output); 38 | end 39 | filename = [mat_output,'/',filepaths(i).name(1:end-4),'.mat']; 40 | save(filename, 'im_gt_y', 'im_b_y', 'im_l_y'); 41 | end 42 | end 43 | -------------------------------------------------------------------------------- /Test/modcrop.m: -------------------------------------------------------------------------------- 1 | function imgs = modcrop(imgs, modulo) 2 | if size(imgs,3)==1 3 | sz = size(imgs); 4 | sz = sz - mod(sz, modulo); 5 | imgs = imgs(1:sz(1), 1:sz(2)); 6 | else 7 | tmpsz = size(imgs); 8 | sz = tmpsz(1:2); 9 | sz = sz - mod(sz, modulo); 10 | imgs = imgs(1:sz(1), 1:sz(2),:); 11 | end 12 | 13 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | from os import listdir 2 | from os.path import join 3 | 4 | from PIL import Image 5 | from torch.utils.data.dataset import Dataset 6 | from torchvision.transforms import Compose, RandomResizedCrop, RandomHorizontalFlip, RandomVerticalFlip, RandomRotation, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize, Grayscale 7 | 8 | import random 9 | import math 10 | from torch.autograd import Variable 11 | import torch 12 | 13 | import torchvision.transforms as transforms 14 | 15 | # gray = transforms.Gray() 16 | import numpy as np 17 | 18 | def is_image_file(filename): 19 | return any(filename.endswith(extension) for extension in ['.png', 'bmp', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG']) 20 | 21 | 22 | def calculate_valid_crop_size(crop_size, blocksize): 23 | return crop_size - (crop_size % blocksize) 24 | 25 | 26 | def train_hr_transform(crop_size): 27 | return Compose([ 28 | RandomCrop(crop_size), 29 | RandomHorizontalFlip(p=0.5), 30 | RandomVerticalFlip(p=0.5), 31 | Grayscale(), 32 | ToTensor(), 33 | ]) 34 | 35 | 36 | 37 | def psnr(img1, img2): 38 | mse = torch.mean((img1 - img2) ** 2) 39 | if mse < 1.0e-10: 40 | return 100 41 | PIXEL_MAX = 1.0 42 | return 20 * math.log10(PIXEL_MAX/math.sqrt(mse)) 43 | 44 | 45 | class TrainDatasetFromFolder(Dataset): 46 | def __init__(self, dataset_dir, crop_size, blocksize): 47 | super(TrainDatasetFromFolder, self).__init__() 48 | self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)] 49 | crop_size = calculate_valid_crop_size(crop_size, blocksize) 50 | self.hr_transform = train_hr_transform(crop_size) 51 | 52 | def __getitem__(self, index): 53 | try: 54 | hr_image = self.hr_transform(Image.open(self.image_filenames[index])) 55 | return hr_image, hr_image 56 | except: 57 | hr_image = self.hr_transform(Image.open(self.image_filenames[index+1])) 58 | return hr_image, hr_image 59 | 60 | def __len__(self): 61 | return len(self.image_filenames) 62 | 63 | 64 | class TestDatasetFromFolder(Dataset): 65 | def __init__(self, dataset_dir, blocksize): 66 | super(TestDatasetFromFolder, self).__init__() 67 | self.blocksize = blocksize 68 | self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)] 69 | 70 | def __getitem__(self, index): 71 | hr_image = Image.open(self.image_filenames[index]) 72 | 73 | w, h = hr_image.size 74 | w = int(np.floor(w/self.blocksize)*self.blocksize) 75 | h = int(np.floor(h/self.blocksize)*self.blocksize) 76 | crop_size = (h, w) 77 | 78 | hr_image = CenterCrop(crop_size)(hr_image) 79 | hr_image = Grayscale()(hr_image) 80 | 81 | return ToTensor()(hr_image), ToTensor()(hr_image) 82 | 83 | def __len__(self): 84 | return len(self.image_filenames) 85 | 86 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | Testing dataset for CSNet. 2 | -------------------------------------------------------------------------------- /datasets/baby_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/datasets/baby_GT.bmp -------------------------------------------------------------------------------- /datasets/bird_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/datasets/bird_GT.bmp -------------------------------------------------------------------------------- /datasets/butterfly_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/datasets/butterfly_GT.bmp -------------------------------------------------------------------------------- /datasets/head_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/datasets/head_GT.bmp -------------------------------------------------------------------------------- /datasets/woman_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/datasets/woman_GT.bmp -------------------------------------------------------------------------------- /images/README.md: -------------------------------------------------------------------------------- 1 | The results of CSNet. 2 | -------------------------------------------------------------------------------- /images/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/images/framework.jpg -------------------------------------------------------------------------------- /images/results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/images/results.jpg -------------------------------------------------------------------------------- /images/results1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/images/results1.jpg -------------------------------------------------------------------------------- /images/results2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/images/results2.jpg -------------------------------------------------------------------------------- /images/table.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/images/table.jpg -------------------------------------------------------------------------------- /images/table1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/images/table1.jpg -------------------------------------------------------------------------------- /images/table2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/images/table2.jpg -------------------------------------------------------------------------------- /images/table3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenxueCui/CSNet-Pytorch/b8c15d920fadee81b39992e980f241056480212c/images/table3.jpg -------------------------------------------------------------------------------- /lib/README.md: -------------------------------------------------------------------------------- 1 | CSNet model is defined in network.py. 2 | 3 | In network.py, two versions are included: 4 | 5 | 1. The simple version with multiple conv layers (CSNet) 6 | 7 | 2. Enhanced version of CSNet (skip connection + residual block) 8 | -------------------------------------------------------------------------------- /lib/network.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | import torch 4 | import numpy as np 5 | import torch.nn.functional as F 6 | 7 | from torch.autograd import Variable 8 | 9 | 10 | # Reshape + Concat layer 11 | 12 | class Reshape_Concat_Adap(torch.autograd.Function): 13 | blocksize = 0 14 | 15 | def __init__(self, block_size): 16 | # super(Reshape_Concat_Adap, self).__init__() 17 | Reshape_Concat_Adap.blocksize = block_size 18 | 19 | @staticmethod 20 | def forward(ctx, input_, ): 21 | ctx.save_for_backward(input_) 22 | 23 | data = torch.clone(input_.data) 24 | b_ = data.shape[0] 25 | c_ = data.shape[1] 26 | w_ = data.shape[2] 27 | h_ = data.shape[3] 28 | 29 | output = torch.zeros((b_, int(c_ / Reshape_Concat_Adap.blocksize / Reshape_Concat_Adap.blocksize), 30 | int(w_ * Reshape_Concat_Adap.blocksize), int(h_ * Reshape_Concat_Adap.blocksize))).cuda() 31 | 32 | for i in range(0, w_): 33 | for j in range(0, h_): 34 | data_temp = data[:, :, i, j] 35 | # data_temp = torch.zeros(data_t.shape).cuda() + data_t 36 | # data_temp = data_temp.contiguous() 37 | data_temp = data_temp.view((b_, int(c_ / Reshape_Concat_Adap.blocksize / Reshape_Concat_Adap.blocksize), 38 | Reshape_Concat_Adap.blocksize, Reshape_Concat_Adap.blocksize)) 39 | # print data_temp.shape 40 | output[:, :, i * Reshape_Concat_Adap.blocksize:(i + 1) * Reshape_Concat_Adap.blocksize, 41 | j * Reshape_Concat_Adap.blocksize:(j + 1) * Reshape_Concat_Adap.blocksize] += data_temp 42 | 43 | return output 44 | 45 | @staticmethod 46 | def backward(ctx, grad_output): 47 | inp, = ctx.saved_tensors 48 | input_ = torch.clone(inp.data) 49 | grad_input = torch.clone(grad_output.data) 50 | 51 | b_ = input_.shape[0] 52 | c_ = input_.shape[1] 53 | w_ = input_.shape[2] 54 | h_ = input_.shape[3] 55 | 56 | output = torch.zeros((b_, c_, w_, h_)).cuda() 57 | output = output.view(b_, c_, w_, h_) 58 | for i in range(0, w_): 59 | for j in range(0, h_): 60 | data_temp = grad_input[:, :, i * Reshape_Concat_Adap.blocksize:(i + 1) * Reshape_Concat_Adap.blocksize, 61 | j * Reshape_Concat_Adap.blocksize:(j + 1) * Reshape_Concat_Adap.blocksize] 62 | # data_temp = torch.zeros(data_t.shape).cuda() + data_t 63 | data_temp = data_temp.contiguous() 64 | data_temp = data_temp.view((b_, c_, 1, 1)) 65 | output[:, :, i, j] += torch.squeeze(data_temp) 66 | 67 | return Variable(output) 68 | 69 | 70 | def My_Reshape_Adap(input, blocksize): 71 | return Reshape_Concat_Adap(blocksize).apply(input) 72 | 73 | 74 | # The residualblock for reconstruction network 75 | class ResidualBlock(nn.Module): 76 | def __init__(self, channels, has_BN = False): 77 | super(ResidualBlock, self).__init__() 78 | self.has_BN = has_BN 79 | self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) 80 | if has_BN: 81 | self.bn1 = nn.BatchNorm2d(channels) 82 | self.prelu = nn.PReLU() 83 | self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) 84 | if has_BN: 85 | self.bn2 = nn.BatchNorm2d(channels) 86 | 87 | def forward(self, x): 88 | residual = self.conv1(x) 89 | if self.has_BN: 90 | residual = self.bn1(residual) 91 | residual = self.prelu(residual) 92 | residual = self.conv2(residual) 93 | if self.has_BN: 94 | residual = self.bn2(residual) 95 | 96 | return x + residual 97 | 98 | 99 | # code of CSNet 100 | class CSNet(nn.Module): 101 | def __init__(self, blocksize=32, subrate=0.1): 102 | 103 | super(CSNet, self).__init__() 104 | self.blocksize = blocksize 105 | 106 | # for sampling 107 | self.sampling = nn.Conv2d(1, int(np.round(blocksize*blocksize*subrate)), blocksize, stride=blocksize, padding=0, bias=False) 108 | # upsampling 109 | self.upsampling = nn.Conv2d(int(np.round(blocksize*blocksize*subrate)), blocksize*blocksize, 1, stride=1, padding=0) 110 | 111 | # reconstruction network 112 | self.conv1 = nn.Sequential( 113 | nn.Conv2d(1, 64, kernel_size=3, padding=1), 114 | nn.PReLU() 115 | ) 116 | self.conv2 = nn.Sequential( 117 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 118 | nn.PReLU() 119 | ) 120 | self.conv3 = nn.Sequential( 121 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 122 | nn.PReLU() 123 | ) 124 | self.conv4 = nn.Sequential( 125 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 126 | nn.PReLU() 127 | ) 128 | self.conv5 = nn.Conv2d(64, 1, kernel_size=3, padding=1) 129 | 130 | def forward(self, x): 131 | x = self.sampling(x) 132 | x = self.upsampling(x) 133 | x = My_Reshape_Adap(x, self.blocksize) # Reshape + Concat 134 | 135 | block1 = self.conv1(x) 136 | block2 = self.conv2(block1) 137 | block3 = self.conv3(block2) 138 | block4 = self.conv4(block3) 139 | block5 = self.conv5(block4) 140 | 141 | return block5 142 | 143 | 144 | # code of CSNet_Enhanced (Enhanced version of CSNet) 145 | class CSNet_Enhanced(nn.Module): 146 | def __init__(self, blocksize=32, subrate=0.1): 147 | 148 | super(CSNet_Enhanced, self).__init__() 149 | self.blocksize = blocksize 150 | 151 | # for sampling 152 | self.sampling = nn.Conv2d(1, int(np.round(blocksize*blocksize*subrate)), blocksize, stride=blocksize, padding=0, bias=False) 153 | # upsampling 154 | self.upsampling = nn.Conv2d(int(np.round(blocksize*blocksize*subrate)), blocksize*blocksize, 1, stride=1, padding=0) 155 | 156 | # reconstruction network 157 | self.block1 = nn.Sequential( 158 | nn.Conv2d(1, 64, kernel_size=7, padding=3), 159 | nn.PReLU() 160 | ) 161 | self.block2 = ResidualBlock(64, has_BN=True) 162 | self.block3 = ResidualBlock(64, has_BN=True) 163 | self.block4 = ResidualBlock(64, has_BN=True) 164 | self.block5 = ResidualBlock(64, has_BN=True) 165 | self.block6 = ResidualBlock(64, has_BN=True) 166 | self.block7 = nn.Sequential( 167 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 168 | nn.PReLU() 169 | ) 170 | self.block8 = nn.Conv2d(64, 1, kernel_size=3, padding=1) 171 | 172 | def forward(self, x): 173 | x = self.sampling(x) 174 | x = self.upsampling(x) 175 | x = My_Reshape_Adap(x, self.blocksize) # Reshape + Concat 176 | 177 | block1 = self.block1(x) 178 | block2 = self.block2(block1) 179 | block3 = self.block3(block2) 180 | block4 = self.block4(block3) 181 | block5 = self.block5(block4) 182 | block6 = self.block6(block5) 183 | block7 = self.block7(block6) 184 | block8 = self.block8(block1 + block7) 185 | 186 | return block8 187 | 188 | 189 | 190 | if __name__ == '__main__': 191 | import torch 192 | 193 | img = torch.randn(1, 1, 32, 32) 194 | net = CSNet() 195 | out = net(img) 196 | print(out.size()) 197 | 198 | -------------------------------------------------------------------------------- /models_subrate_0.1_blocksize_32/README.md: -------------------------------------------------------------------------------- 1 | In the training stage, the produced model (subrate=0.1, blocksize=32) is saved in this folder. 2 | -------------------------------------------------------------------------------- /results_subrate_0.1_blocksize_32/README.md: -------------------------------------------------------------------------------- 1 | In testing stage, this folder is used for saving revealed output. 2 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as Data 3 | import torchvision 4 | import torch.optim as optim 5 | from torch.utils.data import DataLoader 6 | from lib.network import CSNet 7 | from torch import nn 8 | import time 9 | import os 10 | 11 | import argparse 12 | from tqdm import tqdm 13 | 14 | from data_utils import TestDatasetFromFolder, psnr 15 | import torchvision.transforms as transforms 16 | from torch.autograd import Variable 17 | from torchvision.transforms import ToPILImage 18 | 19 | 20 | parser = argparse.ArgumentParser(description='Train Super Resolution Models') 21 | parser.add_argument('--block_size', default=32, type=int, help='CS block size') 22 | parser.add_argument('--save_img', default=1, type=int, help='') 23 | 24 | parser.add_argument('--sub_rate', default=0.1, type=float, help='sampling sub rate') 25 | 26 | parser.add_argument('--NetWeights', type=str, default='epochs_subrate_0.1_blocksize_32/net_epoch_200_0.001724.pth', help="path of CSNet weights for testing") 27 | 28 | opt = parser.parse_args() 29 | 30 | BLOCK_SIZE = opt.block_size 31 | 32 | val_set = TestDatasetFromFolder('/media/gdh-95/data/Set14', blocksize=BLOCK_SIZE) 33 | val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False) 34 | 35 | net = CSNet(BLOCK_SIZE, opt.sub_rate) 36 | mse_loss = nn.MSELoss() 37 | 38 | if opt.NetWeights != '': 39 | net.load_state_dict(torch.load(opt.NetWeights)) 40 | 41 | if torch.cuda.is_available(): 42 | net.cuda() 43 | mse_loss.cuda() 44 | 45 | for epoch in range(1, 1+1): 46 | train_bar = tqdm(val_loader) 47 | running_results = {'batch_sizes': 0, 'g_loss': 0, } 48 | 49 | save_dir = 'results' + '_subrate_' + str(opt.sub_rate) + '_blocksize_' + str( 50 | BLOCK_SIZE) 51 | if not os.path.exists(save_dir): 52 | os.makedirs(save_dir) 53 | 54 | net.eval() 55 | psnrs = 0.0 56 | img_id = 0 57 | 58 | for data, target in train_bar: 59 | batch_size = data.size(0) 60 | if batch_size <= 0: 61 | continue 62 | 63 | running_results['batch_sizes'] += batch_size 64 | img_id += 1 65 | 66 | real_img = Variable(target) 67 | if torch.cuda.is_available(): 68 | real_img = real_img.cuda() 69 | z = Variable(data) 70 | if torch.cuda.is_available(): 71 | z = z.cuda() 72 | fake_img = net(z) 73 | fake_img[fake_img>1] = 1 74 | fake_img[fake_img<0] = 0 75 | 76 | psnr_t = psnr(fake_img.data.cpu(), real_img.data.cpu()) 77 | psnrs += psnr_t 78 | 79 | g_loss = mse_loss(fake_img, real_img) 80 | 81 | running_results['g_loss'] += g_loss.item() * batch_size 82 | 83 | train_bar.set_description(desc='[%d] Loss_G: %.4f' % ( 84 | epoch, running_results['g_loss'] / running_results['batch_sizes'])) 85 | 86 | if opt.save_img > 0: 87 | res = fake_img.data.cpu() 88 | res = torch.squeeze(res, 0) 89 | res = ToPILImage()(res) 90 | res.save(save_dir + '/res_'+str(img_id)+'_'+str(psnr_t)+'.png') 91 | 92 | print("averate psnrs is: ", psnrs/img_id) 93 | -------------------------------------------------------------------------------- /test_new.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from torch.autograd import Variable 4 | import numpy as np 5 | import time, math, glob 6 | import scipy.io as sio 7 | from lib.network import CSNet 8 | 9 | 10 | parser = argparse.ArgumentParser(description="PyTorch LapSRN Eval") 11 | parser.add_argument("--cuda", action="store_true", help="use cuda?") 12 | parser.add_argument("--model", default="epochs_subrate_0.1_blocksize_32/net_epoch_195_0.001642.pth", type=str, help="model path") 13 | parser.add_argument("--dataset", default="Test/Set5_mat", type=str, help="dataset name, Default: Set5") 14 | parser.add_argument('--block_size', default=32, type=int, help='CS block size') 15 | parser.add_argument('--sub_rate', default=0.1, type=float, help='sampling sub rate') 16 | 17 | 18 | def PSNR(pred, gt, shave_border=0): 19 | height, width = pred.shape[:2] 20 | pred = pred[shave_border:height - shave_border, shave_border:width - shave_border] 21 | gt = gt[shave_border:height - shave_border, shave_border:width - shave_border] 22 | imdff = pred - gt 23 | rmse = math.sqrt(np.mean(imdff ** 2)) 24 | if rmse == 0: 25 | return 100 26 | return 20 * math.log10(255.0 / rmse) 27 | 28 | opt = parser.parse_args() 29 | cuda = opt.cuda 30 | 31 | if cuda and not torch.cuda.is_available(): 32 | raise Exception("No GPU found, please run without --cuda") 33 | 34 | model = CSNet(opt.block_size, opt.sub_rate) 35 | 36 | if opt.model != '': 37 | model.load_state_dict(torch.load(opt.model)) 38 | 39 | 40 | image_list = glob.glob(opt.dataset+"/*.*") 41 | 42 | avg_psnr_predicted = 0.0 43 | avg_elapsed_time = 0.0 44 | 45 | for image_name in image_list: 46 | print("Processing ", image_name) 47 | im_gt_y = sio.loadmat(image_name)['im_gt_y'] 48 | 49 | im_gt_y = im_gt_y.astype(float) 50 | 51 | im_input = im_gt_y/255. 52 | 53 | im_input = Variable(torch.from_numpy(im_input).float()).view(1, -1, im_input.shape[0], im_input.shape[1]) 54 | 55 | if cuda: 56 | model = model.cuda() 57 | im_input = im_input.cuda() 58 | else: 59 | model = model.cpu() 60 | 61 | start_time = time.time() 62 | res = model(im_input) 63 | elapsed_time = time.time() - start_time 64 | avg_elapsed_time += elapsed_time 65 | 66 | res = res.cpu() 67 | 68 | im_res_y = res.data[0].numpy().astype(np.float32) 69 | 70 | im_res_y = im_res_y*255. 71 | im_res_y[im_res_y<0] = 0 72 | im_res_y[im_res_y>255.] = 255. 73 | im_res_y = im_res_y[0,:,:] 74 | 75 | psnr_predicted = PSNR(im_gt_y, im_res_y,shave_border=0) 76 | print(psnr_predicted) 77 | avg_psnr_predicted += psnr_predicted 78 | 79 | print("Dataset=", opt.dataset) 80 | print("PSNR_predicted=", avg_psnr_predicted/len(image_list)) 81 | print("It takes average {}s for processing".format(avg_elapsed_time/len(image_list))) 82 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as Data 3 | import torchvision 4 | import torch.optim as optim 5 | from torch.utils.data import DataLoader 6 | from lib.network import CSNet 7 | from torch import nn 8 | import time 9 | import os 10 | 11 | import argparse 12 | from tqdm import tqdm 13 | 14 | from data_utils import TrainDatasetFromFolder 15 | import torchvision.transforms as transforms 16 | from torch.autograd import Variable 17 | 18 | 19 | parser = argparse.ArgumentParser(description='Train Super Resolution Models') 20 | parser.add_argument('--crop_size', default=96, type=int, help='training images crop size') 21 | parser.add_argument('--block_size', default=32, type=int, help='CS block size') 22 | parser.add_argument('--pre_epochs', default=200, type=int, help='pre train epoch number') 23 | parser.add_argument('--num_epochs', default=300, type=int, help='train epoch number') 24 | 25 | parser.add_argument('--batchSize', default=64, type=int, help='train batch size') 26 | parser.add_argument('--sub_rate', default=0.1, type=float, help='sampling sub rate') 27 | 28 | parser.add_argument('--loadEpoch', default=0, type=int, help='load epoch number') 29 | parser.add_argument('--generatorWeights', type=str, default='', help="path to CSNet weights (to continue training)") 30 | 31 | opt = parser.parse_args() 32 | 33 | CROP_SIZE = opt.crop_size 34 | BLOCK_SIZE = opt.block_size 35 | NUM_EPOCHS = opt.num_epochs 36 | PRE_EPOCHS = opt.pre_epochs 37 | LOAD_EPOCH = 0 38 | 39 | 40 | train_set = TrainDatasetFromFolder('/media/gdh-95/data/Train', crop_size=CROP_SIZE, blocksize=BLOCK_SIZE) 41 | train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=opt.batchSize, shuffle=True) 42 | 43 | net = CSNet(BLOCK_SIZE, opt.sub_rate) 44 | 45 | mse_loss = nn.MSELoss() 46 | 47 | if opt.generatorWeights != '': 48 | net.load_state_dict(torch.load(opt.generatorWeights)) 49 | LOAD_EPOCH = opt.loadEpoch 50 | 51 | if torch.cuda.is_available(): 52 | net.cuda() 53 | mse_loss.cuda() 54 | 55 | optimizer = optim.Adam(net.parameters(), lr=0.0004, betas=(0.9, 0.999)) 56 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5) 57 | 58 | for epoch in range(LOAD_EPOCH, NUM_EPOCHS + 1): 59 | train_bar = tqdm(train_loader) 60 | running_results = {'batch_sizes': 0, 'g_loss': 0, } 61 | 62 | net.train() 63 | scheduler.step() 64 | 65 | for data, target in train_bar: 66 | batch_size = data.size(0) 67 | if batch_size <= 0: 68 | continue 69 | 70 | running_results['batch_sizes'] += batch_size 71 | 72 | real_img = Variable(target) 73 | if torch.cuda.is_available(): 74 | real_img = real_img.cuda() 75 | z = Variable(data) 76 | if torch.cuda.is_available(): 77 | z = z.cuda() 78 | fake_img = net(z) 79 | optimizer.zero_grad() 80 | g_loss = mse_loss(fake_img, real_img) 81 | 82 | g_loss.backward() 83 | optimizer.step() 84 | 85 | running_results['g_loss'] += g_loss.item() * batch_size 86 | 87 | train_bar.set_description(desc='[%d] Loss_G: %.4f lr: %.7f' % ( 88 | epoch, running_results['g_loss'] / running_results['batch_sizes'], optimizer.param_groups[0]['lr'])) 89 | 90 | # for saving model 91 | save_dir = 'epochs' + '_subrate_' + str(opt.sub_rate) + '_blocksize_' + str(BLOCK_SIZE) 92 | if not os.path.exists(save_dir): 93 | os.makedirs(save_dir) 94 | if epoch % 5 == 0: 95 | torch.save(net.state_dict(), save_dir + '/net_epoch_%d_%6f.pth' % (epoch, running_results['g_loss']/running_results['batch_sizes'])) 96 | 97 | --------------------------------------------------------------------------------