├── README.md ├── csnet.py ├── eval.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Compressed-sensing-CSNet 2 | “DEEP NETWORKS FOR COMPRESSED IMAGE SENSING”,this is my repetition 3 | this is a CNN method of compressing sensing 4 | -------------------------------------------------------------------------------- /csnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | def tensor_concat(f,n,batch_size,ngpu): 6 | a1 = torch.FloatTensor(batch_size//ngpu,1,32,32) 7 | a1 = f[:,n,:,:] 8 | for i in range(n+1,n+8): 9 | a1 = torch.cat((a1,f[:,i,:,:]),-2) 10 | return a1 11 | 12 | def block2image(initail_map,batch_size,ngpu): 13 | f = initail_map.view(batch_size//ngpu,1024,64) 14 | f = f.view(batch_size//ngpu,32,32,64) 15 | f = f.permute(0,3,1,2) 16 | x1 = tensor_concat(f,0,batch_size//ngpu,ngpu) 17 | x2 = tensor_concat(f,8,batch_size//ngpu,ngpu) 18 | x3 = tensor_concat(f,16,batch_size//ngpu,ngpu) 19 | x4 = tensor_concat(f,24,batch_size//ngpu,ngpu) 20 | x5 = tensor_concat(f,32,batch_size//ngpu,ngpu) 21 | x6 = tensor_concat(f,40,batch_size//ngpu,ngpu) 22 | x7 = tensor_concat(f,48,batch_size//ngpu,ngpu) 23 | x8 = tensor_concat(f,56,batch_size//ngpu,ngpu) 24 | x = torch.cat((x1,x2,x3,x4,x5,x6,x7,x8),-1) 25 | x = torch.unsqueeze(x,1) 26 | x = x.permute(0,1,3,2) 27 | return x 28 | 29 | class CSNET(nn.Module): 30 | 31 | def __init__(self,channels,cr): 32 | super(CSNET,self).__init__() 33 | 34 | self.channels = channels 35 | self.fcr = 153 36 | self.base = 1 37 | 38 | self.sample = nn.Conv2d(self.channels,self.fcr,kernel_size=32,padding=0,stride=32,bias=False) 39 | self.initial = nn.Conv2d(self.fcr,3072,kernel_size=1,padding=0,stride=1,bias=False) 40 | self.pixelshuffle = nn.PixelShuffle(32) 41 | self.conv1 = nn.Conv2d(self.channels,self.base,kernel_size=3,padding=1,stride=1,bias=False) 42 | self.conv2 = nn.Conv2d(self.base,self.base,kernel_size=3,padding=1,stride=1,bias=False) 43 | self.conv3 = nn.Conv2d(self.base,self.channels,kernel_size=3,padding=1,stride=1,bias=False) 44 | self.relu = nn.ReLU(inplace=True) 45 | 46 | def forward(self,input,batch_size,ngpu): 47 | output = self.sample(input) 48 | output = self.initial(output) 49 | output = self.pixelshuffle(output) 50 | output = self.relu(self.conv1(output)) 51 | output = self.relu(self.conv2(output)) 52 | output = self.relu(self.conv2(output)) 53 | output = self.relu(self.conv2(output)) 54 | output = self.conv3(output) 55 | 56 | return output -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.optim as optim 5 | import torch.backends.cudnn as cudnn 6 | import torchvision.utils as vutils 7 | import torchvision 8 | import os 9 | import argparse 10 | import csnet 11 | from torchvision import datasets,transforms 12 | from torch.autograd import Variable 13 | from torch.nn import init 14 | 15 | parser = argparse.ArgumentParser(description='CSNet') 16 | parser.add_argument('--dataset',default='own_image') 17 | parser.add_argument('--textpath', help='path to textset', default='test_img/') 18 | parser.add_argument('--batch-size',type=int,default=1,metavar='N') 19 | parser.add_argument('--image-size',type=int,default=256,metavar='N') 20 | parser.add_argument('--cuda',action='store_true',default=True) 21 | parser.add_argument('--ngpu',type=int,default=1,metavar='N') 22 | parser.add_argument('--seed',type=int,default=1,metavar='S') 23 | parser.add_argument('--save_path',default='./test') 24 | parser.add_argument('--log-interval',type=int,default=100,metavar='N') 25 | parser.add_argument('--outf',default='./results') 26 | parser.add_argument('--cr',type=int,default=20) 27 | opt = parser.parse_args() 28 | 29 | if torch.cuda.is_available() and not opt.cuda: 30 | print("please run with GPU") 31 | # print(opt) 32 | if opt.seed is None: 33 | opt.seed = np.random.randint(1,10000) 34 | print('Random seed: ',opt.seed) 35 | np.random.seed(opt.seed) 36 | torch.manual_seed(opt.seed) 37 | if opt.cuda: 38 | torch.cuda.manual_seed(opt.seed) 39 | criterion_mse = nn.MSELoss() 40 | cudnn.benchmark = True 41 | 42 | def data_loader(): 43 | kwopt = {'num_workers': 8, 'pin_memory': True} if opt.cuda else {} 44 | transforms = torchvision.transforms.Compose([ 45 | torchvision.transforms.Resize(opt.image_size), 46 | torchvision.transforms.ToTensor(), 47 | torchvision.transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) 48 | ]) 49 | dataset = torchvision.datasets.ImageFolder(opt.textpath,transform=transforms) 50 | test_loader = torch.utils.data.DataLoader(dataset,batch_size = opt.batch_size,shuffle = False,**kwopt) 51 | return test_loader 52 | 53 | def evaluation(testloader): 54 | input, _ = testloader.__iter__().__next__() 55 | input = input.numpy() 56 | sz_input = input.shape 57 | channels = sz_input[1] 58 | img_size = sz_input[3] 59 | 60 | target = torch.FloatTensor(opt.batch_size,channels,img_size,img_size) 61 | input = torch.FloatTensor(opt.batch_size,channels,img_size,img_size) 62 | 63 | CSnet = csnet.CSNET(channels,opt.cr) 64 | 65 | if opt.cuda: 66 | device_id = [0] 67 | CSnet = nn.DataParallel(CSnet.cuda(),device_ids = device_id) 68 | criterion_mse.cuda() 69 | input = input.cuda() 70 | 71 | CSnet_path = '%s/cr%s/model/CSnet.pth' % (opt.outf,opt.cr) 72 | CSnet.load_state_dict(torch.load(CSnet_path)) 73 | CSnet.eval() 74 | 75 | csnet_mse_total = 0 76 | for idx, (input, _) in enumerate(testloader, 0): 77 | if input.size(0) != opt.batch_size: 78 | continue 79 | 80 | with torch.no_grad(): 81 | output = CSnet(input,opt.batch_size,opt.ngpu) 82 | 83 | csnet_mse = criterion_mse(output,input.cuda()) 84 | csnet_mse_total += csnet_mse 85 | 86 | if idx % 20 == 0: 87 | print('Test:[%d/%d] mse:%.4f \n' % (idx,len(testloader),csnet_mse.item())) 88 | 89 | vutils.save_image(input.data,'%s/orig_%d.bmp'% (opt.save_path,idx), padding=0) 90 | vutils.save_image(output.data,'%s/recon_%d.bmp' % (opt.save_path,idx), padding=0) 91 | 92 | print('Test: average mse: %.4f,' % (csnet_mse_total.item() / len(testloader))) 93 | 94 | def main(): 95 | test_loader = data_loader() 96 | evaluation(test_loader) 97 | 98 | if __name__ == '__main__': 99 | main() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.optim as optim 5 | import torch.backends.cudnn as cudnn 6 | import torchvision.utils as vutils 7 | import torchvision 8 | import os 9 | import argparse 10 | import csnet 11 | from torchvision import datasets,transforms 12 | from torch.autograd import Variable 13 | from torch.nn import init 14 | from tensorboardX import SummaryWriter 15 | 16 | parser = argparse.ArgumentParser(description='CSNet') 17 | parser.add_argument('--dataset',default='own_image') 18 | parser.add_argument('--trainpath',default='../Image/train/') 19 | parser.add_argument('--valpath',default='../Image/test/') 20 | parser.add_argument('--batch-size',type=int,default=1,metavar='N') 21 | parser.add_argument('--image-size',type=int,default=256,metavar='N') 22 | parser.add_argument('--start_epoch',type=int,default=0,metavar='N')#加载checkpoint即会改变 23 | parser.add_argument('--epochs',type=int,default=100,metavar='N') 24 | parser.add_argument('--lr',type=float,default=1e-3,metavar='LR') 25 | parser.add_argument('--cuda',action='store_true',default=True) 26 | parser.add_argument('--ngpu',type=int,default=1,metavar='N') 27 | parser.add_argument('--seed',type=int,default=1,metavar='S') 28 | parser.add_argument('--log-interval',type=int,default=100,metavar='N') 29 | parser.add_argument('--outf',default='./results') 30 | parser.add_argument('--cr',type=int,default=20) 31 | parser.add_argument('--resume',action='store_true',default=True) 32 | opt = parser.parse_args() 33 | 34 | if torch.cuda.is_available() and not opt.cuda: 35 | print("please run with GPU") 36 | if opt.seed is None: 37 | opt.seed = np.random.randint(1,10000) 38 | np.random.seed(opt.seed) 39 | torch.manual_seed(opt.seed) 40 | if opt.cuda: 41 | torch.cuda.manual_seed(opt.seed) 42 | cudnn.benchmark = True 43 | 44 | if not os.path.exists('%s/cr%s/model' % (opt.outf,opt.cr)): 45 | os.makedirs('%s/cr%s/model' % (opt.outf,opt.cr)) 46 | if not os.path.exists('%s/cr%s/image' % (opt.outf,opt.cr)): 47 | os.makedirs('%s/cr%s/image' % (opt.outf,opt.cr)) 48 | if not os.path.exists('%s/cr%s/log' % (opt.outf,opt.cr)): 49 | os.makedirs('%s/cr%s/log' % (opt.outf,opt.cr)) 50 | log_dir = '%s/cr%s/log' % (opt.outf,opt.cr) 51 | writer = SummaryWriter(log_dir=log_dir) 52 | 53 | def data_loader(): 54 | kwopt = {'num_workers': 4, 'pin_memory': True} if opt.cuda else {} 55 | transforms = torchvision.transforms.Compose([ 56 | torchvision.transforms.RandomCrop(opt.image_size),#随机剪裁 57 | torchvision.transforms.RandomHorizontalFlip(),#依照概率水平翻转 58 | torchvision.transforms.RandomVerticalFlip(),#依照概率垂直翻转 59 | torchvision.transforms.ToTensor(),#转化为tensor 60 | torchvision.transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) 61 | ]) 62 | train_dataset = torchvision.datasets.ImageFolder(opt.trainpath,transform=transforms) 63 | val_dataset = torchvision.datasets.ImageFolder(opt.valpath,transform=transforms) 64 | train_loader = torch.utils.data.DataLoader(train_dataset,batch_size = opt.batch_size,shuffle = True,**kwopt) 65 | val_loader = torch.utils.data.DataLoader(val_dataset,batch_size = opt.batch_size,shuffle = True,**kwopt) 66 | return train_loader, val_loader 67 | 68 | def train(start_epoch,epochs,trainloader, valloader): 69 | input, _ = trainloader.__iter__().__next__() 70 | input = input.numpy() 71 | sz_input = input.shape#128*1*256*256 72 | channels = sz_input[1]#通道数(3) 73 | img_size = sz_input[3]#256 74 | 75 | input = torch.FloatTensor(opt.batch_size,channels,img_size,img_size) 76 | 77 | CSnet = csnet.CSNET(channels,opt.cr) 78 | 79 | for m in CSnet.modules(): 80 | if isinstance(m, (nn.Conv2d)): 81 | nn.init.kaiming_normal_(m.weight, mode='fan_in',nonlinearity='relu') 82 | 83 | optimizer = optim.Adam(CSnet.parameters(),lr=opt.lr,betas=(0.9,0.999)) 84 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [4000], gamma = 0.1, last_epoch=-1) 85 | 86 | criterion_mse = nn.MSELoss() 87 | cudnn.benchmark = True 88 | 89 | if opt.cuda: 90 | device_id = [0] 91 | CSnet = nn.DataParallel(CSnet.cuda(),device_ids = device_id) 92 | criterion_mse.cuda() 93 | input = input.cuda() 94 | 95 | if opt.resume: 96 | if os.path.isfile('%s/checkpoint' % (opt.outf)): 97 | checkpoint = torch.load('%s/checkpoint' % (opt.outf)) 98 | start_epoch = checkpoint['epoch'] + 1 99 | G.load_state_dict(checkpoint['model']) 100 | optimizer_G.load_state_dict(checkpoint['optimizer']) 101 | print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch'])) 102 | else: 103 | print("=> no checkpoint found") 104 | 105 | min_loss = 100000 106 | for epoch in range(epochs): 107 | for idx, (input, _) in enumerate(trainloader, 0): 108 | if input.size(0) != opt.batch_size: 109 | continue 110 | CSnet.train() 111 | 112 | CSnet.zero_grad() 113 | output = CSnet(input,opt.batch_size,opt.ngpu) 114 | csnet_mse = criterion_mse(output,input.cuda()) 115 | csnet_mse.backward() 116 | optimizer.step() 117 | scheduler.step() 118 | 119 | if idx % opt.log_interval == 0: 120 | print('[%d/%d][%d/%d] mse:%.4f' % (epoch,epochs,idx,len(trainloader),csnet_mse.item())) 121 | 122 | writer.add_scalar('train/mse',csnet_mse, epoch) 123 | a = vutils.make_grid(input[:1],normalize=True,scale_each=True) 124 | b = vutils.make_grid(output[:1],normalize=True,scale_each=True) 125 | 126 | writer.add_image('orin',a,epoch) 127 | writer.add_image('recon',b,epoch) 128 | 129 | CSnet.eval() 130 | average_mse = val(epoch,channels,valloader,input,CSnet,criterion_mse) 131 | 132 | if average_mse < min_loss: 133 | min_loss = average_mse 134 | print("save model") 135 | torch.save(CSnet.state_dict(),'%s/cr%s/model/CSnet.pth' % (opt.outf,opt.cr)) 136 | 137 | def val(epoch,channels,valloader,input,CSnet,criterion_mse): 138 | csnet_mse_total = 0 139 | average_mse = 0 140 | for idx, (input, _) in enumerate(valloader, 0): 141 | if input.size(0) != opt.batch_size: 142 | continue 143 | 144 | with torch.no_grad(): 145 | output = CSnet(input,opt.batch_size,opt.ngpu) 146 | 147 | csnet_mse = criterion_mse(output,input.cuda()) 148 | csnet_mse_total += csnet_mse 149 | average_mse = csnet_mse_total.item() / len(valloader) 150 | 151 | if idx % 20 == 0: 152 | print('Test:[%d][%d/%d] mse:%.4f \n' % (epoch,idx,len(valloader),csnet_mse.item())) 153 | 154 | print('Test:[%d] average mse:%.4f,' % (epoch,csnet_mse_total.item() / len(valloader))) 155 | writer.add_scalar('test/mse_loss_epoch', csnet_mse_total.item() / len(valloader), epoch) 156 | 157 | return average_mse 158 | 159 | def main(): 160 | train_loader,val_loader = data_loader() 161 | train(opt.start_epoch,opt.epochs,train_loader,val_loader) 162 | 163 | if __name__ == '__main__': 164 | main() --------------------------------------------------------------------------------