├── color model ├── dataLoadess.py ├── largescale_rgb │ ├── data_gene.m │ └── mask.mat ├── model │ └── color_model │ │ └── RevSCInet_model_epoch_20.pth ├── models.py ├── my_tools.py ├── readme.md ├── test_large.py ├── train.py └── utils.py ├── dataLoadess.py ├── model └── model │ └── RevSCInet_model_epoch_100.pth ├── models.py ├── my_tools.py ├── readme.md ├── test.py ├── test ├── aerial32_cacti.mat ├── crash32_cacti.mat ├── drop8_cacti.mat ├── kobe_cacti.mat ├── runner8_cacti.mat └── traffic_cacti.mat ├── train.py ├── train └── mask.mat └── utils.py /color model/dataLoadess.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.data import Dataset 3 | import os 4 | import torch 5 | import scipy.io as scio 6 | 7 | 8 | class Imgdataset(Dataset): 9 | 10 | def __init__(self, path): 11 | super(Imgdataset, self).__init__() 12 | self.data = [] 13 | if os.path.exists(path): 14 | dir_list = os.listdir(path) 15 | groung_truth_path = path + '/gt' 16 | measurement_path = path + '/measurement' 17 | 18 | if os.path.exists(groung_truth_path): 19 | groung_truth = os.listdir(groung_truth_path) 20 | # measurement = os.listdir(measurement_path) 21 | self.data = [{'groung_truth': groung_truth_path + '/' + groung_truth[i]} for i in 22 | range(len(groung_truth))] 23 | else: 24 | raise FileNotFoundError('path doesnt exist!') 25 | else: 26 | raise FileNotFoundError('path doesnt exist!') 27 | 28 | def __getitem__(self, index): 29 | # print(index) 30 | groung_truth = self.data[index]["groung_truth"] 31 | 32 | gt = scio.loadmat(groung_truth) 33 | # meas = scio.loadmat(measurement) 34 | if "patch_save" in gt: 35 | gt = torch.from_numpy(gt['patch_save'] / 255) 36 | elif "p1" in gt: 37 | gt = torch.from_numpy(gt['p1'] / 255) 38 | elif "p2" in gt: 39 | gt = torch.from_numpy(gt['p2'] / 255) 40 | elif "p3" in gt: 41 | gt = torch.from_numpy(gt['p3'] / 255) 42 | 43 | # meas = torch.from_numpy(meas['meas'] / 255) 44 | 45 | gt = gt.permute(2, 3, 0, 1) 46 | 47 | # print(tran(img).shape) 48 | 49 | return gt 50 | 51 | def __len__(self): 52 | 53 | return len(self.data) 54 | -------------------------------------------------------------------------------- /color model/largescale_rgb/data_gene.m: -------------------------------------------------------------------------------- 1 | clc,clear 2 | close all 3 | 4 | % set DAVIS original video path 5 | video_path='E:\czh\data\sci\video dataset\DAVIS1080p'; 6 | % set your saving path 7 | save_path='.\'; 8 | % load the mask 9 | load('mask.mat') 10 | mask=double(mask); 11 | 12 | resolution='Full-Resolution/'; 13 | block_size=[1080,1920]; 14 | compress_frame=24; 15 | 16 | r=[1,0;0,0];g1=[0,1;0,0];g2=[0,0;1,0];b=[0,0;0,1]; 17 | rggb=cat(3,r,g1+g2,b); 18 | rgb2raw=repmat(rggb,block_size(1)/2,block_size(2)/2); 19 | 20 | gt_save_path=[save_path,'gt/']; 21 | meas_save_path=[save_path,'measurement/']; 22 | if exist(gt_save_path,'dir')==0 23 | mkdir(gt_save_path); 24 | end 25 | if exist(meas_save_path,'dir')==0 26 | mkdir(meas_save_path); 27 | end 28 | save(strcat(save_path,'mask.mat'),'mask') 29 | 30 | num_yb=1; 31 | name_obj=dir([video_path,'/JPEGImages/',resolution]); 32 | for ii=3:length(name_obj) 33 | path=[video_path,'/JPEGImages/',resolution,name_obj(ii).name]; 34 | name_frame=dir(path); 35 | pic1=imread([path,'/',name_frame(3).name]); 36 | w=size(pic1); 37 | 38 | if w(1) Epoch {} Complete: Avg. Loss: {:.7f}".format(epoch, epoch_loss / len(train_data_loader)), 171 | " time: {:.2f}".format(end - begin)) 172 | 173 | 174 | def checkpoint(epoch, model_path): 175 | model_out_path = './' + model_path + '/' + "RevSCInet_model_epoch_{}.pth".format(epoch) 176 | torch.save(rev_net, model_out_path) 177 | print("Checkpoint saved to {}".format(model_out_path)) 178 | 179 | 180 | def main(model, args): 181 | date_time = str(datetime.datetime.now()) 182 | date_time = time2file_name(date_time) 183 | result_path = 'recon' + '/' + date_time 184 | model_path = 'model' + '/' + date_time 185 | if not os.path.exists(result_path): 186 | os.makedirs(result_path) 187 | if not os.path.exists(model_path): 188 | os.makedirs(model_path) 189 | 190 | r = np.array([[1, 0], [0, 0]]) 191 | g1 = np.array([[0, 1], [0, 0]]) 192 | g2 = np.array([[0, 0], [1, 0]]) 193 | b = np.array([[0, 0], [0, 1]]) 194 | rgb2raw = np.zeros([3, args.size[0], args.size[1]]) 195 | rgb2raw[0, :, :] = np.tile(r, (args.size[0] // 2, args.size[1] // 2)) 196 | rgb2raw[1, :, :] = np.tile(g1, (args.size[0] // 2, args.size[1] // 2)) + np.tile(g2, ( 197 | args.size[0] // 2, args.size[1] // 2)) 198 | rgb2raw[2, :, :] = np.tile(b, (args.size[0] // 2, args.size[1] // 2)) 199 | rgb2raw = torch.from_numpy(rgb2raw).cuda().float() 200 | 201 | for epoch in range(args.last_train + 1, args.last_train + args.max_iter + 1): 202 | train(epoch, result_path, model, args, rgb2raw) 203 | if (epoch % 5 == 0) and (epoch < 150): 204 | args.learning_rate = args.learning_rate * 0.95 205 | print(args.learning_rate) 206 | if (epoch % 5 == 0 or epoch > 0): 207 | model = model.module if hasattr(model, "module") else model 208 | checkpoint(epoch, model_path) 209 | if n_gpu > 1: 210 | model = torch.nn.DataParallel(model) 211 | 212 | 213 | if __name__ == '__main__': 214 | print(args.mode) 215 | print(args.learning_rate) 216 | 217 | rev_net = re_3dcnn1(18).cuda() 218 | rev_net.mask = mask 219 | if n_gpu > 1: 220 | rev_net = torch.nn.DataParallel(rev_net) 221 | if args.last_train != 0: 222 | rev_net = torch.load( 223 | './model/' + args.model_save_filename + "/RevSCInet_model_epoch_{}.pth".format(args.last_train)) 224 | rev_net = rev_net.module if hasattr(rev_net, "module") else rev_net 225 | 226 | main(rev_net, args) 227 | -------------------------------------------------------------------------------- /color model/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import scipy.io as scio 4 | import numpy as np 5 | 6 | 7 | def generate_masks(mask_path): 8 | mask = scio.loadmat(mask_path + '/mask.mat') 9 | mask = mask['mask'] 10 | mask = np.transpose(mask, [2, 0, 1]) 11 | mask_s = np.sum(mask, axis=0) 12 | index = np.where(mask_s == 0) 13 | mask_s[index] = 1 14 | mask_s = mask_s.astype(np.uint8) 15 | mask = torch.from_numpy(mask) 16 | mask = mask.float() 17 | mask = mask.cuda() 18 | mask_s = torch.from_numpy(mask_s) 19 | mask_s = mask_s.float() 20 | mask_s = mask_s.cuda() 21 | return mask, mask_s 22 | 23 | 24 | def split_masks(mask, scale, args): 25 | mask_list = list() 26 | for i in range(scale): 27 | for j in range(scale): 28 | if len(mask.shape) == 3: 29 | mask_list.append(mask[:, j * args.size[0] // scale:(j + 1) * args.size[0] // scale, 30 | i * args.size[1] // scale:(i + 1) * args.size[1] // scale]) 31 | 32 | elif len(mask.shape) == 2: 33 | mask_list.append(mask[j * args.size[0] // scale:(j + 1) * args.size[0] // scale, 34 | i * args.size[1] // scale:(i + 1) * args.size[1] // scale]) 35 | elif len(mask.shape) == 4: 36 | mask_list.append(mask[:, :, j * args.size[0] // scale:(j + 1) * args.size[0] // scale, 37 | i * args.size[1] // scale:(i + 1) * args.size[1] // scale]) 38 | 39 | return mask_list 40 | 41 | 42 | def time2file_name(time): 43 | year = time[0:4] 44 | month = time[5:7] 45 | day = time[8:10] 46 | hour = time[11:13] 47 | minute = time[14:16] 48 | second = time[17:19] 49 | time_filename = year + '_' + month + '_' + day + '_' + hour + '_' + minute + '_' + second 50 | return time_filename 51 | -------------------------------------------------------------------------------- /dataLoadess.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import os 3 | import torch 4 | import scipy.io as scio 5 | 6 | 7 | class Imgdataset(Dataset): 8 | 9 | def __init__(self, path): 10 | super(Imgdataset, self).__init__() 11 | self.data = [] 12 | if os.path.exists(path): 13 | groung_truth_path = path + '/gt' 14 | 15 | if os.path.exists(groung_truth_path): 16 | groung_truth = os.listdir(groung_truth_path) 17 | self.data = [{'groung_truth': groung_truth_path + '/' + groung_truth[i]} for i in 18 | range(len(groung_truth))] 19 | else: 20 | raise FileNotFoundError('path doesnt exist!') 21 | else: 22 | raise FileNotFoundError('path doesnt exist!') 23 | 24 | def __getitem__(self, index): 25 | groung_truth = self.data[index]["groung_truth"] 26 | 27 | gt = scio.loadmat(groung_truth) 28 | if "patch_save" in gt: 29 | gt = torch.from_numpy(gt['patch_save'] / 255) 30 | elif "p1" in gt: 31 | gt = torch.from_numpy(gt['p1'] / 255) 32 | elif "p2" in gt: 33 | gt = torch.from_numpy(gt['p2'] / 255) 34 | elif "p3" in gt: 35 | gt = torch.from_numpy(gt['p3'] / 255) 36 | 37 | gt = gt.permute(2, 0, 1) 38 | 39 | return gt 40 | 41 | def __len__(self): 42 | 43 | return len(self.data) 44 | -------------------------------------------------------------------------------- /model/model/RevSCInet_model_epoch_100.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenGroup/RevSCI-net/71ac125ab47dce2e4c091936e3a659900b7da258/model/model/RevSCInet_model_epoch_100.pth -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from my_tools import * 2 | 3 | 4 | class re_3dcnn(nn.Module): 5 | 6 | def __init__(self, args): 7 | super(re_3dcnn, self).__init__() 8 | self.conv1 = nn.Sequential( 9 | nn.Conv3d(1, 16, kernel_size=5, stride=1, padding=2), 10 | nn.LeakyReLU(inplace=True), 11 | nn.Conv3d(16, 32, kernel_size=3, stride=1, padding=1), 12 | nn.LeakyReLU(inplace=True), 13 | nn.Conv3d(32, 32, kernel_size=1, stride=1), 14 | nn.LeakyReLU(inplace=True), 15 | nn.Conv3d(32, 64, kernel_size=3, stride=(1, 2, 2), padding=1), 16 | nn.LeakyReLU(inplace=True), 17 | ) 18 | self.conv2 = nn.Sequential( 19 | nn.ConvTranspose3d(64, 32, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), 20 | output_padding=(0, 1, 1)), 21 | nn.LeakyReLU(inplace=True), 22 | nn.Conv3d(32, 32, kernel_size=3, stride=1, padding=1), 23 | nn.LeakyReLU(inplace=True), 24 | nn.Conv3d(32, 16, kernel_size=1, stride=1), 25 | nn.LeakyReLU(inplace=True), 26 | nn.Conv3d(16, 1, kernel_size=3, stride=1, padding=1), 27 | ) 28 | 29 | self.layers = nn.ModuleList() 30 | for i in range(args.num_block): 31 | self.layers.append(rev_3d_part1(64, args.num_group)) 32 | 33 | def forward(self, meas_re, args): 34 | 35 | batch_size = meas_re.shape[0] 36 | mask = self.mask.to(meas_re.device) 37 | maskt = mask.expand([batch_size, args.B, args.size[0], args.size[1]]) 38 | maskt = maskt.mul(meas_re) 39 | data = meas_re + maskt 40 | out = self.conv1(torch.unsqueeze(data, 1)) 41 | 42 | for layer in self.layers: 43 | out = layer(out) 44 | 45 | out = self.conv2(out) 46 | 47 | return out 48 | 49 | def for_backward(self, mask, meas_re, gt, loss, opt, args): 50 | batch_size = meas_re.shape[0] 51 | maskt = mask.expand([batch_size, args.B, args.size[0], args.size[1]]) 52 | maskt = maskt.mul(meas_re) 53 | data = meas_re + maskt 54 | data = torch.unsqueeze(data, 1) 55 | 56 | with torch.no_grad(): 57 | out1 = self.conv1(data) 58 | out2 = out1 59 | for layer in self.layers: 60 | out2 = layer(out2) 61 | out3 = out2.requires_grad_() 62 | out4 = self.conv2(out3) 63 | 64 | loss1 = loss(torch.squeeze(out4), gt) 65 | loss1.backward() 66 | current_state_grad = out3.grad 67 | 68 | out_current = out3 69 | for layer in reversed(self.layers): 70 | with torch.no_grad(): 71 | out_pre = layer.reverse(out_current) 72 | out_pre.requires_grad_() 73 | out_cur = layer(out_pre) 74 | torch.autograd.backward(out_cur, grad_tensors=current_state_grad) 75 | current_state_grad = out_pre.grad 76 | out_current = out_pre 77 | 78 | out1 = self.conv1(data) 79 | out1.requires_grad_() 80 | torch.autograd.backward(out1, grad_tensors=current_state_grad) 81 | if opt != 0: 82 | opt.step() 83 | 84 | return out4, loss1 85 | 86 | class re_3dcnn1(nn.Module): 87 | 88 | def __init__(self, args): 89 | super(re_3dcnn1, self).__init__() 90 | self.conv1 = nn.Sequential( 91 | nn.Conv3d(1, 16, kernel_size=5, stride=1, padding=2), 92 | nn.LeakyReLU(inplace=True), 93 | nn.Conv3d(16, 32, kernel_size=3, stride=1, padding=1), 94 | nn.LeakyReLU(inplace=True), 95 | nn.Conv3d(32, 32, kernel_size=1, stride=1), 96 | nn.LeakyReLU(inplace=True), 97 | nn.Conv3d(32, 64, kernel_size=3, stride=(1, 2, 2), padding=1), 98 | nn.LeakyReLU(inplace=True), 99 | ) 100 | self.conv2 = nn.Sequential( 101 | nn.ConvTranspose3d(64, 32, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), 102 | output_padding=(0, 1, 1)), 103 | nn.LeakyReLU(inplace=True), 104 | nn.Conv3d(32, 32, kernel_size=3, stride=1, padding=1), 105 | nn.LeakyReLU(inplace=True), 106 | nn.Conv3d(32, 16, kernel_size=1, stride=1), 107 | nn.LeakyReLU(inplace=True), 108 | nn.Conv3d(16, 1, kernel_size=3, stride=1, padding=1), 109 | ) 110 | 111 | self.layers = nn.ModuleList() 112 | for i in range(args.num_block): 113 | self.layers.append(rev_3d_part(32)) 114 | 115 | def forward(self, meas_re, args): 116 | 117 | batch_size = meas_re.shape[0] 118 | mask = self.mask.to(meas_re.device) 119 | maskt = mask.expand([batch_size, args.B, args.size[0], args.size[1]]) 120 | maskt = maskt.mul(meas_re) 121 | data = meas_re + maskt 122 | out = self.conv1(torch.unsqueeze(data, 1)) 123 | 124 | for layer in self.layers: 125 | out = layer(out) 126 | 127 | out = self.conv2(out) 128 | 129 | return out 130 | 131 | def for_backward(self, mask, meas_re, gt, loss, opt, args): 132 | batch_size = meas_re.shape[0] 133 | maskt = mask.expand([batch_size, args.B, args.size[0], args.size[1]]) 134 | maskt = maskt.mul(meas_re) 135 | data = meas_re + maskt 136 | data = torch.unsqueeze(data, 1) 137 | 138 | with torch.no_grad(): 139 | out1 = self.conv1(data) 140 | out2 = out1 141 | for layer in self.layers: 142 | out2 = layer(out2) 143 | out3 = out2.requires_grad_() 144 | out4 = self.conv2(out3) 145 | 146 | loss1 = loss(torch.squeeze(out4), gt) 147 | loss1.backward() 148 | current_state_grad = out3.grad 149 | 150 | out_current = out3 151 | for layer in reversed(self.layers): 152 | with torch.no_grad(): 153 | out_pre = layer.reverse(out_current) 154 | out_pre.requires_grad_() 155 | out_cur = layer(out_pre) 156 | torch.autograd.backward(out_cur, grad_tensors=current_state_grad) 157 | current_state_grad = out_pre.grad 158 | out_current = out_pre 159 | 160 | out1 = self.conv1(data) 161 | out1.requires_grad_() 162 | torch.autograd.backward(out1, grad_tensors=current_state_grad) 163 | if opt != 0: 164 | opt.step() 165 | 166 | return out4, loss1 167 | -------------------------------------------------------------------------------- /my_tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def split_feature(x): 6 | l = x.shape[1] 7 | x1 = x[:, 0:l // 2, ::] 8 | x2 = x[:, l // 2:, ::] 9 | return x1, x2 10 | 11 | 12 | def split_n_features(x, n): 13 | x_list = list(torch.chunk(x, n, dim=1)) 14 | return x_list 15 | 16 | 17 | class rev_part(nn.Module): 18 | 19 | def __init__(self, in_ch): 20 | super(rev_part, self).__init__() 21 | self.f1 = nn.Sequential( 22 | nn.Conv2d(in_ch, in_ch, 3, padding=1), 23 | nn.LeakyReLU(inplace=True), 24 | nn.Conv2d(in_ch, in_ch, 3, padding=1), 25 | ) 26 | self.g1 = nn.Sequential( 27 | nn.Conv2d(in_ch, in_ch, 3, padding=1), 28 | nn.LeakyReLU(inplace=True), 29 | nn.Conv2d(in_ch, in_ch, 3, padding=1), 30 | ) 31 | 32 | def forward(self, x): 33 | x1, x2 = split_feature(x) 34 | y1 = x1 + self.f1(x2) 35 | y2 = x2 + self.g1(y1) 36 | y = torch.cat([y1, y2], dim=1) 37 | return y 38 | 39 | def reverse(self, y): 40 | y1, y2 = split_feature(y) 41 | x2 = y2 - self.g1(y1) 42 | x1 = y1 - self.f1(x2) 43 | x = torch.cat([x1, x2], dim=1) 44 | return x 45 | 46 | 47 | class f_g_layer(nn.Module): 48 | def __init__(self, ch): 49 | super(f_g_layer, self).__init__() 50 | self.nn_layer = nn.Sequential( 51 | nn.Conv3d(ch, ch, 3, padding=1), 52 | nn.LeakyReLU(inplace=True), 53 | nn.Conv3d(ch, ch, 3, padding=1), 54 | ) 55 | 56 | def forward(self, x): 57 | x = self.nn_layer(x) 58 | return x 59 | 60 | 61 | class rev_3d_part1(nn.Module): 62 | 63 | def __init__(self, in_ch, n): 64 | super(rev_3d_part1, self).__init__() 65 | self.f = nn.ModuleList() 66 | self.n = n 67 | self.ch = in_ch 68 | for i in range(n): 69 | self.f.append(f_g_layer(in_ch // n)) 70 | 71 | def forward(self, x): 72 | x = split_n_features(x, self.n) 73 | y1 = x[-1] + self.f[0](x[0]) 74 | y = y1 75 | for i in range(1, self.n): 76 | y1 = x[(self.n - 1 - i)] + self.f[i](y1) 77 | y = torch.cat([y, y1], dim=1) 78 | return y 79 | 80 | def reverse(self, y): 81 | y = split_n_features(y, self.n) 82 | for i in range(1, self.n): 83 | x1 = y[self.n - i] - self.f[self.n - i](y[self.n - i - 1]) 84 | if i == 1: 85 | x = x1 86 | else: 87 | x = torch.cat([x, x1], dim=1) 88 | x1 = y[0] - self.f[0](x[:, 0:(self.ch // self.n), ::]) 89 | x = torch.cat([x, x1], dim=1) 90 | return x 91 | 92 | 93 | class rev_3d_part(nn.Module): 94 | 95 | def __init__(self, in_ch): 96 | super(rev_3d_part, self).__init__() 97 | self.f1 = nn.Sequential( 98 | nn.Conv3d(in_ch, in_ch, 3, padding=1), 99 | nn.LeakyReLU(inplace=True), 100 | nn.Conv3d(in_ch, in_ch, 3, padding=1), 101 | ) 102 | self.g1 = nn.Sequential( 103 | nn.Conv3d(in_ch, in_ch, 3, padding=1), 104 | nn.LeakyReLU(inplace=True), 105 | nn.Conv3d(in_ch, in_ch, 3, padding=1), 106 | ) 107 | 108 | def forward(self, x): 109 | x1, x2 = split_feature(x) 110 | y1 = x1 + self.f1(x2) 111 | y2 = x2 + self.g1(y1) 112 | y = torch.cat([y1, y2], dim=1) 113 | return y 114 | 115 | def reverse(self, y): 116 | y1, y2 = split_feature(y) 117 | x2 = y2 - self.g1(y1) 118 | x1 = y1 - self.f1(x2) 119 | x = torch.cat([x1, x2], dim=1) 120 | return x 121 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Memory-Efficient Network for Large-scale Video Compressive Sensing This repository contains the code for the paper [**Memory-Efficient Network for Large-scale Video Compressive Sensing**](https://arxiv.org/abs/2103.03089) (***CVPR 2021***) by [Ziheng Cheng](https://github.com/zihengcheng), [Bo Chen](https://web.xidian.edu.cn/bchen/), Guanliang Liu, Hao Zhang, Ruiying Lu, Zhengjue Wang and [Xin Yuan](https://www.bell-labs.com/usr/x.yuan). The large-scale color video reconstruction version are added in the "color model" floder. ## Requirements ``` PyTorch > 1.3.0 numpy scipy skimage ``` ## Data The training data for RevSCI-net is the same as the previous work [BIRNAT](https://github.com/BoChenGroup/BIRNAT). Please see the above link to generate the training set. To train the RevSCI-net, should generate the data in ```train/```. ## Train Reversible training: ``` python train.py --mode reverse --num_block 18 --num_group 2 ``` Normal training (automatic differentiation routine): ``` python train.py --mode normal --num_block 18 --num_group 2 ``` If the GPU memory is enough, recommend using normal training (about 1.5x faster than reversible training). Change the number of blocks and groups to different numbers to support different models. ## Test Run ``` python test.py ``` where will evaluate the preformance on simulation data using the pre-trained model in ```model/```. ## Contact [Ziheng Cheng, Xidian University](mailto:zhcheng@stu.xidian.edu.cn "Ziheng Cheng, Xidian University") [Bo Chen, Xidian University](mailto:bchen@mail.xidian.edu.cn "Bo Chen, Xidian University") [Xin Yuan, Bell Labs](mailto:xyuan@bell-labs.com "Xin Yuan, Bell labs") -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from utils import generate_masks, time2file_name 2 | import torch.nn as nn 3 | import torch 4 | import scipy.io as scio 5 | import datetime 6 | import os 7 | import numpy as np 8 | import argparse 9 | from utils import compare_ssim, compare_psnr 10 | 11 | if not torch.cuda.is_available(): 12 | raise Exception('NO GPU!') 13 | 14 | data_path = "./train" 15 | test_path1 = "./test" 16 | 17 | parser = argparse.ArgumentParser(description='Setting, compressive rate, size, and mode') 18 | 19 | parser.add_argument('--last_train', default=100, type=int, help='pretrain model') 20 | parser.add_argument('--model_save_filename', default='./model/', type=str, help='pretrain model save folder name') 21 | parser.add_argument('--max_iter', default=100, type=int, help='max epoch') 22 | parser.add_argument('--learning_rate', default=0.0002, type=float) 23 | parser.add_argument('--batch_size', default=3, type=int) 24 | parser.add_argument('--B', default=8, type=int, help='compressive rate') 25 | parser.add_argument('--num_block', default=18, type=int, help='the number of reversible blocks') 26 | parser.add_argument('--num_group', default=2, type=int, help='the number of groups') 27 | parser.add_argument('--size', default=[256, 256], type=int, help='input image resolution') 28 | parser.add_argument('--mode', default='reverse', type=str, help='training mode: reverse or normal') 29 | 30 | 31 | args = parser.parse_args() 32 | mask, mask_s = generate_masks(data_path) 33 | 34 | loss = nn.MSELoss() 35 | loss.cuda() 36 | 37 | 38 | def test(test_path, epoch, result_path, model, args): 39 | test_list = os.listdir(test_path) 40 | psnr_cnn, ssim_cnn = torch.zeros(len(test_list)), torch.zeros(len(test_list)) 41 | for i in range(len(test_list)): 42 | pic = scio.loadmat(test_path + '/' + test_list[i]) 43 | 44 | if "orig" in pic: 45 | pic = pic['orig'] 46 | pic = pic / 255 47 | 48 | pic_gt = np.zeros([pic.shape[2] // args.B, args.B, args.size[0], args.size[1]]) 49 | for jj in range(pic.shape[2]): 50 | if jj % args.B == 0: 51 | meas_t = np.zeros([args.size[0], args.size[1]]) 52 | n = 0 53 | pic_t = pic[:, :, jj] 54 | mask_t = mask[n, :, :] 55 | 56 | mask_t = mask_t.cpu() 57 | pic_gt[jj // args.B, n, :, :] = pic_t 58 | n += 1 59 | meas_t = meas_t + np.multiply(mask_t.numpy(), pic_t) 60 | 61 | if jj == args.B - 1: 62 | meas_t = np.expand_dims(meas_t, 0) 63 | meas = meas_t 64 | elif (jj + 1) % args.B == 0 and jj != args.B - 1: 65 | meas_t = np.expand_dims(meas_t, 0) 66 | meas = np.concatenate((meas, meas_t), axis=0) 67 | meas = torch.from_numpy(meas).cuda().float() 68 | pic_gt = torch.from_numpy(pic_gt).cuda().float() 69 | 70 | meas_re = torch.div(meas, mask_s) 71 | meas_re = torch.unsqueeze(meas_re, 1) 72 | 73 | out_save1 = torch.zeros([meas.shape[0], args.B, args.size[0], args.size[1]]).cuda() 74 | with torch.no_grad(): 75 | 76 | psnr_1, ssim_1 = 0, 0 77 | for ii in range(meas.shape[0]): 78 | out_pic1 = model(meas_re[ii:ii + 1, ::], args) 79 | out_pic1 = out_pic1[0, ::] 80 | out_save1[ii, :, :, :] = out_pic1[0, :, :, :] 81 | for jj in range(args.B): 82 | out_pic_CNN = out_pic1[0, jj, :, :] 83 | gt_t = pic_gt[ii, jj, :, :] 84 | psnr_1 += compare_psnr(gt_t.cpu().numpy() * 255, out_pic_CNN.cpu().numpy() * 255) 85 | ssim_1 += compare_ssim(gt_t.cpu().numpy() * 255, out_pic_CNN.cpu().numpy() * 255) 86 | 87 | psnr_cnn[i] = psnr_1 / (meas.shape[0] * args.B) 88 | ssim_cnn[i] = ssim_1 / (meas.shape[0] * args.B) 89 | 90 | a = test_list[i] 91 | name1 = result_path + '/RevSCInet_' + a[0:len(a) - 4] + '{}_{:.4f}'.format(epoch, psnr_cnn[i]) + '.mat' 92 | out_save1 = out_save1.cpu() 93 | scio.savemat(name1, {'pic': out_save1.numpy()}) 94 | print("RevSCInet result: PSNR -- {:.4f}, SSIM -- {:.4f}".format(torch.mean(psnr_cnn), torch.mean(ssim_cnn))) 95 | 96 | 97 | if __name__ == '__main__': 98 | 99 | date_time = str(datetime.datetime.now()) 100 | date_time = time2file_name(date_time) 101 | result_path = 'recon' + '/' + date_time 102 | model_path = 'model' + '/' + date_time 103 | if not os.path.exists(result_path): 104 | os.makedirs(result_path) 105 | 106 | if args.last_train != 0: 107 | rev_net = torch.load( 108 | './model/' + args.model_save_filename + "/RevSCInet_model_epoch_{}.pth".format(args.last_train)) 109 | rev_net = rev_net.module if hasattr(rev_net, "module") else rev_net 110 | test(test_path1, args.last_train, result_path, rev_net.eval(), args) 111 | -------------------------------------------------------------------------------- /test/aerial32_cacti.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenGroup/RevSCI-net/71ac125ab47dce2e4c091936e3a659900b7da258/test/aerial32_cacti.mat -------------------------------------------------------------------------------- /test/crash32_cacti.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenGroup/RevSCI-net/71ac125ab47dce2e4c091936e3a659900b7da258/test/crash32_cacti.mat -------------------------------------------------------------------------------- /test/drop8_cacti.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenGroup/RevSCI-net/71ac125ab47dce2e4c091936e3a659900b7da258/test/drop8_cacti.mat -------------------------------------------------------------------------------- /test/kobe_cacti.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenGroup/RevSCI-net/71ac125ab47dce2e4c091936e3a659900b7da258/test/kobe_cacti.mat -------------------------------------------------------------------------------- /test/runner8_cacti.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenGroup/RevSCI-net/71ac125ab47dce2e4c091936e3a659900b7da258/test/runner8_cacti.mat -------------------------------------------------------------------------------- /test/traffic_cacti.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenGroup/RevSCI-net/71ac125ab47dce2e4c091936e3a659900b7da258/test/traffic_cacti.mat -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from dataLoadess import Imgdataset 2 | from torch.utils.data import DataLoader 3 | from models import re_3dcnn 4 | from utils import generate_masks, time2file_name 5 | import torch.optim as optim 6 | import torch.nn as nn 7 | import torch 8 | import scipy.io as scio 9 | import time 10 | import datetime 11 | import os 12 | import numpy as np 13 | import argparse 14 | import random 15 | from torch.autograd import Variable 16 | from tqdm import tqdm 17 | from skimage.metrics import peak_signal_noise_ratio as compare_psnr 18 | from skimage.metrics import structural_similarity as compare_ssim 19 | 20 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 21 | n_gpu = torch.cuda.device_count() 22 | print('The number of GPU is {}'.format(n_gpu)) 23 | 24 | data_path = "./train" 25 | test_path1 = "./test" 26 | 27 | mask, mask_s = generate_masks(data_path) 28 | 29 | parser = argparse.ArgumentParser(description='Setting, compressive rate, size, and mode') 30 | 31 | parser.add_argument('--last_train', default=0, type=int, help='pretrain model') 32 | parser.add_argument('--model_save_filename', default='', type=str, help='pretrain model save folder name') 33 | parser.add_argument('--max_iter', default=100, type=int, help='max epoch') 34 | parser.add_argument('--learning_rate', default=0.0002, type=float) 35 | parser.add_argument('--batch_size', default=3, type=int) 36 | parser.add_argument('--B', default=8, type=int, help='compressive rate') 37 | parser.add_argument('--num_block', default=18, type=int, help='the number of reversible blocks') 38 | parser.add_argument('--num_group', default=2, type=int, help='the number of groups') 39 | parser.add_argument('--size', default=[256, 256], type=int, help='input image resolution') 40 | parser.add_argument('--mode', default='normal', type=str, help='training mode: reverse or normal') 41 | 42 | 43 | args = parser.parse_args() 44 | 45 | dataset = Imgdataset(data_path) 46 | 47 | train_data_loader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=True) 48 | 49 | loss = nn.MSELoss() 50 | loss.cuda() 51 | 52 | 53 | def test(test_path, epoch, result_path, model, args): 54 | test_list = os.listdir(test_path) 55 | psnr_cnn, ssim_cnn = torch.zeros(len(test_list)), torch.zeros(len(test_list)) 56 | for i in range(len(test_list)): 57 | pic = scio.loadmat(test_path + '/' + test_list[i]) 58 | 59 | if "orig" in pic: 60 | pic = pic['orig'] 61 | pic = pic / 255 62 | 63 | pic_gt = np.zeros([pic.shape[2] // args.B, args.B, args.size[0], args.size[1]]) 64 | for jj in range(pic.shape[2]): 65 | if jj % args.B == 0: 66 | meas_t = np.zeros([args.size[0], args.size[1]]) 67 | n = 0 68 | pic_t = pic[:, :, jj] 69 | mask_t = mask[n, :, :] 70 | 71 | mask_t = mask_t.cpu() 72 | pic_gt[jj // args.B, n, :, :] = pic_t 73 | n += 1 74 | meas_t = meas_t + np.multiply(mask_t.numpy(), pic_t) 75 | 76 | if jj == args.B - 1: 77 | meas_t = np.expand_dims(meas_t, 0) 78 | meas = meas_t 79 | elif (jj + 1) % args.B == 0 and jj != args.B - 1: 80 | meas_t = np.expand_dims(meas_t, 0) 81 | meas = np.concatenate((meas, meas_t), axis=0) 82 | meas = torch.from_numpy(meas).cuda().float() 83 | pic_gt = torch.from_numpy(pic_gt).cuda().float() 84 | 85 | meas_re = torch.div(meas, mask_s) 86 | meas_re = torch.unsqueeze(meas_re, 1) 87 | 88 | out_save1 = torch.zeros([meas.shape[0], args.B, args.size[0], args.size[1]]).cuda() 89 | with torch.no_grad(): 90 | 91 | psnr_1, ssim_1 = 0, 0 92 | for ii in range(meas.shape[0]): 93 | out_pic1 = model(meas_re[ii:ii + 1, ::], args) 94 | out_pic1 = out_pic1[0, ::] 95 | out_save1[ii, :, :, :] = out_pic1[0, :, :, :] 96 | for jj in range(args.B): 97 | out_pic_CNN = out_pic1[0, jj, :, :] 98 | gt_t = pic_gt[ii, jj, :, :] 99 | psnr_1 += compare_psnr(gt_t.cpu().numpy(), out_pic_CNN.cpu().numpy()) 100 | ssim_1 += compare_ssim(gt_t.cpu().numpy(), out_pic_CNN.cpu().numpy()) 101 | 102 | psnr_cnn[i] = psnr_1 / (meas.shape[0] * args.B) 103 | ssim_cnn[i] = ssim_1 / (meas.shape[0] * args.B) 104 | 105 | a = test_list[i] 106 | name1 = result_path + '/RevSCInet_' + a[0:len(a) - 4] + '{}_{:.4f}'.format(epoch, psnr_cnn[i]) + '.mat' 107 | out_save1 = out_save1.cpu() 108 | scio.savemat(name1, {'pic': out_save1.numpy()}) 109 | print("RevSCInet result: PSNR -- {:.4f}, SSIM -- {:.4f}".format(torch.mean(psnr_cnn), torch.mean(ssim_cnn))) 110 | 111 | 112 | def train(epoch, result_path, model, args): 113 | epoch_loss = 0 114 | begin = time.time() 115 | 116 | optimizer_g = optim.Adam([{'params': model.parameters()}], lr=args.learning_rate) 117 | 118 | for iteration, batch in tqdm(enumerate(train_data_loader)): 119 | gt = Variable(batch) 120 | gt = gt.cuda().float() # [batch,8,256,256] 121 | 122 | maskt = mask.expand([gt.shape[0], args.B, args.size[0], args.size[1]]) 123 | meas = torch.mul(maskt, gt) 124 | meas = torch.sum(meas, dim=1) 125 | 126 | meas = meas.cuda().float() # [batch,256 256] 127 | 128 | meas_re = torch.div(meas, mask_s) 129 | meas_re = torch.unsqueeze(meas_re, 1) 130 | 131 | optimizer_g.zero_grad() 132 | 133 | if args.mode == 'normal': 134 | xt1 = model(meas_re, args) 135 | Loss1 = loss(torch.squeeze(xt1), gt) 136 | Loss1.backward() 137 | optimizer_g.step() 138 | elif args.mode == 'reverse': 139 | xt1, Loss1 = model.for_backward(mask, meas_re, gt, loss, optimizer_g, args) 140 | 141 | epoch_loss += Loss1.data 142 | 143 | model = model.module if hasattr(model, "module") else model 144 | test(test_path1, epoch, result_path, model.eval(), args) 145 | end = time.time() 146 | print("===> Epoch {} Complete: Avg. Loss: {:.7f}".format(epoch, epoch_loss / len(train_data_loader)), 147 | " time: {:.2f}".format(end - begin)) 148 | 149 | 150 | def checkpoint(epoch, model_path): 151 | model_out_path = './' + model_path + '/' + "RevSCInet_model_epoch_{}.pth".format(epoch) 152 | torch.save(rev_net, model_out_path) 153 | print("Checkpoint saved to {}".format(model_out_path)) 154 | 155 | 156 | def main(model, args): 157 | date_time = str(datetime.datetime.now()) 158 | date_time = time2file_name(date_time) 159 | result_path = 'recon' + '/' + date_time 160 | model_path = 'model' + '/' + date_time 161 | if not os.path.exists(result_path): 162 | os.makedirs(result_path) 163 | if not os.path.exists(model_path): 164 | os.makedirs(model_path) 165 | for epoch in range(args.last_train + 1, args.last_train + args.max_iter + 1): 166 | train(epoch, result_path, model, args) 167 | if (epoch % 5 == 0) and (epoch < 150): 168 | args.learning_rate = args.learning_rate * 0.95 169 | print(args.learning_rate) 170 | if (epoch % 5 == 0 or epoch > 50): 171 | model = model.module if hasattr(model, "module") else model 172 | checkpoint(epoch, model_path) 173 | if n_gpu > 1: 174 | model = torch.nn.DataParallel(model) 175 | 176 | 177 | if __name__ == '__main__': 178 | print(args.mode) 179 | print(args.learning_rate) 180 | 181 | rev_net = re_3dcnn(args).cuda() 182 | rev_net.mask = mask 183 | if n_gpu > 1: 184 | rev_net = torch.nn.DataParallel(rev_net) 185 | if args.last_train != 0: 186 | rev_net = torch.load( 187 | './model/' + args.model_save_filename + "/RevSCInet_model_epoch_{}.pth".format(args.last_train)) 188 | rev_net = rev_net.module if hasattr(rev_net, "module") else rev_net 189 | main(rev_net, args) 190 | -------------------------------------------------------------------------------- /train/mask.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenGroup/RevSCI-net/71ac125ab47dce2e4c091936e3a659900b7da258/train/mask.mat -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import scipy.io as scio 3 | import numpy as np 4 | import cv2 5 | import math 6 | 7 | 8 | def generate_masks(mask_path): 9 | mask = scio.loadmat(mask_path + '/mask.mat') 10 | mask = mask['mask'] 11 | mask = np.transpose(mask, [2, 0, 1]) 12 | mask_s = np.sum(mask, axis=0) 13 | index = np.where(mask_s == 0) 14 | mask_s[index] = 1 15 | mask_s = mask_s.astype(np.float32) 16 | mask = torch.from_numpy(mask) 17 | mask = mask.float() 18 | mask = mask.cuda() 19 | mask_s = torch.from_numpy(mask_s) 20 | mask_s = mask_s.float() 21 | mask_s = mask_s.cuda() 22 | return mask, mask_s 23 | 24 | 25 | def time2file_name(time): 26 | year = time[0:4] 27 | month = time[5:7] 28 | day = time[8:10] 29 | hour = time[11:13] 30 | minute = time[14:16] 31 | second = time[17:19] 32 | time_filename = year + '_' + month + '_' + day + '_' + hour + '_' + minute + '_' + second 33 | return time_filename 34 | 35 | 36 | def ssim(img1, img2): 37 | C1 = (0.01 * 255) ** 2 38 | C2 = (0.03 * 255) ** 2 39 | 40 | img1 = img1.astype(np.float64) 41 | img2 = img2.astype(np.float64) 42 | kernel = cv2.getGaussianKernel(11, 1.5) 43 | window = np.outer(kernel, kernel.transpose()) 44 | 45 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 46 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 47 | mu1_sq = mu1 ** 2 48 | mu2_sq = mu2 ** 2 49 | mu1_mu2 = mu1 * mu2 50 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq 51 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq 52 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 53 | 54 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 55 | (sigma1_sq + sigma2_sq + C2)) 56 | return ssim_map.mean() 57 | 58 | 59 | def compare_ssim(img1, img2): 60 | '''calculate SSIM 61 | the same outputs as MATLAB's 62 | img1, img2: [0, 255] 63 | ''' 64 | if not img1.shape == img2.shape: 65 | raise ValueError('Input images must have the same dimensions.') 66 | if img1.ndim == 2: 67 | return ssim(img1, img2) 68 | elif img1.ndim == 3: 69 | if img1.shape[2] == 3: 70 | ssims = [] 71 | for i in range(3): 72 | ssims.append(ssim(img1, img2)) 73 | return np.array(ssims).mean() 74 | elif img1.shape[2] == 1: 75 | return ssim(np.squeeze(img1), np.squeeze(img2)) 76 | 77 | 78 | 79 | def compare_psnr(img1, img2, shave_border=0): 80 | height, width = img1.shape[:2] 81 | img1 = img1[shave_border:height - shave_border, shave_border:width - shave_border] 82 | img2 = img2[shave_border:height - shave_border, shave_border:width - shave_border] 83 | imdff = img1 - img2 84 | rmse = math.sqrt(np.mean(imdff ** 2)) 85 | if rmse == 0: 86 | return 100 87 | return 20 * math.log10(255.0 / rmse) 88 | --------------------------------------------------------------------------------